diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index eb7dff914332d0e25129dd799ee92ea52cf7b7bf..adc28f84922f5ef32b7a7b3a8fc78bb24a2bd04a 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -92,7 +92,7 @@ class Dynesty(NestedSampler): only advisable for testing environments Other Parameters - ------========== + ================ npoints: int, (1000) The number of live points, note this can also equivalently be given as one of [nlive, nlives, n_live_points] @@ -122,6 +122,12 @@ class Dynesty(NestedSampler): If true, resume run from checkpoint (if available) exit_code: int The code which the same exits on if it hasn't finished sampling + print_method: str ('tqdm') + The method to use for printing. The options are: + - 'tqdm': use a `tqdm` `pbar`, this is the default. + - 'interval-$TIME': print to `stdout` every `$TIME` seconds, + e.g., 'interval-10' prints every ten seconds, this does not print every iteration + - else: print to `stdout` at every iteration """ default_kwargs = dict(bound='multi', sample='rwalk', verbose=True, periodic=None, reflective=None, @@ -137,7 +143,7 @@ class Dynesty(NestedSampler): dlogz=0.1, maxiter=None, maxcall=None, logl_max=np.inf, add_live=True, print_progress=True, save_bounds=False, n_effective=None, - maxmcmc=5000, nact=5) + maxmcmc=5000, nact=5, print_method="tqdm") def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False, skip_import_verification=False, @@ -220,11 +226,22 @@ class Dynesty(NestedSampler): self.kwargs['update_interval'] = int(0.6 * self.kwargs['nlive']) if self.kwargs['print_func'] is None: self.kwargs['print_func'] = self._print_func - self.pbar = tqdm(file=sys.stdout) + print_method = self.kwargs["print_method"] + if print_method == "tqdm": + self.pbar = tqdm(file=sys.stdout) + elif "interval" in print_method: + self._last_print_time = datetime.datetime.now() + self._print_interval = datetime.timedelta(seconds=float(print_method.split("-")[1])) Sampler._verify_kwargs_against_default_kwargs(self) def _print_func(self, results, niter, ncall=None, dlogz=None, *args, **kwargs): """ Replacing status update for dynesty.result.print_func """ + if "interval" in self.kwargs["print_method"]: + _time = datetime.datetime.now() + if _time - self._last_print_time < self._print_interval: + return + else: + self._last_print_time = _time # Extract results at the current iteration. (worst, ustar, vstar, loglstar, logvol, logwt, @@ -249,7 +266,7 @@ class Dynesty(NestedSampler): key = 'logz' # Constructing output. - string = [] + string = list() string.append("bound:{:d}".format(bounditer)) string.append("nc:{:3d}".format(nc)) string.append("ncall:{:.1e}".format(ncall)) @@ -257,8 +274,16 @@ class Dynesty(NestedSampler): string.append("{}={:0.2f}+/-{:0.2f}".format(key, logz, logzerr)) string.append("dlogz:{:0.3f}>{:0.2g}".format(delta_logz, dlogz)) - self.pbar.set_postfix_str(" ".join(string), refresh=False) - self.pbar.update(niter - self.pbar.n) + if self.kwargs["print_method"] == "tqdm": + self.pbar.set_postfix_str(" ".join(string), refresh=False) + self.pbar.update(niter - self.pbar.n) + elif "interval" in self.kwargs["print_method"]: + formatted = " ".join([str(_time - self.start_time)] + string) + print("{}it [{}]".format(niter, formatted), file=sys.stdout) + else: + _time = datetime.datetime.now() + formatted = " ".join([str(_time - self.start_time)] + string) + print("{}it [{}]".format(niter, formatted), file=sys.stdout) def _apply_dynesty_boundaries(self): self._periodic = list() @@ -366,7 +391,7 @@ class Dynesty(NestedSampler): self._close_pool() # Flushes the output to force a line break - if self.kwargs["verbose"]: + if self.kwargs["verbose"] and self.kwargs["print_method"] == "tqdm": self.pbar.close() print("") diff --git a/test/core/sampler/dynesty_test.py b/test/core/sampler/dynesty_test.py index ec8e2274f0582fce8b512edd50ae94b7760e4ba3..2b5ae1f0771c2bf7b8c1d4532af5d0d3f07b8830 100644 --- a/test/core/sampler/dynesty_test.py +++ b/test/core/sampler/dynesty_test.py @@ -66,6 +66,7 @@ class TestDynesty(unittest.TestCase): n_effective=None, maxmcmc=5000, nact=5, + print_method="tqdm", ) self.sampler.kwargs[ "print_func" @@ -122,6 +123,7 @@ class TestDynesty(unittest.TestCase): n_effective=None, maxmcmc=5000, nact=5, + print_method="tqdm", ) for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: