From 200846a84fbad647659bb8a667cc54d34d496340 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Sun, 13 May 2018 19:18:20 +1000 Subject: [PATCH] Allow multiple prior draws - Rename n_samples to size to fit with standard numpy conventions - Fix check to handle multiple values --- tupak/prior.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tupak/prior.py b/tupak/prior.py index fec2d9579..69338bacf 100644 --- a/tupak/prior.py +++ b/tupak/prior.py @@ -17,11 +17,11 @@ class Prior(object): self.latex_label = latex_label def __call__(self): - return self.sample(1) + return self.sample() - def sample(self, n_samples=None): - """Draw a sample from the prior, this rescales a unit line element according to the rescaling function""" - return self.rescale(np.random.uniform(0, 1, n_samples)) + def sample(self, size=None): + """Draw a sample from the prior """ + return self.rescale(np.random.uniform(0, 1, size)) def rescale(self, val): """ @@ -34,7 +34,9 @@ class Prior(object): @staticmethod def test_valid_for_rescaling(val): """Test if 0 < val < 1""" - if (val < 0) or (val > 1): + val = np.atleast_1d(val) + tests = (val < 0) + (val > 1) + if np.any(tests): raise ValueError("Number to be rescaled should be in [0, 1]") def __repr__(self): -- GitLab