Commit ef6bbe06 authored by Gregory Ashton's avatar Gregory Ashton

Merge branch '191-prior-dict-not-working-with-multi-line' into 'master'

Resolve "prior-dict not working with multi-line"

Closes #191

See merge request lscsoft/bilby_pipe!369
parents b382bc1f c7921d7c
......@@ -203,11 +203,17 @@ class BilbyConfigFileParser(configargparse.DefaultConfigFileParser):
def parse(self, stream):
"""Parses the keys + values from a config file."""
# Pre-process lines to put multi-line dicts on single linesj
lines = "".join(list(stream)) # Form single string
lines = lines.replace(",\n", ", ") # Multiline args on single lines
lines = lines.replace("\n}\n", "}\n") # Trailing } on single lines
lines = lines.split("\n")
items = dict()
numbers = dict()
comments = dict()
inline_comments = dict()
for ii, line in enumerate(stream):
for ii, line in enumerate(lines):
line = line.strip()
if not line:
continue
......
......@@ -737,7 +737,14 @@ def create_parser(top_level=True):
"--prior-file", type=nonestr, default=None, help="The prior file"
)
prior_parser_main.add(
"--prior-dict", type=nonestr, default=None, help="A dictionary of priors"
"--prior-dict",
type=nonestr,
default=None,
help=(
"A dictionary of priors (alternative to prior-file). Multiline "
"dictionaries are supported, but each line must contain a single"
"parameter specification and finish with a comma."
),
)
prior_parser.add(
"--convert-to-flat-in-component-mass",
......
......@@ -19,8 +19,8 @@ n-parallel = 4
prior-dict = {
mass_1 = Constraint(name='mass_1', minimum=10, maximum=80),
mass_2 = Constraint(name='mass_2', minimum=10, maximum=80),
mass_ratio = Uniform(name='mass_ratio', minimum=0.125, maximum=1),
chirp_mass = Uniform(name='chirp_mass', minimum=25, maximum=40),
mass_ratio = Uniform(name='mass_ratio', minimum=0.125, maximum=1, latex_label="$q$"),
chirp_mass = Uniform(name='chirp_mass', minimum=25, maximum=40, latex_label="$M_{c}$"),
a_1 = Uniform(name='a_1', minimum=0, maximum=0.99),
a_2 = Uniform(name='a_2', minimum=0, maximum=0.99),
tilt_1 = Sine(name='tilt_1'),
......
......@@ -4,11 +4,12 @@ import sys
import unittest
from unittest.mock import patch
import bilby
from bilby_pipe.bilbyargparser import BilbyArgParser
from bilby_pipe.data_analysis import create_analysis_parser
from bilby_pipe.main import parse_args
from bilby_pipe.parser import create_parser
from bilby_pipe.utils import convert_string_to_dict
from bilby_pipe.utils import convert_prior_string_input, convert_string_to_dict
class TestBilbyArgParser(unittest.TestCase):
......@@ -150,6 +151,7 @@ class TestBilbyConfigFileParser(unittest.TestCase):
self.write_tempory_ini_file(lines)
args, unknown_args = parse_args([self.test_ini_filename], self.parser)
self.assertEqual(args.prior_dict, kwargs_str)
self.assertEqual(unknown_args, [])
def test_prior_dict_multiline(self):
kwargs_str = "{a: Uniform(name='a', minimum=0, maximum=1), b: 1}"
......@@ -157,6 +159,117 @@ class TestBilbyConfigFileParser(unittest.TestCase):
self.write_tempory_ini_file(lines)
args, unknown_args = parse_args([self.test_ini_filename], self.parser)
self.assertEqual(args.prior_dict, kwargs_str)
self.assertEqual(unknown_args, [])
def test_prior_dict_multiline_complicated1(self):
expected_prior = bilby.core.prior.PriorDict(
dict(
a=bilby.core.prior.Uniform(name="a", minimum=0, maximum=1),
b=1,
c=2,
redshift=bilby.gw.prior.UniformSourceFrame(
name="redshift",
minimum=1,
maximum=10,
latex_label=r"$\rm{log}_{10}(M_{Lz}/\rm M_\odot))$",
),
)
)
lines = [
"prior-dict: {a: Uniform(name='a', minimum=0, maximum=1),",
"b: 1,",
" c: 2,",
r"redshift: bilby.gw.prior.UniformSourceFrame(name='redshift',"
+ r"minimum=1, maximum=10, latex_label='$\rm{log}_{10}(M_{Lz}/\rm M_\odot))$'",
"}",
]
self.write_tempory_ini_file(lines)
args, unknown_args = parse_args([self.test_ini_filename], self.parser)
prior = bilby.core.prior.PriorDict(convert_prior_string_input(args.prior_dict))
self.assertEqual(expected_prior, prior)
self.assertEqual(unknown_args, [])
def test_prior_dict_multiline_complicated2(self):
expected_prior = bilby.core.prior.PriorDict(
dict(
a=bilby.core.prior.Uniform(name="a", minimum=0, maximum=1),
b=1,
c=2,
redshift=bilby.gw.prior.UniformSourceFrame(
name="redshift",
minimum=1,
maximum=10,
latex_label=r"$\rm{log}_{10}(M_{Lz}/\rm M_\odot))$",
),
)
)
lines = [
"prior-dict: {a: Uniform(name='a', minimum=0, maximum=1),",
"b: 1,",
" c: 2,",
r"redshift: bilby.gw.prior.UniformSourceFrame(name='redshift',"
+ r"minimum=1, maximum=10, latex_label='$\rm{log}_{10}(M_{Lz}/\rm M_\odot))$',",
"}",
]
self.write_tempory_ini_file(lines)
args, unknown_args = parse_args([self.test_ini_filename], self.parser)
prior = bilby.core.prior.PriorDict(convert_prior_string_input(args.prior_dict))
self.assertEqual(expected_prior, prior)
self.assertEqual(unknown_args, [])
def test_prior_dict_multiline_complicated3(self):
expected_prior = bilby.core.prior.PriorDict(
dict(
a=bilby.core.prior.Uniform(name="a", minimum=0, maximum=1),
b=1,
c=2,
redshift=bilby.gw.prior.UniformSourceFrame(
name="redshift",
minimum=1,
maximum=10,
latex_label=r"$\rm{log}_{10}(M_{Lz}/\rm M_\odot))$",
),
)
)
lines = [
"prior-dict: {a: Uniform(name='a', minimum=0, maximum=1),",
"b: 1,",
" c: 2,",
r"redshift: bilby.gw.prior.UniformSourceFrame(name='redshift',"
+ r"minimum=1, maximum=10, latex_label='$\rm{log}_{10}(M_{Lz}/\rm M_\odot))$'}",
]
self.write_tempory_ini_file(lines)
args, unknown_args = parse_args([self.test_ini_filename], self.parser)
prior = bilby.core.prior.PriorDict(convert_prior_string_input(args.prior_dict))
self.assertEqual(expected_prior, prior)
self.assertEqual(unknown_args, [])
def test_prior_dict_multiline_complicated4(self):
expected_prior = bilby.core.prior.PriorDict(
dict(
a=bilby.core.prior.Uniform(name="a", minimum=0, maximum=1),
b=1,
c=2,
redshift=bilby.gw.prior.UniformSourceFrame(
name="redshift",
minimum=1,
maximum=10,
latex_label=r"$\rm{log}_{10}(M_{Lz}/\rm M_\odot))$",
),
)
)
lines = [
"prior-dict: {a: Uniform(name='a', minimum=0, maximum=1),",
"b: 1,",
r"redshift: bilby.gw.prior.UniformSourceFrame(name='redshift',"
+ r"minimum=1, maximum=10, latex_label='$\rm{log}_{10}(M_{Lz}/\rm M_\odot))$',"
" c: 2}",
]
self.write_tempory_ini_file(lines)
args, unknown_args = parse_args([self.test_ini_filename], self.parser)
prior = bilby.core.prior.PriorDict(convert_prior_string_input(args.prior_dict))
self.assertEqual(expected_prior, prior)
self.assertEqual(unknown_args, [])
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment