From 6071737e307ed7d16d216d68c61cd59da2fab416 Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Tue, 8 Mar 2022 02:29:41 +0000 Subject: [PATCH] Refactor marginalized likelihood tests --- test/gw/likelihood_test.py | 384 ++++++++++--------------------------- 1 file changed, 105 insertions(+), 279 deletions(-) diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 6e34f5872..8ad6cac41 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -1,6 +1,9 @@ +import os +import pytest import unittest from copy import deepcopy -import os +from itertools import product +from parameterized import parameterized import numpy as np import bilby @@ -427,9 +430,17 @@ class TestMarginalizations(unittest.TestCase): For time, this is strongly dependent on the specific time grid used. The `time_jitter` parameter makes this a weaker dependence during sampling. """ + _parameters = product( + ["regular", "roq"], + ["luminosity_distance", "geocent_time", "phase"], + [True, False], + [True, False], + [True, False], + ) lookup_phase = "distance_lookup_phase.npz" lookup_no_phase = "distance_lookup_no_phase.npz" + path_to_roq_weights = "weights.npz" def setUp(self): np.random.seed(500) @@ -466,6 +477,11 @@ class TestMarginalizations(unittest.TestCase): sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, start_time=1126259640, + waveform_arguments=dict( + reference_frequency=20.0, + minimum_frequency=20.0, + approximant="IMRPhenomPv2", + ) ) self.interferometers.inject_signal( parameters=self.parameters, waveform_generator=self.waveform_generator @@ -473,9 +489,38 @@ class TestMarginalizations(unittest.TestCase): self.priors = bilby.gw.prior.BBHPriorDict() self.priors["geocent_time"] = bilby.prior.Uniform( - minimum=self.interferometers.start_time, - maximum=self.interferometers.start_time + self.interferometers.duration, + minimum=self.parameters["geocent_time"] - 0.1, + maximum=self.parameters["geocent_time"] + 0.1 + ) + + trial_roq_paths = [ + "/roq_basis", + os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), + "/home/cbc/ROQ_data/IMRPhenomPv2/4s", + ] + roq_dir = None + for path in trial_roq_paths: + if os.path.isdir(path): + roq_dir = path + break + if roq_dir is None: + raise Exception("Unable to load ROQ basis: cannot proceed with tests") + + self.roq_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( + duration=self.duration, + sampling_frequency=self.sampling_frequency, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, + start_time=1126259640, + waveform_arguments=dict( + reference_frequency=20.0, + minimum_frequency=20.0, + approximant="IMRPhenomPv2", + frequency_nodes_linear=np.load(f"{roq_dir}/fnodes_linear.npy"), + frequency_nodes_quadratic=np.load(f"{roq_dir}/fnodes_quadratic.npy"), + ) ) + self.roq_linear_matrix_file = f"{roq_dir}/B_linear.npy" + self.roq_quadratic_matrix_file = f"{roq_dir}/B_quadratic.npy" def tearDown(self): del self.duration @@ -483,22 +528,17 @@ class TestMarginalizations(unittest.TestCase): del self.parameters del self.interferometers del self.waveform_generator + del self.roq_waveform_generator del self.priors @classmethod def tearDownClass(cls): - # remove lookup tables so that they are not used accidentally in subsequent tests - for filename in [cls.lookup_phase, cls.lookup_no_phase]: + """remove lookup tables so that they are not used accidentally in subsequent tests""" + for filename in [cls.lookup_phase, cls.lookup_no_phase, cls.path_to_roq_weights]: if os.path.exists(filename): os.remove(filename) - def get_likelihood( - self, - time_marginalization=False, - phase_marginalization=False, - distance_marginalization=False, - priors=None - ): + def likelihood_kwargs(self, kind, time_marginalization, phase_marginalization, distance_marginalization, priors): if priors is None: priors = self.priors.copy() if distance_marginalization and phase_marginalization: @@ -507,7 +547,7 @@ class TestMarginalizations(unittest.TestCase): lookup = TestMarginalizations.lookup_no_phase else: lookup = None - like = bilby.gw.likelihood.GravitationalWaveTransient( + kwargs = dict( interferometers=self.interferometers, waveform_generator=self.waveform_generator, distance_marginalization=distance_marginalization, @@ -516,6 +556,36 @@ class TestMarginalizations(unittest.TestCase): distance_marginalization_lookup_table=lookup, priors=priors, ) + if kind == "roq": + kwargs.update(dict( + linear_matrix=self.roq_linear_matrix_file, + quadratic_matrix=self.roq_quadratic_matrix_file, + waveform_generator=self.roq_waveform_generator, + )) + if os.path.exists(self.__class__.path_to_roq_weights): + kwargs["weights"] = self.__class__.path_to_roq_weights + return kwargs + + def get_likelihood( + self, + kind, + time_marginalization=False, + phase_marginalization=False, + distance_marginalization=False, + priors=None + ): + kwargs = self.likelihood_kwargs( + kind, time_marginalization, phase_marginalization, distance_marginalization, priors + ) + if kind == "regular": + cls_ = bilby.gw.likelihood.GravitationalWaveTransient + elif kind == "roq": + cls_ = bilby.gw.likelihood.ROQGravitationalWaveTransient + else: + raise ValueError(f"kind {kind} not understood") + like = cls_(**kwargs) + if kind == "roq" and not os.path.exists(self.__class__.path_to_roq_weights): + like.save_weights(self.__class__.path_to_roq_weights) like.parameters = self.parameters.copy() if time_marginalization: like.parameters["geocent_time"] = self.interferometers.start_time @@ -538,290 +608,46 @@ class TestMarginalizations(unittest.TestCase): marg_like, marginalized.log_likelihood_ratio(), delta=0.5 ) - def test_distance_marginalisation(self): + @parameterized.expand(_parameters) + def test_marginalisation(self, kind, key, distance, time, phase): + if all([distance, time, phase]): + pytest.skip() + tested_args = dict( + distance_marginalization=distance, + time_marginalization=time, + phase_marginalization=phase, + ) + marg_key = f"{key.split('_')[-1]}_marginalization" + if tested_args[marg_key]: + pytest.skip() + reference_args = tested_args.copy() + reference_args[marg_key] = True self._template( - self.get_likelihood(distance_marginalization=True), - self.get_likelihood(), - key="luminosity_distance", + self.get_likelihood(kind, **reference_args), + self.get_likelihood(kind, **tested_args), + key=key, ) - def test_distance_phase_marginalisation(self): - self._template( - self.get_likelihood(distance_marginalization=True, phase_marginalization=True), - self.get_likelihood(phase_marginalization=True), - key="luminosity_distance", - ) - - def test_distance_time_marginalisation(self): - self._template( - self.get_likelihood(distance_marginalization=True, time_marginalization=True), - self.get_likelihood(time_marginalization=True), - key="luminosity_distance", - ) - - def test_distance_phase_time_marginalisation(self): - """ - Test phase marginalised likelihood matches brute force version when - also marginalising over time. - """ - self._template( - self.get_likelihood(distance_marginalization=True, phase_marginalization=True, time_marginalization=True), - self.get_likelihood(phase_marginalization=True, time_marginalization=True), - key="luminosity_distance", - ) - - def test_phase_marginalisation(self): - self._template( - self.get_likelihood(phase_marginalization=True), - self.get_likelihood(), - key="phase", - ) - - def test_phase_distance_marginalisation(self): - self._template( - self.get_likelihood(distance_marginalization=True, phase_marginalization=True), - self.get_likelihood(distance_marginalization=True), - key="phase", - ) - - def test_phase_time_marginalisation(self): - self._template( - self.get_likelihood(time_marginalization=True, phase_marginalization=True), - self.get_likelihood(time_marginalization=True), - key="phase", - ) - - def test_phase_distance_time_marginalisation(self): - self._template( - self.get_likelihood(time_marginalization=True, distance_marginalization=True, phase_marginalization=True), - self.get_likelihood(time_marginalization=True, distance_marginalization=True), - key="phase", - ) - - def test_time_marginalisation(self): - times = self.waveform_generator.time_array - self._template( - self.get_likelihood(time_marginalization=True), - self.get_likelihood(), - key="geocent_time", - values=times, - ) - - def test_time_distance_marginalisation(self): - times = self.waveform_generator.time_array - self._template( - self.get_likelihood(time_marginalization=True, distance_marginalization=True), - self.get_likelihood(distance_marginalization=True), - key="geocent_time", - values=times - ) - - def test_time_phase_marginalisation(self): - times = self.waveform_generator.time_array - self._template( - self.get_likelihood(time_marginalization=True, phase_marginalization=True), - self.get_likelihood(phase_marginalization=True), - key="geocent_time", - values=times - ) - - def test_time_distance_phase_marginalisation(self): - times = self.waveform_generator.time_array - self._template( - self.get_likelihood(time_marginalization=True, phase_marginalization=True, distance_marginalization=True), - self.get_likelihood(phase_marginalization=True, distance_marginalization=True), - key="geocent_time", - values=times - ) - - def test_time_marginalisation_partial_segment(self): + def test_time_marginalisation_full_segment(self): """ Test time marginalised likelihood matches brute force version over just part of a segment. """ priors = self.priors.copy() prior = bilby.prior.Uniform( - minimum=self.parameters["geocent_time"] - 0.1, - maximum=self.parameters["geocent_time"] + 0.1, + minimum=self.interferometers.start_time, + maximum=self.interferometers.start_time + self.interferometers.duration, ) priors["geocent_time"] = prior self._template( - self.get_likelihood(time_marginalization=True, priors=priors.copy()), - self.get_likelihood(priors=priors.copy()), + self.get_likelihood("regular", time_marginalization=True, priors=priors.copy()), + self.get_likelihood("regular", priors=priors.copy()), key="geocent_time", values=self.waveform_generator.time_array, prior=prior, ) -class TestMarginalizationsROQ(TestMarginalizations): - - lookup_phase = "distance_lookup_phase.npz" - lookup_no_phase = "distance_lookup_no_phase.npz" - path_to_roq_weights = "weights.npz" - - def setUp(self): - np.random.seed(500) - self.duration = 4 - self.sampling_frequency = 2048 - self.parameters = dict( - mass_1=31.0, - mass_2=29.0, - a_1=0.4, - a_2=0.3, - tilt_1=0.0, - tilt_2=0.0, - phi_12=1.7, - phi_jl=0.3, - luminosity_distance=4000.0, - theta_jn=0.4, - psi=2.659, - phase=1.3, - geocent_time=1126259642.413, - ra=1.375, - dec=-1.2108, - time_jitter=0, - ) - - self.interferometers = bilby.gw.detector.InterferometerList(["H1"]) - self.interferometers.set_strain_data_from_power_spectral_densities( - sampling_frequency=self.sampling_frequency, - duration=self.duration, - start_time=1126259640, - ) - - waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - duration=self.duration, - sampling_frequency=self.sampling_frequency, - frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, - start_time=1126259640, - waveform_arguments=dict( - reference_frequency=20.0, - minimum_frequency=20.0, - approximant="IMRPhenomPv2" - ) - ) - self.interferometers.inject_signal( - parameters=self.parameters, waveform_generator=waveform_generator - ) - - self.priors = bilby.gw.prior.BBHPriorDict() - # prior range should be a part of segment since ROQ likelihood can not - # calculate values at samples close to edges - self.priors["geocent_time"] = bilby.prior.Uniform( - minimum=self.parameters["geocent_time"] - 0.1, - maximum=self.parameters["geocent_time"] + 0.1 - ) - - # Possible locations for the ROQ: in the docker image, local, or on CIT - trial_roq_paths = [ - "/roq_basis", - os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), - "/home/cbc/ROQ_data/IMRPhenomPv2/4s", - ] - roq_dir = None - for path in trial_roq_paths: - if os.path.isdir(path): - roq_dir = path - break - if roq_dir is None: - raise Exception("Unable to load ROQ basis: cannot proceed with tests") - - self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - duration=self.duration, - sampling_frequency=self.sampling_frequency, - frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, - start_time=1126259640, - waveform_arguments=dict( - reference_frequency=20.0, - minimum_frequency=20.0, - approximant="IMRPhenomPv2", - frequency_nodes_linear=np.load("{}/fnodes_linear.npy".format(roq_dir)), - frequency_nodes_quadratic=np.load("{}/fnodes_quadratic.npy".format(roq_dir)), - ) - ) - self.roq_linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - self.roq_quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) - - @classmethod - def tearDownClass(cls): - for filename in [cls.lookup_phase, cls.lookup_no_phase, cls.path_to_roq_weights]: - if os.path.exists(filename): - os.remove(filename) - - def get_likelihood( - self, - time_marginalization=False, - phase_marginalization=False, - distance_marginalization=False, - priors=None - ): - if priors is None: - priors = self.priors.copy() - if distance_marginalization and phase_marginalization: - lookup = TestMarginalizationsROQ.lookup_phase - elif distance_marginalization: - lookup = TestMarginalizationsROQ.lookup_no_phase - else: - lookup = None - kwargs = dict( - interferometers=self.interferometers, - waveform_generator=self.waveform_generator, - distance_marginalization=distance_marginalization, - phase_marginalization=phase_marginalization, - time_marginalization=time_marginalization, - distance_marginalization_lookup_table=lookup, - priors=priors - ) - if os.path.exists(TestMarginalizationsROQ.path_to_roq_weights): - kwargs.update(dict(weights=TestMarginalizationsROQ.path_to_roq_weights)) - like = bilby.gw.likelihood.ROQGravitationalWaveTransient(**kwargs) - else: - kwargs.update( - dict( - linear_matrix=self.roq_linear_matrix_file, - quadratic_matrix=self.roq_quadratic_matrix_file - ) - ) - like = bilby.gw.likelihood.ROQGravitationalWaveTransient(**kwargs) - like.save_weights(TestMarginalizationsROQ.path_to_roq_weights) - like.parameters = self.parameters.copy() - if time_marginalization: - like.parameters["geocent_time"] = self.interferometers.start_time - return like - - def test_time_marginalisation(self): - self._template( - self.get_likelihood(time_marginalization=True), - self.get_likelihood(), - key="geocent_time", - ) - - def test_time_distance_marginalisation(self): - self._template( - self.get_likelihood(time_marginalization=True, distance_marginalization=True), - self.get_likelihood(distance_marginalization=True), - key="geocent_time", - ) - - def test_time_phase_marginalisation(self): - self._template( - self.get_likelihood(time_marginalization=True, phase_marginalization=True), - self.get_likelihood(phase_marginalization=True), - key="geocent_time", - ) - - def test_time_distance_phase_marginalisation(self): - self._template( - self.get_likelihood(time_marginalization=True, phase_marginalization=True, distance_marginalization=True), - self.get_likelihood(phase_marginalization=True, distance_marginalization=True), - key="geocent_time", - ) - - def test_time_marginalisation_partial_segment(self): - pass - - class TestROQLikelihood(unittest.TestCase): def setUp(self): self.duration = 4 -- GitLab