From db2e46a7d5b396cd7050397090ec37495cfd0a0b Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Thu, 9 May 2019 21:57:16 -0500
Subject: [PATCH] make loading distance marginalisation table testing more
 robust

---
 bilby/gw/likelihood.py | 24 +++++++++++++++++-------
 1 file changed, 17 insertions(+), 7 deletions(-)

diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py
index 72966d35f..43835e935 100644
--- a/bilby/gw/likelihood.py
+++ b/bilby/gw/likelihood.py
@@ -551,13 +551,14 @@ class GravitationalWaveTransient(likelihood.Likelihood):
     def load_lookup_table(self, filename):
         if os.path.exists(filename):
             loaded_file = dict(np.load(filename))
-            if self._test_cached_lookup_table(loaded_file):
+            match, failure = self._test_cached_lookup_table(loaded_file)
+            if match:
                 logger.info('Loaded distance marginalisation lookup table from '
                             '{}.'.format(filename))
                 return loaded_file
             else:
                 logger.info('Loaded distance marginalisation lookup table does '
-                            'not match prior')
+                            'not match for {}.'.format(failure))
                 return None
         elif isinstance(filename, str):
             logger.info('Distance marginalisation file {} does not '
@@ -571,13 +572,22 @@ class GravitationalWaveTransient(likelihood.Likelihood):
                  distance_array=self._distance_array,
                  prior_array=self.distance_prior_array,
                  lookup_table=self._dist_margd_loglikelihood_array,
-                 reference_distance=self._ref_dist)
+                 reference_distance=self._ref_dist,
+                 phase_marginalization=self.phase_marginalization)
 
     def _test_cached_lookup_table(self, loaded_file):
-        cond_a = np.all(self._distance_array == loaded_file['distance_array'])
-        cond_b = np.all(self.distance_prior_array == loaded_file['prior_array'])
-        cond_c = self._ref_dist == loaded_file['reference_distance']
-        return all([cond_a, cond_b, cond_c])
+        pairs = dict(
+            distance_array=self._distance_array,
+            prior_array=self.distance_prior_array,
+            reference_distance=self._ref_dist,
+            phase_marginalization=self.phase_marginalization)
+        for key in pairs:
+            if key not in loaded_file:
+                return False, key
+            elif not np.array_equal(np.atleast_1d(loaded_file[key]),
+                                    np.atleast_1d(pairs[key])):
+                return False, key
+        return True, None
 
     def _create_lookup_table(self):
         """ Make the lookup table """
-- 
GitLab