diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index 3de39b4422aacbb7cab714cb3d97dc0372b3e83a..3ffcfb193df6f34ed07deb9689def2adf444bb60 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -313,6 +313,21 @@ class UniformInComponentsChirpMass(PowerLaw): name=name, latex_label=latex_label, unit=unit, boundary=boundary) +class WrappedInterp1d(interp1d): + """ A wrapper around scipy interp1d which sets equality-by-instantiation """ + def __eq__(self, other): + + for key in self.__dict__: + if type(self.__dict__[key]) is np.ndarray: + if not np.array_equal(self.__dict__[key], other.__dict__[key]): + return False + elif key == "_spline": + pass + elif getattr(self, key) != getattr(other, key): + return False + return True + + class UniformInComponentsMassRatio(Prior): def __init__(self, minimum, maximum, name='mass_ratio', latex_label='$q$', @@ -339,8 +354,9 @@ class UniformInComponentsMassRatio(Prior): latex_label=latex_label, unit=unit, boundary=boundary) self.norm = self._integral(maximum) - self._integral(minimum) qs = np.linspace(minimum, maximum, 1000) - self.icdf = interp1d(self.cdf(qs), qs, kind='cubic', - bounds_error=False, fill_value=(minimum, maximum)) + self.icdf = WrappedInterp1d( + self.cdf(qs), qs, kind='cubic', + bounds_error=False, fill_value=(minimum, maximum)) @staticmethod def _integral(q):