From 76482db77bbec4217869cddd0b918dd74e65587d Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Mon, 7 Mar 2022 22:57:37 +0000
Subject: [PATCH] Allow ptemcee to be initialized from a dictionary

---
 bilby/core/sampler/ptemcee.py     | 30 ++++++++---
 test/core/sampler/ptemcee_test.py | 90 ++++++++++++++++++++-----------
 2 files changed, 83 insertions(+), 37 deletions(-)

diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py
index 4e60496fe..69b6027c0 100644
--- a/bilby/core/sampler/ptemcee.py
+++ b/bilby/core/sampler/ptemcee.py
@@ -64,7 +64,7 @@ class Ptemcee(MCMCSampler):
     autocorr_c: int, (5)
         The step size for the window search used by emcee.autocorr.integrated_time
     safety: int, (1)
-        A multiplicitive factor for the estimated autocorrelation. Useful for
+        A multiplicative factor for the estimated autocorrelation. Useful for
         cases where non-convergence can be observed by eye but the automated
         tools are failing.
     autocorr_tau: int, (1)
@@ -92,14 +92,18 @@ class Ptemcee(MCMCSampler):
         is not recommended for cases where tau is large.
     ignore_keys_for_tau: str
         A pattern used to ignore keys in estimating the autocorrelation time.
-    pos0: str, list, np.ndarray
+    pos0: str, list, np.ndarray, dict
         If a string, one of "prior" or "minimize". For "prior", the initial
         positions of the sampler are drawn from the sampler. If "minimize",
         a scipy.optimize step is applied to all parameters a number of times.
         The walkers are then initialized from the range of values obtained.
         If a list, for the keys in the list the optimization step is applied,
-        otherwise the initial points are drawn from the prior. If a numpy array
-        the shape should be (ntemps, nwalkers, ndim).
+        otherwise the initial points are drawn from the prior.
+        If a :code:`numpy` array the shape should be
+        :code:`(ntemps, nwalkers, ndim)`.
+        If a :code:`dict`, this should be a dictionary with keys matching the
+        :code:`search_parameter_keys`. Each entry should be an array with
+        shape :code:`(ntemps, nwalkers)`.
 
     niterations_per_check: int (5)
         The number of iteration steps to take before checking ACT. This
@@ -282,13 +286,13 @@ class Ptemcee(MCMCSampler):
 
         """
         logger.info("Generating pos0 samples")
-        return [
+        return np.array([
             [
                 self.get_random_draw_from_prior()
                 for _ in range(self.nwalkers)
             ]
             for _ in range(self.kwargs["ntemps"])
-        ]
+        ])
 
     def get_pos0_from_minimize(self, minimize_list=None):
         """ Draw the initial positions using an initial minimization step
@@ -376,6 +380,18 @@ class Ptemcee(MCMCSampler):
         else:
             return self.pos0
 
+    def get_pos0_from_dict(self):
+        """
+        Initialize the starting points from a passed dictionary.
+
+        The :code:`pos0` passed to the :code:`Sampler` should be a dictionary
+        with keys matching the :code:`search_parameter_keys`.
+        Each entry should have shape :code:`(ntemps, nwalkers)`.
+        """
+        pos0 = np.array([self.pos0[key] for key in self.search_parameter_keys])
+        self.pos0 = np.moveaxis(pos0, 0, -1)
+        return self.get_pos0_from_array()
+
     def setup_sampler(self):
         """ Either initialize the sampler or read in the resume file """
         import ptemcee
@@ -461,6 +477,8 @@ class Ptemcee(MCMCSampler):
             return self.get_pos0_from_minimize(minimize_list=self.pos0)
         elif isinstance(self.pos0, np.ndarray):
             return self.get_pos0_from_array()
+        elif isinstance(self.pos0, dict):
+            return self.get_pos0_from_dict()
         else:
             raise SamplerError("pos0={} not implemented".format(self.pos0))
 
diff --git a/test/core/sampler/ptemcee_test.py b/test/core/sampler/ptemcee_test.py
index 2bc4d6580..65c49c4e2 100644
--- a/test/core/sampler/ptemcee_test.py
+++ b/test/core/sampler/ptemcee_test.py
@@ -1,32 +1,31 @@
 import unittest
-from unittest.mock import MagicMock
 
-import bilby
+from bilby.core.likelihood import GaussianLikelihood
+from bilby.core.prior import Uniform, PriorDict
+from bilby.core.sampler import Ptemcee
+from bilby.core.sampler.base_sampler import MCMCSampler
+import numpy as np
 
 
 class TestPTEmcee(unittest.TestCase):
     def setUp(self):
-        self.likelihood = MagicMock()
-        self.priors = bilby.core.prior.PriorDict(
-            dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1))
+        self.likelihood = GaussianLikelihood(
+            x=np.linspace(0, 1, 2),
+            y=np.linspace(0, 1, 2),
+            func=lambda x, **kwargs: x,
+            sigma=1,
         )
-        self.sampler = bilby.core.sampler.Ptemcee(
-            self.likelihood,
-            self.priors,
+        self.priors = PriorDict(dict(a=Uniform(0, 1), b=Uniform(0, 1)))
+        self._args = (self.likelihood, self.priors)
+        self._kwargs = dict(
             outdir="outdir",
             label="label",
             use_ratio=False,
             plot=False,
             skip_import_verification=True,
         )
-
-    def tearDown(self):
-        del self.likelihood
-        del self.priors
-        del self.sampler
-
-    def test_default_kwargs(self):
-        expected = dict(
+        self.sampler = Ptemcee(*self._args, **self._kwargs)
+        self.expected = dict(
             ntemps=10,
             nwalkers=100,
             Tmax=None,
@@ -38,27 +37,56 @@ class TestPTEmcee(unittest.TestCase):
             adapt=False,
             swap_ratios=False,
         )
-        self.assertDictEqual(expected, self.sampler.kwargs)
+
+    def tearDown(self):
+        del self.likelihood
+        del self.priors
+        del self.sampler
+
+    def test_default_kwargs(self):
+        self.assertDictEqual(self.expected, self.sampler.kwargs)
 
     def test_translate_kwargs(self):
-        expected = dict(
-            ntemps=10,
-            nwalkers=100,
-            Tmax=None,
-            betas=None,
-            a=2.0,
-            adaptation_lag=10000,
-            adaptation_time=100,
-            random=None,
-            adapt=False,
-            swap_ratios=False,
-        )
-        for equiv in bilby.core.sampler.base_sampler.MCMCSampler.nwalkers_equiv_kwargs:
+        for equiv in MCMCSampler.nwalkers_equiv_kwargs:
             new_kwargs = self.sampler.kwargs.copy()
             del new_kwargs["nwalkers"]
             new_kwargs[equiv] = 100
             self.sampler.kwargs = new_kwargs
-            self.assertDictEqual(expected, self.sampler.kwargs)
+            self.assertDictEqual(self.expected, self.sampler.kwargs)
+
+    def test_set_pos0_using_array(self):
+        """
+        Verify that setting the initial points from an array matches the
+        default method.
+        """
+        pos0 = self.sampler.get_pos0()
+        new_sampler = Ptemcee(*self._args, **self._kwargs, pos0=pos0)
+        self.assertTrue(np.array_equal(new_sampler.get_pos0(), pos0))
+
+    def test_set_pos0_using_dict(self):
+        """
+        Verify that setting the initial points from a dictionary matches the
+        default method.
+        """
+        old = np.array(self.sampler.get_pos0())
+        pos0 = np.moveaxis(old, -1, 0)
+        pos0 = {
+            key: points for key, points in
+            zip(self.sampler.search_parameter_keys, pos0)
+        }
+        new_sampler = Ptemcee(*self._args, **self._kwargs, pos0=pos0)
+        new = new_sampler.get_pos0()
+        self.assertTrue(np.array_equal(new, old))
+
+    def test_set_pos0_from_minimize(self):
+        """
+        Verify that the minimize method of setting the initial points
+        returns the same shape as the default.
+        """
+        old = self.sampler.get_pos0().shape
+        new_sampler = Ptemcee(*self._args, **self._kwargs, pos0="minimize")
+        new = new_sampler.get_pos0().shape
+        self.assertEqual(old, new)
 
 
 if __name__ == "__main__":
-- 
GitLab