From 6100eae9e93a5b8df69276566d5114cc89c44890 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Wed, 23 Dec 2020 00:17:40 -0600
Subject: [PATCH] Resolve "plot_waveform_posterior uses non representative
 samples"

---
 bilby/gw/result.py | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/bilby/gw/result.py b/bilby/gw/result.py
index 03f5d28d2..27f01c6fd 100644
--- a/bilby/gw/result.py
+++ b/bilby/gw/result.py
@@ -337,7 +337,7 @@ class CompactBinaryCoalescenceResult(CoreResult):
             interferometer.name))
 
         if n_samples is None:
-            n_samples = len(self.posterior)
+            samples = self.posterior
         elif n_samples > len(self.posterior):
             logger.debug(
                 "Requested more waveform samples ({}) than we have "
@@ -345,14 +345,16 @@ class CompactBinaryCoalescenceResult(CoreResult):
                     n_samples, len(self.posterior)
                 )
             )
-            n_samples = len(self.posterior)
+            samples = self.posterior
+        else:
+            samples = self.posterior.sample(n_samples, replace=False)
 
         if start_time is None:
             start_time = - 0.4
-        start_time = np.mean(self.posterior.geocent_time) + start_time
+        start_time = np.mean(samples.geocent_time) + start_time
         if end_time is None:
             end_time = 0.2
-        end_time = np.mean(self.posterior.geocent_time) + end_time
+        end_time = np.mean(samples.geocent_time) + end_time
         if format == "html":
             start_time = - np.inf
             end_time = np.inf
@@ -470,8 +472,8 @@ class CompactBinaryCoalescenceResult(CoreResult):
 
         fd_waveforms = list()
         td_waveforms = list()
-        for ii in range(n_samples):
-            params = dict(self.posterior.iloc[ii])
+        for _, params in samples.iterrows():
+            params = dict(params)
             wf_pols = waveform_generator.frequency_domain_strain(params)
             fd_waveform = interferometer.get_detector_response(wf_pols, params)
             fd_waveforms.append(fd_waveform[frequency_idxs])
-- 
GitLab