From 5ebe14d6dd034de5d2b1428e2a847916b6b79fbe Mon Sep 17 00:00:00 2001
From: Moritz <email@moritz-huebner.de>
Date: Mon, 19 Nov 2018 13:35:18 +1100
Subject: [PATCH] Put imports of sampler packages in separate functions

---
 bilby/core/sampler/cpnest.py                  | 18 ++++---
 bilby/core/sampler/dynesty.py                 | 17 ++++---
 bilby/core/sampler/emcee.py                   | 17 ++++---
 bilby/core/sampler/nestle.py                  | 17 ++++---
 bilby/core/sampler/ptemcee.py                 | 14 +++++-
 bilby/core/sampler/pymc3.py                   | 50 +++++++++++++------
 bilby/core/sampler/pymultinest.py             | 30 ++++++-----
 examples/injection_examples/basic_tutorial.py |  2 +-
 8 files changed, 112 insertions(+), 53 deletions(-)

diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py
index 68826ecae..a3d62b9ee 100644
--- a/bilby/core/sampler/cpnest.py
+++ b/bilby/core/sampler/cpnest.py
@@ -6,12 +6,6 @@ from pandas import DataFrame
 from .base_sampler import NestedSampler
 from ..utils import logger, check_directory_exists_and_if_not_mkdir
 
-try:
-    from cpnest import model as cpmodel, CPNest
-except ImportError:
-    logger.debug('CPNest is not installed on this system, you will '
-                 'not be able to use the CPNest sampler')
-
 
 class Cpnest(NestedSampler):
     """ bilby wrapper of cpnest (https://github.com/johnveitch/cpnest)
@@ -45,6 +39,16 @@ class Cpnest(NestedSampler):
                           seed=None, poolsize=100, nhamiltonian=0, resume=False,
                           output=None)
 
+    @staticmethod
+    def _import_external_sampler():
+        try:
+            from cpnest import model as cpmodel, CPNest
+        except ImportError:
+            logger.debug('CPNest is not installed on this system, you will '
+                         'not be able to use the CPNest sampler')
+            cpmodel, CPNest = None, None
+        return cpmodel, CPNest
+
     def _translate_kwargs(self, kwargs):
         if 'nlive' not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
@@ -54,6 +58,8 @@ class Cpnest(NestedSampler):
             logger.warning('No seed provided, cpnest will use 1234.')
 
     def run_sampler(self):
+        cpmodel, CPNest = self._import_external_sampler()
+
         class Model(cpmodel.Model):
             """ A wrapper class to pass our log_likelihood into cpnest """
 
diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py
index 3e64ff9f2..4373aa88a 100644
--- a/bilby/core/sampler/dynesty.py
+++ b/bilby/core/sampler/dynesty.py
@@ -10,12 +10,6 @@ from deepdish.io import load, save
 from ..utils import logger, check_directory_exists_and_if_not_mkdir
 from .base_sampler import Sampler, NestedSampler
 
-try:
-    import dynesty
-except ImportError:
-    logger.debug('Dynesty is not installed on this system, you will '
-                 'not be able to use the Dynesty sampler')
-
 
 class Dynesty(NestedSampler):
     """
@@ -109,6 +103,16 @@ class Dynesty(NestedSampler):
             n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw)))
             self.n_check_point = n_check_point_rnd
 
