diff --git a/test/likelihood_test.py b/test/likelihood_test.py index 283f8065427c72ea7f636cf7981b2f5e2d155b1d..5fbf02cdb5b07441e8919028ef115bceacfe2bb2 100644 --- a/test/likelihood_test.py +++ b/test/likelihood_test.py @@ -189,6 +189,12 @@ class TestGaussianLikelihood(unittest.TestCase): likelihood.log_likelihood() self.assertTrue(likelihood.sigma == 1) + def test_sigma_other(self): + likelihood = GaussianLikelihood( + self.x, self.y, self.function, sigma=None) + with self.assertRaises(ValueError): + likelihood.sigma = 'test' + def test_repr(self): likelihood = GaussianLikelihood( self.x, self.y, self.function, sigma=self.sigma) @@ -467,6 +473,10 @@ class TestExponentialLikelihood(unittest.TestCase): m.return_value = 3 self.assertEqual(-3, exponential_likelihood.log_likelihood()) + def test_repr(self): + expected = 'ExponentialLikelihood(x={}, y={}, func={})'.format(self.x, self.y, self.function.__name__) + self.assertEqual(expected, repr(self.exponential_likelihood)) + class TestJointLikelihood(unittest.TestCase): @@ -554,6 +564,10 @@ class TestJointLikelihood(unittest.TestCase): self.joint_likelihood.likelihoods = self.first_likelihood self.assertEqual(self.first_likelihood.log_likelihood(), self.joint_likelihood.log_likelihood()) + def test_setting_likelihood_other(self): + with self.assertRaises(ValueError): + self.joint_likelihood.likelihoods = 'test' + # Appending is not supported # def test_appending(self): # joint_likelihood = tupak.core.likelihood.JointLikelihood(self.first_likelihood, self.second_likelihood) diff --git a/test/waveform_generator_test.py b/test/waveform_generator_test.py index 99dd5d3f005f3dd39503c2fbdc747d63f45a8e2e..a7cd95ebe9ee6ce43f4bbe5dbbf5a04bf44f9c2a 100644 --- a/test/waveform_generator_test.py +++ b/test/waveform_generator_test.py @@ -47,6 +47,43 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC self.waveform_generator.waveform_arguments) self.assertEqual(expected, repr(self.waveform_generator)) + def test_repr_with_time_domain_source_model(self): + self.waveform_generator = \ + tupak.gw.waveform_generator.WaveformGenerator(1, 4096, + time_domain_source_model=dummy_func_dict_return_value) + expected = 'WaveformGenerator(duration={}, sampling_frequency={}, start_time={}, ' \ + 'frequency_domain_source_model={}, time_domain_source_model={}, parameters={}, ' \ + 'parameter_conversion={}, non_standard_sampling_parameter_keys={}, waveform_arguments={})'\ + .format(self.waveform_generator.duration, + self.waveform_generator.sampling_frequency, + self.waveform_generator.start_time, + self.waveform_generator.frequency_domain_source_model, + self.waveform_generator.time_domain_source_model.__name__, + self.waveform_generator.parameters, + None, + self.waveform_generator.non_standard_sampling_parameter_keys, + self.waveform_generator.waveform_arguments) + self.assertEqual(expected, repr(self.waveform_generator)) + + def test_repr_with_param_conversion(self): + def conversion_func(): + pass + + self.waveform_generator.parameter_conversion = conversion_func + expected = 'WaveformGenerator(duration={}, sampling_frequency={}, start_time={}, ' \ + 'frequency_domain_source_model={}, time_domain_source_model={}, parameters={}, ' \ + 'parameter_conversion={}, non_standard_sampling_parameter_keys={}, waveform_arguments={})'\ + .format(self.waveform_generator.duration, + self.waveform_generator.sampling_frequency, + self.waveform_generator.start_time, + self.waveform_generator.frequency_domain_source_model.__name__, + self.waveform_generator.time_domain_source_model, + self.waveform_generator.parameters, + conversion_func.__name__, + self.waveform_generator.non_standard_sampling_parameter_keys, + self.waveform_generator.waveform_arguments) + self.assertEqual(expected, repr(self.waveform_generator)) + def test_duration(self): self.assertEqual(self.waveform_generator.duration, 1) @@ -128,6 +165,16 @@ class TestSetters(unittest.TestCase): self.assertListEqual(sorted(list(self.waveform_generator.parameters.keys())), sorted(list(self.simulation_parameters.keys()))) + def test_set_parameter_conversion_at_init(self): + def conversion_func(): + pass + + self.waveform_generator = \ + tupak.gw.waveform_generator.WaveformGenerator(1, 4096, + frequency_domain_source_model=dummy_func_dict_return_value, + parameter_conversion=conversion_func) + self.assertEqual(conversion_func, self.waveform_generator.parameter_conversion) + class TestFrequencyDomainStrainMethod(unittest.TestCase): diff --git a/tupak/gw/waveform_generator.py b/tupak/gw/waveform_generator.py index 1c0f785823ac373a1dd1c8d2306e524ae0d17b51..b5e08cc15548c53ba661ab3393db8925509f281c 100644 --- a/tupak/gw/waveform_generator.py +++ b/tupak/gw/waveform_generator.py @@ -75,7 +75,7 @@ class WaveformGenerator(object): else: fdsm_name = None if self.time_domain_source_model is not None: - tdsm_name = self.frequency_domain_source_model.__name__ + tdsm_name = self.time_domain_source_model.__name__ else: tdsm_name = None if self.parameter_conversion.__name__ == '<lambda>':