Skip to content
Snippets Groups Projects
Commit 1297ea21 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'master' of git.ligo.org:Monash/tupak into clean-up-of-detectors

parents 11bd4249 88b125af
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,8 @@ import sys
import numpy as np
import matplotlib.pyplot as plt
import datetime
import deepdish
from scipy.misc import logsumexp
from tupak.core.result import Result, read_in_result
from tupak.core.prior import Prior
......@@ -211,9 +213,8 @@ class Sampler(object):
try:
t1 = datetime.datetime.now()
self.likelihood.log_likelihood()
logging.info(
"Single likelihood evaluation took {:.3e} s"
.format((datetime.datetime.now() - t1).total_seconds()))
self._sample_log_likelihood_eval = (datetime.datetime.now() - t1).total_seconds()
logging.info("Single likelihood evaluation took {:.3e} s".format(self._sample_log_likelihood_eval))
except TypeError as e:
raise TypeError(
"Likelihood evaluation failed with message: \n'{}'\n"
......@@ -434,8 +435,8 @@ class Dynesty(Sampler):
@kwargs.setter
def kwargs(self, kwargs):
self.__kwargs = dict(dlogz=0.1, bound='multi', sample='rwalk',
walks=self.ndim * 5, verbose=True)
self.__kwargs = dict(dlogz=0.1, bound='multi', sample='rwalk', resume=True,
walks=self.ndim * 5, verbose=True, check_point_delta_t=60*10)
self.__kwargs.update(kwargs)
if 'nlive' not in self.__kwargs:
for equiv in ['nlives', 'n_live_points', 'npoint', 'npoints']:
......@@ -445,6 +446,10 @@ class Dynesty(Sampler):
self.__kwargs['nlive'] = 250
if 'update_interval' not in self.__kwargs:
self.__kwargs['update_interval'] = int(0.6 * self.__kwargs['nlive'])
if 'n_check_point' not in kwargs:
# checkpointing done by default ~ every 10 minutes
self.__kwargs['n_check_point'] = int(self.__kwargs['check_point_delta_t']
// self._sample_log_likelihood_eval)
def _print_func(self, results, niter, ncall, dlogz, *args, **kwargs):
""" Replacing status update for dynesty.result.print_func """
......@@ -487,10 +492,33 @@ class Dynesty(Sampler):
loglikelihood=self.log_likelihood,
prior_transform=self.prior_transform,
ndim=self.ndim, **self.kwargs)
if self.kwargs['resume']:
resume = self.read_saved_state(nested_sampler, continuing=True)
if resume:
logging.info('Resuming from previous run.')
old_ncall = nested_sampler.ncall
maxcall = self.kwargs['n_check_point']
while True:
maxcall += self.kwargs['n_check_point']
nested_sampler.run_nested(
dlogz=self.kwargs['dlogz'],
print_progress=self.kwargs['verbose'],
print_func=self._print_func, maxcall=maxcall,
add_live=False)
if nested_sampler.ncall == old_ncall:
break
old_ncall = nested_sampler.ncall
self.write_current_state(nested_sampler)
self.read_saved_state(nested_sampler)
nested_sampler.run_nested(
dlogz=self.kwargs['dlogz'],
print_progress=self.kwargs['verbose'],
print_func=self._print_func)
print_func=self._print_func, add_live=True)
else:
nested_sampler = dynesty.DynamicNestedSampler(
loglikelihood=self.log_likelihood,
......@@ -510,8 +538,151 @@ class Dynesty(Sampler):
if self.plot:
self.generate_trace_plots(out)
self._remove_checkpoint()
return self.result
def _remove_checkpoint(self):
"""Remove checkpointed state"""
if os.path.isfile('{}/{}_resume.h5'.format(self.outdir, self.label)):
os.remove('{}/{}_resume.h5'.format(self.outdir, self.label))
def read_saved_state(self, nested_sampler, continuing=False):
"""
Read a saved state of the sampler to disk.
The required information to reconstruct the state of the run is read from an hdf5 file.
This currently adds the whole chain to the sampler.
We then remove the old checkpoint and write all unnecessary items back to disk.
FIXME: Load only the necessary quantities, rather than read/write?
Parameters
----------
nested_sampler: `dynesty.NestedSampler`
NestedSampler instance to reconstruct from the saved state.
continuing: bool
Whether the run is continuing or terminating, if True, the loaded state is mostly
written back to disk.
"""
resume_file = '{}/{}_resume.h5'.format(self.outdir, self.label)
if os.path.isfile(resume_file):
saved_state = deepdish.io.load(resume_file)
nested_sampler.saved_u = list(saved_state['unit_cube_samples'])
nested_sampler.saved_v = list(saved_state['physical_samples'])
nested_sampler.saved_logl = list(saved_state['sample_likelihoods'])
nested_sampler.saved_logvol = list(saved_state['sample_log_volume'])
nested_sampler.saved_logwt = list(saved_state['sample_log_weights'])
nested_sampler.saved_logz = list(saved_state['cumulative_log_evidence'])
nested_sampler.saved_logzvar = list(saved_state['cumulative_log_evidence_error'])
nested_sampler.saved_id = list(saved_state['id'])
nested_sampler.saved_it = list(saved_state['it'])
nested_sampler.saved_nc = list(saved_state['nc'])
nested_sampler.saved_boundidx = list(saved_state['boundidx'])
nested_sampler.saved_bounditer = list(saved_state['bounditer'])
nested_sampler.saved_scale = list(saved_state['scale'])
nested_sampler.saved_h = list(saved_state['cumulative_information'])
nested_sampler.ncall = saved_state['ncall']
nested_sampler.live_logl = list(saved_state['live_logl'])
nested_sampler.it = saved_state['iteration'] + 1
nested_sampler.live_u = saved_state['live_u']
nested_sampler.live_v = saved_state['live_v']
nested_sampler.nlive = saved_state['nlive']
nested_sampler.live_bound = saved_state['live_bound']
nested_sampler.live_it = saved_state['live_it']
nested_sampler.added_live = saved_state['added_live']
self._remove_checkpoint()
if continuing:
self.write_current_state(nested_sampler)
return True
else:
return False
def write_current_state(self, nested_sampler):
"""
Write the current state of the sampler to disk.
The required information to reconstruct the state of the run are written to an hdf5 file.
All but the most recent removed live point in the chain are removed from the sampler to reduce memory usage.
This means it is necessary to not append the first live point to the file if updating a previous checkpoint.
Parameters
----------
nested_sampler: `dynesty.NestedSampler`
NestedSampler to write to disk.
"""
resume_file = '{}/{}_resume.h5'.format(self.outdir, self.label)
if os.path.isfile(resume_file):
saved_state = deepdish.io.load(resume_file)
current_state = dict(
unit_cube_samples=np.vstack([saved_state['unit_cube_samples'], nested_sampler.saved_u[1:]]),
physical_samples=np.vstack([saved_state['physical_samples'], nested_sampler.saved_v[1:]]),
sample_likelihoods=np.concatenate([saved_state['sample_likelihoods'], nested_sampler.saved_logl[1:]]),
sample_log_volume=np.concatenate([saved_state['sample_log_volume'], nested_sampler.saved_logvol[1:]]),
sample_log_weights=np.concatenate([saved_state['sample_log_weights'], nested_sampler.saved_logwt[1:]]),
cumulative_log_evidence=np.concatenate([saved_state['cumulative_log_evidence'],
nested_sampler.saved_logz[1:]]),
cumulative_log_evidence_error=np.concatenate([saved_state['cumulative_log_evidence_error'],
nested_sampler.saved_logzvar[1:]]),
cumulative_information=np.concatenate([saved_state['cumulative_information'],
nested_sampler.saved_h[1:]]),
id=np.concatenate([saved_state['id'], nested_sampler.saved_id[1:]]),
it=np.concatenate([saved_state['it'], nested_sampler.saved_it[1:]]),
nc=np.concatenate([saved_state['nc'], nested_sampler.saved_nc[1:]]),
boundidx=np.concatenate([saved_state['boundidx'], nested_sampler.saved_boundidx[1:]]),
bounditer=np.concatenate([saved_state['bounditer'], nested_sampler.saved_bounditer[1:]]),
scale=np.concatenate([saved_state['scale'], nested_sampler.saved_scale[1:]]),
)
else:
current_state = dict(
unit_cube_samples=nested_sampler.saved_u,
physical_samples=nested_sampler.saved_v,
sample_likelihoods=nested_sampler.saved_logl,
sample_log_volume=nested_sampler.saved_logvol,
sample_log_weights=nested_sampler.saved_logwt,
cumulative_log_evidence=nested_sampler.saved_logz,
cumulative_log_evidence_error=nested_sampler.saved_logzvar,
cumulative_information=nested_sampler.saved_h,
id=nested_sampler.saved_id,
it=nested_sampler.saved_it,
nc=nested_sampler.saved_nc,
boundidx=nested_sampler.saved_boundidx,
bounditer=nested_sampler.saved_bounditer,
scale=nested_sampler.saved_scale,
)
current_state.update(
ncall=nested_sampler.ncall, live_logl=nested_sampler.live_logl, iteration=nested_sampler.it - 1,
live_u=nested_sampler.live_u, live_v=nested_sampler.live_v, nlive=nested_sampler.nlive,
live_bound=nested_sampler.live_bound, live_it=nested_sampler.live_it, added_live=nested_sampler.added_live
)
weights = np.exp(current_state['sample_log_weights'] - current_state['cumulative_log_evidence'][-1])
current_state['posterior'] = self.external_sampler.utils.resample_equal(
np.array(current_state['physical_samples']), weights)
deepdish.io.save(resume_file, current_state)
nested_sampler.saved_id = [nested_sampler.saved_id[-1]]
nested_sampler.saved_u = [nested_sampler.saved_u[-1]]
nested_sampler.saved_v = [nested_sampler.saved_v[-1]]
nested_sampler.saved_logl = [nested_sampler.saved_logl[-1]]
nested_sampler.saved_logvol = [nested_sampler.saved_logvol[-1]]
nested_sampler.saved_logwt = [nested_sampler.saved_logwt[-1]]
nested_sampler.saved_logz = [nested_sampler.saved_logz[-1]]
nested_sampler.saved_logzvar = [nested_sampler.saved_logzvar[-1]]
nested_sampler.saved_h = [nested_sampler.saved_h[-1]]
nested_sampler.saved_nc = [nested_sampler.saved_nc[-1]]
nested_sampler.saved_boundidx = [nested_sampler.saved_boundidx[-1]]
nested_sampler.saved_it = [nested_sampler.saved_it[-1]]
nested_sampler.saved_bounditer = [nested_sampler.saved_bounditer[-1]]
nested_sampler.saved_scale = [nested_sampler.saved_scale[-1]]
def generate_trace_plots(self, dynesty_results):
filename = '{}/{}_trace.png'.format(self.outdir, self.label)
logging.debug("Writing trace plot to {}".format(filename))
......
......@@ -67,6 +67,7 @@ class GravitationalWaveTransient(likelihood.Likelihood):
self.prior = prior
if self.distance_marginalization:
self.check_prior_is_set()
self.distance_array = np.array([])
self.delta_distance = 0
self.distance_prior_array = np.array([])
......@@ -74,10 +75,18 @@ class GravitationalWaveTransient(likelihood.Likelihood):
prior['luminosity_distance'] = 1 # this means the prior is a delta function fixed at the RHS value
if self.phase_marginalization:
self.check_prior_is_set()
self.bessel_function_interped = None
self.setup_phase_marginalization()
prior['phase'] = 0
if self.time_marginalization:
self.check_prior_is_set()
def check_prior_is_set(self):
if self.prior is None:
raise ValueError("You can't use a marginalized likelihood without specifying a prior")
@property
def prior(self):
return self.__prior
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment