Skip to content
Snippets Groups Projects
Commit a52d6223 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Merge branch '160-add-more-__repr__-methods' into 'master'

Resolve "Add more __repr__ methods"

Closes #160

See merge request Monash/tupak!170
parents 37acb467 dd021cde
No related branches found
No related tags found
1 merge request!170Resolve "Add more __repr__ methods"
Pipeline #30266 passed
......@@ -11,6 +11,11 @@ class TestBaseClass(unittest.TestCase):
def tearDown(self):
del self.model
def test_repr(self):
expected = 'Recalibrate(prefix={})'.format('\'recalib_\'')
actual = repr(self.model)
self.assertEqual(expected, actual)
def test_calibration_factor(self):
frequency_array = np.linspace(20, 1024, 1000)
cal_factor = self.model.get_calibration_factor(frequency_array)
......@@ -20,14 +25,22 @@ class TestBaseClass(unittest.TestCase):
class TestCubicSpline(unittest.TestCase):
def setUp(self):
self.prefix = 'recalib_'
self.minimum_frequency = 20
self.maximum_frequency = 1024
self.n_points = 5
self.model = calibration.CubicSpline(
prefix='recalib_', minimum_frequency=20, maximum_frequency=1024,
n_points=5)
prefix=self.prefix, minimum_frequency=self.minimum_frequency,
maximum_frequency=self.maximum_frequency, n_points=self.n_points)
self.parameters = {'recalib_{}_{}'.format(param, ii): 0.0
for ii in range(5)
for param in ['amplitude', 'phase']}
def tearDown(self):
del self.prefix
del self.minimum_frequency
del self.maximum_frequency
del self.n_points
del self.model
del self.parameters
......@@ -37,6 +50,12 @@ class TestCubicSpline(unittest.TestCase):
**self.parameters)
assert np.alltrue(cal_factor.real == np.ones_like(frequency_array))
def test_repr(self):
expected = 'CubicSpline(prefix=\'{}\', minimum_frequency={}, maximum_frequency={}, n_points={})'\
.format(self.prefix, self.minimum_frequency, self.maximum_frequency, self.n_points)
actual = repr(self.model)
self.assertEqual(expected, actual)
class TestCubicSplineRequiresFourNodes(unittest.TestCase):
......
......@@ -304,6 +304,17 @@ class TestDetector(unittest.TestCase):
self.assertTrue(np.array_equal(expected[0], actual[0])) # array-like element has to be evaluated separately
self.assertListEqual(expected[1], actual[1])
def test_repr(self):
expected = 'Interferometer(name=\'{}\', power_spectral_density={}, minimum_frequency={}, ' \
'maximum_frequency={}, length={}, latitude={}, longitude={}, elevation={}, xarm_azimuth={}, ' \
'yarm_azimuth={}, xarm_tilt={}, yarm_tilt={})' \
.format(self.name, self.power_spectral_density, float(self.minimum_frequency),
float(self.maximum_frequency), float(self.length), float(self.latitude), float(self.longitude),
float(self.elevation), float(self.xarm_azimuth), float(self.yarm_azimuth), float(self.xarm_tilt),
float(self.yarm_tilt))
print(repr(self.ifo))
self.assertEqual(expected, repr(self.ifo))
class TestInterferometerStrainData(unittest.TestCase):
......@@ -536,10 +547,10 @@ class TestInterferometerStrainData(unittest.TestCase):
def test_frequency_domain_strain_when_set(self):
self.ifosd.sampling_frequency = 200
self.ifosd.duration = 4
expected_strain = self.ifosd.frequency_array*self.ifosd.frequency_mask
expected_strain = self.ifosd.frequency_array * self.ifosd.frequency_mask
self.ifosd._frequency_domain_strain = expected_strain
self.assertTrue(np.array_equal(expected_strain,
self.ifosd.frequency_domain_strain))
self.ifosd.frequency_domain_strain))
@patch('tupak.core.utils.nfft')
def test_frequency_domain_strain_from_frequency_domain_strain(self, m):
......
......@@ -61,6 +61,11 @@ class TestBasicGWTransient(unittest.TestCase):
np.nan_to_num(-np.inf))
self.likelihood.waveform_generator.parameters['mass_2'] = 29
def test_repr(self):
expected = 'BasicGravitationalWaveTransient(interferometers={},\n\twaveform_generator={})'.format(
self.interferometers, self.waveform_generator)
self.assertEqual(expected, repr(self.likelihood))
class TestGWTransient(unittest.TestCase):
......@@ -133,6 +138,12 @@ class TestGWTransient(unittest.TestCase):
np.nan_to_num(-np.inf))
self.likelihood.waveform_generator.parameters['mass_2'] = 29
def test_repr(self):
expected = 'GravitationalWaveTransient(interferometers={},\n\twaveform_generator={},\n\t' \
'time_marginalization={}, distance_marginalization={}, phase_marginalization={}, ' \
'prior={})'.format(self.interferometers, self.waveform_generator, False, False, False, self.prior)
self.assertEqual(expected, repr(self.likelihood))
class TestTimeMarginalization(unittest.TestCase):
......
......@@ -19,6 +19,11 @@ class TestLikelihoodBase(unittest.TestCase):
def tearDown(self):
del self.likelihood
def test_repr(self):
self.likelihood = tupak.core.likelihood.Likelihood(parameters=['a', 'b'])
expected = 'Likelihood(parameters=[\'a\', \'b\'])'
self.assertEqual(expected, repr(self.likelihood))
def test_base_log_likelihood(self):
self.assertTrue(np.isnan(self.likelihood.log_likelihood()))
......@@ -125,6 +130,10 @@ class TestAnalytical1DLikelihood(unittest.TestCase):
parameter2=self.parameter2_value)
self.assertDictEqual(expected_model_parameters, self.analytical_1d_likelihood.model_parameters)
def test_repr(self):
expected = 'Analytical1DLikelihood(x={}, y={}, func={})'.format(self.x, self.y, self.func.__name__)
self.assertEqual(expected, repr(self.analytical_1d_likelihood))
class TestGaussianLikelihood(unittest.TestCase):
......@@ -182,6 +191,13 @@ class TestGaussianLikelihood(unittest.TestCase):
likelihood.log_likelihood()
self.assertTrue(likelihood.sigma is None)
def test_repr(self):
likelihood = tupak.core.likelihood.GaussianLikelihood(
self.x, self.y, self.function, sigma=self.sigma)
expected = 'GaussianLikelihood(x={}, y={}, func={}, sigma={})' \
.format(self.x, self.y, self.function.__name__, self.sigma)
self.assertEqual(expected, repr(likelihood))
class TestStudentTLikelihood(unittest.TestCase):
......@@ -258,6 +274,15 @@ class TestStudentTLikelihood(unittest.TestCase):
self.assertAlmostEqual(4.0, likelihood.lam)
def test_repr(self):
nu = 0
sigma = 0.5
likelihood = tupak.core.likelihood.StudentTLikelihood(
self.x, self.y, self.function, nu=nu, sigma=sigma)
expected = 'StudentTLikelihood(x={}, y={}, func={}, nu={}, sigma={})' \
.format(self.x, self.y, self.function.__name__, nu, sigma)
self.assertEqual(expected, repr(likelihood))
class TestPoissonLikelihood(unittest.TestCase):
......@@ -357,6 +382,12 @@ class TestPoissonLikelihood(unittest.TestCase):
m.return_value = 1
self.assertEqual(0, poisson_likelihood.log_likelihood())
def test_repr(self):
likelihood = tupak.core.likelihood.PoissonLikelihood(
self.x, self.y, self.function)
expected = 'PoissonLikelihood(x={}, y={}, func={})'.format(self.x, self.y, self.function.__name__)
self.assertEqual(expected, repr(likelihood))
class TestExponentialLikelihood(unittest.TestCase):
......
......@@ -32,6 +32,21 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC
del self.waveform_generator
del self.simulation_parameters
def test_repr(self):
expected = 'WaveformGenerator(duration={}, sampling_frequency={}, start_time={}, ' \
'frequency_domain_source_model={}, time_domain_source_model={}, parameters={}, ' \
'parameter_conversion={}, non_standard_sampling_parameter_keys={}, waveform_arguments={})'\
.format(self.waveform_generator.duration,
self.waveform_generator.sampling_frequency,
self.waveform_generator.start_time,
self.waveform_generator.frequency_domain_source_model.__name__,
self.waveform_generator.time_domain_source_model,
self.waveform_generator.parameters,
None,
self.waveform_generator.non_standard_sampling_parameter_keys,
self.waveform_generator.waveform_arguments)
self.assertEqual(expected, repr(self.waveform_generator))
def test_duration(self):
self.assertEqual(self.waveform_generator.duration, 1)
......
......@@ -17,6 +17,9 @@ class Likelihood(object):
"""
self.parameters = parameters
def __repr__(self):
return self.__class__.__name__ + '(parameters={})'.format(self.parameters)
def log_likelihood(self):
"""
......@@ -69,6 +72,9 @@ class Analytical1DLikelihood(Likelihood):
self.__func = func
self.__function_keys = list(self.parameters.keys())
def __repr__(self):
return self.__class__.__name__ + '(x={}, y={}, func={})'.format(self.x, self.y, self.func.__name__)
@property
def func(self):
""" Make func read-only """
......@@ -147,6 +153,10 @@ class GaussianLikelihood(Analytical1DLikelihood):
if self.sigma is None:
self.parameters['sigma'] = None
def __repr__(self):
return self.__class__.__name__ + '(x={}, y={}, func={}, sigma={})'\
.format(self.x, self.y, self.func.__name__, self.sigma)
def log_likelihood(self):
return self.__summed_log_likelihood(sigma=self.__get_sigma())
......@@ -189,6 +199,9 @@ class PoissonLikelihood(Analytical1DLikelihood):
Analytical1DLikelihood.__init__(self, x=x, y=y, func=func)
def __repr__(self):
return Analytical1DLikelihood.__repr__(self)
@property
def y(self):
""" Property assures that y-value is a positive integer. """
......@@ -236,6 +249,9 @@ class ExponentialLikelihood(Analytical1DLikelihood):
"""
Analytical1DLikelihood.__init__(self, x=x, y=y, func=func)
def __repr__(self):
return Analytical1DLikelihood.__repr__(self)
@property
def y(self):
""" Property assures that y-value is positive. """
......@@ -295,6 +311,10 @@ class StudentTLikelihood(Analytical1DLikelihood):
if self.nu is None:
self.parameters['nu'] = None
def __repr__(self):
return self.__class__.__name__ + '(x={}, y={}, func={}, nu={}, sigma={})'\
.format(self.x, self.y, self.func.__name__, self.nu, self.sigma)
@property
def lam(self):
""" Converts 'scale' to 'precision' """
......
......@@ -21,6 +21,9 @@ class Recalibrate(object):
self.params = dict()
self.prefix = prefix
def __repr__(self):
return self.__class__.__name__ + '(prefix=\'{}\')'.format(self.prefix)
def get_calibration_factor(self, frequency_array, **params):
"""Apply calibration model
......@@ -75,7 +78,17 @@ class CubicSpline(Recalibrate):
if n_points < 4:
raise ValueError('Cubic spline calibration requires at least 4 spline nodes.')
self.n_points = n_points
self.spline_points = np.logspace(np.log10(minimum_frequency), np.log10(maximum_frequency), n_points)
self.minimum_frequency = minimum_frequency
self.maximum_frequency = maximum_frequency
self.__spline_points = np.logspace(np.log10(minimum_frequency), np.log10(maximum_frequency), n_points)
@property
def spline_points(self):
return self.__spline_points
def __repr__(self):
return self.__class__.__name__ + '(prefix=\'{}\', minimum_frequency={}, maximum_frequency={}, n_points={})'\
.format(self.prefix, self.minimum_frequency, self.maximum_frequency, self.n_points)
def get_calibration_factor(self, frequency_array, **params):
"""Apply calibration model
......
......@@ -808,6 +808,16 @@ class Interferometer(object):
minimum_frequency=minimum_frequency,
maximum_frequency=maximum_frequency)
def __repr__(self):
return self.__class__.__name__ + '(name=\'{}\', power_spectral_density={}, minimum_frequency={}, ' \
'maximum_frequency={}, length={}, latitude={}, longitude={}, elevation={}, ' \
'xarm_azimuth={}, yarm_azimuth={}, xarm_tilt={}, yarm_tilt={})' \
.format(self.name, self.power_spectral_density, float(self.minimum_frequency),
float(self.maximum_frequency), float(self.length), float(self.latitude), float(self.longitude),
float(self.elevation), float(self.xarm_azimuth), float(self.yarm_azimuth), float(self.xarm_tilt),
float(self.yarm_tilt))
@property
def minimum_frequency(self):
return self.strain_data.minimum_frequency
......
......@@ -82,6 +82,12 @@ class GravitationalWaveTransient(likelihood.Likelihood):
self._setup_distance_marginalization()
prior['luminosity_distance'] = float(self._ref_dist)
def __repr__(self):
return self.__class__.__name__ + '(interferometers={},\n\twaveform_generator={},\n\ttime_marginalization={}, ' \
'distance_marginalization={}, phase_marginalization={}, prior={})'\
.format(self.interferometers, self.waveform_generator, self.time_marginalization,
self.distance_marginalization, self.phase_marginalization, self.prior)
def _check_set_duration_and_sampling_frequency_of_waveform_generator(self):
""" Check the waveform_generator has the same duration and
sampling_frequency as the interferometers. If they are unset, then
......@@ -307,6 +313,11 @@ class BasicGravitationalWaveTransient(likelihood.Likelihood):
self.interferometers = interferometers
self.waveform_generator = waveform_generator
def __repr__(self):
return self.__class__.__name__ + '(interferometers={},\n\twaveform_generator={})'\
.format(self.interferometers, self.waveform_generator)
def noise_log_likelihood(self):
""" Calculates the real part of noise log-likelihood
......
......@@ -6,7 +6,7 @@ class WaveformGenerator(object):
def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequency_domain_source_model=None,
time_domain_source_model=None, parameters=None,
parameter_conversion=lambda parameters, search_keys: (parameters, []),
parameter_conversion=None,
non_standard_sampling_parameter_keys=None,
waveform_arguments=None):
""" A waveform generator
......@@ -52,7 +52,10 @@ class WaveformGenerator(object):
self.__parameters_from_source_model()
self.duration = duration
self.sampling_frequency = sampling_frequency
self.parameter_conversion = parameter_conversion
if parameter_conversion is None:
self.parameter_conversion = lambda params, search_keys: (params, [])
else:
self.parameter_conversion = parameter_conversion
self.non_standard_sampling_parameter_keys = non_standard_sampling_parameter_keys
self.parameters = parameters
if waveform_arguments is not None:
......@@ -66,6 +69,27 @@ class WaveformGenerator(object):
self.__full_source_model_keyword_arguments.update(self.parameters)
self.__added_keys = []
def __repr__(self):
if self.frequency_domain_source_model is not None:
fdsm_name = self.frequency_domain_source_model.__name__
else:
fdsm_name = None
if self.time_domain_source_model is not None:
tdsm_name = self.frequency_domain_source_model.__name__
else:
tdsm_name = None
if self.parameter_conversion.__name__ == '<lambda>':
param_conv_name = None
else:
param_conv_name = self.parameter_conversion.__name__
return self.__class__.__name__ + '(duration={}, sampling_frequency={}, start_time={}, ' \
'frequency_domain_source_model={}, time_domain_source_model={}, ' \
'parameters={}, parameter_conversion={}, ' \
'non_standard_sampling_parameter_keys={}, waveform_arguments={})'\
.format(self.duration, self.sampling_frequency, self.start_time, fdsm_name, tdsm_name, self.parameters,
param_conv_name, self.non_standard_sampling_parameter_keys, self.waveform_arguments)
def frequency_domain_strain(self):
""" Rapper to source_model.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment