Code owners
Assign users and groups as approvers for specific file changes. Learn more.
sampler_run_test.py 5.81 KiB
import multiprocessing
import os
import sys
import threading
import time
from signal import SIGINT
multiprocessing.set_start_method("fork") # noqa
import unittest
import pytest
from parameterized import parameterized
import shutil
import bilby
import numpy as np
_sampler_kwargs = dict(
bilby_mcmc=dict(nsamples=200, printdt=1),
cpnest=dict(nlive=100),
dnest4=dict(
max_num_levels=2,
num_steps=10,
new_level_interval=10,
num_per_step=10,
thread_steps=1,
num_particles=50,
max_pool=1,
),
dynesty=dict(nlive=10, sample="acceptance-walk", nact=5, proposals=["diff"]),
dynamic_dynesty=dict(
nlive_init=10,
nlive_batch=10,
dlogz_init=1.0,
maxbatch=0,
maxcall=100,
sample="act-walk",
),
emcee=dict(iterations=1000, nwalkers=10),
kombine=dict(iterations=200, nwalkers=10, autoburnin=False),
nessai=dict(
nlive=100,
poolsize=100,
max_iteration=500,
),
nestle=dict(nlive=100),
ptemcee=dict(
nsamples=100,
nwalkers=50,
burn_in_act=1,
ntemps=1,
frac_threshold=0.5,
),
PTMCMCSampler=dict(Niter=101, burn=100, covUpdate=100, isave=100),
pymc=dict(draws=50, tune=50, n_init=250),
pymultinest=dict(nlive=100),
ultranest=dict(nlive=100, temporary_directory=False),
zeus=dict(nwalkers=10, iterations=100)
)
sampler_imports = dict(
bilby_mcmc="bilby",
dynamic_dynesty="dynesty"
)
no_pool_test = ["dnest4", "pymultinest", "nestle", "ptmcmcsampler", "ultranest", "pymc"]
loaded_samplers = {k: v.load() for k, v in bilby.core.sampler.IMPLEMENTED_SAMPLERS.items()}
def slow_func(x, m, c):
time.sleep(0.01)
return m * x + c
def model(x, m, c):
return m * x + c
class TestRunningSamplers(unittest.TestCase):
def setUp(self):
bilby.core.utils.random.seed(42)
bilby.core.utils.command_line_args.bilby_test_mode = False
rng = bilby.core.utils.random.rng
self.x = np.linspace(0, 1, 11)
self.injection_parameters = dict(m=0.5, c=0.2)
self.sigma = 0.1
self.y = model(self.x, **self.injection_parameters) + rng.normal(
0, self.sigma, len(self.x)
)
self.likelihood = bilby.likelihood.GaussianLikelihood(
self.x, self.y, model, self.sigma
)
self.priors = bilby.core.prior.PriorDict()
self.priors["m"] = bilby.core.prior.Uniform(0, 5, boundary="periodic")
self.priors["c"] = bilby.core.prior.Uniform(-2, 2, boundary="reflective")
self._remove_tree()
bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir")
@staticmethod
def conversion_function(parameters, likelihood, prior):
converted = parameters.copy()
if "derived" not in converted:
converted["derived"] = converted["m"] * converted["c"]
return converted
def tearDown(self):
del self.likelihood
del self.priors
bilby.core.utils.command_line_args.bilby_test_mode = False
self._remove_tree()
def _remove_tree(self):
try:
shutil.rmtree("outdir")
except OSError:
pass
@parameterized.expand(_sampler_kwargs.keys())
def test_run_sampler_single(self, sampler):
self._run_sampler(sampler, pool_size=1)
@parameterized.expand(_sampler_kwargs.keys())
def test_run_sampler_pool(self, sampler):
self._run_sampler(sampler, pool_size=2)
def _run_sampler(self, sampler, pool_size, **extra_kwargs):
pytest.importorskip(sampler_imports.get(sampler, sampler))
if pool_size > 1 and sampler.lower() in no_pool_test:
pytest.skip(f"{sampler} cannot be parallelized")
bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir")
kwargs = _sampler_kwargs[sampler]
res = bilby.run_sampler(
likelihood=self.likelihood,
priors=self.priors,
sampler=sampler,
save="hdf5",
npool=pool_size,
conversion_function=self.conversion_function,
**kwargs,
**extra_kwargs,
)
assert "derived" in res.posterior
if sampler != "dnest4":
assert res.log_likelihood_evaluations is not None
@parameterized.expand(_sampler_kwargs.keys())
def test_interrupt_sampler_single(self, sampler):
self._run_with_signal_handling(sampler, pool_size=1)
@parameterized.expand(_sampler_kwargs.keys())
def test_interrupt_sampler_pool(self, sampler):
self._run_with_signal_handling(sampler, pool_size=2)
def _run_with_signal_handling(self, sampler, pool_size=1):
pytest.importorskip(sampler_imports.get(sampler, sampler))
if loaded_samplers[sampler.lower()].hard_exit:
pytest.skip(f"{sampler} hard exits, can't test signal handling.")
if pool_size > 1 and sampler.lower() in no_pool_test:
pytest.skip(f"{sampler} cannot be parallelized")
if sys.version_info.minor == 8 and sampler.lower == "cpnest":
pytest.skip("Pool interrupting broken for cpnest with py3.8")
pid = os.getpid()
print(sampler)
def trigger_signal():
# You could do something more robust, e.g. wait until port is listening
time.sleep(4)
os.kill(pid, SIGINT)
thread = threading.Thread(target=trigger_signal)
thread.daemon = True
thread.start()
self.likelihood._func = slow_func
with self.assertRaises((SystemExit, KeyboardInterrupt)):
try:
while True:
self._run_sampler(sampler=sampler, pool_size=pool_size, exit_code=5)
except SystemExit as error:
self.assertEqual(error.code, 5)
raise
if __name__ == "__main__":
unittest.main()