From 164ca227a109f739890efff54689a24d245aa02a Mon Sep 17 00:00:00 2001
From: Michael Williams <michael.williams@ligo.org>
Date: Mon, 26 Feb 2024 16:48:31 +0000
Subject: [PATCH] MNT: change bilby MCMC to use glasflow instead of nflows

---
 bilby/bilby_mcmc/flows.py         | 12 ++++-----
 bilby/bilby_mcmc/proposals.py     |  4 +--
 containers/env-template.yml       |  3 ++-
 mcmc_requirements.txt             |  2 +-
 optional_requirements.txt         |  1 +
 test/bilby_mcmc/test_proposals.py | 41 ++++++++++++++-----------------
 6 files changed, 31 insertions(+), 32 deletions(-)

diff --git a/bilby/bilby_mcmc/flows.py b/bilby/bilby_mcmc/flows.py
index 5fbaf196b..b08ea3a93 100644
--- a/bilby/bilby_mcmc/flows.py
+++ b/bilby/bilby_mcmc/flows.py
@@ -1,17 +1,17 @@
 import torch
-from nflows.distributions.normal import StandardNormal
-from nflows.flows.base import Flow
-from nflows.nn import nets as nets
-from nflows.transforms import (
+from glasflow.nflows.distributions.normal import StandardNormal
+from glasflow.nflows.flows.base import Flow
+from glasflow.nflows.nn import nets as nets
+from glasflow.nflows.transforms import (
     CompositeTransform,
     MaskedAffineAutoregressiveTransform,
     RandomPermutation,
 )
-from nflows.transforms.coupling import (
+from glasflow.nflows.transforms.coupling import (
     AdditiveCouplingTransform,
     AffineCouplingTransform,
 )
-from nflows.transforms.normalization import BatchNorm
+from glasflow.nflows.transforms.normalization import BatchNorm
 from torch.nn import functional as F
 
 # Turn off parallelism
diff --git a/bilby/bilby_mcmc/proposals.py b/bilby/bilby_mcmc/proposals.py
index 17892e050..6100d75f8 100644
--- a/bilby/bilby_mcmc/proposals.py
+++ b/bilby/bilby_mcmc/proposals.py
@@ -754,10 +754,10 @@ class NormalizingFlowProposal(DensityEstimateProposal):
 
     @staticmethod
     def check_dependencies(warn=True):
-        if importlib.util.find_spec("nflows") is None:
+        if importlib.util.find_spec("glasflow") is None:
             if warn:
                 logger.warning(
-                    "Unable to utilise NormalizingFlowProposal as nflows is not installed"
+                    "Unable to utilise NormalizingFlowProposal as glasflow is not installed"
                 )
             return False
         else:
diff --git a/containers/env-template.yml b/containers/env-template.yml
index b62a94623..064364352 100644
--- a/containers/env-template.yml
+++ b/containers/env-template.yml
@@ -20,6 +20,7 @@ dependencies:
   - dill
   - black
   - pytest-cov
+  - pytest-requires
   - arviz
   - parameterized
   - scikit-image
@@ -65,8 +66,8 @@ dependencies:
   - jupyter
   - nbconvert
   - twine
+  - glasflow
   - pip:
     - autodoc
     - ipykernel
     - build
-    - nflows
diff --git a/mcmc_requirements.txt b/mcmc_requirements.txt
index 6f5678c04..441ba479c 100644
--- a/mcmc_requirements.txt
+++ b/mcmc_requirements.txt
@@ -1,2 +1,2 @@
 scikit-learn
-nflows
+glasflow
diff --git a/optional_requirements.txt b/optional_requirements.txt
index 60e3fb4ba..07934750c 100644
--- a/optional_requirements.txt
+++ b/optional_requirements.txt
@@ -1,3 +1,4 @@
 celerite
 george
 plotly
+pytest-requires
diff --git a/test/bilby_mcmc/test_proposals.py b/test/bilby_mcmc/test_proposals.py
index 37fa0a0fe..cd36f94fa 100644
--- a/test/bilby_mcmc/test_proposals.py
+++ b/test/bilby_mcmc/test_proposals.py
@@ -11,6 +11,7 @@ from bilby.bilby_mcmc.chain import Chain, Sample
 from bilby.bilby_mcmc import proposals
 from bilby.bilby_mcmc.utils import LOGLKEY, LOGPKEY
 import numpy as np
+import pytest
 
 
 class GivenProposal(proposals.BaseProposal):
@@ -164,36 +165,32 @@ class TestProposals(TestBaseProposals):
         else:
             print("Unable to test GMM as sklearn is not installed")
 
+    @pytest.mark.requires("glasflow")
     def test_NF_proposal(self):
         priors = self.create_priors()
         chain = self.create_chain(10000)
-        if proposals.NormalizingFlowProposal.check_dependencies():
-            prop = proposals.NormalizingFlowProposal(priors, first_fit=10000)
-            prop.steps_since_refit = 9999
-            start = time.time()
-            p, w = prop(chain)
-            dt = time.time() - start
-            print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]")
-            self.assertTrue(prop.trained)
-            self.proposal_check(prop)
-        else:
-            print("nflows not installed, unable to test NormalizingFlowProposal")
+        prop = proposals.NormalizingFlowProposal(priors, first_fit=10000)
+        prop.steps_since_refit = 9999
+        start = time.time()
+        p, w = prop(chain)
+        dt = time.time() - start
+        print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]")
+        self.assertTrue(prop.trained)
+        self.proposal_check(prop)
 
+    @pytest.mark.requires("glasflow")
     def test_NF_proposal_15D(self):
         ndim = 15
         priors = self.create_priors(ndim)
         chain = self.create_chain(10000, ndim=ndim)
-        if proposals.NormalizingFlowProposal.check_dependencies():
-            prop = proposals.NormalizingFlowProposal(priors, first_fit=10000)
-            prop.steps_since_refit = 9999
-            start = time.time()
-            p, w = prop(chain)
-            dt = time.time() - start
-            print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]")
-            self.assertTrue(prop.trained)
-            self.proposal_check(prop, ndim=ndim)
-        else:
-            print("nflows not installed, unable to test NormalizingFlowProposal")
+        prop = proposals.NormalizingFlowProposal(priors, first_fit=10000)
+        prop.steps_since_refit = 9999
+        start = time.time()
+        p, w = prop(chain)
+        dt = time.time() - start
+        print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]")
+        self.assertTrue(prop.trained)
+        self.proposal_check(prop, ndim=ndim)
 
 
 if __name__ == "__main__":
-- 
GitLab