Skip to content
Snippets Groups Projects

WIP: Improve emcee

Closed Colm Talbot requested to merge improve-emcee into master
5 files
+ 493
123
Compare changes
  • Side-by-side
  • Inline
Files
5
@@ -88,7 +88,7 @@ class Sampler(object):
"""
default_kwargs = dict()
npool_equiv_kwargs = ['queue_size', 'threads', 'nthreads', 'npool']
npool_equiv_kwargs = ['queue_size', 'threads', 'nthreads', 'npool', 'processes']
def __init__(
self, likelihood, priors, outdir='outdir', label='label',
@@ -394,7 +394,7 @@ class Sampler(object):
else:
return self.likelihood.log_likelihood()
def get_random_draw_from_prior(self):
def get_random_draw_from_prior(self, ln_l_min=-np.inf):
""" Get a random draw from the prior distribution
Returns
@@ -407,10 +407,10 @@ class Sampler(object):
new_sample = self.priors.sample()
draw = np.array(list(new_sample[key]
for key in self._search_parameter_keys))
self.check_draw(draw)
self.check_draw(draw, ln_l_min=ln_l_min)
return draw
def get_initial_points_from_prior(self, npoints=1):
def get_initial_points_from_prior(self, npoints=1, ln_l_min=-np.inf):
""" Method to draw a set of live points from the prior
This iterates over draws from the prior until all the samples have a
@@ -434,17 +434,22 @@ class Sampler(object):
unit_cube = []
parameters = []
likelihood = []
from tqdm import tqdm
bar = tqdm(total=npoints)
while len(unit_cube) < npoints:
unit = np.random.rand(self.ndim)
theta = self.prior_transform(unit)
if self.check_draw(theta, warning=False):
if self.check_draw(theta, warning=False, ln_l_min=ln_l_min):
unit_cube.append(unit)
parameters.append(theta)
likelihood.append(self.log_likelihood(theta))
bar.update(n=1)
bar.close()
return np.array(unit_cube), np.array(parameters), np.array(likelihood)
def check_draw(self, theta, warning=True):
def check_draw(self, theta, warning=True, ln_l_min=-np.inf):
"""
Checks if the draw will generate an infinite prior or likelihood
@@ -466,10 +471,13 @@ class Sampler(object):
if warning:
logger.warning('Prior draw {} has inf prior'.format(theta))
return False
if abs(self.log_likelihood(theta)) in bad_values:
ln_l = self.log_likelihood(theta)
if abs(ln_l) in bad_values:
if warning:
logger.warning('Prior draw {} has inf likelihood'.format(theta))
return False
elif ln_l < ln_l_min:
return False
return True
def run_sampler(self):
@@ -539,6 +547,21 @@ class Sampler(object):
else:
return None
@property
def map(self):
if getattr(self, "pool", None) is not None:
return self.pool.map
else:
return map
@property
def npool(self):
npool = 1
for key in self.npool_equiv_kwargs:
if self.kwargs.get(key, None) is not None:
npool = self.kwargs[key]
return npool
class NestedSampler(Sampler):
npoints_equiv_kwargs = ['nlive', 'nlives', 'n_live_points', 'npoints', 'npoint', 'Nlive', 'num_live_points']
@@ -603,7 +626,7 @@ class NestedSampler(Sampler):
class MCMCSampler(Sampler):
nwalkers_equiv_kwargs = ['nwalker', 'nwalkers', 'draws', 'Niter']
nwalkers_equiv_kwargs = ['nwalker', 'nwalkers', 'draws', 'Niter', 'walkers']
nburn_equiv_kwargs = ['burn', 'nburn']
def print_nburn_logging_info(self):
@@ -630,8 +653,9 @@ class MCMCSampler(Sampler):
"""
import emcee
try:
self.result.max_autocorrelation_time = int(np.max(
emcee.autocorr.integrated_time(samples, c=c)))
self.result.max_autocorrelation_time = int(np.ceil(np.max(
emcee.autocorr.integrated_time(samples, c=c, quiet=True, tol=10)
)))
logger.info("Max autocorr time = {}".format(
self.result.max_autocorrelation_time))
except emcee.autocorr.AutocorrError as e:
Loading