diff --git a/bilby/core/result.py b/bilby/core/result.py index 718f921b449d181896a503218272c3613e4bb9ed..380e584a0fca3caa8e773f9ee0db5c134c431011 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -499,11 +499,44 @@ class Result(object): logger.error("\n\n Saving the data has failed with the " "following message:\n {} \n\n".format(e)) - def save_posterior_samples(self, outdir=None): - """Saves posterior samples to a file""" - outdir = self._safe_outdir_creation(outdir, self.save_posterior_samples) - filename = '{}/{}_posterior_samples.txt'.format(outdir, self.label) - self.posterior.to_csv(filename, index=False, header=True) + def save_posterior_samples(self, filename=None, outdir=None, label=None): + """ Saves posterior samples to a file + + Generates a .dat file containing the posterior samples and auxillary + data saved in the posterior. Note, strings in the posterior are + removed while complex numbers will be given as absolute values with + abs appended to the column name + + Parameters + ---------- + filename: str + Alternative filename to use. Defaults to + outdir/label_posterior_samples.dat + outdir, label: str + Alternative outdir and label to use + + """ + if filename is None: + if label is None: + label = self.label + outdir = self._safe_outdir_creation(outdir, self.save_posterior_samples) + filename = '{}/{}_posterior_samples.dat'.format(outdir, label) + else: + outdir = os.path.dirname(filename) + self._safe_outdir_creation(outdir, self.save_posterior_samples) + + # Drop non-numeric columns + df = self.posterior.select_dtypes([np.number]).copy() + + # Convert complex columns to abs + for key in df.keys(): + if np.any(np.iscomplex(df[key])): + complex_term = df.pop(key) + df.loc[:, key + "_abs"] = np.abs(complex_term) + df.loc[:, key + "_angle"] = np.angle(complex_term) + + logger.info("Writing samples file to {}".format(filename)) + df.to_csv(filename, index=False, header=True, sep=' ') def get_latex_labels_from_parameter_keys(self, keys): """ Returns a list of latex_labels corresponding to the given keys diff --git a/test/result_test.py b/test/result_test.py index d504063ca49ae090f459fd79b9692ee485c5a28b..32a7e34ca4ea3f5ad234fbb28ca2bb9aa04e4c73 100644 --- a/test/result_test.py +++ b/test/result_test.py @@ -261,11 +261,26 @@ class TestResult(unittest.TestCase): def test_save_samples(self): self.result.save_posterior_samples() - filename = '{}/{}_posterior_samples.txt'.format(self.result.outdir, self.result.label) + filename = '{}/{}_posterior_samples.dat'.format(self.result.outdir, self.result.label) self.assertTrue(os.path.isfile(filename)) - df = pd.read_csv(filename) + df = pd.read_csv(filename, sep=' ') self.assertTrue(np.allclose(self.result.posterior.values, df.values)) + def test_save_samples_from_filename(self): + filename = '{}/{}_posterior_samples_OTHER.dat'.format(self.result.outdir, self.result.label) + self.result.save_posterior_samples(filename=filename) + self.assertTrue(os.path.isfile(filename)) + df = pd.read_csv(filename, sep=' ') + self.assertTrue(np.allclose(self.result.posterior.values, df.values)) + + def test_save_samples_numpy_load(self): + self.result.save_posterior_samples() + filename = '{}/{}_posterior_samples.dat'.format(self.result.outdir, self.result.label) + self.assertTrue(os.path.isfile(filename)) + data = np.genfromtxt(filename, names=True) + df = pd.read_csv(filename, sep=' ') + self.assertTrue(len(data.dtype) == len(df.keys())) + def test_samples_to_posterior_simple(self): self.result.posterior = None x = [1, 2, 3]