Skip to content
Snippets Groups Projects
Commit 38c8bc0c authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Update ACT method and other changes

- Use CPNest ACT estimation
- Change scaling
- Add chains starting after accept > 0
parent 50543ac9
No related branches found
No related tags found
1 merge request!643Adding an ACT estimate for the walks
Pipeline #90931 passed
......@@ -6,7 +6,6 @@ import sys
import pickle
import signal
import emcee
import tqdm
import matplotlib.pyplot as plt
import numpy as np
......@@ -571,7 +570,7 @@ def sample_rwalk_bilby(args):
max_walk_warning = True
drhat, dr, du, u_prop, logl_prop = np.nan, np.nan, np.nan, np.nan, np.nan
while len(u_list) < nact * act or accept == 0 or len(u_list) < walks:
while len(u_list) < nact * act:
if scale == 0.:
raise RuntimeError("The random walk sampling is stuck! "
......@@ -593,7 +592,8 @@ def sample_rwalk_bilby(args):
drhat /= linalg.norm(drhat)
# Scale based on dimensionality.
dr = drhat * rstate.rand()**(1. / n)
# dr = drhat * rstate.rand()**(1. / n) # CHANGED FROM DYNESTY 1.0
dr = drhat * rstate.rand(n)
# Transform to proposal distribution.
du = np.dot(axes, dr)
......@@ -611,9 +611,11 @@ def sample_rwalk_bilby(args):
pass
else:
nfail += 1
u_list.append(u_list[-1])
v_list.append(v_list[-1])
logl_list.append(logl_list[-1])
# Only start appending to the chain once a single jump is made
if accept > 0:
u_list.append(u_list[-1])
v_list.append(v_list[-1])
logl_list.append(logl_list[-1])
continue
# Check if we're stuck generating bad numbers.
......@@ -637,12 +639,18 @@ def sample_rwalk_bilby(args):
logl_list.append(logl)
else:
reject += 1
u_list.append(u_list[-1])
v_list.append(v_list[-1])
logl_list.append(logl_list[-1])
# Only start appending to the chain once a single jump is made
if accept > 0:
u_list.append(u_list[-1])
v_list.append(v_list[-1])
logl_list.append(logl_list[-1])
if accept > walks:
act = np.max([1, autocorr_new(np.array(u_list).T)])
# If we've taken the minimum number of steps, calculate the ACT
if accept + reject > walks:
act = estimate_nmcmc(
accept / (accept + reject + nfail), walks, maxmcmc)
# If we've taken too many likelihood evaluations then break
if accept + reject > maxmcmc and accept > 0:
if max_walk_warning:
warnings.warn(
......@@ -662,11 +670,16 @@ def sample_rwalk_bilby(args):
"scale-factor accordingly.")
# If the act is finite, pick randomly from within the chain
if np.isfinite(act):
if np.isfinite(act) and act < len(u_list):
idx = np.random.randint(act, len(u_list))
u = u_list[idx]
v = v_list[idx]
logl = logl_list[idx]
elif len(u_list) == 1:
logger.warning("Returning the only point in the chain")
u = u_list[-1]
v = v_list[-1]
logl = logl_list[-1]
else:
idx = np.random.randint(int(len(u_list) / 2), len(u_list))
logger.warning("Returning random point in second half of the chain")
......@@ -680,17 +693,43 @@ def sample_rwalk_bilby(args):
return u, v, logl, ncall, blob
def autocorr_new(y, c=10.0):
f = np.zeros(y.shape[1])
for yy in y:
f += emcee.autocorr.function_1d(yy)
f /= len(y)
taus = 2.0 * np.cumsum(f) - 1.0
window = emcee.autocorr.auto_window(taus, c)
act = taus[window]
if np.isnan(act):
return np.inf
return act
def estimate_nmcmc(accept_ratio, minmcmc, maxmcmc, safety=5, tau=None):
""" Estimate autocorrelation length of chain using acceptance fraction
Using ACL = (2/acc) - 1 multiplied by a safety margin. Code adapated from
CPNest:
- https://github.com/johnveitch/cpnest/blob/master/cpnest/sampler.py
- http://github.com/farr/Ensemble.jl
Parameters
----------
accept_ratio: float [0, 1]
Ratio of the number of accepted points to the total number of points
minmcmc: int
The minimum length of the MCMC chain to use
maxmcmc: int
The maximum length of the MCMC chain to use
safety: int
A safety factor applied in the calculation
tau: int (optional)
The ACT, if given, otherwise estimated.
"""
if tau is None:
tau = maxmcmc / safety
if accept_ratio == 0.0:
Nmcmc_exact = (1. + 1. / tau) * minmcmc
else:
Nmcmc_exact = (
(1. - 1. / tau) * minmcmc +
(safety / tau) * (2. / accept_ratio - 1.)
)
Nmcmc_exact = float(min(Nmcmc_exact, maxmcmc))
Nmcmc = max(safety, int(Nmcmc_exact))
return Nmcmc
class DynestySetupError(Exception):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment