From 286011fdf4563668d356767f51c415a80ffc01be Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Thu, 17 May 2018 11:42:33 +1000
Subject: [PATCH] bbh parameter filling uses likelihood rather than wg and ifos

---
 tupak/conversion.py | 70 ++++++++++++++++++++++++---------------------
 1 file changed, 38 insertions(+), 32 deletions(-)

diff --git a/tupak/conversion.py b/tupak/conversion.py
index f229d1a07..67e9ba7ac 100644
--- a/tupak/conversion.py
+++ b/tupak/conversion.py
@@ -98,7 +98,7 @@ def convert_to_lal_binary_black_hole_parameters(parameters, search_keys, remove=
     return ignored_keys
 
 
-def generate_all_bbh_parameters(sample, waveform_generator=None, interferometers=None, priors=None):
+def generate_all_bbh_parameters(sample, likelihood=None, priors=None):
     """
     From either a single sample or a set of samples fill in all missing BBH parameters, in place.
 
@@ -106,24 +106,21 @@ def generate_all_bbh_parameters(sample, waveform_generator=None, interferometers
     ----------
     sample: dict or pandas.DataFrame
         Samples to fill in with extra parameters, this may be either an injection or posterior samples.
-    waveform_generator: tupak.waveform_generator.WaveformGenerator, optional
-        If the waveform generator and interferometers are provided, the SNRs will be recorded.
-    interferometers: list, optional
-        List of tupak.detector.Interferometer objects.
-        If the waveform generator and interferometers are provided, the SNRs will be recorded.
+    likelihood: tupak.likelihood.Likelihood
+        Likelihood used for sampling, used for waveform and likelihood.interferometers.
     priors: dict, optional
         Dictionary of prior objects, used to fill in non-sampled parameters.
     """
 
-    if waveform_generator is not None:
-        sample['reference_frequency'] = waveform_generator.parameters['reference_frequency']
-        sample['waveform_approximant'] = waveform_generator.parameters['waveform_approximant']
+    if likelihood is not None:
+        sample['reference_frequency'] = likelihood.waveform_generator.parameters['reference_frequency']
+        sample['waveform_approximant'] = likelihood.waveform_generator.parameters['waveform_approximant']
 
     fill_from_fixed_priors(sample, priors)
     convert_to_lal_binary_black_hole_parameters(sample, [key for key in sample.keys()], remove=False)
     generate_non_standard_parameters(sample)
     generate_component_spins(sample)
-    compute_snrs(sample, waveform_generator, interferometers)
+    compute_snrs(sample, likelihood)
 
 
 def fill_from_fixed_priors(sample, priors):
@@ -194,39 +191,48 @@ def generate_component_spins(sample):
         logging.warning("Component spin extraction failed.")
 
 
-def compute_snrs(sample, waveform_generator, interferometers):
+def compute_snrs(sample, likelihood):
     """Compute the optimal and matched filter snrs of all posterior samples."""
-    temp_sample = sample.copy()
-    if waveform_generator is not None and interferometers is not None:
+    temp_sample = sample
+    if likelihood is not None:
         if isinstance(temp_sample, dict):
-            for key in waveform_generator.parameters.keys():
-                waveform_generator.parameters[key] = temp_sample[key]
-            signal_polarizations = waveform_generator.frequency_domain_strain()
-            for interferometer in interferometers:
-                signal = interferometer.get_detector_response(signal_polarizations, waveform_generator.parameters)
+            for key in likelihood.waveform_generator.parameters.keys():
+                likelihood.waveform_generator.parameters[key] = temp_sample[key]
+            signal_polarizations = likelihood.waveform_generator.frequency_domain_strain()
+            for interferometer in likelihood.interferometers:
+                signal = interferometer.get_detector_response(signal_polarizations,
+                                                              likelihood.waveform_generator.parameters)
                 sample['{}_matched_filter_snr'.format(interferometer.name)] = \
                     tupak.utils.matched_filter_snr_squared(signal, interferometer,
-                                                           waveform_generator.time_duration)**0.5
+                                                           likelihood.waveform_generator.time_duration)**0.5
                 sample['{}_optimal_snr'.format(interferometer.name)] = tupak.utils.optimal_snr_squared(
-                    signal, interferometer, waveform_generator.time_duration) ** 0.5
+                    signal, interferometer, likelihood.waveform_generator.time_duration) ** 0.5
         else:
             logging.info('Computing SNRs for every sample, this may take some time.')
-            matched_filter_snrs = {interferometer.name: [] for interferometer in interferometers}
-            optimal_snrs = {interferometer.name: [] for interferometer in interferometers}
+            all_interferometers = likelihood.interferometers
+            matched_filter_snrs = {interferometer.name: [] for interferometer in all_interferometers}
+            optimal_snrs = {interferometer.name: [] for interferometer in all_interferometers}
+            likelihoods = {interferometer.name: [] for interferometer in all_interferometers}
             for ii in range(len(temp_sample)):
-                for key in set(temp_sample.keys()).intersection(waveform_generator.parameters.keys()):
-                    waveform_generator.parameters[key] = temp_sample[key][ii]
-                for key in waveform_generator.search_parameter_keys:
-                    waveform_generator.parameters[key] = temp_sample[key][ii]
-                signal_polarizations = waveform_generator.frequency_domain_strain()
-                for interferometer in interferometers:
-                    signal = interferometer.get_detector_response(signal_polarizations, waveform_generator.parameters)
+                for key in set(temp_sample.keys()).intersection(likelihood.waveform_generator.parameters.keys()):
+                    likelihood.waveform_generator.parameters[key] = temp_sample[key][ii]
+                for key in likelihood.waveform_generator.sampling_parameter_keys:
+                    likelihood.waveform_generator.parameters[key] = temp_sample[key][ii]
+                signal_polarizations = likelihood.waveform_generator.frequency_domain_strain()
+                for interferometer in all_interferometers:
+                    signal = interferometer.get_detector_response(signal_polarizations,
+                                                                  likelihood.waveform_generator.parameters)
                     matched_filter_snrs[interferometer.name].append(tupak.utils.matched_filter_snr_squared(
-                        signal, interferometer, waveform_generator.time_duration)**0.5)
+                        signal, interferometer, likelihood.waveform_generator.time_duration)**0.5)
                     optimal_snrs[interferometer.name].append(tupak.utils.optimal_snr_squared(
-                        signal, interferometer, waveform_generator.time_duration) ** 0.5)
-            for interferometer in interferometers:
+                        signal, interferometer, likelihood.waveform_generator.time_duration) ** 0.5)
+
+            for interferometer in likelihood.interferometers:
                 sample['{}_matched_filter_snr'.format(interferometer.name)] = matched_filter_snrs[interferometer.name]
                 sample['{}_optimal_snr'.format(interferometer.name)] = optimal_snrs[interferometer.name]
+
+            likelihood.interferometers = all_interferometers
+            print([interferometer.name for interferometer in likelihood.interferometers])
+
     else:
         logging.info('Not computing SNRs.')
-- 
GitLab