Skip to content
Snippets Groups Projects
Commit c65f9f1f authored by Colm Talbot's avatar Colm Talbot
Browse files

Merge branch 'default_reference_chirp_mass' into 'master'

Determine reference_chirp_mass for MBGravitationalWaveTransient from prior when it is not specified

See merge request !1169
parents c82ae742 aad4e48b
No related branches found
No related tags found
1 merge request!1169Determine reference_chirp_mass for MBGravitationalWaveTransient from prior when it is not specified
Pipeline #476291 passed
...@@ -8,6 +8,7 @@ from ...core.utils import ( ...@@ -8,6 +8,7 @@ from ...core.utils import (
logger, speed_of_light, solar_mass, radius_of_earth, logger, speed_of_light, solar_mass, radius_of_earth,
gravitational_constant, round_up_to_power_of_two gravitational_constant, round_up_to_power_of_two
) )
from ..prior import CBCPriorDict
class MBGravitationalWaveTransient(GravitationalWaveTransient): class MBGravitationalWaveTransient(GravitationalWaveTransient):
...@@ -21,8 +22,9 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): ...@@ -21,8 +22,9 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
A list of `bilby.detector.Interferometer` instances - contains the detector data and power spectral densities A list of `bilby.detector.Interferometer` instances - contains the detector data and power spectral densities
waveform_generator: `bilby.waveform_generator.WaveformGenerator` waveform_generator: `bilby.waveform_generator.WaveformGenerator`
An object which computes the frequency-domain strain of the signal, given some set of parameters An object which computes the frequency-domain strain of the signal, given some set of parameters
reference_chirp_mass: float reference_chirp_mass: float, optional
A reference chirp mass for determining the frequency banding A reference chirp mass for determining the frequency banding. This is set to prior minimum of chirp mass if
not specified. Hence a CBCPriorDict object needs to be passed to priors when this parameter is not specified.
highest_mode: int, optional highest_mode: int, optional
The maximum magnetic number of gravitational-wave moments. Default is 2 The maximum magnetic number of gravitational-wave moments. Default is 2
linear_interpolation: bool, optional linear_interpolation: bool, optional
...@@ -72,10 +74,11 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): ...@@ -72,10 +74,11 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
""" """
def __init__( def __init__(
self, interferometers, waveform_generator, reference_chirp_mass, highest_mode=2, linear_interpolation=True, self, interferometers, waveform_generator, reference_chirp_mass=None, highest_mode=2,
accuracy_factor=5, time_offset=None, delta_f_end=None, maximum_banding_frequency=None, linear_interpolation=True, accuracy_factor=5, time_offset=None, delta_f_end=None,
minimum_banding_duration=0., distance_marginalization=False, phase_marginalization=False, priors=None, maximum_banding_frequency=None, minimum_banding_duration=0., distance_marginalization=False,
distance_marginalization_lookup_table=None, reference_frame="sky", time_reference="geocenter" phase_marginalization=False, priors=None, distance_marginalization_lookup_table=None,
reference_frame="sky", time_reference="geocenter"
): ):
super(MBGravitationalWaveTransient, self).__init__( super(MBGravitationalWaveTransient, self).__init__(
interferometers=interferometers, waveform_generator=waveform_generator, priors=priors, interferometers=interferometers, waveform_generator=waveform_generator, priors=priors,
...@@ -108,7 +111,24 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): ...@@ -108,7 +111,24 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
if isinstance(reference_chirp_mass, int) or isinstance(reference_chirp_mass, float): if isinstance(reference_chirp_mass, int) or isinstance(reference_chirp_mass, float):
self._reference_chirp_mass = reference_chirp_mass self._reference_chirp_mass = reference_chirp_mass
else: else:
raise TypeError("reference_chirp_mass must be a number") logger.info(
"No int or float number has been passed to reference_chirp_mass. "
"Checking prior minimum of chirp mass ..."
)
if not isinstance(self.priors, CBCPriorDict):
raise TypeError(
f"priors: {self.priors} is not CBCPriorDict. Prior minimum of chirp mass can not be obtained."
)
self._reference_chirp_mass = self.priors.minimum_chirp_mass
if self._reference_chirp_mass is None:
raise Exception(
"Prior minimum of chirp mass can not be determined as priors does not contain necessary mass "
"parameters."
)
logger.info(
"reference_chirp_mass is automatically set to prior minimum of chirp mass: "
f"{self._reference_chirp_mass}."
)
@property @property
def highest_mode(self): def highest_mode(self):
......
...@@ -2,7 +2,6 @@ import itertools ...@@ -2,7 +2,6 @@ import itertools
import os import os
import pytest import pytest
import unittest import unittest
from copy import deepcopy
from itertools import product from itertools import product
from parameterized import parameterized from parameterized import parameterized
...@@ -1571,9 +1570,9 @@ class TestBBHLikelihoodSetUp(unittest.TestCase): ...@@ -1571,9 +1570,9 @@ class TestBBHLikelihoodSetUp(unittest.TestCase):
class TestMBLikelihood(unittest.TestCase): class TestMBLikelihood(unittest.TestCase):
def setUp(self): def setUp(self):
duration = 16 self.duration = 16
fmin = 20. self.fmin = 20.
sampling_frequency = 2048. self.sampling_frequency = 2048.
self.test_parameters = dict( self.test_parameters = dict(
chirp_mass=6.0, chirp_mass=6.0,
mass_ratio=0.5, mass_ratio=0.5,
...@@ -1592,18 +1591,18 @@ class TestMBLikelihood(unittest.TestCase): ...@@ -1592,18 +1591,18 @@ class TestMBLikelihood(unittest.TestCase):
dec=-1.2 dec=-1.2
) # Network SNR is ~50 ) # Network SNR is ~50
ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"]) self.ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"])
np.random.seed(170817) np.random.seed(170817)
ifos.set_strain_data_from_power_spectral_densities( self.ifos.set_strain_data_from_power_spectral_densities(
sampling_frequency=sampling_frequency, duration=duration, sampling_frequency=self.sampling_frequency, duration=self.duration,
start_time=self.test_parameters['geocent_time'] - duration + 2. start_time=self.test_parameters['geocent_time'] - self.duration + 2.
) )
for ifo in ifos: for ifo in self.ifos:
ifo.minimum_frequency = fmin ifo.minimum_frequency = self.fmin
spline_calibration_nodes = 10 spline_calibration_nodes = 10
self.calibration_parameters = {} self.calibration_parameters = {}
for ifo in ifos: for ifo in self.ifos:
ifo.calibration_model = bilby.gw.calibration.CubicSpline( ifo.calibration_model = bilby.gw.calibration.CubicSpline(
prefix=f"recalib_{ifo.name}_", prefix=f"recalib_{ifo.name}_",
minimum_frequency=ifo.minimum_frequency, minimum_frequency=ifo.minimum_frequency,
...@@ -1619,143 +1618,168 @@ class TestMBLikelihood(unittest.TestCase): ...@@ -1619,143 +1618,168 @@ class TestMBLikelihood(unittest.TestCase):
self.calibration_parameters[f"recalib_{ifo.name}_phase_{i}"] = \ self.calibration_parameters[f"recalib_{ifo.name}_phase_{i}"] = \
np.random.normal(loc=0, scale=5 * np.pi / 180) np.random.normal(loc=0, scale=5 * np.pi / 180)
priors = bilby.gw.prior.BBHPriorDict() self.priors = bilby.gw.prior.BBHPriorDict()
priors.pop("mass_1") self.priors.pop("mass_1")
priors.pop("mass_2") self.priors.pop("mass_2")
priors["chirp_mass"] = bilby.core.prior.Uniform(5.5, 6.5) self.priors["chirp_mass"] = bilby.core.prior.Uniform(5.5, 6.5)
priors["mass_ratio"] = bilby.core.prior.Uniform(0.125, 1) self.priors["mass_ratio"] = bilby.core.prior.Uniform(0.125, 1)
priors["geocent_time"] = bilby.core.prior.Uniform( self.priors["geocent_time"] = bilby.core.prior.Uniform(
self.test_parameters['geocent_time'] - 0.1, self.test_parameters['geocent_time'] - 0.1,
self.test_parameters['geocent_time'] + 0.1) self.test_parameters['geocent_time'] + 0.1)
approximant_22 = "IMRPhenomD" def tearDown(self):
approximant_homs = "IMRPhenomHM" del (
non_mb_wfg_22 = bilby.gw.WaveformGenerator( self.ifos,
duration=duration, sampling_frequency=sampling_frequency, self.priors
)
@parameterized.expand([
("IMRPhenomD", True, 2, False, 1.5e-2),
("IMRPhenomD", True, 2, True, 1.5e-2),
("IMRPhenomD", False, 2, False, 5e-3),
("IMRPhenomD", False, 2, True, 6e-3),
("IMRPhenomHM", False, 4, False, 8e-4),
("IMRPhenomHM", False, 4, True, 1e-3)
])
def test_matches_original_likelihood(
self, approximant, linear_interpolation, highest_mode, add_cal_errors, tolerance
):
"""
Check if multi-band likelihood values match original likelihood values
"""
wfg = bilby.gw.WaveformGenerator(
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict( waveform_arguments=dict(
reference_frequency=fmin, minimum_frequency=fmin, approximant=approximant_22) reference_frequency=self.fmin, approximant=approximant
)
) )
mb_wfg_22 = bilby.gw.waveform_generator.WaveformGenerator( self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg)
duration=duration, sampling_frequency=sampling_frequency,
wfg_mb = bilby.gw.WaveformGenerator(
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict( waveform_arguments=dict(
reference_frequency=fmin, approximant=approximant_22) reference_frequency=self.fmin, approximant=approximant
) )
non_mb_wfg_homs = bilby.gw.WaveformGenerator(
duration=duration, sampling_frequency=sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict(
reference_frequency=fmin, minimum_frequency=fmin, approximant=approximant_homs)
) )
mb_wfg_homs = bilby.gw.waveform_generator.WaveformGenerator( likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
duration=duration, sampling_frequency=sampling_frequency, interferometers=self.ifos, waveform_generator=wfg
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=fmin, approximant=approximant_homs)
) )
likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient(
ifos_22 = deepcopy(ifos) interferometers=self.ifos, waveform_generator=wfg_mb,
ifos_22.inject_signal( reference_chirp_mass=self.test_parameters['chirp_mass'],
parameters=self.test_parameters, waveform_generator=non_mb_wfg_22 priors=self.priors.copy(), linear_interpolation=linear_interpolation,
highest_mode=highest_mode
) )
ifos_homs = deepcopy(ifos) likelihood.parameters.update(self.test_parameters)
ifos_homs.inject_signal( likelihood_mb.parameters.update(self.test_parameters)
parameters=self.test_parameters, waveform_generator=non_mb_wfg_homs if add_cal_errors:
likelihood.parameters.update(self.calibration_parameters)
likelihood_mb.parameters.update(self.calibration_parameters)
self.assertLess(
abs(likelihood.log_likelihood_ratio() - likelihood_mb.log_likelihood_ratio()),
tolerance
) )
self.non_mb_22 = bilby.gw.likelihood.GravitationalWaveTransient( def test_large_accuracy_factor(self):
interferometers=ifos_22, waveform_generator=non_mb_wfg_22 """
) Check if larger accuracy factor increases the accuracy.
self.non_mb_homs = bilby.gw.likelihood.GravitationalWaveTransient( """
interferometers=ifos_homs, waveform_generator=non_mb_wfg_homs approximant = "IMRPhenomD"
wfg = bilby.gw.WaveformGenerator(
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant=approximant
)
) )
self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg)
self.mb_22 = bilby.gw.likelihood.MBGravitationalWaveTransient( wfg_mb = bilby.gw.WaveformGenerator(
interferometers=ifos_22, waveform_generator=deepcopy(mb_wfg_22), duration=self.duration, sampling_frequency=self.sampling_frequency,
reference_chirp_mass=self.test_parameters['chirp_mass'], frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
priors=priors.copy() waveform_arguments=dict(
reference_frequency=self.fmin, approximant=approximant
)
) )
self.mb_ifftfft_22 = bilby.gw.likelihood.MBGravitationalWaveTransient( likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
interferometers=ifos_22, waveform_generator=deepcopy(mb_wfg_22), interferometers=self.ifos, waveform_generator=wfg
reference_chirp_mass=self.test_parameters['chirp_mass'],
priors=priors.copy(), linear_interpolation=False
) )
self.mb_homs = bilby.gw.likelihood.MBGravitationalWaveTransient( likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient(
interferometers=ifos_homs, waveform_generator=deepcopy(mb_wfg_homs), interferometers=self.ifos, waveform_generator=wfg_mb,
reference_chirp_mass=self.test_parameters['chirp_mass'], reference_chirp_mass=self.test_parameters['chirp_mass'],
priors=priors.copy(), linear_interpolation=False, highest_mode=4 priors=self.priors.copy(), accuracy_factor=5
) )
self.mb_more_accurate = bilby.gw.likelihood.MBGravitationalWaveTransient( likelihood_mb_more_accurate = bilby.gw.likelihood.MBGravitationalWaveTransient(
interferometers=ifos_22, waveform_generator=deepcopy(mb_wfg_22), interferometers=self.ifos, waveform_generator=wfg_mb,
reference_chirp_mass=self.test_parameters['chirp_mass'], reference_chirp_mass=self.test_parameters['chirp_mass'],
priors=priors.copy(), accuracy_factor=50 priors=self.priors.copy(), accuracy_factor=50
)
def tearDown(self):
del (
self.non_mb_22,
self.non_mb_homs,
self.mb_22,
self.mb_ifftfft_22,
self.mb_homs,
self.mb_more_accurate
) )
likelihood.parameters.update(self.test_parameters)
@parameterized.expand([(False, ), (True, )]) likelihood_mb.parameters.update(self.test_parameters)
def test_matches_non_mb(self, add_cal_errors): likelihood_mb_more_accurate.parameters.update(self.test_parameters)
self.non_mb_22.parameters.update(self.test_parameters)
self.mb_22.parameters.update(self.test_parameters)
if add_cal_errors:
self.non_mb_22.parameters.update(self.calibration_parameters)
self.mb_22.parameters.update(self.calibration_parameters)
self.assertLess( self.assertLess(
abs(self.non_mb_22.log_likelihood_ratio() - self.mb_22.log_likelihood_ratio()), abs(likelihood.log_likelihood_ratio() - likelihood_mb_more_accurate.log_likelihood_ratio()),
1.5e-2 abs(likelihood.log_likelihood_ratio() - likelihood_mb.log_likelihood_ratio()) / 2
) )
@parameterized.expand([(False, ), (True, )]) def test_reference_chirp_mass_from_prior(self):
def test_ifft_fft(self, add_cal_errors):
""" """
Check if multi-banding likelihood with (h, h) computed with the Check if reference chirp mass is automatically determined from prior if no number has been passed
IFFT-FFT algorithm matches the original likelihood.
""" """
self.non_mb_22.parameters.update(self.test_parameters) wfg_mb = bilby.gw.WaveformGenerator(
self.mb_ifftfft_22.parameters.update(self.test_parameters) duration=self.duration, sampling_frequency=self.sampling_frequency,
if add_cal_errors: frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
self.non_mb_22.parameters.update(self.calibration_parameters) waveform_arguments=dict(
self.mb_ifftfft_22.parameters.update(self.calibration_parameters) reference_frequency=self.fmin, approximant="IMRPhenomD"
self.assertLess( )
abs(self.non_mb_22.log_likelihood_ratio() - self.mb_ifftfft_22.log_likelihood_ratio()), )
6e-3 likelihood1 = bilby.gw.likelihood.MBGravitationalWaveTransient(
interferometers=self.ifos, waveform_generator=wfg_mb,
reference_chirp_mass=self.priors["chirp_mass"].minimum,
priors=self.priors.copy()
)
likelihood2 = bilby.gw.likelihood.MBGravitationalWaveTransient(
interferometers=self.ifos, waveform_generator=wfg_mb,
priors=self.priors.copy()
) )
self.assertAlmostEqual(likelihood1.reference_chirp_mass, likelihood2.reference_chirp_mass)
@parameterized.expand([(False, ), (True, )]) def test_no_reference_chirp_mass(self):
def test_homs(self, add_cal_errors):
""" """
Check if multi-banding likelihood matches the original likelihood for higher-order moments. Check if an error is raised if either reference_chirp_mass or priors is not specified.
""" """
self.non_mb_homs.parameters.update(self.test_parameters) wfg_mb = bilby.gw.WaveformGenerator(
self.mb_homs.parameters.update(self.test_parameters) duration=self.duration, sampling_frequency=self.sampling_frequency,
if add_cal_errors: frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
self.non_mb_homs.parameters.update(self.calibration_parameters) waveform_arguments=dict(
self.mb_homs.parameters.update(self.calibration_parameters) reference_frequency=self.fmin, approximant="IMRPhenomD"
self.assertLess( )
abs(self.non_mb_homs.log_likelihood_ratio() - self.mb_homs.log_likelihood_ratio()),
1e-3
) )
with self.assertRaises(TypeError):
bilby.gw.likelihood.MBGravitationalWaveTransient(
interferometers=self.ifos, waveform_generator=wfg_mb
)
def test_large_accuracy_factor(self): def test_cannot_determine_reference_chirp_mass(self):
""" """
Check if larger accuracy factor increases the accuracy. Check if an error is raised if priors does not contain necessary information to determine reference chirp mass
""" """
self.non_mb_22.parameters.update(self.test_parameters) wfg_mb = bilby.gw.WaveformGenerator(
self.mb_22.parameters.update(self.test_parameters) duration=self.duration, sampling_frequency=self.sampling_frequency,
self.mb_more_accurate.parameters.update(self.test_parameters) frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
self.assertLess( waveform_arguments=dict(
abs(self.non_mb_22.log_likelihood_ratio() - self.mb_more_accurate.log_likelihood_ratio()), reference_frequency=self.fmin, approximant="IMRPhenomD"
abs(self.non_mb_22.log_likelihood_ratio() - self.mb_22.log_likelihood_ratio()) / 2 )
) )
for key in ["chirp_mass", "mass_1", "mass_2"]:
if key in self.priors:
self.priors.pop(key)
with self.assertRaises(Exception):
bilby.gw.likelihood.MBGravitationalWaveTransient(
interferometers=self.ifos, waveform_generator=wfg_mb, priors=self.priors
)
if __name__ == "__main__": if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment