From 6491d7086675d3d7b5a05823e3a97416ee5e54c4 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Mon, 4 May 2020 12:37:31 +1000
Subject: [PATCH] Adds a converter for latex labels

Currently, corner/matplotlib chokes on labels based on
filenames/parameter names. This does a sanity check first to fix this.
---
 bilby/core/result.py | 12 ++++++++++++
 test/result_test.py  |  7 +++++++
 2 files changed, 19 insertions(+)

diff --git a/bilby/core/result.py b/bilby/core/result.py
index cefef401d..a35b13991 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -927,6 +927,8 @@ class Result(object):
             'labels', self.get_latex_labels_from_parameter_keys(
                 plot_parameter_keys))
 
+        kwargs["labels"] = sanity_check_labels(kwargs["labels"])
+
         # Unless already set, set the range to include all samples
         # This prevents ValueErrors being raised for parameters with no range
         kwargs['range'] = kwargs.get('range', [1] * len(plot_parameter_keys))
@@ -1584,6 +1586,8 @@ def plot_multiple(results, filename=None, labels=None, colours=None,
     if labels is None:
         labels = default_labels
 
+    labels = sanity_check_labels(labels)
+
     if evidences:
         if np.isnan(results[0].log_bayes_factor):
             template = ' $\mathrm{{ln}}(Z)={lnz:1.3g}$'
@@ -1720,6 +1724,14 @@ def make_pp_plot(results, filename=None, save=True, confidence_interval=[0.68, 0
     return fig, pvals
 
 
+def sanity_check_labels(labels):
+    """ Check labels for plotting to remove matplotlib errors """
+    for ii, lab in enumerate(labels):
+        if "_" in lab and "$" not in lab:
+            labels[ii] = lab.replace("_", "-")
+    return labels
+
+
 class ResultError(Exception):
     """ Base exception for all Result related errors """
 
diff --git a/test/result_test.py b/test/result_test.py
index b49f73435..7ef0c4047 100644
--- a/test/result_test.py
+++ b/test/result_test.py
@@ -722,5 +722,12 @@ class TestResultListError(unittest.TestCase):
             self.nested_results.combine()
 
 
+class TestMiscResults(unittest.TestCase):
+    def test_sanity_check_labels(self):
+        labels = ["a", "$a$", "a_1", "$a_1$"]
+        labels_checked = bilby.core.result.sanity_check_labels(labels)
+        self.assertEqual(labels_checked, ["a", "$a$", "a-1", "$a_1$"])
+
+
 if __name__ == "__main__":
     unittest.main()
-- 
GitLab