Skip to content
Snippets Groups Projects

Improvements to checkpointing to align it with bilby!746

Merged Gregory Ashton requested to merge improve-checkpoint-behaviour into master
1 unresolved thread
Files
5
+ 156
175
@@ -8,9 +8,12 @@ import logging
import os
import pickle
import sys
import time
import bilby
import dill
import dynesty
import dynesty.plotting as dyplot
import matplotlib.pyplot as plt
import mpi4py
import numpy as np
@@ -18,14 +21,18 @@ import pandas as pd
from bilby.core.utils import reflect
from bilby.gw import conversion
from dynesty import NestedSampler
from dynesty.plotting import traceplot
from dynesty.utils import resample_equal, unitcheck
from dynesty.utils import unitcheck
from numpy import linalg
from pandas import DataFrame
from schwimmbad import MPIPool
from .parser import create_analysis_parser
from .utils import fill_sample, get_cli_args, get_initial_points_from_prior
from .utils import (
fill_sample,
get_cli_args,
get_initial_points_from_prior,
safe_file_dump,
)
mpi4py.rc.threads = False
mpi4py.rc.recv_mprobe = False
@@ -197,9 +204,7 @@ def reorder_loglikelihoods(unsorted_loglikelihoods, unsorted_samples, sorted_sam
return unsorted_loglikelihoods[idxs]
def write_checkpoint(
sampler, resume_file, sampling_time, search_parameter_keys, no_plot=False
):
def write_current_state(sampler, resume_file, sampling_time):
""" Writes a checkpoint file
Parameters
@@ -210,81 +215,63 @@ def write_checkpoint(
The name of the resume/checkpoint file to use
sampling_time: float
The total sampling time in seconds
search_parameter_keys: list
A list of the search parameter keys used in sampling (used for
constructing checkpoint plots and pre-results)
no_plot: bool
If true, don't create a check point plot
"""
print("")
logger.info("Writing checkpoint file {}".format(resume_file))
current_state = dict(
unit_cube_samples=sampler.saved_u,
physical_samples=sampler.saved_v,
sample_likelihoods=sampler.saved_logl,
sample_log_volume=sampler.saved_logvol,
sample_log_weights=sampler.saved_logwt,
cumulative_log_evidence=sampler.saved_logz,
cumulative_log_evidence_error=sampler.saved_logzvar,
cumulative_information=sampler.saved_h,
id=sampler.saved_id,
it=sampler.saved_it,
nc=sampler.saved_nc,
boundidx=sampler.saved_boundidx,
bounditer=sampler.saved_bounditer,
scale=sampler.saved_scale,
sampling_time=sampling_time,
)
current_state.update(
ncall=sampler.ncall,
live_logl=sampler.live_logl,
iteration=sampler.it - 1,
live_u=sampler.live_u,
live_v=sampler.live_v,
nlive=sampler.nlive,
live_bound=sampler.live_bound,
live_it=sampler.live_it,
added_live=sampler.added_live,
)
if input_args.do_not_save_bounds_in_resume:
pass
sampler.kwargs["sampling_time"] = sampling_time
if dill.pickles(sampler):
safe_file_dump(sampler, resume_file, dill)
logger.info("Written checkpoint file {}".format(resume_file))
else:
current_state["bound"] = sampler.bound
current_state["nbound"] = sampler.nbound
# Try to save a set of current posterior samples
try:
weights = np.exp(
current_state["sample_log_weights"]
- current_state["cumulative_log_evidence"][-1]
)
current_state["posterior"] = resample_equal(
np.array(current_state["physical_samples"]), weights
logger.warning(
"Cannot write pickle resume file! " "Job will not resume if interrupted."
)
current_state["search_parameter_keys"] = search_parameter_keys
except ValueError:
logger.debug("Unable to create posterior")
with open(resume_file, "wb") as file:
pickle.dump(current_state, file)
# Try to create a checkpoint traceplot
if no_plot is False:
try:
fig = traceplot(sampler.results, labels=sampling_keys)[0]
fig.tight_layout()
fig.savefig(filename_trace)
plt.close("all")
except Exception:
pass
def read_saved_state(resume_file, sampler):
def plot_current_state(sampler, search_parameter_keys, outdir, label):
labels = [label.replace("_", " ") for label in search_parameter_keys]
try:
filename = "{}/{}_checkpoint_trace.png".format(outdir, label)
fig = dyplot.traceplot(sampler.results, labels=labels)[0]
fig.tight_layout()
fig.savefig(filename)
except (
AssertionError,
RuntimeError,
np.linalg.linalg.LinAlgError,
ValueError,
) as e:
logger.warning(e)
logger.warning("Failed to create dynesty state plot at checkpoint")
finally:
plt.close("all")
try:
filename = "{}/{}_checkpoint_run.png".format(outdir, label)
fig, axs = dyplot.runplot(sampler.results)
fig.tight_layout()
plt.savefig(filename)
except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError) as e:
logger.warning(e)
logger.warning("Failed to create dynesty run plot at checkpoint")
finally:
plt.close("all")
try:
filename = "{}/{}_checkpoint_stats.png".format(outdir, label)
fig, axs = plt.subplots(nrows=3, sharex=True)
for ax, name in zip(axs, ["boundidx", "nc", "scale"]):
ax.plot(getattr(sampler, f"saved_{name}"), color="C0")
ax.set_ylabel(name)
axs[-1].set_xlabel("iteration")
fig.tight_layout()
plt.savefig(filename)
except (RuntimeError, ValueError) as e:
logger.warning(e)
logger.warning("Failed to create dynesty stats plot at checkpoint")
finally:
plt.close("all")
def read_saved_state(resume_file, continuing=True):
"""
Read a saved state of the sampler to disk.
@@ -295,8 +282,6 @@ def read_saved_state(resume_file, sampler):
----------
resume_file: str
The path to the resume file to read
sampler: `dynesty.NestedSampler`
NestedSampler instance to reconstruct from the saved state.
Returns
-------
@@ -310,49 +295,16 @@ def read_saved_state(resume_file, sampler):
if os.path.isfile(resume_file):
logger.info("Reading resume file {}".format(resume_file))
try:
with open(resume_file, "rb") as file:
saved = pickle.load(file)
logger.info("Successfully read resume file {}".format(resume_file))
except EOFError as e:
logger.warning("Resume file reading failed with error {}".format(e))
return False, 0
sampler.saved_u = list(saved["unit_cube_samples"])
sampler.saved_v = list(saved["physical_samples"])
sampler.saved_logl = list(saved["sample_likelihoods"])
sampler.saved_logvol = list(saved["sample_log_volume"])
sampler.saved_logwt = list(saved["sample_log_weights"])
sampler.saved_logz = list(saved["cumulative_log_evidence"])
sampler.saved_logzvar = list(saved["cumulative_log_evidence_error"])
sampler.saved_id = list(saved["id"])
sampler.saved_it = list(saved["it"])
sampler.saved_nc = list(saved["nc"])
sampler.saved_boundidx = list(saved["boundidx"])
sampler.saved_bounditer = list(saved["bounditer"])
sampler.saved_scale = list(saved["scale"])
sampler.saved_h = list(saved["cumulative_information"])
sampler.ncall = saved["ncall"]
sampler.live_logl = list(saved["live_logl"])
sampler.it = saved["iteration"] + 1
sampler.live_u = saved["live_u"]
sampler.live_v = saved["live_v"]
sampler.nlive = saved["nlive"]
sampler.live_bound = saved["live_bound"]
sampler.live_it = saved["live_it"]
sampler.added_live = saved["added_live"]
try:
sampler.bound = saved["bound"]
sampler.nbound = saved["nbound"]
except KeyError:
logger.info("No bounds saved in resume")
sampling_time = datetime.timedelta(
seconds=saved["sampling_time"]
).total_seconds()
with open(resume_file, "rb") as file:
sampler = dill.load(file)
if sampler.added_live and continuing:
sampler._remove_live_points()
sampler.nqueue = -1
sampler.rstate = np.random
sampling_time = sampler.kwargs.pop("sampling_time")
return sampler, sampling_time
else:
logger.debug("No resume file {}".format(resume_file))
logger.info("Resume file {} does not exist.".format(resume_file))
return False, 0
@@ -510,77 +462,106 @@ with MPIPool() as pool:
save_bounds=False,
)
logger.info(
"Initialize NestedSampler with {}".format(
json.dumps(init_sampler_kwargs, indent=1, sort_keys=True)
)
)
ndim = len(sampling_keys)
logger.info(f"Initializing sampling points with pool size={POOL_SIZE}")
live_points = get_initial_points_from_prior(
ndim,
nlive,
prior_transform_function,
log_prior_function,
log_likelihood_function,
pool,
)
sampler = NestedSampler(
log_likelihood_function,
prior_transform_function,
ndim,
pool=pool,
queue_size=POOL_SIZE,
print_func=dynesty.results.print_fn_fallback,
periodic=periodic,
reflective=reflective,
live_points=live_points,
use_pool=dict(
update_bound=True,
propose_point=True,
prior_transform=True,
loglikelihood=True,
),
**init_sampler_kwargs,
)
if os.path.isfile(resume_file) and input_args.clean is False:
resume_sampler, sampling_time = read_saved_state(resume_file, sampler)
if resume_sampler is not False:
sampler = resume_sampler
sampler, sampling_time = read_saved_state(resume_file)
if sampler is False:
logger.info(f"Initializing sampling points with pool size={POOL_SIZE}")
live_points = get_initial_points_from_prior(
ndim,
nlive,
prior_transform_function,
log_prior_function,
log_likelihood_function,
pool,
)
logger.info(
"Initialize NestedSampler with {}".format(
json.dumps(init_sampler_kwargs, indent=1, sort_keys=True)
)
)
sampler = NestedSampler(
log_likelihood_function,
prior_transform_function,
ndim,
pool=pool,
queue_size=POOL_SIZE,
print_func=dynesty.results.print_fn_fallback,
periodic=periodic,
reflective=reflective,
live_points=live_points,
use_pool=dict(
update_bound=True,
propose_point=True,
prior_transform=True,
loglikelihood=True,
),
**init_sampler_kwargs,
)
else:
# Reinstate the pool and map (not saved in the pickle)
sampler.pool = pool
sampler.M = pool.map
logger.info(
f"Starting sampling for job {label}, with pool size={POOL_SIZE} "
f"and n_check_point={input_args.n_check_point}"
f"and check_point_deltaT={input_args.check_point_deltaT}"
)
old_ncall = sampler.ncall
sampler_kwargs = dict(
print_progress=True,
maxcall=input_args.n_check_point,
n_effective=input_args.n_effective,
dlogz=input_args.dlogz,
save_bounds=not input_args.do_not_save_bounds_in_resume,
)
while True:
sampler_kwargs["add_live"] = False
sampler.run_nested(**sampler_kwargs)
if sampler.ncall == old_ncall:
break
old_ncall = sampler.ncall
logger.info("Run criteria: {}".format(json.dumps(sampler_kwargs)))
for it, res in enumerate(sampler.sample(**sampler_kwargs)):
(
worst,
ustar,
vstar,
loglstar,
logvol,
logwt,
logz,
logzvar,
h,
nc,
worst_it,
boundidx,
bounditer,
eff,
delta_logz,
) = res
i = it - 1
dynesty.results.print_fn_fallback(res, i, sampler.ncall, dlogz=input_args.dlogz)
if it == 0 or it % input_args.n_check_point != 0:
continue
sampling_time += (datetime.datetime.now() - t0).total_seconds()
t0 = datetime.datetime.now()
write_checkpoint(
sampler,
resume_file,
sampling_time,
sampling_keys,
no_plot=input_args.no_plot,
)
sampler_kwargs["add_live"] = True
if os.path.isfile(resume_file):
last_checkpoint_s = time.time() - os.path.getmtime(resume_file)
else:
last_checkpoint_s = np.inf
if last_checkpoint_s > input_args.check_point_deltaT:
write_current_state(sampler, resume_file, sampling_time)
if input_args.no_plot is False:
plot_current_state(sampler, sampling_keys, outdir, label)
# Adding the final set of live points.
for it_final, res in enumerate(sampler.add_live_points()):
pass
# Create a final checkpoint and set of plots
write_current_state(sampler, resume_file, sampling_time)
if input_args.no_plot is False:
plot_current_state(sampler, sampling_keys, outdir, label)
sampling_time += (datetime.datetime.now() - t0).total_seconds()
out = sampler.results
Loading