From 0bb1113d0d1c50e2a0b37e05c92bd49d8d0be00e Mon Sep 17 00:00:00 2001
From: Moritz Huebner <moritz.huebner@ligo.org>
Date: Tue, 28 Apr 2020 01:14:46 -0500
Subject: [PATCH] Reduce redundant code

---
 bilby/core/grid.py                    | 32 ++-------
 bilby/core/prior/base.py              | 14 ++--
 bilby/core/prior/joint.py             |  8 +--
 bilby/core/result.py                  | 23 +------
 bilby/core/sampler/dynamic_dynesty.py | 16 +----
 bilby/core/sampler/dynesty.py         | 19 +++---
 bilby/core/sampler/emcee.py           | 18 +++--
 bilby/core/sampler/kombine.py         | 18 ++---
 bilby/core/sampler/pymc3.py           | 95 +++++++++++----------------
 bilby/core/utils.py                   | 44 +++++++++++++
 bilby/gw/likelihood.py                | 29 ++++----
 bilby/gw/utils.py                     | 38 ++++-------
 12 files changed, 148 insertions(+), 206 deletions(-)

diff --git a/bilby/core/grid.py b/bilby/core/grid.py
index ee09e2fec..42fb58dd0 100644
--- a/bilby/core/grid.py
+++ b/bilby/core/grid.py
@@ -8,7 +8,7 @@ from collections import OrderedDict
 from .prior import Prior, PriorDict
 from .utils import (logtrapzexp, check_directory_exists_and_if_not_mkdir,
                     logger)
-from .utils import BilbyJsonEncoder, decode_bilby_json
+from .utils import BilbyJsonEncoder, load_json, move_old_file
 from .result import FileMovedError
 
 
@@ -397,18 +397,7 @@ class Grid(object):
 
             filename = grid_file_name(outdir, self.label, gzip)
 
-        if os.path.isfile(filename):
-            if overwrite:
-                logger.debug('Removing existing file {}'.format(filename))
-                os.remove(filename)
-            else:
-                logger.debug(
-                    'Renaming existing file {} to {}.old'.format(filename,
-                                                                 filename))
-                os.rename(filename, filename + '.old')
-
-        logger.debug("Saving result to {}".format(filename))
-
+        move_old_file(filename, overwrite)
         dictionary = self._get_save_data_dictionary()
 
         try:
@@ -452,23 +441,14 @@ class Grid(object):
 
         """
 
-        if filename is not None:
-            fname = filename
-        else:
+        if filename is None:
             if (outdir is None) and (label is None):
                 raise ValueError("No information given to load file")
             else:
-                fname = grid_file_name(outdir, label, gzip)
+                filename = grid_file_name(outdir, label, gzip)
 
-        if os.path.isfile(fname):
-            if gzip or os.path.splitext(fname)[1].lstrip('.') == 'gz':
-                import gzip
-                with gzip.GzipFile(fname, 'r') as file:
-                    json_str = file.read().decode('utf-8')
-                dictionary = json.loads(json_str, object_hook=decode_bilby_json)
-            else:
-                with open(fname, 'r') as file:
-                    dictionary = json.load(file, object_hook=decode_bilby_json)
+        if os.path.isfile(filename):
+            dictionary = load_json(filename, gzip)
             try:
                 grid = cls(likelihood=None, priors=dictionary['priors'],
                            grid_size=dictionary['sample_points'],
diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py
index 29efe925d..db5880da1 100644
--- a/bilby/core/prior/base.py
+++ b/bilby/core/prior/base.py
@@ -8,7 +8,8 @@ import scipy.stats
 from scipy.integrate import cumtrapz
 from scipy.interpolate import interp1d
 
-from bilby.core.utils import infer_args_from_method, BilbyJsonEncoder, decode_bilby_json, logger
+from bilby.core.utils import infer_args_from_method, BilbyJsonEncoder, decode_bilby_json, logger, \
+    get_dict_with_properties
 
 
 class Prior(object):
@@ -280,15 +281,8 @@ class Prior(object):
 
     def get_instantiation_dict(self):
         subclass_args = infer_args_from_method(self.__init__)
-        property_names = [p for p in dir(self.__class__)
-                          if isinstance(getattr(self.__class__, p), property)]
-        dict_with_properties = self.__dict__.copy()
-        for key in property_names:
-            dict_with_properties[key] = getattr(self, key)
-        instantiation_dict = dict()
-        for key in subclass_args:
-            instantiation_dict[key] = dict_with_properties[key]
-        return instantiation_dict
+        dict_with_properties = get_dict_with_properties(self)
+        return {key: dict_with_properties[key] for key in subclass_args}
 
     @property
     def boundary(self):
diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py
index 613183f87..bac8fceb5 100644
--- a/bilby/core/prior/joint.py
+++ b/bilby/core/prior/joint.py
@@ -3,7 +3,7 @@ import scipy.stats
 from scipy.special import erfinv
 
 from .base import Prior, PriorException
-from bilby.core.utils import logger, infer_args_from_method
+from bilby.core.utils import logger, infer_args_from_method, get_dict_with_properties
 
 
 class BaseJointPriorDist(object):
@@ -105,11 +105,7 @@ class BaseJointPriorDist(object):
 
     def get_instantiation_dict(self):
         subclass_args = infer_args_from_method(self.__init__)
-        property_names = [p for p in dir(self.__class__)
-                          if isinstance(getattr(self.__class__, p), property)]
-        dict_with_properties = self.__dict__.copy()
-        for key in property_names:
-            dict_with_properties[key] = getattr(self, key)
+        dict_with_properties = get_dict_with_properties(self)
         instantiation_dict = dict()
         for key in subclass_args:
             if isinstance(dict_with_properties[key], list):
diff --git a/bilby/core/result.py b/bilby/core/result.py
index 5bc278e58..8236f24b8 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -22,7 +22,7 @@ from .utils import (
     check_directory_exists_and_if_not_mkdir,
     latex_plot_format, safe_save_figure,
 )
-from .utils import BilbyJsonEncoder, decode_bilby_json
+from .utils import BilbyJsonEncoder, load_json, move_old_file
 from .prior import Prior, PriorDict, DeltaFunction
 
 
@@ -259,14 +259,7 @@ class Result(object):
         filename = _determine_file_name(filename, outdir, label, 'json', gzip)
 
         if os.path.isfile(filename):
-            if gzip or os.path.splitext(filename)[1].lstrip('.') == 'gz':
-                import gzip
-                with gzip.GzipFile(filename, 'r') as file:
-                    json_str = file.read().decode('utf-8')
-                dictionary = json.loads(json_str, object_hook=decode_bilby_json)
-            else:
-                with open(filename, 'r') as file:
-                    dictionary = json.load(file, object_hook=decode_bilby_json)
+            dictionary = load_json(filename, gzip)
             try:
                 return cls(**dictionary)
             except TypeError as e:
@@ -468,17 +461,7 @@ class Result(object):
         if filename is None:
             filename = result_file_name(outdir, self.label, extension, gzip)
 
-        if os.path.isfile(filename):
-            if overwrite:
-                logger.debug('Removing existing file {}'.format(filename))
-                os.remove(filename)
-            else:
-                logger.debug(
-                    'Renaming existing file {} to {}.old'.format(filename,
-                                                                 filename))
-                os.rename(filename, filename + '.old')
-
-        logger.debug("Saving result to {}".format(filename))
+        move_old_file(filename, overwrite)
 
         # Convert the prior to a string representation for saving on disk
         dictionary = self._get_save_data_dictionary()
diff --git a/bilby/core/sampler/dynamic_dynesty.py b/bilby/core/sampler/dynamic_dynesty.py
index 4e7e1f56d..48d726e4a 100644
--- a/bilby/core/sampler/dynamic_dynesty.py
+++ b/bilby/core/sampler/dynamic_dynesty.py
@@ -5,7 +5,6 @@ import dill as pickle
 import signal
 
 import numpy as np
-from pandas import DataFrame
 
 from ..utils import logger, check_directory_exists_and_if_not_mkdir
 from .base_sampler import Sampler
@@ -139,20 +138,7 @@ class DynamicDynesty(Dynesty):
             print("")
 
         # self.result.sampler_output = out
-        weights = np.exp(out['logwt'] - out['logz'][-1])
-        nested_samples = DataFrame(
-            out.samples, columns=self.search_parameter_keys)
-        nested_samples['weights'] = weights
-        nested_samples['log_likelihood'] = out.logl
-
-        self.result.samples = dynesty.utils.resample_equal(out.samples, weights)
-        self.result.nested_samples = nested_samples
-        self.result.log_likelihood_evaluations = self.reorder_loglikelihoods(
-            unsorted_loglikelihoods=out.logl, unsorted_samples=out.samples,
-            sorted_samples=self.result.samples)
-        self.result.log_evidence = out.logz[-1]
-        self.result.log_evidence_err = out.logzerr[-1]
-
+        self._generate_result(out)
         if self.plot:
             self.generate_trace_plots(out)
 
diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py
index 688efc417..9c2c6683c 100644
--- a/bilby/core/sampler/dynesty.py
+++ b/bilby/core/sampler/dynesty.py
@@ -357,26 +357,29 @@ class Dynesty(NestedSampler):
         with open(dynesty_result, 'wb') as file:
             pickle.dump(out, file)
 
+        self._generate_result(out)
+        self.calc_likelihood_count()
+        self.result.sampling_time = self.sampling_time
+
+        if self.plot:
+            self.generate_trace_plots(out)
+
+        return self.result
+
+    def _generate_result(self, out):
+        import dynesty
         weights = np.exp(out['logwt'] - out['logz'][-1])
         nested_samples = DataFrame(
             out.samples, columns=self.search_parameter_keys)
         nested_samples['weights'] = weights
         nested_samples['log_likelihood'] = out.logl
-
         self.result.samples = dynesty.utils.resample_equal(out.samples, weights)
         self.result.nested_samples = nested_samples
         self.result.log_likelihood_evaluations = self.reorder_loglikelihoods(
             unsorted_loglikelihoods=out.logl, unsorted_samples=out.samples,
             sorted_samples=self.result.samples)
-        self.calc_likelihood_count()
         self.result.log_evidence = out.logz[-1]
         self.result.log_evidence_err = out.logzerr[-1]
-        self.result.sampling_time = self.sampling_time
-
-        if self.plot:
-            self.generate_trace_plots(out)
-
-        return self.result
 
     def _run_nested_wrapper(self, kwargs):
         """ Wrapper function to run_nested
diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py
index 144bede6d..33f21fcf9 100644
--- a/bilby/core/sampler/emcee.py
+++ b/bilby/core/sampler/emcee.py
@@ -373,20 +373,26 @@ class Emcee(MCMCSampler):
         self.calculate_autocorrelation(
             self.sampler.chain.reshape((-1, self.ndim)))
         self.print_nburn_logging_info()
-        self.calc_likelihood_count()
+
+        self._generate_result()
+
+        self.result.samples = self.sampler.chain[:, self.nburn:, :].reshape(
+            (-1, self.ndim))
+        self.result.walkers = self.sampler.chain
+        return self.result
+
+    def _generate_result(self):
         self.result.nburn = self.nburn
+        self.calc_likelihood_count()
         if self.result.nburn > self.nsteps:
             raise SamplerError(
                 "The run has finished, but the chain is not burned in: "
-                "`nburn < nsteps`. Try increasing the number of steps.")
-        self.result.samples = self.sampler.chain[:, self.nburn:, :].reshape(
-            (-1, self.ndim))
+                "`nburn < nsteps` ({} < {}). Try increasing the "
+                "number of steps.".format(self.result.nburn, self.nsteps))
         blobs = np.array(self.sampler.blobs)
         blobs_trimmed = blobs[self.nburn:, :, :].reshape((-1, 2))
         log_likelihoods, log_priors = blobs_trimmed.T
         self.result.log_likelihood_evaluations = log_likelihoods
         self.result.log_prior_evaluations = log_priors
-        self.result.walkers = self.sampler.chain
         self.result.log_evidence = np.nan
         self.result.log_evidence_err = np.nan
-        return self.result
diff --git a/bilby/core/sampler/kombine.py b/bilby/core/sampler/kombine.py
index 48e85342a..f7c7768ec 100644
--- a/bilby/core/sampler/kombine.py
+++ b/bilby/core/sampler/kombine.py
@@ -3,7 +3,6 @@ from ..utils import logger, get_progress_bar
 import numpy as np
 import os
 from .emcee import Emcee
-from .base_sampler import SamplerError
 
 
 class Kombine(Emcee):
@@ -157,20 +156,11 @@ class Kombine(Emcee):
             tmp_chain = self.sampler.chain.copy()
             self.calculate_autocorrelation(tmp_chain.reshape((-1, self.ndim)))
             self.print_nburn_logging_info()
-        self.result.nburn = self.nburn
-        if self.result.nburn > self.nsteps:
-            raise SamplerError(
-                "The run has finished, but the chain is not burned in: `nburn < nsteps` ({} < {}). Try increasing the "
-                "number of steps.".format(self.result.nburn, self.nsteps))
+
+        self._generate_result()
+        self.result.log_evidence_err = np.nan
+
         tmp_chain = self.sampler.chain[self.nburn:, :, :].copy()
         self.result.samples = tmp_chain.reshape((-1, self.ndim))
