Skip to content
Snippets Groups Projects
Commit 1f7e12bf authored by moritz's avatar moritz
Browse files

Moritz Huebner: Got rid of value property in parameter class

parent 99609f17
No related branches found
No related tags found
1 merge request!17Unify parameter and prior classes
......@@ -4,7 +4,6 @@ from __future__ import print_function, division
from . import utils
from . import detector
from . import prior
from . import parameter
from . import source
from . import likelihood
from . import waveform_generator
......
from __future__ import division, print_function, absolute_import
import numpy as np
from . import prior
class Parameter(object):
def __init__(self, name, prior=None, latex_label=None):
self.name = name
self.prior = prior
self.latex_label = latex_label
@property
def prior(self):
return self.__prior
@property
def latex_label(self):
return self.__latex_label
@property
def is_fixed(self):
return isinstance(self.__prior, prior.DeltaFunction)
@prior.setter
def prior(self, prior=None):
if prior is None:
self.set_default_prior()
else:
self.__prior = prior
@latex_label.setter
def latex_label(self, latex_label=None):
if latex_label is None:
self.set_default_latex_label()
else:
self.__latex_label = latex_label
def fix(self, value=None):
"""
Specify parameter as fixed, this will not be sampled.
"""
if value is None or np.isnan(value):
raise ValueError("You can't fix the value to be np.nan. You need to assign it a legal value")
self.prior = prior.DeltaFunction(value)
def set_default_prior(self):
if self.name == 'mass_1':
self.__prior = prior.PowerLaw(alpha=0, bounds=(5, 100))
elif self.name == 'mass_2':
self.__prior = prior.PowerLaw(alpha=0, bounds=(5, 100))
elif self.name == 'mchirp':
self.__prior = prior.PowerLaw(alpha=0, bounds=(5, 100))
elif self.name == 'q':
self.__prior = prior.PowerLaw(alpha=0, bounds=(0, 1))
elif self.name == 'a1':
self.__prior = prior.PowerLaw(alpha=0, bounds=(0, 1))
elif self.name == 'a2':
self.__prior = prior.PowerLaw(alpha=0, bounds=(0, 1))
elif self.name == 'tilt1':
self.__prior = prior.Sine()
elif self.name == 'tilt2':
self.__prior = prior.Sine()
elif self.name == 'phi1':
self.__prior = prior.PowerLaw(alpha=0, bounds=(0, 2 * np.pi))
elif self.name == 'phi2':
self.__prior = prior.PowerLaw(alpha=0, bounds=(0, 2 * np.pi))
elif self.name == 'luminosity_distance':
self.__prior = prior.PowerLaw(alpha=2, bounds=(1e2, 5e3))
elif self.name == 'dec':
self.__prior = prior.Cosine()
elif self.name == 'ra':
self.__prior = prior.PowerLaw(alpha=0, bounds=(0, 2 * np.pi))
elif self.name == 'iota':
self.__prior = prior.Sine()
elif self.name == 'psi':
self.__prior = prior.PowerLaw(alpha=0, bounds=(0, 2 * np.pi))
elif self.name == 'phase':
self.__prior = prior.PowerLaw(alpha=0, bounds=(0, 2 * np.pi))
else:
self.__prior = None
def set_default_values(self):
# spins
if self.name == 'a1':
self.__prior = prior.DeltaFunction(0)
elif self.name == 'a2':
self.__prior = prior.DeltaFunction(0)
elif self.name == 'tilt1':
self.__prior = prior.DeltaFunction(0)
elif self.name == 'tilt2':
self.__prior = prior.DeltaFunction(0)
elif self.name == 'phi1':
self.__prior = prior.DeltaFunction(0)
elif self.name == 'phi2':
self.__prior = prior.DeltaFunction(0)
def set_default_latex_label(self):
if self.name == 'mass_1':
self.__latex_label = '$m_1$'
elif self.name == 'mass_2':
self.__latex_label = '$m_2$'
elif self.name == 'mchirp':
self.__latex_label = '$\mathcal{M}$'
elif self.name == 'q':
self.__latex_label = 'q'
elif self.name == 'a1':
self.__latex_label = 'a_1'
elif self.name == 'a2':
self.__latex_label = 'a_2'
elif self.name == 'tilt1':
self.__latex_label = 'tilt_1'
elif self.name == 'tilt2':
self.__latex_label = 'tilt_2'
elif self.name == 'phi1':
self.__latex_label = '\phi_1'
elif self.name == 'phi2':
self.__latex_label = '\phi_2'
elif self.name == 'luminosity_distance':
self.__latex_label = 'd_L'
elif self.name == 'dec':
self.__latex_label = '\mathrm{DEC}'
elif self.name == 'ra':
self.__latex_label = '\mathrm{RA}'
elif self.name == 'iota':
self.__latex_label = '\iota'
elif self.name == 'psi':
self.__latex_label = '\psi'
elif self.name == 'phase':
self.__latex_label = '\phi'
elif self.name == 'tc':
self.__latex_label = 't_c'
else:
self.__latex_label = self.name
@staticmethod
def parse_floats_to_parameters(old_parameters):
parameters = old_parameters.copy()
for key in parameters:
if type(parameters[key]) is not float and type(parameters[key]) is not int \
and type(parameters[key]) is not Parameter:
print("Expected parameter " + str(key) + " to be a float or int but was " + str(type(parameters[key]))
+ " instead. Will not be converted.")
continue
elif type(parameters[key]) is Parameter:
continue
parameters[key] = Parameter(key)
parameters[key].fix(old_parameters[key])
return parameters
@staticmethod
def parse_keys_to_parameters(keys):
parameters = {}
for key in keys:
parameters[key] = Parameter(key)
return parameters
\ No newline at end of file
......@@ -8,8 +8,9 @@ from scipy.integrate import cumtrapz
class Prior(object):
"""Prior class"""
def __init__(self, **kwargs):
return
def __init__(self, name=None, latex_label=None):
self.name = name
self.latex_label = latex_label
def __call__(self):
return self.sample(1)
......@@ -32,12 +33,66 @@ class Prior(object):
['{}={}'.format(k, v) for k, v in self.__dict__.items()])
return "{}({})".format(prior_name, prior_args)
@property
def is_fixed(self):
return isinstance(self, DeltaFunction)
@property
def latex_label(self):
return self.__latex_label
@latex_label.setter
def latex_label(self, latex_label=None):
if latex_label is None:
self.__latex_label = self.__default_latex_label
else:
self.__latex_label = latex_label
@property
def __default_latex_label(self):
if self.name == 'mass_1':
return '$m_1$'
elif self.name == 'mass_2':
return '$m_2$'
elif self.name == 'mchirp':
return '$\mathcal{M}$'
elif self.name == 'q':
return 'q'
elif self.name == 'a1':
return 'a_1'
elif self.name == 'a2':
return 'a_2'
elif self.name == 'tilt1':
return 'tilt_1'
elif self.name == 'tilt2':
return 'tilt_2'
elif self.name == 'phi1':
return '\phi_1'
elif self.name == 'phi2':
return '\phi_2'
elif self.name == 'luminosity_distance':
return 'd_L'
elif self.name == 'dec':
return '\mathrm{DEC}'
elif self.name == 'ra':
return '\mathrm{RA}'
elif self.name == 'iota':
return '\iota'
elif self.name == 'psi':
return '\psi'
elif self.name == 'phase':
return '\phi'
elif self.name == 'tc':
return 't_c'
else:
return self.name
class Uniform(Prior):
"""Uniform prior"""
def __init__(self, lower, upper):
Prior.__init__(self)
def __init__(self, lower, upper, name=None, latex_label=None):
Prior.__init__(self, name, latex_label)
self.lower = lower
self.upper = upper
self.support = upper - lower
......@@ -56,8 +111,8 @@ class Uniform(Prior):
class DeltaFunction(Prior):
"""Dirac delta function prior, this always returns peak."""
def __init__(self, peak):
Prior.__init__(self)
def __init__(self, peak, name=None, latex_label=None):
Prior.__init__(self, name, latex_label)
self.peak = peak
def rescale(self, val):
......@@ -75,9 +130,9 @@ class DeltaFunction(Prior):
class PowerLaw(Prior):
"""Power law prior distribution"""
def __init__(self, alpha, bounds):
def __init__(self, alpha, bounds, name=None, latex_label=None):
"""Power law with bounds and alpha, spectral index"""
Prior.__init__(self)
Prior.__init__(self, name, latex_label)
self.alpha = alpha
self.low, self.high = bounds
......@@ -101,8 +156,8 @@ class PowerLaw(Prior):
class Cosine(Prior):
def __init__(self):
Prior.__init__(self)
def __init__(self, name=None, latex_label=None):
Prior.__init__(self, name, latex_label)
def rescale(self, val):
"""
......@@ -120,8 +175,8 @@ class Cosine(Prior):
class Sine(Prior):
def __init__(self):
Prior.__init__(self)
def __init__(self, name=None, latex_label=None):
Prior.__init__(self, name, latex_label)
def rescale(self, val):
"""
......@@ -139,9 +194,9 @@ class Sine(Prior):
class Interped(Prior):
def __init__(self, xx, yy):
def __init__(self, xx, yy, name=None, latex_label=None):
"""Initialise object from arrays of x and y=p(x)"""
Prior.__init__(self)
Prior.__init__(self, name, latex_label)
self.xx = xx
self.low = min(self.xx)
self.high = max(self.xx)
......@@ -179,3 +234,71 @@ class FromFile(Interped):
print("Can't load {}.".format(file_name))
print("Format should be:")
print(r"x\tp(x)")
def fix(prior, value=None):
if value is None or np.isnan(value):
raise ValueError("You can't fix the value to be np.nan. You need to assign it a legal value")
prior = DeltaFunction(name=prior.name,
latex_label=prior.latex_label,
peak=value)
return prior
def create_default_prior(name):
if name == 'mass_1':
prior = PowerLaw(alpha=0, bounds=(5, 100))
elif name == 'mass_2':
prior = PowerLaw(alpha=0, bounds=(5, 100))
elif name == 'mchirp':
prior = PowerLaw(alpha=0, bounds=(5, 100))
elif name == 'q':
prior = PowerLaw(alpha=0, bounds=(0, 1))
elif name == 'a1':
prior = PowerLaw(alpha=0, bounds=(0, 1))
elif name == 'a2':
prior = PowerLaw(alpha=0, bounds=(0, 1))
elif name == 'tilt1':
prior = Sine()
elif name == 'tilt2':
prior = Sine()
elif name == 'phi1':
prior = PowerLaw(alpha=0, bounds=(0, 2 * np.pi))
elif name == 'phi2':
prior = PowerLaw(alpha=0, bounds=(0, 2 * np.pi))
elif name == 'luminosity_distance':
prior = PowerLaw(alpha=2, bounds=(1e2, 5e3))
elif name == 'dec':
prior = Cosine()
elif name == 'ra':
prior = PowerLaw(alpha=0, bounds=(0, 2 * np.pi))
elif name == 'iota':
prior = Sine()
elif name == 'psi':
prior = PowerLaw(alpha=0, bounds=(0, 2 * np.pi))
elif name == 'phase':
prior = PowerLaw(alpha=0, bounds=(0, 2 * np.pi))
else:
prior = None
return prior
def parse_floats_to_fixed_priors(old_parameters):
parameters = old_parameters.copy()
for key in parameters:
if type(parameters[key]) is not float and type(parameters[key]) is not int \
and type(parameters[key]) is not Prior:
print("Expected parameter " + str(key) + " to be a float or int but was " + str(type(parameters[key]))
+ " instead. Will not be converted.")
continue
elif type(parameters[key]) is Prior:
continue
parameters[key] = DeltaFunction(name=key, latex_label=None, peak=old_parameters[key])
return parameters
def parse_keys_to_parameters(keys):
parameters = {}
for key in keys:
parameters[key] = create_default_prior(key)
return parameters
......@@ -8,7 +8,7 @@ import sys
import numpy as np
from .result import Result
from .parameter import Parameter
from peyote import prior
class Sampler(object):
......@@ -98,18 +98,17 @@ class Sampler(object):
def initialise_parameters(self):
for key in self.priors:
if isinstance(self.priors[key], Parameter) \
and self.priors[key].prior is not None \
if isinstance(self.priors[key], prior.Prior) is True \
and self.priors[key].is_fixed is False:
self.__search_parameter_keys.append(key)
elif isinstance(self.priors[key], Parameter) \
elif isinstance(self.priors[key], prior.Prior) \
and self.priors[key].is_fixed is True:
self.likelihood.waveform_generator.parameters[key] = \
self.priors[key].prior.sample()
logging.info("Search parameters:")
for key in self.__search_parameter_keys:
logging.info(' {} ~ {}'.format(key, self.priors[key].prior))
logging.info(' {} ~ {}'.format(key, self.priors[key]))
def verify_parameters(self):
required_keys = self.priors
......@@ -119,7 +118,7 @@ class Sampler(object):
"Source model does not contain keys {}".format(unmatched_keys))
def prior_transform(self, theta):
return [self.priors[key].prior.rescale(t) for key, t in zip(self.__search_parameter_keys, theta)]
return [self.priors[key].rescale(t) for key, t in zip(self.__search_parameter_keys, theta)]
def log_likelihood(self, theta):
for i, k in enumerate(self.__search_parameter_keys):
......
import inspect
from . import utils
from . import parameter
class WaveformGenerator(object):
""" A waveform generator
......
......@@ -7,7 +7,7 @@ class TestParameterInstantiationWithoutOptionalParameters(unittest.TestCase):
def setUp(self):
self.test_name = 'test_name'
self.parameter = peyote.parameter.Parameter(self.test_name)
self.parameter = peyote.parameter.PriorFactory(self.test_name)
def tearDown(self):
del self.parameter
......@@ -32,7 +32,7 @@ class TestParameterName(unittest.TestCase):
def setUp(self):
self.test_name = 'test_name'
self.parameter = peyote.parameter.Parameter(self.test_name)
self.parameter = peyote.parameter.PriorFactory(self.test_name)
def tearDown(self):
del self.parameter
......@@ -46,7 +46,7 @@ class TestParameterPrior(unittest.TestCase):
def setUp(self):
self.test_name = 'test_name'
self.parameter = peyote.parameter.Parameter(self.test_name)
self.parameter = peyote.parameter.PriorFactory(self.test_name)
def tearDown(self):
del self.parameter
......@@ -66,7 +66,7 @@ class TestParameterPrior(unittest.TestCase):
class TestParameterValue(unittest.TestCase):
def setUp(self):
self.test_name = 'test_name'
self.parameter = peyote.parameter.Parameter(self.test_name)
self.parameter = peyote.parameter.PriorFactory(self.test_name)
def tearDown(self):
del self.parameter
......@@ -89,7 +89,7 @@ class TestParameterValue(unittest.TestCase):
class TestParameterLatexLabel(unittest.TestCase):
def setUp(self):
self.test_name = 'test_name'
self.parameter = peyote.parameter.Parameter(self.test_name)
self.parameter = peyote.parameter.PriorFactory(self.test_name)
def tearDown(self):
del self.parameter
......@@ -111,7 +111,7 @@ class TestParameterLatexLabel(unittest.TestCase):
class TestParameterIsFixed(unittest.TestCase):
def setUp(self):
self.test_name = 'test_name'
self.parameter = peyote.parameter.Parameter(self.test_name)
self.parameter = peyote.parameter.PriorFactory(self.test_name)
def tearDown(self):
del self.parameter
......@@ -128,7 +128,7 @@ class TestFixMethod(unittest.TestCase):
def setUp(self):
self.test_name = 'test_name'
self.parameter = peyote.parameter.Parameter(self.test_name)
self.parameter = peyote.parameter.PriorFactory(self.test_name)
def tearDown(self):
del self.parameter
......
......@@ -50,8 +50,8 @@ class Test(unittest.TestCase):
[self.msd['IFO']], self.msd['waveform_generator'])
dL = self.msd['simulation_parameters']['luminosity_distance']
priors = {'luminosity_distance' : peyote.parameter.Parameter('luminosity_distance',
prior=peyote.prior.Uniform(lower=dL - 10, upper=dL + 10))
priors = {'luminosity_distance' : peyote.parameter.PriorFactory('luminosity_distance',
prior=peyote.prior.Uniform(lower=dL - 10, upper=dL + 10))
}
result = peyote.sampler.run_sampler(likelihood, priors, sampler='nestle',
......
......@@ -3,7 +3,7 @@ import pylab as plt
import dynesty.plotting as dyplot
import corner
import peyote
import peyote.prior
peyote.utils.setup_logger()
......@@ -38,11 +38,9 @@ waveform_generator = peyote.waveform_generator.WaveformGenerator(
parameters=injection_parameters)
hf_signal = waveform_generator.frequency_domain_strain()
sampling_parameters = peyote.parameter.Parameter.\
parse_floats_to_parameters(injection_parameters)
#sampling_parameters = peyote.prior.parse_floats_to_fixed_priors(injection_parameters)
# sampling_parameters = peyote.parameter.Parameter.\
# parse_keys_to_parameters(simulation_parameters.keys())
sampling_parameters = peyote.prior.parse_keys_to_parameters(injection_parameters.keys())
# Simulate the data in H1
......@@ -81,8 +79,8 @@ fig.savefig('data')
likelihood = peyote.likelihood.Likelihood(IFOs, waveform_generator)
# New way way of doing it, still not perfect
sampling_parameters['mass_1'].prior = peyote.prior.Uniform(lower=35, upper=37)
sampling_parameters['luminosity_distance'].prior = peyote.prior.Uniform(lower=30, upper=200)
sampling_parameters['mass_1'] = peyote.prior.Uniform(lower=35, upper=37, name='mass1')
sampling_parameters['luminosity_distance'] = peyote.prior.Uniform(lower=30, upper=200, name='luminosity_distance')
#sampling_parameters["geocent_time"].prior = peyote.prior.Uniform(lower=injection_parameters["geocent_time"] - 0.1,
# upper=injection_parameters["geocent_time"]+0.1)
......
......@@ -108,14 +108,14 @@ prior = dict(spin11=0, spin12=0, spin13=0, spin21=0, spin22=0, spin23=0,
waveform_approximant='IMRPhenomPv2', reference_frequency=50.,
ra=1.375, dec=-1.2108, geocent_time=time_of_event, psi=2.659,
mass_1=36, mass_2=29)
prior = peyote.parameter.Parameter.parse_floats_to_parameters(prior)
prior['mass_1'] = peyote.parameter.Parameter(
prior = peyote.parameter.PriorFactory.parse_floats_to_parameters(prior)
prior['mass_1'] = peyote.parameter.PriorFactory(
'mass_1', prior=peyote.prior.Uniform(lower=35, upper=41),
latex_label='$m_1$')
prior['mass_2'] = peyote.parameter.Parameter(
prior['mass_2'] = peyote.parameter.PriorFactory(
'mass_2', prior=peyote.prior.Uniform(lower=20, upper=35),
latex_label='$m_2$')
prior['geocent_time'] = peyote.parameter.Parameter(
prior['geocent_time'] = peyote.parameter.PriorFactory(
'mass_2', prior=peyote.prior.Uniform(
lower=time_of_event-0.1, upper=time_of_event+0.1))
......
......@@ -32,10 +32,10 @@ print(likelihood.log_likelihood())
print(likelihood.log_likelihood_ratio())
prior = source.copy()
prior.mass_1 = peyote.parameter.Parameter('mass_1', prior=peyote.prior.Uniform(lower=35, upper=37),
latex_label='$m_1$')
prior.mass_2 = peyote.parameter.Parameter('mass_2', prior=peyote.prior.Uniform(lower=28, upper=30),
latex_label='$m_2$')
prior.mass_1 = peyote.parameter.PriorFactory('mass_1', prior=peyote.prior.Uniform(lower=35, upper=37),
latex_label='$m_1$')
prior.mass_2 = peyote.parameter.PriorFactory('mass_2', prior=peyote.prior.Uniform(lower=28, upper=30),
latex_label='$m_2$')
# result = peyote.sampler.run_sampler(likelihood, prior, sampler='dynesty', npoints=100, print_progress=True)
......
......@@ -71,7 +71,7 @@ simulation_parameters = dict(amplitude=1e-21,
dec=-1.2108,
geocent_time=1126259642.413,
psi=2.659)
sampling_parameters = peyote.parameter.Parameter.parse_floats_to_parameters(simulation_parameters)
sampling_parameters = peyote.parameter.PriorFactory.parse_floats_to_parameters(simulation_parameters)
wg = peyote.waveform_generator.WaveformGenerator(
source_model=gaussian_frequency_domain_strain,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment