From 134a3414677c3ae1904c8ad5fe2306fbb3093536 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Mon, 8 Oct 2018 14:42:50 +1100
Subject: [PATCH] fix nburn and autocorr for ptemcee

---
 bilby/core/sampler/base_sampler.py |  3 ++-
 bilby/core/sampler/ptemcee.py      | 18 ++++++++++--------
 2 files changed, 12 insertions(+), 9 deletions(-)

diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index fdbc5a325..a8ddc2fb0 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 import datetime
 import numpy as np
 from pandas import DataFrame
@@ -439,7 +440,7 @@ class MCMCSampler(Sampler):
 
     def print_nburn_logging_info(self):
         """ Prints logging info as to how nburn was calculated """
-        if type(self.kwargs['nburn']) in [float, int]:
+        if type(self.nburn) in [float, int]:
             logger.info("Discarding {} steps for burn-in".format(self.nburn))
         elif self.result.max_autocorrelation_time is None:
             logger.info("Autocorrelation time not calculated, discarding {} "
diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py
index 2efcab08c..5ac414c03 100644
--- a/bilby/core/sampler/ptemcee.py
+++ b/bilby/core/sampler/ptemcee.py
@@ -1,4 +1,4 @@
-from __future__ import absolute_import
+from __future__ import absolute_import, division, print_function
 
 import numpy as np
 from ..utils import get_progress_bar, logger
@@ -69,16 +69,18 @@ class Ptemcee(Emcee):
                 total=self.nsteps):
             pass
 
-        self.result.nburn = self.nburn
+        self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim)))
         self.result.sampler_output = np.nan
+        self.print_nburn_logging_info()
+        self.result.nburn = self.nburn
+        if self.result.nburn > self.nsteps:
+            logger.warning('Chain not burned in, no samples generated.')
         self.result.samples = sampler.chain[0, :, self.nburn:, :].reshape(
             (-1, self.ndim))
+        self.result.betas = sampler.betas
+        self.result.log_evidence, self.result.log_evidence_err =\
+            sampler.log_evidence_estimate(
+                sampler.loglikelihood, self.nburn / self.nsteps)
         self.result.walkers = sampler.chain[0, :, :, :]
-        self.result.log_evidence = np.nan
-        self.result.log_evidence_err = np.nan
 
-        logger.info("Max autocorr time = {}"
-                    .format(np.max(sampler.get_autocorr_time())))
-        logger.info("Tswap frac = {}"
-                    .format(sampler.tswap_acceptance_fraction))
         return self.result
-- 
GitLab