From ab6bd0cc3fc69f1d80bff7ec47ce95664a7b60d5 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Fri, 8 Oct 2021 11:59:43 +0000 Subject: [PATCH] Resolve "Add caching of the `generate_posterior_sample_from_marginalized_likelihood` method" --- bilby/core/prior/dict.py | 51 +++++++++++-------- bilby/core/result.py | 27 +++++------ bilby/core/sampler/__init__.py | 73 ++++++++++++++++------------ bilby/core/sampler/base_sampler.py | 5 +- bilby/core/sampler/dynesty.py | 5 ++ bilby/core/utils/io.py | 9 ++++ bilby/gw/conversion.py | 78 +++++++++++++++++++++++++++--- 7 files changed, 174 insertions(+), 74 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 275965906..15e6199da 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -423,27 +423,38 @@ class PriorDict(dict): } return all_samples - def normalize_constraint_factor(self, keys): + def normalize_constraint_factor(self, keys, min_accept=10000, sampling_chunk=50000, nrepeats=10): if keys in self._cached_normalizations.keys(): return self._cached_normalizations[keys] else: - min_accept = 1000 - sampling_chunk = 5000 + factor_estimates = [ + self._estimate_normalization(keys, min_accept, sampling_chunk) + for _ in range(nrepeats) + ] + factor = np.mean(factor_estimates) + if np.std(factor_estimates) > 0: + decimals = int(-np.floor(np.log10(3 * np.std(factor_estimates)))) + factor_rounded = np.round(factor, decimals) + else: + factor_rounded = factor + self._cached_normalizations[keys] = factor_rounded + return factor_rounded + + def _estimate_normalization(self, keys, min_accept, sampling_chunk): + samples = self.sample_subset(keys=keys, size=sampling_chunk) + keep = np.atleast_1d(self.evaluate_constraints(samples)) + if len(keep) == 1: + self._cached_normalizations[keys] = 1 + return 1 + all_samples = {key: np.array([]) for key in keys} + while np.count_nonzero(keep) < min_accept: samples = self.sample_subset(keys=keys, size=sampling_chunk) - keep = np.atleast_1d(self.evaluate_constraints(samples)) - if len(keep) == 1: - self._cached_normalizations[keys] = 1 - return 1 - all_samples = {key: np.array([]) for key in keys} - while np.count_nonzero(keep) < min_accept: - samples = self.sample_subset(keys=keys, size=sampling_chunk) - for key in samples: - all_samples[key] = np.hstack( - [all_samples[key], samples[key].flatten()]) - keep = np.array(self.evaluate_constraints(all_samples), dtype=bool) - factor = len(keep) / np.count_nonzero(keep) - self._cached_normalizations[keys] = factor - return factor + for key in samples: + all_samples[key] = np.hstack( + [all_samples[key], samples[key].flatten()]) + keep = np.array(self.evaluate_constraints(all_samples), dtype=bool) + factor = len(keep) / np.count_nonzero(keep) + return factor def prob(self, sample, **kwargs): """ @@ -468,11 +479,11 @@ class PriorDict(dict): def check_prob(self, sample, prob): ratio = self.normalize_constraint_factor(tuple(sample.keys())) if np.all(prob == 0.): - return prob + return prob * ratio else: if isinstance(prob, float): if self.evaluate_constraints(sample): - return prob + return prob * ratio else: return 0. else: @@ -508,7 +519,7 @@ class PriorDict(dict): else: if isinstance(ln_prob, float): if self.evaluate_constraints(sample): - return ln_prob + return ln_prob + np.log(ratio) else: return -np.inf else: diff --git a/bilby/core/result.py b/bilby/core/result.py index dbe7e29ac..060ce6430 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -1429,20 +1429,19 @@ class Result(object): Function which adds in extra parameters to the data frame, should take the data_frame, likelihood and prior as arguments. """ - try: - data_frame = self.posterior - except ValueError: - data_frame = pd.DataFrame( - self.samples, columns=self.search_parameter_keys) - data_frame = self._add_prior_fixed_values_to_posterior( - data_frame, priors) - data_frame['log_likelihood'] = getattr( - self, 'log_likelihood_evaluations', np.nan) - if self.log_prior_evaluations is None and priors is not None: - data_frame['log_prior'] = priors.ln_prob( - dict(data_frame[self.search_parameter_keys]), axis=0) - else: - data_frame['log_prior'] = self.log_prior_evaluations + + data_frame = pd.DataFrame( + self.samples, columns=self.search_parameter_keys) + data_frame = self._add_prior_fixed_values_to_posterior( + data_frame, priors) + data_frame['log_likelihood'] = getattr( + self, 'log_likelihood_evaluations', np.nan) + if self.log_prior_evaluations is None and priors is not None: + data_frame['log_prior'] = priors.ln_prob( + dict(data_frame[self.search_parameter_keys]), axis=0) + else: + data_frame['log_prior'] = self.log_prior_evaluations + if conversion_function is not None: if "npool" in inspect.getargspec(conversion_function).args: data_frame = conversion_function(data_frame, likelihood, priors, npool=npool) diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index eb52cf289..77e481e95 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -173,7 +173,9 @@ def run_sampler( # Generate the meta-data if not given and append the likelihood meta_data if meta_data is None: meta_data = dict() - meta_data["likelihood"] = likelihood.meta_data + likelihood.label = label + likelihood.outdir = outdir + meta_data['likelihood'] = likelihood.meta_data meta_data["loaded_modules"] = loaded_modules_dict() if command_line_args.bilby_zero_likelihood_mode: @@ -223,45 +225,52 @@ def run_sampler( if sampler.cached_result: logger.warning("Using cached result") - return sampler.cached_result - - start_time = datetime.datetime.now() - if command_line_args.bilby_test_mode: - result = sampler._run_test() + result = sampler.cached_result else: - result = sampler.run_sampler() - end_time = datetime.datetime.now() + # Run the sampler + start_time = datetime.datetime.now() + if command_line_args.bilby_test_mode: + result = sampler._run_test() + else: + result = sampler.run_sampler() + end_time = datetime.datetime.now() - # Some samplers calculate the sampling time internally - if result.sampling_time is None: - result.sampling_time = end_time - start_time - elif isinstance(result.sampling_time, float): - result.sampling_time = datetime.timedelta(result.sampling_time) - logger.info("Sampling time: {}".format(result.sampling_time)) - # Convert sampling time into seconds - result.sampling_time = result.sampling_time.total_seconds() + # Some samplers calculate the sampling time internally + if result.sampling_time is None: + result.sampling_time = end_time - start_time + elif isinstance(result.sampling_time, (float, int)): + result.sampling_time = datetime.timedelta(result.sampling_time) - if sampler.use_ratio: - result.log_noise_evidence = likelihood.noise_log_likelihood() - result.log_bayes_factor = result.log_evidence - result.log_evidence = result.log_bayes_factor + result.log_noise_evidence - else: - result.log_noise_evidence = likelihood.noise_log_likelihood() - result.log_bayes_factor = result.log_evidence - result.log_noise_evidence + logger.info('Sampling time: {}'.format(result.sampling_time)) + # Convert sampling time into seconds + result.sampling_time = result.sampling_time.total_seconds() - # Initial save of the sampler in case of failure in post-processing - if save: - result.save_to_file(extension=save, gzip=gzip) + if sampler.use_ratio: + result.log_noise_evidence = likelihood.noise_log_likelihood() + result.log_bayes_factor = result.log_evidence + result.log_evidence = \ + result.log_bayes_factor + result.log_noise_evidence + else: + result.log_noise_evidence = likelihood.noise_log_likelihood() + result.log_bayes_factor = \ + result.log_evidence - result.log_noise_evidence + + if None not in [result.injection_parameters, conversion_function]: + result.injection_parameters = conversion_function( + result.injection_parameters) + + # Initial save of the sampler in case of failure in samples_to_posterior + if save: + result.save_to_file(extension=save, gzip=gzip) if None not in [result.injection_parameters, conversion_function]: result.injection_parameters = conversion_function(result.injection_parameters) - result.samples_to_posterior( - likelihood=likelihood, - priors=result.priors, - conversion_function=conversion_function, - npool=npool, - ) + # Check if the posterior has already been created + if getattr(result, "_posterior", None) is None: + result.samples_to_posterior(likelihood=likelihood, priors=result.priors, + conversion_function=conversion_function, + npool=npool) if save: # The overwrite here ensures we overwrite the initially stored data diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index 5e25a7e54..e590ee242 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -505,14 +505,15 @@ class Sampler(object): logger.debug("Checking cached data") if self.cached_result: - check_keys = ['search_parameter_keys', 'fixed_parameter_keys', - 'kwargs'] + check_keys = ['search_parameter_keys', 'fixed_parameter_keys'] use_cache = True for key in check_keys: if self.cached_result._check_attribute_match_to_other_object( key, self) is False: logger.debug("Cached value {} is unmatched".format(key)) use_cache = False + if self.meta_data["likelihood"] != self.cached_result.meta_data["likelihood"]: + use_cache = False if use_cache is False: self.cached_result = None diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index bad9636be..f2d9051d5 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -610,6 +610,11 @@ class Dynesty(NestedSampler): from ... import __version__ as bilby_version from dynesty import __version__ as dynesty_version import dill + + if getattr(self, "sampler", None) is None: + # Sampler not initialized, not able to write current state + return + check_directory_exists_and_if_not_mkdir(self.outdir) end_time = datetime.datetime.now() if hasattr(self, 'start_time'): diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py index 85095dc6a..880abbfe6 100644 --- a/bilby/core/utils/io.py +++ b/bilby/core/utils/io.py @@ -4,6 +4,7 @@ import os import shutil from importlib import import_module from pathlib import Path +from datetime import timedelta import numpy as np import pandas as pd @@ -76,6 +77,12 @@ class BilbyJsonEncoder(json.JSONEncoder): "__module__": obj.__module__, "__name__": obj.__name__, } + if isinstance(obj, (timedelta)): + return { + "__timedelta__": True, + "__total_seconds__": obj.total_seconds() + } + return obj.isoformat() return json.JSONEncoder.default(self, obj) @@ -171,6 +178,8 @@ def decode_bilby_json(dct): if dct.get("__function__", False) or dct.get("__class__", False): default = ".".join([dct["__module__"], dct["__name__"]]) return getattr(import_module(dct["__module__"]), dct["__name__"], default) + if dct.get("__timedelta__", False): + return timedelta(seconds=dct["__total_seconds__"]) return dct diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 758b0fdbe..492a6201f 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -1,11 +1,13 @@ +import os import sys import multiprocessing +import pickle import numpy as np from pandas import DataFrame from ..core.likelihood import MarginalizedLikelihoodReconstructionError -from ..core.utils import logger, solar_mass +from ..core.utils import logger, solar_mass, command_line_args from ..core.prior import DeltaFunction from .utils import lalsim_SimInspiralTransformPrecessingNewInitialConditions from .eos.eos import SpectralDecompositionEOS, EOSFamily, IntegrateTOV @@ -1201,7 +1203,7 @@ def _compute_snrs(args): def generate_posterior_samples_from_marginalized_likelihood( - samples, likelihood, npool=1): + samples, likelihood, npool=1, block=10, use_cache=True): """ Reconstruct the distance posterior from a run which used a likelihood which explicitly marginalised over time/distance/phase. @@ -1216,6 +1218,11 @@ def generate_posterior_samples_from_marginalized_likelihood( Likelihood used during sampling. npool: int, (default=1) If given, perform generation (where possible) using a multiprocessing pool + block: int, (default=10) + Size of the blocks to use in multiprocessing + use_cache: bool, (default=True) + If true, cache the generation so that reconstuction can begin from the + cache on restart. Returns ======= @@ -1237,23 +1244,82 @@ def generate_posterior_samples_from_marginalized_likelihood( logger.info('Reconstructing marginalised parameters.') - fill_args = [(ii, row, likelihood) for ii, row in samples.iterrows()] + try: + cache_filename = f"{likelihood.outdir}/.{likelihood.label}_generate_posterior_cache.pickle" + except AttributeError: + logger.warning("Likelihood has no outdir and label attribute: caching disabled") + use_cache = False + + if use_cache and os.path.exists(cache_filename) and not command_line_args.clean: + with open(cache_filename, "rb") as f: + cached_samples_dict = pickle.load(f) + + # Check the samples are identical between the cache and current + if cached_samples_dict["_samples"].equals(samples): + # Calculate reconstruction percentage and print a log message + nsamples_converted = np.sum( + [len(val) for key, val in cached_samples_dict.items() if key != "_samples"] + ) + perc = 100 * nsamples_converted / len(cached_samples_dict["_samples"]) + logger.info(f'Using cached reconstruction with {perc:0.1f}% converted.') + else: + logger.info("Cached samples dict out of date, ignoring") + cached_samples_dict = dict(_samples=samples) + + else: + # Initialize cache dict + cached_samples_dict = dict() + + # Store samples to convert for checking + cached_samples_dict["_samples"] = samples + + # Set up the multiprocessing if npool > 1: pool = multiprocessing.Pool(processes=npool) logger.info( "Using a pool with size {} for nsamples={}" .format(npool, len(samples)) ) - new_samples = np.array(pool.map(fill_sample, tqdm(fill_args, file=sys.stdout))) - pool.close() else: - new_samples = np.array([fill_sample(xx) for xx in tqdm(fill_args, file=sys.stdout)]) + pool = None + + fill_args = [(ii, row, likelihood) for ii, row in samples.iterrows()] + ii = 0 + pbar = tqdm(total=len(samples), file=sys.stdout) + while ii < len(samples): + if ii in cached_samples_dict: + ii += block + pbar.update(block) + continue + + if pool is not None: + subset_samples = pool.map(fill_sample, fill_args[ii: ii + block]) + else: + subset_samples = [list(fill_sample(xx)) for xx in fill_args[ii: ii + block]] + + cached_samples_dict[ii] = subset_samples + + if use_cache: + with open(cache_filename, "wb") as f: + pickle.dump(cached_samples_dict, f) + + ii += block + pbar.update(len(subset_samples)) + pbar.close() + + if pool is not None: + pool.close() + + new_samples = np.concatenate( + [np.array(val) for key, val in cached_samples_dict.items() if key != "_samples"] + ) samples['geocent_time'] = new_samples[:, 0] samples['luminosity_distance'] = new_samples[:, 1] samples['phase'] = new_samples[:, 2] if likelihood.calibration_marginalization: samples['recalib_index'] = new_samples[:, 3] + return samples -- GitLab