Skip to content
Snippets Groups Projects
Commit 1abfa584 authored by Matthew Pitkin's avatar Matthew Pitkin
Browse files

Move setting of PyMC3 prior into Prior class

parent 7aaab5fe
No related branches found
No related tags found
1 merge request!139Add PyMC3 sampler
......@@ -127,10 +127,10 @@ class GaussianLikelihood(Likelihood):
model_parameters = dict()
for key in sampler.priors:
if key == 'sigma' and self.sigma is None:
self.sigma = sampler.priors[key].pymc3(sampler)
self.sigma = sampler.priors[key].pymc3_prior(sampler)
else:
if key in self.function_keys:
model_parameters[key] = sampler.priors[key].pymc3(sampler)
model_parameters[key] = sampler.priors[key].pymc3_prior(sampler)
else:
raise ValueError("Prior key '{}' is not a function key!".format(key))
......
......@@ -472,6 +472,23 @@ class Prior(object):
label = self.name
return label
def set_pymc3_prior(self, sampler, priortype, **kwargs):
from tupak.core.sampler import Sampler
if not isinstance(sampler, Sampler):
raise ValueError("'sampler' is not a Sampler class")
try:
samplername = sampler.external_sampler.__name__
except ValueError:
raise ValueError("Sampler's 'external_sampler' has not been initialised")
if samplername != 'pymc3':
raise ValueError("Only use this class method for PyMC3 sampler")
if priortype in sampler.external_sampler.__dict__:
return sampler.external_sampler.__dict__[priortype](self.name, **kwargs)
class DeltaFunction(Prior):
......@@ -524,7 +541,7 @@ class DeltaFunction(Prior):
else:
return 0
def pymc3(self, sampler):
def pymc3_prior(self, sampler):
# just return the value
return self.peak
......@@ -666,22 +683,13 @@ class Uniform(Prior):
return scipy.stats.uniform.logpdf(val, loc=self.minimum,
scale=self.maximum-self.minimum)
def pymc3(self, sampler):
from tupak.core.sampler import Sampler
def pymc3_prior(self, sampler):
priortype = 'Uniform'
priorargs = {}
priorargs['lower'] = self.minimum
priorargs['upper'] = self.maximum
if not isinstance(sampler, Sampler):
raise ValueError("'sampler' is not a Sampler class")
try:
samplername = sampler.external_sampler.__name__
except ValueError:
raise ValueError("Sampler's 'external_sampler' has not been initialised")
if samplername != 'pymc3':
raise ValueError("Only use this class method for PyMC3 sampler")
return sampler.external_sampler.Uniform(self.name, lower=self.minimum,
upper=self.maximum)
return self.set_pymc3_prior(sampler, priortype, **priorargs)
class LogUniform(PowerLaw):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment