Skip to content
Snippets Groups Projects
Commit def7f5e3 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Improve resume behaviour from checkpoint

parent 49581fb5
No related branches found
No related tags found
1 merge request!750Improve ptemcee
Pipeline #113248 failed
......@@ -97,22 +97,36 @@ class Ptemcee(MCMCSampler):
if os.path.isfile(self.resume_file) and self.resume is True:
logger.info("Resume data {} found".format(self.resume_file))
with open(self.resume_file, "rb") as file:
import IPython; IPython.embed()
data = dill.load(file)
self.sampler = data["sampler"]
self.sampler.pool = self.pool
self.sampler.threads = self.threads
self.tau_list = data["tau_list"]
self.tau_list_n = data["tau_list_n"]
self.time_per_check = data["time_per_check"]
self.sampler.pool = self.pool
self.sampler.threads = self.threads
pos0 = None
logger.info("Resuming from previous run with time={}".format(self.sampler.time))
else:
# Initialize the PTSampler
self.sampler = ptemcee.Sampler(
dim=self.ndim, logl=do_nothing_function, logp=do_nothing_function,
pool=self.pool, threads=self.threads, **self.sampler_init_kwargs)
# Overwrite the _likeprior to improve performance with threads > 1
self.sampler._likeprior = LikePriorEvaluator(
self.search_parameter_keys, use_ratio=self.use_ratio)
# Set up empty lists
self.tau_list = []
self.tau_list_n = []
self.time_per_check = []
# Initialize the walker postitions
pos0 = self.get_pos0_from_prior()
return self.sampler, pos0
......@@ -138,9 +152,6 @@ class Ptemcee(MCMCSampler):
def run_sampler_internal(self):
import emcee
sampler, pos0 = self.setup_sampler()
self.time_per_check = []
self.tau_list = []
self.tau_list_n = []
t0 = datetime.datetime.now()
logger.info("Starting to sample")
......@@ -230,7 +241,7 @@ class Ptemcee(MCMCSampler):
last_checkpoint_s = np.sum(self.time_per_check)
if last_checkpoint_s > self.check_point_deltaT:
self.write_current_state()
self.write_current_state(plot=self.plot)
# Check if we reached the end without converging
if sampler.time == self.sampler_function_kwargs["iterations"]:
......@@ -241,7 +252,7 @@ class Ptemcee(MCMCSampler):
)
# Run a final checkpoint to update the plots and samples
self.write_current_state()
self.write_current_state(plot=self.plot)
# Get 0-likelihood samples and store in the result
samples = sampler.chain[0, :, :, :] # nwalkers, nsteps, ndim
......@@ -268,16 +279,27 @@ class Ptemcee(MCMCSampler):
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
if self.pool:
if getattr(self, 'pool', None):
self.write_current_state(plot=False)
logger.warning("Closing pool")
self.pool.close()
self.write_current_state()
sys.exit(77)
def write_current_state(self):
def write_current_state(self, plot=True):
checkpoint(self.outdir, self.label, self.nsamples_effective,
self.sampler, self.nburn, self.thin,
self.search_parameter_keys, self.resume_file, self.tau_list,
self.tau_list_n)
self.tau_list_n, self.time_per_check)
if plot:
# Generate the walkers plot diagnostic
plot_walkers(
self.sampler.chain[0, :, : self.sampler.time, :],
self.nburn, self.search_parameter_keys, self.outdir, self.label
)
# Generate the tau plot diagnostic
plot_tau(self.tau_list_n, self.tau_list, self.outdir, self.label)
def print_progress(
......@@ -340,7 +362,8 @@ def print_progress(
def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
search_parameter_keys, resume_file, tau_list, tau_list_n):
search_parameter_keys, resume_file, tau_list, tau_list_n,
time_per_check):
logger.info("Writing checkpoint and diagnostics")
ndim = sampler.dim
......@@ -360,20 +383,17 @@ def checkpoint(outdir, label, nsamples_effective, sampler, nburn, thin,
sampler_copy._logposterior = sampler._logposterior[:, :, : sampler.time]
sampler_copy._loglikelihood = sampler._loglikelihood[:, :, : sampler.time]
sampler_copy._beta_history = sampler._beta_history[:, : sampler.time]
data = dict(sampler=sampler_copy, tau_list=tau_list, tau_list_n=tau_list_n)
data = dict(
sampler=sampler_copy, tau_list=tau_list, tau_list_n=tau_list_n,
time_per_check=time_per_check)
with open(resume_file, "wb") as file:
dill.dump(data, file, protocol=4)
del data, sampler_copy
# Generate the walkers plot diagnostic
plot_walkers(
sampler.chain[0, :, : sampler.time, :], nburn, search_parameter_keys, outdir, label
)
# Generate the tau plot diagnostic
plot_tau(tau_list_n, tau_list, outdir, label)
logger.info("Finished writing checkpoint and diagnostics")
logger.info("Finished writing checkpoint")
def plot_walkers(walkers, nburn, parameter_labels, outdir, label):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment