From 200846a84fbad647659bb8a667cc54d34d496340 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Sun, 13 May 2018 19:18:20 +1000
Subject: [PATCH] Allow multiple prior draws

- Rename n_samples to size to fit with standard numpy conventions
- Fix check to handle multiple values
---
 tupak/prior.py | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/tupak/prior.py b/tupak/prior.py
index fec2d9579..69338bacf 100644
--- a/tupak/prior.py
+++ b/tupak/prior.py
@@ -17,11 +17,11 @@ class Prior(object):
         self.latex_label = latex_label
 
     def __call__(self):
-        return self.sample(1)
+        return self.sample()
 
-    def sample(self, n_samples=None):
-        """Draw a sample from the prior, this rescales a unit line element according to the rescaling function"""
-        return self.rescale(np.random.uniform(0, 1, n_samples))
+    def sample(self, size=None):
+        """Draw a sample from the prior """
+        return self.rescale(np.random.uniform(0, 1, size))
 
     def rescale(self, val):
         """
@@ -34,7 +34,9 @@ class Prior(object):
     @staticmethod
     def test_valid_for_rescaling(val):
         """Test if 0 < val < 1"""
-        if (val < 0) or (val > 1):
+        val = np.atleast_1d(val)
+        tests = (val < 0) + (val > 1)
+        if np.any(tests):
             raise ValueError("Number to be rescaled should be in [0, 1]")
 
     def __repr__(self):
-- 
GitLab