From 57a2373714090ecac3ea9ec1e38dd7b0e112fb9c Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Tue, 2 Apr 2019 16:03:54 +1100
Subject: [PATCH] Initial work on adding checkpointing to ptemcee

---
 bilby/core/sampler/emcee.py   |  72 ++++++++++++++++----
 bilby/core/sampler/ptemcee.py | 124 ++++++++++++++++++++++++++++------
 2 files changed, 161 insertions(+), 35 deletions(-)

diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py
index 3117a6c7a..6c1e16972 100644
--- a/bilby/core/sampler/emcee.py
+++ b/bilby/core/sampler/emcee.py
@@ -65,7 +65,6 @@ class Emcee(MCMCSampler):
         self.nburn = nburn
         self.burn_in_fraction = burn_in_fraction
         self.burn_in_act = burn_in_act
-        self._old_chain = None
 
     def _translate_kwargs(self, kwargs):
         if 'nwalkers' not in kwargs:
@@ -173,10 +172,7 @@ class Emcee(MCMCSampler):
         d["_Sampler__kwargs"]["pool"] = None
         return d
 
-    def run_sampler(self):
-        import emcee
-        tqdm = get_progress_bar()
-        sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
+    def set_up_checkpoint(self):
         out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
         out_file = os.path.join(out_dir, 'chain.dat')
 
@@ -188,13 +184,26 @@ class Emcee(MCMCSampler):
         check_directory_exists_and_if_not_mkdir(out_dir)
         if not os.path.isfile(out_file):
             with open(out_file, "w") as ff:
