From 5ef8f99f98b091704194b793edb6f9f51f8ed1e9 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Mon, 14 May 2018 22:29:26 +1000
Subject: [PATCH] Closes #62

If a label is given to setup_logger, the output is written to a file.
---
 examples/injection_examples/basic_tutorial.py |  5 ++--
 examples/open_data_examples/GW150914.py       |  6 ++---
 tupak/utils.py                                | 25 +++++++++++++++----
 3 files changed, 26 insertions(+), 10 deletions(-)

diff --git a/examples/injection_examples/basic_tutorial.py b/examples/injection_examples/basic_tutorial.py
index d5eeb5881..c105e80d0 100644
--- a/examples/injection_examples/basic_tutorial.py
+++ b/examples/injection_examples/basic_tutorial.py
@@ -9,11 +9,12 @@ from __future__ import division, print_function
 import tupak
 import numpy as np
 
-tupak.utils.setup_logger(log_level="info")
 
 time_duration = 4.
 sampling_frequency = 2048.
 outdir = 'outdir'
+label = 'basic_tutorial'
+tupak.utils.setup_logger(outdir=outdir, label=label, log_level="info")
 
 np.random.seed(170809)
 
@@ -45,7 +46,7 @@ likelihood = tupak.likelihood.Likelihood(interferometers=IFOs, waveform_generato
 
 # Run sampler
 result = tupak.sampler.run_sampler(likelihood=likelihood, priors=priors, sampler='dynesty', npoints=1000,
-                                   injection_parameters=injection_parameters, outdir=outdir, label='BasicTutorial')
+                                   injection_parameters=injection_parameters, outdir=outdir, label=label)
 result.plot_corner()
 result.plot_walks()
 result.plot_distributions()
diff --git a/examples/open_data_examples/GW150914.py b/examples/open_data_examples/GW150914.py
index 760271e02..4d0271477 100644
--- a/examples/open_data_examples/GW150914.py
+++ b/examples/open_data_examples/GW150914.py
@@ -9,14 +9,14 @@ commonly used prior distributions.  This will take a few hours to run.
 from __future__ import division, print_function
 import tupak
 
-# This sets up logging output to understand what tupak is doing
-tupak.utils.setup_logger()
-
 # Define some convienence labels and the trigger time of the event
 outdir = 'outdir'
 label = 'GW150914'
 time_of_event = tupak.utils.get_event_time(label)
 
+# This sets up logging output to understand what tupak is doing
+tupak.utils.setup_logger(outdir=outdir, label=label)
+
 # Here we import the detector data. This step downloads data from the
 # LIGO/Virgo open data archives. The data is saved to an `Interferometer`
 # object (here called `H1` and `L1`). A Power Spectral Density (PSD) estimate
diff --git a/tupak/utils.py b/tupak/utils.py
index 023a0550a..eb1e20025 100644
--- a/tupak/utils.py
+++ b/tupak/utils.py
@@ -281,19 +281,17 @@ def get_vertex_position_geocentric(latitude, longitude, elevation):
     return np.array([x_comp, y_comp, z_comp])
 
 
-def setup_logger(log_level='info'):
+def setup_logger(outdir=None, label=None, log_level='info'):
     """ Setup logging output: call at the start of the script to use
 
     Parameters
     ----------
+    outdir, label: str
+        If supplied, write the logging output to outdir/label.log
     log_level = ['debug', 'info', 'warning']
         Either a string from the list above, or an interger as specified
         in https://docs.python.org/2/library/logging.html#logging-levels
     """
-    logger = logging.getLogger()
-    stream_handler = logging.StreamHandler()
-    stream_handler.setFormatter(logging.Formatter(
-        '%(asctime)s %(levelname)-8s: %(message)s', datefmt='%H:%M'))
 
     if type(log_level) is str:
         try:
@@ -303,10 +301,27 @@ def setup_logger(log_level='info'):
     else:
         LEVEL = int(log_level)
 
+    logger = logging.getLogger()
+    stream_handler = logging.StreamHandler()
+    stream_handler.setFormatter(logging.Formatter(
+        '%(asctime)s %(levelname)-8s: %(message)s', datefmt='%H:%M'))
     logger.setLevel(LEVEL)
     stream_handler.setLevel(LEVEL)
     logger.addHandler(stream_handler)
 
+    if label:
+        if outdir:
+            check_directory_exists_and_if_not_mkdir(outdir)
+        else:
+            outdir = '.'
+        log_file = '{}/{}.log'.format(outdir, label)
+        file_handler = logging.FileHandler(log_file)
+        file_handler.setFormatter(logging.Formatter(
+            '%(asctime)s %(levelname)-8s: %(message)s', datefmt='%H:%M'))
+
+        file_handler.setLevel(LEVEL)
+        logger.addHandler(file_handler)
+
 
 def get_progress_bar(module='tqdm'):
     if module in ['tqdm']:
-- 
GitLab