Skip to content
Snippets Groups Projects

Improve ptemcee

Merged Gregory Ashton requested to merge improve-ptemcee into master
All threads resolved!
Compare and Show latest version
3 files
+ 512
215
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 502
192
@@ -7,6 +7,7 @@ import signal
import sys
import time
import dill
from collections import namedtuple
import numpy as np
import pandas as pd
@@ -16,6 +17,23 @@ from ..utils import logger
from .base_sampler import SamplerError, MCMCSampler
ConvergenceInputs = namedtuple(
"ConvergenceInputs",
[
"autocorr_c",
"autocorr_tol",
"autocorr_tau",
"safety",
"burn_in_nact",
"thin_by_nact",
"frac_threshold",
"nsamples",
"ignore_keys_for_tau",
"min_tau",
],
)
class Ptemcee(MCMCSampler):
"""bilby wrapper ptemcee (https://github.com/willvousden/ptemcee)
@@ -24,101 +42,244 @@ class Ptemcee(MCMCSampler):
documentation for that class for further help. Under Other Parameters, we
list commonly used kwargs and the bilby defaults.
Parameters
----------
nsamples: int, (5000)
The requested number of samples. Note, in cases where the
autocorrelation parameter is difficult to measure, it is possible to
end up with more than nsamples.
burn_in_act, thin_by_nact: int, (50, 1)
The number of burn-in autocorrelation times to discard and the thin-by
factor. Increasing burn_in_act increases the time required for burn-in.
Increasing thin_by_nact increases the time required to obtain nsamples.
autocorr_tol: int, (50)
The minimum number of autocorrelation times needed to trust the
estimate of the autocorrelation time.
autocorr_c: int, (5)
The step size for the window search used by emcee.autocorr.integrated_time
safety: int, (1)
A multiplicitive factor for the estimated autocorrelation. Useful for
cases where non-convergence can be observed by eye but the automated
tools are failing.
autocorr_tau:
The number of autocorrelation times to use in assessing if the
autocorrelation time is stable.
frac_threshold: float, (0.01)
The maximum fractional change in the autocorrelation for the last
autocorr_tau steps. If the fractional change exceeds this value,
sampling will continue until the estimate of the autocorrelation time
can be trusted.
min_tau: int, (1)
A minimum tau (autocorrelation time) to accept.
check_point_deltaT: float, (600)
The period with which to checkpoint (in seconds).
threads: int, (1)
If threads > 1, a MultiPool object is setup and used.
exit_code: int, (77)
The code on which the sampler exits.
store_walkers: bool (False)
If true, store the unthinned, unburnt chaines in the result. Note, this
is not recommended for cases where tau is large.
ignore_keys_for_tau: str
A pattern used to ignore keys in estimating the autocorrelation time.
pos0: str, list ("prior")
If a string, one of "prior" or "minimize". For "prior", the initial
positions of the sampler are drawn from the sampler. If "minimize",
a scipy.optimize step is applied to all parameters a number of times.
The walkers are then initialized from the range of values obtained.
If a list, for the keys in the list the optimization step is applied,
otherwise the initial points are drawn from the prior.
Other Parameters
----------------
nwalkers: int, (100)
nwalkers: int, (200)
The number of walkers
nsteps: int, (100)
The number of steps to take
nburn: int (50)
The fixed number of steps to discard as burn-in
ntemps: int (2)
The number of temperatures used by ptemcee
Tmax: float
The maximum temperature
"""
# Arguments used by ptemcee
default_kwargs = dict(
ntemps=20, nwalkers=200, Tmax=None, betas=None,
a=2.0, loglargs=[], logpargs=[], loglkwargs={},
logpkwargs={}, adaptation_lag=10000, adaptation_time=100, random=None,
adapt=False, swap_ratios=False)
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, check_point_plot=True, skip_import_verification=False,
resume=True, nsamples=5000, burn_in_nact=50, thin_by_nact=1,
autocorr_c=5, safety=1, frac_threshold=0.01,
autocorr_tol=50, autocorr_tau=5, min_tau=1, check_point_deltaT=600,
threads=1, exit_code=77, plot=False, store_walkers=False,
ignore_keys_for_tau="recalib", pos0="prior", **kwargs):
ntemps=20,
nwalkers=200,
Tmax=None,
betas=None,
a=2.0,
adaptation_lag=10000,
adaptation_time=100,
random=None,
adapt=True,
swap_ratios=False,
)
def __init__(
self,
likelihood,
priors,
outdir="outdir",
label="label",
use_ratio=False,
check_point_plot=True,
skip_import_verification=False,
resume=True,
nsamples=5000,
burn_in_nact=50,
thin_by_nact=1,
autocorr_tol=50,
autocorr_c=5,
safety=1,
autocorr_tau=5,
frac_threshold=0.01,
min_tau=1,
check_point_deltaT=600,
threads=1,
exit_code=77,
plot=False,
store_walkers=False,
ignore_keys_for_tau=None,
pos0="prior",
**kwargs
):
super(Ptemcee, self).__init__(
likelihood=likelihood, priors=priors, outdir=outdir,
label=label, use_ratio=use_ratio, plot=plot,
skip_import_verification=skip_import_verification, **kwargs)
likelihood=likelihood,
priors=priors,
outdir=outdir,
label=label,
use_ratio=use_ratio,
plot=plot,
skip_import_verification=skip_import_verification,
**kwargs
)
self.nwalkers = self.sampler_init_kwargs["nwalkers"]
self.ntemps = self.sampler_init_kwargs["ntemps"]
self.max_steps = 500
# Setup up signal handling
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
# Checkpointing inputs
self.exit_code = exit_code
self.resume = resume
self.autocorr_c = autocorr_c
self.safety = safety
self.burn_in_nact = burn_in_nact
self.thin_by_nact = thin_by_nact
self.frac_threshold = frac_threshold
self.nsamples = nsamples
self.autocorr_tol = autocorr_tol
self.autocorr_tau = autocorr_tau
self.min_tau = min_tau
self.check_point_deltaT = check_point_deltaT
self.check_point_plot = check_point_plot
self.resume_file = "{}/{}_checkpoint_resume.pickle".format(
self.outdir, self.label
)
# Store convergence checking inputs in a named tuple
convergence_inputs_dict = dict(
autocorr_c=autocorr_c,
autocorr_tol=autocorr_tol,
autocorr_tau=autocorr_tau,
safety=safety,
burn_in_nact=burn_in_nact,
thin_by_nact=thin_by_nact,
frac_threshold=frac_threshold,
nsamples=nsamples,
ignore_keys_for_tau=ignore_keys_for_tau,
min_tau=min_tau,
)
self.convergence_inputs = ConvergenceInputs(**convergence_inputs_dict)
# MultiProcessing inputs
self.threads = threads
# Misc inputs
self.store_walkers = store_walkers
self.ignore_keys_for_tau = ignore_keys_for_tau
self.pos0 = pos0
self.check_point_plot = check_point_plot
self.resume_file = "{}/{}_checkpoint_resume.pickle".format(self.outdir, self.label)
self.exit_code = exit_code
@property
def sampler_function_kwargs(self):
keys = ['adapt', 'swap_ratios']
""" Kwargs passed to samper.sampler() """
keys = ["adapt", "swap_ratios"]
return {key: self.kwargs[key] for key in keys}
@property
def sampler_init_kwargs(self):
return {key: value
for key, value in self.kwargs.items()
if key not in self.sampler_function_kwargs}
""" Kwargs passed to initialize ptemcee.Sampler() """
return {
key: value
for key, value in self.kwargs.items()
if key not in self.sampler_function_kwargs
}
def _translate_kwargs(self, kwargs):
""" Translate kwargs """
if "nwalkers" not in kwargs:
for equiv in self.nwalkers_equiv_kwargs:
if equiv in kwargs:
kwargs["nwalkers"] = kwargs.pop(equiv)
def get_pos0_from_prior(self):
""" for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim """
""" Draw the initial positions from the prior
Returns
-------
pos0: list
The initial postitions of the walkers, with shape (ntemps, nwalkers, ndim)
"""
logger.info("Generating pos0 samples")
return [[self.get_random_draw_from_prior()
for _ in range(self.sampler_init_kwargs["nwalkers"])]
for _ in range(self.kwargs['ntemps'])]
return [
[
self.get_random_draw_from_prior()
for _ in range(self.nwalkers)
]
for _ in range(self.kwargs["ntemps"])
]
def get_pos0_from_minimize(self, minimize_list=None):
logger.info("Attempting to set pos0 from minimize")
""" Draw the initial positions using an initial minimization step
See pos0 in the class initialization for details.
Returns
-------
pos0: list
The initial postitions of the walkers, with shape (ntemps, nwalkers, ndim)
"""
from scipy.optimize import minimize
# Set up the minimize list: keys not in this list will have initial
# positions drawn from the prior
if minimize_list is None:
minimize_list = self.search_parameter_keys
pos0 = np.zeros((self.kwargs["ntemps"], self.kwargs["nwalkers"], self.ndim))
else:
pos0 = np.array(self.get_pos0_from_prior())
logger.info("Attempting to set pos0 for {} from minimize".format(minimize_list))
likelihood_copy = copy.copy(self.likelihood)
def neg_log_like(params):
""" Internal function to minimize """
likelihood_copy.parameters.update(
{key: val for key, val in zip(minimize_list, params)})
{key: val for key, val in zip(minimize_list, params)}
)
try:
return -likelihood_copy.log_likelihood()
except RuntimeError:
return +np.inf
bounds = [(self.priors[key].minimum, self.priors[key].maximum)
for key in minimize_list]
# Bounds used in the minimization
bounds = [
(self.priors[key].minimum, self.priors[key].maximum)
for key in minimize_list
]
# Run the minimization step several times to get a range of values
trials = 0
success = []
while True:
@@ -126,7 +287,8 @@ class Ptemcee(MCMCSampler):
likelihood_copy.parameters.update(draw)
x0 = [draw[key] for key in minimize_list]
res = minimize(
neg_log_like, x0, bounds=bounds, method='L-BFGS-B', tol=1e-15)
neg_log_like, x0, bounds=bounds, method="L-BFGS-B", tol=1e-15
)
if res.success:
success.append(res.x)
if trials > 100:
@@ -134,62 +296,91 @@ class Ptemcee(MCMCSampler):
if len(success) >= 10:
break
# Initialize positions from the range of values
success = np.array(success)
for i, key in enumerate(minimize_list):
pos0_min = np.min(success[:, i])
pos0_max = np.max(success[:, i])
logger.info("Initialize {} walkers from {}->{}"
.format(key, pos0_min, pos0_max))
logger.info(
"Initialize {} walkers from {}->{}".format(key, pos0_min, pos0_max)
)
j = self.search_parameter_keys.index(key)
pos0[:, :, j] = np.random.uniform(
pos0_min, pos0_max,
size=(self.kwargs["ntemps"], self.kwargs["nwalkers"]))
pos0_min,
pos0_max,
size=(self.kwargs["ntemps"], self.kwargs["nwalkers"]),
)
return pos0
def setup_sampler(self):
""" Either initialize the sampelr or read in the resume file """
import ptemcee
if os.path.isfile(self.resume_file) and self.resume is True:
logger.info("Resume data {} found".format(self.resume_file))
with open(self.resume_file, "rb") as file:
data = dill.load(file)
# Extract the check-point data
self.sampler = data["sampler"]
self.iteration = data["iteration"]
self.chain_array = data["chain_array"]
self.log_likelihood_array = data["log_likelihood_array"]
self.pos0 = data["pos0"]
self.tau_list = data["tau_list"]
self.tau_list_n = data["tau_list_n"]
self.time_per_check = data["time_per_check"]
# Initialize the pool
self.sampler.pool = self.pool
self.sampler.threads = self.threads
pos0 = None
logger.info("Resuming from previous run with time={}".format(self.sampler.time))
logger.info(
"Resuming from previous run with time={}".format(self.iteration)
)
else:
# Initialize the PTSampler
if self.threads == 1:
self.sampler = ptemcee.Sampler(
dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior,
**self.sampler_init_kwargs)
dim=self.ndim,
logl=self.log_likelihood,
logp=self.log_prior,
**self.sampler_init_kwargs
)
else:
self.sampler = ptemcee.Sampler(
dim=self.ndim, logl=do_nothing_function, logp=do_nothing_function,
pool=self.pool, threads=self.threads, **self.sampler_init_kwargs)
dim=self.ndim,
logl=do_nothing_function,
logp=do_nothing_function,
pool=self.pool,
threads=self.threads,
**self.sampler_init_kwargs
)
self.sampler._likeprior = LikePriorEvaluator(
self.search_parameter_keys, use_ratio=self.use_ratio)
self.search_parameter_keys, use_ratio=self.use_ratio
)
# Set up empty lists
# Initialize storing results
self.iteration = 0
self.chain_array = self.get_zero_chain_array()
self.log_likelihood_array = self.get_zero_log_likelihood_array()
self.tau_list = []
self.tau_list_n = []
self.time_per_check = []
self.pos0 = self.get_pos0()
return self.sampler
# Initialize the walker postitions
pos0 = self.get_pos0()
def get_zero_chain_array(self):
return np.zeros((self.nwalkers, self.max_steps, self.ndim))
return self.sampler, pos0
def get_zero_log_likelihood_array(self):
return np.zeros((self.ntemps, self.nwalkers, self.max_steps))
def get_pos0(self):
""" Master logic for setting pos0 """
if isinstance(self.pos0, str) and self.pos0.lower() == "prior":
return self.get_pos0_from_prior()
elif isinstance(self.pos0, str) and self.pos0.lower() == "minimize":
@@ -200,108 +391,61 @@ class Ptemcee(MCMCSampler):
raise SamplerError("pos0={} not implemented".format(self.pos0))
def setup_pool(self):
""" If threads > 1, setup a MultiPool, else run in serial mode """
if self.threads > 1:
import schwimmbad
logger.info("Creating MultiPool with {} processes".format(self.threads))
self.pool = schwimmbad.MultiPool(
self.threads,
initializer=init,
initargs=(self.likelihood, self.priors))
self.threads, initializer=init, initargs=(self.likelihood, self.priors)
)
else:
self.pool = None
def run_sampler(self):
self.setup_pool()
out = self.run_sampler_internal()
if self.pool:
self.pool.close()
return out
def run_sampler_internal(self):
import emcee
sampler, pos0 = self.setup_sampler()
sampler = self.setup_sampler()
t0 = datetime.datetime.now()
logger.info("Starting to sample")
while True:
for (pos0, _, _) in sampler.sample(pos0, **self.sampler_function_kwargs):
for (pos0, log_posterior, log_likelihood) in sampler.sample(
self.pos0, storechain=False, **self.sampler_function_kwargs):
pass
if self.iteration == self.chain_array.shape[1]:
self.chain_array = np.concatenate((
self.chain_array, self.get_zero_chain_array()), axis=1)
self.log_likelihood_array = np.concatenate((
self.log_likelihood_array, self.get_zero_log_likelihood_array()),
axis=2)
self.pos0 = pos0
self.chain_array[:, self.iteration, :] = pos0[0, :, :]
self.log_likelihood_array[:, :, self.iteration] = log_likelihood
# Calculate time per iteration
self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
t0 = datetime.datetime.now()
# Compute ACT tau for 0-temperature chains
samples = sampler.chain[0, :, : sampler.time, :]
taus = []
for ii in range(sampler.nwalkers):
for jj, key in enumerate(self.search_parameter_keys):
if self.ignore_keys_for_tau and self.ignore_keys_for_tau in key:
continue
try:
taus.append(
emcee.autocorr.integrated_time(
samples[ii, :, jj], c=self.autocorr_c, tol=0
)[0]
)
except emcee.autocorr.AutocorrError:
taus.append(np.inf)
# Apply multiplicitive safety factor
tau = self.safety * np.mean(taus)
# Store for convergence checking and plotting
self.tau_list.append(tau)
self.tau_list_n.append(sampler.time)
# Convert to an integer
tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau
if np.isnan(tau_int) or np.isinf(tau_int):
print_progress(
self.sampler,
self.time_per_check,
self.nsamples,
np.nan,
np.nan,
np.nan,
[np.nan],
False)
continue
# Calculate the effective number of samples available
self.nburn = int(self.burn_in_nact * tau_int)
self.thin = int(np.max([1, self.thin_by_nact * tau_int]))
samples_per_check = sampler.nwalkers / self.thin
self.nsamples_effective = int(sampler.nwalkers * (sampler.time - self.nburn) / self.thin)
# Calculate convergence boolean
converged = self.nsamples < self.nsamples_effective
# Calculate fractional change in tau from previous iterations
check_taus = np.array(self.tau_list[-tau_int * self.autocorr_tau:])
if not np.any(np.isnan(check_taus)):
frac = (tau - check_taus) / tau
tau_usable = np.all(frac < self.frac_threshold)
else:
tau_usable = False
if sampler.time < tau_int * self.autocorr_tol or tau_int < self.min_tau:
tau_usable = False
# Print an update on the progress
print_progress(
self.sampler,
self.time_per_check,
self.nsamples,
(
stop,
self.nburn,
self.thin,
self.tau_int,
self.nsamples_effective,
samples_per_check,
tau_int,
check_taus,
tau_usable,
) = check_iteration(
self.chain_array[:, :self.iteration + 1, :],
sampler,
self.convergence_inputs,
self.search_parameter_keys,
self.time_per_check,
self.tau_list,
self.tau_list_n,
)
if converged and tau_usable:
self.iteration += 1
if stop:
logger.info("Finished sampling")
break
@@ -318,11 +462,11 @@ class Ptemcee(MCMCSampler):
self.write_current_state(plot=self.check_point_plot)
# Get 0-likelihood samples and store in the result
samples = sampler.chain[0, :, :, :] # nwalkers, nsteps, ndim
self.result.samples = (
samples[:, self.nburn: sampler.time:self.thin, :].reshape((-1, self.ndim)))
loglikelihood = sampler.loglikelihood[
0, :, self.nburn:sampler.time:self.thin
self.result.samples = self.chain_array[
:, self.nburn : self.iteration : self.thin, :
].reshape((-1, self.ndim))
loglikelihood = self.log_likelihood_array[
0, :, self.nburn : self.iteration : self.thin
] # nwalkers, nsteps
self.result.log_likelihood_evaluations = loglikelihood.reshape((-1))
@@ -331,50 +475,189 @@ class Ptemcee(MCMCSampler):
self.result.nburn = self.nburn
log_evidence, log_evidence_err = compute_evidence(
sampler, self.outdir, self.label, self.nburn, self.thin
sampler, self.log_likelihood_array, self.outdir, self.label, self.nburn,
self.thin, self.iteration,
)
self.result.log_evidence = log_evidence
self.result.log_evidence_err = log_evidence_err
self.result.sampling_time = datetime.timedelta(seconds=np.sum(self.time_per_check))
self.result.sampling_time = datetime.timedelta(
seconds=np.sum(self.time_per_check)
)
if self.pool:
self.pool.close()
return self.result
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
if getattr(self, 'pool', None):
if getattr(self, "pool", None) or self.threads == 1:
self.write_current_state(plot=False)
logger.warning("Closing pool")
if getattr(self, "pool", None):
logger.info("Closing pool")
self.pool.close()
sys.exit(self.exit_code)
def write_current_state(self, plot=True):
checkpoint(self.outdir, self.label, self.nsamples_effective,
self.sampler, self.nburn, self.thin,
self.search_parameter_keys, self.resume_file, self.tau_list,
self.tau_list_n, self.time_per_check)
checkpoint(
self.iteration,
self.outdir,
self.label,
self.nsamples_effective,
self.sampler,
self.nburn,
self.thin,
self.search_parameter_keys,
self.resume_file,
self.log_likelihood_array,
self.chain_array,
self.pos0,
self.tau_list,
self.tau_list_n,
self.time_per_check,
)
if plot and not np.isnan(self.nburn):
# Generate the walkers plot diagnostic
plot_walkers(
self.sampler.chain[0, :, : self.sampler.time, :],
self.nburn, self.thin, self.search_parameter_keys, self.outdir,
self.label
self.chain_array[:, : self.iteration, :],
self.nburn,
self.thin,
self.search_parameter_keys,
self.outdir,
self.label,
)
# Generate the tau plot diagnostic
plot_tau(self.tau_list_n, self.tau_list, self.outdir, self.label,
self.autocorr_tau)
plot_tau(
self.tau_list_n,
self.tau_list,
self.search_parameter_keys,
self.outdir,
self.label,
self.tau_int,
self.convergence_inputs.autocorr_tau,
)
def check_iteration(
samples,
sampler,
convergence_inputs,
search_parameter_keys,
time_per_check,
tau_list,
tau_list_n,
):
""" Per-iteration logic to calculate the convergence check
Parameters
----------
convergence_inputs: bilby.core.sampler.ptemcee.ConvergenceInputs
A named tuple of the convergence checking inputs
search_parameter_keys: list
A list of the search parameter keys
time_per_check, tau_list, tau_list_n: list
Lists used for tracking the run
Returns
-------
stop: bool
A boolean flag, True if the stoping criteria has been met
burn: int
The number of burn-in steps to discard
thin: int
The thin-by factor to apply
tau_int: int
The integer estimated ACT
nsamples_effective: int
The effective number of samples after burning and thinning
"""
import emcee
ci = convergence_inputs
nwalkers, iteration, ndim = samples.shape
# Compute ACT tau for 0-temperature chains
tau_array = np.zeros((nwalkers, ndim))
for ii in range(nwalkers):
for jj, key in enumerate(search_parameter_keys):
if ci.ignore_keys_for_tau and ci.ignore_keys_for_tau in key:
continue
try:
tau_array[ii, jj] = emcee.autocorr.integrated_time(
samples[ii, :, jj], c=ci.autocorr_c, tol=0)[0]
except emcee.autocorr.AutocorrError:
tau_array[ii, jj] = np.inf
# Maximum over paramters, mean over walkers
tau = np.max(np.mean(tau_array, axis=0))
# Apply multiplicitive safety factor
tau = ci.safety * tau
# Store for convergence checking and plotting
tau_list.append(list(np.mean(tau_array, axis=0)))
tau_list_n.append(iteration)
# Convert to an integer
tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau
if np.isnan(tau_int) or np.isinf(tau_int):
print_progress(
iteration, sampler, time_per_check, ci.nsamples, np.nan, np.nan, np.nan, np.nan, False,
)
return False, np.nan, np.nan, np.nan, np.nan
# Calculate the effective number of samples available
nburn = int(ci.burn_in_nact * tau_int)
thin = int(np.max([1, ci.thin_by_nact * tau_int]))
samples_per_check = nwalkers / thin
nsamples_effective = int(nwalkers * (iteration - nburn) / thin)
# Calculate convergence boolean
converged = ci.nsamples < nsamples_effective
# Calculate fractional change in tau from previous iteration
check_taus = np.array(tau_list[-tau_int * ci.autocorr_tau :])
taus_per_parameter = check_taus[-1, :]
if not np.any(np.isnan(check_taus)):
frac = (taus_per_parameter - check_taus) / taus_per_parameter
max_frac = np.max(frac)
tau_usable = np.all(frac < ci.frac_threshold)
else:
max_frac = np.nan
tau_usable = False
if iteration < tau_int * ci.autocorr_tol or tau_int < ci.min_tau:
tau_usable = False
# Print an update on the progress
print_progress(
iteration,
sampler,
time_per_check,
ci.nsamples,
nsamples_effective,
samples_per_check,
tau_int,
max_frac,
tau_usable,
)
stop = converged and tau_usable
return stop, nburn, thin, tau_int, nsamples_effective
def print_progress(
iteration,
sampler,
time_per_check,
nsamples,
nsamples_effective,
samples_per_check,
tau_int,
tau_list,
max_frac,
tau_usable,
):
# Setup acceptance string
@@ -388,9 +671,7 @@ def print_progress(
)
ave_time_per_check = np.mean(time_per_check[-3:])
time_left = (
(nsamples - nsamples_effective) * ave_time_per_check / samples_per_check
)
time_left = (nsamples - nsamples_effective) * ave_time_per_check / samples_per_check
if time_left > 0:
time_left = str(datetime.timedelta(seconds=int(time_left)))
else:
@@ -398,7 +679,10 @@ def print_progress(
sampling_time = datetime.timedelta(seconds=np.sum(time_per_check))
tau_str = "{}:{:0.1f}->{:0.1f}".format(tau_int, np.min(tau_list), np.max(tau_list))
if max_frac >= 0:
tau_str = "{}(+{:0.1f})".format(tau_int, max_frac)
else:
tau_str = "{}({:0.1f})".format(tau_int, max_frac)
if tau_usable:
tau_str = "={}".format(tau_str)
else:
@@ -406,13 +690,13 @@ def print_progress(
evals_per_check = sampler.nwalkers * sampler.ntemps
ncalls = "{:1.1e}".format(sampler.time * sampler.nwalkers * sampler.ntemps)
ncalls = "{:1.1e}".format(iteration * sampler.nwalkers * sampler.ntemps)
eval_timing = "{:1.1f}ms/ev".format(1e3 * ave_time_per_check / evals_per_check)
samp_timing = "{:1.1f}ms/sm".format(1e3 * ave_time_per_check / samples_per_check)
print(
"{}| {} | nc:{}| a0:{}| swp:{}| n:{}<{}| tau{}| {}| {}".format(
sampler.time,
iteration,
str(sampling_time).split(".")[0],
ncalls,
acceptance_str,
@@ -427,16 +711,30 @@ def print_progress(
)
def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
search_parameter_keys, resume_file, tau_list, tau_list_n,
time_per_check):
def checkpoint(
iteration,
outdir,
label,
nsamples_effective,
sampler,
nburn,
thin,
search_parameter_keys,
resume_file,
log_likelihood_array,
chain_array,
pos0,
tau_list,
tau_list_n,
time_per_check,
):
logger.info("Writing checkpoint and diagnostics")
ndim = sampler.dim
# Store the samples if possible
if nsamples_effective > 0:
filename = "{}/{}_samples.txt".format(outdir, label)
samples = sampler.chain[0, :, nburn:sampler.time:thin, :].reshape(
samples = np.array(chain_array)[:, nburn : iteration : thin, :].reshape(
(-1, ndim)
)
df = pd.DataFrame(samples, columns=search_parameter_keys)
@@ -445,14 +743,17 @@ def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
# Pickle the resume artefacts
sampler_copy = copy.copy(sampler)
del sampler_copy.pool
sampler_copy._chain = sampler._chain[:, :, : sampler.time, :]
sampler_copy._logposterior = sampler._logposterior[:, :, : sampler.time]
sampler_copy._loglikelihood = sampler._loglikelihood[:, :, : sampler.time]
sampler_copy._beta_history = sampler._beta_history[:, : sampler.time]
data = dict(
sampler=sampler_copy, tau_list=tau_list, tau_list_n=tau_list_n,
time_per_check=time_per_check)
iteration=iteration,
sampler=sampler_copy,
tau_list=tau_list,
tau_list_n=tau_list_n,
time_per_check=time_per_check,
log_likelihood_array=log_likelihood_array,
chain_array=chain_array,
pos0=pos0,
)
with open(resume_file, "wb") as file:
dill.dump(data, file, protocol=4)
@@ -465,16 +766,24 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
nwalkers, nsteps, ndim = walkers.shape
idxs = np.arange(nsteps)
fig, axes = plt.subplots(nrows=ndim, ncols=2, figsize=(8, 3 * ndim))
scatter_kwargs = dict(lw=0, marker="o", markersize=1)
scatter_kwargs = dict(lw=0, marker="o", markersize=1, alpha=0.05,)
# Plot the burn-in
for i, (ax, axh) in enumerate(axes):
ax.plot(
idxs[: nburn + 1], walkers[:, : nburn + 1, i].T, color="C1", **scatter_kwargs
idxs[: nburn + 1],
walkers[:, : nburn + 1, i].T,
color="C1",
**scatter_kwargs
)
# Plot the thinned posterior samples
for i, (ax, axh) in enumerate(axes):
ax.plot(idxs[nburn::thin], walkers[:, nburn::thin, i].T, color="C0", **scatter_kwargs)
ax.plot(
idxs[nburn::thin],
walkers[:, nburn::thin, i].T,
color="C0",
**scatter_kwargs
)
axh.hist(walkers[:, nburn::thin, i].reshape((-1)), bins=50, alpha=0.8)
axh.set_xlabel(parameter_labels[i])
ax.set_ylabel(parameter_labels[i])
@@ -485,24 +794,26 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
plt.close(fig)
def plot_tau(tau_list_n, tau_list, outdir, label, autocorr_tau):
def plot_tau(
tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau
):
fig, ax = plt.subplots()
ax.plot(tau_list_n, tau_list, "-", color='C1')
check_tau_idx = -int(tau_list[-1] * autocorr_tau)
check_taus = tau_list[check_tau_idx:]
check_taus_n = tau_list_n[check_tau_idx:]
ax.plot(check_taus_n, check_taus, "-", color='C0')
for i, key in enumerate(search_parameter_keys):
ax.plot(tau_list_n, np.array(tau_list)[:, i], label=key)
ax.axvline(tau_list_n[-1] - tau * autocorr_tau)
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\langle \tau \rangle$")
ax.legend()
fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label))
plt.close(fig)
def compute_evidence(sampler, outdir, label, nburn, thin, make_plots=True):
def compute_evidence(sampler, log_likelihood_array, outdir, label, nburn, thin,
iteration, make_plots=True):
""" Computes the evidence using thermodynamic integration """
betas = sampler.betas
# We compute the evidence without the burnin samples, but we do not thin
lnlike = sampler.loglikelihood[:, :, nburn:sampler.time]
lnlike = log_likelihood_array[:, :, nburn : iteration]
mean_lnlikes = np.mean(np.mean(lnlike, axis=1), axis=1)
mean_lnlikes = mean_lnlikes[::-1]
@@ -535,8 +846,7 @@ def compute_evidence(sampler, outdir, label, nburn, thin, make_plots=True):
ax2.semilogx(min_betas, evidence, "-o")
ax2.set_ylabel(
r"$\int_{\beta_{min}}^{\beta=1}"
+ r"\langle \log(\mathcal{L})\rangle d\beta$",
r"$\int_{\beta_{min}}^{\beta=1}" + r"\langle \log(\mathcal{L})\rangle d\beta$",
size=16,
)
ax2.set_xlabel(r"$\beta_{min}$")
@@ -590,14 +900,14 @@ class LikePriorEvaluator(object):
def __call__(self, x):
lp = self.logp(x)
if np.isnan(lp):
raise ValueError('Prior function returned NaN.')
raise ValueError("Prior function returned NaN.")
if lp == float('-inf'):
if lp == float("-inf"):
# Can't return -inf, since this messes with beta=0 behaviour.
ll = 0
else:
ll = self.logl(x)
if np.isnan(ll).any():
raise ValueError('Log likelihood function returned NaN.')
raise ValueError("Log likelihood function returned NaN.")
return ll, lp
Loading