diff --git a/examples/injection_examples/basic_tutorial.py b/examples/injection_examples/basic_tutorial.py
index 82dbc13d3214bb225956254ad9cf9054841eb1e6..f9946da0039934f222a9072026a2f773c67d0ea1 100644
--- a/examples/injection_examples/basic_tutorial.py
+++ b/examples/injection_examples/basic_tutorial.py
@@ -63,6 +63,4 @@ result = tupak.sampler.run_sampler(likelihood=likelihood, priors=priors, sampler
 
 # make some plots of the outputs
 result.plot_corner()
-result.plot_walks()
-result.plot_distributions()
 print(result)
diff --git a/examples/injection_examples/change_sampled_parameters.py b/examples/injection_examples/change_sampled_parameters.py
index d24c7e4c2bd1cb5dd1310c87b300b3c140e65f56..6d7734b51b60d73f1782762904a02ab7550ba039 100644
--- a/examples/injection_examples/change_sampled_parameters.py
+++ b/examples/injection_examples/change_sampled_parameters.py
@@ -50,6 +50,4 @@ result = tupak.sampler.run_sampler(likelihood=likelihood, priors=priors, sampler
                                    injection_parameters=injection_parameters, label='DifferentParameters',
                                    outdir=outdir, conversion_function=tupak.conversion.generate_all_bbh_parameters)
 result.plot_corner()
-result.plot_walks()
-result.plot_distributions()
 print(result)
diff --git a/examples/injection_examples/create_your_own_source_model.py b/examples/injection_examples/create_your_own_source_model.py
index 64b74ded3ec33b7ade097de89daa4150fa0cd8d8..51d548526404169d69c671121c864206de4dbdf9 100644
--- a/examples/injection_examples/create_your_own_source_model.py
+++ b/examples/injection_examples/create_your_own_source_model.py
@@ -50,7 +50,5 @@ likelihood = tupak.likelihood.GravitationalWaveTransient(IFOs, waveform_generato
 result = tupak.sampler.run_sampler(
     likelihood, prior, sampler='dynesty', outdir=outdir, label=label,
     resume=False, sample='unif', injection_parameters=injection_parameters)
-result.plot_walks()
-result.plot_distributions()
 result.plot_corner()
 print(result)
diff --git a/examples/injection_examples/create_your_own_time_domain_source_model.py b/examples/injection_examples/create_your_own_time_domain_source_model.py
index eb95e7f379cfd73e30b872a274e3ea4b18ed2f2f..af51c5812b5cfc7f49f57f1fb5c3e689c128e1c4 100644
--- a/examples/injection_examples/create_your_own_time_domain_source_model.py
+++ b/examples/injection_examples/create_your_own_time_domain_source_model.py
@@ -71,8 +71,5 @@ result = tupak.sampler.run_sampler(likelihood, prior, sampler='dynesty', npoints
                                     injection_parameters=injection_parameters,
                                     outdir=outdir, label=label)
 
-result.plot_walks()
-result.plot_distributions()
 result.plot_corner()
-
 print(result)
diff --git a/examples/injection_examples/how_to_specify_the_prior.py b/examples/injection_examples/how_to_specify_the_prior.py
index abd2237330dceea7c8dfea31dcde9a076c43ac1d..7a3a0743e4826a4714507f0bca99cf4eb77ad07d 100644
--- a/examples/injection_examples/how_to_specify_the_prior.py
+++ b/examples/injection_examples/how_to_specify_the_prior.py
@@ -64,6 +64,4 @@ likelihood = tupak.likelihood.GravitationalWaveTransient(interferometers=IFOs, w
 result = tupak.sampler.run_sampler(likelihood=likelihood, priors=priors, sampler='dynesty',
                                    injection_parameters=injection_parameters, outdir=outdir, label='specify_prior')
 result.plot_corner()
-result.plot_walks()
-result.plot_distributions()
 print(result)
diff --git a/examples/injection_examples/marginalized_likelihood.py b/examples/injection_examples/marginalized_likelihood.py
index b39f62614b4150d43ea4ca99d7cb6b62bb824ae1..2cd028c550526dbb455d2da63ed146e89f616325 100644
--- a/examples/injection_examples/marginalized_likelihood.py
+++ b/examples/injection_examples/marginalized_likelihood.py
@@ -46,6 +46,4 @@ likelihood = tupak.likelihood.GravitationalWaveTransient(
 result = tupak.sampler.run_sampler(likelihood=likelihood, priors=priors, sampler='dynesty',
                                    injection_parameters=injection_parameters, outdir=outdir, label='BasicTutorial')
 result.plot_corner()
-result.plot_walks()
-result.plot_distributions()
 print(result)
diff --git a/examples/open_data_examples/GW150914.py b/examples/open_data_examples/GW150914.py
index e9a66a0784eae92c6e85d74d3943684aa5e89814..56ea1cde47d0274b68c2088ccd79ca571937d5a2 100644
--- a/examples/open_data_examples/GW150914.py
+++ b/examples/open_data_examples/GW150914.py
@@ -60,6 +60,4 @@ likelihood = tupak.likelihood.GravitationalWaveTransient(interferometers, wavefo
 result = tupak.sampler.run_sampler(likelihood, prior, sampler='dynesty',
                                    outdir=outdir, label=label)
 result.plot_corner()
-result.plot_walks()
-result.plot_distributions()
 print(result)
diff --git a/examples/other_examples/hyper_parameter_example.py b/examples/other_examples/hyper_parameter_example.py
new file mode 100644
index 0000000000000000000000000000000000000000..936ef26833c2005ba4f614cbc154b0c93c09784e
--- /dev/null
+++ b/examples/other_examples/hyper_parameter_example.py
@@ -0,0 +1,82 @@
+#!/bin/python
+"""
+An example of how to use tupak to perform paramater estimation for hyperparams
+"""
+from __future__ import division
+import tupak
+import numpy as np
+
+tupak.utils.setup_logger()
+outdir = 'outdir'
+
+
+class GaussianLikelihood(tupak.likelihood.Likelihood):
+    def __init__(self, x, y, waveform_generator):
+        self.x = x
+        self.y = y
+        self.N = len(x)
+        self.waveform_generator = waveform_generator
+        self.parameters = waveform_generator.parameters
+
+    def log_likelihood(self):
+        sigma = 1
+        res = self.y - self.waveform_generator.time_domain_strain()
+        return -0.5 * (np.sum((res / sigma)**2)
+                       + self.N*np.log(2*np.pi*sigma**2))
+
+
+def model(time, m):
+    return m * time
+
+
+sampling_frequency = 10
+time_duration = 100
+time = np.arange(0, time_duration, 1/sampling_frequency)
+
+true_mu_m = 5
+true_sigma_m = 0.1
+sigma = 0.1
+Nevents = 10
+samples = []
+
+# Make the sample sets
+for i in range(Nevents):
+    m = np.random.normal(true_mu_m, true_sigma_m)
+    injection_parameters = dict(m=m)
+
+    N = len(time)
+    data = model(time, **injection_parameters) + np.random.normal(0, sigma, N)
+
+    waveform_generator = tupak.waveform_generator.WaveformGenerator(
+        time_duration=time_duration, sampling_frequency=sampling_frequency,
+        time_domain_source_model=model)
+
+    likelihood = GaussianLikelihood(time, data, waveform_generator)
+
+    priors = dict(m=tupak.prior.Uniform(-10, 10, 'm'))
+
+    result = tupak.sampler.run_sampler(
+        likelihood=likelihood, priors=priors, sampler='dynesty', npoints=1000,
+        injection_parameters=injection_parameters, outdir=outdir,
+        verbose=False, label='individual_{}'.format(i), use_ratio=False,
+        sample='unif')
+    result.plot_corner()
+    samples.append(result.samples)
+
+# Now run the hyperparameter inference
+run_prior = tupak.prior.Uniform(minimum=-10, maximum=10, name='mu_m')
+hyper_prior = tupak.prior.Gaussian(mu=0, sigma=1, name='hyper')
+
+hp_likelihood = tupak.likelihood.HyperparameterLikelihood(
+        samples, hyper_prior, run_prior)
+
+hp_priors = dict(
+    mu=tupak.prior.Uniform(-10, 10, 'mu', '$\mu_m$'),
+    sigma=tupak.prior.Uniform(0, 10, 'sigma', '$\sigma_m$'))
+
+# And run sampler
+result = tupak.sampler.run_sampler(
+    likelihood=hp_likelihood, priors=hp_priors, sampler='dynesty',
+    npoints=1000, outdir=outdir, label='hyperparameter', use_ratio=False,
+    sample='unif', verbose=True)
+result.plot_corner(truth=dict(mu=true_mu_m, sigma=true_sigma_m))
diff --git a/requirements.txt b/requirements.txt
index f71732b9408956a5cac6709a0079630222d20fde..fe8a629b9a049938e04c73d9c49034b19bffc522 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,7 @@
 future
 dynesty
 corner
-numpy
+numpy>=1.9
 matplotlib>=2.0
 scipy
 gwpy
diff --git a/test/sampler_tests.py b/test/sampler_tests.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcbdfdf255e1c47e84e74230b11c5740e773e5c7
--- /dev/null
+++ b/test/sampler_tests.py
@@ -0,0 +1,130 @@
+from context import tupak
+from tupak import prior
+from tupak.result import Result
+import unittest
+from mock import MagicMock
+import numpy as np
+import inspect
+import os
+import copy
+
+
+class TestSampler(unittest.TestCase):
+
+    def setUp(self):
+        likelihood = tupak.likelihood.Likelihood()
+        likelihood.parameters = dict(a=1, b=2, c=3)
+        delta_prior = prior.DeltaFunction(peak=0)
+        delta_prior.rescale = MagicMock(return_value=prior.DeltaFunction(peak=1))
+        delta_prior.prob = MagicMock(return_value=1)
+        delta_prior.sample = MagicMock(return_value=0)
+        uniform_prior = prior.Uniform(0, 1)
+        uniform_prior.rescale = MagicMock(return_value=prior.Uniform(0, 2))
+        uniform_prior.prob = MagicMock(return_value=1)
+        uniform_prior.sample = MagicMock(return_value=0.5)
+
+        priors = dict(a=delta_prior, b='string', c=uniform_prior)
+        likelihood.log_likelihood_ratio = MagicMock(return_value=1)
+        likelihood.log_likelihood = MagicMock(return_value=2)
+        test_directory = 'test_directory'
+        if os.path.isdir(test_directory):
+            os.rmdir(test_directory)
+        self.sampler = tupak.sampler.Sampler(likelihood=likelihood,
+                                             priors=priors,
+                                             external_sampler='nestle',
+                                             outdir=test_directory,
+                                             use_ratio=False)
+
+    def tearDown(self):
+        os.rmdir(self.sampler.outdir)
+        del self.sampler
+
+    def test_search_parameter_keys(self):
+        expected_search_parameter_keys = ['c']
+        self.assertListEqual(self.sampler.search_parameter_keys, expected_search_parameter_keys)
+
+    def test_fixed_parameter_keys(self):
+        expected_fixed_parameter_keys = ['a']
+        self.assertListEqual(self.sampler.fixed_parameter_keys, expected_fixed_parameter_keys)
+
+    def test_ndim(self):
+        self.assertEqual(self.sampler.ndim, 1)
+
+    def test_kwargs(self):
+        self.assertDictEqual(self.sampler.kwargs, {})
+
+    def test_label(self):
+        self.assertEqual(self.sampler.label, 'label')
+
+    def test_if_external_sampler_is_module(self):
+        self.assertTrue(inspect.ismodule(self.sampler.external_sampler))
+
+    def test_if_external_sampler_has_the_correct_module_name(self):
+        expected_name = 'nestle'
+        self.assertEqual(self.sampler.external_sampler.__name__, expected_name)
+
+    def test_external_sampler_raises_if_sampler_not_installed(self):
+        with self.assertRaises(ImportError):
+            self.sampler.external_sampler = 'unexpected_sampler'
+
+    def test_setting_custom_sampler(self):
+        other_sampler = tupak.sampler.Sampler(self.sampler.likelihood,
+                                             self.sampler.priors)
+        self.sampler.external_sampler = other_sampler
+        self.assertEqual(self.sampler.external_sampler, other_sampler)
+
+    def test_setting_external_sampler_to_something_else_raises_error(self):
+        with self.assertRaises(TypeError):
+            self.sampler.external_sampler = object()
+
+    def test_result(self):
+        expected_result = Result()
+        expected_result.search_parameter_keys = ['c']
+        expected_result.fixed_parameter_keys = ['a']
+        expected_result.parameter_labels = ['c']
+        expected_result.label = 'label'
+        expected_result.outdir = 'outdir'
+        expected_result.kwargs = {}
+        self.assertDictEqual(self.sampler.result.__dict__, expected_result.__dict__)
+
+    def test_make_outdir_if_no_outdir_exists(self):
+        self.assertTrue(os.path.isdir(self.sampler.outdir))
+
+    def test_prior_transform_transforms_search_parameter_keys(self):
+        self.sampler.prior_transform([0])
+        expected_prior = prior.Uniform(0, 1)
+        self.assertListEqual([self.sampler.priors['c'].minimum,
+                              self.sampler.priors['c'].maximum],
+                             [expected_prior.minimum,
+                              expected_prior.maximum])
+
+    def test_prior_transform_does_not_transform_fixed_parameter_keys(self):
+        self.sampler.prior_transform([0])
+        self.assertEqual(self.sampler.priors['a'].peak,
+                         prior.DeltaFunction(peak=0).peak)
+
+    def test_log_prior(self):
+        self.assertEqual(self.sampler.log_prior({1}), 0.0)
+
+    def test_log_likelihood_with_use_ratio(self):
+        self.sampler.use_ratio = True
+        self.assertEqual(self.sampler.log_likelihood([0]), 1)
+
+    def test_log_likelihood_without_use_ratio(self):
+        self.sampler.use_ratio = False
+        self.assertEqual(self.sampler.log_likelihood([0]), 2)
+
+    def test_log_likelihood_correctly_sets_parameters(self):
+        expected_dict = dict(a=0,
+                             b=2,
+                             c=0)
+        _ = self.sampler.log_likelihood([0])
+        self.assertDictEqual(self.sampler.likelihood.parameters, expected_dict)
+
+    def test_get_random_draw(self):
+        self.assertEqual(self.sampler.get_random_draw_from_prior(), np.array([0.5]))
+
+    def test_base_run_sampler(self):
+        sampler_copy = copy.copy(self.sampler)
+        self.sampler.run_sampler()
+        self.assertDictEqual(sampler_copy.__dict__, self.sampler.__dict__)
\ No newline at end of file
diff --git a/tupak/likelihood.py b/tupak/likelihood.py
index 382c1aea8cddabdb010ee507f0855f06154ea749..47e70944d59a2220f19ddf1fcb93ce6322fd21bf 100644
--- a/tupak/likelihood.py
+++ b/tupak/likelihood.py
@@ -14,6 +14,9 @@ import logging
 class Likelihood(object):
     """ Empty likelihood class to be subclassed by other likelihoods """
 
+    def __init__(self, parameters=None):
+        self.parameters = parameters
+
     def log_likelihood(self):
         return np.nan
 
@@ -51,16 +54,15 @@ class GravitationalWaveTransient(Likelihood):
     Returns
     -------
     Likelihood: `tupak.likelihood.Likelihood`
-        A likehood object, able to compute the likelihood of the data given
+        A likelihood object, able to compute the likelihood of the data given
         some model parameters
 
     """
     def __init__(self, interferometers, waveform_generator, distance_marginalization=False, phase_marginalization=False,
                  prior=None):
-        # GravitationalWaveTransient.__init__(self, interferometers, waveform_generator)
+        Likelihood.__init__(self, waveform_generator.parameters)
         self.interferometers = interferometers
         self.waveform_generator = waveform_generator
-        self.parameters = self.waveform_generator.parameters
         self.non_standard_sampling_parameter_keys = self.waveform_generator.non_standard_sampling_parameter_keys
         self.distance_marginalization = distance_marginalization
         self.phase_marginalization = phase_marginalization
@@ -153,7 +155,7 @@ class GravitationalWaveTransient(Likelihood):
 
 
 class BasicGravitationalWaveTransient(Likelihood):
-    """ A basic gravitaitonal wave transient likelihood
+    """ A basic gravitational wave transient likelihood
 
     The simplest frequency-domain gravitational wave transient likelihood. Does
     not include distance/phase marginalization.
@@ -170,11 +172,12 @@ class BasicGravitationalWaveTransient(Likelihood):
     Returns
     -------
     Likelihood: `tupak.likelihood.Likelihood`
-        A likehood object, able to compute the likelihood of the data given
+        A likelihood object, able to compute the likelihood of the data given
         some model parameters
 
     """
     def __init__(self, interferometers, waveform_generator):
+        Likelihood.__init__(self, waveform_generator.parameters)
         self.interferometers = interferometers
         self.waveform_generator = waveform_generator
 
@@ -230,3 +233,53 @@ def get_binary_black_hole_likelihood(interferometers):
     likelihood = tupak.likelihood.GravitationalWaveTransient(interferometers, waveform_generator)
     return likelihood
 
+
+class HyperparameterLikelihood(Likelihood):
+    """ A likelihood for infering hyperparameter posterior distributions
+
+    See Eq. (1) of https://arxiv.org/abs/1801.02699 for a definition.
+
+    Parameters
+    ----------
+    samples: list
+        An N-dimensional list of individual sets of samples. Each set may have
+        a different size.
+    hyper_prior: `tupak.prior.Prior`
+        A prior distribution with a `parameters` argument pointing to the
+        hyperparameters to infer from the samples. These may need to be
+        initialized to any arbitrary value, but this will not effect the
+        result.
+    run_prior: `tupak.prior.Prior`
+        The prior distribution used in the inidivudal inferences which resulted
+        in the set of samples.
+
+    """
+
+    def __init__(self, samples, hyper_prior, run_prior):
+        Likelihood.__init__(self, parameters=hyper_prior.__dict__)
+        self.samples = samples
+        self.hyper_prior = hyper_prior
+        self.run_prior = run_prior
+        if hasattr(hyper_prior, 'lnprob') and hasattr(run_prior, 'lnprob'):
+            logging.info("Using log-probabilities in likelihood")
+            self.log_likelihood = self.log_likelihood_using_lnprob
+        else:
+            logging.info("Using probabilities in likelihood")
+            self.log_likelihood = self.log_likelihood_using_prob
+
+    def log_likelihood_using_lnprob(self):
+        L = []
+        self.hyper_prior.__dict__.update(self.parameters)
+        for samp in self.samples:
+            f = self.hyper_prior.lnprob(samp) - self.run_prior.lnprob(samp)
+            L.append(logsumexp(f))
+        return np.sum(L)
+
+    def log_likelihood_using_prob(self):
+        L = []
+        self.hyper_prior.__dict__.update(self.parameters)
+        for samp in self.samples:
+            L.append(
+                np.sum(self.hyper_prior.prob(samp) /
+                       self.run_prior.prob(samp)))
+        return np.sum(np.log(L))
diff --git a/tupak/prior.py b/tupak/prior.py
index 680c38602176f6c90bb077c24c8222e90104a967..b25c23cc11a8a0b58c65d0c28fa2f573b7eed784 100644
--- a/tupak/prior.py
+++ b/tupak/prior.py
@@ -173,6 +173,12 @@ class PowerLaw(Prior):
             return np.nan_to_num(val ** self.alpha * (1 + self.alpha) / (self.maximum ** (1 + self.alpha)
                                                                          - self.minimum ** (1 + self.alpha))) * in_prior
 
+    def lnprob(self, val):
+        in_prior = (val >= self.minimum) & (val <= self.maximum)
+        normalising = (1+self.alpha)/(self.maximum ** (1 + self.alpha)
+                                      - self.minimum ** (1 + self.alpha))
+        return self.alpha * np.log(val) * np.log(normalising) * in_prior
+
 
 class Uniform(PowerLaw):
     """Uniform prior"""
@@ -254,6 +260,9 @@ class Gaussian(Prior):
         """Return the prior probability of val"""
         return np.exp(-(self.mu - val)**2 / (2 * self.sigma**2)) / (2 * np.pi)**0.5 / self.sigma
 
+    def lnprob(self, val):
+        return -0.5*((self.mu - val)**2 / self.sigma**2 + np.log(2 * np.pi * self.sigma**2))
+
 
 class TruncatedGaussian(Prior):
     """
diff --git a/tupak/result.py b/tupak/result.py
index 07023b03cace71d61d721e7fffcd2817e547c1b0..681eb03bd5cf7a59c4d89011c9286e8d45fa48aa 100644
--- a/tupak/result.py
+++ b/tupak/result.py
@@ -3,14 +3,7 @@ import os
 import numpy as np
 import deepdish
 import pandas as pd
-
-try:
-    from chainconsumer import ChainConsumer
-except ImportError:
-    def ChainConsumer():
-        logging.warning(
-            "You do not have the optional module chainconsumer installed"
-            " unable to generate a corner plot")
+import corner
 
 
 def result_file_name(outdir, label):
@@ -34,10 +27,12 @@ def read_in_result(outdir=None, label=None, filename=None):
     """
     if filename is None:
         filename = result_file_name(outdir, label)
+    elif (outdir is None or label is None) and filename is None:
+        raise ValueError("No information given to load file")
     if os.path.isfile(filename):
         return Result(deepdish.io.load(filename))
     else:
-        raise ValueError("No information given to load file")
+        raise ValueError("No result found")
 
 
 class Result(dict):
@@ -102,103 +97,78 @@ class Result(dict):
                                  .format(k))
         return return_list
 
-    def plot_corner(self, save=True, **kwargs):
-        """ Plot a corner-plot using chain-consumer
+    def plot_corner(self, parameters=None, save=True, dpi=300, **kwargs):
+        """ Plot a corner-plot using corner
+
+        See https://corner.readthedocs.io/en/latest/ for a detailed API.
 
         Parameters
         ----------
+        parameters: list
+            If given, a list of the parameter names to include
         save: bool
             If true, save the image using the given label and outdir
+        **kwargs:
+            Other keyword arguments are passed to `corner.corner`. We set some
+            defaults to improve the basic look and feel, but these can all be
+            overridden.
 
         Returns
         -------
         fig:
             A matplotlib figure instance
+
         """
 
-        # Set some defaults (unless already set)
-        kwargs['figsize'] = kwargs.get('figsize', 'GROW')
-        if save:
-            filename = '{}/{}_corner.png'.format(self.outdir, self.label)
-            kwargs['filename'] = kwargs.get('filename', filename)
-            logging.info('Saving corner plot to {}'.format(kwargs['filename']))
+        defaults_kwargs = dict(
+            bins=50, smooth=0.9, label_kwargs=dict(fontsize=16),
+            title_kwargs=dict(fontsize=16), color='#0072C1',
+            truth_color='tab:orange', show_titles=True,
+            quantiles=[0.025, 0.975], levels=(0.39, 0.8, 0.97),
+            plot_density=False, plot_datapoints=True, fill_contours=True,
+            max_n_ticks=3)
+
+        defaults_kwargs.update(kwargs)
+        kwargs = defaults_kwargs
+
+        if 'truth' in kwargs:
+            kwargs['truths'] = kwargs.pop('truth')
+
         if getattr(self, 'injection_parameters', None) is not None:
-            # If no truth argument given, set these to the injection params
             injection_parameters = [self.injection_parameters[key]
                                     for key in self.search_parameter_keys]
-            kwargs['truth'] = kwargs.get('truth', injection_parameters)
-
-        if type(kwargs.get('truth')) == dict:
-            old_keys = kwargs['truth'].keys()
-            new_keys = self.get_latex_labels_from_parameter_keys(old_keys)
-            for old, new in zip(old_keys, new_keys):
-                kwargs['truth'][new] = kwargs['truth'].pop(old)
-        if 'parameters' in kwargs:
-            kwargs['parameters'] = self.get_latex_labels_from_parameter_keys(
-                kwargs['parameters'])
-
-        # Check all parameter_labels are a valid string
-        for i, label in enumerate(self.parameter_labels):
-            if label is None:
-                self.parameter_labels[i] = 'Unknown'
-        c = ChainConsumer()
-        if c:
-            c.add_chain(self.samples, parameters=self.parameter_labels,
-                        name=self.label)
-            fig = c.plotter.plot(**kwargs)
-            return fig
+            kwargs['truths'] = kwargs.get('truths', injection_parameters)
 
-    def plot_walks(self, save=True, **kwargs):
-        """ Plot the chain walks using chain-consumer
+        if parameters is None:
+            parameters = self.search_parameter_keys
 
-        Parameters
-        ----------
-        save: bool
-            If true, save the image using the given label and outdir
+        xs = self.posterior[parameters].values
+        kwargs['labels'] = kwargs.get(
+            'labels', self.get_latex_labels_from_parameter_keys(
+                parameters))
 
-        Returns
-        -------
-        fig:
-            A matplotlib figure instance
-        """
+        if type(kwargs.get('truths')) == dict:
+            truths = [kwargs['truths'][k] for k in parameters]
+            kwargs['truths'] = truths
 
-        # Set some defaults (unless already set)
-        if save:
-            kwargs['filename'] = '{}/{}_walks.png'.format(self.outdir, self.label)
-            logging.info('Saving walker plot to {}'.format(kwargs['filename']))
-        if getattr(self, 'injection_parameters', None) is not None:
-            kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys]
-        c = ChainConsumer()
-        if c:
-            c.add_chain(self.samples, parameters=self.parameter_labels)
-            fig = c.plotter.plot_walks(**kwargs)
-            return fig
+        fig = corner.corner(xs, **kwargs)
 
-    def plot_distributions(self, save=True, **kwargs):
-        """ Plot the chain walks using chain-consumer
+        if save:
+            filename = '{}/{}_corner.png'.format(self.outdir, self.label)
+            logging.info('Saving corner plot to {}'.format(filename))
+            fig.savefig(filename, dpi=dpi)
 
-        Parameters
-        ----------
-        save: bool
-            If true, save the image using the given label and outdir
+        return fig
 
-        Returns
-        -------
-        fig:
-            A matplotlib figure instance
+    def plot_walks(self, save=True, **kwargs):
+        """
         """
+        logging.warning("plot_walks deprecated")
 
-        # Set some defaults (unless already set)
-        if save:
-            kwargs['filename'] = '{}/{}_distributions.png'.format(self.outdir, self.label)
-            logging.info('Saving distributions plot to {}'.format(kwargs['filename']))
-        if getattr(self, 'injection_parameters', None) is not None:
-            kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys]
-        c = ChainConsumer()
-        if c:
-            c.add_chain(self.samples, parameters=self.parameter_labels)
-            fig = c.plotter.plot_distributions(**kwargs)
-            return fig
+    def plot_distributions(self, save=True, **kwargs):
+        """
+        """
+        logging.warning("plot_distributions deprecated")
 
     def write_prior_to_file(self, outdir):
         """
diff --git a/tupak/sampler.py b/tupak/sampler.py
index 16f7a34fe8d2ca72c66f1c67b238911224e03bd6..0b6c0d316722affc0552513facb204a9e00e81db 100644
--- a/tupak/sampler.py
+++ b/tupak/sampler.py
@@ -19,7 +19,7 @@ class Sampler(object):
 
     Parameters
     ----------
-    likelihood: likelihood.GravitationalWaveTransient
+    likelihood: likelihood.Likelihood
         A  object with a log_l method
     prior: dict
         The prior to be used in the search. Elements can either be floats
@@ -36,8 +36,10 @@ class Sampler(object):
 
     """
 
-    def __init__(self, likelihood, priors, external_sampler='nestle', outdir='outdir', label='label', use_ratio=False,
-                 **kwargs):
+    def __init__(
+            self, likelihood, priors, external_sampler='nestle',
+            outdir='outdir', label='label', use_ratio=False, plot=False,
+            **kwargs):
         self.likelihood = likelihood
         self.priors = priors
         self.label = label
@@ -45,6 +47,7 @@ class Sampler(object):
         self.use_ratio = use_ratio
         self.external_sampler = external_sampler
         self.external_sampler_function = None
+        self.plot = plot
 
         self.__search_parameter_keys = []
         self.__fixed_parameter_keys = []
@@ -148,10 +151,12 @@ class Sampler(object):
             except AttributeError as e:
                 logging.warning('Cannot sample from {}, {}'.format(key, e))
         try:
-            self.likelihood.log_likelihood_ratio()
-        except TypeError:
-            raise TypeError('GravitationalWaveTransient evaluation failed. Have you definitely specified all the parameters?\n{}'.format(
-                self.likelihood.parameters))
+            self.likelihood.log_likelihood()
+        except TypeError as e:
+            raise TypeError(
+                "Likelihood evaluation failed with message: \n'{}'\n"
+                "Have you specified all the parameters:\n{}"
+                .format(e, self.likelihood.parameters))
 
     def prior_transform(self, theta):
         return [self.priors[key].rescale(t) for key, t in zip(self.__search_parameter_keys, theta)]
@@ -325,8 +330,20 @@ class Dynesty(Sampler):
             out.samples, weights)
         self.result.logz = out.logz[-1]
         self.result.logzerr = out.logzerr[-1]
+
+        if self.plot:
+            self.generate_trace_plots(out)
         return self.result
 
+    def generate_trace_plots(self, dynesty_results):
+        filename = '{}/{}_trace.png'.format(self.outdir, self.label)
+        logging.info("Writing trace plot to {}".format(filename))
+        from dynesty import plotting as dyplot
+        fig, axes = dyplot.traceplot(dynesty_results,
+                                     labels=self.result.parameter_labels)
+        fig.tight_layout()
+        fig.savefig(filename)
+
     def _run_test(self):
         dynesty = self.external_sampler
         nested_sampler = dynesty.NestedSampler(
@@ -434,7 +451,7 @@ class Ptemcee(Sampler):
 
 def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
                 sampler='nestle', use_ratio=True, injection_parameters=None,
-                conversion_function=None, **kwargs):
+                conversion_function=None, plot=False, **kwargs):
     """
     The primary interface to easy parameter estimation
 
@@ -459,7 +476,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
     injection_parameters: dict
         A dictionary of injection parameters used in creating the data (if
         using simulated data). Appended to the result object and saved.
-
+    plot: bool
+        If true, generate a corner plot and, if applicable diagnostic plots
     conversion_function: function, optional
         Function to apply to posterior to generate additional parameters.
     **kwargs:
@@ -482,7 +500,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
     if implemented_samplers.__contains__(sampler.title()):
         sampler_class = globals()[sampler.title()]
         sampler = sampler_class(likelihood, priors, sampler, outdir=outdir,
-                                label=label, use_ratio=use_ratio,
+                                label=label, use_ratio=use_ratio, plot=plot,
                                 **kwargs)
 
         if sampler.cached_result:
@@ -509,6 +527,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
         result.samples_to_data_frame(likelihood=likelihood, priors=priors, conversion_function=conversion_function)
         result.kwargs = sampler.kwargs
         result.save_to_file(outdir=outdir, label=label)
+        if plot:
+            result.plot_corner()
         return result
     else:
         raise ValueError(