diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 8428ee0eab3a178f968584408ef522f7ff67a52c..e7e707c287aeb473cf3bb0528d9c78e876a39dd4 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -2,6 +2,7 @@ import itertools import os import pytest import unittest +from copy import deepcopy from itertools import product from parameterized import parameterized @@ -555,7 +556,7 @@ class TestMarginalizations(unittest.TestCase): def likelihood_kwargs(self, kind, time_marginalization, phase_marginalization, distance_marginalization, priors): if priors is None: - priors = self.priors.copy() + priors = deepcopy(self.priors) if distance_marginalization and phase_marginalization: lookup = TestMarginalizations.lookup_phase elif distance_marginalization: @@ -580,7 +581,7 @@ class TestMarginalizations(unittest.TestCase): if os.path.exists(self.__class__.path_to_roq_weights): kwargs["weights"] = self.__class__.path_to_roq_weights elif kind == "relbin": - kwargs["fiducial_parameters"] = self.parameters.copy() + kwargs["fiducial_parameters"] = deepcopy(self.parameters) kwargs["waveform_generator"] = self.relbin_waveform_generator return kwargs @@ -601,6 +602,7 @@ class TestMarginalizations(unittest.TestCase): cls_ = bilby.gw.likelihood.ROQGravitationalWaveTransient elif kind == "relbin": cls_ = bilby.gw.likelihood.RelativeBinningGravitationalWaveTransient + kwargs["epsilon"] = 0.3 self.parameters["fiducial"] = 0 else: raise ValueError(f"kind {kind} not understood") @@ -610,6 +612,10 @@ class TestMarginalizations(unittest.TestCase): like.parameters = self.parameters.copy() if time_marginalization: like.parameters["geocent_time"] = self.interferometers.start_time + if distance_marginalization: + like.parameters["luminosity_distance"] = like._ref_dist + if phase_marginalization: + like.parameters["phase"] = 0.0 return like def _template(self, marginalized, non_marginalized, key, prior=None, values=None): @@ -633,8 +639,6 @@ class TestMarginalizations(unittest.TestCase): def test_marginalisation(self, kind, key, distance, time, phase): if all([distance, time, phase]): pytest.skip() - if key == "geocent_time" and kind == "relbin": - pytest.skip() tested_args = dict( distance_marginalization=distance, time_marginalization=time, @@ -651,7 +655,8 @@ class TestMarginalizations(unittest.TestCase): key=key, ) - def test_time_marginalisation_full_segment(self): + @parameterized.expand(["regular", "relbin"]) + def test_time_marginalisation_full_segment(self, kind): """ Test time marginalised likelihood matches brute force version over just part of a segment. @@ -663,15 +668,15 @@ class TestMarginalizations(unittest.TestCase): ) priors["geocent_time"] = prior self._template( - self.get_likelihood("regular", time_marginalization=True, priors=priors.copy()), - self.get_likelihood("regular", priors=priors.copy()), + self.get_likelihood(kind, time_marginalization=True, priors=priors.copy()), + self.get_likelihood(kind, priors=priors.copy()), key="geocent_time", values=self.waveform_generator.time_array, prior=prior, ) @parameterized.expand( - itertools.product(["regular", "roq"], *itertools.repeat([True, False], 3)), + itertools.product(["regular", "roq", "relbin"], *itertools.repeat([True, False], 3)), name_func=lambda func, num, param: ( f"{func.__name__}_{num}__{param.args[0]}_" + "_".join([ ["D", "P", "T"][ii] for ii, val @@ -853,19 +858,6 @@ class TestROQLikelihood(unittest.TestCase): self.roq.parameters["geocent_time"] = -5 self.assertEqual(self.roq.log_likelihood_ratio(), np.nan_to_num(-np.inf)) - def test_phase_marginalisation_roq(self): - """Test phase marginalised likelihood matches brute force version""" - self.non_roq_phase.parameters = self.test_parameters.copy() - self.roq_phase.parameters = self.test_parameters.copy() - self.assertLess( - abs( - self.non_roq_phase.log_likelihood_ratio() - - self.roq_phase.log_likelihood_ratio() - ) - / self.non_roq_phase.log_likelihood_ratio(), - 1e-3, - ) - def test_create_roq_weights_with_params(self): roq = bilby.gw.likelihood.ROQGravitationalWaveTransient( interferometers=self.ifos,