-                ff.write('walker\t{}\tlog_l'.format(
+                ff.write('walker\t{}\tlog_l\n'.format(
                     '\t'.join(self.search_parameter_keys)))
         template =\
             '{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
 
-        for sample in tqdm(sampler.sample(**self.sampler_function_kwargs),
-                           total=self.nsteps):
+        return out_file, template
+
+    def run_sampler(self):
+        import emcee
+        tqdm = get_progress_bar()
+        sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
+        out_file, template = self.set_up_checkpoint()
+
+        sampler_function_kwargs = self.sampler_function_kwargs
+        iterations = sampler_function_kwargs.pop('iterations')
+        iterations -= self._previous_iterations
+
+        for sample in tqdm(
+                sampler.sample(iterations=iterations, **sampler_function_kwargs),
+                total=iterations):
             if self.prerelease:
                 points = np.hstack([sample.coords, sample.blobs])
             else:
@@ -232,6 +241,9 @@ class Emcee(MCMCSampler):
         self.result.log_evidence_err = np.nan
         return self.result
 
+    def _draw_pos0_from_prior(self):
+        return [self.get_random_draw_from_prior() for _ in range(self.nwalkers)]
+
     def _set_pos0(self):
         if self.pos0 is not None:
             logger.debug("Using given initial positions for walkers")
@@ -248,19 +260,49 @@ class Emcee(MCMCSampler):
                 self.check_draw(draw)
         else:
             logger.debug("Generating initial walker positions from prior")
-            self.pos0 = [self.get_random_draw_from_prior()
-                         for _ in range(self.nwalkers)]
+            self.pos0 = self._draw_pos0_from_prior()
+
+    @property
+    def _old_chain(self):
+        try:
+            old_chain = self.__old_chain
+            n = old_chain.shape[0]
+            idx = n - np.mod(n, self.nwalkers)
+            return old_chain[:idx, :]
+        except AttributeError:
+            return None
+
+    @_old_chain.setter
+    def _old_chain(self, old_chain):
+        self.__old_chain = old_chain
+
+    @property
+    def _previous_iterations(self):
+        if self._old_chain is None:
+            return 0
+        try:
+            return self._old_chain.shape[0] // self.nwalkers
+        except AttributeError:
+            logger.warning(
+                "Unable to calculate previous iterations from checkpoint,"
+                " defaulting to zero")
+            return 0
 
     def load_old_chain(self, file_name=None):
         if file_name is None:
             out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
             file_name = os.path.join(out_dir, 'chain.dat')
         if os.path.isfile(file_name):
-            old_chain = np.genfromtxt(file_name, skip_header=1)
-            self.pos0 = [np.squeeze(old_chain[-(self.nwalkers - ii), 1:-2])
-                         for ii in range(self.nwalkers)]
-            self._old_chain = old_chain[:-self.nwalkers + 1, 1:]
-            logger.info('Resuming from {}'.format(os.path.abspath(file_name)))
+            try:
+                old_chain = np.genfromtxt(file_name, skip_header=1)
+                self.pos0 = [np.squeeze(old_chain[-(self.nwalkers - ii), 1:-2])
+                             for ii in range(self.nwalkers)]
+                self._old_chain = old_chain[:-self.nwalkers + 1, 1:]
+                logger.info('Resuming from {}'.format(os.path.abspath(file_name)))
+            except Exception:
+                logger.warning('Failed to resume. Corrupt checkpoint file {}.'
+                               .format(file_name))
+                self._set_pos0()
         else:
             logger.warning('Failed to resume. {} not found.'.format(file_name))
             self._set_pos0()
diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py
index e0c2401ca..7f02f45e1 100644
--- a/bilby/core/sampler/ptemcee.py
+++ b/bilby/core/sampler/ptemcee.py
@@ -1,8 +1,12 @@
 from __future__ import absolute_import, division, print_function
 
+import os
+from collections import namedtuple
+
 import numpy as np
 
-from ..utils import get_progress_bar
+from ..utils import (
+    logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
 from . import Emcee
 from .base_sampler import SamplerError
 
@@ -36,13 +40,14 @@ class Ptemcee(Emcee):
 
     def __init__(self, likelihood, priors, outdir='outdir', label='label',
                  use_ratio=False, plot=False, skip_import_verification=False,
-                 nburn=None, burn_in_fraction=0.25, burn_in_act=3, **kwargs):
+                 nburn=None, burn_in_fraction=0.25, burn_in_act=3, resume=True,
+                 **kwargs):
         Emcee.__init__(
             self, likelihood=likelihood, priors=priors, outdir=outdir,
             label=label, use_ratio=use_ratio, plot=plot,
             skip_import_verification=skip_import_verification,
             nburn=nburn, burn_in_fraction=burn_in_fraction,
-            burn_in_act=burn_in_act, **kwargs)
+            burn_in_act=burn_in_act, resume=True, **kwargs)
 
     @property
     def sampler_function_kwargs(self):
@@ -55,23 +60,102 @@ class Ptemcee(Emcee):
                 for key, value in self.kwargs.items()
                 if key not in self.sampler_function_kwargs}
 
+    @property
+    def checkpoint_info(self):
+        out_dir = os.path.join(self.outdir, 'ptemcee_{}'.format(self.label))
+        chain_file = os.path.join(out_dir, 'chain.dat')
+        last_pos_file = os.path.join(out_dir, 'last_pos.npy')
+
+        check_directory_exists_and_if_not_mkdir(out_dir)
+        if not os.path.isfile(chain_file):
+            with open(chain_file, "w") as ff:
+                ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
+                    '\t'.join(self.search_parameter_keys)))
+        template =\
+            '{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
+
+        CheckpointInfo = namedtuple(
+            'CheckpointInfo', ['last_pos_file', 'chain_file', 'template'])
+
+        checkpoint_info = CheckpointInfo(
+            last_pos_file=last_pos_file, chain_file=chain_file, template=template)
+
+        return checkpoint_info
+
+    def _draw_pos0_from_prior(self):
+        return [[self.get_random_draw_from_prior()
+                 for _ in range(self.nwalkers)]
+                for _ in range(self.kwargs['ntemps'])]
+
+    @property
+    def _old_chain(self):
+        try:
+            old_chain = self.__old_chain
+            n = old_chain.shape[0]
+            idx = n - np.mod(n, self.nwalkers)
+            return old_chain[:idx]
+        except AttributeError:
+            return None
+
+    @_old_chain.setter
+    def _old_chain(self, old_chain):
+        self.__old_chain = old_chain
+
+    @property
+    def stored_chain(self):
+        return np.genfromtxt(self.checkpoint_info.chain_file, names=True)
+
+    @property
+    def stored_samples(self):
+        return self.stored_chain[self.search_parameter_keys]
+
+    @property
+    def stored_loglike(self):
+        return self.stored_chain['log_l']
+
+    @property
+    def stored_logprior(self):
+        return self.stored_chain['log_p']
+
+    def load_old_chain(self):
+        try:
+            last_pos = np.load(self.checkpoint_info.last_pos_file)
+            self.pos0 = last_pos
+            self._old_chain = self.stored_samples
+            logger.info(
+                'Resuming from {} with {} iterations'.format(
+                    self.checkpoint_info.chain_file,
+                    self._previous_iterations))
+        except Exception:
+            logger.info('Unable to resume')
+            self._set_pos0()
+
     def run_sampler(self):
         import ptemcee
         tqdm = get_progress_bar()
         sampler = ptemcee.Sampler(dim=self.ndim, logl=self.log_likelihood,
                                   logp=self.log_prior, **self.sampler_init_kwargs)
