Code owners
Assign users and groups as approvers for specific file changes. Learn more.
dynesty.py 18.54 KiB
from __future__ import absolute_import
import os
import sys
import pickle
import signal
import numpy as np
from pandas import DataFrame
from ..utils import logger, check_directory_exists_and_if_not_mkdir
from .base_sampler import Sampler, NestedSampler
class Dynesty(NestedSampler):
"""
bilby wrapper of `dynesty.NestedSampler`
(https://dynesty.readthedocs.io/en/latest/)
All positional and keyword arguments (i.e., the args and kwargs) passed to
`run_sampler` will be propagated to `dynesty.NestedSampler`, see
documentation for that class for further help. Under Other Parameter below,
we list commonly all kwargs and the bilby defaults.
Parameters
----------
likelihood: likelihood.Likelihood
A object with a log_l method
priors: bilby.core.prior.PriorDict, dict
Priors to be used in the search.
This has attributes for each parameter to be sampled.
outdir: str, optional
Name of the output directory
label: str, optional
Naming scheme of the output files
use_ratio: bool, optional
Switch to set whether or not you want to use the log-likelihood ratio
or just the log-likelihood
plot: bool, optional
Switch to set whether or not you want to create traceplots
skip_import_verification: bool
Skips the check if the sampler is installed if true. This is
only advisable for testing environments
Other Parameters
----------------
npoints: int, (1000)
The number of live points, note this can also equivalently be given as
one of [nlive, nlives, n_live_points]
bound: {'none', 'single', 'multi', 'balls', 'cubes'}, ('multi')
Method used to select new points
sample: {'unif', 'rwalk', 'slice', 'rslice', 'hslice'}, ('rwalk')
Method used to sample uniformly within the likelihood constraints,
conditioned on the provided bounds
walks: int
Number of walks taken if using `sample='rwalk'`, defaults to `ndim * 5`
dlogz: float, (0.1)
Stopping criteria
verbose: Bool
If true, print information information about the convergence during
check_point: bool,
If true, use check pointing.
check_point_plot: bool,
If true, generate a trace plot along with the check-point
check_point_delta_t: float (600)
The approximate checkpoint period (in seconds). Should the run be
interrupted, it can be resumed from the last checkpoint. Set to
`None` to turn-off check pointing
n_check_point: int, optional (None)
The number of steps to take before check pointing (override
check_point_delta_t).
resume: bool
If true, resume run from checkpoint (if available)
"""
default_kwargs = dict(bound='multi', sample='rwalk',
verbose=True, periodic=None,
check_point_delta_t=600, nlive=1000,
first_update=None, walks=None,
npdim=None, rstate=None, queue_size=None, pool=None,
use_pool=None, live_points=None,
logl_args=None, logl_kwargs=None,
ptform_args=None, ptform_kwargs=None,
enlarge=None, bootstrap=None, vol_dec=0.5, vol_check=2.0,
facc=0.5, slices=5,
update_interval=None, print_func=None,
dlogz=0.1, maxiter=None, maxcall=None,
logl_max=np.inf, add_live=True, print_progress=True,
save_bounds=False)
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
check_point=True, check_point_plot=False, n_check_point=None,
check_point_delta_t=600, resume=True, **kwargs):
NestedSampler.__init__(self, likelihood=likelihood, priors=priors,
outdir=outdir, label=label, use_ratio=use_ratio,
plot=plot,
skip_import_verification=skip_import_verification,
**kwargs)
self.n_check_point = n_check_point
self.check_point = check_point
self.check_point_plot = check_point_plot
self.resume = resume
self._periodic = list()
self._reflective = list()
self._apply_dynesty_boundaries()
if self.n_check_point is None:
# If the log_likelihood_eval_time is not calculable then
# check_point is set to False.
if np.isnan(self._log_likelihood_eval_time):
self.check_point = False
n_check_point_raw = (check_point_delta_t / self._log_likelihood_eval_time)
n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw)))
self.n_check_point = n_check_point_rnd
logger.info("Checkpoint every n_check_point = {}".format(self.n_check_point))
self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label)
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
@property
def sampler_function_kwargs(self):
keys = ['dlogz', 'print_progress', 'print_func', 'maxiter',
'maxcall', 'logl_max', 'add_live', 'save_bounds']
return {key: self.kwargs[key] for key in keys}
@property
def sampler_init_kwargs(self):
return {key: value
for key, value in self.kwargs.items()
if key not in self.sampler_function_kwargs}
def _translate_kwargs(self, kwargs):
if 'nlive' not in kwargs:
for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs:
kwargs['nlive'] = kwargs.pop(equiv)
if 'print_progress' not in kwargs:
if 'verbose' in kwargs:
kwargs['print_progress'] = kwargs.pop('verbose')
def _verify_kwargs_against_default_kwargs(self):
if not self.kwargs['walks']:
self.kwargs['walks'] = self.ndim * 10
if not self.kwargs['update_interval']:
self.kwargs['update_interval'] = int(0.6 * self.kwargs['nlive'])
if not self.kwargs['print_func']:
self.kwargs['print_func'] = self._print_func
Sampler._verify_kwargs_against_default_kwargs(self)
def _print_func(self, results, niter, ncall, dlogz, *args, **kwargs):
""" Replacing status update for dynesty.result.print_func """
# Extract results at the current iteration.
(worst, ustar, vstar, loglstar, logvol, logwt,
logz, logzvar, h, nc, worst_it, boundidx, bounditer,
eff, delta_logz) = results
# Adjusting outputs for printing.
if delta_logz > 1e6:
delta_logz = np.inf
if 0. <= logzvar <= 1e6:
logzerr = np.sqrt(logzvar)
else:
logzerr = np.nan
if logz <= -1e6:
logz = -np.inf
if loglstar <= -1e6:
loglstar = -np.inf
if self.use_ratio:
key = 'logz ratio'
else:
key = 'logz'
# Constructing output.
raw_string = "\r {}| {}={:6.3f} +/- {:6.3f} | dlogz: {:6.3f} > {:6.3f}"
print_str = raw_string.format(
niter, key, logz, logzerr, delta_logz, dlogz)
# Printing.
sys.stderr.write(print_str)
sys.stderr.flush()
def _apply_dynesty_boundaries(self):
if self.kwargs['periodic'] is None:
logger.debug("Setting periodic boundaries for keys:")
self.kwargs['periodic'] = []
self._periodic = list()
self._reflective = list()
for ii, key in enumerate(self.search_parameter_keys):
if self.priors[key].boundary in ['periodic', 'reflective']:
self.kwargs['periodic'].append(ii)
logger.debug(" {}".format(key))
if self.priors[key].boundary == 'periodic':
self._periodic.append(ii)
else:
self._reflective.append(ii)
def run_sampler(self):
import dynesty
self.sampler = dynesty.NestedSampler(
loglikelihood=self.log_likelihood,
prior_transform=self.prior_transform,
ndim=self.ndim, **self.sampler_init_kwargs)
if self.check_point:
out = self._run_external_sampler_with_checkpointing()
else:
out = self._run_external_sampler_without_checkpointing()
# Flushes the output to force a line break
if self.kwargs["verbose"]:
print("")
check_directory_exists_and_if_not_mkdir(self.outdir)
dynesty_result = "{}/{}_dynesty.pickle".format(self.outdir, self.label)
with open(dynesty_result, 'wb') as file:
pickle.dump(out, file)
weights = np.exp(out['logwt'] - out['logz'][-1])
nested_samples = DataFrame(
out.samples, columns=self.search_parameter_keys)
nested_samples['weights'] = weights
nested_samples['log_likelihood'] = out.logl
self.result.samples = dynesty.utils.resample_equal(out.samples, weights)
self.result.nested_samples = nested_samples
self.result.log_likelihood_evaluations = self.reorder_loglikelihoods(
unsorted_loglikelihoods=out.logl, unsorted_samples=out.samples,
sorted_samples=self.result.samples)
self.result.log_evidence = out.logz[-1]
self.result.log_evidence_err = out.logzerr[-1]
if self.plot:
self.generate_trace_plots(out)
return self.result
def _run_external_sampler_without_checkpointing(self):
logger.debug("Running sampler without checkpointing")
self.sampler.run_nested(**self.sampler_function_kwargs)
return self.sampler.results
def _run_external_sampler_with_checkpointing(self):
logger.debug("Running sampler with checkpointing")
if self.resume:
resume = self.read_saved_state(continuing=True)
if resume:
logger.info('Resuming from previous run.')
old_ncall = self.sampler.ncall
sampler_kwargs = self.sampler_function_kwargs.copy()
sampler_kwargs['maxcall'] = self.n_check_point
sampler_kwargs['add_live'] = False
while True:
sampler_kwargs['maxcall'] += self.n_check_point
self.sampler.run_nested(**sampler_kwargs)
if self.sampler.ncall == old_ncall:
break
old_ncall = self.sampler.ncall
self.write_current_state()
self.read_saved_state()
sampler_kwargs['add_live'] = True
self.sampler.run_nested(**sampler_kwargs)
return self.sampler.results
def _remove_checkpoint(self):
"""Remove checkpointed state"""
if os.path.isfile(self.resume_file):
os.remove(self.resume_file)
def read_saved_state(self, continuing=False):
"""
Read a saved state of the sampler to disk.
The required information to reconstruct the state of the run is read
from a pickle file.
This currently adds the whole chain to the sampler.
We then remove the old checkpoint and write all unnecessary items back
to disk.
FIXME: Load only the necessary quantities, rather than read/write?
Parameters
----------
sampler: `dynesty.NestedSampler`
NestedSampler instance to reconstruct from the saved state.
continuing: bool
Whether the run is continuing or terminating, if True, the loaded
state is mostly written back to disk.
"""
logger.debug("Reading resume file {}".format(self.resume_file))
if os.path.isfile(self.resume_file):
with open(self.resume_file, 'rb') as file:
saved = pickle.load(file)
logger.debug(
"Succesfuly read resume file {}".format(self.resume_file))
self.sampler.saved_u = list(saved['unit_cube_samples'])
self.sampler.saved_v = list(saved['physical_samples'])
self.sampler.saved_logl = list(saved['sample_likelihoods'])
self.sampler.saved_logvol = list(saved['sample_log_volume'])
self.sampler.saved_logwt = list(saved['sample_log_weights'])
self.sampler.saved_logz = list(saved['cumulative_log_evidence'])
self.sampler.saved_logzvar = list(saved['cumulative_log_evidence_error'])
self.sampler.saved_id = list(saved['id'])
self.sampler.saved_it = list(saved['it'])
self.sampler.saved_nc = list(saved['nc'])
self.sampler.saved_boundidx = list(saved['boundidx'])
self.sampler.saved_bounditer = list(saved['bounditer'])
self.sampler.saved_scale = list(saved['scale'])
self.sampler.saved_h = list(saved['cumulative_information'])
self.sampler.ncall = saved['ncall']
self.sampler.live_logl = list(saved['live_logl'])
self.sampler.it = saved['iteration'] + 1
self.sampler.live_u = saved['live_u']
self.sampler.live_v = saved['live_v']
self.sampler.nlive = saved['nlive']
self.sampler.live_bound = saved['live_bound']
self.sampler.live_it = saved['live_it']
self.sampler.added_live = saved['added_live']
if continuing:
self.write_current_state(plot=False)
return True
else:
logger.debug(
"Failed to read resume file {}".format(self.resume_file))
return False
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
self.write_current_state(plot=False)
sys.exit(130)
def write_current_state(self, plot=True):
"""
Write the current state of the sampler to disk.
The required information to reconstruct the state of the run are written
to an hdf5 file.
All but the most recent removed live point in the chain are removed from
the sampler to reduce memory usage.
This means it is necessary to not append the first live point to the
file if updating a previous checkpoint.
Parameters
----------
sampler: `dynesty.NestedSampler`
NestedSampler to write to disk.
"""
check_directory_exists_and_if_not_mkdir(self.outdir)
logger.info("Writing checkpoint file {}".format(self.resume_file))
current_state = dict(
unit_cube_samples=self.sampler.saved_u,
physical_samples=self.sampler.saved_v,
sample_likelihoods=self.sampler.saved_logl,
sample_log_volume=self.sampler.saved_logvol,
sample_log_weights=self.sampler.saved_logwt,
cumulative_log_evidence=self.sampler.saved_logz,
cumulative_log_evidence_error=self.sampler.saved_logzvar,
cumulative_information=self.sampler.saved_h,
id=self.sampler.saved_id,
it=self.sampler.saved_it,
nc=self.sampler.saved_nc,
boundidx=self.sampler.saved_boundidx,
bounditer=self.sampler.saved_bounditer,
scale=self.sampler.saved_scale,
)
current_state.update(
ncall=self.sampler.ncall, live_logl=self.sampler.live_logl,
iteration=self.sampler.it - 1, live_u=self.sampler.live_u,
live_v=self.sampler.live_v, nlive=self.sampler.nlive,
live_bound=self.sampler.live_bound, live_it=self.sampler.live_it,
added_live=self.sampler.added_live
)
try:
weights = np.exp(current_state['sample_log_weights'] -
current_state['cumulative_log_evidence'][-1])
from dynesty.utils import resample_equal
current_state['posterior'] = resample_equal(
np.array(current_state['physical_samples']), weights)
except ValueError:
logger.debug("Unable to create posterior")
with open(self.resume_file, 'wb') as file:
pickle.dump(current_state, file)
if plot and self.check_point_plot:
import dynesty.plotting as dyplot
labels = self.search_parameter_keys
fn = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
fig = dyplot.traceplot(self.sampler.results, labels=labels)[0]
fig.tight_layout()
fig.savefig(fn)
def generate_trace_plots(self, dynesty_results):
check_directory_exists_and_if_not_mkdir(self.outdir)
filename = '{}/{}_trace.png'.format(self.outdir, self.label)
logger.debug("Writing trace plot to {}".format(filename))
from dynesty import plotting as dyplot
fig, axes = dyplot.traceplot(dynesty_results,
labels=self.result.parameter_labels)
fig.tight_layout()
fig.savefig(filename)
def _run_test(self):
import dynesty
import pandas as pd
self.sampler = dynesty.NestedSampler(
loglikelihood=self.log_likelihood,
prior_transform=self.prior_transform,
ndim=self.ndim, **self.sampler_init_kwargs)
sampler_kwargs = self.sampler_function_kwargs.copy()
sampler_kwargs['maxiter'] = 2
self.sampler.run_nested(**sampler_kwargs)
self.result.samples = pd.DataFrame(
self.priors.sample(100))[self.search_parameter_keys].values
self.result.log_evidence = np.nan
self.result.log_evidence_err = np.nan
return self.result
def prior_transform(self, theta):
""" Prior transform method that is passed into the external sampler.
cube we map this back to [0, 1].
Parameters
----------
theta: list
List of sampled values on a unit interval
Returns
-------
list: Properly rescaled sampled values
Notes
-----
Since dynesty allows periodic parameters to wander outside the unit
We also allow parameters with reflective boundaries to wander outside
The logic ensures that when theta < 0 you shift to |theta| and when
theta > 1 you return 2 - theta
"""
theta[self._periodic] = np.mod(theta[self._periodic], 1)
theta_ref = theta[self._reflective]
theta[self._reflective] = np.minimum(
np.maximum(theta_ref, abs(theta_ref)), 2 - theta_ref)
return self.priors.rescale(self._search_parameter_keys, theta)