From ab6bd0cc3fc69f1d80bff7ec47ce95664a7b60d5 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Fri, 8 Oct 2021 11:59:43 +0000
Subject: [PATCH] Resolve "Add caching of the
 `generate_posterior_sample_from_marginalized_likelihood` method"

---
 bilby/core/prior/dict.py           | 51 +++++++++++--------
 bilby/core/result.py               | 27 +++++------
 bilby/core/sampler/__init__.py     | 73 ++++++++++++++++------------
 bilby/core/sampler/base_sampler.py |  5 +-
 bilby/core/sampler/dynesty.py      |  5 ++
 bilby/core/utils/io.py             |  9 ++++
 bilby/gw/conversion.py             | 78 +++++++++++++++++++++++++++---
 7 files changed, 174 insertions(+), 74 deletions(-)

diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py
index 275965906..15e6199da 100644
--- a/bilby/core/prior/dict.py
+++ b/bilby/core/prior/dict.py
@@ -423,27 +423,38 @@ class PriorDict(dict):
             }
             return all_samples
 
-    def normalize_constraint_factor(self, keys):
+    def normalize_constraint_factor(self, keys, min_accept=10000, sampling_chunk=50000, nrepeats=10):
         if keys in self._cached_normalizations.keys():
             return self._cached_normalizations[keys]
         else:
-            min_accept = 1000
-            sampling_chunk = 5000
+            factor_estimates = [
+                self._estimate_normalization(keys, min_accept, sampling_chunk)
+                for _ in range(nrepeats)
+            ]
+            factor = np.mean(factor_estimates)
+            if np.std(factor_estimates) > 0:
+                decimals = int(-np.floor(np.log10(3 * np.std(factor_estimates))))
+                factor_rounded = np.round(factor, decimals)
+            else:
+                factor_rounded = factor
+            self._cached_normalizations[keys] = factor_rounded
+            return factor_rounded
+
+    def _estimate_normalization(self, keys, min_accept, sampling_chunk):
+        samples = self.sample_subset(keys=keys, size=sampling_chunk)
+        keep = np.atleast_1d(self.evaluate_constraints(samples))
+        if len(keep) == 1:
+            self._cached_normalizations[keys] = 1
+            return 1
+        all_samples = {key: np.array([]) for key in keys}
+        while np.count_nonzero(keep) < min_accept:
             samples = self.sample_subset(keys=keys, size=sampling_chunk)
-            keep = np.atleast_1d(self.evaluate_constraints(samples))
-            if len(keep) == 1:
-                self._cached_normalizations[keys] = 1
-                return 1
-            all_samples = {key: np.array([]) for key in keys}
-            while np.count_nonzero(keep) < min_accept:
-                samples = self.sample_subset(keys=keys, size=sampling_chunk)
-                for key in samples:
-                    all_samples[key] = np.hstack(
-                        [all_samples[key], samples[key].flatten()])
-                keep = np.array(self.evaluate_constraints(all_samples), dtype=bool)
-            factor = len(keep) / np.count_nonzero(keep)
-            self._cached_normalizations[keys] = factor
-            return factor
+            for key in samples:
+                all_samples[key] = np.hstack(
+                    [all_samples[key], samples[key].flatten()])
+            keep = np.array(self.evaluate_constraints(all_samples), dtype=bool)
+        factor = len(keep) / np.count_nonzero(keep)
+        return factor
 
     def prob(self, sample, **kwargs):
         """
@@ -468,11 +479,11 @@ class PriorDict(dict):
     def check_prob(self, sample, prob):
         ratio = self.normalize_constraint_factor(tuple(sample.keys()))
         if np.all(prob == 0.):
-            return prob
+            return prob * ratio
         else:
             if isinstance(prob, float):
                 if self.evaluate_constraints(sample):
-                    return prob
+                    return prob * ratio
                 else:
                     return 0.
             else:
@@ -508,7 +519,7 @@ class PriorDict(dict):
         else:
             if isinstance(ln_prob, float):
                 if self.evaluate_constraints(sample):
-                    return ln_prob
+                    return ln_prob + np.log(ratio)
                 else:
                     return -np.inf
             else:
diff --git a/bilby/core/result.py b/bilby/core/result.py
index dbe7e29ac..060ce6430 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -1429,20 +1429,19 @@ class Result(object):
             Function which adds in extra parameters to the data frame,
             should take the data_frame, likelihood and prior as arguments.
         """
-        try:
-            data_frame = self.posterior
-        except ValueError:
-            data_frame = pd.DataFrame(
-                self.samples, columns=self.search_parameter_keys)
-            data_frame = self._add_prior_fixed_values_to_posterior(
-                data_frame, priors)
-            data_frame['log_likelihood'] = getattr(
-                self, 'log_likelihood_evaluations', np.nan)
-            if self.log_prior_evaluations is None and priors is not None:
-                data_frame['log_prior'] = priors.ln_prob(
-                    dict(data_frame[self.search_parameter_keys]), axis=0)
-            else:
-                data_frame['log_prior'] = self.log_prior_evaluations
+
+        data_frame = pd.DataFrame(
+            self.samples, columns=self.search_parameter_keys)
+        data_frame = self._add_prior_fixed_values_to_posterior(
+            data_frame, priors)
+        data_frame['log_likelihood'] = getattr(
+            self, 'log_likelihood_evaluations', np.nan)
+        if self.log_prior_evaluations is None and priors is not None:
+            data_frame['log_prior'] = priors.ln_prob(
+                dict(data_frame[self.search_parameter_keys]), axis=0)
+        else:
+            data_frame['log_prior'] = self.log_prior_evaluations
+
         if conversion_function is not None:
             if "npool" in inspect.getargspec(conversion_function).args:
                 data_frame = conversion_function(data_frame, likelihood, priors, npool=npool)
diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py
index eb52cf289..77e481e95 100644
--- a/bilby/core/sampler/__init__.py
+++ b/bilby/core/sampler/__init__.py
@@ -173,7 +173,9 @@ def run_sampler(
     # Generate the meta-data if not given and append the likelihood meta_data
     if meta_data is None:
         meta_data = dict()
-    meta_data["likelihood"] = likelihood.meta_data
+    likelihood.label = label
+    likelihood.outdir = outdir
+    meta_data['likelihood'] = likelihood.meta_data
     meta_data["loaded_modules"] = loaded_modules_dict()
 
     if command_line_args.bilby_zero_likelihood_mode:
@@ -223,45 +225,52 @@ def run_sampler(
 
     if sampler.cached_result:
         logger.warning("Using cached result")
-        return sampler.cached_result
-
-    start_time = datetime.datetime.now()
-    if command_line_args.bilby_test_mode:
-        result = sampler._run_test()
+        result = sampler.cached_result
     else:
-        result = sampler.run_sampler()
-    end_time = datetime.datetime.now()
+        # Run the sampler
+        start_time = datetime.datetime.now()
+        if command_line_args.bilby_test_mode:
+            result = sampler._run_test()
+        else:
+            result = sampler.run_sampler()
+        end_time = datetime.datetime.now()
 
-    # Some samplers calculate the sampling time internally
-    if result.sampling_time is None:
-        result.sampling_time = end_time - start_time
-    elif isinstance(result.sampling_time, float):
-        result.sampling_time = datetime.timedelta(result.sampling_time)
-    logger.info("Sampling time: {}".format(result.sampling_time))
-    # Convert sampling time into seconds
-    result.sampling_time = result.sampling_time.total_seconds()
+        # Some samplers calculate the sampling time internally
+        if result.sampling_time is None:
+            result.sampling_time = end_time - start_time
+        elif isinstance(result.sampling_time, (float, int)):
+            result.sampling_time = datetime.timedelta(result.sampling_time)
 
-    if sampler.use_ratio:
-        result.log_noise_evidence = likelihood.noise_log_likelihood()
-        result.log_bayes_factor = result.log_evidence
-        result.log_evidence = result.log_bayes_factor + result.log_noise_evidence
-    else:
-        result.log_noise_evidence = likelihood.noise_log_likelihood()
-        result.log_bayes_factor = result.log_evidence - result.log_noise_evidence
+        logger.info('Sampling time: {}'.format(result.sampling_time))
+        # Convert sampling time into seconds
+        result.sampling_time = result.sampling_time.total_seconds()
 
-    # Initial save of the sampler in case of failure in post-processing
-    if save:
-        result.save_to_file(extension=save, gzip=gzip)
+        if sampler.use_ratio:
+            result.log_noise_evidence = likelihood.noise_log_likelihood()
+            result.log_bayes_factor = result.log_evidence
+            result.log_evidence = \
+                result.log_bayes_factor + result.log_noise_evidence
+        else:
+            result.log_noise_evidence = likelihood.noise_log_likelihood()
+            result.log_bayes_factor = \
+                result.log_evidence - result.log_noise_evidence
+
+        if None not in [result.injection_parameters, conversion_function]:
+            result.injection_parameters = conversion_function(
+                result.injection_parameters)
+
+        # Initial save of the sampler in case of failure in samples_to_posterior
+        if save:
+            result.save_to_file(extension=save, gzip=gzip)
 
     if None not in [result.injection_parameters, conversion_function]:
         result.injection_parameters = conversion_function(result.injection_parameters)
 
-    result.samples_to_posterior(
-        likelihood=likelihood,
-        priors=result.priors,
-        conversion_function=conversion_function,
-        npool=npool,
-    )
+    # Check if the posterior has already been created
+    if getattr(result, "_posterior", None) is None:
+        result.samples_to_posterior(likelihood=likelihood, priors=result.priors,
+                                    conversion_function=conversion_function,
+                                    npool=npool)
 
     if save:
         # The overwrite here ensures we overwrite the initially stored data
diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index 5e25a7e54..e590ee242 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -505,14 +505,15 @@ class Sampler(object):
 
         logger.debug("Checking cached data")
         if self.cached_result:
-            check_keys = ['search_parameter_keys', 'fixed_parameter_keys',
-                          'kwargs']
+            check_keys = ['search_parameter_keys', 'fixed_parameter_keys']
             use_cache = True
             for key in check_keys:
                 if self.cached_result._check_attribute_match_to_other_object(
                         key, self) is False:
                     logger.debug("Cached value {} is unmatched".format(key))
                     use_cache = False
+            if self.meta_data["likelihood"] != self.cached_result.meta_data["likelihood"]:
+                use_cache = False
             if use_cache is False:
                 self.cached_result = None
 
diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py
index bad9636be..f2d9051d5 100644
--- a/bilby/core/sampler/dynesty.py
+++ b/bilby/core/sampler/dynesty.py
@@ -610,6 +610,11 @@ class Dynesty(NestedSampler):
         from ... import __version__ as bilby_version
         from dynesty import __version__ as dynesty_version
         import dill
+
+        if getattr(self, "sampler", None) is None:
+            # Sampler not initialized, not able to write current state
+            return
+
         check_directory_exists_and_if_not_mkdir(self.outdir)
         end_time = datetime.datetime.now()
         if hasattr(self, 'start_time'):
diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py
index 85095dc6a..880abbfe6 100644
--- a/bilby/core/utils/io.py
+++ b/bilby/core/utils/io.py
@@ -4,6 +4,7 @@ import os
 import shutil
 from importlib import import_module
 from pathlib import Path
+from datetime import timedelta
 
 import numpy as np
 import pandas as pd
@@ -76,6 +77,12 @@ class BilbyJsonEncoder(json.JSONEncoder):
                 "__module__": obj.__module__,
                 "__name__": obj.__name__,
             }
+        if isinstance(obj, (timedelta)):
+            return {
+                "__timedelta__": True,
+                "__total_seconds__": obj.total_seconds()
+            }
+            return obj.isoformat()
         return json.JSONEncoder.default(self, obj)
 
 
@@ -171,6 +178,8 @@ def decode_bilby_json(dct):
     if dct.get("__function__", False) or dct.get("__class__", False):
         default = ".".join([dct["__module__"], dct["__name__"]])
         return getattr(import_module(dct["__module__"]), dct["__name__"], default)
+    if dct.get("__timedelta__", False):
+        return timedelta(seconds=dct["__total_seconds__"])
     return dct
 
 
diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py
index 758b0fdbe..492a6201f 100644
--- a/bilby/gw/conversion.py
+++ b/bilby/gw/conversion.py
@@ -1,11 +1,13 @@
+import os
 import sys
 import multiprocessing
+import pickle
 
 import numpy as np
 from pandas import DataFrame
 
 from ..core.likelihood import MarginalizedLikelihoodReconstructionError
-from ..core.utils import logger, solar_mass
+from ..core.utils import logger, solar_mass, command_line_args
 from ..core.prior import DeltaFunction
 from .utils import lalsim_SimInspiralTransformPrecessingNewInitialConditions
 from .eos.eos import SpectralDecompositionEOS, EOSFamily, IntegrateTOV
@@ -1201,7 +1203,7 @@ def _compute_snrs(args):
 
 
 def generate_posterior_samples_from_marginalized_likelihood(
-        samples, likelihood, npool=1):
+        samples, likelihood, npool=1, block=10, use_cache=True):
     """
     Reconstruct the distance posterior from a run which used a likelihood which
     explicitly marginalised over time/distance/phase.
@@ -1216,6 +1218,11 @@ def generate_posterior_samples_from_marginalized_likelihood(
         Likelihood used during sampling.
     npool: int, (default=1)
         If given, perform generation (where possible) using a multiprocessing pool
+    block: int, (default=10)
+        Size of the blocks to use in multiprocessing
+    use_cache: bool, (default=True)
+        If true, cache the generation so that reconstuction can begin from the
+        cache on restart.
 
     Returns
     =======
@@ -1237,23 +1244,82 @@ def generate_posterior_samples_from_marginalized_likelihood(
 
     logger.info('Reconstructing marginalised parameters.')
 
-    fill_args = [(ii, row, likelihood) for ii, row in samples.iterrows()]
+    try:
+        cache_filename = f"{likelihood.outdir}/.{likelihood.label}_generate_posterior_cache.pickle"
+    except AttributeError:
+        logger.warning("Likelihood has no outdir and label attribute: caching disabled")
+        use_cache = False
+
+    if use_cache and os.path.exists(cache_filename) and not command_line_args.clean:
+        with open(cache_filename, "rb") as f:
+            cached_samples_dict = pickle.load(f)
+
+        # Check the samples are identical between the cache and current
+        if cached_samples_dict["_samples"].equals(samples):
+            # Calculate reconstruction percentage and print a log message
+            nsamples_converted = np.sum(
+                [len(val) for key, val in cached_samples_dict.items() if key != "_samples"]
+            )
+            perc = 100 * nsamples_converted / len(cached_samples_dict["_samples"])
+            logger.info(f'Using cached reconstruction with {perc:0.1f}% converted.')
+        else:
+            logger.info("Cached samples dict out of date, ignoring")
+            cached_samples_dict = dict(_samples=samples)
+
+    else:
+        # Initialize cache dict
+        cached_samples_dict = dict()
+
+        # Store samples to convert for checking
+        cached_samples_dict["_samples"] = samples
+
+    # Set up the multiprocessing
     if npool > 1:
         pool = multiprocessing.Pool(processes=npool)
         logger.info(
             "Using a pool with size {} for nsamples={}"
             .format(npool, len(samples))
         )
-        new_samples = np.array(pool.map(fill_sample, tqdm(fill_args, file=sys.stdout)))
-        pool.close()
     else:
-        new_samples = np.array([fill_sample(xx) for xx in tqdm(fill_args, file=sys.stdout)])
+        pool = None
+
+    fill_args = [(ii, row, likelihood) for ii, row in samples.iterrows()]
+    ii = 0
+    pbar = tqdm(total=len(samples), file=sys.stdout)
+    while ii < len(samples):
+        if ii in cached_samples_dict:
+            ii += block
+            pbar.update(block)
+            continue
+
+        if pool is not None:
+            subset_samples = pool.map(fill_sample, fill_args[ii: ii + block])
+        else:
+            subset_samples = [list(fill_sample(xx)) for xx in fill_args[ii: ii + block]]
+
+        cached_samples_dict[ii] = subset_samples
+
+        if use_cache:
+            with open(cache_filename, "wb") as f:
+                pickle.dump(cached_samples_dict, f)
+
+        ii += block
+        pbar.update(len(subset_samples))
+    pbar.close()
+
+    if pool is not None:
+        pool.close()
+
+    new_samples = np.concatenate(
+        [np.array(val) for key, val in cached_samples_dict.items() if key != "_samples"]
+    )
 
     samples['geocent_time'] = new_samples[:, 0]
     samples['luminosity_distance'] = new_samples[:, 1]
     samples['phase'] = new_samples[:, 2]
     if likelihood.calibration_marginalization:
         samples['recalib_index'] = new_samples[:, 3]
+
     return samples
 
 
-- 
GitLab