diff --git a/bilby/bilby_mcmc/proposals.py b/bilby/bilby_mcmc/proposals.py index e7363fe8d172a1442c86610ecc5672b1c79dba9c..29c40d057a217f194576bf75964c9b331a865739 100644 --- a/bilby/bilby_mcmc/proposals.py +++ b/bilby/bilby_mcmc/proposals.py @@ -547,6 +547,7 @@ class GMMProposal(DensityEstimateProposal): def _sample(self, nsamples=None): return np.squeeze(self.density.sample(n_samples=nsamples)[0]) + @staticmethod def check_dependencies(warn=True): if importlib.util.find_spec("sklearn") is None: if warn: @@ -593,12 +594,15 @@ class NormalizingFlowProposal(DensityEstimateProposal): fallback=fallback, scale_fits=scale_fits, ) - self.setup_flow() - self.setup_optimizer() - + self.initialised = False self.max_training_epochs = max_training_epochs self.js_factor = js_factor + def initialise(self): + self.setup_flow() + self.setup_optimizer() + self.initialised = True + def setup_flow(self): if self.ndim < 3: self.setup_basic_flow() @@ -699,6 +703,9 @@ class NormalizingFlowProposal(DensityEstimateProposal): self.trained = True def propose(self, chain): + if self.initialised is False: + self.initialise() + import torch self.steps_since_refit += 1 @@ -728,6 +735,7 @@ class NormalizingFlowProposal(DensityEstimateProposal): return theta, float(log_factor) + @staticmethod def check_dependencies(warn=True): if importlib.util.find_spec("nflows") is None: if warn: @@ -1094,10 +1102,6 @@ def get_proposal_cycle(string, priors, L1steps=1, warn=True): ] if GMMProposal.check_dependencies(warn=warn): plist.append(GMMProposal(priors, weight=big_weight, scale_fits=L1steps)) - if NormalizingFlowProposal.check_dependencies(warn=warn): - plist.append( - NormalizingFlowProposal(priors, weight=big_weight, scale_fits=L1steps) - ) plist = remove_proposals_using_string(plist, string) return ProposalCycle(plist) diff --git a/test/bilby_mcmc/test_proposals.py b/test/bilby_mcmc/test_proposals.py index 3bb70b168569edf2258c8f6b42674f53f8175a6f..84042d2d409dbaea08ecf21d02604cb82bdb45bc 100644 --- a/test/bilby_mcmc/test_proposals.py +++ b/test/bilby_mcmc/test_proposals.py @@ -165,29 +165,33 @@ class TestProposals(TestBaseProposals): def test_NF_proposal(self): priors = self.create_priors() chain = self.create_chain(10000) - prop = proposals.NormalizingFlowProposal(priors, first_fit=10000) - prop.steps_since_refit = 9999 - start = time.time() - p, w = prop(chain) - dt = time.time() - start - print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]") - self.assertTrue(prop.trained) - - self.proposal_check(prop) + if proposals.NormalizingFlowProposal.check_dependencies(): + prop = proposals.NormalizingFlowProposal(priors, first_fit=10000) + prop.steps_since_refit = 9999 + start = time.time() + p, w = prop(chain) + dt = time.time() - start + print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]") + self.assertTrue(prop.trained) + self.proposal_check(prop) + else: + print("nflows not installed, unable to test NormalizingFlowProposal") def test_NF_proposal_15D(self): ndim = 15 priors = self.create_priors(ndim) chain = self.create_chain(10000, ndim=ndim) - prop = proposals.NormalizingFlowProposal(priors, first_fit=10000) - prop.steps_since_refit = 9999 - start = time.time() - p, w = prop(chain) - dt = time.time() - start - print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]") - self.assertTrue(prop.trained) - - self.proposal_check(prop, ndim=ndim) + if proposals.NormalizingFlowProposal.check_dependencies(): + prop = proposals.NormalizingFlowProposal(priors, first_fit=10000) + prop.steps_since_refit = 9999 + start = time.time() + p, w = prop(chain) + dt = time.time() - start + print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]") + self.assertTrue(prop.trained) + self.proposal_check(prop, ndim=ndim) + else: + print("nflows not installed, unable to test NormalizingFlowProposal") if __name__ == "__main__": diff --git a/test/core/sampler/nessai_test.py b/test/core/sampler/nessai_test.py index 7f6ec21a8a5d26b606e8c6f8aa3cae3ede905ca2..cbb084735ec50274b45d2dd629772c92d0d3daed 100644 --- a/test/core/sampler/nessai_test.py +++ b/test/core/sampler/nessai_test.py @@ -21,6 +21,7 @@ class TestNessai(unittest.TestCase): plot=False, skip_import_verification=True, sampling_seed=150914, + npool=None, # TODO: remove when support for nessai<0.7.0 is dropped ) self.expected = self.sampler.default_kwargs self.expected['output'] = 'outdir/label_nessai/' @@ -59,10 +60,13 @@ class TestNessai(unittest.TestCase): assert self.expected["seed"] == 150914 def test_npool_max_threads(self): + # TODO: remove when support for nessai<0.7.0 is dropped expected = self.expected.copy() expected["n_pool"] = None + expected["max_threads"] = 1 new_kwargs = self.sampler.kwargs.copy() new_kwargs["n_pool"] = 1 + new_kwargs["max_threads"] = 1 self.sampler.kwargs = new_kwargs self.assertDictEqual(expected, self.sampler.kwargs) diff --git a/test/gw/utils_test.py b/test/gw/utils_test.py index 7946a99c83b99958160636f482f4593103f6c7fe..a63c3ae0dde56686c5e28450f117241c864d16ba 100644 --- a/test/gw/utils_test.py +++ b/test/gw/utils_test.py @@ -104,7 +104,7 @@ class TestGWUtils(unittest.TestCase): strain = gwutils.read_frame_file( filename, start_time=None, end_time=None, channel=channel ) - self.assertEqual(strain.channel.name, channel) + self.assertEqual(strain.name, channel) self.assertTrue(np.all(strain.value == data[:-1])) # Check reading with time limits