From ad250b2fd5816e8b1210a0f54315c14573a19242 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Tue, 17 Sep 2019 13:49:18 +1000 Subject: [PATCH] Update tests --- test/sampler_test.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/test/sampler_test.py b/test/sampler_test.py index 5835f5a15..04f42e81a 100644 --- a/test/sampler_test.py +++ b/test/sampler_test.py @@ -145,7 +145,7 @@ class TestDynesty(unittest.TestCase): del self.sampler def test_default_kwargs(self): - expected = dict(bound='multi', sample='rwalk', periodic=None, verbose=True, + expected = dict(bound='multi', sample='rwalk', periodic=None, reflective=None, verbose=True, check_point_delta_t=600, nlive=1000, first_update=None, npdim=None, rstate=None, queue_size=None, pool=None, use_pool=None, live_points=None, logl_args=None, logl_kwargs=None, @@ -155,15 +155,18 @@ class TestDynesty(unittest.TestCase): logl_max=np.inf, add_live=True, print_progress=True, save_bounds=False, walks=20, update_interval=600, print_func='func') self.sampler.kwargs['print_func'] = 'func' # set this manually as this is not testable otherwise - self.assertEqual([], self.sampler.kwargs['periodic']) # Check this separately - self.sampler.kwargs['periodic'] = expected['periodic'] # The dict comparison can't handle lists + # DictEqual can't handle lists so we check these separately + self.assertEqual([], self.sampler.kwargs['periodic']) + self.assertEqual([], self.sampler.kwargs['reflective']) + self.sampler.kwargs['periodic'] = expected['periodic'] + self.sampler.kwargs['reflective'] = expected['reflective'] for key in self.sampler.kwargs.keys(): print("key={}, expected={}, actual={}" .format(key, expected[key], self.sampler.kwargs[key])) self.assertDictEqual(expected, self.sampler.kwargs) def test_translate_kwargs(self): - expected = dict(bound='multi', sample='rwalk', periodic=[], verbose=True, + expected = dict(bound='multi', sample='rwalk', periodic=[], reflective=[], verbose=True, check_point_delta_t=600, nlive=1000, first_update=None, npdim=None, rstate=None, queue_size=None, pool=None, use_pool=None, live_points=None, logl_args=None, logl_kwargs=None, @@ -191,9 +194,10 @@ class TestDynesty(unittest.TestCase): outdir='outdir', label='label', use_ratio=False, plot=False, skip_import_verification=True) - self.assertEqual([0, 1, 3, 4], self.sampler.kwargs["periodic"]) - self.assertEqual([0, 4], self.sampler._periodic) - self.assertEqual([1, 3], self.sampler._reflective) + self.assertEqual([0, 4], self.sampler.kwargs["periodic"]) + self.assertEqual(self.sampler._periodic, self.sampler.kwargs["periodic"]) + self.assertEqual([1, 3], self.sampler.kwargs["reflective"]) + self.assertEqual(self.sampler._reflective, self.sampler.kwargs["reflective"]) class TestEmcee(unittest.TestCase): -- GitLab