From 0857737ef8d94e688f89dfa88c7cb64ff46bc0d2 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Wed, 5 Feb 2020 21:57:38 -0600
Subject: [PATCH] Clean up of the prior sampling mechanism

Introduces a sample_from_constrain_prior_array method to facilitate
drawn an ordered array of samples. Removes redudant code.
---
 bilby/core/prior/dict.py           | 20 ++++++++++++++
 bilby/core/sampler/base_sampler.py | 35 +++++++-----------------
 test/prior_test.py                 |  7 +++++
 test/sampler_test.py               | 44 ++++++++++++++++++++----------
 4 files changed, 67 insertions(+), 39 deletions(-)

diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py
index a88f1fd31..79b9522cf 100644
--- a/bilby/core/prior/dict.py
+++ b/bilby/core/prior/dict.py
@@ -306,6 +306,26 @@ class PriorDict(dict):
         """
         return self.sample_subset_constrained(keys=list(self.keys()), size=size)
 
+    def sample_subset_constrained_as_array(self, keys=iter([]), size=None):
+        """ Return an array of samples
+
+        Parameters
+        ----------
+        keys: list
+            A list of keys to sample in
+        size: int
+            The number of samples to draw
+
+        Returns
+        -------
+        array: array_like
+            An array of shape (len(key), size) of the samples (ordered by keys)
+        """
+        samples_dict = self.sample_subset_constrained(keys=keys, size=size)
+        samples_dict = {key: np.atleast_1d(val) for key, val in samples_dict.items()}
+        samples_list = [samples_dict[key] for key in keys]
+        return np.array(samples_list)
+
     def sample_subset(self, keys=iter([]), size=None):
         """Draw samples from the prior set for parameters which are not a DeltaFunction
 
diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index 4b5cc048f..9d72464c2 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -5,7 +5,7 @@ import numpy as np
 from pandas import DataFrame
 
 from ..utils import logger, command_line_args, Counter
-from ..prior import Prior, PriorDict, ConditionalPriorDict, DeltaFunction, Constraint
+from ..prior import Prior, PriorDict, DeltaFunction, Constraint
 from ..result import Result, read_in_result
 
 
@@ -251,19 +251,13 @@ class Sampler(object):
         AttributeError
             prior can't be sampled.
         """
-        if isinstance(self.priors, ConditionalPriorDict):
+        for key in self.priors:
+            if isinstance(self.priors[key], Constraint):
+                continue
             try:
-                self.likelihood.parameters = self.priors.sample()
+                self.priors[key].sample()
             except AttributeError as e:
-                logger.warning('Cannot sample from prior, {}'.format(e))
-        else:
-            for key in self.priors:
-                if isinstance(self.priors[key], Constraint):
-                    continue
-                try:
-                    self.likelihood.parameters[key] = self.priors[key].sample()
-                except AttributeError as e:
-                    logger.warning('Cannot sample from {}, {}'.format(key, e))
+                logger.warning('Cannot sample from {}, {}'.format(key, e))
 
     def _verify_parameters(self):
         """ Evaluate a set of parameters drawn from the prior
@@ -281,13 +275,8 @@ class Sampler(object):
             raise IllegalSamplingSetError(
                 "Your sampling set contains redundant parameters.")
 
-        self._check_if_priors_can_be_sampled()
-        if isinstance(self.priors, ConditionalPriorDict):
-            theta = self.priors.sample()
-            theta = [theta[key] for key in self._search_parameter_keys]
-        else:
-            theta = [self.priors[key].sample()
-                     for key in self._search_parameter_keys]
+        theta = self.priors.sample_subset_constrained_as_array(
+            self.search_parameter_keys, size=1)[:, 0]
         try:
             self.log_likelihood(theta)
         except TypeError as e:
@@ -308,12 +297,8 @@ class Sampler(object):
 
         t1 = datetime.datetime.now()
         for _ in range(n_evaluations):
-            if isinstance(self.priors, ConditionalPriorDict):
-                theta = self.priors.sample()
-                theta = [theta[key] for key in self._search_parameter_keys]
-            else:
-                theta = [self.priors[key].sample()
-                         for key in self._search_parameter_keys]
+            theta = self.priors.sample_subset_constrained_as_array(
+                self._search_parameter_keys, size=1)[:, 0]
             self.log_likelihood(theta)
         total_time = (datetime.datetime.now() - t1).total_seconds()
         self._log_likelihood_eval_time = total_time / n_evaluations
diff --git a/test/prior_test.py b/test/prior_test.py
index 4d5e38360..d579e31a6 100644
--- a/test/prior_test.py
+++ b/test/prior_test.py
@@ -867,6 +867,13 @@ class TestPriorDict(unittest.TestCase):
         expected = dict(length=np.array([42., 42., 42.]))
         self.assertTrue(np.array_equal(expected['length'], samples['length']))
 
+    def test_sample_subset_constrained_as_array(self):
+        size = 3
+        keys = ["mass", "speed"]
+        out = self.prior_set_from_dict.sample_subset_constrained_as_array(keys, size)
+        self.assertTrue(isinstance(out, np.ndarray))
+        self.assertTrue(out.shape == (len(keys), size))
+
     def test_sample(self):
         size = 7
         np.random.seed(42)
diff --git a/test/sampler_test.py b/test/sampler_test.py
index beb0c65bd..43d3c0ded 100644
--- a/test/sampler_test.py
+++ b/test/sampler_test.py
@@ -98,7 +98,9 @@ class TestCPNest(unittest.TestCase):
 
     def setUp(self):
         self.likelihood = MagicMock()
-        self.priors = dict()
+        self.priors = bilby.core.prior.PriorDict(
+            dict(a=bilby.core.prior.Uniform(0, 1),
+                 b=bilby.core.prior.Uniform(0, 1)))
         self.sampler = bilby.core.sampler.Cpnest(self.likelihood, self.priors,
                                                  outdir='outdir', label='label',
                                                  use_ratio=False, plot=False,
@@ -133,9 +135,9 @@ class TestDynesty(unittest.TestCase):
 
     def setUp(self):
         self.likelihood = MagicMock()
-        self.priors = bilby.core.prior.PriorDict()
-        self.priors['a'] = bilby.core.prior.Prior()
-        self.priors['b'] = bilby.core.prior.Prior()
+        self.priors = bilby.core.prior.PriorDict(
+            dict(a=bilby.core.prior.Uniform(0, 1),
+                 b=bilby.core.prior.Uniform(0, 1)))
         self.sampler = bilby.core.sampler.Dynesty(self.likelihood, self.priors,
                                                   outdir='outdir', label='label',
                                                   use_ratio=False, plot=False,
@@ -208,7 +210,9 @@ class TestEmcee(unittest.TestCase):
 
     def setUp(self):
         self.likelihood = MagicMock()
-        self.priors = dict()
+        self.priors = bilby.core.prior.PriorDict(
+            dict(a=bilby.core.prior.Uniform(0, 1),
+                 b=bilby.core.prior.Uniform(0, 1)))
         self.sampler = bilby.core.sampler.Emcee(self.likelihood, self.priors,
                                                 outdir='outdir', label='label',
                                                 use_ratio=False, plot=False,
@@ -244,11 +248,13 @@ class TestKombine(unittest.TestCase):
 
     def setUp(self):
         self.likelihood = MagicMock()
-        self.priors = dict()
+        self.priors = bilby.core.prior.PriorDict(
+            dict(a=bilby.core.prior.Uniform(0, 1),
+                 b=bilby.core.prior.Uniform(0, 1)))
         self.sampler = bilby.core.sampler.Kombine(self.likelihood, self.priors,
-                                                outdir='outdir', label='label',
-                                                use_ratio=False, plot=False,
-                                                skip_import_verification=True)
+                                                  outdir='outdir', label='label',
+                                                  use_ratio=False, plot=False,
+                                                  skip_import_verification=True)
 
     def tearDown(self):
         del self.likelihood
@@ -279,7 +285,9 @@ class TestNestle(unittest.TestCase):
 
     def setUp(self):
         self.likelihood = MagicMock()
-        self.priors = dict()
+        self.priors = bilby.core.prior.PriorDict(
+            dict(a=bilby.core.prior.Uniform(0, 1),
+                 b=bilby.core.prior.Uniform(0, 1)))
         self.sampler = bilby.core.sampler.Nestle(self.likelihood, self.priors,
                                                  outdir='outdir', label='label',
                                                  use_ratio=False, plot=False,
@@ -316,7 +324,9 @@ class TestPolyChord(unittest.TestCase):
 
     def setUp(self):
         self.likelihood = MagicMock()
-        self.priors = dict(a=bilby.prior.Uniform(0, 1))
+        self.priors = bilby.core.prior.PriorDict(
+            dict(a=bilby.core.prior.Uniform(0, 1),
+                 b=bilby.core.prior.Uniform(0, 1)))
         self.sampler = bilby.core.sampler.PyPolyChord(self.likelihood, self.priors,
                                                       outdir='outdir', label='polychord',
                                                       use_ratio=False, plot=False,
@@ -363,7 +373,9 @@ class TestPTEmcee(unittest.TestCase):
 
     def setUp(self):
         self.likelihood = MagicMock()
-        self.priors = dict()
+        self.priors = bilby.core.prior.PriorDict(
+            dict(a=bilby.core.prior.Uniform(0, 1),
+                 b=bilby.core.prior.Uniform(0, 1)))
         self.sampler = bilby.core.sampler.Ptemcee(self.likelihood, self.priors,
                                                   outdir='outdir', label='label',
                                                   use_ratio=False, plot=False,
@@ -410,7 +422,9 @@ class TestPyMC3(unittest.TestCase):
 
     def setUp(self):
         self.likelihood = MagicMock()
-        self.priors = dict()
+        self.priors = bilby.core.prior.PriorDict(
+            dict(a=bilby.core.prior.Uniform(0, 1),
+                 b=bilby.core.prior.Uniform(0, 1)))
         self.sampler = bilby.core.sampler.Pymc3(self.likelihood, self.priors,
                                                 outdir='outdir', label='label',
                                                 use_ratio=False, plot=False,
@@ -448,7 +462,9 @@ class TestPymultinest(unittest.TestCase):
 
     def setUp(self):
         self.likelihood = MagicMock()
-        self.priors = bilby.core.prior.PriorDict()
+        self.priors = bilby.core.prior.PriorDict(
+            dict(a=bilby.core.prior.Uniform(0, 1),
+                 b=bilby.core.prior.Uniform(0, 1)))
         self.priors['a'] = bilby.core.prior.Prior(boundary='periodic')
         self.priors['b'] = bilby.core.prior.Prior(boundary='reflective')
         self.sampler = bilby.core.sampler.Pymultinest(self.likelihood, self.priors,
-- 
GitLab