Skip to content
Snippets Groups Projects

Improve ptemcee

Merged Gregory Ashton requested to merge improve-ptemcee into master
All threads resolved!
Compare and Show latest version
1 file
+ 100
69
Compare changes
  • Side-by-side
  • Inline
+ 100
69
@@ -47,7 +47,7 @@ class Ptemcee(MCMCSampler):
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
resume=True, nsamples=5000, burn_in_nact=50, thin_by_nact=1,
autocorr_c=5, safety=1, ncheck=50, nfrac=5, frac_threshold=0.01,
autocorr_c=5, safety=1, frac_threshold=0.01,
autocorr_tol=50, min_tau=1, check_point_deltaT=600,
threads=1, **kwargs):
super(Ptemcee, self).__init__(
@@ -60,12 +60,10 @@ class Ptemcee(MCMCSampler):
signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
self.resume = resume
self.ncheck = ncheck
self.autocorr_c = autocorr_c
self.safety = safety
self.burn_in_nact = burn_in_nact
self.thin_by_nact = thin_by_nact
self.nfrac = nfrac
self.frac_threshold = frac_threshold
self.nsamples = nsamples
self.autocorr_tol = autocorr_tol
@@ -94,55 +92,75 @@ class Ptemcee(MCMCSampler):
for _ in range(self.sampler_init_kwargs["nwalkers"])]
for _ in range(self.kwargs['ntemps'])]
def get_sampler(self):
def setup_sampler(self):
import ptemcee
if os.path.isfile(self.resume_file) and self.resume is True:
logger.info("Resume data {} found".format(self.resume_file))
with open(self.resume_file, "rb") as file:
data = dill.load(file)
self.sampler = data["sampler"]
self.sampler.pool = None
self.tau_list = data["tau_list"]
self.tau_list_n = data["tau_list_n"]
self.time_per_check = data["time_per_check"]
self.sampler.pool = self.pool
self.sampler.threads = self.threads
pos0 = None
logger.info("Resuming from previous run with time={}".format(self.sampler.time))
else:
# Initialize the PTSampler
self.sampler = ptemcee.Sampler(
dim=self.ndim, logl=do_nothing_function, logp=do_nothing_function,
pool=self.pool, threads=self.threads, **self.sampler_init_kwargs)
# Overwrite the _likeprior to improve performance with threads > 1
self.sampler._likeprior = LikePriorEvaluator(
self.search_parameter_keys, use_ratio=self.use_ratio)
# Set up empty lists
self.tau_list = []
self.tau_list_n = []
self.time_per_check = []
# Initialize the walker postitions
pos0 = self.get_pos0_from_prior()
return self.sampler, pos0
def run_sampler(self):
import schwimmbad
def setup_pool(self):
if self.threads > 1:
import schwimmbad
logger.info("Creating MultiPool with {} processes".format(self.threads))
with schwimmbad.MultiPool(self.threads, initializer=init,
initargs=(self.likelihood, self.priors)) as pool:
self.pool = pool
return self.run_sampler_internal()
self.pool = schwimmbad.MultiPool(
self.threads,
initializer=init,
initargs=(self.likelihood, self.priors))
else:
self.pool = None
return self.run_sampler_internal()
def run_sampler(self):
self.setup_pool()
out = self.run_sampler_internal()
if self.pool:
self.pool.close()
return out
def run_sampler_internal(self):
import emcee
sampler, pos0 = self.get_sampler()
self.time_per_check = []
self.tau_list = []
self.tau_list_n = []
sampler, pos0 = self.setup_sampler()
t0 = datetime.datetime.now()
logger.info("Starting to sample")
for (pos0, lnprob, lnlike) in sampler.sample(
pos0, **self.sampler_function_kwargs):
# Only check convergence every ncheck steps
if sampler.time % self.ncheck:
continue
# Calculate time per iteration
self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
t0 = datetime.datetime.now()
# Compute ACT tau for 0-temperature chains
samples = sampler.chain[0, :, : sampler.time, :]
@@ -163,49 +181,53 @@ class Ptemcee(MCMCSampler):
# Apply multiplicitive safety factor
tau = self.safety * np.mean(taus)
if np.isnan(tau) or np.isinf(tau):
print("{} | Unable to use tau={}".format(sampler.time, tau), flush=True)
continue
# Convert to an integer and store for plotting
tau = int(tau)
# Store for convergence checking and plotting
self.tau_list.append(tau)
self.tau_list_n.append(sampler.time)
# Convert to an integer
tau = int(np.floor(tau)) if not np.isnan(tau) else tau
if np.isnan(tau) or np.isinf(tau):
print_progress(
self.sampler,
self.time_per_check,
self.nsamples,
np.nan,
np.nan,
np.nan,
False)
continue
# Calculate the effective number of samples available
self.nburn = int(self.burn_in_nact * tau)
self.thin = int(np.max([1, self.thin_by_nact * tau]))
samples_per_check = self.ncheck * sampler.nwalkers / self.thin
samples_per_check = sampler.nwalkers / self.thin
self.nsamples_effective = int(sampler.nwalkers * (sampler.time - self.nburn) / self.thin)
# Calculate fractional change in tau from previous iteration
frac = (tau - np.array(self.tau_list)[-self.nfrac - 1: -1]) / tau
passes = frac < self.frac_threshold
# Calculate convergence boolean
converged = self.nsamples < self.nsamples_effective
converged &= np.all(passes)
if sampler.time < tau * self.autocorr_tol or tau < self.min_tau:
converged = False
tau_pass = False
# Calculate fractional change in tau from previous iteration
check_taus = np.array(self.tau_list[-tau * self.autocorr_tol:])
if not np.any(np.isnan(check_taus)):
frac = (tau - check_taus) / tau
tau_usable = np.all(frac < self.frac_threshold)
else:
tau_pass = True
tau_usable = False
# Calculate time per iteration
self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
t0 = datetime.datetime.now()
if sampler.time < tau * self.autocorr_tol or tau < self.min_tau:
tau_usable = False
# Print an update on the progress
print_progress(
self.sampler,
self.ncheck,
self.time_per_check,
self.nsamples,
self.nsamples_effective,
samples_per_check,
passes,
tau,
tau_pass,
tau_usable,
)
if converged:
@@ -219,7 +241,7 @@ class Ptemcee(MCMCSampler):
last_checkpoint_s = np.sum(self.time_per_check)
if last_checkpoint_s > self.check_point_deltaT:
self.write_current_state()
self.write_current_state(plot=self.plot)
# Check if we reached the end without converging
if sampler.time == self.sampler_function_kwargs["iterations"]:
@@ -230,7 +252,7 @@ class Ptemcee(MCMCSampler):
)
# Run a final checkpoint to update the plots and samples
self.write_current_state()
self.write_current_state(plot=self.plot)
# Get 0-likelihood samples and store in the result
samples = sampler.chain[0, :, :, :] # nwalkers, nsteps, ndim
@@ -257,26 +279,37 @@ class Ptemcee(MCMCSampler):
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
self.write_current_state()
if getattr(self, 'pool', None):
self.write_current_state(plot=False)
logger.warning("Closing pool")
self.pool.close()
sys.exit(77)
def write_current_state(self):
def write_current_state(self, plot=True):
checkpoint(self.outdir, self.label, self.nsamples_effective,
self.sampler, self.nburn, self.thin,
self.search_parameter_keys, self.resume_file, self.tau_list,
self.tau_list_n)
self.tau_list_n, self.time_per_check)
if plot:
# Generate the walkers plot diagnostic
plot_walkers(
self.sampler.chain[0, :, : self.sampler.time, :],
self.nburn, self.search_parameter_keys, self.outdir, self.label
)
# Generate the tau plot diagnostic
plot_tau(self.tau_list_n, self.tau_list, self.outdir, self.label)
def print_progress(
sampler,
ncheck,
time_per_check,
nsamples,
nsamples_effective,
samples_per_check,
passes,
tau,
tau_pass,
tau_usable,
):
# Setup acceptance string
acceptance = sampler.acceptance_fraction[0, :]
@@ -290,30 +323,31 @@ def print_progress(
ave_time_per_check = np.mean(time_per_check[-3:])
time_left = (
(nsamples - nsamples_effective)
* ave_time_per_check
/ samples_per_check
(nsamples - nsamples_effective) * ave_time_per_check / samples_per_check
)
if time_left > 0:
time_left = str(datetime.timedelta(seconds=int(time_left)))
else:
time_left = "waiting on convergence"
convergence = "".join([["F", "T"][i] for i in passes])
sampling_time = datetime.timedelta(seconds=np.sum(time_per_check))
tau_str = str(tau)
if tau_pass is False:
tau_str = tau_str + "(F)"
if tau_usable is False:
tau_str = "!{}".format(tau_str)
else:
tau_str = "={}".format(tau_str)
evals_per_check = sampler.nwalkers * sampler.ntemps * ncheck
evals_per_check = sampler.nwalkers * sampler.ntemps
ncalls = "{:1.1e}".format(sampler.time * sampler.nwalkers * sampler.ntemps)
eval_timing = "{:1.1f}ms/evl".format(1e3 * ave_time_per_check / evals_per_check)
samp_timing = "{:1.2f}ms/smp".format(1e3 * ave_time_per_check / samples_per_check)
print(
"{}| nc:{}| a0:{}| swp:{}| n:{}<{}| tau:{}| {}| {}| {}".format(
"{}| {} | nc:{}| a0:{}| swp:{}| n:{}<{}| tau{}| {}| {}".format(
sampler.time,
str(sampling_time).split(".")[0],
ncalls,
acceptance_str,
tswap_acceptance_str,
@@ -322,14 +356,14 @@ def print_progress(
tau_str,
eval_timing,
samp_timing,
convergence,
),
flush=True,
)
def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
search_parameter_keys, resume_file, tau_list, tau_list_n):
search_parameter_keys, resume_file, tau_list, tau_list_n,
time_per_check):
logger.info("Writing checkpoint and diagnostics")
ndim = sampler.dim
@@ -349,20 +383,17 @@ def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
sampler_copy._logposterior = sampler._logposterior[:, :, : sampler.time]
sampler_copy._loglikelihood = sampler._loglikelihood[:, :, : sampler.time]
sampler_copy._beta_history = sampler._beta_history[:, : sampler.time]
data = dict(sampler=sampler_copy, tau_list=tau_list, tau_list_n=tau_list_n)
data = dict(
sampler=sampler_copy, tau_list=tau_list, tau_list_n=tau_list_n,
time_per_check=time_per_check)
with open(resume_file, "wb") as file:
dill.dump(data, file, protocol=4)
del data, sampler_copy
# Generate the walkers plot diagnostic
plot_walkers(
sampler.chain[0, :, : sampler.time, :], nburn, search_parameter_keys, outdir, label
)
# Generate the tau plot diagnostic
plot_tau(tau_list_n, tau_list, outdir, label)
logger.info("Finished writing checkpoint and diagnostics")
logger.info("Finished writing checkpoint")
def plot_walkers(walkers, nburn, parameter_labels, outdir, label):
@@ -388,7 +419,7 @@ def plot_walkers(walkers, nburn, parameter_labels, outdir, label):
def plot_tau(tau_list_n, tau_list, outdir, label):
fig, ax = plt.subplots()
ax.plot(tau_list_n, tau_list, "-x")
ax.plot(tau_list_n, tau_list, "-")
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\langle \tau \rangle$")
fig.savefig("{}/{}_tau.png".format(outdir, label))
Loading