Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
conditional_prior.py 1.97 KiB
"""
This tutorial demonstrates how we can sample a prior in the shape of a ball.
Note that this will not end up sampling uniformly in that space, constraint
priors are more suitable for that.
This implementation will draw a value for the x-coordinate from p(x), and
given that draw a value for the
y-coordinate from p(y|x), and given that draw a value for the z-coordinate
from p(z|x,y).
Only the x-coordinate will end up being uniform for this
"""
import bilby
import numpy as np


class ZeroLikelihood(bilby.core.likelihood.Likelihood):
    """Flat likelihood. This always returns 0.
    This way our posterior distribution is exactly the prior distribution."""

    def log_likelihood(self):
        return 0


def condition_func_y(reference_params, x):
    """Condition function for our p(y|x) prior."""
    radius = 0.5 * (reference_params["maximum"] - reference_params["minimum"])
    y_max = np.sqrt(radius ** 2 - x ** 2)
    return dict(minimum=-y_max, maximum=y_max)


def condition_func_z(reference_params, x, y):
    """Condition function for our p(z|x, y) prior."""
    radius = 0.5 * (reference_params["maximum"] - reference_params["minimum"])
    z_max = np.sqrt(radius ** 2 - x ** 2 - y ** 2)
    return dict(minimum=-z_max, maximum=z_max)


# Set up the conditional priors and the flat likelihood
priors = bilby.core.prior.ConditionalPriorDict()
priors["x"] = bilby.core.prior.Uniform(minimum=-1, maximum=1, latex_label="$x$")
priors["y"] = bilby.core.prior.ConditionalUniform(
    condition_func=condition_func_y, minimum=-1, maximum=1, latex_label="$y$"
)
priors["z"] = bilby.core.prior.ConditionalUniform(
    condition_func=condition_func_z, minimum=-1, maximum=1, latex_label="$z$"
)
likelihood = ZeroLikelihood(parameters=dict(x=0, y=0, z=0))

# Sample the prior distribution
res = bilby.run_sampler(
    likelihood=likelihood,
    priors=priors,
    sampler="dynesty",
    nlive=5000,
    label="conditional_prior",
    outdir="outdir",
    resume=False,
    clean=True,
)
res.plot_corner()