From 0857737ef8d94e688f89dfa88c7cb64ff46bc0d2 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Wed, 5 Feb 2020 21:57:38 -0600 Subject: [PATCH] Clean up of the prior sampling mechanism Introduces a sample_from_constrain_prior_array method to facilitate drawn an ordered array of samples. Removes redudant code. --- bilby/core/prior/dict.py | 20 ++++++++++++++ bilby/core/sampler/base_sampler.py | 35 +++++++----------------- test/prior_test.py | 7 +++++ test/sampler_test.py | 44 ++++++++++++++++++++---------- 4 files changed, 67 insertions(+), 39 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index a88f1fd31..79b9522cf 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -306,6 +306,26 @@ class PriorDict(dict): """ return self.sample_subset_constrained(keys=list(self.keys()), size=size) + def sample_subset_constrained_as_array(self, keys=iter([]), size=None): + """ Return an array of samples + + Parameters + ---------- + keys: list + A list of keys to sample in + size: int + The number of samples to draw + + Returns + ------- + array: array_like + An array of shape (len(key), size) of the samples (ordered by keys) + """ + samples_dict = self.sample_subset_constrained(keys=keys, size=size) + samples_dict = {key: np.atleast_1d(val) for key, val in samples_dict.items()} + samples_list = [samples_dict[key] for key in keys] + return np.array(samples_list) + def sample_subset(self, keys=iter([]), size=None): """Draw samples from the prior set for parameters which are not a DeltaFunction diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index 4b5cc048f..9d72464c2 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -5,7 +5,7 @@ import numpy as np from pandas import DataFrame from ..utils import logger, command_line_args, Counter -from ..prior import Prior, PriorDict, ConditionalPriorDict, DeltaFunction, Constraint +from ..prior import Prior, PriorDict, DeltaFunction, Constraint from ..result import Result, read_in_result @@ -251,19 +251,13 @@ class Sampler(object): AttributeError prior can't be sampled. """ - if isinstance(self.priors, ConditionalPriorDict): + for key in self.priors: + if isinstance(self.priors[key], Constraint): + continue try: - self.likelihood.parameters = self.priors.sample() + self.priors[key].sample() except AttributeError as e: - logger.warning('Cannot sample from prior, {}'.format(e)) - else: - for key in self.priors: - if isinstance(self.priors[key], Constraint): - continue - try: - self.likelihood.parameters[key] = self.priors[key].sample() - except AttributeError as e: - logger.warning('Cannot sample from {}, {}'.format(key, e)) + logger.warning('Cannot sample from {}, {}'.format(key, e)) def _verify_parameters(self): """ Evaluate a set of parameters drawn from the prior @@ -281,13 +275,8 @@ class Sampler(object): raise IllegalSamplingSetError( "Your sampling set contains redundant parameters.") - self._check_if_priors_can_be_sampled() - if isinstance(self.priors, ConditionalPriorDict): - theta = self.priors.sample() - theta = [theta[key] for key in self._search_parameter_keys] - else: - theta = [self.priors[key].sample() - for key in self._search_parameter_keys] + theta = self.priors.sample_subset_constrained_as_array( + self.search_parameter_keys, size=1)[:, 0] try: self.log_likelihood(theta) except TypeError as e: @@ -308,12 +297,8 @@ class Sampler(object): t1 = datetime.datetime.now() for _ in range(n_evaluations): - if isinstance(self.priors, ConditionalPriorDict): - theta = self.priors.sample() - theta = [theta[key] for key in self._search_parameter_keys] - else: - theta = [self.priors[key].sample() - for key in self._search_parameter_keys] + theta = self.priors.sample_subset_constrained_as_array( + self._search_parameter_keys, size=1)[:, 0] self.log_likelihood(theta) total_time = (datetime.datetime.now() - t1).total_seconds() self._log_likelihood_eval_time = total_time / n_evaluations diff --git a/test/prior_test.py b/test/prior_test.py index 4d5e38360..d579e31a6 100644 --- a/test/prior_test.py +++ b/test/prior_test.py @@ -867,6 +867,13 @@ class TestPriorDict(unittest.TestCase): expected = dict(length=np.array([42., 42., 42.])) self.assertTrue(np.array_equal(expected['length'], samples['length'])) + def test_sample_subset_constrained_as_array(self): + size = 3 + keys = ["mass", "speed"] + out = self.prior_set_from_dict.sample_subset_constrained_as_array(keys, size) + self.assertTrue(isinstance(out, np.ndarray)) + self.assertTrue(out.shape == (len(keys), size)) + def test_sample(self): size = 7 np.random.seed(42) diff --git a/test/sampler_test.py b/test/sampler_test.py index beb0c65bd..43d3c0ded 100644 --- a/test/sampler_test.py +++ b/test/sampler_test.py @@ -98,7 +98,9 @@ class TestCPNest(unittest.TestCase): def setUp(self): self.likelihood = MagicMock() - self.priors = dict() + self.priors = bilby.core.prior.PriorDict( + dict(a=bilby.core.prior.Uniform(0, 1), + b=bilby.core.prior.Uniform(0, 1))) self.sampler = bilby.core.sampler.Cpnest(self.likelihood, self.priors, outdir='outdir', label='label', use_ratio=False, plot=False, @@ -133,9 +135,9 @@ class TestDynesty(unittest.TestCase): def setUp(self): self.likelihood = MagicMock() - self.priors = bilby.core.prior.PriorDict() - self.priors['a'] = bilby.core.prior.Prior() - self.priors['b'] = bilby.core.prior.Prior() + self.priors = bilby.core.prior.PriorDict( + dict(a=bilby.core.prior.Uniform(0, 1), + b=bilby.core.prior.Uniform(0, 1))) self.sampler = bilby.core.sampler.Dynesty(self.likelihood, self.priors, outdir='outdir', label='label', use_ratio=False, plot=False, @@ -208,7 +210,9 @@ class TestEmcee(unittest.TestCase): def setUp(self): self.likelihood = MagicMock() - self.priors = dict() + self.priors = bilby.core.prior.PriorDict( + dict(a=bilby.core.prior.Uniform(0, 1), + b=bilby.core.prior.Uniform(0, 1))) self.sampler = bilby.core.sampler.Emcee(self.likelihood, self.priors, outdir='outdir', label='label', use_ratio=False, plot=False, @@ -244,11 +248,13 @@ class TestKombine(unittest.TestCase): def setUp(self): self.likelihood = MagicMock() - self.priors = dict() + self.priors = bilby.core.prior.PriorDict( + dict(a=bilby.core.prior.Uniform(0, 1), + b=bilby.core.prior.Uniform(0, 1))) self.sampler = bilby.core.sampler.Kombine(self.likelihood, self.priors, - outdir='outdir', label='label', - use_ratio=False, plot=False, - skip_import_verification=True) + outdir='outdir', label='label', + use_ratio=False, plot=False, + skip_import_verification=True) def tearDown(self): del self.likelihood @@ -279,7 +285,9 @@ class TestNestle(unittest.TestCase): def setUp(self): self.likelihood = MagicMock() - self.priors = dict() + self.priors = bilby.core.prior.PriorDict( + dict(a=bilby.core.prior.Uniform(0, 1), + b=bilby.core.prior.Uniform(0, 1))) self.sampler = bilby.core.sampler.Nestle(self.likelihood, self.priors, outdir='outdir', label='label', use_ratio=False, plot=False, @@ -316,7 +324,9 @@ class TestPolyChord(unittest.TestCase): def setUp(self): self.likelihood = MagicMock() - self.priors = dict(a=bilby.prior.Uniform(0, 1)) + self.priors = bilby.core.prior.PriorDict( + dict(a=bilby.core.prior.Uniform(0, 1), + b=bilby.core.prior.Uniform(0, 1))) self.sampler = bilby.core.sampler.PyPolyChord(self.likelihood, self.priors, outdir='outdir', label='polychord', use_ratio=False, plot=False, @@ -363,7 +373,9 @@ class TestPTEmcee(unittest.TestCase): def setUp(self): self.likelihood = MagicMock() - self.priors = dict() + self.priors = bilby.core.prior.PriorDict( + dict(a=bilby.core.prior.Uniform(0, 1), + b=bilby.core.prior.Uniform(0, 1))) self.sampler = bilby.core.sampler.Ptemcee(self.likelihood, self.priors, outdir='outdir', label='label', use_ratio=False, plot=False, @@ -410,7 +422,9 @@ class TestPyMC3(unittest.TestCase): def setUp(self): self.likelihood = MagicMock() - self.priors = dict() + self.priors = bilby.core.prior.PriorDict( + dict(a=bilby.core.prior.Uniform(0, 1), + b=bilby.core.prior.Uniform(0, 1))) self.sampler = bilby.core.sampler.Pymc3(self.likelihood, self.priors, outdir='outdir', label='label', use_ratio=False, plot=False, @@ -448,7 +462,9 @@ class TestPymultinest(unittest.TestCase): def setUp(self): self.likelihood = MagicMock() - self.priors = bilby.core.prior.PriorDict() + self.priors = bilby.core.prior.PriorDict( + dict(a=bilby.core.prior.Uniform(0, 1), + b=bilby.core.prior.Uniform(0, 1))) self.priors['a'] = bilby.core.prior.Prior(boundary='periodic') self.priors['b'] = bilby.core.prior.Prior(boundary='reflective') self.sampler = bilby.core.sampler.Pymultinest(self.likelihood, self.priors, -- GitLab