From 1d92131ca88489906e98cf541b17bd217246029d Mon Sep 17 00:00:00 2001
From: John Veitch <john.veitch@ligo.org>
Date: Thu, 16 Mar 2023 11:26:27 +0000
Subject: [PATCH] Ad the progress bar back in when computing LogLs

---
 bilby/core/result.py | 44 ++++++++++++++++++++++++++++----------------
 1 file changed, 28 insertions(+), 16 deletions(-)

diff --git a/bilby/core/result.py b/bilby/core/result.py
index de8a965e0..97a42a267 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -7,7 +7,7 @@ from copy import copy
 from importlib import import_module
 from itertools import product
 import multiprocessing
-
+from functools import partial
 import numpy as np
 import pandas as pd
 import scipy.stats
@@ -150,23 +150,35 @@ def get_weights_for_reweighting(
         starting_index = 0
 
     dict_samples = [{key: sample[key] for key in result.posterior} for i,sample in result.posterior.iterrows()]
+    n = len(dict_samples) - starting_index
+
+    # Helper function to compute likelihoods in parallel
+    def eval_pool(l):
+        with multiprocessing.Pool(processes=npool) as pool:
+            chunksize = max(100,n//(2*npool))
+            return list(tqdm(
+                    pool.imap(partial(__eval_l,l),
+                            dict_samples[starting_index:], chunksize=chunksize),
+                    desc = 'Computing likelihoods',
+                    total = n
+                ))
+
+    if old_likelihood is None:
+        old_log_likelihood_array[starting_index:] = sample["log_likelihood"]
+    else:
+        old_log_likelihood_array[starting_index:] = eval_pool(old_likelihood)
 
-    with multiprocessing.Pool(processes=npool) as pool:
-        logger.info(
-            "Using a pool with size {} for nsamples={}".format(npool, len(dict_samples))
-        )
-
-        if old_likelihood is not None:
-            old_log_likelihood_array[starting_index:] = pool.starmap(__eval_l, [(old_likelihood, s) for s in dict_samples[starting_index:]])
-        else:
-            old_log_likelihood_array[starting_index:] = sample["log_likelihood"]
-        if new_likelihood is not None:
-            new_log_likelihood_array[starting_index:] = pool.starmap(__eval_l, [(new_likelihood, s) for s in dict_samples[starting_index:]])
-        else:
-            # Don't perform likelihood reweighting (i.e. likelihood isn't updated)
-            new_log_likelihood_array[starting_index:] = old_log_likelihood_array[starting_index:]
+    if new_likelihood is None:
+        # Don't perform likelihood reweighting (i.e. likelihood isn't updated)
+        new_log_likelihood_array[starting_index:] = old_log_likelihood_array[starting_index:]
+    else:
+        new_log_likelihood_array[starting_index:] = eval_pool(new_likelihood)
 
-    for ii, sample in enumerate(tqdm(dict_samples[starting_index:]), start = starting_index):
+    # Compute priors
+    for ii, sample in enumerate(tqdm(dict_samples[starting_index:],
+                                     desc = 'Computing priors',
+                                     total = n),
+                                start = starting_index):
         if old_prior is not None:
             old_log_prior_array[ii] = old_prior.ln_prob(dict_samples[ii])
         else:
-- 
GitLab