From 264a2b73f63f29be598a8190de140901cc20d7bc Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Thu, 25 Mar 2021 02:59:54 +0000 Subject: [PATCH] Reduce imports --- .gitlab-ci.yml | 1 + bilby/__init__.py | 2 +- bilby/core/grid.py | 12 ++- bilby/core/prior/analytical.py | 2 +- bilby/core/prior/base.py | 2 +- bilby/core/prior/conditional.py | 6 +- bilby/core/prior/dict.py | 14 +-- bilby/core/prior/interpolated.py | 4 +- bilby/core/prior/joint.py | 2 +- bilby/core/prior/slabspike.py | 4 +- bilby/core/result.py | 30 +++--- bilby/core/sampler/__init__.py | 1 - bilby/core/sampler/dynamic_dynesty.py | 7 +- bilby/core/sampler/dynesty.py | 20 ++-- bilby/core/sampler/emcee.py | 13 +-- bilby/core/sampler/fake_sampler.py | 1 + bilby/core/sampler/kombine.py | 6 +- bilby/core/sampler/ptemcee.py | 24 +++-- bilby/core/utils.py | 25 +++-- bilby/gw/conversion.py | 15 ++- bilby/gw/cosmology.py | 23 +++-- bilby/gw/detector/__init__.py | 15 +-- bilby/gw/detector/calibration.py | 3 +- bilby/gw/detector/interferometer.py | 20 ++-- bilby/gw/detector/strain_data.py | 130 ++++++++++++++------------ bilby/gw/eos/eos.py | 11 ++- bilby/gw/eos/tov_solver.py | 2 +- bilby/gw/likelihood.py | 24 ++--- bilby/gw/prior.py | 25 ++--- bilby/gw/result.py | 10 +- bilby/gw/source.py | 13 +-- bilby/gw/utils.py | 54 +++++------ requirements.txt | 2 +- sampler_requirements.txt | 3 +- setup.cfg | 1 + test/gw/cosmology_test.py | 2 +- test/gw/utils_test.py | 9 +- test/import_test.py | 22 +++++ 38 files changed, 298 insertions(+), 262 deletions(-) create mode 100644 test/import_test.py diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a19412a7..add1efb9 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -33,6 +33,7 @@ stages: - python -c "import bilby.gw.sampler" - python -c "import bilby.hyper" - python -c "import cli_bilby" + - python test/import_test.py - for script in $(pip show -f bilby | grep "bin\/" | xargs -I {} basename {}); do ${script} --help; done diff --git a/bilby/__init__.py b/bilby/__init__.py index 30e2531b..04a29e1b 100644 --- a/bilby/__init__.py +++ b/bilby/__init__.py @@ -29,7 +29,7 @@ __version__ = utils.get_version_information() if sys.version_info < (3,): raise ImportError( -"""You are running bilby 0.6.4 on Python 2 +"""You are running bilby >= 0.6.4 on Python 2 Bilby 0.6.4 and above are no longer compatible with Python 2, and you still ended up with this version installed. That's unfortunate; sorry about that. diff --git a/bilby/core/grid.py b/bilby/core/grid.py index ccc1e9d7..8562e6d6 100644 --- a/bilby/core/grid.py +++ b/bilby/core/grid.py @@ -1,12 +1,14 @@ -import numpy as np -import os import json +import os from collections import OrderedDict +import numpy as np + from .prior import Prior, PriorDict -from .utils import (logtrapzexp, check_directory_exists_and_if_not_mkdir, - logger) -from .utils import BilbyJsonEncoder, load_json, move_old_file +from .utils import ( + logtrapzexp, check_directory_exists_and_if_not_mkdir, logger, + BilbyJsonEncoder, load_json, move_old_file +) from .result import FileMovedError diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 08e56a98..63ce9683 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -4,7 +4,7 @@ from scipy.special._ufuncs import xlogy, erf, log1p, stdtrit, gammaln, stdtr, \ btdtri, betaln, btdtr, gammaincinv, gammainc from .base import Prior -from bilby.core.utils import logger +from ..utils import logger class DeltaFunction(Prior): diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py index 05df00a7..ab3f2e12 100644 --- a/bilby/core/prior/base.py +++ b/bilby/core/prior/base.py @@ -5,7 +5,6 @@ import re import numpy as np import scipy.stats -from scipy.integrate import cumtrapz from scipy.interpolate import interp1d from bilby.core.utils import infer_args_from_method, BilbyJsonEncoder, decode_bilby_json, logger, \ @@ -162,6 +161,7 @@ class Prior(object): def cdf(self, val): """ Generic method to calculate CDF, can be overwritten in subclass """ + from scipy.integrate import cumtrapz if np.any(np.isinf([self.minimum, self.maximum])): raise ValueError( "Unable to use the generic CDF calculation for priors with" diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index fbe48469..10782282 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -1,11 +1,11 @@ import numpy as np from .base import Prior, PriorException -from bilby.core.prior.interpolated import Interped -from bilby.core.prior.analytical import DeltaFunction, PowerLaw, Uniform, LogUniform, \ +from .interpolated import Interped +from .analytical import DeltaFunction, PowerLaw, Uniform, LogUniform, \ SymmetricLogUniform, Cosine, Sine, Gaussian, TruncatedGaussian, HalfGaussian, \ LogNormal, Exponential, StudentT, Beta, Logistic, Cauchy, Gamma, ChiSquared, FermiDirac -from bilby.core.utils import infer_args_from_method, infer_parameters_from_function +from ..utils import infer_args_from_method, infer_parameters_from_function def conditional_prior_factory(prior_class): diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index fc9918b0..7c82b4f4 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -1,15 +1,14 @@ -from importlib import import_module -from io import open as ioopen import json import os +from importlib import import_module +from io import open as ioopen -from matplotlib.cbook import flatten import numpy as np -from bilby.core.prior.analytical import DeltaFunction -from bilby.core.prior.base import Prior, Constraint -from bilby.core.prior.joint import JointPrior -from bilby.core.utils import logger, check_directory_exists_and_if_not_mkdir, BilbyJsonEncoder, decode_bilby_json +from .analytical import DeltaFunction +from .base import Prior, Constraint +from .joint import JointPrior +from ..utils import logger, check_directory_exists_and_if_not_mkdir, BilbyJsonEncoder, decode_bilby_json class PriorDict(dict): @@ -496,6 +495,7 @@ class PriorDict(dict): ======= list: List of floats containing the rescaled sample """ + from matplotlib.cbook import flatten return list(flatten([self[key].rescale(sample) for key, sample in zip(keys, theta)])) def test_redundancy(self, key, disable_logging=False): diff --git a/bilby/core/prior/interpolated.py b/bilby/core/prior/interpolated.py index 992a75b5..ece311a5 100644 --- a/bilby/core/prior/interpolated.py +++ b/bilby/core/prior/interpolated.py @@ -1,9 +1,8 @@ import numpy as np -from scipy.integrate import cumtrapz from scipy.interpolate import interp1d from .base import Prior -from bilby.core.utils import logger +from ..utils import logger class Interped(Prior): @@ -164,6 +163,7 @@ class Interped(Prior): self._initialize_attributes() def _initialize_attributes(self): + from scipy.integrate import cumtrapz if np.trapz(self._yy, self.xx) != 1: logger.debug('Supplied PDF for {} is not normalised, normalising.'.format(self.name)) self._yy /= np.trapz(self._yy, self.xx) diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index b3ee2f0b..a10a0add 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -3,7 +3,7 @@ import scipy.stats from scipy.special import erfinv from .base import Prior, PriorException -from bilby.core.utils import logger, infer_args_from_method, get_dict_with_properties +from ..utils import logger, infer_args_from_method, get_dict_with_properties class BaseJointPriorDist(object): diff --git a/bilby/core/prior/slabspike.py b/bilby/core/prior/slabspike.py index 54872b8b..8482b814 100644 --- a/bilby/core/prior/slabspike.py +++ b/bilby/core/prior/slabspike.py @@ -1,7 +1,7 @@ import numpy as np -from bilby.core.prior.base import Prior -from bilby.core.utils import logger +from .base import Prior +from ..utils import logger class SlabSpikePrior(Prior): diff --git a/bilby/core/result.py b/bilby/core/result.py index 789fa5d9..072d1e52 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -1,22 +1,14 @@ import inspect +import json import os from collections import OrderedDict, namedtuple from copy import copy -from distutils.version import LooseVersion from importlib import import_module from itertools import product -from tqdm import tqdm -import corner -import h5py -import json -import matplotlib -import matplotlib.pyplot as plt -from matplotlib import lines as mpllines import numpy as np import pandas as pd import scipy.stats -from scipy.special import logsumexp from . import utils from .utils import ( @@ -126,6 +118,7 @@ def get_weights_for_reweighting( n_checkpoint: int Number of samples to reweight before writing a resume file """ + from tqdm.auto import tqdm nposterior = len(result.posterior) @@ -257,6 +250,7 @@ def reweight(result, label=None, new_likelihood=None, new_prior=None, An array of the natural-log priors from the old likelihood """ + from scipy.special import logsumexp result = copy(result) @@ -511,6 +505,7 @@ class Result(object): @classmethod @docstring(_load_doctstring.format(format="hdf5")) def from_hdf5(cls, filename=None, outdir=None, label=None): + import h5py filename = _determine_file_name(filename, outdir, label, 'hdf5', False) with h5py.File(filename, "r") as ff: data = recursively_load_dict_contents_from_group(ff, '/') @@ -770,6 +765,7 @@ class Result(object): with open(filename, 'w') as file: json.dump(dictionary, file, indent=2, cls=BilbyJsonEncoder) elif extension == 'hdf5': + import h5py dictionary["__module__"] = self.__module__ dictionary["__name__"] = self.__class__.__name__ with h5py.File(filename, 'w') as h5file: @@ -984,6 +980,7 @@ class Result(object): figure: matplotlib.pyplot.figure A matplotlib figure object """ + import matplotlib.pyplot as plt logger.info('Plotting {} marginal distribution'.format(key)) label = self.get_latex_labels_from_parameter_keys([key])[0] fig, ax = plt.subplots() @@ -1153,6 +1150,8 @@ class Result(object): A matplotlib figure instance """ + import corner + import matplotlib.pyplot as plt # If in testing mode, not corner plots are generated if utils.command_line_args.bilby_test_mode: @@ -1165,12 +1164,7 @@ class Result(object): truth_color='tab:orange', quantiles=[0.16, 0.84], levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.)), plot_density=False, plot_datapoints=True, fill_contours=True, - max_n_ticks=3) - - if LooseVersion(matplotlib.__version__) < "2.1": - defaults_kwargs['hist_kwargs'] = dict(normed=True) - else: - defaults_kwargs['hist_kwargs'] = dict(density=True) + max_n_ticks=3, hist_kwargs=dict(density=True)) if 'lionize' in kwargs and kwargs['lionize'] is True: defaults_kwargs['truth_color'] = 'tab:blue' @@ -1282,6 +1276,7 @@ class Result(object): @latex_plot_format def plot_walkers(self, **kwargs): """ Method to plot the trace of the walkers in an ensemble MCMC plot """ + import matplotlib.pyplot as plt if hasattr(self, 'walkers') is False: logger.warning("Cannot plot_walkers as no walkers are saved") return @@ -1345,6 +1340,7 @@ class Result(object): Path to the outdir. Default is the one store in the result object. """ + import matplotlib.pyplot as plt # Determine model_posterior, the subset of the full posterior which # should be passed into the model @@ -1794,6 +1790,7 @@ class ResultList(list): result: bilby.core.result.Result The result object with the combined evidences. """ + from scipy.special import logsumexp self.check_nested_samples() # Combine evidences @@ -1886,6 +1883,8 @@ def plot_multiple(results, filename=None, labels=None, colours=None, A matplotlib figure instance """ + import matplotlib.pyplot as plt + import matplotlib.lines as mpllines kwargs['show_titles'] = False kwargs['truths'] = None @@ -1977,6 +1976,7 @@ def make_pp_plot(results, filename=None, save=True, confidence_interval=[0.68, 0 matplotlib figure and a NamedTuple with attributes `combined_pvalue`, `pvalues`, and `names`. """ + import matplotlib.pyplot as plt if keys is None: keys = results[0].search_parameter_keys diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index 48858d6d..21184ba6 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -6,7 +6,6 @@ from collections import OrderedDict import bilby from ..utils import command_line_args, logger, loaded_modules_dict from ..prior import PriorDict, DeltaFunction - from .base_sampler import Sampler, SamplingMarginalisedParameterError from .cpnest import Cpnest from .dynamic_dynesty import DynamicDynesty diff --git a/bilby/core/sampler/dynamic_dynesty.py b/bilby/core/sampler/dynamic_dynesty.py index 88b16763..8bb6d647 100644 --- a/bilby/core/sampler/dynamic_dynesty.py +++ b/bilby/core/sampler/dynamic_dynesty.py @@ -1,6 +1,5 @@ import os -import dill as pickle import signal import numpy as np @@ -168,18 +167,20 @@ class DynamicDynesty(Dynesty): def write_current_state(self): """ """ + import dill check_directory_exists_and_if_not_mkdir(self.outdir) with open(self.resume_file, 'wb') as file: - pickle.dump(self, file) + dill.dump(self, file) def read_saved_state(self, continuing=False): """ """ + import dill logger.debug("Reading resume file {}".format(self.resume_file)) if os.path.isfile(self.resume_file): with open(self.resume_file, 'rb') as file: - self = pickle.load(file) + self = dill.load(file) else: logger.debug( "Failed to read resume file {}".format(self.resume_file)) diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index d453ffd7..7248ff80 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -1,13 +1,10 @@ import datetime -import dill import os import sys -import pickle import signal import time +import warnings -from tqdm.auto import tqdm -import matplotlib.pyplot as plt import numpy as np from pandas import DataFrame @@ -20,11 +17,6 @@ from ..utils import ( from .base_sampler import Sampler, NestedSampler from ..result import rejection_sample -from numpy import linalg -from dynesty.utils import unitcheck -import warnings - - _likelihood = None _priors = None _search_parameter_keys = None @@ -218,6 +210,7 @@ class Dynesty(NestedSampler): kwargs['queue_size'] = kwargs.pop(equiv) def _verify_kwargs_against_default_kwargs(self): + from tqdm.auto import tqdm if not self.kwargs['walks']: self.kwargs['walks'] = self.ndim * 10 if not self.kwargs['update_interval']: @@ -323,6 +316,7 @@ class Dynesty(NestedSampler): def run_sampler(self): import dynesty + import dill logger.info("Using dynesty version {}".format(dynesty.__version__)) if self.kwargs.get("sample", "rwalk") == "rwalk": @@ -376,7 +370,7 @@ class Dynesty(NestedSampler): check_directory_exists_and_if_not_mkdir(self.outdir) dynesty_result = "{}/{}_dynesty.pickle".format(self.outdir, self.label) with open(dynesty_result, 'wb') as file: - pickle.dump(out, file) + dill.dump(out, file) self._generate_result(out) self.calc_likelihood_count() @@ -482,6 +476,7 @@ class Dynesty(NestedSampler): """ from ... import __version__ as bilby_version from dynesty import __version__ as dynesty_version + import dill versions = dict(bilby=bilby_version, dynesty=dynesty_version) if os.path.isfile(self.resume_file): logger.info("Reading resume file {}".format(self.resume_file)) @@ -560,6 +555,7 @@ class Dynesty(NestedSampler): from ... import __version__ as bilby_version from dynesty import __version__ as dynesty_version + import dill check_directory_exists_and_if_not_mkdir(self.outdir) end_time = datetime.datetime.now() if hasattr(self, 'start_time'): @@ -605,6 +601,7 @@ class Dynesty(NestedSampler): df.to_csv(filename, index=False, header=True, sep=' ') def plot_current_state(self): + import matplotlib.pyplot as plt if self.check_point_plot: import dynesty.plotting as dyplot labels = [label.replace('_', ' ') for label in self.search_parameter_keys] @@ -704,6 +701,7 @@ class Dynesty(NestedSampler): def sample_rwalk_bilby(args): """ Modified bilby-implemented version of dynesty.sampling.sample_rwalk """ + from dynesty.utils import unitcheck # Unzipping. (u, loglstar, axes, scale, @@ -737,7 +735,7 @@ def sample_rwalk_bilby(args): # Propose a direction on the unit n-sphere. drhat = rstate.randn(n) - drhat /= linalg.norm(drhat) + drhat /= np.linalg.norm(drhat) # Scale based on dimensionality. dr = drhat * rstate.rand() ** (1.0 / n) diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index 81e4eedd..47185198 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -1,14 +1,13 @@ -from collections import namedtuple import os import signal import shutil -from shutil import copyfile import sys +from collections import namedtuple +from distutils.version import LooseVersion +from shutil import copyfile import numpy as np from pandas import DataFrame -from distutils.version import LooseVersion -import dill as pickle from ..utils import logger, check_directory_exists_and_if_not_mkdir from .base_sampler import MCMCSampler, SamplerError @@ -254,13 +253,14 @@ class Emcee(MCMCSampler): def checkpoint(self): """ Writes a pickle file of the sampler to disk using dill """ + import dill logger.info("Checkpointing sampler to file {}" .format(self.checkpoint_info.sampler_file)) with open(self.checkpoint_info.sampler_file, 'wb') as f: # Overwrites the stored sampler chain with one that is truncated # to only the completed steps self.sampler._chain = self.sampler_chain - pickle.dump(self._sampler, f) + dill.dump(self._sampler, f) def checkpoint_and_exit(self, signum, frame): logger.info("Recieved signal {}".format(signum)) @@ -283,10 +283,11 @@ class Emcee(MCMCSampler): if hasattr(self, '_sampler'): pass elif self.resume and os.path.isfile(self.checkpoint_info.sampler_file): + import dill logger.info("Resuming run from checkpoint file {}" .format(self.checkpoint_info.sampler_file)) with open(self.checkpoint_info.sampler_file, 'rb') as f: - self._sampler = pickle.load(f) + self._sampler = dill.load(f) self._set_pos0_for_resume() else: self._initialise_sampler() diff --git a/bilby/core/sampler/fake_sampler.py b/bilby/core/sampler/fake_sampler.py index 1992d74b..8d218472 100644 --- a/bilby/core/sampler/fake_sampler.py +++ b/bilby/core/sampler/fake_sampler.py @@ -1,5 +1,6 @@ import numpy as np + from .base_sampler import Sampler from ..result import read_in_result diff --git a/bilby/core/sampler/kombine.py b/bilby/core/sampler/kombine.py index d077499c..83947fc8 100644 --- a/bilby/core/sampler/kombine.py +++ b/bilby/core/sampler/kombine.py @@ -1,7 +1,9 @@ -from ..utils import logger -import numpy as np import os + +import numpy as np + from .emcee import Emcee +from ..utils import logger class Kombine(Emcee): diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py index 2be681f3..471439a1 100644 --- a/bilby/core/sampler/ptemcee.py +++ b/bilby/core/sampler/ptemcee.py @@ -1,18 +1,14 @@ - -import os -import datetime import copy +import datetime +import logging +import os import signal import sys import time -import dill from collections import namedtuple -import logging import numpy as np import pandas as pd -import matplotlib.pyplot as plt -import scipy.signal from ..utils import logger, check_directory_exists_and_if_not_mkdir from .base_sampler import SamplerError, MCMCSampler @@ -372,6 +368,7 @@ class Ptemcee(MCMCSampler): import ptemcee if os.path.isfile(self.resume_file) and self.resume is True: + import dill logger.info("Resume data {} found".format(self.resume_file)) with open(self.resume_file, "rb") as file: data = dill.load(file) @@ -833,10 +830,12 @@ def check_iteration( def get_max_gradient(x, axis=0, window_length=11, polyorder=2, smooth=False): + from scipy.signal import savgol_filter if smooth: - x = scipy.signal.savgol_filter( - x, axis=axis, window_length=window_length, polyorder=3) - return np.max(scipy.signal.savgol_filter( + x = savgol_filter( + x, axis=axis, window_length=window_length, polyorder=3 + ) + return np.max(savgol_filter( x, axis=axis, window_length=window_length, polyorder=polyorder, deriv=1)) @@ -963,6 +962,7 @@ def checkpoint( Q_list, time_per_check, ): + import dill logger.info("Writing checkpoint and diagnostics") ndim = sampler.dim @@ -1002,6 +1002,7 @@ def checkpoint( def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label, discard=0): """ Method to plot the trace of the walkers in an ensemble MCMC plot """ + import matplotlib.pyplot as plt nwalkers, nsteps, ndim = walkers.shape if np.isnan(nburn): nburn = nsteps @@ -1053,6 +1054,7 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label, def plot_tau( tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau, ): + import matplotlib.pyplot as plt fig, ax = plt.subplots() for i, key in enumerate(search_parameter_keys): ax.plot(tau_list_n, np.array(tau_list)[:, i], label=key) @@ -1065,6 +1067,7 @@ def plot_tau( def plot_mean_log_posterior(mean_log_posterior, outdir, label): + import matplotlib.pyplot as plt ntemps, nsteps = mean_log_posterior.shape ymax = np.max(mean_log_posterior) @@ -1085,6 +1088,7 @@ def plot_mean_log_posterior(mean_log_posterior, outdir, label): def compute_evidence(sampler, log_likelihood_array, outdir, label, discard, nburn, thin, iteration, make_plots=True): """ Computes the evidence using thermodynamic integration """ + import matplotlib.pyplot as plt betas = sampler.betas # We compute the evidence without the burnin samples, but we do not thin lnlike = log_likelihood_array[:, :, discard + nburn : iteration] diff --git a/bilby/core/utils.py b/bilby/core/utils.py index 7bfb1faa..ccb7fc11 100644 --- a/bilby/core/utils.py +++ b/bilby/core/utils.py @@ -1,26 +1,24 @@ - -from distutils.spawn import find_executable -import logging -import os -import shutil -import sys -from math import fmod import argparse -import inspect import functools +import inspect +import json +import logging +import multiprocessing +import os import types +import shutil import subprocess -import multiprocessing +import sys +import warnings +from distutils.spawn import find_executable from importlib import import_module +from math import fmod from numbers import Number -import json -import warnings import numpy as np +import pandas as pd from scipy.interpolate import interp2d from scipy.special import logsumexp -import pandas as pd -import matplotlib.pyplot as plt logger = logging.getLogger('bilby') @@ -1208,6 +1206,7 @@ def latex_plot_format(func): """ @functools.wraps(func) def wrapper_decorator(*args, **kwargs): + import matplotlib.pyplot as plt from matplotlib import rcParams if "BILBY_STYLE" in kwargs: diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index f034dfe8..1ded6605 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -1,7 +1,6 @@ import sys import multiprocessing -from tqdm.auto import tqdm import numpy as np from pandas import DataFrame @@ -12,13 +11,6 @@ from .utils import lalsim_SimInspiralTransformPrecessingNewInitialConditions from .eos.eos import SpectralDecompositionEOS, EOSFamily, IntegrateTOV from .cosmology import get_cosmology -try: - from astropy import units - from astropy.cosmology import z_at_value -except ImportError: - logger.debug("You do not have astropy installed currently. You will" - " not be able to use some of the prebuilt functions.") - def redshift_to_luminosity_distance(redshift, cosmology=None): cosmology = get_cosmology(cosmology) @@ -32,12 +24,16 @@ def redshift_to_comoving_distance(redshift, cosmology=None): @np.vectorize def luminosity_distance_to_redshift(distance, cosmology=None): + from astropy import units + from astropy.cosmology import z_at_value cosmology = get_cosmology(cosmology) return z_at_value(cosmology.luminosity_distance, distance * units.Mpc) @np.vectorize def comoving_distance_to_redshift(distance, cosmology=None): + from astropy import units + from astropy.cosmology import z_at_value cosmology = get_cosmology(cosmology) return z_at_value(cosmology.comoving_distance, distance * units.Mpc) @@ -1148,6 +1144,7 @@ def compute_snrs(sample, likelihood): sample['{}_optimal_snr'.format(ifo.name)] = \ per_detector_snr.optimal_snr_squared.real ** 0.5 else: + from tqdm.auto import tqdm logger.info( 'Computing SNRs for every sample.') @@ -1212,6 +1209,7 @@ def generate_posterior_samples_from_marginalized_likelihood( return samples elif not isinstance(samples, DataFrame): raise ValueError("Unable to handle input samples of type {}".format(type(samples))) + from tqdm.auto import tqdm logger.info('Reconstructing marginalised parameters.') @@ -1242,6 +1240,7 @@ def generate_sky_frame_parameters(samples, likelihood): return elif not isinstance(samples, DataFrame): raise ValueError + from tqdm.auto import tqdm logger.info('Generating sky frame parameters.') new_samples = list() diff --git a/bilby/gw/cosmology.py b/bilby/gw/cosmology.py index 861fdf23..ded3f5eb 100644 --- a/bilby/gw/cosmology.py +++ b/bilby/gw/cosmology.py @@ -1,19 +1,20 @@ -from ..core.utils import logger +DEFAULT_COSMOLOGY = None +COSMOLOGY = [None, str(None)] -try: + +def _set_default_cosmology(): from astropy import cosmology as cosmo - DEFAULT_COSMOLOGY = cosmo.Planck15 - COSMOLOGY = [DEFAULT_COSMOLOGY, DEFAULT_COSMOLOGY.name] -except ImportError: - logger.debug("You do not have astropy installed currently. You will" - " not be able to use some of the prebuilt functions.") - DEFAULT_COSMOLOGY = None - COSMOLOGY = [None, str(None)] + global DEFAULT_COSMOLOGY, COSMOLOGY + if DEFAULT_COSMOLOGY is None: + DEFAULT_COSMOLOGY = cosmo.Planck15 + COSMOLOGY = [DEFAULT_COSMOLOGY, DEFAULT_COSMOLOGY.name] def get_cosmology(cosmology=None): + from astropy import cosmology as cosmo + _set_default_cosmology() if cosmology is None: - cosmology = COSMOLOGY[0] + cosmology = DEFAULT_COSMOLOGY elif isinstance(cosmology, str): cosmology = cosmo.__dict__[cosmology] return cosmology @@ -41,6 +42,8 @@ def set_cosmology(cosmology=None): cosmo: astropy.cosmology.FLRW Cosmology instance """ + from astropy import cosmology as cosmo + _set_default_cosmology() if cosmology is None: cosmology = DEFAULT_COSMOLOGY elif isinstance(cosmology, cosmo.FLRW): diff --git a/bilby/gw/detector/__init__.py b/bilby/gw/detector/__init__.py index 37d60276..b68313c0 100644 --- a/bilby/gw/detector/__init__.py +++ b/bilby/gw/detector/__init__.py @@ -5,13 +5,6 @@ from .networks import * from .psd import * from .strain_data import * -try: - import lal - import lalsimulation as lalsim -except ImportError: - logger.debug("You do not have lalsuite installed currently. You will" - " not be able to use some of the prebuilt functions.") - def get_safe_signal_duration(mass_1, mass_2, a_1, a_2, tilt_1, tilt_2, flow=10): """ Calculate the safe signal duration, given the parameters @@ -30,8 +23,10 @@ def get_safe_signal_duration(mass_1, mass_2, a_1, a_2, tilt_1, tilt_2, flow=10): to the nearest power of 2) """ - chirp_time = lalsim.SimInspiralChirpTimeBound( - flow, mass_1 * lal.MSUN_SI, mass_2 * lal.MSUN_SI, + from lal import MSUN_SI + from lalsimulation import SimInspiralChirpTimeBound + chirp_time = SimInspiralChirpTimeBound( + flow, mass_1 * MSUN_SI, mass_2 * MSUN_SI, a_1 * np.cos(tilt_1), a_2 * np.cos(tilt_2)) return max(2**(int(np.log2(chirp_time)) + 1), 4) @@ -329,7 +324,7 @@ def load_data_by_channel_name( ifo = get_empty_interferometer(det) ifo.set_strain_data_from_channel_name( - channel = channel_name, + channel=channel_name, sampling_frequency=sampling_frequency, duration=segment_duration, start_time=start_time) diff --git a/bilby/gw/detector/calibration.py b/bilby/gw/detector/calibration.py index 88a5a664..b52eb928 100644 --- a/bilby/gw/detector/calibration.py +++ b/bilby/gw/detector/calibration.py @@ -2,7 +2,6 @@ """ import numpy as np -import tables from scipy.interpolate import interp1d @@ -29,6 +28,7 @@ def read_calibration_file(filename, frequency_array, number_of_response_curves, Shape is (number_of_response_curves x len(frequency_array)) """ + import tables calibration_file = tables.open_file(filename, 'r') calibration_amplitude = \ @@ -71,6 +71,7 @@ def write_calibration_file(filename, frequency_array, calibration_draws, calibra Parameters used to generate the random draws of the calibration response curves """ + import tables calibration_file = tables.open_file(filename, 'w') deltaR_group = calibration_file.create_group(calibration_file.root, 'deltaR') diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index 72a00168..7155d0f4 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -2,7 +2,6 @@ import os import sys import numpy as np -from matplotlib import pyplot as plt from ...core import utils from ...core.utils import docstring, logger @@ -12,13 +11,6 @@ from .calibration import Recalibrate from .geometry import InterferometerGeometry from .strain_data import InterferometerStrainData -try: - import gwpy - import gwpy.signal -except ImportError: - logger.debug("You do not have gwpy installed currently. You will " - " not be able to use some of the prebuilt functions.") - class Interferometer(object): """Class for the Interferometer """ @@ -599,6 +591,7 @@ class Interferometer(object): header='f h(f)') def plot_data(self, signal=None, outdir='.', label=None): + import matplotlib.pyplot as plt if utils.command_line_args.bilby_test_mode: return @@ -656,23 +649,26 @@ class Interferometer(object): plotting. """ + import matplotlib.pyplot as plt + from gwpy.timeseries import TimeSeries + from gwpy.signal.filter_design import bandpass, concatenate_zpks, notch # We use the gwpy timeseries to perform bandpass and notching if notches is None: notches = list() - timeseries = gwpy.timeseries.TimeSeries( + timeseries = TimeSeries( data=self.strain_data.time_domain_strain, times=self.strain_data.time_array) zpks = [] if bandpass_frequencies is not None: - zpks.append(gwpy.signal.filter_design.bandpass( + zpks.append(bandpass( bandpass_frequencies[0], bandpass_frequencies[1], self.strain_data.sampling_frequency)) if notches is not None: for line in notches: - zpks.append(gwpy.signal.filter_design.notch( + zpks.append(notch( line, self.strain_data.sampling_frequency)) if len(zpks) > 0: - zpk = gwpy.signal.filter_design.concatenate_zpks(*zpks) + zpk = concatenate_zpks(*zpks) strain = timeseries.filter(zpk, filtfilt=False) else: strain = timeseries diff --git a/bilby/gw/detector/strain_data.py b/bilby/gw/detector/strain_data.py index 71d8acf6..b6e2b46c 100644 --- a/bilby/gw/detector/strain_data.py +++ b/bilby/gw/detector/strain_data.py @@ -1,5 +1,4 @@ import numpy as np -from scipy.signal.windows import tukey from ...core import utils from ...core.series import CoupledTimeAndFrequencySeries @@ -7,19 +6,6 @@ from ...core.utils import logger from .. import utils as gwutils from ..utils import PropertyAccessor -try: - import gwpy - import gwpy.signal -except ImportError: - logger.debug("You do not have gwpy installed currently. You will " - " not be able to use some of the prebuilt functions.") - -try: - import lal -except ImportError: - logger.debug("You do not have lalsuite installed currently. You will" - " not be able to use some of the prebuilt functions.") - class InterferometerStrainData(object): """ Strain data for an interferometer """ @@ -198,6 +184,7 @@ class InterferometerStrainData(object): window: array Window function over time array """ + from scipy.signal.windows import tukey if roll_off is not None: self.roll_off = roll_off elif alpha is not None: @@ -253,11 +240,15 @@ class InterferometerStrainData(object): """ Output the time series strain data as a :class:`gwpy.timeseries.TimeSeries`. """ + try: + from gwpy.timeseries import TimeSeries + except ModuleNotFoundError: + raise ModuleNotFoundError("Cannot output strain data as gwpy TimeSeries") - return gwpy.timeseries.TimeSeries(self.time_domain_strain, - sample_rate=self.sampling_frequency, - t0=self.start_time, - channel=self.channel) + return TimeSeries( + self.time_domain_strain, sample_rate=self.sampling_frequency, + t0=self.start_time, channel=self.channel + ) def to_pycbc_timeseries(self): """ @@ -265,37 +256,48 @@ class InterferometerStrainData(object): """ try: - import pycbc - except ImportError: - raise ImportError("Cannot output strain data as PyCBC TimeSeries") + from pycbc.types.timeseries import TimeSeries + from lal import LIGOTimeGPS + except ModuleNotFoundError: + raise ModuleNotFoundError("Cannot output strain data as PyCBC TimeSeries") - return pycbc.types.timeseries.TimeSeries(self.time_domain_strain, - delta_t=(1. / self.sampling_frequency), - epoch=lal.LIGOTimeGPS(self.start_time)) + return TimeSeries( + self.time_domain_strain, delta_t=(1. / self.sampling_frequency), + epoch=LIGOTimeGPS(self.start_time) + ) def to_lal_timeseries(self): """ Output the time series strain data as a LAL TimeSeries object. """ + try: + from lal import CreateREAL8TimeSeries, LIGOTimeGPS, SecondUnit + except ModuleNotFoundError: + raise ModuleNotFoundError("Cannot output strain data as PyCBC TimeSeries") - laldata = lal.CreateREAL8TimeSeries("", - lal.LIGOTimeGPS(self.start_time), - 0., (1. / self.sampling_frequency), - lal.SecondUnit, - len(self.time_domain_strain)) - laldata.data.data[:] = self.time_domain_strain + lal_data = CreateREAL8TimeSeries( + "", LIGOTimeGPS(self.start_time), 0, 1 / self.sampling_frequency, + SecondUnit, len(self.time_domain_strain) + ) + lal_data.data.data[:] = self.time_domain_strain - return laldata + return lal_data def to_gwpy_frequencyseries(self): """ Output the frequency series strain data as a :class:`gwpy.frequencyseries.FrequencySeries`. """ + try: + from gwpy.frequencyseries import FrequencySeries + except ModuleNotFoundError: + raise ModuleNotFoundError("Cannot output strain data as gwpy FrequencySeries") - return gwpy.frequencyseries.FrequencySeries(self.frequency_domain_strain, - frequencies=self.frequency_array, - epoch=self.start_time, - channel=self.channel) + return FrequencySeries( + self.frequency_domain_strain, + frequencies=self.frequency_array, + epoch=self.start_time, + channel=self.channel + ) def to_pycbc_frequencyseries(self): """ @@ -303,31 +305,42 @@ class InterferometerStrainData(object): """ try: - import pycbc + from pycbc.types.frequencyseries import FrequencySeries + from lal import LIGOTimeGPS except ImportError: raise ImportError("Cannot output strain data as PyCBC FrequencySeries") - return pycbc.types.frequencyseries.FrequencySeries(self.frequency_domain_strain, - delta_f=(self.frequency_array[1] - self.frequency_array[0]), - epoch=lal.LIGOTimeGPS(self.start_time)) + return FrequencySeries( + self.frequency_domain_strain, + delta_f=1 / self.duration, + epoch=LIGOTimeGPS(self.start_time) + ) def to_lal_frequencyseries(self): """ Output the frequency series strain data as a LAL FrequencySeries object. """ - - laldata = lal.CreateCOMPLEX16FrequencySeries("", - lal.LIGOTimeGPS(self.start_time), - self.frequency_array[0], - (self.frequency_array[1] - self.frequency_array[0]), - lal.SecondUnit, - len(self.frequency_domain_strain)) - laldata.data.data[:] = self.frequency_domain_strain - - return laldata + try: + from lal import CreateCOMPLEX16FrequencySeries, LIGOTimeGPS, SecondUnit + except ModuleNotFoundError: + raise ModuleNotFoundError("Cannot output strain data as PyCBC TimeSeries") + + lal_data = CreateCOMPLEX16FrequencySeries( + "", + LIGOTimeGPS(self.start_time), + self.frequency_array[0], + 1 / self.duration, + SecondUnit, + len(self.frequency_domain_strain) + ) + lal_data.data.data[:] = self.frequency_domain_strain + + return lal_data def low_pass_filter(self, filter_freq=None): """ Low pass filter the data """ + from gwpy.signal.filter_design import lowpass + from gwpy.timeseries import TimeSeries if filter_freq is None: logger.debug( @@ -342,10 +355,8 @@ class InterferometerStrainData(object): return logger.debug("Applying low pass filter with filter frequency {}".format(filter_freq)) - bp = gwpy.signal.filter_design.lowpass( - filter_freq, self.sampling_frequency) - strain = gwpy.timeseries.TimeSeries( - self.time_domain_strain, sample_rate=self.sampling_frequency) + bp = lowpass(filter_freq, self.sampling_frequency) + strain = TimeSeries(self.time_domain_strain, sample_rate=self.sampling_frequency) strain = strain.filter(bp, filtfilt=True) self._time_domain_strain = strain.value @@ -379,6 +390,7 @@ class InterferometerStrainData(object): The frequencies and power spectral density array """ + from gwpy.timeseries import TimeSeries data = self.time_domain_strain @@ -394,7 +406,7 @@ class InterferometerStrainData(object): data = data[idxs] # WARNING this line can cause issues if the data is non-contiguous - strain = gwpy.timeseries.TimeSeries(data=data, sample_rate=self.sampling_frequency) + strain = TimeSeries(data=data, sample_rate=self.sampling_frequency) psd_alpha = 2 * self.roll_off / fft_length logger.info( "Tukey window PSD data with alpha={}, roll off={}".format( @@ -500,8 +512,9 @@ class InterferometerStrainData(object): The data to use """ + from gwpy.timeseries import TimeSeries logger.debug('Setting data using provided gwpy TimeSeries object') - if type(time_series) != gwpy.timeseries.TimeSeries: + if not isinstance(time_series, TimeSeries): raise ValueError("Input time_series is not a gwpy TimeSeries") self._times_and_frequencies = \ CoupledTimeAndFrequencySeries(duration=time_series.duration.value, @@ -557,7 +570,8 @@ class InterferometerStrainData(object): The path to the file to read in """ - timeseries = gwpy.timeseries.TimeSeries.read(filename, format='csv') + from gwpy.timeseries import TimeSeries + timeseries = TimeSeries.read(filename, format='csv') self.set_from_gwpy_timeseries(timeseries) def set_from_frequency_domain_strain( @@ -700,6 +714,7 @@ class InterferometerStrainData(object): The sampling frequency (in Hz) """ + from gwpy.timeseries import TimeSeries channel_comp = channel.split(':') if len(channel_comp) != 2: raise IndexError('Channel name must have format `IFO:Channel`') @@ -709,8 +724,7 @@ class InterferometerStrainData(object): start_time=start_time) logger.info('Fetching data using channel {}'.format(channel)) - strain = gwpy.timeseries.TimeSeries.get( - channel, start_time, start_time + duration) + strain = TimeSeries.get(channel, start_time, start_time + duration) strain = strain.resample(sampling_frequency) self.set_from_gwpy_timeseries(strain) diff --git a/bilby/gw/eos/eos.py b/bilby/gw/eos/eos.py index b2a45dc2..f63c2b83 100644 --- a/bilby/gw/eos/eos.py +++ b/bilby/gw/eos/eos.py @@ -1,9 +1,6 @@ import os import numpy as np -import matplotlib.pyplot as plt -from scipy.integrate import cumtrapz, quad from scipy.interpolate import interp1d, CubicSpline -from scipy.optimize import minimize_scalar from .tov_solver import IntegrateTOV from ...core import utils @@ -58,6 +55,7 @@ class TabularEOS(object): """ def __init__(self, eos, sampling_flag=False, warning_flag=False): + from scipy.integrate import cumtrapz self.sampling_flag = sampling_flag self.warning_flag = warning_flag @@ -398,6 +396,7 @@ class TabularEOS(object): fig: matplotlib.figure.Figure EOS plot. """ + import matplotlib.pyplot as plt # Set data based on specified representation varnames = rep.split('-') @@ -538,7 +537,7 @@ class SpectralDecompositionEOS(TabularEOS): return 1. / spectral_adiabatic_index(self.gammas, x) def mu(self, x): - + from scipy.integrate import quad return np.exp(-quad(self.__mu_integrand, 0, x)[0]) def __eps_integrand(self, x): @@ -546,7 +545,7 @@ class SpectralDecompositionEOS(TabularEOS): return np.exp(x) * self.mu(x) / spectral_adiabatic_index(self.gammas, x) def energy_density(self, x, eps0): - + from scipy.integrate import quad quad_result, quad_err = quad(self.__eps_integrand, 0, x) eps_of_x = (eps0 * C_CGS ** 2.) / self.mu(x) + self.p0 / self.mu(x) * quad_result return eps_of_x @@ -638,6 +637,7 @@ class EOSFamily(object): populated here via the TOV solver upon object construction. """ def __init__(self, eos, npts=500): + from scipy.optimize import minimize_scalar self.eos = eos # FIXME: starting_energy_density is set somewhat arbitrarily @@ -781,6 +781,7 @@ class EOSFamily(object): fig: matplotlib.figure.Figure EOS Family plot. """ + import matplotlib.pyplot as plt # Set data based on specified representation varnames = rep.split('-') diff --git a/bilby/gw/eos/tov_solver.py b/bilby/gw/eos/tov_solver.py index 5b5d7e13..0135086e 100644 --- a/bilby/gw/eos/tov_solver.py +++ b/bilby/gw/eos/tov_solver.py @@ -1,7 +1,6 @@ # Monica Rizzo, 2019 import numpy as np -from scipy.integrate import solve_ivp class IntegrateTOV: @@ -111,6 +110,7 @@ class IntegrateTOV: """ Evolves TOV+k2 equations and returns final quantities """ + from scipy.integrate import solve_ivp # integration settings the same as in lalsimulation rel_err = 1e-4 diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py index 0f113683..b6fdfea7 100644 --- a/bilby/gw/likelihood.py +++ b/bilby/gw/likelihood.py @@ -5,16 +5,9 @@ import json import copy import numpy as np -import scipy.integrate as integrate -from scipy.interpolate import interp1d import pandas as pd -from tqdm import tqdm - -try: - from scipy.special import logsumexp -except ImportError: - from scipy.misc import logsumexp -from scipy.special import i0e +from scipy.interpolate import interp1d +from scipy.special import logsumexp, i0e from ..core.likelihood import Likelihood from ..core.utils import BilbyJsonEncoder, decode_bilby_json @@ -939,10 +932,11 @@ class GravitationalWaveTransient(Likelihood): self.cache_lookup_table() def _setup_phase_marginalization(self, min_bound=-5, max_bound=10): + x_values = np.logspace(min_bound, max_bound, int(1e6)) self._bessel_function_interped = interp1d( - np.logspace(-5, max_bound, int(1e6)), np.logspace(-5, max_bound, int(1e6)) + - np.log([i0e(snr) for snr in np.logspace(-5, max_bound, int(1e6))]), - bounds_error=False, fill_value=(0, np.nan)) + x_values, x_values + np.log([i0e(snr) for snr in x_values]), + bounds_error=False, fill_value=(0, np.nan) + ) def _setup_time_marginalization(self): self._delta_tc = 2 / self.waveform_generator.sampling_frequency @@ -980,6 +974,7 @@ class GravitationalWaveTransient(Likelihood): self.number_of_response_curves, self.starting_index) else: # generate the fake curves + from tqdm.auto import tqdm self.calibration_parameter_draws[interferometer.name] =\ pd.DataFrame(calibration_priors.sample(self.number_of_response_curves)) @@ -1609,10 +1604,11 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): f_high: float The maximum frequency which must be considered """ + from scipy.integrate import simps integrand1 = np.power(freq, -7. / 3) / psd - integral1 = integrate.simps(integrand1, freq) + integral1 = simps(integrand1, freq) integrand3 = np.power(freq, 2. / 3.) / (psd * integral1) - f_3_bar = integrate.simps(integrand3, freq) + f_3_bar = simps(integrand3, freq) f_high = scaling * f_3_bar**(1 / 3) diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index 899e9806..8c21948a 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -3,7 +3,6 @@ import copy import numpy as np from scipy.interpolate import InterpolatedUnivariateSpline, interp1d -from scipy.integrate import cumtrapz from scipy.special import hyp2f1 from scipy.stats import norm @@ -22,12 +21,6 @@ from .conversion import ( total_mass_and_mass_ratio_to_component_masses) from .cosmology import get_cosmology -try: - from astropy import cosmology as cosmo, units -except ImportError: - logger.debug("You do not have astropy installed currently. You will" - " not be able to use some of the prebuilt functions.") - DEFAULT_PRIOR_DIR = os.path.join(os.path.dirname(__file__), 'prior_files') @@ -99,6 +92,7 @@ class Cosmological(Interped): @property def _default_args_dict(self): + from astropy import units return dict( redshift=dict(name='redshift', latex_label='$z$', unit=None), luminosity_distance=dict( @@ -108,6 +102,7 @@ class Cosmological(Interped): def __init__(self, minimum, maximum, cosmology=None, name=None, latex_label=None, unit=None, boundary=None): + from astropy import units self.cosmology = get_cosmology(cosmology) if name not in self._default_args_dict: raise ValueError( @@ -172,6 +167,7 @@ class Cosmological(Interped): recalculate_array: boolean Determines if the distance arrays are recalculated """ + from astropy.cosmology import z_at_value cosmology = get_cosmology(self.cosmology) limit_dict[self.name] = value if self.name == 'redshift': @@ -183,8 +179,9 @@ class Cosmological(Interped): if value == 0: limit_dict['redshift'] = 0 else: - limit_dict['redshift'] = cosmo.z_at_value( - cosmology.luminosity_distance, value * self.unit) + limit_dict['redshift'] = z_at_value( + cosmology.luminosity_distance, value * self.unit + ) limit_dict['comoving_distance'] = ( cosmology.comoving_distance(limit_dict['redshift']).value ) @@ -192,8 +189,9 @@ class Cosmological(Interped): if value == 0: limit_dict['redshift'] = 0 else: - limit_dict['redshift'] = cosmo.z_at_value( - cosmology.comoving_distance, value * self.unit) + limit_dict['redshift'] = z_at_value( + cosmology.comoving_distance, value * self.unit + ) limit_dict['luminosity_distance'] = ( cosmology.luminosity_distance(limit_dict['redshift']).value ) @@ -255,8 +253,10 @@ class Cosmological(Interped): """ Get a dictionary containing the arguments needed to reproduce this object. """ + from astropy.cosmology.core import Cosmology + from astropy import units dict_with_properties = super(Cosmological, self)._repr_dict - if isinstance(dict_with_properties['cosmology'], cosmo.core.Cosmology): + if isinstance(dict_with_properties['cosmology'], Cosmology): if dict_with_properties['cosmology'].name is not None: dict_with_properties['cosmology'] = dict_with_properties['cosmology'].name if isinstance(dict_with_properties['unit'], units.Unit): @@ -1113,6 +1113,7 @@ class HealPixMapPriorDist(BaseJointPriorDist): """ Method that builds the inverse cdf of the P(pixel) distribution for rescaling """ + from scipy.integrate import cumtrapz yy = self._all_interped(self.pix_xx) yy /= np.trapz(yy, self.pix_xx) YY = cumtrapz(yy, self.pix_xx, initial=0) diff --git a/bilby/gw/result.py b/bilby/gw/result.py index 7116a5bf..fc599232 100644 --- a/bilby/gw/result.py +++ b/bilby/gw/result.py @@ -1,10 +1,7 @@ - import json -import pickle import os +import pickle -import matplotlib.pyplot as plt -from matplotlib import rcParams import numpy as np from ..core.result import Result as CoreResult @@ -149,6 +146,7 @@ class CompactBinaryCoalescenceResult(CoreResult): format: str Format to save the plot, default=png, options are png/pdf """ + import matplotlib.pyplot as plt if format not in ["png", "pdf"]: raise ValueError("Format should be one of png or pdf") @@ -394,6 +392,8 @@ class CompactBinaryCoalescenceResult(CoreResult): ) ) else: + import matplotlib.pyplot as plt + from matplotlib import rcParams old_font_size = rcParams["font.size"] rcParams["font.size"] = 20 fig, axs = plt.subplots( @@ -734,6 +734,8 @@ class CompactBinaryCoalescenceResult(CoreResult): If true, load the cached pickle file (default name), or the pickle-file give as a path. """ + import matplotlib.pyplot as plt + from matplotlib import rcParams try: from astropy.time import Time diff --git a/bilby/gw/source.py b/bilby/gw/source.py index 96e61afc..b874b8a0 100644 --- a/bilby/gw/source.py +++ b/bilby/gw/source.py @@ -10,13 +10,6 @@ from .utils import (lalsim_GetApproximantFromString, lalsim_SimInspiralWaveformParamsInsertTidalLambda2, lalsim_SimInspiralChooseFDWaveformSequence) -try: - import lal - import lalsimulation as lalsim -except ImportError: - logger.debug("You do not have lalsuite installed currently. You will" - " not be able to use some of the prebuilt functions.") - def lal_binary_black_hole( frequency_array, mass_1, mass_2, luminosity_distance, a_1, tilt_1, @@ -293,6 +286,9 @@ def _base_lal_cbc_fd_waveform( ======= dict: A dictionary with the plus and cross polarisation strain modes """ + import lal + import lalsimulation as lalsim + waveform_approximant = waveform_kwargs['waveform_approximant'] reference_frequency = waveform_kwargs['reference_frequency'] minimum_frequency = waveform_kwargs['minimum_frequency'] @@ -509,6 +505,7 @@ def _base_roq_waveform( Dict containing plus and cross modes evaluated at the linear and quadratic frequency nodes. """ + from lal import CreateDict frequency_nodes_linear = waveform_arguments['frequency_nodes_linear'] frequency_nodes_quadratic = waveform_arguments['frequency_nodes_quadratic'] reference_frequency = waveform_arguments['reference_frequency'] @@ -519,7 +516,7 @@ def _base_roq_waveform( mass_1 = mass_1 * utils.solar_mass mass_2 = mass_2 * utils.solar_mass - waveform_dictionary = lal.CreateDict() + waveform_dictionary = CreateDict() lalsim_SimInspiralWaveformParamsInsertTidalLambda1( waveform_dictionary, lambda_1) lalsim_SimInspiralWaveformParamsInsertTidalLambda2( diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index cf7c0f3e..fa81766b 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -1,29 +1,15 @@ -import os import json +import os from math import fmod import numpy as np from scipy.interpolate import interp1d -import matplotlib.pyplot as plt from ..core.utils import (ra_dec_to_theta_phi, speed_of_light, logger, run_commandline, check_directory_exists_and_if_not_mkdir, SamplesSummary, theta_phi_to_ra_dec) -try: - from gwpy.timeseries import TimeSeries -except ImportError: - logger.debug("You do not have gwpy installed currently. You will " - " not be able to use some of the prebuilt functions.") - -try: - import lal - import lalsimulation as lalsim -except ImportError: - logger.debug("You do not have lalsuite installed currently. You will" - " not be able to use some of the prebuilt functions.") - def asd_from_freq_series(freq_data, df): """ @@ -88,7 +74,8 @@ def time_delay_geocentric(detector1, detector2, ra, dec, time): float: Time delay between the two detectors in the geocentric frame """ - gmst = fmod(lal.GreenwichMeanSiderealTime(time), 2 * np.pi) + from lal import GreenwichMeanSiderealTime + gmst = fmod(GreenwichMeanSiderealTime(time), 2 * np.pi) theta, phi = ra_dec_to_theta_phi(ra, dec, gmst) omega = np.array([np.sin(theta) * np.cos(phi), np.sin(theta) * np.sin(phi), np.cos(theta)]) delta_d = detector2 - detector1 @@ -122,7 +109,8 @@ def get_polarization_tensor(ra, dec, time, psi, mode): array_like: A 3x3 representation of the polarization_tensor for the specified mode. """ - gmst = fmod(lal.GreenwichMeanSiderealTime(time), 2 * np.pi) + from lal import GreenwichMeanSiderealTime + gmst = fmod(GreenwichMeanSiderealTime(time), 2 * np.pi) theta, phi = ra_dec_to_theta_phi(ra, dec, gmst) u = np.array([np.cos(phi) * np.cos(theta), np.cos(theta) * np.sin(phi), -np.sin(theta)]) v = np.array([-np.sin(phi), np.cos(phi), 0]) @@ -390,8 +378,9 @@ def zenith_azimuth_to_ra_dec(zenith, azimuth, geocent_time, ifos): ra, dec: float The zenith and azimuthal angles in the sky frame. """ + from lal import GreenwichMeanSiderealTime theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, ifos) - gmst = lal.GreenwichMeanSiderealTime(geocent_time) + gmst = GreenwichMeanSiderealTime(geocent_time) ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) ra = ra % (2 * np.pi) return ra, dec @@ -478,6 +467,7 @@ def get_open_strain_data( fails, this function retruns `None`. """ + from gwpy.timeseries import TimeSeries filename = '{}/{}_{}_{}.txt'.format(outdir, name, start_time, end_time) if buffer_time < 0: @@ -529,6 +519,7 @@ def read_frame_file(file_name, start_time, end_time, channel=None, buffer_time=0 strain: gwpy.timeseries.TimeSeries """ + from gwpy.timeseries import TimeSeries loaded = False strain = None @@ -793,17 +784,19 @@ def convert_args_list_to_float(*args_list): def lalsim_SimInspiralTransformPrecessingNewInitialConditions( theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2, mass_1, mass_2, reference_frequency, phase): + from lalsimulation import SimInspiralTransformPrecessingNewInitialConditions args_list = convert_args_list_to_float( theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2, mass_1, mass_2, reference_frequency, phase) - return lalsim.SimInspiralTransformPrecessingNewInitialConditions(*args_list) + return SimInspiralTransformPrecessingNewInitialConditions(*args_list) def lalsim_GetApproximantFromString(waveform_approximant): + from lalsimulation import GetApproximantFromString if isinstance(waveform_approximant, str): - return lalsim.GetApproximantFromString(waveform_approximant) + return GetApproximantFromString(waveform_approximant) else: raise ValueError("waveform_approximant must be of type str") @@ -840,6 +833,7 @@ def lalsim_SimInspiralFD( waveform_dictionary: None, lal.Dict approximant: int, str """ + from lalsimulation import SimInspiralFD args = convert_args_list_to_float( mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z, @@ -849,7 +843,7 @@ def lalsim_SimInspiralFD( approximant = _get_lalsim_approximant(approximant) - return lalsim.SimInspiralFD(*args, waveform_dictionary, approximant) + return SimInspiralFD(*args, waveform_dictionary, approximant) def lalsim_SimInspiralChooseFDWaveform( @@ -884,6 +878,7 @@ def lalsim_SimInspiralChooseFDWaveform( waveform_dictionary: None, lal.Dict approximant: int, str """ + from lalsimulation import SimInspiralChooseFDWaveform args = convert_args_list_to_float( mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z, @@ -893,7 +888,7 @@ def lalsim_SimInspiralChooseFDWaveform( approximant = _get_lalsim_approximant(approximant) - return lalsim.SimInspiralChooseFDWaveform(*args, waveform_dictionary, approximant) + return SimInspiralChooseFDWaveform(*args, waveform_dictionary, approximant) def _get_lalsim_approximant(approximant): @@ -931,6 +926,8 @@ def lalsim_SimInspiralChooseFDWaveformSequence( approximant: int, str frequency_array: np.ndarray, lal.REAL8Vector """ + from lal import REAL8Vector, CreateREAL8Vector + from lalsimulation import SimInspiralChooseFDWaveformSequence [mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z, luminosity_distance, iota, phase, reference_frequency] = convert_args_list_to_float( @@ -944,12 +941,12 @@ def lalsim_SimInspiralChooseFDWaveformSequence( else: raise ValueError("approximant not an int") - if not isinstance(frequency_array, lal.REAL8Vector): + if not isinstance(frequency_array, REAL8Vector): old_frequency_array = frequency_array.copy() - frequency_array = lal.CreateREAL8Vector(len(old_frequency_array)) + frequency_array = CreateREAL8Vector(len(old_frequency_array)) frequency_array.data = old_frequency_array - return lalsim.SimInspiralChooseFDWaveformSequence( + return SimInspiralChooseFDWaveformSequence( phase, mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z, reference_frequency, luminosity_distance, iota, waveform_dictionary, approximant, frequency_array) @@ -957,23 +954,25 @@ def lalsim_SimInspiralChooseFDWaveformSequence( def lalsim_SimInspiralWaveformParamsInsertTidalLambda1( waveform_dictionary, lambda_1): + from lalsimulation import SimInspiralWaveformParamsInsertTidalLambda1 try: lambda_1 = float(lambda_1) except ValueError: raise ValueError("Unable to convert lambda_1 to float") - return lalsim.SimInspiralWaveformParamsInsertTidalLambda1( + return SimInspiralWaveformParamsInsertTidalLambda1( waveform_dictionary, lambda_1) def lalsim_SimInspiralWaveformParamsInsertTidalLambda2( waveform_dictionary, lambda_2): + from lalsimulation import SimInspiralWaveformParamsInsertTidalLambda2 try: lambda_2 = float(lambda_2) except ValueError: raise ValueError("Unable to convert lambda_2 to float") - return lalsim.SimInspiralWaveformParamsInsertTidalLambda2( + return SimInspiralWaveformParamsInsertTidalLambda2( waveform_dictionary, lambda_2) @@ -1020,6 +1019,7 @@ def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label= Function to transform the spline into plotted values. """ + import matplotlib.pyplot as plt freq_points = np.exp(log_freqs) freqs = np.logspace(min(log_freqs), max(log_freqs), nfreqs, base=np.exp(1)) diff --git a/requirements.txt b/requirements.txt index 05d86e2f..4c80f78a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ dynesty emcee corner numpy<1.20 -matplotlib>=2.0 +matplotlib>=2.1 scipy>=0.16 pandas mock diff --git a/sampler_requirements.txt b/sampler_requirements.txt index 2e715558..69a03701 100644 --- a/sampler_requirements.txt +++ b/sampler_requirements.txt @@ -3,8 +3,7 @@ dynesty emcee nestle ptemcee -pymc3==3.6; python_version <= '2.7' -pymc3>=3.6; python_version > '3.4' +pymc3>=3.6 pymultinest kombine ultranest>=3.0.0 diff --git a/setup.cfg b/setup.cfg index a9ae5b2d..8dd2e017 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,6 +5,7 @@ ignore = E129 W503 W504 W605 E203 E402 [tool:pytest] addopts = + --ignore test/import_test.py --ignore test/integration/other_test.py --ignore test/integration/example_test.py --ignore test/integration/sample_from_the_prior_test.py diff --git a/test/gw/cosmology_test.py b/test/gw/cosmology_test.py index 158375d5..0864c4cc 100644 --- a/test/gw/cosmology_test.py +++ b/test/gw/cosmology_test.py @@ -46,7 +46,7 @@ class TestGetCosmology(unittest.TestCase): self.assertEqual(cosmology.get_cosmology("WMAP9").name, "WMAP9") def test_getting_cosmology_with_default(self): - self.assertEqual(cosmology.get_cosmology(), cosmology.COSMOLOGY[0]) + self.assertEqual(cosmology.get_cosmology().name, "Planck15") if __name__ == "__main__": diff --git a/test/gw/utils_test.py b/test/gw/utils_test.py index fee929fa..a3a7582f 100644 --- a/test/gw/utils_test.py +++ b/test/gw/utils_test.py @@ -3,9 +3,10 @@ import os from shutil import rmtree import numpy as np -import gwpy import lal import lalsimulation as lalsim +from gwpy.timeseries import TimeSeries +from gwpy.detector import Channel from scipy.stats import ks_2samp import bilby @@ -125,8 +126,8 @@ class TestGWUtils(unittest.TestCase): N = 100 times = np.linspace(start_time, end_time, N) data = np.random.normal(0, 1, N) - ts = gwpy.timeseries.TimeSeries(data=data, times=times, t0=0) - ts.channel = gwpy.detector.Channel(channel) + ts = TimeSeries(data=data, times=times, t0=0) + ts.channel = Channel(channel) ts.name = channel filename = os.path.join(self.outdir, "test.gwf") ts.write(filename, format="gwf") @@ -158,7 +159,7 @@ class TestGWUtils(unittest.TestCase): ) self.assertTrue(np.all(strain.value == data[:-1])) - ts = gwpy.timeseries.TimeSeries(data=data, times=times, t0=0) + ts = TimeSeries(data=data, times=times, t0=0) ts.name = "NOT-A-KNOWN-CHANNEL" ts.write(filename, format="gwf") strain = gwutils.read_frame_file(filename, start_time=None, end_time=None) diff --git a/test/import_test.py b/test/import_test.py new file mode 100644 index 00000000..29a4c602 --- /dev/null +++ b/test/import_test.py @@ -0,0 +1,22 @@ +import sys + +import bilby # noqa + +unique_packages = set(sys.modules) + +unwanted = { + "lal", "lalsimulation", "matplotlib", + "h5py", "dill", "tqdm", "tables", "deepdish", "corner", +} + +for filename in ["sampler_requirements.txt", "optional_requirements.txt"]: + with open(filename, "r") as ff: + packages = ff.readlines() + for package in packages: + package = package.split(">")[0].split("<")[0].split("=")[0].strip() + unwanted.add(package) + +if not unique_packages.isdisjoint(unwanted): + raise ImportError( + f"{' '.join(unique_packages.intersection(unwanted))} imported with Bilby" + ) -- GitLab