Skip to content
Snippets Groups Projects
Commit 08b2a15f authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'dynesty-print' into 'master'

Add an option to use print rather than tqdm for dynesty

Closes #543

See merge request lscsoft/bilby!937
parents e5e9dfa5 a2653421
No related branches found
No related tags found
No related merge requests found
......@@ -92,7 +92,7 @@ class Dynesty(NestedSampler):
only advisable for testing environments
Other Parameters
------==========
================
npoints: int, (1000)
The number of live points, note this can also equivalently be given as
one of [nlive, nlives, n_live_points]
......@@ -122,6 +122,12 @@ class Dynesty(NestedSampler):
If true, resume run from checkpoint (if available)
exit_code: int
The code which the same exits on if it hasn't finished sampling
print_method: str ('tqdm')
The method to use for printing. The options are:
- 'tqdm': use a `tqdm` `pbar`, this is the default.
- 'interval-$TIME': print to `stdout` every `$TIME` seconds,
e.g., 'interval-10' prints every ten seconds, this does not print every iteration
- else: print to `stdout` at every iteration
"""
default_kwargs = dict(bound='multi', sample='rwalk',
verbose=True, periodic=None, reflective=None,
......@@ -137,7 +143,7 @@ class Dynesty(NestedSampler):
dlogz=0.1, maxiter=None, maxcall=None,
logl_max=np.inf, add_live=True, print_progress=True,
save_bounds=False, n_effective=None,
maxmcmc=5000, nact=5)
maxmcmc=5000, nact=5, print_method="tqdm")
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
......@@ -220,11 +226,22 @@ class Dynesty(NestedSampler):
self.kwargs['update_interval'] = int(0.6 * self.kwargs['nlive'])
if self.kwargs['print_func'] is None:
self.kwargs['print_func'] = self._print_func
self.pbar = tqdm(file=sys.stdout)
print_method = self.kwargs["print_method"]
if print_method == "tqdm":
self.pbar = tqdm(file=sys.stdout)
elif "interval" in print_method:
self._last_print_time = datetime.datetime.now()
self._print_interval = datetime.timedelta(seconds=float(print_method.split("-")[1]))
Sampler._verify_kwargs_against_default_kwargs(self)
def _print_func(self, results, niter, ncall=None, dlogz=None, *args, **kwargs):
""" Replacing status update for dynesty.result.print_func """
if "interval" in self.kwargs["print_method"]:
_time = datetime.datetime.now()
if _time - self._last_print_time < self._print_interval:
return
else:
self._last_print_time = _time
# Extract results at the current iteration.
(worst, ustar, vstar, loglstar, logvol, logwt,
......@@ -249,7 +266,7 @@ class Dynesty(NestedSampler):
key = 'logz'
# Constructing output.
string = []
string = list()
string.append("bound:{:d}".format(bounditer))
string.append("nc:{:3d}".format(nc))
string.append("ncall:{:.1e}".format(ncall))
......@@ -257,8 +274,16 @@ class Dynesty(NestedSampler):
string.append("{}={:0.2f}+/-{:0.2f}".format(key, logz, logzerr))
string.append("dlogz:{:0.3f}>{:0.2g}".format(delta_logz, dlogz))
self.pbar.set_postfix_str(" ".join(string), refresh=False)
self.pbar.update(niter - self.pbar.n)
if self.kwargs["print_method"] == "tqdm":
self.pbar.set_postfix_str(" ".join(string), refresh=False)
self.pbar.update(niter - self.pbar.n)
elif "interval" in self.kwargs["print_method"]:
formatted = " ".join([str(_time - self.start_time)] + string)
print("{}it [{}]".format(niter, formatted), file=sys.stdout)
else:
_time = datetime.datetime.now()
formatted = " ".join([str(_time - self.start_time)] + string)
print("{}it [{}]".format(niter, formatted), file=sys.stdout)
def _apply_dynesty_boundaries(self):
self._periodic = list()
......@@ -366,7 +391,7 @@ class Dynesty(NestedSampler):
self._close_pool()
# Flushes the output to force a line break
if self.kwargs["verbose"]:
if self.kwargs["verbose"] and self.kwargs["print_method"] == "tqdm":
self.pbar.close()
print("")
......
......@@ -66,6 +66,7 @@ class TestDynesty(unittest.TestCase):
n_effective=None,
maxmcmc=5000,
nact=5,
print_method="tqdm",
)
self.sampler.kwargs[
"print_func"
......@@ -122,6 +123,7 @@ class TestDynesty(unittest.TestCase):
n_effective=None,
maxmcmc=5000,
nact=5,
print_method="tqdm",
)
for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment