From 214b16a99e959d33a346807d6f1fa079df8075b2 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Tue, 24 Mar 2020 21:08:32 -0700
Subject: [PATCH] Add flag for storing walkers and exit signal

---
 bilby/core/sampler/ptemcee.py | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py
index 15edd1edc..f314563e5 100644
--- a/bilby/core/sampler/ptemcee.py
+++ b/bilby/core/sampler/ptemcee.py
@@ -49,7 +49,7 @@ class Ptemcee(MCMCSampler):
                  resume=True, nsamples=5000, burn_in_nact=50, thin_by_nact=1,
                  autocorr_c=5, safety=1, frac_threshold=0.01,
                  autocorr_tol=50, min_tau=1, check_point_deltaT=600, 
-                 threads=1, **kwargs):
+                 threads=1, exit_code=77, store_walkers=False, **kwargs):
         super(Ptemcee, self).__init__(
             likelihood=likelihood, priors=priors, outdir=outdir,
             label=label, use_ratio=use_ratio, plot=plot,
@@ -71,8 +71,10 @@ class Ptemcee(MCMCSampler):
         self.check_point_deltaT = check_point_deltaT
 
         self.threads = threads
+        self.store_walkers = store_walkers
 
         self.resume_file = "{}/{}_checkpoint_resume.pickle".format(self.outdir, self.label)
+        self.exit_code = exit_code
 
     @property
     def sampler_function_kwargs(self):
@@ -256,7 +258,6 @@ class Ptemcee(MCMCSampler):
 
         # Get 0-likelihood samples and store in the result
         samples = sampler.chain[0, :, :, :]  # nwalkers, nsteps, ndim
-        self.result.walkers = samples[:, :sampler.time:, :]
         self.result.samples = (
             samples[:, self.nburn: sampler.time:self.thin, :].reshape((-1, self.ndim)))
         loglikelihood = sampler.loglikelihood[
@@ -264,7 +265,8 @@ class Ptemcee(MCMCSampler):
         ]  # nwalkers, nsteps
         self.result.log_likelihood_evaluations = loglikelihood.reshape((-1))
 
-        self.result.walkers = self.sampler.chain
+        if self.store_walkers:
+            self.result.walkers = self.sampler.chain
         self.result.nburn = self.nburn
 
         log_evidence, log_evidence_err = compute_evidence(
@@ -283,7 +285,7 @@ class Ptemcee(MCMCSampler):
             self.write_current_state(plot=False)
             logger.warning("Closing pool")
             self.pool.close()
-        sys.exit(77)
+        sys.exit(self.exit_code)
 
     def write_current_state(self, plot=True):
         checkpoint(self.outdir, self.label, self.nsamples_effective,
-- 
GitLab