Skip to content
Snippets Groups Projects
Commit a11ef839 authored by Moritz Huebner's avatar Moritz Huebner Committed by Colm Talbot
Browse files

Add some tests

parent 9352d877
No related branches found
No related tags found
No related merge requests found
......@@ -189,6 +189,12 @@ class TestGaussianLikelihood(unittest.TestCase):
likelihood.log_likelihood()
self.assertTrue(likelihood.sigma == 1)
def test_sigma_other(self):
likelihood = GaussianLikelihood(
self.x, self.y, self.function, sigma=None)
with self.assertRaises(ValueError):
likelihood.sigma = 'test'
def test_repr(self):
likelihood = GaussianLikelihood(
self.x, self.y, self.function, sigma=self.sigma)
......@@ -467,6 +473,10 @@ class TestExponentialLikelihood(unittest.TestCase):
m.return_value = 3
self.assertEqual(-3, exponential_likelihood.log_likelihood())
def test_repr(self):
expected = 'ExponentialLikelihood(x={}, y={}, func={})'.format(self.x, self.y, self.function.__name__)
self.assertEqual(expected, repr(self.exponential_likelihood))
class TestJointLikelihood(unittest.TestCase):
......@@ -554,6 +564,10 @@ class TestJointLikelihood(unittest.TestCase):
self.joint_likelihood.likelihoods = self.first_likelihood
self.assertEqual(self.first_likelihood.log_likelihood(), self.joint_likelihood.log_likelihood())
def test_setting_likelihood_other(self):
with self.assertRaises(ValueError):
self.joint_likelihood.likelihoods = 'test'
# Appending is not supported
# def test_appending(self):
# joint_likelihood = tupak.core.likelihood.JointLikelihood(self.first_likelihood, self.second_likelihood)
......
......@@ -47,6 +47,43 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC
self.waveform_generator.waveform_arguments)
self.assertEqual(expected, repr(self.waveform_generator))
def test_repr_with_time_domain_source_model(self):
self.waveform_generator = \
tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
time_domain_source_model=dummy_func_dict_return_value)
expected = 'WaveformGenerator(duration={}, sampling_frequency={}, start_time={}, ' \
'frequency_domain_source_model={}, time_domain_source_model={}, parameters={}, ' \
'parameter_conversion={}, non_standard_sampling_parameter_keys={}, waveform_arguments={})'\
.format(self.waveform_generator.duration,
self.waveform_generator.sampling_frequency,
self.waveform_generator.start_time,
self.waveform_generator.frequency_domain_source_model,
self.waveform_generator.time_domain_source_model.__name__,
self.waveform_generator.parameters,
None,
self.waveform_generator.non_standard_sampling_parameter_keys,
self.waveform_generator.waveform_arguments)
self.assertEqual(expected, repr(self.waveform_generator))
def test_repr_with_param_conversion(self):
def conversion_func():
pass
self.waveform_generator.parameter_conversion = conversion_func
expected = 'WaveformGenerator(duration={}, sampling_frequency={}, start_time={}, ' \
'frequency_domain_source_model={}, time_domain_source_model={}, parameters={}, ' \
'parameter_conversion={}, non_standard_sampling_parameter_keys={}, waveform_arguments={})'\
.format(self.waveform_generator.duration,
self.waveform_generator.sampling_frequency,
self.waveform_generator.start_time,
self.waveform_generator.frequency_domain_source_model.__name__,
self.waveform_generator.time_domain_source_model,
self.waveform_generator.parameters,
conversion_func.__name__,
self.waveform_generator.non_standard_sampling_parameter_keys,
self.waveform_generator.waveform_arguments)
self.assertEqual(expected, repr(self.waveform_generator))
def test_duration(self):
self.assertEqual(self.waveform_generator.duration, 1)
......@@ -128,6 +165,16 @@ class TestSetters(unittest.TestCase):
self.assertListEqual(sorted(list(self.waveform_generator.parameters.keys())),
sorted(list(self.simulation_parameters.keys())))
def test_set_parameter_conversion_at_init(self):
def conversion_func():
pass
self.waveform_generator = \
tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
frequency_domain_source_model=dummy_func_dict_return_value,
parameter_conversion=conversion_func)
self.assertEqual(conversion_func, self.waveform_generator.parameter_conversion)
class TestFrequencyDomainStrainMethod(unittest.TestCase):
......
......@@ -75,7 +75,7 @@ class WaveformGenerator(object):
else:
fdsm_name = None
if self.time_domain_source_model is not None:
tdsm_name = self.frequency_domain_source_model.__name__
tdsm_name = self.time_domain_source_model.__name__
else:
tdsm_name = None
if self.parameter_conversion.__name__ == '<lambda>':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment