From 5ebe14d6dd034de5d2b1428e2a847916b6b79fbe Mon Sep 17 00:00:00 2001 From: Moritz <email@moritz-huebner.de> Date: Mon, 19 Nov 2018 13:35:18 +1100 Subject: [PATCH] Put imports of sampler packages in separate functions --- bilby/core/sampler/cpnest.py | 18 ++++--- bilby/core/sampler/dynesty.py | 17 ++++--- bilby/core/sampler/emcee.py | 17 ++++--- bilby/core/sampler/nestle.py | 17 ++++--- bilby/core/sampler/ptemcee.py | 14 +++++- bilby/core/sampler/pymc3.py | 50 +++++++++++++------ bilby/core/sampler/pymultinest.py | 30 ++++++----- examples/injection_examples/basic_tutorial.py | 2 +- 8 files changed, 112 insertions(+), 53 deletions(-) diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py index 68826ecae..a3d62b9ee 100644 --- a/bilby/core/sampler/cpnest.py +++ b/bilby/core/sampler/cpnest.py @@ -6,12 +6,6 @@ from pandas import DataFrame from .base_sampler import NestedSampler from ..utils import logger, check_directory_exists_and_if_not_mkdir -try: - from cpnest import model as cpmodel, CPNest -except ImportError: - logger.debug('CPNest is not installed on this system, you will ' - 'not be able to use the CPNest sampler') - class Cpnest(NestedSampler): """ bilby wrapper of cpnest (https://github.com/johnveitch/cpnest) @@ -45,6 +39,16 @@ class Cpnest(NestedSampler): seed=None, poolsize=100, nhamiltonian=0, resume=False, output=None) + @staticmethod + def _import_external_sampler(): + try: + from cpnest import model as cpmodel, CPNest + except ImportError: + logger.debug('CPNest is not installed on this system, you will ' + 'not be able to use the CPNest sampler') + cpmodel, CPNest = None, None + return cpmodel, CPNest + def _translate_kwargs(self, kwargs): if 'nlive' not in kwargs: for equiv in self.npoints_equiv_kwargs: @@ -54,6 +58,8 @@ class Cpnest(NestedSampler): logger.warning('No seed provided, cpnest will use 1234.') def run_sampler(self): + cpmodel, CPNest = self._import_external_sampler() + class Model(cpmodel.Model): """ A wrapper class to pass our log_likelihood into cpnest """ diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index 3e64ff9f2..4373aa88a 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -10,12 +10,6 @@ from deepdish.io import load, save from ..utils import logger, check_directory_exists_and_if_not_mkdir from .base_sampler import Sampler, NestedSampler -try: - import dynesty -except ImportError: - logger.debug('Dynesty is not installed on this system, you will ' - 'not be able to use the Dynesty sampler') - class Dynesty(NestedSampler): """ @@ -109,6 +103,16 @@ class Dynesty(NestedSampler): n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw))) self.n_check_point = n_check_point_rnd + @staticmethod + def _import_external_sampler(): + try: + import dynesty + except ImportError: + logger.debug('Dynesty is not installed on this system, you will ' + 'not be able to use the Dynesty sampler') + dynesty = None + return dynesty + @property def sampler_function_kwargs(self): keys = ['dlogz', 'print_progress', 'print_func', 'maxiter', @@ -174,6 +178,7 @@ class Dynesty(NestedSampler): sys.stderr.flush() def run_sampler(self): + dynesty = self._import_external_sampler() self.sampler = dynesty.NestedSampler( loglikelihood=self.log_likelihood, prior_transform=self.prior_transform, diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index aa8d00392..77be0dc30 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -6,12 +6,6 @@ from pandas import DataFrame from ..utils import logger, get_progress_bar from .base_sampler import MCMCSampler -try: - import emcee -except ImportError: - logger.debug('Emcee is not installed on this system, you will ' - 'not be able to use the Emcee sampler') - class Emcee(MCMCSampler): """bilby wrapper emcee (https://github.com/dfm/emcee) @@ -60,6 +54,16 @@ class Emcee(MCMCSampler): self.burn_in_fraction = burn_in_fraction self.burn_in_act = burn_in_act + @staticmethod + def _import_external_sampler(): + try: + import emcee + except ImportError: + logger.debug('Emcee is not installed on this system, you will ' + 'not be able to use the Emcee sampler') + emcee = None + return emcee + def _translate_kwargs(self, kwargs): if 'nwalkers' not in kwargs: for equiv in self.nwalkers_equiv_kwargs: @@ -111,6 +115,7 @@ class Emcee(MCMCSampler): self.kwargs['iterations'] = nsteps def run_sampler(self): + emcee = self._import_external_sampler() tqdm = get_progress_bar() sampler = emcee.EnsembleSampler(dim=self.ndim, lnpostfn=self.lnpostfn, **self.sampler_init_kwargs) self._set_pos0() diff --git a/bilby/core/sampler/nestle.py b/bilby/core/sampler/nestle.py index 468e89d3c..9ea6b5eff 100644 --- a/bilby/core/sampler/nestle.py +++ b/bilby/core/sampler/nestle.py @@ -6,12 +6,6 @@ from pandas import DataFrame from ..utils import logger from .base_sampler import NestedSampler -try: - import nestle -except ImportError: - logger.debug('Nestle is not installed on this system, you will ' - 'not be able to use the Nestle sampler') - class Nestle(NestedSampler): """bilby wrapper `nestle.Sampler` (http://kylebarbary.com/nestle/) @@ -38,6 +32,16 @@ class Nestle(NestedSampler): maxcall=None, dlogz=None, decline_factor=None, rstate=None, callback=None) + @staticmethod + def _import_external_sampler(): + try: + import nestle + except ImportError: + logger.debug('Nestle is not installed on this system, you will ' + 'not be able to use the Nestle sampler') + nestle = None + return nestle + def _translate_kwargs(self, kwargs): if 'npoints' not in kwargs: for equiv in self.npoints_equiv_kwargs: @@ -59,6 +63,7 @@ class Nestle(NestedSampler): bilby.core.result.Result: Packaged information about the result """ + nestle = self._import_external_sampler() out = nestle.sample( loglikelihood=self.log_likelihood, prior_transform=self.prior_transform, diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py index dd6953d1d..de400fe31 100644 --- a/bilby/core/sampler/ptemcee.py +++ b/bilby/core/sampler/ptemcee.py @@ -50,6 +50,17 @@ class Ptemcee(Emcee): use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification, nburn=nburn, burn_in_fraction=burn_in_fraction, burn_in_act=burn_in_act, **kwargs) + + @staticmethod + def _import_external_sampler(): + try: + import ptemcee + except ImportError: + logger.debug('Nestle is not installed on this system, you will ' + 'not be able to use the Nestle sampler') + ptemcee = None + return ptemcee + @property def sampler_function_kwargs(self): keys = ['iterations', 'thin', 'storechain', 'adapt', 'swap_ratios'] @@ -62,8 +73,9 @@ class Ptemcee(Emcee): if key not in self.sampler_function_kwargs} def run_sampler(self): - tqdm = get_progress_bar() + ptemcee = self._import_external_sampler() + tqdm = get_progress_bar() sampler = ptemcee.Sampler(dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior, **self.sampler_init_kwargs) self.pos0 = [[self.get_random_draw_from_prior() diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py index 5abd79343..573d662e9 100644 --- a/bilby/core/sampler/pymc3.py +++ b/bilby/core/sampler/pymc3.py @@ -12,20 +12,6 @@ from ..likelihood import GaussianLikelihood, PoissonLikelihood, ExponentialLikel StudentTLikelihood from ...gw.likelihood import BasicGravitationalWaveTransient, GravitationalWaveTransient -try: - import pymc3 - from pymc3.sampling import STEP_METHODS - from pymc3.theanof import floatX -except ImportError: - logger.debug('PyMC3 is not installed on this system, you will ' - 'not be able to use the PyMC3 sampler') -try: - import theano # noqa - import theano.tensor as tt - from theano.compile.ops import as_op # noqa -except ImportError: - logger.debug("You must have Theano installed to use PyMC3") - class Pymc3(MCMCSampler): """ bilby wrapper of the PyMC3 sampler (https://docs.pymc.io/) @@ -82,6 +68,31 @@ class Pymc3(MCMCSampler): self.draws = draws self.chains = self.__kwargs['chains'] + @staticmethod + def _import_external_sampler(): + try: + import pymc3 + from pymc3.sampling import STEP_METHODS + from pymc3.theanof import floatX + except ImportError: + logger.debug('PyMC3 is not installed on this system, you will ' + 'not be able to use the PyMC3 sampler') + pymc3 = None + STEP_METHODS = None + floatX = None + return pymc3, STEP_METHODS, floatX + + @staticmethod + def _import_theano(): + try: + import theano # noqa + import theano.tensor as tt + from theano.compile.ops import as_op # noqa + except ImportError: + logger.debug("You must have Theano installed to use PyMC3") + theano, tt, as_op = None, None, None + return theano, tt, as_op + def _verify_parameters(self): """ Change `_verify_parameters()` to just pass, i.e., don't try and @@ -247,6 +258,8 @@ class Pymc3(MCMCSampler): """ # check prior is a Sine + pymc3, STEP_METHODS, floatX = self._import_external_sampler() + theano, tt, as_op = self._import_theano() if isinstance(self.priors[key], Sine): class Pymc3Sine(pymc3.Continuous): @@ -285,6 +298,8 @@ class Pymc3(MCMCSampler): """ # check prior is a Cosine + pymc3, STEP_METHODS, floatX = self._import_external_sampler() + theano, tt, as_op = self._import_theano() if isinstance(self.priors[key], Cosine): class Pymc3Cosine(pymc3.Continuous): @@ -322,6 +337,8 @@ class Pymc3(MCMCSampler): """ # check prior is a PowerLaw + pymc3, STEP_METHODS, floatX = self._import_external_sampler() + theano, tt, as_op = self._import_theano() if isinstance(self.priors[key], PowerLaw): # check power law is set @@ -373,7 +390,7 @@ class Pymc3(MCMCSampler): def run_sampler(self): # set the step method - + pymc3, STEP_METHODS, floatX = self._import_external_sampler() step_methods = {m.__name__.lower(): m.__name__ for m in STEP_METHODS} if 'step' in self.__kwargs: self.step_method = self.__kwargs.pop('step') @@ -454,6 +471,7 @@ class Pymc3(MCMCSampler): self.setup_prior_mapping() self.pymc3_priors = OrderedDict() + pymc3, STEP_METHODS, floatX = self._import_external_sampler() # set the parameter prior distributions (in the model context manager) with self.pymc3_model: @@ -517,6 +535,8 @@ class Pymc3(MCMCSampler): """ # create theano Op for the log likelihood if not using a predefined model + pymc3, STEP_METHODS, floatX = self._import_external_sampler() + theano, tt, as_op = self._import_theano() class LogLike(tt.Op): itypes = [tt.dvector] diff --git a/bilby/core/sampler/pymultinest.py b/bilby/core/sampler/pymultinest.py index ac748a3b8..3e2a7314a 100644 --- a/bilby/core/sampler/pymultinest.py +++ b/bilby/core/sampler/pymultinest.py @@ -6,18 +6,6 @@ import os from ..utils import check_directory_exists_and_if_not_mkdir from .base_sampler import NestedSampler from ..utils import logger -try: - try: - # Suppresses import error printouts from pymultinest - sys.stdout = open(os.devnull, 'w') - import pymultinest - sys.stdout = sys.__stdout__ - except ImportError: - logger.debug('PyMultinest is not installed on this system, you will ' - 'not be able to use the PyMultinest sampler') -except SystemExit: - logger.debug('Multinest is not installed on this system, you will ' - 'not be able to use the Multinest sampler') class Pymultinest(NestedSampler): @@ -57,6 +45,23 @@ class Pymultinest(NestedSampler): context=0, write_output=True, log_zero=-1e100, max_iter=0, init_MPI=False, dump_callback=None) + @staticmethod + def _import_external_sampler(): + try: + # Suppresses import error printouts from pymultinest + sys.stdout = open(os.devnull, 'w') + import pymultinest + sys.stdout = sys.__stdout__ + except ImportError: + logger.debug('PyMultinest is not installed on this system, you will ' + 'not be able to use the PyMultinest sampler') + pymultinest = None + except SystemExit: + logger.debug('Multinest is not installed on this system, you will ' + 'not be able to use the Multinest sampler') + pymultinest = None + return pymultinest + def _translate_kwargs(self, kwargs): if 'n_live_points' not in kwargs: for equiv in self.npoints_equiv_kwargs: @@ -87,6 +92,7 @@ class Pymultinest(NestedSampler): NestedSampler._verify_kwargs_against_default_kwargs(self) def run_sampler(self): + pymultinest = self._import_external_sampler() self._verify_kwargs_against_default_kwargs() out = pymultinest.solve( LogLikelihood=self.log_likelihood, Prior=self.prior_transform, diff --git a/examples/injection_examples/basic_tutorial.py b/examples/injection_examples/basic_tutorial.py index cb9831e2a..fe4e71eb0 100644 --- a/examples/injection_examples/basic_tutorial.py +++ b/examples/injection_examples/basic_tutorial.py @@ -80,7 +80,7 @@ likelihood = bilby.gw.GravitationalWaveTransient( # Run sampler. In this case we're going to use the `dynesty` sampler result = bilby.run_sampler( - likelihood=likelihood, priors=priors, sampler='dynesty', npoints=1000, + likelihood=likelihood, priors=priors, sampler='pymc3', npoints=1000, injection_parameters=injection_parameters, outdir=outdir, label=label) # Make a corner plot. -- GitLab