Commit 47c905c8 authored by MoritzThomasHuebner's avatar MoritzThomasHuebner
Browse files

Eliminated remaining PriorSet references

parent 274019b7
Pipeline #34428 passed with stage
in 9 minutes and 16 seconds
...@@ -115,7 +115,7 @@ First `pip` installable version https://pypi.org/project/BILBY/ . ...@@ -115,7 +115,7 @@ First `pip` installable version https://pypi.org/project/BILBY/ .
- Major effort to update all docstrings and add some documentation. - Major effort to update all docstrings and add some documentation.
- Marginalized likelihoods. - Marginalized likelihoods.
- Examples of searches for gravitational waves from a Supernova and using a sine-Gaussian. - Examples of searches for gravitational waves from a Supernova and using a sine-Gaussian.
- A `PriorSet` to handle sets of priors and allows reading in from a standardised prior file (see https://monash.docs.ligo.org/bilby/prior.html). - A `PriorDict` to handle sets of priors and allows reading in from a standardised prior file (see https://monash.docs.ligo.org/bilby/prior.html).
- A standardised file for storing detector data. - A standardised file for storing detector data.
### Removed ### Removed
......
...@@ -36,7 +36,7 @@ class PriorDict(OrderedDict): ...@@ -36,7 +36,7 @@ class PriorDict(OrderedDict):
elif type(filename) is str: elif type(filename) is str:
self.from_file(filename) self.from_file(filename)
elif dictionary is not None: elif dictionary is not None:
raise ValueError("PriorSet input dictionary not understood") raise ValueError("PriorDict input dictionary not understood")
def to_file(self, outdir, label): def to_file(self, outdir, label):
""" Write the prior distribution to file. """ Write the prior distribution to file.
......
...@@ -10,7 +10,7 @@ from collections import OrderedDict ...@@ -10,7 +10,7 @@ from collections import OrderedDict
from . import utils from . import utils
from .utils import logger, infer_parameters_from_function from .utils import logger, infer_parameters_from_function
from .prior import PriorSet, DeltaFunction from .prior import PriorDict, DeltaFunction
def result_file_name(outdir, label): def result_file_name(outdir, label):
...@@ -81,7 +81,7 @@ class Result(dict): ...@@ -81,7 +81,7 @@ class Result(dict):
setattr(self, key, val) setattr(self, key, val)
if getattr(self, 'priors', None) is not None: if getattr(self, 'priors', None) is not None:
self.priors = PriorSet(self.priors) self.priors = PriorDict(self.priors)
def __add__(self, other): def __add__(self, other):
matches = ['sampler', 'search_parameter_keys'] matches = ['sampler', 'search_parameter_keys']
...@@ -311,9 +311,9 @@ class Result(dict): ...@@ -311,9 +311,9 @@ class Result(dict):
parameters: (list, dict), optional parameters: (list, dict), optional
If given, either a list of the parameter names to include, or a If given, either a list of the parameter names to include, or a
dictionary of parameter names and their "true" values to plot. dictionary of parameter names and their "true" values to plot.
priors: {bool (False), bilby.core.prior.PriorSet} priors: {bool (False), bilby.core.prior.PriorDict}
If true, add the stored prior probability density functions to the If true, add the stored prior probability density functions to the
one-dimensional marginal distributions. If instead a PriorSet one-dimensional marginal distributions. If instead a PriorDict
is provided, this will be plotted. is provided, this will be plotted.
titles: bool titles: bool
If true, add 1D titles of the median and (by default 1-sigma) If true, add 1D titles of the median and (by default 1-sigma)
...@@ -577,7 +577,7 @@ class Result(dict): ...@@ -577,7 +577,7 @@ class Result(dict):
Parameters Parameters
---------- ----------
priors: dict, PriorSet priors: dict, PriorDict
Prior distributions Prior distributions
""" """
self.prior_values = pd.DataFrame() self.prior_values = pd.DataFrame()
......
...@@ -4,7 +4,7 @@ import datetime ...@@ -4,7 +4,7 @@ import datetime
from collections import OrderedDict from collections import OrderedDict
from ..utils import command_line_args, logger from ..utils import command_line_args, logger
from ..prior import PriorSet from ..prior import PriorDict
from .base_sampler import Sampler from .base_sampler import Sampler
from .cpnest import Cpnest from .cpnest import Cpnest
...@@ -47,8 +47,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', ...@@ -47,8 +47,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
---------- ----------
likelihood: `bilby.Likelihood` likelihood: `bilby.Likelihood`
A `Likelihood` instance A `Likelihood` instance
priors: `bilby.PriorSet` priors: `bilby.PriorDict`
A PriorSet/dictionary of the priors for each parameter - missing A PriorDict/dictionary of the priors for each parameter - missing
parameters will use default priors, if None, all priors will be default parameters will use default priors, if None, all priors will be default
label: str label: str
Name for the run, used in output files Name for the run, used in output files
...@@ -101,8 +101,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', ...@@ -101,8 +101,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
priors = dict() priors = dict()
if type(priors) in [dict, OrderedDict]: if type(priors) in [dict, OrderedDict]:
priors = PriorSet(priors) priors = PriorDict(priors)
elif isinstance(priors, PriorSet): elif isinstance(priors, PriorDict):
pass pass
else: else:
raise ValueError("Input priors not understood") raise ValueError("Input priors not understood")
......
...@@ -3,7 +3,7 @@ import datetime ...@@ -3,7 +3,7 @@ import datetime
import numpy as np import numpy as np
from pandas import DataFrame from pandas import DataFrame
from ..utils import logger, command_line_args from ..utils import logger, command_line_args
from ..prior import Prior, PriorSet from ..prior import Prior, PriorDict
from ..result import Result, read_in_result from ..result import Result, read_in_result
...@@ -14,7 +14,7 @@ class Sampler(object): ...@@ -14,7 +14,7 @@ class Sampler(object):
---------- ----------
likelihood: likelihood.Likelihood likelihood: likelihood.Likelihood
A object with a log_l method A object with a log_l method
priors: bilby.core.prior.PriorSet, dict priors: bilby.core.prior.PriorDict, dict
Priors to be used in the search. Priors to be used in the search.
This has attributes for each parameter to be sampled. This has attributes for each parameter to be sampled.
external_sampler: str, Sampler, optional external_sampler: str, Sampler, optional
...@@ -35,7 +35,7 @@ class Sampler(object): ...@@ -35,7 +35,7 @@ class Sampler(object):
------- -------
likelihood: likelihood.Likelihood likelihood: likelihood.Likelihood
A object with a log_l method A object with a log_l method
priors: bilby.core.prior.PriorSet priors: bilby.core.prior.PriorDict
Priors to be used in the search. Priors to be used in the search.
This has attributes for each parameter to be sampled. This has attributes for each parameter to be sampled.
external_sampler: Module external_sampler: Module
...@@ -74,10 +74,10 @@ class Sampler(object): ...@@ -74,10 +74,10 @@ class Sampler(object):
self, likelihood, priors, outdir='outdir', label='label', self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False, **kwargs): use_ratio=False, plot=False, skip_import_verification=False, **kwargs):
self.likelihood = likelihood self.likelihood = likelihood
if isinstance(priors, PriorSet): if isinstance(priors, PriorDict):
self.priors = priors self.priors = priors
else: else:
self.priors = PriorSet(priors) self.priors = PriorDict(priors)
self.label = label self.label = label
self.outdir = outdir self.outdir = outdir
self.use_ratio = use_ratio self.use_ratio = use_ratio
......
...@@ -12,7 +12,7 @@ from ..core import likelihood ...@@ -12,7 +12,7 @@ from ..core import likelihood
from ..core.utils import logger from ..core.utils import logger
from ..core.prior import Prior, Uniform from ..core.prior import Prior, Uniform
from .detector import InterferometerList from .detector import InterferometerList
from .prior import BBHPriorSet from .prior import BBHPriorDict
from .source import lal_binary_black_hole from .source import lal_binary_black_hole
from .utils import noise_weighted_inner_product from .utils import noise_weighted_inner_product
from .waveform_generator import WaveformGenerator from .waveform_generator import WaveformGenerator
...@@ -122,7 +122,7 @@ class GravitationalWaveTransient(likelihood.Likelihood): ...@@ -122,7 +122,7 @@ class GravitationalWaveTransient(likelihood.Likelihood):
self.interferometers.start_time, self.interferometers.start_time,
self.interferometers.start_time + self.interferometers.duration) self.interferometers.start_time + self.interferometers.duration)
else: else:
self.prior[key] = BBHPriorSet()[key] self.prior[key] = BBHPriorDict()[key]
@property @property
def prior(self): def prior(self):
......
...@@ -89,7 +89,7 @@ for interferometer in interferometers: ...@@ -89,7 +89,7 @@ for interferometer in interferometers:
interferometer.plot_data(signal=signal, outdir=outdir, label=label) interferometer.plot_data(signal=signal, outdir=outdir, label=label)
# set up priors # set up priors
priors = bilby.gw.prior.BBHPriorSet() priors = bilby.gw.prior.BBHPriorDict()
for key in ['a_1', 'a_2', 'tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'psi', for key in ['a_1', 'a_2', 'tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'psi',
'geocent_time', 'phase']: 'geocent_time', 'phase']:
priors[key] = injection_parameters[key] priors[key] = injection_parameters[key]
......
...@@ -54,7 +54,7 @@ ifos.set_strain_data_from_power_spectral_densities( ...@@ -54,7 +54,7 @@ ifos.set_strain_data_from_power_spectral_densities(
ifos.inject_signal(waveform_generator=waveform_generator, ifos.inject_signal(waveform_generator=waveform_generator,
parameters=injection_parameters) parameters=injection_parameters)
# Set up a PriorSet, which inherits from dict. # Set up a PriorDict, which inherits from dict.
# By default we will sample all terms in the signal models. However, this will # By default we will sample all terms in the signal models. However, this will
# take a long time for the calculation, so for this example we will set almost # take a long time for the calculation, so for this example we will set almost
# all of the priors to be equall to their injected values. This implies the # all of the priors to be equall to their injected values. This implies the
...@@ -64,7 +64,7 @@ ifos.inject_signal(waveform_generator=waveform_generator, ...@@ -64,7 +64,7 @@ ifos.inject_signal(waveform_generator=waveform_generator,
# The above list does *not* include mass_1, mass_2, iota and luminosity # The above list does *not* include mass_1, mass_2, iota and luminosity
# distance, which means those are the parameters that will be included in the # distance, which means those are the parameters that will be included in the
# sampler. If we do nothing, then the default priors get used. # sampler. If we do nothing, then the default priors get used.
priors = bilby.gw.prior.BBHPriorSet() priors = bilby.gw.prior.BBHPriorDict()
priors['geocent_time'] = bilby.core.prior.Uniform( priors['geocent_time'] = bilby.core.prior.Uniform(
minimum=injection_parameters['geocent_time'] - 1, minimum=injection_parameters['geocent_time'] - 1,
maximum=injection_parameters['geocent_time'] + 1, maximum=injection_parameters['geocent_time'] + 1,
......
...@@ -64,7 +64,7 @@ interferometers.inject_signal(parameters=injection_parameters, ...@@ -64,7 +64,7 @@ interferometers.inject_signal(parameters=injection_parameters,
# Load the default prior for binary neutron stars. # Load the default prior for binary neutron stars.
# We're going to sample in chirp_mass, symmetric_mass_ratio, lambda_tilde, and # We're going to sample in chirp_mass, symmetric_mass_ratio, lambda_tilde, and
# delta_lambda rather than mass_1, mass_2, lambda_1, and lambda_2. # delta_lambda rather than mass_1, mass_2, lambda_1, and lambda_2.
priors = bilby.gw.prior.BNSPriorSet() priors = bilby.gw.prior.BNSPriorDict()
for key in ['psi', 'geocent_time', 'ra', 'dec', 'chi_1', 'chi_2', for key in ['psi', 'geocent_time', 'ra', 'dec', 'chi_1', 'chi_2',
'iota', 'luminosity_distance', 'phase']: 'iota', 'luminosity_distance', 'phase']:
priors[key] = injection_parameters[key] priors[key] = injection_parameters[key]
......
...@@ -48,7 +48,7 @@ ifos.inject_signal(waveform_generator=waveform_generator, ...@@ -48,7 +48,7 @@ ifos.inject_signal(waveform_generator=waveform_generator,
# Set up prior # Set up prior
# Note it is possible to sample in different parameters to those that were # Note it is possible to sample in different parameters to those that were
# injected. # injected.
priors = bilby.gw.prior.BBHPriorSet() priors = bilby.gw.prior.BBHPriorDict()
priors.pop('mass_1') priors.pop('mass_1')
priors.pop('mass_2') priors.pop('mass_2')
priors.pop('luminosity_distance') priors.pop('luminosity_distance')
......
...@@ -58,7 +58,7 @@ ifos.inject_signal(waveform_generator=waveform_generator, ...@@ -58,7 +58,7 @@ ifos.inject_signal(waveform_generator=waveform_generator,
parameters=injection_parameters) parameters=injection_parameters)
# Now we set up the priors on each of the binary parameters. # Now we set up the priors on each of the binary parameters.
priors = bilby.core.prior.PriorSet() priors = bilby.core.prior.PriorDict()
priors["mass_1"] = bilby.core.prior.Uniform( priors["mass_1"] = bilby.core.prior.Uniform(
name='mass_1', minimum=5, maximum=60, unit='$M_{\\odot}$') name='mass_1', minimum=5, maximum=60, unit='$M_{\\odot}$')
priors["mass_2"] = bilby.core.prior.Uniform( priors["mass_2"] = bilby.core.prior.Uniform(
......
...@@ -39,7 +39,7 @@ ifos.inject_signal(waveform_generator=waveform_generator, ...@@ -39,7 +39,7 @@ ifos.inject_signal(waveform_generator=waveform_generator,
# Set up prior # Set up prior
# This loads in a predefined set of priors for BBHs. # This loads in a predefined set of priors for BBHs.
priors = bilby.gw.prior.BBHPriorSet() priors = bilby.gw.prior.BBHPriorDict()
# These parameters will not be sampled # These parameters will not be sampled
for key in ['tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'phase', 'iota', 'ra', for key in ['tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'phase', 'iota', 'ra',
'dec', 'geocent_time', 'psi']: 'dec', 'geocent_time', 'psi']:
......
...@@ -38,7 +38,7 @@ ifos.inject_signal(waveform_generator=waveform_generator, ...@@ -38,7 +38,7 @@ ifos.inject_signal(waveform_generator=waveform_generator,
parameters=injection_parameters) parameters=injection_parameters)
# Set up prior # Set up prior
priors = bilby.gw.prior.BBHPriorSet() priors = bilby.gw.prior.BBHPriorDict()
# These parameters will not be sampled # These parameters will not be sampled
for key in ['a_1', 'a_2', 'tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'iota', 'ra', for key in ['a_1', 'a_2', 'tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'iota', 'ra',
'dec']: 'dec']:
......
...@@ -30,7 +30,7 @@ ifos.set_strain_data_from_power_spectral_densities( ...@@ -30,7 +30,7 @@ ifos.set_strain_data_from_power_spectral_densities(
ifos.inject_signal(waveform_generator=waveform_generator, ifos.inject_signal(waveform_generator=waveform_generator,
parameters=injection_parameters) parameters=injection_parameters)
priors = bilby.gw.prior.BBHPriorSet() priors = bilby.gw.prior.BBHPriorDict()
for key in ['a_1', 'a_2', 'tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'psi', for key in ['a_1', 'a_2', 'tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'psi',
'mass_1', 'mass_2', 'phase', 'geocent_time', 'luminosity_distance', 'mass_1', 'mass_2', 'phase', 'geocent_time', 'luminosity_distance',
'iota']: 'iota']:
......
...@@ -30,7 +30,7 @@ interferometers = bilby.gw.detector.get_event_data(label) ...@@ -30,7 +30,7 @@ interferometers = bilby.gw.detector.get_event_data(label)
# The prior is printed to the terminal at run-time. # The prior is printed to the terminal at run-time.
# You can overwrite this using the syntax below in the file, # You can overwrite this using the syntax below in the file,
# or choose a fixed value by just providing a float value as the prior. # or choose a fixed value by just providing a float value as the prior.
prior = bilby.gw.prior.BBHPriorSet(filename='GW150914.prior') prior = bilby.gw.prior.BBHPriorDict(filename='GW150914.prior')
# In this step we define a `waveform_generator`. This is out object which # In this step we define a `waveform_generator`. This is out object which
# creates the frequency-domain strain. In this instance, we are using the # creates the frequency-domain strain. In this instance, we are using the
......
...@@ -6,7 +6,7 @@ stimation on GW150914 using open data. ...@@ -6,7 +6,7 @@ stimation on GW150914 using open data.
""" """
import bilby import bilby
prior = bilby.gw.prior.BBHPriorSet(filename='GW150914.prior') prior = bilby.gw.prior.BBHPriorDict(filename='GW150914.prior')
interferometers = bilby.gw.detector.get_event_data("GW150914") interferometers = bilby.gw.detector.get_event_data("GW150914")
likelihood = bilby.gw.likelihood.get_binary_black_hole_likelihood(interferometers) likelihood = bilby.gw.likelihood.get_binary_black_hole_likelihood(interferometers)
result = bilby.run_sampler(likelihood, prior, label='GW150914') result = bilby.run_sampler(likelihood, prior, label='GW150914')
......
...@@ -85,7 +85,7 @@ class TestGWTransient(unittest.TestCase): ...@@ -85,7 +85,7 @@ class TestGWTransient(unittest.TestCase):
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
) )
self.prior = bilby.gw.prior.BBHPriorSet() self.prior = bilby.gw.prior.BBHPriorDict()
self.prior['geocent_time'] = bilby.prior.Uniform( self.prior['geocent_time'] = bilby.prior.Uniform(
minimum=self.parameters['geocent_time'] - self.duration / 2, minimum=self.parameters['geocent_time'] - self.duration / 2,
maximum=self.parameters['geocent_time'] + self.duration / 2) maximum=self.parameters['geocent_time'] + self.duration / 2)
...@@ -158,7 +158,7 @@ class TestTimeMarginalization(unittest.TestCase): ...@@ -158,7 +158,7 @@ class TestTimeMarginalization(unittest.TestCase):
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
) )
self.prior = bilby.gw.prior.BBHPriorSet() self.prior = bilby.gw.prior.BBHPriorDict()
self.prior['geocent_time'] = bilby.prior.Uniform( self.prior['geocent_time'] = bilby.prior.Uniform(
minimum=self.parameters['geocent_time'] - self.duration / 2, minimum=self.parameters['geocent_time'] - self.duration / 2,
maximum=self.parameters['geocent_time'] + self.duration / 2) maximum=self.parameters['geocent_time'] + self.duration / 2)
...@@ -224,7 +224,7 @@ class TestMarginalizedLikelihood(unittest.TestCase): ...@@ -224,7 +224,7 @@ class TestMarginalizedLikelihood(unittest.TestCase):
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
) )
self.prior = bilby.gw.prior.BBHPriorSet() self.prior = bilby.gw.prior.BBHPriorDict()
self.prior['geocent_time'] = bilby.prior.Uniform( self.prior['geocent_time'] = bilby.prior.Uniform(
minimum=self.parameters['geocent_time'] - self.duration / 2, minimum=self.parameters['geocent_time'] - self.duration / 2,
maximum=self.parameters['geocent_time'] + self.duration / 2) maximum=self.parameters['geocent_time'] + self.duration / 2)
...@@ -287,7 +287,7 @@ class TestPhaseMarginalization(unittest.TestCase): ...@@ -287,7 +287,7 @@ class TestPhaseMarginalization(unittest.TestCase):
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
) )
self.prior = bilby.gw.prior.BBHPriorSet() self.prior = bilby.gw.prior.BBHPriorDict()
self.prior['geocent_time'] = bilby.prior.Uniform( self.prior['geocent_time'] = bilby.prior.Uniform(
minimum=self.parameters['geocent_time'] - self.duration / 2, minimum=self.parameters['geocent_time'] - self.duration / 2,
maximum=self.parameters['geocent_time'] + self.duration / 2) maximum=self.parameters['geocent_time'] + self.duration / 2)
...@@ -350,7 +350,7 @@ class TestTimePhaseMarginalization(unittest.TestCase): ...@@ -350,7 +350,7 @@ class TestTimePhaseMarginalization(unittest.TestCase):
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
) )
self.prior = bilby.gw.prior.BBHPriorSet() self.prior = bilby.gw.prior.BBHPriorDict()
self.prior['geocent_time'] = bilby.prior.Uniform( self.prior['geocent_time'] = bilby.prior.Uniform(
minimum=self.parameters['geocent_time'] - self.duration / 2, minimum=self.parameters['geocent_time'] - self.duration / 2,
maximum=self.parameters['geocent_time'] + self.duration / 2) maximum=self.parameters['geocent_time'] + self.duration / 2)
......
...@@ -5,7 +5,7 @@ import os ...@@ -5,7 +5,7 @@ import os
import sys import sys
class TestBBHPriorSet(unittest.TestCase): class TestBBHPriorDict(unittest.TestCase):
def setUp(self): def setUp(self):
self.prior_dict = dict() self.prior_dict = dict()
...@@ -13,7 +13,7 @@ class TestBBHPriorSet(unittest.TestCase): ...@@ -13,7 +13,7 @@ class TestBBHPriorSet(unittest.TestCase):
'/'.join(os.path.dirname( '/'.join(os.path.dirname(
os.path.abspath(sys.argv[0])).split('/')[:-1]) os.path.abspath(sys.argv[0])).split('/')[:-1])
self.filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prior_files/binary_black_holes.prior') self.filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prior_files/binary_black_holes.prior')
self.default_prior = bilby.gw.prior.BBHPriorSet( self.default_prior = bilby.gw.prior.BBHPriorDict(
filename=self.filename) filename=self.filename)
def tearDown(self): def tearDown(self):
...@@ -21,7 +21,7 @@ class TestBBHPriorSet(unittest.TestCase): ...@@ -21,7 +21,7 @@ class TestBBHPriorSet(unittest.TestCase):
del self.filename del self.filename
def test_create_default_prior(self): def test_create_default_prior(self):
default = bilby.gw.prior.BBHPriorSet() default = bilby.gw.prior.BBHPriorDict()
minima = all([self.default_prior[key].minimum == default[key].minimum minima = all([self.default_prior[key].minimum == default[key].minimum
for key in default.keys()]) for key in default.keys()])
maxima = all([self.default_prior[key].maximum == default[key].maximum maxima = all([self.default_prior[key].maximum == default[key].maximum
...@@ -32,10 +32,10 @@ class TestBBHPriorSet(unittest.TestCase): ...@@ -32,10 +32,10 @@ class TestBBHPriorSet(unittest.TestCase):
self.assertTrue(all([minima, maxima, names])) self.assertTrue(all([minima, maxima, names]))
def test_create_from_dict(self): def test_create_from_dict(self):
bilby.gw.prior.BBHPriorSet(dictionary=self.prior_dict) bilby.gw.prior.BBHPriorDict(dictionary=self.prior_dict)
def test_create_from_filename(self): def test_create_from_filename(self):
bilby.gw.prior.BBHPriorSet(filename=self.filename) bilby.gw.prior.BBHPriorDict(filename=self.filename)
def test_key_in_prior_not_redundant(self): def test_key_in_prior_not_redundant(self):
test = self.default_prior.test_redundancy('mass_1') test = self.default_prior.test_redundancy('mass_1')
...@@ -62,7 +62,7 @@ class TestCalibrationPrior(unittest.TestCase): ...@@ -62,7 +62,7 @@ class TestCalibrationPrior(unittest.TestCase):
phase_sigma = 0.1 phase_sigma = 0.1
n_nodes = 9 n_nodes = 9
label = 'test' label = 'test'
test = bilby.gw.prior.CalibrationPriorSet.constant_uncertainty_spline( test = bilby.gw.prior.CalibrationPriorDict.constant_uncertainty_spline(
amplitude_sigma, phase_sigma, self.minimum_frequency, amplitude_sigma, phase_sigma, self.minimum_frequency,
self.maximum_frequency, n_nodes, label) self.maximum_frequency, n_nodes, label)
......
...@@ -295,7 +295,7 @@ class TestPriorClasses(unittest.TestCase): ...@@ -295,7 +295,7 @@ class TestPriorClasses(unittest.TestCase):
self.assertEqual(prior, repr_prior) self.assertEqual(prior, repr_prior)
class TestPriorSet(unittest.TestCase): class TestPriorDict(unittest.TestCase):
def setUp(self): def setUp(self):
self.first_prior = bilby.core.prior.Uniform(name='a', minimum=0, maximum=1, unit='kg') self.first_prior = bilby.core.prior.Uniform(name='a', minimum=0, maximum=1, unit='kg')
...@@ -304,10 +304,10 @@ class TestPriorSet(unittest.TestCase): ...@@ -304,10 +304,10 @@ class TestPriorSet(unittest.TestCase):
self.prior_dict = dict(mass=self.first_prior, self.prior_dict = dict(mass=self.first_prior,
speed=self.second_prior, speed=self.second_prior,
length=self.third_prior) length=self.third_prior)
self.prior_set_from_dict = bilby.core.prior.PriorSet(dictionary=self.prior_dict) self.prior_set_from_dict = bilby.core.prior.PriorDict(dictionary=self.prior_dict)
self.default_prior_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), self.default_prior_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'prior_files/binary_black_holes.prior') 'prior_files/binary_black_holes.prior')
self.prior_set_from_file = bilby.core.prior.PriorSet(filename=self.default_prior_file) self.prior_set_from_file = bilby.core.prior.PriorDict(filename=self.default_prior_file)
def tearDown(self): def tearDown(self):
del self.first_prior del self.first_prior
...@@ -387,7 +387,7 @@ class TestPriorSet(unittest.TestCase): ...@@ -387,7 +387,7 @@ class TestPriorSet(unittest.TestCase):
self.assertDictEqual(expected, self.prior_set_from_dict) self.assertDictEqual(expected, self.prior_set_from_dict)
def test_prior_set_from_dict_but_using_a_string(self): def test_prior_set_from_dict_but_using_a_string(self):
prior_set = bilby.core.prior.PriorSet(dictionary=self.default_prior_file) prior_set = bilby.core.prior.PriorDict(dictionary=self.default_prior_file)
expected = dict( expected = dict(
mass_1=bilby.core.prior.Uniform( mass_1=bilby.core.prior.Uniform(
name='mass_1', minimum=5, maximum=100, unit='$M_{\\odot}$'), name='mass_1', minimum=5, maximum=100, unit='$M_{\\odot}$'),
...@@ -416,7 +416,7 @@ class TestPriorSet(unittest.TestCase): ...@@ -416,7 +416,7 @@ class TestPriorSet(unittest.TestCase):
def test_dict_argument_is_not_string_or_dict(self): def test_dict_argument_is_not_string_or_dict(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
bilby.core.prior.PriorSet(dictionary=list()) bilby.core.prior.PriorDict(dictionary=list())
def test_sample_subset_correct_size(self): def test_sample_subset_correct_size(self):
size = 7 size = 7
...@@ -476,7 +476,7 @@ class TestFillPrior(unittest.TestCase): ...@@ -476,7 +476,7 @@ class TestFillPrior(unittest.TestCase):
self.likelihood.parameters = dict(a=0, b=0, c=0, d=0, asdf=0, ra=1) self.likelihood.parameters = dict(a=0, b=0, c=0, d=0, asdf=0, ra=1)
self.likelihood.non_standard_sampling_parameter_keys = dict(t=8) self.likelihood.non_standard_sampling_parameter_keys = dict(t=8)
self.priors = dict(a=1, b=1.1, c='string', d=bilby.core.prior.Uniform(0, 1)) self.priors = dict(a=1, b=1.1, c='string', d=bilby.core.prior.Uniform(0, 1))
self.priors = bilby.core.prior.PriorSet(dictionary=self.priors)