Skip to content
Snippets Groups Projects

Improve ptemcee

Merged Gregory Ashton requested to merge improve-ptemcee into master
All threads resolved!
Compare and Show latest version
8 files
+ 117
32
Compare changes
  • Side-by-side
  • Inline
Files
8
@@ -13,7 +13,7 @@ import pandas as pd
import matplotlib.pyplot as plt
from ..utils import logger
from .base_sampler import SamplerError, MCMCSampler
from .base_sampler import MCMCSampler
class Ptemcee(MCMCSampler):
@@ -39,7 +39,7 @@ class Ptemcee(MCMCSampler):
# Arguments used by ptemcee
default_kwargs = dict(
ntemps=20, nwalkers=200, Tmax=None, betas=None,
threads=1, pool=None, a=2.0, loglargs=[], logpargs=[], loglkwargs={},
a=2.0, loglargs=[], logpargs=[], loglkwargs={},
logpkwargs={}, adaptation_lag=10000, adaptation_time=100, random=None,
iterations=1000, thin=1, storechain=True, adapt=False,
swap_ratios=False)
@@ -48,7 +48,8 @@ class Ptemcee(MCMCSampler):
use_ratio=False, plot=False, skip_import_verification=False,
resume=True, nsamples=5000, burn_in_nact=50, thin_by_nact=1,
autocorr_c=5, safety=1, ncheck=50, nfrac=5, frac_threshold=0.01,
autocorr_tol=50, min_tau=1, check_point_deltaT=600, **kwargs):
autocorr_tol=50, min_tau=1, check_point_deltaT=600,
threads=1, **kwargs):
super(Ptemcee, self).__init__(
likelihood=likelihood, priors=priors, outdir=outdir,
label=label, use_ratio=use_ratio, plot=plot,
@@ -71,6 +72,8 @@ class Ptemcee(MCMCSampler):
self.min_tau = min_tau
self.check_point_deltaT = check_point_deltaT
self.threads = threads
self.resume_file = "{}/{}_checkpoint_resume.pickle".format(self.outdir, self.label)
@property
@@ -86,18 +89,20 @@ class Ptemcee(MCMCSampler):
def get_pos0_from_prior(self):
""" for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim """
logger.info("Generating pos0 samples")
return [[self.get_random_draw_from_prior()
for _ in range(self.sampler_init_kwargs["nwalkers"])]
for _ in range(self.kwargs['ntemps'])]
def get_sampler(self):
def setup_sampler(self):
import ptemcee
if os.path.isfile(self.resume_file) and self.resume is True:
logger.info("Resume data {} found".format(self.resume_file))
with open(self.resume_file, "rb") as file:
data = dill.load(file)
self.sampler = data["sampler"]
self.sampler.pool = None
self.sampler.pool = self.pool
self.sampler.threads = self.threads
self.tau_list = data["tau_list"]
self.tau_list_n = data["tau_list_n"]
pos0 = None
@@ -105,21 +110,41 @@ class Ptemcee(MCMCSampler):
else:
self.sampler = ptemcee.Sampler(
dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior,
**self.sampler_init_kwargs)
dim=self.ndim, logl=do_nothing_function, logp=do_nothing_function,
pool=self.pool, threads=self.threads, **self.sampler_init_kwargs)
self.sampler._likeprior = LikePriorEvaluator(
self.search_parameter_keys, use_ratio=self.use_ratio)
pos0 = self.get_pos0_from_prior()
return self.sampler, pos0
def setup_pool(self):
if self.threads > 1:
import schwimmbad
logger.info("Creating MultiPool with {} processes".format(self.threads))
self.pool = schwimmbad.MultiPool(
self.threads,
initializer=init,
initargs=(self.likelihood, self.priors))
else:
self.pool = None
def run_sampler(self):
import emcee
sampler, pos0 = self.get_sampler()
self.setup_pool()
out = self.run_sampler_internal()
if self.pool:
self.pool.close()
return out
def run_sampler_internal(self):
import emcee
sampler, pos0 = self.setup_sampler()
self.time_per_check = []
self.tau_list = []
self.tau_list_n = []
t0 = datetime.datetime.now()
logger.info("Starting to sample")
for (pos0, lnprob, lnlike) in sampler.sample(
pos0, **self.sampler_function_kwargs):
# Only check convergence every ncheck steps
@@ -146,7 +171,7 @@ class Ptemcee(MCMCSampler):
tau = self.safety * np.mean(taus)
if np.isnan(tau) or np.isinf(tau):
logger.info("{} | Unable to use tau={}".format(sampler.time, tau))
print("{} | Unable to use tau={}".format(sampler.time, tau), flush=True)
continue
# Convert to an integer and store for plotting
@@ -198,7 +223,8 @@ class Ptemcee(MCMCSampler):
if os.path.isfile(self.resume_file):
last_checkpoint_s = time.time() - os.path.getmtime(self.resume_file)
else:
last_checkpoint_s = np.inf
last_checkpoint_s = np.sum(self.time_per_check)
if last_checkpoint_s > self.check_point_deltaT:
self.write_current_state()
@@ -238,6 +264,8 @@ class Ptemcee(MCMCSampler):
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
if self.pool:
self.pool.close()
self.write_current_state()
sys.exit(77)
@@ -424,3 +452,58 @@ def compute_evidence(sampler, outdir, label, nburn, thin, make_plots=True):
return lnZ, lnZerr
def do_nothing_function():
""" This is a do-nothing function, we overwrite the likelihood and prior elsewhere """
pass
likelihood = None
priors = None
def init(likelihood_in, priors_in):
global likelihood
global priors
likelihood = likelihood_in
priors = priors_in
class LikePriorEvaluator(object):
"""
A overwrite of the ptemcee.LikePriorEvaluator to use bilby likelihood and priors
"""
def __init__(self, search_parameter_keys, use_ratio=False):
self.search_parameter_keys = search_parameter_keys
self.use_ratio = use_ratio
def logl(self, v_array):
parameters = {key: v for key, v in zip(self.search_parameter_keys, v_array)}
if priors.evaluate_constraints(parameters) > 0:
likelihood.parameters.update(parameters)
if self.use_ratio:
return likelihood.log_likelihood() - likelihood.noise_log_likelihood()
else:
return likelihood.log_likelihood()
else:
return np.nan_to_num(-np.inf)
def logp(self, v_array):
params = {key: t for key, t in zip(self.search_parameter_keys, v_array)}
return priors.ln_prob(params)
def __call__(self, x):
lp = self.logp(x)
if np.isnan(lp):
raise ValueError('Prior function returned NaN.')
if lp == float('-inf'):
# Can't return -inf, since this messes with beta=0 behaviour.
ll = 0
else:
ll = self.logl(x)
if np.isnan(ll).any():
raise ValueError('Log likelihood function returned NaN.')
return ll, lp
Loading