diff --git a/tupak/core/sampler.py b/tupak/core/sampler.py index 30f938aa2971d94c84486dbd625d298c52661044..3c21baceebf063845365e795165150274b4b84d1 100644 --- a/tupak/core/sampler.py +++ b/tupak/core/sampler.py @@ -200,7 +200,13 @@ class Sampler(object): return result def _check_if_priors_can_be_sampled(self): - """Check if all priors can be sampled properly. Raises AttributeError if prior can't be sampled.""" + """Check if all priors can be sampled properly. + + Raises + ------ + AttributeError + prior can't be sampled. + """ for key in self.priors: try: self.likelihood.parameters[key] = self.priors[key].sample() @@ -208,13 +214,26 @@ class Sampler(object): logger.warning('Cannot sample from {}, {}'.format(key, e)) def _verify_parameters(self): - """ Sets initial values for likelihood.parameters. Raises TypeError if likelihood can't be evaluated.""" + """ Sets initial values for likelihood.parameters. + + Raises + ------ + TypeError + Likelihood can't be evaluated. + + """ self._check_if_priors_can_be_sampled() try: t1 = datetime.datetime.now() self.likelihood.log_likelihood() - self._sample_log_likelihood_eval = (datetime.datetime.now() - t1).total_seconds() - logger.info("Single likelihood evaluation took {:.3e} s".format(self._sample_log_likelihood_eval)) + self._log_likelihood_eval_time = ( + datetime.datetime.now() - t1).total_seconds() + if self._log_likelihood_eval_time == 0: + self._log_likelihood_eval_time = np.nan + logger.info("Unable to measure single likelihood time") + else: + logger.info("Single likelihood evaluation took {:.3e} s" + .format(self._log_likelihood_eval_time)) except TypeError as e: raise TypeError( "Likelihood evaluation failed with message: \n'{}'\n" @@ -450,21 +469,34 @@ class Dynesty(Sampler): @kwargs.setter def kwargs(self, kwargs): - self.__kwargs = dict(dlogz=0.1, bound='multi', sample='rwalk', resume=True, - walks=self.ndim * 5, verbose=True, check_point_delta_t=60 * 10) + # Set some default values + self.__kwargs = dict(dlogz=0.1, bound='multi', sample='rwalk', + resume=True, walks=self.ndim * 5, verbose=True, + check_point_delta_t=60 * 10, nlive=250) + + # Overwrite default values with user specified values self.__kwargs.update(kwargs) + + # Check if nlive was instead given by another name if 'nlive' not in self.__kwargs: for equiv in ['nlives', 'n_live_points', 'npoint', 'npoints']: if equiv in self.__kwargs: self.__kwargs['nlive'] = self.__kwargs.pop(equiv) - if 'nlive' not in self.__kwargs: - self.__kwargs['nlive'] = 250 + + # Set the update interval if 'update_interval' not in self.__kwargs: self.__kwargs['update_interval'] = int(0.6 * self.__kwargs['nlive']) - if 'n_check_point' not in kwargs: - # checkpointing done by default ~ every 10 minutes + + # Set the checking pointing + # If the log_likelihood_eval_time was not able to be calculated + # then n_check_point is set to None (no checkpointing) + if np.isnan(self._log_likelihood_eval_time): + self.__kwargs['n_check_point'] = None + + # If n_check_point is not already set, set it checkpoint every 10 mins + if 'n_check_point' not in self.__kwargs: n_check_point_raw = (self.__kwargs['check_point_delta_t'] - / self._sample_log_likelihood_eval) + / self._log_likelihood_eval_time) n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw))) self.__kwargs['n_check_point'] = n_check_point_rnd @@ -509,6 +541,35 @@ class Dynesty(Sampler): prior_transform=self.prior_transform, ndim=self.ndim, **self.kwargs) + if self.kwargs['n_check_point']: + out = self._run_external_sampler_with_checkpointing(nested_sampler) + else: + out = self._run_external_sampler_without_checkpointing(nested_sampler) + + # self.result.sampler_output = out + weights = np.exp(out['logwt'] - out['logz'][-1]) + self.result.samples = dynesty.utils.resample_equal( + out.samples, weights) + self.result.log_likelihood_evaluations = out.logl + self.result.log_evidence = out.logz[-1] + self.result.log_evidence_err = out.logzerr[-1] + + if self.plot: + self.generate_trace_plots(out) + + return self.result + + def _run_external_sampler_without_checkpointing(self, nested_sampler): + logger.debug("Running sampler without checkpointing") + nested_sampler.run_nested( + dlogz=self.kwargs['dlogz'], + print_progress=self.kwargs['verbose'], + print_func=self._print_func) + print("") + return nested_sampler.results + + def _run_external_sampler_with_checkpointing(self, nested_sampler): + logger.debug("Running sampler with checkpointing") if self.kwargs['resume']: resume = self.read_saved_state(nested_sampler, continuing=True) if resume: @@ -537,21 +598,8 @@ class Dynesty(Sampler): print_func=self._print_func, add_live=True) print("") - out = nested_sampler.results - - # self.result.sampler_output = out - weights = np.exp(out['logwt'] - out['logz'][-1]) - self.result.samples = dynesty.utils.resample_equal( - out.samples, weights) - self.result.log_likelihood_evaluations = out.logl - self.result.log_evidence = out.logz[-1] - self.result.log_evidence_err = out.logzerr[-1] - - if self.plot: - self.generate_trace_plots(out) - self._remove_checkpoint() - return self.result + return nested_sampler.results def _remove_checkpoint(self): """Remove checkpointed state"""