From 134a3414677c3ae1904c8ad5fe2306fbb3093536 Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Mon, 8 Oct 2018 14:42:50 +1100 Subject: [PATCH] fix nburn and autocorr for ptemcee --- bilby/core/sampler/base_sampler.py | 3 ++- bilby/core/sampler/ptemcee.py | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index fdbc5a325..a8ddc2fb0 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -1,3 +1,4 @@ +from __future__ import absolute_import import datetime import numpy as np from pandas import DataFrame @@ -439,7 +440,7 @@ class MCMCSampler(Sampler): def print_nburn_logging_info(self): """ Prints logging info as to how nburn was calculated """ - if type(self.kwargs['nburn']) in [float, int]: + if type(self.nburn) in [float, int]: logger.info("Discarding {} steps for burn-in".format(self.nburn)) elif self.result.max_autocorrelation_time is None: logger.info("Autocorrelation time not calculated, discarding {} " diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py index 2efcab08c..5ac414c03 100644 --- a/bilby/core/sampler/ptemcee.py +++ b/bilby/core/sampler/ptemcee.py @@ -1,4 +1,4 @@ -from __future__ import absolute_import +from __future__ import absolute_import, division, print_function import numpy as np from ..utils import get_progress_bar, logger @@ -69,16 +69,18 @@ class Ptemcee(Emcee): total=self.nsteps): pass - self.result.nburn = self.nburn + self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim))) self.result.sampler_output = np.nan + self.print_nburn_logging_info() + self.result.nburn = self.nburn + if self.result.nburn > self.nsteps: + logger.warning('Chain not burned in, no samples generated.') self.result.samples = sampler.chain[0, :, self.nburn:, :].reshape( (-1, self.ndim)) + self.result.betas = sampler.betas + self.result.log_evidence, self.result.log_evidence_err =\ + sampler.log_evidence_estimate( + sampler.loglikelihood, self.nburn / self.nsteps) self.result.walkers = sampler.chain[0, :, :, :] - self.result.log_evidence = np.nan - self.result.log_evidence_err = np.nan - logger.info("Max autocorr time = {}" - .format(np.max(sampler.get_autocorr_time()))) - logger.info("Tswap frac = {}" - .format(sampler.tswap_acceptance_fraction)) return self.result -- GitLab