From 765090051a82f294f99c8b21e2317cbe632b0524 Mon Sep 17 00:00:00 2001 From: Matthew Pitkin <matthew.pitkin@ligo.org> Date: Sat, 4 Aug 2018 23:00:29 +0100 Subject: [PATCH] Move setting of PyMC3 priors: - move the setting of the PyMC3 priors into the Pymc3 sampler class --- tupak/core/likelihood.py | 20 +++++++-------- tupak/core/prior.py | 55 ++++++++++++++++++++++++++-------------- tupak/core/sampler.py | 40 ++++++++++++++++++++++++++++- 3 files changed, 85 insertions(+), 30 deletions(-) diff --git a/tupak/core/likelihood.py b/tupak/core/likelihood.py index c8d57fca3..575e2d2a7 100644 --- a/tupak/core/likelihood.py +++ b/tupak/core/likelihood.py @@ -124,20 +124,20 @@ class GaussianLikelihood(Likelihood): if samplername != 'pymc3': raise ValueError("Only use this class method for PyMC3 sampler") - model_parameters = dict() - for key in sampler.priors: - if key == 'sigma' and self.sigma is None: - self.sigma = sampler.priors[key].pymc3_prior(sampler) + if 'sigma' in sampler.pymc3_priors: + # if sigma is suppled use that value + if self.sigma is None: + self.sigma = sampler.pymc3_priors.pop('sigma') else: - if key in self.function_keys: - model_parameters[key] = sampler.priors[key].pymc3_prior(sampler) - else: - raise ValueError("Prior key '{}' is not a function key!".format(key)) + del sampler.pymc3_priors['sigma'] - model = self.function(self.x, **model_parameters) + for key in sampler.pymc3_priors: + if key not in self.function_keys: + raise ValueError("Prior key '{}' is not a function key!".format(key)) - return sampler.external_sampler.Normal('likelihood', mu=model, sd=self.sigma, observed=self.y) + model = self.function(self.x, **sampler.pymc3_priors) + return sampler.external_sampler.Normal('likelihood', mu=model, sd=self.sigma, observed=self.y) class PoissonLikelihood(Likelihood): def __init__(self, x, counts, func): diff --git a/tupak/core/prior.py b/tupak/core/prior.py index 41dd444dc..d48617970 100644 --- a/tupak/core/prior.py +++ b/tupak/core/prior.py @@ -467,22 +467,29 @@ class Prior(object): label = self.name return label - def set_pymc3_prior(self, sampler, priortype, **kwargs): + def pymc3_prior(self, sampler): + """ + 'pymc3_prior' A user defined PyMC3 prior. + + This should be overwritten by each subclass if needed. + + Parameters + ---------- + val: float + A random number between 0 and 1 + + Returns + ------- + None + + """ + from tupak.core.sampler import Sampler if not isinstance(sampler, Sampler): raise ValueError("'sampler' is not a Sampler class") - try: - samplername = sampler.external_sampler.__name__ - except ValueError: - raise ValueError("Sampler's 'external_sampler' has not been initialised") - - if samplername != 'pymc3': - raise ValueError("Only use this class method for PyMC3 sampler") - - if priortype in sampler.external_sampler.__dict__: - return sampler.external_sampler.__dict__[priortype](self.name, **kwargs) + return None class DeltaFunction(Prior): @@ -644,6 +651,10 @@ class Uniform(Prior): """ Prior.__init__(self, name, latex_label, minimum, maximum) + # set PyMC3 Uniform distribution attributes + self.lower = self.minimum + self.upper = self.maximum + def rescale(self, val): Prior.test_valid_for_rescaling(val) return self.minimum + val * (self.maximum - self.minimum) @@ -676,14 +687,6 @@ class Uniform(Prior): return scipy.stats.uniform.logpdf(val, loc=self.minimum, scale=self.maximum-self.minimum) - def pymc3_prior(self, sampler): - priortype = 'Uniform' - priorargs = {} - priorargs['lower'] = self.minimum - priorargs['upper'] = self.maximum - - return self.set_pymc3_prior(sampler, priortype, **priorargs) - class LogUniform(PowerLaw): @@ -855,6 +858,20 @@ class Gaussian(Prior): return Prior._subclass_repr_helper(self, subclass_args=['mu', 'sigma']) +class Normal(Gaussian): + def __init__(self, mu, sigma, name=None, latex_label=None): + """A copy of the Gaussian prior, but with "Normal" name to copy that + used for the distribution in PyMC3. + + """ + + Gaussian.__init__(self, mu, sigma, name, latex_label) + + # set argument names used in PyMC3 distribution + self.mu = mu + self.sd = sigma + + class TruncatedGaussian(Prior): def __init__(self, mu, sigma, minimum, maximum, name=None, latex_label=None): diff --git a/tupak/core/sampler.py b/tupak/core/sampler.py index 041b2c14f..b3cee6884 100644 --- a/tupak/core/sampler.py +++ b/tupak/core/sampler.py @@ -958,6 +958,8 @@ class Pymc3(Sampler): model = pymc3.Model() with model: + self.set_prior() + likelihood = self.likelihood.pymc3_likelihood(self) # perform the sampling @@ -965,7 +967,7 @@ class Pymc3(Sampler): chains=self.chains, discard_tuned_samples=self.discard_tuned_samples) - nparams = int(len(trace.varnames)/self.chains) + nparams = len(trace.varnames) nsamples = len(trace)*self.chains self.result.samples = np.zeros((nsamples, nparams)) @@ -978,6 +980,42 @@ class Pymc3(Sampler): self.result.log_evidence_err = np.nan return self.result + def set_prior(self): + """ Set the PyMC3 prior distributions. + + """ + + self.pymc3_priors = dict() + + # set the parameter prior distributions + for key in self.priors: + # if the prior contains a pymc3_prior method use that otherwise try + # and find the PyMC3 distribution + if self.priors[key].pymc3_prior(self) is not None: + self.pymc3_priors[key] = self.priors[key].pymc3_prior(self) + else: + # use Prior distribution name + distname = self.priors[key].__class__.__name__ + + # check whether name is a PyMC3 distribution + if distname in self.external_sampler.__dict__: + # check the required arguments for the PyMC3 distribution + reqargs = inspect.getargspec(self.external_sampler.__dict__[distname].__init__).args[1:] + + priorkwargs = dict() + + # check whether the Prior class has required attributes + for arg in reqargs: + if hasattr(self.priors[key], arg): + priorkwargs[arg] = getattr(self.priors[key], arg) + else: + priorkwargs[arg] = None + + # set the prior + self.pymc3_priors[key] = self.external_sampler.__dict__[distname](key, **priorkwargs) + else: + raise ValueError("Prior '{}' is not a PyMC3 distribution.".format(distname)) + def calculate_autocorrelation(self, samples, c=3): """ Uses the `emcee.autocorr` module to estimate the autocorrelation -- GitLab