diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index 0907e49b6addf30dc77abb8008fa5f105a88bbad..2a1cff07c7868169b89e8f84d30beea5d52f27b3 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -1,7 +1,7 @@ import os import numpy as np from scipy.interpolate import UnivariateSpline -from ..core.prior import (PriorSet, Uniform, FromFile, Prior, DeltaFunction, +from ..core.prior import (PriorDict, Uniform, FromFile, Prior, DeltaFunction, Gaussian, Interped) from ..core.utils import logger @@ -65,7 +65,7 @@ class AlignedSpin(Interped): latex_label=latex_label, unit=unit) -class BBHPriorSet(PriorSet): +class BBHPriorDict(PriorDict): def __init__(self, dictionary=None, filename=None): """ Initialises a Prior set for Binary Black holes @@ -82,7 +82,7 @@ class BBHPriorSet(PriorSet): elif filename is not None: if not os.path.isfile(filename): filename = os.path.join(os.path.dirname(__file__), 'prior_files', filename) - PriorSet.__init__(self, dictionary=dictionary, filename=filename) + PriorDict.__init__(self, dictionary=dictionary, filename=filename) def test_redundancy(self, key): """ @@ -135,7 +135,15 @@ class BBHPriorSet(PriorSet): return redundant -class BNSPriorSet(PriorSet): +class BBHPriorSet(BBHPriorDict): + + def __init__(self, dictionary=None, filename=None): + """ DEPRECATED: USE BBHPriorDict INSTEAD""" + logger.warning("The name 'BBHPriorSet' is deprecated use 'BBHPriorDict' instead") + super(BBHPriorSet, self).__init__(dictionary, filename) + + +class BNSPriorDict(PriorDict): def __init__(self, dictionary=None, filename=None): """ Initialises a Prior set for Binary Neutron Stars @@ -153,10 +161,10 @@ class BNSPriorSet(PriorSet): elif filename is not None: if not os.path.isfile(filename): filename = os.path.join(os.path.dirname(__file__), 'prior_files', filename) - PriorSet.__init__(self, dictionary=dictionary, filename=filename) + PriorDict.__init__(self, dictionary=dictionary, filename=filename) def test_redundancy(self, key): - bbh_redundancy = BBHPriorSet().test_redundancy(key) + bbh_redundancy = BBHPriorDict().test_redundancy(key) if bbh_redundancy: return True redundant = False @@ -174,6 +182,14 @@ class BNSPriorSet(PriorSet): return redundant +class BNSPriorSet(BNSPriorDict): + + def __init__(self, dictionary=None, filename=None): + """ DEPRECATED: USE BNSPriorDict INSTEAD""" + super(BNSPriorSet, self).__init__(dictionary, filename) + logger.warning("The name 'BNSPriorSet' is deprecated use 'BNSPriorDict' instead") + + Prior._default_latex_labels = { 'mass_1': '$m_1$', 'mass_2': '$m_2$', @@ -203,7 +219,7 @@ Prior._default_latex_labels = { 'delta_lambda': '$\\delta\\Lambda$'} -class CalibrationPriorSet(PriorSet): +class CalibrationPriorDict(PriorDict): def __init__(self, dictionary=None, filename=None): """ Initialises a Prior set for Binary Black holes @@ -218,7 +234,7 @@ class CalibrationPriorSet(PriorSet): if dictionary is None and filename is not None: filename = os.path.join(os.path.dirname(__file__), 'prior_files', filename) - PriorSet.__init__(self, dictionary=dictionary, filename=filename) + PriorDict.__init__(self, dictionary=dictionary, filename=filename) self.source = None def write_to_file(self, outdir, label): @@ -233,7 +249,7 @@ class CalibrationPriorSet(PriorSet): label: str Label for prior. """ - PriorSet.write_to_file(self, outdir=outdir, label=label) + PriorDict.write_to_file(self, outdir=outdir, label=label) if self.source is not None: prior_file = os.path.join(outdir, "{}.prior".format(label)) with open(prior_file, "a") as outfile: @@ -264,7 +280,7 @@ class CalibrationPriorSet(PriorSet): Returns ------- - prior: PriorSet + prior: PriorDict Priors for the relevant parameters. This includes the frequencies of the nodes which are _not_ sampled. """ @@ -287,7 +303,7 @@ class CalibrationPriorSet(PriorSet): phase_sigma_nodes =\ UnivariateSpline(frequency_array, phase_sigma)(nodes) - prior = CalibrationPriorSet() + prior = CalibrationPriorDict() for ii in range(n_nodes): name = "recalib_{}_amplitude_{}".format(label, ii) latex_label = "$A^{}_{}$".format(label, ii) @@ -334,7 +350,7 @@ class CalibrationPriorSet(PriorSet): Returns ------- - prior: PriorSet + prior: PriorDict Priors for the relevant parameters. This includes the frequencies of the nodes which are _not_ sampled. """ @@ -346,7 +362,7 @@ class CalibrationPriorSet(PriorSet): phase_mean_nodes = [0] * n_nodes phase_sigma_nodes = [phase_sigma] * n_nodes - prior = CalibrationPriorSet() + prior = CalibrationPriorDict() for ii in range(n_nodes): name = "recalib_{}_amplitude_{}".format(label, ii) latex_label = "$A^{}_{}$".format(label, ii) @@ -366,3 +382,11 @@ class CalibrationPriorSet(PriorSet): latex_label=latex_label) return prior + + +class CalibrationPriorSet(CalibrationPriorDict): + + def __init__(self, dictionary=None, filename=None): + """ DEPRECATED: USE BNSPriorDict INSTEAD""" + super(CalibrationPriorSet, self).__init__(dictionary, filename) + logger.warning("The name 'CalibrationPriorSet' is deprecated use 'CalibrationPriorDict' instead")