From 9c9b7096c87480667d5c3328daaf3322bfbddf64 Mon Sep 17 00:00:00 2001 From: RorySmith <rory.smith@caltech.edu> Date: Mon, 4 May 2020 09:25:49 +1000 Subject: [PATCH] reverted to old rwalk --- src/analysis.py | 53 +++++++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/src/analysis.py b/src/analysis.py index b17d20e..75f881d 100644 --- a/src/analysis.py +++ b/src/analysis.py @@ -43,8 +43,6 @@ logger = bilby.core.utils.logger def main(): """ Do nothing function to play nicely with MPI """ pass - - def sample_rwalk_parallel_with_act(args): """ A dynesty sampling method optimised for parallel_bilby @@ -59,6 +57,7 @@ def sample_rwalk_parallel_with_act(args): reflective = kwargs.get("reflective", None) # Setup. + n = len(u) walks = kwargs.get("walks", 50) # minimum number of steps maxmcmc = kwargs.get("maxmcmc", 10000) # maximum number of steps @@ -73,7 +72,9 @@ def sample_rwalk_parallel_with_act(args): logl_list = [] drhat, dr, du, u_prop, logl_prop = np.nan, np.nan, np.nan, np.nan, np.nan - while len(u_list) < nact * act: + i = 0 + while i < nact * act: + i += 1 # Propose a direction on the unit n-sphere. drhat = rstate.randn(n) drhat /= linalg.norm(drhat) @@ -98,6 +99,7 @@ def sample_rwalk_parallel_with_act(args): else: nfail += 1 if accept > 0: + u_list.append(u_list[-1]) v_list.append(v_list[-1]) logl_list.append(logl_list[-1]) @@ -132,35 +134,38 @@ def sample_rwalk_parallel_with_act(args): # If we've taken too many likelihood evaluations then break if accept + reject > maxmcmc: - logger.warning( - "Hit maximum number of walks {} with accept={}, reject={}, " - "nfail={}, and act={}. Try increasing maxmcmc".format( - maxmcmc, accept, reject, nfail, act - ) - ) + #logger.warning( + # "Hit maximum number of walks {} with accept={}, reject={}, " + # "nfail={}, and act={}. Try increasing maxmcmc".format( + # maxmcmc, accept, reject, nfail, act + # ) + #) break - # If the act is finite, pick randomly from within the chain - factor = 0.1 - if len(u_list) == 0: - logger.warning("No accepted points: returning -inf") - u = u - v = prior_transform(u) - logl = -np.inf - elif np.isfinite(act) and int(factor * nact * act) < len(u_list): - idx = np.random.randint(int(factor * nact * act), len(u_list)) + # If the act is finite, pick randomly from within the chain + if np.isfinite(act) and int(0.5 * nact * act) < len(u_list): + idx = np.random.randint(int(0.5 * nact * act), len(u_list)) u = u_list[idx] v = v_list[idx] logl = logl_list[idx] - else: - logger.warning( - "len(u_list)={}<{}: returning the last point in the chain".format( - len(u_list), int(factor * nact * act) - ) - ) + elif len(u_list) <= 2 and len(u_list) > 0: + #logger.warning("Returning the only point in the chain") u = u_list[-1] v = v_list[-1] logl = logl_list[-1] + elif len(u_list) == 0: + #logger.warning("No accepted points: returning a random draw") + u = np.random.uniform(size=du.shape[0]) + v = prior_transform(u) + logl = loglikelihood(v) + else: + idx = np.random.randint(int(len(u_list) / 2), len(u_list)) + #logger.warning("Returning random point in second half of the chain") + u = u_list[idx] + v = v_list[idx] + + + logl = logl_list[idx] blob = {"accept": accept, "reject": reject, "fail": nfail, "scale": scale} -- GitLab