From 346566dc423c9f64695b0993e910499abfa25915 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Tue, 13 Dec 2022 22:48:14 +0000
Subject: [PATCH] MAINT: enforce safe_file_dump

---
 bilby/core/result.py                |  9 +++------
 bilby/core/sampler/dynesty.py       |  9 ++-------
 bilby/core/sampler/emcee.py         | 22 +++++++++++-----------
 bilby/core/sampler/ptemcee.py       |  7 ++-----
 bilby/core/utils/io.py              |  7 ++++---
 bilby/gw/conversion.py              |  5 ++---
 bilby/gw/detector/interferometer.py |  6 ++----
 bilby/gw/detector/networks.py       |  7 ++-----
 bilby/gw/result.py                  |  5 ++---
 9 files changed, 30 insertions(+), 47 deletions(-)

diff --git a/bilby/core/result.py b/bilby/core/result.py
index 6bcf8c81c..4124c197d 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -22,6 +22,7 @@ from .utils import (
     recursively_save_dict_contents_to_group,
     recursively_load_dict_contents_from_group,
     recursively_decode_bilby_json,
+    safe_file_dump,
 )
 from .prior import Prior, PriorDict, DeltaFunction, ConditionalDeltaFunction
 
@@ -735,16 +736,12 @@ class Result(object):
                 with h5py.File(filename, 'w') as h5file:
                     recursively_save_dict_contents_to_group(h5file, '/', dictionary)
             elif extension == 'pkl':
-                import dill
-                with open(filename, "wb") as ff:
-                    dill.dump(self, ff)
+                safe_file_dump(self, filename, "dill")
             else:
                 raise ValueError("Extension type {} not understood".format(extension))
         except Exception as e:
-            import dill
             filename = ".".join(filename.split(".")[:-1]) + ".pkl"
