Commit 9bffa289 authored by Gregory Ashton's avatar Gregory Ashton

Overhaul to using pickles and signals

parent 57a23737
This diff is collapsed.
from __future__ import absolute_import, division, print_function
import os
from collections import namedtuple
import numpy as np
from ..utils import (
logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
from ..utils import logger, get_progress_bar
from . import Emcee
from .base_sampler import SamplerError
......@@ -31,12 +27,11 @@ class Ptemcee(Emcee):
The number of temperatures used by ptemcee
"""
default_kwargs = dict(ntemps=2, nwalkers=500, Tmax=None, betas=None,
threads=1, pool=None, a=2.0, loglargs=[], logpargs=[],
loglkwargs={}, logpkwargs={}, adaptation_lag=10000,
adaptation_time=100, random=None, iterations=100,
thin=1, storechain=True, adapt=True, swap_ratios=False,
)
default_kwargs = dict(
ntemps=2, nwalkers=500, Tmax=None, betas=None, threads=1, pool=None,
a=2.0, loglargs=[], logpargs=[], loglkwargs={}, logpkwargs={},
adaptation_lag=10000, adaptation_time=100, random=None, iterations=100,
thin=1, storechain=True, adapt=True, swap_ratios=False)
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
......@@ -61,120 +56,88 @@ class Ptemcee(Emcee):
if key not in self.sampler_function_kwargs}
@property
def checkpoint_info(self):
out_dir = os.path.join(self.outdir, 'ptemcee_{}'.format(self.label))
chain_file = os.path.join(out_dir, 'chain.dat')
last_pos_file = os.path.join(out_dir, 'last_pos.npy')
check_directory_exists_and_if_not_mkdir(out_dir)
if not os.path.isfile(chain_file):
with open(chain_file, "w") as ff:
ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
'\t'.join(self.search_parameter_keys)))
template =\
'{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
CheckpointInfo = namedtuple(
'CheckpointInfo', ['last_pos_file', 'chain_file', 'template'])
checkpoint_info = CheckpointInfo(
last_pos_file=last_pos_file, chain_file=chain_file, template=template)
return checkpoint_info
def ntemps(self):
return self.kwargs['ntemps']
def _draw_pos0_from_prior(self):
# for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim
return [[self.get_random_draw_from_prior()
for _ in range(self.nwalkers)]
for _ in range(self.kwargs['ntemps'])]
@property
def _old_chain(self):
try:
old_chain = self.__old_chain
n = old_chain.shape[0]
idx = n - np.mod(n, self.nwalkers)
return old_chain[:idx]
except AttributeError:
return None
@_old_chain.setter
def _old_chain(self, old_chain):
self.__old_chain = old_chain
def _set_pos0_for_resume(self):
self.pos0 = None
@property
def stored_chain(self):
return np.genfromtxt(self.checkpoint_info.chain_file, names=True)
def _previous_iterations(self):
""" Returns the number of iterations that the sampler has saved
@property
def stored_samples(self):
return self.stored_chain[self.search_parameter_keys]
This is used when loading in a sampler from a pickle file to figure out
how much of the run has already been completed
"""
return self.sampler.time
@property
def stored_loglike(self):
return self.stored_chain['log_l']
def sampler_chain(self):
nsteps = self._previous_iterations
return self.sampler.chain[:, :, :nsteps, :]
@property
def stored_logprior(self):
return self.stored_chain['log_p']
def load_old_chain(self):
try:
last_pos = np.load(self.checkpoint_info.last_pos_file)
self.pos0 = last_pos
self._old_chain = self.stored_samples
logger.info(
'Resuming from {} with {} iterations'.format(
self.checkpoint_info.chain_file,
self._previous_iterations))
except Exception:
logger.info('Unable to resume')
self._set_pos0()
def _pos0_shape(self):
return (self.ntemps, self.nwalkers, self.ndim)
def run_sampler(self):
def _initialise_sampler(self):
import ptemcee
tqdm = get_progress_bar()
sampler = ptemcee.Sampler(dim=self.ndim, logl=self.log_likelihood,
logp=self.log_prior, **self.sampler_init_kwargs)
if self.resume:
self.load_old_chain()
else:
self._set_pos0()
self._sampler = ptemcee.Sampler(
dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior,
**self.sampler_init_kwargs)
def print_tswap_acceptance_fraction(self):
logger.info("Sampler per-chain tswap acceptance fraction = {}".format(
self.sampler.tswap_acceptance_fraction))
def write_chains_to_file(self, pos, loglike, logpost):
with open(self.checkpoint_info.chain_file, "a") as ff:
loglike = np.squeeze(loglike[0, :])
logprior = np.squeeze(logpost[0, :]) - loglike
for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)):
line = np.concatenate((point, [logl, logp]))
ff.write(self.checkpoint_info.chain_template.format(ii, *line))
def run_sampler(self):
tqdm = get_progress_bar()
sampler_function_kwargs = self.sampler_function_kwargs
iterations = sampler_function_kwargs.pop('iterations')
iterations -= self._previous_iterations
# main iteration loop
for pos, logpost, loglike in tqdm(
sampler.sample(self.pos0, iterations=iterations,
**sampler_function_kwargs),
self.sampler.sample(self.pos0, iterations=iterations,
**sampler_function_kwargs),
total=iterations):
np.save(self.checkpoint_info.last_pos_file, pos)
with open(self.checkpoint_info.chain_file, "a") as ff:
loglike = np.squeeze(loglike[:1, :])
logprior = np.squeeze(logpost[:1, :]) - loglike
for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)):
line = np.concatenate((point, [logl, logp]))
ff.write(self.checkpoint_info.template.format(ii, *line))
self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim)))
self.write_chains_to_file(pos, loglike, logpost)
self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim)))
self.result.sampler_output = np.nan
self.print_nburn_logging_info()
self.print_tswap_acceptance_fraction()
self.result.nburn = self.nburn
if self.result.nburn > self.nsteps:
raise SamplerError(
"The run has finished, but the chain is not burned in: "
"`nburn < nsteps`. Try increasing the number of steps.")
walkers = self.stored_samples.view((float, self.ndim))
walkers = walkers.reshape(self.nwalkers, self.nsteps, self.ndim)
self.result.walkers = walkers
self.result.samples = walkers[:, self.nburn:, :].reshape((-1, self.ndim))
self.result.samples = self.sampler.chain[0, :, self.nburn:, :].reshape(
(-1, self.ndim))
self.result.walkers = self.sampler.chain[0, :, :, :]
n_samples = self.nwalkers * self.nburn
self.result.log_likelihood_evaluations = self.stored_loglike[n_samples:]
self.result.log_prior_evaluations = self.stored_logprior[n_samples:]
self.result.betas = sampler.betas
self.result.betas = self.sampler.betas
self.result.log_evidence, self.result.log_evidence_err =\
sampler.log_evidence_estimate(
sampler.loglikelihood, self.nburn / self.nsteps)
self.sampler.log_evidence_estimate(
self.sampler.loglikelihood, self.nburn / self.nsteps)
return self.result
......@@ -6,3 +6,4 @@ matplotlib>=2.0
scipy>=0.16
pandas
mock
dill
......@@ -79,6 +79,7 @@ setup(name='bilby',
'future',
'dynesty',
'corner',
'dill',
'numpy>=1.9',
'matplotlib>=2.0',
'pandas',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment