From ad250b2fd5816e8b1210a0f54315c14573a19242 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Tue, 17 Sep 2019 13:49:18 +1000
Subject: [PATCH] Update tests

---
 test/sampler_test.py | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/test/sampler_test.py b/test/sampler_test.py
index 5835f5a15..04f42e81a 100644
--- a/test/sampler_test.py
+++ b/test/sampler_test.py
@@ -145,7 +145,7 @@ class TestDynesty(unittest.TestCase):
         del self.sampler
 
     def test_default_kwargs(self):
-        expected = dict(bound='multi', sample='rwalk', periodic=None, verbose=True,
+        expected = dict(bound='multi', sample='rwalk', periodic=None, reflective=None, verbose=True,
                         check_point_delta_t=600, nlive=1000, first_update=None,
                         npdim=None, rstate=None, queue_size=None, pool=None,
                         use_pool=None, live_points=None, logl_args=None, logl_kwargs=None,
@@ -155,15 +155,18 @@ class TestDynesty(unittest.TestCase):
                         logl_max=np.inf, add_live=True, print_progress=True, save_bounds=False,
                         walks=20, update_interval=600, print_func='func')
         self.sampler.kwargs['print_func'] = 'func'  # set this manually as this is not testable otherwise
-        self.assertEqual([], self.sampler.kwargs['periodic'])  # Check this separately
-        self.sampler.kwargs['periodic'] = expected['periodic']  # The dict comparison can't handle lists
+        # DictEqual can't handle lists so we check these separately
+        self.assertEqual([], self.sampler.kwargs['periodic'])
+        self.assertEqual([], self.sampler.kwargs['reflective'])
+        self.sampler.kwargs['periodic'] = expected['periodic']
+        self.sampler.kwargs['reflective'] = expected['reflective']
         for key in self.sampler.kwargs.keys():
             print("key={}, expected={}, actual={}"
                   .format(key, expected[key], self.sampler.kwargs[key]))
         self.assertDictEqual(expected, self.sampler.kwargs)
 
     def test_translate_kwargs(self):
-        expected = dict(bound='multi', sample='rwalk', periodic=[], verbose=True,
+        expected = dict(bound='multi', sample='rwalk', periodic=[], reflective=[], verbose=True,
                         check_point_delta_t=600, nlive=1000, first_update=None,
                         npdim=None, rstate=None, queue_size=None, pool=None,
                         use_pool=None, live_points=None, logl_args=None, logl_kwargs=None,
@@ -191,9 +194,10 @@ class TestDynesty(unittest.TestCase):
                                                   outdir='outdir', label='label',
                                                   use_ratio=False, plot=False,
                                                   skip_import_verification=True)
-        self.assertEqual([0, 1, 3, 4], self.sampler.kwargs["periodic"])
-        self.assertEqual([0, 4], self.sampler._periodic)
-        self.assertEqual([1, 3], self.sampler._reflective)
+        self.assertEqual([0, 4], self.sampler.kwargs["periodic"])
+        self.assertEqual(self.sampler._periodic, self.sampler.kwargs["periodic"])
+        self.assertEqual([1, 3], self.sampler.kwargs["reflective"])
+        self.assertEqual(self.sampler._reflective, self.sampler.kwargs["reflective"])
 
 
 class TestEmcee(unittest.TestCase):
-- 
GitLab