diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index 5d56b3ea5ce4585e33637ced1ef6ce73382a67ab..ef65ce1dddd70b2d57aaf07c15cfe999a47ba901 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -13,7 +13,74 @@ from ..utils import ( from . import proposal from .base_sampler import Sampler, SamplingMarginalisedParameterError -IMPLEMENTED_SAMPLERS = get_entry_points("bilby.samplers") + +class ImplementedSamplers: + """Dictionary-like object that contains implemented samplers. + + This class is singleton and only one instance can exist. + """ + + _instance = None + + _samplers = get_entry_points("bilby.samplers") + + def keys(self): + """Iterator of available samplers by name. + + Reduces the list to its simplest. This includes removing the 'bilby.' + prefix from native samplers if a corresponding plugin is not available. + """ + keys = [] + for key in self._samplers.keys(): + name = key.replace("bilby.", "") + if name in self._samplers.keys(): + keys.append(key) + else: + keys.append(name) + return iter(keys) + + def values(self): + """Iterator of sampler classes. + + Note: the classes need to loaded using :code:`.load()` before being + called. + """ + return iter(self._samplers.values()) + + def items(self): + """Iterator of tuples containing keys (sampler names) and classes. + + Note: the classes need to loaded using :code:`.load()` before being + called. + """ + return iter(((k, v) for k, v in zip(self.keys(), self.values()))) + + def valid_keys(self): + """All valid keys including bilby.<sampler name>.""" + keys = set(self._samplers.keys()) + return iter(keys.union({k.replace("bilby.", "") for k in keys})) + + def __getitem__(self, key): + if key in self._samplers: + return self._samplers[key] + elif f"bilby.{key}" in self._samplers: + return self._samplers[f"bilby.{key}"] + else: + raise ValueError( + f"Sampler {key} is not implemented! " + f"Available samplers are: {list(self.keys())}" + ) + + def __contains__(self, value): + return value in self.valid_keys() + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + +IMPLEMENTED_SAMPLERS = ImplementedSamplers() def get_implemented_samplers(): @@ -51,14 +118,7 @@ def get_sampler_class(sampler): ValueError Raised if the sampler is not implemented. """ - sampler = sampler.lower() - if sampler in IMPLEMENTED_SAMPLERS: - return IMPLEMENTED_SAMPLERS[sampler].load() - else: - raise ValueError( - f"Sampler {sampler} not yet implemented. " - f"The available samplers are: {get_implemented_samplers()}" - ) + return IMPLEMENTED_SAMPLERS[sampler.lower()].load() if command_line_args.sampler_help: diff --git a/docs/plugins.txt b/docs/plugins.txt index 38422e646c62096ff27719e0f9573f173e7f4734..a73f0ce4931e2c6a0457342e98117b8c84963a2c 100644 --- a/docs/plugins.txt +++ b/docs/plugins.txt @@ -61,6 +61,19 @@ here, we encourage you to open a - This could be your sampler +Bilby-native samplers +--------------------- + +Some samplers are implemented directly in :code:`bilby` and these are avertised +under two possible names: + +- :code:`bilby.<sampler name>`: always available, indicates the sampler is implemented in bilby, +- :code:`<sampler name>`: only refers to the native bilby implementation if an external plugin does not already provide this sampler. + +This allows for an external plugin to provide a sampler without introducing +namespace conflicts. + + -------------------------------- Information for bilby developers -------------------------------- diff --git a/docs/samplers.txt b/docs/samplers.txt index 69699dad0e3b22725b4d3f3f6d942763406df0da..7b4a190f1b972257ec05a51ee45ca3070350edcd 100644 --- a/docs/samplers.txt +++ b/docs/samplers.txt @@ -73,6 +73,18 @@ MCMC samplers - zeus :code:`bilby.core.sampler.zeus.Zeus` +-------------------------- +Listing available samplers +-------------------------- + +A list of available samplers can be produced using +:py:func:`bilby.core.sampler.get_implemented_samplers`. +This will list native bilby samplers and any samplers available via a plugin. +If a plugin provides a sampler that is also implemented in bilby, the bilby +implementation will be labeled with the prfix `bilby.` to distinguish it from +the plugin version. See `sampler plugins`_ for more details. + + ------------------- Installing samplers ------------------- diff --git a/setup.py b/setup.py index b929d5bb41f4f17571b09bdffe75e545d894ba8a..353e92ee443d1e6e90eb104e07c15ec063229020 100644 --- a/setup.py +++ b/setup.py @@ -84,23 +84,23 @@ setup( "bilby_result=cli_bilby.bilby_result:main", ], "bilby.samplers": [ - "bilby_mcmc=bilby.bilby_mcmc.sampler:Bilby_MCMC", - "cpnest=bilby.core.sampler.cpnest:Cpnest", - "dnest4=bilby.core.sampler.dnest4:DNest4", - "dynesty=bilby.core.sampler.dynesty:Dynesty", - "dynamic_dynesty=bilby.core.sampler.dynamic_dynesty:DynamicDynesty", - "emcee=bilby.core.sampler.emcee:Emcee", - "kombine=bilby.core.sampler.kombine:Kombine", - "nessai=bilby.core.sampler.nessai:Nessai", - "nestle=bilby.core.sampler.nestle:Nestle", - "ptemcee=bilby.core.sampler.ptemcee:Ptemcee", - "ptmcmcsampler=bilby.core.sampler.ptmcmc:PTMCMCSampler", - "pymc=bilby.core.sampler.pymc:Pymc", - "pymultinest=bilby.core.sampler.pymultinest:Pymultinest", - "pypolychord=bilby.core.sampler.polychord:PyPolyChord", - "ultranest=bilby.core.sampler.ultranest:Ultranest", - "zeus=bilby.core.sampler.zeus:Zeus", - "fake_sampler=bilby.core.sampler.fake_sampler:FakeSampler", + "bilby.bilby_mcmc=bilby.bilby_mcmc.sampler:Bilby_MCMC", + "bilby.cpnest=bilby.core.sampler.cpnest:Cpnest", + "bilby.dnest4=bilby.core.sampler.dnest4:DNest4", + "bilby.dynesty=bilby.core.sampler.dynesty:Dynesty", + "bilby.dynamic_dynesty=bilby.core.sampler.dynamic_dynesty:DynamicDynesty", + "bilby.emcee=bilby.core.sampler.emcee:Emcee", + "bilby.kombine=bilby.core.sampler.kombine:Kombine", + "bilby.nessai=bilby.core.sampler.nessai:Nessai", + "bilby.nestle=bilby.core.sampler.nestle:Nestle", + "bilby.ptemcee=bilby.core.sampler.ptemcee:Ptemcee", + "bilby.ptmcmcsampler=bilby.core.sampler.ptmcmc:PTMCMCSampler", + "bilby.pymc=bilby.core.sampler.pymc:Pymc", + "bilby.pymultinest=bilby.core.sampler.pymultinest:Pymultinest", + "bilby.pypolychord=bilby.core.sampler.polychord:PyPolyChord", + "bilby.ultranest=bilby.core.sampler.ultranest:Ultranest", + "bilby.zeus=bilby.core.sampler.zeus:Zeus", + "bilby.fake_sampler=bilby.core.sampler.fake_sampler:FakeSampler", ], }, classifiers=[ diff --git a/test/core/sampler/implemented_samplers_test.py b/test/core/sampler/implemented_samplers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c34398185051997efef52c2f8e967bb866072d39 --- /dev/null +++ b/test/core/sampler/implemented_samplers_test.py @@ -0,0 +1,39 @@ +from bilby.core.sampler import IMPLEMENTED_SAMPLERS, ImplementedSamplers +import pytest + + +def test_singleton(): + assert ImplementedSamplers() is IMPLEMENTED_SAMPLERS + + +def test_keys(): + # The fake sampler should never have a plugin, so this should always work + assert "fake_sampler" in IMPLEMENTED_SAMPLERS.keys() + assert "bilby.fake_sampler" not in IMPLEMENTED_SAMPLERS.keys() + + +def test_allowed_keys(): + # The fake sampler should never have a plugin, so this should always work + assert "fake_sampler" in IMPLEMENTED_SAMPLERS.valid_keys() + assert "bilby.fake_sampler" in IMPLEMENTED_SAMPLERS.valid_keys() + + +def test_values(): + # Values and keys should have the same lengths + assert len(list(IMPLEMENTED_SAMPLERS.values())) \ + == len(list(IMPLEMENTED_SAMPLERS.keys())) + assert len(list(IMPLEMENTED_SAMPLERS.values())) \ + == len(list(IMPLEMENTED_SAMPLERS._samplers.values())) + + +def test_items(): + keys, values = list(zip(*IMPLEMENTED_SAMPLERS.items())) + assert len(keys) == len(values) + # Keys and values should be the same as the individual methods + assert list(keys) == list(IMPLEMENTED_SAMPLERS.keys()) + assert list(values) == list(IMPLEMENTED_SAMPLERS.values()) + + +@pytest.mark.parametrize("sampler", ["fake_sampler", "bilby.fake_sampler"]) +def test_in_operator(sampler): + assert sampler in IMPLEMENTED_SAMPLERS