Forked from
lscsoft / GstLAL
2155 commits behind the upstream repository.
-
ChiWai Chan authored
svd_bank_snr: In mode 1, use mkwhitened_multirate_src instead of mkwhiten
ChiWai Chan authoredsvd_bank_snr: In mode 1, use mkwhitened_multirate_src instead of mkwhiten
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
svd_bank_snr.py 12.69 KiB
"""
Short cutting gstlal inspiral pipeline to produce SNR for template(s)
"""
import sys
import numpy
import threading
import gi
gi.require_version('Gst', '1.0')
gi.require_version('GstAudio', '1.0')
from gi.repository import GObject, Gst, GstAudio
GObject.threads_init()
Gst.init(None)
from gstlal import datasource
from gstlal import lloidparts
from gstlal import multirate_datasource
from gstlal import pipeparts
from gstlal import pipeio
from gstlal import reference_psd
from gstlal import simplehandler
import lal
from lal import LIGOTimeGPS
import lal.series
from ligo.lw import array as ligolw_array
from ligo.lw import ligolw
from ligo.lw import param as ligolw_param
from ligo.lw import utils as ligolw_utils
@ligolw_array.use_in
@ligolw_param.use_in
class SNRContentHandler(ligolw.LIGOLWContentHandler):
pass
class LLOID_SNR(object):
"""
The options for SNR calculation, please refer to multirate_datasource.mkwhitened_src()
and llloidparts.mkLLOIDhoftToSnrSlices() for more information. Defaults are:
"psd_fft_length": 32,
"ht_gate_threshold": None,
"veto_segments": None,
"track_psd": False,
"width": 32,
"verbose": False
"""
def __init__(self, psd_fft_length = 32, ht_gate_threshold = None, veto_segments = None, track_psd = False, width = 32, verbose = False):
self.psd_fft_length = psd_fft_length
self.ht_gate_threshold = ht_gate_threshold
self.veto_segments = veto_segments
self.track_psd = track_psd
self.width = width
self.verbose = verbose
self.lock = threading.Lock()
self.snr_info = {
"timestamps": [],
"instrument": None,
"deltaT": None,
"data": [],
}
def __call__(self, gw_data_source_info, bank, instrument, psd = None):
pipeline = Gst.Pipeline(name = "gstlal_inspiral_LLOID_SNR")
mainloop = GObject.MainLoop()
handler = simplehandler.Handler(mainloop, pipeline)
self.snr_info["instrument"] = instrument
# sanity check
if psd is not None:
assert instrument in set(psd)
assert instrument in set(gw_data_source_info.channel_dict)
if self.verbose:
sys.stderr.write("Building pipeline to calculate SNR...\n")
src, statevector, dqvector = datasource.mkbasicsrc(pipeline, gw_data_source_info, instrument, self.verbose)
hoftdict = multirate_datasource.mkwhitened_multirate_src(
pipeline,
src = src,
rates = set(rate for rate in bank.get_rates()),
instrument = instrument,
psd = psd[instrument],
psd_fft_length = self.psd_fft_length,
ht_gate_threshold = self.ht_gate_threshold,
veto_segments = self.veto_segments,
track_psd = self.track_psd,
width = self.width,
statevector = statevector,
dqvector = dqvector,
fir_whiten_reference_psd = bank.processed_psd
)
snr = lloidparts.mkLLOIDhoftToSnrSlices(
pipeline,
hoftdict = hoftdict,
bank = bank,
control_snksrc = (None, None),
block_duration = 8 * Gst.SECOND,
fir_stride = 16,
verbose = self.verbose,
logname = instrument
)
appsink = pipeparts.mkappsink(pipeline, snr, drop = False)
handler_id = appsink.connect("new-preroll", self.new_preroll_handler)
assert handler_id > 0
handler_id = appsink.connect("new-sample", self.pull_buffer)
assert handler_id > 0
handler_id = appsink.connect("eos", self.pull_buffer)
assert handler_id > 0
if self.verbose:
sys.stderr.write("Setting pipeline state to READY...\n")
if pipeline.set_state(Gst.State.READY) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline cannot enter ready state.")
datasource.pipeline_seek_for_gps(pipeline, *gw_data_source_info.seg)
if self.verbose:
sys.stderr.write("Seting pipeline state to PLAYING...\n")
if pipeline.set_state(Gst.State.PLAYING) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline cannot enter playing state.")
if self.verbose:
sys.stderr.write("Calculating SNR...\n")
mainloop.run()
if self.verbose:
sys.stderr.write("Calculation done.\n")
if pipeline.set_state(Gst.State.NULL) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline could not be set to NULL.")
del pipeline, mainloop, handler
# return snr_info containing information to construct snr for all template in the template bank
# see make_SNR_series() to make SNR LAL Series from snr_info
self.snr_info["data"] = numpy.concatenate(numpy.array(self.snr_info["data"]), axis = 0)
return self.snr_info
#===============================================================================================
#
# internal functions
#
#===============================================================================================
def new_preroll_handler(self, elem):
with self.lock:
# ignore preroll buffers
elem.emit("pull-preroll")
return Gst.FlowReturn.OK
def pull_buffer(self, elem):
with self.lock:
sample = elem.emit("pull-sample")
if sample is None:
return Gst.FlowReturn.OK
else:
success, rate = sample.get_caps().get_structure(0).get_int("rate")
assert success == True
# make sure the sampling rate is the same for all data
if self.snr_info["deltaT"] is None:
self.snr_info["deltaT"] = 1. / rate
else:
assert self.snr_info["deltaT"] == 1. / rate, "data have different sampling rate."
buf = sample.get_buffer()
if buf.mini_object.flags & Gst.BufferFlags.GAP or buf.n_memory() == 0:
return Gst.FlowReturn.OK
# FIXME: check timestamps
data = pipeio.array_from_audio_sample(sample)
if data is not None:
self.snr_info["data"].append(data)
self.snr_info["timestamps"].append(LIGOTimeGPS(0, sample.get_buffer().pts))
return Gst.FlowReturn.OK
class FIR_SNR(object):
"""
Required arguments:
-gw_data_source_info:
-template:
-psd:
-instrument:
Optional arguments:
-psd_fft_length:
-zero_pad: Hanning Window's zero padding (seconds)
-average_samples:
-median_samples:
-rate:
-verbose:
"""
def __init__(self, rate, psd_fft_length = 32, zero_pad = 0, average_samples = 64, median_samples = 7, width = 32, track_psd = False, verbose = False):
self.average_samples = average_samples
self.lock = threading.Lock()
self.median_samples = median_samples
self.psd_fft_length = psd_fft_length
self.rate = rate
self.track_psd = track_psd
self.verbose = verbose
self.width = width
self.zero_pad = zero_pad
self.snr_info = {
"timestamps": [],
"instrument": None,
"deltaT": 1./rate,
"data": [],
}
def __call__(self, gw_data_source_info, template, instrument, latency, psd = None):
self.snr_info["instrument"] = instrument
pipeline = Gst.Pipeline("gstlal_inspiral_simple_SNR")
mainloop = GObject.MainLoop()
handler = simplehandler.Handler(mainloop, pipeline)
if self.verbose:
sys.stderr.write("Building pipeline to calculate SNR\n")
src, statevector, dqvector = datasource.mkbasicsrc(pipeline, gw_data_source_info, instrument, verbose = self.verbose)
hoftdict = multirate_datasource.mkwhitened_multirate_src(
pipeline,
src = src,
rates = [self.rate],
instrument = instrument,
psd = psd[instrument],
psd_fft_length = self.psd_fft_length,
track_psd = self.track_psd,
width = self.width,
statevector = statevector,
dqvector = dqvector
)
#FIXME: how to set latency
head = pipeparts.mkfirbank(pipeline, hoftdict[self.rate], latency = latency, fir_matrix = [template], block_stride = 16 * self.rate, time_domain = False)
appsink = pipeparts.mkappsink(pipeline, head, drop = False)
handler_id = appsink.connect("new-preroll", self.new_preroll_handler)
assert handler_id > 0
handler_id = appsink.connect("new-sample", self.pull_buffer)
assert handler_id > 0
handler_id = appsink.connect("eos", self.pull_buffer)
assert handler_id > 0
if self.verbose:
sys.stderr.write("Setting pipeline state to READY...\n")
if pipeline.set_state(Gst.State.READY) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline cannot enter ready state.")
datasource.pipeline_seek_for_gps(pipeline, *gw_data_source_info.seg)
if self.verbose:
sys.stderr.write("Setting pipeline state to PLAYING...\n")
if pipeline.set_state(Gst.State.PLAYING) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline cannot enter playing state.")
if self.verbose:
sys.stderr.write("Calculating SNR...\n")
mainloop.run()
if self.verbose:
sys.stderr.write("Calculation done.\n")
if pipeline.set_state(Gst.State.NULL) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline could not be set to NULL.")
del pipeline, mainloop, handler
self.snr_info["data"] = numpy.concatenate(numpy.array(self.snr_info["data"]), axis = 0)
return self.snr_info
#===============================================================================================
#
# internal functions
#
#===============================================================================================
def new_preroll_handler(self, elem):
with self.lock:
# ignore preroll buffers
elem.emit("pull-preroll")
return Gst.FlowReturn.OK
def pull_buffer(self, elem):
with self.lock:
sample = elem.emit("pull-sample")
if sample is None:
return Gst.FlowReturn.OK
else:
success, rate = sample.get_caps().get_structure(0).get_int("rate")
assert success == True
# make sure the sampling rate is the same for all data
if self.snr_info["deltaT"] is not None:
assert self.snr_info["deltaT"] == 1. / rate, "data have different sampling rate."
self.snr_info["deltaT"] = 1. / rate
buf = sample.get_buffer()
if buf.mini_object.flags & Gst.BufferFlags.GAP or buf.n_memory() == 0:
return Gst.FlowReturn.OK
data = pipeio.array_from_audio_sample(sample)
if data is not None:
self.snr_info["data"].append(data)
self.snr_info["timestamps"].append(LIGOTimeGPS(0, sample.get_buffer().pts))
return Gst.FlowReturn.OK
#=============================================================================================
#
# Output Utilities
#
#=============================================================================================
def make_snr_series(snr_info, row_number = 0, drop_first = 0, drop_last = 0):
assert drop_first >= 0, "must drop positive number of data"
assert drop_last >= 0, "must drop positive number of data"
bps = drop_first * int(round(1 / snr_info["deltaT"]))
bpe = -drop_last * int(round(1 / snr_info["deltaT"])) if drop_last != 0 else None
data = numpy.abs(snr_info["data"])[:,row_number][bps:bpe]
if data.dtype == numpy.float32:
tseries = lal.CreateREAL4TimeSeries(
name = snr_info["instrument"],
epoch = snr_info["timestamps"][0] + drop_first,
deltaT = snr_info["deltaT"],
f0 = 0,
sampleUnits = lal.DimensionlessUnit,
length = len(data)
)
tseries.data.data = data
elif data.dtype == numpy.float64:
tseries = lal.CreateREAL8TimeSeries(
name = snr_info["instrument"],
epoch = snr_info["timestamps"][0] + drop_first,
deltaT = snr_info["deltaT"],
f0 = 0,
sampleUnits = lal.DimensionlessUnit,
length = len(data)
)
tseries.data.data = data
else:
raise ValueError("unsupported type : %s " % data.dtype)
return tseries
def make_xmldoc(snrdict, xmldoc = None, root_name = u"gstlal_inspiral_snr"):
if xmldoc is None:
xmldoc = ligolw.Document()
root = xmldoc.appendChild(ligolw.LIGO_LW())
root.Name = root_name
for instrument, snr in snrdict.items():
if snr.data.data.dtype == numpy.float32:
tseries = root.appendChild(lal.series.build_REAL4TimeSeries(snr))
elif snr.data.data.dtype == numpy.float64:
tseries = root.appendChild(lal.series.build_REAL8TimeSeries(snr))
else:
raise ValueError("unsupported type : %s" % snr.data.data.dtype)
if instrument is not None:
tseries.appendChild(ligolw_param.Param.from_pyvalue(u"instrument", instrument))
return xmldoc
def read_xmldoc(xmldoc, root_name = u"gstlal_inspiral_snr"):
if root_name is not None:
xmldoc, = (elem for elem in xmldoc.getElementsByTagName(ligolw.LIGO_LW.tagName) if elem.hasAttribute(u"Name") if elem.Name == root_name)
result = []
for elem in (elem for elem in xmldoc.getElementsByTagName(ligolw.LIGO_LW.tagName) if elem.hasAttribute(u"Name")):
if elem.Name == u"REAL4TimeSeries":
result.append((ligolw_param.get_pyvalue(elem, u"instrument"), lal.series.parse_REAL4TimeSeries(elem)))
elif elem.Name == u"REAL8TimeSeries":
result.append((ligolw_param.get_pyvalue(elem, u"instrument"), lal.series.parse_REAL8TimeSeries(elem)))
assert result is not None, "xmldoc contains no LAL Series or LAL Series is unsupported"
return dict(result)
# wrapper for writing snr series to URL
def write_url(xmldoc, filename, verbose = False):
ligolw_utils.write_filename(xmldoc, filename, gz = filename.endswith(".gz"), verbose = verbose)
# wrapper for reading snr series from URL
def read_url(filename, contenthandler, verbose = False):
return ligolw_utils.load_url(filename, verbose = verbose, contenthandler = contenthandler)