Skip to content
Snippets Groups Projects
Commit 98d07fac authored by Colm Talbot's avatar Colm Talbot
Browse files

Merge branch 'refactor-marg-likelihood-tests' into 'master'

Refactor marginalized likelihood tests

See merge request !1078
parents c3bf8a04 6071737e
No related branches found
No related tags found
1 merge request!1078Refactor marginalized likelihood tests
Pipeline #367227 failed
import os
import pytest
import unittest
from copy import deepcopy
import os
from itertools import product
from parameterized import parameterized
import numpy as np
import bilby
......@@ -427,9 +430,17 @@ class TestMarginalizations(unittest.TestCase):
For time, this is strongly dependent on the specific time grid used.
The `time_jitter` parameter makes this a weaker dependence during sampling.
"""
_parameters = product(
["regular", "roq"],
["luminosity_distance", "geocent_time", "phase"],
[True, False],
[True, False],
[True, False],
)
lookup_phase = "distance_lookup_phase.npz"
lookup_no_phase = "distance_lookup_no_phase.npz"
path_to_roq_weights = "weights.npz"
def setUp(self):
np.random.seed(500)
......@@ -466,6 +477,11 @@ class TestMarginalizations(unittest.TestCase):
sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
start_time=1126259640,
waveform_arguments=dict(
reference_frequency=20.0,
minimum_frequency=20.0,
approximant="IMRPhenomPv2",
)
)
self.interferometers.inject_signal(
parameters=self.parameters, waveform_generator=self.waveform_generator
......@@ -473,9 +489,38 @@ class TestMarginalizations(unittest.TestCase):
self.priors = bilby.gw.prior.BBHPriorDict()
self.priors["geocent_time"] = bilby.prior.Uniform(
minimum=self.interferometers.start_time,
maximum=self.interferometers.start_time + self.interferometers.duration,
minimum=self.parameters["geocent_time"] - 0.1,
maximum=self.parameters["geocent_time"] + 0.1
)
trial_roq_paths = [
"/roq_basis",
os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"),
"/home/cbc/ROQ_data/IMRPhenomPv2/4s",
]
roq_dir = None
for path in trial_roq_paths:
if os.path.isdir(path):
roq_dir = path
break
if roq_dir is None:
raise Exception("Unable to load ROQ basis: cannot proceed with tests")
self.roq_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
duration=self.duration,
sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq,
start_time=1126259640,
waveform_arguments=dict(
reference_frequency=20.0,
minimum_frequency=20.0,
approximant="IMRPhenomPv2",
frequency_nodes_linear=np.load(f"{roq_dir}/fnodes_linear.npy"),
frequency_nodes_quadratic=np.load(f"{roq_dir}/fnodes_quadratic.npy"),
)
)
self.roq_linear_matrix_file = f"{roq_dir}/B_linear.npy"
self.roq_quadratic_matrix_file = f"{roq_dir}/B_quadratic.npy"
def tearDown(self):
del self.duration
......@@ -483,22 +528,17 @@ class TestMarginalizations(unittest.TestCase):
del self.parameters
del self.interferometers
del self.waveform_generator
del self.roq_waveform_generator
del self.priors
@classmethod
def tearDownClass(cls):
# remove lookup tables so that they are not used accidentally in subsequent tests
for filename in [cls.lookup_phase, cls.lookup_no_phase]:
"""remove lookup tables so that they are not used accidentally in subsequent tests"""
for filename in [cls.lookup_phase, cls.lookup_no_phase, cls.path_to_roq_weights]:
if os.path.exists(filename):
os.remove(filename)
def get_likelihood(
self,
time_marginalization=False,
phase_marginalization=False,
distance_marginalization=False,
priors=None
):
def likelihood_kwargs(self, kind, time_marginalization, phase_marginalization, distance_marginalization, priors):
if priors is None:
priors = self.priors.copy()
if distance_marginalization and phase_marginalization:
......@@ -507,7 +547,7 @@ class TestMarginalizations(unittest.TestCase):
lookup = TestMarginalizations.lookup_no_phase
else:
lookup = None
like = bilby.gw.likelihood.GravitationalWaveTransient(
kwargs = dict(
interferometers=self.interferometers,
waveform_generator=self.waveform_generator,
distance_marginalization=distance_marginalization,
......@@ -516,6 +556,36 @@ class TestMarginalizations(unittest.TestCase):
distance_marginalization_lookup_table=lookup,
priors=priors,
)
if kind == "roq":
kwargs.update(dict(
linear_matrix=self.roq_linear_matrix_file,
quadratic_matrix=self.roq_quadratic_matrix_file,
waveform_generator=self.roq_waveform_generator,
))
if os.path.exists(self.__class__.path_to_roq_weights):
kwargs["weights"] = self.__class__.path_to_roq_weights
return kwargs
def get_likelihood(
self,
kind,
time_marginalization=False,
phase_marginalization=False,
distance_marginalization=False,
priors=None
):
kwargs = self.likelihood_kwargs(
kind, time_marginalization, phase_marginalization, distance_marginalization, priors
)
if kind == "regular":
cls_ = bilby.gw.likelihood.GravitationalWaveTransient
elif kind == "roq":
cls_ = bilby.gw.likelihood.ROQGravitationalWaveTransient
else:
raise ValueError(f"kind {kind} not understood")
like = cls_(**kwargs)
if kind == "roq" and not os.path.exists(self.__class__.path_to_roq_weights):
like.save_weights(self.__class__.path_to_roq_weights)
like.parameters = self.parameters.copy()
if time_marginalization:
like.parameters["geocent_time"] = self.interferometers.start_time
......@@ -538,290 +608,46 @@ class TestMarginalizations(unittest.TestCase):
marg_like, marginalized.log_likelihood_ratio(), delta=0.5
)
def test_distance_marginalisation(self):
@parameterized.expand(_parameters)
def test_marginalisation(self, kind, key, distance, time, phase):
if all([distance, time, phase]):
pytest.skip()
tested_args = dict(
distance_marginalization=distance,
time_marginalization=time,
phase_marginalization=phase,
)
marg_key = f"{key.split('_')[-1]}_marginalization"
if tested_args[marg_key]:
pytest.skip()
reference_args = tested_args.copy()
reference_args[marg_key] = True
self._template(
self.get_likelihood(distance_marginalization=True),
self.get_likelihood(),
key="luminosity_distance",
self.get_likelihood(kind, **reference_args),
self.get_likelihood(kind, **tested_args),
key=key,
)
def test_distance_phase_marginalisation(self):
self._template(
self.get_likelihood(distance_marginalization=True, phase_marginalization=True),
self.get_likelihood(phase_marginalization=True),
key="luminosity_distance",
)
def test_distance_time_marginalisation(self):
self._template(
self.get_likelihood(distance_marginalization=True, time_marginalization=True),
self.get_likelihood(time_marginalization=True),
key="luminosity_distance",
)
def test_distance_phase_time_marginalisation(self):
"""
Test phase marginalised likelihood matches brute force version when
also marginalising over time.
"""
self._template(
self.get_likelihood(distance_marginalization=True, phase_marginalization=True, time_marginalization=True),
self.get_likelihood(phase_marginalization=True, time_marginalization=True),
key="luminosity_distance",
)
def test_phase_marginalisation(self):
self._template(
self.get_likelihood(phase_marginalization=True),
self.get_likelihood(),
key="phase",
)
def test_phase_distance_marginalisation(self):
self._template(
self.get_likelihood(distance_marginalization=True, phase_marginalization=True),
self.get_likelihood(distance_marginalization=True),
key="phase",
)
def test_phase_time_marginalisation(self):
self._template(
self.get_likelihood(time_marginalization=True, phase_marginalization=True),
self.get_likelihood(time_marginalization=True),
key="phase",
)
def test_phase_distance_time_marginalisation(self):
self._template(
self.get_likelihood(time_marginalization=True, distance_marginalization=True, phase_marginalization=True),
self.get_likelihood(time_marginalization=True, distance_marginalization=True),
key="phase",
)
def test_time_marginalisation(self):
times = self.waveform_generator.time_array
self._template(
self.get_likelihood(time_marginalization=True),
self.get_likelihood(),
key="geocent_time",
values=times,
)
def test_time_distance_marginalisation(self):
times = self.waveform_generator.time_array
self._template(
self.get_likelihood(time_marginalization=True, distance_marginalization=True),
self.get_likelihood(distance_marginalization=True),
key="geocent_time",
values=times
)
def test_time_phase_marginalisation(self):
times = self.waveform_generator.time_array
self._template(
self.get_likelihood(time_marginalization=True, phase_marginalization=True),
self.get_likelihood(phase_marginalization=True),
key="geocent_time",
values=times
)
def test_time_distance_phase_marginalisation(self):
times = self.waveform_generator.time_array
self._template(
self.get_likelihood(time_marginalization=True, phase_marginalization=True, distance_marginalization=True),
self.get_likelihood(phase_marginalization=True, distance_marginalization=True),
key="geocent_time",
values=times
)
def test_time_marginalisation_partial_segment(self):
def test_time_marginalisation_full_segment(self):
"""
Test time marginalised likelihood matches brute force version over
just part of a segment.
"""
priors = self.priors.copy()
prior = bilby.prior.Uniform(
minimum=self.parameters["geocent_time"] - 0.1,
maximum=self.parameters["geocent_time"] + 0.1,
minimum=self.interferometers.start_time,
maximum=self.interferometers.start_time + self.interferometers.duration,
)
priors["geocent_time"] = prior
self._template(
self.get_likelihood(time_marginalization=True, priors=priors.copy()),
self.get_likelihood(priors=priors.copy()),
self.get_likelihood("regular", time_marginalization=True, priors=priors.copy()),
self.get_likelihood("regular", priors=priors.copy()),
key="geocent_time",
values=self.waveform_generator.time_array,
prior=prior,
)
class TestMarginalizationsROQ(TestMarginalizations):
lookup_phase = "distance_lookup_phase.npz"
lookup_no_phase = "distance_lookup_no_phase.npz"
path_to_roq_weights = "weights.npz"
def setUp(self):
np.random.seed(500)
self.duration = 4
self.sampling_frequency = 2048
self.parameters = dict(
mass_1=31.0,
mass_2=29.0,
a_1=0.4,
a_2=0.3,
tilt_1=0.0,
tilt_2=0.0,
phi_12=1.7,
phi_jl=0.3,
luminosity_distance=4000.0,
theta_jn=0.4,
psi=2.659,
phase=1.3,
geocent_time=1126259642.413,
ra=1.375,
dec=-1.2108,
time_jitter=0,
)
self.interferometers = bilby.gw.detector.InterferometerList(["H1"])
self.interferometers.set_strain_data_from_power_spectral_densities(
sampling_frequency=self.sampling_frequency,
duration=self.duration,
start_time=1126259640,
)
waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
duration=self.duration,
sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
start_time=1126259640,
waveform_arguments=dict(
reference_frequency=20.0,
minimum_frequency=20.0,
approximant="IMRPhenomPv2"
)
)
self.interferometers.inject_signal(
parameters=self.parameters, waveform_generator=waveform_generator
)
self.priors = bilby.gw.prior.BBHPriorDict()
# prior range should be a part of segment since ROQ likelihood can not
# calculate values at samples close to edges
self.priors["geocent_time"] = bilby.prior.Uniform(
minimum=self.parameters["geocent_time"] - 0.1,
maximum=self.parameters["geocent_time"] + 0.1
)
# Possible locations for the ROQ: in the docker image, local, or on CIT
trial_roq_paths = [
"/roq_basis",
os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"),
"/home/cbc/ROQ_data/IMRPhenomPv2/4s",
]
roq_dir = None
for path in trial_roq_paths:
if os.path.isdir(path):
roq_dir = path
break
if roq_dir is None:
raise Exception("Unable to load ROQ basis: cannot proceed with tests")
self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
duration=self.duration,
sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq,
start_time=1126259640,
waveform_arguments=dict(
reference_frequency=20.0,
minimum_frequency=20.0,
approximant="IMRPhenomPv2",
frequency_nodes_linear=np.load("{}/fnodes_linear.npy".format(roq_dir)),
frequency_nodes_quadratic=np.load("{}/fnodes_quadratic.npy".format(roq_dir)),
)
)
self.roq_linear_matrix_file = "{}/B_linear.npy".format(roq_dir)
self.roq_quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir)
@classmethod
def tearDownClass(cls):
for filename in [cls.lookup_phase, cls.lookup_no_phase, cls.path_to_roq_weights]:
if os.path.exists(filename):
os.remove(filename)
def get_likelihood(
self,
time_marginalization=False,
phase_marginalization=False,
distance_marginalization=False,
priors=None
):
if priors is None:
priors = self.priors.copy()
if distance_marginalization and phase_marginalization:
lookup = TestMarginalizationsROQ.lookup_phase
elif distance_marginalization:
lookup = TestMarginalizationsROQ.lookup_no_phase
else:
lookup = None
kwargs = dict(
interferometers=self.interferometers,
waveform_generator=self.waveform_generator,
distance_marginalization=distance_marginalization,
phase_marginalization=phase_marginalization,
time_marginalization=time_marginalization,
distance_marginalization_lookup_table=lookup,
priors=priors
)
if os.path.exists(TestMarginalizationsROQ.path_to_roq_weights):
kwargs.update(dict(weights=TestMarginalizationsROQ.path_to_roq_weights))
like = bilby.gw.likelihood.ROQGravitationalWaveTransient(**kwargs)
else:
kwargs.update(
dict(
linear_matrix=self.roq_linear_matrix_file,
quadratic_matrix=self.roq_quadratic_matrix_file
)
)
like = bilby.gw.likelihood.ROQGravitationalWaveTransient(**kwargs)
like.save_weights(TestMarginalizationsROQ.path_to_roq_weights)
like.parameters = self.parameters.copy()
if time_marginalization:
like.parameters["geocent_time"] = self.interferometers.start_time
return like
def test_time_marginalisation(self):
self._template(
self.get_likelihood(time_marginalization=True),
self.get_likelihood(),
key="geocent_time",
)
def test_time_distance_marginalisation(self):
self._template(
self.get_likelihood(time_marginalization=True, distance_marginalization=True),
self.get_likelihood(distance_marginalization=True),
key="geocent_time",
)
def test_time_phase_marginalisation(self):
self._template(
self.get_likelihood(time_marginalization=True, phase_marginalization=True),
self.get_likelihood(phase_marginalization=True),
key="geocent_time",
)
def test_time_distance_phase_marginalisation(self):
self._template(
self.get_likelihood(time_marginalization=True, phase_marginalization=True, distance_marginalization=True),
self.get_likelihood(phase_marginalization=True, distance_marginalization=True),
key="geocent_time",
)
def test_time_marginalisation_partial_segment(self):
pass
class TestROQLikelihood(unittest.TestCase):
def setUp(self):
self.duration = 4
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment