Skip to content
Snippets Groups Projects
Commit 3d2da9b8 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'add-testing-for-pp-test' into 'master'

Refactor and add testing for the PP-test module

See merge request !112
parents 12f82cea 4b60aa12
No related branches found
No related tags found
1 merge request!112Refactor and add testing for the PP-test module
Pipeline #66268 passed
......@@ -6,17 +6,19 @@ import glob
import json
import os
from bilby.core.result import read_in_result, make_pp_plot
from bilby.core.result import read_in_result, make_pp_plot, ResultList
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import tqdm
import matplotlib as mpl
from .utils import logger
mpl.rcParams.update(mpl.rcParamsDefault)
def main():
def create_parser():
parser = argparse.ArgumentParser(
prog="bilby_pipe PP test",
usage="Generates a pp plot from a directory containing a set of results",
......@@ -32,23 +34,34 @@ def main():
parser.add_argument(
"-n", type=int, help="Number of samples to truncate to", default=None
)
args, _ = parser.parse_known_args()
return parser
def get_results_filenames(args):
results_files = []
for extension in ["json", "h5", "hdf5"]:
glob_string = os.path.join(args.directory, "*result*" + extension)
results_files += glob.glob(glob_string)
results_files = [rf for rf in results_files if os.path.isfile(rf)]
if len(results_files) == 0:
raise ValueError("No results found in path {}".format(args.directory))
raise FileNotFoundError("No results found in path {}".format(args.directory))
if args.n is not None:
results_files = results_files[: args.n]
return results_files
def check_consistency(results):
results._check_consistent_sampler()
results._check_consistent_data()
results._check_consistent_parameters()
results._check_consistent_priors()
def read_in_result_list(args, results_filenames):
print("Reading in results ...")
results = []
for f in tqdm.tqdm(results_files):
for f in tqdm.tqdm(results_filenames):
try:
results.append(read_in_result(f))
except json.decoder.JSONDecodeError:
......@@ -70,32 +83,44 @@ def main():
print(
"List of result-labels: {}".format(sorted([res.label for res in results]))
)
return ResultList(results)
r0 = results[0]
sampler = r0.sampler
def get_basename(args):
if args.outdir is None:
args.outdir = args.directory
basename = "{}/{}".format(args.outdir, sampler)
basename = "{}/".format(args.outdir)
if args.label is not None:
basename += "_{}".format(args.label)
basename += "{}_".format(args.label)
return basename
def main(args=None):
if args is None:
args = create_parser().parse_known_args()
results_filenames = get_results_filenames(args)
results = read_in_result_list(args, results_filenames)
check_consistency(results)
basename = get_basename(args)
print("Create the PP plot")
keys = r0.priors.keys()
print("Parameters = {}".format(keys))
make_pp_plot(results, filename="{}_pp.png".format(basename), keys=keys)
logger.info("Generating PP plot")
keys = results[0].priors.keys()
logger.info("Parameters = {}".format(keys))
make_pp_plot(results, filename="{}pp.png".format(basename), keys=keys)
print("Create sampling-time histogram")
logger.info("Create sampling-time histogram")
stimes = [r.sampling_time for r in results]
fig, ax = plt.subplots()
ax.hist(np.array(stimes) / 3600, bins=50)
ax.set_xlabel("Sampling time [hr]")
fig.tight_layout()
fig.savefig("{}_sampling_times.png".format(basename))
fig.savefig("{}sampling_times.png".format(basename))
print("Create optimal SNR plot")
logger.info("Create optimal SNR plot")
fig, ax = plt.subplots()
snrs = []
for det in ["H1", "L1"]:
detectors = list(results[0].meta_data["likelihood"]["interferometers"].keys())
for det in detectors:
snrs.append(
[
r.meta_data["likelihood"]["interferometers"][det]["optimal_SNR"]
......@@ -107,4 +132,4 @@ def main():
ax.hist(network_snr, bins=50, label=det)
ax.set_xlabel("Network optimal SNR")
fig.tight_layout()
fig.savefig("{}_optimal_SNR.png".format(basename))
fig.savefig("{}optimal_SNR.png".format(basename))
import os
from types import SimpleNamespace
import shutil
import unittest
import bilby
import pandas as pd
import numpy as np
import bilby_pipe
import bilby_pipe.pp_test
class TestPP(unittest.TestCase):
def setUp(self):
self.outdir = "test_outdir"
self.args = SimpleNamespace(
directory=self.outdir, outdir=None, label=None, n=None, print=False
)
os.mkdir(self.outdir)
def tearDown(self):
shutil.rmtree(self.outdir)
del self.outdir
def create_fake_results(self):
self.N_results = 3
self.results_filenames = []
self.priors = bilby.core.prior.PriorDict(
dict(
A=bilby.core.prior.Normal(0, 1, "A"),
B=bilby.core.prior.Normal(0, 1, "B"),
)
)
for i in range(self.N_results):
result = bilby.core.result.Result()
result.outdir = self.outdir
result.label = "label_{}".format(i)
result.search_parameter_keys = ["A", "B"]
result.priors = self.priors
result.posterior = pd.DataFrame(
dict(A=np.random.normal(0, 1, 100), B=np.random.normal(0, 1, 100))
)
result.injection_parameters = dict(A=0, B=0)
result.sampling_time = np.random.uniform(0, 1)
result.meta_data = dict(
likelihood=dict(
interferometers=dict(H1=dict(optimal_SNR=1), L1=dict(optimal_SNR=1))
)
)
filename = "{}/{}_result.json".format(result.outdir, result.label)
result.save_to_file(filename)
self.results_filenames.append(filename)
def test_parser(self):
directory = "directory"
parser = bilby_pipe.pp_test.create_parser()
args = parser.parse_args(
[directory, "--outdir", self.outdir, "--label", "TEST", "-n", "10"]
)
self.assertEqual(args.directory, directory)
self.assertEqual(args.outdir, self.outdir)
self.assertEqual(args.label, "TEST")
self.assertEqual(args.n, 10)
def test_get_results_filename(self):
self.create_fake_results()
results_filenames = bilby_pipe.pp_test.get_results_filenames(self.args)
self.assertEqual(sorted(results_filenames), sorted(self.results_filenames))
def test_get_results_filename_with_n(self):
n = 2
self.create_fake_results()
args = self.args
args.n = n
results_filenames = bilby_pipe.pp_test.get_results_filenames(args)
self.assertEqual(len(results_filenames), n)
def test_get_results_filename_no_file(self):
with self.assertRaises(FileNotFoundError):
bilby_pipe.pp_test.get_results_filenames(self.args)
def test_read_in_result_list(self):
self.create_fake_results()
res = bilby_pipe.pp_test.read_in_result_list(self.args, self.results_filenames)
self.assertEqual(len(res), self.N_results)
self.assertIsInstance(res, bilby.core.result.ResultList)
def test_main(self):
self.create_fake_results()
bilby_pipe.pp_test.main(self.args)
if __name__ == "__main__":
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment