diff --git a/tupak/core/result.py b/tupak/core/result.py index 6e6a6ddabb46f234f4d02f5dd8c2b0a411f1fbd7..a515ffb94afd7f5fb38caf85e127e5c95dcd9da2 100644 --- a/tupak/core/result.py +++ b/tupak/core/result.py @@ -5,6 +5,7 @@ import deepdish import pandas as pd import corner import matplotlib +import matplotlib.pyplot as plt def result_file_name(outdir, label): @@ -259,6 +260,31 @@ class Result(dict): return fig + def plot_walkers(self, save=True, **kwargs): + """ Method to plot the trace of the walkers in an ensmble MCMC plot """ + if hasattr(self, 'walkers') is False: + logging.warning("Cannot plot_walkers as no walkers are saved") + return + + nwalkers, nsteps, ndim = self.walkers.shape + idxs = np.arange(nsteps) + fig, axes = plt.subplots(nrows=ndim, figsize=(6, 3*ndim)) + walkers = self.walkers[:, :, :] + for i, ax in enumerate(axes): + ax.plot(idxs[:self.nburn+1], walkers[:, :self.nburn+1, i].T, + lw=0.1, color='r') + ax.set_ylabel(self.parameter_labels[i]) + + for i, ax in enumerate(axes): + ax.plot(idxs[self.nburn:], walkers[:, self.nburn:, i].T, lw=0.1, + color='k') + ax.set_ylabel(self.parameter_labels[i]) + + fig.tight_layout() + filename = '{}/{}_walkers.png'.format(self.outdir, self.label) + logging.debug('Saving walkers plot to {}'.format('filename')) + fig.savefig(filename) + def plot_walks(self, save=True, **kwargs): """DEPRECATED""" logging.warning("plot_walks deprecated") diff --git a/tupak/core/sampler.py b/tupak/core/sampler.py index 70835ea9e0789779c78d433b8c58407e935df9d2..50646a6bb4e030a98b5f278351d889772ab7294a 100644 --- a/tupak/core/sampler.py +++ b/tupak/core/sampler.py @@ -5,7 +5,6 @@ import logging import os import sys import numpy as np -import matplotlib.pyplot as plt import datetime import deepdish @@ -775,10 +774,10 @@ class Emcee(Sampler): self.result.samples = sampler.chain[:, self.nburn:, :].reshape( (-1, self.ndim)) self.result.walkers = sampler.chain[:, :, :] + self.result.nburn = self.nburn self.result.log_evidence = np.nan self.result.log_evidence_err = np.nan - if self.plot: - self.plot_walkers() + try: logging.info("Max autocorr time = {}".format( np.max(sampler.get_autocorr_time()))) @@ -789,29 +788,6 @@ class Emcee(Sampler): def lnpostfn(self, theta): return self.log_likelihood(theta) + self.log_prior(theta) - def _get_walkers_to_plot(self): - return self.result.walkers[:, :, :] - - def plot_walkers(self, save=True, **kwargs): - nwalkers, nsteps, ndim = self.result.walkers.shape - idxs = np.arange(nsteps) - fig, axes = plt.subplots(nrows=ndim, figsize=(6, 3*self.ndim)) - walkers = self._get_walkers_to_plot() - for i, ax in enumerate(axes): - ax.plot(idxs[:self.nburn+1], walkers[:, :self.nburn+1, i].T, - lw=0.1, color='r') - ax.set_ylabel(self.result.parameter_labels[i]) - - for i, ax in enumerate(axes): - ax.plot(idxs[self.nburn:], walkers[:, self.nburn:, i].T, lw=0.1, - color='k') - ax.set_ylabel(self.result.parameter_labels[i]) - - fig.tight_layout() - filename = '{}/{}_walkers.png'.format(self.outdir, self.label) - logging.debug('Saving walkers plot to {}'.format('filename')) - fig.savefig(filename) - class Ptemcee(Emcee): """ https://github.com/willvousden/ptemcee """ @@ -843,8 +819,7 @@ class Ptemcee(Emcee): self.result.walkers = sampler.chain[0, :, :, :] self.result.log_evidence = np.nan self.result.log_evidence_err = np.nan - if self.plot: - self.plot_walkers() + logging.info("Max autocorr time = {}" .format(np.max(sampler.get_autocorr_time()))) logging.info("Tswap frac = {}"