diff --git a/bilby/core/prior.py b/bilby/core/prior.py index cd1fa31baec5e08ba5876db7bb05e0127a5ad328..48cdd527d1fa59cf6030abd391e80962951658e4 100644 --- a/bilby/core/prior.py +++ b/bilby/core/prior.py @@ -3,7 +3,6 @@ from __future__ import division import re from importlib import import_module import os -from collections import OrderedDict from future.utils import iteritems import json from io import open as ioopen @@ -24,7 +23,7 @@ from .utils import ( ) -class PriorDict(OrderedDict): +class PriorDict(dict): def __init__(self, dictionary=None, filename=None, conversion_function=None): """ A set of priors @@ -472,7 +471,7 @@ class PriorDict(OrderedDict): We have to overwrite the copy method as it fails due to the presence of defaults. """ - return self.__class__(dictionary=OrderedDict(self)) + return self.__class__(dictionary=dict(self)) class PriorSet(PriorDict): @@ -777,7 +776,7 @@ class Prior(object): dict_with_properties = self.__dict__.copy() for key in property_names: dict_with_properties[key] = getattr(self, key) - instantiation_dict = OrderedDict() + instantiation_dict = dict() for key in subclass_args: instantiation_dict[key] = dict_with_properties[key] return instantiation_dict @@ -2677,11 +2676,11 @@ class MultivariateGaussianDist(object): self.add_mode(mu, sigma, corrcoef, cov, weight) # a dictionary of the parameters as requested by the prior - self.requested_parameters = OrderedDict() + self.requested_parameters = dict() self.reset_request() # a dictionary of the rescaled parameters - self.rescale_parameters = OrderedDict() + self.rescale_parameters = dict() self.reset_rescale() # a list of sampled parameters @@ -2977,7 +2976,7 @@ class MultivariateGaussianDist(object): dict_with_properties = self.__dict__.copy() for key in property_names: dict_with_properties[key] = getattr(self, key) - instantiation_dict = OrderedDict() + instantiation_dict = dict() for key in subclass_args: if isinstance(dict_with_properties[key], list): value = np.asarray(dict_with_properties[key]).tolist() diff --git a/test/gw_prior_test.py b/test/gw_prior_test.py index ed4cee33c9288498aa0348ba5cd02ade84fa8bb7..d594d2971f813d5a9d0570eed07da497557ded18 100644 --- a/test/gw_prior_test.py +++ b/test/gw_prior_test.py @@ -3,6 +3,7 @@ from collections import OrderedDict import unittest import os import sys +import pickle import numpy as np from astropy import cosmology @@ -110,6 +111,16 @@ class TestBBHPriorDict(unittest.TestCase): minimum=20, maximum=40, name='chirp_mass') self.assertFalse(self.bbh_prior_dict.test_has_redundant_keys()) + def test_pickle_prior(self): + priors = dict(chirp_mass=bilby.core.prior.Uniform(10, 20), + mass_ratio=bilby.core.prior.Uniform(0.125, 1)) + priors = bilby.gw.prior.BBHPriorDict(priors) + with open("test.pickle", "wb") as file: + pickle.dump(priors, file) + with open("test.pickle", "rb") as file: + priors_loaded = pickle.load(file) + self.assertEqual(priors, priors_loaded) + class TestPackagedPriors(unittest.TestCase): """ Test that the prepackaged priors load """ diff --git a/test/prior_test.py b/test/prior_test.py index e9a06d6531314e9051092564288b1554ff219b25..a0123982de0e92ee0f07f0c7f92e8fba54d72e6e 100644 --- a/test/prior_test.py +++ b/test/prior_test.py @@ -658,8 +658,8 @@ class TestPriorDict(unittest.TestCase): priors_set = bilby.core.prior.PriorSet(self.priors) self.assertEqual(priors_dict, priors_set) - def test_prior_set_is_ordered_dict(self): - self.assertIsInstance(self.prior_set_from_dict, OrderedDict) + def test_prior_set_is_dict(self): + self.assertIsInstance(self.prior_set_from_dict, dict) def test_prior_set_has_correct_length(self): self.assertEqual(3, len(self.prior_set_from_dict))