Commit e2e56e7d authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Add multiprocessing native support

1) Writes the sampler_kwargs, fully expanded (i.e. "default" gets
expanded into its actual values) into the complete config: better
reproducibility
2) Update the sampler_kwargs if request-cpu > 1 based on the sampler
3) Moved sampler_kwargs logic into inputs to enable the above changes
parent 3330984f
......@@ -14,10 +14,8 @@ from bilby_pipe.main import parse_args
from bilby_pipe.parser import create_parser
from bilby_pipe.utils import (
CHECKPOINT_EXIT_CODE,
SAMPLER_SETTINGS,
BilbyPipeError,
DataDump,
convert_string_to_dict,
log_version_information,
logger,
)
......@@ -57,6 +55,7 @@ class DataAnalysisInput(Input):
self.cluster = args.cluster
self.process = args.process
self.periodic_restart_time = args.periodic_restart_time
self.request_cpus = args.request_cpus
# Naming arguments
self.outdir = args.outdir
......@@ -176,24 +175,6 @@ class DataAnalysisInput(Input):
"more than one element: {}. Unable to proceed".format(sampler)
)
@property
def sampler_kwargs(self):
return self._sampler_kwargs
@sampler_kwargs.setter
def sampler_kwargs(self, sampler_kwargs):
if sampler_kwargs is not None:
if sampler_kwargs.lower() == "default":
self._sampler_kwargs = SAMPLER_SETTINGS["Default"]
elif sampler_kwargs.lower() == "fasttest":
self._sampler_kwargs = SAMPLER_SETTINGS["FastTest"]
else:
self._sampler_kwargs = convert_string_to_dict(
sampler_kwargs, "sampler-kwargs"
)
else:
self._sampler_kwargs = dict()
@property
def interferometers(self):
try:
......
......@@ -14,6 +14,7 @@ import bilby
from . import utils
from .utils import (
SAMPLER_SETTINGS,
BilbyPipeError,
BilbyPipeInternalError,
convert_string_to_dict,
......@@ -1143,3 +1144,38 @@ class Input(object):
self._postprocessing_arguments = postprocessing_arguments.split(" ")
else:
self._postprocessing_arguments = postprocessing_arguments
@property
def sampler_kwargs(self):
return self._sampler_kwargs
@sampler_kwargs.setter
def sampler_kwargs(self, sampler_kwargs):
if sampler_kwargs is not None:
if sampler_kwargs.lower() == "default":
self._sampler_kwargs = SAMPLER_SETTINGS["Default"]
elif sampler_kwargs.lower() == "fasttest":
self._sampler_kwargs = SAMPLER_SETTINGS["FastTest"]
else:
self._sampler_kwargs = convert_string_to_dict(
sampler_kwargs, "sampler-kwargs"
)
else:
self._sampler_kwargs = dict()
self.update_sampler_kwargs_conditional_on_request_cpus()
def update_sampler_kwargs_conditional_on_request_cpus(self):
""" If the user adds request-cpu >1, update kwargs based on the sampler """
# Keys are samplers, values are the dictionary inputs to update
parallelisation_dict = dict(
dynesty=dict(queue_size=self.request_cpus),
ptemcee=dict(threads=self.request_cpus),
cpnest=dict(nthreads=self.request_cpus),
)
# Only run if request_cpus > 1
if self.request_cpus > 1:
# Only update if parallelisation_dict contains the sampler
self._sampler_kwargs.update(parallelisation_dict.get(self.sampler, dict()))
......@@ -121,6 +121,7 @@ class MainInput(Input):
self.request_memory = args.request_memory
self.request_memory_generation = args.request_memory_generation
self.request_cpus = args.request_cpus
self.sampler_kwargs = args.sampler_kwargs
if self.create_plots:
for plot_attr in [
......@@ -1019,6 +1020,7 @@ def write_complete_config_file(parser, args, inputs):
if isinstance(val, list):
if isinstance(val[0], str):
setattr(args, key, "[{}]".format(", ".join(val)))
args.sampler_kwargs = str(inputs.sampler_kwargs)
parser.write_to_file(
filename=inputs.complete_ini_file,
args=args,
......
......@@ -72,6 +72,7 @@ class TestSlurm(unittest.TestCase):
n_simulation=1,
log_directory=None,
osg=True,
sampler_kwargs="{}",
)
self.test_unknown_args = ["--argument", "value"]
self.inputs = bilby_pipe.main.MainInput(self.test_args, self.test_unknown_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment