From e310aec5885a46c17f9ad20dc341fadd1c4b3208 Mon Sep 17 00:00:00 2001 From: Michael Williams <michael.williams@ligo.org> Date: Fri, 15 Mar 2024 22:22:59 +0000 Subject: [PATCH] ENH: Add support for sampler plugins --- .gitlab-ci.yml | 2 +- bilby/core/sampler/__init__.py | 133 +++++++++++---------- bilby/core/utils/__init__.py | 1 + bilby/core/utils/entry_points.py | 18 +++ docs/index.txt | 1 + docs/plugins.txt | 80 +++++++++++++ requirements.txt | 1 + setup.py | 21 +++- test/core/sampler/base_sampler_test.py | 8 +- test/core/sampler/cpnest_test.py | 3 +- test/core/sampler/dnest4_test.py | 3 +- test/core/sampler/dynamic_dynesty_test.py | 3 +- test/core/sampler/dynesty_test.py | 5 +- test/core/sampler/emcee_test.py | 3 +- test/core/sampler/general_sampler_tests.py | 32 +++++ test/core/sampler/kombine_test.py | 3 +- test/core/sampler/nessai_test.py | 3 +- test/core/sampler/nestle_test.py | 3 +- test/core/sampler/ptemcee_test.py | 2 +- test/core/sampler/pymc_test.py | 3 +- test/core/sampler/pymultinest_test.py | 3 +- test/core/sampler/ultranest_test.py | 14 ++- test/core/sampler/zeus_test.py | 3 +- test/integration/sampler_run_test.py | 4 +- test/test_samplers_import.py | 32 ++--- 25 files changed, 285 insertions(+), 99 deletions(-) create mode 100644 bilby/core/utils/entry_points.py create mode 100644 docs/plugins.txt create mode 100644 test/core/sampler/general_sampler_tests.py diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 13c4774c2..3fcef7103 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -84,7 +84,7 @@ basic-3.11: script: - python -m pip install . - *list-env - - python test/test_samplers_import.py + - pytest test/test_samplers_import.py -v import-samplers-3.9: <<: *test-samplers-import diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index ab0c8f154..5d56b3ea5 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -2,54 +2,69 @@ import datetime import inspect import sys -import bilby -from bilby.bilby_mcmc import Bilby_MCMC - from ..prior import DeltaFunction, PriorDict -from ..utils import command_line_args, env_package_list, loaded_modules_dict, logger +from ..utils import ( + command_line_args, + env_package_list, + get_entry_points, + loaded_modules_dict, + logger, +) from . import proposal from .base_sampler import Sampler, SamplingMarginalisedParameterError -from .cpnest import Cpnest -from .dnest4 import DNest4 -from .dynamic_dynesty import DynamicDynesty -from .dynesty import Dynesty -from .emcee import Emcee -from .fake_sampler import FakeSampler -from .kombine import Kombine -from .nessai import Nessai -from .nestle import Nestle -from .polychord import PyPolyChord -from .ptemcee import Ptemcee -from .ptmcmc import PTMCMCSampler -from .pymc import Pymc -from .pymultinest import Pymultinest -from .ultranest import Ultranest -from .zeus import Zeus - -IMPLEMENTED_SAMPLERS = { - "bilby_mcmc": Bilby_MCMC, - "cpnest": Cpnest, - "dnest4": DNest4, - "dynamic_dynesty": DynamicDynesty, - "dynesty": Dynesty, - "emcee": Emcee, - "kombine": Kombine, - "nessai": Nessai, - "nestle": Nestle, - "ptemcee": Ptemcee, - "ptmcmcsampler": PTMCMCSampler, - "pymc": Pymc, - "pymultinest": Pymultinest, - "pypolychord": PyPolyChord, - "ultranest": Ultranest, - "zeus": Zeus, - "fake_sampler": FakeSampler, -} + +IMPLEMENTED_SAMPLERS = get_entry_points("bilby.samplers") + + +def get_implemented_samplers(): + """Get a list of the names of the implemented samplers. + + This includes natively supported samplers (e.g. dynesty) and any additional + samplers that are supported through the sampler plugins. + + Returns + ------- + list + The list of implemented samplers. + """ + return list(IMPLEMENTED_SAMPLERS.keys()) + + +def get_sampler_class(sampler): + """Get the class for a sampler from its name. + + This includes natively supported samplers (e.g. dynesty) and any additional + samplers that are supported through the sampler plugins. + + Parameters + ---------- + sampler : str + The name of the sampler. + + Returns + ------- + Sampler + The sampler class. + + Raises + ------ + ValueError + Raised if the sampler is not implemented. + """ + sampler = sampler.lower() + if sampler in IMPLEMENTED_SAMPLERS: + return IMPLEMENTED_SAMPLERS[sampler].load() + else: + raise ValueError( + f"Sampler {sampler} not yet implemented. " + f"The available samplers are: {get_implemented_samplers()}" + ) + if command_line_args.sampler_help: sampler = command_line_args.sampler_help if sampler in IMPLEMENTED_SAMPLERS: - sampler_class = IMPLEMENTED_SAMPLERS[sampler] + sampler_class = IMPLEMENTED_SAMPLERS[sampler].load() print(f'Help for sampler "{sampler}":') print(sampler_class.__doc__) else: @@ -60,7 +75,7 @@ if command_line_args.sampler_help: ) else: print(f"Requested sampler {sampler} not implemented") - print(f"Available samplers = {IMPLEMENTED_SAMPLERS}") + print(f"Available samplers = {get_implemented_samplers()}") sys.exit() @@ -185,24 +200,20 @@ def run_sampler( if isinstance(sampler, Sampler): pass elif isinstance(sampler, str): - if sampler.lower() in IMPLEMENTED_SAMPLERS: - sampler_class = IMPLEMENTED_SAMPLERS[sampler.lower()] - sampler = sampler_class( - likelihood, - priors=priors, - outdir=outdir, - label=label, - injection_parameters=injection_parameters, - meta_data=meta_data, - use_ratio=use_ratio, - plot=plot, - result_class=result_class, - npool=npool, - **kwargs, - ) - else: - print(IMPLEMENTED_SAMPLERS) - raise ValueError(f"Sampler {sampler} not yet implemented") + sampler_class = get_sampler_class(sampler) + sampler = sampler_class( + likelihood, + priors=priors, + outdir=outdir, + label=label, + injection_parameters=injection_parameters, + meta_data=meta_data, + use_ratio=use_ratio, + plot=plot, + result_class=result_class, + npool=npool, + **kwargs, + ) elif inspect.isclass(sampler): sampler = sampler.__init__( likelihood, @@ -219,7 +230,7 @@ def run_sampler( else: raise ValueError( "Provided sampler should be a Sampler object or name of a known " - f"sampler: {', '.join(IMPLEMENTED_SAMPLERS.keys())}." + f"sampler: {get_implemented_samplers()}." ) if sampler.cached_result: diff --git a/bilby/core/utils/__init__.py b/bilby/core/utils/__init__.py index 25d1eda93..bb5991564 100644 --- a/bilby/core/utils/__init__.py +++ b/bilby/core/utils/__init__.py @@ -6,6 +6,7 @@ from .constants import * from .conversion import * from .counter import * from .docs import * +from .entry_points import * from .introspection import * from .io import * from .log import * diff --git a/bilby/core/utils/entry_points.py b/bilby/core/utils/entry_points.py new file mode 100644 index 000000000..305fc5704 --- /dev/null +++ b/bilby/core/utils/entry_points.py @@ -0,0 +1,18 @@ +import sys +if sys.version_info < (3, 10): + from importlib_metadata import entry_points +else: + from importlib.metadata import entry_points + + +def get_entry_points(group): + """Return a dictionary of entry points for a given group + + Parameters + ---------- + group: str + Entry points you wish to query + """ + return { + custom.name: custom for custom in entry_points(group=group) + } diff --git a/docs/index.txt b/docs/index.txt index 7bc1d0916..1c46bc9bc 100644 --- a/docs/index.txt +++ b/docs/index.txt @@ -18,6 +18,7 @@ Welcome to bilby's documentation! samplers dynesty-guide bilby-mcmc-guide + plugins bilby-output compact-binary-coalescence-parameter-estimation transient-gw-data diff --git a/docs/plugins.txt b/docs/plugins.txt new file mode 100644 index 000000000..38422e646 --- /dev/null +++ b/docs/plugins.txt @@ -0,0 +1,80 @@ +======= +Plugins +======= + +---------------- +Defining plugins +---------------- + +:code:`bilby` allows for additional customizations/extra features via plugins. +This allows users to add new functionality without the need to modify the main +:code:`bilby` codebase, for example to add a new sampler. + +To make your plugins discoverable to :code:`bilby`, you need to specify a plugin +group (which :code:`bilby` knows to search for), a name for the plugin, and the +python path to your function/class within your package metadata, see `here +<https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata>`_ +for details. For example, if you have a package called :code:`mypackage` and +you wish to add a plugin called :code:`my_awesome_plugin` within the group +:code:`bilby.plugin`, you would specify the following in your `pyproject.toml +<https://packaging.python.org/en/latest/guides/writing-pyproject-toml/>`_ +file:: + + [project.entry-points."bilby.plugin"] + my_awesome_plugin = "mypackage.plugin" + +Currently :code:`bilby` allows for the following plugin groups: + +- :code:`"bilby.samplers"`: group for adding samplers to :code:`bilby`. See :ref:`Sampler plugins` for more details. + + +--------------- +Sampler plugins +--------------- + +Sampler plugins can specified via the :code:`"bilby.samplers"` group and these +are automatically added to the 'known' samplers in :code:`bilby`. +This allows users to add support for new samplers without having to modify the +core :code:`bilby` codebase. +Sampler plugins should implement a sampler class that in inherits from one of +the following classes: + +- :py:class:`bilby.core.sampler.base_sampler.Sampler` +- :py:class:`bilby.core.sampler.base_sampler.NestedSampler` +- :py:class:`bilby.core.sampler.base_sampler.MCMCSampler` + +We provide a `template <https://github.com/bilby-plugins/sampler-template>`_ +for creating sampler plugins on GitHub. + +.. note:: + When implementing a new sampler plugin, please avoid using a generic name for + the plugin (e.g. 'nest', 'mcmc') as this may lead to naming conflicts. + + +Sampler plugin library +---------------------- + +This is a list of known sampler plugins. if you don't see your plugin listed +here, we encourage you to open a +`merge request <https://git.ligo.org/lscsoft/bilby/-/merge_requests/new>`_ to add it. + +- This could be your sampler + + +-------------------------------- +Information for bilby developers +-------------------------------- + +Using plugins within bilby +-------------------------- + +Within :code:`bilby`, plugins are discovered with the +:py:func:`bilby.core.utils.get_entry_points` function, +and can be used throughout the :code:`bilby` infrastructure. + +Adding a new plugin group +------------------------- + +If you want to add support for a new plugin group, please +`open an issue <https://git.ligo.org/lscsoft/bilby/-/issues/new>`_ +to discuss the details with other developers. diff --git a/requirements.txt b/requirements.txt index 73c9d23b2..336b9e000 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ dill tqdm h5py attrs +importlib-metadata>=3.6; python_version < '3.10' diff --git a/setup.py b/setup.py index 7a1523eaa..b929d5bb4 100644 --- a/setup.py +++ b/setup.py @@ -82,7 +82,26 @@ setup( "console_scripts": [ "bilby_plot=cli_bilby.plot_multiple_posteriors:main", "bilby_result=cli_bilby.bilby_result:main", - ] + ], + "bilby.samplers": [ + "bilby_mcmc=bilby.bilby_mcmc.sampler:Bilby_MCMC", + "cpnest=bilby.core.sampler.cpnest:Cpnest", + "dnest4=bilby.core.sampler.dnest4:DNest4", + "dynesty=bilby.core.sampler.dynesty:Dynesty", + "dynamic_dynesty=bilby.core.sampler.dynamic_dynesty:DynamicDynesty", + "emcee=bilby.core.sampler.emcee:Emcee", + "kombine=bilby.core.sampler.kombine:Kombine", + "nessai=bilby.core.sampler.nessai:Nessai", + "nestle=bilby.core.sampler.nestle:Nestle", + "ptemcee=bilby.core.sampler.ptemcee:Ptemcee", + "ptmcmcsampler=bilby.core.sampler.ptmcmc:PTMCMCSampler", + "pymc=bilby.core.sampler.pymc:Pymc", + "pymultinest=bilby.core.sampler.pymultinest:Pymultinest", + "pypolychord=bilby.core.sampler.polychord:PyPolyChord", + "ultranest=bilby.core.sampler.ultranest:Ultranest", + "zeus=bilby.core.sampler.zeus:Zeus", + "fake_sampler=bilby.core.sampler.fake_sampler:FakeSampler", + ], }, classifiers=[ "Programming Language :: Python :: 3.9", diff --git a/test/core/sampler/base_sampler_test.py b/test/core/sampler/base_sampler_test.py index 47cc2003e..1250fa0d6 100644 --- a/test/core/sampler/base_sampler_test.py +++ b/test/core/sampler/base_sampler_test.py @@ -10,6 +10,8 @@ import numpy as np import bilby from bilby.core import prior +loaded_samplers = {k: v.load() for k, v in bilby.core.sampler.IMPLEMENTED_SAMPLERS.items()} + class TestSampler(unittest.TestCase): def setUp(self, soft_init=False): @@ -170,9 +172,7 @@ class GenericSamplerTest(unittest.TestCase): @parameterized.expand(samplers) def test_pool_creates_properly_no_pool(self, sampler_name): - sampler = bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler_name]( - self.likelihood, self.priors - ) + sampler = loaded_samplers[sampler_name](self.likelihood, self.priors) sampler._setup_pool() if sampler_name == "kombine": from kombine import SerialPool @@ -184,7 +184,7 @@ class GenericSamplerTest(unittest.TestCase): @parameterized.expand(samplers) def test_pool_creates_properly_pool(self, sampler): - sampler = bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler]( + sampler = loaded_samplers[sampler]( self.likelihood, self.priors, npool=2 ) sampler._setup_pool() diff --git a/test/core/sampler/cpnest_test.py b/test/core/sampler/cpnest_test.py index 08a23b0a8..d56412d7a 100644 --- a/test/core/sampler/cpnest_test.py +++ b/test/core/sampler/cpnest_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.cpnest class TestCPNest(unittest.TestCase): @@ -10,7 +11,7 @@ class TestCPNest(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.Cpnest( + self.sampler = bilby.core.sampler.cpnest.Cpnest( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/dnest4_test.py b/test/core/sampler/dnest4_test.py index dac5289f6..4ce7fc62d 100644 --- a/test/core/sampler/dnest4_test.py +++ b/test/core/sampler/dnest4_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.dnest4 class TestDnest4(unittest.TestCase): @@ -10,7 +11,7 @@ class TestDnest4(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.DNest4( + self.sampler = bilby.core.sampler.dnest4.DNest4( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/dynamic_dynesty_test.py b/test/core/sampler/dynamic_dynesty_test.py index f5119affc..36d7e6b08 100644 --- a/test/core/sampler/dynamic_dynesty_test.py +++ b/test/core/sampler/dynamic_dynesty_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.dynamic_dynesty class TestDynamicDynesty(unittest.TestCase): @@ -10,7 +11,7 @@ class TestDynamicDynesty(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.DynamicDynesty( + self.sampler = bilby.core.sampler.dynamic_dynesty.DynamicDynesty( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/dynesty_test.py b/test/core/sampler/dynesty_test.py index d88ba4de9..39ac6f231 100644 --- a/test/core/sampler/dynesty_test.py +++ b/test/core/sampler/dynesty_test.py @@ -5,6 +5,7 @@ from attr import define import bilby import numpy as np import parameterized +import bilby.core.sampler.dynesty from bilby.core.sampler import dynesty_utils from scipy.stats import gamma, ks_1samp, uniform, powerlaw import shutil @@ -41,7 +42,7 @@ class TestDynesty(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.Dynesty( + self.sampler = bilby.core.sampler.dynesty.Dynesty( self.likelihood, self.priors, outdir="outdir", @@ -84,7 +85,7 @@ class TestDynesty(unittest.TestCase): self.priors["c"] = bilby.core.prior.Prior(boundary=None) self.priors["d"] = bilby.core.prior.Prior(boundary="reflective") self.priors["e"] = bilby.core.prior.Prior(boundary="periodic") - self.sampler = bilby.core.sampler.Dynesty( + self.sampler = bilby.core.sampler.dynesty.Dynesty( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/emcee_test.py b/test/core/sampler/emcee_test.py index 66265e51e..e22d891d5 100644 --- a/test/core/sampler/emcee_test.py +++ b/test/core/sampler/emcee_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.emcee class TestEmcee(unittest.TestCase): @@ -10,7 +11,7 @@ class TestEmcee(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.Emcee( + self.sampler = bilby.core.sampler.emcee.Emcee( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/general_sampler_tests.py b/test/core/sampler/general_sampler_tests.py new file mode 100644 index 000000000..38700c632 --- /dev/null +++ b/test/core/sampler/general_sampler_tests.py @@ -0,0 +1,32 @@ +from bilby.core.sampler import ( + get_implemented_samplers, + get_sampler_class, +) +import pytest + + +def test_get_implemented_samplers(): + """Assert the function returns a list of the correct length""" + from bilby.core.sampler import IMPLEMENTED_SAMPLERS + + out = get_implemented_samplers() + assert isinstance(out, list) + assert len(out) == len(IMPLEMENTED_SAMPLERS) + assert "dynesty" in out + + +def test_get_sampler_class(): + """Assert the function returns the correct class""" + from bilby.core.sampler.dynesty import Dynesty + + sampler_class = get_sampler_class("dynesty") + assert sampler_class is Dynesty + + +def test_get_sampler_class_not_implemented(): + """Assert an error is raised if the sampler is not recognized""" + with pytest.raises( + ValueError, + match=r"Sampler not_a_valid_sampler not yet implemented" + ): + get_sampler_class("not_a_valid_sampler") diff --git a/test/core/sampler/kombine_test.py b/test/core/sampler/kombine_test.py index d16eb8c90..042352056 100644 --- a/test/core/sampler/kombine_test.py +++ b/test/core/sampler/kombine_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.kombine class TestKombine(unittest.TestCase): @@ -10,7 +11,7 @@ class TestKombine(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.Kombine( + self.sampler = bilby.core.sampler.kombine.Kombine( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/nessai_test.py b/test/core/sampler/nessai_test.py index 0cac7a45b..cca5d22b0 100644 --- a/test/core/sampler/nessai_test.py +++ b/test/core/sampler/nessai_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock, patch, mock_open import bilby +import bilby.core.sampler.nessai class TestNessai(unittest.TestCase): @@ -12,7 +13,7 @@ class TestNessai(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.Nessai( + self.sampler = bilby.core.sampler.nessai.Nessai( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/nestle_test.py b/test/core/sampler/nestle_test.py index e5623ef33..f6f8a698c 100644 --- a/test/core/sampler/nestle_test.py +++ b/test/core/sampler/nestle_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.nestle class TestNestle(unittest.TestCase): @@ -10,7 +11,7 @@ class TestNestle(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.Nestle( + self.sampler = bilby.core.sampler.nestle.Nestle( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/ptemcee_test.py b/test/core/sampler/ptemcee_test.py index 65c49c4e2..ec135eeef 100644 --- a/test/core/sampler/ptemcee_test.py +++ b/test/core/sampler/ptemcee_test.py @@ -2,7 +2,7 @@ import unittest from bilby.core.likelihood import GaussianLikelihood from bilby.core.prior import Uniform, PriorDict -from bilby.core.sampler import Ptemcee +from bilby.core.sampler.ptemcee import Ptemcee from bilby.core.sampler.base_sampler import MCMCSampler import numpy as np diff --git a/test/core/sampler/pymc_test.py b/test/core/sampler/pymc_test.py index 3ef4fac80..15e3275f4 100644 --- a/test/core/sampler/pymc_test.py +++ b/test/core/sampler/pymc_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.pymc class TestPyMC(unittest.TestCase): @@ -10,7 +11,7 @@ class TestPyMC(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.Pymc( + self.sampler = bilby.core.sampler.pymc.Pymc( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/pymultinest_test.py b/test/core/sampler/pymultinest_test.py index 8ffcef674..7ec64b486 100644 --- a/test/core/sampler/pymultinest_test.py +++ b/test/core/sampler/pymultinest_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.pymultinest class TestPymultinest(unittest.TestCase): @@ -12,7 +13,7 @@ class TestPymultinest(unittest.TestCase): ) self.priors["a"] = bilby.core.prior.Prior(boundary="periodic") self.priors["b"] = bilby.core.prior.Prior(boundary="reflective") - self.sampler = bilby.core.sampler.Pymultinest( + self.sampler = bilby.core.sampler.pymultinest.Pymultinest( self.likelihood, self.priors, outdir="outdir", diff --git a/test/core/sampler/ultranest_test.py b/test/core/sampler/ultranest_test.py index be22c1a1f..c0219295b 100644 --- a/test/core/sampler/ultranest_test.py +++ b/test/core/sampler/ultranest_test.py @@ -3,6 +3,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.ultranest class TestUltranest(unittest.TestCase): @@ -15,10 +16,15 @@ class TestUltranest(unittest.TestCase): b=bilby.core.prior.Uniform(0, 1))) self.priors["a"] = bilby.core.prior.Prior(boundary="periodic") self.priors["b"] = bilby.core.prior.Prior(boundary="reflective") - self.sampler = bilby.core.sampler.Ultranest(self.likelihood, self.priors, - outdir="outdir", label="label", - use_ratio=False, plot=False, - skip_import_verification=True) + self.sampler = bilby.core.sampler.ultranest.Ultranest( + self.likelihood, + self.priors, + outdir="outdir", + label="label", + use_ratio=False, + plot=False, + skip_import_verification=True, + ) def tearDown(self): del self.likelihood diff --git a/test/core/sampler/zeus_test.py b/test/core/sampler/zeus_test.py index 2b3e2b5de..0f8dea9b1 100644 --- a/test/core/sampler/zeus_test.py +++ b/test/core/sampler/zeus_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import MagicMock import bilby +import bilby.core.sampler.zeus class TestZeus(unittest.TestCase): @@ -10,7 +11,7 @@ class TestZeus(unittest.TestCase): self.priors = bilby.core.prior.PriorDict( dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) ) - self.sampler = bilby.core.sampler.Zeus( + self.sampler = bilby.core.sampler.zeus.Zeus( self.likelihood, self.priors, outdir="outdir", diff --git a/test/integration/sampler_run_test.py b/test/integration/sampler_run_test.py index 00678539b..79a61a7b4 100644 --- a/test/integration/sampler_run_test.py +++ b/test/integration/sampler_run_test.py @@ -66,6 +66,8 @@ sampler_imports = dict( no_pool_test = ["dnest4", "pymultinest", "nestle", "ptmcmcsampler", "ultranest", "pymc"] +loaded_samplers = {k: v.load() for k, v in bilby.core.sampler.IMPLEMENTED_SAMPLERS.items()} + def slow_func(x, m, c): time.sleep(0.01) @@ -154,7 +156,7 @@ class TestRunningSamplers(unittest.TestCase): def _run_with_signal_handling(self, sampler, pool_size=1): pytest.importorskip(sampler_imports.get(sampler, sampler)) - if bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler.lower()].hard_exit: + if loaded_samplers[sampler.lower()].hard_exit: pytest.skip(f"{sampler} hard exits, can't test signal handling.") if pool_size > 1 and sampler.lower() in no_pool_test: pytest.skip(f"{sampler} cannot be parallelized") diff --git a/test/test_samplers_import.py b/test/test_samplers_import.py index d607d87f3..1cab28a76 100644 --- a/test/test_samplers_import.py +++ b/test/test_samplers_import.py @@ -1,17 +1,21 @@ -""" -Tests that all of the implemented samplers can be initialized. - -The :code:`FakeSampler` is omitted as that doesn't require importing -any package. -""" import bilby +import pytest + + +@pytest.mark.parametrize( + "sampler_name", bilby.core.sampler.IMPLEMENTED_SAMPLERS.keys() +) +def test_sampler_import(sampler_name): + """ + Tests that all of the implemented samplers can be initialized. -bilby.core.utils.logger.setLevel("ERROR") -IMPLEMENTED_SAMPLERS = bilby.core.sampler.IMPLEMENTED_SAMPLERS -likelihood = bilby.core.likelihood.Likelihood(dict()) -priors = bilby.core.prior.PriorDict(dict(a=bilby.core.prior.Uniform(0, 1))) -for sampler in IMPLEMENTED_SAMPLERS: - if sampler in ["fake_sampler", "pypolychord"]: - continue - sampler_class = IMPLEMENTED_SAMPLERS[sampler] + Do not test :code:`FakeSampler` since it requires an additional argument. + """ + if sampler_name in ["fake_sampler", "pypolychord"]: + pytest.skip(f"Skipping import test for {sampler_name}") + bilby.core.utils.logger.setLevel("ERROR") + likelihood = bilby.core.likelihood.Likelihood(dict()) + priors = bilby.core.prior.PriorDict(dict(a=bilby.core.prior.Uniform(0, 1))) + sampler_class = bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler_name].load() sampler = sampler_class(likelihood=likelihood, priors=priors) + assert sampler is not None -- GitLab