diff --git a/tupak/core/sampler.py b/tupak/core/sampler.py index d331f712fcda8c70f4317a24fe0c14bb88b4a9ff..822dc8a6febe7966f0e360c2ba710f39a1e91a05 100644 --- a/tupak/core/sampler.py +++ b/tupak/core/sampler.py @@ -200,7 +200,13 @@ class Sampler(object): return result def _check_if_priors_can_be_sampled(self): - """Check if all priors can be sampled properly. Raises AttributeError if prior can't be sampled.""" + """Check if all priors can be sampled properly. + + Raises + ------ + AttributeError + prior can't be sampled. + """ for key in self.priors: try: self.likelihood.parameters[key] = self.priors[key].sample() @@ -208,13 +214,26 @@ class Sampler(object): logger.warning('Cannot sample from {}, {}'.format(key, e)) def _verify_parameters(self): - """ Sets initial values for likelihood.parameters. Raises TypeError if likelihood can't be evaluated.""" + """ Sets initial values for likelihood.parameters. + + Raises + ------ + TypeError + Likelihood can't be evaluated. + + """ self._check_if_priors_can_be_sampled() try: t1 = datetime.datetime.now() self.likelihood.log_likelihood() - self._sample_log_likelihood_eval = (datetime.datetime.now() - t1).total_seconds() - logger.info("Single likelihood evaluation took {:.3e} s".format(self._sample_log_likelihood_eval)) + self._log_likelihood_eval_time = ( + datetime.datetime.now() - t1).total_seconds() + if self._log_likelihood_eval_time == 0: + self._log_likelihood_eval_time = np.nan + logger.info("Unable to measure single likelihood time") + else: + logger.info("Single likelihood evaluation took {:.3e} s" + .format(self._log_likelihood_eval_time)) except TypeError as e: raise TypeError( "Likelihood evaluation failed with message: \n'{}'\n" @@ -450,21 +469,34 @@ class Dynesty(Sampler): @kwargs.setter def kwargs(self, kwargs): - self.__kwargs = dict(dlogz=0.1, bound='multi', sample='rwalk', resume=True, - walks=self.ndim * 5, verbose=True, check_point_delta_t=60 * 10) + # Set some default values + self.__kwargs = dict(dlogz=0.1, bound='multi', sample='rwalk', + resume=True, walks=self.ndim * 5, verbose=True, + check_point_delta_t=60 * 10, nlive=250) + + # Overwrite default values with user specified values self.__kwargs.update(kwargs) + + # Check if nlive was instead given by another name if 'nlive' not in self.__kwargs: for equiv in ['nlives', 'n_live_points', 'npoint', 'npoints']: if equiv in self.__kwargs: self.__kwargs['nlive'] = self.__kwargs.pop(equiv) - if 'nlive' not in self.__kwargs: - self.__kwargs['nlive'] = 250 + + # Set the update interval if 'update_interval' not in self.__kwargs: self.__kwargs['update_interval'] = int(0.6 * self.__kwargs['nlive']) - if 'n_check_point' not in kwargs: - # checkpointing done by default ~ every 10 minutes + + # Set the checking pointing + # If the log_likelihood_eval_time was not able to be calculated + # then n_check_point is set to None (no checkpointing) + if np.isnan(self._log_likelihood_eval_time): + self.__kwargs['n_check_point'] = None + + # If n_check_point is not already set, set it checkpoint every 10 mins + if 'n_check_point' not in self.__kwargs: n_check_point_raw = (self.__kwargs['check_point_delta_t'] - / self._sample_log_likelihood_eval) + / self._log_likelihood_eval_time) n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw))) self.__kwargs['n_check_point'] = n_check_point_rnd @@ -504,46 +536,19 @@ class Dynesty(Sampler): def _run_external_sampler(self): dynesty = self.external_sampler - if self.kwargs.get('dynamic', False) is False: - nested_sampler = dynesty.NestedSampler( - loglikelihood=self.log_likelihood, - prior_transform=self.prior_transform, - ndim=self.ndim, **self.kwargs) - - if self.kwargs['resume']: - resume = self.read_saved_state(nested_sampler, continuing=True) - if resume: - logger.info('Resuming from previous run.') - - old_ncall = nested_sampler.ncall - maxcall = self.kwargs['n_check_point'] - while True: - maxcall += self.kwargs['n_check_point'] - nested_sampler.run_nested( - dlogz=self.kwargs['dlogz'], - print_progress=self.kwargs['verbose'], - print_func=self._print_func, maxcall=maxcall, - add_live=False) - if nested_sampler.ncall == old_ncall: - break - old_ncall = nested_sampler.ncall - - self.write_current_state(nested_sampler) - - self.read_saved_state(nested_sampler) + nested_sampler = dynesty.NestedSampler( + loglikelihood=self.log_likelihood, + prior_transform=self.prior_transform, + ndim=self.ndim, **self.kwargs) - nested_sampler.run_nested( - dlogz=self.kwargs['dlogz'], - print_progress=self.kwargs['verbose'], - print_func=self._print_func, add_live=True) + if self.kwargs['n_check_point']: + out = self._run_external_sampler_with_checkpointing(nested_sampler) else: - nested_sampler = dynesty.DynamicNestedSampler( - loglikelihood=self.log_likelihood, - prior_transform=self.prior_transform, - ndim=self.ndim, **self.kwargs) - nested_sampler.run_nested(print_progress=self.kwargs['verbose']) - print("") - out = nested_sampler.results + out = self._run_external_sampler_without_checkpointing(nested_sampler) + + # Flushes the output to force a line break + if self.kwargs["verbose"]: + print("") # self.result.sampler_output = out weights = np.exp(out['logwt'] - out['logz'][-1]) @@ -556,9 +561,47 @@ class Dynesty(Sampler): if self.plot: self.generate_trace_plots(out) - self._remove_checkpoint() return self.result + def _run_external_sampler_without_checkpointing(self, nested_sampler): + logger.debug("Running sampler without checkpointing") + nested_sampler.run_nested( + dlogz=self.kwargs['dlogz'], + print_progress=self.kwargs['verbose'], + print_func=self._print_func) + return nested_sampler.results + + def _run_external_sampler_with_checkpointing(self, nested_sampler): + logger.debug("Running sampler with checkpointing") + if self.kwargs['resume']: + resume = self.read_saved_state(nested_sampler, continuing=True) + if resume: + logger.info('Resuming from previous run.') + + old_ncall = nested_sampler.ncall + maxcall = self.kwargs['n_check_point'] + while True: + maxcall += self.kwargs['n_check_point'] + nested_sampler.run_nested( + dlogz=self.kwargs['dlogz'], + print_progress=self.kwargs['verbose'], + print_func=self._print_func, maxcall=maxcall, + add_live=False) + if nested_sampler.ncall == old_ncall: + break + old_ncall = nested_sampler.ncall + + self.write_current_state(nested_sampler) + + self.read_saved_state(nested_sampler) + + nested_sampler.run_nested( + dlogz=self.kwargs['dlogz'], + print_progress=self.kwargs['verbose'], + print_func=self._print_func, add_live=True) + self._remove_checkpoint() + return nested_sampler.results + def _remove_checkpoint(self): """Remove checkpointed state""" if os.path.isfile('{}/{}_resume.h5'.format(self.outdir, self.label)):