From be3a886a1666e60d288811b604ab549997d858e0 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Thu, 7 Feb 2019 20:33:26 -0600
Subject: [PATCH] basic checkpointing and resuming

---
 CHANGELOG.md                |  1 +
 bilby/core/sampler/emcee.py | 88 +++++++++++++++++++++++++++++--------
 2 files changed, 71 insertions(+), 18 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index f22560ff..2a08a648 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -3,6 +3,7 @@
 ## Unreleased
 
 ### Added
+- `emcee` now writes all progress to disk and can resume from a previous run.
 - 
 
 ### Changed
diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py
index 56e6887d..b52c2bee 100644
--- a/bilby/core/sampler/emcee.py
+++ b/bilby/core/sampler/emcee.py
@@ -1,10 +1,13 @@
 from __future__ import absolute_import, print_function
 
+import os
+
 import numpy as np
 from pandas import DataFrame
 from distutils.version import LooseVersion
 
-from ..utils import logger, get_progress_bar
+from ..utils import (
+    logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
 from .base_sampler import MCMCSampler, SamplerError
 
 
@@ -41,19 +44,23 @@ class Emcee(MCMCSampler):
     default_kwargs = dict(nwalkers=500, a=2, args=[], kwargs={},
                           postargs=None, pool=None, live_dangerously=False,
                           runtime_sortingfn=None, lnprob0=None, rstate0=None,
-                          blobs0=None, iterations=100, thin=1, storechain=True, mh_proposal=None)
+                          blobs0=None, iterations=100, thin=1, storechain=True,
+                          mh_proposal=None)
 
-    def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False,
-                 skip_import_verification=False, pos0=None, nburn=None, burn_in_fraction=0.25,
+    def __init__(self, likelihood, priors, outdir='outdir', label='label',
+                 use_ratio=False, plot=False, skip_import_verification=False,
+                 pos0=None, nburn=None, burn_in_fraction=0.25, resume=True,
                  burn_in_act=3, **kwargs):
-        MCMCSampler.__init__(self, likelihood=likelihood, priors=priors, outdir=outdir, label=label,
-                             use_ratio=use_ratio, plot=plot,
-                             skip_import_verification=skip_import_verification,
-                             **kwargs)
+        MCMCSampler.__init__(
+            self, likelihood=likelihood, priors=priors, outdir=outdir,
+            label=label, use_ratio=use_ratio, plot=plot,
+            skip_import_verification=skip_import_verification, **kwargs)
+        self.resume = resume
         self.pos0 = pos0
         self.nburn = nburn
         self.burn_in_fraction = burn_in_fraction
         self.burn_in_act = burn_in_act
+        self._old_chain = None
 
     def _translate_kwargs(self, kwargs):
         if 'nwalkers' not in kwargs:
@@ -168,23 +175,54 @@ class Emcee(MCMCSampler):
         import emcee
         tqdm = get_progress_bar()
         sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
-        self._set_pos0()
-        for _ in tqdm(sampler.sample(**self.sampler_function_kwargs),
-                      total=self.nsteps):
-            pass
+        out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
+        out_file = os.path.join(out_dir, 'chain.dat')
+
+        if self.resume:
+            self.load_old_chain(out_file)
+        else:
+            self._set_pos0()
+
+        check_directory_exists_and_if_not_mkdir(out_dir)
+        if not os.path.isfile(out_file):
+            with open(out_file, "w") as ff:
+                ff.write('walker\t{}\tlog_l'.format(
+                    '\t'.join(self.search_parameter_keys)))
+        template =\
+            '{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
+
+        for sample in tqdm(sampler.sample(**self.sampler_function_kwargs),
+                           total=self.nsteps):
+            points = np.hstack([sample[0], np.array(sample[3])])
+            # import IPython; IPython.embed()
+            with open(out_file, "a") as ff:
+                for ii, point in enumerate(points):
+                    ff.write(template.format(ii, *point))
+
         self.result.sampler_output = np.nan
-        self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim)))
+        blobs_flat = np.array(sampler.blobs).reshape((-1, 2))
+        log_likelihoods, log_priors = blobs_flat.T
+        if self._old_chain is not None:
+            chain = np.vstack([self._old_chain[:, :-2],
+                               sampler.chain.reshape((-1, self.ndim))])
+            log_ls = np.hstack([self._old_chain[:, -2], log_likelihoods])
+            log_ps = np.hstack([self._old_chain[:, -1], log_priors])
+            self.nsteps = chain.shape[0] // self.nwalkers
+        else:
+            chain = sampler.chain.reshape((-1, self.ndim))
+            log_ls = log_likelihoods
+            log_ps = log_priors
+        self.calculate_autocorrelation(chain)
         self.print_nburn_logging_info()
         self.result.nburn = self.nburn
+        n_samples = self.nwalkers * self.nburn
         if self.result.nburn > self.nsteps:
             raise SamplerError(
                 "The run has finished, but the chain is not burned in: "
                 "`nburn < nsteps`. Try increasing the number of steps.")
-        self.result.samples = sampler.chain[:, self.nburn:, :].reshape((-1, self.ndim))
-        blobs_flat = np.array(sampler.blobs)[self.nburn:, :, :].reshape((-1, 2))
-        log_likelihoods, log_priors = blobs_flat.T
-        self.result.log_likelihood_evaluations = log_likelihoods
-        self.result.log_prior_evaluations = log_priors
+        self.result.samples = chain[n_samples:, :]
+        self.result.log_likelihood_evaluations = log_ls[n_samples:]
+        self.result.log_prior_evaluations = log_ps[n_samples:]
         self.result.walkers = sampler.chain
         self.result.log_evidence = np.nan
         self.result.log_evidence_err = np.nan
@@ -209,6 +247,20 @@ class Emcee(MCMCSampler):
             self.pos0 = [self.get_random_draw_from_prior()
                          for _ in range(self.nwalkers)]
 
+    def load_old_chain(self, file_name=None):
+        if file_name is None:
+            out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
+            file_name = os.path.join(out_dir, 'chain.dat')
+        if os.path.isfile(file_name):
+            old_chain = np.genfromtxt(file_name, skip_header=1)
+            self.pos0 = [np.squeeze(old_chain[-(self.nwalkers - ii), 1:-2])
+                         for ii in range(self.nwalkers)]
+            self._old_chain = old_chain[:-self.nwalkers + 1, 1:]
+            logger.info('Resuming from {}'.format(os.path.abspath(file_name)))
+        else:
+            logger.warning('Failed to resume. {} not found.'.format(file_name))
+            self._set_pos0()
+
     def lnpostfn(self, theta):
         log_prior = self.log_prior(theta)
         if np.isinf(log_prior):
-- 
GitLab