From f52cdc87f11782619f82ad9ab1357743d77b3f64 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Mon, 1 Oct 2018 01:00:29 -0500
Subject: [PATCH] Resolve 197

---
 test/result_test.py  | 73 ++++++++++++++++++++++++++++++++++++++++++++
 tupak/core/result.py | 14 +++++----
 2 files changed, 81 insertions(+), 6 deletions(-)
 create mode 100644 test/result_test.py

diff --git a/test/result_test.py b/test/result_test.py
new file mode 100644
index 000000000..5f8904b6b
--- /dev/null
+++ b/test/result_test.py
@@ -0,0 +1,73 @@
+from __future__ import absolute_import, division
+
+import tupak
+import unittest
+import numpy as np
+import pandas as pd
+import shutil
+
+
+class TestResult(unittest.TestCase):
+
+    def setUp(self):
+        tupak.utils.command_line_args.test = False
+        result = tupak.core.result.Result()
+        test_directory = 'test_directory'
+        result.outdir = test_directory
+        result.label = 'test'
+
+        N = 100
+        posterior = pd.DataFrame(dict(x=np.random.normal(0, 1, N),
+                                      y=np.random.normal(0, 1, N)))
+        result.search_parameter_keys = ['x', 'y']
+        result.parameter_labels_with_unit = ['x', 'y']
+        result.posterior = posterior
+        self.result = result
+        pass
+
+    def tearDown(self):
+        tupak.utils.command_line_args.test = True
+        try:
+            shutil.rmtree(self.result.outdir)
+        except OSError:
+            pass
+        del self.result
+        pass
+
+    def test_plot_corner(self):
+        self.result.injection_parameters = dict(x=0.8, y=1.1)
+        self.result.plot_corner()
+        self.result.plot_corner(parameters=['x', 'y'])
+        self.result.plot_corner(parameters=['x', 'y'], truths=[1, 1])
+        self.result.plot_corner(parameters=dict(x=1, y=1))
+        self.result.plot_corner(truths=dict(x=1, y=1))
+        self.result.plot_corner(truth=dict(x=1, y=1))
+        with self.assertRaises(ValueError):
+            self.result.plot_corner(truths=dict(x=1, y=1),
+                                    parameters=dict(x=1, y=1))
+        with self.assertRaises(ValueError):
+            self.result.plot_corner(truths=[1, 1],
+                                    parameters=dict(x=1, y=1))
+        with self.assertRaises(ValueError):
+            self.result.plot_corner(parameters=['x', 'y'],
+                                    truths=dict(x=1, y=1))
+
+    def test_plot_corner_with_injection_parameters(self):
+        self.result.plot_corner()
+        self.result.plot_corner(parameters=['x', 'y'])
+        self.result.plot_corner(parameters=['x', 'y'], truths=[1, 1])
+        self.result.plot_corner(parameters=dict(x=1, y=1))
+
+    def test_plot_corner_with_priors(self):
+        priors = tupak.core.prior.PriorSet()
+        priors['x'] = tupak.core.prior.Uniform(-1, 1, 'x')
+        priors['y'] = tupak.core.prior.Uniform(-1, 1, 'y')
+        self.result.plot_corner(priors=priors)
+        self.result.priors = priors
+        self.result.plot_corner(priors=True)
+        with self.assertRaises(ValueError):
+            self.result.plot_corner(priors='test')
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tupak/core/result.py b/tupak/core/result.py
index a2b952f5f..018c9a120 100644
--- a/tupak/core/result.py
+++ b/tupak/core/result.py
@@ -362,12 +362,6 @@ class Result(dict):
         defaults_kwargs.update(kwargs)
         kwargs = defaults_kwargs
 
-        # If injection parameters where stored, use these as truth values
-        if getattr(self, 'injection_parameters', None) is not None:
-            injection_parameters = [self.injection_parameters.get(key, None)
-                                    for key in self.search_parameter_keys]
-            kwargs['truths'] = kwargs.get('truths', injection_parameters)
-
         # Handle if truths was passed in
         if 'truth' in kwargs:
             kwargs['truths'] = kwargs.pop('truth')
@@ -383,6 +377,14 @@ class Result(dict):
                 raise ValueError(
                     "Combination of parameters and truths not understood")
 
+        # If injection parameters where stored, use these as parameter values
+        # but do not overwrite input parameters (or truths)
+        cond1 = getattr(self, 'injection_parameters', None) is not None
+        cond2 = parameters is None
+        if cond1 and cond2:
+            parameters = {key: self.injection_parameters[key] for key in
+                          self.search_parameter_keys}
+
         # If parameters is a dictionary, use the keys to determine which
         # parameters to plot and the values as truths.
         if isinstance(parameters, dict):
-- 
GitLab