Commit 0b1aa313 authored by Colm Talbot's avatar Colm Talbot

Merge branch 'read_mvn_prior_from_file' into 'master'

Read and write a MultivariateGaussian prior to file

See merge request !515
parents 8d403f27 917c49a2
Pipeline #70842 passed with stages
in 6 minutes and 20 seconds
......@@ -89,10 +89,23 @@ class PriorDict(OrderedDict):
prior_file = os.path.join(outdir, "{}.prior".format(label))
logger.debug("Writing priors to {}".format(prior_file))
mvgs = []
with open(prior_file, "w") as outfile:
for key in self.keys():
"{} = {}\n".format(key, self[key]))
if isinstance(self[key], MultivariateGaussian):
mvgname = '_'.join(self[key].mvg.names) + '_mvg'
if mvgname not in mvgs:
"{} = {}\n".format(mvgname, self[key].mvg))
mvgstr = repr(self[key].mvg)
priorstr = repr(self[key])
"{} = {}\n".format(key, priorstr.replace(mvgstr,
"{} = {}\n".format(key, self[key]))
def from_file(self, filename):
""" Reads in a prior from a file specification
......@@ -114,6 +127,7 @@ class PriorDict(OrderedDict):
comments = ['#', '\n']
prior = dict()
mvgdict = dict(inf=np.inf) # evaluate inf as np.inf
with open(filename, 'r') as f:
for line in f:
if line[0] in comments:
......@@ -139,6 +153,13 @@ class PriorDict(OrderedDict):
cls = getattr(import_module(module), cls)
if key.lower() == "conversion_function":
setattr(self, key, cls)
elif (cls.__name__ in ['MultivariateGaussianDist',
if key not in mvgdict:
mvgdict[key] = eval(val, None, mvgdict)
elif (cls.__name__ in ['MultivariateGaussian',
prior[key] = eval(val, None, mvgdict)
prior[key] = cls.from_repr(args)
......@@ -2825,6 +2846,67 @@ class MultivariateGaussianDist(object):
def __len__(self):
return len(self.names)
def __repr__(self):
"""Overrides the special method __repr__.
Returns a representation of this instance that resembles how it is instantiated.
Works correctly for all child classes
str: A string representation of this instance
subclass_args = infer_args_from_method(self.__init__)
dist_name = self.__class__.__name__
property_names = [p for p in dir(self.__class__) if isinstance(getattr(self.__class__, p), property)]
dict_with_properties = self.__dict__.copy()
for key in property_names:
dict_with_properties[key] = getattr(self, key)
argslist = []
for key in subclass_args:
# make sure lists containing arrays are returned just as lists
if isinstance(dict_with_properties[key], list):
argsval = np.asarray(dict_with_properties[key]).tolist()
argsval = dict_with_properties[key]
argslist.append('{}={}'.format(key, repr(argsval)))
args = ', '.join(argslist)
return "{}({})".format(dist_name, args)
def __eq__(self, other):
if self.__class__ != other.__class__:
return False
if sorted(self.__dict__.keys()) != sorted(other.__dict__.keys()):
return False
for key in self.__dict__:
if key == 'mvn':
if len(self.__dict__[key]) != len(other.__dict__[key]):
return False
for thismvn, othermvn in zip(self.__dict__[key], other.__dict__[key]):
if (not isinstance(thismvn, scipy.stats._multivariate.multivariate_normal_frozen) or
not isinstance(othermvn, scipy.stats._multivariate.multivariate_normal_frozen)):
return False
elif isinstance(self.__dict__[key], (np.ndarray, list)):
thisarr = np.asarray(self.__dict__[key])
otherarr = np.asarray(other.__dict__[key])
if thisarr.dtype == np.float and otherarr.dtype == np.float:
fin1 = np.isfinite(np.asarray(self.__dict__[key]))
fin2 = np.isfinite(np.asarray(other.__dict__[key]))
if not np.array_equal(fin1, fin2):
return False
if not np.allclose(thisarr[fin1], otherarr[fin2], atol=1e-15):
return False
if not np.array_equal(thisarr, otherarr):
return False
if not self.__dict__[key] == other.__dict__[key]:
return False
return True
class MultivariateNormalDist(MultivariateGaussianDist):
......@@ -479,12 +479,16 @@ class TestPriorClasses(unittest.TestCase):
if isinstance(prior, bilby.core.prior.Interped):
continue # we cannot test this because of the numpy arrays
elif isinstance(prior, bilby.core.prior.MultivariateGaussian):
continue # we cannot test this because of the internal objects
repr_prior_string = 'bilby.core.prior.' + repr(prior)
repr_prior_string = repr_prior_string.replace(
elif isinstance(prior,
repr_prior_string = '' + repr(prior)
repr_prior_string = 'bilby.core.prior.' + repr(prior)
repr_prior = eval(repr_prior_string)
repr_prior = eval(repr_prior_string, None, dict(inf=np.inf))
self.assertEqual(prior, repr_prior)
def test_set_maximum_setting(self):
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment