diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py index 72966d35f7380973beca2276fe6ca7f7ad9e2eab..43835e9357da88abe991394b1f3bff4e6e97f632 100644 --- a/bilby/gw/likelihood.py +++ b/bilby/gw/likelihood.py @@ -551,13 +551,14 @@ class GravitationalWaveTransient(likelihood.Likelihood): def load_lookup_table(self, filename): if os.path.exists(filename): loaded_file = dict(np.load(filename)) - if self._test_cached_lookup_table(loaded_file): + match, failure = self._test_cached_lookup_table(loaded_file) + if match: logger.info('Loaded distance marginalisation lookup table from ' '{}.'.format(filename)) return loaded_file else: logger.info('Loaded distance marginalisation lookup table does ' - 'not match prior') + 'not match for {}.'.format(failure)) return None elif isinstance(filename, str): logger.info('Distance marginalisation file {} does not ' @@ -571,13 +572,22 @@ class GravitationalWaveTransient(likelihood.Likelihood): distance_array=self._distance_array, prior_array=self.distance_prior_array, lookup_table=self._dist_margd_loglikelihood_array, - reference_distance=self._ref_dist) + reference_distance=self._ref_dist, + phase_marginalization=self.phase_marginalization) def _test_cached_lookup_table(self, loaded_file): - cond_a = np.all(self._distance_array == loaded_file['distance_array']) - cond_b = np.all(self.distance_prior_array == loaded_file['prior_array']) - cond_c = self._ref_dist == loaded_file['reference_distance'] - return all([cond_a, cond_b, cond_c]) + pairs = dict( + distance_array=self._distance_array, + prior_array=self.distance_prior_array, + reference_distance=self._ref_dist, + phase_marginalization=self.phase_marginalization) + for key in pairs: + if key not in loaded_file: + return False, key + elif not np.array_equal(np.atleast_1d(loaded_file[key]), + np.atleast_1d(pairs[key])): + return False, key + return True, None def _create_lookup_table(self): """ Make the lookup table """