diff --git a/tupak/prior.py b/tupak/prior.py index 69338bacf0ea78be929e65c885656b8bb77cbdcb..8eb15c937131cfdb5d69ed1d47ebf6ac12097169 100644 --- a/tupak/prior.py +++ b/tupak/prior.py @@ -10,11 +10,41 @@ import os class Prior(object): - """Prior class""" + """ + Prior class + + Methods + ------- + __init__: + Instantiate a prior object. + __call__: + Draw a single sample from the prior. + __repr__: + Print prior type and parameters. + sample(size=None): + Draw samples of size size from the prior. + rescale(val): + Rescale samples from a uniform distribution on [0, 1] to samples from the prior. + test_valid_for_recaling(val): + Test whether val is in [0, 1] and hence valid for rescaling. - def __init__(self, name=None, latex_label=None): + Parameters + ---------- + name: str + Name associated with prior. + latex_label: str + Latex label associated with prior, used for plotting. + minimum: float, optional + Minimum of the domain, default=-np.inf + maximum: float, optional + Maximum of the domain, default=np.inf + """ + + def __init__(self, name=None, latex_label=None, minimum=-np.inf, maximum=np.inf): self.name = name self.latex_label = latex_label + self.minimum = minimum + self.maximum = maximum def __call__(self): return self.sample() @@ -93,9 +123,7 @@ class Uniform(Prior): """Uniform prior""" def __init__(self, minimum, maximum, name=None, latex_label=None): - Prior.__init__(self, name, latex_label) - self.minimum = minimum - self.maximum = maximum + Prior.__init__(self, name, latex_label, minimum, maximum) self.support = maximum - minimum def rescale(self, val): @@ -104,7 +132,7 @@ class Uniform(Prior): def prob(self, val): """Return the prior probability of val""" - if (self.minimum < val) and (val < self.maximum): + if (self.minimum < val) & (val < self.maximum): return 1 / self.support else: return 0 @@ -114,7 +142,7 @@ class DeltaFunction(Prior): """Dirac delta function prior, this always returns peak.""" def __init__(self, peak, name=None, latex_label=None): - Prior.__init__(self, name, latex_label) + Prior.__init__(self, name, latex_label, minimum=peak, maximum=peak) self.peak = peak def rescale(self, val): @@ -135,10 +163,8 @@ class PowerLaw(Prior): def __init__(self, alpha, minimum, maximum, name=None, latex_label=None): """Power law with bounds and alpha, spectral index""" - Prior.__init__(self, name, latex_label) + Prior.__init__(self, name, latex_label, minimum, maximum) self.alpha = alpha - self.minimum = minimum - self.maximum = maximum def rescale(self, val): """ @@ -167,8 +193,8 @@ class PowerLaw(Prior): class Cosine(Prior): - def __init__(self, name=None, latex_label=None): - Prior.__init__(self, name, latex_label) + def __init__(self, name=None, latex_label=None, minimum=-np.pi / 2, maximum=np.pi / 2): + Prior.__init__(self, name, latex_label, minimum, maximum) def rescale(self, val): """ @@ -182,7 +208,7 @@ class Cosine(Prior): @staticmethod def prob(val): """Return the prior probability of val, defined over [-pi/2, pi/2]""" - if (val > -np.pi / 2) and (val < np.pi / 2): + if (val > np.minimum) and (val < np.maximum): return np.cos(val) / 2 else: return 0 @@ -190,8 +216,8 @@ class Cosine(Prior): class Sine(Prior): - def __init__(self, name=None, latex_label=None): - Prior.__init__(self, name, latex_label) + def __init__(self, name=None, latex_label=None, minimum=0, maximum=np.pi): + Prior.__init__(self, name, latex_label, minimum, maximum) def rescale(self, val): """