Skip to content
Snippets Groups Projects
Commit db4ba5cc authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Improvements to checkpointing to align it with bilby!746

See bilby!746

1) Pickle the entire sampler
2) Save every check_point_deltaT seconds
3) Add stats and run plots at each iteration
parent 865c0449
No related branches found
No related tags found
No related merge requests found
......@@ -13,4 +13,4 @@ addopts = -p no:warnings
line-width=88
multi-line=3
trailing-comma=True
known_third_party=bilby,bilby_pipe,dynesty,emcee,gwpy,matplotlib,mock,mpi4py,numpy,pandas,parallel_bilby,ptemcee,schwimmbad,setuptools,tqdm
known_third_party=bilby,bilby_pipe,dill,dynesty,emcee,gwpy,matplotlib,mock,mpi4py,numpy,pandas,parallel_bilby,ptemcee,schwimmbad,setuptools,tqdm
......@@ -8,9 +8,12 @@ import logging
import os
import pickle
import sys
import time
import bilby
import dill
import dynesty
import dynesty.plotting as dyplot
import matplotlib.pyplot as plt
import mpi4py
import numpy as np
......@@ -18,14 +21,18 @@ import pandas as pd
from bilby.core.utils import reflect
from bilby.gw import conversion
from dynesty import NestedSampler
from dynesty.plotting import traceplot
from dynesty.utils import resample_equal, unitcheck
from dynesty.utils import unitcheck
from numpy import linalg
from pandas import DataFrame
from schwimmbad import MPIPool
from .parser import create_analysis_parser
from .utils import fill_sample, get_cli_args, get_initial_points_from_prior
from .utils import (
fill_sample,
get_cli_args,
get_initial_points_from_prior,
safe_file_dump,
)
mpi4py.rc.threads = False
mpi4py.rc.recv_mprobe = False
......@@ -197,9 +204,7 @@ def reorder_loglikelihoods(unsorted_loglikelihoods, unsorted_samples, sorted_sam
return unsorted_loglikelihoods[idxs]
def write_checkpoint(
sampler, resume_file, sampling_time, search_parameter_keys, no_plot=False
):
def write_current_state(sampler, resume_file, sampling_time):
""" Writes a checkpoint file
Parameters
......@@ -210,81 +215,63 @@ def write_checkpoint(
The name of the resume/checkpoint file to use
sampling_time: float
The total sampling time in seconds
search_parameter_keys: list
A list of the search parameter keys used in sampling (used for
constructing checkpoint plots and pre-results)
no_plot: bool
If true, don't create a check point plot
"""
print("")
logger.info("Writing checkpoint file {}".format(resume_file))
current_state = dict(
unit_cube_samples=sampler.saved_u,
physical_samples=sampler.saved_v,
sample_likelihoods=sampler.saved_logl,
sample_log_volume=sampler.saved_logvol,
sample_log_weights=sampler.saved_logwt,
cumulative_log_evidence=sampler.saved_logz,
cumulative_log_evidence_error=sampler.saved_logzvar,
cumulative_information=sampler.saved_h,
id=sampler.saved_id,
it=sampler.saved_it,
nc=sampler.saved_nc,
boundidx=sampler.saved_boundidx,
bounditer=sampler.saved_bounditer,
scale=sampler.saved_scale,
sampling_time=sampling_time,
)
current_state.update(
ncall=sampler.ncall,
live_logl=sampler.live_logl,
iteration=sampler.it - 1,
live_u=sampler.live_u,
live_v=sampler.live_v,
nlive=sampler.nlive,
live_bound=sampler.live_bound,
live_it=sampler.live_it,
added_live=sampler.added_live,
)
if input_args.do_not_save_bounds_in_resume:
pass
sampler.kwargs["sampling_time"] = sampling_time
if dill.pickles(sampler):
safe_file_dump(sampler, resume_file, dill)
logger.info("Written checkpoint file {}".format(resume_file))
else:
current_state["bound"] = sampler.bound
current_state["nbound"] = sampler.nbound
# Try to save a set of current posterior samples
try:
weights = np.exp(
current_state["sample_log_weights"]
- current_state["cumulative_log_evidence"][-1]
)
current_state["posterior"] = resample_equal(
np.array(current_state["physical_samples"]), weights
logger.warning(
"Cannot write pickle resume file! " "Job will not resume if interrupted."
)
current_state["search_parameter_keys"] = search_parameter_keys
except ValueError:
logger.debug("Unable to create posterior")
with open(resume_file, "wb") as file:
pickle.dump(current_state, file)
# Try to create a checkpoint traceplot
if no_plot is False:
try:
fig = traceplot(sampler.results, labels=sampling_keys)[0]
fig.tight_layout()
fig.savefig(filename_trace)
plt.close("all")
except Exception:
pass
def read_saved_state(resume_file, sampler):
def plot_current_state(sampler, search_parameter_keys, outdir, label):
labels = [label.replace("_", " ") for label in search_parameter_keys]
try:
filename = "{}/{}_checkpoint_trace.png".format(outdir, label)
fig = dyplot.traceplot(sampler.results, labels=labels)[0]
fig.tight_layout()
fig.savefig(filename)
except (
AssertionError,
RuntimeError,
np.linalg.linalg.LinAlgError,
ValueError,
) as e:
logger.warning(e)
logger.warning("Failed to create dynesty state plot at checkpoint")
finally:
plt.close("all")
try:
filename = "{}/{}_checkpoint_run.png".format(outdir, label)
fig, axs = dyplot.runplot(sampler.results)
fig.tight_layout()
plt.savefig(filename)
except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError) as e:
logger.warning(e)
logger.warning("Failed to create dynesty run plot at checkpoint")
finally:
plt.close("all")
try:
filename = "{}/{}_checkpoint_stats.png".format(outdir, label)
fig, axs = plt.subplots(nrows=3, sharex=True)
for ax, name in zip(axs, ["boundidx", "nc", "scale"]):
ax.plot(getattr(sampler, f"saved_{name}"), color="C0")
ax.set_ylabel(name)
axs[-1].set_xlabel("iteration")
fig.tight_layout()
plt.savefig(filename)
except (RuntimeError, ValueError) as e:
logger.warning(e)
logger.warning("Failed to create dynesty stats plot at checkpoint")
finally:
plt.close("all")
def read_saved_state(resume_file, continuing=True):
"""
Read a saved state of the sampler to disk.
......@@ -295,8 +282,6 @@ def read_saved_state(resume_file, sampler):
----------
resume_file: str
The path to the resume file to read
sampler: `dynesty.NestedSampler`
NestedSampler instance to reconstruct from the saved state.
Returns
-------
......@@ -310,49 +295,16 @@ def read_saved_state(resume_file, sampler):
if os.path.isfile(resume_file):
logger.info("Reading resume file {}".format(resume_file))
try:
with open(resume_file, "rb") as file:
saved = pickle.load(file)
logger.info("Successfully read resume file {}".format(resume_file))
except EOFError as e:
logger.warning("Resume file reading failed with error {}".format(e))
return False, 0
sampler.saved_u = list(saved["unit_cube_samples"])
sampler.saved_v = list(saved["physical_samples"])
sampler.saved_logl = list(saved["sample_likelihoods"])
sampler.saved_logvol = list(saved["sample_log_volume"])
sampler.saved_logwt = list(saved["sample_log_weights"])
sampler.saved_logz = list(saved["cumulative_log_evidence"])
sampler.saved_logzvar = list(saved["cumulative_log_evidence_error"])
sampler.saved_id = list(saved["id"])
sampler.saved_it = list(saved["it"])
sampler.saved_nc = list(saved["nc"])
sampler.saved_boundidx = list(saved["boundidx"])
sampler.saved_bounditer = list(saved["bounditer"])
sampler.saved_scale = list(saved["scale"])
sampler.saved_h = list(saved["cumulative_information"])
sampler.ncall = saved["ncall"]
sampler.live_logl = list(saved["live_logl"])
sampler.it = saved["iteration"] + 1
sampler.live_u = saved["live_u"]
sampler.live_v = saved["live_v"]
sampler.nlive = saved["nlive"]
sampler.live_bound = saved["live_bound"]
sampler.live_it = saved["live_it"]
sampler.added_live = saved["added_live"]
try:
sampler.bound = saved["bound"]
sampler.nbound = saved["nbound"]
except KeyError:
logger.info("No bounds saved in resume")
sampling_time = datetime.timedelta(
seconds=saved["sampling_time"]
).total_seconds()
with open(resume_file, "rb") as file:
sampler = dill.load(file)
if sampler.added_live and continuing:
sampler._remove_live_points()
sampler.nqueue = -1
sampler.rstate = np.random
sampling_time = sampler.kwargs.pop("sampling_time")
return sampler, sampling_time
else:
logger.debug("No resume file {}".format(resume_file))
logger.info("Resume file {} does not exist.".format(resume_file))
return False, 0
......@@ -510,77 +462,106 @@ with MPIPool() as pool:
save_bounds=False,
)
logger.info(
"Initialize NestedSampler with {}".format(
json.dumps(init_sampler_kwargs, indent=1, sort_keys=True)
)
)
ndim = len(sampling_keys)
logger.info(f"Initializing sampling points with pool size={POOL_SIZE}")
live_points = get_initial_points_from_prior(
ndim,
nlive,
prior_transform_function,
log_prior_function,
log_likelihood_function,
pool,
)
sampler = NestedSampler(
log_likelihood_function,
prior_transform_function,
ndim,
pool=pool,
queue_size=POOL_SIZE,
print_func=dynesty.results.print_fn_fallback,
periodic=periodic,
reflective=reflective,
live_points=live_points,
use_pool=dict(
update_bound=True,
propose_point=True,
prior_transform=True,
loglikelihood=True,
),
**init_sampler_kwargs,
)
if os.path.isfile(resume_file) and input_args.clean is False:
resume_sampler, sampling_time = read_saved_state(resume_file, sampler)
if resume_sampler is not False:
sampler = resume_sampler
sampler, sampling_time = read_saved_state(resume_file)
if sampler is False:
logger.info(f"Initializing sampling points with pool size={POOL_SIZE}")
live_points = get_initial_points_from_prior(
ndim,
nlive,
prior_transform_function,
log_prior_function,
log_likelihood_function,
pool,
)
logger.info(
"Initialize NestedSampler with {}".format(
json.dumps(init_sampler_kwargs, indent=1, sort_keys=True)
)
)
sampler = NestedSampler(
log_likelihood_function,
prior_transform_function,
ndim,
pool=pool,
queue_size=POOL_SIZE,
print_func=dynesty.results.print_fn_fallback,
periodic=periodic,
reflective=reflective,
live_points=live_points,
use_pool=dict(
update_bound=True,
propose_point=True,
prior_transform=True,
loglikelihood=True,
),
**init_sampler_kwargs,
)
else:
# Reinstate the pool and map (not saved in the pickle)
sampler.pool = pool
sampler.M = pool.map
logger.info(
f"Starting sampling for job {label}, with pool size={POOL_SIZE} "
f"and n_check_point={input_args.n_check_point}"
f"and check_point_deltaT={input_args.check_point_deltaT}"
)
old_ncall = sampler.ncall
sampler_kwargs = dict(
print_progress=True,
maxcall=input_args.n_check_point,
n_effective=input_args.n_effective,
dlogz=input_args.dlogz,
save_bounds=not input_args.do_not_save_bounds_in_resume,
)
while True:
sampler_kwargs["add_live"] = False
sampler.run_nested(**sampler_kwargs)
if sampler.ncall == old_ncall:
break
old_ncall = sampler.ncall
logger.info("Run criteria: {}".format(json.dumps(sampler_kwargs)))
for it, res in enumerate(sampler.sample(**sampler_kwargs)):
(
worst,
ustar,
vstar,
loglstar,
logvol,
logwt,
logz,
logzvar,
h,
nc,
worst_it,
boundidx,
bounditer,
eff,
delta_logz,
) = res
i = it - 1
dynesty.results.print_fn_fallback(res, i, sampler.ncall, dlogz=input_args.dlogz)
if it == 0 or it % input_args.n_check_point != 0:
continue
sampling_time += (datetime.datetime.now() - t0).total_seconds()
t0 = datetime.datetime.now()
write_checkpoint(
sampler,
resume_file,
sampling_time,
sampling_keys,
no_plot=input_args.no_plot,
)
sampler_kwargs["add_live"] = True
if os.path.isfile(resume_file):
last_checkpoint_s = time.time() - os.path.getmtime(resume_file)
else:
last_checkpoint_s = np.inf
if last_checkpoint_s > input_args.check_point_deltaT:
write_current_state(sampler, resume_file, sampling_time)
if input_args.no_plot is False:
plot_current_state(sampler, sampling_keys, outdir, label)
# Adding the final set of live points.
for it_final, res in enumerate(sampler.add_live_points()):
pass
# Create a final checkpoint and set of plots
write_current_state(sampler, resume_file, sampling_time)
if input_args.no_plot is False:
plot_current_state(sampler, sampling_keys, outdir, label)
sampling_time += (datetime.datetime.now() - t0).total_seconds()
out = sampler.results
......
......@@ -126,9 +126,9 @@ def _add_dynesty_settings_to_parser(parser):
)
dynesty_group.add_argument(
"--n-check-point",
default=100000,
default=100,
type=int,
help="Steps to take before checkpoint",
help="Steps to take before attempting checkpoint",
)
return parser
......@@ -221,12 +221,6 @@ def _add_ptemcee_settings_to_parser(parser):
"<http://arxiv.org/abs/1501.05823>`_ for details."
),
)
ptemcee_group.add_argument(
"--check-point-deltaT",
default=600,
type=float,
help="Write a checkpoint resume file and diagnostic plots every deltaT [s]",
)
return parser
......@@ -258,6 +252,12 @@ def _add_misc_settings_to_parser(parser):
"resume files large (~GB)"
),
)
misc_group.add_argument(
"--check-point-deltaT",
default=600,
type=float,
help="Write a checkpoint resume file and diagnostic plots every deltaT [s]",
)
return parser
......
......@@ -79,3 +79,22 @@ def get_initial_points_from_prior(
l_list = [point[2] for point in initial_points]
return np.array(u_list), np.array(v_list), np.array(l_list)
def safe_file_dump(data, filename, module):
""" Safely dump data to a .pickle file
Parameters
----------
data:
data to dump
filename: str
The file to dump to
module: pickle, dill
The python module to use
"""
temp_filename = filename + ".temp"
with open(temp_filename, "wb") as file:
module.dump(data, file)
os.rename(temp_filename, filename)
......@@ -46,7 +46,7 @@ class ParserTest(unittest.TestCase):
mem_per_cpu="1000",
min_eff=10,
minimum_frequency="20",
n_check_point=100000,
n_check_point=100,
n_effective=np.inf,
n_simulation=0,
n_parallel=4,
......@@ -118,8 +118,9 @@ class ParserTest(unittest.TestCase):
label=None,
maxmcmc=5000,
min_eff=10,
n_check_point=100000,
n_check_point=100,
n_effective=np.inf,
check_point_deltaT=600,
nact=5,
nlive=1000,
no_plot=False,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment