Skip to content
Snippets Groups Projects
Commit be95d907 authored by Michael Williams's avatar Michael Williams
Browse files

Update tests for sampling seed changes

parent 8a028473
No related branches found
No related tags found
1 merge request!459Fix 240 - Sampling seed not passed to samplers
import os
import shutil
import unittest
from unittest.mock import patch
import bilby
from bilby_pipe.data_analysis import DataAnalysisInput, create_analysis_parser
......@@ -61,7 +62,9 @@ class TestDataAnalysisInput(unittest.TestCase):
self.assertEqual(self.inputs.sampler, "dynesty")
def test_set_sampling_kwargs_ini(self):
self.assertEqual(self.inputs.sampler_kwargs, dict(a=1, b=2))
self.assertEqual(
self.inputs.sampler_kwargs, dict(a=1, b=2, sampling_seed=150914)
)
def test_set_sampling_kwargs_direct(self):
self.inputs.sampler_kwargs = "{'a':5, 'b':5}"
......@@ -70,8 +73,11 @@ class TestDataAnalysisInput(unittest.TestCase):
def test_unset_sampling_kwargs(self):
args, unknown_args = parse_args(self.default_args_list, self.parser)
args.sampler_kwargs = None
inputs = DataAnalysisInput(args, unknown_args, test=True)
self.assertEqual(inputs.sampler_kwargs, dict())
args.sampling_seed = None
# This tests the case where the sampling seed is not set
with patch("numpy.random.randint", return_value=170817):
inputs = DataAnalysisInput(args, unknown_args, test=True)
self.assertEqual(inputs.sampler_kwargs, dict(sampling_seed=170817))
def test_set_sampler_kwargs_fail(self):
with self.assertRaises(BilbyPipeError):
......
......@@ -17,3 +17,4 @@ label = label
sampler = nestle
sampler-kwargs = {'a': 1, 'b': 2}
data-label = "DATA"
sampling-seed = 150914
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment