From 0a2f71aefd539641753e28ccdb038b7ba2682214 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Sun, 1 Dec 2019 18:12:19 -0600
Subject: [PATCH] make waveform generator class saved in GWT meta data

---
 bilby/core/utils.py        | 4 +++-
 bilby/gw/likelihood.py     | 1 +
 bilby/gw/result.py         | 9 +++++++--
 test/gw_likelihood_test.py | 1 +
 4 files changed, 12 insertions(+), 3 deletions(-)

diff --git a/bilby/core/utils.py b/bilby/core/utils.py
index d058c5599..66dc06368 100644
--- a/bilby/core/utils.py
+++ b/bilby/core/utils.py
@@ -998,6 +998,8 @@ class BilbyJsonEncoder(json.JSONEncoder):
             return {'__dataframe__': True, 'content': obj.to_dict(orient='list')}
         if inspect.isfunction(obj):
             return {"__function__": True, "__module__": obj.__module__, "__name__": obj.__name__}
+        if inspect.isclass(obj):
+            return {"__class__": True, "__module__": obj.__module__, "__name__": obj.__name__}
         return json.JSONEncoder.default(self, obj)
 
 
@@ -1036,7 +1038,7 @@ def decode_bilby_json(dct):
         return complex(dct["real"], dct["imag"])
     if dct.get("__dataframe__", False):
         return pd.DataFrame(dct['content'])
-    if dct.get("__function__", False):
+    if dct.get("__function__", False) or dct.get("__class__", False):
         default = ".".join([dct["__module__"], dct["__name__"]])
         return getattr(import_module(dct["__module__"]), dct["__name__"], default)
     return dct
diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py
index 804dab655..02bc2452e 100644
--- a/bilby/gw/likelihood.py
+++ b/bilby/gw/likelihood.py
@@ -702,6 +702,7 @@ class GravitationalWaveTransient(likelihood.Likelihood):
             time_marginalization=self.time_marginalization,
             phase_marginalization=self.phase_marginalization,
             distance_marginalization=self.distance_marginalization,
+            waveform_generator_class=self.waveform_generator.__class__,
             waveform_arguments=self.waveform_generator.waveform_arguments,
             frequency_domain_source_model=self.waveform_generator.frequency_domain_source_model,
             parameter_conversion=self.waveform_generator.parameter_conversion,
diff --git a/bilby/gw/result.py b/bilby/gw/result.py
index ed9125a94..52af660fc 100644
--- a/bilby/gw/result.py
+++ b/bilby/gw/result.py
@@ -11,7 +11,6 @@ import numpy as np
 from ..core.result import Result as CoreResult
 from ..core.utils import infft, logger, check_directory_exists_and_if_not_mkdir
 from .utils import plot_spline_pos, spline_angle_xform, asd_from_freq_series
-from .waveform_generator import WaveformGenerator
 from .detector import get_empty_interferometer, Interferometer
 
 
@@ -79,6 +78,12 @@ class CompactBinaryCoalescenceResult(CoreResult):
         return self.__get_from_nested_meta_data(
             'likelihood', 'waveform_arguments', 'waveform_approximant')
 
+    @property
+    def waveform_generator_class(self):
+        """ Dict of waveform arguments """
+        return self.__get_from_nested_meta_data(
+            'likelihood', 'waveform_generator_class')
+
     @property
     def waveform_arguments(self):
         """ Dict of waveform arguments """
@@ -347,7 +352,7 @@ class CompactBinaryCoalescenceResult(CoreResult):
         plot_times = interferometer.time_array[time_idxs] - interferometer.strain_data.start_time
         plot_frequencies = interferometer.frequency_array[frequency_idxs]
 
-        waveform_generator = WaveformGenerator(
+        waveform_generator = self.waveform_generator_class(
             duration=self.duration, sampling_frequency=self.sampling_frequency,
             start_time=self.start_time,
             frequency_domain_source_model=self.frequency_domain_source_model,
diff --git a/test/gw_likelihood_test.py b/test/gw_likelihood_test.py
index 91966dc45..c95c4a0df 100644
--- a/test/gw_likelihood_test.py
+++ b/test/gw_likelihood_test.py
@@ -156,6 +156,7 @@ class TestGWTransient(unittest.TestCase):
             time_marginalization=False,
             phase_marginalization=False,
             distance_marginalization=False,
+            waveform_generator_class=self.waveform_generator.__class__,
             waveform_arguments=self.waveform_generator.waveform_arguments,
             frequency_domain_source_model=self.waveform_generator.frequency_domain_source_model,
             parameter_conversion=self.waveform_generator.parameter_conversion,
-- 
GitLab