Skip to content
Snippets Groups Projects

Improve ptemcee

Merged Gregory Ashton requested to merge improve-ptemcee into master
All threads resolved!
Compare and Show latest version
2 files
+ 333
110
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 331
109
@@ -13,7 +13,7 @@ import pandas as pd
@@ -13,7 +13,7 @@ import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
from ..utils import logger
from ..utils import logger
from .base_sampler import MCMCSampler
from .base_sampler import SamplerError, MCMCSampler
class Ptemcee(MCMCSampler):
class Ptemcee(MCMCSampler):
@@ -24,37 +24,124 @@ class Ptemcee(MCMCSampler):
@@ -24,37 +24,124 @@ class Ptemcee(MCMCSampler):
documentation for that class for further help. Under Other Parameters, we
documentation for that class for further help. Under Other Parameters, we
list commonly used kwargs and the bilby defaults.
list commonly used kwargs and the bilby defaults.
 
Parameters
 
----------
 
nsamples: int, (5000)
 
The requested number of samples. Note, in cases where the
 
autocorrelation parameter is difficult to measure, it is possible to
 
end up with more than nsamples.
 
burn_in_act, thin_by_nact: int, (50, 1)
 
The number of burn-in autocorrelation times to discard and the thin-by
 
factor. Increasing burn_in_act increases the time required for burn-in.
 
Increasing thin_by_nact increases the time required to obtain nsamples.
 
autocorr_tol: int, (50)
 
The minimum number of autocorrelation times needed to trust the
 
estimate of the autocorrelation time.
 
autocorr_c: int, (5)
 
The step size for the window search used by emcee.autocorr.integrated_time
 
safety: int, (1)
 
A multiplicitive factor for the estimated autocorrelation. Useful for
 
cases where non-convergence can be observed by eye but the automated
 
tools are failing.
 
autocorr_tau:
 
The number of autocorrelation times to use in assessing if the
 
autocorrelation time is stable.
 
frac_threshold: float, (0.01)
 
The maximum fractional change in the autocorrelation for the last
 
autocorr_tau steps. If the fractional change exceeds this value,
 
sampling will continue until the estimate of the autocorrelation time
 
can be trusted.
 
min_tau: int, (1)
 
A minimum tau (autocorrelation time) to accept.
 
check_point_deltaT: float, (600)
 
The period with which to checkpoint (in seconds).
 
threads: int, (1)
 
If threads > 1, a MultiPool object is setup and used.
 
exit_code: int, (77)
 
The code on which the sampler exits.
 
store_walkers: bool (False)
 
If true, store the unthinned, unburnt chaines in the result. Note, this
 
is not recommended for cases where tau is large.
 
ignore_keys_for_tau: str
 
A pattern used to ignore keys in estimating the autocorrelation time.
 
pos0: str, list ("prior")
 
If a string, one of "prior" or "minimize". For "prior", the initial
 
positions of the sampler are drawn from the sampler. If "minimize",
 
a scipy.optimize step is applied to all parameters a number of times.
 
The walkers are then initialized from the range of values obtained.
 
If a list, for the keys in the list the optimization step is applied,
 
otherwise the initial points are drawn from the prior.
 
 
Other Parameters
Other Parameters
----------------
----------------
nwalkers: int, (100)
nwalkers: int, (200)
The number of walkers
The number of walkers
nsteps: int, (100)
nsteps: int, (100)
The number of steps to take
The number of steps to take
nburn: int (50)
The fixed number of steps to discard as burn-in
ntemps: int (2)
ntemps: int (2)
The number of temperatures used by ptemcee
The number of temperatures used by ptemcee
 
Tmax: float
 
The maximum temperature
"""
"""
 
