Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • john-veitch/bilby
  • duncanmmacleod/bilby
  • colm.talbot/bilby
  • lscsoft/bilby
  • matthew-pitkin/bilby
  • salvatore-vitale/tupak
  • charlie.hoy/bilby
  • bfarr/bilby
  • virginia.demilio/bilby
  • vivien/bilby
  • eric-howell/bilby
  • sebastian-khan/bilby
  • rhys.green/bilby
  • moritz.huebner/bilby
  • joseph.mills/bilby
  • scott.coughlin/bilby
  • matthew.carney/bilby
  • hyungwon.lee/bilby
  • monica.rizzo/bilby
  • christopher-berry/bilby
  • lindsay.demarchi/bilby
  • kaushik.rao/bilby
  • charles.kimball/bilby
  • andrew.matas/bilby
  • juan.calderonbustillo/bilby
  • patrick-meyers/bilby
  • hannah.middleton/bilby
  • eve.chase/bilby
  • grant.meadors/bilby
  • khun.phukon/bilby
  • sumeet.kulkarni/bilby
  • daniel.reardon/bilby
  • cjhaster/bilby
  • sylvia.biscoveanu/bilby
  • james-clark/bilby
  • meg.millhouse/bilby
  • joshua.willis/bilby
  • nikhil.sarin/bilby
  • paul.easter/bilby
  • youngmin/bilby
  • daniel-williams/bilby
  • shanika.galaudage/bilby
  • bruce.edelman/bilby
  • avi.vajpeyi/bilby
  • isobel.romero-shaw/bilby
  • andrew.kim/bilby
  • dominika.zieba/bilby
  • jonathan.davies/bilby
  • marc.arene/bilby
  • srishti.tiwari/bilby-tidal-heating-eccentric
  • aditya.vijaykumar/bilby
  • michael.williams/bilby
  • cecilio.garcia-quiros/bilby
  • rory-smith/bilby
  • maite.mateu-lucena/bilby
  • wushichao/bilby
  • kaylee.desoto/bilby
  • brandon.piotrzkowski/bilby
  • rossella.gamba/bilby
  • hunter.gabbard/bilby
  • deep.chatterjee/bilby
  • tathagata.ghosh/bilby
  • arunava.mukherjee/bilby
  • philip.relton/bilby
  • reed.essick/bilby
  • pawan.gupta/bilby
  • francisco.hernandez/bilby
  • rhiannon.udall/bilby
  • leo.tsukada/bilby
  • will-farr/bilby
  • vijay.varma/bilby
  • jeremy.baier/bilby
  • joshua.brandt/bilby
  • ethan.payne/bilby
  • ka-lok.lo/bilby
  • antoni.ramos-buades/bilby
  • oliviastephany.wilk/bilby
  • jack.heinzel/bilby
  • samson.leong/bilby-psi4
  • viviana.caceres/bilby
  • nadia.qutob/bilby
  • michael-coughlin/bilby
  • hemantakumar.phurailatpam/bilby
  • boris.goncharov/bilby
  • sama.al-shammari/bilby
  • siqi.zhong/bilby
  • jocelyn-read/bilby
  • marc.penuliar/bilby
  • stephanie.letourneau/bilby
  • alexandresebastien.goettel/bilby
  • alec.gunny/bilby
  • serguei.ossokine/bilby
  • pratyusava.baral/bilby
  • sophie.hourihane/bilby
  • eunsub/bilby
  • james.hart/bilby
  • pratyusava.baral/bilby-tg
  • zhaozc/bilby
  • pratyusava.baral/bilby_SoG
  • tomasz.baka/bilby
  • nicogerardo.bers/bilby
  • soumen.roy/bilby
  • isaac.mcmahon/healpix-redundancy
  • asamakai.baker/bilby-frequency-dependent-antenna-pattern-functions
  • anna.puecher/bilby
  • pratyusava.baral/bilby-x-g
  • thibeau.wouters/bilby
  • christian.adamcewicz/bilby
  • raffi.enficiaud/bilby
109 results
Show changes
......@@ -78,13 +78,13 @@ class TestRelativeBinningLikelihood(unittest.TestCase):
duration=duration, sampling_frequency=sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict(
reference_frequency=fmin, minimum_frequency=fmin, approximant=approximant)
reference_frequency=fmin, minimum_frequency=fmin, waveform_approximant=approximant)
)
bin_wfg = bilby.gw.waveform_generator.WaveformGenerator(
duration=duration, sampling_frequency=sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole_relative_binning,
waveform_arguments=dict(
reference_frequency=fmin, approximant=approximant, minimum_frequency=fmin)
reference_frequency=fmin, waveform_approximant=approximant, minimum_frequency=fmin)
)
ifos.inject_signal(
parameters=self.test_parameters,
......
......@@ -3,6 +3,7 @@ import unittest
import tempfile
from itertools import product
from parameterized import parameterized
import pytest
import h5py
import numpy as np
......@@ -273,6 +274,7 @@ class TestGWTransient(unittest.TestCase):
)
@pytest.mark.requires_roqs
class TestROQLikelihood(unittest.TestCase):
def setUp(self):
self.duration = 4
......@@ -341,7 +343,7 @@ class TestROQLikelihood(unittest.TestCase):
waveform_arguments=dict(
reference_frequency=20.0,
minimum_frequency=20.0,
approximant="IMRPhenomPv2",
waveform_approximant="IMRPhenomPv2",
),
)
......@@ -360,7 +362,7 @@ class TestROQLikelihood(unittest.TestCase):
frequency_nodes_quadratic=fnodes_quadratic,
reference_frequency=20.0,
minimum_frequency=20.0,
approximant="IMRPhenomPv2",
waveform_approximant="IMRPhenomPv2",
),
)
......@@ -538,6 +540,7 @@ class TestROQLikelihood(unittest.TestCase):
)
@pytest.mark.requires_roqs
class TestRescaledROQLikelihood(unittest.TestCase):
def test_rescaling(self):
......@@ -597,7 +600,7 @@ class TestRescaledROQLikelihood(unittest.TestCase):
frequency_nodes_quadratic=fnodes_quadratic,
reference_frequency=20.0,
minimum_frequency=20.0,
approximant="IMRPhenomPv2",
waveform_approximant="IMRPhenomPv2",
),
)
......@@ -612,6 +615,7 @@ class TestRescaledROQLikelihood(unittest.TestCase):
)
@pytest.mark.requires_roqs
class TestROQLikelihoodHDF5(unittest.TestCase):
"""
Test ROQ likelihood constructed from .hdf5 basis
......@@ -732,6 +736,32 @@ class TestROQLikelihoodHDF5(unittest.TestCase):
)
def test_likelihood_accuracy(self, basis_linear, basis_quadratic, mc_range, roq_scale_factor, add_cal_errors):
"Compare with log likelihood ratios computed by the non-ROQ likelihood"
# The maximum error of log likelihood ratio. It is set to be larger for roq_scale_factor=1 as the injected SNR
# is higher.
if roq_scale_factor == 1:
max_llr_error = 5e-1
elif roq_scale_factor == 2:
max_llr_error = 5e-2
else:
raise
self.assertLess_likelihood_errors(
basis_linear, basis_quadratic, mc_range, roq_scale_factor, add_cal_errors, max_llr_error
)
@parameterized.expand([(_path_to_basis_mb, 100, 1024), (_path_to_basis_mb, 20, 200), (_path_to_basis_mb, 100, 200)])
def test_likelihood_accuracy_narrower_frequency_range(self, basis, minimum_frequency, maximum_frequency):
"""Compare with log likelihood ratios computed by the non-ROQ likelihood in the case where analyzed frequency
range is narrower than the basis frequency range"""
self.assertLess_likelihood_errors(
basis, basis, (8, 9), 1, False, 1.5e-1,
minimum_frequency=minimum_frequency, maximum_frequency=maximum_frequency
)
def assertLess_likelihood_errors(
self, basis_linear, basis_quadratic, mc_range, roq_scale_factor, add_cal_errors, max_llr_error,
minimum_frequency=None, maximum_frequency=None
):
self.minimum_frequency *= roq_scale_factor
self.sampling_frequency *= roq_scale_factor
self.duration /= roq_scale_factor
......@@ -745,7 +775,12 @@ class TestROQLikelihoodHDF5(unittest.TestCase):
interferometers = bilby.gw.detector.InterferometerList(["H1", "L1"])
for ifo in interferometers:
ifo.minimum_frequency = self.minimum_frequency
if minimum_frequency is None:
ifo.minimum_frequency = self.minimum_frequency
else:
ifo.minimum_frequency = minimum_frequency
if maximum_frequency is not None:
ifo.maximum_frequency = maximum_frequency
interferometers.set_strain_data_from_zero_noise(
sampling_frequency=self.sampling_frequency,
duration=self.duration,
......@@ -805,14 +840,6 @@ class TestROQLikelihoodHDF5(unittest.TestCase):
quadratic_matrix=basis_quadratic,
roq_scale_factor=roq_scale_factor
)
# The maximum error of log likelihood ratio. It is set to be larger for roq_scale_factor=1 as the injected SNR
# is higher.
if roq_scale_factor == 1:
max_llr_error = 5e-1
elif roq_scale_factor == 2:
max_llr_error = 5e-2
else:
raise
for mc in np.linspace(self.priors["chirp_mass"].minimum, self.priors["chirp_mass"].maximum, 11):
parameters = self.injection_parameters.copy()
parameters["chirp_mass"] = mc
......@@ -823,6 +850,7 @@ class TestROQLikelihoodHDF5(unittest.TestCase):
self.assertLess(np.abs(llr - llr_roq), max_llr_error)
@pytest.mark.requires_roqs
class TestCreateROQLikelihood(unittest.TestCase):
"""
Test if ROQ likelihood is constructed without any errors from .hdf5 or .npy basis
......@@ -943,6 +971,7 @@ class TestCreateROQLikelihood(unittest.TestCase):
)
@pytest.mark.requires_roqs
class TestInOutROQWeights(unittest.TestCase):
@parameterized.expand(['npz', 'hdf5'])
......@@ -1217,7 +1246,7 @@ class TestMBLikelihood(unittest.TestCase):
("IMRPhenomHM", False, 4, True, 1e-3)
])
def test_matches_original_likelihood(
self, approximant, linear_interpolation, highest_mode, add_cal_errors, tolerance
self, waveform_approximant, linear_interpolation, highest_mode, add_cal_errors, tolerance
):
"""
Check if multi-band likelihood values match original likelihood values
......@@ -1226,7 +1255,7 @@ class TestMBLikelihood(unittest.TestCase):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict(
reference_frequency=self.fmin, waveform_approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg)
......@@ -1235,7 +1264,7 @@ class TestMBLikelihood(unittest.TestCase):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, waveform_approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
......@@ -1261,12 +1290,12 @@ class TestMBLikelihood(unittest.TestCase):
"""
Check if larger accuracy factor increases the accuracy.
"""
approximant = "IMRPhenomD"
waveform_approximant = "IMRPhenomD"
wfg = bilby.gw.WaveformGenerator(
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict(
reference_frequency=self.fmin, waveform_approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg)
......@@ -1275,7 +1304,7 @@ class TestMBLikelihood(unittest.TestCase):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, waveform_approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
......@@ -1307,7 +1336,7 @@ class TestMBLikelihood(unittest.TestCase):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant="IMRPhenomD"
reference_frequency=self.fmin, waveform_approximant="IMRPhenomD"
)
)
likelihood1 = bilby.gw.likelihood.MBGravitationalWaveTransient(
......@@ -1329,7 +1358,7 @@ class TestMBLikelihood(unittest.TestCase):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant="IMRPhenomD"
reference_frequency=self.fmin, waveform_approximant="IMRPhenomD"
)
)
with self.assertRaises(TypeError):
......@@ -1345,7 +1374,7 @@ class TestMBLikelihood(unittest.TestCase):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant="IMRPhenomD"
reference_frequency=self.fmin, waveform_approximant="IMRPhenomD"
)
)
for key in ["chirp_mass", "mass_1", "mass_2"]:
......@@ -1362,12 +1391,12 @@ class TestMBLikelihood(unittest.TestCase):
Check if multiband weights can be saved as a file, and a likelihood object constructed from the weights file
produces the same likelihood value.
"""
approximant = "IMRPhenomD"
waveform_approximant = "IMRPhenomD"
wfg = bilby.gw.WaveformGenerator(
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
self.ifos.inject_signal(
......@@ -1378,7 +1407,7 @@ class TestMBLikelihood(unittest.TestCase):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient(
......@@ -1401,7 +1430,7 @@ class TestMBLikelihood(unittest.TestCase):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
likelihood_mb_from_weights = bilby.gw.likelihood.MBGravitationalWaveTransient(
......@@ -1418,12 +1447,12 @@ class TestMBLikelihood(unittest.TestCase):
"""
Check if a likelihood object constructed from dictionary-like weights produce the same likelihood value
"""
approximant = "IMRPhenomD"
waveform_approximant = "IMRPhenomD"
wfg = bilby.gw.WaveformGenerator(
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
self.ifos.inject_signal(
......@@ -1434,7 +1463,7 @@ class TestMBLikelihood(unittest.TestCase):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient(
......@@ -1451,7 +1480,7 @@ class TestMBLikelihood(unittest.TestCase):
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, approximant=approximant
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
weights = likelihood_mb.weights
......@@ -1463,6 +1492,55 @@ class TestMBLikelihood(unittest.TestCase):
self.assertAlmostEqual(llr, llr_from_weights)
@parameterized.expand([
("IMRPhenomD", True, 2, False, 1e-2),
("IMRPhenomD", True, 2, True, 1e-2),
("IMRPhenomHM", False, 4, False, 5e-3),
])
def test_matches_original_likelihood_low_maximum_frequency(
self, waveform_approximant, linear_interpolation, highest_mode, add_cal_errors, tolerance
):
"""
Test for maximum frequency < sampling frequency / 2
"""
for ifo in self.ifos:
ifo.maximum_frequency = self.sampling_frequency / 8
wfg = bilby.gw.WaveformGenerator(
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
waveform_arguments=dict(
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg)
wfg_mb = bilby.gw.WaveformGenerator(
duration=self.duration, sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
waveform_arguments=dict(
reference_frequency=self.fmin, waveform_approximant=waveform_approximant
)
)
likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
interferometers=self.ifos, waveform_generator=wfg
)
likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient(
interferometers=self.ifos, waveform_generator=wfg_mb,
reference_chirp_mass=self.test_parameters['chirp_mass'],
priors=self.priors.copy(), linear_interpolation=linear_interpolation,
highest_mode=highest_mode
)
likelihood.parameters.update(self.test_parameters)
likelihood_mb.parameters.update(self.test_parameters)
if add_cal_errors:
likelihood.parameters.update(self.calibration_parameters)
likelihood_mb.parameters.update(self.calibration_parameters)
self.assertLess(
abs(likelihood.log_likelihood_ratio() - likelihood_mb.log_likelihood_ratio()),
tolerance
)
if __name__ == "__main__":
unittest.main()
......@@ -306,7 +306,7 @@ class TestROQBBH(unittest.TestCase):
frequency_nodes_quadratic=fnodes_quadratic,
reference_frequency=50.0,
minimum_frequency=20.0,
approximant="IMRPhenomPv2",
waveform_approximant="IMRPhenomPv2",
)
self.frequency_array = bilby.core.utils.create_frequency_series(2048, 4)
......
import unittest
import os
from shutil import rmtree
from importlib.metadata import version
import numpy as np
import lal
......@@ -89,12 +90,28 @@ class TestGWUtils(unittest.TestCase):
with self.assertRaises(ValueError):
gwutils.get_event_time("GW010290")
@pytest.mark.skipif(version("gwpy") < "3.0.8", reason="GWpy version < 3.0.8")
def test_read_frame_file(self):
"""
Test that reading a frame file works as expected
for a few conditions.
1. Reading without time limits returns the full data
2. Reading with time limits returns the expected data
(inclusive of start time if present, exclusive of end time)
3. Reading without the channel name provided finds a standard name
4. Reading without the channel with a non-standard name returns None.
Notes
=====
There was a longstanding bug in gwpy that we previously tested for
here, but this has been fixed in gwpy 3.0.8.
"""
start_time = 0
end_time = 10
channel = "H1:GDS-CALIB_STRAIN"
N = 100
times = np.linspace(start_time, end_time, N)
times = np.linspace(start_time, end_time, N, endpoint=False)
data = np.random.normal(0, 1, N)
ts = TimeSeries(data=data, times=times, t0=0)
ts.channel = Channel(channel)
......@@ -107,7 +124,7 @@ class TestGWUtils(unittest.TestCase):
filename, start_time=None, end_time=None, channel=channel
)
self.assertEqual(strain.name, channel)
self.assertTrue(np.all(strain.value == data[:-1]))
self.assertTrue(np.all(strain.value == data))
# Check reading with time limits
start_cut = 2
......@@ -115,19 +132,18 @@ class TestGWUtils(unittest.TestCase):
strain = gwutils.read_frame_file(
filename, start_time=start_cut, end_time=end_cut, channel=channel
)
idxs = (times > start_cut) & (times < end_cut)
# Dropping the last element - for some reason gwpy drops the last element when reading in data
self.assertTrue(np.all(strain.value == data[idxs][:-1]))
idxs = (times >= start_cut) & (times < end_cut)
self.assertTrue(np.all(strain.value == data[idxs]))
# Check reading with unknown channels
strain = gwutils.read_frame_file(filename, start_time=None, end_time=None)
self.assertTrue(np.all(strain.value == data[:-1]))
self.assertTrue(np.all(strain.value == data))
# Check reading with incorrect channel
strain = gwutils.read_frame_file(
filename, start_time=None, end_time=None, channel="WRONG"
)
self.assertTrue(np.all(strain.value == data[:-1]))
self.assertTrue(np.all(strain.value == data))
ts = TimeSeries(data=data, times=times, t0=0)
ts.name = "NOT-A-KNOWN-CHANNEL"
......
......@@ -438,42 +438,42 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
def test_frequency_domain_caching_and_using_time_domain_strain_without_parameters(
self,
):
original_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters
)
new_waveform = self.waveform_generator.time_domain_strain()
self.assertNotEqual(original_waveform, new_waveform)
self.assertFalse(_test_caching_different_domain(
self.waveform_generator.frequency_domain_strain,
self.waveform_generator.time_domain_strain,
self.simulation_parameters,
None,
))
def test_frequency_domain_caching_and_using_time_domain_strain_with_parameters(
self,
):
original_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters
)
new_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters
)
self.assertNotEqual(original_waveform, new_waveform)
self.assertFalse(_test_caching_different_domain(
self.waveform_generator.frequency_domain_strain,
self.waveform_generator.time_domain_strain,
self.simulation_parameters,
self.simulation_parameters,
))
def test_time_domain_caching_and_using_frequency_domain_strain_without_parameters(
self,
):
original_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters
)
new_waveform = self.waveform_generator.frequency_domain_strain()
self.assertNotEqual(original_waveform, new_waveform)
self.assertFalse(_test_caching_different_domain(
self.waveform_generator.time_domain_strain,
self.waveform_generator.frequency_domain_strain,
self.simulation_parameters,
None,
))
def test_time_domain_caching_and_using_frequency_domain_strain_with_parameters(
self,
):
original_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters
)
new_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters
)
self.assertNotEqual(original_waveform, new_waveform)
self.assertFalse(_test_caching_different_domain(
self.waveform_generator.time_domain_strain,
self.waveform_generator.frequency_domain_strain,
self.simulation_parameters,
self.simulation_parameters,
))
def test_frequency_domain_caching_changing_model(self):
original_waveform = self.waveform_generator.frequency_domain_strain(
......@@ -648,42 +648,51 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
def test_frequency_domain_caching_and_using_time_domain_strain_without_parameters(
self,
):
original_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters
)
new_waveform = self.waveform_generator.time_domain_strain()
self.assertNotEqual(original_waveform, new_waveform)
self.assertFalse(_test_caching_different_domain(
self.waveform_generator.frequency_domain_strain,
self.waveform_generator.time_domain_strain,
self.simulation_parameters,
None,
))
def test_frequency_domain_caching_and_using_time_domain_strain_with_parameters(
self,
):
original_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters
)
new_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters
)
self.assertNotEqual(original_waveform, new_waveform)
self.assertFalse(_test_caching_different_domain(
self.waveform_generator.frequency_domain_strain,
self.waveform_generator.time_domain_strain,
self.simulation_parameters,
self.simulation_parameters,
))
def test_time_domain_caching_and_using_frequency_domain_strain_without_parameters(
self,
):
original_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters
)
new_waveform = self.waveform_generator.frequency_domain_strain()
self.assertNotEqual(original_waveform, new_waveform)
self.assertFalse(_test_caching_different_domain(
self.waveform_generator.time_domain_strain,
self.waveform_generator.frequency_domain_strain,
self.simulation_parameters,
None,
))
def test_time_domain_caching_and_using_frequency_domain_strain_with_parameters(
self,
):
original_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters
)
new_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters
)
self.assertNotEqual(original_waveform, new_waveform)
self.assertFalse(_test_caching_different_domain(
self.waveform_generator.time_domain_strain,
self.waveform_generator.frequency_domain_strain,
self.simulation_parameters,
self.simulation_parameters,
))
def _test_caching_different_domain(func1, func2, params1, params2):
original_waveform = func1(parameters=params1)
new_waveform = func2(parameters=params2)
output = True
for key in original_waveform:
output &= np.array_equal(original_waveform[key], new_waveform[key])
return output
if __name__ == "__main__":
......
......@@ -55,7 +55,6 @@ _sampler_kwargs = dict(
PTMCMCSampler=dict(Niter=101, burn=100, covUpdate=100, isave=100),
pymc=dict(draws=50, tune=50, n_init=250),
pymultinest=dict(nlive=100),
pypolychord=dict(nlive=100),
ultranest=dict(nlive=100, temporary_directory=False),
zeus=dict(nwalkers=10, iterations=100)
)
......@@ -65,7 +64,7 @@ sampler_imports = dict(
dynamic_dynesty="dynesty"
)
no_pool_test = ["dnest4", "pymultinest", "nestle", "ptmcmcsampler", "pypolychord", "ultranest", "pymc"]
no_pool_test = ["dnest4", "pymultinest", "nestle", "ptmcmcsampler", "ultranest", "pymc"]
def slow_func(x, m, c):
......
......@@ -11,7 +11,7 @@ IMPLEMENTED_SAMPLERS = bilby.core.sampler.IMPLEMENTED_SAMPLERS
likelihood = bilby.core.likelihood.Likelihood(dict())
priors = bilby.core.prior.PriorDict(dict(a=bilby.core.prior.Uniform(0, 1)))
for sampler in IMPLEMENTED_SAMPLERS:
if sampler == "fake_sampler":
if sampler in ["fake_sampler", "pypolychord"]:
continue
sampler_class = IMPLEMENTED_SAMPLERS[sampler]
sampler = sampler_class(likelihood=likelihood, priors=priors)