Skip to content
Snippets Groups Projects

Improve ptemcee

Merged Gregory Ashton requested to merge improve-ptemcee into master
All threads resolved!
Compare and Show latest version
1 file
+ 220
117
Compare changes
  • Side-by-side
  • Inline
+ 220
117
@@ -13,7 +13,7 @@ import pandas as pd
@@ -13,7 +13,7 @@ import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
from ..utils import logger
from ..utils import logger
from .base_sampler import MCMCSampler
from .base_sampler import SamplerError, MCMCSampler
class Ptemcee(MCMCSampler):
class Ptemcee(MCMCSampler):
@@ -41,15 +41,15 @@ class Ptemcee(MCMCSampler):
@@ -41,15 +41,15 @@ class Ptemcee(MCMCSampler):
ntemps=20, nwalkers=200, Tmax=None, betas=None,
ntemps=20, nwalkers=200, Tmax=None, betas=None,
a=2.0, loglargs=[], logpargs=[], loglkwargs={},
a=2.0, loglargs=[], logpargs=[], loglkwargs={},
logpkwargs={}, adaptation_lag=10000, adaptation_time=100, random=None,
logpkwargs={}, adaptation_lag=10000, adaptation_time=100, random=None,
iterations=1000, thin=1, storechain=True, adapt=False,
adapt=False, swap_ratios=False)
swap_ratios=False)
def __init__(self, likelihood, priors, outdir='outdir', label='label',
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
use_ratio=False, check_point_plot=True, skip_import_verification=False,
resume=True, nsamples=5000, burn_in_nact=50, thin_by_nact=1,
resume=True, nsamples=5000, burn_in_nact=50, thin_by_nact=1,
autocorr_c=5, safety=1, ncheck=50, nfrac=5, frac_threshold=0.01,
autocorr_c=5, safety=1, frac_threshold=0.01,
autocorr_tol=50, min_tau=1, check_point_deltaT=600,
autocorr_tol=50, autocorr_tau=5, min_tau=1, check_point_deltaT=600,
threads=1, **kwargs):
threads=1, exit_code=77, plot=False, store_walkers=False,
 
ignore_keys_for_tau="recalib", pos0="prior", **kwargs):
super(Ptemcee, self).__init__(
super(Ptemcee, self).__init__(
likelihood=likelihood, priors=priors, outdir=outdir,
likelihood=likelihood, priors=priors, outdir=outdir,
label=label, use_ratio=use_ratio, plot=plot,
label=label, use_ratio=use_ratio, plot=plot,
@@ -60,25 +60,29 @@ class Ptemcee(MCMCSampler):
@@ -60,25 +60,29 @@ class Ptemcee(MCMCSampler):
signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
self.resume = resume
self.resume = resume
self.ncheck = ncheck
self.autocorr_c = autocorr_c
self.autocorr_c = autocorr_c
self.safety = safety
self.safety = safety
self.burn_in_nact = burn_in_nact
self.burn_in_nact = burn_in_nact
self.thin_by_nact = thin_by_nact
self.thin_by_nact = thin_by_nact
self.nfrac = nfrac
self.frac_threshold = frac_threshold
self.frac_threshold = frac_threshold
self.nsamples = nsamples
self.nsamples = nsamples
self.autocorr_tol = autocorr_tol
self.autocorr_tol = autocorr_tol
 
self.autocorr_tau = autocorr_tau
self.min_tau = min_tau
self.min_tau = min_tau
self.check_point_deltaT = check_point_deltaT
self.check_point_deltaT = check_point_deltaT
self.threads = threads
self.threads = threads
 
self.store_walkers = store_walkers
 
self.ignore_keys_for_tau = ignore_keys_for_tau
 
self.pos0 = pos0
 
self.check_point_plot = check_point_plot
self.resume_file = "{}/{}_checkpoint_resume.pickle".format(self.outdir, self.label)
self.resume_file = "{}/{}_checkpoint_resume.pickle".format(self.outdir, self.label)
 
self.exit_code = exit_code
@property
@property
def sampler_function_kwargs(self):
def sampler_function_kwargs(self):
keys = ['iterations', 'thin', 'storechain', 'adapt', 'swap_ratios']
keys = ['adapt', 'swap_ratios']
return {key: self.kwargs[key] for key in keys}
return {key: self.kwargs[key] for key in keys}
@property
@property
@@ -94,62 +98,145 @@ class Ptemcee(MCMCSampler):
@@ -94,62 +98,145 @@ class Ptemcee(MCMCSampler):
for _ in range(self.sampler_init_kwargs["nwalkers"])]
for _ in range(self.sampler_init_kwargs["nwalkers"])]
for _ in range(self.kwargs['ntemps'])]
for _ in range(self.kwargs['ntemps'])]
def get_sampler(self):
def get_pos0_from_minimize(self, minimize_list=None):
 
