Skip to content
Snippets Groups Projects

Improve ptemcee

Merged Gregory Ashton requested to merge improve-ptemcee into master
All threads resolved!
Compare and
1 file
+ 426
102
Compare changes
  • Side-by-side
  • Inline
+ 426
102
from __future__ import absolute_import, division, print_function
import os
from shutil import copyfile
import datetime
import copy
import signal
import sys
import time
import dill
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from ..utils import logger, get_progress_bar
from . import Emcee
from .base_sampler import SamplerError
from ..utils import logger
from .base_sampler import SamplerError, MCMCSampler
class Ptemcee(Emcee):
class Ptemcee(MCMCSampler):
"""bilby wrapper ptemcee (https://github.com/willvousden/ptemcee)
All positional and keyword arguments (i.e., the args and kwargs) passed to
@@ -32,27 +36,46 @@ class Ptemcee(Emcee):
The number of temperatures used by ptemcee
"""
# Arguments used by ptemcee
default_kwargs = dict(
ntemps=2, nwalkers=500, Tmax=None, betas=None, threads=1, pool=None,
a=2.0, loglargs=[], logpargs=[], loglkwargs={}, logpkwargs={},
adaptation_lag=10000, adaptation_time=100, random=None, iterations=100,
thin=1, storechain=True, adapt=True, swap_ratios=False)
ntemps=20, nwalkers=200, Tmax=None, betas=None,
a=2.0, loglargs=[], logpargs=[], loglkwargs={},
logpkwargs={}, adaptation_lag=10000, adaptation_time=100, random=None,
iterations=1000, thin=1, storechain=True, adapt=False,
swap_ratios=False)
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
nburn=None, burn_in_fraction=0.25, burn_in_act=3, resume=True,
**kwargs):
resume=True, nsamples=5000, burn_in_nact=50, thin_by_nact=1,
autocorr_c=5, safety=1, ncheck=50, nfrac=5, frac_threshold=0.01,
autocorr_tol=50, min_tau=1, check_point_deltaT=600,
threads=1, **kwargs):
super(Ptemcee, self).__init__(
likelihood=likelihood, priors=priors, outdir=outdir,
label=label, use_ratio=use_ratio, plot=plot,
skip_import_verification=skip_import_verification,
nburn=nburn, burn_in_fraction=burn_in_fraction,
burn_in_act=burn_in_act, resume=resume, **kwargs)
skip_import_verification=skip_import_verification, **kwargs)
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
self.resume = resume
self.ncheck = ncheck
self.autocorr_c = autocorr_c
self.safety = safety
self.burn_in_nact = burn_in_nact
self.thin_by_nact = thin_by_nact
self.nfrac = nfrac
self.frac_threshold = frac_threshold
self.nsamples = nsamples
self.autocorr_tol = autocorr_tol
self.min_tau = min_tau
self.check_point_deltaT = check_point_deltaT
self.threads = threads
self.resume_file = "{}/{}_checkpoint_resume.pickle".format(self.outdir, self.label)
@property
def sampler_function_kwargs(self):
keys = ['iterations', 'thin', 'storechain', 'adapt', 'swap_ratios']
@@ -64,104 +87,405 @@ class Ptemcee(Emcee):
for key, value in self.kwargs.items()
if key not in self.sampler_function_kwargs}
@property
def ntemps(self):
return self.kwargs['ntemps']
@property
def sampler_chain(self):
nsteps = self._previous_iterations
return self.sampler.chain[:, :, :nsteps, :]
def get_pos0_from_prior(self):
""" for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim """
logger.info("Generating pos0 samples")
return [[self.get_random_draw_from_prior()
for _ in range(self.sampler_init_kwargs["nwalkers"])]
for _ in range(self.kwargs['ntemps'])]
def _initialise_sampler(self):
def get_sampler(self):
import ptemcee
self._sampler = ptemcee.Sampler(
dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior,
**self.sampler_init_kwargs)
self._init_chain_file()
def print_tswap_acceptance_fraction(self):
logger.info("Sampler per-chain tswap acceptance fraction = {}".format(
self.sampler.tswap_acceptance_fraction))
def write_chains_to_file(self, pos, loglike, logpost):
chain_file = self.checkpoint_info.chain_file
temp_chain_file = chain_file + '.temp'
if os.path.isfile(chain_file):
try:
copyfile(chain_file, temp_chain_file)
except OSError:
logger.warning("Failed to write temporary chain file {}".format(temp_chain_file))
with open(temp_chain_file, "a") as ff:
loglike = np.squeeze(loglike[0, :])
logprior = np.squeeze(logpost[0, :]) - loglike
for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)):
line = np.concatenate((point, [logl, logp]))
ff.write(self.checkpoint_info.chain_template.format(ii, *line))
os.rename(temp_chain_file, chain_file)
if os.path.isfile(self.resume_file) and self.resume is True:
logger.info("Resume data {} found".format(self.resume_file))
with open(self.resume_file, "rb") as file:
data = dill.load(file)
self.sampler = data["sampler"]
self.sampler.pool = None
self.tau_list = data["tau_list"]
self.tau_list_n = data["tau_list_n"]
pos0 = None
logger.info("Resuming from previous run with time={}".format(self.sampler.time))
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
sys.exit(130)
else:
self.sampler = ptemcee.Sampler(
dim=self.ndim, logl=do_nothing_function, logp=do_nothing_function,
pool=self.pool, threads=self.threads, **self.sampler_init_kwargs)
self.sampler._likeprior = LikePriorEvaluator(
self.likelihood, self.priors, self.search_parameter_keys,
use_ratio=self.use_ratio)
pos0 = self.get_pos0_from_prior()
@property
def _previous_iterations(self):
""" Returns the number of iterations that the sampler has saved
return self.sampler, pos0
This is used when loading in a sampler from a pickle file to figure out
how much of the run has already been completed
"""
return self.sampler.time
def run_sampler(self):
import schwimmbad
if self.threads > 1:
logger.info("Creating MultiPool with {} processes".format(self.threads))
with schwimmbad.MultiPool(self.threads) as pool:
self.pool = pool
return self.run_sampler_internal()
else:
self.pool = None
return self.run_sampler_internal()
def _draw_pos0_from_prior(self):
# for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim
return [[self.get_random_draw_from_prior()
for _ in range(self.nwalkers)]
for _ in range(self.kwargs['ntemps'])]
def run_sampler_internal(self):
import emcee
sampler, pos0 = self.get_sampler()
self.time_per_check = []
self.tau_list = []
self.tau_list_n = []
@property
def _pos0_shape(self):
return (self.ntemps, self.nwalkers, self.ndim)
t0 = datetime.datetime.now()
logger.info("Starting to sample")
for (pos0, lnprob, lnlike) in sampler.sample(
pos0, **self.sampler_function_kwargs):
# Only check convergence every ncheck steps
if sampler.time % self.ncheck:
continue
def _set_pos0_for_resume(self):
self.pos0 = None
# Compute ACT tau for 0-temperature chains
samples = sampler.chain[0, :, : sampler.time, :]
taus = []
for ii in range(sampler.nwalkers):
for jj, key in enumerate(self.search_parameter_keys):
if "recalib" in key:
continue
try:
taus.append(
emcee.autocorr.integrated_time(
samples[ii, :, jj], c=self.autocorr_c, tol=0
)[0]
)
except emcee.autocorr.AutocorrError:
taus.append(np.inf)
def run_sampler(self):
tqdm = get_progress_bar()
sampler_function_kwargs = self.sampler_function_kwargs
iterations = sampler_function_kwargs.pop('iterations')
iterations -= self._previous_iterations
# main iteration loop
for pos, logpost, loglike in tqdm(
self.sampler.sample(self.pos0, iterations=iterations,
**sampler_function_kwargs),
total=iterations):
self.write_chains_to_file(pos, loglike, logpost)
self.checkpoint()
self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim)))
self.result.sampler_output = np.nan
self.print_nburn_logging_info()
self.print_tswap_acceptance_fraction()
# Apply multiplicitive safety factor
tau = self.safety * np.mean(taus)
if np.isnan(tau) or np.isinf(tau):
print("{} | Unable to use tau={}".format(sampler.time, tau), flush=True)
continue
# Convert to an integer and store for plotting
tau = int(tau)
self.tau_list.append(tau)
self.tau_list_n.append(sampler.time)
# Calculate the effective number of samples available
self.nburn = int(self.burn_in_nact * tau)
self.thin = int(np.max([1, self.thin_by_nact * tau]))
samples_per_check = self.ncheck * sampler.nwalkers / self.thin
self.nsamples_effective = int(sampler.nwalkers * (sampler.time - self.nburn) / self.thin)
# Calculate fractional change in tau from previous iteration
frac = (tau - np.array(self.tau_list)[-self.nfrac - 1: -1]) / tau
passes = frac < self.frac_threshold
# Calculate convergence boolean
converged = self.nsamples < self.nsamples_effective
converged &= np.all(passes)
if sampler.time < tau * self.autocorr_tol or tau < self.min_tau:
converged = False
tau_pass = False
else:
tau_pass = True
# Calculate time per iteration
self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
t0 = datetime.datetime.now()
# Print an update on the progress
print_progress(
self.sampler,
self.ncheck,
self.time_per_check,
self.nsamples,
self.nsamples_effective,
samples_per_check,
passes,
tau,
tau_pass,
)
if converged:
logger.info("Finished sampling")
break
# If a checkpoint is due, checkpoint
if os.path.isfile(self.resume_file):
last_checkpoint_s = time.time() - os.path.getmtime(self.resume_file)
else:
last_checkpoint_s = np.sum(self.time_per_check)
if last_checkpoint_s > self.check_point_deltaT:
self.write_current_state()
# Check if we reached the end without converging
if sampler.time == self.sampler_function_kwargs["iterations"]:
raise ValueError(
"Failed to reach convergence by iterations={}".format(
self.sampler_function_kwargs["iterations"]
)
)
# Run a final checkpoint to update the plots and samples
self.write_current_state()
# Get 0-likelihood samples and store in the result
samples = sampler.chain[0, :, :, :] # nwalkers, nsteps, ndim
self.result.walkers = samples[:, :sampler.time:, :]
self.result.samples = (
samples[:, self.nburn: sampler.time:self.thin, :].reshape((-1, self.ndim)))
loglikelihood = sampler.loglikelihood[
0, :, self.nburn:sampler.time:self.thin
] # nwalkers, nsteps
self.result.log_likelihood_evaluations = loglikelihood.reshape((-1))
self.result.walkers = self.sampler.chain
self.result.nburn = self.nburn
if self.result.nburn > self.nsteps:
raise SamplerError(
"The run has finished, but the chain is not burned in: "
"`nburn < nsteps`. Try increasing the number of steps.")
self.calc_likelihood_count()
self.result.samples = self.sampler.chain[0, :, self.nburn:, :].reshape(
(-1, self.ndim))
self.result.walkers = self.sampler.chain[0, :, :, :]
n_samples = self.nwalkers * self.nburn
self.result.log_likelihood_evaluations = self.stored_loglike[n_samples:]
self.result.log_prior_evaluations = self.stored_logprior[n_samples:]
self.result.betas = self.sampler.betas
self.result.log_evidence, self.result.log_evidence_err =\
self.sampler.log_evidence_estimate(
self.sampler.loglikelihood, self.nburn / self.nsteps)
log_evidence, log_evidence_err = compute_evidence(
sampler, self.outdir, self.label, self.nburn, self.thin
)
self.result.log_evidence = log_evidence
self.result.log_evidence_err = log_evidence_err
self.result.sampling_time = datetime.timedelta(seconds=np.sum(self.time_per_check))
return self.result
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
self.write_current_state()
sys.exit(77)
def write_current_state(self):
checkpoint(self.outdir, self.label, self.nsamples_effective,
self.sampler, self.nburn, self.thin,
self.search_parameter_keys, self.resume_file, self.tau_list,
self.tau_list_n)
def print_progress(
sampler,
ncheck,
time_per_check,
nsamples,
nsamples_effective,
samples_per_check,
passes,
tau,
tau_pass,
):
# Setup acceptance string
acceptance = sampler.acceptance_fraction[0, :]
acceptance_str = "{:1.2f}->{:1.2f}".format(np.min(acceptance), np.max(acceptance))
# Setup tswap acceptance string
tswap_acceptance_fraction = sampler.tswap_acceptance_fraction
tswap_acceptance_str = "{:1.2f}->{:1.2f}".format(
np.min(tswap_acceptance_fraction), np.max(tswap_acceptance_fraction)
)
ave_time_per_check = np.mean(time_per_check[-3:])
time_left = (
(nsamples - nsamples_effective)
* ave_time_per_check
/ samples_per_check
)
if time_left > 0:
time_left = str(datetime.timedelta(seconds=int(time_left)))
else:
time_left = "waiting on convergence"
convergence = "".join([["F", "T"][i] for i in passes])
tau_str = str(tau)
if tau_pass is False:
tau_str = tau_str + "(F)"
evals_per_check = sampler.nwalkers * sampler.ntemps * ncheck
ncalls = "{:1.1e}".format(sampler.time * sampler.nwalkers * sampler.ntemps)
eval_timing = "{:1.1f}ms/evl".format(1e3 * ave_time_per_check / evals_per_check)
samp_timing = "{:1.2f}ms/smp".format(1e3 * ave_time_per_check / samples_per_check)
print(
"{}| nc:{}| a0:{}| swp:{}| n:{}<{}| tau:{}| {}| {}| {}".format(
sampler.time,
ncalls,
acceptance_str,
tswap_acceptance_str,
nsamples_effective,
nsamples,
tau_str,
eval_timing,
samp_timing,
convergence,
),
flush=True,
)
def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
search_parameter_keys, resume_file, tau_list, tau_list_n):
logger.info("Writing checkpoint and diagnostics")
ndim = sampler.dim
# Store the samples if possible
if nsamples_effective > 0:
filename = "{}/{}_samples.txt".format(outdir, label)
samples = sampler.chain[0, :, nburn:sampler.time:thin, :].reshape(
(-1, ndim)
)
df = pd.DataFrame(samples, columns=search_parameter_keys)
df.to_csv(filename, index=False, header=True, sep=" ")
# Pickle the resume artefacts
sampler_copy = copy.copy(sampler)
del sampler_copy.pool
sampler_copy._chain = sampler._chain[:, :, : sampler.time, :]
sampler_copy._logposterior = sampler._logposterior[:, :, : sampler.time]
sampler_copy._loglikelihood = sampler._loglikelihood[:, :, : sampler.time]
sampler_copy._beta_history = sampler._beta_history[:, : sampler.time]
data = dict(sampler=sampler_copy, tau_list=tau_list, tau_list_n=tau_list_n)
with open(resume_file, "wb") as file:
dill.dump(data, file, protocol=4)
del data, sampler_copy
# Generate the walkers plot diagnostic
plot_walkers(
sampler.chain[0, :, : sampler.time, :], nburn, search_parameter_keys, outdir, label
)
# Generate the tau plot diagnostic
plot_tau(tau_list_n, tau_list, outdir, label)
logger.info("Finished writing checkpoint and diagnostics")
def plot_walkers(walkers, nburn, parameter_labels, outdir, label):
""" Method to plot the trace of the walkers in an ensemble MCMC plot """
nwalkers, nsteps, ndim = walkers.shape
idxs = np.arange(nsteps)
fig, axes = plt.subplots(nrows=ndim, figsize=(6, 3 * ndim))
scatter_kwargs = dict(lw=0, marker="o", markersize=1, alpha=0.05)
for i, ax in enumerate(axes):
ax.plot(
idxs[: nburn + 1], walkers[:, : nburn + 1, i].T, color="r", **scatter_kwargs
)
ax.set_ylabel(parameter_labels[i])
for i, ax in enumerate(axes):
ax.plot(idxs[nburn:], walkers[:, nburn:, i].T, color="k", **scatter_kwargs)
fig.tight_layout()
filename = "{}/{}_traceplot.png".format(outdir, label)
fig.savefig(filename)
plt.close(fig)
def plot_tau(tau_list_n, tau_list, outdir, label):
fig, ax = plt.subplots()
ax.plot(tau_list_n, tau_list, "-x")
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\langle \tau \rangle$")
fig.savefig("{}/{}_tau.png".format(outdir, label))
plt.close(fig)
def compute_evidence(sampler, outdir, label, nburn, thin, make_plots=True):
""" Computes the evidence using thermodynamic integration """
betas = sampler.betas
# We compute the evidence without the burnin samples, but we do not thin
lnlike = sampler.loglikelihood[:, :, nburn:sampler.time]
mean_lnlikes = np.mean(np.mean(lnlike, axis=1), axis=1)
mean_lnlikes = mean_lnlikes[::-1]
betas = betas[::-1]
if any(np.isinf(mean_lnlikes)):
logger.warning(
"mean_lnlikes contains inf: recalculating without"
" the {} infs".format(len(betas[np.isinf(mean_lnlikes)]))
)
idxs = np.isinf(mean_lnlikes)
mean_lnlikes = mean_lnlikes[~idxs]
betas = betas[~idxs]
lnZ = np.trapz(mean_lnlikes, betas)
z1 = np.trapz(mean_lnlikes, betas)
z2 = np.trapz(mean_lnlikes[::-1][::2][::-1], betas[::-1][::2][::-1])
lnZerr = np.abs(z1 - z2)
if make_plots:
fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(6, 8))
ax1.semilogx(betas, mean_lnlikes, "-o")
ax1.set_xlabel(r"$\beta$")
ax1.set_ylabel(r"$\langle \log(\mathcal{L}) \rangle$")
min_betas = []
evidence = []
for i in range(int(len(betas) / 2.0)):
min_betas.append(betas[i])
evidence.append(np.trapz(mean_lnlikes[i:], betas[i:]))
ax2.semilogx(min_betas, evidence, "-o")
ax2.set_ylabel(
r"$\int_{\beta_{min}}^{\beta=1}"
+ r"\langle \log(\mathcal{L})\rangle d\beta$",
size=16,
)
ax2.set_xlabel(r"$\beta_{min}$")
plt.tight_layout()
fig.savefig("{}/{}_beta_lnl.png".format(outdir, label))
return lnZ, lnZerr
def do_nothing_function():
""" This is a do-nothing function, we overwrite the likelihood and prior elsewhere """
pass
class LikePriorEvaluator(object):
"""
A overwrite of the ptemcee.LikePriorEvaluator to use bilby likelihood and priors
"""
def __init__(self, likelihood, priors, search_parameter_keys, use_ratio=False):
self.likelihood = likelihood
self.priors = priors
self.search_parameter_keys = search_parameter_keys
self.use_ratio = use_ratio
def logl(self, v_array):
parameters = {key: v for key, v in zip(self.search_parameter_keys, v_array)}
if self.priors.evaluate_constraints(parameters) > 0:
self.likelihood.parameters.update(parameters)
if self.use_ratio:
return self.likelihood.log_likelihood() - self.likelihood.noise_log_likelihood()
else:
return self.likelihood.log_likelihood()
else:
return np.nan_to_num(-np.inf)
def logp(self, v_array):
params = {key: t for key, t in zip(self.search_parameter_keys, v_array)}
return self.priors.ln_prob(params)
def __call__(self, x):
lp = self.logp(x)
if np.isnan(lp):
raise ValueError('Prior function returned NaN.')
if lp == float('-inf'):
# Can't return -inf, since this messes with beta=0 behaviour.
ll = 0
else:
ll = self.logl(x)
if np.isnan(ll).any():
raise ValueError('Log likelihood function returned NaN.')
return ll, lp
Loading