Skip to content
Snippets Groups Projects

Add the option to use the ROQ

Merged Gregory Ashton requested to merge add-roq into master
Files
2
+ 117
10
@@ -3,12 +3,14 @@
Module to run parallel bilby using MPI
"""
import datetime
import inspect
import json
import logging
import os
import pickle
import sys
import time
from importlib import import_module
import bilby
import dill
@@ -40,6 +42,115 @@ mpi4py.rc.recv_mprobe = False
logger = bilby.core.utils.logger
def roq_likelihood_kwargs(args):
"""Return the kwargs required for the ROQ setup
Parameters
----------
args: Namespace
The parser arguments
Returns
-------
kwargs: dict
A dictionary of the required kwargs
"""
if hasattr(args, "likelihood_roq_params"):
params = args.likelihood_roq_params
else:
params = np.genfromtxt(args.roq_folder + "/params.dat", names=True)
if hasattr(args, "likelihood_roq_weights"):
weights = args.likelihood_roq_weights
else:
weights = args.weight_file
logger.info("Loading ROQ weights from {}".format(weights))
return dict(
weights=weights, roq_params=params, roq_scale_factor=args.roq_scale_factor
)
def setup_likelihood(interferometers, waveform_generator, priors, args):
"""Takes the kwargs and sets up and returns either an ROQ GW or GW likelihood.
Parameters
----------
interferometers: bilby.gw.detectors.InterferometerList
The pre-loaded bilby IFO
waveform_generator: bilby.gw.waveform_generator.WaveformGenerator
The waveform generation
priors: dict
The priors, used for setting up marginalization
args: Namespace
The parser arguments
Returns
-------
likelihood: bilby.gw.likelihood.GravitationalWaveTransient
The likelihood (either GravitationalWaveTransient or ROQGravitationalWaveTransient)
"""
search_priors = priors.copy()
likelihood_kwargs = dict(
interferometers=interferometers,
waveform_generator=waveform_generator,
priors=search_priors,
phase_marginalization=args.phase_marginalization,
distance_marginalization=args.distance_marginalization,
distance_marginalization_lookup_table=args.distance_marginalization_lookup_table,
time_marginalization=args.time_marginalization,
reference_frame=args.reference_frame,
time_reference=args.time_reference,
)
if args.likelihood_type == "GravitationalWaveTransient":
Likelihood = bilby.gw.likelihood.GravitationalWaveTransient
likelihood_kwargs.update(jitter_time=args.jitter_time)
elif args.likelihood_type == "ROQGravitationalWaveTransient":
Likelihood = bilby.gw.likelihood.ROQGravitationalWaveTransient
if args.time_marginalization:
logger.warning(
"Time marginalization not implemented for "
"ROQGravitationalWaveTransient: option ignored"
)
likelihood_kwargs.pop("time_marginalization", None)
likelihood_kwargs.pop("jitter_time", None)
likelihood_kwargs.update(roq_likelihood_kwargs(args))
elif "." in args.likelihood_type:
split_path = args.likelihood_type.split(".")
module = ".".join(split_path[:-1])
likelihood_class = split_path[-1]
Likelihood = getattr(import_module(module), likelihood_class)
likelihood_kwargs.update(args.extra_likelihood_kwargs)
if "roq" in args.likelihood_type.lower():
likelihood_kwargs.pop("time_marginalization", None)
likelihood_kwargs.pop("jitter_time", None)
likelihood_kwargs.update(args.roq_likelihood_kwargs)
else:
raise ValueError("Unknown Likelihood class {}")
likelihood_kwargs = {
key: likelihood_kwargs[key]
for key in likelihood_kwargs
if key in inspect.getfullargspec(Likelihood.__init__).args
}
logger.info(
"Initialise likelihood {} with kwargs: \n{}".format(
Likelihood, likelihood_kwargs
)
)
return Likelihood(**likelihood_kwargs)
def main():
""" Do nothing function to play nicely with MPI """
pass
@@ -227,6 +338,8 @@ waveform_generator.start_time = ifo_list[0].time_array[0]
args = data_dump["args"]
injection_parameters = data_dump.get("injection_parameters", None)
args.weight_file = data_dump["meta_data"].get("weight_file", None)
outdir = args.outdir
if input_args.outdir is not None:
outdir = input_args.outdir
@@ -237,17 +350,11 @@ if input_args.label is not None:
priors = bilby.gw.prior.PriorDict.from_json(data_dump["prior_file"])
logger.setLevel(logging.WARNING)
likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
ifo_list,
waveform_generator,
likelihood = setup_likelihood(
interferometers=ifo_list,
waveform_generator=waveform_generator,
priors=priors,
time_marginalization=args.time_marginalization,
phase_marginalization=args.phase_marginalization,
distance_marginalization=args.distance_marginalization,
distance_marginalization_lookup_table=args.distance_marginalization_lookup_table,
jitter_time=args.jitter_time,
reference_frame=args.reference_frame,
time_reference=args.time_reference,
args=args,
)
logger.setLevel(logging.INFO)
Loading