From 4d3d51e63775ad65770558ded0115deb1a1dc947 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Thu, 21 Jun 2018 12:44:29 +1000
Subject: [PATCH] Add optional label to data saving

Adds the option to add a label to the saved data and plots associated
with Interferometer
---
 tupak/gw/detector.py | 39 +++++++++++++++++++++++++++------------
 1 file changed, 27 insertions(+), 12 deletions(-)

diff --git a/tupak/gw/detector.py b/tupak/gw/detector.py
index 54e20d245..bf554706d 100644
--- a/tupak/gw/detector.py
+++ b/tupak/gw/detector.py
@@ -862,7 +862,7 @@ class Interferometer(object):
         """
         return self.strain_data.frequency_domain_strain / self.amplitude_spectral_density_array
 
-    def save_data(self, outdir):
+    def save_data(self, outdir, label=None):
         """ Creates a save file for the data in plain text format
 
         Parameters
@@ -876,13 +876,17 @@ class Interferometer(object):
                         self.frequency_domain_strain.real,
                         self.frequency_domain_strain.imag]).T,
                    header='f real_h(f) imag_h(f)')
-        np.savetxt('{}/{}_psd.dat'.format(outdir, self.name),
+        if label is None:
+            filename = '{}/{}_psd.dat'.format(outdir, self.name)
+        else:
+            filename = '{}/{}_{}_psd.dat'.format(outdir, self.name, label)
+        np.savetxt(filename,
                    np.array(
                        [self.frequency_array,
                         self.amplitude_spectral_density_array]).T,
                    header='f h(f)')
 
-    def plot_data(self, signal=None, outdir='.'):
+    def plot_data(self, signal=None, outdir='.', label=None):
         fig, ax = plt.subplots()
         ax.loglog(self.frequency_array,
                   np.abs(self.frequency_domain_strain),
@@ -898,8 +902,13 @@ class Interferometer(object):
         ax.set_xlabel(r'frequency [Hz]')
         ax.set_xlim(20, 2000)
         ax.legend(loc='best')
-        fig.savefig(
-            '{}/{}_frequency_domain_data.png'.format(outdir, self.name))
+        if label is None:
+            fig.savefig(
+                '{}/{}_frequency_domain_data.png'.format(outdir, self.name))
+        else:
+            fig.savefig(
+                '{}/{}_{}_frequency_domain_data.png'.format(
+                    outdir, self.name, label))
 
 
 class PowerSpectralDensity(object):
@@ -1146,7 +1155,7 @@ def load_interferometer(filename):
 
 def get_interferometer_with_open_data(
         name, trigger_time, time_duration=4, start_time=None, alpha=0.25, psd_offset=-1024,
-        psd_duration=100, cache=True, outdir='outdir', plot=True, filter_freq=1024,
+        psd_duration=100, cache=True, outdir='outdir', label=None, plot=True, filter_freq=1024,
         raw_data_file=None, **kwargs):
     """
     Helper function to obtain an Interferometer instance with appropriate
@@ -1175,6 +1184,8 @@ def get_interferometer_with_open_data(
         Name of a raw data file if this supposed to be read from a local file
     outdir: str
         Directory where the psd files are saved
+    label: str
+        If given, an identifying label used in generating file names.
     plot: bool
         If true, create an ASD + strain plot
     filter_freq: float
@@ -1245,7 +1256,7 @@ def get_interferometer_with_open_data(
         start_time=strain.epoch.value)
 
     if plot:
-        interferometer.plot_data(outdir=outdir)
+        interferometer.plot_data(outdir=outdir, label=label)
 
     return interferometer
 
@@ -1253,7 +1264,7 @@ def get_interferometer_with_open_data(
 def get_interferometer_with_fake_noise_and_injection(
         name, injection_polarizations, injection_parameters,
         sampling_frequency=4096, time_duration=4, start_time=None,
-        outdir='outdir', plot=True, save=True, zero_noise=False):
+        outdir='outdir', label=None, plot=True, save=True, zero_noise=False):
     """
     Helper function to obtain an Interferometer instance with appropriate
     power spectral density and data, given an center_time.
@@ -1276,6 +1287,8 @@ def get_interferometer_with_fake_noise_and_injection(
         end of segment.
     outdir: str
         directory in which to store output
+    label: str
+        If given, an identifying label used in generating file names.
     plot: bool
         If true, create an ASD + strain plot
     save: bool
@@ -1311,10 +1324,10 @@ def get_interferometer_with_fake_noise_and_injection(
         injection_polarizations, injection_parameters)
 
     if plot:
-        interferometer.plot_data(signal=signal, outdir=outdir)
+        interferometer.plot_data(signal=signal, outdir=outdir, label=label)
 
     if save:
-        interferometer.save_data(outdir)
+        interferometer.save_data(outdir, label=label)
 
     return interferometer
 
@@ -1322,7 +1335,7 @@ def get_interferometer_with_fake_noise_and_injection(
 def get_event_data(
         event, interferometer_names=None, time_duration=4, alpha=0.25,
         psd_offset=-1024, psd_duration=100, cache=True, outdir='outdir',
-        plot=True, filter_freq=1024, raw_data_file=None, **kwargs):
+        label=None, plot=True, filter_freq=1024, raw_data_file=None, **kwargs):
     """
     Get open data for a specified event.
 
@@ -1347,6 +1360,8 @@ def get_event_data(
         If we want to read the event data from a local file.
     outdir: str
         Directory where the psd files are saved
+    label: str
+        If given, an identifying label used in generating file names.
     plot: bool
         If true, create an ASD + strain plot
     filter_freq: float
@@ -1374,7 +1389,7 @@ def get_event_data(
             interferometers.append(get_interferometer_with_open_data(
                 name, trigger_time=event_time, time_duration=time_duration, alpha=alpha,
                 psd_offset=psd_offset, psd_duration=psd_duration, cache=cache,
-                outdir=outdir, plot=plot, filter_freq=filter_freq,
+                outdir=outdir, label=label, plot=plot, filter_freq=filter_freq,
                 raw_data_file=raw_data_file, **kwargs))
         except ValueError:
             logging.warning('No data found for {}.'.format(name))
-- 
GitLab