Skip to content
Snippets Groups Projects
Commit 79a1888b authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Collection of clean ups of dynesty

1. Explicit with/without check pointing functions
2. Closes #150 if timing is not succesful it defaults to no
checkpointing
3. General clean up of some doc strings
4. General clean up and commenting of kwargs setup
parent dda8ebf6
No related branches found
No related tags found
1 merge request!118Clean up dynesty
......@@ -200,7 +200,13 @@ class Sampler(object):
return result
def _check_if_priors_can_be_sampled(self):
"""Check if all priors can be sampled properly. Raises AttributeError if prior can't be sampled."""
"""Check if all priors can be sampled properly.
Raises
------
AttributeError
prior can't be sampled.
"""
for key in self.priors:
try:
self.likelihood.parameters[key] = self.priors[key].sample()
......@@ -208,13 +214,26 @@ class Sampler(object):
logger.warning('Cannot sample from {}, {}'.format(key, e))
def _verify_parameters(self):
""" Sets initial values for likelihood.parameters. Raises TypeError if likelihood can't be evaluated."""
""" Sets initial values for likelihood.parameters.
Raises
------
TypeError
Likelihood can't be evaluated.
"""
self._check_if_priors_can_be_sampled()
try:
t1 = datetime.datetime.now()
self.likelihood.log_likelihood()
self._sample_log_likelihood_eval = (datetime.datetime.now() - t1).total_seconds()
logger.info("Single likelihood evaluation took {:.3e} s".format(self._sample_log_likelihood_eval))
self._log_likelihood_eval_time = (
datetime.datetime.now() - t1).total_seconds()
if self._log_likelihood_eval_time == 0:
self._log_likelihood_eval_time = np.nan
logger.info("Unable to measure single likelihood time")
else:
logger.info("Single likelihood evaluation took {:.3e} s"
.format(self._log_likelihood_eval_time))
except TypeError as e:
raise TypeError(
"Likelihood evaluation failed with message: \n'{}'\n"
......@@ -450,21 +469,34 @@ class Dynesty(Sampler):
@kwargs.setter
def kwargs(self, kwargs):
self.__kwargs = dict(dlogz=0.1, bound='multi', sample='rwalk', resume=True,
walks=self.ndim * 5, verbose=True, check_point_delta_t=60 * 10)
# Set some default values
self.__kwargs = dict(dlogz=0.1, bound='multi', sample='rwalk',
resume=True, walks=self.ndim * 5, verbose=True,
check_point_delta_t=60 * 10, nlive=250)
# Overwrite default values with user specified values
self.__kwargs.update(kwargs)
# Check if nlive was instead given by another name
if 'nlive' not in self.__kwargs:
for equiv in ['nlives', 'n_live_points', 'npoint', 'npoints']:
if equiv in self.__kwargs:
self.__kwargs['nlive'] = self.__kwargs.pop(equiv)
if 'nlive' not in self.__kwargs:
self.__kwargs['nlive'] = 250
# Set the update interval
if 'update_interval' not in self.__kwargs:
self.__kwargs['update_interval'] = int(0.6 * self.__kwargs['nlive'])
if 'n_check_point' not in kwargs:
# checkpointing done by default ~ every 10 minutes
# Set the checking pointing
# If the log_likelihood_eval_time was not able to be calculated
# then n_check_point is set to None (no checkpointing)
if np.isnan(self._log_likelihood_eval_time):
self.__kwargs['n_check_point'] = None
# If n_check_point is not already set, set it checkpoint every 10 mins
if 'n_check_point' not in self.__kwargs:
n_check_point_raw = (self.__kwargs['check_point_delta_t']
/ self._sample_log_likelihood_eval)
/ self._log_likelihood_eval_time)
n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw)))
self.__kwargs['n_check_point'] = n_check_point_rnd
......@@ -509,6 +541,35 @@ class Dynesty(Sampler):
prior_transform=self.prior_transform,
ndim=self.ndim, **self.kwargs)
if self.kwargs['n_check_point']:
out = self._run_external_sampler_with_checkpointing(nested_sampler)
else:
out = self._run_external_sampler_without_checkpointing(nested_sampler)
# self.result.sampler_output = out
weights = np.exp(out['logwt'] - out['logz'][-1])
self.result.samples = dynesty.utils.resample_equal(
out.samples, weights)
self.result.log_likelihood_evaluations = out.logl
self.result.log_evidence = out.logz[-1]
self.result.log_evidence_err = out.logzerr[-1]
if self.plot:
self.generate_trace_plots(out)
return self.result
def _run_external_sampler_without_checkpointing(self, nested_sampler):
logger.debug("Running sampler without checkpointing")
nested_sampler.run_nested(
dlogz=self.kwargs['dlogz'],
print_progress=self.kwargs['verbose'],
print_func=self._print_func)
print("")
return nested_sampler.results
def _run_external_sampler_with_checkpointing(self, nested_sampler):
logger.debug("Running sampler with checkpointing")
if self.kwargs['resume']:
resume = self.read_saved_state(nested_sampler, continuing=True)
if resume:
......@@ -537,21 +598,8 @@ class Dynesty(Sampler):
print_func=self._print_func, add_live=True)
print("")
out = nested_sampler.results
# self.result.sampler_output = out
weights = np.exp(out['logwt'] - out['logz'][-1])
self.result.samples = dynesty.utils.resample_equal(
out.samples, weights)
self.result.log_likelihood_evaluations = out.logl
self.result.log_evidence = out.logz[-1]
self.result.log_evidence_err = out.logzerr[-1]
if self.plot:
self.generate_trace_plots(out)
self._remove_checkpoint()
return self.result
return nested_sampler.results
def _remove_checkpoint(self):
"""Remove checkpointed state"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment