From cfc6973d8929d7205025a8ac1f5a3aed10d1e501 Mon Sep 17 00:00:00 2001 From: Moritz Huebner <email@moritz-huebner.de> Date: Tue, 22 May 2018 18:58:20 +1000 Subject: [PATCH] Moritz Huebner: Added more tests for the Sampler class --- test/sampler_tests.py | 131 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 114 insertions(+), 17 deletions(-) diff --git a/test/sampler_tests.py b/test/sampler_tests.py index 494d4a18b..44b93ba9e 100644 --- a/test/sampler_tests.py +++ b/test/sampler_tests.py @@ -1,33 +1,130 @@ from context import tupak from tupak import prior -from tupak import likelihood +from tupak.result import Result import unittest -from mock import Mock, MagicMock +from mock import MagicMock import numpy as np +import inspect +import os +import copy class TestSamplerInstantiation(unittest.TestCase): def setUp(self): - self.likelihood = likelihood.Likelihood() - self.likelihood.parameters = MagicMock(return_value=dict(a=1, b=2, c=3)) - delta_prior = prior.DeltaFunction(0) - delta_prior.peak = MagicMock(return_value=0) - delta_prior.rescale = MagicMock(return_value=delta_prior) + likelihood = tupak.likelihood.Likelihood() + likelihood.parameters = dict(a=1, b=2, c=3) + delta_prior = prior.DeltaFunction(peak=0) + delta_prior.rescale = MagicMock(return_value=prior.DeltaFunction(peak=1)) delta_prior.prob = MagicMock(return_value=1) + delta_prior.sample = MagicMock(return_value=0) uniform_prior = prior.Uniform(0, 1) - uniform_prior.minimum = MagicMock(return_value=0) - uniform_prior.maximum = MagicMock(return_value=1) - uniform_prior.rescale = MagicMock(return_value=uniform_prior) + uniform_prior.rescale = MagicMock(return_value=prior.Uniform(0, 2)) uniform_prior.prob = MagicMock(return_value=1) + uniform_prior.sample = MagicMock(return_value=0.5) - self.priors = dict(a=delta_prior, b='string', c=uniform_prior) - self.likelihood.log_likelihood_ratio = MagicMock(return_value=1) - self.likelihood.log_likelihood = MagicMock(return_value=2) + priors = dict(a=delta_prior, b='string', c=uniform_prior) + likelihood.log_likelihood_ratio = MagicMock(return_value=1) + likelihood.log_likelihood = MagicMock(return_value=2) + test_directory = 'test_directory' + if os.path.isdir(test_directory): + os.rmdir(test_directory) + self.sampler = tupak.sampler.Sampler(likelihood=likelihood, + priors=priors, + external_sampler='nestle', + outdir=test_directory, + use_ratio=False) def tearDown(self): - del self.likelihood - del self.priors + os.rmdir(self.sampler.outdir) + del self.sampler - def test_default_instantiation(self): - sampler = tupak.sampler.Sampler(self.likelihood, self.priors) + def test_search_parameter_keys(self): + expected_search_parameter_keys = ['c'] + self.assertListEqual(self.sampler.search_parameter_keys, expected_search_parameter_keys) + + def test_fixed_parameter_keys(self): + expected_fixed_parameter_keys = ['a'] + self.assertListEqual(self.sampler.fixed_parameter_keys, expected_fixed_parameter_keys) + + def test_ndim(self): + self.assertEqual(self.sampler.ndim, 1) + + def test_kwargs(self): + self.assertDictEqual(self.sampler.kwargs, {}) + + def test_label(self): + self.assertEqual(self.sampler.label, 'label') + + def test_if_external_sampler_is_module(self): + self.assertTrue(inspect.ismodule(self.sampler.external_sampler)) + + def test_if_external_sampler_has_the_correct_module_name(self): + expected_name = 'nestle' + self.assertEqual(self.sampler.external_sampler.__name__, expected_name) + + def test_external_sampler_raises_if_sampler_not_installed(self): + with self.assertRaises(ImportError): + self.sampler.external_sampler = 'unexpected_sampler' + + def test_setting_custom_sampler(self): + other_sampler = tupak.sampler.Sampler(self.sampler.likelihood, + self.sampler.priors) + self.sampler.external_sampler = other_sampler + self.assertEqual(self.sampler.external_sampler, other_sampler) + + def test_setting_external_sampler_to_something_else_raises_error(self): + with self.assertRaises(TypeError): + self.sampler.external_sampler = object() + + def test_result(self): + expected_result = Result() + expected_result.search_parameter_keys = ['c'] + expected_result.fixed_parameter_keys = ['a'] + expected_result.parameter_labels = ['c'] + expected_result.label = 'label' + expected_result.outdir = 'outdir' + expected_result.kwargs = {} + self.assertDictEqual(self.sampler.result.__dict__, expected_result.__dict__) + + def test_make_outdir_if_no_outdir_exists(self): + self.assertTrue(os.path.isdir(self.sampler.outdir)) + + def test_prior_transform_transforms_search_parameter_keys(self): + self.sampler.prior_transform([0]) + expected_prior = prior.Uniform(0, 1) + self.assertListEqual([self.sampler.priors['c'].minimum, + self.sampler.priors['c'].maximum], + [expected_prior.minimum, + expected_prior.maximum]) + + def test_prior_transform_does_not_transform_fixed_parameter_keys(self): + self.sampler.prior_transform([0]) + self.assertEqual(self.sampler.priors['a'].peak, + prior.DeltaFunction(peak=0).peak) + + def test_log_prior(self): + self.assertEqual(self.sampler.log_prior({1}), 0.0) + + def test_log_likelihood_with_use_ratio(self): + self.sampler.use_ratio = True + self.assertEqual(self.sampler.log_likelihood([0]), 1) + + def test_log_likelihood_without_use_ratio(self): + self.sampler.use_ratio = False + self.assertEqual(self.sampler.log_likelihood([0]), 2) + + def test_log_likelihood_correctly_sets_parameters(self): + expected_dict = dict(a=0, + b=2, + c=0) + _ = self.sampler.log_likelihood([0]) + self.assertDictEqual(self.sampler.likelihood.parameters, expected_dict) + + def test_get_random_draw(self): + self.assertEqual(self.sampler.get_random_draw_from_prior(), np.array([0.5])) + + def test_base_run_sampler(self): + sampler_copy = copy.copy(self.sampler) + self.sampler.run_sampler() + self.assertDictEqual(sampler_copy.__dict__, self.sampler.__dict__) \ No newline at end of file -- GitLab