diff --git a/bilby/core/utils.py b/bilby/core/utils.py index d058c55995d5c2b183a33065c6668daefe93d4a6..66dc063684170387a856c43c3b179b4b2c14b680 100644 --- a/bilby/core/utils.py +++ b/bilby/core/utils.py @@ -998,6 +998,8 @@ class BilbyJsonEncoder(json.JSONEncoder): return {'__dataframe__': True, 'content': obj.to_dict(orient='list')} if inspect.isfunction(obj): return {"__function__": True, "__module__": obj.__module__, "__name__": obj.__name__} + if inspect.isclass(obj): + return {"__class__": True, "__module__": obj.__module__, "__name__": obj.__name__} return json.JSONEncoder.default(self, obj) @@ -1036,7 +1038,7 @@ def decode_bilby_json(dct): return complex(dct["real"], dct["imag"]) if dct.get("__dataframe__", False): return pd.DataFrame(dct['content']) - if dct.get("__function__", False): + if dct.get("__function__", False) or dct.get("__class__", False): default = ".".join([dct["__module__"], dct["__name__"]]) return getattr(import_module(dct["__module__"]), dct["__name__"], default) return dct diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py index 804dab655743e8ed16d72ab225be92ca15c7d72d..02bc2452e1548b47c6dc8e0b6d9f171689e4504c 100644 --- a/bilby/gw/likelihood.py +++ b/bilby/gw/likelihood.py @@ -702,6 +702,7 @@ class GravitationalWaveTransient(likelihood.Likelihood): time_marginalization=self.time_marginalization, phase_marginalization=self.phase_marginalization, distance_marginalization=self.distance_marginalization, + waveform_generator_class=self.waveform_generator.__class__, waveform_arguments=self.waveform_generator.waveform_arguments, frequency_domain_source_model=self.waveform_generator.frequency_domain_source_model, parameter_conversion=self.waveform_generator.parameter_conversion, diff --git a/bilby/gw/result.py b/bilby/gw/result.py index ed9125a94840b6f48518d0761ba18c3dc675f08c..52af660fcf7cce6ad9f6158cdbb9f9799f3d6a1b 100644 --- a/bilby/gw/result.py +++ b/bilby/gw/result.py @@ -11,7 +11,6 @@ import numpy as np from ..core.result import Result as CoreResult from ..core.utils import infft, logger, check_directory_exists_and_if_not_mkdir from .utils import plot_spline_pos, spline_angle_xform, asd_from_freq_series -from .waveform_generator import WaveformGenerator from .detector import get_empty_interferometer, Interferometer @@ -79,6 +78,12 @@ class CompactBinaryCoalescenceResult(CoreResult): return self.__get_from_nested_meta_data( 'likelihood', 'waveform_arguments', 'waveform_approximant') + @property + def waveform_generator_class(self): + """ Dict of waveform arguments """ + return self.__get_from_nested_meta_data( + 'likelihood', 'waveform_generator_class') + @property def waveform_arguments(self): """ Dict of waveform arguments """ @@ -347,7 +352,7 @@ class CompactBinaryCoalescenceResult(CoreResult): plot_times = interferometer.time_array[time_idxs] - interferometer.strain_data.start_time plot_frequencies = interferometer.frequency_array[frequency_idxs] - waveform_generator = WaveformGenerator( + waveform_generator = self.waveform_generator_class( duration=self.duration, sampling_frequency=self.sampling_frequency, start_time=self.start_time, frequency_domain_source_model=self.frequency_domain_source_model, diff --git a/test/gw_likelihood_test.py b/test/gw_likelihood_test.py index 91966dc458e56b8504acafa8297363226ace1697..c95c4a0dfceee341811f286c340cbeab34736bb8 100644 --- a/test/gw_likelihood_test.py +++ b/test/gw_likelihood_test.py @@ -156,6 +156,7 @@ class TestGWTransient(unittest.TestCase): time_marginalization=False, phase_marginalization=False, distance_marginalization=False, + waveform_generator_class=self.waveform_generator.__class__, waveform_arguments=self.waveform_generator.waveform_arguments, frequency_domain_source_model=self.waveform_generator.frequency_domain_source_model, parameter_conversion=self.waveform_generator.parameter_conversion,