Skip to content
Snippets Groups Projects
Commit 642120c9 authored by Avi Vajpeyi's avatar Avi Vajpeyi :alien:
Browse files

Merge branch 'add-roq' into 'master'

Add the option to use the ROQ

See merge request !60
parents 9a6c526c 0c5a9b45
No related branches found
No related tags found
1 merge request!60Add the option to use the ROQ
Pipeline #153834 passed with warnings
......@@ -3,12 +3,14 @@
Module to run parallel bilby using MPI
"""
import datetime
import inspect
import json
import logging
import os
import pickle
import sys
import time
from importlib import import_module
import bilby
import dill
......@@ -40,6 +42,115 @@ mpi4py.rc.recv_mprobe = False
logger = bilby.core.utils.logger
def roq_likelihood_kwargs(args):
"""Return the kwargs required for the ROQ setup
Parameters
----------
args: Namespace
The parser arguments
Returns
-------
kwargs: dict
A dictionary of the required kwargs
"""
if hasattr(args, "likelihood_roq_params"):
params = args.likelihood_roq_params
else:
params = np.genfromtxt(args.roq_folder + "/params.dat", names=True)
if hasattr(args, "likelihood_roq_weights"):
weights = args.likelihood_roq_weights
else:
weights = args.weight_file
logger.info("Loading ROQ weights from {}".format(weights))
return dict(
weights=weights, roq_params=params, roq_scale_factor=args.roq_scale_factor
)
def setup_likelihood(interferometers, waveform_generator, priors, args):
"""Takes the kwargs and sets up and returns either an ROQ GW or GW likelihood.
Parameters
----------
interferometers: bilby.gw.detectors.InterferometerList
The pre-loaded bilby IFO
waveform_generator: bilby.gw.waveform_generator.WaveformGenerator
The waveform generation
priors: dict
The priors, used for setting up marginalization
args: Namespace
The parser arguments
Returns
-------
likelihood: bilby.gw.likelihood.GravitationalWaveTransient
The likelihood (either GravitationalWaveTransient or ROQGravitationalWaveTransient)
"""
search_priors = priors.copy()
likelihood_kwargs = dict(
interferometers=interferometers,
waveform_generator=waveform_generator,
priors=search_priors,
phase_marginalization=args.phase_marginalization,
distance_marginalization=args.distance_marginalization,
distance_marginalization_lookup_table=args.distance_marginalization_lookup_table,
time_marginalization=args.time_marginalization,
reference_frame=args.reference_frame,
time_reference=args.time_reference,
)
if args.likelihood_type == "GravitationalWaveTransient":
Likelihood = bilby.gw.likelihood.GravitationalWaveTransient
likelihood_kwargs.update(jitter_time=args.jitter_time)
elif args.likelihood_type == "ROQGravitationalWaveTransient":
Likelihood = bilby.gw.likelihood.ROQGravitationalWaveTransient
if args.time_marginalization:
logger.warning(
"Time marginalization not implemented for "
"ROQGravitationalWaveTransient: option ignored"
)
likelihood_kwargs.pop("time_marginalization", None)
likelihood_kwargs.pop("jitter_time", None)
likelihood_kwargs.update(roq_likelihood_kwargs(args))
elif "." in args.likelihood_type:
split_path = args.likelihood_type.split(".")
module = ".".join(split_path[:-1])
likelihood_class = split_path[-1]
Likelihood = getattr(import_module(module), likelihood_class)
likelihood_kwargs.update(args.extra_likelihood_kwargs)
if "roq" in args.likelihood_type.lower():
likelihood_kwargs.pop("time_marginalization", None)
likelihood_kwargs.pop("jitter_time", None)
likelihood_kwargs.update(args.roq_likelihood_kwargs)
else:
raise ValueError("Unknown Likelihood class {}")
likelihood_kwargs = {
key: likelihood_kwargs[key]
for key in likelihood_kwargs
if key in inspect.getfullargspec(Likelihood.__init__).args
}
logger.info(
"Initialise likelihood {} with kwargs: \n{}".format(
Likelihood, likelihood_kwargs
)
)
return Likelihood(**likelihood_kwargs)
def main():
""" Do nothing function to play nicely with MPI """
pass
......@@ -227,6 +338,8 @@ waveform_generator.start_time = ifo_list[0].time_array[0]
args = data_dump["args"]
injection_parameters = data_dump.get("injection_parameters", None)
args.weight_file = data_dump["meta_data"].get("weight_file", None)
outdir = args.outdir
if input_args.outdir is not None:
outdir = input_args.outdir
......@@ -237,17 +350,11 @@ if input_args.label is not None:
priors = bilby.gw.prior.PriorDict.from_json(data_dump["prior_file"])
logger.setLevel(logging.WARNING)
likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
ifo_list,
waveform_generator,
likelihood = setup_likelihood(
interferometers=ifo_list,
waveform_generator=waveform_generator,
priors=priors,
time_marginalization=args.time_marginalization,
phase_marginalization=args.phase_marginalization,
distance_marginalization=args.distance_marginalization,
distance_marginalization_lookup_table=args.distance_marginalization_lookup_table,
jitter_time=args.jitter_time,
reference_frame=args.reference_frame,
time_reference=args.time_reference,
args=args,
)
logger.setLevel(logging.INFO)
......
......@@ -44,6 +44,8 @@ def main():
args = generation_parser.parse_args(args=cli_args)
args = add_extra_args_from_bilby_pipe_namespace(args)
inputs = bilby_pipe_datagen.DataGenerationInput(args, [])
if inputs.likelihood_type == "ROQGravitationalWaveTransient":
inputs.save_roq_weights()
inputs.log_directory = None
shutil.rmtree(inputs.data_generation_log_directory) # Hack to remove unused dir
......@@ -69,13 +71,16 @@ def main():
data_dump_file = f"{data_dir}/{label}_data_dump.pickle"
meta_data = dict(
config_file=args.ini,
data_dump_file=data_dump_file,
bilby_version=bilby.__version__,
bilby_pipe_version=bilby_pipe.__version__,
parallel_bilby_version=__version__,
dynesty_version=dynesty.__version__,
meta_data = inputs.meta_data
meta_data.update(
dict(
config_file=args.ini,
data_dump_file=data_dump_file,
bilby_version=bilby.__version__,
bilby_pipe_version=bilby_pipe.__version__,
parallel_bilby_version=__version__,
dynesty_version=dynesty.__version__,
)
)
logger.info("Initial meta_data = {}".format(meta_data))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment