Skip to content
Snippets Groups Projects

Clean up dynesty

Merged Gregory Ashton requested to merge clean-up-dynesty into master
1 file
+ 93
50
Compare changes
  • Side-by-side
  • Inline
+ 93
50
@@ -200,7 +200,13 @@ class Sampler(object):
return result
def _check_if_priors_can_be_sampled(self):
"""Check if all priors can be sampled properly. Raises AttributeError if prior can't be sampled."""
"""Check if all priors can be sampled properly.
Raises
------
AttributeError
prior can't be sampled.
"""
for key in self.priors:
try:
self.likelihood.parameters[key] = self.priors[key].sample()
@@ -208,13 +214,26 @@ class Sampler(object):
logger.warning('Cannot sample from {}, {}'.format(key, e))
def _verify_parameters(self):
""" Sets initial values for likelihood.parameters. Raises TypeError if likelihood can't be evaluated."""
""" Sets initial values for likelihood.parameters.
Raises
------
TypeError
Likelihood can't be evaluated.
"""
self._check_if_priors_can_be_sampled()
try:
t1 = datetime.datetime.now()
self.likelihood.log_likelihood()
self._sample_log_likelihood_eval = (datetime.datetime.now() - t1).total_seconds()
logger.info("Single likelihood evaluation took {:.3e} s".format(self._sample_log_likelihood_eval))
self._log_likelihood_eval_time = (
datetime.datetime.now() - t1).total_seconds()
if self._log_likelihood_eval_time == 0:
self._log_likelihood_eval_time = np.nan
logger.info("Unable to measure single likelihood time")
else:
logger.info("Single likelihood evaluation took {:.3e} s"
.format(self._log_likelihood_eval_time))
except TypeError as e:
raise TypeError(
"Likelihood evaluation failed with message: \n'{}'\n"
@@ -450,21 +469,34 @@ class Dynesty(Sampler):
@kwargs.setter
def kwargs(self, kwargs):
self.__kwargs = dict(dlogz=0.1, bound='multi', sample='rwalk', resume=True,
walks=self.ndim * 5, verbose=True, check_point_delta_t=60 * 10)
# Set some default values
self.__kwargs = dict(dlogz=0.1, bound='multi', sample='rwalk',
resume=True, walks=self.ndim * 5, verbose=True,
check_point_delta_t=60 * 10, nlive=250)
# Overwrite default values with user specified values
self.__kwargs.update(kwargs)
# Check if nlive was instead given by another name
if 'nlive' not in self.__kwargs:
for equiv in ['nlives', 'n_live_points', 'npoint', 'npoints']:
if equiv in self.__kwargs:
self.__kwargs['nlive'] = self.__kwargs.pop(equiv)
if 'nlive' not in self.__kwargs:
self.__kwargs['nlive'] = 250
# Set the update interval
if 'update_interval' not in self.__kwargs:
self.__kwargs['update_interval'] = int(0.6 * self.__kwargs['nlive'])
if 'n_check_point' not in kwargs:
# checkpointing done by default ~ every 10 minutes
# Set the checking pointing
# If the log_likelihood_eval_time was not able to be calculated
# then n_check_point is set to None (no checkpointing)
if np.isnan(self._log_likelihood_eval_time):
self.__kwargs['n_check_point'] = None
# If n_check_point is not already set, set it checkpoint every 10 mins
if 'n_check_point' not in self.__kwargs:
n_check_point_raw = (self.__kwargs['check_point_delta_t']
/ self._sample_log_likelihood_eval)
/ self._log_likelihood_eval_time)
n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw)))
self.__kwargs['n_check_point'] = n_check_point_rnd
@@ -504,46 +536,19 @@ class Dynesty(Sampler):
def _run_external_sampler(self):
dynesty = self.external_sampler
if self.kwargs.get('dynamic', False) is False:
nested_sampler = dynesty.NestedSampler(
loglikelihood=self.log_likelihood,
prior_transform=self.prior_transform,
ndim=self.ndim, **self.kwargs)
if self.kwargs['resume']:
resume = self.read_saved_state(nested_sampler, continuing=True)
if resume:
logger.info('Resuming from previous run.')
old_ncall = nested_sampler.ncall
maxcall = self.kwargs['n_check_point']
while True:
maxcall += self.kwargs['n_check_point']
nested_sampler.run_nested(
dlogz=self.kwargs['dlogz'],
print_progress=self.kwargs['verbose'],
print_func=self._print_func, maxcall=maxcall,
add_live=False)
if nested_sampler.ncall == old_ncall:
break
old_ncall = nested_sampler.ncall
self.write_current_state(nested_sampler)
self.read_saved_state(nested_sampler)
nested_sampler = dynesty.NestedSampler(
loglikelihood=self.log_likelihood,
prior_transform=self.prior_transform,
ndim=self.ndim, **self.kwargs)
nested_sampler.run_nested(
dlogz=self.kwargs['dlogz'],
print_progress=self.kwargs['verbose'],
print_func=self._print_func, add_live=True)
if self.kwargs['n_check_point']:
out = self._run_external_sampler_with_checkpointing(nested_sampler)
else:
nested_sampler = dynesty.DynamicNestedSampler(
loglikelihood=self.log_likelihood,
prior_transform=self.prior_transform,
ndim=self.ndim, **self.kwargs)
nested_sampler.run_nested(print_progress=self.kwargs['verbose'])
print("")
out = nested_sampler.results
out = self._run_external_sampler_without_checkpointing(nested_sampler)
# Flushes the output to force a line break
if self.kwargs["verbose"]:
print("")
# self.result.sampler_output = out
weights = np.exp(out['logwt'] - out['logz'][-1])
@@ -556,9 +561,47 @@ class Dynesty(Sampler):
if self.plot:
self.generate_trace_plots(out)
self._remove_checkpoint()
return self.result
def _run_external_sampler_without_checkpointing(self, nested_sampler):
logger.debug("Running sampler without checkpointing")
nested_sampler.run_nested(
dlogz=self.kwargs['dlogz'],
print_progress=self.kwargs['verbose'],
print_func=self._print_func)
return nested_sampler.results
def _run_external_sampler_with_checkpointing(self, nested_sampler):
logger.debug("Running sampler with checkpointing")
if self.kwargs['resume']:
resume = self.read_saved_state(nested_sampler, continuing=True)
if resume:
logger.info('Resuming from previous run.')
old_ncall = nested_sampler.ncall
maxcall = self.kwargs['n_check_point']
while True:
maxcall += self.kwargs['n_check_point']
nested_sampler.run_nested(
dlogz=self.kwargs['dlogz'],
print_progress=self.kwargs['verbose'],
print_func=self._print_func, maxcall=maxcall,
add_live=False)
if nested_sampler.ncall == old_ncall:
break
old_ncall = nested_sampler.ncall
self.write_current_state(nested_sampler)
self.read_saved_state(nested_sampler)
nested_sampler.run_nested(
dlogz=self.kwargs['dlogz'],
print_progress=self.kwargs['verbose'],
print_func=self._print_func, add_live=True)
self._remove_checkpoint()
return nested_sampler.results
def _remove_checkpoint(self):
"""Remove checkpointed state"""
if os.path.isfile('{}/{}_resume.h5'.format(self.outdir, self.label)):
Loading