Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
sampler_run_test.py 5.81 KiB
import multiprocessing
import os
import sys
import threading
import time
from signal import SIGINT

multiprocessing.set_start_method("fork")  # noqa

import unittest
import pytest
from parameterized import parameterized
import shutil

import bilby
import numpy as np


_sampler_kwargs = dict(
    bilby_mcmc=dict(nsamples=200, printdt=1),
    cpnest=dict(nlive=100),
    dnest4=dict(
        max_num_levels=2,
        num_steps=10,
        new_level_interval=10,
        num_per_step=10,
        thread_steps=1,
        num_particles=50,
        max_pool=1,
    ),
    dynesty=dict(nlive=10, sample="acceptance-walk", nact=5, proposals=["diff"]),
    dynamic_dynesty=dict(
        nlive_init=10,
        nlive_batch=10,
        dlogz_init=1.0,
        maxbatch=0,
        maxcall=100,
        sample="act-walk",
    ),
    emcee=dict(iterations=1000, nwalkers=10),
    kombine=dict(iterations=200, nwalkers=10, autoburnin=False),
    nessai=dict(
        nlive=100,
        poolsize=100,
        max_iteration=500,
    ),
    nestle=dict(nlive=100),
    ptemcee=dict(
        nsamples=100,
        nwalkers=50,
        burn_in_act=1,
        ntemps=1,
        frac_threshold=0.5,
    ),
    PTMCMCSampler=dict(Niter=101, burn=100, covUpdate=100, isave=100),
    pymc=dict(draws=50, tune=50, n_init=250),
    pymultinest=dict(nlive=100),
    ultranest=dict(nlive=100, temporary_directory=False),
    zeus=dict(nwalkers=10, iterations=100)
)

sampler_imports = dict(
    bilby_mcmc="bilby",
    dynamic_dynesty="dynesty"
)

no_pool_test = ["dnest4", "pymultinest", "nestle", "ptmcmcsampler", "ultranest", "pymc"]

loaded_samplers = {k: v.load() for k, v in bilby.core.sampler.IMPLEMENTED_SAMPLERS.items()}


def slow_func(x, m, c):
    time.sleep(0.01)
    return m * x + c


def model(x, m, c):
    return m * x + c


class TestRunningSamplers(unittest.TestCase):
    def setUp(self):
        bilby.core.utils.random.seed(42)
        bilby.core.utils.command_line_args.bilby_test_mode = False
        rng = bilby.core.utils.random.rng
        self.x = np.linspace(0, 1, 11)
        self.injection_parameters = dict(m=0.5, c=0.2)
        self.sigma = 0.1
        self.y = model(self.x, **self.injection_parameters) + rng.normal(
            0, self.sigma, len(self.x)
        )
        self.likelihood = bilby.likelihood.GaussianLikelihood(
            self.x, self.y, model, self.sigma
        )

        self.priors = bilby.core.prior.PriorDict()
        self.priors["m"] = bilby.core.prior.Uniform(0, 5, boundary="periodic")
        self.priors["c"] = bilby.core.prior.Uniform(-2, 2, boundary="reflective")
        self._remove_tree()
        bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir")

    @staticmethod
    def conversion_function(parameters, likelihood, prior):
        converted = parameters.copy()
        if "derived" not in converted:
            converted["derived"] = converted["m"] * converted["c"]
        return converted

    def tearDown(self):
        del self.likelihood
        del self.priors
        bilby.core.utils.command_line_args.bilby_test_mode = False
        self._remove_tree()

    def _remove_tree(self):
        try:
            shutil.rmtree("outdir")
        except OSError:
            pass

    @parameterized.expand(_sampler_kwargs.keys())
    def test_run_sampler_single(self, sampler):
        self._run_sampler(sampler, pool_size=1)

    @parameterized.expand(_sampler_kwargs.keys())
    def test_run_sampler_pool(self, sampler):
        self._run_sampler(sampler, pool_size=2)

    def _run_sampler(self, sampler, pool_size, **extra_kwargs):
        pytest.importorskip(sampler_imports.get(sampler, sampler))
        if pool_size > 1 and sampler.lower() in no_pool_test:
            pytest.skip(f"{sampler} cannot be parallelized")
        bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir")
        kwargs = _sampler_kwargs[sampler]
        res = bilby.run_sampler(
            likelihood=self.likelihood,
            priors=self.priors,
            sampler=sampler,
            save="hdf5",
            npool=pool_size,
            conversion_function=self.conversion_function,
            **kwargs,
            **extra_kwargs,
        )
        assert "derived" in res.posterior
        if sampler != "dnest4":
            assert res.log_likelihood_evaluations is not None

    @parameterized.expand(_sampler_kwargs.keys())
    def test_interrupt_sampler_single(self, sampler):
        self._run_with_signal_handling(sampler, pool_size=1)

    @parameterized.expand(_sampler_kwargs.keys())
    def test_interrupt_sampler_pool(self, sampler):
        self._run_with_signal_handling(sampler, pool_size=2)

    def _run_with_signal_handling(self, sampler, pool_size=1):
        pytest.importorskip(sampler_imports.get(sampler, sampler))
        if loaded_samplers[sampler.lower()].hard_exit:
            pytest.skip(f"{sampler} hard exits, can't test signal handling.")
        if pool_size > 1 and sampler.lower() in no_pool_test:
            pytest.skip(f"{sampler} cannot be parallelized")
        if sys.version_info.minor == 8 and sampler.lower == "cpnest":
            pytest.skip("Pool interrupting broken for cpnest with py3.8")
        pid = os.getpid()
        print(sampler)

        def trigger_signal():
            # You could do something more robust, e.g. wait until port is listening
            time.sleep(4)
            os.kill(pid, SIGINT)

        thread = threading.Thread(target=trigger_signal)
        thread.daemon = True
        thread.start()

        self.likelihood._func = slow_func

        with self.assertRaises((SystemExit, KeyboardInterrupt)):
            try:
                while True:
                    self._run_sampler(sampler=sampler, pool_size=pool_size, exit_code=5)
            except SystemExit as error:
                self.assertEqual(error.code, 5)
                raise


if __name__ == "__main__":
    unittest.main()