Skip to content
Snippets Groups Projects
Commit dcb96dec authored by Colm Talbot's avatar Colm Talbot
Browse files

MAINT: add compatibility with pymc_v5

parent eabbd678
No related branches found
No related tags found
1 merge request!1191MAINT: add compatibility with pymc_v5
...@@ -124,18 +124,23 @@ class Pymc(MCMCSampler): ...@@ -124,18 +124,23 @@ class Pymc(MCMCSampler):
@staticmethod @staticmethod
def _import_external_sampler(): def _import_external_sampler():
import pymc import pymc
from pymc.aesaraf import floatX from pymc import floatX
from pymc.step_methods import STEP_METHODS from pymc.step_methods import STEP_METHODS
return pymc, STEP_METHODS, floatX return pymc, STEP_METHODS, floatX
@staticmethod @staticmethod
def _import_aesara(): def _import_tensor():
import aesara # noqa try:
import aesara.tensor as tt import pytensor as tensor # noqa
from aesara.compile.ops import as_op # noqa import pytensor.tensor as tt
from pytensor.compile.ops import as_op # noqa
return aesara, tt, as_op except ImportError:
import aesara as tensor # noqa
import aesara.tensor as tt
from aesara.compile.ops import as_op # noqa
return tensor, tt, as_op
def _verify_parameters(self): def _verify_parameters(self):
""" """
...@@ -251,8 +256,8 @@ class Pymc(MCMCSampler): ...@@ -251,8 +256,8 @@ class Pymc(MCMCSampler):
""" """
# check prior is a Sine # check prior is a Sine
pymc, STEP_METHODS, floatX = self._import_external_sampler() pymc, _, floatX = self._import_external_sampler()
aesara, tt, as_op = self._import_aesara() _, tt, _ = self._import_tensor()
if isinstance(self.priors[key], Sine): if isinstance(self.priors[key], Sine):
class PymcSine(pymc.Continuous): class PymcSine(pymc.Continuous):
...@@ -296,8 +301,8 @@ class Pymc(MCMCSampler): ...@@ -296,8 +301,8 @@ class Pymc(MCMCSampler):
""" """
# check prior is a Cosine # check prior is a Cosine
pymc, STEP_METHODS, floatX = self._import_external_sampler() pymc, _, floatX = self._import_external_sampler()
aesara, tt, as_op = self._import_aesara() _, tt, _ = self._import_tensor()
if isinstance(self.priors[key], Cosine): if isinstance(self.priors[key], Cosine):
class PymcCosine(pymc.Continuous): class PymcCosine(pymc.Continuous):
...@@ -340,8 +345,8 @@ class Pymc(MCMCSampler): ...@@ -340,8 +345,8 @@ class Pymc(MCMCSampler):
""" """
# check prior is a PowerLaw # check prior is a PowerLaw
pymc, STEP_METHODS, floatX = self._import_external_sampler() pymc, _, floatX = self._import_external_sampler()
aesara, tt, as_op = self._import_aesara() _, tt, _ = self._import_tensor()
if isinstance(self.priors[key], PowerLaw): if isinstance(self.priors[key], PowerLaw):
# check power law is set # check power law is set
...@@ -405,8 +410,7 @@ class Pymc(MCMCSampler): ...@@ -405,8 +410,7 @@ class Pymc(MCMCSampler):
""" """
# check prior is a PowerLaw # check prior is a PowerLaw
pymc, STEP_METHODS, floatX = self._import_external_sampler() pymc, _, _ = self._import_external_sampler()
aesara, tt, as_op = self._import_aesara()
if isinstance(self.priors[key], MultivariateGaussian): if isinstance(self.priors[key], MultivariateGaussian):
# get names of multivariate Gaussian parameters # get names of multivariate Gaussian parameters
mvpars = self.priors[key].mvg.names mvpars = self.priors[key].mvg.names
...@@ -648,8 +652,14 @@ class Pymc(MCMCSampler): ...@@ -648,8 +652,14 @@ class Pymc(MCMCSampler):
self.kwargs.update(nuts_kwargs) self.kwargs.update(nuts_kwargs)
with self.pymc_model: with self.pymc_model:
# perform the sampling # perform the sampling and then convert to inference data
trace = pymc.sample(**self.kwargs) trace = pymc.sample(**self.kwargs, return_inferencedata=False)
ikwargs = dict(
model=self.pymc_model,
save_warmup=not self.kwargs["discard_tuned_samples"],
log_likelihood=True,
)
trace = pymc.to_inference_data(trace, **ikwargs)
posterior = trace.posterior.to_dataframe().reset_index() posterior = trace.posterior.to_dataframe().reset_index()
self.result.samples = posterior[self.search_parameter_keys] self.result.samples = posterior[self.search_parameter_keys]
...@@ -709,7 +719,7 @@ class Pymc(MCMCSampler): ...@@ -709,7 +719,7 @@ class Pymc(MCMCSampler):
self.setup_prior_mapping() self.setup_prior_mapping()
self.pymc_priors = dict() self.pymc_priors = dict()
pymc, STEP_METHODS, floatX = self._import_external_sampler() pymc, _, _ = self._import_external_sampler()
# initialise a dictionary of multivariate Gaussian parameters # initialise a dictionary of multivariate Gaussian parameters
self.multivariate_normal_sets = {} self.multivariate_normal_sets = {}
...@@ -804,9 +814,9 @@ class Pymc(MCMCSampler): ...@@ -804,9 +814,9 @@ class Pymc(MCMCSampler):
Convert any bilby likelihoods to PyMC distributions. Convert any bilby likelihoods to PyMC distributions.
""" """
# create aesara Op for the log likelihood if not using a predefined model # create Op for the log likelihood if not using a predefined model
pymc, STEP_METHODS, floatX = self._import_external_sampler() pymc, _, _ = self._import_external_sampler()
aesara, tt, as_op = self._import_aesara() _, tt, _ = self._import_tensor()
class LogLike(tt.Op): class LogLike(tt.Op):
...@@ -838,7 +848,7 @@ class Pymc(MCMCSampler): ...@@ -838,7 +848,7 @@ class Pymc(MCMCSampler):
(theta,) = inputs (theta,) = inputs
return [g[0] * self.logpgrad(theta)] return [g[0] * self.logpgrad(theta)]
# create aesara Op for calculating the gradient of the log likelihood # create Op for calculating the gradient of the log likelihood
class LogLikeGrad(tt.Op): class LogLikeGrad(tt.Op):
itypes = [tt.dvector] itypes = [tt.dvector]
...@@ -993,7 +1003,7 @@ class Pymc(MCMCSampler): ...@@ -993,7 +1003,7 @@ class Pymc(MCMCSampler):
f"Unknown key '{key}' when setting GravitationalWaveTransient likelihood" f"Unknown key '{key}' when setting GravitationalWaveTransient likelihood"
) )
# convert to aesara tensor variable # convert to tensor variable
values = tt.as_tensor_variable(list(parameters.values())) values = tt.as_tensor_variable(list(parameters.values()))
pymc.DensityDist( pymc.DensityDist(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment