From a7654deb7b973b69b68245524faa051fc256924b Mon Sep 17 00:00:00 2001 From: Matthew Pitkin <matthew.pitkin@ligo.org> Date: Mon, 13 Aug 2018 11:24:39 +0100 Subject: [PATCH] Allow PyMC3 prior to be set through the Prior ln_prob method - added this to PyMC3 example --- ...near_regression_pymc3_custom_likelihood.py | 28 +++++++++- tupak/core/prior.py | 24 --------- tupak/core/sampler.py | 53 +++++++++++++++++-- 3 files changed, 76 insertions(+), 29 deletions(-) diff --git a/examples/other_examples/linear_regression_pymc3_custom_likelihood.py b/examples/other_examples/linear_regression_pymc3_custom_likelihood.py index 9535d58d9..3b3ab8af1 100644 --- a/examples/other_examples/linear_regression_pymc3_custom_likelihood.py +++ b/examples/other_examples/linear_regression_pymc3_custom_likelihood.py @@ -109,11 +109,37 @@ class GaussianLikelihoodPyMC3(tupak.Likelihood): # the time, data and signal model likelihood = GaussianLikelihoodPyMC3(time, data, sigma, model) + +# Define a custom prior for one of the parameter for use with PyMC3 +class PriorPyMC3(tupak.core.prior.Prior): + def __init__(self, minimum, maximum, name=None, latex_label=None): + """ + Uniform prior with bounds (should be equivalent to tupak.prior.Uniform) + """ + + tupak.core.prior.Prior.__init__(self, name, latex_label, + minimum=minimum, + maximum=maximum) + + def ln_prob(self, sampler=None): + """ + Change ln_prob method to take in a Sampler and return a PyMC3 + distribution. + """ + + from tupak.core.sampler import Pymc3 + + if not isinstance(sampler, Pymc3): + raise ValueError("Sampler is not a tupak Pymc3 sampler object") + + return pm.Uniform(self.name, lower=self.minimum, + upper=self.maximum) + # From hereon, the syntax is exactly equivalent to other tupak examples # We make a prior priors = {} priors['m'] = tupak.core.prior.Uniform(0, 5, 'm') -priors['c'] = tupak.core.prior.Uniform(-2, 2, 'c') +priors['c'] = PriorPyMC3(-2, 2, 'c') # And run sampler result = tupak.run_sampler( diff --git a/tupak/core/prior.py b/tupak/core/prior.py index 31f04c93f..0a4151601 100644 --- a/tupak/core/prior.py +++ b/tupak/core/prior.py @@ -467,30 +467,6 @@ class Prior(object): label = self.name return label - def pymc3_prior(self, sampler): - """ - 'pymc3_prior' A user defined PyMC3 prior. - - This should be overwritten by each subclass if needed. - - Parameters - ---------- - sampler: `tupak.core.sampler.Sampler` - A Sampler class - - Returns - ------- - None - - """ - - from tupak.core.sampler import Sampler - - if not isinstance(sampler, Sampler): - raise ValueError("'sampler' is not a Sampler class") - - return None - class DeltaFunction(Prior): diff --git a/tupak/core/sampler.py b/tupak/core/sampler.py index 045d1f667..eca15e6c8 100644 --- a/tupak/core/sampler.py +++ b/tupak/core/sampler.py @@ -957,6 +957,46 @@ class Pymc3(Sampler): """ pass + def _initialise_parameters(self): + """ + Change `_initialise_parameters()`, so that it does call the `sample` + method in the Prior class. + + """ + + self.__search_parameter_keys = [] + self.__fixed_parameter_keys = [] + + for key in self.priors: + if isinstance(self.priors[key], Prior) \ + and self.priors[key].is_fixed is False: + self.__search_parameter_keys.append(key) + elif isinstance(self.priors[key], Prior) \ + and self.priors[key].is_fixed is True: + self.__fixed_parameter_keys.append(key) + + logger.info("Search parameters:") + for key in self.__search_parameter_keys: + logger.info(' {} = {}'.format(key, self.priors[key])) + for key in self.__fixed_parameter_keys: + logger.info(' {} = {}'.format(key, self.priors[key].peak)) + + def _initialise_result(self): + """ + Initialise results within Pymc3 subclass. + """ + result = Result() + result.sampler = self.__class__.__name__.lower() + result.search_parameter_keys = self.__search_parameter_keys + result.fixed_parameter_keys = self.__fixed_parameter_keys + result.parameter_labels = [ + self.priors[k].latex_label for k in + self.__search_parameter_keys] + result.label = self.label + result.outdir = self.outdir + result.kwargs = self.kwargs + return result + @property def kwargs(self): """ Ensures that proper keyword arguments are used for the Pymc3 sampler. @@ -1253,10 +1293,15 @@ class Pymc3(Sampler): # set the parameter prior distributions (in the model context manager) with self.pymc3_model: for key in self.priors: - # if the prior contains a pymc3_prior method use that otherwise try - # and find the PyMC3 distribution - if self.priors[key].pymc3_prior(self) is not None: - self.pymc3_priors[key] = self.priors[key].pymc3_prior(self) + # if the prior contains ln_prob method that takes a 'sampler' argument + # then try using that + lnprobargs = inspect.getargspec(self.priors[key].ln_prob).args + if 'sampler' in lnprobargs: + try: + self.pymc3_priors[key] = self.priors[key].ln_prob(sampler=self) + except RuntimeError: + raise RuntimeError(("Problem setting PyMC3 prior for ", + "'{}'".format(key))) else: # use Prior distribution name distname = self.priors[key].__class__.__name__ -- GitLab