Skip to content
Snippets Groups Projects
Commit 7c080692 authored by Patrick Godwin's avatar Patrick Godwin Committed by Madeline Wade
Browse files

gstlal_inspiral_calc_snr, svd_bank_snr.py: convert to Stream API, python3 fixes

parent 79b63052
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python3
#
# Copyright (C) 2019-2020 ChiWai Chan
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
"""
Typical Usages:
......@@ -93,12 +109,17 @@ $ gstlal_inspiral_calc_snr \
--verbose
"""
from collections import defaultdict
from functools import reduce
from optparse import OptionParser, OptionGroup, IndentedHelpFormatter
import math
import os
import sys
import time
import gi
gi.require_version('Gst', '1.0')
from gi.repository import Gst
from gstlal import datasource
from gstlal import far
from gstlal import inspiral
......@@ -111,13 +132,7 @@ from gstlal import svd_bank
from gstlal import svd_bank_snr
from gstlal.psd import read_psd
from gstlal.stats.inspiral_lr import LnLRDensity
import gi
gi.require_version('Gst', '1.0')
gi.require_version('GstAudio', '1.0')
from gi.repository import GObject, Gst, GstAudio
GObject.threads_init()
Gst.init(None)
from gstlal.stream import MessageType, Stream
import lal
from lal import LIGOTimeGPS
......@@ -407,9 +422,6 @@ else:
if options.verbose:
sys.stderr.write("Building pipeline...\n")
pipeline = Gst.Pipeline(name = "gstlal_inspiral_SNR")
mainloop = GObject.MainLoop()
#
# Construct Default Pipeline Handler
#
......@@ -424,83 +436,95 @@ for instrument in gw_data_source_info.channel_dict:
snr_document = svd_bank_snr.SignalNoiseRatioDocument(bank_snrs_dict, verbose = options.verbose)
if options.coinc_output == None:
handler = svd_bank_snr.SimpleSNRHandler(pipeline, mainloop, snr_document, verbose = options.verbose)
else:
handler = svd_bank_snr.Handler(snr_document, verbose = options.verbose)
snr_appsync = pipeparts.AppSync(appsink_new_buffer = handler.appsink_new_snr_buffer)
#
# Construct Pipeline
#
itacac_dict = {}
for instrument in gw_data_source_info.channel_dict:
src, statevector, dqvector = datasource.mkbasicsrc(pipeline, gw_data_source_info, instrument, options.verbose)
hoft = multirate_datasource.mkwhitened_multirate_src(
pipeline,
src,
set(rate for bank_SNRs in bank_snrs_dict[instrument] for rate in bank_SNRs.bank.get_rates()),
stream = Stream.from_datasource(
gw_data_source_info,
gw_data_source_info.channel_dict.keys(),
state_vector=True,
dq_vector=True,
verbose=options.verbose,
)
for instrument, head in stream.items():
rates = set(rate for bank_SNRs in bank_snrs_dict[instrument] for rate in bank_SNRs.bank.get_rates())
head = head.condition(
max(rates),
instrument,
dqvector = dqvector,
fir_whiten_reference_psd = bank.processed_psd,
ht_gate_threshold = options.ht_gate_threshold,
statevector = stream.source.state_vector[instrument],
dqvector = stream.source.dq_vector[instrument],
psd = psds_dict[instrument],
psd_fft_length = options.psd_fft_length,
statevector = statevector,
track_psd = options.track_psd,
fir_whiten_reference_psd = bank.processed_psd,
ht_gate_threshold = options.ht_gate_threshold,
veto_segments = veto_segments,
width = options.output_width
)
for index, bank_SNR in enumerate(bank_snrs_dict[instrument]):
bank = bank_SNR.bank
stream[instrument] = head.multiband(rates, instrument=instrument)
snrs = stream.remap()
if options.coinc_output is not None:
triggers = stream.remap()
for instrument, bank_snrs in bank_snrs_dict.items():
for index, bank_snr in enumerate(bank_snrs):
bank = bank_snr.bank
if options.mode == 0:
snr = lloidparts.mkLLOIDhoftToSnrSlices(
pipeline,
hoft,
snr = stream[instrument].create_snr_slices(
bank,
(None, None),
1 * Gst.SECOND,
control_peak_time = options.control_peak_time,
fir_stride = options.fir_stride,
logname = instrument,
nxydump_segment = None,
reconstruction_segment_list = reconstruction_segment_list,
snrslices = None,
verbose = options.verbose
)
else:
fir_matrix = []
for template in bank.templates:
fir_matrix += [template.real, template.imag]
snr = pipeparts.mktogglecomplex(
pipeline,
pipeparts.mkfirbank(
pipeline,
hoft[bank.sample_rate],
latency = 0,
fir_matrix = fir_matrix,
block_stride = 16 * bank.sample_rate,
time_domain = False
)
)
# Construct SNR handler by default and terminate the pipeline at here
if options.coinc_output == None:
snr_appsync.add_sink(pipeline, snr, name = "%s_%d" % (instrument, index))
snr = stream[instrument][bank.sample_rate].firbank(
latency = 0,
fir_matrix = fir_matrix,
block_stride = 16 * bank.sample_rate,
time_domain = False
)
snr.togglecomplex()
# Construct SNR handler by default and terminate the pipeline here
if options.coinc_output is not None:
snr = snr.tee()
snrs[f"{instrument}_{index:d}"] = snr.queue()
# Construct optional trigger generator
else:
snr = pipeparts.mktee(pipeline, snr)
snr_appsync.add_sink(pipeline, pipeparts.mkqueue(pipeline, snr), name = "%s_%d" % (instrument, index))
if options.coinc_output is not None:
triggers[bank.bank_id][instrument] = snr.queue()
if options.coinc_output is not None:
itacac_props = defaultdict(dict)
for instrument, bank_snrs in bank_snrs_dict.items():
for index, bank_snr in enumerate(bank_snrs):
bank = bank_snr.bank
nsamps_window = 1 * max(bank.get_rates())
if bank.bank_id not in itacac_dict:
itacac_dict[bank.bank_id] = pipeparts.mkgeneric(pipeline, None, "lal_itacac")
head = itacac_dict[bank.bank_id]
pad = head.get_request_pad("sink%d" % len(head.sinkpads))
for prop, val in [("n", nsamps_window), ("snr-thresh", LnLRDensity.snr_min), ("bank_filename", bank.template_bank_filename), ("sigmasq", bank.sigmasq), ("autocorrelation_matrix", pipeio.repack_complex_array_to_real(bank.autocorrelation_bank)), ("autocorrelation_mask", bank.autocorrelation_mask)]:
pad.set_property(prop, val)
pipeparts.mkqueue(pipeline, snr).srcpads[0].link(pad)
itacac_props[bank.bank_id][instrument] = {
"n": nsamps_window,
"snr-thresh": LnLRDensity.snr_min,
"bank_filename": bank.template_bank_filename,
"sigmasq": bank.sigmasq,
"autocorrelation_matrix": pipeio.repack_complex_array_to_real(bank.autocorrelation_bank),
"autocorrelation_mask": bank.autocorrelation_mask,
}
for bank_id, head in triggers.items():
triggers[bank_id] = head.itacac(**itacac_props[bank_id])
if options.coinc_output == None:
tracker = svd_bank_snr.SimpleSNRTracker(snr_document, verbose = options.verbose)
else:
tracker = svd_bank_snr.Tracker(snr_document, verbose = options.verbose)
snrs.bufsink(tracker.on_snr_buffer)
#
# Construct optional LLOID handler instead if --coinc-output is provided
......@@ -562,9 +586,8 @@ if options.coinc_output != None:
verbose = options.verbose
)
handler.init(
mainloop,
pipeline,
tracker.init(
stream,
coincs_document,
rankingstat,
list(banks_dict.values())[0][options.bank_number].horizon_distance_func,
......@@ -589,37 +612,24 @@ if options.coinc_output != None:
verbose = options.verbose
)
assert len(itacac_dict.keys()) >= 1
trigger_appsync = pipeparts.AppSync(appsink_new_buffer = handler.appsink_new_buffer)
trigger_appsinks = set(trigger_appsync.add_sink(pipeline, src, caps = Gst.Caps.from_string("application/x-lal-snglinspiral"), name = "bank_%s_sink" % bank_id) for bank_id, src in itacac_dict.items())
assert len(list(triggers.keys())) >= 1
triggers.bufsink(tracker.on_buffer, caps=Gst.Caps.from_string("application/x-lal-snglinspiral"))
stream.add_callback(MessageType.EOS, tracker.on_eos)
#
# Run pipeline
#
if options.verbose:
sys.stderr.write("Setting pipeline state to READY...\n")
if pipeline.set_state(Gst.State.READY) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline cannot enter ready state.")
datasource.pipeline_seek_for_gps(pipeline, *gw_data_source_info.seg)
if options.verbose:
sys.stderr.write("Seting pipeline state to PLAYING...\n")
if pipeline.set_state(Gst.State.PLAYING) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline cannot enter playing state.")
if options.verbose:
sys.stderr.write("Calculating SNR...\n")
mainloop.run()
stream.start()
if options.verbose:
sys.stderr.write("Calculation done.\n")
if pipeline.set_state(Gst.State.NULL) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline could not be set to NULL.")
# Write outputs
if options.coinc_output != None:
handler.write_output_url(url = options.coinc_output)
handler.write_snrs(options.outdir, row_number = options.row_number, counts = options.row_counts, COMPLEX = options.complex)
tracker.write_output_url(url = options.coinc_output)
tracker.write_snrs(options.outdir, row_number = options.row_number, counts = options.row_counts, COMPLEX = options.complex)
......@@ -271,24 +271,20 @@ class Bank_SNR(object):
tseries.data.data = array
return tseries
class SNRHandlerMixin(object):
def __init__(self, *arg, **kwargs):
super(SNRHandlerMixin, self).__init__(*arg, **kwargs)
class SNRTrackerMixin(object):
#def __init__(self, *arg, **kwargs):
# ...
def appsink_new_snr_buffer(self, elem):
def on_snr_buffer(self, buf):
"""Callback function for SNR appsink."""
with self.lock:
# Note: be sure to set property="%s_%d" % (instrument, index) for appsink element
instrument = elem.name.split("_")[0]
index = int(elem.name.split("_")[1])
instrument = buf.name.split("_")[0]
index = int(buf.name.split("_")[1])
cur_bank = self.snr_document.bank_snrs_dict[instrument][index]
sample = elem.emit("pull-sample")
if sample is None:
return Gst.FlowReturn.OK
success, rate = sample.get_caps().get_structure(0).get_int("rate")
assert success == True
success, rate = buf.caps.get_structure(0).get_int("rate")
assert success
if cur_bank.deltaT is None:
cur_bank.deltaT = 1. / rate
......@@ -296,22 +292,15 @@ class SNRHandlerMixin(object):
# sampling rate should not be changing
assert cur_bank.deltaT == 1. / rate, "Data has different sampling rate."
buf = sample.get_buffer()
if buf.mini_object.flags & Gst.BufferFlags.GAP or buf.n_memory() == 0:
return Gst.FlowReturn.OK
# add the time offset of template end time here, this offset should be the same for each templates
cur_time_stamp = LIGOTimeGPS(0, sample.get_buffer().pts) + cur_bank.sngl_inspiral_table[0].end
cur_time_stamp = buf.t0 + cur_bank.sngl_inspiral_table[0].end
if cur_bank.s >= cur_time_stamp and cur_bank.e > cur_time_stamp:
# record the first timestamp closet to start time
cur_bank.epoch = cur_time_stamp
cur_bank.data = [pipeio.array_from_audio_sample(sample)]
cur_bank.data = [buf.data]
elif cur_bank.s <= cur_time_stamp < cur_bank.e:
cur_bank.data.append(pipeio.array_from_audio_sample(sample))
else:
Gst.FlowReturn.OK
return Gst.FlowReturn.OK
cur_bank.data.append(buf.data)
def write_snrs(self, outdir, row_number=None, counts=1, COMPLEX=False):
"""Writing SNRs timeseries to LIGO_LW xml files."""
......@@ -322,23 +311,22 @@ class SNRHandlerMixin(object):
self.snr_document.write_output_url(outdir, row_number=row_number, counts=counts)
class SimpleSNRHandler(SNRHandlerMixin, simplehandler.Handler):
class SimpleSNRTracker(SNRTrackerMixin):
"""Simple SNR pipeline handler.
This is the SNR pipeline handler derived from simplehandler. It
This is the SNR pipeline handler, which
only implements the controls for collecting SNR timeseries.
"""
def __init__(self, pipeline, mainloop, snr_document, verbose=False):
super(SimpleSNRHandler, self).__init__(mainloop, pipeline)
def __init__(self, snr_document, verbose=False):
self.lock = threading.Lock()
self.snr_document = snr_document
self.verbose = verbose
class Handler(SNRHandlerMixin, lloidhandler.Handler):
"""Simplified version of lloidhandler.Handler.
class Tracker(SNRTrackerMixin, lloidhandler.LLOIDTracker):
"""Simplified version of lloidhandler.LLOIDTracker.
This is the SNR pipeline handler derived from lloidhandler. In
This is the SNR pipeline handler derived from LLOIDTracker. In
addition to the control for collecting SNR timeseries, it
implements controls for trigger generator.
......@@ -349,10 +337,9 @@ class Handler(SNRHandlerMixin, lloidhandler.Handler):
self.verbose = verbose
# Explictly delay the class initialization.
def init(self, mainloop, pipeline, coincs_document, rankingstat, horizon_distance_func, gracedbwrapper, zerolag_rankingstatpdf_url=None, rankingstatpdf_url=None, ranking_stat_output_url=None, ranking_stat_input_url=None, likelihood_snapshot_interval=None, sngls_snr_threshold=None, FAR_trialsfactor=1.0, verbose=False):
super(Handler, self).__init__(
mainloop,
pipeline,
def init(self, stream, coincs_document, rankingstat, horizon_distance_func, gracedbwrapper, zerolag_rankingstatpdf_url=None, rankingstatpdf_url=None, ranking_stat_output_url=None, ranking_stat_input_url=None, likelihood_snapshot_interval=None, sngls_snr_threshold=None, FAR_trialsfactor=1.0, verbose=False):
super(Tracker, self).__init__(
stream,
coincs_document,
rankingstat,
horizon_distance_func,
......@@ -366,7 +353,7 @@ class Handler(SNRHandlerMixin, lloidhandler.Handler):
FAR_trialsfactor = FAR_trialsfactor,
kafka_server = None,
cluster = True,
tag = "0000",
job_tag = "0000",
verbose = verbose
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment