Skip to content
Snippets Groups Projects
Commit 5ebe14d6 authored by Moritz's avatar Moritz
Browse files

Put imports of sampler packages in separate functions

parent faa5c8d9
No related branches found
No related tags found
1 merge request!230Resolve "PEP8 Imports"
......@@ -6,12 +6,6 @@ from pandas import DataFrame
from .base_sampler import NestedSampler
from ..utils import logger, check_directory_exists_and_if_not_mkdir
try:
from cpnest import model as cpmodel, CPNest
except ImportError:
logger.debug('CPNest is not installed on this system, you will '
'not be able to use the CPNest sampler')
class Cpnest(NestedSampler):
""" bilby wrapper of cpnest (https://github.com/johnveitch/cpnest)
......@@ -45,6 +39,16 @@ class Cpnest(NestedSampler):
seed=None, poolsize=100, nhamiltonian=0, resume=False,
output=None)
@staticmethod
def _import_external_sampler():
try:
from cpnest import model as cpmodel, CPNest
except ImportError:
logger.debug('CPNest is not installed on this system, you will '
'not be able to use the CPNest sampler')
cpmodel, CPNest = None, None
return cpmodel, CPNest
def _translate_kwargs(self, kwargs):
if 'nlive' not in kwargs:
for equiv in self.npoints_equiv_kwargs:
......@@ -54,6 +58,8 @@ class Cpnest(NestedSampler):
logger.warning('No seed provided, cpnest will use 1234.')
def run_sampler(self):
cpmodel, CPNest = self._import_external_sampler()
class Model(cpmodel.Model):
""" A wrapper class to pass our log_likelihood into cpnest """
......
......@@ -10,12 +10,6 @@ from deepdish.io import load, save
from ..utils import logger, check_directory_exists_and_if_not_mkdir
from .base_sampler import Sampler, NestedSampler
try:
import dynesty
except ImportError:
logger.debug('Dynesty is not installed on this system, you will '
'not be able to use the Dynesty sampler')
class Dynesty(NestedSampler):
"""
......@@ -109,6 +103,16 @@ class Dynesty(NestedSampler):
n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw)))
self.n_check_point = n_check_point_rnd
@staticmethod
def _import_external_sampler():
try:
import dynesty
except ImportError:
logger.debug('Dynesty is not installed on this system, you will '
'not be able to use the Dynesty sampler')
dynesty = None
return dynesty
@property
def sampler_function_kwargs(self):
keys = ['dlogz', 'print_progress', 'print_func', 'maxiter',
......@@ -174,6 +178,7 @@ class Dynesty(NestedSampler):
sys.stderr.flush()
def run_sampler(self):
dynesty = self._import_external_sampler()
self.sampler = dynesty.NestedSampler(
loglikelihood=self.log_likelihood,
prior_transform=self.prior_transform,
......
......@@ -6,12 +6,6 @@ from pandas import DataFrame
from ..utils import logger, get_progress_bar
from .base_sampler import MCMCSampler
try:
import emcee
except ImportError:
logger.debug('Emcee is not installed on this system, you will '
'not be able to use the Emcee sampler')
class Emcee(MCMCSampler):
"""bilby wrapper emcee (https://github.com/dfm/emcee)
......@@ -60,6 +54,16 @@ class Emcee(MCMCSampler):
self.burn_in_fraction = burn_in_fraction
self.burn_in_act = burn_in_act
@staticmethod
def _import_external_sampler():
try:
import emcee
except ImportError:
logger.debug('Emcee is not installed on this system, you will '
'not be able to use the Emcee sampler')
emcee = None
return emcee
def _translate_kwargs(self, kwargs):
if 'nwalkers' not in kwargs:
for equiv in self.nwalkers_equiv_kwargs:
......@@ -111,6 +115,7 @@ class Emcee(MCMCSampler):
self.kwargs['iterations'] = nsteps
def run_sampler(self):
emcee = self._import_external_sampler()
tqdm = get_progress_bar()
sampler = emcee.EnsembleSampler(dim=self.ndim, lnpostfn=self.lnpostfn, **self.sampler_init_kwargs)
self._set_pos0()
......
......@@ -6,12 +6,6 @@ from pandas import DataFrame
from ..utils import logger
from .base_sampler import NestedSampler
try:
import nestle
except ImportError:
logger.debug('Nestle is not installed on this system, you will '
'not be able to use the Nestle sampler')
class Nestle(NestedSampler):
"""bilby wrapper `nestle.Sampler` (http://kylebarbary.com/nestle/)
......@@ -38,6 +32,16 @@ class Nestle(NestedSampler):
maxcall=None, dlogz=None, decline_factor=None,
rstate=None, callback=None)
@staticmethod
def _import_external_sampler():
try:
import nestle
except ImportError:
logger.debug('Nestle is not installed on this system, you will '
'not be able to use the Nestle sampler')
nestle = None
return nestle
def _translate_kwargs(self, kwargs):
if 'npoints' not in kwargs:
for equiv in self.npoints_equiv_kwargs:
......@@ -59,6 +63,7 @@ class Nestle(NestedSampler):
bilby.core.result.Result: Packaged information about the result
"""
nestle = self._import_external_sampler()
out = nestle.sample(
loglikelihood=self.log_likelihood,
prior_transform=self.prior_transform,
......
......@@ -50,6 +50,17 @@ class Ptemcee(Emcee):
use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification,
nburn=nburn, burn_in_fraction=burn_in_fraction, burn_in_act=burn_in_act, **kwargs)
@staticmethod
def _import_external_sampler():
try:
import ptemcee
except ImportError:
logger.debug('Nestle is not installed on this system, you will '
'not be able to use the Nestle sampler')
ptemcee = None
return ptemcee
@property
def sampler_function_kwargs(self):
keys = ['iterations', 'thin', 'storechain', 'adapt', 'swap_ratios']
......@@ -62,8 +73,9 @@ class Ptemcee(Emcee):
if key not in self.sampler_function_kwargs}
def run_sampler(self):
tqdm = get_progress_bar()
ptemcee = self._import_external_sampler()
tqdm = get_progress_bar()
sampler = ptemcee.Sampler(dim=self.ndim, logl=self.log_likelihood,
logp=self.log_prior, **self.sampler_init_kwargs)
self.pos0 = [[self.get_random_draw_from_prior()
......
......@@ -12,20 +12,6 @@ from ..likelihood import GaussianLikelihood, PoissonLikelihood, ExponentialLikel
StudentTLikelihood
from ...gw.likelihood import BasicGravitationalWaveTransient, GravitationalWaveTransient
try:
import pymc3
from pymc3.sampling import STEP_METHODS
from pymc3.theanof import floatX
except ImportError:
logger.debug('PyMC3 is not installed on this system, you will '
'not be able to use the PyMC3 sampler')
try:
import theano # noqa
import theano.tensor as tt
from theano.compile.ops import as_op # noqa
except ImportError:
logger.debug("You must have Theano installed to use PyMC3")
class Pymc3(MCMCSampler):
""" bilby wrapper of the PyMC3 sampler (https://docs.pymc.io/)
......@@ -82,6 +68,31 @@ class Pymc3(MCMCSampler):
self.draws = draws
self.chains = self.__kwargs['chains']
@staticmethod
def _import_external_sampler():
try:
import pymc3
from pymc3.sampling import STEP_METHODS
from pymc3.theanof import floatX
except ImportError:
logger.debug('PyMC3 is not installed on this system, you will '
'not be able to use the PyMC3 sampler')
pymc3 = None
STEP_METHODS = None
floatX = None
return pymc3, STEP_METHODS, floatX
@staticmethod
def _import_theano():
try:
import theano # noqa
import theano.tensor as tt
from theano.compile.ops import as_op # noqa
except ImportError:
logger.debug("You must have Theano installed to use PyMC3")
theano, tt, as_op = None, None, None
return theano, tt, as_op
def _verify_parameters(self):
"""
Change `_verify_parameters()` to just pass, i.e., don't try and
......@@ -247,6 +258,8 @@ class Pymc3(MCMCSampler):
"""
# check prior is a Sine
pymc3, STEP_METHODS, floatX = self._import_external_sampler()
theano, tt, as_op = self._import_theano()
if isinstance(self.priors[key], Sine):
class Pymc3Sine(pymc3.Continuous):
......@@ -285,6 +298,8 @@ class Pymc3(MCMCSampler):
"""
# check prior is a Cosine
pymc3, STEP_METHODS, floatX = self._import_external_sampler()
theano, tt, as_op = self._import_theano()
if isinstance(self.priors[key], Cosine):
class Pymc3Cosine(pymc3.Continuous):
......@@ -322,6 +337,8 @@ class Pymc3(MCMCSampler):
"""
# check prior is a PowerLaw
pymc3, STEP_METHODS, floatX = self._import_external_sampler()
theano, tt, as_op = self._import_theano()
if isinstance(self.priors[key], PowerLaw):
# check power law is set
......@@ -373,7 +390,7 @@ class Pymc3(MCMCSampler):
def run_sampler(self):
# set the step method
pymc3, STEP_METHODS, floatX = self._import_external_sampler()
step_methods = {m.__name__.lower(): m.__name__ for m in STEP_METHODS}
if 'step' in self.__kwargs:
self.step_method = self.__kwargs.pop('step')
......@@ -454,6 +471,7 @@ class Pymc3(MCMCSampler):
self.setup_prior_mapping()
self.pymc3_priors = OrderedDict()
pymc3, STEP_METHODS, floatX = self._import_external_sampler()
# set the parameter prior distributions (in the model context manager)
with self.pymc3_model:
......@@ -517,6 +535,8 @@ class Pymc3(MCMCSampler):
"""
# create theano Op for the log likelihood if not using a predefined model
pymc3, STEP_METHODS, floatX = self._import_external_sampler()
theano, tt, as_op = self._import_theano()
class LogLike(tt.Op):
itypes = [tt.dvector]
......
......@@ -6,18 +6,6 @@ import os
from ..utils import check_directory_exists_and_if_not_mkdir
from .base_sampler import NestedSampler
from ..utils import logger
try:
try:
# Suppresses import error printouts from pymultinest
sys.stdout = open(os.devnull, 'w')
import pymultinest
sys.stdout = sys.__stdout__
except ImportError:
logger.debug('PyMultinest is not installed on this system, you will '
'not be able to use the PyMultinest sampler')
except SystemExit:
logger.debug('Multinest is not installed on this system, you will '
'not be able to use the Multinest sampler')
class Pymultinest(NestedSampler):
......@@ -57,6 +45,23 @@ class Pymultinest(NestedSampler):
context=0, write_output=True, log_zero=-1e100,
max_iter=0, init_MPI=False, dump_callback=None)
@staticmethod
def _import_external_sampler():
try:
# Suppresses import error printouts from pymultinest
sys.stdout = open(os.devnull, 'w')
import pymultinest
sys.stdout = sys.__stdout__
except ImportError:
logger.debug('PyMultinest is not installed on this system, you will '
'not be able to use the PyMultinest sampler')
pymultinest = None
except SystemExit:
logger.debug('Multinest is not installed on this system, you will '
'not be able to use the Multinest sampler')
pymultinest = None
return pymultinest
def _translate_kwargs(self, kwargs):
if 'n_live_points' not in kwargs:
for equiv in self.npoints_equiv_kwargs:
......@@ -87,6 +92,7 @@ class Pymultinest(NestedSampler):
NestedSampler._verify_kwargs_against_default_kwargs(self)
def run_sampler(self):
pymultinest = self._import_external_sampler()
self._verify_kwargs_against_default_kwargs()
out = pymultinest.solve(
LogLikelihood=self.log_likelihood, Prior=self.prior_transform,
......
......@@ -80,7 +80,7 @@ likelihood = bilby.gw.GravitationalWaveTransient(
# Run sampler. In this case we're going to use the `dynesty` sampler
result = bilby.run_sampler(
likelihood=likelihood, priors=priors, sampler='dynesty', npoints=1000,
likelihood=likelihood, priors=priors, sampler='pymc3', npoints=1000,
injection_parameters=injection_parameters, outdir=outdir, label=label)
# Make a corner plot.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment