From f709f23ab7e59665eb2bda553e4382b40cda3540 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Tue, 26 Jun 2018 14:57:50 +1000
Subject: [PATCH] Move plot_walkers to a results method

Moves the plot_walker method to the Results object. The benefit here is
that we can generate walker plots after the run and mitigate the chances
of a plot error spoiling a run.
---
 tupak/core/result.py  | 26 ++++++++++++++++++++++++++
 tupak/core/sampler.py | 31 +++----------------------------
 2 files changed, 29 insertions(+), 28 deletions(-)

diff --git a/tupak/core/result.py b/tupak/core/result.py
index 6e6a6ddab..a515ffb94 100644
--- a/tupak/core/result.py
+++ b/tupak/core/result.py
@@ -5,6 +5,7 @@ import deepdish
 import pandas as pd
 import corner
 import matplotlib
+import matplotlib.pyplot as plt
 
 
 def result_file_name(outdir, label):
@@ -259,6 +260,31 @@ class Result(dict):
 
         return fig
 
+    def plot_walkers(self, save=True, **kwargs):
+        """ Method to plot the trace of the walkers in an ensmble MCMC plot """
+        if hasattr(self, 'walkers') is False:
+            logging.warning("Cannot plot_walkers as no walkers are saved")
+            return
+
+        nwalkers, nsteps, ndim = self.walkers.shape
+        idxs = np.arange(nsteps)
+        fig, axes = plt.subplots(nrows=ndim, figsize=(6, 3*ndim))
+        walkers = self.walkers[:, :, :]
+        for i, ax in enumerate(axes):
+            ax.plot(idxs[:self.nburn+1], walkers[:, :self.nburn+1, i].T,
+                    lw=0.1, color='r')
+            ax.set_ylabel(self.parameter_labels[i])
+
+        for i, ax in enumerate(axes):
+            ax.plot(idxs[self.nburn:], walkers[:, self.nburn:, i].T, lw=0.1,
+                    color='k')
+            ax.set_ylabel(self.parameter_labels[i])
+
+        fig.tight_layout()
+        filename = '{}/{}_walkers.png'.format(self.outdir, self.label)
+        logging.debug('Saving walkers plot to {}'.format('filename'))
+        fig.savefig(filename)
+
     def plot_walks(self, save=True, **kwargs):
         """DEPRECATED"""
         logging.warning("plot_walks deprecated")
diff --git a/tupak/core/sampler.py b/tupak/core/sampler.py
index 70835ea9e..50646a6bb 100644
--- a/tupak/core/sampler.py
+++ b/tupak/core/sampler.py
@@ -5,7 +5,6 @@ import logging
 import os
 import sys
 import numpy as np
-import matplotlib.pyplot as plt
 import datetime
 import deepdish
 
@@ -775,10 +774,10 @@ class Emcee(Sampler):
         self.result.samples = sampler.chain[:, self.nburn:, :].reshape(
             (-1, self.ndim))
         self.result.walkers = sampler.chain[:, :, :]
+        self.result.nburn = self.nburn
         self.result.log_evidence = np.nan
         self.result.log_evidence_err = np.nan
-        if self.plot:
-            self.plot_walkers()
+
         try:
             logging.info("Max autocorr time = {}".format(
                          np.max(sampler.get_autocorr_time())))
@@ -789,29 +788,6 @@ class Emcee(Sampler):
     def lnpostfn(self, theta):
         return self.log_likelihood(theta) + self.log_prior(theta)
 
-    def _get_walkers_to_plot(self):
-        return self.result.walkers[:, :, :]
-
-    def plot_walkers(self, save=True, **kwargs):
-        nwalkers, nsteps, ndim = self.result.walkers.shape
-        idxs = np.arange(nsteps)
-        fig, axes = plt.subplots(nrows=ndim, figsize=(6, 3*self.ndim))
-        walkers = self._get_walkers_to_plot()
-        for i, ax in enumerate(axes):
-            ax.plot(idxs[:self.nburn+1], walkers[:, :self.nburn+1, i].T,
-                    lw=0.1, color='r')
-            ax.set_ylabel(self.result.parameter_labels[i])
-
-        for i, ax in enumerate(axes):
-            ax.plot(idxs[self.nburn:], walkers[:, self.nburn:, i].T, lw=0.1,
-                    color='k')
-            ax.set_ylabel(self.result.parameter_labels[i])
-
-        fig.tight_layout()
-        filename = '{}/{}_walkers.png'.format(self.outdir, self.label)
-        logging.debug('Saving walkers plot to {}'.format('filename'))
-        fig.savefig(filename)
-
 
 class Ptemcee(Emcee):
     """ https://github.com/willvousden/ptemcee """
@@ -843,8 +819,7 @@ class Ptemcee(Emcee):
         self.result.walkers = sampler.chain[0, :, :, :]
         self.result.log_evidence = np.nan
         self.result.log_evidence_err = np.nan
-        if self.plot:
-            self.plot_walkers()
+
         logging.info("Max autocorr time = {}"
                      .format(np.max(sampler.get_autocorr_time())))
         logging.info("Tswap frac = {}"
-- 
GitLab