From 0bb1113d0d1c50e2a0b37e05c92bd49d8d0be00e Mon Sep 17 00:00:00 2001 From: Moritz Huebner <moritz.huebner@ligo.org> Date: Tue, 28 Apr 2020 01:14:46 -0500 Subject: [PATCH] Reduce redundant code --- bilby/core/grid.py | 32 ++------- bilby/core/prior/base.py | 14 ++-- bilby/core/prior/joint.py | 8 +-- bilby/core/result.py | 23 +------ bilby/core/sampler/dynamic_dynesty.py | 16 +---- bilby/core/sampler/dynesty.py | 19 +++--- bilby/core/sampler/emcee.py | 18 +++-- bilby/core/sampler/kombine.py | 18 ++--- bilby/core/sampler/pymc3.py | 95 +++++++++++---------------- bilby/core/utils.py | 44 +++++++++++++ bilby/gw/likelihood.py | 29 ++++---- bilby/gw/utils.py | 38 ++++------- 12 files changed, 148 insertions(+), 206 deletions(-) diff --git a/bilby/core/grid.py b/bilby/core/grid.py index ee09e2fec..42fb58dd0 100644 --- a/bilby/core/grid.py +++ b/bilby/core/grid.py @@ -8,7 +8,7 @@ from collections import OrderedDict from .prior import Prior, PriorDict from .utils import (logtrapzexp, check_directory_exists_and_if_not_mkdir, logger) -from .utils import BilbyJsonEncoder, decode_bilby_json +from .utils import BilbyJsonEncoder, load_json, move_old_file from .result import FileMovedError @@ -397,18 +397,7 @@ class Grid(object): filename = grid_file_name(outdir, self.label, gzip) - if os.path.isfile(filename): - if overwrite: - logger.debug('Removing existing file {}'.format(filename)) - os.remove(filename) - else: - logger.debug( - 'Renaming existing file {} to {}.old'.format(filename, - filename)) - os.rename(filename, filename + '.old') - - logger.debug("Saving result to {}".format(filename)) - + move_old_file(filename, overwrite) dictionary = self._get_save_data_dictionary() try: @@ -452,23 +441,14 @@ class Grid(object): """ - if filename is not None: - fname = filename - else: + if filename is None: if (outdir is None) and (label is None): raise ValueError("No information given to load file") else: - fname = grid_file_name(outdir, label, gzip) + filename = grid_file_name(outdir, label, gzip) - if os.path.isfile(fname): - if gzip or os.path.splitext(fname)[1].lstrip('.') == 'gz': - import gzip - with gzip.GzipFile(fname, 'r') as file: - json_str = file.read().decode('utf-8') - dictionary = json.loads(json_str, object_hook=decode_bilby_json) - else: - with open(fname, 'r') as file: - dictionary = json.load(file, object_hook=decode_bilby_json) + if os.path.isfile(filename): + dictionary = load_json(filename, gzip) try: grid = cls(likelihood=None, priors=dictionary['priors'], grid_size=dictionary['sample_points'], diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py index 29efe925d..db5880da1 100644 --- a/bilby/core/prior/base.py +++ b/bilby/core/prior/base.py @@ -8,7 +8,8 @@ import scipy.stats from scipy.integrate import cumtrapz from scipy.interpolate import interp1d -from bilby.core.utils import infer_args_from_method, BilbyJsonEncoder, decode_bilby_json, logger +from bilby.core.utils import infer_args_from_method, BilbyJsonEncoder, decode_bilby_json, logger, \ + get_dict_with_properties class Prior(object): @@ -280,15 +281,8 @@ class Prior(object): def get_instantiation_dict(self): subclass_args = infer_args_from_method(self.__init__) - property_names = [p for p in dir(self.__class__) - if isinstance(getattr(self.__class__, p), property)] - dict_with_properties = self.__dict__.copy() - for key in property_names: - dict_with_properties[key] = getattr(self, key) - instantiation_dict = dict() - for key in subclass_args: - instantiation_dict[key] = dict_with_properties[key] - return instantiation_dict + dict_with_properties = get_dict_with_properties(self) + return {key: dict_with_properties[key] for key in subclass_args} @property def boundary(self): diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 613183f87..bac8fceb5 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -3,7 +3,7 @@ import scipy.stats from scipy.special import erfinv from .base import Prior, PriorException -from bilby.core.utils import logger, infer_args_from_method +from bilby.core.utils import logger, infer_args_from_method, get_dict_with_properties class BaseJointPriorDist(object): @@ -105,11 +105,7 @@ class BaseJointPriorDist(object): def get_instantiation_dict(self): subclass_args = infer_args_from_method(self.__init__) - property_names = [p for p in dir(self.__class__) - if isinstance(getattr(self.__class__, p), property)] - dict_with_properties = self.__dict__.copy() - for key in property_names: - dict_with_properties[key] = getattr(self, key) + dict_with_properties = get_dict_with_properties(self) instantiation_dict = dict() for key in subclass_args: if isinstance(dict_with_properties[key], list): diff --git a/bilby/core/result.py b/bilby/core/result.py index 5bc278e58..8236f24b8 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -22,7 +22,7 @@ from .utils import ( check_directory_exists_and_if_not_mkdir, latex_plot_format, safe_save_figure, ) -from .utils import BilbyJsonEncoder, decode_bilby_json +from .utils import BilbyJsonEncoder, load_json, move_old_file from .prior import Prior, PriorDict, DeltaFunction @@ -259,14 +259,7 @@ class Result(object): filename = _determine_file_name(filename, outdir, label, 'json', gzip) if os.path.isfile(filename): - if gzip or os.path.splitext(filename)[1].lstrip('.') == 'gz': - import gzip - with gzip.GzipFile(filename, 'r') as file: - json_str = file.read().decode('utf-8') - dictionary = json.loads(json_str, object_hook=decode_bilby_json) - else: - with open(filename, 'r') as file: - dictionary = json.load(file, object_hook=decode_bilby_json) + dictionary = load_json(filename, gzip) try: return cls(**dictionary) except TypeError as e: @@ -468,17 +461,7 @@ class Result(object): if filename is None: filename = result_file_name(outdir, self.label, extension, gzip) - if os.path.isfile(filename): - if overwrite: - logger.debug('Removing existing file {}'.format(filename)) - os.remove(filename) - else: - logger.debug( - 'Renaming existing file {} to {}.old'.format(filename, - filename)) - os.rename(filename, filename + '.old') - - logger.debug("Saving result to {}".format(filename)) + move_old_file(filename, overwrite) # Convert the prior to a string representation for saving on disk dictionary = self._get_save_data_dictionary() diff --git a/bilby/core/sampler/dynamic_dynesty.py b/bilby/core/sampler/dynamic_dynesty.py index 4e7e1f56d..48d726e4a 100644 --- a/bilby/core/sampler/dynamic_dynesty.py +++ b/bilby/core/sampler/dynamic_dynesty.py @@ -5,7 +5,6 @@ import dill as pickle import signal import numpy as np -from pandas import DataFrame from ..utils import logger, check_directory_exists_and_if_not_mkdir from .base_sampler import Sampler @@ -139,20 +138,7 @@ class DynamicDynesty(Dynesty): print("") # self.result.sampler_output = out - weights = np.exp(out['logwt'] - out['logz'][-1]) - nested_samples = DataFrame( - out.samples, columns=self.search_parameter_keys) - nested_samples['weights'] = weights - nested_samples['log_likelihood'] = out.logl - - self.result.samples = dynesty.utils.resample_equal(out.samples, weights) - self.result.nested_samples = nested_samples - self.result.log_likelihood_evaluations = self.reorder_loglikelihoods( - unsorted_loglikelihoods=out.logl, unsorted_samples=out.samples, - sorted_samples=self.result.samples) - self.result.log_evidence = out.logz[-1] - self.result.log_evidence_err = out.logzerr[-1] - + self._generate_result(out) if self.plot: self.generate_trace_plots(out) diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index 688efc417..9c2c6683c 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -357,26 +357,29 @@ class Dynesty(NestedSampler): with open(dynesty_result, 'wb') as file: pickle.dump(out, file) + self._generate_result(out) + self.calc_likelihood_count() + self.result.sampling_time = self.sampling_time + + if self.plot: + self.generate_trace_plots(out) + + return self.result + + def _generate_result(self, out): + import dynesty weights = np.exp(out['logwt'] - out['logz'][-1]) nested_samples = DataFrame( out.samples, columns=self.search_parameter_keys) nested_samples['weights'] = weights nested_samples['log_likelihood'] = out.logl - self.result.samples = dynesty.utils.resample_equal(out.samples, weights) self.result.nested_samples = nested_samples self.result.log_likelihood_evaluations = self.reorder_loglikelihoods( unsorted_loglikelihoods=out.logl, unsorted_samples=out.samples, sorted_samples=self.result.samples) - self.calc_likelihood_count() self.result.log_evidence = out.logz[-1] self.result.log_evidence_err = out.logzerr[-1] - self.result.sampling_time = self.sampling_time - - if self.plot: - self.generate_trace_plots(out) - - return self.result def _run_nested_wrapper(self, kwargs): """ Wrapper function to run_nested diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index 144bede6d..33f21fcf9 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -373,20 +373,26 @@ class Emcee(MCMCSampler): self.calculate_autocorrelation( self.sampler.chain.reshape((-1, self.ndim))) self.print_nburn_logging_info() - self.calc_likelihood_count() + + self._generate_result() + + self.result.samples = self.sampler.chain[:, self.nburn:, :].reshape( + (-1, self.ndim)) + self.result.walkers = self.sampler.chain + return self.result + + def _generate_result(self): self.result.nburn = self.nburn + self.calc_likelihood_count() if self.result.nburn > self.nsteps: raise SamplerError( "The run has finished, but the chain is not burned in: " - "`nburn < nsteps`. Try increasing the number of steps.") - self.result.samples = self.sampler.chain[:, self.nburn:, :].reshape( - (-1, self.ndim)) + "`nburn < nsteps` ({} < {}). Try increasing the " + "number of steps.".format(self.result.nburn, self.nsteps)) blobs = np.array(self.sampler.blobs) blobs_trimmed = blobs[self.nburn:, :, :].reshape((-1, 2)) log_likelihoods, log_priors = blobs_trimmed.T self.result.log_likelihood_evaluations = log_likelihoods self.result.log_prior_evaluations = log_priors - self.result.walkers = self.sampler.chain self.result.log_evidence = np.nan self.result.log_evidence_err = np.nan - return self.result diff --git a/bilby/core/sampler/kombine.py b/bilby/core/sampler/kombine.py index 48e85342a..f7c7768ec 100644 --- a/bilby/core/sampler/kombine.py +++ b/bilby/core/sampler/kombine.py @@ -3,7 +3,6 @@ from ..utils import logger, get_progress_bar import numpy as np import os from .emcee import Emcee -from .base_sampler import SamplerError class Kombine(Emcee): @@ -157,20 +156,11 @@ class Kombine(Emcee): tmp_chain = self.sampler.chain.copy() self.calculate_autocorrelation(tmp_chain.reshape((-1, self.ndim))) self.print_nburn_logging_info() - self.result.nburn = self.nburn - if self.result.nburn > self.nsteps: - raise SamplerError( - "The run has finished, but the chain is not burned in: `nburn < nsteps` ({} < {}). Try increasing the " - "number of steps.".format(self.result.nburn, self.nsteps)) + + self._generate_result() + self.result.log_evidence_err = np.nan + tmp_chain = self.sampler.chain[self.nburn:, :, :].copy() self.result.samples = tmp_chain.reshape((-1, self.ndim)) - blobs = np.array(self.sampler.blobs) - blobs_trimmed = blobs[self.nburn:, :, :].reshape((-1, 2)) - self.calc_likelihood_count() - log_likelihoods, log_priors = blobs_trimmed.T - self.result.log_likelihood_evaluations = log_likelihoods - self.result.log_prior_evaluations = log_priors self.result.walkers = self.sampler.chain.reshape((self.nwalkers, self.nsteps, self.ndim)) - self.result.log_evidence = np.nan - self.result.log_evidence_err = np.nan return self.result diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py index 7a8745c8f..824f294ce 100644 --- a/bilby/core/sampler/pymc3.py +++ b/bilby/core/sampler/pymc3.py @@ -469,43 +469,13 @@ class Pymc3(MCMCSampler): for sms in self.step_method[key]: curmethod = sms.lower() methodslist.append(curmethod) - args = {} - if curmethod == 'nuts': - if nuts_kwargs is not None: - args = nuts_kwargs - elif step_kwargs is not None: - args = step_kwargs.pop('nuts', {}) - # add values into nuts_kwargs - nuts_kwargs = args - else: - args = {} - else: - if step_kwargs is not None: - args = step_kwargs.get(curmethod, {}) - else: - args = {} - self.kwargs['step'].append( - pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args)) + nuts_kwargs = self._create_nuts_kwargs(curmethod, key, nuts_kwargs, pymc3, step_kwargs, + step_methods) else: curmethod = self.step_method[key].lower() methodslist.append(curmethod) - args = {} - if curmethod == 'nuts': - if nuts_kwargs is not None: - args = nuts_kwargs - elif step_kwargs is not None: - args = step_kwargs.pop('nuts', {}) - # add values into nuts_kwargs - nuts_kwargs = args - else: - args = {} - else: - if step_kwargs is not None: - args = step_kwargs.get(curmethod, {}) - else: - args = {} - self.kwargs['step'].append( - pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args)) + nuts_kwargs = self._create_nuts_kwargs(curmethod, key, nuts_kwargs, pymc3, step_kwargs, + step_methods) else: with self.pymc3_model: # check for a compound step list @@ -514,18 +484,7 @@ class Pymc3(MCMCSampler): for sms in self.step_method: curmethod = sms.lower() methodslist.append(curmethod) - args = {} - if curmethod == 'nuts': - if nuts_kwargs is not None: - args = nuts_kwargs - elif step_kwargs is not None: - args = step_kwargs.pop('nuts', {}) - # add values into nuts_kwargs - nuts_kwargs = args - else: - args = {} - else: - args = step_kwargs.get(curmethod, {}) + args, nuts_kwargs = self._create_args_and_nuts_kwargs(curmethod, nuts_kwargs, step_kwargs) compound.append(pymc3.__dict__[step_methods[curmethod]](**args)) self.kwargs['step'] = compound else: @@ -533,18 +492,7 @@ class Pymc3(MCMCSampler): if self.step_method is not None: curmethod = self.step_method.lower() methodslist.append(curmethod) - args = {} - if curmethod == 'nuts': - if nuts_kwargs is not None: - args = nuts_kwargs - elif step_kwargs is not None: - args = step_kwargs.pop('nuts', {}) - # add values into nuts_kwargs - nuts_kwargs = args - else: - args = {} - else: - args = step_kwargs.get(curmethod, {}) + args, nuts_kwargs = self._create_args_and_nuts_kwargs(curmethod, nuts_kwargs, step_kwargs) self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args) else: # re-add step_kwargs if no step methods are set @@ -582,6 +530,37 @@ class Pymc3(MCMCSampler): self.calc_likelihood_count() return self.result + def _create_args_and_nuts_kwargs(self, curmethod, nuts_kwargs, step_kwargs): + if curmethod == 'nuts': + args, nuts_kwargs = self._get_nuts_args(nuts_kwargs, step_kwargs) + else: + args = step_kwargs.get(curmethod, {}) + return args, nuts_kwargs + + def _create_nuts_kwargs(self, curmethod, key, nuts_kwargs, pymc3, step_kwargs, step_methods): + if curmethod == 'nuts': + args, nuts_kwargs = self._get_nuts_args(nuts_kwargs, step_kwargs) + else: + if step_kwargs is not None: + args = step_kwargs.get(curmethod, {}) + else: + args = {} + self.kwargs['step'].append( + pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args)) + return nuts_kwargs + + @staticmethod + def _get_nuts_args(nuts_kwargs, step_kwargs): + if nuts_kwargs is not None: + args = nuts_kwargs + elif step_kwargs is not None: + args = step_kwargs.pop('nuts', {}) + # add values into nuts_kwargs + nuts_kwargs = args + else: + args = {} + return args, nuts_kwargs + def set_prior(self): """ Set the PyMC3 prior distributions. diff --git a/bilby/core/utils.py b/bilby/core/utils.py index 71b0b5028..01bd97dca 100644 --- a/bilby/core/utils.py +++ b/bilby/core/utils.py @@ -131,6 +131,15 @@ def _infer_args_from_function_except_for_first_arg(func): return infer_args_from_function_except_n_args(func=func, n=1) +def get_dict_with_properties(obj): + property_names = [p for p in dir(obj.__class__) + if isinstance(getattr(obj.__class__, p), property)] + dict_with_properties = obj.__dict__.copy() + for key in property_names: + dict_with_properties[key] = getattr(obj, key) + return dict_with_properties + + def get_sampling_frequency(time_array): """ Calculate sampling frequency from a time series @@ -1022,6 +1031,41 @@ def encode_astropy_quantity(dct): return dct +def move_old_file(filename, overwrite=False): + """ Moves or removes an old file. + + Parameters + ---------- + filename: str + Name of the file to be move + overwrite: bool, optional + Whether or not to remove the file or to change the name + to filename + '.old' + """ + if os.path.isfile(filename): + if overwrite: + logger.debug('Removing existing file {}'.format(filename)) + os.remove(filename) + else: + logger.debug( + 'Renaming existing file {} to {}.old'.format(filename, + filename)) + os.rename(filename, filename + '.old') + logger.debug("Saving result to {}".format(filename)) + + +def load_json(filename, gzip): + if gzip or os.path.splitext(filename)[1].lstrip('.') == 'gz': + import gzip + with gzip.GzipFile(filename, 'r') as file: + json_str = file.read().decode('utf-8') + dictionary = json.loads(json_str, object_hook=decode_bilby_json) + else: + with open(filename, 'r') as file: + dictionary = json.load(file, object_hook=decode_bilby_json) + return dictionary + + def decode_bilby_json(dct): if dct.get("__prior_dict__", False): cls = getattr(import_module(dct['__module__']), dct['__name__']) diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py index 786262481..bdfb8de54 100644 --- a/bilby/gw/likelihood.py +++ b/bilby/gw/likelihood.py @@ -439,14 +439,7 @@ class GravitationalWaveTransient(Likelihood): if signal_polarizations is None: signal_polarizations = \ self.waveform_generator.frequency_domain_strain(self.parameters) - d_inner_h = 0 - h_inner_h = 0 - for interferometer in self.interferometers: - per_detector_snr = self.calculate_snrs( - signal_polarizations, interferometer) - - d_inner_h += per_detector_snr.d_inner_h - h_inner_h += per_detector_snr.optimal_snr_squared + d_inner_h, h_inner_h = self._calculate_inner_products(signal_polarizations) d_inner_h_dist = ( d_inner_h * self.parameters['luminosity_distance'] / @@ -472,6 +465,17 @@ class GravitationalWaveTransient(Likelihood): self._rescale_signal(signal_polarizations, new_distance) return new_distance + def _calculate_inner_products(self, signal_polarizations): + d_inner_h = 0 + h_inner_h = 0 + for interferometer in self.interferometers: + per_detector_snr = self.calculate_snrs( + signal_polarizations, interferometer) + + d_inner_h += per_detector_snr.d_inner_h + h_inner_h += per_detector_snr.optimal_snr_squared + return d_inner_h, h_inner_h + def generate_phase_sample_from_marginalized_likelihood( self, signal_polarizations=None): """ @@ -497,14 +501,7 @@ class GravitationalWaveTransient(Likelihood): if signal_polarizations is None: signal_polarizations = \ self.waveform_generator.frequency_domain_strain(self.parameters) - d_inner_h = 0 - h_inner_h = 0 - for interferometer in self.interferometers: - per_detector_snr = self.calculate_snrs( - signal_polarizations, interferometer) - - d_inner_h += per_detector_snr.d_inner_h - h_inner_h += per_detector_snr.optimal_snr_squared + d_inner_h, h_inner_h = self._calculate_inner_products(signal_polarizations) phases = np.linspace(0, 2 * np.pi, 101) phasor = np.exp(-2j * phases) diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index 5feb3d51d..dc5ea7964 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -717,28 +717,15 @@ def lalsim_SimInspiralFD( approximant: int, str """ - [mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z, - luminosity_distance, iota, phase, longitude_ascending_nodes, - eccentricity, mean_per_ano, delta_frequency, minimum_frequency, - maximum_frequency, reference_frequency] = convert_args_list_to_float( + args = convert_args_list_to_float( mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z, luminosity_distance, iota, phase, longitude_ascending_nodes, eccentricity, mean_per_ano, delta_frequency, minimum_frequency, maximum_frequency, reference_frequency) - if isinstance(approximant, int): - pass - elif isinstance(approximant, str): - approximant = lalsim_GetApproximantFromString(approximant) - else: - raise ValueError("approximant not an int") + approximant = _get_lalsim_approximant(approximant) - return lalsim.SimInspiralFD( - mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, - spin_2z, luminosity_distance, iota, phase, - longitude_ascending_nodes, eccentricity, mean_per_ano, delta_frequency, - minimum_frequency, maximum_frequency, reference_frequency, - waveform_dictionary, approximant) + return lalsim.SimInspiralFD(*args, waveform_dictionary, approximant) def lalsim_SimInspiralChooseFDWaveform( @@ -774,28 +761,25 @@ def lalsim_SimInspiralChooseFDWaveform( approximant: int, str """ - [mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z, - luminosity_distance, iota, phase, longitude_ascending_nodes, - eccentricity, mean_per_ano, delta_frequency, minimum_frequency, - maximum_frequency, reference_frequency] = convert_args_list_to_float( + args = convert_args_list_to_float( mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z, luminosity_distance, iota, phase, longitude_ascending_nodes, eccentricity, mean_per_ano, delta_frequency, minimum_frequency, maximum_frequency, reference_frequency) + approximant = _get_lalsim_approximant(approximant) + + return lalsim.SimInspiralChooseFDWaveform(*args, waveform_dictionary, approximant) + + +def _get_lalsim_approximant(approximant): if isinstance(approximant, int): pass elif isinstance(approximant, str): approximant = lalsim_GetApproximantFromString(approximant) else: raise ValueError("approximant not an int") - - return lalsim.SimInspiralChooseFDWaveform( - mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, - spin_2z, luminosity_distance, iota, phase, - longitude_ascending_nodes, eccentricity, mean_per_ano, delta_frequency, - minimum_frequency, maximum_frequency, reference_frequency, - waveform_dictionary, approximant) + return approximant def lalsim_SimInspiralChooseFDWaveformSequence( -- GitLab