Skip to content
Snippets Groups Projects
Commit 7b53b132 authored by Daichi Tsuna's avatar Daichi Tsuna
Browse files

cs_triggergen: whiten templates, get autocorrelation

parent 7a3aa77c
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python
import sys
import math
import numpy
import gi
gi.require_version('Gst','1.0')
......@@ -10,9 +9,10 @@ GObject.threads_init()
from gi.repository import Gst
Gst.init(None)
from gstlal import simplehandler
from gstlal import pipeparts
from gstlal import datasource
from gstlal import pipeio
from gstlal import pipeparts
from gstlal import simplehandler
from gstlal import snglbursttable
from lal import LIGOTimeGPS
from optparse import OptionParser
......@@ -22,8 +22,10 @@ from glue.ligolw import lsctables
from glue.ligolw import utils as ligolw_utils
from glue.ligolw.utils import process as ligolw_process
import lal
import lalsimulation
#
# ================================================================================
#
......@@ -53,6 +55,11 @@ def parse_command_line():
options, filenames = parser.parse_args()
required_options = ["sample_rate", "frame_cache", "output", "channel", "template_bank", "gps_start_time", "gps_end_time", "threshold", "cluster_events"]
missing_options = [option for option in required_options if getattr(options, option) is None]
if missing_options:
raise ValueError, "missing required options %s" % ", ".join(sorted("--%s" % option.replace("_", "-") for option in missing_options))
return options, filenames
#
......@@ -70,13 +77,78 @@ def parse_command_line():
options, filenames = parse_command_line()
#
# handler for obtaining psd
#
class PSDHandler(simplehandler.Handler):
def __init__(self, mainloop, pipeline, firbank):
simplehandler.Handler.__init__(self,mainloop, pipeline)
self.firbank = firbank
def do_on_message(self, bus, message):
if message.type == Gst.MessageType.ELEMENT and message.get_structure().get_name() == "spectrum":
instrument = message.src.get_name().split("_")[-1]
psd = pipeio.parse_spectrum_message(message)
timestamp = psd.epoch
stability = float(message.src.get_property("n-samples")) / message.src.get_property("average-samples")
if stability > 0.3:
template_bank = [None] * len(template_bank_table)
template_t = [None] * len(template_bank_table)
autocorr = [None] * len(template_bank_table)
# make templates, whiten, put into firbank
for i, row in enumerate(template_bank_table):
# linearly polarized, so just use plus mode time series
template_bank[i], _ = lalsimulation.GenerateStringCusp(1.0,row.central_freq,1.0/options.sample_rate)
# zero-pad it to 32 seconds to obtain same deltaF as the PSD
template_bank[i] = lal.ResizeREAL8TimeSeries(template_bank[i],-int(32*options.sample_rate - template_bank[i].data.length)//2,int(32*options.sample_rate))
# setup of frequency domain
length = template_bank[i].data.length
duration = float(length) / options.sample_rate
epoch = -(length-1) // 2 /options.sample_rate
template_f = lal.CreateCOMPLEX16FrequencySeries("template_freq", LIGOTimeGPS(epoch), psd.f0, 1.0/duration, lal.Unit("strain s"), length // 2 + 1)
fplan = lal.CreateForwardREAL8FFTPlan(length,0)
# FFT to frequency domain
lal.REAL8TimeFreqFFT(template_f,template_bank[i],fplan)
# set DC and Nyquist to zero
template_f.data.data[0] = 0.0
template_f.data.data[template_f.data.length-1] = 0.0
# whiten
template_f = lal.WhitenCOMPLEX16FrequencySeries(template_f,psd)
# obtain autocorrelation time series by
# squaring the template and inverse FFTing it
template_f_squared = lal.CreateCOMPLEX16FrequencySeries("whitened template_freq squared", LIGOTimeGPS(epoch), psd.f0, 1.0/duration, lal.Unit("s"), length // 2 + 1)
autocorr[i] = lal.CreateREAL8TimeSeries("autocorr_time", LIGOTimeGPS(epoch), psd.f0, 1.0 / options.sample_rate, lal.Unit("strain"), length)
rplan = lal.CreateReverseREAL8FFTPlan(length,0)
template_f_squared.data.data = abs(template_f.data.data)**2
lal.REAL8FreqTimeFFT(autocorr[i],template_f_squared,rplan)
# normalize autocorrelation by central (maximum) value
autocorr[i].data.data /= numpy.max(autocorr[i].data.data)
# Inverse FFT template bank back to time domain
template_t[i] = lal.CreateREAL8TimeSeries("whitened template time", LIGOTimeGPS(epoch), psd.f0, 1.0 / options.sample_rate, lal.Unit("strain"), length)
lal.REAL8FreqTimeFFT(template_t[i],template_f,rplan)
# normalize
template_t[i].data.data /= numpy.sqrt(numpy.dot(template_t[i].data.data, template_t[i].data.data))
template_t[i] = template_t[i].data.data
firbank.set_property("latency",-(len(template_t[0]) - 1) // 2)
firbank.set_property("fir_matrix", template_t)
self.firbank = firbank
else:
template, _ = lalsimulation.GenerateStringCusp(1.0,30,1.0/options.sample_rate)
firbank.set_property("fir_matrix",numpy.zeros((template.data.length,len(template_bank_table)), dtype=float))
firbank.set_property("latency",-(template.data.length - 1) // 2)
self.firbank = firbank
return True
return False
#
# get data and insert injections if injection file is given
#
pipeline = Gst.Pipeline(name="pipeline")
mainloop = GObject.MainLoop()
handler = simplehandler.Handler(mainloop, pipeline)
head = pipeparts.mklalcachesrc(pipeline, options.frame_cache)
head = pipeparts.mkframecppchanneldemux(pipeline, head)
......@@ -98,7 +170,7 @@ if options.injection_file is not None:
# whiten
#
head = pipeparts.mkwhiten(pipeline, head)
head = pipeparts.mkwhiten(pipeline, head, fft_length = 32)
#
......@@ -107,6 +179,9 @@ head = pipeparts.mkwhiten(pipeline, head)
head = pipeparts.mkaudioconvert(pipeline,head)
head = pipeparts.mkresample(pipeline,head)
# FIXME check later if it's okay for filters that are exactly half of sample rate.
# FIXME NO hardcoding original sample rate!
head = pipeparts.mkaudioamplify(pipeline,head,1./numpy.sqrt(options.sample_rate/16384.0))
head = pipeparts.mkcapsfilter(pipeline,head,"audio/x-raw, format=F32LE, rate=%d" % options.sample_rate)
head = pipeparts.mkqueue(pipeline,head)
......@@ -121,19 +196,14 @@ class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
xmldoc = ligolw_utils.load_filename(options.template_bank, contenthandler = LIGOLWContentHandler, verbose = True)
sngl_burst_table = lsctables.SnglBurstTable.get_table(xmldoc)
template_bank_table = lsctables.SnglBurstTable.get_table(xmldoc)
#
# filter bank
#
template_bank = [None] * len(sngl_burst_table)
for i, row in enumerate(sngl_burst_table):
template_bank[i], _ = lalsimulation.GenerateStringCusp(1.0,row.central_freq,1.0/options.sample_rate)
template_bank[i] = template_bank[i].data.data
template_bank[i] /= math.sqrt(numpy.dot(template_bank[i], template_bank[i]))
head = pipeparts.mkfirbank(pipeline, head, latency = -(len(template_bank[0]) - 1) // 2, fir_matrix = template_bank, block_stride = 4 * options.sample_rate)
head = firbank = pipeparts.mkfirbank(pipeline, head, time_domain = False, block_stride = 4 * options.sample_rate)
#
......@@ -152,7 +222,6 @@ xmldoc.childNodes[-1].appendChild(sngl_burst_table)
# trigger generator
#
head = pipeparts.mkgeneric(pipeline, head, "lal_string_triggergen", threshold = options.threshold, cluster = options.cluster_events, bank_filename = options.template_bank)
......@@ -196,8 +265,11 @@ if pipeline.set_state(Gst.State.PLAYING) != Gst.StateChangeReturn.SUCCESS:
raise RuntimeError("pipeline did not enter playing state")
mainloop = GObject.MainLoop()
handler = PSDHandler(mainloop, pipeline, firbank)
mainloop.run()
#
# write output to disk
#
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment