Skip to content
Snippets Groups Projects

Improve ptemcee

Merged Gregory Ashton requested to merge improve-ptemcee into master
All threads resolved!
Compare and Show latest version
3 files
+ 535
220
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 525
197
@@ -7,6 +7,7 @@ import signal
@@ -7,6 +7,7 @@ import signal
import sys
import sys
import time
import time
import dill
import dill
 
from collections import namedtuple
import numpy as np
import numpy as np
import pandas as pd
import pandas as pd
@@ -16,6 +17,24 @@ from ..utils import logger
@@ -16,6 +17,24 @@ from ..utils import logger
from .base_sampler import SamplerError, MCMCSampler
from .base_sampler import SamplerError, MCMCSampler
 
ConvergenceInputs = namedtuple(
 
"ConvergenceInputs",
 
[
 
"autocorr_c",
 
"autocorr_tol",
 
"autocorr_tau",
 
"safety",
 
"burn_in_nact",
 
"thin_by_nact",
 
"frac_threshold",
 
"nsamples",
 
"ignore_keys_for_tau",
 
"min_tau",
 
"niterations_per_check",
 
],
 
)
 
 
class Ptemcee(MCMCSampler):
class Ptemcee(MCMCSampler):
"""bilby wrapper ptemcee (https://github.com/willvousden/ptemcee)
"""bilby wrapper ptemcee (https://github.com/willvousden/ptemcee)
@@ -24,101 +43,246 @@ class Ptemcee(MCMCSampler):
@@ -24,101 +43,246 @@ class Ptemcee(MCMCSampler):
documentation for that class for further help. Under Other Parameters, we
documentation for that class for further help. Under Other Parameters, we
list commonly used kwargs and the bilby defaults.
list commonly used kwargs and the bilby defaults.
 
Parameters
 
----------
 
nsamples: int, (5000)
 
The requested number of samples. Note, in cases where the
 
autocorrelation parameter is difficult to measure, it is possible to
 
end up with more than nsamples.
 
burn_in_act, thin_by_nact: int, (50, 1)
 
The number of burn-in autocorrelation times to discard and the thin-by
 
factor. Increasing burn_in_act increases the time required for burn-in.
 
Increasing thin_by_nact increases the time required to obtain nsamples.
 
autocorr_tol: int, (50)
 
The minimum number of autocorrelation times needed to trust the
 
estimate of the autocorrelation time.
 
autocorr_c: int, (5)
 
The step size for the window search used by emcee.autocorr.integrated_time
 
safety: int, (1)
 
A multiplicitive factor for the estimated autocorrelation. Useful for
 
cases where non-convergence can be observed by eye but the automated
 
tools are failing.
 
autocorr_tau:
 
The number of autocorrelation times to use in assessing if the
 
autocorrelation time is stable.
 
frac_threshold: float, (0.01)
 
The maximum fractional change in the autocorrelation for the last
 
autocorr_tau steps. If the fractional change exceeds this value,
 
sampling will continue until the estimate of the autocorrelation time
 
can be trusted.
 
min_tau: int, (1)
 
A minimum tau (autocorrelation time) to accept.
 
check_point_deltaT: float, (600)
 
The period with which to checkpoint (in seconds).
 
threads: int, (1)
 
If threads > 1, a MultiPool object is setup and used.
 
exit_code: int, (77)
 
The code on which the sampler exits.
 
store_walkers: bool (False)
 
If true, store the unthinned, unburnt chaines in the result. Note, this
 
is not recommended for cases where tau is large.
 
ignore_keys_for_tau: str
 
A pattern used to ignore keys in estimating the autocorrelation time.
 
pos0: str, list ("prior")
 
If a string, one of "prior" or "minimize". For "prior", the initial
 
positions of the sampler are drawn from the sampler. If "minimize",
 
a scipy.optimize step is applied to all parameters a number of times.
 
The walkers are then initialized from the range of values obtained.
 
If a list, for the keys in the list the optimization step is applied,
 
otherwise the initial points are drawn from the prior.
 
 
Other Parameters
Other Parameters
----------------
----------------
nwalkers: int, (100)
nwalkers: int, (200)
The number of walkers
The number of walkers
nsteps: int, (100)
nsteps: int, (100)
The number of steps to take
The number of steps to take
nburn: int (50)
The fixed number of steps to discard as burn-in
ntemps: int (2)
ntemps: int (2)
The number of temperatures used by ptemcee
The number of temperatures used by ptemcee
 
Tmax: float
 
The maximum temperature
"""
"""
 
# Arguments used by ptemcee
# Arguments used by ptemcee
default_kwargs = dict(
default_kwargs = dict(
ntemps=20, nwalkers=200, Tmax=None, betas=None,
ntemps=20,
a=2.0, loglargs=[], logpargs=[], loglkwargs={},
nwalkers=200,
logpkwargs={}, adaptation_lag=10000, adaptation_time=100, random=None,
Tmax=None,
adapt=False, swap_ratios=False)
betas=None,
a=2.0,
def __init__(self, likelihood, priors, outdir='outdir', label='label',
adaptation_lag=10000,
use_ratio=False, check_point_plot=True, skip_import_verification=False,
adaptation_time=100,
resume=True, nsamples=5000, burn_in_nact=50, thin_by_nact=1,
random=None,
autocorr_c=5, safety=1, frac_threshold=0.01,
adapt=True,
autocorr_tol=50, autocorr_tau=5, min_tau=1, check_point_deltaT=600,
swap_ratios=False,
threads=1, exit_code=77, plot=False, store_walkers=False,
)
ignore_keys_for_tau="recalib", pos0="prior", **kwargs):
 
def __init__(
 
self,
 
likelihood,
 
priors,
 
outdir="outdir",
 
label="label",
 
use_ratio=False,
 
check_point_plot=True,
 
skip_import_verification=False,
 
resume=True,
 
nsamples=5000,
 
burn_in_nact=50,
 
thin_by_nact=1,
 
autocorr_tol=50,
 
autocorr_c=5,
 
safety=1,
 
autocorr_tau=5,
 
frac_threshold=0.01,
 
min_tau=1,
 
check_point_deltaT=600,
 
threads=1,
 
exit_code=77,
 
plot=False,
 
store_walkers=False,
 
ignore_keys_for_tau=None,
 
pos0="prior",
 
niterations_per_check=10,
 
**kwargs
 
):
super(Ptemcee, self).__init__(
super(Ptemcee, self).__init__(
likelihood=likelihood, priors=priors, outdir=outdir,
likelihood=likelihood,
label=label, use_ratio=use_ratio, plot=plot,
priors=priors,
skip_import_verification=skip_import_verification, **kwargs)
outdir=outdir,
 
label=label,
 
use_ratio=use_ratio,
 
plot=plot,
 
skip_import_verification=skip_import_verification,
 
**kwargs
 
)
 
 
self.nwalkers = self.sampler_init_kwargs["nwalkers"]
 
self.ntemps = self.sampler_init_kwargs["ntemps"]
 
self.max_steps = 500
 
# Setup up signal handling
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
 
# Checkpointing inputs
 
self.exit_code = exit_code
self.resume = resume
self.resume = resume
self.autocorr_c = autocorr_c
self.safety = safety
self.burn_in_nact = burn_in_nact
self.thin_by_nact = thin_by_nact
self.frac_threshold = frac_threshold
self.nsamples = nsamples
self.autocorr_tol = autocorr_tol
self.autocorr_tau = autocorr_tau
self.min_tau = min_tau
self.check_point_deltaT = check_point_deltaT
self.check_point_deltaT = check_point_deltaT
 
self.check_point_plot = check_point_plot
 
self.resume_file = "{}/{}_checkpoint_resume.pickle".format(
 
self.outdir, self.label
 
)
 
 
# Store convergence checking inputs in a named tuple
 
convergence_inputs_dict = dict(
 
autocorr_c=autocorr_c,
 
autocorr_tol=autocorr_tol,
 
autocorr_tau=autocorr_tau,
 
safety=safety,
 
burn_in_nact=burn_in_nact,
 
thin_by_nact=thin_by_nact,
 
frac_threshold=frac_threshold,
 
nsamples=nsamples,
 
ignore_keys_for_tau=ignore_keys_for_tau,
 
min_tau=min_tau,
 
niterations_per_check=niterations_per_check,
 
)
 
self.convergence_inputs = ConvergenceInputs(**convergence_inputs_dict)
 
# MultiProcessing inputs
self.threads = threads
self.threads = threads
 
 
# Misc inputs
self.store_walkers = store_walkers
self.store_walkers = store_walkers
self.ignore_keys_for_tau = ignore_keys_for_tau
self.pos0 = pos0
self.pos0 = pos0
self.check_point_plot = check_point_plot
self.resume_file = "{}/{}_checkpoint_resume.pickle".format(self.outdir, self.label)
self.exit_code = exit_code
@property
@property
def sampler_function_kwargs(self):
def sampler_function_kwargs(self):
keys = ['adapt', 'swap_ratios']
""" Kwargs passed to samper.sampler() """
 
keys = ["adapt", "swap_ratios"]
return {key: self.kwargs[key] for key in keys}
return {key: self.kwargs[key] for key in keys}
@property
@property
def sampler_init_kwargs(self):
def sampler_init_kwargs(self):
return {key: value
""" Kwargs passed to initialize ptemcee.Sampler() """
for key, value in self.kwargs.items()
return {
if key not in self.sampler_function_kwargs}
key: value
 
for key, value in self.kwargs.items()
 
if key not in self.sampler_function_kwargs
 
}
 
 
def _translate_kwargs(self, kwargs):
 
""" Translate kwargs """
 
if "nwalkers" not in kwargs:
 
for equiv in self.nwalkers_equiv_kwargs:
 
if equiv in kwargs:
 
kwargs["nwalkers"] = kwargs.pop(equiv)
def get_pos0_from_prior(self):
def get_pos0_from_prior(self):
""" for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim """
""" Draw the initial positions from the prior
 
 
Returns
 
-------
 
pos0: list
 
The initial postitions of the walkers, with shape (ntemps, nwalkers, ndim)
 
 
"""
logger.info("Generating pos0 samples")
logger.info("Generating pos0 samples")
return [[self.get_random_draw_from_prior()
return [
for _ in range(self.sampler_init_kwargs["nwalkers"])]
[
for _ in range(self.kwargs['ntemps'])]
self.get_random_draw_from_prior()
 
for _ in range(self.nwalkers)
 
]
 
for _ in range(self.kwargs["ntemps"])
 
]
def get_pos0_from_minimize(self, minimize_list=None):
def get_pos0_from_minimize(self, minimize_list=None):
logger.info("Attempting to set pos0 from minimize")
""" Draw the initial positions using an initial minimization step
 
 
See pos0 in the class initialization for details.
 
 
Returns
 
-------
 
pos0: list
 
The initial postitions of the walkers, with shape (ntemps, nwalkers, ndim)
 
 
"""
 
from scipy.optimize import minimize
from scipy.optimize import minimize
 
# Set up the minimize list: keys not in this list will have initial
 
# positions drawn from the prior
if minimize_list is None:
if minimize_list is None:
minimize_list = self.search_parameter_keys
minimize_list = self.search_parameter_keys
pos0 = np.zeros((self.kwargs["ntemps"], self.kwargs["nwalkers"], self.ndim))
pos0 = np.zeros((self.kwargs["ntemps"], self.kwargs["nwalkers"], self.ndim))
else:
else:
pos0 = np.array(self.get_pos0_from_prior())
pos0 = np.array(self.get_pos0_from_prior())
 
logger.info("Attempting to set pos0 for {} from minimize".format(minimize_list))
 
likelihood_copy = copy.copy(self.likelihood)
likelihood_copy = copy.copy(self.likelihood)
def neg_log_like(params):
def neg_log_like(params):
 
""" Internal function to minimize """
likelihood_copy.parameters.update(
likelihood_copy.parameters.update(
{key: val for key, val in zip(minimize_list, params)})
{key: val for key, val in zip(minimize_list, params)}
 
)
try:
try:
return -likelihood_copy.log_likelihood()
return -likelihood_copy.log_likelihood()
except RuntimeError:
except RuntimeError:
return +np.inf
return +np.inf
bounds = [(self.priors[key].minimum, self.priors[key].maximum)
for key in minimize_list]
# Bounds used in the minimization
 
bounds = [
 
(self.priors[key].minimum, self.priors[key].maximum)
 
for key in minimize_list
 
]
 
 
# Run the minimization step several times to get a range of values
trials = 0
trials = 0
success = []
success = []
while True:
while True:
@@ -126,7 +290,8 @@ class Ptemcee(MCMCSampler):
@@ -126,7 +290,8 @@ class Ptemcee(MCMCSampler):
likelihood_copy.parameters.update(draw)
likelihood_copy.parameters.update(draw)
x0 = [draw[key] for key in minimize_list]
x0 = [draw[key] for key in minimize_list]
res = minimize(
res = minimize(
neg_log_like, x0, bounds=bounds, method='L-BFGS-B', tol=1e-15)
neg_log_like, x0, bounds=bounds, method="L-BFGS-B", tol=1e-15
 
)
if res.success:
if res.success:
success.append(res.x)
success.append(res.x)
if trials > 100:
if trials > 100:
@@ -134,62 +299,94 @@ class Ptemcee(MCMCSampler):
@@ -134,62 +299,94 @@ class Ptemcee(MCMCSampler):
if len(success) >= 10:
if len(success) >= 10:
break
break
 
# Initialize positions from the range of values
success = np.array(success)
success = np.array(success)
for i, key in enumerate(minimize_list):
for i, key in enumerate(minimize_list):
pos0_min = np.min(success[:, i])
pos0_min = np.min(success[:, i])
pos0_max = np.max(success[:, i])
pos0_max = np.max(success[:, i])
logger.info("Initialize {} walkers from {}->{}"
logger.info(
.format(key, pos0_min, pos0_max))
"Initialize {} walkers from {}->{}".format(key, pos0_min, pos0_max)
 
)
j = self.search_parameter_keys.index(key)
j = self.search_parameter_keys.index(key)
pos0[:, :, j] = np.random.uniform(
pos0[:, :, j] = np.random.uniform(
pos0_min, pos0_max,
pos0_min,
size=(self.kwargs["ntemps"], self.kwargs["nwalkers"]))
pos0_max,
 
size=(self.kwargs["ntemps"], self.kwargs["nwalkers"]),
 
)
return pos0
return pos0
def setup_sampler(self):
def setup_sampler(self):
 
""" Either initialize the sampelr or read in the resume file """
import ptemcee
import ptemcee
 
if os.path.isfile(self.resume_file) and self.resume is True:
if os.path.isfile(self.resume_file) and self.resume is True:
logger.info("Resume data {} found".format(self.resume_file))
logger.info("Resume data {} found".format(self.resume_file))
with open(self.resume_file, "rb") as file:
with open(self.resume_file, "rb") as file:
data = dill.load(file)
data = dill.load(file)
 
# Extract the check-point data
self.sampler = data["sampler"]
self.sampler = data["sampler"]
 
self.iteration = data["iteration"]
 
self.chain_array = data["chain_array"]
 
self.log_likelihood_array = data["log_likelihood_array"]
 
self.pos0 = data["pos0"]
 
self.beta_list = data["beta_list"]
 
self.sampler._betas = np.array(self.beta_list[-1])
self.tau_list = data["tau_list"]
self.tau_list = data["tau_list"]
self.tau_list_n = data["tau_list_n"]
self.tau_list_n = data["tau_list_n"]
self.time_per_check = data["time_per_check"]
self.time_per_check = data["time_per_check"]
 
# Initialize the pool
self.sampler.pool = self.pool
self.sampler.pool = self.pool
self.sampler.threads = self.threads
self.sampler.threads = self.threads
pos0 = None
logger.info(
"Resuming from previous run with time={}".format(self.iteration)
logger.info("Resuming from previous run with time={}".format(self.sampler.time))
)
else:
else:
# Initialize the PTSampler
# Initialize the PTSampler
if self.threads == 1:
if self.threads == 1:
self.sampler = ptemcee.Sampler(
self.sampler = ptemcee.Sampler(
dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior,
dim=self.ndim,
**self.sampler_init_kwargs)
logl=self.log_likelihood,
 
logp=self.log_prior,
 
**self.sampler_init_kwargs
 
)
else:
else:
self.sampler = ptemcee.Sampler(
self.sampler = ptemcee.Sampler(
dim=self.ndim, logl=do_nothing_function, logp=do_nothing_function,
dim=self.ndim,
pool=self.pool, threads=self.threads, **self.sampler_init_kwargs)
logl=do_nothing_function,
 
logp=do_nothing_function,
 
pool=self.pool,
 
threads=self.threads,
 
**self.sampler_init_kwargs
 
)
self.sampler._likeprior = LikePriorEvaluator(
self.sampler._likeprior = LikePriorEvaluator(
self.search_parameter_keys, use_ratio=self.use_ratio)
self.search_parameter_keys, use_ratio=self.use_ratio
)
# Set up empty lists
 
# Initialize storing results
 
self.iteration = 0
 
self.chain_array = self.get_zero_chain_array()
 
self.log_likelihood_array = self.get_zero_log_likelihood_array()
 
self.beta_list = []
self.tau_list = []
self.tau_list = []
self.tau_list_n = []
self.tau_list_n = []
self.time_per_check = []
self.time_per_check = []
 
self.pos0 = self.get_pos0()
 
 
return self.sampler
# Initialize the walker postitions
def get_zero_chain_array(self):
pos0 = self.get_pos0()
return np.zeros((self.nwalkers, self.max_steps, self.ndim))
return self.sampler, pos0
def get_zero_log_likelihood_array(self):
 
return np.zeros((self.ntemps, self.nwalkers, self.max_steps))
def get_pos0(self):
def get_pos0(self):
 
""" Master logic for setting pos0 """
if isinstance(self.pos0, str) and self.pos0.lower() == "prior":
if isinstance(self.pos0, str) and self.pos0.lower() == "prior":
return self.get_pos0_from_prior()
return self.get_pos0_from_prior()
elif isinstance(self.pos0, str) and self.pos0.lower() == "minimize":
elif isinstance(self.pos0, str) and self.pos0.lower() == "minimize":
@@ -200,108 +397,65 @@ class Ptemcee(MCMCSampler):
@@ -200,108 +397,65 @@ class Ptemcee(MCMCSampler):
raise SamplerError("pos0={} not implemented".format(self.pos0))
raise SamplerError("pos0={} not implemented".format(self.pos0))
def setup_pool(self):
def setup_pool(self):
 
""" If threads > 1, setup a MultiPool, else run in serial mode """
if self.threads > 1:
if self.threads > 1:
import schwimmbad
import schwimmbad
 
logger.info("Creating MultiPool with {} processes".format(self.threads))
logger.info("Creating MultiPool with {} processes".format(self.threads))
self.pool = schwimmbad.MultiPool(
self.pool = schwimmbad.MultiPool(
self.threads,
self.threads, initializer=init, initargs=(self.likelihood, self.priors)
initializer=init,
)
initargs=(self.likelihood, self.priors))
else:
else:
self.pool = None
self.pool = None
def run_sampler(self):
def run_sampler(self):
self.setup_pool()
self.setup_pool()
out = self.run_sampler_internal()
sampler = self.setup_sampler()
if self.pool:
self.pool.close()
return out
def run_sampler_internal(self):
import emcee
sampler, pos0 = self.setup_sampler()
t0 = datetime.datetime.now()
t0 = datetime.datetime.now()
logger.info("Starting to sample")
logger.info("Starting to sample")
while True:
while True:
for (pos0, _, _) in sampler.sample(pos0, **self.sampler_function_kwargs):
for (pos0, log_posterior, log_likelihood) in sampler.sample(
 
self.pos0, storechain=False,
 
iterations=self.convergence_inputs.niterations_per_check,
 
**self.sampler_function_kwargs):
pass
pass
 
if self.iteration == self.chain_array.shape[1]:
 
self.chain_array = np.concatenate((
 
self.chain_array, self.get_zero_chain_array()), axis=1)
 
self.log_likelihood_array = np.concatenate((
 
self.log_likelihood_array, self.get_zero_log_likelihood_array()),
 
axis=2)
 
 
self.pos0 = pos0
 
self.chain_array[:, self.iteration, :] = pos0[0, :, :]
 
self.log_likelihood_array[:, :, self.iteration] = log_likelihood
 
# Calculate time per iteration
# Calculate time per iteration
self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
t0 = datetime.datetime.now()
t0 = datetime.datetime.now()
# Compute ACT tau for 0-temperature chains
self.iteration += 1
samples = sampler.chain[0, :, : sampler.time, :]
taus = []
for ii in range(sampler.nwalkers):
for jj, key in enumerate(self.search_parameter_keys):
if self.ignore_keys_for_tau and self.ignore_keys_for_tau in key:
continue
try:
taus.append(
emcee.autocorr.integrated_time(
samples[ii, :, jj], c=self.autocorr_c, tol=0
)[0]
)
except emcee.autocorr.AutocorrError:
taus.append(np.inf)
# Apply multiplicitive safety factor
tau = self.safety * np.mean(taus)
# Store for convergence checking and plotting
self.tau_list.append(tau)
self.tau_list_n.append(sampler.time)
# Convert to an integer
tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau
if np.isnan(tau_int) or np.isinf(tau_int):
print_progress(
self.sampler,
self.time_per_check,
self.nsamples,
np.nan,
np.nan,
np.nan,
[np.nan],
False)
continue
# Calculate the effective number of samples available
self.nburn = int(self.burn_in_nact * tau_int)
self.thin = int(np.max([1, self.thin_by_nact * tau_int]))
samples_per_check = sampler.nwalkers / self.thin
self.nsamples_effective = int(sampler.nwalkers * (sampler.time - self.nburn) / self.thin)
# Calculate convergence boolean
converged = self.nsamples < self.nsamples_effective
# Calculate fractional change in tau from previous iterations
check_taus = np.array(self.tau_list[-tau_int * self.autocorr_tau:])
if not np.any(np.isnan(check_taus)):
frac = (tau - check_taus) / tau
tau_usable = np.all(frac < self.frac_threshold)
else:
tau_usable = False
if sampler.time < tau_int * self.autocorr_tol or tau_int < self.min_tau:
tau_usable = False
# Print an update on the progress
(
print_progress(
stop,
self.sampler,
self.nburn,
self.time_per_check,
self.thin,
self.nsamples,
self.tau_int,
self.nsamples_effective,
self.nsamples_effective,
samples_per_check,
) = check_iteration(
tau_int,
self.chain_array[:, :self.iteration + 1, :],
check_taus,
sampler,
tau_usable,
self.convergence_inputs,
 
self.search_parameter_keys,
 
self.time_per_check,
 
self.beta_list,
 
self.tau_list,
 
self.tau_list_n,
)
)
if converged and tau_usable:
if stop:
logger.info("Finished sampling")
logger.info("Finished sampling")
break
break
@@ -318,11 +472,11 @@ class Ptemcee(MCMCSampler):
@@ -318,11 +472,11 @@ class Ptemcee(MCMCSampler):
self.write_current_state(plot=self.check_point_plot)
self.write_current_state(plot=self.check_point_plot)
# Get 0-likelihood samples and store in the result
# Get 0-likelihood samples and store in the result
samples = sampler.chain[0, :, :, :] # nwalkers, nsteps, ndim
self.result.samples = self.chain_array[
self.result.samples = (
:, self.nburn : self.iteration : self.thin, :
samples[:, self.nburn: sampler.time:self.thin, :].reshape((-1, self.ndim)))
].reshape((-1, self.ndim))
loglikelihood = sampler.loglikelihood[
loglikelihood = self.log_likelihood_array[
0, :, self.nburn:sampler.time:self.thin
0, :, self.nburn : self.iteration : self.thin
] # nwalkers, nsteps
] # nwalkers, nsteps
self.result.log_likelihood_evaluations = loglikelihood.reshape((-1))
self.result.log_likelihood_evaluations = loglikelihood.reshape((-1))
@@ -331,51 +485,195 @@ class Ptemcee(MCMCSampler):
@@ -331,51 +485,195 @@ class Ptemcee(MCMCSampler):
self.result.nburn = self.nburn
self.result.nburn = self.nburn
log_evidence, log_evidence_err = compute_evidence(
log_evidence, log_evidence_err = compute_evidence(
sampler, self.outdir, self.label, self.nburn, self.thin
sampler, self.log_likelihood_array, self.outdir, self.label, self.nburn,
 
self.thin, self.iteration,
)
)
self.result.log_evidence = log_evidence
self.result.log_evidence = log_evidence
self.result.log_evidence_err = log_evidence_err
self.result.log_evidence_err = log_evidence_err
self.result.sampling_time = datetime.timedelta(seconds=np.sum(self.time_per_check))
self.result.sampling_time = datetime.timedelta(
 
seconds=np.sum(self.time_per_check)
 
)
 
 
if self.pool:
 
self.pool.close()
return self.result
return self.result
def write_current_state_and_exit(self, signum=None, frame=None):
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
logger.warning("Run terminated with signal {}".format(signum))
if getattr(self, 'pool', None):
if getattr(self, "pool", None) or self.threads == 1:
self.write_current_state(plot=False)
self.write_current_state(plot=False)
logger.warning("Closing pool")
if getattr(self, "pool", None):
 
logger.info("Closing pool")
self.pool.close()
self.pool.close()
 
logger.info("Exit on signal {}".format(self.exit_code))
sys.exit(self.exit_code)
sys.exit(self.exit_code)
def write_current_state(self, plot=True):
def write_current_state(self, plot=True):
checkpoint(self.outdir, self.label, self.nsamples_effective,
checkpoint(
self.sampler, self.nburn, self.thin,
self.iteration,
self.search_parameter_keys, self.resume_file, self.tau_list,
self.outdir,
self.tau_list_n, self.time_per_check)
self.label,
 
self.nsamples_effective,
 
self.sampler,
 
self.nburn,
 
self.thin,
 
self.search_parameter_keys,
 
self.resume_file,
 
self.log_likelihood_array,
 
self.chain_array,
 
self.pos0,
 
self.beta_list,
 
self.tau_list,
 
self.tau_list_n,
 
self.time_per_check,
 
)
if plot and not np.isnan(self.nburn):
if plot and not np.isnan(self.nburn):
# Generate the walkers plot diagnostic
# Generate the walkers plot diagnostic
plot_walkers(
plot_walkers(
self.sampler.chain[0, :, : self.sampler.time, :],
self.chain_array[:, : self.iteration, :],
self.nburn, self.thin, self.search_parameter_keys, self.outdir,
self.nburn,
self.label
self.thin,
 
self.search_parameter_keys,
 
self.outdir,
 
self.label,
)
)
# Generate the tau plot diagnostic
# Generate the tau plot diagnostic
plot_tau(self.tau_list_n, self.tau_list, self.outdir, self.label,
plot_tau(
self.autocorr_tau)
self.tau_list_n,
 
self.tau_list,
 
self.search_parameter_keys,
 
self.outdir,
 
self.label,
 
self.tau_int,
 
self.convergence_inputs.autocorr_tau,
 
)
 
 
 
def check_iteration(
 
samples,
 
sampler,
 
convergence_inputs,
 
search_parameter_keys,
 
time_per_check,
 
beta_list,
 
tau_list,
 
tau_list_n,
 
):
 
""" Per-iteration logic to calculate the convergence check
 
 
Parameters
 
----------
 
convergence_inputs: bilby.core.sampler.ptemcee.ConvergenceInputs
 
A named tuple of the convergence checking inputs
 
search_parameter_keys: list
 
A list of the search parameter keys
 
time_per_check, tau_list, tau_list_n: list
 
Lists used for tracking the run
 
 
Returns
 
-------
 
stop: bool
 
A boolean flag, True if the stoping criteria has been met
 
burn: int
 
The number of burn-in steps to discard
 
thin: int
 
The thin-by factor to apply
 
tau_int: int
 
The integer estimated ACT
 
nsamples_effective: int
 
The effective number of samples after burning and thinning
 
"""
 
import emcee
 
 
ci = convergence_inputs
 
nwalkers, iteration, ndim = samples.shape
 
 
# Compute ACT tau for 0-temperature chains
 
tau_array = np.zeros((nwalkers, ndim))
 
for ii in range(nwalkers):
 
for jj, key in enumerate(search_parameter_keys):
 
if ci.ignore_keys_for_tau and ci.ignore_keys_for_tau in key:
 
continue
 
try:
 
tau_array[ii, jj] = emcee.autocorr.integrated_time(
 
samples[ii, :, jj], c=ci.autocorr_c, tol=0)[0]
 
except emcee.autocorr.AutocorrError:
 
tau_array[ii, jj] = np.inf
 
 
# Maximum over paramters, mean over walkers
 
tau = np.max(np.mean(tau_array, axis=0))
 
 
# Apply multiplicitive safety factor
 
tau = ci.safety * tau
 
 
# Store for convergence checking and plotting
 
beta_list.append(list(sampler.betas))
 
tau_list.append(list(np.mean(tau_array, axis=0)))
 
tau_list_n.append(iteration)
 
 
# Convert to an integer
 
tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau
 
 
if np.isnan(tau_int) or np.isinf(tau_int):
 
print_progress(
 
iteration, sampler, time_per_check, np.nan, np.nan,
 
np.nan, np.nan, False, convergence_inputs,
 
)
 
return False, np.nan, np.nan, np.nan, np.nan
 
 
# Calculate the effective number of samples available
 
nburn = int(ci.burn_in_nact * tau_int)
 
thin = int(np.max([1, ci.thin_by_nact * tau_int]))
 
samples_per_check = nwalkers / thin
 
nsamples_effective = int(nwalkers * (iteration - nburn) / thin)
 
 
# Calculate convergence boolean
 
converged = ci.nsamples < nsamples_effective
 
 
# Calculate fractional change in tau from previous iteration
 
check_taus = np.array(tau_list[-tau_int * ci.autocorr_tau :])
 
taus_per_parameter = check_taus[-1, :]
 
if not np.any(np.isnan(check_taus)):
 
frac = (taus_per_parameter - check_taus) / taus_per_parameter
 
max_frac = np.max(frac)
 
tau_usable = np.all(frac < ci.frac_threshold)
 
else:
 
max_frac = np.nan
 
tau_usable = False
 
 
if iteration < tau_int * ci.autocorr_tol or tau_int < ci.min_tau:
 
tau_usable = False
 
 
# Print an update on the progress
 
print_progress(
 
iteration,
 
sampler,
 
time_per_check,
 
nsamples_effective,
 
samples_per_check,
 
tau_int,
 
max_frac,
 
tau_usable,
 
convergence_inputs,
 
)
 
stop = converged and tau_usable
 
return stop, nburn, thin, tau_int, nsamples_effective
def print_progress(
def print_progress(
 
iteration,
sampler,
sampler,
time_per_check,
time_per_check,
nsamples,
nsamples_effective,
nsamples_effective,
samples_per_check,
samples_per_check,
tau_int,
tau_int,
tau_list,
max_frac,
tau_usable,
tau_usable,
 
convergence_inputs,
):
):
# Setup acceptance string
# Setup acceptance string
acceptance = sampler.acceptance_fraction[0, :]
acceptance = sampler.acceptance_fraction[0, :]
@@ -388,9 +686,7 @@ def print_progress(
@@ -388,9 +686,7 @@ def print_progress(
)
)
ave_time_per_check = np.mean(time_per_check[-3:])
ave_time_per_check = np.mean(time_per_check[-3:])
time_left = (
time_left = (convergence_inputs.nsamples - nsamples_effective) * ave_time_per_check / samples_per_check
(nsamples - nsamples_effective) * ave_time_per_check / samples_per_check
)
if time_left > 0:
if time_left > 0:
time_left = str(datetime.timedelta(seconds=int(time_left)))
time_left = str(datetime.timedelta(seconds=int(time_left)))
else:
else:
@@ -398,27 +694,31 @@ def print_progress(
@@ -398,27 +694,31 @@ def print_progress(
sampling_time = datetime.timedelta(seconds=np.sum(time_per_check))
sampling_time = datetime.timedelta(seconds=np.sum(time_per_check))
tau_str = "{}:{:0.1f}->{:0.1f}".format(tau_int, np.min(tau_list), np.max(tau_list))
if max_frac >= 0:
 
tau_str = "{}(+{:0.1f})".format(tau_int, max_frac)
 
else:
 
tau_str = "{}({:0.1f})".format(tau_int, max_frac)
if tau_usable:
if tau_usable:
tau_str = "={}".format(tau_str)
tau_str = "={}".format(tau_str)
else:
else:
tau_str = "!{}".format(tau_str)
tau_str = "!{}".format(tau_str)
evals_per_check = sampler.nwalkers * sampler.ntemps
evals_per_check = sampler.nwalkers * sampler.ntemps * convergence_inputs.niterations_per_check
ncalls = "{:1.1e}".format(sampler.time * sampler.nwalkers * sampler.ntemps)
ncalls = "{:1.1e}".format(
eval_timing = "{:1.1f}ms/ev".format(1e3 * ave_time_per_check / evals_per_check)
convergence_inputs.niterations_per_check * iteration * sampler.nwalkers * sampler.ntemps)
 
eval_timing = "{:1.2f}ms/ev".format(1e3 * ave_time_per_check / evals_per_check)
samp_timing = "{:1.1f}ms/sm".format(1e3 * ave_time_per_check / samples_per_check)
samp_timing = "{:1.1f}ms/sm".format(1e3 * ave_time_per_check / samples_per_check)
print(
print(
"{}| {} | nc:{}| a0:{}| swp:{}| n:{}<{}| tau{}| {}| {}".format(
"{}| {}| nc:{}| a0:{}| swp:{}| n:{}<{}| tau{}| {}| {}".format(
sampler.time,
iteration,
str(sampling_time).split(".")[0],
str(sampling_time).split(".")[0],
ncalls,
ncalls,
acceptance_str,
acceptance_str,
tswap_acceptance_str,
tswap_acceptance_str,
nsamples_effective,
nsamples_effective,
nsamples,
convergence_inputs.nsamples,
tau_str,
tau_str,
eval_timing,
eval_timing,
samp_timing,
samp_timing,
@@ -427,16 +727,31 @@ def print_progress(
@@ -427,16 +727,31 @@ def print_progress(
)
)
def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
def checkpoint(
search_parameter_keys, resume_file, tau_list, tau_list_n,
iteration,
time_per_check):
outdir,
 
label,
 
nsamples_effective,
 
sampler,
 
nburn,
 
thin,
 
search_parameter_keys,
 
resume_file,
 
log_likelihood_array,
 
chain_array,
 
pos0,
 
beta_list,
 
tau_list,
 
tau_list_n,
 
time_per_check,
 
):
logger.info("Writing checkpoint and diagnostics")
logger.info("Writing checkpoint and diagnostics")
ndim = sampler.dim
ndim = sampler.dim
# Store the samples if possible
# Store the samples if possible
if nsamples_effective > 0:
if nsamples_effective > 0:
filename = "{}/{}_samples.txt".format(outdir, label)
filename = "{}/{}_samples.txt".format(outdir, label)
samples = sampler.chain[0, :, nburn:sampler.time:thin, :].reshape(
samples = np.array(chain_array)[:, nburn : iteration : thin, :].reshape(
(-1, ndim)
(-1, ndim)
)
)
df = pd.DataFrame(samples, columns=search_parameter_keys)
df = pd.DataFrame(samples, columns=search_parameter_keys)
@@ -445,14 +760,18 @@ def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
@@ -445,14 +760,18 @@ def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
# Pickle the resume artefacts
# Pickle the resume artefacts
sampler_copy = copy.copy(sampler)
sampler_copy = copy.copy(sampler)
del sampler_copy.pool
del sampler_copy.pool
sampler_copy._chain = sampler._chain[:, :, : sampler.time, :]
sampler_copy._logposterior = sampler._logposterior[:, :, : sampler.time]
sampler_copy._loglikelihood = sampler._loglikelihood[:, :, : sampler.time]
sampler_copy._beta_history = sampler._beta_history[:, : sampler.time]
data = dict(
data = dict(
sampler=sampler_copy, tau_list=tau_list, tau_list_n=tau_list_n,
iteration=iteration,
time_per_check=time_per_check)
sampler=sampler_copy,
 
beta_list=beta_list,
 
tau_list=tau_list,
 
tau_list_n=tau_list_n,
 
time_per_check=time_per_check,
 
log_likelihood_array=log_likelihood_array,
 
chain_array=chain_array,
 
pos0=pos0,
 
)
with open(resume_file, "wb") as file:
with open(resume_file, "wb") as file:
dill.dump(data, file, protocol=4)
dill.dump(data, file, protocol=4)
@@ -465,16 +784,24 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
@@ -465,16 +784,24 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
nwalkers, nsteps, ndim = walkers.shape
nwalkers, nsteps, ndim = walkers.shape
idxs = np.arange(nsteps)
idxs = np.arange(nsteps)
fig, axes = plt.subplots(nrows=ndim, ncols=2, figsize=(8, 3 * ndim))
fig, axes = plt.subplots(nrows=ndim, ncols=2, figsize=(8, 3 * ndim))
scatter_kwargs = dict(lw=0, marker="o", markersize=1)
scatter_kwargs = dict(lw=0, marker="o", markersize=1, alpha=0.05,)
# Plot the burn-in
# Plot the burn-in
for i, (ax, axh) in enumerate(axes):
for i, (ax, axh) in enumerate(axes):
ax.plot(
ax.plot(
idxs[: nburn + 1], walkers[:, : nburn + 1, i].T, color="C1", **scatter_kwargs
idxs[: nburn + 1],
 
walkers[:, : nburn + 1, i].T,
 
color="C1",
 
**scatter_kwargs
)
)
# Plot the thinned posterior samples
# Plot the thinned posterior samples
for i, (ax, axh) in enumerate(axes):
for i, (ax, axh) in enumerate(axes):
ax.plot(idxs[nburn::thin], walkers[:, nburn::thin, i].T, color="C0", **scatter_kwargs)
ax.plot(
 
idxs[nburn::thin],
 
walkers[:, nburn::thin, i].T,
 
color="C0",
 
**scatter_kwargs
 
)
axh.hist(walkers[:, nburn::thin, i].reshape((-1)), bins=50, alpha=0.8)
axh.hist(walkers[:, nburn::thin, i].reshape((-1)), bins=50, alpha=0.8)
axh.set_xlabel(parameter_labels[i])
axh.set_xlabel(parameter_labels[i])
ax.set_ylabel(parameter_labels[i])
ax.set_ylabel(parameter_labels[i])
@@ -485,24 +812,26 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
@@ -485,24 +812,26 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
plt.close(fig)
plt.close(fig)
def plot_tau(tau_list_n, tau_list, outdir, label, autocorr_tau):
def plot_tau(
 
tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau
 
):
fig, ax = plt.subplots()
fig, ax = plt.subplots()
ax.plot(tau_list_n, tau_list, "-", color='C1')
for i, key in enumerate(search_parameter_keys):
check_tau_idx = -int(tau_list[-1] * autocorr_tau)
ax.plot(tau_list_n, np.array(tau_list)[:, i], label=key)
check_taus = tau_list[check_tau_idx:]
ax.axvline(tau_list_n[-1] - tau * autocorr_tau)
check_taus_n = tau_list_n[check_tau_idx:]
ax.plot(check_taus_n, check_taus, "-", color='C0')
ax.set_xlabel("Iteration")
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\langle \tau \rangle$")
ax.set_ylabel(r"$\langle \tau \rangle$")
 
ax.legend()
fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label))
fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label))
plt.close(fig)
plt.close(fig)
def compute_evidence(sampler, outdir, label, nburn, thin, make_plots=True):
def compute_evidence(sampler, log_likelihood_array, outdir, label, nburn, thin,
 
iteration, make_plots=True):
""" Computes the evidence using thermodynamic integration """
""" Computes the evidence using thermodynamic integration """
betas = sampler.betas
betas = sampler.betas
# We compute the evidence without the burnin samples, but we do not thin
# We compute the evidence without the burnin samples, but we do not thin
lnlike = sampler.loglikelihood[:, :, nburn:sampler.time]
lnlike = log_likelihood_array[:, :, nburn : iteration]
mean_lnlikes = np.mean(np.mean(lnlike, axis=1), axis=1)
mean_lnlikes = np.mean(np.mean(lnlike, axis=1), axis=1)
mean_lnlikes = mean_lnlikes[::-1]
mean_lnlikes = mean_lnlikes[::-1]
@@ -535,8 +864,7 @@ def compute_evidence(sampler, outdir, label, nburn, thin, make_plots=True):
@@ -535,8 +864,7 @@ def compute_evidence(sampler, outdir, label, nburn, thin, make_plots=True):
ax2.semilogx(min_betas, evidence, "-o")
ax2.semilogx(min_betas, evidence, "-o")
ax2.set_ylabel(
ax2.set_ylabel(
r"$\int_{\beta_{min}}^{\beta=1}"
r"$\int_{\beta_{min}}^{\beta=1}" + r"\langle \log(\mathcal{L})\rangle d\beta$",
+ r"\langle \log(\mathcal{L})\rangle d\beta$",
size=16,
size=16,
)
)
ax2.set_xlabel(r"$\beta_{min}$")
ax2.set_xlabel(r"$\beta_{min}$")
@@ -590,14 +918,14 @@ class LikePriorEvaluator(object):
@@ -590,14 +918,14 @@ class LikePriorEvaluator(object):
def __call__(self, x):
def __call__(self, x):
lp = self.logp(x)
lp = self.logp(x)
if np.isnan(lp):
if np.isnan(lp):
raise ValueError('Prior function returned NaN.')
raise ValueError("Prior function returned NaN.")
if lp == float('-inf'):
if lp == float("-inf"):
# Can't return -inf, since this messes with beta=0 behaviour.
# Can't return -inf, since this messes with beta=0 behaviour.
ll = 0
ll = 0
else:
else:
ll = self.logl(x)
ll = self.logl(x)
if np.isnan(ll).any():
if np.isnan(ll).any():
raise ValueError('Log likelihood function returned NaN.')
raise ValueError("Log likelihood function returned NaN.")
return ll, lp
return ll, lp
Loading