Skip to content
Snippets Groups Projects

make loading distance marginalisation table testing more robust

Merged Colm Talbot requested to merge distance-marg-lookup-save-phase into master
1 file
+ 17
7
Compare changes
  • Side-by-side
  • Inline
+ 17
7
@@ -551,13 +551,14 @@ class GravitationalWaveTransient(likelihood.Likelihood):
def load_lookup_table(self, filename):
if os.path.exists(filename):
loaded_file = dict(np.load(filename))
if self._test_cached_lookup_table(loaded_file):
match, failure = self._test_cached_lookup_table(loaded_file)
if match:
logger.info('Loaded distance marginalisation lookup table from '
'{}.'.format(filename))
return loaded_file
else:
logger.info('Loaded distance marginalisation lookup table does '
'not match prior')
'not match for {}.'.format(failure))
return None
elif isinstance(filename, str):
logger.info('Distance marginalisation file {} does not '
@@ -571,13 +572,22 @@ class GravitationalWaveTransient(likelihood.Likelihood):
distance_array=self._distance_array,
prior_array=self.distance_prior_array,
lookup_table=self._dist_margd_loglikelihood_array,
reference_distance=self._ref_dist)
reference_distance=self._ref_dist,
phase_marginalization=self.phase_marginalization)
def _test_cached_lookup_table(self, loaded_file):
cond_a = np.all(self._distance_array == loaded_file['distance_array'])
cond_b = np.all(self.distance_prior_array == loaded_file['prior_array'])
cond_c = self._ref_dist == loaded_file['reference_distance']
return all([cond_a, cond_b, cond_c])
pairs = dict(
distance_array=self._distance_array,
prior_array=self.distance_prior_array,
reference_distance=self._ref_dist,
phase_marginalization=self.phase_marginalization)
for key in pairs:
if key not in loaded_file:
return False, key
elif not np.array_equal(np.atleast_1d(loaded_file[key]),
np.atleast_1d(pairs[key])):
return False, key
return True, None
def _create_lookup_table(self):
""" Make the lookup table """
Loading