Skip to content
Snippets Groups Projects
Commit 9debf63b authored by Michael Williams's avatar Michael Williams Committed by Colm Talbot
Browse files

TST: optionally skip tests that require glasflow

Uses pytest-requires to skip tests. This means the tests with showed as skipped rather than passing
parent df3aafce
No related branches found
No related tags found
1 merge request!1332MNT: change bilby MCMC to use glasflow instead of nflows
This commit is part of merge request !1332. Comments created here will be created in the context of that merge request.
......@@ -11,6 +11,7 @@ from bilby.bilby_mcmc.chain import Chain, Sample
from bilby.bilby_mcmc import proposals
from bilby.bilby_mcmc.utils import LOGLKEY, LOGPKEY
import numpy as np
import pytest
class GivenProposal(proposals.BaseProposal):
......@@ -164,36 +165,32 @@ class TestProposals(TestBaseProposals):
else:
print("Unable to test GMM as sklearn is not installed")
@pytest.mark.requires("glasflow")
def test_NF_proposal(self):
priors = self.create_priors()
chain = self.create_chain(10000)
if proposals.NormalizingFlowProposal.check_dependencies():
prop = proposals.NormalizingFlowProposal(priors, first_fit=10000)
prop.steps_since_refit = 9999
start = time.time()
p, w = prop(chain)
dt = time.time() - start
print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]")
self.assertTrue(prop.trained)
self.proposal_check(prop)
else:
print("nflows not installed, unable to test NormalizingFlowProposal")
prop = proposals.NormalizingFlowProposal(priors, first_fit=10000)
prop.steps_since_refit = 9999
start = time.time()
p, w = prop(chain)
dt = time.time() - start
print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]")
self.assertTrue(prop.trained)
self.proposal_check(prop)
@pytest.mark.requires("glasflow")
def test_NF_proposal_15D(self):
ndim = 15
priors = self.create_priors(ndim)
chain = self.create_chain(10000, ndim=ndim)
if proposals.NormalizingFlowProposal.check_dependencies():
prop = proposals.NormalizingFlowProposal(priors, first_fit=10000)
prop.steps_since_refit = 9999
start = time.time()
p, w = prop(chain)
dt = time.time() - start
print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]")
self.assertTrue(prop.trained)
self.proposal_check(prop, ndim=ndim)
else:
print("nflows not installed, unable to test NormalizingFlowProposal")
prop = proposals.NormalizingFlowProposal(priors, first_fit=10000)
prop.steps_since_refit = 9999
start = time.time()
p, w = prop(chain)
dt = time.time() - start
print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]")
self.assertTrue(prop.trained)
self.proposal_check(prop, ndim=ndim)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment