From 346566dc423c9f64695b0993e910499abfa25915 Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Tue, 13 Dec 2022 22:48:14 +0000 Subject: [PATCH] MAINT: enforce safe_file_dump --- bilby/core/result.py | 9 +++------ bilby/core/sampler/dynesty.py | 9 ++------- bilby/core/sampler/emcee.py | 22 +++++++++++----------- bilby/core/sampler/ptemcee.py | 7 ++----- bilby/core/utils/io.py | 7 ++++--- bilby/gw/conversion.py | 5 ++--- bilby/gw/detector/interferometer.py | 6 ++---- bilby/gw/detector/networks.py | 7 ++----- bilby/gw/result.py | 5 ++--- 9 files changed, 30 insertions(+), 47 deletions(-) diff --git a/bilby/core/result.py b/bilby/core/result.py index 6bcf8c81c..4124c197d 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -22,6 +22,7 @@ from .utils import ( recursively_save_dict_contents_to_group, recursively_load_dict_contents_from_group, recursively_decode_bilby_json, + safe_file_dump, ) from .prior import Prior, PriorDict, DeltaFunction, ConditionalDeltaFunction @@ -735,16 +736,12 @@ class Result(object): with h5py.File(filename, 'w') as h5file: recursively_save_dict_contents_to_group(h5file, '/', dictionary) elif extension == 'pkl': - import dill - with open(filename, "wb") as ff: - dill.dump(self, ff) + safe_file_dump(self, filename, "dill") else: raise ValueError("Extension type {} not understood".format(extension)) except Exception as e: - import dill filename = ".".join(filename.split(".")[:-1]) + ".pkl" - with open(filename, "wb") as ff: - dill.dump(self, ff) + safe_file_dump(self, filename, "dill") logger.error( "\n\nSaving the data has failed with the following message:\n" "{}\nData has been dumped to {}.\n\n".format(e, filename) diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index a4e459f59..2b636bd0f 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -343,14 +343,11 @@ class Dynesty(NestedSampler): self.kwargs[key] = selected def nestcheck_data(self, out_file): - import pickle - import nestcheck.data_processing ns_run = nestcheck.data_processing.process_dynesty_run(out_file) nestcheck_result = f"{self.outdir}/{self.label}_nestcheck.pickle" - with open(nestcheck_result, "wb") as file_nest: - pickle.dump(ns_run, file_nest) + safe_file_dump(ns_run, nestcheck_result, "pickle") @property def nlive(self): @@ -370,7 +367,6 @@ class Dynesty(NestedSampler): @signal_wrapper def run_sampler(self): - import dill import dynesty logger.info(f"Using dynesty version {dynesty.__version__}") @@ -436,8 +432,7 @@ class Dynesty(NestedSampler): self.nestcheck_data(out) dynesty_result = f"{self.outdir}/{self.label}_dynesty.pickle" - with open(dynesty_result, "wb") as file: - dill.dump(out, file) + safe_file_dump(out, dynesty_result, "dill") self._generate_result(out) self.result.sampling_time = self.sampling_time diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index 18a36fd13..112b56dd7 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -7,7 +7,7 @@ from shutil import copyfile import numpy as np from pandas import DataFrame -from ..utils import check_directory_exists_and_if_not_mkdir, logger +from ..utils import check_directory_exists_and_if_not_mkdir, logger, safe_file_dump from .base_sampler import MCMCSampler, SamplerError, signal_wrapper from .ptemcee import LikePriorEvaluator @@ -285,20 +285,20 @@ class Emcee(MCMCSampler): return self.sampler.chain[:, :nsteps, :] def write_current_state(self): - """Writes a pickle file of the sampler to disk using dill""" - import dill + """ + Writes a pickle file of the sampler to disk using dill + Overwrites the stored sampler chain with one that is truncated + to only the completed steps + """ logger.info( f"Checkpointing sampler to file {self.checkpoint_info.sampler_file}" ) - with open(self.checkpoint_info.sampler_file, "wb") as f: - # Overwrites the stored sampler chain with one that is truncated - # to only the completed steps - self.sampler._chain = self.sampler_chain - _pool = self.sampler.pool - self.sampler.pool = None - dill.dump(self._sampler, f) - self.sampler.pool = _pool + self.sampler._chain = self.sampler_chain + _pool = self.sampler.pool + self.sampler.pool = None + safe_file_dump(self._sampler, self.checkpoint_info.sampler_file, "dill") + self.sampler.pool = _pool def _initialise_sampler(self): from emcee import EnsembleSampler diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py index 85b253b1e..e9da5371c 100644 --- a/bilby/core/sampler/ptemcee.py +++ b/bilby/core/sampler/ptemcee.py @@ -8,7 +8,7 @@ from collections import namedtuple import numpy as np import pandas as pd -from ..utils import check_directory_exists_and_if_not_mkdir, logger +from ..utils import check_directory_exists_and_if_not_mkdir, logger, safe_file_dump from .base_sampler import ( MCMCSampler, SamplerError, @@ -1189,8 +1189,6 @@ def checkpoint( Q_list, time_per_check, ): - import dill - logger.info("Writing checkpoint and diagnostics") ndim = sampler.dim @@ -1223,8 +1221,7 @@ def checkpoint( pos0=pos0, ) - with open(resume_file, "wb") as file: - dill.dump(data, file, protocol=4) + safe_file_dump(data, resume_file, "dill") del data, sampler_copy logger.info("Finished writing checkpoint") diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py index 6d1482714..d6b66750d 100644 --- a/bilby/core/utils/io.py +++ b/bilby/core/utils/io.py @@ -370,10 +370,11 @@ def safe_file_dump(data, filename, module): data to dump filename: str The file to dump to - module: pickle, dill - The python module to use + module: pickle, dill, str + The python module to use. If a string, the module will be imported """ - + if isinstance(module, str): + module = import_module(module) temp_filename = filename + ".temp" with open(temp_filename, "wb") as file: module.dump(data, file) diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 3b3b255b3..3c57c7df4 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -12,7 +12,7 @@ import numpy as np from pandas import DataFrame, Series from ..core.likelihood import MarginalizedLikelihoodReconstructionError -from ..core.utils import logger, solar_mass, command_line_args +from ..core.utils import logger, solar_mass, command_line_args, safe_file_dump from ..core.prior import DeltaFunction from .utils import lalsim_SimInspiralTransformPrecessingNewInitialConditions from .eos.eos import SpectralDecompositionEOS, EOSFamily, IntegrateTOV @@ -1687,8 +1687,7 @@ def generate_posterior_samples_from_marginalized_likelihood( cached_samples_dict[ii] = subset_samples if use_cache: - with open(cache_filename, "wb") as f: - pickle.dump(cached_samples_dict, f) + safe_file_dump(cached_samples_dict, cache_filename, "pickle") ii += block pbar.update(len(subset_samples)) diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index 9f8c29147..86c5c03da 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -8,7 +8,7 @@ from bilby_cython.geometry import ( ) from ...core import utils -from ...core.utils import docstring, logger, PropertyAccessor +from ...core.utils import docstring, logger, PropertyAccessor, safe_file_dump from .. import utils as gwutils from .calibration import Recalibrate from .geometry import InterferometerGeometry @@ -798,11 +798,9 @@ class Interferometer(object): format="pickle", extra=".. versionadded:: 1.1.0" )) def to_pickle(self, outdir="outdir", label=None): - import dill utils.check_directory_exists_and_if_not_mkdir('outdir') filename = self._filename_from_outdir_label_extension(outdir, label, extension="pkl") - with open(filename, "wb") as ff: - dill.dump(self, ff) + safe_file_dump(self, filename, "dill") @classmethod @docstring(_load_docstring.format(format="pickle")) diff --git a/bilby/gw/detector/networks.py b/bilby/gw/detector/networks.py index 22092649d..00a8e14cf 100644 --- a/bilby/gw/detector/networks.py +++ b/bilby/gw/detector/networks.py @@ -4,7 +4,7 @@ import numpy as np import math from ...core import utils -from ...core.utils import logger +from ...core.utils import logger, safe_file_dump from .interferometer import Interferometer from .psd import PowerSpectralDensity @@ -273,15 +273,12 @@ class InterferometerList(list): """ def to_pickle(self, outdir="outdir", label="ifo_list"): - import dill - utils.check_directory_exists_and_if_not_mkdir("outdir") label = label + "_" + "".join(ifo.name for ifo in self) filename = self._filename_from_outdir_label_extension( outdir, label, extension="pkl" ) - with open(filename, "wb") as ff: - dill.dump(self, ff) + safe_file_dump(self, filename, "dill") @classmethod def from_pickle(cls, filename=None): diff --git a/bilby/gw/result.py b/bilby/gw/result.py index 154053b25..197173b7a 100644 --- a/bilby/gw/result.py +++ b/bilby/gw/result.py @@ -7,7 +7,7 @@ import numpy as np from ..core.result import Result as CoreResult from ..core.utils import ( infft, logger, check_directory_exists_and_if_not_mkdir, - latex_plot_format, safe_save_figure + latex_plot_format, safe_file_dump, safe_save_figure, ) from .utils import plot_spline_pos, spline_angle_xform, asd_from_freq_series from .detector import get_empty_interferometer, Interferometer @@ -781,8 +781,7 @@ class CompactBinaryCoalescenceResult(CoreResult): logger.info('Initialising skymap class') skypost = confidence_levels(pts, trials=trials, jobs=jobs) logger.info('Pickling skymap to {}'.format(default_obj_filename)) - with open(default_obj_filename, 'wb') as out: - pickle.dump(skypost, out) + safe_file_dump(skypost, default_obj_filename, "pickle") else: if isinstance(load_pickle, str): -- GitLab