Commit 97a650ae authored by Daniel Wysocki's avatar Daniel Wysocki

Further updates to acorr plotting

parent 8760f568
Pipeline #153172 failed with stage
in 1 minute
......@@ -35,16 +35,20 @@ def acorr(X, lag, axis=0):
return result
_supported_acorr_methods = ("default", "max_likelihood")
def acorr_time(
chains,
tolerance=0.1,
initial_grid_size=100, max_retries=10,
method="default",
return_acorr=False,
):
import numpy
from ..utils import bisect_int
n_samples, n_walkers, n_dim = numpy.shape(chains)
# Initialize array that will hold the auto-correlation for each lag.
......
......@@ -5,7 +5,8 @@ from pop_models.posterior import (
H5CleanedPosteriorSamples,
)
from .max_posterior import compute_burnin, plot_burnins
from .acorr import full_acorr_analysis, plot_acorrs
#from .acorr import full_acorr_analysis, plot_acorrs
from . import integrated_acorr
def make_parser():
import argparse
......@@ -24,15 +25,15 @@ def make_parser():
)
acorr_group.add_argument(
"--acorr-plot",
help="Plot autocorrelation lengths to this file.",
help="Plot integrated autocorrelations to this file.",
)
parser.add_argument(
"--acorr-tolerance",
type=float, default=0.05,
help="Maximum auto-correlation to allow. Lag will be taken as the "
"shortest number of samples for which all parameters fall below "
"this auto-correlation. Default is 0.05.",
"--acorr-safety-factor",
metavar="M",
type=float, default=5.0,
help="Integrated auto-correlation safety factor to use if --acorr-plot "
"is provided.",
)
parser.add_argument(
"--lag-max",
......@@ -108,24 +109,38 @@ def main(raw_args=None):
# Determine auto-correlation and appropriate lag time.
if cli_args.acorr_plot is not None:
lags, acorrs, lag_tols = full_acorr_analysis(
f.get_samples(slice(n_burnin, None)),
lag_max=cli_args.lag_max, tolerance=cli_args.acorr_tolerance,
samples_post_burnin = f.get_samples(slice(n_burnin,None))
# posterior_log_prob_post_burnin = (
# f.get_posterior_log_prob(slice(n_burnin,None))
# )
(
n_acorr,
(acorrs, acorr_errs),
(acorr_times, acorr_time_errs),
) = (
integrated_acorr.acorr_time(
samples_post_burnin, retall=True,
m=cli_args.acorr_safety_factor,
)
)
n_thinning = numpy.max(lag_tols)
print(len(f.variable_names), len(lag_tols))
# delta_n_acorr = numpy.tile(n_acorr, (f.n_walkers, f.n_dim))
print(
"Lags for each parameter are:",
dict(zip(f.variable_names, lag_tols)),
"Auto-correlation time is {dn} for the worst parameter."
.format(dn=n_acorr)
)
print("Thinning is:", n_thinning)
fig, ax = plot_acorrs(
lags, acorrs, lag_tols=lag_tols,
tolerance=cli_args.acorr_tolerance,
param_names=f.variable_names,
print(
"Auto-correlation time for all parameters are",
", ".join([str(x) for x in acorr_times[n_acorr]]),
)
# Take the most auto-correlated walker and parameter throughout.
n_thinning = numpy.max(n_acorr)
print("Thinning is", n_thinning)
# Plot the autocorrelation diagnostics.
print(acorrs.shape)
integrated_acorr.plot_integrated_auto_correlation(
cli_args.acorr_plot,
acorr_times, None,
)
fig.savefig(cli_args.acorr_plot)
plt.close(fig)
elif cli_args.fixed_thinning is not None:
n_thinning = cli_args.fixed_thinning
else:
......
......@@ -151,3 +151,119 @@ def acorr_time(chains, m=5, retall=False):
)
else:
return acorr_worst
color_cycle = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
linestyle_cycle = [
"solid",
"dashed",
"dashdot",
"dotted",
]
def plot_auto_correlation(
filename,
acorrs, acorr_errs,
):
import warnings
import itertools
import numpy
import matplotlib as mpl
import matplotlib.pyplot as plt
n_lags, n_dim = numpy.shape(acorrs)
lags = numpy.arange(n_lags)
fig, ax = plt.subplots(figsize=(8,4))
for d, color in zip(range(n_dim), itertools.cycle(color_cycle)):
ac = acorrs[...,d]
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
lo = numpy.nanmin(ac, axis=1)
hi = numpy.nanmax(ac, axis=1)
med = numpy.nanmedian(ac, axis=1)
no_nans = ~(numpy.isnan(lo) | numpy.isnan(hi) | numpy.isnan(med))
ax.fill_between(
lags[no_nans], lo[no_nans], hi[no_nans],
color=color, alpha=0.1,
)
ax.plot(
lags[no_nans], med[no_nans],
color=color,
label="Param #{i}".format(i=d+1),
)
ax.set_xlabel("lag")
ax.set_ylabel("auto-correlation")
ax.set_xscale("log")
# ax.set_yscale("log")
ax.legend(loc="best")
fig.tight_layout()
fig.savefig(filename)
def plot_integrated_auto_correlation(filename, acorr_times, acorr_time_errs):
import itertools
import numpy
import matplotlib as mpl
import matplotlib.pyplot as plt
K_final_plus_one, n_dim = numpy.shape(acorr_times)
K_final = K_final_plus_one - 1
Ks = numpy.arange(K_final_plus_one)
fig, ax = plt.subplots(figsize=(8,4))
for d, color in zip(range(n_dim), itertools.cycle(color_cycle)):
ac = acorr_times[...,d]
ax.plot(
Ks, ac,
color=color,
label="Param #{i}".format(i=d+1),
)
if acorr_time_errs is not None:
ac_err = acorr_time_errs[...,d]
ax.fill_between(
Ks, ac-ac_err, ac+ac_err,
color=color,
alpha=0.1,
)
ax.fill_between(
Ks, ac-2.0*ac_err, ac+2.0*ac_err,
color=color,
alpha=0.1,
)
ax.set_xlabel(r"$K$")
ax.set_ylabel(r"$\tau_K$")
ax.set_xscale("log")
# ax.set_yscale("log")
ax.legend(loc="best")
fig.tight_layout()
fig.savefig(filename)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment