Skip to content
Snippets Groups Projects
nessai_test.py 2.99 KiB
Newer Older
Michael Williams's avatar
Michael Williams committed
import unittest

from mock import MagicMock, patch, mock_open

import bilby


class TestNessai(unittest.TestCase):
    maxDiff = None

    def setUp(self):
        self.likelihood = MagicMock()
        self.priors = bilby.core.prior.PriorDict(
            dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1))
        )
        self.sampler = bilby.core.sampler.Nessai(
            self.likelihood,
            self.priors,
            outdir="outdir",
            label="label",
            use_ratio=False,
            plot=False,
            skip_import_verification=True,
        )
        self.expected = self.sampler.default_kwargs
        self.expected['output'] = 'outdir/label_nessai/'
Michael Williams's avatar
Michael Williams committed

    def tearDown(self):
        del self.likelihood
        del self.priors
        del self.sampler
        del self.expected

    def test_translate_kwargs_nlive(self):
        expected = self.expected.copy()
        # nlive in the default kwargs is not a fixed but depends on the
        # version of nessai, so get the value here and use it when setting
        # the equivalent kwargs.
        nlive = expected["nlive"]
Michael Williams's avatar
Michael Williams committed
        for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
            new_kwargs = self.sampler.kwargs.copy()
            del new_kwargs["nlive"]
            new_kwargs[equiv] = nlive
Michael Williams's avatar
Michael Williams committed
            self.sampler.kwargs = new_kwargs
            self.assertDictEqual(expected, self.sampler.kwargs)

    def test_translate_kwargs_npool(self):
        expected = self.expected.copy()
        expected["n_pool"] = None
        for equiv in bilby.core.sampler.base_sampler.NestedSampler.npool_equiv_kwargs:
            new_kwargs = self.sampler.kwargs.copy()
            del new_kwargs["n_pool"]
            new_kwargs[equiv] = None
            self.sampler.kwargs = new_kwargs
            self.assertDictEqual(expected, self.sampler.kwargs)

    def test_translate_kwargs_seed(self):
        expected = self.expected.copy()
        expected["seed"] = 150914
        for equiv in bilby.core.sampler.nessai.Nessai.seed_equiv_kwargs:
            new_kwargs = self.sampler.kwargs.copy()
            del new_kwargs["seed"]
            new_kwargs[equiv] = 150914
            self.sampler.kwargs = new_kwargs
            self.assertDictEqual(expected, self.sampler.kwargs)

    def test_npool_max_threads(self):
        expected = self.expected.copy()
        expected["n_pool"] = None
        new_kwargs = self.sampler.kwargs.copy()
        new_kwargs["n_pool"] = 1
        self.sampler.kwargs = new_kwargs
        self.assertDictEqual(expected, self.sampler.kwargs)

    @patch("builtins.open", mock_open(read_data='{"nlive": 4000}'))
Michael Williams's avatar
Michael Williams committed
    def test_update_from_config_file(self):
        expected = self.expected.copy()
        expected["nlive"] = 4000
Michael Williams's avatar
Michael Williams committed
        new_kwargs = self.expected.copy()
        new_kwargs["config_file"] = "config_file.json"
        self.sampler.kwargs = new_kwargs
        self.assertDictEqual(expected, self.sampler.kwargs)


if __name__ == "__main__":
    unittest.main()