Skip to content
Snippets Groups Projects

Improve ptemcee

Merged Gregory Ashton requested to merge improve-ptemcee into master
All threads resolved!
Compare and Show latest version
3 files
+ 291
161
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 283
139
@@ -7,6 +7,7 @@ import signal
import sys
import time
import dill
from collections import namedtuple
import numpy as np
import pandas as pd
@@ -16,6 +17,23 @@ from ..utils import logger
from .base_sampler import SamplerError, MCMCSampler
ConvergenceInputs = namedtuple(
"ConvergenceInputs",
[
"autocorr_c",
"autocorr_tol",
"autocorr_tau",
"safety",
"burn_in_nact",
"thin_by_nact",
"frac_threshold",
"nsamples",
"ignore_keys_for_tau",
"min_tau",
],
)
class Ptemcee(MCMCSampler):
"""bilby wrapper ptemcee (https://github.com/willvousden/ptemcee)
@@ -93,14 +111,10 @@ class Ptemcee(MCMCSampler):
Tmax=None,
betas=None,
a=2.0,
loglargs=[],
logpargs=[],
loglkwargs={},
logpkwargs={},
adaptation_lag=10000,
adaptation_time=100,
random=None,
adapt=False,
adapt=True,
swap_ratios=False,
)
@@ -143,70 +157,114 @@ class Ptemcee(MCMCSampler):
**kwargs
)
self.nwalkers = self.sampler_init_kwargs["nwalkers"]
self.ntemps = self.sampler_init_kwargs["ntemps"]
self.max_steps = 500
# Setup up signal handling
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
# Checkpointing inputs
self.exit_code = exit_code
self.resume = resume
self.autocorr_c = autocorr_c
self.safety = safety
self.burn_in_nact = burn_in_nact
self.thin_by_nact = thin_by_nact
self.frac_threshold = frac_threshold
self.nsamples = nsamples
self.autocorr_tol = autocorr_tol
self.autocorr_tau = autocorr_tau
self.min_tau = min_tau
self.check_point_deltaT = check_point_deltaT
self.threads = threads
self.store_walkers = store_walkers
self.ignore_keys_for_tau = ignore_keys_for_tau
self.pos0 = pos0
self.check_point_plot = check_point_plot
self.resume_file = "{}/{}_checkpoint_resume.pickle".format(
self.outdir, self.label
)
self.exit_code = exit_code
# Store convergence checking inputs in a named tuple
convergence_inputs_dict = dict(
autocorr_c=autocorr_c,
autocorr_tol=autocorr_tol,
autocorr_tau=autocorr_tau,
safety=safety,
burn_in_nact=burn_in_nact,
thin_by_nact=thin_by_nact,
frac_threshold=frac_threshold,
nsamples=nsamples,
ignore_keys_for_tau=ignore_keys_for_tau,
min_tau=min_tau,
)
self.convergence_inputs = ConvergenceInputs(**convergence_inputs_dict)
# MultiProcessing inputs
self.threads = threads
# Misc inputs
self.store_walkers = store_walkers
self.pos0 = pos0
@property
def sampler_function_kwargs(self):
""" Kwargs passed to samper.sampler() """
keys = ["adapt", "swap_ratios"]
return {key: self.kwargs[key] for key in keys}
@property
def sampler_init_kwargs(self):
""" Kwargs passed to initialize ptemcee.Sampler() """
return {
key: value
for key, value in self.kwargs.items()
if key not in self.sampler_function_kwargs
}
def _translate_kwargs(self, kwargs):
""" Translate kwargs """
if "nwalkers" not in kwargs:
for equiv in self.nwalkers_equiv_kwargs:
if equiv in kwargs:
kwargs["nwalkers"] = kwargs.pop(equiv)
def get_pos0_from_prior(self):
""" for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim """
""" Draw the initial positions from the prior
Returns
-------
pos0: list
The initial postitions of the walkers, with shape (ntemps, nwalkers, ndim)
"""
logger.info("Generating pos0 samples")
return [
[
self.get_random_draw_from_prior()
for _ in range(self.sampler_init_kwargs["nwalkers"])
for _ in range(self.nwalkers)
]
for _ in range(self.kwargs["ntemps"])
]
def get_pos0_from_minimize(self, minimize_list=None):
logger.info("Attempting to set pos0 from minimize")
""" Draw the initial positions using an initial minimization step
See pos0 in the class initialization for details.
Returns
-------
pos0: list
The initial postitions of the walkers, with shape (ntemps, nwalkers, ndim)
"""
from scipy.optimize import minimize
# Set up the minimize list: keys not in this list will have initial
# positions drawn from the prior
if minimize_list is None:
minimize_list = self.search_parameter_keys
pos0 = np.zeros((self.kwargs["ntemps"], self.kwargs["nwalkers"], self.ndim))
else:
pos0 = np.array(self.get_pos0_from_prior())
logger.info("Attempting to set pos0 for {} from minimize".format(minimize_list))
likelihood_copy = copy.copy(self.likelihood)
def neg_log_like(params):
""" Internal function to minimize """
likelihood_copy.parameters.update(
{key: val for key, val in zip(minimize_list, params)}
)
@@ -215,10 +273,13 @@ class Ptemcee(MCMCSampler):
except RuntimeError:
return +np.inf
# Bounds used in the minimization
bounds = [
(self.priors[key].minimum, self.priors[key].maximum)
for key in minimize_list
]
# Run the minimization step several times to get a range of values
trials = 0
success = []
while True:
@@ -235,6 +296,7 @@ class Ptemcee(MCMCSampler):
if len(success) >= 10:
break
# Initialize positions from the range of values
success = np.array(success)
for i, key in enumerate(minimize_list):
pos0_min = np.min(success[:, i])
@@ -251,6 +313,7 @@ class Ptemcee(MCMCSampler):
return pos0
def setup_sampler(self):
""" Either initialize the sampelr or read in the resume file """
import ptemcee
if os.path.isfile(self.resume_file) and self.resume is True:
@@ -258,18 +321,22 @@ class Ptemcee(MCMCSampler):
with open(self.resume_file, "rb") as file:
data = dill.load(file)
# Extract the check-point data
self.sampler = data["sampler"]
self.iteration = data["iteration"]
self.chain_array = data["chain_array"]
self.log_likelihood_array = data["log_likelihood_array"]
self.pos0 = data["pos0"]
self.tau_list = data["tau_list"]
self.tau_list_n = data["tau_list_n"]
self.time_per_check = data["time_per_check"]
# Initialize the pool
self.sampler.pool = self.pool
self.sampler.threads = self.threads
pos0 = None
logger.info(
"Resuming from previous run with time={}".format(self.sampler.time)
"Resuming from previous run with time={}".format(self.iteration)
)
else:
@@ -295,17 +362,25 @@ class Ptemcee(MCMCSampler):
self.search_parameter_keys, use_ratio=self.use_ratio
)
# Set up empty lists
# Initialize storing results
self.iteration = 0
self.chain_array = self.get_zero_chain_array()
self.log_likelihood_array = self.get_zero_log_likelihood_array()
self.tau_list = []
self.tau_list_n = []
self.time_per_check = []
self.pos0 = self.get_pos0()
return self.sampler
# Initialize the walker postitions
pos0 = self.get_pos0()
def get_zero_chain_array(self):
return np.zeros((self.nwalkers, self.max_steps, self.ndim))
return self.sampler, pos0
def get_zero_log_likelihood_array(self):
return np.zeros((self.ntemps, self.nwalkers, self.max_steps))
def get_pos0(self):
""" Master logic for setting pos0 """
if isinstance(self.pos0, str) and self.pos0.lower() == "prior":
return self.get_pos0_from_prior()
elif isinstance(self.pos0, str) and self.pos0.lower() == "minimize":
@@ -316,6 +391,7 @@ class Ptemcee(MCMCSampler):
raise SamplerError("pos0={} not implemented".format(self.pos0))
def setup_pool(self):
""" If threads > 1, setup a MultiPool, else run in serial mode """
if self.threads > 1:
import schwimmbad
@@ -328,105 +404,48 @@ class Ptemcee(MCMCSampler):
def run_sampler(self):
self.setup_pool()
out = self.run_sampler_internal()
if self.pool:
self.pool.close()
return out
def run_sampler_internal(self):
import emcee
sampler, pos0 = self.setup_sampler()
sampler = self.setup_sampler()
t0 = datetime.datetime.now()
logger.info("Starting to sample")
while True:
for (pos0, _, _) in sampler.sample(pos0, **self.sampler_function_kwargs):
for (pos0, log_posterior, log_likelihood) in sampler.sample(
self.pos0, storechain=False, **self.sampler_function_kwargs):
pass
if self.iteration == self.chain_array.shape[1]:
self.chain_array = np.concatenate((
self.chain_array, self.get_zero_chain_array()), axis=1)
self.log_likelihood_array = np.concatenate((
self.log_likelihood_array, self.get_zero_log_likelihood_array()),
axis=2)
self.pos0 = pos0
self.chain_array[:, self.iteration, :] = pos0[0, :, :]
self.log_likelihood_array[:, :, self.iteration] = log_likelihood
# Calculate time per iteration
self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
t0 = datetime.datetime.now()
# Compute ACT tau for 0-temperature chains
samples = sampler.chain[0, :, : sampler.time, :]
taus = []
for ii in range(sampler.nwalkers):
tau_ii = []
for jj, key in enumerate(self.search_parameter_keys):
if self.ignore_keys_for_tau and self.ignore_keys_for_tau in key:
continue
try:
tau_ii.append(
emcee.autocorr.integrated_time(
samples[ii, :, jj], c=self.autocorr_c, tol=0
)[0]
)
except emcee.autocorr.AutocorrError:
taus.append(np.inf)
taus.append(tau_ii)
tau = np.max(np.mean(taus, axis=0))
# Apply multiplicitive safety factor
tau = self.safety * tau
# Store for convergence checking and plotting
self.tau_list.append(np.mean(taus, axis=0))
self.tau_list_n.append(sampler.time)
# Convert to an integer
tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau
self.tau_int = tau_int
if np.isnan(tau_int) or np.isinf(tau_int):
print_progress(
self.sampler,
self.time_per_check,
self.nsamples,
np.nan,
np.nan,
np.nan,
[np.nan],
False,
)
continue
# Calculate the effective number of samples available
self.nburn = int(self.burn_in_nact * tau_int)
self.thin = int(np.max([1, self.thin_by_nact * tau_int]))
samples_per_check = sampler.nwalkers / self.thin
self.nsamples_effective = int(
sampler.nwalkers * (sampler.time - self.nburn) / self.thin
)
# Calculate convergence boolean
converged = self.nsamples < self.nsamples_effective
# Calculate fractional change in tau from previous iterations
check_taus = np.array(self.tau_list[-tau_int * self.autocorr_tau :])
if not np.any(np.isnan(check_taus)):
frac = (tau - check_taus) / tau
tau_usable = np.all(frac < self.frac_threshold)
else:
tau_usable = False
if sampler.time < tau_int * self.autocorr_tol or tau_int < self.min_tau:
tau_usable = False
# Print an update on the progress
print_progress(
self.sampler,
self.time_per_check,
self.nsamples,
(
stop,
self.nburn,
self.thin,
self.tau_int,
self.nsamples_effective,
samples_per_check,
tau_int,
check_taus,
tau_usable,
) = check_iteration(
self.chain_array[:, :self.iteration + 1, :],
sampler,
self.convergence_inputs,
self.search_parameter_keys,
self.time_per_check,
self.tau_list,
self.tau_list_n,
)
if converged and tau_usable:
self.iteration += 1
if stop:
logger.info("Finished sampling")
break
@@ -443,12 +462,11 @@ class Ptemcee(MCMCSampler):
self.write_current_state(plot=self.check_point_plot)
# Get 0-likelihood samples and store in the result
samples = sampler.chain[0, :, :, :] # nwalkers, nsteps, ndim
self.result.samples = samples[
:, self.nburn : sampler.time : self.thin, :
self.result.samples = self.chain_array[
:, self.nburn : self.iteration : self.thin, :
].reshape((-1, self.ndim))
loglikelihood = sampler.loglikelihood[
0, :, self.nburn : sampler.time : self.thin
loglikelihood = self.log_likelihood_array[
0, :, self.nburn : self.iteration : self.thin
] # nwalkers, nsteps
self.result.log_likelihood_evaluations = loglikelihood.reshape((-1))
@@ -457,7 +475,8 @@ class Ptemcee(MCMCSampler):
self.result.nburn = self.nburn
log_evidence, log_evidence_err = compute_evidence(
sampler, self.outdir, self.label, self.nburn, self.thin
sampler, self.log_likelihood_array, self.outdir, self.label, self.nburn,
self.thin, self.iteration,
)
self.result.log_evidence = log_evidence
self.result.log_evidence_err = log_evidence_err
@@ -466,6 +485,9 @@ class Ptemcee(MCMCSampler):
seconds=np.sum(self.time_per_check)
)
if self.pool:
self.pool.close()
return self.result
def write_current_state_and_exit(self, signum=None, frame=None):
@@ -479,6 +501,7 @@ class Ptemcee(MCMCSampler):
def write_current_state(self, plot=True):
checkpoint(
self.iteration,
self.outdir,
self.label,
self.nsamples_effective,
@@ -487,6 +510,9 @@ class Ptemcee(MCMCSampler):
self.thin,
self.search_parameter_keys,
self.resume_file,
self.log_likelihood_array,
self.chain_array,
self.pos0,
self.tau_list,
self.tau_list_n,
self.time_per_check,
@@ -495,7 +521,7 @@ class Ptemcee(MCMCSampler):
if plot and not np.isnan(self.nburn):
# Generate the walkers plot diagnostic
plot_walkers(
self.sampler.chain[0, :, : self.sampler.time, :],
self.chain_array[:, : self.iteration, :],
self.nburn,
self.thin,
self.search_parameter_keys,
@@ -511,18 +537,127 @@ class Ptemcee(MCMCSampler):
self.outdir,
self.label,
self.tau_int,
self.autocorr_tau,
self.convergence_inputs.autocorr_tau,
)
def check_iteration(
samples,
sampler,
convergence_inputs,
search_parameter_keys,
time_per_check,
tau_list,
tau_list_n,
):
""" Per-iteration logic to calculate the convergence check
Parameters
----------
convergence_inputs: bilby.core.sampler.ptemcee.ConvergenceInputs
A named tuple of the convergence checking inputs
search_parameter_keys: list
A list of the search parameter keys
time_per_check, tau_list, tau_list_n: list
Lists used for tracking the run
Returns
-------
stop: bool
A boolean flag, True if the stoping criteria has been met
burn: int
The number of burn-in steps to discard
thin: int
The thin-by factor to apply
tau_int: int
The integer estimated ACT
nsamples_effective: int
The effective number of samples after burning and thinning
"""
import emcee
ci = convergence_inputs
nwalkers, iteration, ndim = samples.shape
# Compute ACT tau for 0-temperature chains
tau_array = np.zeros((nwalkers, ndim))
for ii in range(nwalkers):
for jj, key in enumerate(search_parameter_keys):
if ci.ignore_keys_for_tau and ci.ignore_keys_for_tau in key:
continue
try:
tau_array[ii, jj] = emcee.autocorr.integrated_time(
samples[ii, :, jj], c=ci.autocorr_c, tol=0)[0]
except emcee.autocorr.AutocorrError:
tau_array[ii, jj] = np.inf
# Maximum over paramters, mean over walkers
tau = np.max(np.mean(tau_array, axis=0))
# Apply multiplicitive safety factor
tau = ci.safety * tau
# Store for convergence checking and plotting
tau_list.append(list(np.mean(tau_array, axis=0)))
tau_list_n.append(iteration)
# Convert to an integer
tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau
if np.isnan(tau_int) or np.isinf(tau_int):
print_progress(
iteration, sampler, time_per_check, ci.nsamples, np.nan, np.nan, np.nan, np.nan, False,
)
return False, np.nan, np.nan, np.nan, np.nan
# Calculate the effective number of samples available
nburn = int(ci.burn_in_nact * tau_int)
thin = int(np.max([1, ci.thin_by_nact * tau_int]))
samples_per_check = nwalkers / thin
nsamples_effective = int(nwalkers * (iteration - nburn) / thin)
# Calculate convergence boolean
converged = ci.nsamples < nsamples_effective
# Calculate fractional change in tau from previous iteration
check_taus = np.array(tau_list[-tau_int * ci.autocorr_tau :])
taus_per_parameter = check_taus[-1, :]
if not np.any(np.isnan(check_taus)):
frac = (taus_per_parameter - check_taus) / taus_per_parameter
max_frac = np.max(frac)
tau_usable = np.all(frac < ci.frac_threshold)
else:
max_frac = np.nan
tau_usable = False
if iteration < tau_int * ci.autocorr_tol or tau_int < ci.min_tau:
tau_usable = False
# Print an update on the progress
print_progress(
iteration,
sampler,
time_per_check,
ci.nsamples,
nsamples_effective,
samples_per_check,
tau_int,
max_frac,
tau_usable,
)
stop = converged and tau_usable
return stop, nburn, thin, tau_int, nsamples_effective
def print_progress(
iteration,
sampler,
time_per_check,
nsamples,
nsamples_effective,
samples_per_check,
tau_int,
tau_list,
max_frac,
tau_usable,
):
# Setup acceptance string
@@ -544,7 +679,10 @@ def print_progress(
sampling_time = datetime.timedelta(seconds=np.sum(time_per_check))
tau_str = "{}:{:0.1f}->{:0.1f}".format(tau_int, np.min(tau_list), np.max(tau_list))
if max_frac >= 0:
tau_str = "{}(+{:0.1f})".format(tau_int, max_frac)
else:
tau_str = "{}({:0.1f})".format(tau_int, max_frac)
if tau_usable:
tau_str = "={}".format(tau_str)
else:
@@ -552,13 +690,13 @@ def print_progress(
evals_per_check = sampler.nwalkers * sampler.ntemps
ncalls = "{:1.1e}".format(sampler.time * sampler.nwalkers * sampler.ntemps)
ncalls = "{:1.1e}".format(iteration * sampler.nwalkers * sampler.ntemps)
eval_timing = "{:1.1f}ms/ev".format(1e3 * ave_time_per_check / evals_per_check)
samp_timing = "{:1.1f}ms/sm".format(1e3 * ave_time_per_check / samples_per_check)
print(
"{}| {} | nc:{}| a0:{}| swp:{}| n:{}<{}| tau{}| {}| {}".format(
sampler.time,
iteration,
str(sampling_time).split(".")[0],
ncalls,
acceptance_str,
@@ -574,6 +712,7 @@ def print_progress(
def checkpoint(
iteration,
outdir,
label,
nsamples_effective,
@@ -582,6 +721,9 @@ def checkpoint(
thin,
search_parameter_keys,
resume_file,
log_likelihood_array,
chain_array,
pos0,
tau_list,
tau_list_n,
time_per_check,
@@ -592,7 +734,7 @@ def checkpoint(
# Store the samples if possible
if nsamples_effective > 0:
filename = "{}/{}_samples.txt".format(outdir, label)
samples = sampler.chain[0, :, nburn : sampler.time : thin, :].reshape(
samples = np.array(chain_array)[:, nburn : iteration : thin, :].reshape(
(-1, ndim)
)
df = pd.DataFrame(samples, columns=search_parameter_keys)
@@ -601,16 +743,16 @@ def checkpoint(
# Pickle the resume artefacts
sampler_copy = copy.copy(sampler)
del sampler_copy.pool
sampler_copy._chain = sampler._chain[:, :, : sampler.time, :]
sampler_copy._logposterior = sampler._logposterior[:, :, : sampler.time]
sampler_copy._loglikelihood = sampler._loglikelihood[:, :, : sampler.time]
sampler_copy._beta_history = sampler._beta_history[:, : sampler.time]
data = dict(
iteration=iteration,
sampler=sampler_copy,
tau_list=tau_list,
tau_list_n=tau_list_n,
time_per_check=time_per_check,
log_likelihood_array=log_likelihood_array,
chain_array=chain_array,
pos0=pos0,
)
with open(resume_file, "wb") as file:
@@ -652,7 +794,9 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
plt.close(fig)
def plot_tau(tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau):
def plot_tau(
tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau
):
fig, ax = plt.subplots()
for i, key in enumerate(search_parameter_keys):
ax.plot(tau_list_n, np.array(tau_list)[:, i], label=key)
@@ -664,11 +808,12 @@ def plot_tau(tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, au
plt.close(fig)
def compute_evidence(sampler, outdir, label, nburn, thin, make_plots=True):
def compute_evidence(sampler, log_likelihood_array, outdir, label, nburn, thin,
iteration, make_plots=True):
""" Computes the evidence using thermodynamic integration """
betas = sampler.betas
# We compute the evidence without the burnin samples, but we do not thin
lnlike = sampler.loglikelihood[:, :, nburn : sampler.time]
lnlike = log_likelihood_array[:, :, nburn : iteration]
mean_lnlikes = np.mean(np.mean(lnlike, axis=1), axis=1)
mean_lnlikes = mean_lnlikes[::-1]
@@ -701,8 +846,7 @@ def compute_evidence(sampler, outdir, label, nburn, thin, make_plots=True):
ax2.semilogx(min_betas, evidence, "-o")
ax2.set_ylabel(
r"$\int_{\beta_{min}}^{\beta=1}"
+ r"\langle \log(\mathcal{L})\rangle d\beta$",
r"$\int_{\beta_{min}}^{\beta=1}" + r"\langle \log(\mathcal{L})\rangle d\beta$",
size=16,
)
ax2.set_xlabel(r"$\beta_{min}$")
Loading