-        blobs = np.array(self.sampler.blobs)
-        blobs_trimmed = blobs[self.nburn:, :, :].reshape((-1, 2))
-        self.calc_likelihood_count()
-        log_likelihoods, log_priors = blobs_trimmed.T
-        self.result.log_likelihood_evaluations = log_likelihoods
-        self.result.log_prior_evaluations = log_priors
         self.result.walkers = self.sampler.chain.reshape((self.nwalkers, self.nsteps, self.ndim))
-        self.result.log_evidence = np.nan
-        self.result.log_evidence_err = np.nan
         return self.result
diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py
index 7a8745c8f..824f294ce 100644
--- a/bilby/core/sampler/pymc3.py
+++ b/bilby/core/sampler/pymc3.py
@@ -469,43 +469,13 @@ class Pymc3(MCMCSampler):
                         for sms in self.step_method[key]:
                             curmethod = sms.lower()
                             methodslist.append(curmethod)
-                            args = {}
-                            if curmethod == 'nuts':
-                                if nuts_kwargs is not None:
-                                    args = nuts_kwargs
-                                elif step_kwargs is not None:
-                                    args = step_kwargs.pop('nuts', {})
-                                    # add values into nuts_kwargs
-                                    nuts_kwargs = args
-                                else:
-                                    args = {}
-                            else:
-                                if step_kwargs is not None:
-                                    args = step_kwargs.get(curmethod, {})
-                                else:
-                                    args = {}
-                            self.kwargs['step'].append(
-                                pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
+                            nuts_kwargs = self._create_nuts_kwargs(curmethod, key, nuts_kwargs, pymc3, step_kwargs,
+                                                                   step_methods)
                     else:
                         curmethod = self.step_method[key].lower()
                         methodslist.append(curmethod)
-                        args = {}
-                        if curmethod == 'nuts':
-                            if nuts_kwargs is not None:
-                                args = nuts_kwargs
-                            elif step_kwargs is not None:
-                                args = step_kwargs.pop('nuts', {})
-                                # add values into nuts_kwargs
-                                nuts_kwargs = args
-                            else:
-                                args = {}
-                        else:
-                            if step_kwargs is not None:
-                                args = step_kwargs.get(curmethod, {})
-                            else:
-                                args = {}
-                        self.kwargs['step'].append(
-                            pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
+                        nuts_kwargs = self._create_nuts_kwargs(curmethod, key, nuts_kwargs, pymc3, step_kwargs,
+                                                               step_methods)
         else:
             with self.pymc3_model:
                 # check for a compound step list
@@ -514,18 +484,7 @@ class Pymc3(MCMCSampler):
                     for sms in self.step_method:
                         curmethod = sms.lower()
                         methodslist.append(curmethod)
-                        args = {}
-                        if curmethod == 'nuts':
-                            if nuts_kwargs is not None:
-                                args = nuts_kwargs
-                            elif step_kwargs is not None:
-                                args = step_kwargs.pop('nuts', {})
-                                # add values into nuts_kwargs
-                                nuts_kwargs = args
-                            else:
-                                args = {}
-                        else:
-                            args = step_kwargs.get(curmethod, {})
+                        args, nuts_kwargs = self._create_args_and_nuts_kwargs(curmethod, nuts_kwargs, step_kwargs)
                         compound.append(pymc3.__dict__[step_methods[curmethod]](**args))
                         self.kwargs['step'] = compound
                 else:
@@ -533,18 +492,7 @@ class Pymc3(MCMCSampler):
                     if self.step_method is not None:
                         curmethod = self.step_method.lower()
                         methodslist.append(curmethod)
-                        args = {}
-                        if curmethod == 'nuts':
-                            if nuts_kwargs is not None:
-                                args = nuts_kwargs
-                            elif step_kwargs is not None:
-                                args = step_kwargs.pop('nuts', {})
-                                # add values into nuts_kwargs
-                                nuts_kwargs = args
-                            else:
-                                args = {}
-                        else:
-                            args = step_kwargs.get(curmethod, {})
+                        args, nuts_kwargs = self._create_args_and_nuts_kwargs(curmethod, nuts_kwargs, step_kwargs)
                         self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args)
                     else:
                         # re-add step_kwargs if no step methods are set
@@ -582,6 +530,37 @@ class Pymc3(MCMCSampler):
         self.calc_likelihood_count()
         return self.result
 
+    def _create_args_and_nuts_kwargs(self, curmethod, nuts_kwargs, step_kwargs):
+        if curmethod == 'nuts':
+            args, nuts_kwargs = self._get_nuts_args(nuts_kwargs, step_kwargs)
+        else:
+            args = step_kwargs.get(curmethod, {})
+        return args, nuts_kwargs
+
+    def _create_nuts_kwargs(self, curmethod, key, nuts_kwargs, pymc3, step_kwargs, step_methods):
+        if curmethod == 'nuts':
+            args, nuts_kwargs = self._get_nuts_args(nuts_kwargs, step_kwargs)
+        else:
+            if step_kwargs is not None:
+                args = step_kwargs.get(curmethod, {})
+            else:
+                args = {}
+        self.kwargs['step'].append(
+            pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
+        return nuts_kwargs
+
+    @staticmethod
+    def _get_nuts_args(nuts_kwargs, step_kwargs):
+        if nuts_kwargs is not None:
+            args = nuts_kwargs
+        elif step_kwargs is not None:
+            args = step_kwargs.pop('nuts', {})
+            # add values into nuts_kwargs
+            nuts_kwargs = args
+        else:
+            args = {}
+        return args, nuts_kwargs
+
     def set_prior(self):
         """
         Set the PyMC3 prior distributions.
diff --git a/bilby/core/utils.py b/bilby/core/utils.py
index 71b0b5028..01bd97dca 100644
--- a/bilby/core/utils.py
+++ b/bilby/core/utils.py
@@ -131,6 +131,15 @@ def _infer_args_from_function_except_for_first_arg(func):
     return infer_args_from_function_except_n_args(func=func, n=1)
 
 
+def get_dict_with_properties(obj):
+    property_names = [p for p in dir(obj.__class__)
+                      if isinstance(getattr(obj.__class__, p), property)]
+    dict_with_properties = obj.__dict__.copy()
+    for key in property_names:
+        dict_with_properties[key] = getattr(obj, key)
+    return dict_with_properties
+
+
 def get_sampling_frequency(time_array):
     """
     Calculate sampling frequency from a time series
@@ -1022,6 +1031,41 @@ def encode_astropy_quantity(dct):
     return dct
 
 
+def move_old_file(filename, overwrite=False):
+    """ Moves or removes an old file.
+
+    Parameters
+    ----------
+    filename: str
+        Name of the file to be move
+    overwrite: bool, optional
+        Whether or not to remove the file or to change the name
+        to filename + '.old'
+    """
+    if os.path.isfile(filename):
+        if overwrite:
+            logger.debug('Removing existing file {}'.format(filename))
+            os.remove(filename)
+        else:
+            logger.debug(
+                'Renaming existing file {} to {}.old'.format(filename,
+                                                             filename))
+            os.rename(filename, filename + '.old')
+    logger.debug("Saving result to {}".format(filename))
+
+
+def load_json(filename, gzip):
+    if gzip or os.path.splitext(filename)[1].lstrip('.') == 'gz':
+        import gzip
+        with gzip.GzipFile(filename, 'r') as file:
+            json_str = file.read().decode('utf-8')
+        dictionary = json.loads(json_str, object_hook=decode_bilby_json)
+    else:
+        with open(filename, 'r') as file:
+            dictionary = json.load(file, object_hook=decode_bilby_json)
+    return dictionary
+
+
 def decode_bilby_json(dct):
     if dct.get("__prior_dict__", False):
         cls = getattr(import_module(dct['__module__']), dct['__name__'])
diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py
index 786262481..bdfb8de54 100644
--- a/bilby/gw/likelihood.py
+++ b/bilby/gw/likelihood.py
@@ -439,14 +439,7 @@ class GravitationalWaveTransient(Likelihood):
         if signal_polarizations is None:
             signal_polarizations = \
                 self.waveform_generator.frequency_domain_strain(self.parameters)
-        d_inner_h = 0
-        h_inner_h = 0
-        for interferometer in self.interferometers:
-            per_detector_snr = self.calculate_snrs(
-                signal_polarizations, interferometer)
-
-            d_inner_h += per_detector_snr.d_inner_h
-            h_inner_h += per_detector_snr.optimal_snr_squared
+        d_inner_h, h_inner_h = self._calculate_inner_products(signal_polarizations)
 
         d_inner_h_dist = (
             d_inner_h * self.parameters['luminosity_distance'] /
@@ -472,6 +465,17 @@ class GravitationalWaveTransient(Likelihood):
         self._rescale_signal(signal_polarizations, new_distance)
         return new_distance
 
+    def _calculate_inner_products(self, signal_polarizations):
+        d_inner_h = 0
+        h_inner_h = 0
+        for interferometer in self.interferometers:
+            per_detector_snr = self.calculate_snrs(
+                signal_polarizations, interferometer)
+
+            d_inner_h += per_detector_snr.d_inner_h
+            h_inner_h += per_detector_snr.optimal_snr_squared
+        return d_inner_h, h_inner_h
+
     def generate_phase_sample_from_marginalized_likelihood(
             self, signal_polarizations=None):
         """
@@ -497,14 +501,7 @@ class GravitationalWaveTransient(Likelihood):
         if signal_polarizations is None:
             signal_polarizations = \
                 self.waveform_generator.frequency_domain_strain(self.parameters)
-        d_inner_h = 0
-        h_inner_h = 0
-        for interferometer in self.interferometers:
-            per_detector_snr = self.calculate_snrs(
-                signal_polarizations, interferometer)
-
-            d_inner_h += per_detector_snr.d_inner_h
-            h_inner_h += per_detector_snr.optimal_snr_squared
+        d_inner_h, h_inner_h = self._calculate_inner_products(signal_polarizations)
 
         phases = np.linspace(0, 2 * np.pi, 101)
         phasor = np.exp(-2j * phases)
diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py
index 5feb3d51d..dc5ea7964 100644
--- a/bilby/gw/utils.py
+++ b/bilby/gw/utils.py
@@ -717,28 +717,15 @@ def lalsim_SimInspiralFD(
     approximant: int, str
     """
 
-    [mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z,
-     luminosity_distance, iota, phase, longitude_ascending_nodes,
-     eccentricity, mean_per_ano, delta_frequency, minimum_frequency,
-     maximum_frequency, reference_frequency] = convert_args_list_to_float(
+    args = convert_args_list_to_float(
         mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z,
         luminosity_distance, iota, phase, longitude_ascending_nodes,
         eccentricity, mean_per_ano, delta_frequency, minimum_frequency,
         maximum_frequency, reference_frequency)
 
-    if isinstance(approximant, int):
-        pass
-    elif isinstance(approximant, str):
-        approximant = lalsim_GetApproximantFromString(approximant)
-    else:
-        raise ValueError("approximant not an int")
+    approximant = _get_lalsim_approximant(approximant)
 
-    return lalsim.SimInspiralFD(
-        mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y,
-        spin_2z, luminosity_distance, iota, phase,
-        longitude_ascending_nodes, eccentricity, mean_per_ano, delta_frequency,
-        minimum_frequency, maximum_frequency, reference_frequency,
-        waveform_dictionary, approximant)
+    return lalsim.SimInspiralFD(*args, waveform_dictionary, approximant)
 
 
 def lalsim_SimInspiralChooseFDWaveform(
@@ -774,28 +761,25 @@ def lalsim_SimInspiralChooseFDWaveform(
     approximant: int, str
     """
 
-    [mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z,
-     luminosity_distance, iota, phase, longitude_ascending_nodes,
-     eccentricity, mean_per_ano, delta_frequency, minimum_frequency,
-     maximum_frequency, reference_frequency] = convert_args_list_to_float(
+    args = convert_args_list_to_float(
         mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z,
         luminosity_distance, iota, phase, longitude_ascending_nodes,
         eccentricity, mean_per_ano, delta_frequency, minimum_frequency,
         maximum_frequency, reference_frequency)
 
+    approximant = _get_lalsim_approximant(approximant)
+
+    return lalsim.SimInspiralChooseFDWaveform(*args, waveform_dictionary, approximant)
+
+
+def _get_lalsim_approximant(approximant):
     if isinstance(approximant, int):
         pass
     elif isinstance(approximant, str):
         approximant = lalsim_GetApproximantFromString(approximant)
     else:
         raise ValueError("approximant not an int")
-
-    return lalsim.SimInspiralChooseFDWaveform(
-        mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y,
-        spin_2z, luminosity_distance, iota, phase,
-        longitude_ascending_nodes, eccentricity, mean_per_ano, delta_frequency,
-        minimum_frequency, maximum_frequency, reference_frequency,
-        waveform_dictionary, approximant)
+    return approximant
 
 
 def lalsim_SimInspiralChooseFDWaveformSequence(
-- 
GitLab