+    @staticmethod
+    def _import_external_sampler():
+        try:
+            import dynesty
+        except ImportError:
+            logger.debug('Dynesty is not installed on this system, you will '
+                         'not be able to use the Dynesty sampler')
+            dynesty = None
+        return dynesty
+
     @property
     def sampler_function_kwargs(self):
         keys = ['dlogz', 'print_progress', 'print_func', 'maxiter',
@@ -174,6 +178,7 @@ class Dynesty(NestedSampler):
         sys.stderr.flush()
 
     def run_sampler(self):
+        dynesty = self._import_external_sampler()
         self.sampler = dynesty.NestedSampler(
             loglikelihood=self.log_likelihood,
             prior_transform=self.prior_transform,
diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py
index aa8d00392..77be0dc30 100644
--- a/bilby/core/sampler/emcee.py
+++ b/bilby/core/sampler/emcee.py
@@ -6,12 +6,6 @@ from pandas import DataFrame
 from ..utils import logger, get_progress_bar
 from .base_sampler import MCMCSampler
 
-try:
-    import emcee
-except ImportError:
-    logger.debug('Emcee is not installed on this system, you will '
-                 'not be able to use the Emcee sampler')
-
 
 class Emcee(MCMCSampler):
     """bilby wrapper emcee (https://github.com/dfm/emcee)
@@ -60,6 +54,16 @@ class Emcee(MCMCSampler):
         self.burn_in_fraction = burn_in_fraction
         self.burn_in_act = burn_in_act
 
+    @staticmethod
+    def _import_external_sampler():
+        try:
+            import emcee
+        except ImportError:
+            logger.debug('Emcee is not installed on this system, you will '
+                         'not be able to use the Emcee sampler')
+            emcee = None
+        return emcee
+
     def _translate_kwargs(self, kwargs):
         if 'nwalkers' not in kwargs:
             for equiv in self.nwalkers_equiv_kwargs:
@@ -111,6 +115,7 @@ class Emcee(MCMCSampler):
         self.kwargs['iterations'] = nsteps
 
     def run_sampler(self):
+        emcee = self._import_external_sampler()
         tqdm = get_progress_bar()
         sampler = emcee.EnsembleSampler(dim=self.ndim, lnpostfn=self.lnpostfn, **self.sampler_init_kwargs)
         self._set_pos0()
diff --git a/bilby/core/sampler/nestle.py b/bilby/core/sampler/nestle.py
index 468e89d3c..9ea6b5eff 100644
--- a/bilby/core/sampler/nestle.py
+++ b/bilby/core/sampler/nestle.py
@@ -6,12 +6,6 @@ from pandas import DataFrame
 from ..utils import logger
 from .base_sampler import NestedSampler
 
-try:
-    import nestle
-except ImportError:
-    logger.debug('Nestle is not installed on this system, you will '
-                 'not be able to use the Nestle sampler')
-
 
 class Nestle(NestedSampler):
     """bilby wrapper `nestle.Sampler` (http://kylebarbary.com/nestle/)
@@ -38,6 +32,16 @@ class Nestle(NestedSampler):
                           maxcall=None, dlogz=None, decline_factor=None,
                           rstate=None, callback=None)
 
+    @staticmethod
+    def _import_external_sampler():
+        try:
+            import nestle
+        except ImportError:
+            logger.debug('Nestle is not installed on this system, you will '
+                         'not be able to use the Nestle sampler')
+            nestle = None
+        return nestle
+
     def _translate_kwargs(self, kwargs):
         if 'npoints' not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
@@ -59,6 +63,7 @@ class Nestle(NestedSampler):
         bilby.core.result.Result: Packaged information about the result
 
         """
+        nestle = self._import_external_sampler()
         out = nestle.sample(
             loglikelihood=self.log_likelihood,
             prior_transform=self.prior_transform,
diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py
index dd6953d1d..de400fe31 100644
--- a/bilby/core/sampler/ptemcee.py
+++ b/bilby/core/sampler/ptemcee.py
@@ -50,6 +50,17 @@ class Ptemcee(Emcee):
                        use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification,
                        nburn=nburn, burn_in_fraction=burn_in_fraction, burn_in_act=burn_in_act, **kwargs)
 
+
+    @staticmethod
+    def _import_external_sampler():
+        try:
+            import ptemcee
+        except ImportError:
+            logger.debug('Nestle is not installed on this system, you will '
+                         'not be able to use the Nestle sampler')
+            ptemcee = None
+        return ptemcee
+
     @property
     def sampler_function_kwargs(self):
         keys = ['iterations', 'thin', 'storechain', 'adapt', 'swap_ratios']
@@ -62,8 +73,9 @@ class Ptemcee(Emcee):
                 if key not in self.sampler_function_kwargs}
 
     def run_sampler(self):
-        tqdm = get_progress_bar()
+        ptemcee = self._import_external_sampler()
 
+        tqdm = get_progress_bar()
         sampler = ptemcee.Sampler(dim=self.ndim, logl=self.log_likelihood,
                                   logp=self.log_prior, **self.sampler_init_kwargs)
         self.pos0 = [[self.get_random_draw_from_prior()
diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py
index 5abd79343..573d662e9 100644
--- a/bilby/core/sampler/pymc3.py
+++ b/bilby/core/sampler/pymc3.py
@@ -12,20 +12,6 @@ from ..likelihood import GaussianLikelihood, PoissonLikelihood, ExponentialLikel
     StudentTLikelihood
 from ...gw.likelihood import BasicGravitationalWaveTransient, GravitationalWaveTransient
 
-try:
-    import pymc3
-    from pymc3.sampling import STEP_METHODS
-    from pymc3.theanof import floatX
-except ImportError:
-    logger.debug('PyMC3 is not installed on this system, you will '
-                 'not be able to use the PyMC3 sampler')
-try:
-    import theano  # noqa
-    import theano.tensor as tt
-    from theano.compile.ops import as_op  # noqa
-except ImportError:
-    logger.debug("You must have Theano installed to use PyMC3")
-
 
 class Pymc3(MCMCSampler):
     """ bilby wrapper of the PyMC3 sampler (https://docs.pymc.io/)
@@ -82,6 +68,31 @@ class Pymc3(MCMCSampler):
         self.draws = draws
         self.chains = self.__kwargs['chains']
 
+    @staticmethod
+    def _import_external_sampler():
+        try:
+            import pymc3
+            from pymc3.sampling import STEP_METHODS
+            from pymc3.theanof import floatX
+        except ImportError:
+            logger.debug('PyMC3 is not installed on this system, you will '
+                         'not be able to use the PyMC3 sampler')
+            pymc3 = None
+            STEP_METHODS = None
+            floatX = None
+        return pymc3, STEP_METHODS, floatX
+
+    @staticmethod
+    def _import_theano():
+        try:
+            import theano  # noqa
+            import theano.tensor as tt
+            from theano.compile.ops import as_op  # noqa
+        except ImportError:
+            logger.debug("You must have Theano installed to use PyMC3")
+            theano, tt, as_op = None, None, None
+        return theano, tt, as_op
+
     def _verify_parameters(self):
         """
         Change `_verify_parameters()` to just pass, i.e., don't try and
@@ -247,6 +258,8 @@ class Pymc3(MCMCSampler):
         """
 
         # check prior is a Sine
+        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
+        theano, tt, as_op = self._import_theano()
         if isinstance(self.priors[key], Sine):
 
             class Pymc3Sine(pymc3.Continuous):
@@ -285,6 +298,8 @@ class Pymc3(MCMCSampler):
         """
 
         # check prior is a Cosine
+        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
+        theano, tt, as_op = self._import_theano()
         if isinstance(self.priors[key], Cosine):
 
             class Pymc3Cosine(pymc3.Continuous):
@@ -322,6 +337,8 @@ class Pymc3(MCMCSampler):
         """
 
         # check prior is a PowerLaw
+        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
+        theano, tt, as_op = self._import_theano()
         if isinstance(self.priors[key], PowerLaw):
 
             # check power law is set
@@ -373,7 +390,7 @@ class Pymc3(MCMCSampler):
 
     def run_sampler(self):
         # set the step method
-
+        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
         step_methods = {m.__name__.lower(): m.__name__ for m in STEP_METHODS}
         if 'step' in self.__kwargs:
             self.step_method = self.__kwargs.pop('step')
@@ -454,6 +471,7 @@ class Pymc3(MCMCSampler):
         self.setup_prior_mapping()
 
         self.pymc3_priors = OrderedDict()
+        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
 
         # set the parameter prior distributions (in the model context manager)
         with self.pymc3_model:
@@ -517,6 +535,8 @@ class Pymc3(MCMCSampler):
         """
 
         # create theano Op for the log likelihood if not using a predefined model
+        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
+        theano, tt, as_op = self._import_theano()
         class LogLike(tt.Op):
 
             itypes = [tt.dvector]
diff --git a/bilby/core/sampler/pymultinest.py b/bilby/core/sampler/pymultinest.py
index ac748a3b8..3e2a7314a 100644
--- a/bilby/core/sampler/pymultinest.py
+++ b/bilby/core/sampler/pymultinest.py
@@ -6,18 +6,6 @@ import os
 from ..utils import check_directory_exists_and_if_not_mkdir
 from .base_sampler import NestedSampler
 from ..utils import logger
-try:
-    try:
-        # Suppresses import error printouts from pymultinest
-        sys.stdout = open(os.devnull, 'w')
-        import pymultinest
-        sys.stdout = sys.__stdout__
-    except ImportError:
-        logger.debug('PyMultinest is not installed on this system, you will '
-                     'not be able to use the PyMultinest sampler')
-except SystemExit:
-    logger.debug('Multinest is not installed on this system, you will '
-                 'not be able to use the Multinest sampler')
 
 
 class Pymultinest(NestedSampler):
@@ -57,6 +45,23 @@ class Pymultinest(NestedSampler):
                           context=0, write_output=True, log_zero=-1e100,
                           max_iter=0, init_MPI=False, dump_callback=None)
 
+    @staticmethod
+    def _import_external_sampler():
+        try:
+            # Suppresses import error printouts from pymultinest
+            sys.stdout = open(os.devnull, 'w')
+            import pymultinest
+            sys.stdout = sys.__stdout__
+        except ImportError:
+            logger.debug('PyMultinest is not installed on this system, you will '
+                         'not be able to use the PyMultinest sampler')
+            pymultinest = None
+        except SystemExit:
+            logger.debug('Multinest is not installed on this system, you will '
+                         'not be able to use the Multinest sampler')
+            pymultinest = None
+        return pymultinest
+
     def _translate_kwargs(self, kwargs):
         if 'n_live_points' not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
@@ -87,6 +92,7 @@ class Pymultinest(NestedSampler):
         NestedSampler._verify_kwargs_against_default_kwargs(self)
 
     def run_sampler(self):
+        pymultinest = self._import_external_sampler()
         self._verify_kwargs_against_default_kwargs()
         out = pymultinest.solve(
             LogLikelihood=self.log_likelihood, Prior=self.prior_transform,
diff --git a/examples/injection_examples/basic_tutorial.py b/examples/injection_examples/basic_tutorial.py
index cb9831e2a..fe4e71eb0 100644
--- a/examples/injection_examples/basic_tutorial.py
+++ b/examples/injection_examples/basic_tutorial.py
@@ -80,7 +80,7 @@ likelihood = bilby.gw.GravitationalWaveTransient(
 
 # Run sampler.  In this case we're going to use the `dynesty` sampler
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='dynesty', npoints=1000,
+    likelihood=likelihood, priors=priors, sampler='pymc3', npoints=1000,
     injection_parameters=injection_parameters, outdir=outdir, label=label)
 
 # Make a corner plot.
-- 
GitLab