Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
sampler.py 45.06 KiB
import datetime
import os
import time
from collections import Counter
from pathlib import Path

import numpy as np
import pandas as pd

from ..core.result import rejection_sample
from ..core.sampler.base_sampler import (
    MCMCSampler,
    ResumeError,
    SamplerError,
    _sampling_convenience_dump,
    signal_wrapper,
)
from ..core.utils import check_directory_exists_and_if_not_mkdir, logger, safe_file_dump
from . import proposals
from .chain import Chain, Sample
from .utils import LOGLKEY, LOGPKEY, ConvergenceInputs, ParallelTemperingInputs


class Bilby_MCMC(MCMCSampler):
    """The built-in Bilby MCMC sampler

    Parameters
    ----------
    likelihood: likelihood.Likelihood
        A  object with a log_l method
    priors: bilby.core.prior.PriorDict, dict
        Priors to be used in the search.
        This has attributes for each parameter to be sampled.
    outdir: str, optional
        Name of the output directory
    label: str, optional
        Naming scheme of the output files
    use_ratio: bool, optional
        Switch to set whether or not you want to use the log-likelihood ratio
        or just the log-likelihood
    skip_import_verification: bool
        Skips the check if the sampler is installed if true. This is
        only advisable for testing environments
    check_point_plot: bool
        If true, create plots at the check point
    check_point_delta_t: float
        The time in seconds afterwhich to checkpoint (defaults to 30 minutes)
    diagnostic: bool
        If true, create deep-diagnostic plots used for checking convergence
        problems.
    resume: bool
        If true, resume from any existing check point files
    exit_code: int
        The code on which to raise if exiting
    nsamples: int (1000)
        The number of samples to draw
    nensemble: int (1)
        The number of ensemble-chains to run (with periodic communication)
    pt_ensemble: bool (False)
        If true, each run a parallel-tempered set of chains for each
        ensemble-chain (in which case the total number of chains is
        nensemble * ntemps). Else, only the zero-ensemble chain is run with a
        parallel-tempering (in which case the total number of chains is
        nensemble + ntemps - 1).
    ntemps: int (1)
        The number of parallel-tempered chains to run
    Tmax: float, (None)
        If given, the maximum temperature to set the initial temperate-ladder
    Tmax_from_SNR: float (20)
        (Alternative to Tmax): The SNR to estimate an appropriate Tmax from.
    initial_betas: list (None)
        (Alternative to Tmax and Tmax_from_SNR): If given, an initial choice of
        the inverse temperature ladder.
    pt_rejection_sample: bool (False)
        If true, use rejection sampling to draw samples from the pt-chains.
    adapt, adapt_t0, adapt_nu: bool, float, float (True, 100, 10)
        Whether to use adaptation and the adaptation parameters.
        See arXiv:1501.05823 for a description of adapt_t0 and adapt_nu.
    burn_in_nact, thin_by_nact, fixed_discard: float, float, float (10, 1, 0)
        The number of auto-correlation times to discard for burn-in and to
        thin by. The fixed_discard is the number of steps discarded before
        automatic autocorrelation time analysis begins.
    autocorr_c: float (5)
        The step-size for the window search. See emcee.autocorr.integrated_time
        for additional details.
    L1steps: int
        The number of internal steps to take. Improves the scaling performance
        of multiprocessing. Note, all ACTs are calculated based on the saved
        steps. So, the total ACT (or number of steps) is L1steps * tau
        (or L1steps * position).
    L2steps: int
        The number of steps to take before swapping between parallel-tempered
        and ensemble chains.
    npool: int
        The number of multiprocessing cores to use. For efficiency, this must be
        matched to an integer number of the total number of chains.
    printdt: float
        Print an update on the progress every printdt s. Note, each print
        requires an evaluation of the ACT so short print times are unwise.
    min_tau: 1
        The minimum allowed ACT. Can be used to force a larger ACT.
    proposal_cycle: str, bilby.core.sampler.bilby_mcmc.proposals.ProposalCycle
        Either a string pointing to one of the built-in proposal cycles or,
        a proposal cycle.
    stop_after_convergence:
        If running with parallel-tempered chains. Stop updating the chains once
        they have congerged. After this time, random samples will be drawn at
        swap time.
    fixed_tau: int
        A fixed value for the ACT: used for testing purposes.
    tau_window: int, None
        Using tau', a previous estimates of tau, calculate the new tau using
        the last tau_window * tau' steps. If None, the entire chain is used.
    evidence_method: str, [stepping_stone, thermodynamic]
        The evidence calculation method to use. Defaults to stepping_stone, but
        the results of all available methods are stored in the ln_z_dict.
    verbose: bool
        Whether to print diagnostic output during the run.

    """

    default_kwargs = dict(
        nsamples=1000,
        nensemble=1,
        pt_ensemble=False,
        ntemps=1,
        Tmax=None,
        Tmax_from_SNR=20,
        initial_betas=None,
        adapt=True,
        adapt_t0=100,
        adapt_nu=10,
        pt_rejection_sample=False,
        burn_in_nact=10,
        thin_by_nact=1,
        fixed_discard=0,
        autocorr_c=5,
        L1steps=100,
        L2steps=3,
        printdt=60,
        min_tau=1,
        proposal_cycle="default",
        stop_after_convergence=False,
        fixed_tau=None,
        tau_window=None,
        evidence_method="stepping_stone",
    )

    def __init__(
        self,
        likelihood,
        priors,
        outdir="outdir",
        label="label",
        use_ratio=False,
        skip_import_verification=True,
        check_point_plot=True,
        check_point_delta_t=1800,
        diagnostic=False,
        resume=True,
        exit_code=130,
        verbose=True,
        **kwargs,
    ):

        super(Bilby_MCMC, self).__init__(
            likelihood=likelihood,
            priors=priors,
            outdir=outdir,
            label=label,
            use_ratio=use_ratio,
            skip_import_verification=skip_import_verification,
            exit_code=exit_code,
            **kwargs,
        )

        self.check_point_plot = check_point_plot
        self.diagnostic = diagnostic
        self.kwargs["target_nsamples"] = self.kwargs["nsamples"]
        self.L1steps = self.kwargs["L1steps"]
        self.L2steps = self.kwargs["L2steps"]
        self.pt_inputs = ParallelTemperingInputs(
            **{key: self.kwargs[key] for key in ParallelTemperingInputs._fields}
        )
        self.convergence_inputs = ConvergenceInputs(
            **{key: self.kwargs[key] for key in ConvergenceInputs._fields}
        )
        self.proposal_cycle = self.kwargs["proposal_cycle"]
        self.pt_rejection_sample = self.kwargs["pt_rejection_sample"]
        self.evidence_method = self.kwargs["evidence_method"]

        self.printdt = self.kwargs["printdt"]
        check_directory_exists_and_if_not_mkdir(self.outdir)
        self.resume = resume
        self.check_point_delta_t = check_point_delta_t
        self.resume_file = "{}/{}_resume.pickle".format(self.outdir, self.label)

        self.verify_configuration()
        self.verbose = verbose

    def verify_configuration(self):
        if self.convergence_inputs.burn_in_nact / self.kwargs["target_nsamples"] > 0.1:
            logger.warning("Burn-in inefficiency fraction greater than 10%")

    def _translate_kwargs(self, kwargs):
        kwargs = super()._translate_kwargs(kwargs)
        if "printdt" not in kwargs:
            for equiv in ["print_dt", "print_update"]:
                if equiv in kwargs:
                    kwargs["printdt"] = kwargs.pop(equiv)
        if "npool" not in kwargs:
            for equiv in self.npool_equiv_kwargs:
                if equiv in kwargs:
                    kwargs["npool"] = kwargs.pop(equiv)

    @property
    def target_nsamples(self):
        return self.kwargs["target_nsamples"]

    @signal_wrapper
    def run_sampler(self):
        self._setup_pool()
        self.setup_chain_set()
        self.start_time = datetime.datetime.now()
        self.draw()
        self._close_pool()
        self.check_point(ignore_time=True)

        self.result = self.add_data_to_result(
            result=self.result,
            ptsampler=self.ptsampler,
            outdir=self.outdir,
            label=self.label,
            make_plots=self.check_point_plot,
        )

        return self.result

    @staticmethod
    def add_data_to_result(result, ptsampler, outdir, label, make_plots):
        result.samples = ptsampler.samples
        result.log_likelihood_evaluations = result.samples[LOGLKEY]
        result.log_prior_evaluations = result.samples[LOGPKEY]
        ptsampler.compute_evidence(
            outdir=outdir,
            label=label,
            make_plots=make_plots,
        )
        result.log_evidence = ptsampler.ln_z
        result.log_evidence_err = ptsampler.ln_z_err
        result.sampling_time = datetime.timedelta(seconds=ptsampler.sampling_time)
        result.meta_data["bilby_mcmc"] = dict(
            tau=ptsampler.tau,
            convergence_inputs=ptsampler.convergence_inputs._asdict(),
            pt_inputs=ptsampler.pt_inputs._asdict(),
            total_steps=ptsampler.position,
            nsamples=ptsampler.nsamples,
        )
        if ptsampler.pool is not None:
            npool = ptsampler.pool._processes
        else:
            npool = 1
        result.meta_data["run_statistics"] = dict(
            nlikelihood=ptsampler.position * ptsampler.L1steps * ptsampler._nsamplers,
            neffsamples=ptsampler.nsamples * ptsampler.convergence_inputs.thin_by_nact,
            sampling_time_s=result.sampling_time.seconds,
            ncores=npool,
        )

        return result

    def setup_chain_set(self):
        if os.path.isfile(self.resume_file) and self.resume is True:
            self.read_current_state()
            self.ptsampler.pool = self.pool
        else:
            self.init_ptsampler()

    def init_ptsampler(self):

        logger.info(f"Initializing BilbyPTMCMCSampler with:\n{self.get_setup_string()}")
        self.ptsampler = BilbyPTMCMCSampler(
            convergence_inputs=self.convergence_inputs,
            pt_inputs=self.pt_inputs,
            proposal_cycle=self.proposal_cycle,
            pt_rejection_sample=self.pt_rejection_sample,
            pool=self.pool,
            use_ratio=self.use_ratio,
            evidence_method=self.evidence_method,
        )

    def get_setup_string(self):
        string = (
            f"  Convergence settings: {self.convergence_inputs}\n"
            f"  Parallel-tempering settings: {self.pt_inputs}\n"
            f"  proposal_cycle: {self.proposal_cycle}\n"
            f"  pt_rejection_sample: {self.pt_rejection_sample}"
        )
        return string

    def draw(self):
        self._steps_since_last_print = 0
        self._time_since_last_print = 0
        logger.info(f"Drawing {self.target_nsamples} samples")
        logger.info(f"Checkpoint every {self.check_point_delta_t}s")
        logger.info(f"Print update every {self.printdt}s")

        while True:
            t0 = datetime.datetime.now()
            self.ptsampler.step_all_chains()
            dt = (datetime.datetime.now() - t0).total_seconds()
            self.ptsampler.sampling_time += dt
            self._time_since_last_print += dt
            self._steps_since_last_print += self.ptsampler.L1steps

            if self._time_since_last_print > self.printdt:
                tp0 = datetime.datetime.now()
                self.print_progress()
                tp = datetime.datetime.now()
                ppt_frac = (tp - tp0).total_seconds() / self._time_since_last_print
                if ppt_frac > 0.01:
                    logger.warning(
                        f"Non-negligible print progress time (ppt_frac={ppt_frac:0.2f})"
                    )
                self._steps_since_last_print = 0
                self._time_since_last_print = 0

            self.check_point()

            if self.ptsampler.nsamples_last >= self.target_nsamples:
                # Perform a second check without cached values
                if self.ptsampler.nsamples_nocache >= self.target_nsamples:
                    logger.info("Reached convergence: exiting sampling")
                    break

    def check_point(self, ignore_time=False):
        tS = (datetime.datetime.now() - self.start_time).total_seconds()
        if os.path.isfile(self.resume_file):
            tR = time.time() - os.path.getmtime(self.resume_file)
        else:
            tR = np.inf

        if ignore_time or np.min([tS, tR]) > self.check_point_delta_t:
            logger.info("Checkpoint start")
            self.write_current_state()
            self.print_long_progress()
            logger.info("Checkpoint finished")

    def _remove_checkpoint(self):
        """Remove checkpointed state"""
        if os.path.isfile(self.resume_file):
            os.remove(self.resume_file)

    def read_current_state(self):
        import dill

        with open(self.resume_file, "rb") as file:
            self.ptsampler = dill.load(file)
            if self.ptsampler.pt_inputs != self.pt_inputs:
                msg = (
                    f"pt_inputs has changed: {self.ptsampler.pt_inputs} "
                    f"-> {self.pt_inputs}"
                )
                raise ResumeError(msg)
            self.ptsampler.set_convergence_inputs(self.convergence_inputs)
            self.ptsampler.proposal_cycle = self.proposal_cycle
            self.ptsampler.pt_rejection_sample = self.pt_rejection_sample

        logger.info(
            f"Loaded resume file {self.resume_file} "
            f"with {self.ptsampler.position} steps "
            f"setup:\n{self.get_setup_string()}"
        )

    def write_current_state(self):
        import dill

        if not hasattr(self, "ptsampler"):
            logger.debug("Attempted checkpoint before initialization")
            return
        logger.debug("Check point")
        check_directory_exists_and_if_not_mkdir(self.outdir)

        _pool = self.ptsampler.pool
        self.ptsampler.pool = None
        if dill.pickles(self.ptsampler):
            safe_file_dump(self.ptsampler, self.resume_file, dill)
            logger.info("Written checkpoint file {}".format(self.resume_file))
        else:
            logger.warning(
                "Cannot write pickle resume file! " "Job may not resume if interrupted."
            )
            # Touch the file to postpone next check-point attempt
            Path(self.resume_file).touch(exist_ok=True)
        self.ptsampler.pool = _pool

    def print_long_progress(self):
        self.print_per_proposal()
        self.print_tau_dict()
        if self.ptsampler.ntemps > 1:
            self.print_pt_acceptance()
        if self.ptsampler.nensemble > 1:
            self.print_ensemble_acceptance()
        if self.check_point_plot:
            self.plot_progress(
                self.ptsampler, self.label, self.outdir, self.priors, self.diagnostic
            )
            self.ptsampler.compute_evidence(
                outdir=self.outdir, label=self.label, make_plots=True
            )

    def print_ensemble_acceptance(self):
        logger.info(f"Ensemble swaps = {self.ptsampler.swap_counter['ensemble']}")
        logger.info(self.ptsampler.ensemble_proposal_cycle)

    def print_progress(self):
        position = self.ptsampler.position

        # Total sampling time
        sampling_time = datetime.timedelta(seconds=self.ptsampler.sampling_time)
        time = str(sampling_time).split(".")[0]

        # Time for last evaluation set
        time_per_eval_ms = (
            1000 * self._time_since_last_print / self._steps_since_last_print
        )

        # Pull out progress summary
        tau = self.ptsampler.tau
        nsamples = self.ptsampler.nsamples
        minimum_index = self.ptsampler.primary_sampler.chain.minimum_index
        method = self.ptsampler.primary_sampler.chain.minimum_index_method
        mindex_str = f"{minimum_index:0.2e}({method})"
        alpha = self.ptsampler.primary_sampler.acceptance_ratio
        maxl = self.ptsampler.primary_sampler.chain.max_log_likelihood

        nlikelihood = position * self.L1steps * self.ptsampler._nsamplers
        eff = 100 * nsamples / nlikelihood

        # Estimated time til finish (ETF)
        if tau < np.inf:
            remaining_samples = self.target_nsamples - nsamples
            remaining_evals = (
                remaining_samples
                * self.convergence_inputs.thin_by_nact
                * tau
                * self.L1steps
            )
            remaining_time_s = time_per_eval_ms * 1e-3 * remaining_evals
            remaining_time_dt = datetime.timedelta(seconds=remaining_time_s)
            if remaining_samples > 0:
                remaining_time = str(remaining_time_dt).split(".")[0]
            else:
                remaining_time = "0"
        else:
            remaining_time = "-"

        msg = (
            f"{position:0.2e}|{time}|{mindex_str}|t={tau:0.0f}|"
            f"n={nsamples:0.0f}|a={alpha:0.2f}|e={eff:0.1e}%|"
            f"{time_per_eval_ms:0.2f}ms/ev|maxl={maxl:0.2f}|"
            f"ETF={remaining_time}"
        )

        if self.pt_rejection_sample:
            count = self.ptsampler.rejection_sampling_count
            rse = 100 * count / nsamples
            msg += f"|rse={rse:0.2f}%"

        if self.verbose:
            print(msg, flush=True)

    def print_per_proposal(self):
        logger.info("Zero-temperature proposals:")
        for prop in self.ptsampler[0].proposal_cycle.proposal_list:
            logger.info(prop)

    def print_pt_acceptance(self):
        logger.info(f"Temperature swaps = {self.ptsampler.swap_counter['temperature']}")
        for column in self.ptsampler.sampler_list_of_tempered_lists:
            for ii, sampler in enumerate(column):
                total = sampler.pt_accepted + sampler.pt_rejected
                beta = sampler.beta
                if total > 0:
                    ratio = f"{sampler.pt_accepted / total:0.2f}"
                else:
                    ratio = "-"
                logger.info(
                    f"Temp:{ii}<->{ii+1}|"
                    f"beta={beta:0.4g}|"
                    f"hot-samp={sampler.nsamples}|"
                    f"swap={ratio}|"
                    f"conv={sampler.chain.converged}|"
                )

    def print_tau_dict(self):
        msg = f"Current taus={self.ptsampler.primary_sampler.chain.tau_dict}"
        logger.info(msg)

    @staticmethod
    def plot_progress(ptsampler, label, outdir, priors, diagnostic=False):
        logger.info("Creating diagnostic plots")
        for ii, row in ptsampler.sampler_dictionary.items():
            for jj, sampler in enumerate(row):
                plot_label = f"{label}_E{sampler.Eindex}_T{sampler.Tindex}"
                if diagnostic is True or sampler.beta == 1:
                    sampler.chain.plot(
                        outdir=outdir,
                        label=plot_label,
                        priors=priors,
                        all_samples=ptsampler.samples,
                    )


class BilbyPTMCMCSampler(object):
    def __init__(
        self,
        convergence_inputs,
        pt_inputs,
        proposal_cycle,
        pt_rejection_sample,
        pool,
        use_ratio,
        evidence_method,
    ):
        self.set_pt_inputs(pt_inputs)
        self.use_ratio = use_ratio
        self.setup_sampler_dictionary(convergence_inputs, proposal_cycle)
        self.set_convergence_inputs(convergence_inputs)
        self.pt_rejection_sample = pt_rejection_sample
        self.pool = pool
        self.evidence_method = evidence_method

        # Initialize counters
        self.swap_counter = Counter()
        self.swap_counter["temperature"] = 0
        self.swap_counter["L2-temperature"] = 0
        self.swap_counter["ensemble"] = 0
        self.swap_counter["L2-ensemble"] = int(self.L2steps / 2) + 1

        self._nsamples_dict = {}
        self.ensemble_proposal_cycle = proposals.get_default_ensemble_proposal_cycle(
            _sampling_convenience_dump.priors
        )
        self.sampling_time = 0
        self.ln_z_dict = dict()
        self.ln_z_err_dict = dict()

    def get_initial_betas(self):
        pt_inputs = self.pt_inputs
        if self.ntemps == 1:
            betas = np.array([1])
        elif pt_inputs.initial_betas is not None:
            betas = np.array(pt_inputs.initial_betas)
        elif pt_inputs.Tmax is not None:
            betas = np.logspace(0, -np.log10(pt_inputs.Tmax), pt_inputs.ntemps)
        elif pt_inputs.Tmax_from_SNR is not None:
            ndim = len(_sampling_convenience_dump.priors.non_fixed_keys)
            target_hot_likelihood = ndim / 2
            Tmax = pt_inputs.Tmax_from_SNR**2 / (2 * target_hot_likelihood)
            betas = np.logspace(0, -np.log10(Tmax), pt_inputs.ntemps)
        else:
            raise SamplerError("Unable to set temperature ladder from inputs")

        if len(betas) != self.ntemps:
            raise SamplerError("Temperatures do not match ntemps")

        return betas

    def setup_sampler_dictionary(self, convergence_inputs, proposal_cycle):

        betas = self.get_initial_betas()
        logger.info(
            f"Initializing BilbyPTMCMCSampler with:"
            f"ntemps={self.ntemps},"
            f"nensemble={self.nensemble},"
            f"pt_ensemble={self.pt_ensemble},"
            f"initial_betas={betas}\n"
        )
        self.sampler_dictionary = dict()
        for Tindex, beta in enumerate(betas):
            if beta == 1 or self.pt_ensemble:
                n = self.nensemble
            else:
                n = 1
            temp_sampler_list = [
                BilbyMCMCSampler(
                    beta=beta,
                    Tindex=Tindex,
                    Eindex=Eindex,
                    convergence_inputs=convergence_inputs,
                    proposal_cycle=proposal_cycle,
                    use_ratio=self.use_ratio,
                )
                for Eindex in range(n)
            ]
            self.sampler_dictionary[Tindex] = temp_sampler_list

        # Store data
        self._nsamplers = len(self.sampler_list)

    @property
    def sampler_list(self):
        """A list of all individual samplers"""
        return [s for item in self.sampler_dictionary.values() for s in item]

    @sampler_list.setter
    def sampler_list(self, sampler_list):
        for sampler in sampler_list:
            self.sampler_dictionary[sampler.Tindex][sampler.Eindex] = sampler

    def sampler_list_by_column(self, column):
        return [row[column] for row in self.sampler_dictionary.values()]

    @property
    def sampler_list_of_tempered_lists(self):
        if self.pt_ensemble:
            return [self.sampler_list_by_column(ii) for ii in range(self.nensemble)]
        else:
            return [self.sampler_list_by_column(0)]

    @property
    def tempered_sampler_list(self):
        return [s for s in self.sampler_list if s.beta < 1]

    @property
    def zerotemp_sampler_list(self):
        return [s for s in self.sampler_list if s.beta == 1]

    @property
    def primary_sampler(self):
        return self.sampler_dictionary[0][0]

    def set_pt_inputs(self, pt_inputs):
        logger.info(f"Setting parallel tempering inputs={pt_inputs}")
        self.pt_inputs = pt_inputs

        # Pull out only what is needed
        self.ntemps = pt_inputs.ntemps
        self.nensemble = pt_inputs.nensemble
        self.pt_ensemble = pt_inputs.pt_ensemble
        self.adapt = pt_inputs.adapt
        self.adapt_t0 = pt_inputs.adapt_t0
        self.adapt_nu = pt_inputs.adapt_nu

    def set_convergence_inputs(self, convergence_inputs):
        logger.info(f"Setting convergence_inputs={convergence_inputs}")
        self.convergence_inputs = convergence_inputs
        self.L1steps = convergence_inputs.L1steps
        self.L2steps = convergence_inputs.L2steps
        for sampler in self.sampler_list:
            sampler.set_convergence_inputs(convergence_inputs)

    @property
    def tau(self):
        return self.primary_sampler.chain.tau

    @property
    def minimum_index(self):
        return self.primary_sampler.chain.minimum_index

    @property
    def nsamples(self):
        pos = self.primary_sampler.chain.position
        if hasattr(self, "_nsamples_dict") is False:
            self._nsamples_dict = {}
        if pos in self._nsamples_dict:
            return self._nsamples_dict[pos]
        logger.debug(f"Calculating nsamples at {pos}")
        self._nsamples_dict[pos] = self._calculate_nsamples()
        return self._nsamples_dict[pos]

    @property
    def nsamples_last(self):
        if len(self._nsamples_dict) > 0:
            return list(self._nsamples_dict.values())[-1]
        else:
            return 0

    @property
    def nsamples_nocache(self):
        for sampler in self.sampler_list:
            sampler.chain.tau_nocache
        pos = self.primary_sampler.chain.position
        self._nsamples_dict[pos] = self._calculate_nsamples()
        return self._nsamples_dict[pos]

    def _calculate_nsamples(self):
        nsamples_list = []
        for sampler in self.zerotemp_sampler_list:
            nsamples_list.append(sampler.nsamples)
        if self.pt_rejection_sample:
            for samp in self.sampler_list[1:]:
                nsamples_list.append(
                    len(samp.rejection_sample_zero_temperature_samples())
                )
        return sum(nsamples_list)

    @property
    def samples(self):
        sample_list = []
        for sampler in self.zerotemp_sampler_list:
            sample_list.append(sampler.samples)
        if self.pt_rejection_sample:
            for sampler in self.tempered_sampler_list:
                sample_list.append(sampler.samples)
        return pd.concat(sample_list)

    @property
    def position(self):
        return self.primary_sampler.chain.position

    @property
    def evaluations(self):
        return int(self.position * len(self.sampler_list))

    def __getitem__(self, index):
        return self.sampler_list[index]

    def step_all_chains(self):
        if self.pool:
            self.sampler_list = self.pool.map(call_step, self.sampler_list)
        else:
            for ii, sampler in enumerate(self.sampler_list):
                self.sampler_list[ii] = sampler.step()

        if self.nensemble > 1 and self.swap_counter["L2-ensemble"] >= self.L2steps:
            self.swap_counter["ensemble"] += 1
            self.swap_counter["L2-ensemble"] = 0
            self.ensemble_step()

        if self.ntemps > 1 and self.swap_counter["L2-temperature"] >= self.L2steps:
            self.swap_counter["temperature"] += 1
            self.swap_counter["L2-temperature"] = 0
            self.swap_tempered_chains()
            if self.position < self.adapt_t0 * 10:
                if self.adapt:
                    self.adapt_temperatures()
            elif self.adapt:
                logger.info(
                    f"Adaptation of temperature chains finished at step {self.position}"
                )
                self.adapt = False

        self.swap_counter["L2-ensemble"] += 1
        self.swap_counter["L2-temperature"] += 1

    @staticmethod
    def _get_sample_to_swap(sampler):
        if sampler.chain.converged is False:
            v = sampler.chain[-1]
        else:
            v = sampler.chain.random_sample
        logl = v[LOGLKEY]
        return v, logl

    def swap_tempered_chains(self):
        if self.pt_ensemble:
            Eindexs = range(self.nensemble)
        else:
            Eindexs = [0]
        for Eindex in Eindexs:
            for Tindex in range(self.ntemps - 1):
                sampleri = self.sampler_dictionary[Tindex][Eindex]
                vi, logli = self._get_sample_to_swap(sampleri)
                betai = sampleri.beta

                samplerj = self.sampler_dictionary[Tindex + 1][Eindex]
                vj, loglj = self._get_sample_to_swap(samplerj)
                betaj = samplerj.beta

                dbeta = betaj - betai
                with np.errstate(over="ignore"):
                    alpha_swap = np.exp(dbeta * (logli - loglj))

                if np.random.uniform(0, 1) <= alpha_swap:
                    sampleri.chain[-1] = vj
                    samplerj.chain[-1] = vi
                    self.sampler_dictionary[Tindex][Eindex] = sampleri
                    self.sampler_dictionary[Tindex + 1][Eindex] = samplerj
                    sampleri.pt_accepted += 1
                else:
                    sampleri.pt_rejected += 1

    def ensemble_step(self):
        for Tindex, sampler_list in self.sampler_dictionary.items():
            if len(sampler_list) > 1:
                for Eindex, sampler in enumerate(sampler_list):
                    curr = sampler.chain.current_sample
                    proposal = self.ensemble_proposal_cycle.get_proposal()
                    complement = [s.chain for s in sampler_list if s != sampler]
                    prop, log_factor = proposal(sampler.chain, complement)
                    logp = sampler.log_prior(prop)

                    if logp == -np.inf:
                        sampler.reject_proposal(curr, proposal)
                        self.sampler_dictionary[Tindex][Eindex] = sampler
                        continue

                    prop[LOGPKEY] = logp
                    prop[LOGLKEY] = sampler.log_likelihood(prop)
                    alpha = np.exp(
                        log_factor
                        + sampler.beta * prop[LOGLKEY]
                        + prop[LOGPKEY]
                        - sampler.beta * curr[LOGLKEY]
                        - curr[LOGPKEY]
                    )

                    if np.random.uniform(0, 1) <= alpha:
                        sampler.accept_proposal(prop, proposal)
                    else:
                        sampler.reject_proposal(curr, proposal)
                    self.sampler_dictionary[Tindex][Eindex] = sampler

    def adapt_temperatures(self):
        """Adapt the temperature of the chains

        Using the dynamic temperature selection described in arXiv:1501.05823,
        adapt the chains to target a constant swap ratio. This method is based
        on github.com/willvousden/ptemcee/tree/master/ptemcee
        """

        self.primary_sampler.chain.minimum_index_adapt = self.position
        tt = self.swap_counter["temperature"]
        for sampler_list in self.sampler_list_of_tempered_lists:
            betas = np.array([s.beta for s in sampler_list])
            ratios = np.array([s.acceptance_ratio for s in sampler_list[:-1]])

            # Modulate temperature adjustments with a hyperbolic decay.
            decay = self.adapt_t0 / (tt + self.adapt_t0)
            kappa = decay / self.adapt_nu

            # Construct temperature adjustments.
            dSs = kappa * (ratios[:-1] - ratios[1:])

            # Compute new ladder (hottest and coldest chains don't move).
            deltaTs = np.diff(1 / betas[:-1])
            deltaTs *= np.exp(dSs)
            betas[1:-1] = 1 / (np.cumsum(deltaTs) + 1 / betas[0])
            for sampler, beta in zip(sampler_list, betas):
                sampler.beta = beta

    @property
    def ln_z(self):
        return self.ln_z_dict.get(self.evidence_method, np.nan)

    @property
    def ln_z_err(self):
        return self.ln_z_err_dict.get(self.evidence_method, np.nan)

    def compute_evidence(self, outdir, label, make_plots=True):
        if self.ntemps == 1:
            return
        kwargs = dict(outdir=outdir, label=label, make_plots=make_plots)
        methods = dict(
            thermodynamic=self.thermodynamic_integration_evidence,
            stepping_stone=self.stepping_stone_evidence,
        )
        for key, method in methods.items():
            ln_z, ln_z_err = self.compute_evidence_per_ensemble(method, kwargs)
            self.ln_z_dict[key] = ln_z
            self.ln_z_err_dict[key] = ln_z_err
            logger.info(
                f"Log-evidence of {ln_z:0.2f}+/-{ln_z_err:0.2f} calculated using {key} method"
            )

    def compute_evidence_per_ensemble(self, method, kwargs):
        from scipy.special import logsumexp

        if self.ntemps == 1:
            return np.nan, np.nan

        lnZ_list = []
        lnZerr_list = []
        for index, ptchain in enumerate(self.sampler_list_of_tempered_lists):
            lnZ, lnZerr = method(ptchain, **kwargs)
            lnZ_list.append(lnZ)
            lnZerr_list.append(lnZerr)

        N = len(lnZ_list)

        # Average lnZ
        lnZ = logsumexp(lnZ_list, b=1.0 / N)

        # Propagate uncertainty in combined evidence
        lnZerr = 0.5 * logsumexp(2 * np.array(lnZerr_list), b=1.0 / N)

        return lnZ, lnZerr

    def thermodynamic_integration_evidence(
        self, ptchain, outdir, label, make_plots=True
    ):
        """Computes the evidence using thermodynamic integration

        We compute the evidence without the burnin samples, no thinning
        """
        from scipy.stats import sem

        betas = []
        mean_lnlikes = []
        sem_lnlikes = []
        for sampler in ptchain:
            lnlikes = sampler.chain.get_1d_array(LOGLKEY)
            mindex = sampler.chain.minimum_index
            lnlikes = lnlikes[mindex:]
            mean_lnlikes.append(np.mean(lnlikes))
            sem_lnlikes.append(sem(lnlikes))
            betas.append(sampler.beta)

        # Convert to array and re-order
        betas = np.array(betas)[::-1]
        mean_lnlikes = np.array(mean_lnlikes)[::-1]
        sem_lnlikes = np.array(sem_lnlikes)[::-1]

        lnZ, lnZerr = self._compute_evidence_from_mean_lnlikes(betas, mean_lnlikes)

        if make_plots:
            plot_label = f"{label}_E{ptchain[0].Eindex}"
            self._create_lnZ_plots(
                betas=betas,
                mean_lnlikes=mean_lnlikes,
                outdir=outdir,
                label=plot_label,
                sem_lnlikes=sem_lnlikes,
            )

        return lnZ, lnZerr

    def stepping_stone_evidence(self, ptchain, outdir, label, make_plots=True):
        """
        Compute the evidence using the stepping stone approximation.

        See https://arxiv.org/abs/1810.04488 and
        https://pubmed.ncbi.nlm.nih.gov/21187451/ for details.

        The uncertainty calculation is hopefully combining the evidence in each
        of the steps.

        Returns
        -------
        ln_z: float
            Estimate of the natural log evidence
        ln_z_err: float
            Estimate of the uncertainty in the evidence
        """
        # Order in increasing beta
        ptchain.reverse()

        # Get maximum usable set of samples across the ptchain
        min_index = max([samp.chain.minimum_index for samp in ptchain])
        max_index = min([len(samp.chain.get_1d_array(LOGLKEY)) for samp in ptchain])
        tau = self.tau

        if max_index - min_index <= 1 or np.isinf(tau):
            return np.nan, np.nan

        # Read in log likelihoods
        ln_likes = np.array(
            [samp.chain.get_1d_array(LOGLKEY)[min_index:max_index] for samp in ptchain]
        )[:-1].T

        # Thin to only independent samples
        ln_likes = ln_likes[:: int(self.tau), :]
        steps = ln_likes.shape[0]

        # Calculate delta betas
        betas = np.array([samp.beta for samp in ptchain])

        ln_z, ln_ratio = self._calculate_stepping_stone(betas, ln_likes)

        # Implementation of the bootstrap method described in Maturana-Russel
        # et. al. (2019) to estimate the evidence uncertainty.
        ll = 50  # Block length
        repeats = 100  # Repeats
        ln_z_realisations = []
        try:
            for _ in range(repeats):
                idxs = [np.random.randint(i, i + ll) for i in range(steps - ll)]
                ln_z_realisations.append(
                    self._calculate_stepping_stone(betas, ln_likes[idxs, :])[0]
                )
            ln_z_err = np.std(ln_z_realisations)
        except ValueError:
            logger.info("Failed to estimate stepping stone uncertainty")
            ln_z_err = np.nan

        if make_plots:
            plot_label = f"{label}_E{ptchain[0].Eindex}"
            self._create_stepping_stone_plot(
                means=ln_ratio,
                outdir=outdir,
                label=plot_label,
            )

        return ln_z, ln_z_err

    @staticmethod
    def _calculate_stepping_stone(betas, ln_likes):
        from scipy.special import logsumexp

        n_samples = ln_likes.shape[0]
        d_betas = betas[1:] - betas[:-1]
        ln_ratio = logsumexp(d_betas * ln_likes, axis=0) - np.log(n_samples)
        return sum(ln_ratio), ln_ratio

    @staticmethod
    def _compute_evidence_from_mean_lnlikes(betas, mean_lnlikes):
        lnZ = np.trapz(mean_lnlikes, betas)
        z2 = np.trapz(mean_lnlikes[::-1][::2][::-1], betas[::-1][::2][::-1])
        lnZerr = np.abs(lnZ - z2)
        return lnZ, lnZerr

    def _create_lnZ_plots(self, betas, mean_lnlikes, outdir, label, sem_lnlikes=None):
        import matplotlib.pyplot as plt

        logger.debug("Creating thermodynamic evidence diagnostic plot")

        fig, ax1 = plt.subplots()
        if betas[-1] == 0:
            x, y = betas[:-1], mean_lnlikes[:-1]
        else:
            x, y = betas, mean_lnlikes
        if sem_lnlikes is not None:
            ax1.errorbar(x, y, sem_lnlikes, fmt="-")
        else:
            ax1.plot(x, y, "-o")
        ax1.set_xscale("log")
        ax1.set_xlabel(r"$\beta$")
        ax1.set_ylabel(r"$\langle \log(\mathcal{L}) \rangle$")

        plt.tight_layout()
        fig.savefig("{}/{}_beta_lnl.png".format(outdir, label))
        plt.close()

    def _create_stepping_stone_plot(self, means, outdir, label):
        import matplotlib.pyplot as plt

        logger.debug("Creating stepping stone evidence diagnostic plot")

        n_steps = len(means)

        fig, axes = plt.subplots(nrows=2, figsize=(8, 10))

        ax = axes[0]
        ax.plot(np.arange(1, n_steps + 1), means)
        ax.set_xlabel("$k$")
        ax.set_ylabel("$r_{k}$")

        ax = axes[1]
        ax.plot(np.arange(1, n_steps + 1), np.cumsum(means[::1])[::1])
        ax.set_xlabel("$k$")
        ax.set_ylabel("Cumulative $\\ln Z$")

        plt.tight_layout()
        fig.savefig("{}/{}_stepping_stone.png".format(outdir, label))
        plt.close()

    @property
    def rejection_sampling_count(self):
        if self.pt_rejection_sample:
            counts = 0
            for column in self.sampler_list_of_tempered_lists:
                for sampler in column:
                    counts += sampler.rejection_sampling_count
            return counts
        else:
            return None


class BilbyMCMCSampler(object):
    def __init__(
        self,
        convergence_inputs,
        proposal_cycle=None,
        beta=1,
        Tindex=0,
        Eindex=0,
        use_ratio=False,
    ):
        self.beta = beta
        self.Tindex = Tindex
        self.Eindex = Eindex
        self.use_ratio = use_ratio

        self.parameters = _sampling_convenience_dump.priors.non_fixed_keys
        self.ndim = len(self.parameters)

        full_sample_dict = _sampling_convenience_dump.priors.sample()
        initial_sample = {
            k: v
            for k, v in full_sample_dict.items()
            if k in _sampling_convenience_dump.priors.non_fixed_keys
        }
        initial_sample = Sample(initial_sample)
        initial_sample[LOGLKEY] = self.log_likelihood(initial_sample)
        initial_sample[LOGPKEY] = self.log_prior(initial_sample)

        self.chain = Chain(initial_sample=initial_sample)
        self.set_convergence_inputs(convergence_inputs)

        self.accepted = 0
        self.rejected = 0
        self.pt_accepted = 0
        self.pt_rejected = 0
        self.rejection_sampling_count = 0

        if isinstance(proposal_cycle, str):
            # Only print warnings for the primary sampler
            if Tindex == 0 and Eindex == 0:
                warn = True
            else:
                warn = False

            self.proposal_cycle = proposals.get_proposal_cycle(
                proposal_cycle,
                _sampling_convenience_dump.priors,
                L1steps=self.chain.L1steps,
                warn=warn,
            )
        elif isinstance(proposal_cycle, proposals.ProposalCycle):
            self.proposal_cycle = proposal_cycle
        else:
            raise SamplerError("Proposal cycle not understood")

        if self.Tindex == 0 and self.Eindex == 0:
            logger.info(f"Using {self.proposal_cycle}")

    def set_convergence_inputs(self, convergence_inputs):
        for key, val in convergence_inputs._asdict().items():
            setattr(self.chain, key, val)
        self.target_nsamples = convergence_inputs.target_nsamples
        self.stop_after_convergence = convergence_inputs.stop_after_convergence

    def log_likelihood(self, sample):
        _sampling_convenience_dump.likelihood.parameters.update(sample.sample_dict)

        if self.use_ratio:
            logl = _sampling_convenience_dump.likelihood.log_likelihood_ratio()
        else:
            logl = _sampling_convenience_dump.likelihood.log_likelihood()

        return logl

    def log_prior(self, sample):
        return _sampling_convenience_dump.priors.ln_prob(sample.parameter_only_dict)

    def accept_proposal(self, prop, proposal):
        self.chain.append(prop)
        self.accepted += 1
        proposal.accepted += 1

    def reject_proposal(self, curr, proposal):
        self.chain.append(curr)
        self.rejected += 1
        proposal.rejected += 1

    def step(self):
        if self.stop_after_convergence and self.chain.converged:
            return self

        internal_steps = 0
        internal_accepted = 0
        internal_rejected = 0
        curr = self.chain.current_sample.copy()
        while internal_steps < self.chain.L1steps:
            internal_steps += 1
            proposal = self.proposal_cycle.get_proposal()
            prop, log_factor = proposal(self.chain)
            logp = self.log_prior(prop)

            if np.isinf(logp) or np.isnan(logp):
                internal_rejected += 1
                proposal.rejected += 1
                continue

            prop[LOGPKEY] = logp
            prop[LOGLKEY] = self.log_likelihood(prop)

            if np.isinf(prop[LOGLKEY]) or np.isnan(prop[LOGLKEY]):
                internal_rejected += 1
                proposal.rejected += 1
                continue

            with np.errstate(over="ignore"):
                alpha = np.exp(
                    log_factor
                    + self.beta * prop[LOGLKEY]
                    + prop[LOGPKEY]
                    - self.beta * curr[LOGLKEY]
                    - curr[LOGPKEY]
                )

            if np.random.uniform(0, 1) <= alpha:
                internal_accepted += 1
                proposal.accepted += 1
                curr = prop
                self.chain.current_sample = curr
            else:
                internal_rejected += 1
                proposal.rejected += 1

        self.chain.append(curr)
        self.rejected += internal_rejected
        self.accepted += internal_accepted
        return self

    @property
    def nsamples(self):
        nsamples = self.chain.nsamples
        if nsamples > self.target_nsamples and self.chain.converged is False:
            logger.debug(f"Temperature {self.Tindex} chain reached convergence")
            self.chain.converged = True
        return nsamples

    @property
    def acceptance_ratio(self):
        return self.accepted / (self.accepted + self.rejected)

    @property
    def samples(self):
        if self.beta == 1:
            return self.chain.samples
        else:
            return self.rejection_sample_zero_temperature_samples(print_message=True)

    def rejection_sample_zero_temperature_samples(self, print_message=False):
        beta = self.beta
        chain = self.chain
        hot_samples = pd.DataFrame(
            chain._chain_array[chain.minimum_index : chain.position], columns=chain.keys
        )
        if len(hot_samples) == 0:
            logger.debug(
                f"Rejection sampling for Temp {self.Tindex} failed: "
                "no usable hot samples"
            )
            return hot_samples

        # Pull out log likelihood
        zerotemp_logl = hot_samples[LOGLKEY]

        # Revert to true likelihood if needed
        if _sampling_convenience_dump.use_ratio:
            zerotemp_logl += (
                _sampling_convenience_dump.likelihood.noise_log_likelihood()
            )

        # Calculate normalised weights
        log_weights = (1 - beta) * zerotemp_logl
        max_weight = np.max(log_weights)
        unnormalised_weights = np.exp(log_weights - max_weight)
        weights = unnormalised_weights / np.sum(unnormalised_weights)

        # Rejection sample
        samples = rejection_sample(hot_samples, weights)

        # Logging
        self.rejection_sampling_count = len(samples)

        if print_message:
            logger.info(
                f"Rejection sampling Temp {self.Tindex}, beta={beta:0.2f} "
                f"yielded {len(samples)} samples"
            )
        return samples


# Methods used to aid parallelisation:


def call_step(sampler):
    sampler = sampler.step()
    return sampler