logger.info("Attempting to set pos0 from minimize")
 
from scipy.optimize import minimize
 
 
if minimize_list is None:
 
minimize_list = self.search_parameter_keys
 
pos0 = np.zeros((self.kwargs["ntemps"], self.kwargs["nwalkers"], self.ndim))
 
else:
 
pos0 = np.array(self.get_pos0_from_prior())
 
 
likelihood_copy = copy.copy(self.likelihood)
 
 
def neg_log_like(params):
 
likelihood_copy.parameters.update(
 
{key: val for key, val in zip(minimize_list, params)})
 
try:
 
return -likelihood_copy.log_likelihood()
 
except RuntimeError:
 
return +np.inf
 
bounds = [(self.priors[key].minimum, self.priors[key].maximum)
 
for key in minimize_list]
 
trials = 0
 
success = []
 
while True:
 
draw = self.priors.sample()
 
likelihood_copy.parameters.update(draw)
 
x0 = [draw[key] for key in minimize_list]
 
res = minimize(
 
neg_log_like, x0, bounds=bounds, method='L-BFGS-B', tol=1e-15)
 
if res.success:
 
success.append(res.x)
 
if trials > 100:
 
raise SamplerError("Unable to set pos0 from minimize")
 
if len(success) >= 10:
 
break
 
 
success = np.array(success)
 
for i, key in enumerate(minimize_list):
 
pos0_min = np.min(success[:, i])
 
pos0_max = np.max(success[:, i])
 
logger.info("Initialize {} walkers from {}->{}"
 
.format(key, pos0_min, pos0_max))
 
j = self.search_parameter_keys.index(key)
 
pos0[:, :, j] = np.random.uniform(
 
pos0_min, pos0_max,
 
size=(self.kwargs["ntemps"], self.kwargs["nwalkers"]))
 
return pos0
 
 
def setup_sampler(self):
import ptemcee
import ptemcee
if os.path.isfile(self.resume_file) and self.resume is True:
if os.path.isfile(self.resume_file) and self.resume is True:
logger.info("Resume data {} found".format(self.resume_file))
logger.info("Resume data {} found".format(self.resume_file))
with open(self.resume_file, "rb") as file:
with open(self.resume_file, "rb") as file:
data = dill.load(file)
data = dill.load(file)
 
self.sampler = data["sampler"]
self.sampler = data["sampler"]
self.sampler.pool = None
self.tau_list = data["tau_list"]
self.tau_list = data["tau_list"]
self.tau_list_n = data["tau_list_n"]
self.tau_list_n = data["tau_list_n"]
 
self.time_per_check = data["time_per_check"]
 
 
self.sampler.pool = self.pool
 
self.sampler.threads = self.threads
 
pos0 = None
pos0 = None
 
logger.info("Resuming from previous run with time={}".format(self.sampler.time))
logger.info("Resuming from previous run with time={}".format(self.sampler.time))
else:
else:
self.sampler = ptemcee.Sampler(
# Initialize the PTSampler
dim=self.ndim, logl=do_nothing_function, logp=do_nothing_function,
if self.threads == 1:
pool=self.pool, threads=self.threads, **self.sampler_init_kwargs)
self.sampler = ptemcee.Sampler(
self.sampler._likeprior = LikePriorEvaluator(
dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior,
self.search_parameter_keys, use_ratio=self.use_ratio)
**self.sampler_init_kwargs)
pos0 = self.get_pos0_from_prior()
else:
 
self.sampler = ptemcee.Sampler(
 
dim=self.ndim, logl=do_nothing_function, logp=do_nothing_function,
 
pool=self.pool, threads=self.threads, **self.sampler_init_kwargs)
 
 
self.sampler._likeprior = LikePriorEvaluator(
 
self.search_parameter_keys, use_ratio=self.use_ratio)
 
 
# Set up empty lists
 
self.tau_list = []
 
self.tau_list_n = []
 
self.time_per_check = []
 
 
# Initialize the walker postitions
 
