From 164ca227a109f739890efff54689a24d245aa02a Mon Sep 17 00:00:00 2001 From: Michael Williams <michael.williams@ligo.org> Date: Mon, 26 Feb 2024 16:48:31 +0000 Subject: [PATCH] MNT: change bilby MCMC to use glasflow instead of nflows --- bilby/bilby_mcmc/flows.py | 12 ++++----- bilby/bilby_mcmc/proposals.py | 4 +-- containers/env-template.yml | 3 ++- mcmc_requirements.txt | 2 +- optional_requirements.txt | 1 + test/bilby_mcmc/test_proposals.py | 41 ++++++++++++++----------------- 6 files changed, 31 insertions(+), 32 deletions(-) diff --git a/bilby/bilby_mcmc/flows.py b/bilby/bilby_mcmc/flows.py index 5fbaf196b..b08ea3a93 100644 --- a/bilby/bilby_mcmc/flows.py +++ b/bilby/bilby_mcmc/flows.py @@ -1,17 +1,17 @@ import torch -from nflows.distributions.normal import StandardNormal -from nflows.flows.base import Flow -from nflows.nn import nets as nets -from nflows.transforms import ( +from glasflow.nflows.distributions.normal import StandardNormal +from glasflow.nflows.flows.base import Flow +from glasflow.nflows.nn import nets as nets +from glasflow.nflows.transforms import ( CompositeTransform, MaskedAffineAutoregressiveTransform, RandomPermutation, ) -from nflows.transforms.coupling import ( +from glasflow.nflows.transforms.coupling import ( AdditiveCouplingTransform, AffineCouplingTransform, ) -from nflows.transforms.normalization import BatchNorm +from glasflow.nflows.transforms.normalization import BatchNorm from torch.nn import functional as F # Turn off parallelism diff --git a/bilby/bilby_mcmc/proposals.py b/bilby/bilby_mcmc/proposals.py index 17892e050..6100d75f8 100644 --- a/bilby/bilby_mcmc/proposals.py +++ b/bilby/bilby_mcmc/proposals.py @@ -754,10 +754,10 @@ class NormalizingFlowProposal(DensityEstimateProposal): @staticmethod def check_dependencies(warn=True): - if importlib.util.find_spec("nflows") is None: + if importlib.util.find_spec("glasflow") is None: if warn: logger.warning( - "Unable to utilise NormalizingFlowProposal as nflows is not installed" + "Unable to utilise NormalizingFlowProposal as glasflow is not installed" ) return False else: diff --git a/containers/env-template.yml b/containers/env-template.yml index b62a94623..064364352 100644 --- a/containers/env-template.yml +++ b/containers/env-template.yml @@ -20,6 +20,7 @@ dependencies: - dill - black - pytest-cov + - pytest-requires - arviz - parameterized - scikit-image @@ -65,8 +66,8 @@ dependencies: - jupyter - nbconvert - twine + - glasflow - pip: - autodoc - ipykernel - build - - nflows diff --git a/mcmc_requirements.txt b/mcmc_requirements.txt index 6f5678c04..441ba479c 100644 --- a/mcmc_requirements.txt +++ b/mcmc_requirements.txt @@ -1,2 +1,2 @@ scikit-learn -nflows +glasflow diff --git a/optional_requirements.txt b/optional_requirements.txt index 60e3fb4ba..07934750c 100644 --- a/optional_requirements.txt +++ b/optional_requirements.txt @@ -1,3 +1,4 @@ celerite george plotly +pytest-requires diff --git a/test/bilby_mcmc/test_proposals.py b/test/bilby_mcmc/test_proposals.py index 37fa0a0fe..cd36f94fa 100644 --- a/test/bilby_mcmc/test_proposals.py +++ b/test/bilby_mcmc/test_proposals.py @@ -11,6 +11,7 @@ from bilby.bilby_mcmc.chain import Chain, Sample from bilby.bilby_mcmc import proposals from bilby.bilby_mcmc.utils import LOGLKEY, LOGPKEY import numpy as np +import pytest class GivenProposal(proposals.BaseProposal): @@ -164,36 +165,32 @@ class TestProposals(TestBaseProposals): else: print("Unable to test GMM as sklearn is not installed") + @pytest.mark.requires("glasflow") def test_NF_proposal(self): priors = self.create_priors() chain = self.create_chain(10000) - if proposals.NormalizingFlowProposal.check_dependencies(): - prop = proposals.NormalizingFlowProposal(priors, first_fit=10000) - prop.steps_since_refit = 9999 - start = time.time() - p, w = prop(chain) - dt = time.time() - start - print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]") - self.assertTrue(prop.trained) - self.proposal_check(prop) - else: - print("nflows not installed, unable to test NormalizingFlowProposal") + prop = proposals.NormalizingFlowProposal(priors, first_fit=10000) + prop.steps_since_refit = 9999 + start = time.time() + p, w = prop(chain) + dt = time.time() - start + print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]") + self.assertTrue(prop.trained) + self.proposal_check(prop) + @pytest.mark.requires("glasflow") def test_NF_proposal_15D(self): ndim = 15 priors = self.create_priors(ndim) chain = self.create_chain(10000, ndim=ndim) - if proposals.NormalizingFlowProposal.check_dependencies(): - prop = proposals.NormalizingFlowProposal(priors, first_fit=10000) - prop.steps_since_refit = 9999 - start = time.time() - p, w = prop(chain) - dt = time.time() - start - print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]") - self.assertTrue(prop.trained) - self.proposal_check(prop, ndim=ndim) - else: - print("nflows not installed, unable to test NormalizingFlowProposal") + prop = proposals.NormalizingFlowProposal(priors, first_fit=10000) + prop.steps_since_refit = 9999 + start = time.time() + p, w = prop(chain) + dt = time.time() - start + print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]") + self.assertTrue(prop.trained) + self.proposal_check(prop, ndim=ndim) if __name__ == "__main__": -- GitLab