Newer
Older
import unittest
from mock import MagicMock, patch, mock_open
import bilby
class TestNessai(unittest.TestCase):
maxDiff = None
def setUp(self):
self.likelihood = MagicMock()
self.priors = bilby.core.prior.PriorDict(
dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1))
)
self.sampler = bilby.core.sampler.Nessai(
self.likelihood,
self.priors,
outdir="outdir",
label="label",
use_ratio=False,
plot=False,
skip_import_verification=True,
)
self.expected = self.sampler.default_kwargs
self.expected['output'] = 'outdir/label_nessai/'
def tearDown(self):
del self.likelihood
del self.priors
del self.sampler
del self.expected
def test_translate_kwargs_nlive(self):
expected = self.expected.copy()
# nlive in the default kwargs is not a fixed but depends on the
# version of nessai, so get the value here and use it when setting
# the equivalent kwargs.
nlive = expected["nlive"]
for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy()
del new_kwargs["nlive"]
self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs)
def test_translate_kwargs_npool(self):
expected = self.expected.copy()
expected["n_pool"] = None
for equiv in bilby.core.sampler.base_sampler.NestedSampler.npool_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy()
del new_kwargs["n_pool"]
new_kwargs[equiv] = None
self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs)
def test_translate_kwargs_seed(self):
expected = self.expected.copy()
expected["seed"] = 150914
for equiv in bilby.core.sampler.nessai.Nessai.seed_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy()
del new_kwargs["seed"]
new_kwargs[equiv] = 150914
self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs)
def test_npool_max_threads(self):
expected = self.expected.copy()
expected["n_pool"] = None
new_kwargs = self.sampler.kwargs.copy()
new_kwargs["n_pool"] = 1
self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs)
@patch("builtins.open", mock_open(read_data='{"nlive": 4000}'))
def test_update_from_config_file(self):
expected = self.expected.copy()
new_kwargs = self.expected.copy()
new_kwargs["config_file"] = "config_file.json"
self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs)
if __name__ == "__main__":
unittest.main()