gstlal_cs_triggergen 11.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
#!/usr/bin/env python

import sys
import numpy
import gi
gi.require_version('Gst','1.0')
from gi.repository import GObject
GObject.threads_init()
from gi.repository import Gst
Gst.init(None)

from gstlal import datasource
13 14 15
from gstlal import pipeio
from gstlal import pipeparts
from gstlal import simplehandler
16 17 18 19
from gstlal import snglbursttable 
from lal import LIGOTimeGPS
from optparse import OptionParser

20 21 22 23
from ligo.lw import ligolw
from ligo.lw import lsctables
from ligo.lw import utils as ligolw_utils
from ligo.lw.utils import process as ligolw_process
24

25
import lal
26 27
import lalsimulation

28

29 30 31 32 33 34 35 36 37 38 39 40 41 42
#
# ================================================================================ 
#
#                                  Command Line
#
# ================================================================================ 
#


def parse_command_line():
	parser = OptionParser(
		description = "GstLAL-based cosmic string search pipeline."
	)

43 44 45 46 47 48 49 50 51 52 53 54
	parser.add_option("--sample-rate", metavar = "rate", type = "float", help = "Desired sample rate (Hz).")
	parser.add_option("--frame-cache", metavar = "filename", help = "The frame cache file to load as input data.")
	parser.add_option("--output", metavar = "filename", help = "Name of output xml file.")
	parser.add_option("--injection-file", metavar = "filename", help = "Name of xml injection file.")
        parser.add_option("--channel", metavar = "channel", type = "string",help = "Name of channel.")
	parser.add_option("--template-bank", metavar = "filename", help = "Name of template file.")
	parser.add_option("--gps-start-time", metavar = "start_time", type = "int",  help = "GPS start time.")
	parser.add_option("--gps-end-time", metavar = "end_time", type = "int", help = "GPS end time.")
	parser.add_option("--threshold", metavar = "snr_threshold", type = "float", help = "SNR threshold.")
	parser.add_option("--cluster-events", metavar = "cluster_events", type = "float", help = "Cluster events with input timescale.")
	parser.add_option("--user-tag", metavar = "user_tag", type = "string", help = "User tag set in the search summary and process tables")
	parser.add_option("--verbose", action = "store_true", help = "Be verbose.")
55 56 57

	options, filenames = parser.parse_args()

58 59 60 61 62
	required_options = ["sample_rate", "frame_cache", "output", "channel", "template_bank", "gps_start_time", "gps_end_time", "threshold", "cluster_events"]
	missing_options = [option for option in required_options if getattr(options, option) is None]
	if missing_options:
		raise ValueError, "missing required options %s" % ", ".join(sorted("--%s" % option.replace("_", "-") for option in missing_options))

63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
	return options, filenames

#
# ================================================================================ 
#
#				      Main
#
# ================================================================================ 
#

#
# parse command line
#

options, filenames = parse_command_line()


80 81 82 83 84 85 86 87
#
# handler for obtaining psd
#

class PSDHandler(simplehandler.Handler):
	def __init__(self, mainloop, pipeline, firbank):
		simplehandler.Handler.__init__(self,mainloop, pipeline)
		self.firbank = firbank
88
		self.triggergen = triggergen
89 90 91 92 93 94 95 96 97

	def do_on_message(self, bus, message):
		if message.type == Gst.MessageType.ELEMENT and  message.get_structure().get_name() == "spectrum":
			instrument = message.src.get_name().split("_")[-1]
			psd = pipeio.parse_spectrum_message(message)
			timestamp = psd.epoch
			stability = float(message.src.get_property("n-samples")) / message.src.get_property("average-samples")

			if stability > 0.3:
98
				print >> sys.stderr, "At GPS time", timestamp, "PSD stable"
99 100 101 102 103
				template_t = [None] * len(template_bank_table)
				autocorr = [None] * len(template_bank_table)
				# make templates, whiten, put into firbank
				for i, row in enumerate(template_bank_table):
					# linearly polarized, so just use plus mode time series
104
					template_t[i], _ = lalsimulation.GenerateStringCusp(1.0,row.central_freq,1.0/options.sample_rate)
105
					# zero-pad it to 32 seconds to obtain same deltaF as the PSD
106
					template_t[i] = lal.ResizeREAL8TimeSeries(template_t[i],-int(32*options.sample_rate - template_t[i].data.length)//2,int(32*options.sample_rate))
107
					# setup of frequency domain
108
					length = template_t[i].data.length
109 110 111 112 113
					duration = float(length) / options.sample_rate
					epoch = -(length-1) // 2 /options.sample_rate
					template_f = lal.CreateCOMPLEX16FrequencySeries("template_freq", LIGOTimeGPS(epoch), psd.f0, 1.0/duration, lal.Unit("strain s"), length // 2 + 1)
					fplan = lal.CreateForwardREAL8FFTPlan(length,0)
					# FFT to frequency domain
114
					lal.REAL8TimeFreqFFT(template_f,template_t[i],fplan)
115 116 117 118 119
					# set DC and Nyquist to zero
					template_f.data.data[0] = 0.0
					template_f.data.data[template_f.data.length-1] = 0.0
					# whiten
					template_f = lal.WhitenCOMPLEX16FrequencySeries(template_f,psd)
120
					# obtain autocorr time series by squaring template and inverse FFT it
121
					template_f_squared = lal.CreateCOMPLEX16FrequencySeries("whitened template_freq squared", LIGOTimeGPS(epoch), psd.f0, 1.0/duration, lal.Unit("s"), length // 2 + 1)
122
					autocorr_t = lal.CreateREAL8TimeSeries("autocorr_time", LIGOTimeGPS(epoch), psd.f0, 1.0 / options.sample_rate, lal.Unit("strain"), length)
123 124
					rplan = lal.CreateReverseREAL8FFTPlan(length,0)
					template_f_squared.data.data = abs(template_f.data.data)**2
125
					lal.REAL8FreqTimeFFT(autocorr_t,template_f_squared,rplan)
126
					# normalize autocorrelation by central (maximum) value
127 128 129
					autocorr_t.data.data /= numpy.max(autocorr_t.data.data)
					autocorr_t = autocorr_t.data.data
					max_index = numpy.argmax(autocorr_t)
130
					# find the index of the third extremum for the templates, making them all have the same length.
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
					# FIXME we do this because we know that the template with the lowest high-f cutoff will have
					# the largest chi2_index.
					if i == 0:
						extr_ctr = 0
						chi2_index = 0
						for j in range(max_index+1, len(autocorr_t)):
							slope1 = autocorr_t[j+1] - autocorr_t[j]
							slope0 = autocorr_t[j] - autocorr_t[j-1]
							chi2_index += 1
							if(slope1 * slope0 < 0):
								extr_ctr += 1
								if(extr_ctr == 2):
									break
					assert extr_ctr == 2, 'could not find 3rd extremum'
					# extract the part within the third extremum, setting the peak to be the center.
					autocorr[i] = numpy.concatenate((autocorr_t[1:(chi2_index+1)][::-1], autocorr_t[:(chi2_index+1)]))
					assert len(autocorr[i])%2==1, 'autocorr must have odd number of samples'
148 149 150 151 152 153
					# Inverse FFT template bank back to time domain
					template_t[i] = lal.CreateREAL8TimeSeries("whitened template time", LIGOTimeGPS(epoch), psd.f0, 1.0 / options.sample_rate, lal.Unit("strain"), length)
					lal.REAL8FreqTimeFFT(template_t[i],template_f,rplan)
					# normalize
					template_t[i].data.data /= numpy.sqrt(numpy.dot(template_t[i].data.data, template_t[i].data.data))
					template_t[i] = template_t[i].data.data
154
				firbank.set_property("latency", -(len(template_t[0]) - 1) // 2)
155
				firbank.set_property("fir_matrix", template_t)
156
				triggergen.set_property("autocorrelation_matrix", autocorr)
157
				self.firbank = firbank
158
				self.triggergen = triggergen
159
			else:
160
				# use templates with all zeros during burn-in period, that way we won't get any triggers.
161
				print >> sys.stderr, "At GPS time", timestamp, "burn in period"
162
				template = [None] * len(template_bank_table)
163
				autocorr = [None] * len(template_bank_table)
164 165 166 167 168
				for i, row in enumerate(template_bank_table):
					template[i], _ = lalsimulation.GenerateStringCusp(1.0,30,1.0/options.sample_rate)
					template[i] = lal.ResizeREAL8TimeSeries(template[i], -int(32*options.sample_rate - template[i].data.length)//2 ,int(32*options.sample_rate))
					template[i] = template[i].data.data
					template[i] *= 0.0
169 170 171
					# Set autocorrealtion to zero vectors as well.
					# The length is set to be similar to that obtained when the PSD is stable, but probably the length doesn't matter
					autocorr[i] = numpy.zeros(403)
172 173
				firbank.set_property("latency",-(len(template[0]) - 1) // 2)
				firbank.set_property("fir_matrix", template)
174
				triggergen.set_property("autocorrelation_matrix", autocorr)
175
				self.firbank = firbank
176
				self.triggergen = triggergen
177 178 179 180
			return True
		return False


181
#
182
# get data and insert injections if injection file is given
183 184
#

185

186 187 188 189
pipeline = Gst.Pipeline(name="pipeline")

head = pipeparts.mklalcachesrc(pipeline, options.frame_cache)
head = pipeparts.mkframecppchanneldemux(pipeline, head)
190
pipeparts.framecpp_channeldemux_set_units(head, {options.channel:"strain"})
191

192 193 194 195 196 197 198 199 200 201 202
elem = pipeparts.mkaudioconvert(pipeline, None)
pipeparts.src_deferred_link(head, options.channel, elem.get_static_pad("sink"))
head = elem


#
# injections
#

if options.injection_file is not None:
	head = pipeparts.mkinjections(pipeline, head, options.injection_file)
203 204 205 206 207

#
# whiten
#

208
head = pipeparts.mkwhiten(pipeline, head, fft_length = 32)
209 210 211 212 213 214 215 216


#
# resampler and caps filter
#

head = pipeparts.mkaudioconvert(pipeline,head)
head = pipeparts.mkresample(pipeline,head)
217 218 219
# FIXME check later if it's okay for filters that are exactly half of sample rate.
# FIXME NO hardcoding original sample rate!
head = pipeparts.mkaudioamplify(pipeline,head,1./numpy.sqrt(options.sample_rate/16384.0))
220 221 222 223 224 225 226 227 228 229 230 231 232 233
head = pipeparts.mkcapsfilter(pipeline,head,"audio/x-raw, format=F32LE, rate=%d" % options.sample_rate)
head = pipeparts.mkqueue(pipeline,head)


#
# load xml file and find single burst table
#

@lsctables.use_in
class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
	pass

xmldoc = ligolw_utils.load_filename(options.template_bank, contenthandler = LIGOLWContentHandler, verbose = True)

234
template_bank_table = lsctables.SnglBurstTable.get_table(xmldoc)
235 236 237 238 239 240


#
# filter bank
#

241
head = firbank = pipeparts.mkfirbank(pipeline, head, fir_matrix = numpy.zeros((len(template_bank_table),int(32*options.sample_rate)),dtype=numpy.float64), block_stride = 4 * options.sample_rate)
242 243


244 245 246 247 248 249
#
# format output xml file for putting triggers
#

xmldoc = ligolw.Document()
xmldoc.appendChild(ligolw.LIGO_LW())
250
process = ligolw_process.register_to_xmldoc(xmldoc, "StringSearch", options.__dict__)
251 252 253 254 255

sngl_burst_table = lsctables.New(lsctables.SnglBurstTable, ["process_id", "event_id","ifo","search","channel","start_time","start_time_ns","peak_time","peak_time_ns","duration","central_freq","bandwidth","amplitude","snr","confidence","chisq","chisq_dof"])
xmldoc.childNodes[-1].appendChild(sngl_burst_table)


256 257 258 259
#
# trigger generator
#

260
head = triggergen = pipeparts.mkgeneric(pipeline, head, "lal_string_triggergen", threshold = options.threshold, cluster = options.cluster_events, bank_filename = options.template_bank)
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276


#
# appsync
#

def appsink_new_buffer(elem):
	buf = elem.emit("pull-sample").get_buffer()
	events = []
	for i in range(buf.n_memory()):
		memory = buf.peek_memory(i)
		result, mapinfo = memory.map(Gst.MapFlags.READ)
		assert result
		if mapinfo.data:
			events.extend(snglbursttable.GSTLALSnglBurst.from_buffer(mapinfo.data))
		memory.unmap(mapinfo)
277 278 279 280
	for event in events:
		event.process_id = process.process_id
		event.event_id = sngl_burst_table.get_next_id()
		sngl_burst_table.append(event)
281 282 283 284 285 286 287 288 289 290 291 292 293

appsync = pipeparts.AppSync(appsink_new_buffer = appsink_new_buffer)
appsync.add_sink(pipeline, head, caps = Gst.Caps.from_string("application/x-lal-snglburst"))


if pipeline.set_state(Gst.State.READY) != Gst.StateChangeReturn.SUCCESS:
	raise RuntimeError("pipeline did not enter ready state")


#
# seek
#

294 295
options.gps_start_time = LIGOTimeGPS(options.gps_start_time)
options.gps_end_time = LIGOTimeGPS(options.gps_end_time)
296 297
datasource.pipeline_seek_for_gps(pipeline, options.gps_start_time, options.gps_end_time);

298

299 300 301 302
if pipeline.set_state(Gst.State.PLAYING) != Gst.StateChangeReturn.SUCCESS:
	raise RuntimeError("pipeline did not enter playing state")


303 304
mainloop = GObject.MainLoop()
handler = PSDHandler(mainloop, pipeline, firbank)
305
mainloop.run()
306

307

308 309 310 311 312
#
# write output to disk
#

ligolw_utils.write_filename(xmldoc, options.output, gz = (options.output or "stdout").endswith(".gz"), verbose = options.verbose)