Skip to content
Snippets Groups Projects

Add a mean-log-likelihood method to improve the ACT estimation

Merged Gregory Ashton requested to merge add-mean-log-like-to-ptemcee into master
All threads resolved!
Compare and
2 files
+ 182
65
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 175
58
@@ -12,6 +12,7 @@ from collections import namedtuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.signal
from ..utils import logger, check_directory_exists_and_if_not_mkdir
from .base_sampler import SamplerError, MCMCSampler
@@ -23,10 +24,12 @@ ConvergenceInputs = namedtuple(
"autocorr_c",
"autocorr_tol",
"autocorr_tau",
"gradient_tau",
"Q_tol",
"safety",
"burn_in_nact",
"burn_in_fixed_discard",
"thin_by_nact",
"frac_threshold",
"nsamples",
"ignore_keys_for_tau",
"min_tau",
@@ -53,6 +56,8 @@ class Ptemcee(MCMCSampler):
The number of burn-in autocorrelation times to discard and the thin-by
factor. Increasing burn_in_nact increases the time required for burn-in.
Increasing thin_by_nact increases the time required to obtain nsamples.
burn_in_fixed_discard: TBD
autocorr_tol: int, (50)
The minimum number of autocorrelation times needed to trust the
estimate of the autocorrelation time.
@@ -62,14 +67,15 @@ class Ptemcee(MCMCSampler):
A multiplicitive factor for the estimated autocorrelation. Useful for
cases where non-convergence can be observed by eye but the automated
tools are failing.
autocorr_tau:
autocorr_tau: int, (1)
The number of autocorrelation times to use in assessing if the
autocorrelation time is stable.
frac_threshold: float, (0.01)
The maximum fractional change in the autocorrelation for the last
autocorr_tau steps. If the fractional change exceeds this value,
sampling will continue until the estimate of the autocorrelation time
can be trusted.
gradient_tau: float, (0.1)
The maximum (smoothed) local gradient of the ACT estimate to allow.
This ensures the ACT estimate is stable before finishing sampling.
Q_tol: float (1.01)
The maximum between-chain to within-chain tolerance allowed (akin to
the Gelman-Rubin statistic).
min_tau: int, (1)
A minimum tau (autocorrelation time) to accept.
check_point_deltaT: float, (600)
@@ -79,7 +85,7 @@ class Ptemcee(MCMCSampler):
exit_code: int, (77)
The code on which the sampler exits.
store_walkers: bool (False)
If true, store the unthinned, unburnt chaines in the result. Note, this
If true, store the unthinned, unburnt chains in the result. Note, this
is not recommended for cases where tau is large.
ignore_keys_for_tau: str
A pattern used to ignore keys in estimating the autocorrelation time.
@@ -90,6 +96,12 @@ class Ptemcee(MCMCSampler):
The walkers are then initialized from the range of values obtained.
If a list, for the keys in the list the optimization step is applied,
otherwise the initial points are drawn from the prior.
niterations_per_check: int (5)
The number of iteration steps to take before checking ACT. This
effectively pre-thins the chains. Larger values reduce the per-eval
timing due to improved efficiency. But, if it is made too large the
pre-thinning may be overly agressive effectively wasting compute-time.
If you see tau=1, then niterations_per_check is likely too large.
Other Parameters
@@ -98,7 +110,7 @@ class Ptemcee(MCMCSampler):
The number of walkers
nsteps: int, (100)
The number of steps to take
ntemps: int (2)
ntemps: int (10)
The number of temperatures used by ptemcee
Tmax: float
The maximum temperature
@@ -107,15 +119,15 @@ class Ptemcee(MCMCSampler):
# Arguments used by ptemcee
default_kwargs = dict(
ntemps=20,
nwalkers=200,
ntemps=10,
nwalkers=100,
Tmax=None,
betas=None,
a=2.0,
adaptation_lag=10000,
adaptation_time=100,
random=None,
adapt=True,
adapt=False,
swap_ratios=False,
)
@@ -130,13 +142,15 @@ class Ptemcee(MCMCSampler):
skip_import_verification=False,
resume=True,
nsamples=5000,
burn_in_nact=10,
burn_in_nact=50,
burn_in_fixed_discard=1000,
thin_by_nact=0.5,
autocorr_tol=50,
autocorr_c=5,
safety=1,
autocorr_tau=5,
frac_threshold=0.01,
autocorr_tau=1,
gradient_tau=0.1,
Q_tol=1.02,
min_tau=1,
check_point_deltaT=600,
threads=1,
@@ -145,7 +159,7 @@ class Ptemcee(MCMCSampler):
store_walkers=False,
ignore_keys_for_tau=None,
pos0="prior",
niterations_per_check=10,
niterations_per_check=5,
**kwargs
):
super(Ptemcee, self).__init__(
@@ -184,14 +198,17 @@ class Ptemcee(MCMCSampler):
autocorr_tau=autocorr_tau,
safety=safety,
burn_in_nact=burn_in_nact,
burn_in_fixed_discard=burn_in_fixed_discard,
thin_by_nact=thin_by_nact,
frac_threshold=frac_threshold,
gradient_tau=gradient_tau,
Q_tol=Q_tol,
nsamples=nsamples,
ignore_keys_for_tau=ignore_keys_for_tau,
min_tau=min_tau,
niterations_per_check=niterations_per_check,
)
self.convergence_inputs = ConvergenceInputs(**convergence_inputs_dict)
logger.info("Using convergence inputs: {}".format(self.convergence_inputs))
# Check if threads was given as an equivalent arg
if threads == 1:
@@ -340,6 +357,7 @@ class Ptemcee(MCMCSampler):
self.sampler._betas = np.array(self.beta_list[-1])
self.tau_list = data["tau_list"]
self.tau_list_n = data["tau_list_n"]
self.Q_list = data["Q_list"]
self.time_per_check = data["time_per_check"]
# Initialize the pool
@@ -380,6 +398,7 @@ class Ptemcee(MCMCSampler):
self.beta_list = []
self.tau_list = []
self.tau_list_n = []
self.Q_list = []
self.time_per_check = []
self.pos0 = self.get_pos0()
@@ -437,6 +456,8 @@ class Ptemcee(MCMCSampler):
self.pos0 = pos0
self.chain_array[:, self.iteration, :] = pos0[0, :, :]
self.log_likelihood_array[:, :, self.iteration] = log_likelihood
self.mean_log_likelihood = np.mean(
self.log_likelihood_array[:, :, :self. iteration], axis=1)
# Calculate time per iteration
self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
@@ -444,6 +465,8 @@ class Ptemcee(MCMCSampler):
self.iteration += 1
min_check_iteration = get_min_iteration_to_check(self.mean_log_likelihood)
(
stop,
self.nburn,
@@ -451,7 +474,8 @@ class Ptemcee(MCMCSampler):
self.tau_int,
self.nsamples_effective,
) = check_iteration(
self.chain_array[:, :self.iteration + 1, :],
self.iteration,
self.chain_array[:, min_check_iteration:self.iteration + 1, :],
sampler,
self.convergence_inputs,
self.search_parameter_keys,
@@ -459,6 +483,7 @@ class Ptemcee(MCMCSampler):
self.beta_list,
self.tau_list,
self.tau_list_n,
self.Q_list,
)
if stop:
@@ -534,6 +559,7 @@ class Ptemcee(MCMCSampler):
self.beta_list,
self.tau_list,
self.tau_list_n,
self.Q_list,
self.time_per_check,
)
@@ -546,6 +572,7 @@ class Ptemcee(MCMCSampler):
self.search_parameter_keys,
self.outdir,
self.label,
self.convergence_inputs.burn_in_fixed_discard
)
# Generate the tau plot diagnostic
@@ -559,8 +586,28 @@ class Ptemcee(MCMCSampler):
self.convergence_inputs.autocorr_tau,
)
plot_mean_log_likelihood(
self.mean_log_likelihood,
self.outdir,
self.label
)
def get_min_iteration_to_check(mean_log_likelihood, frac=0.1):
nsteps = mean_log_likelihood.shape[1]
if nsteps > 10:
zero_chain_mean_log_likelihood = mean_log_likelihood[0, :]
maxl = np.max(zero_chain_mean_log_likelihood)
fracdiff = (maxl - zero_chain_mean_log_likelihood) / maxl
idxs = fracdiff < frac
if np.sum(idxs) > 0:
min_it = np.min(np.arange(len(idxs))[idxs])
return min_it
return 0
def check_iteration(
iteration,
samples,
sampler,
convergence_inputs,
@@ -569,6 +616,7 @@ def check_iteration(
beta_list,
tau_list,
tau_list_n,
Q_list,
):
""" Per-iteration logic to calculate the convergence check
@@ -597,21 +645,26 @@ def check_iteration(
import emcee
ci = convergence_inputs
nwalkers, iteration, ndim = samples.shape
nwalkers, nsteps, ndim = samples.shape
# Compute ACT tau for 0-temperature chains
tau_array = np.zeros((nwalkers, ndim))
for ii in range(nwalkers):
for jj, key in enumerate(search_parameter_keys):
if ci.ignore_keys_for_tau and ci.ignore_keys_for_tau in key:
continue
try:
tau_array[ii, jj] = emcee.autocorr.integrated_time(
samples[ii, :, jj], c=ci.autocorr_c, tol=0)[0]
except emcee.autocorr.AutocorrError:
tau_array[ii, jj] = np.inf
if nsteps > ci.burn_in_fixed_discard:
for ii in range(nwalkers):
for jj, key in enumerate(search_parameter_keys):
if ci.ignore_keys_for_tau and ci.ignore_keys_for_tau in key:
continue
try:
# ADD A COMMENT
samples_to_calc = samples[ii, ci.burn_in_fixed_discard:, jj]
tau_array[ii, jj] = emcee.autocorr.integrated_time(
samples_to_calc, c=ci.autocorr_c, tol=0)[0]
except emcee.autocorr.AutocorrError:
tau_array[ii, jj] = np.inf
else:
tau_array += np.inf
# Maximum over paramters, mean over walkers
# Maximum over parameters, mean over walkers
tau = np.max(np.mean(tau_array, axis=0))
# Apply multiplicitive safety factor
@@ -622,37 +675,56 @@ def check_iteration(
tau_list.append(list(np.mean(tau_array, axis=0)))
tau_list_n.append(iteration)
# Convert to an integer
tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau
Q = get_Q_convergence(samples)
Q_list.append(Q)
if np.isnan(tau_int) or np.isinf(tau_int):
if np.isnan(tau) or np.isinf(tau):
print_progress(
iteration, sampler, time_per_check, np.nan, np.nan,
np.nan, np.nan, False, convergence_inputs,
np.nan, np.nan, False, convergence_inputs, np.nan,
)
return False, np.nan, np.nan, np.nan, np.nan
# Convert to an integer
tau_int = int(np.ceil(tau))
# Calculate the effective number of samples available
nburn = int(ci.burn_in_nact * tau_int)
nburn = ci.burn_in_fixed_discard + int(ci.burn_in_nact * tau_int)
thin = int(np.max([1, ci.thin_by_nact * tau_int]))
samples_per_check = nwalkers / thin
nsamples_effective = int(nwalkers * (iteration - nburn) / thin)
# Calculate convergence boolean
converged = ci.nsamples < nsamples_effective
# Calculate fractional change in tau from previous iteration
check_taus = np.array(tau_list[-tau_int * ci.autocorr_tau :])
taus_per_parameter = check_taus[-1, :]
if not np.any(np.isnan(check_taus)):
frac = (taus_per_parameter - check_taus) / taus_per_parameter
max_frac = np.max(frac)
tau_usable = np.all(frac < ci.frac_threshold)
converged = Q < ci.Q_tol and ci.nsamples < nsamples_effective
logger.debug("Convergence: Q<Q_tol={}, nsamples<nsamples_effective={}"
.format(Q < ci.Q_tol, ci.nsamples < nsamples_effective))
# Calculate change in tau from previous iterations
GRAD_WINDOW_LENGTH = 11
nsteps_to_check = ci.autocorr_tau * np.max([2 * GRAD_WINDOW_LENGTH, tau_int])
lower_tau_index = np.max([0, len(tau_list) - nsteps_to_check])
check_taus = np.array(tau_list[lower_tau_index :])
if not np.any(np.isnan(check_taus)) and check_taus.shape[0] > GRAD_WINDOW_LENGTH:
# Estimate the maximum gradient
grad = np.max(scipy.signal.savgol_filter(
check_taus, axis=0, window_length=GRAD_WINDOW_LENGTH, polyorder=2, deriv=1))
if grad < ci.gradient_tau:
logger.debug("tau usable as grad < gradient_tau={}".format(ci.gradient_tau))
tau_usable = True
else:
logger.debug("tau not usable as grad > gradient_tau={}".format(ci.gradient_tau))
tau_usable = False
else:
max_frac = np.nan
logger.debug("ACT is nan")
grad = np.nan
tau_usable = False
if iteration < tau_int * ci.autocorr_tol or tau_int < ci.min_tau:
if nsteps < tau_int * ci.autocorr_tol:
logger.debug("ACT less than autocorr_tol")
tau_usable = False
elif tau_int < ci.min_tau:
logger.debug("ACT less than min_tau")
tau_usable = False
# Print an update on the progress
@@ -663,14 +735,26 @@ def check_iteration(
nsamples_effective,
samples_per_check,
tau_int,
max_frac,
grad,
tau_usable,
convergence_inputs,
Q
)
stop = converged and tau_usable
return stop, nburn, thin, tau_int, nsamples_effective
def get_Q_convergence(samples):
nwalkers, nsteps, ndim = samples.shape
W = np.mean(np.var(samples, axis=1), axis=0)
per_walker_mean = np.mean(samples, axis=1)
mean = np.mean(per_walker_mean, axis=0)
B = nsteps / (nwalkers - 1.) * np.sum((per_walker_mean - mean)**2, axis=0)
Vhat = (nsteps - 1) / nsteps * W + (nwalkers + 1) / (nwalkers * nsteps) * B
Q_per_dim = np.sqrt(Vhat / W)
return np.max(Q_per_dim)
def print_progress(
iteration,
sampler,
@@ -678,17 +762,18 @@ def print_progress(
nsamples_effective,
samples_per_check,
tau_int,
max_frac,
grad,
tau_usable,
convergence_inputs,
Q,
):
# Setup acceptance string
acceptance = sampler.acceptance_fraction[0, :]
acceptance_str = "{:1.2f}->{:1.2f}".format(np.min(acceptance), np.max(acceptance))
acceptance_str = "{:1.2f}-{:1.2f}".format(np.min(acceptance), np.max(acceptance))
# Setup tswap acceptance string
tswap_acceptance_fraction = sampler.tswap_acceptance_fraction
tswap_acceptance_str = "{:1.2f}->{:1.2f}".format(
tswap_acceptance_str = "{:1.2f}-{:1.2f}".format(
np.min(tswap_acceptance_fraction), np.max(tswap_acceptance_fraction)
)
@@ -701,15 +786,18 @@ def print_progress(
sampling_time = datetime.timedelta(seconds=np.sum(time_per_check))
if max_frac >= 0:
tau_str = "{}(+{:0.2f})".format(tau_int, max_frac)
if grad < convergence_inputs.gradient_tau:
tau_str = "{}({:0.2f}<{})".format(tau_int, grad, convergence_inputs.gradient_tau)
else:
tau_str = "{}({:0.2f})".format(tau_int, max_frac)
tau_str = "{}({:0.2f}>{})".format(tau_int, grad, convergence_inputs.gradient_tau)
if tau_usable:
tau_str = "={}".format(tau_str)
else:
tau_str = "!{}".format(tau_str)
Q_str = "{:0.2f}".format(Q)
evals_per_check = sampler.nwalkers * sampler.ntemps * convergence_inputs.niterations_per_check
ncalls = "{:1.1e}".format(
@@ -718,7 +806,7 @@ def print_progress(
samp_timing = "{:1.1f}ms/sm".format(1e3 * ave_time_per_check / samples_per_check)
print(
"{}| {}| nc:{}| a0:{}| swp:{}| n:{}<{}| tau{}| {}| {}".format(
"{}|{}|nc:{}|a0:{}|swp:{}|n:{}<{}|t{}|q:{}|{}|{}".format(
iteration,
str(sampling_time).split(".")[0],
ncalls,
@@ -727,6 +815,7 @@ def print_progress(
nsamples_effective,
convergence_inputs.nsamples,
tau_str,
Q_str,
eval_timing,
samp_timing,
),
@@ -750,6 +839,7 @@ def checkpoint(
beta_list,
tau_list,
tau_list_n,
Q_list,
time_per_check,
):
logger.info("Writing checkpoint and diagnostics")
@@ -774,6 +864,7 @@ def checkpoint(
beta_list=beta_list,
tau_list=tau_list,
tau_list_n=tau_list_n,
Q_list=Q_list,
time_per_check=time_per_check,
log_likelihood_array=log_likelihood_array,
chain_array=chain_array,
@@ -786,17 +877,29 @@ def checkpoint(
logger.info("Finished writing checkpoint")
def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label,
nburn_fixed=0):
""" Method to plot the trace of the walkers in an ensemble MCMC plot """
nwalkers, nsteps, ndim = walkers.shape
idxs = np.arange(nsteps)
fig, axes = plt.subplots(nrows=ndim, ncols=2, figsize=(8, 3 * ndim))
scatter_kwargs = dict(lw=0, marker="o", markersize=1, alpha=0.05,)
# Plot the fixed burn-in
if nburn_fixed > 0:
for i, (ax, axh) in enumerate(axes):
ax.plot(
idxs[: nburn_fixed],
walkers[:, : nburn_fixed, i].T,
color="gray",
**scatter_kwargs
)
# Plot the burn-in
for i, (ax, axh) in enumerate(axes):
ax.plot(
idxs[: nburn + 1],
walkers[:, : nburn + 1, i].T,
idxs[nburn_fixed: nburn + 1],
walkers[:, nburn_fixed: nburn + 1, i].T,
color="C1",
**scatter_kwargs
)
@@ -820,19 +923,33 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
def plot_tau(
tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau
tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau,
):
fig, ax = plt.subplots()
for i, key in enumerate(search_parameter_keys):
ax.plot(tau_list_n, np.array(tau_list)[:, i], label=key)
ax.axvline(tau_list_n[-1] - tau * autocorr_tau)
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\langle \tau \rangle$")
ax.legend()
fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label))
plt.close(fig)
def plot_mean_log_likelihood(mean_log_likelihood, outdir, label):
ntemps, nsteps = mean_log_likelihood.shape
min_check_iteration = get_min_iteration_to_check(mean_log_likelihood)
fig, ax = plt.subplots()
idxs = np.arange(nsteps)
ax.plot(idxs, mean_log_likelihood.T)
ax.axvline(min_check_iteration)
fig.savefig("{}/{}_checkpoint_meanloglike.png".format(outdir, label))
plt.close(fig)
def compute_evidence(sampler, log_likelihood_array, outdir, label, nburn, thin,
iteration, make_plots=True):
""" Computes the evidence using thermodynamic integration """
Loading