Commit 6c90e0a0 authored by Moritz Huebner's avatar Moritz Huebner Committed by Colm Talbot

Restructure tests

parent 514b1bd9
[run]
omit =
test/example_test.py
test/gw_example_test.py
test/noise_realisation_test.py
test/other_test.py
test/core/example_test.py
test/gw/example_test.py
test/gw/noise_realisation_test.py
test/gw/other_test.py
......@@ -91,8 +91,8 @@ python-3.7-samplers:
script:
- python -m pip install .
- pytest test/sampler_test.py --durations 10
- pytest test/sample_from_the_prior_test.py
- pytest test/core/sampler/sampler_run_test.py --durations 10
- pytest test/gw/sample_from_the_prior_test.py
# test samplers on python 3.6
python-3.6-samplers:
......@@ -101,7 +101,7 @@ python-3.6-samplers:
script:
- python -m pip install .
- pytest test/sampler_test.py
- pytest test/core/sampler/sampler_run_test.py
# Tests run at a fixed schedule rather than on push
scheduled-python-3.7:
......@@ -113,9 +113,9 @@ scheduled-python-3.7:
- python -m pip install .
# Run tests which are only done on schedule
- pytest test/example_test.py
- pytest test/gw_example_test.py
- pytest test/sample_from_the_prior_test.py
- pytest test/core/example_test.py
- pytest test/gw/example_test.py
- pytest test/gw/sample_from_the_prior_test.py
plotting:
stage: test
......@@ -126,7 +126,7 @@ plotting:
- python -m pip install .
- python -m pip install ligo.skymap
- pytest test/gw_plot_test.py
- pytest test/gw/plot_test.py
pages:
stage: deploy
......
......@@ -5,12 +5,12 @@ ignore = E129 W503 W504 W605 E203 E402
[tool:pytest]
addopts =
--ignore test/other_test.py
--ignore test/gw_example_test.py
--ignore test/example_test.py
--ignore test/sample_from_the_prior_test.py
--ignore test/gw_plot_test.py
--ignore test/sampler_test.py
--ignore test/gw/other_test.py
--ignore test/gw/example_test.py
--ignore test/core/example_test.py
--ignore test/gw/sample_from_the_prior_test.py
--ignore test/gw/plot_test.py
--ignore test/core/sampler/sampler_run_test.py
[metadata]
license_file = LICENSE.md
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import bilby # noqa
import unittest
if __name__ == "__main__":
unittest.main()
import unittest
import numpy as np
from mock import Mock
import bilby
class TestPriorInstantiationWithoutOptionalPriors(unittest.TestCase):
def setUp(self):
self.prior = bilby.core.prior.Prior()
def tearDown(self):
del self.prior
def test_name(self):
self.assertIsNone(self.prior.name)
def test_latex_label(self):
self.assertIsNone(self.prior.latex_label)
def test_is_fixed(self):
self.assertFalse(self.prior.is_fixed)
def test_class_instance(self):
self.assertIsInstance(self.prior, bilby.core.prior.Prior)
def test_magic_call_is_the_same_as_sampling(self):
self.prior.sample = Mock(return_value=0.5)
self.assertEqual(self.prior.sample(), self.prior())
def test_base_rescale_method(self):
self.assertIsNone(self.prior.rescale(1))
def test_base_repr(self):
"""
We compare that the strings contain all of the same characters in not
necessarily the same order as python2 doesn't conserve the order of the
arguments.
"""
self.prior = bilby.core.prior.Prior(
name="test_name",
latex_label="test_label",
minimum=0,
maximum=1,
check_range_nonzero=True,
boundary=None,
)
expected_string = (
"Prior(name='test_name', latex_label='test_label', unit=None, minimum=0, maximum=1, "
"check_range_nonzero=True, boundary=None)"
)
self.assertTrue(sorted(expected_string) == sorted(self.prior.__repr__()))
def test_base_prob(self):
self.assertTrue(np.isnan(self.prior.prob(5)))
def test_base_ln_prob(self):
self.prior.prob = lambda val: val
self.assertEqual(np.log(5), self.prior.ln_prob(5))
def test_is_in_prior(self):
self.prior.minimum = 0
self.prior.maximum = 1
val_below = self.prior.minimum - 0.1
val_at_minimum = self.prior.minimum
val_in_prior = (self.prior.minimum + self.prior.maximum) / 2.0
val_at_maximum = self.prior.maximum
val_above = self.prior.maximum + 0.1
self.assertTrue(self.prior.is_in_prior_range(val_at_minimum))
self.assertTrue(self.prior.is_in_prior_range(val_at_maximum))
self.assertTrue(self.prior.is_in_prior_range(val_in_prior))
self.assertFalse(self.prior.is_in_prior_range(val_below))
self.assertFalse(self.prior.is_in_prior_range(val_above))
def test_boundary_is_none(self):
self.assertIsNone(self.prior.boundary)
class TestPriorName(unittest.TestCase):
def setUp(self):
self.test_name = "test_name"
self.prior = bilby.core.prior.Prior(self.test_name)
def tearDown(self):
del self.prior
del self.test_name
def test_name_assignment(self):
self.prior.name = "other_name"
self.assertEqual(self.prior.name, "other_name")
class TestPriorLatexLabel(unittest.TestCase):
def setUp(self):
self.test_name = "test_name"
self.prior = bilby.core.prior.Prior(self.test_name)
def tearDown(self):
del self.test_name
del self.prior
def test_label_assignment(self):
test_label = "test_label"
self.prior.latex_label = "test_label"
self.assertEqual(test_label, self.prior.latex_label)
def test_default_label_assignment(self):
self.prior.name = "chirp_mass"
self.prior.latex_label = None
self.assertEqual(self.prior.latex_label, "$\mathcal{M}$")
def test_default_label_assignment_default(self):
self.assertTrue(self.prior.latex_label, self.prior.name)
class TestPriorIsFixed(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
del self.prior
def test_is_fixed_parent_class(self):
self.prior = bilby.core.prior.Prior()
self.assertFalse(self.prior.is_fixed)
def test_is_fixed_delta_function_class(self):
self.prior = bilby.core.prior.DeltaFunction(peak=0)
self.assertTrue(self.prior.is_fixed)
def test_is_fixed_uniform_class(self):
self.prior = bilby.core.prior.Uniform(minimum=0, maximum=10)
self.assertFalse(self.prior.is_fixed)
class TestPriorBoundary(unittest.TestCase):
def setUp(self):
self.prior = bilby.core.prior.Prior(boundary=None)
def tearDown(self):
del self.prior
def test_set_boundary_valid(self):
self.prior.boundary = "periodic"
self.assertEqual(self.prior.boundary, "periodic")
def test_set_boundary_invalid(self):
with self.assertRaises(ValueError):
self.prior.boundary = "else"
if __name__ == "__main__":
unittest.main()
import unittest
import mock
import bilby
class TestConditionalPrior(unittest.TestCase):
def setUp(self):
self.condition_func_call_counter = 0
def condition_func(reference_parameters, test_variable_1, test_variable_2):
self.condition_func_call_counter += 1
return {key: value + 1 for key, value in reference_parameters.items()}
self.condition_func = condition_func
self.minimum = 0
self.maximum = 5
self.test_variable_1 = 0
self.test_variable_2 = 1
self.prior = bilby.core.prior.ConditionalBasePrior(
condition_func=condition_func, minimum=self.minimum, maximum=self.maximum
)
def tearDown(self):
del self.condition_func
del self.condition_func_call_counter
del self.minimum
del self.maximum
del self.test_variable_1
del self.test_variable_2
del self.prior
def test_reference_params(self):
self.assertDictEqual(
dict(minimum=self.minimum, maximum=self.maximum),
self.prior.reference_params,
)
def test_required_variables(self):
self.assertListEqual(
["test_variable_1", "test_variable_2"],
sorted(self.prior.required_variables),
)
def test_required_variables_no_condition_func(self):
self.prior = bilby.core.prior.ConditionalBasePrior(
condition_func=None, minimum=self.minimum, maximum=self.maximum
)
self.assertListEqual([], self.prior.required_variables)
def test_get_instantiation_dict(self):
expected = dict(
minimum=0,
maximum=5,
name=None,
latex_label=None,
unit=None,
boundary=None,
condition_func=self.condition_func,
)
actual = self.prior.get_instantiation_dict()
for key, value in expected.items():
if key == "condition_func":
continue
self.assertEqual(value, actual[key])
def test_update_conditions_correct_variables(self):
self.prior.update_conditions(
test_variable_1=self.test_variable_1, test_variable_2=self.test_variable_2
)
self.assertEqual(1, self.condition_func_call_counter)
self.assertEqual(self.minimum + 1, self.prior.minimum)
self.assertEqual(self.maximum + 1, self.prior.maximum)
def test_update_conditions_no_variables(self):
self.prior.update_conditions(
test_variable_1=self.test_variable_1, test_variable_2=self.test_variable_2
)
self.prior.update_conditions()
self.assertEqual(1, self.condition_func_call_counter)
self.assertEqual(self.minimum + 1, self.prior.minimum)
self.assertEqual(self.maximum + 1, self.prior.maximum)
def test_update_conditions_illegal_variables(self):
with self.assertRaises(bilby.core.prior.IllegalRequiredVariablesException):
self.prior.update_conditions(test_parameter_1=self.test_variable_1)
def test_sample_calls_update_conditions(self):
with mock.patch.object(self.prior, "update_conditions") as m:
self.prior.sample(
1,
test_parameter_1=self.test_variable_1,
test_parameter_2=self.test_variable_2,
)
m.assert_called_with(
test_parameter_1=self.test_variable_1,
test_parameter_2=self.test_variable_2,
)
def test_rescale_calls_update_conditions(self):
with mock.patch.object(self.prior, "update_conditions") as m:
self.prior.rescale(
1,
test_parameter_1=self.test_variable_1,
test_parameter_2=self.test_variable_2,
)
m.assert_called_with(
test_parameter_1=self.test_variable_1,
test_parameter_2=self.test_variable_2,
)
def test_rescale_prob_update_conditions(self):
with mock.patch.object(self.prior, "update_conditions") as m:
self.prior.prob(
1,
test_parameter_1=self.test_variable_1,
test_parameter_2=self.test_variable_2,
)
m.assert_called_with(
test_parameter_1=self.test_variable_1,
test_parameter_2=self.test_variable_2,
)
def test_rescale_ln_prob_update_conditions(self):
with mock.patch.object(self.prior, "update_conditions") as m:
self.prior.ln_prob(
1,
test_parameter_1=self.test_variable_1,
test_parameter_2=self.test_variable_2,
)
calls = [
mock.call(
test_parameter_1=self.test_variable_1,
test_parameter_2=self.test_variable_2,
),
mock.call(),
]
m.assert_has_calls(calls)
def test_reset_to_reference_parameters(self):
self.prior.minimum = 10
self.prior.maximum = 20
self.prior.reset_to_reference_parameters()
self.assertEqual(self.prior.reference_params["minimum"], self.prior.minimum)
self.assertEqual(self.prior.reference_params["maximum"], self.prior.maximum)
def test_cond_prior_instantiation_no_boundary_prior(self):
prior = bilby.core.prior.ConditionalFermiDirac(
condition_func=None, sigma=1, mu=1
)
self.assertIsNone(prior.boundary)
class TestConditionalPriorDict(unittest.TestCase):
def setUp(self):
def condition_func_1(reference_parameters, var_0):
return reference_parameters
def condition_func_2(reference_parameters, var_0, var_1):
return reference_parameters
def condition_func_3(reference_parameters, var_1, var_2):
return reference_parameters
self.minimum = 0
self.maximum = 1
self.prior_0 = bilby.core.prior.Uniform(
minimum=self.minimum, maximum=self.maximum
)
self.prior_1 = bilby.core.prior.ConditionalUniform(
condition_func=condition_func_1, minimum=self.minimum, maximum=self.maximum
)
self.prior_2 = bilby.core.prior.ConditionalUniform(
condition_func=condition_func_2, minimum=self.minimum, maximum=self.maximum
)
self.prior_3 = bilby.core.prior.ConditionalUniform(
condition_func=condition_func_3, minimum=self.minimum, maximum=self.maximum
)
self.conditional_priors = bilby.core.prior.ConditionalPriorDict(
dict(
var_3=self.prior_3,
var_2=self.prior_2,
var_0=self.prior_0,
var_1=self.prior_1,
)
)
self.conditional_priors_manually_set_items = (
bilby.core.prior.ConditionalPriorDict()
)
self.test_sample = dict(var_0=0.3, var_1=0.4, var_2=0.5, var_3=0.4)
for key, value in dict(
var_0=self.prior_0,
var_1=self.prior_1,
var_2=self.prior_2,
var_3=self.prior_3,
).items():
self.conditional_priors_manually_set_items[key] = value
def tearDown(self):
del self.minimum
del self.maximum
del self.prior_0
del self.prior_1
del self.prior_2
del self.prior_3
del self.conditional_priors
del self.conditional_priors_manually_set_items
del self.test_sample
def test_conditions_resolved_upon_instantiation(self):
self.assertListEqual(
["var_0", "var_1", "var_2", "var_3"], self.conditional_priors.sorted_keys
)
def test_conditions_resolved_setting_items(self):
self.assertListEqual(
["var_0", "var_1", "var_2", "var_3"],
self.conditional_priors_manually_set_items.sorted_keys,
)
def test_unconditional_keys_upon_instantiation(self):
self.assertListEqual(["var_0"], self.conditional_priors.unconditional_keys)
def test_unconditional_keys_setting_items(self):
self.assertListEqual(
["var_0"], self.conditional_priors_manually_set_items.unconditional_keys
)
def test_conditional_keys_upon_instantiation(self):
self.assertListEqual(
["var_1", "var_2", "var_3"], self.conditional_priors.conditional_keys
)
def test_conditional_keys_setting_items(self):
self.assertListEqual(
["var_1", "var_2", "var_3"],
self.conditional_priors_manually_set_items.conditional_keys,
)
def test_prob(self):
self.assertEqual(1, self.conditional_priors.prob(sample=self.test_sample))
def test_prob_illegal_conditions(self):
del self.conditional_priors["var_0"]
with self.assertRaises(bilby.core.prior.IllegalConditionsException):
self.conditional_priors.prob(sample=self.test_sample)
def test_ln_prob(self):
self.assertEqual(0, self.conditional_priors.ln_prob(sample=self.test_sample))
def test_ln_prob_illegal_conditions(self):
del self.conditional_priors["var_0"]
with self.assertRaises(bilby.core.prior.IllegalConditionsException):
self.conditional_priors.ln_prob(sample=self.test_sample)
def test_sample_subset_all_keys(self):
with mock.patch("numpy.random.uniform") as m:
m.return_value = 0.5
self.assertDictEqual(
dict(var_0=0.5, var_1=0.5, var_2=0.5, var_3=0.5),
self.conditional_priors.sample_subset(
keys=["var_0", "var_1", "var_2", "var_3"]
),
)
def test_sample_illegal_subset(self):
with mock.patch("numpy.random.uniform") as m:
m.return_value = 0.5
with self.assertRaises(bilby.core.prior.IllegalConditionsException):
self.conditional_priors.sample_subset(keys=["var_1"])
def test_sample_multiple(self):
def condition_func(reference_params, a):
return dict(
minimum=reference_params["minimum"],
maximum=reference_params["maximum"],
alpha=reference_params["alpha"] * a,
)
priors = bilby.core.prior.ConditionalPriorDict()
priors["a"] = bilby.core.prior.Uniform(minimum=0, maximum=1)
priors["b"] = bilby.core.prior.ConditionalPowerLaw(
condition_func=condition_func, minimum=1, maximum=2, alpha=-2
)
print(priors.sample(2))
def test_rescale(self):
def condition_func_1_rescale(reference_parameters, var_0):
if var_0 == 0.5:
return dict(minimum=reference_parameters["minimum"], maximum=1)
return reference_parameters
def condition_func_2_rescale(reference_parameters, var_0, var_1):
if var_0 == 0.5 and var_1 == 0.5:
return dict(minimum=reference_parameters["minimum"], maximum=1)
return reference_parameters
def condition_func_3_rescale(reference_parameters, var_1, var_2):
if var_1 == 0.5 and var_2 == 0.5:
return dict(minimum=reference_parameters["minimum"], maximum=1)
return reference_parameters
self.prior_0 = bilby.core.prior.Uniform(minimum=self.minimum, maximum=1)
self.prior_1 = bilby.core.prior.ConditionalUniform(
condition_func=condition_func_1_rescale, minimum=self.minimum, maximum=2
)
self.prior_2 = bilby.core.prior.ConditionalUniform(
condition_func=condition_func_2_rescale, minimum=self.minimum, maximum=2
)
self.prior_3 = bilby.core.prior.ConditionalUniform(
condition_func=condition_func_3_rescale, minimum=self.minimum, maximum=2
)
self.conditional_priors = bilby.core.prior.ConditionalPriorDict(
dict(
var_3=self.prior_3,
var_2=self.prior_2,
var_0=self.prior_0,
var_1=self.prior_1,
)
)
ref_variables = [0.5, 0.5, 0.5, 0.5]
res = self.conditional_priors.rescale(
keys=list(self.test_sample.keys()), theta=ref_variables