Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • john-veitch/bilby
  • duncanmmacleod/bilby
  • colm.talbot/bilby
  • lscsoft/bilby
  • matthew-pitkin/bilby
  • salvatore-vitale/tupak
  • charlie.hoy/bilby
  • bfarr/bilby
  • virginia.demilio/bilby
  • vivien/bilby
  • eric-howell/bilby
  • sebastian-khan/bilby
  • rhys.green/bilby
  • moritz.huebner/bilby
  • joseph.mills/bilby
  • scott.coughlin/bilby
  • matthew.carney/bilby
  • hyungwon.lee/bilby
  • monica.rizzo/bilby
  • christopher-berry/bilby
  • lindsay.demarchi/bilby
  • kaushik.rao/bilby
  • charles.kimball/bilby
  • andrew.matas/bilby
  • juan.calderonbustillo/bilby
  • patrick-meyers/bilby
  • hannah.middleton/bilby
  • eve.chase/bilby
  • grant.meadors/bilby
  • khun.phukon/bilby
  • sumeet.kulkarni/bilby
  • daniel.reardon/bilby
  • cjhaster/bilby
  • sylvia.biscoveanu/bilby
  • james-clark/bilby
  • meg.millhouse/bilby
  • joshua.willis/bilby
  • nikhil.sarin/bilby
  • paul.easter/bilby
  • youngmin/bilby
  • daniel-williams/bilby
  • shanika.galaudage/bilby
  • bruce.edelman/bilby
  • avi.vajpeyi/bilby
  • isobel.romero-shaw/bilby
  • andrew.kim/bilby
  • dominika.zieba/bilby
  • jonathan.davies/bilby
  • marc.arene/bilby
  • srishti.tiwari/bilby-tidal-heating-eccentric
  • aditya.vijaykumar/bilby
  • michael.williams/bilby
  • cecilio.garcia-quiros/bilby
  • rory-smith/bilby
  • maite.mateu-lucena/bilby
  • wushichao/bilby
  • kaylee.desoto/bilby
  • brandon.piotrzkowski/bilby
  • rossella.gamba/bilby
  • hunter.gabbard/bilby
  • deep.chatterjee/bilby
  • tathagata.ghosh/bilby
  • arunava.mukherjee/bilby
  • philip.relton/bilby
  • reed.essick/bilby
  • pawan.gupta/bilby
  • francisco.hernandez/bilby
  • rhiannon.udall/bilby
  • leo.tsukada/bilby
  • will-farr/bilby
  • vijay.varma/bilby
  • jeremy.baier/bilby
  • joshua.brandt/bilby
  • ethan.payne/bilby
  • ka-lok.lo/bilby
  • antoni.ramos-buades/bilby
  • oliviastephany.wilk/bilby
  • jack.heinzel/bilby
  • samson.leong/bilby-psi4
  • viviana.caceres/bilby
  • nadia.qutob/bilby
  • michael-coughlin/bilby
  • hemantakumar.phurailatpam/bilby
  • boris.goncharov/bilby
  • sama.al-shammari/bilby
  • siqi.zhong/bilby
  • jocelyn-read/bilby
  • marc.penuliar/bilby
  • stephanie.letourneau/bilby
  • alexandresebastien.goettel/bilby
  • alec.gunny/bilby
  • serguei.ossokine/bilby
  • pratyusava.baral/bilby
  • sophie.hourihane/bilby
  • eunsub/bilby
  • james.hart/bilby
  • pratyusava.baral/bilby-tg
  • zhaozc/bilby
  • pratyusava.baral/bilby_SoG
  • tomasz.baka/bilby
  • nicogerardo.bers/bilby
  • soumen.roy/bilby
  • isaac.mcmahon/healpix-redundancy
  • asamakai.baker/bilby-frequency-dependent-antenna-pattern-functions
  • anna.puecher/bilby
  • pratyusava.baral/bilby-x-g
  • thibeau.wouters/bilby
  • christian.adamcewicz/bilby
  • raffi.enficiaud/bilby
109 results
Show changes
Showing
with 1917 additions and 1581 deletions
......@@ -22,14 +22,29 @@ from .pymultinest import Pymultinest
from .ultranest import Ultranest
from .fake_sampler import FakeSampler
from .dnest4 import DNest4
from .zeus import Zeus
from bilby.bilby_mcmc import Bilby_MCMC
from . import proposal
IMPLEMENTED_SAMPLERS = {
'cpnest': Cpnest, 'dnest4': DNest4, 'dynamic_dynesty': DynamicDynesty,
'dynesty': Dynesty, 'emcee': Emcee,'kombine': Kombine, 'nessai': Nessai,
'nestle': Nestle, 'ptemcee': Ptemcee, 'ptmcmcsampler': PTMCMCSampler,
'pymc3': Pymc3, 'pymultinest': Pymultinest, 'pypolychord': PyPolyChord,
'ultranest': Ultranest, 'fake_sampler': FakeSampler}
"bilby_mcmc": Bilby_MCMC,
"cpnest": Cpnest,
"dnest4": DNest4,
"dynamic_dynesty": DynamicDynesty,
"dynesty": Dynesty,
"emcee": Emcee,
"kombine": Kombine,
"nessai": Nessai,
"nestle": Nestle,
"ptemcee": Ptemcee,
"ptmcmcsampler": PTMCMCSampler,
"pymc3": Pymc3,
"pymultinest": Pymultinest,
"pypolychord": PyPolyChord,
"ultranest": Ultranest,
"zeus": Zeus,
"fake_sampler": FakeSampler,
}
if command_line_args.sampler_help:
sampler = command_line_args.sampler_help
......@@ -39,20 +54,36 @@ if command_line_args.sampler_help:
print(sampler_class.__doc__)
else:
if sampler == "None":
print('For help with a specific sampler, call sampler-help with '
'the name of the sampler')
print(
"For help with a specific sampler, call sampler-help with "
"the name of the sampler"
)
else:
print('Requested sampler {} not implemented'.format(sampler))
print('Available samplers = {}'.format(IMPLEMENTED_SAMPLERS))
print("Requested sampler {} not implemented".format(sampler))
print("Available samplers = {}".format(IMPLEMENTED_SAMPLERS))
sys.exit()
def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
sampler='dynesty', use_ratio=None, injection_parameters=None,
conversion_function=None, plot=False, default_priors_file=None,
clean=None, meta_data=None, save=True, gzip=False,
result_class=None, npool=1, **kwargs):
def run_sampler(
likelihood,
priors=None,
label="label",
outdir="outdir",
sampler="dynesty",
use_ratio=None,
injection_parameters=None,
conversion_function=None,
plot=False,
default_priors_file=None,
clean=None,
meta_data=None,
save=True,
gzip=False,
result_class=None,
npool=1,
**kwargs
):
"""
The primary interface to easy parameter estimation
......@@ -115,13 +146,13 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
"""
logger.info(
"Running for label '{}', output will be saved to '{}'".format(
label, outdir))
"Running for label '{}', output will be saved to '{}'".format(label, outdir)
)
if clean:
command_line_args.clean = clean
if command_line_args.clean:
kwargs['resume'] = False
kwargs["resume"] = False
from . import IMPLEMENTED_SAMPLERS
......@@ -142,11 +173,14 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
# Generate the meta-data if not given and append the likelihood meta_data
if meta_data is None:
meta_data = dict()
likelihood.label = label
likelihood.outdir = outdir
meta_data['likelihood'] = likelihood.meta_data
meta_data["loaded_modules"] = loaded_modules_dict()
if command_line_args.bilby_zero_likelihood_mode:
from bilby.core.likelihood import ZeroLikelihood
likelihood = ZeroLikelihood(likelihood)
if isinstance(sampler, Sampler):
......@@ -155,64 +189,88 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
if sampler.lower() in IMPLEMENTED_SAMPLERS:
sampler_class = IMPLEMENTED_SAMPLERS[sampler.lower()]
sampler = sampler_class(
likelihood, priors=priors, outdir=outdir, label=label,
injection_parameters=injection_parameters, meta_data=meta_data,
use_ratio=use_ratio, plot=plot, result_class=result_class,
npool=npool, **kwargs)
likelihood,
priors=priors,
outdir=outdir,
label=label,
injection_parameters=injection_parameters,
meta_data=meta_data,
use_ratio=use_ratio,
plot=plot,
result_class=result_class,
npool=npool,
**kwargs
)
else:
print(IMPLEMENTED_SAMPLERS)
raise ValueError(
"Sampler {} not yet implemented".format(sampler))
raise ValueError("Sampler {} not yet implemented".format(sampler))
elif inspect.isclass(sampler):
sampler = sampler.__init__(
likelihood, priors=priors,
outdir=outdir, label=label, use_ratio=use_ratio, plot=plot,
injection_parameters=injection_parameters, meta_data=meta_data,
npool=npool, **kwargs)
likelihood,
priors=priors,
outdir=outdir,
label=label,
use_ratio=use_ratio,
plot=plot,
injection_parameters=injection_parameters,
meta_data=meta_data,
npool=npool,
**kwargs
)
else:
raise ValueError(
"Provided sampler should be a Sampler object or name of a known "
"sampler: {}.".format(', '.join(IMPLEMENTED_SAMPLERS.keys())))
"sampler: {}.".format(", ".join(IMPLEMENTED_SAMPLERS.keys()))
)
if sampler.cached_result:
logger.warning("Using cached result")
return sampler.cached_result
start_time = datetime.datetime.now()
if command_line_args.bilby_test_mode:
result = sampler._run_test()
else:
result = sampler.run_sampler()
end_time = datetime.datetime.now()
# Some samplers calculate the sampling time internally
if result.sampling_time is None:
result.sampling_time = end_time - start_time
logger.info('Sampling time: {}'.format(result.sampling_time))
# Convert sampling time into seconds
result.sampling_time = result.sampling_time.total_seconds()
if sampler.use_ratio:
result.log_noise_evidence = likelihood.noise_log_likelihood()
result.log_bayes_factor = result.log_evidence
result.log_evidence = \
result.log_bayes_factor + result.log_noise_evidence
result = sampler.cached_result
else:
result.log_noise_evidence = likelihood.noise_log_likelihood()
result.log_bayes_factor = \
result.log_evidence - result.log_noise_evidence
# Run the sampler
start_time = datetime.datetime.now()
if command_line_args.bilby_test_mode:
result = sampler._run_test()
else:
result = sampler.run_sampler()
end_time = datetime.datetime.now()
# Initial save of the sampler in case of failure in post-processing
if save:
result.save_to_file(extension=save, gzip=gzip)
# Some samplers calculate the sampling time internally
if result.sampling_time is None:
result.sampling_time = end_time - start_time
elif isinstance(result.sampling_time, (float, int)):
result.sampling_time = datetime.timedelta(result.sampling_time)
logger.info('Sampling time: {}'.format(result.sampling_time))
# Convert sampling time into seconds
result.sampling_time = result.sampling_time.total_seconds()
if sampler.use_ratio:
result.log_noise_evidence = likelihood.noise_log_likelihood()
result.log_bayes_factor = result.log_evidence
result.log_evidence = \
result.log_bayes_factor + result.log_noise_evidence
else:
result.log_noise_evidence = likelihood.noise_log_likelihood()
result.log_bayes_factor = \
result.log_evidence - result.log_noise_evidence
if None not in [result.injection_parameters, conversion_function]:
result.injection_parameters = conversion_function(
result.injection_parameters)
# Initial save of the sampler in case of failure in samples_to_posterior
if save:
result.save_to_file(extension=save, gzip=gzip)
if None not in [result.injection_parameters, conversion_function]:
result.injection_parameters = conversion_function(
result.injection_parameters)
result.injection_parameters = conversion_function(result.injection_parameters)
result.samples_to_posterior(likelihood=likelihood, priors=result.priors,
conversion_function=conversion_function,
npool=npool)
# Check if the posterior has already been created
if getattr(result, "_posterior", None) is None:
result.samples_to_posterior(likelihood=likelihood, priors=result.priors,
conversion_function=conversion_function,
npool=npool)
if save:
# The overwrite here ensures we overwrite the initially stored data
......@@ -229,5 +287,7 @@ def _check_marginalized_parameters_not_sampled(likelihood, priors):
if key in priors:
if not isinstance(priors[key], (float, DeltaFunction)):
raise SamplingMarginalisedParameterError(
"Likelihood is {} marginalized but you are trying to sample in {}. "
.format(key, key))
"Likelihood is {} marginalized but you are trying to sample in {}. ".format(
key, key
)
)
......@@ -446,6 +446,8 @@ class Sampler(object):
==========
theta: array_like
Parameter values at which to evaluate likelihood
warning: bool
Whether or not to print a warning
Returns
=======
......@@ -453,14 +455,19 @@ class Sampler(object):
True if the likelihood and prior are finite, false otherwise
"""
bad_values = [np.inf, np.nan_to_num(np.inf), np.nan]
if abs(self.log_prior(theta)) in bad_values:
log_p = self.log_prior(theta)
log_l = self.log_likelihood(theta)
return \
self._check_bad_value(val=log_p, warning=warning, theta=theta, label='prior') and \
self._check_bad_value(val=log_l, warning=warning, theta=theta, label='likelihood')
@staticmethod
def _check_bad_value(val, warning, theta, label):
val = np.abs(val)
bad_values = [np.inf, np.nan_to_num(np.inf)]
if val in bad_values or np.isnan(val):
if warning:
logger.warning('Prior draw {} has inf prior'.format(theta))
return False
if abs(self.log_likelihood(theta)) in bad_values:
if warning:
logger.warning('Prior draw {} has inf likelihood'.format(theta))
logger.warning(f'Prior draw {theta} has inf {label}')
return False
return True
......@@ -498,14 +505,15 @@ class Sampler(object):
logger.debug("Checking cached data")
if self.cached_result:
check_keys = ['search_parameter_keys', 'fixed_parameter_keys',
'kwargs']
check_keys = ['search_parameter_keys', 'fixed_parameter_keys']
use_cache = True
for key in check_keys:
if self.cached_result._check_attribute_match_to_other_object(
key, self) is False:
logger.debug("Cached value {} is unmatched".format(key))
use_cache = False
if self.meta_data["likelihood"] != self.cached_result.meta_data["likelihood"]:
use_cache = False
if use_cache is False:
self.cached_result = None
......@@ -661,6 +669,10 @@ class SamplerError(Error):
""" Base class for Error related to samplers in this module """
class ResumeError(Error):
""" Class for errors arising from resuming runs """
class SamplerNotInstalledError(SamplerError):
""" Base class for Error raised by not installed samplers """
......
......@@ -13,6 +13,7 @@ from ..utils import (
check_directory_exists_and_if_not_mkdir,
reflect,
safe_file_dump,
latex_plot_format,
)
from .base_sampler import Sampler, NestedSampler
from ..result import rejection_sample
......@@ -91,7 +92,7 @@ class Dynesty(NestedSampler):
only advisable for testing environments
Other Parameters
------==========
================
npoints: int, (1000)
The number of live points, note this can also equivalently be given as
one of [nlive, nlives, n_live_points]
......@@ -101,7 +102,9 @@ class Dynesty(NestedSampler):
Method used to sample uniformly within the likelihood constraints,
conditioned on the provided bounds
walks: int
Number of walks taken if using `sample='rwalk'`, defaults to `ndim * 10`
Number of walks taken if using `sample='rwalk'`, defaults to `100`.
Note that the default `walks` in dynesty itself is 25, although using
`ndim * 10` can be a reasonable rule of thumb for new problems.
dlogz: float, (0.1)
Stopping criteria
verbose: Bool
......@@ -119,6 +122,12 @@ class Dynesty(NestedSampler):
If true, resume run from checkpoint (if available)
exit_code: int
The code which the same exits on if it hasn't finished sampling
print_method: str ('tqdm')
The method to use for printing. The options are:
- 'tqdm': use a `tqdm` `pbar`, this is the default.
- 'interval-$TIME': print to `stdout` every `$TIME` seconds,
e.g., 'interval-10' prints every ten seconds, this does not print every iteration
- else: print to `stdout` at every iteration
"""
default_kwargs = dict(bound='multi', sample='rwalk',
verbose=True, periodic=None, reflective=None,
......@@ -134,12 +143,13 @@ class Dynesty(NestedSampler):
dlogz=0.1, maxiter=None, maxcall=None,
logl_max=np.inf, add_live=True, print_progress=True,
save_bounds=False, n_effective=None,
maxmcmc=5000, nact=5)
maxmcmc=5000, nact=5, print_method="tqdm")
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
check_point=True, check_point_plot=True, n_check_point=None,
check_point_delta_t=600, resume=True, exit_code=130, **kwargs):
check_point_delta_t=600, resume=True, nestcheck=False, exit_code=130, **kwargs):
super(Dynesty, self).__init__(likelihood=likelihood, priors=priors,
outdir=outdir, label=label, use_ratio=use_ratio,
plot=plot, skip_import_verification=skip_import_verification,
......@@ -153,6 +163,8 @@ class Dynesty(NestedSampler):
self._reflective = list()
self._apply_dynesty_boundaries()
self.nestcheck = nestcheck
if self.n_check_point is None:
self.n_check_point = 1000
self.check_point_delta_t = check_point_delta_t
......@@ -212,16 +224,27 @@ class Dynesty(NestedSampler):
def _verify_kwargs_against_default_kwargs(self):
from tqdm.auto import tqdm
if not self.kwargs['walks']:
self.kwargs['walks'] = self.ndim * 10
self.kwargs['walks'] = 100
if not self.kwargs['update_interval']:
self.kwargs['update_interval'] = int(0.6 * self.kwargs['nlive'])
if self.kwargs['print_func'] is None:
self.kwargs['print_func'] = self._print_func
self.pbar = tqdm(file=sys.stdout)
print_method = self.kwargs["print_method"]
if print_method == "tqdm":
self.pbar = tqdm(file=sys.stdout)
elif "interval" in print_method:
self._last_print_time = datetime.datetime.now()
self._print_interval = datetime.timedelta(seconds=float(print_method.split("-")[1]))
Sampler._verify_kwargs_against_default_kwargs(self)
def _print_func(self, results, niter, ncall=None, dlogz=None, *args, **kwargs):
""" Replacing status update for dynesty.result.print_func """
if "interval" in self.kwargs["print_method"]:
_time = datetime.datetime.now()
if _time - self._last_print_time < self._print_interval:
return
else:
self._last_print_time = _time
# Extract results at the current iteration.
(worst, ustar, vstar, loglstar, logvol, logwt,
......@@ -246,7 +269,7 @@ class Dynesty(NestedSampler):
key = 'logz'
# Constructing output.
string = []
string = list()
string.append("bound:{:d}".format(bounditer))
string.append("nc:{:3d}".format(nc))
string.append("ncall:{:.1e}".format(ncall))
......@@ -254,8 +277,16 @@ class Dynesty(NestedSampler):
string.append("{}={:0.2f}+/-{:0.2f}".format(key, logz, logzerr))
string.append("dlogz:{:0.3f}>{:0.2g}".format(delta_logz, dlogz))
self.pbar.set_postfix_str(" ".join(string), refresh=False)
self.pbar.update(niter - self.pbar.n)
if self.kwargs["print_method"] == "tqdm":
self.pbar.set_postfix_str(" ".join(string), refresh=False)
self.pbar.update(niter - self.pbar.n)
elif "interval" in self.kwargs["print_method"]:
formatted = " ".join([str(_time - self.start_time)] + string)
print("{}it [{}]".format(niter, formatted), file=sys.stdout)
else:
_time = datetime.datetime.now()
formatted = " ".join([str(_time - self.start_time)] + string)
print("{}it [{}]".format(niter, formatted), file=sys.stdout)
def _apply_dynesty_boundaries(self):
self._periodic = list()
......@@ -274,6 +305,14 @@ class Dynesty(NestedSampler):
self.kwargs["periodic"] = self._periodic
self.kwargs["reflective"] = self._reflective
def nestcheck_data(self, out_file):
import nestcheck.data_processing
import pickle
ns_run = nestcheck.data_processing.process_dynesty_run(out_file)
nestcheck_result = "{}/{}_nestcheck.pickle".format(self.outdir, self.label)
with open(nestcheck_result, 'wb') as file_nest:
pickle.dump(ns_run, file_nest)
def _setup_pool(self):
if self.kwargs["pool"] is not None:
logger.info("Using user defined pool.")
......@@ -324,7 +363,7 @@ class Dynesty(NestedSampler):
"Using the bilby-implemented rwalk sample method with ACT estimated walks")
dynesty.dynesty._SAMPLING["rwalk"] = sample_rwalk_bilby
dynesty.nestedsamplers._SAMPLING["rwalk"] = sample_rwalk_bilby
if self.kwargs.get("walks", 25) > self.kwargs.get("maxmcmc"):
if self.kwargs.get("walks") > self.kwargs.get("maxmcmc"):
raise DynestySetupError("You have maxmcmc > walks (minimum mcmc)")
if self.kwargs.get("nact", 5) < 1:
raise DynestySetupError("Unable to run with nact < 1")
......@@ -363,17 +402,20 @@ class Dynesty(NestedSampler):
self._close_pool()
# Flushes the output to force a line break
if self.kwargs["verbose"]:
if self.kwargs["verbose"] and self.kwargs["print_method"] == "tqdm":
self.pbar.close()
print("")
check_directory_exists_and_if_not_mkdir(self.outdir)
if self.nestcheck:
self.nestcheck_data(out)
dynesty_result = "{}/{}_dynesty.pickle".format(self.outdir, self.label)
with open(dynesty_result, 'wb') as file:
dill.dump(out, file)
self._generate_result(out)
self.calc_likelihood_count()
self.result.sampling_time = self.sampling_time
if self.plot:
......@@ -383,7 +425,9 @@ class Dynesty(NestedSampler):
def _generate_result(self, out):
import dynesty
weights = np.exp(out['logwt'] - out['logz'][-1])
from scipy.special import logsumexp
logwts = out["logwt"]
weights = np.exp(logwts - out['logz'][-1])
nested_samples = DataFrame(
out.samples, columns=self.search_parameter_keys)
nested_samples['weights'] = weights
......@@ -396,6 +440,16 @@ class Dynesty(NestedSampler):
self.result.log_evidence = out.logz[-1]
self.result.log_evidence_err = out.logzerr[-1]
self.result.information_gain = out.information[-1]
self.result.num_likelihood_evaluations = getattr(self.sampler, 'ncall', 0)
logneff = logsumexp(logwts) * 2 - logsumexp(logwts * 2)
neffsamples = int(np.exp(logneff))
self.result.meta_data["run_statistics"] = dict(
nlikelihood=self.result.num_likelihood_evaluations,
neffsamples=neffsamples,
sampling_time_s=self.sampling_time.seconds,
ncores=self.kwargs.get("queue_size", 1)
)
def _run_nested_wrapper(self, kwargs):
""" Wrapper function to run_nested
......@@ -556,6 +610,11 @@ class Dynesty(NestedSampler):
from ... import __version__ as bilby_version
from dynesty import __version__ as dynesty_version
import dill
if getattr(self, "sampler", None) is None:
# Sampler not initialized, not able to write current state
return
check_directory_exists_and_if_not_mkdir(self.outdir)
end_time = datetime.datetime.now()
if hasattr(self, 'start_time'):
......@@ -641,11 +700,7 @@ class Dynesty(NestedSampler):
plt.close('all')
try:
filename = "{}/{}_checkpoint_stats.png".format(self.outdir, self.label)
fig, axs = plt.subplots(nrows=3, sharex=True)
for ax, name in zip(axs, ["boundidx", "nc", "scale"]):
ax.plot(getattr(self.sampler, "saved_{}".format(name)), color="C0")
ax.set_ylabel(name)
axs[-1].set_xlabel("iteration")
fig, axs = dynesty_stats_plot(self.sampler)
fig.tight_layout()
plt.savefig(filename)
except (RuntimeError, ValueError) as e:
......@@ -701,16 +756,6 @@ class Dynesty(NestedSampler):
"""
return self.priors.rescale(self._search_parameter_keys, theta)
def calc_likelihood_count(self):
if self.likelihood_benchmark:
if hasattr(self, 'sampler'):
self.result.num_likelihood_evaluations = \
getattr(self.sampler, 'ncall', 0)
else:
self.result.num_likelihood_evaluations = 0
else:
return None
def sample_rwalk_bilby(args):
""" Modified bilby-implemented version of dynesty.sampling.sample_rwalk """
......@@ -728,8 +773,8 @@ def sample_rwalk_bilby(args):
# Setup.
n = len(u)
walks = kwargs.get('walks', 25) # minimum number of steps
maxmcmc = kwargs.get('maxmcmc', 2000) # Maximum number of steps
walks = kwargs.get('walks', 100) # minimum number of steps
maxmcmc = kwargs.get('maxmcmc', 5000) # Maximum number of steps
nact = kwargs.get('nact', 5) # Number of ACT
old_act = kwargs.get('old_act', walks)
......@@ -831,7 +876,7 @@ def sample_rwalk_bilby(args):
def estimate_nmcmc(accept_ratio, old_act, maxmcmc, safety=5, tau=None):
""" Estimate autocorrelation length of chain using acceptance fraction
Using ACL = (2/acc) - 1 multiplied by a safety margin. Code adapated from CPNest:
Using ACL = (2/acc) - 1 multiplied by a safety margin. Code adapted from CPNest:
- https://github.com/johnveitch/cpnest/blob/master/cpnest/sampler.py
- http://github.com/farr/Ensemble.jl
......@@ -864,5 +909,72 @@ def estimate_nmcmc(accept_ratio, old_act, maxmcmc, safety=5, tau=None):
return max(safety, int(Nmcmc_exact))
@latex_plot_format
def dynesty_stats_plot(sampler):
"""
Plot diagnostic statistics from a dynesty run
The plotted quantities per iteration are:
- nc: the number of likelihood calls
- scale: the scale applied to the MCMC steps
- lifetime: the number of iterations a point stays in the live set
There is also a histogram of the lifetime compared with the theoretical
distribution. To avoid edge effects, we discard the first 6 * nlive
Parameters
----------
sampler
Returns
-------
fig: matplotlib.pyplot.figure.Figure
Figure handle for the new plot
axs: matplotlib.pyplot.axes.Axes
Axes handles for the new plot
"""
import matplotlib.pyplot as plt
from scipy.stats import geom, ks_1samp
fig, axs = plt.subplots(nrows=4, figsize=(8, 8))
for ax, name in zip(axs, ["nc", "scale"]):
ax.plot(getattr(sampler, "saved_{}".format(name)), color="blue")
ax.set_ylabel(name.title())
lifetimes = np.arange(len(sampler.saved_it)) - sampler.saved_it
axs[-2].set_ylabel("Lifetime")
nlive = sampler.nlive
burn = int(geom(p=1 / nlive).isf(1 / 2 / nlive))
if len(sampler.saved_it) > burn + sampler.nlive:
axs[-2].plot(np.arange(0, burn), lifetimes[:burn], color="grey")
axs[-2].plot(np.arange(burn, len(lifetimes) - nlive), lifetimes[burn: -nlive], color="blue")
axs[-2].plot(np.arange(len(lifetimes) - nlive, len(lifetimes)), lifetimes[-nlive:], color="red")
lifetimes = lifetimes[burn: -nlive]
ks_result = ks_1samp(lifetimes, geom(p=1 / nlive).cdf)
axs[-1].hist(
lifetimes,
bins=np.linspace(0, 6 * nlive, 60),
histtype="step",
density=True,
color="blue",
label=f"p value = {ks_result.pvalue:.3f}"
)
axs[-1].plot(
np.arange(1, 6 * nlive),
geom(p=1 / nlive).pmf(np.arange(1, 6 * nlive)),
color="red"
)
axs[-1].set_xlim(0, 6 * nlive)
axs[-1].legend()
axs[-1].set_yscale("log")
else:
axs[-2].plot(np.arange(0, len(lifetimes) - nlive), lifetimes[:-nlive], color="grey")
axs[-2].plot(np.arange(len(lifetimes) - nlive, len(lifetimes)), lifetimes[-nlive:], color="red")
axs[-2].set_yscale("log")
axs[-2].set_xlabel("Iteration")
axs[-1].set_xlabel("Lifetime")
return fig, axs
class DynestySetupError(Exception):
pass
......@@ -23,7 +23,7 @@ class Emcee(MCMCSampler):
Parameters
==========
nwalkers: int, (100)
nwalkers: int, (500)
The number of walkers
nsteps: int, (100)
The number of steps
......@@ -263,7 +263,7 @@ class Emcee(MCMCSampler):
dill.dump(self._sampler, f)
def checkpoint_and_exit(self, signum, frame):
logger.info("Recieved signal {}".format(signum))
logger.info("Received signal {}".format(signum))
self.checkpoint()
sys.exit()
......@@ -273,9 +273,9 @@ class Emcee(MCMCSampler):
@property
def sampler(self):
""" Returns the ptemcee sampler object
""" Returns the emcee sampler object
If, alrady initialized, returns the stored _sampler value. Otherwise,
If, already initialized, returns the stored _sampler value. Otherwise,
first checks if there is a pickle file from which to load. If there is
not, then initialize the sampler and set the initial random draw
......
......@@ -8,7 +8,7 @@ from ..utils import logger, check_directory_exists_and_if_not_mkdir, load_json
class Nessai(NestedSampler):
"""bilby wrapper of nessai (https://github.com/mj-will/nessai)
All positional and keyword arguments passed to `run_sampler` are propogated
All positional and keyword arguments passed to `run_sampler` are propagated
to `nessai.flowsampler.FlowSampler`
See the documentation for an explanation of the different kwargs.
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from .calculus import *
from .cmd import *
from .colors import *
from .constants import *
from .conversion import *
from .counter import *
from .docs import *
from .introspection import *
from .io import *
from .log import *
from .plotting import *
from .progress import *
from .samples import *
from .series import *
# Instantiate the default argument parser at runtime
command_line_args, command_line_parser = set_up_command_line_arguments()
# Instantiate the default logging
setup_logger(print_version=False, log_level=command_line_args.log_level)
This diff is collapsed.
This diff is collapsed.
class tcolors:
KEY = '\033[93m'
VALUE = '\033[91m'
HIGHLIGHT = '\033[95m'
END = '\033[0m'
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.