Skip to content
Snippets Groups Projects

Add a mean-log-likelihood method to improve the ACT estimation

Merged Gregory Ashton requested to merge add-mean-log-like-to-ptemcee into master
All threads resolved!
Compare and Show latest version
1 file
+ 59
18
Compare changes
  • Side-by-side
  • Inline
@@ -12,6 +12,7 @@ from collections import namedtuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.signal
from ..utils import logger
from .base_sampler import SamplerError, MCMCSampler
@@ -23,10 +24,11 @@ ConvergenceInputs = namedtuple(
"autocorr_c",
"autocorr_tol",
"autocorr_tau",
"gradient_tau",
"Q_tol",
"safety",
"burn_in_nact",
"thin_by_nact",
"frac_threshold",
"nsamples",
"ignore_keys_for_tau",
"min_tau",
@@ -65,11 +67,12 @@ class Ptemcee(MCMCSampler):
autocorr_tau:
The number of autocorrelation times to use in assessing if the
autocorrelation time is stable.
frac_threshold: float, (0.01)
The maximum fractional change in the autocorrelation for the last
autocorr_tau steps. If the fractional change exceeds this value,
sampling will continue until the estimate of the autocorrelation time
can be trusted.
gradient_tau: float, (0.05)
The maximum (smoothed) local gradient of the ACT estimate to allow.
This ensures the ACT estimate is stable before finishing sampling.
Q_tol: float (1.01)
The maximum between-chain to within-chain tolerance allowed (akin to
the Gelman-Rubin statistic).
min_tau: int, (1)
A minimum tau (autocorrelation time) to accept.
check_point_deltaT: float, (600)
@@ -114,14 +117,14 @@ class Ptemcee(MCMCSampler):
# Arguments used by ptemcee
default_kwargs = dict(
ntemps=10,
nwalkers=200,
nwalkers=100,
Tmax=None,
betas=None,
a=2.0,
adaptation_lag=10000,
adaptation_time=100,
random=None,
adapt=True,
adapt=False,
swap_ratios=False,
)
@@ -137,12 +140,13 @@ class Ptemcee(MCMCSampler):
resume=True,
nsamples=5000,
burn_in_nact=50,
thin_by_nact=1,
thin_by_nact=0.5,
autocorr_tol=50,
autocorr_c=5,
safety=1,
autocorr_tau=50,
frac_threshold=0.01,
gradient_tau=0.1,
Q_tol=1.02,
min_tau=1,
check_point_deltaT=600,
threads=1,
@@ -191,7 +195,8 @@ class Ptemcee(MCMCSampler):
safety=safety,
burn_in_nact=burn_in_nact,
thin_by_nact=thin_by_nact,
frac_threshold=frac_threshold,
gradient_tau=gradient_tau,
Q_tol=Q_tol,
nsamples=nsamples,
ignore_keys_for_tau=ignore_keys_for_tau,
min_tau=min_tau,
@@ -347,6 +352,7 @@ class Ptemcee(MCMCSampler):
self.sampler._betas = np.array(self.beta_list[-1])
self.tau_list = data["tau_list"]
self.tau_list_n = data["tau_list_n"]
self.Q_list = data["Q_list"]
self.time_per_check = data["time_per_check"]
# Initialize the pool
@@ -387,6 +393,7 @@ class Ptemcee(MCMCSampler):
self.beta_list = []
self.tau_list = []
self.tau_list_n = []
self.Q_list = []
self.time_per_check = []
self.pos0 = self.get_pos0()
@@ -471,6 +478,7 @@ class Ptemcee(MCMCSampler):
self.beta_list,
self.tau_list,
self.tau_list_n,
self.Q_list,
)
if stop:
@@ -545,6 +553,7 @@ class Ptemcee(MCMCSampler):
self.beta_list,
self.tau_list,
self.tau_list_n,
self.Q_list,
self.time_per_check,
)
@@ -600,6 +609,7 @@ def check_iteration(
beta_list,
tau_list,
tau_list_n,
Q_list,
):
""" Per-iteration logic to calculate the convergence check
@@ -653,6 +663,9 @@ def check_iteration(
tau_list.append(list(np.mean(tau_array, axis=0)))
tau_list_n.append(iteration)
Q = get_Q_convergence(samples)
Q_list.append(Q)
# Convert to an integer
tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau
@@ -670,13 +683,23 @@ def check_iteration(
nsamples_effective = int(nwalkers * (iteration - nburn) / thin)
# Calculate convergence boolean
converged = ci.nsamples < nsamples_effective
converged = Q < ci.Q_tol and ci.nsamples < nsamples_effective
# Calculate change in tau from previous iterations
check_taus = np.array(tau_list[-tau_int * ci.autocorr_tau :])
if not np.any(np.isnan(check_taus)) and check_taus.shape[0] > 2:
grad = np.max(np.gradient(check_taus, axis=0))
tau_usable = grad < ci.frac_threshold
lower_tau_index = np.max([0, -tau_int * ci.autocorr_tau])
check_taus = np.array(tau_list[lower_tau_index :])
GRAD_WINDOW_LENGTH = 11
if not np.any(np.isnan(check_taus)) and check_taus.shape[0] > GRAD_WINDOW_LENGTH:
# Estimate the maximum gradient
grad = np.max(scipy.signal.savgol_filter(
check_taus, axis=0, window_length=GRAD_WINDOW_LENGTH, polyorder=2, deriv=1))
if grad < ci.gradient_tau:
logger.debug("tau usable as grad < gradient_tau={}".format(ci.gradient_tau))
tau_usable = True
else:
logger.debug("tau not usable as grad > gradient_tau={}".format(ci.gradient_tau))
tau_usable = False
else:
logger.debug("ACT is nan")
grad = np.nan
@@ -700,11 +723,23 @@ def check_iteration(
grad,
tau_usable,
convergence_inputs,
Q
)
stop = converged and tau_usable
return stop, nburn, thin, tau_int, nsamples_effective
def get_Q_convergence(samples):
nwalkers, nsteps, ndim = samples.shape
W = np.mean(np.var(samples, axis=1), axis=0)
per_walker_mean = np.mean(samples, axis=1)
mean = np.mean(per_walker_mean, axis=0)
B = nsteps / (nwalkers - 1.) * np.sum((per_walker_mean - mean)**2, axis=0)
Vhat = (nsteps - 1) / nsteps * W + (nwalkers + 1) / (nwalkers * nsteps) * B
Q_per_dim = np.sqrt(Vhat / W)
return np.max(Q_per_dim)
def print_progress(
iteration,
sampler,
@@ -715,6 +750,7 @@ def print_progress(
grad,
tau_usable,
convergence_inputs,
Q,
):
# Setup acceptance string
acceptance = sampler.acceptance_fraction[0, :]
@@ -739,11 +775,14 @@ def print_progress(
tau_str = "{}(+{:0.2f})".format(tau_int, grad)
else:
tau_str = "{}({:0.2f})".format(tau_int, grad)
if tau_usable:
tau_str = "={}".format(tau_str)
else:
tau_str = "!{}".format(tau_str)
Q_str = "{:0.2f}".format(Q)
evals_per_check = sampler.nwalkers * sampler.ntemps * convergence_inputs.niterations_per_check
ncalls = "{:1.1e}".format(
@@ -752,7 +791,7 @@ def print_progress(
samp_timing = "{:1.1f}ms/sm".format(1e3 * ave_time_per_check / samples_per_check)
print(
"{}| {}| nc:{}| a0:{}| swp:{}| n:{}<{}| tau{}| {}| {}".format(
"{}|{}|nc:{}|a0:{}|swp:{}|n:{}<{}|t{}|q:{}|{}|{}".format(
iteration,
str(sampling_time).split(".")[0],
ncalls,
@@ -761,6 +800,7 @@ def print_progress(
nsamples_effective,
convergence_inputs.nsamples,
tau_str,
Q_str,
eval_timing,
samp_timing,
),
@@ -784,6 +824,7 @@ def checkpoint(
beta_list,
tau_list,
tau_list_n,
Q_list,
time_per_check,
):
logger.info("Writing checkpoint and diagnostics")
@@ -808,6 +849,7 @@ def checkpoint(
beta_list=beta_list,
tau_list=tau_list,
tau_list_n=tau_list_n,
Q_list=Q_list,
time_per_check=time_per_check,
log_likelihood_array=log_likelihood_array,
chain_array=chain_array,
@@ -859,7 +901,6 @@ def plot_tau(
fig, ax = plt.subplots()
for i, key in enumerate(search_parameter_keys):
ax.plot(tau_list_n, np.array(tau_list)[:, i], label=key)
ax.axvline(tau_list_n[-1] - tau * autocorr_tau)
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\langle \tau \rangle$")
ax.legend()
Loading