From ac05c0f6fb28376572bd429dac9d71655920d590 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Mon, 13 Mar 2023 22:24:38 +0000
Subject: [PATCH] Merge branch 'fix-act-walk-caching' into 'master'

BUGFIX: fix act-walk caching

See merge request lscsoft/bilby!1225

(cherry picked from commit e019b5a210252bb884a854bf450f6e10ba712837)

78a23a7a BUGFIX: fix act-walk caching
3bd9e531 FORMAT: fix formatting issues
76e813ee BUGFIX: typo in _cache attribute
06985f91 BGUFIX: track act for rwalk method
dff3a72c TYPO: add missing blank line
f6017840 TEST: bugfix in dynesty test
---
 bilby/core/sampler/dynesty.py       | 16 ++++++++--------
 bilby/core/sampler/dynesty_utils.py | 16 +++++++++-------
 test/core/sampler/dynesty_test.py   |  1 +
 3 files changed, 18 insertions(+), 15 deletions(-)

diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py
index dc748792a..498a0fffb 100644
--- a/bilby/core/sampler/dynesty.py
+++ b/bilby/core/sampler/dynesty.py
@@ -450,32 +450,32 @@ class Dynesty(NestedSampler):
                 f"An average of {2 * self.nact} steps will be accepted up to chain length "
                 f"{self.maxmcmc}."
             )
-            from .dynesty_utils import sample_rwalk_bilby
+            from .dynesty_utils import AcceptanceTrackingRWalk
 
             if self.kwargs["walks"] > self.maxmcmc:
                 raise DynestySetupError("You have maxmcmc < walks (minimum mcmc)")
             if self.nact < 1:
                 raise DynestySetupError("Unable to run with nact < 1")
-            dynesty.nestedsamplers._SAMPLING["rwalk"] = sample_rwalk_bilby
+            AcceptanceTrackingRWalk.old_act = None
+            dynesty.nestedsamplers._SAMPLING["rwalk"] = AcceptanceTrackingRWalk()
         elif sample == "acceptance-walk":
             logger.info(
                 "Using the bilby-implemented rwalk sampling with an average of "
                 f"{self.naccept} accepted steps per MCMC and maximum length {self.maxmcmc}"
             )
-            from .dynesty_utils import fixed_length_rwalk_bilby
+            from .dynesty_utils import FixedRWalk
 
-            dynesty.nestedsamplers._SAMPLING[
-                "acceptance-walk"
-            ] = fixed_length_rwalk_bilby
+            dynesty.nestedsamplers._SAMPLING["acceptance-walk"] = FixedRWalk()
         elif sample == "act-walk":
             logger.info(
                 "Using the bilby-implemented rwalk sampling tracking the "
                 f"autocorrelation function and thinning by "
                 f"{self.nact} with maximum length {self.nact * self.maxmcmc}"
             )
-            from .dynesty_utils import bilby_act_rwalk
+            from .dynesty_utils import ACTTrackingRWalk
 
-            dynesty.nestedsamplers._SAMPLING["act-walk"] = bilby_act_rwalk
+            ACTTrackingRWalk._cache = list()
+            dynesty.nestedsamplers._SAMPLING["act-walk"] = ACTTrackingRWalk()
         elif sample == "rwalk_dynesty":
             sample = sample.strip("_dynesty")
             self.kwargs["sample"] = sample
diff --git a/bilby/core/sampler/dynesty_utils.py b/bilby/core/sampler/dynesty_utils.py
index a13647427..1d9599112 100644
--- a/bilby/core/sampler/dynesty_utils.py
+++ b/bilby/core/sampler/dynesty_utils.py
@@ -181,8 +181,11 @@ class ACTTrackingRWalk:
     parallel process.
     """
 
+    # the _cache is a class level variable to avoid being forgotten at every
+    # iteration when using multiprocessing
+    _cache = list()
+
     def __init__(self):
-        self._cache = list()
         self.act = 1
         self.thin = getattr(_SamplingContainer, "nact", 2)
         self.maxmcmc = getattr(_SamplingContainer, "maxmcmc", 5000) * 50
@@ -367,10 +370,13 @@ class AcceptanceTrackingRWalk:
     corresponds to specifying :code:`sample="rwalk"`
     """
 
+    # to retain state between calls to pool.Map, this needs to be a class
+    # level attribute
+    old_act = None
+
     def __init__(self, old_act=None):
         self.maxmcmc = getattr(_SamplingContainer, "maxmcmc", 5000)
         self.nact = getattr(_SamplingContainer, "nact", 40)
-        self.old_act = old_act
 
     def __call__(self, args):
         rstate = get_random_generator(args.rseed)
@@ -437,7 +443,7 @@ class AcceptanceTrackingRWalk:
             logl = args.loglikelihood(v)
 
         blob = {"accept": accept, "reject": reject + nfail, "scale": args.scale}
-        self.old_act = act
+        AcceptanceTrackingRWalk.old_act = act
 
         ncall = accept + reject
         return u, v, logl, ncall, blob
@@ -708,7 +714,3 @@ def apply_boundaries_(u_prop, periodic, reflective):
 
 
 proposal_funcs = dict(diff=propose_differetial_evolution, volumetric=propose_volumetric)
-
-fixed_length_rwalk_bilby = FixedRWalk()
-bilby_act_rwalk = ACTTrackingRWalk()
-sample_rwalk_bilby = AcceptanceTrackingRWalk()
diff --git a/test/core/sampler/dynesty_test.py b/test/core/sampler/dynesty_test.py
index fdd1edcff..0abb6907a 100644
--- a/test/core/sampler/dynesty_test.py
+++ b/test/core/sampler/dynesty_test.py
@@ -242,6 +242,7 @@ class TestEstimateNMCMC(unittest.TestCase):
         safety * (2 / accept_ratio - 1)
         """
         sampler = dynesty_utils.AcceptanceTrackingRWalk()
+        dynesty_utils.AcceptanceTrackingRWalk.old_act = None
         for _ in range(10):
             accept_ratio = np.random.uniform()
             safety = np.random.randint(2, 8)
-- 
GitLab