diff --git a/tupak/prior.py b/tupak/prior.py index fec2d957940462270e13fbd366ed41f4fab4ace0..69338bacf0ea78be929e65c885656b8bb77cbdcb 100644 --- a/tupak/prior.py +++ b/tupak/prior.py @@ -17,11 +17,11 @@ class Prior(object): self.latex_label = latex_label def __call__(self): - return self.sample(1) + return self.sample() - def sample(self, n_samples=None): - """Draw a sample from the prior, this rescales a unit line element according to the rescaling function""" - return self.rescale(np.random.uniform(0, 1, n_samples)) + def sample(self, size=None): + """Draw a sample from the prior """ + return self.rescale(np.random.uniform(0, 1, size)) def rescale(self, val): """ @@ -34,7 +34,9 @@ class Prior(object): @staticmethod def test_valid_for_rescaling(val): """Test if 0 < val < 1""" - if (val < 0) or (val > 1): + val = np.atleast_1d(val) + tests = (val < 0) + (val > 1) + if np.any(tests): raise ValueError("Number to be rescaled should be in [0, 1]") def __repr__(self):