-        self.pos0 = [[self.get_random_draw_from_prior()
-                      for _ in range(self.nwalkers)]
-                     for _ in range(self.kwargs['ntemps'])]
 
-        log_likelihood_evaluations = []
-        log_prior_evaluations = []
+        if self.resume:
+            self.load_old_chain()
+        else:
+            self._set_pos0()
+
+        sampler_function_kwargs = self.sampler_function_kwargs
+        iterations = sampler_function_kwargs.pop('iterations')
+        iterations -= self._previous_iterations
+
         for pos, logpost, loglike in tqdm(
-                sampler.sample(self.pos0, **self.sampler_function_kwargs),
-                total=self.nsteps):
-            log_likelihood_evaluations.append(loglike)
-            log_prior_evaluations.append(logpost - loglike)
-            pass
+                sampler.sample(self.pos0, iterations=iterations,
+                               **sampler_function_kwargs),
+                total=iterations):
+            np.save(self.checkpoint_info.last_pos_file, pos)
+            with open(self.checkpoint_info.chain_file, "a") as ff:
+                loglike = np.squeeze(loglike[:1, :])
+                logprior = np.squeeze(logpost[:1, :]) - loglike
+                for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)):
+                    line = np.concatenate((point, [logl, logp]))
+                    ff.write(self.checkpoint_info.template.format(ii, *line))
 
         self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim)))
         self.result.sampler_output = np.nan
@@ -81,16 +165,16 @@ class Ptemcee(Emcee):
             raise SamplerError(
                 "The run has finished, but the chain is not burned in: "
                 "`nburn < nsteps`. Try increasing the number of steps.")
-        self.result.samples = sampler.chain[0, :, self.nburn:, :].reshape(
-            (-1, self.ndim))
-        self.result.log_likelihood_evaluations = np.array(
-            log_likelihood_evaluations)[self.nburn:, 0, :].reshape((-1))
-        self.result.log_prior_evaluations = np.array(
-            log_prior_evaluations)[self.nburn:, 0, :].reshape((-1))
+        walkers = self.stored_samples.view((float, self.ndim))
+        walkers = walkers.reshape(self.nwalkers, self.nsteps, self.ndim)
+        self.result.walkers = walkers
+        self.result.samples = walkers[:, self.nburn:, :].reshape((-1, self.ndim))
+        n_samples = self.nwalkers * self.nburn
+        self.result.log_likelihood_evaluations = self.stored_loglike[n_samples:]
+        self.result.log_prior_evaluations = self.stored_logprior[n_samples:]
         self.result.betas = sampler.betas
         self.result.log_evidence, self.result.log_evidence_err =\
             sampler.log_evidence_estimate(
                 sampler.loglikelihood, self.nburn / self.nsteps)
-        self.result.walkers = sampler.chain[0, :, :, :]
 
         return self.result
-- 
GitLab