Skip to content
Snippets Groups Projects
Commit 1065e347 authored by Colm Talbot's avatar Colm Talbot
Browse files

Merge branch 'support-pymc5' into 'master'

MAINT: add compatibility with pymc_v5

Closes #658

See merge request !1191
parents eabbd678 dcb96dec
No related branches found
No related tags found
1 merge request!1191MAINT: add compatibility with pymc_v5
Pipeline #486987 failed
......@@ -124,18 +124,23 @@ class Pymc(MCMCSampler):
@staticmethod
def _import_external_sampler():
import pymc
from pymc.aesaraf import floatX
from pymc import floatX
from pymc.step_methods import STEP_METHODS
return pymc, STEP_METHODS, floatX
@staticmethod
def _import_aesara():
import aesara # noqa
import aesara.tensor as tt
from aesara.compile.ops import as_op # noqa
return aesara, tt, as_op
def _import_tensor():
try:
import pytensor as tensor # noqa
import pytensor.tensor as tt
from pytensor.compile.ops import as_op # noqa
except ImportError:
import aesara as tensor # noqa
import aesara.tensor as tt
from aesara.compile.ops import as_op # noqa
return tensor, tt, as_op
def _verify_parameters(self):
"""
......@@ -251,8 +256,8 @@ class Pymc(MCMCSampler):
"""
# check prior is a Sine
pymc, STEP_METHODS, floatX = self._import_external_sampler()
aesara, tt, as_op = self._import_aesara()
pymc, _, floatX = self._import_external_sampler()
_, tt, _ = self._import_tensor()
if isinstance(self.priors[key], Sine):
class PymcSine(pymc.Continuous):
......@@ -296,8 +301,8 @@ class Pymc(MCMCSampler):
"""
# check prior is a Cosine
pymc, STEP_METHODS, floatX = self._import_external_sampler()
aesara, tt, as_op = self._import_aesara()
pymc, _, floatX = self._import_external_sampler()
_, tt, _ = self._import_tensor()
if isinstance(self.priors[key], Cosine):
class PymcCosine(pymc.Continuous):
......@@ -340,8 +345,8 @@ class Pymc(MCMCSampler):
"""
# check prior is a PowerLaw
pymc, STEP_METHODS, floatX = self._import_external_sampler()
aesara, tt, as_op = self._import_aesara()
pymc, _, floatX = self._import_external_sampler()
_, tt, _ = self._import_tensor()
if isinstance(self.priors[key], PowerLaw):
# check power law is set
......@@ -405,8 +410,7 @@ class Pymc(MCMCSampler):
"""
# check prior is a PowerLaw
pymc, STEP_METHODS, floatX = self._import_external_sampler()
aesara, tt, as_op = self._import_aesara()
pymc, _, _ = self._import_external_sampler()
if isinstance(self.priors[key], MultivariateGaussian):
# get names of multivariate Gaussian parameters
mvpars = self.priors[key].mvg.names
......@@ -648,8 +652,14 @@ class Pymc(MCMCSampler):
self.kwargs.update(nuts_kwargs)
with self.pymc_model:
# perform the sampling
trace = pymc.sample(**self.kwargs)
# perform the sampling and then convert to inference data
trace = pymc.sample(**self.kwargs, return_inferencedata=False)
ikwargs = dict(
model=self.pymc_model,
save_warmup=not self.kwargs["discard_tuned_samples"],
log_likelihood=True,
)
trace = pymc.to_inference_data(trace, **ikwargs)
posterior = trace.posterior.to_dataframe().reset_index()
self.result.samples = posterior[self.search_parameter_keys]
......@@ -709,7 +719,7 @@ class Pymc(MCMCSampler):
self.setup_prior_mapping()
self.pymc_priors = dict()
pymc, STEP_METHODS, floatX = self._import_external_sampler()
pymc, _, _ = self._import_external_sampler()
# initialise a dictionary of multivariate Gaussian parameters
self.multivariate_normal_sets = {}
......@@ -804,9 +814,9 @@ class Pymc(MCMCSampler):
Convert any bilby likelihoods to PyMC distributions.
"""
# create aesara Op for the log likelihood if not using a predefined model
pymc, STEP_METHODS, floatX = self._import_external_sampler()
aesara, tt, as_op = self._import_aesara()
# create Op for the log likelihood if not using a predefined model
pymc, _, _ = self._import_external_sampler()
_, tt, _ = self._import_tensor()
class LogLike(tt.Op):
......@@ -838,7 +848,7 @@ class Pymc(MCMCSampler):
(theta,) = inputs
return [g[0] * self.logpgrad(theta)]
# create aesara Op for calculating the gradient of the log likelihood
# create Op for calculating the gradient of the log likelihood
class LogLikeGrad(tt.Op):
itypes = [tt.dvector]
......@@ -993,7 +1003,7 @@ class Pymc(MCMCSampler):
f"Unknown key '{key}' when setting GravitationalWaveTransient likelihood"
)
# convert to aesara tensor variable
# convert to tensor variable
values = tt.as_tensor_variable(list(parameters.values()))
pymc.DensityDist(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment