Skip to content
Snippets Groups Projects
Commit 797886f1 authored by Colm Talbot's avatar Colm Talbot
Browse files

Merge branch 'Add_some_tests' into 'master'

Add some tests

See merge request Monash/tupak!193
parents 9352d877 a11ef839
No related branches found
No related tags found
No related merge requests found
......@@ -189,6 +189,12 @@ class TestGaussianLikelihood(unittest.TestCase):
likelihood.log_likelihood()
self.assertTrue(likelihood.sigma == 1)
def test_sigma_other(self):
likelihood = GaussianLikelihood(
self.x, self.y, self.function, sigma=None)
with self.assertRaises(ValueError):
likelihood.sigma = 'test'
def test_repr(self):
likelihood = GaussianLikelihood(
self.x, self.y, self.function, sigma=self.sigma)
......@@ -467,6 +473,10 @@ class TestExponentialLikelihood(unittest.TestCase):
m.return_value = 3
self.assertEqual(-3, exponential_likelihood.log_likelihood())
def test_repr(self):
expected = 'ExponentialLikelihood(x={}, y={}, func={})'.format(self.x, self.y, self.function.__name__)
self.assertEqual(expected, repr(self.exponential_likelihood))
class TestJointLikelihood(unittest.TestCase):
......@@ -554,6 +564,10 @@ class TestJointLikelihood(unittest.TestCase):
self.joint_likelihood.likelihoods = self.first_likelihood
self.assertEqual(self.first_likelihood.log_likelihood(), self.joint_likelihood.log_likelihood())
def test_setting_likelihood_other(self):
with self.assertRaises(ValueError):
self.joint_likelihood.likelihoods = 'test'
# Appending is not supported
# def test_appending(self):
# joint_likelihood = tupak.core.likelihood.JointLikelihood(self.first_likelihood, self.second_likelihood)
......
......@@ -47,6 +47,43 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC
self.waveform_generator.waveform_arguments)
self.assertEqual(expected, repr(self.waveform_generator))
def test_repr_with_time_domain_source_model(self):
self.waveform_generator = \
tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
time_domain_source_model=dummy_func_dict_return_value)
expected = 'WaveformGenerator(duration={}, sampling_frequency={}, start_time={}, ' \
'frequency_domain_source_model={}, time_domain_source_model={}, parameters={}, ' \
'parameter_conversion={}, non_standard_sampling_parameter_keys={}, waveform_arguments={})'\
.format(self.waveform_generator.duration,
self.waveform_generator.sampling_frequency,
self.waveform_generator.start_time,
self.waveform_generator.frequency_domain_source_model,
self.waveform_generator.time_domain_source_model.__name__,
self.waveform_generator.parameters,
None,
self.waveform_generator.non_standard_sampling_parameter_keys,
self.waveform_generator.waveform_arguments)
self.assertEqual(expected, repr(self.waveform_generator))
def test_repr_with_param_conversion(self):
def conversion_func():
pass
self.waveform_generator.parameter_conversion = conversion_func
expected = 'WaveformGenerator(duration={}, sampling_frequency={}, start_time={}, ' \
'frequency_domain_source_model={}, time_domain_source_model={}, parameters={}, ' \
'parameter_conversion={}, non_standard_sampling_parameter_keys={}, waveform_arguments={})'\
.format(self.waveform_generator.duration,
self.waveform_generator.sampling_frequency,
self.waveform_generator.start_time,
self.waveform_generator.frequency_domain_source_model.__name__,
self.waveform_generator.time_domain_source_model,
self.waveform_generator.parameters,
conversion_func.__name__,
self.waveform_generator.non_standard_sampling_parameter_keys,
self.waveform_generator.waveform_arguments)
self.assertEqual(expected, repr(self.waveform_generator))
def test_duration(self):
self.assertEqual(self.waveform_generator.duration, 1)
......@@ -128,6 +165,16 @@ class TestSetters(unittest.TestCase):
self.assertListEqual(sorted(list(self.waveform_generator.parameters.keys())),
sorted(list(self.simulation_parameters.keys())))
def test_set_parameter_conversion_at_init(self):
def conversion_func():
pass
self.waveform_generator = \
tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
frequency_domain_source_model=dummy_func_dict_return_value,
parameter_conversion=conversion_func)
self.assertEqual(conversion_func, self.waveform_generator.parameter_conversion)
class TestFrequencyDomainStrainMethod(unittest.TestCase):
......
......@@ -75,7 +75,7 @@ class WaveformGenerator(object):
else:
fdsm_name = None
if self.time_domain_source_model is not None:
tdsm_name = self.frequency_domain_source_model.__name__
tdsm_name = self.time_domain_source_model.__name__
else:
tdsm_name = None
if self.parameter_conversion.__name__ == '<lambda>':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment