Skip to content
Snippets Groups Projects

Improve ptemcee

Merged Gregory Ashton requested to merge improve-ptemcee into master
All threads resolved!
Compare and Show latest version
1 file
+ 204
110
Compare changes
  • Side-by-side
  • Inline
+ 204
110
@@ -13,7 +13,7 @@ import pandas as pd
import matplotlib.pyplot as plt
from ..utils import logger
from .base_sampler import MCMCSampler
from .base_sampler import SamplerError, MCMCSampler
class Ptemcee(MCMCSampler):
@@ -41,15 +41,15 @@ class Ptemcee(MCMCSampler):
ntemps=20, nwalkers=200, Tmax=None, betas=None,
a=2.0, loglargs=[], logpargs=[], loglkwargs={},
logpkwargs={}, adaptation_lag=10000, adaptation_time=100, random=None,
iterations=1000, thin=1, storechain=True, adapt=False,
swap_ratios=False)
adapt=False, swap_ratios=False)
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
use_ratio=False, check_point_plot=True, skip_import_verification=False,
resume=True, nsamples=5000, burn_in_nact=50, thin_by_nact=1,
autocorr_c=5, safety=1, ncheck=50, nfrac=5, frac_threshold=0.01,
autocorr_tol=50, min_tau=1, check_point_deltaT=600,
threads=1, **kwargs):
autocorr_c=5, safety=1, frac_threshold=0.01,
autocorr_tol=50, autocorr_tau=5, min_tau=1, check_point_deltaT=600,
threads=1, exit_code=77, plot=False, store_walkers=False,
ignore_keys_for_tau="recalib", pos0="prior", **kwargs):
super(Ptemcee, self).__init__(
likelihood=likelihood, priors=priors, outdir=outdir,
label=label, use_ratio=use_ratio, plot=plot,
@@ -60,25 +60,29 @@ class Ptemcee(MCMCSampler):
signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
self.resume = resume
self.ncheck = ncheck
self.autocorr_c = autocorr_c
self.safety = safety
self.burn_in_nact = burn_in_nact
self.thin_by_nact = thin_by_nact
self.nfrac = nfrac
self.frac_threshold = frac_threshold
self.nsamples = nsamples
self.autocorr_tol = autocorr_tol
self.autocorr_tau = autocorr_tau
self.min_tau = min_tau
self.check_point_deltaT = check_point_deltaT
self.threads = threads
self.store_walkers = store_walkers
self.ignore_keys_for_tau = ignore_keys_for_tau
self.pos0 = pos0
self.check_point_plot = check_point_plot
self.resume_file = "{}/{}_checkpoint_resume.pickle".format(self.outdir, self.label)
self.exit_code = exit_code
@property
def sampler_function_kwargs(self):
keys = ['iterations', 'thin', 'storechain', 'adapt', 'swap_ratios']
keys = ['adapt', 'swap_ratios']
return {key: self.kwargs[key] for key in keys}
@property
@@ -94,30 +98,107 @@ class Ptemcee(MCMCSampler):
for _ in range(self.sampler_init_kwargs["nwalkers"])]
for _ in range(self.kwargs['ntemps'])]
def get_pos0_from_minimize(self, minimize_list=None):
logger.info("Attempting to set pos0 from minimize")
from scipy.optimize import minimize
if minimize_list is None:
minimize_list = self.search_parameter_keys
pos0 = np.zeros((self.kwargs["ntemps"], self.kwargs["nwalkers"], self.ndim))
else:
pos0 = np.array(self.get_pos0_from_prior())
likelihood_copy = copy.copy(self.likelihood)
def neg_log_like(params):
likelihood_copy.parameters.update(
{key: val for key, val in zip(minimize_list, params)})
try:
return -likelihood_copy.log_likelihood()
except RuntimeError:
return +np.inf
bounds = [(self.priors[key].minimum, self.priors[key].maximum)
for key in minimize_list]
trials = 0
success = []
while True:
draw = self.priors.sample()
likelihood_copy.parameters.update(draw)
x0 = [draw[key] for key in minimize_list]
res = minimize(
neg_log_like, x0, bounds=bounds, method='L-BFGS-B')
if res.success:
success.append(res.x)
if trials > 100:
raise SamplerError("Unable to set pos0 from minimize")
if len(success) >= 3:
break
success = np.array(success)
for i, key in enumerate(minimize_list):
pos0_min = np.min(success[:, i])
pos0_max = np.max(success[:, i])
logger.info("Initialize {} walkers from {}->{}"
.format(key, pos0_min, pos0_max))
j = self.search_parameter_keys.index(key)
pos0[:, :, j] = np.random.uniform(
pos0_min, pos0_max,
size=(self.kwargs["ntemps"], self.kwargs["nwalkers"]))
return pos0
def setup_sampler(self):
import ptemcee
if os.path.isfile(self.resume_file) and self.resume is True:
logger.info("Resume data {} found".format(self.resume_file))
with open(self.resume_file, "rb") as file:
data = dill.load(file)
self.sampler = data["sampler"]
self.sampler.pool = self.pool
self.sampler.threads = self.threads
self.tau_list = data["tau_list"]
self.tau_list_n = data["tau_list_n"]
self.time_per_check = data["time_per_check"]
self.sampler.pool = self.pool
self.sampler.threads = self.threads
pos0 = None
logger.info("Resuming from previous run with time={}".format(self.sampler.time))
else:
self.sampler = ptemcee.Sampler(
dim=self.ndim, logl=do_nothing_function, logp=do_nothing_function,
pool=self.pool, threads=self.threads, **self.sampler_init_kwargs)
self.sampler._likeprior = LikePriorEvaluator(
self.search_parameter_keys, use_ratio=self.use_ratio)
pos0 = self.get_pos0_from_prior()
# Initialize the PTSampler
if self.threads == 1:
self.sampler = ptemcee.Sampler(
dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior,
**self.sampler_init_kwargs)
else:
self.sampler = ptemcee.Sampler(
dim=self.ndim, logl=do_nothing_function, logp=do_nothing_function,
pool=self.pool, threads=self.threads, **self.sampler_init_kwargs)
self.sampler._likeprior = LikePriorEvaluator(
self.search_parameter_keys, use_ratio=self.use_ratio)
# Set up empty lists
self.tau_list = []
self.tau_list_n = []
self.time_per_check = []
# Initialize the walker postitions
pos0 = self.get_pos0()
return self.sampler, pos0
def get_pos0(self):
if isinstance(self.pos0, str) and self.pos0.lower() == "prior":
return self.get_pos0_from_prior()
elif isinstance(self.pos0, str) and self.pos0.lower() == "minimize":
return self.get_pos0_from_minimize()
elif isinstance(self.pos0, list):
return self.get_pos0_from_minimize(minimize_list=self.pos0)
else:
raise SamplerError("pos0={} not implemented".format(self.pos0))
def setup_pool(self):
if self.threads > 1:
import schwimmbad
@@ -139,24 +220,23 @@ class Ptemcee(MCMCSampler):
def run_sampler_internal(self):
import emcee
sampler, pos0 = self.setup_sampler()
self.time_per_check = []
self.tau_list = []
self.tau_list_n = []
t0 = datetime.datetime.now()
logger.info("Starting to sample")
for (pos0, lnprob, lnlike) in sampler.sample(
pos0, **self.sampler_function_kwargs):
# Only check convergence every ncheck steps
if sampler.time % self.ncheck:
continue
while True:
for (pos0, _, _) in sampler.sample(pos0, **self.sampler_function_kwargs):
pass
# Calculate time per iteration
self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
t0 = datetime.datetime.now()
# Compute ACT tau for 0-temperature chains
samples = sampler.chain[0, :, : sampler.time, :]
taus = []
for ii in range(sampler.nwalkers):
for jj, key in enumerate(self.search_parameter_keys):
if "recalib" in key:
if self.ignore_keys_for_tau and self.ignore_keys_for_tau in key:
continue
try:
taus.append(
@@ -170,52 +250,58 @@ class Ptemcee(MCMCSampler):
# Apply multiplicitive safety factor
tau = self.safety * np.mean(taus)
if np.isnan(tau) or np.isinf(tau):
print("{} | Unable to use tau={}".format(sampler.time, tau), flush=True)
continue
# Convert to an integer and store for plotting
tau = int(tau)
# Store for convergence checking and plotting
self.tau_list.append(tau)
self.tau_list_n.append(sampler.time)
# Convert to an integer
tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau
if np.isnan(tau_int) or np.isinf(tau_int):
print_progress(
self.sampler,
self.time_per_check,
self.nsamples,
np.nan,
np.nan,
np.nan,
[np.nan],
False)
continue
# Calculate the effective number of samples available
self.nburn = int(self.burn_in_nact * tau)
self.thin = int(np.max([1, self.thin_by_nact * tau]))
samples_per_check = self.ncheck * sampler.nwalkers / self.thin
self.nburn = int(self.burn_in_nact * tau_int)
self.thin = int(np.max([1, self.thin_by_nact * tau_int]))
samples_per_check = sampler.nwalkers / self.thin
self.nsamples_effective = int(sampler.nwalkers * (sampler.time - self.nburn) / self.thin)
# Calculate fractional change in tau from previous iteration
frac = (tau - np.array(self.tau_list)[-self.nfrac - 1: -1]) / tau
passes = frac < self.frac_threshold
# Calculate convergence boolean
converged = self.nsamples < self.nsamples_effective
converged &= np.all(passes)
if sampler.time < tau * self.autocorr_tol or tau < self.min_tau:
converged = False
tau_pass = False
# Calculate fractional change in tau from previous iterations
check_taus = np.array(self.tau_list[-tau_int * self.autocorr_tau:])
if not np.any(np.isnan(check_taus)):
frac = (tau - check_taus) / tau
tau_usable = np.all(frac < self.frac_threshold)
else:
tau_pass = True
tau_usable = False
# Calculate time per iteration
self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
t0 = datetime.datetime.now()
if sampler.time < tau_int * self.autocorr_tol or tau_int < self.min_tau:
tau_usable = False
# Print an update on the progress
print_progress(
self.sampler,
self.ncheck,
self.time_per_check,
self.nsamples,
self.nsamples_effective,
samples_per_check,
passes,
tau,
tau_pass,
tau_int,
check_taus,
tau_usable,
)
if converged:
if converged and tau_usable:
logger.info("Finished sampling")
break
@@ -226,22 +312,13 @@ class Ptemcee(MCMCSampler):
last_checkpoint_s = np.sum(self.time_per_check)
if last_checkpoint_s > self.check_point_deltaT:
self.write_current_state()
# Check if we reached the end without converging
if sampler.time == self.sampler_function_kwargs["iterations"]:
raise ValueError(
"Failed to reach convergence by iterations={}".format(
self.sampler_function_kwargs["iterations"]
)
)
self.write_current_state(plot=self.check_point_plot)
# Run a final checkpoint to update the plots and samples
self.write_current_state()
self.write_current_state(plot=self.check_point_plot)
# Get 0-likelihood samples and store in the result
samples = sampler.chain[0, :, :, :] # nwalkers, nsteps, ndim
self.result.walkers = samples[:, :sampler.time:, :]
self.result.samples = (
samples[:, self.nburn: sampler.time:self.thin, :].reshape((-1, self.ndim)))
loglikelihood = sampler.loglikelihood[
@@ -249,7 +326,8 @@ class Ptemcee(MCMCSampler):
] # nwalkers, nsteps
self.result.log_likelihood_evaluations = loglikelihood.reshape((-1))
self.result.walkers = self.sampler.chain
if self.store_walkers:
self.result.walkers = self.sampler.chain
self.result.nburn = self.nburn
log_evidence, log_evidence_err = compute_evidence(
@@ -264,28 +342,40 @@ class Ptemcee(MCMCSampler):
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
if self.pool:
if getattr(self, 'pool', None):
self.write_current_state(plot=False)
logger.warning("Closing pool")
self.pool.close()
self.write_current_state()
sys.exit(77)
sys.exit(self.exit_code)
def write_current_state(self):
def write_current_state(self, plot=True):
checkpoint(self.outdir, self.label, self.nsamples_effective,
self.sampler, self.nburn, self.thin,
self.search_parameter_keys, self.resume_file, self.tau_list,
self.tau_list_n)
self.tau_list_n, self.time_per_check)
if plot and not np.isnan(self.nburn):
# Generate the walkers plot diagnostic
plot_walkers(
self.sampler.chain[0, :, : self.sampler.time, :],
self.nburn, self.thin, self.search_parameter_keys, self.outdir,
self.label
)
# Generate the tau plot diagnostic
plot_tau(self.tau_list_n, self.tau_list, self.outdir, self.label,
self.autocorr_tau)
def print_progress(
sampler,
ncheck,
time_per_check,
nsamples,
nsamples_effective,
samples_per_check,
passes,
tau,
tau_pass,
tau_int,
tau_list,
tau_usable,
):
# Setup acceptance string
acceptance = sampler.acceptance_fraction[0, :]
@@ -299,30 +389,31 @@ def print_progress(
ave_time_per_check = np.mean(time_per_check[-3:])
time_left = (
(nsamples - nsamples_effective)
* ave_time_per_check
/ samples_per_check
(nsamples - nsamples_effective) * ave_time_per_check / samples_per_check
)
if time_left > 0:
time_left = str(datetime.timedelta(seconds=int(time_left)))
else:
time_left = "waiting on convergence"
convergence = "".join([["F", "T"][i] for i in passes])
sampling_time = datetime.timedelta(seconds=np.sum(time_per_check))
tau_str = str(tau)
if tau_pass is False:
tau_str = tau_str + "(F)"
tau_str = "{}:{:0.1f}->{:0.1f}".format(tau_int, np.min(tau_list), np.max(tau_list))
if tau_usable:
tau_str = "={}".format(tau_str)
else:
tau_str = "!{}".format(tau_str)
evals_per_check = sampler.nwalkers * sampler.ntemps * ncheck
evals_per_check = sampler.nwalkers * sampler.ntemps
ncalls = "{:1.1e}".format(sampler.time * sampler.nwalkers * sampler.ntemps)
eval_timing = "{:1.1f}ms/evl".format(1e3 * ave_time_per_check / evals_per_check)
samp_timing = "{:1.2f}ms/smp".format(1e3 * ave_time_per_check / samples_per_check)
eval_timing = "{:1.1f}ms/ev".format(1e3 * ave_time_per_check / evals_per_check)
samp_timing = "{:1.1f}ms/sm".format(1e3 * ave_time_per_check / samples_per_check)
print(
"{}| nc:{}| a0:{}| swp:{}| n:{}<{}| tau:{}| {}| {}| {}".format(
"{}| {} | nc:{}| a0:{}| swp:{}| n:{}<{}| tau{}| {}| {}".format(
sampler.time,
str(sampling_time).split(".")[0],
ncalls,
acceptance_str,
tswap_acceptance_str,
@@ -331,14 +422,14 @@ def print_progress(
tau_str,
eval_timing,
samp_timing,
convergence,
),
flush=True,
)
def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
search_parameter_keys, resume_file, tau_list, tau_list_n):
search_parameter_keys, resume_file, tau_list, tau_list_n,
time_per_check):
logger.info("Writing checkpoint and diagnostics")
ndim = sampler.dim
@@ -358,49 +449,52 @@ def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
sampler_copy._logposterior = sampler._logposterior[:, :, : sampler.time]
sampler_copy._loglikelihood = sampler._loglikelihood[:, :, : sampler.time]
sampler_copy._beta_history = sampler._beta_history[:, : sampler.time]
data = dict(sampler=sampler_copy, tau_list=tau_list, tau_list_n=tau_list_n)
data = dict(
sampler=sampler_copy, tau_list=tau_list, tau_list_n=tau_list_n,
time_per_check=time_per_check)
with open(resume_file, "wb") as file:
dill.dump(data, file, protocol=4)
del data, sampler_copy
logger.info("Finished writing checkpoint")
# Generate the walkers plot diagnostic
plot_walkers(
sampler.chain[0, :, : sampler.time, :], nburn, search_parameter_keys, outdir, label
)
# Generate the tau plot diagnostic
plot_tau(tau_list_n, tau_list, outdir, label)
logger.info("Finished writing checkpoint and diagnostics")
def plot_walkers(walkers, nburn, parameter_labels, outdir, label):
def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
""" Method to plot the trace of the walkers in an ensemble MCMC plot """
nwalkers, nsteps, ndim = walkers.shape
idxs = np.arange(nsteps)
fig, axes = plt.subplots(nrows=ndim, figsize=(6, 3 * ndim))
scatter_kwargs = dict(lw=0, marker="o", markersize=1, alpha=0.05)
for i, ax in enumerate(axes):
fig, axes = plt.subplots(nrows=ndim, ncols=2, figsize=(8, 3 * ndim))
scatter_kwargs = dict(lw=0, marker="o", markersize=1)
# Plot the burn-in
for i, (ax, axh) in enumerate(axes):
ax.plot(
idxs[: nburn + 1], walkers[:, : nburn + 1, i].T, color="r", **scatter_kwargs
idxs[: nburn + 1], walkers[:, : nburn + 1, i].T, color="C1", **scatter_kwargs
)
ax.set_ylabel(parameter_labels[i])
for i, ax in enumerate(axes):
ax.plot(idxs[nburn:], walkers[:, nburn:, i].T, color="k", **scatter_kwargs)
# Plot the thinned posterior samples
for i, (ax, axh) in enumerate(axes):
ax.plot(idxs[nburn::thin], walkers[:, nburn::thin, i].T, color="C0", **scatter_kwargs)
axh.hist(walkers[:, nburn::thin, i].reshape((-1)), bins=50, alpha=0.8)
axh.set_xlabel(parameter_labels[i])
ax.set_ylabel(parameter_labels[i])
fig.tight_layout()
filename = "{}/{}_traceplot.png".format(outdir, label)
filename = "{}/{}_checkpoint_trace.png".format(outdir, label)
fig.savefig(filename)
plt.close(fig)
def plot_tau(tau_list_n, tau_list, outdir, label):
def plot_tau(tau_list_n, tau_list, outdir, label, autocorr_tau):
fig, ax = plt.subplots()
ax.plot(tau_list_n, tau_list, "-x")
ax.plot(tau_list_n, tau_list, "-", color='C1')
check_tau_idx = -int(tau_list[-1] * autocorr_tau)
check_taus = tau_list[check_tau_idx:]
check_taus_n = tau_list_n[check_tau_idx:]
ax.plot(check_taus_n, check_taus, "-", color='C0')
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\langle \tau \rangle$")
fig.savefig("{}/{}_tau.png".format(outdir, label))
fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label))
plt.close(fig)
Loading