From 44267c50c138e61f80e8f1670effef5771b52f17 Mon Sep 17 00:00:00 2001
From: Matthew Pitkin <matthew.pitkin@ligo.org>
Date: Fri, 9 Sep 2022 15:04:56 +0000
Subject: [PATCH] Switch PyMC3 to PyMC (which is the new name starting from
 PyMC v4.0.0)

---
 bilby/core/sampler/__init__.py                |   4 +-
 bilby/core/sampler/{pymc3.py => pymc.py}      | 276 +++++++++---------
 docs/samplers.txt                             |   2 +-
 ...ion_pymc3.py => linear_regression_pymc.py} |   4 +-
 ...near_regression_pymc_custom_likelihood.py} |  45 ++-
 sampler_requirements.txt                      |   2 +-
 .../sampler/{pymc3_test.py => pymc_test.py}   |  14 +-
 test/integration/sampler_run_test.py          |   4 +-
 8 files changed, 169 insertions(+), 182 deletions(-)
 rename bilby/core/sampler/{pymc3.py => pymc.py} (82%)
 rename examples/core_examples/alternative_samplers/{linear_regression_pymc3.py => linear_regression_pymc.py} (97%)
 rename examples/core_examples/alternative_samplers/{linear_regression_pymc3_custom_likelihood.py => linear_regression_pymc_custom_likelihood.py} (77%)
 rename test/core/sampler/{pymc3_test.py => pymc_test.py} (89%)

diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py
index 56c32ed3d..ce9422f85 100644
--- a/bilby/core/sampler/__init__.py
+++ b/bilby/core/sampler/__init__.py
@@ -21,7 +21,7 @@ from .nestle import Nestle
 from .polychord import PyPolyChord
 from .ptemcee import Ptemcee
 from .ptmcmc import PTMCMCSampler
-from .pymc3 import Pymc3
+from .pymc import Pymc
 from .pymultinest import Pymultinest
 from .ultranest import Ultranest
 from .zeus import Zeus
@@ -38,7 +38,7 @@ IMPLEMENTED_SAMPLERS = {
     "nestle": Nestle,
     "ptemcee": Ptemcee,
     "ptmcmcsampler": PTMCMCSampler,
-    "pymc3": Pymc3,
+    "pymc": Pymc,
     "pymultinest": Pymultinest,
     "pypolychord": PyPolyChord,
     "ultranest": Ultranest,
diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc.py
similarity index 82%
rename from bilby/core/sampler/pymc3.py
rename to bilby/core/sampler/pymc.py
index a67094eb0..95a57dd4e 100644
--- a/bilby/core/sampler/pymc3.py
+++ b/bilby/core/sampler/pymc.py
@@ -14,13 +14,13 @@ from ..utils import derivatives, infer_args_from_method
 from .base_sampler import MCMCSampler
 
 
-class Pymc3(MCMCSampler):
-    """bilby wrapper of the PyMC3 sampler (https://docs.pymc.io/)
+class Pymc(MCMCSampler):
+    """bilby wrapper of the PyMC sampler (https://www.pymc.io/)
 
     All keyword arguments (i.e., the kwargs) passed to `run_sampler` will be
-    propapated to `pymc3.sample` where appropriate, see documentation for that
+    propapated to `pymc.sample` where appropriate, see documentation for that
     class for further help. Under Other Parameters, we list commonly used
-    kwargs and the bilby, or where appropriate, PyMC3 defaults.
+    kwargs and the bilby, or where appropriate, PyMC defaults.
 
     Parameters
     ==========
@@ -40,11 +40,11 @@ class Pymc3(MCMCSampler):
         particular variable names (these are case insensitive). If passing a
         dictionary of methods, the values keyed on particular variables can be
         lists of methods to form compound steps. If no method is provided for
-        any particular variable then PyMC3 will automatically decide upon a
+        any particular variable then PyMC will automatically decide upon a
         default, with the first option being the NUTS sampler. The currently
         allowed methods are 'NUTS', 'HamiltonianMC', 'Metropolis',
         'BinaryMetropolis', 'BinaryGibbsMetropolis', 'Slice', and
-        'CategoricalGibbsMetropolis'. Note: you cannot provide a PyMC3 step
+        'CategoricalGibbsMetropolis'. Note: you cannot provide a PyMC step
         method function itself here as it is outside of the model context
         manager.
     step_kwargs: dict
@@ -59,7 +59,7 @@ class Pymc3(MCMCSampler):
         step=None,
         init="auto",
         n_init=200000,
-        start=None,
+        initvals=None,
         trace=None,
         chain_idx=0,
         chains=2,
@@ -109,7 +109,7 @@ class Pymc3(MCMCSampler):
         self.default_step_kwargs = {m.__name__.lower(): None for m in STEP_METHODS}
         self.default_kwargs.update(self.default_step_kwargs)
 
-        super(Pymc3, self).__init__(
+        super(Pymc, self).__init__(
             likelihood=likelihood,
             priors=priors,
             outdir=outdir,
@@ -124,24 +124,24 @@ class Pymc3(MCMCSampler):
 
     @staticmethod
     def _import_external_sampler():
-        import pymc3
-        from pymc3.sampling import STEP_METHODS
-        from pymc3.theanof import floatX
+        import pymc
+        from pymc.aesaraf import floatX
+        from pymc.step_methods import STEP_METHODS
 
-        return pymc3, STEP_METHODS, floatX
+        return pymc, STEP_METHODS, floatX
 
     @staticmethod
-    def _import_theano():
-        import theano  # noqa
-        import theano.tensor as tt
-        from theano.compile.ops import as_op  # noqa
+    def _import_aesara():
+        import aesara  # noqa
+        import aesara.tensor as tt
+        from aesara.compile.ops import as_op  # noqa
 
-        return theano, tt, as_op
+        return aesara, tt, as_op
 
     def _verify_parameters(self):
         """
         Change `_verify_parameters()` to just pass, i.e., don't try and
-        evaluate the likelihood for PyMC3.
+        evaluate the likelihood for PyMC.
         """
         pass
 
@@ -154,64 +154,64 @@ class Pymc3(MCMCSampler):
     def setup_prior_mapping(self):
         """
         Set the mapping between predefined bilby priors and the equivalent
-        PyMC3 distributions.
+        PyMC distributions.
         """
 
         prior_map = {}
         self.prior_map = prior_map
 
-        # predefined PyMC3 distributions
+        # predefined PyMC distributions
         prior_map["Gaussian"] = {
-            "pymc3": "Normal",
-            "argmap": {"mu": "mu", "sigma": "sd"},
+            "pymc": "Normal",
+            "argmap": {"mu": "mu", "sigma": "sigma"},
         }
         prior_map["TruncatedGaussian"] = {
-            "pymc3": "TruncatedNormal",
+            "pymc": "TruncatedNormal",
             "argmap": {
                 "mu": "mu",
-                "sigma": "sd",
+                "sigma": "sigma",
                 "minimum": "lower",
                 "maximum": "upper",
             },
         }
-        prior_map["HalfGaussian"] = {"pymc3": "HalfNormal", "argmap": {"sigma": "sd"}}
+        prior_map["HalfGaussian"] = {"pymc": "HalfNormal", "argmap": {"sigma": "sigma"}}
         prior_map["Uniform"] = {
-            "pymc3": "Uniform",
+            "pymc": "Uniform",
             "argmap": {"minimum": "lower", "maximum": "upper"},
         }
         prior_map["LogNormal"] = {
-            "pymc3": "Lognormal",
-            "argmap": {"mu": "mu", "sigma": "sd"},
+            "pymc": "Lognormal",
+            "argmap": {"mu": "mu", "sigma": "sigma"},
         }
         prior_map["Exponential"] = {
-            "pymc3": "Exponential",
+            "pymc": "Exponential",
             "argmap": {"mu": "lam"},
             "argtransform": {"mu": lambda mu: 1.0 / mu},
         }
         prior_map["StudentT"] = {
-            "pymc3": "StudentT",
-            "argmap": {"df": "nu", "mu": "mu", "scale": "sd"},
+            "pymc": "StudentT",
+            "argmap": {"df": "nu", "mu": "mu", "scale": "sigma"},
         }
         prior_map["Beta"] = {
-            "pymc3": "Beta",
+            "pymc": "Beta",
             "argmap": {"alpha": "alpha", "beta": "beta"},
         }
         prior_map["Logistic"] = {
-            "pymc3": "Logistic",
+            "pymc": "Logistic",
             "argmap": {"mu": "mu", "scale": "s"},
         }
         prior_map["Cauchy"] = {
-            "pymc3": "Cauchy",
+            "pymc": "Cauchy",
             "argmap": {"alpha": "alpha", "beta": "beta"},
         }
         prior_map["Gamma"] = {
-            "pymc3": "Gamma",
+            "pymc": "Gamma",
             "argmap": {"k": "alpha", "theta": "beta"},
             "argtransform": {"theta": lambda theta: 1.0 / theta},
         }
-        prior_map["ChiSquared"] = {"pymc3": "ChiSquared", "argmap": {"nu": "nu"}}
+        prior_map["ChiSquared"] = {"pymc": "ChiSquared", "argmap": {"nu": "nu"}}
         prior_map["Interped"] = {
-            "pymc3": "Interpolated",
+            "pymc": "Interpolated",
             "argmap": {"xx": "x_points", "yy": "pdf_points"},
         }
         prior_map["Normal"] = prior_map["Gaussian"]
@@ -237,7 +237,7 @@ class Pymc3(MCMCSampler):
 
     def _deltafunction_prior(self, key, **kwargs):
         """
-        Map the bilby delta function prior to a single value for PyMC3.
+        Map the bilby delta function prior to a single value for PyMC.
         """
 
         # check prior is a DeltaFunction
@@ -248,15 +248,15 @@ class Pymc3(MCMCSampler):
 
     def _sine_prior(self, key):
         """
-        Map the bilby Sine prior to a PyMC3 style function
+        Map the bilby Sine prior to a PyMC style function
         """
 
         # check prior is a Sine
-        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
-        theano, tt, as_op = self._import_theano()
+        pymc, STEP_METHODS, floatX = self._import_external_sampler()
+        aesara, tt, as_op = self._import_aesara()
         if isinstance(self.priors[key], Sine):
 
-            class Pymc3Sine(pymc3.Continuous):
+            class PymcSine(pymc.Continuous):
                 def __init__(self, lower=0.0, upper=np.pi):
                     if lower >= upper:
                         raise ValueError("Lower bound is above upper bound!")
@@ -272,20 +272,20 @@ class Pymc3(MCMCSampler):
                         - upper * tt.cos(upper)
                     ) / self.norm
 
-                    transform = pymc3.distributions.transforms.interval(lower, upper)
+                    transform = pymc.distributions.transforms.interval(lower, upper)
 
-                    super(Pymc3Sine, self).__init__(transform=transform)
+                    super(PymcSine, self).__init__(transform=transform)
 
                 def logp(self, value):
                     upper = self.upper
                     lower = self.lower
-                    return pymc3.distributions.dist_math.bound(
+                    return pymc.distributions.dist_math.bound(
                         tt.log(tt.sin(value) / self.norm),
                         lower <= value,
                         value <= upper,
                     )
 
-            return Pymc3Sine(
+            return PymcSine(
                 key, lower=self.priors[key].minimum, upper=self.priors[key].maximum
             )
         else:
@@ -293,15 +293,15 @@ class Pymc3(MCMCSampler):
 
     def _cosine_prior(self, key):
         """
-        Map the bilby Cosine prior to a PyMC3 style function
+        Map the bilby Cosine prior to a PyMC style function
         """
 
         # check prior is a Cosine
-        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
-        theano, tt, as_op = self._import_theano()
+        pymc, STEP_METHODS, floatX = self._import_external_sampler()
+        aesara, tt, as_op = self._import_aesara()
         if isinstance(self.priors[key], Cosine):
 
-            class Pymc3Cosine(pymc3.Continuous):
+            class PymcCosine(pymc.Continuous):
                 def __init__(self, lower=-np.pi / 2.0, upper=np.pi / 2.0):
                     if lower >= upper:
                         raise ValueError("Lower bound is above upper bound!")
@@ -316,20 +316,20 @@ class Pymc3(MCMCSampler):
                         - tt.cos(lower)
                     ) / self.norm
 
-                    transform = pymc3.distributions.transforms.interval(lower, upper)
+                    transform = pymc.distributions.transforms.interval(lower, upper)
 
-                    super(Pymc3Cosine, self).__init__(transform=transform)
+                    super(PymcCosine, self).__init__(transform=transform)
 
                 def logp(self, value):
                     upper = self.upper
                     lower = self.lower
-                    return pymc3.distributions.dist_math.bound(
+                    return pymc.distributions.dist_math.bound(
                         tt.log(tt.cos(value) / self.norm),
                         lower <= value,
                         value <= upper,
                     )
 
-            return Pymc3Cosine(
+            return PymcCosine(
                 key, lower=self.priors[key].minimum, upper=self.priors[key].maximum
             )
         else:
@@ -337,12 +337,12 @@ class Pymc3(MCMCSampler):
 
     def _powerlaw_prior(self, key):
         """
-        Map the bilby PowerLaw prior to a PyMC3 style function
+        Map the bilby PowerLaw prior to a PyMC style function
         """
 
         # check prior is a PowerLaw
-        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
-        theano, tt, as_op = self._import_theano()
+        pymc, STEP_METHODS, floatX = self._import_external_sampler()
+        aesara, tt, as_op = self._import_aesara()
         if isinstance(self.priors[key], PowerLaw):
 
             # check power law is set
@@ -353,12 +353,12 @@ class Pymc3(MCMCSampler):
                 # use Pareto distribution
                 palpha = -(1.0 + self.priors[key].alpha)
 
-                return pymc3.Bound(pymc3.Pareto, upper=self.priors[key].minimum)(
+                return pymc.Bound(pymc.Pareto, upper=self.priors[key].minimum)(
                     key, alpha=palpha, m=self.priors[key].maximum
                 )
             else:
 
-                class Pymc3PowerLaw(pymc3.Continuous):
+                class PymcPowerLaw(pymc.Continuous):
                     def __init__(self, lower, upper, alpha, testval=1):
                         falpha = alpha
                         self.lower = lower = tt.as_tensor_variable(floatX(lower))
@@ -374,11 +374,9 @@ class Pymc3(MCMCSampler):
                                 * (tt.pow(self.upper, beta) - tt.pow(self.lower, beta))
                             )
 
-                        transform = pymc3.distributions.transforms.interval(
-                            lower, upper
-                        )
+                        transform = pymc.distributions.transforms.interval(lower, upper)
 
-                        super(Pymc3PowerLaw, self).__init__(
+                        super(PymcPowerLaw, self).__init__(
                             transform=transform, testval=testval
                         )
 
@@ -387,13 +385,13 @@ class Pymc3(MCMCSampler):
                         lower = self.lower
                         alpha = self.alpha
 
-                        return pymc3.distributions.dist_math.bound(
+                        return pymc.distributions.dist_math.bound(
                             alpha * tt.log(value) + tt.log(self.norm),
                             lower <= value,
                             value <= upper,
                         )
 
-                return Pymc3PowerLaw(
+                return PymcPowerLaw(
                     key,
                     lower=self.priors[key].minimum,
                     upper=self.priors[key].maximum,
@@ -404,12 +402,12 @@ class Pymc3(MCMCSampler):
 
     def _multivariate_normal_prior(self, key):
         """
-        Map the bilby MultivariateNormal prior to a PyMC3 style function.
+        Map the bilby MultivariateNormal prior to a PyMC style function.
         """
 
         # check prior is a PowerLaw
-        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
-        theano, tt, as_op = self._import_theano()
+        pymc, STEP_METHODS, floatX = self._import_external_sampler()
+        aesara, tt, as_op = self._import_aesara()
         if isinstance(self.priors[key], MultivariateGaussian):
             # get names of multivariate Gaussian parameters
             mvpars = self.priors[key].mvg.names
@@ -447,7 +445,7 @@ class Pymc3(MCMCSampler):
                         upper[i] = maxmu[i] + 100.0 * maxsigma[i]
 
                 # create a bounded MultivariateNormal distribution
-                BoundedMvN = pymc3.Bound(pymc3.MvNormal, lower=lower, upper=upper)
+                BoundedMvN = pymc.Bound(pymc.MvNormal, lower=lower, upper=upper)
 
                 comp_dists = []  # list of any component modes
                 for i in range(mvg.nmodes):
@@ -462,7 +460,7 @@ class Pymc3(MCMCSampler):
 
                 # create a Mixture model
                 setname = f"mixture{self.multivariate_normal_num_sets}"
-                mix = pymc3.Mixture(
+                mix = pymc.Mixture(
                     setname,
                     w=mvg.weights,
                     comp_dists=comp_dists,
@@ -486,7 +484,7 @@ class Pymc3(MCMCSampler):
 
     def run_sampler(self):
         # set the step method
-        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
+        pymc, STEP_METHODS, floatX = self._import_external_sampler()
         step_methods = {m.__name__.lower(): m.__name__ for m in STEP_METHODS}
         if "step" in self._kwargs:
             self.step_method = self._kwargs.pop("step")
@@ -527,15 +525,15 @@ class Pymc3(MCMCSampler):
         else:
             self.step_method = None
 
-        # initialise the PyMC3 model
-        self.pymc3_model = pymc3.Model()
+        # initialise the PyMC model
+        self.pymc_model = pymc.Model()
 
         # set the prior
         self.set_prior()
 
         # if a custom log_likelihood function requires a `sampler` argument
         # then use that log_likelihood function, with the assumption that it
-        # takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines
+        # takes in a Pymc Sampler, with a pymc_model attribute, and defines
         # the likelihood within that context manager
         likeargs = infer_args_from_method(self.likelihood.log_likelihood)
         if "sampler" in likeargs:
@@ -580,7 +578,7 @@ class Pymc3(MCMCSampler):
         if isinstance(self.step_method, dict):
             # create list of step methods (any not given will default to NUTS)
             self.kwargs["step"] = []
-            with self.pymc3_model:
+            with self.pymc_model:
                 for key in self.step_method:
                     # check for a compound step list
                     if isinstance(self.step_method[key], list):
@@ -591,7 +589,7 @@ class Pymc3(MCMCSampler):
                                 curmethod,
                                 key,
                                 nuts_kwargs,
-                                pymc3,
+                                pymc,
                                 step_kwargs,
                                 step_methods,
                             )
@@ -602,12 +600,12 @@ class Pymc3(MCMCSampler):
                             curmethod,
                             key,
                             nuts_kwargs,
-                            pymc3,
+                            pymc,
                             step_kwargs,
                             step_methods,
                         )
         else:
-            with self.pymc3_model:
+            with self.pymc_model:
                 # check for a compound step list
                 if isinstance(self.step_method, list):
                     compound = []
@@ -617,7 +615,7 @@ class Pymc3(MCMCSampler):
                         args, nuts_kwargs = self._create_args_and_nuts_kwargs(
                             curmethod, nuts_kwargs, step_kwargs
                         )
-                        compound.append(pymc3.__dict__[step_methods[curmethod]](**args))
+                        compound.append(pymc.__dict__[step_methods[curmethod]](**args))
                         self.kwargs["step"] = compound
                 else:
                     self.kwargs["step"] = None
@@ -627,32 +625,32 @@ class Pymc3(MCMCSampler):
                         args, nuts_kwargs = self._create_args_and_nuts_kwargs(
                             curmethod, nuts_kwargs, step_kwargs
                         )
-                        self.kwargs["step"] = pymc3.__dict__[step_methods[curmethod]](
+                        self.kwargs["step"] = pymc.__dict__[step_methods[curmethod]](
                             **args
                         )
                     else:
                         # re-add step_kwargs if no step methods are set
                         if len(step_kwargs) > 0 and StrictVersion(
-                            pymc3.__version__
+                            pymc.__version__
                         ) < StrictVersion("3.7"):
                             self.kwargs["step_kwargs"] = step_kwargs
 
         # check whether only NUTS step method has been assigned
         if np.all([sm.lower() == "nuts" for sm in methodslist]):
-            # in this case we can let PyMC3 autoinitialise NUTS, so remove the step methods and re-add nuts_kwargs
+            # in this case we can let PyMC autoinitialise NUTS, so remove the step methods and re-add nuts_kwargs
             self.kwargs["step"] = None
 
-            if len(nuts_kwargs) > 0 and StrictVersion(
-                pymc3.__version__
-            ) < StrictVersion("3.7"):
+            if len(nuts_kwargs) > 0 and StrictVersion(pymc.__version__) < StrictVersion(
+                "3.7"
+            ):
                 self.kwargs["nuts_kwargs"] = nuts_kwargs
             elif len(nuts_kwargs) > 0:
                 # add NUTS kwargs to standard kwargs
                 self.kwargs.update(nuts_kwargs)
 
-        with self.pymc3_model:
+        with self.pymc_model:
             # perform the sampling
-            trace = pymc3.sample(**self.kwargs, return_inferencedata=True)
+            trace = pymc.sample(**self.kwargs)
 
         posterior = trace.posterior.to_dataframe().reset_index()
         self.result.samples = posterior[self.search_parameter_keys]
@@ -674,7 +672,7 @@ class Pymc3(MCMCSampler):
         return args, nuts_kwargs
 
     def _create_nuts_kwargs(
-        self, curmethod, key, nuts_kwargs, pymc3, step_kwargs, step_methods
+        self, curmethod, key, nuts_kwargs, pymc, step_kwargs, step_methods
     ):
         if curmethod == "nuts":
             args, nuts_kwargs = self._get_nuts_args(nuts_kwargs, step_kwargs)
@@ -684,9 +682,7 @@ class Pymc3(MCMCSampler):
             else:
                 args = {}
         self.kwargs["step"].append(
-            pymc3.__dict__[step_methods[curmethod]](
-                vars=[self.pymc3_priors[key]], **args
-            )
+            pymc.__dict__[step_methods[curmethod]](vars=[self.pymc_priors[key]], **args)
         )
         return nuts_kwargs
 
@@ -702,57 +698,55 @@ class Pymc3(MCMCSampler):
             args = {}
         return args, nuts_kwargs
 
-    def _pymc3_version(self):
-        pymc3, _, _ = self._import_external_sampler()
-        return pymc3.__version__
+    def _pymc_version(self):
+        pymc, _, _ = self._import_external_sampler()
+        return pymc.__version__
 
     def set_prior(self):
         """
-        Set the PyMC3 prior distributions.
+        Set the PyMC prior distributions.
         """
 
         self.setup_prior_mapping()
 
-        self.pymc3_priors = dict()
-        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
+        self.pymc_priors = dict()
+        pymc, STEP_METHODS, floatX = self._import_external_sampler()
 
         # initialise a dictionary of multivariate Gaussian parameters
         self.multivariate_normal_sets = {}
         self.multivariate_normal_num_sets = 0
 
         # set the parameter prior distributions (in the model context manager)
-        with self.pymc3_model:
+        with self.pymc_model:
             for key in self.priors:
                 # if the prior contains ln_prob method that takes a 'sampler' argument
                 # then try using that
                 lnprobargs = infer_args_from_method(self.priors[key].ln_prob)
                 if "sampler" in lnprobargs:
                     try:
-                        self.pymc3_priors[key] = self.priors[key].ln_prob(sampler=self)
+                        self.pymc_priors[key] = self.priors[key].ln_prob(sampler=self)
                     except RuntimeError:
-                        raise RuntimeError(
-                            ("Problem setting PyMC3 prior for ", f"'{key}'")
-                        )
+                        raise RuntimeError((f"Problem setting PyMC prior for '{key}'"))
                 else:
                     # use Prior distribution name
                     distname = self.priors[key].__class__.__name__
 
                     if distname in self.prior_map:
-                        # check if we have a predefined PyMC3 distribution
+                        # check if we have a predefined PyMC distribution
                         if (
-                            "pymc3" in self.prior_map[distname]
+                            "pymc" in self.prior_map[distname]
                             and "argmap" in self.prior_map[distname]
                         ):
-                            # check the required arguments for the PyMC3 distribution
-                            pymc3distname = self.prior_map[distname]["pymc3"]
+                            # check the required arguments for the PyMC distribution
+                            pymcdistname = self.prior_map[distname]["pymc"]
 
-                            if pymc3distname not in pymc3.__dict__:
+                            if pymcdistname not in pymc.__dict__:
                                 raise ValueError(
-                                    f"Prior '{pymc3distname}' is not a known PyMC3 distribution."
+                                    f"Prior '{pymcdistname}' is not a known PyMC distribution."
                                 )
 
                             reqargs = infer_args_from_method(
-                                pymc3.__dict__[pymc3distname].__init__
+                                pymc.__dict__[pymcdistname].dist
                             )
 
                             # set keyword arguments
@@ -790,11 +784,11 @@ class Pymc3(MCMCSampler):
                                 else:
                                     if parg in reqargs:
                                         priorkwargs[parg] = None
-                            self.pymc3_priors[key] = pymc3.__dict__[pymc3distname](
+                            self.pymc_priors[key] = pymc.__dict__[pymcdistname](
                                 key, **priorkwargs
                             )
                         elif "internal" in self.prior_map[distname]:
-                            self.pymc3_priors[key] = self.prior_map[distname][
+                            self.pymc_priors[key] = self.prior_map[distname][
                                 "internal"
                             ](key)
                         else:
@@ -808,12 +802,12 @@ class Pymc3(MCMCSampler):
 
     def set_likelihood(self):
         """
-        Convert any bilby likelihoods to PyMC3 distributions.
+        Convert any bilby likelihoods to PyMC distributions.
         """
 
-        # create theano Op for the log likelihood if not using a predefined model
-        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
-        theano, tt, as_op = self._import_theano()
+        # create aesara Op for the log likelihood if not using a predefined model
+        pymc, STEP_METHODS, floatX = self._import_external_sampler()
+        aesara, tt, as_op = self._import_aesara()
 
         class LogLike(tt.Op):
 
@@ -845,7 +839,7 @@ class Pymc3(MCMCSampler):
                 (theta,) = inputs
                 return [g[0] * self.logpgrad(theta)]
 
-        # create theano Op for calculating the gradient of the log likelihood
+        # create aesara Op for calculating the gradient of the log likelihood
         class LogLikeGrad(tt.Op):
 
             itypes = [tt.dvector]
@@ -878,7 +872,7 @@ class Pymc3(MCMCSampler):
 
                 outputs[0][0] = grads
 
-        with self.pymc3_model:
+        with self.pymc_model:
             #  check if it is a predefined likelhood function
             if isinstance(self.likelihood, GaussianLikelihood):
                 # check required attributes exist
@@ -891,24 +885,24 @@ class Pymc3(MCMCSampler):
                         "Gaussian Likelihood does not have all the correct attributes!"
                     )
 
-                if "sigma" in self.pymc3_priors:
+                if "sigma" in self.pymc_priors:
                     # if sigma is suppled use that value
                     if self.likelihood.sigma is None:
-                        self.likelihood.sigma = self.pymc3_priors.pop("sigma")
+                        self.likelihood.sigma = self.pymc_priors.pop("sigma")
                     else:
-                        del self.pymc3_priors["sigma"]
+                        del self.pymc_priors["sigma"]
 
-                for key in self.pymc3_priors:
+                for key in self.pymc_priors:
                     if key not in self.likelihood.function_keys:
                         raise ValueError(f"Prior key '{key}' is not a function key!")
 
-                model = self.likelihood.func(self.likelihood.x, **self.pymc3_priors)
+                model = self.likelihood.func(self.likelihood.x, **self.pymc_priors)
 
                 # set the distribution
-                pymc3.Normal(
+                pymc.Normal(
                     "likelihood",
                     mu=model,
-                    sd=self.likelihood.sigma,
+                    sigma=self.likelihood.sigma,
                     observed=self.likelihood.y,
                 )
             elif isinstance(self.likelihood, PoissonLikelihood):
@@ -920,15 +914,15 @@ class Pymc3(MCMCSampler):
                         "Poisson Likelihood does not have all the correct attributes!"
                     )
 
-                for key in self.pymc3_priors:
+                for key in self.pymc_priors:
                     if key not in self.likelihood.function_keys:
                         raise ValueError(f"Prior key '{key}' is not a function key!")
 
                 # get rate function
-                model = self.likelihood.func(self.likelihood.x, **self.pymc3_priors)
+                model = self.likelihood.func(self.likelihood.x, **self.pymc_priors)
 
                 # set the distribution
-                pymc3.Poisson("likelihood", mu=model, observed=self.likelihood.y)
+                pymc.Poisson("likelihood", mu=model, observed=self.likelihood.y)
             elif isinstance(self.likelihood, ExponentialLikelihood):
                 # check required attributes exist
                 if not hasattr(self.likelihood, "x") or not hasattr(
@@ -938,15 +932,15 @@ class Pymc3(MCMCSampler):
                         "Exponential Likelihood does not have all the correct attributes!"
                     )
 
-                for key in self.pymc3_priors:
+                for key in self.pymc_priors:
                     if key not in self.likelihood.function_keys:
                         raise ValueError(f"Prior key '{key}' is not a function key!")
 
                 # get mean function
-                model = self.likelihood.func(self.likelihood.x, **self.pymc3_priors)
+                model = self.likelihood.func(self.likelihood.x, **self.pymc_priors)
 
                 # set the distribution
-                pymc3.Exponential(
+                pymc.Exponential(
                     "likelihood", lam=1.0 / model, observed=self.likelihood.y
                 )
             elif isinstance(self.likelihood, StudentTLikelihood):
@@ -961,25 +955,25 @@ class Pymc3(MCMCSampler):
                         "StudentT Likelihood does not have all the correct attributes!"
                     )
 
-                if "nu" in self.pymc3_priors:
+                if "nu" in self.pymc_priors:
                     # if nu is suppled use that value
                     if self.likelihood.nu is None:
-                        self.likelihood.nu = self.pymc3_priors.pop("nu")
+                        self.likelihood.nu = self.pymc_priors.pop("nu")
                     else:
-                        del self.pymc3_priors["nu"]
+                        del self.pymc_priors["nu"]
 
-                for key in self.pymc3_priors:
+                for key in self.pymc_priors:
                     if key not in self.likelihood.function_keys:
                         raise ValueError(f"Prior key '{key}' is not a function key!")
 
-                model = self.likelihood.func(self.likelihood.x, **self.pymc3_priors)
+                model = self.likelihood.func(self.likelihood.x, **self.pymc_priors)
 
                 # set the distribution
-                pymc3.StudentT(
+                pymc.StudentT(
                     "likelihood",
                     nu=self.likelihood.nu,
                     mu=model,
-                    sd=self.likelihood.sigma,
+                    sigma=self.likelihood.sigma,
                     observed=self.likelihood.y,
                 )
             elif isinstance(
@@ -988,22 +982,22 @@ class Pymc3(MCMCSampler):
             ):
                 # set theano Op - pass _search_parameter_keys, which only contains non-fixed variables
                 logl = LogLike(
-                    self._search_parameter_keys, self.likelihood, self.pymc3_priors
+                    self._search_parameter_keys, self.likelihood, self.pymc_priors
                 )
 
                 parameters = dict()
                 for key in self._search_parameter_keys:
                     try:
-                        parameters[key] = self.pymc3_priors[key]
+                        parameters[key] = self.pymc_priors[key]
                     except KeyError:
                         raise KeyError(
                             f"Unknown key '{key}' when setting GravitationalWaveTransient likelihood"
                         )
 
-                # convert to theano tensor variable
+                # convert to aesara tensor variable
                 values = tt.as_tensor_variable(list(parameters.values()))
 
-                pymc3.DensityDist(
+                pymc.DensityDist(
                     "likelihood", lambda v: logl(v), observed={"v": values}
                 )
             else:
diff --git a/docs/samplers.txt b/docs/samplers.txt
index fad1d62ba..69699dad0 100644
--- a/docs/samplers.txt
+++ b/docs/samplers.txt
@@ -69,7 +69,7 @@ MCMC samplers
 - bilby-mcmc :code:`bilby.bilby_mcmc.sampler.Bilby_MCMC`
 - emcee :code:`bilby.core.sampler.emcee.Emcee`
 - ptemcee :code:`bilby.core.sampler.ptemcee.Ptemcee`
-- pymc3 :code:`bilby.core.sampler.pymc3.Pymc3`
+- pymc :code:`bilby.core.sampler.pymc.Pymc`
 - zeus :code:`bilby.core.sampler.zeus.Zeus`
 
 
diff --git a/examples/core_examples/alternative_samplers/linear_regression_pymc3.py b/examples/core_examples/alternative_samplers/linear_regression_pymc.py
similarity index 97%
rename from examples/core_examples/alternative_samplers/linear_regression_pymc3.py
rename to examples/core_examples/alternative_samplers/linear_regression_pymc.py
index 75cbf16ae..0efc872be 100644
--- a/examples/core_examples/alternative_samplers/linear_regression_pymc3.py
+++ b/examples/core_examples/alternative_samplers/linear_regression_pymc.py
@@ -11,7 +11,7 @@ import numpy as np
 from bilby.core.likelihood import GaussianLikelihood
 
 # A few simple setup steps
-label = "linear_regression_pymc3"
+label = "linear_regression_pymc"
 outdir = "outdir"
 bilby.utils.check_directory_exists_and_if_not_mkdir(outdir)
 
@@ -58,7 +58,7 @@ priors["c"] = bilby.core.prior.Uniform(-2, 2, "c")
 result = bilby.run_sampler(
     likelihood=likelihood,
     priors=priors,
-    sampler="pymc3",
+    sampler="pymc",
     injection_parameters=injection_parameters,
     outdir=outdir,
     draws=2000,
diff --git a/examples/core_examples/alternative_samplers/linear_regression_pymc3_custom_likelihood.py b/examples/core_examples/alternative_samplers/linear_regression_pymc_custom_likelihood.py
similarity index 77%
rename from examples/core_examples/alternative_samplers/linear_regression_pymc3_custom_likelihood.py
rename to examples/core_examples/alternative_samplers/linear_regression_pymc_custom_likelihood.py
index d2074304f..e9763770c 100644
--- a/examples/core_examples/alternative_samplers/linear_regression_pymc3_custom_likelihood.py
+++ b/examples/core_examples/alternative_samplers/linear_regression_pymc_custom_likelihood.py
@@ -11,10 +11,10 @@ would give equivalent results as using the pre-defined 'Gaussian Likelihood'
 import bilby
 import matplotlib.pyplot as plt
 import numpy as np
-import pymc3 as pm
+import pymc as pm
 
 # A few simple setup steps
-label = "linear_regression_pymc3_custom_likelihood"
+label = "linear_regression_pymc_custom_likelihood"
 outdir = "outdir"
 bilby.utils.check_directory_exists_and_if_not_mkdir(outdir)
 
@@ -50,7 +50,7 @@ fig.savefig("{}/{}_data.png".format(outdir, label))
 
 # Parameter estimation: we now define a Gaussian Likelihood class relevant for
 # our model.
-class GaussianLikelihoodPyMC3(bilby.core.likelihood.GaussianLikelihood):
+class GaussianLikelihoodPyMC(bilby.core.likelihood.GaussianLikelihood):
     def __init__(self, x, y, sigma, func):
         """
         A general Gaussian likelihood - the parameters are inferred from the
@@ -68,45 +68,44 @@ class GaussianLikelihoodPyMC3(bilby.core.likelihood.GaussianLikelihood):
             will require a prior and will be sampled over (unless a fixed
             value is given).
         """
-        super(GaussianLikelihoodPyMC3, self).__init__(x=x, y=y, func=func, sigma=sigma)
+        super(GaussianLikelihoodPyMC, self).__init__(x=x, y=y, func=func, sigma=sigma)
 
     def log_likelihood(self, sampler=None):
         """
         Parameters
         ----------
-        sampler: :class:`bilby.core.sampler.Pymc3`
+        sampler: :class:`bilby.core.sampler.Pymc`
             A Sampler object must be passed containing the prior distributions
             and PyMC3 :class:`~pymc3.Model` to use as a context manager.
             If this is not passed, the super class is called and the regular
             likelihood is evaluated.
         """
 
-        from bilby.core.sampler import Pymc3
+        from bilby.core.sampler import Pymc
 
-        if not isinstance(sampler, Pymc3):
-            print(sampler, type(sampler))
-            return super(GaussianLikelihoodPyMC3, self).log_likelihood()
+        if not isinstance(sampler, Pymc):
+            return super(GaussianLikelihoodPyMC, self).log_likelihood()
 
-        if not hasattr(sampler, "pymc3_model"):
-            raise AttributeError("Sampler has not PyMC3 model attribute")
+        if not hasattr(sampler, "pymc_model"):
+            raise AttributeError("Sampler has not PyMC model attribute")
 
-        with sampler.pymc3_model:
-            mdist = sampler.pymc3_priors["m"]
-            cdist = sampler.pymc3_priors["c"]
+        with sampler.pymc_model:
+            mdist = sampler.pymc_priors["m"]
+            cdist = sampler.pymc_priors["c"]
 
             mu = model(time, mdist, cdist)
 
             # set the likelihood distribution
-            pm.Normal("likelihood", mu=mu, sd=self.sigma, observed=self.y)
+            pm.Normal("likelihood", mu=mu, sigma=self.sigma, observed=self.y)
 
 
 # Now lets instantiate a version of our GaussianLikelihood, giving it
 # the time, data and signal model
-likelihood = GaussianLikelihoodPyMC3(time, data, sigma, model)
+likelihood = GaussianLikelihoodPyMC(time, data, sigma, model)
 
 
-# Define a custom prior for one of the parameter for use with PyMC3
-class PyMC3UniformPrior(bilby.core.prior.Uniform):
+# Define a custom prior for one of the parameter for use with PyMC
+class PyMCUniformPrior(bilby.core.prior.Uniform):
     def __init__(self, minimum, maximum, name=None, latex_label=None):
         """
         Uniform prior with bounds (should be equivalent to bilby.prior.Uniform)
@@ -124,10 +123,10 @@ class PyMC3UniformPrior(bilby.core.prior.Uniform):
         float or array to be passed to the superclass.
         """
 
-        from bilby.core.sampler import Pymc3
+        from bilby.core.sampler import Pymc
 
-        if not isinstance(sampler, Pymc3):
-            return super(PyMC3UniformPrior, self).ln_prob(sampler)
+        if not isinstance(sampler, Pymc):
+            return super(PyMCUniformPrior, self).ln_prob(sampler)
 
         return pm.Uniform(self.name, lower=self.minimum, upper=self.maximum)
 
@@ -136,13 +135,13 @@ class PyMC3UniformPrior(bilby.core.prior.Uniform):
 # We make a prior
 priors = dict()
 priors["m"] = bilby.core.prior.Uniform(0, 5, "m")
-priors["c"] = PyMC3UniformPrior(-2, 2, "c")
+priors["c"] = PyMCUniformPrior(-2, 2, "c")
 
 # And run sampler
 result = bilby.run_sampler(
     likelihood=likelihood,
     priors=priors,
-    sampler="pymc3",
+    sampler="pymc",
     draws=1000,
     tune=1000,
     discard_tuned_samples=True,
diff --git a/sampler_requirements.txt b/sampler_requirements.txt
index d60093c8b..64d2c8c50 100644
--- a/sampler_requirements.txt
+++ b/sampler_requirements.txt
@@ -3,7 +3,7 @@ dynesty
 emcee
 nestle
 ptemcee
-pymc3>=3.6
+pymc>=4.0.0
 pymultinest
 kombine
 ultranest>=3.0.0
diff --git a/test/core/sampler/pymc3_test.py b/test/core/sampler/pymc_test.py
similarity index 89%
rename from test/core/sampler/pymc3_test.py
rename to test/core/sampler/pymc_test.py
index b3bb758d3..c904e1fd8 100644
--- a/test/core/sampler/pymc3_test.py
+++ b/test/core/sampler/pymc_test.py
@@ -1,22 +1,16 @@
 import unittest
 from unittest.mock import MagicMock
 
-import pytest
-
 import bilby
 
 
-@pytest.mark.xfail(
-    raises=AttributeError,
-    reason="Dependency issue with pymc3 causes attribute error on import",
-)
-class TestPyMC3(unittest.TestCase):
+class TestPyMC(unittest.TestCase):
     def setUp(self):
         self.likelihood = MagicMock()
         self.priors = bilby.core.prior.PriorDict(
             dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1))
         )
-        self.sampler = bilby.core.sampler.Pymc3(
+        self.sampler = bilby.core.sampler.Pymc(
             self.likelihood,
             self.priors,
             outdir="outdir",
@@ -37,7 +31,7 @@ class TestPyMC3(unittest.TestCase):
             step=None,
             init="auto",
             n_init=200000,
-            start=None,
+            initvals=None,
             trace=None,
             chain_idx=0,
             chains=2,
@@ -61,7 +55,7 @@ class TestPyMC3(unittest.TestCase):
             step=None,
             init="auto",
             n_init=200000,
-            start=None,
+            initvals=None,
             trace=None,
             chain_idx=0,
             chains=2,
diff --git a/test/integration/sampler_run_test.py b/test/integration/sampler_run_test.py
index 3aa2157c0..17307e7d7 100644
--- a/test/integration/sampler_run_test.py
+++ b/test/integration/sampler_run_test.py
@@ -54,7 +54,7 @@ _sampler_kwargs = dict(
         frac_threshold=0.5,
     ),
     PTMCMCSampler=dict(Niter=101, burn=2, isave=100),
-    # pymc3=dict(draws=50, tune=50, n_init=250),  removed until testing issue can be resolved
+    pymc=dict(draws=50, tune=50, n_init=250),
     pymultinest=dict(nlive=100),
     pypolychord=dict(nlive=100),
     ultranest=dict(nlive=100, temporary_directory=False),
@@ -65,7 +65,7 @@ sampler_imports = dict(
     dynamic_dynesty="dynesty"
 )
 
-no_pool_test = ["dnest4", "pymultinest", "nestle", "ptmcmcsampler", "pypolychord", "ultranest"]
+no_pool_test = ["dnest4", "pymultinest", "nestle", "ptmcmcsampler", "pypolychord", "ultranest", "pymc"]
 
 
 def slow_func(x, m, c):
-- 
GitLab