Skip to content
Snippets Groups Projects
Commit 36f60962 authored by Michael Williams's avatar Michael Williams Committed by Colm Talbot
Browse files

Distinguish bilby native samplers

parent d714ddd4
No related branches found
No related tags found
No related merge requests found
......@@ -13,7 +13,74 @@ from ..utils import (
from . import proposal
from .base_sampler import Sampler, SamplingMarginalisedParameterError
IMPLEMENTED_SAMPLERS = get_entry_points("bilby.samplers")
class ImplementedSamplers:
"""Dictionary-like object that contains implemented samplers.
This class is singleton and only one instance can exist.
"""
_instance = None
_samplers = get_entry_points("bilby.samplers")
def keys(self):
"""Iterator of available samplers by name.
Reduces the list to its simplest. This includes removing the 'bilby.'
prefix from native samplers if a corresponding plugin is not available.
"""
keys = []
for key in self._samplers.keys():
name = key.replace("bilby.", "")
if name in self._samplers.keys():
keys.append(key)
else:
keys.append(name)
return iter(keys)
def values(self):
"""Iterator of sampler classes.
Note: the classes need to loaded using :code:`.load()` before being
called.
"""
return iter(self._samplers.values())
def items(self):
"""Iterator of tuples containing keys (sampler names) and classes.
Note: the classes need to loaded using :code:`.load()` before being
called.
"""
return iter(((k, v) for k, v in zip(self.keys(), self.values())))
def valid_keys(self):
"""All valid keys including bilby.<sampler name>."""
keys = set(self._samplers.keys())
return iter(keys.union({k.replace("bilby.", "") for k in keys}))
def __getitem__(self, key):
if key in self._samplers:
return self._samplers[key]
elif f"bilby.{key}" in self._samplers:
return self._samplers[f"bilby.{key}"]
else:
raise ValueError(
f"Sampler {key} is not implemented! "
f"Available samplers are: {list(self.keys())}"
)
def __contains__(self, value):
return value in self.valid_keys()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
IMPLEMENTED_SAMPLERS = ImplementedSamplers()
def get_implemented_samplers():
......@@ -51,14 +118,7 @@ def get_sampler_class(sampler):
ValueError
Raised if the sampler is not implemented.
"""
sampler = sampler.lower()
if sampler in IMPLEMENTED_SAMPLERS:
return IMPLEMENTED_SAMPLERS[sampler].load()
else:
raise ValueError(
f"Sampler {sampler} not yet implemented. "
f"The available samplers are: {get_implemented_samplers()}"
)
return IMPLEMENTED_SAMPLERS[sampler.lower()].load()
if command_line_args.sampler_help:
......
......@@ -61,6 +61,19 @@ here, we encourage you to open a
- This could be your sampler
Bilby-native samplers
---------------------
Some samplers are implemented directly in :code:`bilby` and these are avertised
under two possible names:
- :code:`bilby.<sampler name>`: always available, indicates the sampler is implemented in bilby,
- :code:`<sampler name>`: only refers to the native bilby implementation if an external plugin does not already provide this sampler.
This allows for an external plugin to provide a sampler without introducing
namespace conflicts.
--------------------------------
Information for bilby developers
--------------------------------
......
......@@ -73,6 +73,18 @@ MCMC samplers
- zeus :code:`bilby.core.sampler.zeus.Zeus`
--------------------------
Listing available samplers
--------------------------
A list of available samplers can be produced using
:py:func:`bilby.core.sampler.get_implemented_samplers`.
This will list native bilby samplers and any samplers available via a plugin.
If a plugin provides a sampler that is also implemented in bilby, the bilby
implementation will be labeled with the prfix `bilby.` to distinguish it from
the plugin version. See `sampler plugins`_ for more details.
-------------------
Installing samplers
-------------------
......
......@@ -84,23 +84,23 @@ setup(
"bilby_result=cli_bilby.bilby_result:main",
],
"bilby.samplers": [
"bilby_mcmc=bilby.bilby_mcmc.sampler:Bilby_MCMC",
"cpnest=bilby.core.sampler.cpnest:Cpnest",
"dnest4=bilby.core.sampler.dnest4:DNest4",
"dynesty=bilby.core.sampler.dynesty:Dynesty",
"dynamic_dynesty=bilby.core.sampler.dynamic_dynesty:DynamicDynesty",
"emcee=bilby.core.sampler.emcee:Emcee",
"kombine=bilby.core.sampler.kombine:Kombine",
"nessai=bilby.core.sampler.nessai:Nessai",
"nestle=bilby.core.sampler.nestle:Nestle",
"ptemcee=bilby.core.sampler.ptemcee:Ptemcee",
"ptmcmcsampler=bilby.core.sampler.ptmcmc:PTMCMCSampler",
"pymc=bilby.core.sampler.pymc:Pymc",
"pymultinest=bilby.core.sampler.pymultinest:Pymultinest",
"pypolychord=bilby.core.sampler.polychord:PyPolyChord",
"ultranest=bilby.core.sampler.ultranest:Ultranest",
"zeus=bilby.core.sampler.zeus:Zeus",
"fake_sampler=bilby.core.sampler.fake_sampler:FakeSampler",
"bilby.bilby_mcmc=bilby.bilby_mcmc.sampler:Bilby_MCMC",
"bilby.cpnest=bilby.core.sampler.cpnest:Cpnest",
"bilby.dnest4=bilby.core.sampler.dnest4:DNest4",
"bilby.dynesty=bilby.core.sampler.dynesty:Dynesty",
"bilby.dynamic_dynesty=bilby.core.sampler.dynamic_dynesty:DynamicDynesty",
"bilby.emcee=bilby.core.sampler.emcee:Emcee",
"bilby.kombine=bilby.core.sampler.kombine:Kombine",
"bilby.nessai=bilby.core.sampler.nessai:Nessai",
"bilby.nestle=bilby.core.sampler.nestle:Nestle",
"bilby.ptemcee=bilby.core.sampler.ptemcee:Ptemcee",
"bilby.ptmcmcsampler=bilby.core.sampler.ptmcmc:PTMCMCSampler",
"bilby.pymc=bilby.core.sampler.pymc:Pymc",
"bilby.pymultinest=bilby.core.sampler.pymultinest:Pymultinest",
"bilby.pypolychord=bilby.core.sampler.polychord:PyPolyChord",
"bilby.ultranest=bilby.core.sampler.ultranest:Ultranest",
"bilby.zeus=bilby.core.sampler.zeus:Zeus",
"bilby.fake_sampler=bilby.core.sampler.fake_sampler:FakeSampler",
],
},
classifiers=[
......
from bilby.core.sampler import IMPLEMENTED_SAMPLERS, ImplementedSamplers
import pytest
def test_singleton():
assert ImplementedSamplers() is IMPLEMENTED_SAMPLERS
def test_keys():
# The fake sampler should never have a plugin, so this should always work
assert "fake_sampler" in IMPLEMENTED_SAMPLERS.keys()
assert "bilby.fake_sampler" not in IMPLEMENTED_SAMPLERS.keys()
def test_allowed_keys():
# The fake sampler should never have a plugin, so this should always work
assert "fake_sampler" in IMPLEMENTED_SAMPLERS.valid_keys()
assert "bilby.fake_sampler" in IMPLEMENTED_SAMPLERS.valid_keys()
def test_values():
# Values and keys should have the same lengths
assert len(list(IMPLEMENTED_SAMPLERS.values())) \
== len(list(IMPLEMENTED_SAMPLERS.keys()))
assert len(list(IMPLEMENTED_SAMPLERS.values())) \
== len(list(IMPLEMENTED_SAMPLERS._samplers.values()))
def test_items():
keys, values = list(zip(*IMPLEMENTED_SAMPLERS.items()))
assert len(keys) == len(values)
# Keys and values should be the same as the individual methods
assert list(keys) == list(IMPLEMENTED_SAMPLERS.keys())
assert list(values) == list(IMPLEMENTED_SAMPLERS.values())
@pytest.mark.parametrize("sampler", ["fake_sampler", "bilby.fake_sampler"])
def test_in_operator(sampler):
assert sampler in IMPLEMENTED_SAMPLERS
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment