Commit f52cdc87 authored by Gregory Ashton's avatar Gregory Ashton Committed by Colm Talbot

Resolve 197

parent cd8795f0
from __future__ import absolute_import, division
import tupak
import unittest
import numpy as np
import pandas as pd
import shutil
class TestResult(unittest.TestCase):
def setUp(self):
tupak.utils.command_line_args.test = False
result = tupak.core.result.Result()
test_directory = 'test_directory'
result.outdir = test_directory
result.label = 'test'
N = 100
posterior = pd.DataFrame(dict(x=np.random.normal(0, 1, N),
y=np.random.normal(0, 1, N)))
result.search_parameter_keys = ['x', 'y']
result.parameter_labels_with_unit = ['x', 'y']
result.posterior = posterior
self.result = result
pass
def tearDown(self):
tupak.utils.command_line_args.test = True
try:
shutil.rmtree(self.result.outdir)
except OSError:
pass
del self.result
pass
def test_plot_corner(self):
self.result.injection_parameters = dict(x=0.8, y=1.1)
self.result.plot_corner()
self.result.plot_corner(parameters=['x', 'y'])
self.result.plot_corner(parameters=['x', 'y'], truths=[1, 1])
self.result.plot_corner(parameters=dict(x=1, y=1))
self.result.plot_corner(truths=dict(x=1, y=1))
self.result.plot_corner(truth=dict(x=1, y=1))
with self.assertRaises(ValueError):
self.result.plot_corner(truths=dict(x=1, y=1),
parameters=dict(x=1, y=1))
with self.assertRaises(ValueError):
self.result.plot_corner(truths=[1, 1],
parameters=dict(x=1, y=1))
with self.assertRaises(ValueError):
self.result.plot_corner(parameters=['x', 'y'],
truths=dict(x=1, y=1))
def test_plot_corner_with_injection_parameters(self):
self.result.plot_corner()
self.result.plot_corner(parameters=['x', 'y'])
self.result.plot_corner(parameters=['x', 'y'], truths=[1, 1])
self.result.plot_corner(parameters=dict(x=1, y=1))
def test_plot_corner_with_priors(self):
priors = tupak.core.prior.PriorSet()
priors['x'] = tupak.core.prior.Uniform(-1, 1, 'x')
priors['y'] = tupak.core.prior.Uniform(-1, 1, 'y')
self.result.plot_corner(priors=priors)
self.result.priors = priors
self.result.plot_corner(priors=True)
with self.assertRaises(ValueError):
self.result.plot_corner(priors='test')
if __name__ == '__main__':
unittest.main()
......@@ -362,12 +362,6 @@ class Result(dict):
defaults_kwargs.update(kwargs)
kwargs = defaults_kwargs
# If injection parameters where stored, use these as truth values
if getattr(self, 'injection_parameters', None) is not None:
injection_parameters = [self.injection_parameters.get(key, None)
for key in self.search_parameter_keys]
kwargs['truths'] = kwargs.get('truths', injection_parameters)
# Handle if truths was passed in
if 'truth' in kwargs:
kwargs['truths'] = kwargs.pop('truth')
......@@ -383,6 +377,14 @@ class Result(dict):
raise ValueError(
"Combination of parameters and truths not understood")
# If injection parameters where stored, use these as parameter values
# but do not overwrite input parameters (or truths)
cond1 = getattr(self, 'injection_parameters', None) is not None
cond2 = parameters is None
if cond1 and cond2:
parameters = {key: self.injection_parameters[key] for key in
self.search_parameter_keys}
# If parameters is a dictionary, use the keys to determine which
# parameters to plot and the values as truths.
if isinstance(parameters, dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment