Skip to content
Snippets Groups Projects

DEV: update custom dynesty proposals

Merged Colm Talbot requested to merge dynesty-differential into master
Compare and
10 files
+ 1200
253
Compare changes
  • Side-by-side
  • Inline
Files
10
+ 187
195
@@ -3,7 +3,6 @@ import inspect
import os
import sys
import time
import warnings
import numpy as np
from pandas import DataFrame
@@ -13,12 +12,19 @@ from ..utils import (
check_directory_exists_and_if_not_mkdir,
latex_plot_format,
logger,
reflect,
safe_file_dump,
)
from .base_sampler import NestedSampler, Sampler, _SamplingContainer, signal_wrapper
def _set_sampling_kwargs(args):
nact, maxmcmc, proposals, naccept = args
_SamplingContainer.nact = nact
_SamplingContainer.maxmcmc = maxmcmc
_SamplingContainer.proposals = proposals
_SamplingContainer.naccept = naccept
def _prior_transform_wrapper(theta):
"""Wrapper to the prior transformation. Needed for multiprocessing."""
from .base_sampler import _sampling_convenience_dump
@@ -101,38 +107,42 @@ class Dynesty(NestedSampler):
If true, resume run from checkpoint (if available)
maxmcmc: int (5000)
The maximum length of the MCMC exploration to find a new point
nact: int (5)
The number of "autocorrelation" times to continue the MCMC for.
Note that this is a very poor approximation to the true ACT and should
be interpreted very loosely.
rejection_sample_posterior: bool
nact: int (2)
The number of autocorrelation lengths for MCMC exploration.
For use with the :code:`act-walk` and :code:`rwalk` sample methods.
See the dynesty guide in the Bilby docs for more details.
naccept: int (60)
The expected number of accepted steps for MCMC exploration when using
the :code:`acceptance-walk` sampling method.
rejection_sample_posterior: bool (True)
Whether to form the posterior by rejection sampling the nested samples.
If False, the nested samples are resampled with repetition. This was
the default behaviour in :code:`Bilby<=1.4.1` and leads to
non-independent samples being produced.
proposals: iterable (None)
The proposal methods to use during MCMC. This can be some combination
of :code:`"diff", "volumetric"`. See the dynesty guide in the Bilby docs
for more details. default=:code:`["diff"]`.
Other Parameters
================
nlive: int, (1000)
The number of live points, note this can also equivalently be given as
one of [nlive, nlives, n_live_points, npoints]
bound: {'none', 'single', 'multi', 'balls', 'cubes'}, ('multi')
bound: {'live', 'live-multi', 'none', 'single', 'multi', 'balls', 'cubes'}, ('live')
Method used to select new points
sample: {'unif', 'rwalk', 'slice', 'rslice', 'hslice'}, ('rwalk')
sample: {'act-walk', 'acceptance-walk', 'unif', 'rwalk', 'slice',
'rslice', 'hslice', 'rwalk_dynesty'}, ('act-walk')
Method used to sample uniformly within the likelihood constraints,
conditioned on the provided bounds
walks: int
Number of walks taken if using `sample='rwalk'`, defaults to `100`.
walks: int (100)
Number of walks taken if using the dynesty implemented sample methods
Note that the default `walks` in dynesty itself is 25, although using
`ndim * 10` can be a reasonable rule of thumb for new problems.
For :code:`sample="act-walk"` and :code:`sample="rwalk"` this parameter
has no impact on the sampling.
dlogz: float, (0.1)
Stopping criteria
facc: float, (0.2)
The target acceptance fraction for the rwalk evolution. The proposal
scale is tuned to meet this fraction.
save_bounds: bool, (False)
Whether to save the dynesty bounding ellipse objects. This is disabled
by default as it can lead to extremely large memory usage.
"""
@property
@@ -143,7 +153,9 @@ class Dynesty(NestedSampler):
for key, param in params.items()
if param.default != param.empty
}
kwargs["sample"] = "rwalk"
kwargs["sample"] = "act-walk"
kwargs["bound"] = "live"
kwargs["update_interval"] = 600
kwargs["facc"] = 0.2
return kwargs
@@ -184,12 +196,16 @@ class Dynesty(NestedSampler):
exit_code=130,
print_method="tqdm",
maxmcmc=5000,
nact=5,
nact=2,
naccept=60,
rejection_sample_posterior=True,
proposals=None,
**kwargs,
):
_SamplingContainer.maxmcmc = maxmcmc
_SamplingContainer.nact = nact
self.nact = nact
self.naccept = naccept
self.maxmcmc = maxmcmc
self.proposals = proposals
self.print_method = print_method
self._translate_kwargs(kwargs)
super(Dynesty, self).__init__(
@@ -214,7 +230,9 @@ class Dynesty(NestedSampler):
self.nestcheck = nestcheck
if self.n_check_point is None:
self.n_check_point = 1000
self.n_check_point = max(
int(check_point_delta_t / self._log_likelihood_eval_time / 10), 10
)
self.check_point_delta_t = check_point_delta_t
logger.info(f"Checkpoint every check_point_delta_t = {check_point_delta_t}s")
@@ -372,29 +390,104 @@ class Dynesty(NestedSampler):
return Sampler
@signal_wrapper
def run_sampler(self):
def _set_sampling_method(self):
"""
Resolve the sampling method and sampler to use from the provided
:code:`bound` and :code:`sample` arguments.
This requires registering the :code:`bilby` specific methods in the
appropriate locations within :code:`dynesty`.
Additionally, some combinations of bound/sample/proposals are not
compatible and so we either warn the user or raise an error.
"""
import dynesty
logger.info(f"Using dynesty version {dynesty.__version__}")
_set_sampling_kwargs((self.nact, self.maxmcmc, self.proposals, self.naccept))
sample = self.kwargs["sample"]
bound = self.kwargs["bound"]
if self.kwargs.get("sample", "rwalk") == "rwalk":
if sample not in ["rwalk", "act-walk", "acceptance-walk"] and bound in [
"live",
"live-multi",
]:
logger.info(
"Using the bilby-implemented rwalk sample method with ACT estimated walks"
"Live-point based bound method requested with dynesty sample "
f"'{sample}', overwriting to 'multi'"
)
dynesty.dynesty._SAMPLING["rwalk"] = sample_rwalk_bilby
dynesty.nestedsamplers._SAMPLING["rwalk"] = sample_rwalk_bilby
if self.kwargs["walks"] > _SamplingContainer.maxmcmc:
raise DynestySetupError("You have maxmcmc > walks (minimum mcmc)")
if _SamplingContainer.nact < 1:
self.kwargs["bound"] = "multi"
elif bound == "live":
from .dynesty_utils import LivePointSampler
dynesty.dynamicsampler._SAMPLERS["live"] = LivePointSampler
elif bound == "live-multi":
from .dynesty_utils import MultiEllipsoidLivePointSampler
dynesty.dynamicsampler._SAMPLERS[
"live-multi"
] = MultiEllipsoidLivePointSampler
elif sample == "acceptance-walk":
raise DynestySetupError(
"bound must be set to live or live-multi for sample=acceptance-walk"
)
elif self.proposals is None:
logger.warning(
"No proposals specified using dynesty sampling, defaulting "
"to 'volumetric'."
)
self.proposals = ["volumetric"]
_SamplingContainer.proposals = self.proposals
elif "diff" in self.proposals:
raise DynestySetupError(
"bound must be set to live or live-multi to use differential "
"evolution proposals"
)
if sample == "rwalk":
logger.info(
"Using the bilby-implemented rwalk sample method with ACT estimated walks. "
f"An average of {2 * self.nact} steps will be accepted up to chain length "
f"{self.maxmcmc}."
)
from .dynesty_utils import sample_rwalk_bilby
if self.kwargs["walks"] > self.maxmcmc:
raise DynestySetupError("You have maxmcmc < walks (minimum mcmc)")
if self.nact < 1:
raise DynestySetupError("Unable to run with nact < 1")
elif self.kwargs.get("sample") == "rwalk_dynesty":
self._kwargs["sample"] = "rwalk"
logger.info("Using the dynesty-implemented rwalk sample method")
elif self.kwargs.get("sample") == "rstagger_dynesty":
self._kwargs["sample"] = "rstagger"
logger.info("Using the dynesty-implemented rstagger sample method")
dynesty.nestedsamplers._SAMPLING["rwalk"] = sample_rwalk_bilby
elif sample == "acceptance-walk":
logger.info(
"Using the bilby-implemented rwalk sampling with an average of "
f"{self.naccept} accepted steps per MCMC and maximum length {self.maxmcmc}"
)
from .dynesty_utils import fixed_length_rwalk_bilby
dynesty.nestedsamplers._SAMPLING[
"acceptance-walk"
] = fixed_length_rwalk_bilby
elif sample == "act-walk":
logger.info(
"Using the bilby-implemented rwalk sampling tracking the "
f"autocorrelation function and thinning by "
f"{self.nact} with maximum length {self.nact * self.maxmcmc}"
)
from .dynesty_utils import bilby_act_rwalk
dynesty.nestedsamplers._SAMPLING["act-walk"] = bilby_act_rwalk
elif sample == "rwalk_dynesty":
sample = sample.strip("_dynesty")
self.kwargs["sample"] = sample
logger.info(f"Using the dynesty-implemented {sample} sample method")
@signal_wrapper
def run_sampler(self):
import dynesty
logger.info(f"Using dynesty version {dynesty.__version__}")
self._set_sampling_method()
self._setup_pool()
if self.resume:
@@ -446,7 +539,32 @@ class Dynesty(NestedSampler):
return self.result
def _setup_pool(self):
"""
In addition to the usual steps, we need to set the sampling kwargs on
every process. To make sure we get every process, run the kwarg setting
more times than we have processes.
"""
super(Dynesty, self)._setup_pool()
if self.pool is not None:
args = (
[(self.nact, self.maxmcmc, self.proposals, self.naccept)]
* self.npool
* 10
)
self.pool.map(_set_sampling_kwargs, args)
def _generate_result(self, out):
"""
Extract the information we need from the dynesty output. This includes
the evidence, nested samples, run statistics. In addition, we generate
the posterior samples from the nested samples.
Parameters
==========
out: dynesty.result.Result
The dynesty output.
"""
import dynesty
from scipy.special import logsumexp
@@ -500,6 +618,14 @@ class Dynesty(NestedSampler):
sampler_kwargs["add_live"] = True
def _run_external_sampler_with_checkpointing(self):
"""
In order to access the checkpointing, we run the sampler for short
periods of time (less than the checkpoint time) and if sufficient
time has passed, write a checkpoint before continuing. To get the most
informative checkpoint plots, the current live points are added to the
chain of nested samples within dynesty and have to be removed before
restarting the sampler.
"""
logger.debug("Running sampler with checkpointing")
old_ncall = self.sampler.ncall
@@ -646,6 +772,11 @@ class Dynesty(NestedSampler):
self.sampler.M = self.sampler.pool.map
def dump_samples_to_dat(self):
"""
Save the current posterior samples to a space-separated plain-text
file. These are unbiased posterior samples, however, there will not
be many of them until the analysis is nearly over.
"""
sampler = self.sampler
ln_weights = sampler.saved_logwt - sampler.saved_logz[-1]
@@ -664,10 +795,20 @@ class Dynesty(NestedSampler):
df.to_csv(filename, index=False, header=True, sep=" ")
def plot_current_state(self):
import matplotlib.pyplot as plt
"""
Make diagonstic plots of the history and current state of the sampler.
These plots are a mixture of :code:`dynesty` implemented run and trace
plots and our custom stats plot. We also make a copy of the trace plot
using the unit hypercube samples to reflect the internal state of the
sampler.
Any errors during plotting should be handled so that sampling can
continue.
"""
if self.check_point_plot:
import dynesty.plotting as dyplot
import matplotlib.pyplot as plt
labels = [label.replace("_", " ") for label in self.search_parameter_keys]
try:
@@ -757,6 +898,7 @@ class Dynesty(NestedSampler):
plt.close("all")
def _run_test(self):
"""Run the sampler very briefly as a sanity test that it works."""
import pandas as pd
self.sampler = self.sampler_init(
@@ -804,168 +946,17 @@ class Dynesty(NestedSampler):
return self.priors.rescale(self._search_parameter_keys, theta)
def sample_rwalk_bilby(args):
"""Modified bilby-implemented version of dynesty.sampling.sample_rwalk"""
from dynesty.utils import get_random_generator, unitcheck
# Unzipping.
(u, loglstar, axes, scale, prior_transform, loglikelihood, rseed, kwargs) = args
rstate = get_random_generator(rseed)
# Bounds
nonbounded = kwargs.get("nonbounded", None)
if nonbounded is not None and sum(nonbounded) == 0:
nonbounded = None
periodic = kwargs.get("periodic", None)
reflective = kwargs.get("reflective", None)
# Setup.
n = len(u)
walks = kwargs.get("walks", 100) # minimum number of steps
maxmcmc = _SamplingContainer.maxmcmc
nact = _SamplingContainer.nact
old_act = getattr(_SamplingContainer, "old_act", walks)
# Initialize internal variables
accept = 0
reject = 0
nfail = 0
act = np.inf
u_list = []
v_list = []
logl_list = []
ii = 0
while ii < nact * act:
ii += 1
# Propose a direction on the unit n-sphere.
drhat = rstate.normal(0, 1, n)
drhat /= np.linalg.norm(drhat)
# Scale based on dimensionality.
dr = drhat * rstate.uniform(0, 1) ** (1.0 / n)
# Transform to proposal distribution.
du = np.dot(axes, dr)
u_prop = u + scale * du
# Wrap periodic parameters
if periodic is not None:
u_prop[periodic] = np.mod(u_prop[periodic], 1)
# Reflect
if reflective is not None:
u_prop[reflective] = reflect(u_prop[reflective])
# Check unit cube constraints.
if unitcheck(u_prop, nonbounded):
pass
else:
nfail += 1
# Only start appending to the chain once a single jump is made
if accept > 0:
u_list.append(u_list[-1])
v_list.append(v_list[-1])
logl_list.append(logl_list[-1])
continue
# Check proposed point.
v_prop = prior_transform(np.array(u_prop))
logl_prop = loglikelihood(np.array(v_prop))
if logl_prop > loglstar:
u = u_prop
v = v_prop
logl = logl_prop
accept += 1
u_list.append(u)
v_list.append(v)
logl_list.append(logl)
else:
reject += 1
# Only start appending to the chain once a single jump is made
if accept > 0:
u_list.append(u_list[-1])
v_list.append(v_list[-1])
logl_list.append(logl_list[-1])
# If we've taken the minimum number of steps, calculate the ACT
if accept + reject > walks:
act = estimate_nmcmc(
accept_ratio=accept / (accept + reject + nfail),
old_act=old_act,
maxmcmc=maxmcmc,
)
# If we've taken too many likelihood evaluations then break
if accept + reject > maxmcmc:
warnings.warn(
f"Hit maximum number of walks {maxmcmc} with accept={accept},"
f" reject={reject}, and nfail={nfail} try increasing maxmcmc"
)
break
# If the act is finite, pick randomly from within the chain
if np.isfinite(act) and int(0.5 * nact * act) < len(u_list):
idx = np.random.randint(int(0.5 * nact * act), len(u_list))
u = u_list[idx]
v = v_list[idx]
logl = logl_list[idx]
else:
logger.debug("Unable to find a new point using walk: returning a random point")
u = np.random.uniform(size=n)
v = prior_transform(u)
logl = loglikelihood(v)
blob = {"accept": accept, "reject": reject, "fail": nfail, "scale": scale}
_SamplingContainer.old_act = act
ncall = accept + reject
return u, v, logl, ncall, blob
def estimate_nmcmc(accept_ratio, old_act, maxmcmc, safety=5, tau=None):
"""Estimate autocorrelation length of chain using acceptance fraction
Using ACL = (2/acc) - 1 multiplied by a safety margin. Code adapted from CPNest:
- https://github.com/johnveitch/cpnest/blob/master/cpnest/sampler.py
- http://github.com/farr/Ensemble.jl
Parameters
==========
accept_ratio: float [0, 1]
Ratio of the number of accepted points to the total number of points
old_act: int
The ACT of the last iteration
maxmcmc: int
The maximum length of the MCMC chain to use
safety: int
A safety factor applied in the calculation
tau: int (optional)
The ACT, if given, otherwise estimated.
"""
if tau is None:
tau = maxmcmc / safety
if accept_ratio == 0.0:
Nmcmc_exact = (1 + 1 / tau) * old_act
else:
Nmcmc_exact = (1.0 - 1.0 / tau) * old_act + (safety / tau) * (
2.0 / accept_ratio - 1.0
)
Nmcmc_exact = float(min(Nmcmc_exact, maxmcmc))
return max(safety, Nmcmc_exact)
@latex_plot_format
def dynesty_stats_plot(sampler):
"""
Plot diagnostic statistics from a dynesty run
The plotted quantities per iteration are:
- nc: the number of likelihood calls
- scale: the scale applied to the MCMC steps
- scale: the number of accepted MCMC steps if using :code:`bound="live"`
or :code:`bound="live-multi"`, otherwise, the scale applied to the MCMC
steps
- lifetime: the number of iterations a point stays in the live set
There is also a histogram of the lifetime compared with the theoretical
@@ -973,7 +964,8 @@ def dynesty_stats_plot(sampler):
Parameters
----------
sampler
sampler: dynesty.sampler.Sampler
The sampler object containing the run history.
Returns
-------
Loading