Newer
Older
@pytest.mark.skipif(sys.version_info[1] <= 6, reason="pymc3 is broken in py36")
@pytest.mark.xfail(
raises=AttributeError,
reason="Dependency issue with pymc3 causes attribute error on import",
)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class TestPyMC3(unittest.TestCase):
def setUp(self):
self.likelihood = MagicMock()
self.priors = bilby.core.prior.PriorDict(
dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1))
)
self.sampler = bilby.core.sampler.Pymc3(
self.likelihood,
self.priors,
outdir="outdir",
label="label",
use_ratio=False,
plot=False,
skip_import_verification=True,
)
def tearDown(self):
del self.likelihood
del self.priors
del self.sampler
def test_default_kwargs(self):
expected = dict(
draws=500,
step=None,
init="auto",
n_init=200000,
start=None,
trace=None,
chain_idx=0,
chains=2,
cores=1,
tune=500,
progressbar=True,
model=None,
nuts_kwargs=None,
step_kwargs=None,
random_seed=None,
discard_tuned_samples=True,
compute_convergence_checks=True,
)
expected.update(self.sampler.default_nuts_kwargs)
expected.update(self.sampler.default_step_kwargs)
self.assertDictEqual(expected, self.sampler.kwargs)
def test_translate_kwargs(self):
expected = dict(
draws=500,
step=None,
init="auto",
n_init=200000,
start=None,
trace=None,
chain_idx=0,
chains=2,
cores=1,
tune=500,
progressbar=True,
model=None,
nuts_kwargs=None,
step_kwargs=None,
random_seed=None,
discard_tuned_samples=True,
compute_convergence_checks=True,
)
expected.update(self.sampler.default_nuts_kwargs)
expected.update(self.sampler.default_step_kwargs)
self.sampler.kwargs["draws"] = 123
for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy()
del new_kwargs["draws"]
new_kwargs[equiv] = 500
self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs)
if __name__ == "__main__":
unittest.main()