# Arguments used by ptemcee
# Arguments used by ptemcee
default_kwargs = dict(
default_kwargs = dict(
ntemps=20, nwalkers=200, Tmax=None, betas=None,
ntemps=20,
a=2.0, loglargs=[], logpargs=[], loglkwargs={},
nwalkers=200,
logpkwargs={}, adaptation_lag=10000, adaptation_time=100, random=None,
Tmax=None,
iterations=1000, thin=1, storechain=True, adapt=False,
betas=None,
swap_ratios=False)
a=2.0,
loglargs=[],
def __init__(self, likelihood, priors, outdir='outdir', label='label',
logpargs=[],
use_ratio=False, check_point_plot=True, skip_import_verification=False,
loglkwargs={},
resume=True, nsamples=5000, burn_in_nact=50, thin_by_nact=1,
logpkwargs={},
autocorr_c=5, safety=1, frac_threshold=0.01,
adaptation_lag=10000,
autocorr_tol=50, min_tau=1, check_point_deltaT=600,
adaptation_time=100,
threads=1, exit_code=77, plot=False, store_walkers=False,
random=None,
**kwargs):
adapt=False,
 
swap_ratios=False,
 
)
 
 
def __init__(
 
self,
 
likelihood,
 
priors,
 
outdir="outdir",
 
label="label",
 
use_ratio=False,
 
check_point_plot=True,
 
skip_import_verification=False,
 
resume=True,
 
nsamples=5000,
 
burn_in_nact=50,
 
thin_by_nact=1,
 
autocorr_tol=50,
 
autocorr_c=5,
 
safety=1,
 
autocorr_tau=5,
 
frac_threshold=0.01,
 
min_tau=1,
 
check_point_deltaT=600,
 
threads=1,
 
exit_code=77,
 
plot=False,
 
store_walkers=False,
 
ignore_keys_for_tau=None,
 
pos0="prior",
 
**kwargs
 
):
super(Ptemcee, self).__init__(
super(Ptemcee, self).__init__(
likelihood=likelihood, priors=priors, outdir=outdir,
likelihood=likelihood,
label=label, use_ratio=use_ratio, plot=plot,
priors=priors,
skip_import_verification=skip_import_verification, **kwargs)
outdir=outdir,
 
label=label,
 
use_ratio=use_ratio,
 
plot=plot,
 
skip_import_verification=skip_import_verification,
 
**kwargs
 
)
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
@@ -68,36 +155,104 @@ class Ptemcee(MCMCSampler):
@@ -68,36 +155,104 @@ class Ptemcee(MCMCSampler):
self.frac_threshold = frac_threshold
self.frac_threshold = frac_threshold
self.nsamples = nsamples
self.nsamples = nsamples
self.autocorr_tol = autocorr_tol
self.autocorr_tol = autocorr_tol
 
self.autocorr_tau = autocorr_tau
self.min_tau = min_tau
self.min_tau = min_tau
self.check_point_deltaT = check_point_deltaT
self.check_point_deltaT = check_point_deltaT
self.threads = threads
self.threads = threads
self.store_walkers = store_walkers
self.store_walkers = store_walkers
 
self.ignore_keys_for_tau = ignore_keys_for_tau
 
self.pos0 = pos0
self.check_point_plot = check_point_plot
self.check_point_plot = check_point_plot
self.resume_file = "{}/{}_checkpoint_resume.pickle".format(self.outdir, self.label)
self.resume_file = "{}/{}_checkpoint_resume.pickle".format(
 
self.outdir, self.label
 
)
self.exit_code = exit_code
self.exit_code = exit_code
@property
@property
def sampler_function_kwargs(self):
def sampler_function_kwargs(self):
keys = ['iterations', 'thin', 'storechain', 'adapt', 'swap_ratios']
keys = ["adapt", "swap_ratios"]
return {key: self.kwargs[key] for key in keys}
return {key: self.kwargs[key] for key in keys}
@property
@property
def sampler_init_kwargs(self):
def sampler_init_kwargs(self):
return {key: value
return {
for key, value in self.kwargs.items()
key: value
if key not in self.sampler_function_kwargs}
for key, value in self.kwargs.items()
 
if key not in self.sampler_function_kwargs
 
}
def get_pos0_from_prior(self):
def get_pos0_from_prior(self):
""" for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim """
""" for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim """
logger.info("Generating pos0 samples")
logger.info("Generating pos0 samples")
return [[self.get_random_draw_from_prior()
return [
for _ in range(self.sampler_init_kwargs["nwalkers"])]
[
for _ in range(self.kwargs['ntemps'])]
self.get_random_draw_from_prior()
 
for _ in range(self.sampler_init_kwargs["nwalkers"])
 
]
 
for _ in range(self.kwargs["ntemps"])
 
]
 
 
def get_pos0_from_minimize(self, minimize_list=None):
 
logger.info("Attempting to set pos0 from minimize")
 
from scipy.optimize import minimize
 
 
if minimize_list is None:
 
minimize_list = self.search_parameter_keys
 
pos0 = np.zeros((self.kwargs["ntemps"], self.kwargs["nwalkers"], self.ndim))
 
else:
 
pos0 = np.array(self.get_pos0_from_prior())
 
 
likelihood_copy = copy.copy(self.likelihood)
 
 
def neg_log_like(params):
 
likelihood_copy.parameters.update(
 
{key: val for key, val in zip(minimize_list, params)}
 
)
 
try:
 
return -likelihood_copy.log_likelihood()
 
except RuntimeError:
 
return +np.inf
 
 
bounds = [
 
(self.priors[key].minimum, self.priors[key].maximum)
 
for key in minimize_list
 
]
 
trials = 0
 
success = []
 
while True:
 
draw = self.priors.sample()
 
likelihood_copy.parameters.update(draw)
 
x0 = [draw[key] for key in minimize_list]
 
res = minimize(
 
neg_log_like, x0, bounds=bounds, method="L-BFGS-B", tol=1e-15
 
)
 
if res.success:
 
success.append(res.x)
 
if trials > 100:
 
raise SamplerError("Unable to set pos0 from minimize")
 
if len(success) >= 10:
 
break
 
 
success = np.array(success)
 
for i, key in enumerate(minimize_list):
 
pos0_min = np.min(success[:, i])
 
pos0_max = np.max(success[:, i])
 
logger.info(
 
"Initialize {} walkers from {}->{}".format(key, pos0_min, pos0_max)
 
)
 
j = self.search_parameter_keys.index(key)
 
pos0[:, :, j] = np.random.uniform(
 
pos0_min,
 
pos0_max,
 
size=(self.kwargs["ntemps"], self.kwargs["nwalkers"]),
 
)
 
return pos0
def setup_sampler(self):
def setup_sampler(self):
import ptemcee
import ptemcee
 
if os.path.isfile(self.resume_file) and self.resume is True:
if os.path.isfile(self.resume_file) and self.resume is True:
logger.info("Resume data {} found".format(self.resume_file))
logger.info("Resume data {} found".format(self.resume_file))
with open(self.resume_file, "rb") as file:
with open(self.resume_file, "rb") as file:
@@ -113,17 +268,32 @@ class Ptemcee(MCMCSampler):
@@ -113,17 +268,32 @@ class Ptemcee(MCMCSampler):
pos0 = None
pos0 = None
logger.info("Resuming from previous run with time={}".format(self.sampler.time))
logger.info(
 
"Resuming from previous run with time={}".format(self.sampler.time)
 
)
else:
else:
# Initialize the PTSampler
# Initialize the PTSampler
self.sampler = ptemcee.Sampler(
if self.threads == 1:
dim=self.ndim, logl=do_nothing_function, logp=do_nothing_function,
self.sampler = ptemcee.Sampler(
pool=self.pool, threads=self.threads, **self.sampler_init_kwargs)
dim=self.ndim,
 
logl=self.log_likelihood,
 
logp=self.log_prior,
 
**self.sampler_init_kwargs
 
)
 
else:
 
self.sampler = ptemcee.Sampler(
 
dim=self.ndim,
 
logl=do_nothing_function,
 
logp=do_nothing_function,
 
pool=self.pool,
 
threads=self.threads,
 
**self.sampler_init_kwargs
 
)
# Overwrite the _likeprior to improve performance with threads > 1
self.sampler._likeprior = LikePriorEvaluator(
self.sampler._likeprior = LikePriorEvaluator(
self.search_parameter_keys, use_ratio=self.use_ratio
self.search_parameter_keys, use_ratio=self.use_ratio)
)
# Set up empty lists
# Set up empty lists
self.tau_list = []
self.tau_list = []
@@ -131,18 +301,28 @@ class Ptemcee(MCMCSampler):
@@ -131,18 +301,28 @@ class Ptemcee(MCMCSampler):
self.time_per_check = []
self.time_per_check = []
# Initialize the walker postitions
# Initialize the walker postitions
pos0 = self.get_pos0_from_prior()
pos0 = self.get_pos0()
return self.sampler, pos0
return self.sampler, pos0
 
def get_pos0(self):
 
if isinstance(self.pos0, str) and self.pos0.lower() == "prior":
 
return self.get_pos0_from_prior()
 
elif isinstance(self.pos0, str) and self.pos0.lower() == "minimize":
 
return self.get_pos0_from_minimize()
 
elif isinstance(self.pos0, list):
 
return self.get_pos0_from_minimize(minimize_list=self.pos0)
 
else:
 
raise SamplerError("pos0={} not implemented".format(self.pos0))
 
def setup_pool(self):
def setup_pool(self):
if self.threads > 1:
if self.threads > 1:
import schwimmbad
import schwimmbad
 
logger.info("Creating MultiPool with {} processes".format(self.threads))
logger.info("Creating MultiPool with {} processes".format(self.threads))
self.pool = schwimmbad.MultiPool(
self.pool = schwimmbad.MultiPool(
self.threads,
self.threads, initializer=init, initargs=(self.likelihood, self.priors)
initializer=init,
)
initargs=(self.likelihood, self.priors))
else:
else:
self.pool = None
self.pool = None
@@ -155,12 +335,14 @@ class Ptemcee(MCMCSampler):
@@ -155,12 +335,14 @@ class Ptemcee(MCMCSampler):
def run_sampler_internal(self):
def run_sampler_internal(self):
import emcee
import emcee
 
sampler, pos0 = self.setup_sampler()
sampler, pos0 = self.setup_sampler()
t0 = datetime.datetime.now()
t0 = datetime.datetime.now()
logger.info("Starting to sample")
logger.info("Starting to sample")
for (pos0, lnprob, lnlike) in sampler.sample(
while True:
pos0, **self.sampler_function_kwargs):
for (pos0, _, _) in sampler.sample(pos0, **self.sampler_function_kwargs):
 
pass
# Calculate time per iteration
# Calculate time per iteration
self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
@@ -171,7 +353,7 @@ class Ptemcee(MCMCSampler):
@@ -171,7 +353,7 @@ class Ptemcee(MCMCSampler):
taus = []
taus = []
for ii in range(sampler.nwalkers):
for ii in range(sampler.nwalkers):
for jj, key in enumerate(self.search_parameter_keys):
for jj, key in enumerate(self.search_parameter_keys):
if "recalib" in key:
if self.ignore_keys_for_tau and self.ignore_keys_for_tau in key:
continue
continue
try:
try:
taus.append(
taus.append(
@@ -190,9 +372,9 @@ class Ptemcee(MCMCSampler):
@@ -190,9 +372,9 @@ class Ptemcee(MCMCSampler):
self.tau_list_n.append(sampler.time)
self.tau_list_n.append(sampler.time)
# Convert to an integer
# Convert to an integer
tau = int(np.floor(tau)) if not np.isnan(tau) else tau
tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau
if np.isnan(tau) or np.isinf(tau):
if np.isnan(tau_int) or np.isinf(tau_int):
print_progress(
print_progress(
self.sampler,
self.sampler,
self.time_per_check,
self.time_per_check,
@@ -200,27 +382,31 @@ class Ptemcee(MCMCSampler):
@@ -200,27 +382,31 @@ class Ptemcee(MCMCSampler):
np.nan,
np.nan,
np.nan,
np.nan,
np.nan,
np.nan,
False)
[np.nan],
 
False,
 
)
continue
continue
# Calculate the effective number of samples available
# Calculate the effective number of samples available
self.nburn = int(self.burn_in_nact * tau)
self.nburn = int(self.burn_in_nact * tau_int)
self.thin = int(np.max([1, self.thin_by_nact * tau]))
self.thin = int(np.max([1, self.thin_by_nact * tau_int]))
samples_per_check = sampler.nwalkers / self.thin
samples_per_check = sampler.nwalkers / self.thin
self.nsamples_effective = int(sampler.nwalkers * (sampler.time - self.nburn) / self.thin)
self.nsamples_effective = int(
 
sampler.nwalkers * (sampler.time - self.nburn) / self.thin
 
)
# Calculate convergence boolean
# Calculate convergence boolean
converged = self.nsamples < self.nsamples_effective
converged = self.nsamples < self.nsamples_effective
# Calculate fractional change in tau from previous iteration
# Calculate fractional change in tau from previous iterations
check_taus = np.array(self.tau_list[-tau * self.autocorr_tol:])
check_taus = np.array(self.tau_list[-tau_int * self.autocorr_tau :])
if not np.any(np.isnan(check_taus)):
if not np.any(np.isnan(check_taus)):
frac = (tau - check_taus) / tau
frac = (tau - check_taus) / tau
tau_usable = np.all(frac < self.frac_threshold)
tau_usable = np.all(frac < self.frac_threshold)
else:
else:
tau_usable = False
tau_usable = False
if sampler.time < tau * self.autocorr_tol or tau < self.min_tau:
if sampler.time < tau_int * self.autocorr_tol or tau_int < self.min_tau:
tau_usable = False
tau_usable = False
# Print an update on the progress
# Print an update on the progress
@@ -230,7 +416,8 @@ class Ptemcee(MCMCSampler):
@@ -230,7 +416,8 @@ class Ptemcee(MCMCSampler):
self.nsamples,
self.nsamples,
self.nsamples_effective,
self.nsamples_effective,
samples_per_check,
samples_per_check,
tau,
tau_int,
 
check_taus,
tau_usable,
tau_usable,
)
)
@@ -247,23 +434,16 @@ class Ptemcee(MCMCSampler):
@@ -247,23 +434,16 @@ class Ptemcee(MCMCSampler):
if last_checkpoint_s > self.check_point_deltaT:
if last_checkpoint_s > self.check_point_deltaT:
self.write_current_state(plot=self.check_point_plot)
self.write_current_state(plot=self.check_point_plot)
# Check if we reached the end without converging
if sampler.time == self.sampler_function_kwargs["iterations"]:
raise ValueError(
"Failed to reach convergence by iterations={}".format(
self.sampler_function_kwargs["iterations"]
)
)
# Run a final checkpoint to update the plots and samples
# Run a final checkpoint to update the plots and samples
self.write_current_state(plot=self.check_point_plot)
self.write_current_state(plot=self.check_point_plot)
# Get 0-likelihood samples and store in the result
# Get 0-likelihood samples and store in the result
samples = sampler.chain[0, :, :, :] # nwalkers, nsteps, ndim
samples = sampler.chain[0, :, :, :] # nwalkers, nsteps, ndim
self.result.samples = (
self.result.samples = samples[
samples[:, self.nburn: sampler.time:self.thin, :].reshape((-1, self.ndim)))
:, self.nburn : sampler.time : self.thin, :
 
].reshape((-1, self.ndim))
loglikelihood = sampler.loglikelihood[
loglikelihood = sampler.loglikelihood[
0, :, self.nburn:sampler.time:self.thin
0, :, self.nburn : sampler.time : self.thin
] # nwalkers, nsteps
] # nwalkers, nsteps
self.result.log_likelihood_evaluations = loglikelihood.reshape((-1))
self.result.log_likelihood_evaluations = loglikelihood.reshape((-1))
@@ -277,34 +457,54 @@ class Ptemcee(MCMCSampler):
@@ -277,34 +457,54 @@ class Ptemcee(MCMCSampler):
self.result.log_evidence = log_evidence
self.result.log_evidence = log_evidence
self.result.log_evidence_err = log_evidence_err
self.result.log_evidence_err = log_evidence_err
self.result.sampling_time = datetime.timedelta(seconds=np.sum(self.time_per_check))
self.result.sampling_time = datetime.timedelta(
 
seconds=np.sum(self.time_per_check)
 
)
return self.result
return self.result
def write_current_state_and_exit(self, signum=None, frame=None):
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
logger.warning("Run terminated with signal {}".format(signum))
if getattr(self, 'pool', None):
if getattr(self, "pool", None):
self.write_current_state(plot=False)
self.write_current_state(plot=False)
logger.warning("Closing pool")
logger.warning("Closing pool")
self.pool.close()
self.pool.close()
sys.exit(self.exit_code)
sys.exit(self.exit_code)
def write_current_state(self, plot=True):
def write_current_state(self, plot=True):
checkpoint(self.outdir, self.label, self.nsamples_effective,
checkpoint(
self.sampler, self.nburn, self.thin,
self.outdir,
self.search_parameter_keys, self.resume_file, self.tau_list,
self.label,
self.tau_list_n, self.time_per_check)
self.nsamples_effective,
 
self.sampler,
 
self.nburn,
 
self.thin,
 
self.search_parameter_keys,
 
self.resume_file,
 
self.tau_list,
 
self.tau_list_n,
 
self.time_per_check,
 
)
if plot:
if plot and not np.isnan(self.nburn):
# Generate the walkers plot diagnostic
# Generate the walkers plot diagnostic
plot_walkers(
plot_walkers(
self.sampler.chain[0, :, : self.sampler.time, :],
self.sampler.chain[0, :, : self.sampler.time, :],
self.nburn, self.search_parameter_keys, self.outdir, self.label
self.nburn,
 
self.thin,
 
self.search_parameter_keys,
 
self.outdir,
 
self.label,
)
)
# Generate the tau plot diagnostic
# Generate the tau plot diagnostic
plot_tau(self.tau_list_n, self.tau_list, self.outdir, self.label,
plot_tau(
self.autocorr_tol)
self.tau_list_n,
 
self.tau_list,
 
self.outdir,
 
self.label,
 
self.autocorr_tau,
 
)
def print_progress(
def print_progress(
@@ -313,7 +513,8 @@ def print_progress(
@@ -313,7 +513,8 @@ def print_progress(
nsamples,
nsamples,
nsamples_effective,
nsamples_effective,
samples_per_check,
samples_per_check,
tau,
tau_int,
 
tau_list,
tau_usable,
tau_usable,
):
):
# Setup acceptance string
# Setup acceptance string
@@ -327,9 +528,7 @@ def print_progress(
@@ -327,9 +528,7 @@ def print_progress(
)
)
ave_time_per_check = np.mean(time_per_check[-3:])
ave_time_per_check = np.mean(time_per_check[-3:])
time_left = (
time_left = (nsamples - nsamples_effective) * ave_time_per_check / samples_per_check
(nsamples - nsamples_effective) * ave_time_per_check / samples_per_check
)
if time_left > 0:
if time_left > 0:
time_left = str(datetime.timedelta(seconds=int(time_left)))
time_left = str(datetime.timedelta(seconds=int(time_left)))
else:
else:
@@ -337,17 +536,17 @@ def print_progress(
@@ -337,17 +536,17 @@ def print_progress(
sampling_time = datetime.timedelta(seconds=np.sum(time_per_check))
sampling_time = datetime.timedelta(seconds=np.sum(time_per_check))
tau_str = str(tau)
tau_str = "{}:{:0.1f}->{:0.1f}".format(tau_int, np.min(tau_list), np.max(tau_list))
if tau_usable is False:
if tau_usable:
tau_str = "!{}".format(tau_str)
else:
tau_str = "={}".format(tau_str)
tau_str = "={}".format(tau_str)
 
else:
 
tau_str = "!{}".format(tau_str)
evals_per_check = sampler.nwalkers * sampler.ntemps
evals_per_check = sampler.nwalkers * sampler.ntemps
ncalls = "{:1.1e}".format(sampler.time * sampler.nwalkers * sampler.ntemps)
ncalls = "{:1.1e}".format(sampler.time * sampler.nwalkers * sampler.ntemps)
eval_timing = "{:1.1f}ms/evl".format(1e3 * ave_time_per_check / evals_per_check)
eval_timing = "{:1.1f}ms/ev".format(1e3 * ave_time_per_check / evals_per_check)
samp_timing = "{:1.2f}ms/smp".format(1e3 * ave_time_per_check / samples_per_check)
samp_timing = "{:1.1f}ms/sm".format(1e3 * ave_time_per_check / samples_per_check)
print(
print(
"{}| {} | nc:{}| a0:{}| swp:{}| n:{}<{}| tau{}| {}| {}".format(
"{}| {} | nc:{}| a0:{}| swp:{}| n:{}<{}| tau{}| {}| {}".format(
@@ -366,16 +565,26 @@ def print_progress(
@@ -366,16 +565,26 @@ def print_progress(
)
)
def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
def checkpoint(
search_parameter_keys, resume_file, tau_list, tau_list_n,
outdir,
time_per_check):
label,
 
nsamples_effective,
 
sampler,
 
nburn,
 
thin,
 
search_parameter_keys,
 
resume_file,
 
tau_list,
 
tau_list_n,
 
time_per_check,
 
):
logger.info("Writing checkpoint and diagnostics")
logger.info("Writing checkpoint and diagnostics")
ndim = sampler.dim
ndim = sampler.dim
# Store the samples if possible
# Store the samples if possible
if nsamples_effective > 0:
if nsamples_effective > 0:
filename = "{}/{}_samples.txt".format(outdir, label)
filename = "{}/{}_samples.txt".format(outdir, label)
samples = sampler.chain[0, :, nburn:sampler.time:thin, :].reshape(
samples = sampler.chain[0, :, nburn : sampler.time : thin, :].reshape(
(-1, ndim)
(-1, ndim)
)
)
df = pd.DataFrame(samples, columns=search_parameter_keys)
df = pd.DataFrame(samples, columns=search_parameter_keys)
@@ -390,48 +599,61 @@ def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
@@ -390,48 +599,61 @@ def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
sampler_copy._beta_history = sampler._beta_history[:, : sampler.time]
sampler_copy._beta_history = sampler._beta_history[:, : sampler.time]
data = dict(
data = dict(
sampler=sampler_copy, tau_list=tau_list, tau_list_n=tau_list_n,
sampler=sampler_copy,
time_per_check=time_per_check)
tau_list=tau_list,
 
tau_list_n=tau_list_n,
 
time_per_check=time_per_check,
 
)
with open(resume_file, "wb") as file:
with open(resume_file, "wb") as file:
dill.dump(data, file, protocol=4)
dill.dump(data, file, protocol=4)
del data, sampler_copy
del data, sampler_copy
logger.info("Finished writing checkpoint")
logger.info("Finished writing checkpoint")
def plot_walkers(walkers, nburn, parameter_labels, outdir, label):
def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
""" Method to plot the trace of the walkers in an ensemble MCMC plot """
""" Method to plot the trace of the walkers in an ensemble MCMC plot """
nwalkers, nsteps, ndim = walkers.shape
nwalkers, nsteps, ndim = walkers.shape
idxs = np.arange(nsteps)
idxs = np.arange(nsteps)
fig, axes = plt.subplots(nrows=ndim, figsize=(6, 3 * ndim))
fig, axes = plt.subplots(nrows=ndim, ncols=2, figsize=(8, 3 * ndim))
scatter_kwargs = dict(lw=0, marker="o", markersize=1, alpha=0.05)
scatter_kwargs = dict(lw=0, marker="o", markersize=1, alpha=0.05,)
for i, ax in enumerate(axes):
# Plot the burn-in
 
for i, (ax, axh) in enumerate(axes):
ax.plot(
ax.plot(
idxs[: nburn + 1], walkers[:, : nburn + 1, i].T, color="r", **scatter_kwargs
idxs[: nburn + 1],
 
walkers[:, : nburn + 1, i].T,
 
color="C1",
 
**scatter_kwargs
)
)
ax.set_ylabel(parameter_labels[i])
for i, ax in enumerate(axes):
# Plot the thinned posterior samples
ax.plot(idxs[nburn:], walkers[:, nburn:, i].T, color="k", **scatter_kwargs)
for i, (ax, axh) in enumerate(axes):
 
ax.plot(
 
idxs[nburn::thin],
 
walkers[:, nburn::thin, i].T,
 
color="C0",
 
**scatter_kwargs
 
)
 
axh.hist(walkers[:, nburn::thin, i].reshape((-1)), bins=50, alpha=0.8)
 
axh.set_xlabel(parameter_labels[i])
 
ax.set_ylabel(parameter_labels[i])
fig.tight_layout()
fig.tight_layout()
filename = "{}/{}_traceplot.png".format(outdir, label)
filename = "{}/{}_checkpoint_trace.png".format(outdir, label)
fig.savefig(filename)
fig.savefig(filename)
plt.close(fig)
plt.close(fig)
def plot_tau(tau_list_n, tau_list, outdir, label, autocorr_tol):
def plot_tau(tau_list_n, tau_list, outdir, label, autocorr_tau):
fig, ax = plt.subplots()
fig, ax = plt.subplots()
ax.plot(tau_list_n, tau_list, "-")
ax.plot(tau_list_n, tau_list, "-", color="C1")
check_tau_idx = -int(tau_list[-1] * autocorr_tol)
check_tau_idx = -int(tau_list[-1] * autocorr_tau)
check_taus = tau_list[check_tau_idx:]
check_taus = tau_list[check_tau_idx:]
check_taus_n = tau_list_n[check_tau_idx:]
check_taus_n = tau_list_n[check_tau_idx:]
ax.plot(check_taus_n, check_taus, "--")
ax.plot(check_taus_n, check_taus, "-", color="C0")
ax.set_xlabel("Iteration")
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\langle \tau \rangle$")
ax.set_ylabel(r"$\langle \tau \rangle$")
fig.savefig("{}/{}_tau.png".format(outdir, label))
fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label))
plt.close(fig)
plt.close(fig)
@@ -439,7 +661,7 @@ def compute_evidence(sampler, outdir, label, nburn, thin, make_plots=True):
@@ -439,7 +661,7 @@ def compute_evidence(sampler, outdir, label, nburn, thin, make_plots=True):
""" Computes the evidence using thermodynamic integration """
""" Computes the evidence using thermodynamic integration """
betas = sampler.betas
betas = sampler.betas
# We compute the evidence without the burnin samples, but we do not thin
# We compute the evidence without the burnin samples, but we do not thin
lnlike = sampler.loglikelihood[:, :, nburn:sampler.time]
lnlike = sampler.loglikelihood[:, :, nburn : sampler.time]
mean_lnlikes = np.mean(np.mean(lnlike, axis=1), axis=1)
mean_lnlikes = np.mean(np.mean(lnlike, axis=1), axis=1)
mean_lnlikes = mean_lnlikes[::-1]
mean_lnlikes = mean_lnlikes[::-1]
@@ -527,14 +749,14 @@ class LikePriorEvaluator(object):
@@ -527,14 +749,14 @@ class LikePriorEvaluator(object):
def __call__(self, x):
def __call__(self, x):
lp = self.logp(x)
lp = self.logp(x)
if np.isnan(lp):
if np.isnan(lp):
raise ValueError('Prior function returned NaN.')
raise ValueError("Prior function returned NaN.")
if lp == float('-inf'):
if lp == float("-inf"):
# Can't return -inf, since this messes with beta=0 behaviour.
# Can't return -inf, since this messes with beta=0 behaviour.
ll = 0
ll = 0
else:
else:
ll = self.logl(x)
ll = self.logl(x)
if np.isnan(ll).any():
if np.isnan(ll).any():
raise ValueError('Log likelihood function returned NaN.')
raise ValueError("Log likelihood function returned NaN.")
return ll, lp
return ll, lp
Loading