Skip to content
Snippets Groups Projects
Commit d5535984 authored by ChiWai Chan's avatar ChiWai Chan
Browse files

svd_bank_snr.py: simplify the program and change FIR_SNR to calculate the...

svd_bank_snr.py: simplify the program and change FIR_SNR to calculate the modulus of the complex snr

gstlal_inspiral_calc_snr: make corresponding changes
parent 088e2692
No related branches found
No related tags found
No related merge requests found
......@@ -52,9 +52,6 @@ Typical Usages:
--reference-psd (optional)
--track-psd (default = False. If --reference-psd is not given, this will be set to True)
--psd-fft-length (default = 32s)
--average-samples (default = 64)
--median-samples (default = 7)
--zero-pad (default = 0s)
4. Output options:
--output-width (default = 32bits)
......@@ -120,7 +117,7 @@ def parse_command_line():
group.add_option("--row-number", type = "int", help = "The row number of the template (optional). All the SNRs will be outputed if it is not given.")
group.add_option("--table", metavar = "filename", help = "A LIGO light-weight xml.gz file containing SnglInspiral Table. Expecting one template for each instrument only.")
group.add_option("--approximant", metavar = "name", type = "str", help = "Name of the Waveform model (require).")
group.add_option("--template-duration", metavar = "seconds", type = "int", help = "Duration of the template")
group.add_option("--template-duration", metavar = "seconds", type = "float", help = "Duration of the template")
group.add_option("--sample-rate", metavar = "Hz", default = 2048, type = "int", help = "Sampling rate of the template and SNR for mode 1")
group.add_option("--f-low", metavar = "Hz", default = 10, type = "float", help = "The minimum frequency of GW signal")
group.add_option("--f-high", metavar = "Hz", type = "float", help = "The maximum frequency of GW signal")
......@@ -250,7 +247,7 @@ def parse_command_line():
options, gw_data_source_info, template, psd = parse_command_line()
#====================================================================================================
#
# main
# main
#
#====================================================================================================
......@@ -259,6 +256,10 @@ if options.mode == 0:
num_of_row = bank.bank_fragments[0].mix_matrix.shape[1] / 2
lloid_snr = svd_bank_snr.LLOID_SNR(
gw_data_source_info,
bank,
options.instrument,
psd = psd,
psd_fft_length = options.psd_fft_length,
ht_gate_threshold = options.ht_gate_threshold,
veto_segments = options.veto_segments,
......@@ -266,31 +267,32 @@ if options.mode == 0:
width = options.output_width,
verbose = options.verbose
)
snr_info = lloid_snr(gw_data_source_info, bank, options.instrument, psd = psd)
if options.row_number is None:
# FIXME: put it in one xmldoc
for row in range(num_of_row):
snrdict = {options.instrument : svd_bank_snr.make_snr_series(snr_info, row_number = row, drop_first = options.drop_first, drop_last = options.drop_last)}
snrdict = {options.instrument : lloid_snr(row_number = row, drop_first = options.drop_first, drop_last = options.drop_last)}
svd_bank_snr.write_url(svd_bank_snr.make_xmldoc(snrdict), os.path.join(options.outdir, "snr_%d.xml.gz" % row), verbose = options.verbose)
else:
snrdict = {options.instrument : svd_bank_snr.make_snr_series(snr_info, row_number = options.row_number, drop_first = options.drop_first, drop_last = options.drop_last)}
snrdict = {options.instrument : lloid_snr(row_number = options.row_number, drop_first = options.drop_first, drop_last = options.drop_last)}
svd_bank_snr.write_url(svd_bank_snr.make_xmldoc(snrdict), os.path.join(options.outdir, "snr.xml.gz"), verbose = options.verbose)
elif options.mode == 1:
#FIXME: proper handle for latency
fir_snr = svd_bank_snr.FIR_SNR(
gw_data_source_info,
template,
options.instrument,
options.sample_rate,
0,
psd = psd,
psd_fft_length = options.psd_fft_length,
average_samples = options.average_samples,
median_samples = options.median_samples,
track_psd = options.track_psd,
rate = options.sample_rate,
width = options.output_width,
track_psd = options.track_psd,
verbose = options.verbose
)
#FIXME: allow multiple instruments
#FIXME: proper handle for latency
# drop the quadrature phase component
snr_info = fir_snr(gw_data_source_info, template.real, options.instrument, 0, psd = psd)
snrdict = {options.instrument:svd_bank_snr.make_snr_series(snr_info, drop_first = options.drop_first, drop_last = options.drop_last)}
#FIXME: allow multiple instruments
snrdict = {options.instrument : fir_snr(drop_first = options.drop_first, drop_last = options.drop_last)}
svd_bank_snr.write_url(svd_bank_snr.make_xmldoc(snrdict),os.path.join(options.outdir, "snr.xml.gz"), verbose = options.verbose)
"""
Short cutting gstlal inspiral pipeline to produce SNR for template(s)
Short cutting gstlal inspiral pipeline to produce SNR for gstlal_svd_bank.
A gstlal-based direct matched filter in time domain is also implemented.
"""
import sys
......@@ -35,69 +36,142 @@ from ligo.lw import utils as ligolw_utils
class SNRContentHandler(ligolw.LIGOLWContentHandler):
pass
class LLOID_SNR(object):
"""
The options for SNR calculation, please refer to multirate_datasource.mkwhitened_src()
and llloidparts.mkLLOIDhoftToSnrSlices() for more information. Defaults are:
"psd_fft_length": 32,
"ht_gate_threshold": None,
"veto_segments": None,
"track_psd": False,
"width": 32,
"verbose": False
"""
def __init__(self, psd_fft_length = 32, ht_gate_threshold = None, veto_segments = None, track_psd = False, width = 32, verbose = False):
self.psd_fft_length = psd_fft_length
self.ht_gate_threshold = ht_gate_threshold
self.veto_segments = veto_segments
self.track_psd = track_psd
self.width = width
class SNR_Pipeline(object):
def __init__(self, name = "gstlal_inspiral_SNR", verbose = False):
self.pipeline = Gst.Pipeline(name = name)
self.mainloop = GObject.MainLoop()
self.handler = simplehandler.Handler(self.mainloop, self.pipeline)
self.verbose = verbose
self.lock = threading.Lock()
self.snr_info = {
"timestamps": [],
"epoch": None,
"instrument": None,
"deltaT": None,
"data": [],
}
def __call__(self, gw_data_source_info, bank, instrument, psd = None):
pipeline = Gst.Pipeline(name = "gstlal_inspiral_LLOID_SNR")
mainloop = GObject.MainLoop()
handler = simplehandler.Handler(mainloop, pipeline)
def run(self, segments):
if self.verbose:
sys.stderr.write("Setting pipeline state to READY...\n")
if self.pipeline.set_state(Gst.State.READY) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline cannot enter ready state.")
datasource.pipeline_seek_for_gps(self.pipeline, *segments)
if self.verbose:
sys.stderr.write("Seting pipeline state to PLAYING...\n")
if self.pipeline.set_state(Gst.State.PLAYING) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline cannot enter playing state.")
if self.verbose:
sys.stderr.write("Calculating SNR...\n")
self.mainloop.run()
if self.verbose:
sys.stderr.write("Calculation done.\n")
if self.pipeline.set_state(Gst.State.NULL) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline could not be set to NULL.")
def get_snr_series(self, row_number = 0, drop_first = 0, drop_last = 0):
assert drop_first >= 0, "must drop positive number of data"
assert drop_last >= 0, "must drop positive number of data"
bps = drop_first * int(round(1 / self.snr_info["deltaT"]))
bpe = -drop_last * int(round(1 / self.snr_info["deltaT"])) if drop_last != 0 else None
data = numpy.abs(self.snr_info["data"])[:,row_number][bps:bpe]
if data.dtype == numpy.float32:
tseries = lal.CreateREAL4TimeSeries(
name = self.snr_info["instrument"],
epoch = self.snr_info["epoch"] + drop_first,
deltaT = self.snr_info["deltaT"],
f0 = 0,
sampleUnits = lal.DimensionlessUnit,
length = len(data)
)
tseries.data.data = data
elif data.dtype == numpy.float64:
tseries = lal.CreateREAL8TimeSeries(
name = self.snr_info["instrument"],
epoch = self.snr_info["epoch"] + drop_first,
deltaT = self.snr_info["deltaT"],
f0 = 0,
sampleUnits = lal.DimensionlessUnit,
length = len(data)
)
tseries.data.data = data
else:
raise ValueError("unsupported type : %s " % data.dtype)
return tseries
def new_preroll_handler(self, elem):
with self.lock:
# ignore preroll buffers
elem.emit("pull-preroll")
return Gst.FlowReturn.OK
def pull_snr_buffer(self, elem):
with self.lock:
sample = elem.emit("pull-sample")
if sample is None:
return Gst.FlowReturn.OK
success, rate = sample.get_caps().get_structure(0).get_int("rate")
assert success == True
# make sure the sampling rate is the same for all data
if self.snr_info["deltaT"] is None:
self.snr_info["deltaT"] = 1. / rate
else:
assert self.snr_info["deltaT"] == 1. / rate, "data have different sampling rate."
# record the first timestamp
if self.snr_info["epoch"] is None:
self.snr_info["epoch"] = LIGOTimeGPS(0, sample.get_buffer().pts)
buf = sample.get_buffer()
if buf.mini_object.flags & Gst.BufferFlags.GAP or buf.n_memory() == 0:
return Gst.FlowReturn.OK
# FIXME: check timestamps
data = pipeio.array_from_audio_sample(sample)
if data is not None:
self.snr_info["data"].append(data)
return Gst.FlowReturn.OK
class LLOID_SNR(SNR_Pipeline):
def __init__(self, gw_data_source_info, bank, instrument, psd = None, psd_fft_length = 32, ht_gate_threshold = None, veto_segments = None, track_psd = False, width = 32, verbose = False):
SNR_Pipeline.__init__(self, name = "gstlal_inspiral_lloid_snr", verbose = verbose)
self.snr_info["instrument"] = instrument
# sanity check
if psd is not None:
assert instrument in set(psd)
assert instrument in set(gw_data_source_info.channel_dict)
if not (instrument in set(psd)):
raise ValueError("No psd for instrument %s." % instrument)
if self.verbose:
sys.stderr.write("Building pipeline to calculate SNR...\n")
src, statevector, dqvector = datasource.mkbasicsrc(pipeline, gw_data_source_info, instrument, self.verbose)
src, statevector, dqvector = datasource.mkbasicsrc(self.pipeline, gw_data_source_info, instrument, self.verbose)
hoftdict = multirate_datasource.mkwhitened_multirate_src(
pipeline,
self.pipeline,
src = src,
rates = set(rate for rate in bank.get_rates()),
instrument = instrument,
psd = psd[instrument],
psd_fft_length = self.psd_fft_length,
ht_gate_threshold = self.ht_gate_threshold,
veto_segments = self.veto_segments,
track_psd = self.track_psd,
width = self.width,
psd_fft_length = psd_fft_length,
ht_gate_threshold = ht_gate_threshold,
veto_segments = veto_segments,
track_psd = track_psd,
width = width,
statevector = statevector,
dqvector = dqvector,
fir_whiten_reference_psd = bank.processed_psd
)
snr = lloidparts.mkLLOIDhoftToSnrSlices(
pipeline,
self.pipeline,
hoftdict = hoftdict,
bank = bank,
control_snksrc = (None, None),
......@@ -107,248 +181,72 @@ class LLOID_SNR(object):
logname = instrument
)
appsink = pipeparts.mkappsink(pipeline, snr, drop = False)
appsink = pipeparts.mkappsink(self.pipeline, snr, drop = False)
handler_id = appsink.connect("new-preroll", self.new_preroll_handler)
assert handler_id > 0
handler_id = appsink.connect("new-sample", self.pull_buffer)
handler_id = appsink.connect("new-sample", self.pull_snr_buffer)
assert handler_id > 0
handler_id = appsink.connect("eos", self.pull_buffer)
handler_id = appsink.connect("eos", self.pull_snr_buffer)
assert handler_id > 0
if self.verbose:
sys.stderr.write("Setting pipeline state to READY...\n")
if pipeline.set_state(Gst.State.READY) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline cannot enter ready state.")
datasource.pipeline_seek_for_gps(pipeline, *gw_data_source_info.seg)
if self.verbose:
sys.stderr.write("Seting pipeline state to PLAYING...\n")
if pipeline.set_state(Gst.State.PLAYING) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline cannot enter playing state.")
if self.verbose:
sys.stderr.write("Calculating SNR...\n")
mainloop.run()
if self.verbose:
sys.stderr.write("Calculation done.\n")
if pipeline.set_state(Gst.State.NULL) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline could not be set to NULL.")
del pipeline, mainloop, handler
# return snr_info containing information to construct snr for all template in the template bank
# see make_SNR_series() to make SNR LAL Series from snr_info
self.snr_info["data"] = numpy.concatenate(numpy.array(self.snr_info["data"]), axis = 0)
return self.snr_info
#===============================================================================================
#
# internal functions
#
#===============================================================================================
def new_preroll_handler(self, elem):
with self.lock:
# ignore preroll buffers
elem.emit("pull-preroll")
return Gst.FlowReturn.OK
def pull_buffer(self, elem):
with self.lock:
sample = elem.emit("pull-sample")
if sample is None:
return Gst.FlowReturn.OK
else:
success, rate = sample.get_caps().get_structure(0).get_int("rate")
assert success == True
# make sure the sampling rate is the same for all data
if self.snr_info["deltaT"] is None:
self.snr_info["deltaT"] = 1. / rate
else:
assert self.snr_info["deltaT"] == 1. / rate, "data have different sampling rate."
buf = sample.get_buffer()
if buf.mini_object.flags & Gst.BufferFlags.GAP or buf.n_memory() == 0:
return Gst.FlowReturn.OK
# FIXME: check timestamps
data = pipeio.array_from_audio_sample(sample)
if data is not None:
self.snr_info["data"].append(data)
self.snr_info["timestamps"].append(LIGOTimeGPS(0, sample.get_buffer().pts))
return Gst.FlowReturn.OK
class FIR_SNR(object):
"""
Required arguments:
-gw_data_source_info:
-template:
-psd:
-instrument:
Optional arguments:
-psd_fft_length:
-zero_pad: Hanning Window's zero padding (seconds)
-average_samples:
-median_samples:
-rate:
-verbose:
"""
def __init__(self, rate, psd_fft_length = 32, zero_pad = 0, average_samples = 64, median_samples = 7, width = 32, track_psd = False, verbose = False):
self.average_samples = average_samples
self.lock = threading.Lock()
self.median_samples = median_samples
self.psd_fft_length = psd_fft_length
self.rate = rate
self.track_psd = track_psd
self.verbose = verbose
self.width = width
self.zero_pad = zero_pad
self.run(gw_data_source_info.seg)
self.snr_info["data"] = numpy.concatenate(numpy.array(self.snr_info["data"]), axis = 0)
self.snr_info = {
"timestamps": [],
"instrument": None,
"deltaT": 1./rate,
"data": [],
}
def __call__(self, row_number = 0, drop_first = 0, drop_last = 0):
return self.get_snr_series(row_number, drop_first, drop_last)
def __call__(self, gw_data_source_info, template, instrument, latency, psd = None):
class FIR_SNR(SNR_Pipeline):
def __init__(self, gw_data_source_info, template, instrument, rate, latency, psd = None, psd_fft_length = 32, width = 32, track_psd = False, verbose = False):
SNR_Pipeline.__init__(self, name = "gstlal_inspiral_fir_snr", verbose = verbose)
self.snr_info["instrument"] = instrument
pipeline = Gst.Pipeline("gstlal_inspiral_simple_SNR")
mainloop = GObject.MainLoop()
handler = simplehandler.Handler(mainloop, pipeline)
# sanity check
if psd is not None:
if not (instrument in set(psd)):
raise ValueError("No psd for instrument %s." % instrument)
if self.verbose:
sys.stderr.write("Building pipeline to calculate SNR\n")
src, statevector, dqvector = datasource.mkbasicsrc(pipeline, gw_data_source_info, instrument, verbose = self.verbose)
src, statevector, dqvector = datasource.mkbasicsrc(self.pipeline, gw_data_source_info, instrument, verbose = self.verbose)
hoftdict = multirate_datasource.mkwhitened_multirate_src(
pipeline,
self.pipeline,
src = src,
rates = [self.rate],
rates = [rate],
instrument = instrument,
psd = psd[instrument],
psd_fft_length = self.psd_fft_length,
track_psd = self.track_psd,
width = self.width,
psd_fft_length = psd_fft_length,
track_psd = track_psd,
width = width,
statevector = statevector,
dqvector = dqvector
)
#FIXME: how to set latency
head = pipeparts.mkfirbank(pipeline, hoftdict[self.rate], latency = latency, fir_matrix = [template], block_stride = 16 * self.rate, time_domain = False)
head = pipeparts.mkfirbank(self.pipeline, hoftdict[rate], latency = latency, fir_matrix = [template.real, template.imag], block_stride = 16 * rate, time_domain = False)
appsink = pipeparts.mkappsink(pipeline, head, drop = False)
appsink = pipeparts.mkappsink(self.pipeline, head, drop = False)
handler_id = appsink.connect("new-preroll", self.new_preroll_handler)
assert handler_id > 0
handler_id = appsink.connect("new-sample", self.pull_buffer)
handler_id = appsink.connect("new-sample", self.pull_snr_buffer)
assert handler_id > 0
handler_id = appsink.connect("eos", self.pull_buffer)
handler_id = appsink.connect("eos", self.pull_snr_buffer)
assert handler_id > 0
if self.verbose:
sys.stderr.write("Setting pipeline state to READY...\n")
if pipeline.set_state(Gst.State.READY) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline cannot enter ready state.")
datasource.pipeline_seek_for_gps(pipeline, *gw_data_source_info.seg)
self.run(gw_data_source_info.seg)
self.snr_info["data"] = numpy.concatenate(numpy.array(self.snr_info["data"]), axis = 0)
self.snr_info["data"] = numpy.vectorize(complex)(self.snr_info["data"][:,0], self.snr_info["data"][:,1])
self.snr_info["data"].shape = len(self.snr_info["data"]), 1
if self.verbose:
sys.stderr.write("Setting pipeline state to PLAYING...\n")
if pipeline.set_state(Gst.State.PLAYING) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline cannot enter playing state.")
if self.verbose:
sys.stderr.write("Calculating SNR...\n")
mainloop.run()
if self.verbose:
sys.stderr.write("Calculation done.\n")
if pipeline.set_state(Gst.State.NULL) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline could not be set to NULL.")
del pipeline, mainloop, handler
self.snr_info["data"] = numpy.concatenate(numpy.array(self.snr_info["data"]), axis = 0)
return self.snr_info
#===============================================================================================
#
# internal functions
#
#===============================================================================================
def new_preroll_handler(self, elem):
with self.lock:
# ignore preroll buffers
elem.emit("pull-preroll")
return Gst.FlowReturn.OK
def pull_buffer(self, elem):
with self.lock:
sample = elem.emit("pull-sample")
if sample is None:
return Gst.FlowReturn.OK
else:
success, rate = sample.get_caps().get_structure(0).get_int("rate")
assert success == True
# make sure the sampling rate is the same for all data
if self.snr_info["deltaT"] is not None:
assert self.snr_info["deltaT"] == 1. / rate, "data have different sampling rate."
self.snr_info["deltaT"] = 1. / rate
buf = sample.get_buffer()
if buf.mini_object.flags & Gst.BufferFlags.GAP or buf.n_memory() == 0:
return Gst.FlowReturn.OK
data = pipeio.array_from_audio_sample(sample)
if data is not None:
self.snr_info["data"].append(data)
self.snr_info["timestamps"].append(LIGOTimeGPS(0, sample.get_buffer().pts))
return Gst.FlowReturn.OK
def __call__(self, row_number = 0 , drop_first = 0, drop_last = 0):
return self.get_snr_series(row_number, drop_first, drop_last)
#=============================================================================================
#
# Output Utilities
# Output Utilities
#
#=============================================================================================
def make_snr_series(snr_info, row_number = 0, drop_first = 0, drop_last = 0):
assert drop_first >= 0, "must drop positive number of data"
assert drop_last >= 0, "must drop positive number of data"
bps = drop_first * int(round(1 / snr_info["deltaT"]))
bpe = -drop_last * int(round(1 / snr_info["deltaT"])) if drop_last != 0 else None
data = numpy.abs(snr_info["data"])[:,row_number][bps:bpe]
if data.dtype == numpy.float32:
tseries = lal.CreateREAL4TimeSeries(
name = snr_info["instrument"],
epoch = snr_info["timestamps"][0] + drop_first,
deltaT = snr_info["deltaT"],
f0 = 0,
sampleUnits = lal.DimensionlessUnit,
length = len(data)
)
tseries.data.data = data
elif data.dtype == numpy.float64:
tseries = lal.CreateREAL8TimeSeries(
name = snr_info["instrument"],
epoch = snr_info["timestamps"][0] + drop_first,
deltaT = snr_info["deltaT"],
f0 = 0,
sampleUnits = lal.DimensionlessUnit,
length = len(data)
)
tseries.data.data = data
else:
raise ValueError("unsupported type : %s " % data.dtype)
return tseries
def make_xmldoc(snrdict, xmldoc = None, root_name = u"gstlal_inspiral_snr"):
if xmldoc is None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment