diff --git a/tupak/core/sampler.py b/tupak/core/sampler.py index eda23e4afe8010bebc768f72abb8a5c7a527da5f..d331f712fcda8c70f4317a24fe0c14bb88b4a9ff 100644 --- a/tupak/core/sampler.py +++ b/tupak/core/sampler.py @@ -773,7 +773,9 @@ class Emcee(Sampler): def _run_external_sampler(self): self.nwalkers = self.kwargs.get('nwalkers', 100) self.nsteps = self.kwargs.get('nsteps', 100) - self.nburn = self.kwargs.get('nburn', 50) + self.nburn = self.kwargs.get('nburn', None) + self.burn_in_fraction = self.kwargs.get('burn_in_fraction', 0.25) + self.burn_in_act = self.kwargs.get('burn_in_act', 3) a = self.kwargs.get('a', 2) emcee = self.external_sampler tqdm = utils.get_progress_bar(self.kwargs.pop('tqdm', 'tqdm')) @@ -806,18 +808,14 @@ class Emcee(Sampler): pass self.result.sampler_output = np.nan + self.calculate_autocorrelation(sampler) + self.setup_nburn() + self.result.nburn = self.nburn self.result.samples = sampler.chain[:, self.nburn:, :].reshape( (-1, self.ndim)) self.result.walkers = sampler.chain[:, :, :] - self.result.nburn = self.nburn self.result.log_evidence = np.nan self.result.log_evidence_err = np.nan - - try: - logger.info("Max autocorr time = {}".format( - np.max(sampler.get_autocorr_time()))) - except emcee.autocorr.AutocorrError as e: - logger.info("Unable to calculate autocorr time: {}".format(e)) return self.result def lnpostfn(self, theta): @@ -827,6 +825,41 @@ class Emcee(Sampler): else: return self.log_likelihood(theta) + p + def setup_nburn(self): + """ Handles calculating nburn, either from a given value or inferred """ + if type(self.nburn) in [float, int]: + self.nburn = int(self.nburn) + logger.info("Discarding {} steps for burn-in".format(self.nburn)) + elif self.result.max_autocorrelation_time is None: + self.nburn = int(self.burn_in_fraction * self.nsteps) + logger.info("Autocorrelation time not calculated, discarding {} " + " steps for burn-in".format(self.nburn)) + else: + self.nburn = int( + self.burn_in_act * self.result.max_autocorrelation_time) + logger.info("Discarding {} steps for burn-in, estimated from " + "autocorr".format(self.nburn)) + + def calculate_autocorrelation(self, sampler, c=3): + """ Uses the `emcee.autocorr` module to estimate the autocorrelation + + Parameters + ---------- + c: float + The minimum number of autocorrelation times needed to trust the + estimate (default: `3`). See `emcee.autocorr.integrated_time`. + """ + + import emcee + try: + self.result.max_autocorrelation_time = int(np.max( + sampler.get_autocorr_time(c=c))) + logger.info("Max autocorr time = {}".format( + self.result.max_autocorrelation_time)) + except emcee.autocorr.AutocorrError as e: + self.result.max_autocorrelation_time = None + logger.info("Unable to calculate autocorr time: {}".format(e)) + class Ptemcee(Emcee): """ https://github.com/willvousden/ptemcee """