-            with open(filename, "wb") as ff:
-                dill.dump(self, ff)
+            safe_file_dump(self, filename, "dill")
             logger.error(
                 "\n\nSaving the data has failed with the following message:\n"
                 "{}\nData has been dumped to {}.\n\n".format(e, filename)
diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py
index a4e459f59..2b636bd0f 100644
--- a/bilby/core/sampler/dynesty.py
+++ b/bilby/core/sampler/dynesty.py
@@ -343,14 +343,11 @@ class Dynesty(NestedSampler):
         self.kwargs[key] = selected
 
     def nestcheck_data(self, out_file):
-        import pickle
-
         import nestcheck.data_processing
 
         ns_run = nestcheck.data_processing.process_dynesty_run(out_file)
         nestcheck_result = f"{self.outdir}/{self.label}_nestcheck.pickle"
-        with open(nestcheck_result, "wb") as file_nest:
-            pickle.dump(ns_run, file_nest)
+        safe_file_dump(ns_run, nestcheck_result, "pickle")
 
     @property
     def nlive(self):
@@ -370,7 +367,6 @@ class Dynesty(NestedSampler):
 
     @signal_wrapper
     def run_sampler(self):
-        import dill
         import dynesty
 
         logger.info(f"Using dynesty version {dynesty.__version__}")
@@ -436,8 +432,7 @@ class Dynesty(NestedSampler):
             self.nestcheck_data(out)
 
         dynesty_result = f"{self.outdir}/{self.label}_dynesty.pickle"
-        with open(dynesty_result, "wb") as file:
-            dill.dump(out, file)
+        safe_file_dump(out, dynesty_result, "dill")
 
         self._generate_result(out)
         self.result.sampling_time = self.sampling_time
diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py
index 18a36fd13..112b56dd7 100644
--- a/bilby/core/sampler/emcee.py
+++ b/bilby/core/sampler/emcee.py
@@ -7,7 +7,7 @@ from shutil import copyfile
 import numpy as np
 from pandas import DataFrame
 
-from ..utils import check_directory_exists_and_if_not_mkdir, logger
+from ..utils import check_directory_exists_and_if_not_mkdir, logger, safe_file_dump
 from .base_sampler import MCMCSampler, SamplerError, signal_wrapper
 from .ptemcee import LikePriorEvaluator
 
@@ -285,20 +285,20 @@ class Emcee(MCMCSampler):
         return self.sampler.chain[:, :nsteps, :]
 
     def write_current_state(self):
-        """Writes a pickle file of the sampler to disk using dill"""
-        import dill
+        """
+        Writes a pickle file of the sampler to disk using dill
 
+        Overwrites the stored sampler chain with one that is truncated
+        to only the completed steps
+        """
         logger.info(
             f"Checkpointing sampler to file {self.checkpoint_info.sampler_file}"
         )
-        with open(self.checkpoint_info.sampler_file, "wb") as f:
-            # Overwrites the stored sampler chain with one that is truncated
-            # to only the completed steps
-            self.sampler._chain = self.sampler_chain
-            _pool = self.sampler.pool
-            self.sampler.pool = None
-            dill.dump(self._sampler, f)
-            self.sampler.pool = _pool
+        self.sampler._chain = self.sampler_chain
+        _pool = self.sampler.pool
+        self.sampler.pool = None
+        safe_file_dump(self._sampler, self.checkpoint_info.sampler_file, "dill")
+        self.sampler.pool = _pool
 
     def _initialise_sampler(self):
         from emcee import EnsembleSampler
diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py
index 85b253b1e..e9da5371c 100644
--- a/bilby/core/sampler/ptemcee.py
+++ b/bilby/core/sampler/ptemcee.py
@@ -8,7 +8,7 @@ from collections import namedtuple
 import numpy as np
 import pandas as pd
 
-from ..utils import check_directory_exists_and_if_not_mkdir, logger
+from ..utils import check_directory_exists_and_if_not_mkdir, logger, safe_file_dump
 from .base_sampler import (
     MCMCSampler,
     SamplerError,
@@ -1189,8 +1189,6 @@ def checkpoint(
     Q_list,
     time_per_check,
 ):
-    import dill
-
     logger.info("Writing checkpoint and diagnostics")
     ndim = sampler.dim
 
@@ -1223,8 +1221,7 @@ def checkpoint(
         pos0=pos0,
     )
 
-    with open(resume_file, "wb") as file:
-        dill.dump(data, file, protocol=4)
+    safe_file_dump(data, resume_file, "dill")
     del data, sampler_copy
     logger.info("Finished writing checkpoint")
 
diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py
index 6d1482714..d6b66750d 100644
--- a/bilby/core/utils/io.py
+++ b/bilby/core/utils/io.py
@@ -370,10 +370,11 @@ def safe_file_dump(data, filename, module):
         data to dump
     filename: str
         The file to dump to
-    module: pickle, dill
-        The python module to use
+    module: pickle, dill, str
+        The python module to use. If a string, the module will be imported
     """
-
+    if isinstance(module, str):
+        module = import_module(module)
     temp_filename = filename + ".temp"
     with open(temp_filename, "wb") as file:
         module.dump(data, file)
diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py
index 3b3b255b3..3c57c7df4 100644
--- a/bilby/gw/conversion.py
+++ b/bilby/gw/conversion.py
@@ -12,7 +12,7 @@ import numpy as np
 from pandas import DataFrame, Series
 
 from ..core.likelihood import MarginalizedLikelihoodReconstructionError
-from ..core.utils import logger, solar_mass, command_line_args
+from ..core.utils import logger, solar_mass, command_line_args, safe_file_dump
 from ..core.prior import DeltaFunction
 from .utils import lalsim_SimInspiralTransformPrecessingNewInitialConditions
 from .eos.eos import SpectralDecompositionEOS, EOSFamily, IntegrateTOV
@@ -1687,8 +1687,7 @@ def generate_posterior_samples_from_marginalized_likelihood(
         cached_samples_dict[ii] = subset_samples
 
         if use_cache:
-            with open(cache_filename, "wb") as f:
-                pickle.dump(cached_samples_dict, f)
+            safe_file_dump(cached_samples_dict, cache_filename, "pickle")
 
         ii += block
         pbar.update(len(subset_samples))
diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py
index 9f8c29147..86c5c03da 100644
--- a/bilby/gw/detector/interferometer.py
+++ b/bilby/gw/detector/interferometer.py
@@ -8,7 +8,7 @@ from bilby_cython.geometry import (
 )
 
 from ...core import utils
-from ...core.utils import docstring, logger, PropertyAccessor
+from ...core.utils import docstring, logger, PropertyAccessor, safe_file_dump
 from .. import utils as gwutils
 from .calibration import Recalibrate
 from .geometry import InterferometerGeometry
@@ -798,11 +798,9 @@ class Interferometer(object):
         format="pickle", extra=".. versionadded:: 1.1.0"
     ))
     def to_pickle(self, outdir="outdir", label=None):
-        import dill
         utils.check_directory_exists_and_if_not_mkdir('outdir')
         filename = self._filename_from_outdir_label_extension(outdir, label, extension="pkl")
-        with open(filename, "wb") as ff:
-            dill.dump(self, ff)
+        safe_file_dump(self, filename, "dill")
 
     @classmethod
     @docstring(_load_docstring.format(format="pickle"))
diff --git a/bilby/gw/detector/networks.py b/bilby/gw/detector/networks.py
index 22092649d..00a8e14cf 100644
--- a/bilby/gw/detector/networks.py
+++ b/bilby/gw/detector/networks.py
@@ -4,7 +4,7 @@ import numpy as np
 import math
 
 from ...core import utils
-from ...core.utils import logger
+from ...core.utils import logger, safe_file_dump
 from .interferometer import Interferometer
 from .psd import PowerSpectralDensity
 
@@ -273,15 +273,12 @@ class InterferometerList(list):
     """
 
     def to_pickle(self, outdir="outdir", label="ifo_list"):
-        import dill
-
         utils.check_directory_exists_and_if_not_mkdir("outdir")
         label = label + "_" + "".join(ifo.name for ifo in self)
         filename = self._filename_from_outdir_label_extension(
             outdir, label, extension="pkl"
         )
-        with open(filename, "wb") as ff:
-            dill.dump(self, ff)
+        safe_file_dump(self, filename, "dill")
 
     @classmethod
     def from_pickle(cls, filename=None):
diff --git a/bilby/gw/result.py b/bilby/gw/result.py
index 154053b25..197173b7a 100644
--- a/bilby/gw/result.py
+++ b/bilby/gw/result.py
@@ -7,7 +7,7 @@ import numpy as np
 from ..core.result import Result as CoreResult
 from ..core.utils import (
     infft, logger, check_directory_exists_and_if_not_mkdir,
-    latex_plot_format, safe_save_figure
+    latex_plot_format, safe_file_dump, safe_save_figure,
 )
 from .utils import plot_spline_pos, spline_angle_xform, asd_from_freq_series
 from .detector import get_empty_interferometer, Interferometer
@@ -781,8 +781,7 @@ class CompactBinaryCoalescenceResult(CoreResult):
             logger.info('Initialising skymap class')
             skypost = confidence_levels(pts, trials=trials, jobs=jobs)
             logger.info('Pickling skymap to {}'.format(default_obj_filename))
-            with open(default_obj_filename, 'wb') as out:
-                pickle.dump(skypost, out)
+            safe_file_dump(skypost, default_obj_filename, "pickle")
 
         else:
             if isinstance(load_pickle, str):
-- 
GitLab