Commit 0e6d36ea authored by Gregory Ashton's avatar Gregory Ashton

Improve testing for prior files

parent 0998b75b
from __future__ import division
import re
from importlib import import_module
import os
from collections import OrderedDict
......@@ -750,7 +751,10 @@ class Prior(object):
@classmethod
def _parse_argument_string(cls, val):
if '(' in val:
if re.sub(r'\'.*\'', '', val) in ['r', 'u']:
# If the val is a latex label like "r"\log(x)"' then ignore it
pass
elif '(' in val:
other_cls = val.split('(')[0]
vals = '('.join(val.split('(')[1:])[:-1]
if "." in other_cls:
......
......@@ -110,6 +110,37 @@ class TestBBHPriorDict(unittest.TestCase):
self.assertFalse(self.bbh_prior_dict.test_has_redundant_keys())
class TestPackagedPriors(unittest.TestCase):
""" Test that the prepackaged priors load """
def test_aligned(self):
filename = 'aligned_spin_binary_black_holes.prior'
prior_dict = bilby.gw.prior.BBHPriorDict(filename=filename)
self.assertTrue('chi_1' in prior_dict)
self.assertTrue('chi_2' in prior_dict)
def test_GW150914(self):
filename = 'GW150914.prior'
prior_dict = bilby.gw.prior.BBHPriorDict(filename=filename)
self.assertTrue('geocent_time' in prior_dict)
def test_precessing(self):
filename = 'precessing_binary_neutron_stars.prior'
prior_dict = bilby.gw.prior.BBHPriorDict(filename=filename)
self.assertTrue('lambda_1' in prior_dict)
self.assertTrue('lambda_2' in prior_dict)
def test_binary_black_holes(self):
filename = 'binary_black_holes.prior'
prior_dict = bilby.gw.prior.BBHPriorDict(filename=filename)
self.assertTrue('a_1' in prior_dict)
def test_binary_neutron_stars(self):
filename = 'binary_neutron_stars.prior'
prior_dict = bilby.gw.prior.BNSPriorDict(filename=filename)
self.assertTrue('lambda_1' in prior_dict)
class TestBNSPriorDict(unittest.TestCase):
def setUp(self):
......
mass_1 = Uniform(name='mass_1', minimum=5, maximum=100, unit='$M_{\odot}$', boundary=None)
mass_2 = 20
logA = Uniform(name='logA', minimum=10, maximum=20, latex_label=r'$\log(A_{0})$')
......@@ -702,6 +702,22 @@ class TestPriorDict(unittest.TestCase):
self.assertFalse(self.prior_set_from_dict.test_redundancy(key=key))
class TestLoadPrior(unittest.TestCase):
def test_load_prior_with_float(self):
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'prior_files/prior_with_floats.prior')
prior = bilby.core.prior.PriorDict(filename)
self.assertTrue("mass_1" in prior)
self.assertTrue("mass_2" in prior)
self.assertTrue(prior['mass_2'].peak == 20)
def test_load_prior_with_parentheses(self):
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'prior_files/prior_with_parentheses.prior')
prior = bilby.core.prior.PriorDict(filename)
self.assertTrue(isinstance(prior['logA'], bilby.core.prior.Uniform))
class TestFillPrior(unittest.TestCase):
def setUp(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment