From ac05c0f6fb28376572bd429dac9d71655920d590 Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Mon, 13 Mar 2023 22:24:38 +0000 Subject: [PATCH] Merge branch 'fix-act-walk-caching' into 'master' BUGFIX: fix act-walk caching See merge request lscsoft/bilby!1225 (cherry picked from commit e019b5a210252bb884a854bf450f6e10ba712837) 78a23a7a BUGFIX: fix act-walk caching 3bd9e531 FORMAT: fix formatting issues 76e813ee BUGFIX: typo in _cache attribute 06985f91 BGUFIX: track act for rwalk method dff3a72c TYPO: add missing blank line f6017840 TEST: bugfix in dynesty test --- bilby/core/sampler/dynesty.py | 16 ++++++++-------- bilby/core/sampler/dynesty_utils.py | 16 +++++++++------- test/core/sampler/dynesty_test.py | 1 + 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index dc748792a..498a0fffb 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -450,32 +450,32 @@ class Dynesty(NestedSampler): f"An average of {2 * self.nact} steps will be accepted up to chain length " f"{self.maxmcmc}." ) - from .dynesty_utils import sample_rwalk_bilby + from .dynesty_utils import AcceptanceTrackingRWalk if self.kwargs["walks"] > self.maxmcmc: raise DynestySetupError("You have maxmcmc < walks (minimum mcmc)") if self.nact < 1: raise DynestySetupError("Unable to run with nact < 1") - dynesty.nestedsamplers._SAMPLING["rwalk"] = sample_rwalk_bilby + AcceptanceTrackingRWalk.old_act = None + dynesty.nestedsamplers._SAMPLING["rwalk"] = AcceptanceTrackingRWalk() elif sample == "acceptance-walk": logger.info( "Using the bilby-implemented rwalk sampling with an average of " f"{self.naccept} accepted steps per MCMC and maximum length {self.maxmcmc}" ) - from .dynesty_utils import fixed_length_rwalk_bilby + from .dynesty_utils import FixedRWalk - dynesty.nestedsamplers._SAMPLING[ - "acceptance-walk" - ] = fixed_length_rwalk_bilby + dynesty.nestedsamplers._SAMPLING["acceptance-walk"] = FixedRWalk() elif sample == "act-walk": logger.info( "Using the bilby-implemented rwalk sampling tracking the " f"autocorrelation function and thinning by " f"{self.nact} with maximum length {self.nact * self.maxmcmc}" ) - from .dynesty_utils import bilby_act_rwalk + from .dynesty_utils import ACTTrackingRWalk - dynesty.nestedsamplers._SAMPLING["act-walk"] = bilby_act_rwalk + ACTTrackingRWalk._cache = list() + dynesty.nestedsamplers._SAMPLING["act-walk"] = ACTTrackingRWalk() elif sample == "rwalk_dynesty": sample = sample.strip("_dynesty") self.kwargs["sample"] = sample diff --git a/bilby/core/sampler/dynesty_utils.py b/bilby/core/sampler/dynesty_utils.py index a13647427..1d9599112 100644 --- a/bilby/core/sampler/dynesty_utils.py +++ b/bilby/core/sampler/dynesty_utils.py @@ -181,8 +181,11 @@ class ACTTrackingRWalk: parallel process. """ + # the _cache is a class level variable to avoid being forgotten at every + # iteration when using multiprocessing + _cache = list() + def __init__(self): - self._cache = list() self.act = 1 self.thin = getattr(_SamplingContainer, "nact", 2) self.maxmcmc = getattr(_SamplingContainer, "maxmcmc", 5000) * 50 @@ -367,10 +370,13 @@ class AcceptanceTrackingRWalk: corresponds to specifying :code:`sample="rwalk"` """ + # to retain state between calls to pool.Map, this needs to be a class + # level attribute + old_act = None + def __init__(self, old_act=None): self.maxmcmc = getattr(_SamplingContainer, "maxmcmc", 5000) self.nact = getattr(_SamplingContainer, "nact", 40) - self.old_act = old_act def __call__(self, args): rstate = get_random_generator(args.rseed) @@ -437,7 +443,7 @@ class AcceptanceTrackingRWalk: logl = args.loglikelihood(v) blob = {"accept": accept, "reject": reject + nfail, "scale": args.scale} - self.old_act = act + AcceptanceTrackingRWalk.old_act = act ncall = accept + reject return u, v, logl, ncall, blob @@ -708,7 +714,3 @@ def apply_boundaries_(u_prop, periodic, reflective): proposal_funcs = dict(diff=propose_differetial_evolution, volumetric=propose_volumetric) - -fixed_length_rwalk_bilby = FixedRWalk() -bilby_act_rwalk = ACTTrackingRWalk() -sample_rwalk_bilby = AcceptanceTrackingRWalk() diff --git a/test/core/sampler/dynesty_test.py b/test/core/sampler/dynesty_test.py index fdd1edcff..0abb6907a 100644 --- a/test/core/sampler/dynesty_test.py +++ b/test/core/sampler/dynesty_test.py @@ -242,6 +242,7 @@ class TestEstimateNMCMC(unittest.TestCase): safety * (2 / accept_ratio - 1) """ sampler = dynesty_utils.AcceptanceTrackingRWalk() + dynesty_utils.AcceptanceTrackingRWalk.old_act = None for _ in range(10): accept_ratio = np.random.uniform() safety = np.random.randint(2, 8) -- GitLab