Skip to content
Snippets Groups Projects
Commit 5b4355b9 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Standardize methods ordering

parent 9bffa289
No related branches found
No related tags found
1 merge request!423Improvements to checkpointing for emcee/ptemcee
Pipeline #56010 passed
......@@ -45,11 +45,11 @@ class Emcee(MCMCSampler):
"""
default_kwargs = dict(nwalkers=500, a=2, args=[], kwargs={},
postargs=None, pool=None, live_dangerously=False,
runtime_sortingfn=None, lnprob0=None, rstate0=None,
blobs0=None, iterations=100, thin=1, storechain=True,
mh_proposal=None)
default_kwargs = dict(
nwalkers=500, a=2, args=[], kwargs={}, postargs=None, pool=None,
live_dangerously=False, runtime_sortingfn=None, lnprob0=None,
rstate0=None, blobs0=None, iterations=100, thin=1, storechain=True,
mh_proposal=None)
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
......@@ -142,6 +142,14 @@ class Emcee(MCMCSampler):
return init_kwargs
def lnpostfn(self, theta):
log_prior = self.log_prior(theta)
if np.isinf(log_prior):
return -np.inf, [np.nan, np.nan]
else:
log_likelihood = self.log_likelihood(theta)
return log_likelihood + log_prior, [log_likelihood, log_prior]
@property
def nburn(self):
if type(self.__nburn) in [float, int]:
......@@ -206,7 +214,8 @@ class Emcee(MCMCSampler):
"""
out_dir = os.path.join(
self.outdir, '{}_{}'.format(self.__class__.__name__, self.label))
self.outdir, '{}_{}'.format(self.__class__.__name__.lower(),
self.label))
check_directory_exists_and_if_not_mkdir(out_dir)
sampler_file = os.path.join(out_dir, 'sampler.pickle')
......@@ -253,9 +262,6 @@ class Emcee(MCMCSampler):
import emcee
self._sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
def _set_pos0_for_resume(self):
self.pos0 = self.sampler.chain[:, -1, :]
@property
def sampler(self):
""" Returns the ptemcee sampler object
......@@ -285,42 +291,6 @@ class Emcee(MCMCSampler):
for ii, point in enumerate(points):
ff.write(self.checkpoint_info.chain_template.format(ii, *point))
def run_sampler(self):
tqdm = get_progress_bar()
sampler_function_kwargs = self.sampler_function_kwargs
iterations = sampler_function_kwargs.pop('iterations')
iterations -= self._previous_iterations
print('pos0', self.pos0)
sampler_function_kwargs['p0'] = self.pos0
for sample in tqdm(
self.sampler.sample(iterations=iterations, **sampler_function_kwargs),
total=iterations):
self.write_chains_to_file(sample)
self.result.sampler_output = np.nan
blobs_flat = np.array(self.sampler.blobs).reshape((-1, 2))
log_likelihoods, log_priors = blobs_flat.T
chain = self.sampler.chain.reshape((-1, self.ndim))
log_ls = log_likelihoods
log_ps = log_priors
self.calculate_autocorrelation(chain)
self.print_nburn_logging_info()
self.result.nburn = self.nburn
n_samples = self.nwalkers * self.nburn
if self.result.nburn > self.nsteps:
raise SamplerError(
"The run has finished, but the chain is not burned in: "
"`nburn < nsteps`. Try increasing the number of steps.")
self.result.samples = chain[n_samples:, :]
self.result.log_likelihood_evaluations = log_ls[n_samples:]
self.result.log_prior_evaluations = log_ps[n_samples:]
self.result.walkers = self.sampler.chain
self.result.log_evidence = np.nan
self.result.log_evidence_err = np.nan
return self.result
@property
def _previous_iterations(self):
""" Returns the number of iterations that the sampler has saved
......@@ -356,10 +326,41 @@ class Emcee(MCMCSampler):
logger.debug("Generating initial walker positions from prior")
self.pos0 = self._draw_pos0_from_prior()
def lnpostfn(self, theta):
log_prior = self.log_prior(theta)
if np.isinf(log_prior):
return -np.inf, [np.nan, np.nan]
else:
log_likelihood = self.log_likelihood(theta)
return log_likelihood + log_prior, [log_likelihood, log_prior]
def _set_pos0_for_resume(self):
self.pos0 = self.sampler.chain[:, -1, :]
def run_sampler(self):
tqdm = get_progress_bar()
sampler_function_kwargs = self.sampler_function_kwargs
iterations = sampler_function_kwargs.pop('iterations')
iterations -= self._previous_iterations
print('pos0', self.pos0)
sampler_function_kwargs['p0'] = self.pos0
for sample in tqdm(
self.sampler.sample(iterations=iterations, **sampler_function_kwargs),
total=iterations):
self.write_chains_to_file(sample)
self.result.sampler_output = np.nan
blobs_flat = np.array(self.sampler.blobs).reshape((-1, 2))
log_likelihoods, log_priors = blobs_flat.T
chain = self.sampler.chain.reshape((-1, self.ndim))
log_ls = log_likelihoods
log_ps = log_priors
self.calculate_autocorrelation(chain)
self.print_nburn_logging_info()
self.result.nburn = self.nburn
n_samples = self.nwalkers * self.nburn
if self.result.nburn > self.nsteps:
raise SamplerError(
"The run has finished, but the chain is not burned in: "
"`nburn < nsteps`. Try increasing the number of steps.")
self.result.samples = chain[n_samples:, :]
self.result.log_likelihood_evaluations = log_ls[n_samples:]
self.result.log_prior_evaluations = log_ps[n_samples:]
self.result.walkers = self.sampler.chain
self.result.log_evidence = np.nan
self.result.log_evidence_err = np.nan
return self.result
......@@ -59,33 +59,11 @@ class Ptemcee(Emcee):
def ntemps(self):
return self.kwargs['ntemps']
def _draw_pos0_from_prior(self):
# for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim
return [[self.get_random_draw_from_prior()
for _ in range(self.nwalkers)]
for _ in range(self.kwargs['ntemps'])]
def _set_pos0_for_resume(self):
self.pos0 = None
@property
def _previous_iterations(self):
""" Returns the number of iterations that the sampler has saved
This is used when loading in a sampler from a pickle file to figure out
how much of the run has already been completed
"""
return self.sampler.time
@property
def sampler_chain(self):
nsteps = self._previous_iterations
return self.sampler.chain[:, :, :nsteps, :]
@property
def _pos0_shape(self):
return (self.ntemps, self.nwalkers, self.ndim)
def _initialise_sampler(self):
import ptemcee
self._sampler = ptemcee.Sampler(
......@@ -104,6 +82,28 @@ class Ptemcee(Emcee):
line = np.concatenate((point, [logl, logp]))
ff.write(self.checkpoint_info.chain_template.format(ii, *line))
@property
def _previous_iterations(self):
""" Returns the number of iterations that the sampler has saved
This is used when loading in a sampler from a pickle file to figure out
how much of the run has already been completed
"""
return self.sampler.time
def _draw_pos0_from_prior(self):
# for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim
return [[self.get_random_draw_from_prior()
for _ in range(self.nwalkers)]
for _ in range(self.kwargs['ntemps'])]
@property
def _pos0_shape(self):
return (self.ntemps, self.nwalkers, self.ndim)
def _set_pos0_for_resume(self):
self.pos0 = None
def run_sampler(self):
tqdm = get_progress_bar()
sampler_function_kwargs = self.sampler_function_kwargs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment