From e2e043f1af7bf12baba7bc4306082473636122c5 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Mon, 11 Nov 2019 22:57:58 -0600
Subject: [PATCH] Adding a custom dynesty sampler

---
 bilby/core/sampler/dynesty.py | 124 +++++++++++++++++++++++++++++++---
 1 file changed, 114 insertions(+), 10 deletions(-)

diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py
index 102436ce4..982f0cba0 100644
--- a/bilby/core/sampler/dynesty.py
+++ b/bilby/core/sampler/dynesty.py
@@ -14,6 +14,11 @@ from pandas import DataFrame
 from ..utils import logger, check_directory_exists_and_if_not_mkdir, reflect
 from .base_sampler import Sampler, NestedSampler
 
+from numpy import linalg
+from dynesty.utils import unitcheck
+import warnings
+import math
+
 
 class Dynesty(NestedSampler):
     """
@@ -225,6 +230,10 @@ class Dynesty(NestedSampler):
             self.kwargs['live_points'] = (
                 self.get_initial_points_from_prior(
                     self.kwargs['nlive']))
+
+        dynesty.dynesty._SAMPLING["rwalk"] = sample_rwalk_bilby
+        dynesty.nestedsamplers._SAMPLING["rwalk"] = sample_rwalk_bilby
+
         self.sampler = dynesty.NestedSampler(
             loglikelihood=self.log_likelihood,
             prior_transform=self.prior_transform,
@@ -510,15 +519,110 @@ class Dynesty(NestedSampler):
         -------
         list: Properly rescaled sampled values
 
-        Notes
-        -----
-        Since dynesty allows periodic parameters to wander outside the unit,
-        here we transform them depending of if they should be periodic or
-        reflective. For reflective boundaries, theta < 0 you shift to |theta|
-        and when theta > 1 you return 2 - theta. For periodic boundaries,
-        if theta < 0, you shift to 1-|theta| and when theta > 1 you shift to
-        |theta| - 1 (i.e. wrap around).
-
         """
-        theta[self._reflective] = reflect(theta[self._reflective])
         return self.priors.rescale(self._search_parameter_keys, theta)
+
+
+def sample_rwalk_bilby(args):
+    """
+    Modified version of dynesty.sampling.sample_rwalk
+
+    """
+
+    # Unzipping.
+    (u, loglstar, axes, scale,
+     prior_transform, loglikelihood, kwargs) = args
+    rstate = np.random
+
+    # Bounds
+    nonbounded = kwargs.get('nonbounded', None)
+    periodic = kwargs.get('periodic', None)
+    reflective = kwargs.get('reflective', None)
+
+    # Setup.
+    n = len(u)
+    walks = kwargs.get('walks', 25)  # number of steps
+    accept = 0
+    reject = 0
+    fail = 0
+    nfail = 0
+    nc = 0
+    ncall = 0
+
+    drhat, dr, du, u_prop, logl_prop = np.nan, np.nan, np.nan, np.nan, np.nan
+    while nc + nfail < walks or accept == 0:
+        while True:
+
+            # Check scale-factor.
+            if scale == 0.:
+                raise RuntimeError("The random walk sampling is stuck! "
+                                   "Some useful output quantities:\n"
+                                   "u: {0}\n"
+                                   "drhat: {1}\n"
+                                   "dr: {2}\n"
+                                   "du: {3}\n"
+                                   "u_prop: {4}\n"
+                                   "loglstar: {5}\n"
+                                   "logl_prop: {6}\n"
+                                   "axes: {7}\n"
+                                   "scale: {8}."
+                                   .format(u, drhat, dr, du, u_prop,
+                                           loglstar, logl_prop, axes, scale))
+
+            # Propose a direction on the unit n-sphere.
+            drhat = rstate.randn(n)
+            drhat /= linalg.norm(drhat)
+
+            # Scale based on dimensionality.
+            dr = drhat * rstate.rand()**(1. / n)
+
+            # Transform to proposal distribution.
+            du = np.dot(axes, dr)
+            u_prop = u + scale * du
+
+            # Wrap periodic parameters
+            if periodic is not None:
+                u_prop[periodic] = np.mod(u_prop[periodic], 1)
+            # Reflect
+            if reflective is not None:
+                u_prop[reflective] = reflect(u_prop[reflective])
+
+            # Check unit cube constraints.
+            if unitcheck(u_prop, nonbounded):
+                break
+            else:
+                fail += 1
+                nfail += 1
+
+            # Check if we're stuck generating bad numbers.
+            if fail > 100 * walks:
+                warnings.warn("Random number generation appears to be "
+                              "extremely inefficient. Adjusting the "
+                              "scale-factor accordingly.")
+                fail = 0
+                scale *= math.exp(-1. / n)
+
+        # Check proposed point.
+        v_prop = prior_transform(np.array(u_prop))
+        logl_prop = loglikelihood(np.array(v_prop))
+        if logl_prop >= loglstar:
+            u = u_prop
+            v = v_prop
+            logl = logl_prop
+            accept += 1
+        else:
+            reject += 1
+        nc += 1
+        ncall += 1
+
+        # Check if we're stuck generating bad points.
+        if nc > 50 * walks:
+            scale *= math.exp(-1. / n)
+            warnings.warn("Random walk proposals appear to be "
+                          "extremely inefficient. Adjusting the "
+                          "scale-factor accordingly.")
+            nc, accept, reject = 0, 0, 0  # reset values
+
+    blob = {'accept': accept, 'reject': reject, 'fail': nfail, 'scale': scale}
+
+    return u, v, logl, ncall, blob
-- 
GitLab