Skip to content
Snippets Groups Projects

Improve ptemcee

Merged Gregory Ashton requested to merge improve-ptemcee into master
All threads resolved!
Compare and Show latest version
1 file
+ 30
17
Compare changes
  • Side-by-side
  • Inline
@@ -352,27 +352,32 @@ class Ptemcee(MCMCSampler):
samples = sampler.chain[0, :, : sampler.time, :]
taus = []
for ii in range(sampler.nwalkers):
tau_ii = []
for jj, key in enumerate(self.search_parameter_keys):
if self.ignore_keys_for_tau and self.ignore_keys_for_tau in key:
continue
try:
taus.append(
tau_ii.append(
emcee.autocorr.integrated_time(
samples[ii, :, jj], c=self.autocorr_c, tol=0
)[0]
)
except emcee.autocorr.AutocorrError:
taus.append(np.inf)
taus.append(tau_ii)
tau = np.max(np.mean(taus, axis=0))
# Apply multiplicitive safety factor
tau = self.safety * np.mean(taus)
tau = self.safety * tau
# Store for convergence checking and plotting
self.tau_list.append(tau)
self.tau_list.append(np.mean(taus, axis=0))
self.tau_list_n.append(sampler.time)
# Convert to an integer
tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau
self.tau_int = tau_int
if np.isnan(tau_int) or np.isinf(tau_int):
print_progress(
@@ -382,7 +387,7 @@ class Ptemcee(MCMCSampler):
np.nan,
np.nan,
np.nan,
[np.nan],
np.nan,
False,
)
continue
@@ -399,11 +404,14 @@ class Ptemcee(MCMCSampler):
converged = self.nsamples < self.nsamples_effective
# Calculate fractional change in tau from previous iterations
check_taus = np.array(self.tau_list[-tau_int * self.autocorr_tau :])
check_taus = np.array(self.tau_list[-tau_int * self.autocorr_tau:])
taus_per_parameter = np.array(taus)[-1, :]
if not np.any(np.isnan(check_taus)):
frac = (tau - check_taus) / tau
frac = (taus_per_parameter - check_taus) / taus_per_parameter
max_frac = np.max(frac)
tau_usable = np.all(frac < self.frac_threshold)
else:
max_frac = np.nan
tau_usable = False
if sampler.time < tau_int * self.autocorr_tol or tau_int < self.min_tau:
@@ -417,7 +425,7 @@ class Ptemcee(MCMCSampler):
self.nsamples_effective,
samples_per_check,
tau_int,
check_taus,
max_frac,
tau_usable,
)
@@ -465,9 +473,10 @@ class Ptemcee(MCMCSampler):
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
if getattr(self, "pool", None):
if getattr(self, "pool", None) or self.threads == 1:
self.write_current_state(plot=False)
logger.warning("Closing pool")
if getattr(self, "pool", None):
logger.info("Closing pool")
self.pool.close()
sys.exit(self.exit_code)
@@ -501,8 +510,10 @@ class Ptemcee(MCMCSampler):
plot_tau(
self.tau_list_n,
self.tau_list,
self.search_parameter_keys,
self.outdir,
self.label,
self.tau_int,
self.autocorr_tau,
)
@@ -514,7 +525,7 @@ def print_progress(
nsamples_effective,
samples_per_check,
tau_int,
tau_list,
max_frac,
tau_usable,
):
# Setup acceptance string
@@ -536,7 +547,10 @@ def print_progress(
sampling_time = datetime.timedelta(seconds=np.sum(time_per_check))
tau_str = "{}:{:0.1f}->{:0.1f}".format(tau_int, np.min(tau_list), np.max(tau_list))
if max_frac >= 0:
tau_str = "{}(+{:0.1f})".format(tau_int, max_frac)
else:
tau_str = "{}({:0.1f})".format(tau_int, max_frac)
if tau_usable:
tau_str = "={}".format(tau_str)
else:
@@ -644,15 +658,14 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
plt.close(fig)
def plot_tau(tau_list_n, tau_list, outdir, label, autocorr_tau):
def plot_tau(tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau):
fig, ax = plt.subplots()
ax.plot(tau_list_n, tau_list, "-", color="C1")
check_tau_idx = -int(tau_list[-1] * autocorr_tau)
check_taus = tau_list[check_tau_idx:]
check_taus_n = tau_list_n[check_tau_idx:]
ax.plot(check_taus_n, check_taus, "-", color="C0")
for i, key in enumerate(search_parameter_keys):
ax.plot(tau_list_n, np.array(tau_list)[:, i], label=key)
ax.axvline(tau_list_n[-1] - tau * autocorr_tau)
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\langle \tau \rangle$")
ax.legend()
fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label))
plt.close(fig)
Loading