From db2e46a7d5b396cd7050397090ec37495cfd0a0b Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Thu, 9 May 2019 21:57:16 -0500 Subject: [PATCH] make loading distance marginalisation table testing more robust --- bilby/gw/likelihood.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py index 72966d35f..43835e935 100644 --- a/bilby/gw/likelihood.py +++ b/bilby/gw/likelihood.py @@ -551,13 +551,14 @@ class GravitationalWaveTransient(likelihood.Likelihood): def load_lookup_table(self, filename): if os.path.exists(filename): loaded_file = dict(np.load(filename)) - if self._test_cached_lookup_table(loaded_file): + match, failure = self._test_cached_lookup_table(loaded_file) + if match: logger.info('Loaded distance marginalisation lookup table from ' '{}.'.format(filename)) return loaded_file else: logger.info('Loaded distance marginalisation lookup table does ' - 'not match prior') + 'not match for {}.'.format(failure)) return None elif isinstance(filename, str): logger.info('Distance marginalisation file {} does not ' @@ -571,13 +572,22 @@ class GravitationalWaveTransient(likelihood.Likelihood): distance_array=self._distance_array, prior_array=self.distance_prior_array, lookup_table=self._dist_margd_loglikelihood_array, - reference_distance=self._ref_dist) + reference_distance=self._ref_dist, + phase_marginalization=self.phase_marginalization) def _test_cached_lookup_table(self, loaded_file): - cond_a = np.all(self._distance_array == loaded_file['distance_array']) - cond_b = np.all(self.distance_prior_array == loaded_file['prior_array']) - cond_c = self._ref_dist == loaded_file['reference_distance'] - return all([cond_a, cond_b, cond_c]) + pairs = dict( + distance_array=self._distance_array, + prior_array=self.distance_prior_array, + reference_distance=self._ref_dist, + phase_marginalization=self.phase_marginalization) + for key in pairs: + if key not in loaded_file: + return False, key + elif not np.array_equal(np.atleast_1d(loaded_file[key]), + np.atleast_1d(pairs[key])): + return False, key + return True, None def _create_lookup_table(self): """ Make the lookup table """ -- GitLab