From 9c9b7096c87480667d5c3328daaf3322bfbddf64 Mon Sep 17 00:00:00 2001
From: RorySmith <rory.smith@caltech.edu>
Date: Mon, 4 May 2020 09:25:49 +1000
Subject: [PATCH] reverted to old rwalk

---
 src/analysis.py | 53 +++++++++++++++++++++++++++----------------------
 1 file changed, 29 insertions(+), 24 deletions(-)

diff --git a/src/analysis.py b/src/analysis.py
index b17d20e..75f881d 100644
--- a/src/analysis.py
+++ b/src/analysis.py
@@ -43,8 +43,6 @@ logger = bilby.core.utils.logger
 def main():
     """ Do nothing function to play nicely with MPI """
     pass
-
-
 def sample_rwalk_parallel_with_act(args):
     """ A dynesty sampling method optimised for parallel_bilby
 
@@ -59,6 +57,7 @@ def sample_rwalk_parallel_with_act(args):
     reflective = kwargs.get("reflective", None)
 
     # Setup.
+
     n = len(u)
     walks = kwargs.get("walks", 50)  # minimum number of steps
     maxmcmc = kwargs.get("maxmcmc", 10000)  # maximum number of steps
@@ -73,7 +72,9 @@ def sample_rwalk_parallel_with_act(args):
     logl_list = []
 
     drhat, dr, du, u_prop, logl_prop = np.nan, np.nan, np.nan, np.nan, np.nan
-    while len(u_list) < nact * act:
+    i = 0
+    while i < nact * act:
+        i += 1
         # Propose a direction on the unit n-sphere.
         drhat = rstate.randn(n)
         drhat /= linalg.norm(drhat)
@@ -98,6 +99,7 @@ def sample_rwalk_parallel_with_act(args):
         else:
             nfail += 1
             if accept > 0:
+
                 u_list.append(u_list[-1])
                 v_list.append(v_list[-1])
                 logl_list.append(logl_list[-1])
@@ -132,35 +134,38 @@ def sample_rwalk_parallel_with_act(args):
 
         # If we've taken too many likelihood evaluations then break
         if accept + reject > maxmcmc:
-            logger.warning(
-                "Hit maximum number of walks {} with accept={}, reject={}, "
-                "nfail={}, and act={}. Try increasing maxmcmc".format(
-                    maxmcmc, accept, reject, nfail, act
-                )
-            )
+            #logger.warning(
+            #    "Hit maximum number of walks {} with accept={}, reject={}, "
+            #    "nfail={}, and act={}. Try increasing maxmcmc".format(
+            #        maxmcmc, accept, reject, nfail, act
+            #    )
+            #)
             break
 
-    # If the act is finite, pick randomly from within the chain
-    factor = 0.1
-    if len(u_list) == 0:
-        logger.warning("No accepted points: returning -inf")
-        u = u
-        v = prior_transform(u)
-        logl = -np.inf
-    elif np.isfinite(act) and int(factor * nact * act) < len(u_list):
-        idx = np.random.randint(int(factor * nact * act), len(u_list))
+        # If the act is finite, pick randomly from within the chain
+    if np.isfinite(act) and int(0.5 * nact * act) < len(u_list):
+        idx = np.random.randint(int(0.5 * nact * act), len(u_list))
         u = u_list[idx]
         v = v_list[idx]
         logl = logl_list[idx]
-    else:
-        logger.warning(
-            "len(u_list)={}<{}: returning the last point in the chain".format(
-                len(u_list), int(factor * nact * act)
-            )
-        )
+    elif len(u_list) <= 2 and len(u_list) > 0:
+        #logger.warning("Returning the only point in the chain")
         u = u_list[-1]
         v = v_list[-1]
         logl = logl_list[-1]
+    elif len(u_list) == 0:
+        #logger.warning("No accepted points: returning a random draw")
+        u = np.random.uniform(size=du.shape[0])
+        v = prior_transform(u)
+        logl = loglikelihood(v)
+    else:
+        idx = np.random.randint(int(len(u_list) / 2), len(u_list))
+        #logger.warning("Returning random point in second half of the chain")
+        u = u_list[idx]
+        v = v_list[idx]
+
+
+        logl = logl_list[idx]
 
     blob = {"accept": accept, "reject": reject, "fail": nfail, "scale": scale}
 
-- 
GitLab