import numpy as np
import pandas as pd

from bilby.core.likelihood import Multinomial
from bilby.core.prior import DirichletPriorDict
from bilby.core.sampler import run_sampler


n_dim = 3
label = "dirichlet_"
priors = DirichletPriorDict(n_dim=n_dim, label=label)

injection_parameters = dict(
    dirichlet_0=1 / 3,
    dirichlet_1=1 / 3,
    dirichlet_2=1 / 3,
)
data = [injection_parameters[label + str(ii)] * 1000 for ii in range(n_dim)]

likelihood = Multinomial(data=data, n_dimensions=n_dim, label=label)

result = run_sampler(
    likelihood=likelihood, priors=priors, nlive=100, walks=10,
    label="multinomial", injection_parameters=injection_parameters
)

result.posterior[label + str(n_dim - 1)] = 1 - np.sum([result.posterior[key] for key in priors], axis=0)
result.plot_corner(parameters=injection_parameters)

samples = priors.sample(10000)
samples[label + str(n_dim - 1)] = 1 - np.sum([samples[key] for key in samples], axis=0)
result.posterior = pd.DataFrame(samples)
result.plot_corner(parameters=[key for key in samples], filename="outdir/dirichlet_prior_corner.png")