pos0 = self.get_pos0()
return self.sampler, pos0
return self.sampler, pos0
def run_sampler(self):
def get_pos0(self):
import schwimmbad
if isinstance(self.pos0, str) and self.pos0.lower() == "prior":
 
return self.get_pos0_from_prior()
 
elif isinstance(self.pos0, str) and self.pos0.lower() == "minimize":
 
return self.get_pos0_from_minimize()
 
elif isinstance(self.pos0, list):
 
return self.get_pos0_from_minimize(minimize_list=self.pos0)
 
else:
 
raise SamplerError("pos0={} not implemented".format(self.pos0))
 
 
def setup_pool(self):
if self.threads > 1:
if self.threads > 1:
 
import schwimmbad
logger.info("Creating MultiPool with {} processes".format(self.threads))
logger.info("Creating MultiPool with {} processes".format(self.threads))
with schwimmbad.MultiPool(self.threads, initializer=init,
self.pool = schwimmbad.MultiPool(
initargs=(self.likelihood, self.priors)) as pool:
self.threads,
self.pool = pool
initializer=init,
return self.run_sampler_internal()
initargs=(self.likelihood, self.priors))
else:
else:
self.pool = None
self.pool = None
return self.run_sampler_internal()
 
def run_sampler(self):
 
self.setup_pool()
 
out = self.run_sampler_internal()
 
if self.pool:
 
self.pool.close()
 
return out
def run_sampler_internal(self):
def run_sampler_internal(self):
import emcee
import emcee
sampler, pos0 = self.get_sampler()
sampler, pos0 = self.setup_sampler()
self.time_per_check = []
self.tau_list = []
self.tau_list_n = []
t0 = datetime.datetime.now()
t0 = datetime.datetime.now()
logger.info("Starting to sample")
logger.info("Starting to sample")
for (pos0, lnprob, lnlike) in sampler.sample(
while True:
pos0, **self.sampler_function_kwargs):
for (pos0, _, _) in sampler.sample(pos0, **self.sampler_function_kwargs):
# Only check convergence every ncheck steps
pass
if sampler.time % self.ncheck:
continue
# Calculate time per iteration
 
self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
 
t0 = datetime.datetime.now()
# Compute ACT tau for 0-temperature chains
# Compute ACT tau for 0-temperature chains
samples = sampler.chain[0, :, : sampler.time, :]
samples = sampler.chain[0, :, : sampler.time, :]
taus = []
taus = []
for ii in range(sampler.nwalkers):
for ii in range(sampler.nwalkers):
for jj, key in enumerate(self.search_parameter_keys):
for jj, key in enumerate(self.search_parameter_keys):
if "recalib" in key:
if self.ignore_keys_for_tau and self.ignore_keys_for_tau in key:
continue
continue
try:
try:
taus.append(
taus.append(
@@ -163,52 +250,58 @@ class Ptemcee(MCMCSampler):
@@ -163,52 +250,58 @@ class Ptemcee(MCMCSampler):
# Apply multiplicitive safety factor
# Apply multiplicitive safety factor
tau = self.safety * np.mean(taus)
tau = self.safety * np.mean(taus)
if np.isnan(tau) or np.isinf(tau):
# Store for convergence checking and plotting
print("{} | Unable to use tau={}".format(sampler.time, tau), flush=True)
continue
# Convert to an integer and store for plotting
tau = int(tau)
self.tau_list.append(tau)
self.tau_list.append(tau)
self.tau_list_n.append(sampler.time)
self.tau_list_n.append(sampler.time)
 
# Convert to an integer
 
tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau
 
 
if np.isnan(tau_int) or np.isinf(tau_int):
 
print_progress(
 
self.sampler,
 
self.time_per_check,
 
self.nsamples,
 
np.nan,
 
np.nan,
 
np.nan,
 
[np.nan],
 
False)
 
continue
 
# Calculate the effective number of samples available
# Calculate the effective number of samples available
self.nburn = int(self.burn_in_nact * tau)
self.nburn = int(self.burn_in_nact * tau_int)
self.thin = int(np.max([1, self.thin_by_nact * tau]))
self.thin = int(np.max([1, self.thin_by_nact * tau_int]))
samples_per_check = self.ncheck * sampler.nwalkers / self.thin
samples_per_check = sampler.nwalkers / self.thin
self.nsamples_effective = int(sampler.nwalkers * (sampler.time - self.nburn) / self.thin)
self.nsamples_effective = int(sampler.nwalkers * (sampler.time - self.nburn) / self.thin)
# Calculate fractional change in tau from previous iteration
frac = (tau - np.array(self.tau_list)[-self.nfrac - 1: -1]) / tau
passes = frac < self.frac_threshold
# Calculate convergence boolean
# Calculate convergence boolean
converged = self.nsamples < self.nsamples_effective
converged = self.nsamples < self.nsamples_effective
converged &= np.all(passes)
if sampler.time < tau * self.autocorr_tol or tau < self.min_tau:
# Calculate fractional change in tau from previous iterations
converged = False
check_taus = np.array(self.tau_list[-tau_int * self.autocorr_tau:])
tau_pass = False
if not np.any(np.isnan(check_taus)):
 
frac = (tau - check_taus) / tau
 
tau_usable = np.all(frac < self.frac_threshold)
else:
else:
tau_pass = True
tau_usable = False
# Calculate time per iteration
if sampler.time < tau_int * self.autocorr_tol or tau_int < self.min_tau:
self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
tau_usable = False
t0 = datetime.datetime.now()
# Print an update on the progress
# Print an update on the progress
print_progress(
print_progress(
self.sampler,
self.sampler,
self.ncheck,
self.time_per_check,
self.time_per_check,
self.nsamples,
self.nsamples,
self.nsamples_effective,
self.nsamples_effective,
samples_per_check,
samples_per_check,
passes,
tau_int,
tau,
check_taus,
tau_pass,
tau_usable,
)
)
if converged:
if converged and tau_usable:
logger.info("Finished sampling")
logger.info("Finished sampling")
break
break
@@ -219,22 +312,13 @@ class Ptemcee(MCMCSampler):
@@ -219,22 +312,13 @@ class Ptemcee(MCMCSampler):
last_checkpoint_s = np.sum(self.time_per_check)
last_checkpoint_s = np.sum(self.time_per_check)
if last_checkpoint_s > self.check_point_deltaT:
if last_checkpoint_s > self.check_point_deltaT:
self.write_current_state()
self.write_current_state(plot=self.check_point_plot)
# Check if we reached the end without converging
if sampler.time == self.sampler_function_kwargs["iterations"]:
raise ValueError(
"Failed to reach convergence by iterations={}".format(
self.sampler_function_kwargs["iterations"]
)
)
# Run a final checkpoint to update the plots and samples
# Run a final checkpoint to update the plots and samples
self.write_current_state()
self.write_current_state(plot=self.check_point_plot)
# Get 0-likelihood samples and store in the result
# Get 0-likelihood samples and store in the result
samples = sampler.chain[0, :, :, :] # nwalkers, nsteps, ndim
samples = sampler.chain[0, :, :, :] # nwalkers, nsteps, ndim
self.result.walkers = samples[:, :sampler.time:, :]
self.result.samples = (
self.result.samples = (
samples[:, self.nburn: sampler.time:self.thin, :].reshape((-1, self.ndim)))
samples[:, self.nburn: sampler.time:self.thin, :].reshape((-1, self.ndim)))
loglikelihood = sampler.loglikelihood[
loglikelihood = sampler.loglikelihood[
@@ -242,7 +326,8 @@ class Ptemcee(MCMCSampler):
@@ -242,7 +326,8 @@ class Ptemcee(MCMCSampler):
] # nwalkers, nsteps
] # nwalkers, nsteps
self.result.log_likelihood_evaluations = loglikelihood.reshape((-1))
self.result.log_likelihood_evaluations = loglikelihood.reshape((-1))
self.result.walkers = self.sampler.chain
if self.store_walkers:
 
self.result.walkers = self.sampler.chain
self.result.nburn = self.nburn
self.result.nburn = self.nburn
log_evidence, log_evidence_err = compute_evidence(
log_evidence, log_evidence_err = compute_evidence(
@@ -257,26 +342,40 @@ class Ptemcee(MCMCSampler):
@@ -257,26 +342,40 @@ class Ptemcee(MCMCSampler):
def write_current_state_and_exit(self, signum=None, frame=None):
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
logger.warning("Run terminated with signal {}".format(signum))
self.write_current_state()
if getattr(self, 'pool', None):
sys.exit(77)
self.write_current_state(plot=False)
 
logger.warning("Closing pool")
 
self.pool.close()
 
sys.exit(self.exit_code)
def write_current_state(self):
def write_current_state(self, plot=True):
checkpoint(self.outdir, self.label, self.nsamples_effective,
checkpoint(self.outdir, self.label, self.nsamples_effective,
self.sampler, self.nburn, self.thin,
self.sampler, self.nburn, self.thin,
self.search_parameter_keys, self.resume_file, self.tau_list,
self.search_parameter_keys, self.resume_file, self.tau_list,
self.tau_list_n)
self.tau_list_n, self.time_per_check)
 
 
if plot and not np.isnan(self.nburn):
 
# Generate the walkers plot diagnostic
 
plot_walkers(
 
self.sampler.chain[0, :, : self.sampler.time, :],
 
self.nburn, self.thin, self.search_parameter_keys, self.outdir,
 
self.label
 
)
 
 
# Generate the tau plot diagnostic
 
plot_tau(self.tau_list_n, self.tau_list, self.outdir, self.label,
 
self.autocorr_tau)
def print_progress(
def print_progress(
sampler,
sampler,
ncheck,
time_per_check,
time_per_check,
nsamples,
nsamples,
nsamples_effective,
nsamples_effective,
samples_per_check,
samples_per_check,
passes,
tau_int,
tau,
tau_list,
tau_pass,
tau_usable,
):
):
# Setup acceptance string
# Setup acceptance string
acceptance = sampler.acceptance_fraction[0, :]
acceptance = sampler.acceptance_fraction[0, :]
@@ -290,30 +389,31 @@ def print_progress(
@@ -290,30 +389,31 @@ def print_progress(
ave_time_per_check = np.mean(time_per_check[-3:])
ave_time_per_check = np.mean(time_per_check[-3:])
time_left = (
time_left = (
(nsamples - nsamples_effective)
(nsamples - nsamples_effective) * ave_time_per_check / samples_per_check
* ave_time_per_check
/ samples_per_check
)
)
if time_left > 0:
if time_left > 0:
time_left = str(datetime.timedelta(seconds=int(time_left)))
time_left = str(datetime.timedelta(seconds=int(time_left)))
else:
else:
time_left = "waiting on convergence"
time_left = "waiting on convergence"
convergence = "".join([["F", "T"][i] for i in passes])
sampling_time = datetime.timedelta(seconds=np.sum(time_per_check))
tau_str = str(tau)
tau_str = "{}:{:0.1f}->{:0.1f}".format(tau_int, np.min(tau_list), np.max(tau_list))
if tau_pass is False:
if tau_usable:
tau_str = tau_str + "(F)"
tau_str = "={}".format(tau_str)
 
else:
 
tau_str = "!{}".format(tau_str)
evals_per_check = sampler.nwalkers * sampler.ntemps * ncheck
evals_per_check = sampler.nwalkers * sampler.ntemps
ncalls = "{:1.1e}".format(sampler.time * sampler.nwalkers * sampler.ntemps)
ncalls = "{:1.1e}".format(sampler.time * sampler.nwalkers * sampler.ntemps)
eval_timing = "{:1.1f}ms/evl".format(1e3 * ave_time_per_check / evals_per_check)
eval_timing = "{:1.1f}ms/ev".format(1e3 * ave_time_per_check / evals_per_check)
samp_timing = "{:1.2f}ms/smp".format(1e3 * ave_time_per_check / samples_per_check)
samp_timing = "{:1.1f}ms/sm".format(1e3 * ave_time_per_check / samples_per_check)
print(
print(
"{}| nc:{}| a0:{}| swp:{}| n:{}<{}| tau:{}| {}| {}| {}".format(
"{}| {} | nc:{}| a0:{}| swp:{}| n:{}<{}| tau{}| {}| {}".format(
sampler.time,
sampler.time,
 
str(sampling_time).split(".")[0],
ncalls,
ncalls,
acceptance_str,
acceptance_str,
tswap_acceptance_str,
tswap_acceptance_str,
@@ -322,14 +422,14 @@ def print_progress(
@@ -322,14 +422,14 @@ def print_progress(
tau_str,
tau_str,
eval_timing,
eval_timing,
samp_timing,
samp_timing,
convergence,
),
),
flush=True,
flush=True,
)
)
def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
search_parameter_keys, resume_file, tau_list, tau_list_n):
search_parameter_keys, resume_file, tau_list, tau_list_n,
 
time_per_check):
logger.info("Writing checkpoint and diagnostics")
logger.info("Writing checkpoint and diagnostics")
ndim = sampler.dim
ndim = sampler.dim
@@ -349,49 +449,52 @@ def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
@@ -349,49 +449,52 @@ def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
sampler_copy._logposterior = sampler._logposterior[:, :, : sampler.time]
sampler_copy._logposterior = sampler._logposterior[:, :, : sampler.time]
sampler_copy._loglikelihood = sampler._loglikelihood[:, :, : sampler.time]
sampler_copy._loglikelihood = sampler._loglikelihood[:, :, : sampler.time]
sampler_copy._beta_history = sampler._beta_history[:, : sampler.time]
sampler_copy._beta_history = sampler._beta_history[:, : sampler.time]
data = dict(sampler=sampler_copy, tau_list=tau_list, tau_list_n=tau_list_n)
 
data = dict(
 
sampler=sampler_copy, tau_list=tau_list, tau_list_n=tau_list_n,
 
time_per_check=time_per_check)
 
with open(resume_file, "wb") as file:
with open(resume_file, "wb") as file:
dill.dump(data, file, protocol=4)
dill.dump(data, file, protocol=4)
del data, sampler_copy
del data, sampler_copy
logger.info("Finished writing checkpoint")
# Generate the walkers plot diagnostic
plot_walkers(
sampler.chain[0, :, : sampler.time, :], nburn, search_parameter_keys, outdir, label
)
# Generate the tau plot diagnostic
plot_tau(tau_list_n, tau_list, outdir, label)
logger.info("Finished writing checkpoint and diagnostics")
def plot_walkers(walkers, nburn, parameter_labels, outdir, label):
def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
""" Method to plot the trace of the walkers in an ensemble MCMC plot """
""" Method to plot the trace of the walkers in an ensemble MCMC plot """
nwalkers, nsteps, ndim = walkers.shape
nwalkers, nsteps, ndim = walkers.shape
idxs = np.arange(nsteps)
idxs = np.arange(nsteps)
fig, axes = plt.subplots(nrows=ndim, figsize=(6, 3 * ndim))
fig, axes = plt.subplots(nrows=ndim, ncols=2, figsize=(8, 3 * ndim))
scatter_kwargs = dict(lw=0, marker="o", markersize=1, alpha=0.05)
scatter_kwargs = dict(lw=0, marker="o", markersize=1)
for i, ax in enumerate(axes):
# Plot the burn-in
 
for i, (ax, axh) in enumerate(axes):
ax.plot(
ax.plot(
idxs[: nburn + 1], walkers[:, : nburn + 1, i].T, color="r", **scatter_kwargs
idxs[: nburn + 1], walkers[:, : nburn + 1, i].T, color="C1", **scatter_kwargs
)
)
ax.set_ylabel(parameter_labels[i])
for i, ax in enumerate(axes):
# Plot the thinned posterior samples
ax.plot(idxs[nburn:], walkers[:, nburn:, i].T, color="k", **scatter_kwargs)
for i, (ax, axh) in enumerate(axes):
 
ax.plot(idxs[nburn::thin], walkers[:, nburn::thin, i].T, color="C0", **scatter_kwargs)
 
axh.hist(walkers[:, nburn::thin, i].reshape((-1)), bins=50, alpha=0.8)
 
axh.set_xlabel(parameter_labels[i])
 
ax.set_ylabel(parameter_labels[i])
fig.tight_layout()
fig.tight_layout()
filename = "{}/{}_traceplot.png".format(outdir, label)
filename = "{}/{}_checkpoint_trace.png".format(outdir, label)
fig.savefig(filename)
fig.savefig(filename)
plt.close(fig)
plt.close(fig)
def plot_tau(tau_list_n, tau_list, outdir, label):
def plot_tau(tau_list_n, tau_list, outdir, label, autocorr_tau):
fig, ax = plt.subplots()
fig, ax = plt.subplots()
ax.plot(tau_list_n, tau_list, "-x")
ax.plot(tau_list_n, tau_list, "-", color='C1')
 
check_tau_idx = -int(tau_list[-1] * autocorr_tau)
 
check_taus = tau_list[check_tau_idx:]
 
check_taus_n = tau_list_n[check_tau_idx:]
 
ax.plot(check_taus_n, check_taus, "-", color='C0')
ax.set_xlabel("Iteration")
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\langle \tau \rangle$")
ax.set_ylabel(r"$\langle \tau \rangle$")
fig.savefig("{}/{}_tau.png".format(outdir, label))
fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label))
plt.close(fig)
plt.close(fig)
Loading