From 145d3989afd1d2933773262434d7337927fb1af4 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Wed, 4 Jul 2018 16:37:24 +1000
Subject: [PATCH] Adds in the ability to pass through pos0 of the walkers

---
 tupak/core/sampler.py | 21 +++++++++++++++++++--
 1 file changed, 19 insertions(+), 2 deletions(-)

diff --git a/tupak/core/sampler.py b/tupak/core/sampler.py
index 22e0797e..89d542df 100644
--- a/tupak/core/sampler.py
+++ b/tupak/core/sampler.py
@@ -348,8 +348,15 @@ class Sampler(object):
     def _log_summary_for_sampler(self):
         """Print a summary of the sampler used and its kwargs"""
         if self.cached_result is None:
+            kwargs_print = self.kwargs.copy()
+            for k in kwargs_print:
+                if type(kwargs_print[k]) in (list, np.ndarray):
+                    array_repr = np.array(kwargs_print[k])
+                    if array_repr.shape > 10:
+                        kwargs_print[k] = ('array_like, shape={}'
+                                           .format(array_repr.shape))
             logging.info("Using sampler {} with kwargs {}".format(
-                self.__class__.__name__, self.kwargs))
+                self.__class__.__name__, kwargs_print))
 
 
 class Nestle(Sampler):
@@ -764,7 +771,17 @@ class Emcee(Sampler):
         sampler = emcee.EnsembleSampler(
             nwalkers=self.nwalkers, dim=self.ndim, lnpostfn=self.lnpostfn,
             a=a)
-        pos0 = [self.get_random_draw_from_prior() for i in range(self.nwalkers)]
+
+        if 'pos0' in self.kwargs:
+            logging.debug("Using given initial positions for walkers")
+            pos0 = np.squeeze(self.kwargs['pos0'])
+            if pos0.shape != (self.ndim, self.nwalkers):
+                raise ValueError(
+                    'Input pos0 should be of shape ndim, nwalkers')
+        else:
+            logging.debug("Generating initial walker positions from prior")
+            pos0 = [self.get_random_draw_from_prior()
+                    for i in range(self.nwalkers)]
 
         for result in tqdm(
                 sampler.sample(pos0, iterations=self.nsteps), total=self.nsteps):
-- 
GitLab