From 46f1403da2ec4620e843360a4e0335868be017e9 Mon Sep 17 00:00:00 2001 From: John Veitch <john.veitch@ligo.org> Date: Thu, 16 Mar 2023 02:16:16 +0000 Subject: [PATCH] Allow de-marginalisation of calibration along with time --- bilby/core/result.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/bilby/core/result.py b/bilby/core/result.py index 7c8715594..d5584e527 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -6,6 +6,8 @@ from collections import namedtuple from copy import copy from importlib import import_module from itertools import product +import multiprocessing +from functools import partial import numpy as np import pandas as pd @@ -29,6 +31,9 @@ from .prior import Prior, PriorDict, DeltaFunction, ConditionalDeltaFunction EXTENSIONS = ["json", "hdf5", "h5", "pickle", "pkl"] +def __eval_l(l, p): + l.parameters.update(p) + return l.log_likelihood() def result_file_name(outdir, label, extension='json', gzip=False): """ Returns the standard filename used for a result file @@ -104,7 +109,7 @@ def read_in_result(filename=None, outdir=None, label=None, extension='json', gzi def get_weights_for_reweighting( result, new_likelihood=None, new_prior=None, old_likelihood=None, - old_prior=None, resume_file=None, n_checkpoint=5000): + old_prior=None, resume_file=None, n_checkpoint=5000, npool=1): """ Calculate the weights for reweight() See bilby.core.result.reweight() for help with the inputs @@ -145,30 +150,31 @@ def get_weights_for_reweighting( starting_index = 0 - for ii, sample in tqdm(result.posterior.iloc[starting_index:].iterrows()): - # Convert sample to dictionary - par_sample = {key: sample[key] for key in result.posterior} + dict_samples = [{key: sample[key] for key in result.posterior} for i,sample in result.posterior.iterrows()] + + with multiprocessing.Pool(processes=npool) as pool: + logger.info( + "Using a pool with size {} for nsamples={}".format(npool, len(dict_samples)) + ) if old_likelihood is not None: - old_likelihood.parameters.update(par_sample) - old_log_likelihood_array[ii] = old_likelihood.log_likelihood() + old_log_likelihood_array[starting_index:] = pool.starmap(__eval_l, [(old_likelihood, s) for s in dict_samples[starting_index:]]) else: - old_log_likelihood_array[ii] = sample["log_likelihood"] - + old_log_likelihood_array[starting_index:] = sample["log_likelihood"] if new_likelihood is not None: - new_likelihood.parameters.update(par_sample) - new_log_likelihood_array[ii] = new_likelihood.log_likelihood() + new_log_likelihood_array[starting_index:] = pool.starmap(__eval_l, [(new_likelihood, s) for s in dict_samples[starting_index:]]) else: # Don't perform likelihood reweighting (i.e. likelihood isn't updated) - new_log_likelihood_array[ii] = old_log_likelihood_array[ii] + new_log_likelihood_array[starting_index:] = old_log_likelihood_array[starting_index:] + for ii, sample in enumerate(tqdm(dict_samples[starting_index:]), start = starting_index): if old_prior is not None: - old_log_prior_array[ii] = old_prior.ln_prob(par_sample) + old_log_prior_array[ii] = old_prior.ln_prob(dict_samples[ii]) else: old_log_prior_array[ii] = sample["log_prior"] if new_prior is not None: - new_log_prior_array[ii] = new_prior.ln_prob(par_sample) + new_log_prior_array[ii] = new_prior.ln_prob(dict_samples[ii]) else: # Don't perform prior reweighting (i.e. prior isn't updated) new_log_prior_array[ii] = old_log_prior_array[ii] @@ -272,7 +278,7 @@ def reweight(result, label=None, new_likelihood=None, new_prior=None, get_weights_for_reweighting( result, new_likelihood=new_likelihood, new_prior=new_prior, old_likelihood=old_likelihood, old_prior=old_prior, - resume_file=resume_file, n_checkpoint=n_checkpoint) + resume_file=resume_file, n_checkpoint=n_checkpoint, npool=npool) if use_nested_samples: ln_weights += np.log(result.posterior["weights"]) @@ -300,6 +306,7 @@ def reweight(result, label=None, new_likelihood=None, new_prior=None, if conversion_function is not None: data_frame = result.posterior if "npool" in inspect.signature(conversion_function).parameters: + logger.info(f"Convertng with {npool} threads") data_frame = conversion_function(data_frame, new_likelihood, new_prior, npool=npool) else: data_frame = conversion_function(data_frame, new_likelihood, new_prior) -- GitLab