Skip to content
Snippets Groups Projects
Commit dd02cacc authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'fix-240-sampling-seed' into 'master'

Fix 240 - Sampling seed not passed to samplers

Closes #240

See merge request !459
parents 7600db53 bad2ca68
No related branches found
No related tags found
1 merge request!459Fix 240 - Sampling seed not passed to samplers
Pipeline #450551 passed
......@@ -115,20 +115,23 @@ class DataAnalysisInput(Input):
@property
def sampling_seed(self):
return self._samplng_seed
return self._sampling_seed
@sampling_seed.setter
def sampling_seed(self, sampling_seed):
if sampling_seed is None:
sampling_seed = np.random.randint(1, 1e6)
self._samplng_seed = sampling_seed
self._sampling_seed = sampling_seed
np.random.seed(sampling_seed)
logger.info(f"Sampling seed set to {sampling_seed}")
if self.sampler == "cpnest":
self.sampler_kwargs["seed"] = self.sampler_kwargs.get(
"seed", self._samplng_seed
)
if not any(
[
k in self.sampler_kwargs
for k in bilby.core.sampler.Sampler.sampling_seed_equiv_kwargs
]
):
self.sampler_kwargs["sampling_seed"] = self._sampling_seed
@property
def interferometers(self):
......
import os
import shutil
import unittest
from unittest.mock import patch
import bilby
from bilby_pipe.data_analysis import DataAnalysisInput, create_analysis_parser
......@@ -61,7 +62,9 @@ class TestDataAnalysisInput(unittest.TestCase):
self.assertEqual(self.inputs.sampler, "dynesty")
def test_set_sampling_kwargs_ini(self):
self.assertEqual(self.inputs.sampler_kwargs, dict(a=1, b=2))
self.assertEqual(
self.inputs.sampler_kwargs, dict(a=1, b=2, sampling_seed=150914)
)
def test_set_sampling_kwargs_direct(self):
self.inputs.sampler_kwargs = "{'a':5, 'b':5}"
......@@ -70,8 +73,11 @@ class TestDataAnalysisInput(unittest.TestCase):
def test_unset_sampling_kwargs(self):
args, unknown_args = parse_args(self.default_args_list, self.parser)
args.sampler_kwargs = None
inputs = DataAnalysisInput(args, unknown_args, test=True)
self.assertEqual(inputs.sampler_kwargs, dict())
args.sampling_seed = None
# This tests the case where the sampling seed is not set
with patch("numpy.random.randint", return_value=170817):
inputs = DataAnalysisInput(args, unknown_args, test=True)
self.assertEqual(inputs.sampler_kwargs, dict(sampling_seed=170817))
def test_set_sampler_kwargs_fail(self):
with self.assertRaises(BilbyPipeError):
......
......@@ -17,3 +17,4 @@ label = label
sampler = nestle
sampler-kwargs = {'a': 1, 'b': 2}
data-label = "DATA"
sampling-seed = 150914
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment