From c30494b1ac0003f2b7800e1e9db0e7d81e96f1b2 Mon Sep 17 00:00:00 2001
From: John Veitch <john.veitch@ligo.org>
Date: Thu, 16 Mar 2023 02:16:31 +0000
Subject: [PATCH] Parallelise the reweighting step

---
 bilby/gw/likelihood/base.py | 33 +++++++++++++++++++++++----------
 1 file changed, 23 insertions(+), 10 deletions(-)

diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py
index 1daa9a1e5..614d5c724 100644
--- a/bilby/gw/likelihood/base.py
+++ b/bilby/gw/likelihood/base.py
@@ -492,27 +492,41 @@ class GravitationalWaveTransient(Likelihood):
         else:
             return self.parameters
 
-        if self.calibration_marginalization and self.time_marginalization:
-            raise AttributeError(
-                "Cannot use time and calibration marginalization simultaneously for regeneration at the moment!"
-                "The matrix manipulation has not been tested.")
+        """
+        As we go through and repopulate we will undo the marginalizations
+        for the next function in the chain - back up the state info first
+        """
+        backup_options = (self.time_marginalization,
+                   self.calibration_marginalization,
+                   self.distance_marginalization,
+                   self.phase_marginalization)
 
-        if self.calibration_marginalization:
-            new_calibration = self.generate_calibration_sample_from_marginalized_likelihood(
-                signal_polarizations=signal_polarizations)
-            self.parameters['recalib_index'] = new_calibration
         if self.time_marginalization:
             new_time = self.generate_time_sample_from_marginalized_likelihood(
                 signal_polarizations=signal_polarizations)
             self.parameters['geocent_time'] = new_time
+            self.time_marginalization = False
         if self.distance_marginalization:
             new_distance = self.generate_distance_sample_from_marginalized_likelihood(
                 signal_polarizations=signal_polarizations)
             self.parameters['luminosity_distance'] = new_distance
+            self.distance_marginalization = False
         if self.phase_marginalization:
             new_phase = self.generate_phase_sample_from_marginalized_likelihood(
                 signal_polarizations=signal_polarizations)
             self.parameters['phase'] = new_phase
+        if self.calibration_marginalization:
+            new_calibration = self.generate_calibration_sample_from_marginalized_likelihood(
+                signal_polarizations=signal_polarizations)
+            self.parameters['recalib_index'] = new_calibration
+            self.calibration_marginalization = False
+
+        # Restore state info here
+        self.time_marginalization, \
+        self.calibration_marginalization, \
+        self.distance_marginalization, \
+        self.phase_marginalization = backup_options
+
         return self.parameters.copy()
 
     def generate_calibration_sample_from_marginalized_likelihood(
@@ -539,8 +553,7 @@ class GravitationalWaveTransient(Likelihood):
                 self.waveform_generator.frequency_domain_strain(self.parameters)
 
         log_like = self.get_calibration_log_likelihoods(signal_polarizations=signal_polarizations)
-
-        calibration_post = np.exp(log_like - max(log_like))
+        calibration_post = np.exp(log_like - log_like.max())
         calibration_post /= np.sum(calibration_post)
 
         new_calibration = np.random.choice(self.number_of_response_curves, p=calibration_post)
-- 
GitLab