Skip to content
Snippets Groups Projects
Commit 9c89194d authored by Colm Talbot's avatar Colm Talbot
Browse files

Merge branch 'bugfix-cbc-waveform-generator' into 'master'

BUGFIX: CBCWaveformGenerator: don't access hidden enum

See merge request !1140
parents 69ccfdfe 027a3d44
No related branches found
No related tags found
1 merge request!1140BUGFIX: CBCWaveformGenerator: don't access hidden enum
Pipeline #449180 passed
......@@ -258,14 +258,16 @@ class WaveformGenerator(object):
class LALCBCWaveformGenerator(WaveformGenerator):
""" A waveform generator with specific checks for LAL CBC waveforms """
LAL_SIM_INSPIRAL_SPINS_FLOW = 1
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.validate_reference_frequency()
def validate_reference_frequency(self):
from lalsimulation import SimInspiralGetSpinFreqFromApproximant, LAL_SIM_INSPIRAL_SPINS_FLOW
from lalsimulation import SimInspiralGetSpinFreqFromApproximant
waveform_approximant = self.waveform_arguments["waveform_approximant"]
waveform_approximant_number = lalsim_GetApproximantFromString(waveform_approximant)
if SimInspiralGetSpinFreqFromApproximant(waveform_approximant_number) == LAL_SIM_INSPIRAL_SPINS_FLOW:
if SimInspiralGetSpinFreqFromApproximant(waveform_approximant_number) == self.LAL_SIM_INSPIRAL_SPINS_FLOW:
if self.waveform_arguments["reference_frequency"] != self.waveform_arguments["minimum_frequency"]:
raise ValueError(f"For {waveform_approximant}, reference_frequency must equal minimum_frequency")
import unittest
from unittest import mock
import bilby
import lalsimulation
import numpy as np
......@@ -159,6 +161,69 @@ class TestWaveformArgumentsSetting(unittest.TestCase):
)
class TestLALCBCWaveformArgumentsSetting(unittest.TestCase):
def setUp(self):
self.kwargs = dict(
duration=4,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
sampling_frequency=2048,
)
def tearDown(self):
del self.kwargs
def test_spin_reference_enumeration(self):
"""
Verify that the value of the reference enumerator hasn't changed by comparing
against a known approximant.
"""
self.assertEqual(
lalsimulation.SimInspiralGetSpinFreqFromApproximant(lalsimulation.SEOBNRv3),
bilby.gw.waveform_generator.LALCBCWaveformGenerator.LAL_SIM_INSPIRAL_SPINS_FLOW,
)
def test_create_waveform_generator_non_precessing(self):
self.kwargs["waveform_arguments"] = dict(
minimum_frequency=20.0,
reference_frequency=50.0,
waveform_approximant="TaylorF2",
)
wfg = bilby.gw.waveform_generator.LALCBCWaveformGenerator(**self.kwargs)
self.assertDictEqual(
wfg.waveform_arguments,
dict(
minimum_frequency=20.0,
reference_frequency=50.0,
waveform_approximant="TaylorF2",
),
)
def test_create_waveform_generator_eob_succeeds(self):
self.kwargs["waveform_arguments"] = dict(
minimum_frequency=20.0,
reference_frequency=20.0,
waveform_approximant="SEOBNRv3",
)
wfg = bilby.gw.waveform_generator.LALCBCWaveformGenerator(**self.kwargs)
self.assertDictEqual(
wfg.waveform_arguments,
dict(
minimum_frequency=20.0,
reference_frequency=20.0,
waveform_approximant="SEOBNRv3",
),
)
def test_create_waveform_generator_eob_fails(self):
self.kwargs["waveform_arguments"] = dict(
minimum_frequency=20.0,
reference_frequency=50.0,
waveform_approximant="SEOBNRv3",
)
with self.assertRaises(ValueError):
_ = bilby.gw.waveform_generator.LALCBCWaveformGenerator(**self.kwargs)
class TestSetters(unittest.TestCase):
def setUp(self):
self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment