Commit ac609ca1 authored by Colm Talbot's avatar Colm Talbot
Browse files

inital attempt at adding compute roq weights

parent b0486b4e
#!/usr/bin/env python
"""
Script to analyse the stored data
"""
from __future__ import division, print_function
import sys
import os
import numpy as np
import matplotlib
matplotlib.use("agg") # noqa
import bilby
from bilby_pipe.utils import logger
from bilby_pipe.main import DataDump, parse_args
from bilby_pipe.parser import create_parser
from bilby_pipe.input import Input
class ComputeROQWeightsInput(Input):
""" Handles user-input and analysis of intermediate ifo list
Parameters
----------
parser: BilbyArgParser, optional
The parser containing the command line / ini file inputs
args_list: list, optional
A list of the arguments to parse. Defauts to `sys.argv[1:]`
"""
def __init__(self, args, unknown_args):
logger.info("Command line arguments: {}".format(args))
self.ini = args.ini
self.idx = args.idx
self.cluster = args.cluster
self.process = args.process
self.detectors = args.detectors
self.prior_file = args.prior_file
self._priors = None
self.deltaT = args.deltaT
self.reference_frequency = args.reference_frequency
self.waveform_approximant = args.waveform_approximant
self.distance_marginalization = args.distance_marginalization
self.phase_marginalization = args.phase_marginalization
self.time_marginalization = args.time_marginalization
self.sampling_seed = args.sampling_seed
self.sampler = args.sampler
self.sampler_kwargs = args.sampler_kwargs
self.outdir = args.outdir
self.label = args.label
self.data_label = args.data_label
self.default_prior = args.default_prior
self.frequency_domain_source_model = args.frequency_domain_source_model
self.likelihood_type = args.likelihood_type
self.roq_folder = args.roq_folder
self.result = None
@property
def cluster(self):
return self._cluster
@cluster.setter
def cluster(self, cluster):
try:
self._cluster = int(cluster)
except (ValueError, TypeError):
logger.debug("Unable to convert input `cluster` to type int")
self._cluster = cluster
@property
def process(self):
return self._process
@process.setter
def process(self, process):
try:
self._process = int(process)
except (ValueError, TypeError):
logger.debug("Unable to convert input `process` to type int")
self._process = process
@property
def interferometers(self):
try:
return self._interferometers
except AttributeError:
ifos = self.data_dump.interferometers
names = [ifo.name for ifo in ifos]
logger.info("Found data for detectors = {}".format(names))
ifos_to_use = [ifo for ifo in ifos if ifo.name in self.detectors]
names_to_use = [ifo.name for ifo in ifos_to_use]
logger.info("Using data for detectors = {}".format(names_to_use))
self._interferometers = bilby.gw.detector.InterferometerList(ifos_to_use)
self.print_detector_information(self._interferometers)
return self._interferometers
@staticmethod
def print_detector_information(interferometers):
for ifo in interferometers:
logger.info(
"{}: sampling-frequency={}, segment-start-time={}, duration={}".format(
ifo.name,
ifo.strain_data.sampling_frequency,
ifo.strain_data.start_time,
ifo.strain_data.duration,
)
)
@property
def meta_data(self):
return self.data_dump.meta_data
@property
def trigger_time(self):
return self.data_dump.trigger_time
@property
def data_dump(self):
try:
return self._data_dump
except AttributeError:
filename = DataDump.get_filename(
self.data_directory, self.data_label, str(self.idx)
)
self._data_dump = DataDump.from_pickle(filename)
return self._data_dump
@property
def priors(self):
if self._priors is None:
if self.default_prior in bilby.core.prior.__dict__.keys():
self._priors = bilby.core.prior.__dict__[self.default_prior](
filename=self.prior_file
)
elif self.default_prior in bilby.gw.prior.__dict__.keys():
self._priors = bilby.gw.prior.__dict__[self.default_prior](
filename=self.prior_file
)
else:
logger.info("No prior {} found.").format(self.default_prior)
logger.info("Defaulting to BBHPriorDict")
self._priors = bilby.gw.prior.BBHPriorDict(filename=self.prior_file)
if isinstance(
self._priors, (bilby.gw.prior.BBHPriorDict, bilby.gw.prior.BNSPriorDict)
):
self._priors["geocent_time"] = bilby.core.prior.Uniform(
minimum=self.trigger_time - self.deltaT / 2,
maximum=self.trigger_time + self.deltaT / 2,
name="geocent_time",
latex_label="$t_c$",
unit="$s$",
)
return self._priors
@property
def waveform_generator(self):
logger.info(
"Using the ROQ likelihood with roq-folder={}".format(self.roq_folder)
)
freq_nodes_linear = np.load(self.roq_folder + "/fnodes_linear.npy")
freq_nodes_quadratic = np.load(self.roq_folder + "/fnodes_quadratic.npy")
waveform_arguments = self.waveform_arguments.copy()
waveform_arguments["frequency_nodes_linear"] = freq_nodes_linear
waveform_arguments["frequency_nodes_quadratic"] = freq_nodes_quadratic
waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
sampling_frequency=self.interferometers.sampling_frequency,
duration=self.interferometers.duration,
frequency_domain_source_model=bilby.gw.source.roq,
start_time=self.interferometers.start_time,
parameter_conversion=self.parameter_conversion,
waveform_arguments=waveform_arguments,
)
return waveform_generator
@property
def waveform_arguments(self):
return dict(
reference_frequency=self.reference_frequency,
waveform_approximant=self.waveform_approximant,
minimum_frequency=self.interferometers[0].minimum_frequency,
) # FIXME
@property
def likelihood(self):
logger.info(
"Using the ROQ likelihood with roq-folder={}".format(self.roq_folder)
)
if self.time_marginalization:
logger.warning(
"Time marginalization not implemented for "
"ROQGravitationalWaveTransient: option ignored"
)
return bilby.gw.likelihood.ROQGravitationalWaveTransient(
interferometers=self.interferometers,
waveform_generator=self.waveform_generator,
priors=self.priors,
weights=self.roq_weights,
phase_marginalization=self.phase_marginalization,
distance_marginalization=self.distance_marginalization,
)
@property
def roq_weights_file(self):
return os.path.join(
self.data_directory, "roq_weights_{}.json".format(self.idx))
def create_roq_weight_parser():
return create_parser(
pipe_args=False,
job_args=True,
run_spec=True,
pe_summary=False,
injection=False,
data_gen=False,
waveform=True,
generation=False,
analysis=False,
)
def main():
args, unknown_args = parse_args(sys.argv[1:], create_roq_weight_parser())
if args.roq_folder is not None:
roq_weights = ComputeROQWeightsInput(args, unknown_args)
roq_weights.likelihood.save_weights(roq_weights.roq_weights_file)
......@@ -286,11 +286,8 @@ class DataAnalysisInput(Input):
elif self.likelihood_type == "ROQGravitationalWaveTransient":
logger.info(
"Using the ROQ likelihood with roq-folder={}".format(self.roq_folder)
"Loading ROQ weights from {}".format(self.roq_weights)
)
basis_matrix_linear = np.load(self.roq_folder + "/B_linear.npy").T
basic_matrix_quadratic = np.load(self.roq_folder + "/B_quadratic.npy").T
if self.time_marginalization:
logger.warning(
"Time marginalization not implemented for "
......@@ -300,8 +297,7 @@ class DataAnalysisInput(Input):
interferometers=self.interferometers,
waveform_generator=self.waveform_generator,
priors=self.priors,
linear_matrix=basis_matrix_linear,
quadratic_matrix=basic_matrix_quadratic,
weights=self.roq_weights,
phase_marginalization=self.phase_marginalization,
distance_marginalization=self.distance_marginalization,
)
......@@ -357,6 +353,7 @@ def create_analysis_parser():
data_gen=False,
waveform=True,
generation=False,
roq_weights=False,
analysis=True,
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment