Skip to content
Snippets Groups Projects
pymc3_test.py 2.61 KiB
Newer Older
Moritz Huebner's avatar
Moritz Huebner committed
import unittest
Colm Talbot's avatar
Colm Talbot committed
import pytest
import sys
Moritz Huebner's avatar
Moritz Huebner committed

from mock import MagicMock

import bilby


Colm Talbot's avatar
Colm Talbot committed
@pytest.mark.skipif(sys.version_info[1] <= 6, reason="pymc3 is broken in py36")
@pytest.mark.xfail(
    raises=AttributeError,
    reason="Dependency issue with pymc3 causes attribute error on import",
)
Moritz Huebner's avatar
Moritz Huebner committed
class TestPyMC3(unittest.TestCase):
    def setUp(self):
        self.likelihood = MagicMock()
        self.priors = bilby.core.prior.PriorDict(
            dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1))
        )
        self.sampler = bilby.core.sampler.Pymc3(
            self.likelihood,
            self.priors,
            outdir="outdir",
            label="label",
            use_ratio=False,
            plot=False,
            skip_import_verification=True,
        )

    def tearDown(self):
        del self.likelihood
        del self.priors
        del self.sampler

    def test_default_kwargs(self):
        expected = dict(
            draws=500,
            step=None,
            init="auto",
            n_init=200000,
            start=None,
            trace=None,
            chain_idx=0,
            chains=2,
            cores=1,
            tune=500,
            progressbar=True,
            model=None,
            nuts_kwargs=None,
            step_kwargs=None,
            random_seed=None,
            discard_tuned_samples=True,
            compute_convergence_checks=True,
        )
        expected.update(self.sampler.default_nuts_kwargs)
        expected.update(self.sampler.default_step_kwargs)
        self.assertDictEqual(expected, self.sampler.kwargs)

    def test_translate_kwargs(self):
        expected = dict(
            draws=500,
            step=None,
            init="auto",
            n_init=200000,
            start=None,
            trace=None,
            chain_idx=0,
            chains=2,
            cores=1,
            tune=500,
            progressbar=True,
            model=None,
            nuts_kwargs=None,
            step_kwargs=None,
            random_seed=None,
            discard_tuned_samples=True,
            compute_convergence_checks=True,
        )
        expected.update(self.sampler.default_nuts_kwargs)
        expected.update(self.sampler.default_step_kwargs)
        self.sampler.kwargs["draws"] = 123
        for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
            new_kwargs = self.sampler.kwargs.copy()
            del new_kwargs["draws"]
            new_kwargs[equiv] = 500
            self.sampler.kwargs = new_kwargs
            self.assertDictEqual(expected, self.sampler.kwargs)


if __name__ == "__main__":
    unittest.main()