Skip to content
Snippets Groups Projects
Commit e2cb1d50 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Merge branch '332-sampling_frequency-and-duration-should-be-saved-in-the-cbcresult' into 'master'

Resolve "sampling_frequency and duration should be saved in the CBCResult"

Closes #332

See merge request lscsoft/bilby!449
parents de3e6d3d 80a553a6
No related branches found
No related tags found
No related merge requests found
...@@ -91,14 +91,6 @@ class GravitationalWaveTransient(likelihood.Likelihood): ...@@ -91,14 +91,6 @@ class GravitationalWaveTransient(likelihood.Likelihood):
self.phase_marginalization = phase_marginalization self.phase_marginalization = phase_marginalization
self.priors = priors self.priors = priors
self._check_set_duration_and_sampling_frequency_of_waveform_generator() self._check_set_duration_and_sampling_frequency_of_waveform_generator()
self.meta_data = dict(
interferometers=self.interferometers.meta_data,
time_marginalization=self.time_marginalization,
phase_marginalization=self.phase_marginalization,
distance_marginalization=self.distance_marginalization,
waveform_arguments=waveform_generator.waveform_arguments,
frequency_domain_source_model=str(
waveform_generator.frequency_domain_source_model))
if self.time_marginalization: if self.time_marginalization:
self._check_prior_is_set(key='geocent_time') self._check_prior_is_set(key='geocent_time')
...@@ -635,6 +627,20 @@ class GravitationalWaveTransient(likelihood.Likelihood): ...@@ -635,6 +627,20 @@ class GravitationalWaveTransient(likelihood.Likelihood):
for mode in signal: for mode in signal:
signal[mode] *= self._ref_dist / new_distance signal[mode] *= self._ref_dist / new_distance
@property
def meta_data(self):
return dict(
interferometers=self.interferometers.meta_data,
time_marginalization=self.time_marginalization,
phase_marginalization=self.phase_marginalization,
distance_marginalization=self.distance_marginalization,
waveform_arguments=self.waveform_generator.waveform_arguments,
frequency_domain_source_model=str(
self.waveform_generator.frequency_domain_source_model),
sampling_frequency=self.waveform_generator.sampling_frequency,
duration=self.waveform_generator.duration,
start_time=self.waveform_generator.start_time)
class BasicGravitationalWaveTransient(likelihood.Likelihood): class BasicGravitationalWaveTransient(likelihood.Likelihood):
......
...@@ -17,14 +17,33 @@ class CompactBinaryCoalesenceResult(CoreResult): ...@@ -17,14 +17,33 @@ class CompactBinaryCoalesenceResult(CoreResult):
def __get_from_nested_meta_data(self, *keys): def __get_from_nested_meta_data(self, *keys):
dictionary = self.meta_data dictionary = self.meta_data
try: try:
item = None
for k in keys: for k in keys:
item = dictionary[k] item = dictionary[k]
dictionary = item dictionary = item
return item return item
except KeyError: except KeyError:
raise ValueError( raise AttributeError(
"No information stored for {}".format('/'.join(keys))) "No information stored for {}".format('/'.join(keys)))
@property
def sampling_frequency(self):
""" Sampling frequency in Hertz"""
return self.__get_from_nested_meta_data(
'likelihood', 'sampling_frequency')
@property
def duration(self):
""" Duration in seconds """
return self.__get_from_nested_meta_data(
'likelihood', 'duration')
@property
def start_time(self):
""" Start time in seconds """
return self.__get_from_nested_meta_data(
'likelihood', 'start_time')
@property @property
def time_marginalization(self): def time_marginalization(self):
""" Boolean for if the likelihood used time marginalization """ """ Boolean for if the likelihood used time marginalization """
...@@ -82,7 +101,7 @@ class CompactBinaryCoalesenceResult(CoreResult): ...@@ -82,7 +101,7 @@ class CompactBinaryCoalesenceResult(CoreResult):
try: try:
return self.__get_from_nested_meta_data( return self.__get_from_nested_meta_data(
'likelihood', 'interferometers', detector) 'likelihood', 'interferometers', detector)
except ValueError: except AttributeError:
logger.info("No injection for detector {}".format(detector)) logger.info("No injection for detector {}".format(detector))
return None return None
......
...@@ -147,6 +147,20 @@ class TestGWTransient(unittest.TestCase): ...@@ -147,6 +147,20 @@ class TestGWTransient(unittest.TestCase):
self.assertListEqual(bilby.gw.detector.InterferometerList(ifos), self.likelihood.interferometers) self.assertListEqual(bilby.gw.detector.InterferometerList(ifos), self.likelihood.interferometers)
self.assertTrue(type(self.likelihood.interferometers) == bilby.gw.detector.InterferometerList) self.assertTrue(type(self.likelihood.interferometers) == bilby.gw.detector.InterferometerList)
def test_meta_data(self):
expected = dict(
interferometers=self.interferometers.meta_data,
time_marginalization=False,
phase_marginalization=False,
distance_marginalization=False,
waveform_arguments=self.waveform_generator.waveform_arguments,
frequency_domain_source_model=str(
self.waveform_generator.frequency_domain_source_model),
sampling_frequency=self.waveform_generator.sampling_frequency,
duration=self.waveform_generator.duration,
start_time=self.waveform_generator.start_time)
self.assertDictEqual(expected, self.likelihood.meta_data)
class TestTimeMarginalization(unittest.TestCase): class TestTimeMarginalization(unittest.TestCase):
......
...@@ -49,8 +49,8 @@ class TestCBCResult(unittest.TestCase): ...@@ -49,8 +49,8 @@ class TestCBCResult(unittest.TestCase):
def test_phase_marginalization_unset(self): def test_phase_marginalization_unset(self):
self.result.meta_data['likelihood'].pop('phase_marginalization') self.result.meta_data['likelihood'].pop('phase_marginalization')
with self.assertRaises(ValueError): with self.assertRaises(AttributeError):
self.result.phase_marginalization, self.result.phase_marginalization
def test_time_marginalization(self): def test_time_marginalization(self):
self.assertEqual( self.assertEqual(
...@@ -59,8 +59,8 @@ class TestCBCResult(unittest.TestCase): ...@@ -59,8 +59,8 @@ class TestCBCResult(unittest.TestCase):
def test_time_marginalization_unset(self): def test_time_marginalization_unset(self):
self.result.meta_data['likelihood'].pop('time_marginalization') self.result.meta_data['likelihood'].pop('time_marginalization')
with self.assertRaises(ValueError): with self.assertRaises(AttributeError):
self.result.time_marginalization, self.result.time_marginalization
def test_distance_marginalization(self): def test_distance_marginalization(self):
self.assertEqual( self.assertEqual(
...@@ -69,8 +69,8 @@ class TestCBCResult(unittest.TestCase): ...@@ -69,8 +69,8 @@ class TestCBCResult(unittest.TestCase):
def test_distance_marginalization_unset(self): def test_distance_marginalization_unset(self):
self.result.meta_data['likelihood'].pop('distance_marginalization') self.result.meta_data['likelihood'].pop('distance_marginalization')
with self.assertRaises(ValueError): with self.assertRaises(AttributeError):
self.result.distance_marginalization, self.result.distance_marginalization
def test_reference_frequency(self): def test_reference_frequency(self):
self.assertEqual( self.assertEqual(
...@@ -79,8 +79,8 @@ class TestCBCResult(unittest.TestCase): ...@@ -79,8 +79,8 @@ class TestCBCResult(unittest.TestCase):
def test_reference_frequency_unset(self): def test_reference_frequency_unset(self):
self.result.meta_data['likelihood']['waveform_arguments'].pop('reference_frequency') self.result.meta_data['likelihood']['waveform_arguments'].pop('reference_frequency')
with self.assertRaises(ValueError): with self.assertRaises(AttributeError):
self.result.reference_frequency, self.result.reference_frequency
def test_waveform_approximant(self): def test_waveform_approximant(self):
self.assertEqual( self.assertEqual(
...@@ -89,8 +89,8 @@ class TestCBCResult(unittest.TestCase): ...@@ -89,8 +89,8 @@ class TestCBCResult(unittest.TestCase):
def test_waveform_approximant_unset(self): def test_waveform_approximant_unset(self):
self.result.meta_data['likelihood']['waveform_arguments'].pop('waveform_approximant') self.result.meta_data['likelihood']['waveform_arguments'].pop('waveform_approximant')
with self.assertRaises(ValueError): with self.assertRaises(AttributeError):
self.result.waveform_approximant, self.result.waveform_approximant
def test_frequency_domain_source_model(self): def test_frequency_domain_source_model(self):
self.assertEqual( self.assertEqual(
...@@ -99,7 +99,7 @@ class TestCBCResult(unittest.TestCase): ...@@ -99,7 +99,7 @@ class TestCBCResult(unittest.TestCase):
def test_frequency_domain_source_model_unset(self): def test_frequency_domain_source_model_unset(self):
self.result.meta_data['likelihood'].pop('frequency_domain_source_model') self.result.meta_data['likelihood'].pop('frequency_domain_source_model')
with self.assertRaises(ValueError): with self.assertRaises(AttributeError):
self.result.frequency_domain_source_model self.result.frequency_domain_source_model
def test_detector_injection_properties(self): def test_detector_injection_properties(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment