Skip to content
Snippets Groups Projects
Commit 13c120ae authored by Gregory Ashton's avatar Gregory Ashton Committed by Moritz Huebner
Browse files

Fix MCMC tests

parent 0c1baa01
No related branches found
No related tags found
1 merge request!985Fix MCMC tests
......@@ -406,8 +406,9 @@ class Chain(object):
# Plot the histograms
for ax, key in zip(axes[:, 1], self.keys):
yy_all = all_samples[key]
ax.hist(yy_all, bins=50, alpha=0.6, density=True, color="k")
if all_samples is not None:
yy_all = all_samples[key]
ax.hist(yy_all, bins=50, alpha=0.6, density=True, color="k")
yy = self.get_1d_array(key)[nburn : self.position : self.thin]
ax.hist(yy, bins=50, alpha=0.8, density=True)
......
......@@ -103,7 +103,8 @@ class TestChain(unittest.TestCase):
def test_minimum_index(self):
chain = self.create_chain()
self.assertEqual(chain.minimum_index, 0)
# Test initialization
self.assertEqual(chain.minimum_index, 1)
chain._last_minimum_index = (chain.position, 10, "I")
self.assertEqual(chain.minimum_index, 10)
......
......@@ -3,6 +3,7 @@ import copy
import shutil
import unittest
import inspect
import importlib
import sys
import time
import bilby
......@@ -150,10 +151,13 @@ class TestProposals(TestBaseProposals):
self.assertTrue(prop.trained)
def test_GMM_proposal(self):
priors = self.create_priors()
prop = proposals.GMMProposal(priors)
self.proposal_check(prop, N=20000)
self.assertTrue(prop.trained)
if importlib.util.find_spec("sklearn") is not None:
priors = self.create_priors()
prop = proposals.GMMProposal(priors)
self.proposal_check(prop, N=20000)
self.assertTrue(prop.trained)
else:
print("Unable to test GMM as sklearn is not installed")
def test_NF_proposal(self):
priors = self.create_priors()
......
......@@ -3,7 +3,7 @@ import shutil
import unittest
import bilby
from bilby.bilby_mcmc.sampler import BilbyMCMC, BilbyMCMCSampler, _initialize_global_variables
from bilby.bilby_mcmc.sampler import Bilby_MCMC, BilbyMCMCSampler, _initialize_global_variables
from bilby.bilby_mcmc.utils import ConvergenceInputs
from bilby.core.sampler.base_sampler import SamplerError
import numpy as np
......@@ -12,7 +12,7 @@ import pandas as pd
class TestBilbyMCMCSampler(unittest.TestCase):
def setUp(self):
default_kwargs = BilbyMCMC.default_kwargs
default_kwargs = Bilby_MCMC.default_kwargs
default_kwargs["target_nsamples"] = 100
default_kwargs["L1steps"] = 1
self.convergence_inputs = ConvergenceInputs(
......@@ -78,8 +78,6 @@ class TestBilbyMCMCSampler(unittest.TestCase):
self.assertEqual(sampler.chain.position, nsteps)
self.assertEqual(sampler.accepted + sampler.rejected, nsteps)
self.assertTrue(isinstance(sampler.samples, pd.DataFrame))
for prop in sampler.proposal_cycle.proposal_list:
self.assertGreater(prop.n, 50)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment