diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index d52e05ae8d21f7bff9054da180dcc3fb1ecbeb6d..b494dca9f4a947f63cce12004bae181c478e5c7e 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -838,12 +838,8 @@ def _generate_all_cbc_parameters(sample, defaults, base_conversion, output_sample = fill_from_fixed_priors(output_sample, priors) output_sample, _ = base_conversion(output_sample) if likelihood is not None: - if ( - hasattr(likelihood, 'phase_marginalization') or - hasattr(likelihood, 'time_marginalization') or - hasattr(likelihood, 'distance_marginalization') or - hasattr(likelihood, 'calibration_marginalization') - ): + marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list()) + if len(marginalized_parameters) > 0: try: generate_posterior_samples_from_marginalized_likelihood( samples=output_sample, likelihood=likelihood, npool=npool) @@ -854,10 +850,17 @@ def _generate_all_cbc_parameters(sample, defaults, base_conversion, "interpretation.".format(e) ) if priors is not None: - for par, name in zip( - ['distance', 'phase', 'time'], - ['luminosity_distance', 'phase', 'geocent_time']): - if getattr(likelihood, '{}_marginalization'.format(par), False): + misnamed_marginalizations = dict( + distance="luminosity_distance", + time="geocent_time", + calibration="recalib_index", + ) + for par in marginalized_parameters: + name = misnamed_marginalizations.get(par, par) + if ( + getattr(likelihood, f'{par}_marginalization', False) + and name in likelihood.priors + ): priors[name] = likelihood.priors[name] if ( @@ -1296,10 +1299,8 @@ def generate_posterior_samples_from_marginalized_likelihood( sample: DataFrame Returns the posterior with new samples. """ - if not any([likelihood.phase_marginalization, - likelihood.distance_marginalization, - likelihood.time_marginalization, - likelihood.calibration_marginalization]): + marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list()) + if len(marginalized_parameters) == 0: return samples # pass through a dictionary @@ -1382,11 +1383,8 @@ def generate_posterior_samples_from_marginalized_likelihood( [np.array(val) for key, val in cached_samples_dict.items() if key != "_samples"] ) - samples['geocent_time'] = new_samples[:, 0] - samples['luminosity_distance'] = new_samples[:, 1] - samples['phase'] = new_samples[:, 2] - if likelihood.calibration_marginalization: - samples['recalib_index'] = new_samples[:, 3] + for ii, key in enumerate(marginalized_parameters): + samples[key] = new_samples[:, ii] return samples @@ -1413,13 +1411,8 @@ def generate_sky_frame_parameters(samples, likelihood): def fill_sample(args): ii, sample, likelihood = args + marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list()) sample = dict(sample).copy() likelihood.parameters.update(dict(sample).copy()) new_sample = likelihood.generate_posterior_sample_from_marginalized_likelihood() - - if not likelihood.calibration_marginalization: - return new_sample["geocent_time"], new_sample["luminosity_distance"],\ - new_sample["phase"] - else: - return new_sample["geocent_time"], new_sample["luminosity_distance"],\ - new_sample["phase"], new_sample['recalib_index'] + return (new_sample[key] for key in marginalized_parameters) diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index d4d8eca345bb8f6dfca921f4cd20bfcbc3c5cb6c..fd2aea3ea7365ba7ccf3a23f57f162879800ee2c 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -448,8 +448,7 @@ class GravitationalWaveTransient(Likelihood): This involves a deepcopy of the signal to avoid issues with waveform caching, as the signal is overwritten in place. """ - if any([self.phase_marginalization, self.distance_marginalization, - self.time_marginalization, self.calibration_marginalization]): + if len(self._marginalized_parameters) > 0: signal_polarizations = copy.deepcopy( self.waveform_generator.frequency_domain_strain( self.parameters)) diff --git a/test/gw/conversion_test.py b/test/gw/conversion_test.py index 54dd064f08e956587debb4951dcf08270efaa612..e89bdce00164bc24a2440f18dd2a21dc25b41c02 100644 --- a/test/gw/conversion_test.py +++ b/test/gw/conversion_test.py @@ -465,6 +465,39 @@ class TestGenerateAllParameters(unittest.TestCase): for key in expected: self.assertIn(key, new_parameters) + def test_generate_bbh_paramters_with_likelihood(self): + priors = bilby.gw.prior.BBHPriorDict() + priors["geocent_time"] = bilby.core.prior.Uniform(0.4, 0.6) + ifos = bilby.gw.detector.InterferometerList(["H1"]) + ifos.set_strain_data_from_power_spectral_densities(duration=1, sampling_frequency=256) + wfg = bilby.gw.waveform_generator.WaveformGenerator( + frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole + ) + likelihood = bilby.gw.likelihood.GravitationalWaveTransient( + interferometers=ifos, + waveform_generator=wfg, + priors=priors, + phase_marginalization=True, + time_marginalization=True, + reference_frame="H1L1", + ) + self.parameters["zenith"] = 0.0 + self.parameters["azimuth"] = 0.0 + del self.parameters["ra"], self.parameters["dec"] + converted = bilby.gw.conversion.generate_all_bbh_parameters( + sample=self.parameters, likelihood=likelihood, priors=priors + ) + extra_expected = [ + "geocent_time", + "phase", + "H1_optimal_snr", + "H1_matched_filter_snr", + "ra", + "dec", + ] + for key in extra_expected: + self.assertIn(key, converted) + class TestDistanceTransformations(unittest.TestCase): def setUp(self): diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index a4528e623dee9729ae769812af3fc049fa992683..3489dab6872591e68f491ab7b9493457f34732d8 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -1,3 +1,4 @@ +import itertools import os import pytest import unittest @@ -648,6 +649,43 @@ class TestMarginalizations(unittest.TestCase): prior=prior, ) + @parameterized.expand( + itertools.product(["regular", "roq"], *itertools.repeat([True, False], 3)), + name_func=lambda func, num, param: ( + f"{func.__name__}_{num}__{param.args[0]}_" + "_".join([ + ["D", "P", "T"][ii] for ii, val + in enumerate(param.args[1:]) if val + ]) + ) + ) + def test_marginalization_reconstruction(self, kind, distance, phase, time): + if time and kind == "roq": + pytest.skip("Time reconstruction not supported for ROQ likelihood") + marginalizations = dict( + geocent_time=time, + luminosity_distance=distance, + phase=phase, + ) + like = self.get_likelihood( + kind=kind, + distance_marginalization=distance, + time_marginalization=time, + phase_marginalization=phase, + ) + params = self.parameters.copy() + reference_values = dict( + luminosity_distance=self.priors["luminosity_distance"].rescale(0.5), + geocent_time=self.interferometers.start_time, + phase=0.0, + ) + for key in marginalizations: + if marginalizations[key]: + params[key] = reference_values[key] + like.parameters.update(params) + output = like.generate_posterior_sample_from_marginalized_likelihood() + for key in marginalizations: + self.assertFalse(marginalizations[key] and reference_values[key] == output[key]) + class TestROQLikelihood(unittest.TestCase): def setUp(self):