Skip to content
Snippets Groups Projects
Commit 7c74e2d4 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Merge branch 'fix-mcmc-tests' into 'master'

Fix MCMC tests

See merge request lscsoft/bilby!985
parents 0c1baa01 13c120ae
No related branches found
No related tags found
No related merge requests found
......@@ -406,8 +406,9 @@ class Chain(object):
# Plot the histograms
for ax, key in zip(axes[:, 1], self.keys):
yy_all = all_samples[key]
ax.hist(yy_all, bins=50, alpha=0.6, density=True, color="k")
if all_samples is not None:
yy_all = all_samples[key]
ax.hist(yy_all, bins=50, alpha=0.6, density=True, color="k")
yy = self.get_1d_array(key)[nburn : self.position : self.thin]
ax.hist(yy, bins=50, alpha=0.8, density=True)
......
......@@ -103,7 +103,8 @@ class TestChain(unittest.TestCase):
def test_minimum_index(self):
chain = self.create_chain()
self.assertEqual(chain.minimum_index, 0)
# Test initialization
self.assertEqual(chain.minimum_index, 1)
chain._last_minimum_index = (chain.position, 10, "I")
self.assertEqual(chain.minimum_index, 10)
......
......@@ -3,6 +3,7 @@ import copy
import shutil
import unittest
import inspect
import importlib
import sys
import time
import bilby
......@@ -150,10 +151,13 @@ class TestProposals(TestBaseProposals):
self.assertTrue(prop.trained)
def test_GMM_proposal(self):
priors = self.create_priors()
prop = proposals.GMMProposal(priors)
self.proposal_check(prop, N=20000)
self.assertTrue(prop.trained)
if importlib.util.find_spec("sklearn") is not None:
priors = self.create_priors()
prop = proposals.GMMProposal(priors)
self.proposal_check(prop, N=20000)
self.assertTrue(prop.trained)
else:
print("Unable to test GMM as sklearn is not installed")
def test_NF_proposal(self):
priors = self.create_priors()
......
......@@ -3,7 +3,7 @@ import shutil
import unittest
import bilby
from bilby.bilby_mcmc.sampler import BilbyMCMC, BilbyMCMCSampler, _initialize_global_variables
from bilby.bilby_mcmc.sampler import Bilby_MCMC, BilbyMCMCSampler, _initialize_global_variables
from bilby.bilby_mcmc.utils import ConvergenceInputs
from bilby.core.sampler.base_sampler import SamplerError
import numpy as np
......@@ -12,7 +12,7 @@ import pandas as pd
class TestBilbyMCMCSampler(unittest.TestCase):
def setUp(self):
default_kwargs = BilbyMCMC.default_kwargs
default_kwargs = Bilby_MCMC.default_kwargs
default_kwargs["target_nsamples"] = 100
default_kwargs["L1steps"] = 1
self.convergence_inputs = ConvergenceInputs(
......@@ -78,8 +78,6 @@ class TestBilbyMCMCSampler(unittest.TestCase):
self.assertEqual(sampler.chain.position, nsteps)
self.assertEqual(sampler.accepted + sampler.rejected, nsteps)
self.assertTrue(isinstance(sampler.samples, pd.DataFrame))
for prop in sampler.proposal_cycle.proposal_list:
self.assertGreater(prop.n, 50)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment