Skip to content
Snippets Groups Projects
Commit 16160d9c authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Improve plot and tau printing

parent fcceee09
No related branches found
No related tags found
1 merge request!750Improve ptemcee
......@@ -47,7 +47,7 @@ class Ptemcee(MCMCSampler):
use_ratio=False, check_point_plot=True, skip_import_verification=False,
resume=True, nsamples=5000, burn_in_nact=50, thin_by_nact=1,
autocorr_c=5, safety=1, frac_threshold=0.01,
autocorr_tol=50, min_tau=1, check_point_deltaT=600,
autocorr_tol=50, autocorr_tau=5, min_tau=1, check_point_deltaT=600,
threads=1, exit_code=77, plot=False, store_walkers=False,
**kwargs):
super(Ptemcee, self).__init__(
......@@ -67,6 +67,7 @@ class Ptemcee(MCMCSampler):
self.frac_threshold = frac_threshold
self.nsamples = nsamples
self.autocorr_tol = autocorr_tol
self.autocorr_tau = autocorr_tau
self.min_tau = min_tau
self.check_point_deltaT = check_point_deltaT
......@@ -190,9 +191,9 @@ class Ptemcee(MCMCSampler):
self.tau_list_n.append(sampler.time)
# Convert to an integer
tau = int(np.floor(tau)) if not np.isnan(tau) else tau
tau_int = int(np.floor(tau)) if not np.isnan(tau) else tau
if np.isnan(tau) or np.isinf(tau):
if np.isnan(tau_int) or np.isinf(tau_int):
print_progress(
self.sampler,
self.time_per_check,
......@@ -200,27 +201,28 @@ class Ptemcee(MCMCSampler):
np.nan,
np.nan,
np.nan,
[np.nan],
False)
continue
# Calculate the effective number of samples available
self.nburn = int(self.burn_in_nact * tau)
self.thin = int(np.max([1, self.thin_by_nact * tau]))
self.nburn = int(self.burn_in_nact * tau_int)
self.thin = int(np.max([1, self.thin_by_nact * tau_int]))
samples_per_check = sampler.nwalkers / self.thin
self.nsamples_effective = int(sampler.nwalkers * (sampler.time - self.nburn) / self.thin)
# Calculate convergence boolean
converged = self.nsamples < self.nsamples_effective
# Calculate fractional change in tau from previous iteration
check_taus = np.array(self.tau_list[-tau * self.autocorr_tol:])
# Calculate fractional change in tau from previous iterations
check_taus = np.array(self.tau_list[-tau_int * self.autocorr_tau:])
if not np.any(np.isnan(check_taus)):
frac = (tau - check_taus) / tau
tau_usable = np.all(frac < self.frac_threshold)
else:
tau_usable = False
if sampler.time < tau * self.autocorr_tol or tau < self.min_tau:
if sampler.time < tau_int * self.autocorr_tol or tau_int < self.min_tau:
tau_usable = False
# Print an update on the progress
......@@ -230,7 +232,8 @@ class Ptemcee(MCMCSampler):
self.nsamples,
self.nsamples_effective,
samples_per_check,
tau,
tau_int,
check_taus,
tau_usable,
)
......@@ -291,12 +294,13 @@ class Ptemcee(MCMCSampler):
# Generate the walkers plot diagnostic
plot_walkers(
self.sampler.chain[0, :, : self.sampler.time, :],
self.nburn, self.search_parameter_keys, self.outdir, self.label
self.nburn, self.thin, self.search_parameter_keys, self.outdir,
self.label
)
# Generate the tau plot diagnostic
plot_tau(self.tau_list_n, self.tau_list, self.outdir, self.label,
self.autocorr_tol)
self.autocorr_tau)
def print_progress(
......@@ -305,7 +309,8 @@ def print_progress(
nsamples,
nsamples_effective,
samples_per_check,
tau,
tau_int,
tau_list,
tau_usable,
):
# Setup acceptance string
......@@ -329,7 +334,7 @@ def print_progress(
sampling_time = datetime.timedelta(seconds=np.sum(time_per_check))
tau_str = str(tau)
tau_str = "{}:{:0.1f}->{:0.1f}".format(tau_int, np.min(tau_list), np.max(tau_list))
if tau_usable:
tau_str = "={}".format(tau_str)
else:
......@@ -393,37 +398,41 @@ def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
logger.info("Finished writing checkpoint")
def plot_walkers(walkers, nburn, parameter_labels, outdir, label):
def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
""" Method to plot the trace of the walkers in an ensemble MCMC plot """
nwalkers, nsteps, ndim = walkers.shape
idxs = np.arange(nsteps)
fig, axes = plt.subplots(nrows=ndim, figsize=(6, 3 * ndim))
scatter_kwargs = dict(lw=0, marker="o", markersize=1, alpha=0.05)
for i, ax in enumerate(axes):
fig, axes = plt.subplots(nrows=ndim, ncols=2, figsize=(8, 3 * ndim))
scatter_kwargs = dict(lw=0, marker="o", markersize=1)
# Plot the burn-in
for i, (ax, axh) in enumerate(axes):
ax.plot(
idxs[: nburn + 1], walkers[:, : nburn + 1, i].T, color="r", **scatter_kwargs
idxs[: nburn + 1], walkers[:, : nburn + 1, i].T, color="C1", **scatter_kwargs
)
ax.set_ylabel(parameter_labels[i])
for i, ax in enumerate(axes):
ax.plot(idxs[nburn:], walkers[:, nburn:, i].T, color="k", **scatter_kwargs)
# Plot the thinned posterior samples
for i, (ax, axh) in enumerate(axes):
ax.plot(idxs[nburn::thin], walkers[:, nburn::thin, i].T, color="C0", **scatter_kwargs)
axh.hist(walkers[:, nburn::thin, i].reshape((-1)), bins=50, alpha=0.8)
axh.set_xlabel(parameter_labels[i])
ax.set_ylabel(parameter_labels[i])
fig.tight_layout()
filename = "{}/{}_traceplot.png".format(outdir, label)
filename = "{}/{}_checkpoint_trace.png".format(outdir, label)
fig.savefig(filename)
plt.close(fig)
def plot_tau(tau_list_n, tau_list, outdir, label, autocorr_tol):
def plot_tau(tau_list_n, tau_list, outdir, label, autocorr_tau):
fig, ax = plt.subplots()
ax.plot(tau_list_n, tau_list, "-")
check_tau_idx = -int(tau_list[-1] * autocorr_tol)
ax.plot(tau_list_n, tau_list, "-", color='C1')
check_tau_idx = -int(tau_list[-1] * autocorr_tau)
check_taus = tau_list[check_tau_idx:]
check_taus_n = tau_list_n[check_tau_idx:]
ax.plot(check_taus_n, check_taus, "--")
ax.plot(check_taus_n, check_taus, "-", color='C0')
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\langle \tau \rangle$")
fig.savefig("{}/{}_tau.png".format(outdir, label))
fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label))
plt.close(fig)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment