Skip to content
Snippets Groups Projects

Improve ptemcee

Merged Gregory Ashton requested to merge improve-ptemcee into master
All threads resolved!
Compare and Show latest version
2 files
+ 237
77
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 235
76
@@ -24,36 +24,124 @@ class Ptemcee(MCMCSampler):
documentation for that class for further help. Under Other Parameters, we
list commonly used kwargs and the bilby defaults.
Parameters
----------
nsamples: int, (5000)
The requested number of samples. Note, in cases where the
autocorrelation parameter is difficult to measure, it is possible to
end up with more than nsamples.
burn_in_act, thin_by_nact: int, (50, 1)
The number of burn-in autocorrelation times to discard and the thin-by
factor. Increasing burn_in_act increases the time required for burn-in.
Increasing thin_by_nact increases the time required to obtain nsamples.
autocorr_tol: int, (50)
The minimum number of autocorrelation times needed to trust the
estimate of the autocorrelation time.
autocorr_c: int, (5)
The step size for the window search used by emcee.autocorr.integrated_time
safety: int, (1)
A multiplicitive factor for the estimated autocorrelation. Useful for
cases where non-convergence can be observed by eye but the automated
tools are failing.
autocorr_tau:
The number of autocorrelation times to use in assessing if the
autocorrelation time is stable.
frac_threshold: float, (0.01)
The maximum fractional change in the autocorrelation for the last
autocorr_tau steps. If the fractional change exceeds this value,
sampling will continue until the estimate of the autocorrelation time
can be trusted.
min_tau: int, (1)
A minimum tau (autocorrelation time) to accept.
check_point_deltaT: float, (600)
The period with which to checkpoint (in seconds).
threads: int, (1)
If threads > 1, a MultiPool object is setup and used.
exit_code: int, (77)
The code on which the sampler exits.
store_walkers: bool (False)
If true, store the unthinned, unburnt chaines in the result. Note, this
is not recommended for cases where tau is large.
ignore_keys_for_tau: str
A pattern used to ignore keys in estimating the autocorrelation time.
pos0: str, list ("prior")
If a string, one of "prior" or "minimize". For "prior", the initial
positions of the sampler are drawn from the sampler. If "minimize",
a scipy.optimize step is applied to all parameters a number of times.
The walkers are then initialized from the range of values obtained.
If a list, for the keys in the list the optimization step is applied,
otherwise the initial points are drawn from the prior.
Other Parameters
----------------
nwalkers: int, (100)
nwalkers: int, (200)
The number of walkers
nsteps: int, (100)
The number of steps to take
nburn: int (50)
The fixed number of steps to discard as burn-in
ntemps: int (2)
The number of temperatures used by ptemcee
Tmax: float
The maximum temperature
"""
# Arguments used by ptemcee
default_kwargs = dict(
ntemps=20, nwalkers=200, Tmax=None, betas=None,
a=2.0, loglargs=[], logpargs=[], loglkwargs={},
logpkwargs={}, adaptation_lag=10000, adaptation_time=100, random=None,
adapt=False, swap_ratios=False)
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, check_point_plot=True, skip_import_verification=False,
resume=True, nsamples=5000, burn_in_nact=50, thin_by_nact=1,
autocorr_c=5, safety=1, frac_threshold=0.01,
autocorr_tol=50, autocorr_tau=5, min_tau=1, check_point_deltaT=600,
threads=1, exit_code=77, plot=False, store_walkers=False,
ignore_keys_for_tau="recalib", pos0="prior", **kwargs):
ntemps=20,
nwalkers=200,
Tmax=None,
betas=None,
a=2.0,
loglargs=[],
logpargs=[],
loglkwargs={},
logpkwargs={},
adaptation_lag=10000,
adaptation_time=100,
random=None,
adapt=False,
swap_ratios=False,
)
def __init__(
self,
likelihood,
priors,
outdir="outdir",
label="label",
use_ratio=False,
check_point_plot=True,
skip_import_verification=False,
resume=True,
nsamples=5000,
burn_in_nact=50,
thin_by_nact=1,
autocorr_tol=50,
autocorr_c=5,
safety=1,
autocorr_tau=5,
frac_threshold=0.01,
min_tau=1,
check_point_deltaT=600,
threads=1,
exit_code=77,
plot=False,
store_walkers=False,
ignore_keys_for_tau="recalib",
pos0="prior",
**kwargs
):
super(Ptemcee, self).__init__(
likelihood=likelihood, priors=priors, outdir=outdir,
label=label, use_ratio=use_ratio, plot=plot,
skip_import_verification=skip_import_verification, **kwargs)
likelihood=likelihood,
priors=priors,
outdir=outdir,
label=label,
use_ratio=use_ratio,
plot=plot,
skip_import_verification=skip_import_verification,
**kwargs
)
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
@@ -77,26 +165,34 @@ class Ptemcee(MCMCSampler):
self.pos0 = pos0
self.check_point_plot = check_point_plot
self.resume_file = "{}/{}_checkpoint_resume.pickle".format(self.outdir, self.label)
self.resume_file = "{}/{}_checkpoint_resume.pickle".format(
self.outdir, self.label
)
self.exit_code = exit_code
@property
def sampler_function_kwargs(self):
keys = ['adapt', 'swap_ratios']
keys = ["adapt", "swap_ratios"]
return {key: self.kwargs[key] for key in keys}
@property
def sampler_init_kwargs(self):
return {key: value
for key, value in self.kwargs.items()
if key not in self.sampler_function_kwargs}
return {
key: value
for key, value in self.kwargs.items()
if key not in self.sampler_function_kwargs
}
def get_pos0_from_prior(self):
""" for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim """
logger.info("Generating pos0 samples")
return [[self.get_random_draw_from_prior()
for _ in range(self.sampler_init_kwargs["nwalkers"])]
for _ in range(self.kwargs['ntemps'])]
return [
[
self.get_random_draw_from_prior()
for _ in range(self.sampler_init_kwargs["nwalkers"])
]
for _ in range(self.kwargs["ntemps"])
]
def get_pos0_from_minimize(self, minimize_list=None):
logger.info("Attempting to set pos0 from minimize")
@@ -112,13 +208,17 @@ class Ptemcee(MCMCSampler):
def neg_log_like(params):
likelihood_copy.parameters.update(
{key: val for key, val in zip(minimize_list, params)})
{key: val for key, val in zip(minimize_list, params)}
)
try:
return -likelihood_copy.log_likelihood()
except RuntimeError:
return +np.inf
bounds = [(self.priors[key].minimum, self.priors[key].maximum)
for key in minimize_list]
bounds = [
(self.priors[key].minimum, self.priors[key].maximum)
for key in minimize_list
]
trials = 0
success = []
while True:
@@ -126,7 +226,8 @@ class Ptemcee(MCMCSampler):
likelihood_copy.parameters.update(draw)
x0 = [draw[key] for key in minimize_list]
res = minimize(
neg_log_like, x0, bounds=bounds, method='L-BFGS-B', tol=1e-15)
neg_log_like, x0, bounds=bounds, method="L-BFGS-B", tol=1e-15
)
if res.success:
success.append(res.x)
if trials > 100:
@@ -138,16 +239,20 @@ class Ptemcee(MCMCSampler):
for i, key in enumerate(minimize_list):
pos0_min = np.min(success[:, i])
pos0_max = np.max(success[:, i])
logger.info("Initialize {} walkers from {}->{}"
.format(key, pos0_min, pos0_max))
logger.info(
"Initialize {} walkers from {}->{}".format(key, pos0_min, pos0_max)
)
j = self.search_parameter_keys.index(key)
pos0[:, :, j] = np.random.uniform(
pos0_min, pos0_max,
size=(self.kwargs["ntemps"], self.kwargs["nwalkers"]))
pos0_min,
pos0_max,
size=(self.kwargs["ntemps"], self.kwargs["nwalkers"]),
)
return pos0
def setup_sampler(self):
import ptemcee
if os.path.isfile(self.resume_file) and self.resume is True:
logger.info("Resume data {} found".format(self.resume_file))
with open(self.resume_file, "rb") as file:
@@ -163,21 +268,32 @@ class Ptemcee(MCMCSampler):
pos0 = None
logger.info("Resuming from previous run with time={}".format(self.sampler.time))
logger.info(
"Resuming from previous run with time={}".format(self.sampler.time)
)
else:
# Initialize the PTSampler
if self.threads == 1:
self.sampler = ptemcee.Sampler(
dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior,
**self.sampler_init_kwargs)
dim=self.ndim,
logl=self.log_likelihood,
logp=self.log_prior,
**self.sampler_init_kwargs
)
else:
self.sampler = ptemcee.Sampler(
dim=self.ndim, logl=do_nothing_function, logp=do_nothing_function,
pool=self.pool, threads=self.threads, **self.sampler_init_kwargs)
dim=self.ndim,
logl=do_nothing_function,
logp=do_nothing_function,
pool=self.pool,
threads=self.threads,
**self.sampler_init_kwargs
)
self.sampler._likeprior = LikePriorEvaluator(
self.search_parameter_keys, use_ratio=self.use_ratio)
self.search_parameter_keys, use_ratio=self.use_ratio
)
# Set up empty lists
self.tau_list = []
@@ -202,11 +318,11 @@ class Ptemcee(MCMCSampler):
def setup_pool(self):
if self.threads > 1:
import schwimmbad
logger.info("Creating MultiPool with {} processes".format(self.threads))
self.pool = schwimmbad.MultiPool(
self.threads,
initializer=init,
initargs=(self.likelihood, self.priors))
self.threads, initializer=init, initargs=(self.likelihood, self.priors)
)
else:
self.pool = None
@@ -219,6 +335,7 @@ class Ptemcee(MCMCSampler):
def run_sampler_internal(self):
import emcee
sampler, pos0 = self.setup_sampler()
t0 = datetime.datetime.now()
@@ -266,20 +383,23 @@ class Ptemcee(MCMCSampler):
np.nan,
np.nan,
[np.nan],
False)
False,
)
continue
# Calculate the effective number of samples available
self.nburn = int(self.burn_in_nact * tau_int)
self.thin = int(np.max([1, self.thin_by_nact * tau_int]))
samples_per_check = sampler.nwalkers / self.thin
self.nsamples_effective = int(sampler.nwalkers * (sampler.time - self.nburn) / self.thin)
self.nsamples_effective = int(
sampler.nwalkers * (sampler.time - self.nburn) / self.thin
)
# Calculate convergence boolean
converged = self.nsamples < self.nsamples_effective
# Calculate fractional change in tau from previous iterations
check_taus = np.array(self.tau_list[-tau_int * self.autocorr_tau:])
check_taus = np.array(self.tau_list[-tau_int * self.autocorr_tau :])
if not np.any(np.isnan(check_taus)):
frac = (tau - check_taus) / tau
tau_usable = np.all(frac < self.frac_threshold)
@@ -319,10 +439,11 @@ class Ptemcee(MCMCSampler):
# Get 0-likelihood samples and store in the result
samples = sampler.chain[0, :, :, :] # nwalkers, nsteps, ndim
self.result.samples = (
samples[:, self.nburn: sampler.time:self.thin, :].reshape((-1, self.ndim)))
self.result.samples = samples[
:, self.nburn : sampler.time : self.thin, :
].reshape((-1, self.ndim))
loglikelihood = sampler.loglikelihood[
0, :, self.nburn:sampler.time:self.thin
0, :, self.nburn : sampler.time : self.thin
] # nwalkers, nsteps
self.result.log_likelihood_evaluations = loglikelihood.reshape((-1))
@@ -336,35 +457,54 @@ class Ptemcee(MCMCSampler):
self.result.log_evidence = log_evidence
self.result.log_evidence_err = log_evidence_err
self.result.sampling_time = datetime.timedelta(seconds=np.sum(self.time_per_check))
self.result.sampling_time = datetime.timedelta(
seconds=np.sum(self.time_per_check)
)
return self.result
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
if getattr(self, 'pool', None):
if getattr(self, "pool", None):
self.write_current_state(plot=False)
logger.warning("Closing pool")
self.pool.close()
sys.exit(self.exit_code)
def write_current_state(self, plot=True):
checkpoint(self.outdir, self.label, self.nsamples_effective,
self.sampler, self.nburn, self.thin,
self.search_parameter_keys, self.resume_file, self.tau_list,
self.tau_list_n, self.time_per_check)
checkpoint(
self.outdir,
self.label,
self.nsamples_effective,
self.sampler,
self.nburn,
self.thin,
self.search_parameter_keys,
self.resume_file,
self.tau_list,
self.tau_list_n,
self.time_per_check,
)
if plot and not np.isnan(self.nburn):
# Generate the walkers plot diagnostic
plot_walkers(
self.sampler.chain[0, :, : self.sampler.time, :],
self.nburn, self.thin, self.search_parameter_keys, self.outdir,
self.label
self.nburn,
self.thin,
self.search_parameter_keys,
self.outdir,
self.label,
)
# Generate the tau plot diagnostic
plot_tau(self.tau_list_n, self.tau_list, self.outdir, self.label,
self.autocorr_tau)
plot_tau(
self.tau_list_n,
self.tau_list,
self.outdir,
self.label,
self.autocorr_tau,
)
def print_progress(
@@ -388,9 +528,7 @@ def print_progress(
)
ave_time_per_check = np.mean(time_per_check[-3:])
time_left = (
(nsamples - nsamples_effective) * ave_time_per_check / samples_per_check
)
time_left = (nsamples - nsamples_effective) * ave_time_per_check / samples_per_check
if time_left > 0:
time_left = str(datetime.timedelta(seconds=int(time_left)))
else:
@@ -427,16 +565,26 @@ def print_progress(
)
def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
search_parameter_keys, resume_file, tau_list, tau_list_n,
time_per_check):
def checkpoint(
outdir,
label,
nsamples_effective,
sampler,
nburn,
thin,
search_parameter_keys,
resume_file,
tau_list,
tau_list_n,
time_per_check,
):
logger.info("Writing checkpoint and diagnostics")
ndim = sampler.dim
# Store the samples if possible
if nsamples_effective > 0:
filename = "{}/{}_samples.txt".format(outdir, label)
samples = sampler.chain[0, :, nburn:sampler.time:thin, :].reshape(
samples = sampler.chain[0, :, nburn : sampler.time : thin, :].reshape(
(-1, ndim)
)
df = pd.DataFrame(samples, columns=search_parameter_keys)
@@ -451,8 +599,11 @@ def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
sampler_copy._beta_history = sampler._beta_history[:, : sampler.time]
data = dict(
sampler=sampler_copy, tau_list=tau_list, tau_list_n=tau_list_n,
time_per_check=time_per_check)
sampler=sampler_copy,
tau_list=tau_list,
tau_list_n=tau_list_n,
time_per_check=time_per_check,
)
with open(resume_file, "wb") as file:
dill.dump(data, file, protocol=4)
@@ -469,12 +620,20 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
# Plot the burn-in
for i, (ax, axh) in enumerate(axes):
ax.plot(
idxs[: nburn + 1], walkers[:, : nburn + 1, i].T, color="C1", **scatter_kwargs
idxs[: nburn + 1],
walkers[:, : nburn + 1, i].T,
color="C1",
**scatter_kwargs
)
# Plot the thinned posterior samples
for i, (ax, axh) in enumerate(axes):
ax.plot(idxs[nburn::thin], walkers[:, nburn::thin, i].T, color="C0", **scatter_kwargs)
ax.plot(
idxs[nburn::thin],
walkers[:, nburn::thin, i].T,
color="C0",
**scatter_kwargs
)
axh.hist(walkers[:, nburn::thin, i].reshape((-1)), bins=50, alpha=0.8)
axh.set_xlabel(parameter_labels[i])
ax.set_ylabel(parameter_labels[i])
@@ -487,11 +646,11 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
def plot_tau(tau_list_n, tau_list, outdir, label, autocorr_tau):
fig, ax = plt.subplots()
ax.plot(tau_list_n, tau_list, "-", color='C1')
ax.plot(tau_list_n, tau_list, "-", color="C1")
check_tau_idx = -int(tau_list[-1] * autocorr_tau)
check_taus = tau_list[check_tau_idx:]
check_taus_n = tau_list_n[check_tau_idx:]
ax.plot(check_taus_n, check_taus, "-", color='C0')
ax.plot(check_taus_n, check_taus, "-", color="C0")
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\langle \tau \rangle$")
fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label))
@@ -502,7 +661,7 @@ def compute_evidence(sampler, outdir, label, nburn, thin, make_plots=True):
""" Computes the evidence using thermodynamic integration """
betas = sampler.betas
# We compute the evidence without the burnin samples, but we do not thin
lnlike = sampler.loglikelihood[:, :, nburn:sampler.time]
lnlike = sampler.loglikelihood[:, :, nburn : sampler.time]
mean_lnlikes = np.mean(np.mean(lnlike, axis=1), axis=1)
mean_lnlikes = mean_lnlikes[::-1]
@@ -590,14 +749,14 @@ class LikePriorEvaluator(object):
def __call__(self, x):
lp = self.logp(x)
if np.isnan(lp):
raise ValueError('Prior function returned NaN.')
raise ValueError("Prior function returned NaN.")
if lp == float('-inf'):
if lp == float("-inf"):
# Can't return -inf, since this messes with beta=0 behaviour.
ll = 0
else:
ll = self.logl(x)
if np.isnan(ll).any():
raise ValueError('Log likelihood function returned NaN.')
raise ValueError("Log likelihood function returned NaN.")
return ll, lp
Loading