From bad30b7cfd732d16b62322e45db5bb5f9ebc4f3c Mon Sep 17 00:00:00 2001
From: Rhiannon Udall <rhiannon.udall@ligo.org>
Date: Mon, 22 Jul 2024 16:12:00 +0000
Subject: [PATCH] Add identity conversion and generation functions

---
 bilby/gw/conversion.py     | 56 +++++++++++++++++++++++++++++++++
 test/gw/conversion_test.py | 64 ++++++++++++++++++++++++++++++++++++++
 2 files changed, 120 insertions(+)

diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py
index 85bdf115a..fe1688f4f 100644
--- a/bilby/gw/conversion.py
+++ b/bilby/gw/conversion.py
@@ -2556,3 +2556,59 @@ def fill_sample(args):
     likelihood.parameters.update(dict(sample).copy())
     new_sample = likelihood.generate_posterior_sample_from_marginalized_likelihood()
     return tuple((new_sample[key] for key in marginalized_parameters))
+
+
+def identity_map_conversion(parameters):
+    """An identity map conversion function that makes no changes to the parameters,
+    but returns the correct signature expected by other conversion functions
+    (e.g. convert_to_lal_binary_black_hole_parameters)"""
+    return parameters, []
+
+
+def identity_map_generation(sample, likelihood=None, priors=None, npool=1):
+    """An identity map generation function that handles marginalizations, SNRs, etc. correctly,
+    but does not attempt e.g. conversions in mass or spins
+
+    Parameters
+    ==========
+    sample: dict or pandas.DataFrame
+        Samples to fill in with extra parameters, this may be either an
+        injection or posterior samples.
+    likelihood: bilby.gw.likelihood.GravitationalWaveTransient, optional
+        GravitationalWaveTransient used for sampling, used for waveform and
+        likelihood.interferometers.
+    priors: dict, optional
+        Dictionary of prior objects, used to fill in non-sampled parameters.
+
+    Returns
+    =======
+
+    """
+    output_sample = sample.copy()
+
+    output_sample = fill_from_fixed_priors(output_sample, priors)
+
+    if likelihood is not None:
+        compute_per_detector_log_likelihoods(
+            samples=output_sample, likelihood=likelihood, npool=npool)
+
+        marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list())
+        if len(marginalized_parameters) > 0:
+            try:
+                generate_posterior_samples_from_marginalized_likelihood(
+                    samples=output_sample, likelihood=likelihood, npool=npool)
+            except MarginalizedLikelihoodReconstructionError as e:
+                logger.warning(
+                    "Marginalised parameter reconstruction failed with message "
+                    "{}. Some parameters may not have the intended "
+                    "interpretation.".format(e)
+                )
+
+        if ("ra" in output_sample.keys() and "dec" in output_sample.keys() and "psi" in output_sample.keys()):
+            compute_snrs(output_sample, likelihood, npool=npool)
+        else:
+            logger.info(
+                "Skipping SNR computation since samples have insufficient sky location information"
+            )
+
+    return output_sample
diff --git a/test/gw/conversion_test.py b/test/gw/conversion_test.py
index 8ce40e3b8..f75dd5c22 100644
--- a/test/gw/conversion_test.py
+++ b/test/gw/conversion_test.py
@@ -171,6 +171,26 @@ class TestBasicConversions(unittest.TestCase):
         )
         self.assertTrue((self.delta_lambda_tilde - delta_lambda_tilde) < 1e-5)
 
+    def test_identity_conversion(self):
+        original_samples = dict(
+            mass_1=self.mass_1,
+            mass_2=self.mass_2,
+            mass_ratio=self.mass_ratio,
+            total_mass=self.total_mass,
+            chirp_mass=self.chirp_mass,
+            symmetric_mass_ratio=self.symmetric_mass_ratio,
+            cos_angle=self.cos_angle,
+            angle=self.angle,
+            lambda_1=self.lambda_1,
+            lambda_2=self.lambda_2,
+            lambda_tilde=self.lambda_tilde,
+            delta_lambda_tilde=self.delta_lambda_tilde
+        )
+        identity_samples, blank_list = conversion.identity_map_conversion(original_samples)
+        assert blank_list == []
+        for key, val in identity_samples.items():
+            assert val == self.__dict__[key]
+
 
 class TestConvertToLALParams(unittest.TestCase):
     def setUp(self):
@@ -509,6 +529,50 @@ class TestGenerateAllParameters(unittest.TestCase):
         for key in extra_expected:
             self.assertIn(key, converted)
 
+    def test_identity_generation_no_likelihood(self):
+        test_fixed_prior = bilby.core.prior.PriorDict({
+            "test_param_a": bilby.core.prior.DeltaFunction(0, name="test_param_a"),
+            "test_param_b": bilby.core.prior.DeltaFunction(1, name="test_param_b")
+        }
+        )
+        output_sample = conversion.identity_map_generation(self.parameters, priors=test_fixed_prior)
+        assert output_sample.pop("test_param_a") == 0
+        assert output_sample.pop("test_param_b") == 1
+        for key, val in self.parameters.items():
+            assert output_sample.pop(key) == val
+        assert output_sample == {}
+
+    def test_identity_generation_with_likelihood(self):
+        priors = bilby.gw.prior.BBHPriorDict()
+        priors["geocent_time"] = bilby.core.prior.Uniform(0.4, 0.6)
+        self.parameters["time_jitter"] = 0.0
+        # Note we do *not* switch to azimuth/zenith, because the identity generation function
+        # is not intended to be capable of that conversion
+        ifos = bilby.gw.detector.InterferometerList(["H1"])
+        ifos.set_strain_data_from_power_spectral_densities(duration=1, sampling_frequency=256)
+        wfg = bilby.gw.waveform_generator.WaveformGenerator(
+            frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole
+        )
+        likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
+            interferometers=ifos,
+            waveform_generator=wfg,
+            priors=priors,
+            phase_marginalization=True,
+            time_marginalization=True,
+            reference_frame="sky",
+        )
+        output_sample = conversion.identity_map_generation(self.parameters, priors=priors, likelihood=likelihood)
+        extra_expected = [
+            "phase",
+            "geocent_time",
+            "H1_optimal_snr",
+            "H1_matched_filter_snr",
+        ]
+        for key in extra_expected:
+            self.assertIn(key, output_sample)
+        for key, val in self.parameters.items():
+            self.assertTrue(output_sample[key] == val)
+
 
 class TestDistanceTransformations(unittest.TestCase):
     def setUp(self):
-- 
GitLab