From 08cddc461b80e83e3951e11a3606e32bd591cf87 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Sun, 29 May 2022 02:12:11 +0000
Subject: [PATCH] Add pp plot test

---
 bilby/core/result.py     |  5 +++--
 test/core/result_test.py | 33 +++++++++++++++++++++++++++++++++
 2 files changed, 36 insertions(+), 2 deletions(-)

diff --git a/bilby/core/result.py b/bilby/core/result.py
index 45ffebba5..137936887 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -1977,6 +1977,8 @@ def make_pp_plot(results, filename=None, save=True, confidence_interval=[0.68, 0
         The font size for the legend
     keys: list
         A list of keys to use, if None defaults to search_parameter_keys
+    title: bool
+        Whether to add the number of results and total p-value as a plot title
     confidence_interval_alpha: float, list, optional
         The transparency for the background condifence interval
     weight_list: list, optional
@@ -2001,8 +2003,7 @@ def make_pp_plot(results, filename=None, save=True, confidence_interval=[0.68, 0
     credible_levels = list()
     for i, result in enumerate(results):
         credible_levels.append(
-            result.get_all_injection_credible_levels(keys, weights=weight_list[i]),
-            ignore_index=True,
+            result.get_all_injection_credible_levels(keys, weights=weight_list[i])
         )
     credible_levels = pd.DataFrame(credible_levels)
 
diff --git a/test/core/result_test.py b/test/core/result_test.py
index fe83c8201..d70deb20a 100644
--- a/test/core/result_test.py
+++ b/test/core/result_test.py
@@ -678,5 +678,38 @@ class TestMiscResults(unittest.TestCase):
         self.assertEqual(labels_checked, ["a", "$a$", "a-1", "$a_1$"])
 
 
+class TestPPPlots(unittest.TestCase):
+
+    def setUp(self):
+        priors = bilby.core.prior.PriorDict(dict(
+            a=bilby.core.prior.Uniform(0, 1, latex_label="$a$"),
+            b=bilby.core.prior.Uniform(0, 1, latex_label="$b$"),
+        ))
+        self.results = [
+            bilby.core.result.Result(
+                label=str(ii),
+                outdir='.',
+                search_parameter_keys=list(priors.keys()),
+                priors=priors,
+                injection_parameters=priors.sample(),
+                posterior=pd.DataFrame(priors.sample(500)),
+            )
+            for ii in range(10)
+        ]
+
+    def test_make_pp_plot(self):
+        _ = bilby.core.result.make_pp_plot(self.results, save=False)
+
+    def test_pp_plot_raises_error_with_wrong_number_of_lines(self):
+        with self.assertRaises(ValueError):
+            _ = bilby.core.result.make_pp_plot(self.results, save=False, lines=["-"])
+
+    def test_pp_plot_raises_error_with_wrong_number_of_confidence_intervals(self):
+        with self.assertRaises(ValueError):
+            _ = bilby.core.result.make_pp_plot(
+                self.results, save=False, confidence_interval_alpha=[0.1]
+            )
+
+
 if __name__ == "__main__":
     unittest.main()
-- 
GitLab