From edefa404f49f0508a8973f2293e95a0f3d0d3457 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Thu, 7 Nov 2019 15:01:29 -0600 Subject: [PATCH] Collection of fixes to resolve issues with testing --- .gitlab-ci.yml | 10 ---------- bilby/core/sampler/emcee.py | 22 +++++++++++++++------- test/gw_utils_test.py | 6 +++--- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index fc710b579..54485557f 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -44,15 +44,6 @@ basic-3.7: <<: *test-python image: python:3.7 -# test example on python 2 -python-2.7: - stage: test - image: bilbydev/bilby-test-suite-python27 - script: - - python -m pip install . - # Run tests without finding coverage - - pytest --ignore=test/utils_py3_test.py - # test example on python 3 python-3.7: stage: test @@ -97,7 +88,6 @@ pages: stage: deploy dependencies: - python-3.7 - - python-2.7 script: - mkdir public/ - mv htmlcov/ public/ diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index c2fc9a6f5..12d795d33 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -57,6 +57,8 @@ class Emcee(MCMCSampler): pos0=None, nburn=None, burn_in_fraction=0.25, resume=True, burn_in_act=3, **kwargs): import emcee + self.emcee = emcee + if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'): self.prerelease = True else: @@ -93,8 +95,8 @@ class Emcee(MCMCSampler): @property def sampler_function_kwargs(self): - import emcee - keys = ['lnprob0', 'rstate0', 'blobs0', 'iterations', 'thin', 'storechain', 'mh_proposal'] + keys = ['lnprob0', 'rstate0', 'blobs0', 'iterations', 'thin', + 'storechain', 'mh_proposal'] # updated function keywords for emcee > v2.2.1 updatekeys = {'p0': 'initial_state', @@ -107,7 +109,8 @@ class Emcee(MCMCSampler): if self.prerelease: if function_kwargs['mh_proposal'] is not None: logger.warning("The 'mh_proposal' option is no longer used " - "in emcee v{}, and will be ignored.".format(emcee.__version__)) + "in emcee v{}, and will be ignored.".format( + self.emcee.__version__)) del function_kwargs['mh_proposal'] for key in updatekeys: @@ -259,8 +262,7 @@ class Emcee(MCMCSampler): sys.exit() def _initialise_sampler(self): - import emcee - self._sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs) + self._sampler = self.emcee.EnsembleSampler(**self.sampler_init_kwargs) self._init_chain_file() @property @@ -307,7 +309,10 @@ class Emcee(MCMCSampler): This is used when loading in a sampler from a pickle file to figure out how much of the run has already been completed """ - return len(self.sampler.blobs) + try: + return len(self.sampler.blobs) + except AttributeError: + return 0 def _draw_pos0_from_prior(self): return np.array( @@ -344,7 +349,10 @@ class Emcee(MCMCSampler): iterations = sampler_function_kwargs.pop('iterations') iterations -= self._previous_iterations - sampler_function_kwargs['p0'] = self.pos0 + if self.prerelease: + sampler_function_kwargs['initial_state'] = self.pos0 + else: + sampler_function_kwargs['p0'] = self.pos0 # main iteration loop for sample in tqdm( diff --git a/test/gw_utils_test.py b/test/gw_utils_test.py index b8d1838ca..e9b7b6262 100644 --- a/test/gw_utils_test.py +++ b/test/gw_utils_test.py @@ -124,7 +124,7 @@ class TestGWUtils(unittest.TestCase): strain = gwutils.read_frame_file( filename, start_time=None, end_time=None, channel=channel) self.assertEqual(strain.channel.name, channel) - self.assertTrue(np.all(strain.value==data)) + self.assertTrue(np.all(strain.value==data[:-1])) # Check reading with time limits start_cut = 2 @@ -138,12 +138,12 @@ class TestGWUtils(unittest.TestCase): # Check reading with unknown channels strain = gwutils.read_frame_file( filename, start_time=None, end_time=None) - self.assertTrue(np.all(strain.value==data)) + self.assertTrue(np.all(strain.value==data[:-1])) # Check reading with incorrect channel strain = gwutils.read_frame_file( filename, start_time=None, end_time=None, channel='WRONG') - self.assertTrue(np.all(strain.value==data)) + self.assertTrue(np.all(strain.value==data[:-1])) ts = gwpy.timeseries.TimeSeries(data=data, times=times, t0=0) ts.name = 'NOT-A-KNOWN-CHANNEL' -- GitLab