diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 13c4774c2f7f3f8be4067a2718892bf79d146fb2..3fcef71031c0227aefdc83b5f55ab64fd457b4e6 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -84,7 +84,7 @@ basic-3.11: script: - python -m pip install . - *list-env - - python test/test_samplers_import.py + - pytest test/test_samplers_import.py -v import-samplers-3.9: <<: *test-samplers-import diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index ab0c8f15480d83c6bcb11e5afba4cec6e584b0e6..5d56b3ea5ce4585e33637ced1ef6ce73382a67ab 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -2,54 +2,69 @@ import datetime import inspect import sys -import bilby -from bilby.bilby_mcmc import Bilby_MCMC - from ..prior import DeltaFunction, PriorDict -from ..utils import command_line_args, env_package_list, loaded_modules_dict, logger +from ..utils import ( + command_line_args, + env_package_list, + get_entry_points, + loaded_modules_dict, + logger, +) from . import proposal from .base_sampler import Sampler, SamplingMarginalisedParameterError -from .cpnest import Cpnest -from .dnest4 import DNest4 -from .dynamic_dynesty import DynamicDynesty -from .dynesty import Dynesty -from .emcee import Emcee -from .fake_sampler import FakeSampler -from .kombine import Kombine -from .nessai import Nessai -from .nestle import Nestle -from .polychord import PyPolyChord -from .ptemcee import Ptemcee -from .ptmcmc import PTMCMCSampler -from .pymc import Pymc -from .pymultinest import Pymultinest -from .ultranest import Ultranest -from .zeus import Zeus - -IMPLEMENTED_SAMPLERS = { - "bilby_mcmc": Bilby_MCMC, - "cpnest": Cpnest, - "dnest4": DNest4, - "dynamic_dynesty": DynamicDynesty, - "dynesty": Dynesty, - "emcee": Emcee, - "kombine": Kombine, - "nessai": Nessai, - "nestle": Nestle, - "ptemcee": Ptemcee, - "ptmcmcsampler": PTMCMCSampler, - "pymc": Pymc, - "pymultinest": Pymultinest, - "pypolychord": PyPolyChord, - "ultranest": Ultranest, - "zeus": Zeus, - "fake_sampler": FakeSampler, -} + +IMPLEMENTED_SAMPLERS = get_entry_points("bilby.samplers") + + +def get_implemented_samplers(): + """Get a list of the names of the implemented samplers. + + This includes natively supported samplers (e.g. dynesty) and any additional + samplers that are supported through the sampler plugins. + + Returns + ------- + list + The list of implemented samplers. + """ + return list(IMPLEMENTED_SAMPLERS.keys()) + + +def get_sampler_class(sampler): + """Get the class for a sampler from its name. + + This includes natively supported samplers (e.g. dynesty) and any additional + samplers that are supported through the sampler plugins. + + Parameters + ---------- + sampler : str + The name of the sampler. + + Returns + ------- + Sampler + The sampler class. + + Raises + ------ + ValueError + Raised if the sampler is not implemented. + """ + sampler = sampler.lower() + if sampler in IMPLEMENTED_SAMPLERS: + return IMPLEMENTED_SAMPLERS[sampler].load() + else: + raise ValueError( + f"Sampler {sampler} not yet implemented. " + f"The available samplers are: {get_implemented_samplers()}" + ) + if command_line_args.sampler_help: sampler = command_line_args.sampler_help if sampler in IMPLEMENTED_SAMPLERS: - sampler_class = IMPLEMENTED_SAMPLERS[sampler] + sampler_class = IMPLEMENTED_SAMPLERS[sampler].load() print(f'Help for sampler "{sampler}":') print(sampler_class.__doc__) else: @@ -60,7 +75,7 @@ if command_line_args.sampler_help: ) else: print(f"Requested sampler {sampler} not implemented") - print(f"Available samplers = {IMPLEMENTED_SAMPLERS}") + print(f"Available samplers = {get_implemented_samplers()}") sys.exit() @@ -185,24 +200,20 @@ def run_sampler( if isinstance(sampler, Sampler): pass elif isinstance(sampler, str): - if sampler.lower() in IMPLEMENTED_SAMPLERS: - sampler_class = IMPLEMENTED_SAMPLERS[sampler.lower()] - sampler = sampler_class( - likelihood, - priors=priors, - outdir=outdir, - label=label, - injection_parameters=injection_parameters, - meta_data=meta_data, - use_ratio=use_ratio, - plot=plot, - result_class=result_class, - npool=npool, - **kwargs, - ) - else: - print(IMPLEMENTED_SAMPLERS) - raise ValueError(f"Sampler {sampler} not yet implemented") + sampler_class = get_sampler_class(sampler) + sampler = sampler_class( + likelihood, + priors=priors, + outdir=outdir, + label=label, + injection_parameters=injection_parameters, + meta_data=meta_data, + use_ratio=use_ratio, + plot=plot, + result_class=result_class, + npool=npool, + **kwargs, + ) elif inspect.isclass(sampler): sampler = sampler.__init__( likelihood, @@ -219,7 +230,7 @@ def run_sampler( else: raise ValueError( "Provided sampler should be a Sampler object or name of a known " - f"sampler: {', '.join(IMPLEMENTED_SAMPLERS.keys())}." + f"sampler: {get_implemented_samplers()}." ) if sampler.cached_result: diff --git a/bilby/core/utils/__init__.py b/bilby/core/utils/__init__.py index 25d1eda934151fdcb85806dd336406a9f5ea7a90..bb59915642f7b144533437a8a0f7e41cc4d15b66 100644 --- a/bilby/core/utils/__init__.py +++ b/bilby/core/utils/__init__.py @@ -6,6 +6,7 @@ from .constants import * from .conversion import * from .counter import * from .docs import * +from .entry_points import * from .introspection import * from .io import * from .log import * diff --git a/bilby/core/utils/entry_points.py b/bilby/core/utils/entry_points.py new file mode 100644 index 0000000000000000000000000000000000000000..305fc57040dfdbbbc1abec8c1aa82cbf8f63c25c --- /dev/null +++ b/bilby/core/utils/entry_points.py @@ -0,0 +1,18 @@ +import sys +if sys.version_info < (3, 10): + from importlib_metadata import entry_points +else: + from importlib.metadata import entry_points + + +def get_entry_points(group): + """Return a dictionary of entry points for a given group + + Parameters + ---------- + group: str + Entry points you wish to query + """ + return { + custom.name: custom for custom in entry_points(group=group) + } diff --git a/docs/index.txt b/docs/index.txt index 7bc1d091660e759c7e5928e05fab5f27abfe5329..1c46bc9bcfdd440e1321ef86f8cbe57fe2966144 100644 --- a/docs/index.txt +++ b/docs/index.txt @@ -18,6 +18,7 @@ Welcome to bilby's documentation! samplers dynesty-guide bilby-mcmc-guide + plugins bilby-output compact-binary-coalescence-parameter-estimation transient-gw-data diff --git a/docs/plugins.txt b/docs/plugins.txt new file mode 100644 index 0000000000000000000000000000000000000000..38422e646c62096ff27719e0f9573f173e7f4734 --- /dev/null +++ b/docs/plugins.txt @@ -0,0 +1,80 @@ +======= +Plugins +======= + +---------------- +Defining plugins +---------------- + +:code:`bilby` allows for additional customizations/extra features via plugins. +This allows users to add new functionality without the need to modify the main +:code:`bilby` codebase, for example to add a new sampler. + +To make your plugins discoverable to :code:`bilby`, you need to specify a plugin +group (which :code:`bilby` knows to search for), a name for the plugin, and the +python path to your function/class within your package metadata, see `here +<https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata>`_ +for details. For example, if you have a package called :code:`mypackage` and +you wish to add a plugin called :code:`my_awesome_plugin` within the group +:code:`bilby.plugin`, you would specify the following in your `pyproject.toml +<https://packaging.python.org/en/latest/guides/writing-pyproject-toml/>`_ +file:: + + [project.entry-points."bilby.plugin"] + my_awesome_plugin = "mypackage.plugin" + +Currently :code:`bilby` allows for the following plugin groups: + +- :code:`"bilby.samplers"`: group for adding samplers to :code:`bilby`. See :ref:`Sampler plugins` for more details. + + +--------------- +Sampler plugins +--------------- + +Sampler plugins can specified via the :code:`"bilby.samplers"` group and these +are automatically added to the 'known' samplers in :code:`bilby`. +This allows users to add support for new samplers without having to modify the +core :code:`bilby` codebase. +Sampler plugins should implement a sampler class that in inherits from one of +the following classes: + +- :py:class:`bilby.core.sampler.base_sampler.Sampler` +- :py:class:`bilby.core.sampler.base_sampler.NestedSampler` +- :py:class:`bilby.core.sampler.base_sampler.MCMCSampler` + +We provide a `template <https://github.com/bilby-plugins/sampler-template>`_ +for creating sampler plugins on GitHub. + +.. note:: + When implementing a new sampler plugin, please avoid using a generic name for + the plugin (e.g. 'nest', 'mcmc') as this may lead to naming conflicts. + + +Sampler plugin library +---------------------- + +This is a list of known sampler plugins. if you don't see your plugin listed +here, we encourage you to open a +`merge request <https://git.ligo.org/lscsoft/bilby/-/merge_requests/new>`_ to add it. + +- This could be your sampler + + +-------------------------------- +Information for bilby developers +-------------------------------- + +Using plugins within bilby +-------------------------- + +Within :code:`bilby`, plugins are discovered with the +:py:func:`bilby.core.utils.get_entry_points` function, +and can be used throughout the :code:`bilby` infrastructure. + +Adding a new plugin group +------------------------- + +If you want to add support for a new plugin group, please +`open an issue <https://git.ligo.org/lscsoft/bilby/-/issues/new>`_ +to discuss the details with other developers. diff --git a/requirements.txt b/requirements.txt index 73c9d23b209c0ca70aa52ed981b9da5b307e1326..336b9e00088d0da761d2182c6924041ff134c93b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ dill tqdm h5py attrs +importlib-metadata>=3.6; python_version < '3.10' diff --git a/setup.py b/setup.py index 7a1523eaa513126baa31ffaee0ea4828a7658ff6..b929d5bb41f4f17571b09bdffe75e545d894ba8a 100644 --- a/setup.py +++ b/setup.py @@ -82,7 +82,26 @@ setup( "console_scripts": [ "bilby_plot=cli_bilby.plot_multiple_posteriors:main", "bilby_result=cli_bilby.bilby_result:main", - ] + ], + "bilby.samplers": [ + "bilby_mcmc=bilby.bilby_mcmc.sampler:Bilby_MCMC", + "cpnest=bilby.core.sampler.cpnest:Cpnest", + "dnest4=bilby.core.sampler.dnest4:DNest4", + "dynesty=bilby.core.sampler.dynesty:Dynesty", + "dynamic_dynesty=bilby.core.sampler.dynamic_dynesty:DynamicDynesty", + "emcee=bilby.core.sampler.emcee:Emcee", + "kombine=bilby.core.sampler.kombine:Kombine", + "nessai=bilby.core.sampler.nessai:Nessai", + "nestle=bilby.core.sampler.nestle:Nestle", + "ptemcee=bilby.core.sampler.ptemcee:Ptemcee", + "ptmcmcsampler=bilby.core.sampler.ptmcmc:PTMCMCSampler", + "pymc=bilby.core.sampler.pymc:Pymc", + "pymultinest=bilby.core.sampler.pymultinest:Pymultinest", + "pypolychord=bilby.core.sampler.polychord:PyPolyChord", + "ultranest=bilby.core.sampler.ultranest:Ultranest", + "zeus=bilby.core.sampler.zeus:Zeus", + "fake_sampler=bilby.core.sampler.fake_sampler:FakeSampler", + ], }, classifiers=[ "Programming Language :: Python :: 3.9", diff --git a/test/core/sampler/base_sampler_test.py b/test/core/sampler/base_sampler_test.py index 47cc2003e08da4bd972275de44a57682b8aa79cb..1250fa0d682d126f8d9881c8ef51a254479fe534 100644 --- a/test/core/sampler/base_sampler_test.py +++ b/test/core/sampler/base_sampler_test.py @@ -10,6 +10,8 @@ import numpy as np import bilby from bilby.core import prior +loaded_samplers = {k: v.load() for k, v in bilby.core.sampler.IMPLEMENTED_SAMPLERS.items()} + class TestSampler(unittest.TestCase): def setUp(self, soft_init=False): @@ -170,9 +172,7 @@ class GenericSamplerTest(unittest.TestCase): @parameterized.expand(samplers) def test_pool_creates_properly_no_pool(self, sampler_name): - sampler = bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler_name]( - self.likelihood, self.priors - ) + sampler = loaded_samplers[sampler_name](self.likelihood, self.priors) sampler._setup_pool() if sampler_name == "kombine": from kombine import SerialPool @@ -184,7 +184,7 @@ class GenericSamplerTest(unittest.TestCase): @parameterized.expand(samplers) def test_pool_creates_properly_pool(self, sampler): - sampler = bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler]( + sampler = loaded_samplers[sampler]( self.likelihood, self.priors, npool=2 ) sampler._setup_pool() diff --git a/test/core/sampler/cpnest_test.py b/test/core/sampler/cpnest_test.py index 08a23b0a85276818423f83c583765ac5f6721547..d56412d7aa918eb09df56d73cfd0ddd19f70c325 100644 --- a/test/core/sampler/cpnest_test.py +++ b/test/core/sampler/cpnest_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.cpnest class TestCPNest(unittest.TestCase): @@ -10,7 +11,7 @@ class TestCPNest(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.Cpnest( + self.sampler = bilby.core.sampler.cpnest.Cpnest( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/dnest4_test.py b/test/core/sampler/dnest4_test.py index dac5289f68ec775e1a00189a85244200b85891f8..4ce7fc62d5cc251f0008101a38de6dd5a0681306 100644 --- a/test/core/sampler/dnest4_test.py +++ b/test/core/sampler/dnest4_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.dnest4 class TestDnest4(unittest.TestCase): @@ -10,7 +11,7 @@ class TestDnest4(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.DNest4( + self.sampler = bilby.core.sampler.dnest4.DNest4( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/dynamic_dynesty_test.py b/test/core/sampler/dynamic_dynesty_test.py index f5119affcf63a1d6d23de69a6325e63c487bc8c4..36d7e6b088e4ea882f0eae859cd72ee31226c9ca 100644 --- a/test/core/sampler/dynamic_dynesty_test.py +++ b/test/core/sampler/dynamic_dynesty_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.dynamic_dynesty class TestDynamicDynesty(unittest.TestCase): @@ -10,7 +11,7 @@ class TestDynamicDynesty(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.DynamicDynesty( + self.sampler = bilby.core.sampler.dynamic_dynesty.DynamicDynesty( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/dynesty_test.py b/test/core/sampler/dynesty_test.py index d88ba4de9558bd0848e9ea33e19d890d42cbc9b4..39ac6f2318fd3e4d7f7d2fa0ac2b433126cdfb26 100644 --- a/test/core/sampler/dynesty_test.py +++ b/test/core/sampler/dynesty_test.py @@ -5,6 +5,7 @@ from attr import define import bilby import numpy as np import parameterized +import bilby.core.sampler.dynesty from bilby.core.sampler import dynesty_utils from scipy.stats import gamma, ks_1samp, uniform, powerlaw import shutil @@ -41,7 +42,7 @@ class TestDynesty(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.Dynesty( + self.sampler = bilby.core.sampler.dynesty.Dynesty( self.likelihood, self.priors, outdir="outdir", @@ -84,7 +85,7 @@ class TestDynesty(unittest.TestCase): self.priors["c"] = bilby.core.prior.Prior(boundary=None) self.priors["d"] = bilby.core.prior.Prior(boundary="reflective") self.priors["e"] = bilby.core.prior.Prior(boundary="periodic") - self.sampler = bilby.core.sampler.Dynesty( + self.sampler = bilby.core.sampler.dynesty.Dynesty( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/emcee_test.py b/test/core/sampler/emcee_test.py index 66265e51e44797123a90b77193137c8e1fba4c2b..e22d891d51e848e5e7e98f17def490eee9e5b000 100644 --- a/test/core/sampler/emcee_test.py +++ b/test/core/sampler/emcee_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.emcee class TestEmcee(unittest.TestCase): @@ -10,7 +11,7 @@ class TestEmcee(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.Emcee( + self.sampler = bilby.core.sampler.emcee.Emcee( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/general_sampler_tests.py b/test/core/sampler/general_sampler_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..38700c632d3312a0d8202e93a36081c46be436a2 --- /dev/null +++ b/test/core/sampler/general_sampler_tests.py @@ -0,0 +1,32 @@ +from bilby.core.sampler import ( + get_implemented_samplers, + get_sampler_class, +) +import pytest + + +def test_get_implemented_samplers(): + """Assert the function returns a list of the correct length""" + from bilby.core.sampler import IMPLEMENTED_SAMPLERS + + out = get_implemented_samplers() + assert isinstance(out, list) + assert len(out) == len(IMPLEMENTED_SAMPLERS) + assert "dynesty" in out + + +def test_get_sampler_class(): + """Assert the function returns the correct class""" + from bilby.core.sampler.dynesty import Dynesty + + sampler_class = get_sampler_class("dynesty") + assert sampler_class is Dynesty + + +def test_get_sampler_class_not_implemented(): + """Assert an error is raised if the sampler is not recognized""" + with pytest.raises( + ValueError, + match=r"Sampler not_a_valid_sampler not yet implemented" + ): + get_sampler_class("not_a_valid_sampler") diff --git a/test/core/sampler/kombine_test.py b/test/core/sampler/kombine_test.py index d16eb8c90c7f11f6cf4280e00f8ca0c8e5ebb1a3..0423520561d90215a8c2b5615fd3aa832dbad97f 100644 --- a/test/core/sampler/kombine_test.py +++ b/test/core/sampler/kombine_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.kombine class TestKombine(unittest.TestCase): @@ -10,7 +11,7 @@ class TestKombine(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.Kombine( + self.sampler = bilby.core.sampler.kombine.Kombine( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/nessai_test.py b/test/core/sampler/nessai_test.py index 0cac7a45b24e9174336ed454e11908fd0e0e6555..cca5d22b035480e3b46eb7db7730ef86d82f0239 100644 --- a/test/core/sampler/nessai_test.py +++ b/test/core/sampler/nessai_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock, patch, mock_open import bilby +import bilby.core.sampler.nessai class TestNessai(unittest.TestCase): @@ -12,7 +13,7 @@ class TestNessai(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.Nessai( + self.sampler = bilby.core.sampler.nessai.Nessai( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/nestle_test.py b/test/core/sampler/nestle_test.py index e5623ef336552a0b3ba8936c49d8704fad22ee0d..f6f8a698cc3274715cb2453e97aeb157535149fe 100644 --- a/test/core/sampler/nestle_test.py +++ b/test/core/sampler/nestle_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.nestle class TestNestle(unittest.TestCase): @@ -10,7 +11,7 @@ class TestNestle(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.Nestle( + self.sampler = bilby.core.sampler.nestle.Nestle( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/ptemcee_test.py b/test/core/sampler/ptemcee_test.py index 65c49c4e28ea1529d0a879a2106b8ad8f3da0bb9..ec135eeef09bf5bdaa0fb3c7ff87b8fd7ea536da 100644 --- a/test/core/sampler/ptemcee_test.py +++ b/test/core/sampler/ptemcee_test.py @@ -2,7 +2,7 @@ import unittest from bilby.core.likelihood import GaussianLikelihood from bilby.core.prior import Uniform, PriorDict -from bilby.core.sampler import Ptemcee +from bilby.core.sampler.ptemcee import Ptemcee from bilby.core.sampler.base_sampler import MCMCSampler import numpy as np diff --git a/test/core/sampler/pymc_test.py b/test/core/sampler/pymc_test.py index 3ef4fac80826a53768ce0e876dceb300ec48784a..15e3275f44efc9852220fa451ce99b9494c28b38 100644 --- a/test/core/sampler/pymc_test.py +++ b/test/core/sampler/pymc_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.pymc class TestPyMC(unittest.TestCase): @@ -10,7 +11,7 @@ class TestPyMC(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.Pymc( + self.sampler = bilby.core.sampler.pymc.Pymc( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/pymultinest_test.py b/test/core/sampler/pymultinest_test.py index 8ffcef6745b89ed350dd1bc00a784b5e9999585d..7ec64b4867f47dee216770559505e2ebebcd4f9b 100644 --- a/test/core/sampler/pymultinest_test.py +++ b/test/core/sampler/pymultinest_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.pymultinest class TestPymultinest(unittest.TestCase): @@ -12,7 +13,7 @@ class TestPymultinest(unittest.TestCase): ) self.priors["a"] = bilby.core.prior.Prior(boundary="periodic") self.priors["b"] = bilby.core.prior.Prior(boundary="reflective") - self.sampler = bilby.core.sampler.Pymultinest( + self.sampler = bilby.core.sampler.pymultinest.Pymultinest( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/ultranest_test.py b/test/core/sampler/ultranest_test.py index be22c1a1f50b8d304000fcb8d0e4816e57c9c1b9..c0219295bc78c74de302d7a5334da1ccc9bb5173 100644 --- a/test/core/sampler/ultranest_test.py +++ b/test/core/sampler/ultranest_test.py @@ -3,6 +3,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.ultranest class TestUltranest(unittest.TestCase): @@ -15,10 +16,15 @@ class TestUltranest(unittest.TestCase): b=bilby.core.prior.Uniform(0, 1))) self.priors["a"] = bilby.core.prior.Prior(boundary="periodic") self.priors["b"] = bilby.core.prior.Prior(boundary="reflective") - self.sampler = bilby.core.sampler.Ultranest(self.likelihood, self.priors, - outdir="outdir", label="label", - use_ratio=False, plot=False, - skip_import_verification=True) + self.sampler = bilby.core.sampler.ultranest.Ultranest( + self.likelihood, + self.priors, + outdir="outdir", + label="label", + use_ratio=False, + plot=False, + skip_import_verification=True, + ) def tearDown(self): del self.likelihood diff --git a/test/core/sampler/zeus_test.py b/test/core/sampler/zeus_test.py index 2b3e2b5dea14dfe19c6b8930bb353d6871d73409..0f8dea9b1269fd43191e12a5bf89d49bee8d9cee 100644 --- a/test/core/sampler/zeus_test.py +++ b/test/core/sampler/zeus_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.zeus class TestZeus(unittest.TestCase): @@ -10,7 +11,7 @@ class TestZeus(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.Zeus( + self.sampler = bilby.core.sampler.zeus.Zeus( self.likelihood, self.priors, outdir="outdir", diff --git a/test/integration/sampler_run_test.py b/test/integration/sampler_run_test.py index 00678539b8dd83dfa499a9ffc008dc53cac44c9e..79a61a7b42f57a10ab07476029ba222f2978439c 100644 --- a/test/integration/sampler_run_test.py +++ b/test/integration/sampler_run_test.py @@ -66,6 +66,8 @@ sampler_imports = dict( no_pool_test = ["dnest4", "pymultinest", "nestle", "ptmcmcsampler", "ultranest", "pymc"] +loaded_samplers = {k: v.load() for k, v in bilby.core.sampler.IMPLEMENTED_SAMPLERS.items()} + def slow_func(x, m, c): time.sleep(0.01) @@ -154,7 +156,7 @@ class TestRunningSamplers(unittest.TestCase): def _run_with_signal_handling(self, sampler, pool_size=1): pytest.importorskip(sampler_imports.get(sampler, sampler)) - if bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler.lower()].hard_exit: + if loaded_samplers[sampler.lower()].hard_exit: pytest.skip(f"{sampler} hard exits, can't test signal handling.") if pool_size > 1 and sampler.lower() in no_pool_test: pytest.skip(f"{sampler} cannot be parallelized") diff --git a/test/test_samplers_import.py b/test/test_samplers_import.py index d607d87f3370b5787c0fa13eded78c41ff8aca8d..1cab28a76b229bd6e8001be7ba9ef7b7118f3ab3 100644 --- a/test/test_samplers_import.py +++ b/test/test_samplers_import.py @@ -1,17 +1,21 @@ -""" -Tests that all of the implemented samplers can be initialized. - -The :code:`FakeSampler` is omitted as that doesn't require importing -any package. -""" import bilby +import pytest + + +@pytest.mark.parametrize( + "sampler_name", bilby.core.sampler.IMPLEMENTED_SAMPLERS.keys() +) +def test_sampler_import(sampler_name): + """ + Tests that all of the implemented samplers can be initialized. -bilby.core.utils.logger.setLevel("ERROR") -IMPLEMENTED_SAMPLERS = bilby.core.sampler.IMPLEMENTED_SAMPLERS -likelihood = bilby.core.likelihood.Likelihood(dict()) -priors = bilby.core.prior.PriorDict(dict(a=bilby.core.prior.Uniform(0, 1))) -for sampler in IMPLEMENTED_SAMPLERS: - if sampler in ["fake_sampler", "pypolychord"]: - continue - sampler_class = IMPLEMENTED_SAMPLERS[sampler] + Do not test :code:`FakeSampler` since it requires an additional argument. + """ + if sampler_name in ["fake_sampler", "pypolychord"]: + pytest.skip(f"Skipping import test for {sampler_name}") + bilby.core.utils.logger.setLevel("ERROR") + likelihood = bilby.core.likelihood.Likelihood(dict()) + priors = bilby.core.prior.PriorDict(dict(a=bilby.core.prior.Uniform(0, 1))) + sampler_class = bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler_name].load() sampler = sampler_class(likelihood=likelihood, priors=priors) + assert sampler is not None