From 900b09d231bd579c03afd658cdd8efb6d8942cf5 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Fri, 2 Aug 2019 14:11:44 +1000
Subject: [PATCH 1/2] allow users to specify external source functions

---
 bilby_pipe/data_analysis.py   |  2 +-
 bilby_pipe/data_generation.py |  4 ++--
 bilby_pipe/input.py           | 12 +++++++++++-
 3 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/bilby_pipe/data_analysis.py b/bilby_pipe/data_analysis.py
index b804dfaa..2307d910 100644
--- a/bilby_pipe/data_analysis.py
+++ b/bilby_pipe/data_analysis.py
@@ -246,7 +246,7 @@ class DataAnalysisInput(Input):
         waveform_arguments = self.get_default_waveform_arguments()
 
         if self.likelihood_type == "GravitationalWaveTransient":
-            waveform_generator = bilby.gw.WaveformGenerator(
+            waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
                 sampling_frequency=self.interferometers.sampling_frequency,
                 duration=self.interferometers.duration,
                 frequency_domain_source_model=self.bilby_frequency_domain_source_model,
diff --git a/bilby_pipe/data_generation.py b/bilby_pipe/data_generation.py
index 71a35cc0..3820dfa5 100644
--- a/bilby_pipe/data_generation.py
+++ b/bilby_pipe/data_generation.py
@@ -333,7 +333,7 @@ class DataGenerationInput(Input):
 
         waveform_arguments = self.get_default_waveform_arguments()
 
-        waveform_generator = bilby.gw.WaveformGenerator(
+        waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
             duration=self.duration,
             start_time=self.start_time,
             sampling_frequency=self.sampling_frequency,
@@ -397,7 +397,7 @@ class DataGenerationInput(Input):
 
         waveform_arguments = self.get_default_waveform_arguments()
 
-        waveform_generator = bilby.gw.WaveformGenerator(
+        waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
             duration=self.duration,
             sampling_frequency=self.sampling_frequency,
             frequency_domain_source_model=self.bilby_frequency_domain_source_model,
diff --git a/bilby_pipe/input.py b/bilby_pipe/input.py
index 3e03d6b8..ab937e80 100644
--- a/bilby_pipe/input.py
+++ b/bilby_pipe/input.py
@@ -7,6 +7,7 @@ from __future__ import division, print_function
 import os
 import glob
 import json
+from importlib import import_module
 
 import numpy as np
 import bilby
@@ -166,11 +167,20 @@ class Input(object):
 
     @property
     def bilby_frequency_domain_source_model(self):
-        """ The bilby function to pass to the waveform_generator """
+        """
+        The bilby function to pass to the waveform_generator
+
+        This can be a function defined in an external package.
+        """
         if self.frequency_domain_source_model in bilby.gw.source.__dict__.keys():
             model = self._frequency_domain_source_model
             logger.info("Using the {} source model".format(model))
             return bilby.gw.source.__dict__[model]
+        elif "." in self.frequency_domain_source_model:
+            split_model = self._frequency_domain_source_model.split(".")
+            module = '.'.join(split_model[:-1])
+            func = split_model[-1]
+            return getattr(import_module(module), func)
         else:
             raise BilbyPipeError(
                 "No source model {} found.".format(self._frequency_domain_source_model)
-- 
GitLab


From a1f34d63e07c6df3d16e8e94d3e88cc28b023c0c Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Fri, 2 Aug 2019 14:19:59 +1000
Subject: [PATCH 2/2] formatting

---
 bilby_pipe/input.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/bilby_pipe/input.py b/bilby_pipe/input.py
index ab937e80..caba2c41 100644
--- a/bilby_pipe/input.py
+++ b/bilby_pipe/input.py
@@ -178,7 +178,7 @@ class Input(object):
             return bilby.gw.source.__dict__[model]
         elif "." in self.frequency_domain_source_model:
             split_model = self._frequency_domain_source_model.split(".")
-            module = '.'.join(split_model[:-1])
+            module = ".".join(split_model[:-1])
             func = split_model[-1]
             return getattr(import_module(module), func)
         else:
-- 
GitLab