From bb076348e315e5c40d80de8a59e2624109f43cf6 Mon Sep 17 00:00:00 2001
From: Rhiannon Udall <rhiannon.udall@ligo.org>
Date: Thu, 27 Jul 2023 14:21:12 +0000
Subject: [PATCH] DEV: Add linestyle option for plot_multiple

---
 bilby/core/result.py | 30 ++++++++++++++++++++++++------
 1 file changed, 24 insertions(+), 6 deletions(-)

diff --git a/bilby/core/result.py b/bilby/core/result.py
index 7bfe9d585..eafb97814 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -1177,22 +1177,33 @@ class Result(object):
         if utils.command_line_args.bilby_test_mode:
             return
 
-        # bilby default corner kwargs. Overwritten by anything passed to kwargs
         defaults_kwargs = dict(
-            bins=50, smooth=0.9, label_kwargs=dict(fontsize=16),
+            bins=50, smooth=0.9,
             title_kwargs=dict(fontsize=16), color='#0072C1',
             truth_color='tab:orange', quantiles=[0.16, 0.84],
             levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.)),
             plot_density=False, plot_datapoints=True, fill_contours=True,
-            max_n_ticks=3, hist_kwargs=dict(density=True))
+            max_n_ticks=3)
 
         if 'lionize' in kwargs and kwargs['lionize'] is True:
             defaults_kwargs['truth_color'] = 'tab:blue'
             defaults_kwargs['color'] = '#FF8C00'
 
+        label_kwargs_defaults = dict(fontsize=16)
+        hist_kwargs_defaults = dict(density=True)
+
+        label_kwargs_input = kwargs.get("label_kwargs", dict())
+        hist_kwargs_input = kwargs.get("hist_kwargs", dict())
+
+        label_kwargs_defaults.update(label_kwargs_input)
+        hist_kwargs_defaults.update(hist_kwargs_input)
+
         defaults_kwargs.update(kwargs)
         kwargs = defaults_kwargs
 
+        kwargs["label_kwargs"] = label_kwargs_defaults
+        kwargs["hist_kwargs"] = hist_kwargs_defaults
+
         # Handle if truths was passed in
         if 'truth' in kwargs:
             kwargs['truths'] = kwargs.pop('truth')
@@ -1967,7 +1978,8 @@ class ResultList(list):
 
 @latex_plot_format
 def plot_multiple(results, filename=None, labels=None, colours=None,
-                  save=True, evidences=False, corner_labels=None, **kwargs):
+                  save=True, evidences=False, corner_labels=None, linestyles=None,
+                  **kwargs):
     """ Generate a corner plot overlaying two sets of results
 
     Parameters
@@ -2021,11 +2033,17 @@ def plot_multiple(results, filename=None, labels=None, colours=None,
             c = colours[i]
         else:
             c = 'C{}'.format(i)
+        if linestyles is not None:
+            linestyle = linestyles[i]
+        else:
+            linestyle = 'solid'
         hist_kwargs = kwargs.get('hist_kwargs', dict())
         hist_kwargs['color'] = c
-        fig = result.plot_corner(fig=fig, save=False, color=c, **kwargs)
+        hist_kwargs["linestyle"] = linestyle
+        kwargs["hist_kwargs"] = hist_kwargs
+        fig = result.plot_corner(fig=fig, save=False, color=c, contour_kwargs={"linestyle": linestyle}, **kwargs)
         default_filename += '_{}'.format(result.label)
-        lines.append(mpllines.Line2D([0], [0], color=c))
+        lines.append(mpllines.Line2D([0], [0], color=c, linestyle=linestyle))
         default_labels.append(result.label)
 
     # Rescale the axes
-- 
GitLab