diff --git a/bilby/core/sampler/dynamic_dynesty.py b/bilby/core/sampler/dynamic_dynesty.py index ef28f22ddbb2ce099d481c3adebaae1ef1d3b0cd..b54aad112b070af43c56aa534c10c5f0b9087695 100644 --- a/bilby/core/sampler/dynamic_dynesty.py +++ b/bilby/core/sampler/dynamic_dynesty.py @@ -1,10 +1,4 @@ -import datetime - -import numpy as np - -from ..utils import logger -from .base_sampler import Sampler, signal_wrapper -from .dynesty import Dynesty, _log_likelihood_wrapper, _prior_transform_wrapper +from .dynesty import Dynesty class DynamicDynesty(Dynesty): @@ -14,221 +8,36 @@ class DynamicDynesty(Dynesty): All positional and keyword arguments (i.e., the args and kwargs) passed to `run_sampler` will be propagated to `dynesty.DynamicNestedSampler`, see - documentation for that class for further help. Under Other Parameter below, - we list commonly all kwargs and the bilby defaults. + documentation for that class for further help. - Parameters - ========== - likelihood: likelihood.Likelihood - A object with a log_l method - priors: bilby.core.prior.PriorDict, dict - Priors to be used in the search. - This has attributes for each parameter to be sampled. - outdir: str, optional - Name of the output directory - label: str, optional - Naming scheme of the output files - use_ratio: bool, optional - Switch to set whether or not you want to use the log-likelihood ratio - or just the log-likelihood - plot: bool, optional - Switch to set whether or not you want to create traceplots - skip_import_verification: bool - Skips the check if the sampler is installed if true. This is - only advisable for testing environments - - Other Parameters - ------========== - bound: {'none', 'single', 'multi', 'balls', 'cubes'}, ('multi') - Method used to select new points - sample: {'unif', 'rwalk', 'slice', 'rslice', 'hslice'}, ('rwalk') - Method used to sample uniformly within the likelihood constraints, - conditioned on the provided bounds - walks: int - Number of walks taken if using `sample='rwalk'`, defaults to `ndim * 5` - verbose: Bool - If true, print information information about the convergence during - check_point: bool, - If true, use check pointing. - check_point_delta_t: float (600) - The approximate checkpoint period (in seconds). Should the run be - interrupted, it can be resumed from the last checkpoint. Set to - `None` to turn-off check pointing - n_check_point: int, optional (None) - The number of steps to take before check pointing (override - check_point_delta_t). - resume: bool - If true, resume run from checkpoint (if available) + For additional documentation see bilby.core.sampler.Dynesty. """ - default_kwargs = dict( - bound="multi", - sample="rwalk", - verbose=True, - check_point_delta_t=600, - first_update=None, - npdim=None, - rstate=None, - queue_size=None, - pool=None, - use_pool=None, - logl_args=None, - logl_kwargs=None, - ptform_args=None, - ptform_kwargs=None, - enlarge=None, - bootstrap=None, - vol_dec=0.5, - vol_check=2.0, - facc=0.5, - slices=5, - walks=None, - update_interval=0.6, - nlive_init=500, - maxiter_init=None, - maxcall_init=None, - dlogz_init=0.01, - logl_max_init=np.inf, - nlive_batch=500, - wt_function=None, - wt_kwargs=None, - maxiter_batch=None, - maxcall_batch=None, - maxiter=None, - maxcall=None, - maxbatch=None, - stop_function=None, - stop_kwargs=None, - use_stop=True, - save_bounds=True, - print_progress=True, - print_func=None, - live_points=None, - ) - - def __init__( - self, - likelihood, - priors, - outdir="outdir", - label="label", - use_ratio=False, - plot=False, - skip_import_verification=False, - check_point=True, - n_check_point=None, - check_point_delta_t=600, - resume=True, - **kwargs, - ): - super(DynamicDynesty, self).__init__( - likelihood=likelihood, - priors=priors, - outdir=outdir, - label=label, - use_ratio=use_ratio, - plot=plot, - skip_import_verification=skip_import_verification, - **kwargs, - ) - self.n_check_point = n_check_point - self.check_point = check_point - self.resume = resume - if self.n_check_point is None: - # If the log_likelihood_eval_time is not calculable then - # check_point is set to False. - if np.isnan(self._log_likelihood_eval_time): - self.check_point = False - n_check_point_raw = check_point_delta_t / self._log_likelihood_eval_time - n_check_point_rnd = int(float(f"{n_check_point_raw:1.0g}")) - self.n_check_point = n_check_point_rnd - - self.resume_file = f"{self.outdir}/{self.label}_resume.pickle" + external_sampler_name = "dynesty" @property - def external_sampler_name(self): - return "dynesty" + def nlive(self): + return self.kwargs["nlive_init"] @property - def sampler_function_kwargs(self): - keys = [ - "nlive_init", - "maxiter_init", - "maxcall_init", - "dlogz_init", - "logl_max_init", - "nlive_batch", - "wt_function", - "wt_kwargs", - "maxiter_batch", - "maxcall_batch", - "maxiter", - "maxcall", - "maxbatch", - "stop_function", - "stop_kwargs", - "use_stop", - "save_bounds", - "print_progress", - "print_func", - "live_points", - ] - return {key: self.kwargs[key] for key in keys} - - @signal_wrapper - def run_sampler(self): - import dynesty + def sampler_init(self): + from dynesty import DynamicNestedSampler - self._setup_pool() - self.sampler = dynesty.DynamicNestedSampler( - loglikelihood=_log_likelihood_wrapper, - prior_transform=_prior_transform_wrapper, - ndim=self.ndim, - **self.sampler_init_kwargs, - ) - - if self.check_point: - out = self._run_external_sampler_with_checkpointing() - else: - out = self._run_external_sampler_without_checkpointing() - self._close_pool() + return DynamicNestedSampler - # Flushes the output to force a line break - if self.kwargs["verbose"]: - print("") - - # self.result.sampler_output = out - self._generate_result(out) - if self.plot: - self.generate_trace_plots(out) - - return self.result - - def _run_external_sampler_with_checkpointing(self): - logger.debug("Running sampler with checkpointing") - if self.resume: - resume = self.read_saved_state(continuing=True) - if resume: - logger.info("Resuming from previous run.") - - old_ncall = self.sampler.ncall - sampler_kwargs = self.sampler_function_kwargs.copy() - sampler_kwargs["maxcall"] = self.n_check_point - self.start_time = datetime.datetime.now() - while True: - sampler_kwargs["maxcall"] += self.n_check_point - self.sampler.run_nested(**sampler_kwargs) - if self.sampler.ncall == old_ncall: - break - old_ncall = self.sampler.ncall - - self.write_current_state() + @property + def sampler_class(self): + from dynesty.dynamicsampler import DynamicSampler - self._remove_checkpoint() - return self.sampler.results + return DynamicSampler - def write_current_state_and_exit(self, signum=None, frame=None): - Sampler.write_current_state_and_exit(self=self, signum=signum, frame=frame) + def finalize_sampler_kwargs(self, sampler_kwargs): + sampler_kwargs["maxcall"] = self.sampler.ncall + self.n_check_point - def _verify_kwargs_against_default_kwargs(self): - Sampler._verify_kwargs_against_default_kwargs(self) + def read_saved_state(self, continuing=False): + resume = super(DynamicDynesty, self).read_saved_state(continuing=continuing) + if not resume: + return resume + else: + self.sampler.loglikelihood.pool = self.pool + return resume diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index 82da609e77a17613f90c248b6b04b7b1c2854bf9..d9c6ab6c01180226ce3bb91e58e0928df9cc00d7 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -1,4 +1,5 @@ import datetime +import inspect import os import sys import time @@ -15,7 +16,7 @@ from ..utils import ( reflect, safe_file_dump, ) -from .base_sampler import NestedSampler, Sampler, signal_wrapper +from .base_sampler import NestedSampler, Sampler, _SamplingContainer, signal_wrapper def _prior_transform_wrapper(theta): @@ -57,8 +58,8 @@ class Dynesty(NestedSampler): All positional and keyword arguments (i.e., the args and kwargs) passed to `run_sampler` will be propagated to `dynesty.NestedSampler`, see - documentation for that class for further help. Under Other Parameter below, - we list commonly all kwargs and the bilby defaults. + documentation for that class for further help. Under Other Parameters below, + we list commonly used kwargs and the Bilby defaults. Parameters ========== @@ -79,12 +80,37 @@ class Dynesty(NestedSampler): skip_import_verification: bool Skips the check if the sampler is installed if true. This is only advisable for testing environments + print_method: str ('tqdm') + The method to use for printing. The options are: + - 'tqdm': use a `tqdm` `pbar`, this is the default. + - 'interval-$TIME': print to `stdout` every `$TIME` seconds, + e.g., 'interval-10' prints every ten seconds, this does not print every iteration + - else: print to `stdout` at every iteration + exit_code: int + The code which the same exits on if it hasn't finished sampling + check_point: bool, + If true, use check pointing. + check_point_plot: bool, + If true, generate a trace plot along with the check-point + check_point_delta_t: float (600) + The minimum checkpoint period (in seconds). Should the run be + interrupted, it can be resumed from the last checkpoint. + n_check_point: int, optional (None) + The number of steps to take before checking whether to check_point. + resume: bool + If true, resume run from checkpoint (if available) + maxmcmc: int (5000) + The maximum length of the MCMC exploration to find a new point + nact: int (5) + The number of "autocorrelation" times to continue the MCMC for. + Note that this is a very poor approximation to the true ACT and should + be interpreted very loosely. Other Parameters ================ - npoints: int, (1000) + nlive: int, (1000) The number of live points, note this can also equivalently be given as - one of [nlive, nlives, n_live_points] + one of [nlive, nlives, n_live_points, npoints] bound: {'none', 'single', 'multi', 'balls', 'cubes'}, ('multi') Method used to select new points sample: {'unif', 'rwalk', 'slice', 'rslice', 'hslice'}, ('rwalk') @@ -96,69 +122,44 @@ class Dynesty(NestedSampler): `ndim * 10` can be a reasonable rule of thumb for new problems. dlogz: float, (0.1) Stopping criteria - print_progress: Bool - If true, print information information about the convergence during. - `verbose` has the same effect. - check_point: bool, - If true, use check pointing. - check_point_plot: bool, - If true, generate a trace plot along with the check-point - check_point_delta_t: float (600) - The minimum checkpoint period (in seconds). Should the run be - interrupted, it can be resumed from the last checkpoint. - n_check_point: int, optional (None) - The number of steps to take before checking whether to check_point. - resume: bool - If true, resume run from checkpoint (if available) - exit_code: int - The code which the same exits on if it hasn't finished sampling - print_method: str ('tqdm') - The method to use for printing. The options are: - - 'tqdm': use a `tqdm` `pbar`, this is the default. - - 'interval-$TIME': print to `stdout` every `$TIME` seconds, - e.g., 'interval-10' prints every ten seconds, this does not print every iteration - - else: print to `stdout` at every iteration + facc: float, (0.2) + The target acceptance fraction for the rwalk evolution. The proposal + scale is tuned to meet this fraction. + save_bounds: bool, (False) + Whether to save the dynesty bounding ellipse objects. This is disabled + by default as it can lead to extremely large memory usage. """ - default_kwargs = dict( - bound="multi", - sample="rwalk", - print_progress=True, - periodic=None, - reflective=None, - check_point_delta_t=1800, - nlive=1000, - first_update=None, - walks=100, - npdim=None, - rstate=None, - queue_size=1, - pool=None, - use_pool=None, - live_points=None, - logl_args=None, - logl_kwargs=None, - ptform_args=None, - ptform_kwargs=None, - enlarge=1.5, - bootstrap=None, - vol_dec=0.5, - vol_check=8.0, - facc=0.2, - slices=5, - update_interval=None, - print_func=None, - dlogz=0.1, - maxiter=None, - maxcall=None, - logl_max=np.inf, - add_live=True, - save_bounds=False, - n_effective=None, - maxmcmc=5000, - nact=5, - print_method="tqdm", - ) + @property + def _dynesty_init_kwargs(self): + params = inspect.signature(self.sampler_init).parameters + kwargs = { + key: param.default + for key, param in params.items() + if param.default != param.empty + } + kwargs["sample"] = "rwalk" + kwargs["facc"] = 0.2 + return kwargs + + @property + def _dynesty_sampler_kwargs(self): + params = inspect.signature(self.sampler_class.run_nested).parameters + kwargs = { + key: param.default + for key, param in params.items() + if param.default != param.empty + } + kwargs["save_bounds"] = False + if "dlogz" in kwargs: + kwargs["dlogz"] = 0.1 + return kwargs + + @property + def default_kwargs(self): + kwargs = self._dynesty_init_kwargs + kwargs.update(self._dynesty_sampler_kwargs) + return kwargs def __init__( self, @@ -176,8 +177,14 @@ class Dynesty(NestedSampler): resume=True, nestcheck=False, exit_code=130, + print_method="tqdm", + maxmcmc=5000, + nact=5, **kwargs, ): + _SamplingContainer.maxmcmc = maxmcmc + _SamplingContainer.nact = nact + self.print_method = print_method self._translate_kwargs(kwargs) super(Dynesty, self).__init__( likelihood=likelihood, @@ -194,9 +201,8 @@ class Dynesty(NestedSampler): self.check_point = check_point self.check_point_plot = check_point_plot self.resume = resume - self._periodic = list() - self._reflective = list() - self._apply_dynesty_boundaries() + self._apply_dynesty_boundaries("periodic") + self._apply_dynesty_boundaries("reflective") self.nestcheck = nestcheck @@ -207,36 +213,15 @@ class Dynesty(NestedSampler): self.resume_file = f"{self.outdir}/{self.label}_resume.pickle" self.sampling_time = datetime.timedelta() - - def __getstate__(self): - """For pickle: remove external_sampler, which can be an unpicklable "module" """ - state = self.__dict__.copy() - if "external_sampler" in state: - del state["external_sampler"] - return state + self.pbar = None @property def sampler_function_kwargs(self): - keys = [ - "dlogz", - "print_progress", - "print_func", - "maxiter", - "maxcall", - "logl_max", - "add_live", - "save_bounds", - "n_effective", - ] - return {key: self.kwargs[key] for key in keys} + return {key: self.kwargs[key] for key in self._dynesty_sampler_kwargs} @property def sampler_init_kwargs(self): - return { - key: value - for key, value in self.kwargs.items() - if key not in self.sampler_function_kwargs - } + return {key: self.kwargs[key] for key in self._dynesty_init_kwargs} def _translate_kwargs(self, kwargs): kwargs = super()._translate_kwargs(kwargs) @@ -257,27 +242,32 @@ class Dynesty(NestedSampler): kwargs["queue_size"] = kwargs.pop(equiv) def _verify_kwargs_against_default_kwargs(self): - from tqdm.auto import tqdm - if not self.kwargs["walks"]: self.kwargs["walks"] = 100 - if not self.kwargs["update_interval"]: - self.kwargs["update_interval"] = int(0.6 * self.kwargs["nlive"]) if self.kwargs["print_func"] is None: self.kwargs["print_func"] = self._print_func - print_method = self.kwargs["print_method"] - if print_method == "tqdm" and self.kwargs["print_progress"]: - self.pbar = tqdm(file=sys.stdout) - elif "interval" in print_method: + if "interval" in self.print_method: self._last_print_time = datetime.datetime.now() self._print_interval = datetime.timedelta( - seconds=float(print_method.split("-")[1]) + seconds=float(self.print_method.split("-")[1]) ) Sampler._verify_kwargs_against_default_kwargs(self) - def _print_func(self, results, niter, ncall=None, dlogz=None, *args, **kwargs): + def _print_func( + self, + results, + niter, + ncall=None, + dlogz=None, + stop_val=None, + nbatch=None, + logl_min=-np.inf, + logl_max=np.inf, + *args, + **kwargs, + ): """Replacing status update for dynesty.result.print_func""" - if "interval" in self.kwargs["print_method"]: + if "interval" in self.print_method: _time = datetime.datetime.now() if _time - self._last_print_time < self._print_interval: return @@ -291,23 +281,13 @@ class Dynesty(NestedSampler): total_time_str = str(total_time).split(".")[0] # Extract results at the current iteration. - ( - worst, - ustar, - vstar, - loglstar, - logvol, - logwt, - logz, - logzvar, - h, - nc, - worst_it, - boundidx, - bounditer, - eff, - delta_logz, - ) = results + loglstar = results.loglstar + delta_logz = results.delta_logz + logz = results.logz + logzvar = results.logzvar + nc = results.nc + bounditer = results.bounditer + eff = results.eff # Adjusting outputs for printing. if delta_logz > 1e6: @@ -333,34 +313,34 @@ class Dynesty(NestedSampler): string.append(f"ncall:{ncall:.1e}") string.append(f"eff:{eff:0.1f}%") string.append(f"{key}={logz:0.2f}+/-{logzerr:0.2f}") - string.append(f"dlogz:{delta_logz:0.3f}>{dlogz:0.2g}") + if nbatch is not None: + string.append(f"batch:{nbatch}") + if logl_min > -np.inf: + string.append(f"logl:{logl_min:.1f} < {loglstar:.1f} < {logl_max:.1f}") + if dlogz is not None: + string.append(f"dlogz:{delta_logz:0.3f}>{dlogz:0.2g}") + else: + string.append(f"stop:{stop_val:6.3f}") + string = " ".join(string) - if self.kwargs["print_method"] == "tqdm": - self.pbar.set_postfix_str(" ".join(string), refresh=False) + if self.print_method == "tqdm": + self.pbar.set_postfix_str(string, refresh=False) self.pbar.update(niter - self.pbar.n) - elif "interval" in self.kwargs["print_method"]: - formatted = " ".join([total_time_str] + string) - print(f"{niter}it [{formatted}]", file=sys.stdout, flush=True) else: - formatted = " ".join([total_time_str] + string) - print(f"{niter}it [{formatted}]", file=sys.stdout, flush=True) - - def _apply_dynesty_boundaries(self): - self._periodic = list() - self._reflective = list() - for ii, key in enumerate(self.search_parameter_keys): - if self.priors[key].boundary == "periodic": - logger.debug(f"Setting periodic boundary for {key}") - self._periodic.append(ii) - elif self.priors[key].boundary == "reflective": - logger.debug(f"Setting reflective boundary for {key}") - self._reflective.append(ii) + print(f"{niter}it [{total_time_str} {string}]", file=sys.stdout, flush=True) + def _apply_dynesty_boundaries(self, key): # The periodic kwargs passed into dynesty allows the parameters to # wander out of the bounds, this includes both periodic and reflective. # these are then handled in the prior_transform - self.kwargs["periodic"] = self._periodic - self.kwargs["reflective"] = self._reflective + selected = list() + for ii, param in enumerate(self.search_parameter_keys): + if self.priors[param].boundary == key: + logger.debug(f"Setting {key} boundary for {param}") + selected.append(ii) + if len(selected) == 0: + selected = None + self.kwargs[key] = selected def nestcheck_data(self, out_file): import pickle @@ -372,6 +352,22 @@ class Dynesty(NestedSampler): with open(nestcheck_result, "wb") as file_nest: pickle.dump(ns_run, file_nest) + @property + def nlive(self): + return self.kwargs["nlive"] + + @property + def sampler_init(self): + from dynesty import NestedSampler + + return NestedSampler + + @property + def sampler_class(self): + from dynesty.sampler import Sampler + + return Sampler + @signal_wrapper def run_sampler(self): import dill @@ -385,9 +381,9 @@ class Dynesty(NestedSampler): ) dynesty.dynesty._SAMPLING["rwalk"] = sample_rwalk_bilby dynesty.nestedsamplers._SAMPLING["rwalk"] = sample_rwalk_bilby - if self.kwargs.get("walks") > self.kwargs.get("maxmcmc"): + if self.kwargs["walks"] > _SamplingContainer.maxmcmc: raise DynestySetupError("You have maxmcmc > walks (minimum mcmc)") - if self.kwargs.get("nact", 5) < 1: + if _SamplingContainer.nact < 1: raise DynestySetupError("Unable to run with nact < 1") elif self.kwargs.get("sample") == "rwalk_dynesty": self._kwargs["sample"] = "rwalk" @@ -406,14 +402,19 @@ class Dynesty(NestedSampler): else: if self.kwargs["live_points"] is None: self.kwargs["live_points"] = self.get_initial_points_from_prior( - self.kwargs["nlive"] + self.nlive ) - self.sampler = dynesty.NestedSampler( + self.kwargs["live_points"] = (*self.kwargs["live_points"], None) + self.sampler = self.sampler_init( loglikelihood=_log_likelihood_wrapper, prior_transform=_prior_transform_wrapper, ndim=self.ndim, **self.sampler_init_kwargs, ) + if self.print_method == "tqdm" and self.kwargs["print_progress"]: + from tqdm.auto import tqdm + + self.pbar = tqdm(file=sys.stdout, initial=self.sampler.it) self.start_time = datetime.datetime.now() if self.check_point: @@ -425,8 +426,8 @@ class Dynesty(NestedSampler): self._close_pool() # Flushes the output to force a line break - if self.kwargs["print_progress"] and self.kwargs["print_method"] == "tqdm": - self.pbar.close() + if self.pbar is not None: + self.pbar = self.pbar.close() print("") check_directory_exists_and_if_not_mkdir(self.outdir) @@ -441,9 +442,6 @@ class Dynesty(NestedSampler): self._generate_result(out) self.result.sampling_time = self.sampling_time - if self.plot: - self.generate_trace_plots(out) - return self.result def _generate_result(self, out): @@ -481,39 +479,23 @@ class Dynesty(NestedSampler): self.sampling_time += end_time - self.start_time self.start_time = end_time - def _run_nested_wrapper(self, kwargs): - """Wrapper function to run_nested - - This wrapper catches exceptions related to different versions of - dynesty accepting different arguments. - - Parameters - ========== - kwargs: dict - The dictionary of kwargs to pass to run_nested - - """ - logger.debug(f"Calling run_nested with sampler_function_kwargs {kwargs}") - try: - self.sampler.run_nested(**kwargs) - except TypeError: - kwargs.pop("n_effective") - self.sampler.run_nested(**kwargs) - def _run_external_sampler_without_checkpointing(self): logger.debug("Running sampler without checkpointing") - self._run_nested_wrapper(self.sampler_function_kwargs) + self.sampler.run_nested(**self.sampler_function_kwargs) return self.sampler.results + def finalize_sampler_kwargs(self, sampler_kwargs): + sampler_kwargs["maxcall"] = self.n_check_point + sampler_kwargs["add_live"] = True + def _run_external_sampler_with_checkpointing(self): logger.debug("Running sampler with checkpointing") old_ncall = self.sampler.ncall sampler_kwargs = self.sampler_function_kwargs.copy() - sampler_kwargs["maxcall"] = self.n_check_point - sampler_kwargs["add_live"] = True while True: - self._run_nested_wrapper(sampler_kwargs) + self.finalize_sampler_kwargs(sampler_kwargs) + self.sampler.run_nested(**sampler_kwargs) if self.sampler.ncall == old_ncall: break old_ncall = self.sampler.ncall @@ -527,11 +509,10 @@ class Dynesty(NestedSampler): if last_checkpoint_s > self.check_point_delta_t: self.write_current_state() self.plot_current_state() - if self.sampler.added_live: + if getattr(self.sampler, "added_live", False): self.sampler._remove_live_points() - sampler_kwargs["add_live"] = True - self._run_nested_wrapper(sampler_kwargs) + self.sampler.run_nested(**sampler_kwargs) self.write_current_state() self.plot_current_state() return self.sampler.results @@ -595,7 +576,6 @@ class Dynesty(NestedSampler): if getattr(self.sampler, "added_live", False) and continuing: self.sampler._remove_live_points() self.sampler.nqueue = -1 - self.sampler.rstate = np.random self.start_time = self.sampler.kwargs.pop("start_time") self.sampling_time = self.sampler.kwargs.pop("sampling_time") self.sampler.pool = self.pool @@ -609,8 +589,8 @@ class Dynesty(NestedSampler): return False def write_current_state_and_exit(self, signum=None, frame=None): - if self.kwargs["print_method"] == "tqdm": - self.pbar.close() + if self.pbar is not None: + self.pbar = self.pbar.close() super(Dynesty, self).write_current_state_and_exit(signum=signum, frame=frame) def write_current_state(self): @@ -699,8 +679,10 @@ class Dynesty(NestedSampler): filename = f"{self.outdir}/{self.label}_checkpoint_trace_unit.png" from copy import deepcopy + from dynesty.utils import results_substitute + temp = deepcopy(self.sampler.results) - temp["samples"] = temp["samples_u"] + temp = results_substitute(temp, dict(samples=temp["samples_u"])) fig = dyplot.traceplot(temp, labels=labels)[0] fig.tight_layout() fig.savefig(filename) @@ -735,26 +717,15 @@ class Dynesty(NestedSampler): except (RuntimeError, ValueError) as e: logger.warning(e) logger.warning("Failed to create dynesty stats plot at checkpoint") + except DynestySetupError: + logger.debug("Cannot create Dynesty stats plot with dynamic sampler.") finally: plt.close("all") - def generate_trace_plots(self, dynesty_results): - check_directory_exists_and_if_not_mkdir(self.outdir) - filename = f"{self.outdir}/{self.label}_trace.png" - logger.debug(f"Writing trace plot to {filename}") - from dynesty import plotting as dyplot - - fig, axes = dyplot.traceplot( - dynesty_results, labels=self.result.parameter_labels - ) - fig.tight_layout() - fig.savefig(filename) - def _run_test(self): - import dynesty import pandas as pd - self.sampler = dynesty.NestedSampler( + self.sampler = self.sampler_class( loglikelihood=self.log_likelihood, prior_transform=self.prior_transform, ndim=self.ndim, @@ -794,23 +765,25 @@ class Dynesty(NestedSampler): def sample_rwalk_bilby(args): """Modified bilby-implemented version of dynesty.sampling.sample_rwalk""" - from dynesty.utils import unitcheck + from dynesty.utils import get_random_generator, unitcheck # Unzipping. - (u, loglstar, axes, scale, prior_transform, loglikelihood, kwargs) = args - rstate = np.random + (u, loglstar, axes, scale, prior_transform, loglikelihood, rseed, kwargs) = args + rstate = get_random_generator(rseed) # Bounds nonbounded = kwargs.get("nonbounded", None) + if nonbounded is not None and sum(nonbounded) == 0: + nonbounded = None periodic = kwargs.get("periodic", None) reflective = kwargs.get("reflective", None) # Setup. n = len(u) walks = kwargs.get("walks", 100) # minimum number of steps - maxmcmc = kwargs.get("maxmcmc", 5000) # Maximum number of steps - nact = kwargs.get("nact", 5) # Number of ACT - old_act = kwargs.get("old_act", walks) + maxmcmc = _SamplingContainer.maxmcmc + nact = _SamplingContainer.nact + old_act = getattr(_SamplingContainer, "old_act", walks) # Initialize internal variables accept = 0 @@ -826,11 +799,11 @@ def sample_rwalk_bilby(args): ii += 1 # Propose a direction on the unit n-sphere. - drhat = rstate.randn(n) + drhat = rstate.normal(0, 1, n) drhat /= np.linalg.norm(drhat) # Scale based on dimensionality. - dr = drhat * rstate.rand() ** (1.0 / n) + dr = drhat * rstate.uniform(0, 1) ** (1.0 / n) # Transform to proposal distribution. du = np.dot(axes, dr) @@ -903,7 +876,7 @@ def sample_rwalk_bilby(args): logl = loglikelihood(v) blob = {"accept": accept, "reject": reject, "fail": nfail, "scale": scale} - kwargs["old_act"] = act + _SamplingContainer.old_act = act ncall = accept + reject return u, v, logl, ncall, blob @@ -941,7 +914,7 @@ def estimate_nmcmc(accept_ratio, old_act, maxmcmc, safety=5, tau=None): 2.0 / accept_ratio - 1.0 ) Nmcmc_exact = float(min(Nmcmc_exact, maxmcmc)) - return max(safety, int(Nmcmc_exact)) + return max(safety, Nmcmc_exact) @latex_plot_format @@ -973,14 +946,17 @@ def dynesty_stats_plot(sampler): from scipy.stats import geom, ks_1samp fig, axs = plt.subplots(nrows=4, figsize=(8, 8)) + data = sampler.saved_run.D for ax, name in zip(axs, ["nc", "scale"]): - ax.plot(getattr(sampler, f"saved_{name}"), color="blue") + ax.plot(data[name], color="blue") ax.set_ylabel(name.title()) - lifetimes = np.arange(len(sampler.saved_it)) - sampler.saved_it + lifetimes = np.arange(len(data["it"])) - data["it"] axs[-2].set_ylabel("Lifetime") + if not hasattr(sampler, "nlive"): + raise DynestySetupError("Cannot make stats plot for dynamic sampler.") nlive = sampler.nlive burn = int(geom(p=1 / nlive).isf(1 / 2 / nlive)) - if len(sampler.saved_it) > burn + sampler.nlive: + if len(data["it"]) > burn + sampler.nlive: axs[-2].plot(np.arange(0, burn), lifetimes[:burn], color="grey") axs[-2].plot( np.arange(burn, len(lifetimes) - nlive), diff --git a/requirements.txt b/requirements.txt index b69a3c7ce373c9df519b9d75a35b3e7260bf43cc..f199d07f19bf96a5dde54b8c5d73657ff8ae4c5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ bilby.cython>=0.3.0 -dynesty<1.1 +dynesty>=2 emcee corner numpy diff --git a/test/core/sampler/dynamic_dynesty_test.py b/test/core/sampler/dynamic_dynesty_test.py index 9b4a926f802e562fb325c5744d9437ef8ea8fae8..f837644157f7463a8ae5bbbe408384c280d57aeb 100644 --- a/test/core/sampler/dynamic_dynesty_test.py +++ b/test/core/sampler/dynamic_dynesty_test.py @@ -1,4 +1,35 @@ import unittest +from unittest.mock import MagicMock + +import bilby + + +class TestDynamicDynesty(unittest.TestCase): + def setUp(self): + self.likelihood = MagicMock() + self.priors = bilby.core.prior.PriorDict( + dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) + ) + self.sampler = bilby.core.sampler.DynamicDynesty( + self.likelihood, + self.priors, + outdir="outdir", + label="label", + use_ratio=False, + plot=False, + skip_import_verification=True, + ) + + def tearDown(self): + del self.likelihood + del self.priors + del self.sampler + + def test_default_kwargs(self): + """Only test the kwargs where we specify different defaults to dynesty""" + expected = dict(sample="rwalk", facc=0.2, save_bounds=False) + for key in expected: + self.assertEqual(expected[key], self.sampler.kwargs[key]) if __name__ == "__main__": diff --git a/test/core/sampler/dynesty_test.py b/test/core/sampler/dynesty_test.py index 8cbab6d318896dde43e524812e662c44d148f1bf..c78f1a4fc6af4af32f2f10f631da3a23d271160c 100644 --- a/test/core/sampler/dynesty_test.py +++ b/test/core/sampler/dynesty_test.py @@ -1,8 +1,7 @@ import unittest +from copy import deepcopy from unittest.mock import MagicMock -import numpy as np - import bilby @@ -28,111 +27,19 @@ class TestDynesty(unittest.TestCase): del self.sampler def test_default_kwargs(self): - expected = dict( - bound="multi", - sample="rwalk", - periodic=None, - reflective=None, - check_point_delta_t=1800, - nlive=1000, - first_update=None, - npdim=None, - rstate=None, - queue_size=1, - pool=None, - use_pool=None, - live_points=None, - logl_args=None, - logl_kwargs=None, - ptform_args=None, - ptform_kwargs=None, - enlarge=1.5, - bootstrap=None, - vol_dec=0.5, - vol_check=8.0, - facc=0.2, - slices=5, - dlogz=0.1, - maxiter=None, - maxcall=None, - logl_max=np.inf, - add_live=True, - print_progress=True, - save_bounds=False, - walks=100, - update_interval=600, - print_func="func", - n_effective=None, - maxmcmc=5000, - nact=5, - print_method="tqdm", - ) - self.sampler.kwargs[ - "print_func" - ] = "func" # set this manually as this is not testable otherwise - # DictEqual can't handle lists so we check these separately - self.assertEqual([], self.sampler.kwargs["periodic"]) - self.assertEqual([], self.sampler.kwargs["reflective"]) - self.sampler.kwargs["periodic"] = expected["periodic"] - self.sampler.kwargs["reflective"] = expected["reflective"] - for key in self.sampler.kwargs.keys(): - print( - "key={}, expected={}, actual={}".format( - key, expected[key], self.sampler.kwargs[key] - ) - ) - self.assertDictEqual(expected, self.sampler.kwargs) + """Only test the kwargs where we specify different defaults to dynesty""" + expected = dict(sample="rwalk", facc=0.2, save_bounds=False, dlogz=0.1) + for key in expected: + self.assertEqual(expected[key], self.sampler.kwargs[key]) def test_translate_kwargs(self): - expected = dict( - bound="multi", - sample="rwalk", - periodic=[], - reflective=[], - check_point_delta_t=1800, - nlive=1000, - first_update=None, - npdim=None, - rstate=None, - queue_size=1, - pool=None, - use_pool=None, - live_points=None, - logl_args=None, - logl_kwargs=None, - ptform_args=None, - ptform_kwargs=None, - enlarge=1.5, - bootstrap=None, - vol_dec=0.5, - vol_check=8.0, - facc=0.2, - slices=5, - dlogz=0.1, - maxiter=None, - maxcall=None, - logl_max=np.inf, - add_live=True, - print_progress=True, - save_bounds=False, - walks=100, - update_interval=600, - print_func="func", - n_effective=None, - maxmcmc=5000, - nact=5, - print_method="tqdm", - ) - + expected = 1000 for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: - new_kwargs = self.sampler.kwargs.copy() + new_kwargs = deepcopy(self.sampler.kwargs) del new_kwargs["nlive"] - new_kwargs[equiv] = 1000 - self.sampler.kwargs = new_kwargs - self.sampler.kwargs[ - "print_func" - ] = "func" # set this manually as this is not testable otherwise - self.assertDictEqual(expected, self.sampler.kwargs) + new_kwargs[equiv] = expected + self.sampler._translate_kwargs(new_kwargs) + self.assertEqual(new_kwargs["nlive"], expected) def test_prior_boundary(self): self.priors["a"] = bilby.core.prior.Prior(boundary="periodic") @@ -150,9 +57,7 @@ class TestDynesty(unittest.TestCase): skip_import_verification=True, ) self.assertEqual([0, 4], self.sampler.kwargs["periodic"]) - self.assertEqual(self.sampler._periodic, self.sampler.kwargs["periodic"]) self.assertEqual([1, 3], self.sampler.kwargs["reflective"]) - self.assertEqual(self.sampler._reflective, self.sampler.kwargs["reflective"]) if __name__ == "__main__":