Skip to content
Snippets Groups Projects
Commit 0c5a9b45 authored by Gregory Ashton's avatar Gregory Ashton Committed by Avi Vajpeyi
Browse files

Add the option to use the ROQ

Additonally allows users to specify their own likelihood as in
bilby_pipe and improves the meta-data read through.
parent 9a6c526c
No related branches found
No related tags found
No related merge requests found
......@@ -3,12 +3,14 @@
Module to run parallel bilby using MPI
"""
import datetime
import inspect
import json
import logging
import os
import pickle
import sys
import time
from importlib import import_module
import bilby
import dill
......@@ -40,6 +42,115 @@ mpi4py.rc.recv_mprobe = False
logger = bilby.core.utils.logger
def roq_likelihood_kwargs(args):
"""Return the kwargs required for the ROQ setup
Parameters
----------
args: Namespace
The parser arguments
Returns
-------
kwargs: dict
A dictionary of the required kwargs
"""
if hasattr(args, "likelihood_roq_params"):
params = args.likelihood_roq_params
else:
params = np.genfromtxt(args.roq_folder + "/params.dat", names=True)
if hasattr(args, "likelihood_roq_weights"):
weights = args.likelihood_roq_weights
else:
weights = args.weight_file
logger.info("Loading ROQ weights from {}".format(weights))
return dict(
weights=weights, roq_params=params, roq_scale_factor=args.roq_scale_factor
)
def setup_likelihood(interferometers, waveform_generator, priors, args):
"""Takes the kwargs and sets up and returns either an ROQ GW or GW likelihood.
Parameters
----------
interferometers: bilby.gw.detectors.InterferometerList
The pre-loaded bilby IFO
waveform_generator: bilby.gw.waveform_generator.WaveformGenerator
The waveform generation
priors: dict
The priors, used for setting up marginalization
args: Namespace
The parser arguments
Returns
-------
likelihood: bilby.gw.likelihood.GravitationalWaveTransient
The likelihood (either GravitationalWaveTransient or ROQGravitationalWaveTransient)
"""
search_priors = priors.copy()
likelihood_kwargs = dict(
interferometers=interferometers,
waveform_generator=waveform_generator,
priors=search_priors,
phase_marginalization=args.phase_marginalization,
distance_marginalization=args.distance_marginalization,
distance_marginalization_lookup_table=args.distance_marginalization_lookup_table,
time_marginalization=args.time_marginalization,
reference_frame=args.reference_frame,
time_reference=args.time_reference,
)
if args.likelihood_type == "GravitationalWaveTransient":
Likelihood = bilby.gw.likelihood.GravitationalWaveTransient
likelihood_kwargs.update(jitter_time=args.jitter_time)
elif args.likelihood_type == "ROQGravitationalWaveTransient":
Likelihood = bilby.gw.likelihood.ROQGravitationalWaveTransient
if args.time_marginalization:
logger.warning(
"Time marginalization not implemented for "
"ROQGravitationalWaveTransient: option ignored"
)
likelihood_kwargs.pop("time_marginalization", None)
likelihood_kwargs.pop("jitter_time", None)
likelihood_kwargs.update(roq_likelihood_kwargs(args))
elif "." in args.likelihood_type:
split_path = args.likelihood_type.split(".")
module = ".".join(split_path[:-1])
likelihood_class = split_path[-1]
Likelihood = getattr(import_module(module), likelihood_class)
likelihood_kwargs.update(args.extra_likelihood_kwargs)
if "roq" in args.likelihood_type.lower():
likelihood_kwargs.pop("time_marginalization", None)
likelihood_kwargs.pop("jitter_time", None)
likelihood_kwargs.update(args.roq_likelihood_kwargs)
else:
raise ValueError("Unknown Likelihood class {}")
likelihood_kwargs = {
key: likelihood_kwargs[key]
for key in likelihood_kwargs
if key in inspect.getfullargspec(Likelihood.__init__).args
}
logger.info(
"Initialise likelihood {} with kwargs: \n{}".format(
Likelihood, likelihood_kwargs
)
)
return Likelihood(**likelihood_kwargs)
def main():
""" Do nothing function to play nicely with MPI """
pass
......@@ -227,6 +338,8 @@ waveform_generator.start_time = ifo_list[0].time_array[0]
args = data_dump["args"]
injection_parameters = data_dump.get("injection_parameters", None)
args.weight_file = data_dump["meta_data"].get("weight_file", None)
outdir = args.outdir
if input_args.outdir is not None:
outdir = input_args.outdir
......@@ -237,17 +350,11 @@ if input_args.label is not None:
priors = bilby.gw.prior.PriorDict.from_json(data_dump["prior_file"])
logger.setLevel(logging.WARNING)
likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
ifo_list,
waveform_generator,
likelihood = setup_likelihood(
interferometers=ifo_list,
waveform_generator=waveform_generator,
priors=priors,
time_marginalization=args.time_marginalization,
phase_marginalization=args.phase_marginalization,
distance_marginalization=args.distance_marginalization,
distance_marginalization_lookup_table=args.distance_marginalization_lookup_table,
jitter_time=args.jitter_time,
reference_frame=args.reference_frame,
time_reference=args.time_reference,
args=args,
)
logger.setLevel(logging.INFO)
......
......@@ -44,6 +44,8 @@ def main():
args = generation_parser.parse_args(args=cli_args)
args = add_extra_args_from_bilby_pipe_namespace(args)
inputs = bilby_pipe_datagen.DataGenerationInput(args, [])
if inputs.likelihood_type == "ROQGravitationalWaveTransient":
inputs.save_roq_weights()
inputs.log_directory = None
shutil.rmtree(inputs.data_generation_log_directory) # Hack to remove unused dir
......@@ -69,13 +71,16 @@ def main():
data_dump_file = f"{data_dir}/{label}_data_dump.pickle"
meta_data = dict(
config_file=args.ini,
data_dump_file=data_dump_file,
bilby_version=bilby.__version__,
bilby_pipe_version=bilby_pipe.__version__,
parallel_bilby_version=__version__,
dynesty_version=dynesty.__version__,
meta_data = inputs.meta_data
meta_data.update(
dict(
config_file=args.ini,
data_dump_file=data_dump_file,
bilby_version=bilby.__version__,
bilby_pipe_version=bilby_pipe.__version__,
parallel_bilby_version=__version__,
dynesty_version=dynesty.__version__,
)
)
logger.info("Initial meta_data = {}".format(meta_data))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment