Skip to content
Snippets Groups Projects
Commit 6071737e authored by Colm Talbot's avatar Colm Talbot
Browse files

Refactor marginalized likelihood tests

parent a2390302
No related branches found
Tags 1.1.5
1 merge request!1078Refactor marginalized likelihood tests
import os
import pytest
import unittest
from copy import deepcopy
import os
from itertools import product
from parameterized import parameterized
import numpy as np
import bilby
......@@ -427,9 +430,17 @@ class TestMarginalizations(unittest.TestCase):
For time, this is strongly dependent on the specific time grid used.
The `time_jitter` parameter makes this a weaker dependence during sampling.
"""
_parameters = product(
["regular", "roq"],
["luminosity_distance", "geocent_time", "phase"],
[True, False],
[True, False],
[True, False],
)
lookup_phase = "distance_lookup_phase.npz"
lookup_no_phase = "distance_lookup_no_phase.npz"
path_to_roq_weights = "weights.npz"
def setUp(self):
np.random.seed(500)
......@@ -466,6 +477,11 @@ class TestMarginalizations(unittest.TestCase):
sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
start_time=1126259640,
waveform_arguments=dict(
reference_frequency=20.0,
minimum_frequency=20.0,
approximant="IMRPhenomPv2",
)
)
self.interferometers.inject_signal(
parameters=self.parameters, waveform_generator=self.waveform_generator
......@@ -473,9 +489,38 @@ class TestMarginalizations(unittest.TestCase):
self.priors = bilby.gw.prior.BBHPriorDict()
self.priors["geocent_time"] = bilby.prior.Uniform(
minimum=self.interferometers.start_time,
maximum=self.interferometers.start_time + self.interferometers.duration,
minimum=self.parameters["geocent_time"] - 0.1,
maximum=self.parameters["geocent_time"] + 0.1
)
trial_roq_paths = [
"/roq_basis",
os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"),
"/home/cbc/ROQ_data/IMRPhenomPv2/4s",
]
roq_dir = None
for path in trial_roq_paths:
if os.path.isdir(path):
roq_dir = path
break
if roq_dir is None:
raise Exception("Unable to load ROQ basis: cannot proceed with tests")
self.roq_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
duration=self.duration,
sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq,
start_time=1126259640,
waveform_arguments=dict(
reference_frequency=20.0,
minimum_frequency=20.0,
approximant="IMRPhenomPv2",
frequency_nodes_linear=np.load(f"{roq_dir}/fnodes_linear.npy"),
frequency_nodes_quadratic=np.load(f"{roq_dir}/fnodes_quadratic.npy"),
)
)
self.roq_linear_matrix_file = f"{roq_dir}/B_linear.npy"
self.roq_quadratic_matrix_file = f"{roq_dir}/B_quadratic.npy"
def tearDown(self):
del self.duration
......@@ -483,22 +528,17 @@ class TestMarginalizations(unittest.TestCase):
del self.parameters
del self.interferometers
del self.waveform_generator
del self.roq_waveform_generator
del self.priors
@classmethod
def tearDownClass(cls):
# remove lookup tables so that they are not used accidentally in subsequent tests
for filename in [cls.lookup_phase, cls.lookup_no_phase]:
"""remove lookup tables so that they are not used accidentally in subsequent tests"""
for filename in [cls.lookup_phase, cls.lookup_no_phase, cls.path_to_roq_weights]:
if os.path.exists(filename):
os.remove(filename)
def get_likelihood(
self,
time_marginalization=False,
phase_marginalization=False,
distance_marginalization=False,
priors=None
):
def likelihood_kwargs(self, kind, time_marginalization, phase_marginalization, distance_marginalization, priors):
if priors is None:
priors = self.priors.copy()
if distance_marginalization and phase_marginalization:
......@@ -507,7 +547,7 @@ class TestMarginalizations(unittest.TestCase):
lookup = TestMarginalizations.lookup_no_phase
else:
lookup = None
like = bilby.gw.likelihood.GravitationalWaveTransient(
kwargs = dict(
interferometers=self.interferometers,
waveform_generator=self.waveform_generator,
distance_marginalization=distance_marginalization,
......@@ -516,6 +556,36 @@ class TestMarginalizations(unittest.TestCase):
distance_marginalization_lookup_table=lookup,
priors=priors,
)
if kind == "roq":
kwargs.update(dict(
linear_matrix=self.roq_linear_matrix_file,
quadratic_matrix=self.roq_quadratic_matrix_file,
waveform_generator=self.roq_waveform_generator,
))
if os.path.exists(self.__class__.path_to_roq_weights):
kwargs["weights"] = self.__class__.path_to_roq_weights
return kwargs
def get_likelihood(
self,
kind,
time_marginalization=False,
phase_marginalization=False,
distance_marginalization=False,
priors=None
):
kwargs = self.likelihood_kwargs(
kind, time_marginalization, phase_marginalization, distance_marginalization, priors
)
if kind == "regular":
cls_ = bilby.gw.likelihood.GravitationalWaveTransient
elif kind == "roq":
cls_ = bilby.gw.likelihood.ROQGravitationalWaveTransient
else:
raise ValueError(f"kind {kind} not understood")
like = cls_(**kwargs)
if kind == "roq" and not os.path.exists(self.__class__.path_to_roq_weights):
like.save_weights(self.__class__.path_to_roq_weights)
like.parameters = self.parameters.copy()
if time_marginalization:
like.parameters["geocent_time"] = self.interferometers.start_time
......@@ -538,290 +608,46 @@ class TestMarginalizations(unittest.TestCase):
marg_like, marginalized.log_likelihood_ratio(), delta=0.5
)
def test_distance_marginalisation(self):
@parameterized.expand(_parameters)
def test_marginalisation(self, kind, key, distance, time, phase):
if all([distance, time, phase]):
pytest.skip()
tested_args = dict(
distance_marginalization=distance,
time_marginalization=time,
phase_marginalization=phase,
)
marg_key = f"{key.split('_')[-1]}_marginalization"
if tested_args[marg_key]:
pytest.skip()
reference_args = tested_args.copy()
reference_args[marg_key] = True
self._template(
self.get_likelihood(distance_marginalization=True),
self.get_likelihood(),
key="luminosity_distance",
self.get_likelihood(kind, **reference_args),
self.get_likelihood(kind, **tested_args),
key=key,
)
def test_distance_phase_marginalisation(self):
self._template(
self.get_likelihood(distance_marginalization=True, phase_marginalization=True),
self.get_likelihood(phase_marginalization=True),
key="luminosity_distance",
)
def test_distance_time_marginalisation(self):
self._template(
self.get_likelihood(distance_marginalization=True, time_marginalization=True),
self.get_likelihood(time_marginalization=True),
key="luminosity_distance",
)
def test_distance_phase_time_marginalisation(self):
"""
Test phase marginalised likelihood matches brute force version when
also marginalising over time.
"""
self._template(
self.get_likelihood(distance_marginalization=True, phase_marginalization=True, time_marginalization=True),
self.get_likelihood(phase_marginalization=True, time_marginalization=True),
key="luminosity_distance",
)
def test_phase_marginalisation(self):
self._template(
self.get_likelihood(phase_marginalization=True),
self.get_likelihood(),
key="phase",
)
def test_phase_distance_marginalisation(self):
self._template(
self.get_likelihood(distance_marginalization=True, phase_marginalization=True),
self.get_likelihood(distance_marginalization=True),
key="phase",
)
def test_phase_time_marginalisation(self):
self._template(
self.get_likelihood(time_marginalization=True, phase_marginalization=True),
self.get_likelihood(time_marginalization=True),
key="phase",
)
def test_phase_distance_time_marginalisation(self):
self._template(
self.get_likelihood(time_marginalization=True, distance_marginalization=True, phase_marginalization=True),
self.get_likelihood(time_marginalization=True, distance_marginalization=True),
key="phase",
)
def test_time_marginalisation(self):
times = self.waveform_generator.time_array
self._template(
self.get_likelihood(time_marginalization=True),
self.get_likelihood(),
key="geocent_time",
values=times,
)
def test_time_distance_marginalisation(self):
times = self.waveform_generator.time_array
self._template(
self.get_likelihood(time_marginalization=True, distance_marginalization=True),
self.get_likelihood(distance_marginalization=True),
key="geocent_time",
values=times
)
def test_time_phase_marginalisation(self):
times = self.waveform_generator.time_array
self._template(
self.get_likelihood(time_marginalization=True, phase_marginalization=True),
self.get_likelihood(phase_marginalization=True),
key="geocent_time",
values=times
)
def test_time_distance_phase_marginalisation(self):
times = self.waveform_generator.time_array
self._template(
self.get_likelihood(time_marginalization=True, phase_marginalization=True, distance_marginalization=True),
self.get_likelihood(phase_marginalization=True, distance_marginalization=True),
key="geocent_time",
values=times
)
def test_time_marginalisation_partial_segment(self):
def test_time_marginalisation_full_segment(self):
"""
Test time marginalised likelihood matches brute force version over
just part of a segment.
"""
priors = self.priors.copy()
prior = bilby.prior.Uniform(
minimum=self.parameters["geocent_time"] - 0.1,
maximum=self.parameters["geocent_time"] + 0.1,
minimum=self.interferometers.start_time,
maximum=self.interferometers.start_time + self.interferometers.duration,
)
priors["geocent_time"] = prior
self._template(
self.get_likelihood(time_marginalization=True, priors=priors.copy()),
self.get_likelihood(priors=priors.copy()),
self.get_likelihood("regular", time_marginalization=True, priors=priors.copy()),
self.get_likelihood("regular", priors=priors.copy()),
key="geocent_time",
values=self.waveform_generator.time_array,
prior=prior,
)
class TestMarginalizationsROQ(TestMarginalizations):
lookup_phase = "distance_lookup_phase.npz"
lookup_no_phase = "distance_lookup_no_phase.npz"
path_to_roq_weights = "weights.npz"
def setUp(self):
np.random.seed(500)
self.duration = 4
self.sampling_frequency = 2048
self.parameters = dict(
mass_1=31.0,
mass_2=29.0,
a_1=0.4,
a_2=0.3,
tilt_1=0.0,
tilt_2=0.0,
phi_12=1.7,
phi_jl=0.3,
luminosity_distance=4000.0,
theta_jn=0.4,
psi=2.659,
phase=1.3,
geocent_time=1126259642.413,
ra=1.375,
dec=-1.2108,
time_jitter=0,
)
self.interferometers = bilby.gw.detector.InterferometerList(["H1"])
self.interferometers.set_strain_data_from_power_spectral_densities(
sampling_frequency=self.sampling_frequency,
duration=self.duration,
start_time=1126259640,
)
waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
duration=self.duration,
sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
start_time=1126259640,
waveform_arguments=dict(
reference_frequency=20.0,
minimum_frequency=20.0,
approximant="IMRPhenomPv2"
)
)
self.interferometers.inject_signal(
parameters=self.parameters, waveform_generator=waveform_generator
)
self.priors = bilby.gw.prior.BBHPriorDict()
# prior range should be a part of segment since ROQ likelihood can not
# calculate values at samples close to edges
self.priors["geocent_time"] = bilby.prior.Uniform(
minimum=self.parameters["geocent_time"] - 0.1,
maximum=self.parameters["geocent_time"] + 0.1
)
# Possible locations for the ROQ: in the docker image, local, or on CIT
trial_roq_paths = [
"/roq_basis",
os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"),
"/home/cbc/ROQ_data/IMRPhenomPv2/4s",
]
roq_dir = None
for path in trial_roq_paths:
if os.path.isdir(path):
roq_dir = path
break
if roq_dir is None:
raise Exception("Unable to load ROQ basis: cannot proceed with tests")
self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
duration=self.duration,
sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq,
start_time=1126259640,
waveform_arguments=dict(
reference_frequency=20.0,
minimum_frequency=20.0,
approximant="IMRPhenomPv2",
frequency_nodes_linear=np.load("{}/fnodes_linear.npy".format(roq_dir)),
frequency_nodes_quadratic=np.load("{}/fnodes_quadratic.npy".format(roq_dir)),
)
)
self.roq_linear_matrix_file = "{}/B_linear.npy".format(roq_dir)
self.roq_quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir)
@classmethod
def tearDownClass(cls):
for filename in [cls.lookup_phase, cls.lookup_no_phase, cls.path_to_roq_weights]:
if os.path.exists(filename):
os.remove(filename)
def get_likelihood(
self,
time_marginalization=False,
phase_marginalization=False,
distance_marginalization=False,
priors=None
):
if priors is None:
priors = self.priors.copy()
if distance_marginalization and phase_marginalization:
lookup = TestMarginalizationsROQ.lookup_phase
elif distance_marginalization:
lookup = TestMarginalizationsROQ.lookup_no_phase
else:
lookup = None
kwargs = dict(
interferometers=self.interferometers,
waveform_generator=self.waveform_generator,
distance_marginalization=distance_marginalization,
phase_marginalization=phase_marginalization,
time_marginalization=time_marginalization,
distance_marginalization_lookup_table=lookup,
priors=priors
)
if os.path.exists(TestMarginalizationsROQ.path_to_roq_weights):
kwargs.update(dict(weights=TestMarginalizationsROQ.path_to_roq_weights))
like = bilby.gw.likelihood.ROQGravitationalWaveTransient(**kwargs)
else:
kwargs.update(
dict(
linear_matrix=self.roq_linear_matrix_file,
quadratic_matrix=self.roq_quadratic_matrix_file
)
)
like = bilby.gw.likelihood.ROQGravitationalWaveTransient(**kwargs)
like.save_weights(TestMarginalizationsROQ.path_to_roq_weights)
like.parameters = self.parameters.copy()
if time_marginalization:
like.parameters["geocent_time"] = self.interferometers.start_time
return like
def test_time_marginalisation(self):
self._template(
self.get_likelihood(time_marginalization=True),
self.get_likelihood(),
key="geocent_time",
)
def test_time_distance_marginalisation(self):
self._template(
self.get_likelihood(time_marginalization=True, distance_marginalization=True),
self.get_likelihood(distance_marginalization=True),
key="geocent_time",
)
def test_time_phase_marginalisation(self):
self._template(
self.get_likelihood(time_marginalization=True, phase_marginalization=True),
self.get_likelihood(phase_marginalization=True),
key="geocent_time",
)
def test_time_distance_phase_marginalisation(self):
self._template(
self.get_likelihood(time_marginalization=True, phase_marginalization=True, distance_marginalization=True),
self.get_likelihood(phase_marginalization=True, distance_marginalization=True),
key="geocent_time",
)
def test_time_marginalisation_partial_segment(self):
pass
class TestROQLikelihood(unittest.TestCase):
def setUp(self):
self.duration = 4
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment