Commit f709f23a authored by Gregory Ashton's avatar Gregory Ashton

Move plot_walkers to a results method

Moves the plot_walker method to the Results object. The benefit here is
that we can generate walker plots after the run and mitigate the chances
of a plot error spoiling a run.
parent 560dd832
......@@ -5,6 +5,7 @@ import deepdish
import pandas as pd
import corner
import matplotlib
import matplotlib.pyplot as plt
def result_file_name(outdir, label):
......@@ -259,6 +260,31 @@ class Result(dict):
return fig
def plot_walkers(self, save=True, **kwargs):
""" Method to plot the trace of the walkers in an ensmble MCMC plot """
if hasattr(self, 'walkers') is False:
logging.warning("Cannot plot_walkers as no walkers are saved")
return
nwalkers, nsteps, ndim = self.walkers.shape
idxs = np.arange(nsteps)
fig, axes = plt.subplots(nrows=ndim, figsize=(6, 3*ndim))
walkers = self.walkers[:, :, :]
for i, ax in enumerate(axes):
ax.plot(idxs[:self.nburn+1], walkers[:, :self.nburn+1, i].T,
lw=0.1, color='r')
ax.set_ylabel(self.parameter_labels[i])
for i, ax in enumerate(axes):
ax.plot(idxs[self.nburn:], walkers[:, self.nburn:, i].T, lw=0.1,
color='k')
ax.set_ylabel(self.parameter_labels[i])
fig.tight_layout()
filename = '{}/{}_walkers.png'.format(self.outdir, self.label)
logging.debug('Saving walkers plot to {}'.format('filename'))
fig.savefig(filename)
def plot_walks(self, save=True, **kwargs):
"""DEPRECATED"""
logging.warning("plot_walks deprecated")
......
......@@ -5,7 +5,6 @@ import logging
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import datetime
import deepdish
......@@ -775,10 +774,10 @@ class Emcee(Sampler):
self.result.samples = sampler.chain[:, self.nburn:, :].reshape(
(-1, self.ndim))
self.result.walkers = sampler.chain[:, :, :]
self.result.nburn = self.nburn
self.result.log_evidence = np.nan
self.result.log_evidence_err = np.nan
if self.plot:
self.plot_walkers()
try:
logging.info("Max autocorr time = {}".format(
np.max(sampler.get_autocorr_time())))
......@@ -789,29 +788,6 @@ class Emcee(Sampler):
def lnpostfn(self, theta):
return self.log_likelihood(theta) + self.log_prior(theta)
def _get_walkers_to_plot(self):
return self.result.walkers[:, :, :]
def plot_walkers(self, save=True, **kwargs):
nwalkers, nsteps, ndim = self.result.walkers.shape
idxs = np.arange(nsteps)
fig, axes = plt.subplots(nrows=ndim, figsize=(6, 3*self.ndim))
walkers = self._get_walkers_to_plot()
for i, ax in enumerate(axes):
ax.plot(idxs[:self.nburn+1], walkers[:, :self.nburn+1, i].T,
lw=0.1, color='r')
ax.set_ylabel(self.result.parameter_labels[i])
for i, ax in enumerate(axes):
ax.plot(idxs[self.nburn:], walkers[:, self.nburn:, i].T, lw=0.1,
color='k')
ax.set_ylabel(self.result.parameter_labels[i])
fig.tight_layout()
filename = '{}/{}_walkers.png'.format(self.outdir, self.label)
logging.debug('Saving walkers plot to {}'.format('filename'))
fig.savefig(filename)
class Ptemcee(Emcee):
""" https://github.com/willvousden/ptemcee """
......@@ -843,8 +819,7 @@ class Ptemcee(Emcee):
self.result.walkers = sampler.chain[0, :, :, :]
self.result.log_evidence = np.nan
self.result.log_evidence_err = np.nan
if self.plot:
self.plot_walkers()
logging.info("Max autocorr time = {}"
.format(np.max(sampler.get_autocorr_time())))
logging.info("Tswap frac = {}"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment