Skip to content
Snippets Groups Projects
Commit b43227b0 authored by Colm Talbot's avatar Colm Talbot
Browse files

make resuming work

parent ce5b40aa
No related branches found
No related tags found
1 merge request!74Dynesty checkpointing
Pipeline #
......@@ -7,6 +7,8 @@ import sys
import numpy as np
import matplotlib.pyplot as plt
import datetime
import deepdish
from scipy.misc import logsumexp
from tupak.core.result import Result, read_in_result
from tupak.core.prior import Prior
......@@ -434,8 +436,8 @@ class Dynesty(Sampler):
@kwargs.setter
def kwargs(self, kwargs):
self.__kwargs = dict(dlogz=0.1, bound='multi', sample='rwalk',
walks=self.ndim * 5, verbose=True, n_check_point=5000)
self.__kwargs = dict(dlogz=0.1, bound='multi', sample='rwalk', resume=True,
walks=self.ndim * 5, verbose=True, n_check_point=1000000)
self.__kwargs.update(kwargs)
if 'nlive' not in self.__kwargs:
for equiv in ['nlives', 'n_live_points', 'npoint', 'npoints']:
......@@ -475,10 +477,6 @@ class Dynesty(Sampler):
print_str = "\r {}| {}={:6.3f} +/- {:6.3f} | dlogz: {:6.3f} > {:6.3f}".format(
niter, key, logz, logzerr, delta_logz, dlogz)
with open(self.sample_file, 'a') as sample_file:
sample_file.write('\t'.join([str(param) for param in vstar]) +
'\t{}\t{}\t{}\n'.format(loglstar, logz, logwt))
# Printing.
sys.stderr.write(print_str)
sys.stderr.flush()
......@@ -486,24 +484,21 @@ class Dynesty(Sampler):
def _run_external_sampler(self):
dynesty = self.external_sampler
self.sample_file = '{}/{}.samples'.format(self.outdir, self.label)
if os.path.isfile(self.sample_file):
os.rename(self.sample_file, self.sample_file + '_old')
with open(self.sample_file, 'w') as sample_file:
sample_file.write('\t'.join([key for key in self.priors.keys()]))
sample_file.write('\tlogl\tlogz\tlogwt\n')
if self.kwargs.get('dynamic', False) is False:
nested_sampler = dynesty.NestedSampler(
loglikelihood=self.log_likelihood,
prior_transform=self.prior_transform,
ndim=self.ndim, **self.kwargs)
old_ncall = 0
if self.kwargs['resume']:
resume = self.read_saved_state(nested_sampler)
if resume:
logging.info('Resuming from previous run.')
old_ncall = nested_sampler.ncall
maxcall = self.kwargs['n_check_point']
# maxcall = 5000
while True:
maxcall += self.kwargs['n_check_point']
print(nested_sampler.ncall, 'fhdslahfjkldsahfsa')
nested_sampler.run_nested(
dlogz=self.kwargs['dlogz'],
print_progress=self.kwargs['verbose'],
......@@ -511,8 +506,12 @@ class Dynesty(Sampler):
add_live=False)
if nested_sampler.ncall == old_ncall:
break
print(old_ncall, nested_sampler.ncall)
old_ncall = nested_sampler.ncall
self.write_current_state(nested_sampler)
self.read_saved_state(nested_sampler)
nested_sampler.run_nested(
dlogz=self.kwargs['dlogz'],
print_progress=self.kwargs['verbose'],
......@@ -538,6 +537,111 @@ class Dynesty(Sampler):
self.generate_trace_plots(out)
return self.result
def read_saved_state(self, nested_sampler):
resume_file = '{}/{}_resume.h5'.format(self.outdir, self.label)
if os.path.isfile(resume_file):
saved_state = deepdish.io.load(resume_file)
nested_sampler.saved_u = list(saved_state['unit_cube_samples'])
nested_sampler.saved_v = list(saved_state['physical_samples'])
nested_sampler.saved_logl = list(saved_state['sample_likelihoods'])
nested_sampler.saved_logvol = list(saved_state['sample_log_volume'])
nested_sampler.saved_logwt = list(saved_state['sample_log_weights'])
nested_sampler.saved_logz = list(saved_state['cumulative_log_evidence'])
nested_sampler.saved_logzvar = list(saved_state['cumulative_log_evidence_error'])
nested_sampler.saved_id = list(saved_state['id'])
nested_sampler.saved_it = list(saved_state['it'])
nested_sampler.saved_nc = list(saved_state['nc'])
nested_sampler.saved_boundidx = list(saved_state['boundidx'])
nested_sampler.saved_bounditer = list(saved_state['bounditer'])
nested_sampler.saved_scale = list(saved_state['scale'])
nested_sampler.saved_h = list(saved_state['cumulative_information'])
nested_sampler.ncall = saved_state['ncall']
nested_sampler.live_logl = list(saved_state['live_logl'])
nested_sampler.it = saved_state['iteration'] + 1
nested_sampler.live_u = saved_state['live_u']
nested_sampler.live_v = saved_state['live_v']
nested_sampler.nlive = saved_state['nlive']
nested_sampler.live_bound = saved_state['live_bound']
nested_sampler.live_it = saved_state['live_it']
nested_sampler.added_live = saved_state['added_live']
return True
else:
return False
def write_current_state(self, nested_sampler):
resume_file = '{}/{}_resume.h5'.format(self.outdir, self.label)
if os.path.isfile(resume_file):
saved_state = deepdish.io.load(resume_file)
print(np.shape(saved_state['sample_likelihoods']), np.shape(nested_sampler.saved_logl))
current_state = dict(
unit_cube_samples=np.vstack([saved_state['unit_cube_samples'], nested_sampler.saved_u[1:]]),
physical_samples=np.vstack([saved_state['physical_samples'], nested_sampler.saved_v[1:]]),
sample_likelihoods=np.concatenate([saved_state['sample_likelihoods'], nested_sampler.saved_logl[1:]]),
sample_log_volume=np.concatenate([saved_state['sample_log_volume'], nested_sampler.saved_logvol[1:]]),
sample_log_weights=np.concatenate([saved_state['sample_log_weights'], nested_sampler.saved_logwt[1:]]),
cumulative_log_evidence=np.concatenate([saved_state['cumulative_log_evidence'],
nested_sampler.saved_logz[1:]]),
cumulative_log_evidence_error=np.concatenate([saved_state['cumulative_log_evidence_error'],
nested_sampler.saved_logzvar[1:]]),
cumulative_information=np.concatenate([saved_state['cumulative_information'],
nested_sampler.saved_h[1:]]),
id=np.concatenate([saved_state['id'], nested_sampler.saved_id[1:]]),
it=np.concatenate([saved_state['it'], nested_sampler.saved_it[1:]]),
nc=np.concatenate([saved_state['nc'], nested_sampler.saved_nc[1:]]),
boundidx=np.concatenate([saved_state['boundidx'], nested_sampler.saved_boundidx[1:]]),
bounditer=np.concatenate([saved_state['bounditer'], nested_sampler.saved_bounditer[1:]]),
scale=np.concatenate([saved_state['scale'], nested_sampler.saved_scale[1:]]),
)
else:
current_state = dict(
unit_cube_samples=nested_sampler.saved_u,
physical_samples=nested_sampler.saved_v,
sample_likelihoods=nested_sampler.saved_logl,
sample_log_volume=nested_sampler.saved_logvol,
sample_log_weights=nested_sampler.saved_logwt,
cumulative_log_evidence=nested_sampler.saved_logz,
cumulative_log_evidence_error=nested_sampler.saved_logzvar,
cumulative_information=nested_sampler.saved_h,
id=nested_sampler.saved_id,
it=nested_sampler.saved_it,
nc=nested_sampler.saved_nc,
boundidx=nested_sampler.saved_boundidx,
bounditer=nested_sampler.saved_bounditer,
scale=nested_sampler.saved_scale,
)
current_state.update(
ncall=nested_sampler.ncall, live_logl=nested_sampler.live_logl, iteration=nested_sampler.it - 1,
live_u=nested_sampler.live_u, live_v=nested_sampler.live_v, nlive=nested_sampler.nlive,
live_bound=nested_sampler.live_bound, live_it=nested_sampler.live_it, added_live=nested_sampler.added_live
)
deepdish.io.save(resume_file, current_state)
nested_sampler.saved_id = [nested_sampler.saved_id[-1]]
nested_sampler.saved_u = [nested_sampler.saved_u[-1]]
nested_sampler.saved_v = [nested_sampler.saved_v[-1]]
nested_sampler.saved_logl = [nested_sampler.saved_logl[-1]]
nested_sampler.saved_logvol = [nested_sampler.saved_logvol[-1]]
nested_sampler.saved_logwt = [nested_sampler.saved_logwt[-1]]
nested_sampler.saved_logz = [nested_sampler.saved_logz[-1]]
nested_sampler.saved_logzvar = [nested_sampler.saved_logzvar[-1]]
nested_sampler.saved_h = [nested_sampler.saved_h[-1]]
nested_sampler.saved_nc = [nested_sampler.saved_nc[-1]]
nested_sampler.saved_boundidx = [nested_sampler.saved_boundidx[-1]]
nested_sampler.saved_it = [nested_sampler.saved_it[-1]]
nested_sampler.saved_bounditer = [nested_sampler.saved_bounditer[-1]]
nested_sampler.saved_scale = [nested_sampler.saved_scale[-1]]
def generate_trace_plots(self, dynesty_results):
filename = '{}/{}_trace.png'.format(self.outdir, self.label)
logging.debug("Writing trace plot to {}".format(filename))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment