From e2e043f1af7bf12baba7bc4306082473636122c5 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Mon, 11 Nov 2019 22:57:58 -0600 Subject: [PATCH] Adding a custom dynesty sampler --- bilby/core/sampler/dynesty.py | 124 +++++++++++++++++++++++++++++++--- 1 file changed, 114 insertions(+), 10 deletions(-) diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index 102436ce4..982f0cba0 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -14,6 +14,11 @@ from pandas import DataFrame from ..utils import logger, check_directory_exists_and_if_not_mkdir, reflect from .base_sampler import Sampler, NestedSampler +from numpy import linalg +from dynesty.utils import unitcheck +import warnings +import math + class Dynesty(NestedSampler): """ @@ -225,6 +230,10 @@ class Dynesty(NestedSampler): self.kwargs['live_points'] = ( self.get_initial_points_from_prior( self.kwargs['nlive'])) + + dynesty.dynesty._SAMPLING["rwalk"] = sample_rwalk_bilby + dynesty.nestedsamplers._SAMPLING["rwalk"] = sample_rwalk_bilby + self.sampler = dynesty.NestedSampler( loglikelihood=self.log_likelihood, prior_transform=self.prior_transform, @@ -510,15 +519,110 @@ class Dynesty(NestedSampler): ------- list: Properly rescaled sampled values - Notes - ----- - Since dynesty allows periodic parameters to wander outside the unit, - here we transform them depending of if they should be periodic or - reflective. For reflective boundaries, theta < 0 you shift to |theta| - and when theta > 1 you return 2 - theta. For periodic boundaries, - if theta < 0, you shift to 1-|theta| and when theta > 1 you shift to - |theta| - 1 (i.e. wrap around). - """ - theta[self._reflective] = reflect(theta[self._reflective]) return self.priors.rescale(self._search_parameter_keys, theta) + + +def sample_rwalk_bilby(args): + """ + Modified version of dynesty.sampling.sample_rwalk + + """ + + # Unzipping. + (u, loglstar, axes, scale, + prior_transform, loglikelihood, kwargs) = args + rstate = np.random + + # Bounds + nonbounded = kwargs.get('nonbounded', None) + periodic = kwargs.get('periodic', None) + reflective = kwargs.get('reflective', None) + + # Setup. + n = len(u) + walks = kwargs.get('walks', 25) # number of steps + accept = 0 + reject = 0 + fail = 0 + nfail = 0 + nc = 0 + ncall = 0 + + drhat, dr, du, u_prop, logl_prop = np.nan, np.nan, np.nan, np.nan, np.nan + while nc + nfail < walks or accept == 0: + while True: + + # Check scale-factor. + if scale == 0.: + raise RuntimeError("The random walk sampling is stuck! " + "Some useful output quantities:\n" + "u: {0}\n" + "drhat: {1}\n" + "dr: {2}\n" + "du: {3}\n" + "u_prop: {4}\n" + "loglstar: {5}\n" + "logl_prop: {6}\n" + "axes: {7}\n" + "scale: {8}." + .format(u, drhat, dr, du, u_prop, + loglstar, logl_prop, axes, scale)) + + # Propose a direction on the unit n-sphere. + drhat = rstate.randn(n) + drhat /= linalg.norm(drhat) + + # Scale based on dimensionality. + dr = drhat * rstate.rand()**(1. / n) + + # Transform to proposal distribution. + du = np.dot(axes, dr) + u_prop = u + scale * du + + # Wrap periodic parameters + if periodic is not None: + u_prop[periodic] = np.mod(u_prop[periodic], 1) + # Reflect + if reflective is not None: + u_prop[reflective] = reflect(u_prop[reflective]) + + # Check unit cube constraints. + if unitcheck(u_prop, nonbounded): + break + else: + fail += 1 + nfail += 1 + + # Check if we're stuck generating bad numbers. + if fail > 100 * walks: + warnings.warn("Random number generation appears to be " + "extremely inefficient. Adjusting the " + "scale-factor accordingly.") + fail = 0 + scale *= math.exp(-1. / n) + + # Check proposed point. + v_prop = prior_transform(np.array(u_prop)) + logl_prop = loglikelihood(np.array(v_prop)) + if logl_prop >= loglstar: + u = u_prop + v = v_prop + logl = logl_prop + accept += 1 + else: + reject += 1 + nc += 1 + ncall += 1 + + # Check if we're stuck generating bad points. + if nc > 50 * walks: + scale *= math.exp(-1. / n) + warnings.warn("Random walk proposals appear to be " + "extremely inefficient. Adjusting the " + "scale-factor accordingly.") + nc, accept, reject = 0, 0, 0 # reset values + + blob = {'accept': accept, 'reject': reject, 'fail': nfail, 'scale': scale} + + return u, v, logl, ncall, blob -- GitLab