Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
gstlal_cs_triggergen 23.88 KiB
#!/usr/bin/env python

import math
import numpy
import sys
import threading
import gi
gi.require_version('Gst','1.0')
from gi.repository import GObject
GObject.threads_init()
from gi.repository import Gst
Gst.init(None)

from gstlal import datasource
from gstlal import pipeio
from gstlal import pipeparts
from gstlal import simplehandler
from gstlal import snglbursttable 
from gstlal import streamburca
from optparse import OptionParser

from ligo import segments
from ligo.lw.utils import segments as ligolw_segments
from ligo.lw.utils import ligolw_add
from ligo.lw.utils import process as ligolw_process
from ligo.lw import ligolw
from ligo.lw import lsctables
from ligo.lw import utils as ligolw_utils

import lal
from lal import series
from lal import LIGOTimeGPS
import lalsimulation


#
# ================================================================================ 
#
#                                  Command Line
#
# ================================================================================ 
#


def parse_command_line():
	parser = OptionParser(
		description = "GstLAL-based cosmic string search pipeline."
	)

	parser.add_option("--sample-rate", metavar = "rate", type = "float", help = "Desired sample rate (Hz).")
	parser.add_option("--frame-cache", metavar = "filename", help = "Frame cache file to load as input data.")
	parser.add_option("--reference-psd", metavar = "filename", help = "Reference psd files as input to obtain the template and SNR. Can be given for multiple detectors, but must be in one file. If None, the PSD will be measured on the fly, but there will be some burn-in time where the data will not be analyzed until the PSD converges.")
	parser.add_option("--output", metavar = "filename", help = "Name of output xml file.")
	parser.add_option("--segments-file", metavar = "filename", help = "Set the name of the LIGO Light-Weight XML file with segment lists that are science mode, for the trigger generator to enable gating.  See also --segments-name.")
	parser.add_option("--segments-name", metavar = "name", help = "Set the name of the segment lists to retrieve from the segments file.  See also --segments-file.")
	parser.add_option("--vetoes-file", metavar = "filename", help = "Set the name of the LIGO Light-Weight XML file with segment lists that are vetoed, for the trigger generator to enable gating.  See also --vetoes-name.")
	parser.add_option("--vetoes-name", metavar = "name", help = "Set the name of the veto segment lists to retrieve from the veto segments file.  See also --vetoes-file.")
	parser.add_option("--injection-file", metavar = "filename", help = "Name of xml file with injections.")
	parser.add_option("--time-slide-file", metavar = "filename", help = "Name of xml file with time slides for each detector.")
        parser.add_option("--channel", metavar = "channel", action = "append", type = "string", help = "Name of channel. Can be given multiple inputs, but must be one for each detector.")
	parser.add_option("--template-bank", metavar = "filename", action = "append", help = "Name of template file. Template bank for all the detectors involved should be given.")
	parser.add_option("--gps-start-time", metavar = "start_time", type = "int",  help = "GPS start time.")
	parser.add_option("--gps-end-time", metavar = "end_time", type = "int", help = "GPS end time.")
	parser.add_option("--threshold", metavar = "snr_threshold", type = "float", help = "SNR threshold.")
	parser.add_option("--cluster-events", metavar = "cluster_events", type = "float", help = "Cluster events with input timescale (in seconds).")
	parser.add_option("--user-tag", metavar = "user_tag", type = "string", help = "User tag set in the search summary and process tables")
	parser.add_option("--deltat", metavar = "deltat", type = "float", default = 0.008, help = "Maximum time difference in seconds for coincidence, excluding the light-travel time between the detectors. Default: 0.008")
	parser.add_option("--verbose", action = "store_true", help = "Be verbose.")

	options, filenames = parser.parse_args()

	required_options = ["sample_rate", "frame_cache", "output", "time_slide_file", "channel", "template_bank", "gps_start_time", "gps_end_time", "threshold", "cluster_events"]
	missing_options = [option for option in required_options if getattr(options, option) is None]
	if missing_options:
		raise ValueError("missing required options %s" % ", ".join(sorted("--%s" % option.replace("_", "-") for option in missing_options)))
	if len(options.template_bank) != len(options.channel):
		raise ValueError("number of --template-bank options must equal number of --channel options")
	if options.segments_file is not None and options.segments_name is None:
		raise ValueError("segments name should be specified for the input segments file")
	if options.vetoes_file is not None and options.vetoes_name is None:
		raise ValueError("vetoes name should be specified for the input vetoes file")

	return options, filenames


#
# parse command line
#

options, filenames = parse_command_line()


#
# handler for updating templates using psd and putting triggers for coincidence
#

class PipelineHandler(simplehandler.Handler):
	def __init__(self, mainloop, pipeline, xmldoc, template_banks, sngl_burst, analyzed_seglistdict, reference_psds, firbanks, triggergens):
		simplehandler.Handler.__init__(self, mainloop, pipeline)
		self.lock = threading.Lock()
		self.template_bank = template_banks
		self.sngl_burst = sngl_burst
		self.analyzed_seglistdict = analyzed_seglistdict
		self.firbank = firbanks
		self.triggergen = triggergens
		# template normalization. use central_freq to uniquely identify templates
		self.sigma = dict((row.central_freq, 0.0) for row in template_banks[template_banks.keys()[0]])
		# for PSD
		self.update_psd = dict.fromkeys(triggergens, 0)
		self.reference_psd = reference_psds
		# create a StreamBurca instance, initialized with the XML document and the coincidence parameters
		self.streamburca = streamburca.StreamBurca(xmldoc, process.process_id, options.deltat, min_instruments = 2, verbose = options.verbose)


	def appsink_new_buffer(self, elem):
		with self.lock:
			buf = elem.emit("pull-sample").get_buffer()
			events = []
			for i in range(buf.n_memory()):
				memory = buf.peek_memory(i)
				result, mapinfo = memory.map(Gst.MapFlags.READ)
				assert result
				if mapinfo.data:
					events.extend(snglbursttable.GSTLALSnglBurst.from_buffer(mapinfo.data))
				memory.unmap(mapinfo)
			# get ifo from the appsink name property
			instrument = elem.get_property("name")
			# extract segment.  move the segment's upper
			# boundary to include all triggers.
			buf_timestamp = LIGOTimeGPS(0, buf.pts)
			buf_seg = {instrument: segments.segmentlist([segments.segment(buf_timestamp, buf_timestamp + LIGOTimeGPS(0, buf.duration))])}
			if events:
				buf_seg[instrument] |= segments.segmentlist([segments.segment(buf_timestamp, max(event.peak for event in events if event.ifo == instrument))])
			# obtain union of this segment and the previously added segments
			self.analyzed_seglistdict |= buf_seg
			# put info of each event in the sngl burst table
			if options.verbose:
				print >> sys.stderr, "at", buf_timestamp, "got", len(events), "in", set([event.ifo for event in events])
			for event in events:
				event.process_id = process.process_id
				event.event_id = self.sngl_burst.get_next_id()
				event.amplitude = event.snr / self.sigma[event.central_freq]
			# push the single detector triggers into the StreamBurca instance
			# the push method returns True if the coincidence engine has new results. in that case, call the pull() method to run the coincidence engine.
			if events:
				if self.streamburca.push(instrument, events, buf_timestamp):
					self.streamburca.pull()

	def flush(self):
		with self.lock:
			# dump segmentlistdict to segment table
			with ligolw_segments.LigolwSegments(xmldoc, process) as llwsegment:
				llwsegment.insert_from_segmentlistdict(self.analyzed_seglistdict, name = u"StringSearch", comment="triggergen")
			# leftover triggers
			self.streamburca.pull(flush = True)

	def update_templates(self, instrument, psd):
		template_t = [None] * len(self.template_bank[instrument])
		autocorr = [None] * len(self.template_bank[instrument])
		# make templates, whiten, put into firbank
		# NOTE Currently works only for cusps. this for-loop needs to be revisited when searching for other sources (kinks, ...)
		for i, row in enumerate(self.template_bank[instrument]):
			# Obtain cusp waveform. A cusp signal is linearly polarized, so just use plus mode time series
			template_t[i], _ = lalsimulation.GenerateStringCusp(1.0,row.central_freq,1.0/options.sample_rate)
			# zero-pad it to 32 seconds to obtain same deltaF as the PSD
			# we have to make the number of samples in the template odd, but if we do that here deltaF of freq domain template will be different from psd's deltaF, and whitening cannot be done. So we keep it exactly 32 seconds, and after getting a whitened template we add a sample of 0 in the tail.
			template_t[i] = lal.ResizeREAL8TimeSeries(template_t[i], -int(32*options.sample_rate - template_t[i].data.length) // 2, int(32*options.sample_rate))
			# setup of frequency domain
			length = template_t[i].data.length
			duration = float(length) / options.sample_rate
			epoch = - float(length // 2) / options.sample_rate
			template_f = lal.CreateCOMPLEX16FrequencySeries("template_freq", LIGOTimeGPS(epoch), psd.f0, 1.0/duration, lal.Unit("strain s"), length // 2 + 1)
			fplan = lal.CreateForwardREAL8FFTPlan(length,0)
			# FFT to frequency domain
			lal.REAL8TimeFreqFFT(template_f,template_t[i],fplan)
			# set DC and Nyquist to zero
			template_f.data.data[0] = 0.0
			template_f.data.data[template_f.data.length-1] = 0.0
			# whiten
			assert template_f.deltaF == psd.deltaF, "freq interval not same between template and PSD"
			template_f = lal.WhitenCOMPLEX16FrequencySeries(template_f,psd)
			# Obtain the normalization for getting the amplitude of signal from SNR
			# Integrate over frequency range covered by template. Note that template_f is already whitened.
			sigmasq = 0.0
			sigmasq = numpy.trapz(4.0 * template_f.data.data**2, dx = psd.deltaF)
			self.sigma[row.central_freq] = numpy.sqrt(sigmasq.real)
			# obtain autocorr time series by squaring template and inverse FFT it
			template_f_squared = lal.CreateCOMPLEX16FrequencySeries("whitened template_freq squared", LIGOTimeGPS(epoch), psd.f0, 1.0/duration, lal.Unit("strain s"), length // 2 + 1)
			autocorr_t = lal.CreateREAL8TimeSeries("autocorr_time", LIGOTimeGPS(epoch), psd.f0, 1.0 / options.sample_rate, lal.Unit("strain"), length)
			rplan = lal.CreateReverseREAL8FFTPlan(length, 0)
			template_f_squared.data.data = abs(template_f.data.data)**2
			lal.REAL8FreqTimeFFT(autocorr_t,template_f_squared,rplan)
			# normalize autocorrelation by central (maximum) value
			autocorr_t.data.data /= numpy.max(autocorr_t.data.data)
			autocorr_t = autocorr_t.data.data
			max_index = numpy.argmax(autocorr_t)
			# find the index of the third extremum for the template with lowest high-f cutoff.
			# we don't want to do this for all templates, because we know that
			# the template with the lowest high-f cutoff will have the largest chi2_index.
			if i == 0:
				extr_ctr = 0
				chi2_index = 0
				for j in range(max_index+1, len(autocorr_t)):
					slope1 = autocorr_t[j+1] - autocorr_t[j]
					slope0 = autocorr_t[j] - autocorr_t[j-1]
					chi2_index += 1
					if(slope1 * slope0 < 0):
						extr_ctr += 1
						if(extr_ctr == 2):
							break
			assert extr_ctr == 2, 'could not find 3rd extremum'
			# extract the part within the third extremum, setting the peak to be the center.
			autocorr[i] = numpy.concatenate((autocorr_t[1:(chi2_index+1)][::-1], autocorr_t[:(chi2_index+1)]))
			assert len(autocorr[i])%2==1, 'autocorr must have odd number of samples'
			# Inverse FFT template bank back to time domain
			template_t[i] = lal.CreateREAL8TimeSeries("whitened template_time", LIGOTimeGPS(epoch), psd.f0, 1.0 / options.sample_rate, lal.Unit("strain"), length)
			lal.REAL8FreqTimeFFT(template_t[i],template_f,rplan)
			# normalize
			template_t[i] = template_t[i].data.data
			template_t[i] /= numpy.sqrt(numpy.dot(template_t[i], template_t[i]))
			# to make the sample number odd we add 1 sample in the end here
			template_t[i] = numpy.append(template_t[i], 0.0)
			assert len(template_t[i])%2==1, 'template must have odd number of samples'
		self.firbank[instrument].set_property("latency", (len(template_t[0]) - 1) // 2)
		self.firbank[instrument].set_property("fir_matrix", template_t)
		self.triggergen[instrument].set_property("autocorrelation_matrix", autocorr)


	def do_on_message(self, bus, message):
		if message.type == Gst.MessageType.ELEMENT and  message.get_structure().get_name() == "spectrum":
			instrument = message.src.get_name().split("_")[-1]
			if self.reference_psd is None:
				psd = pipeio.parse_spectrum_message(message)
				timestamp = psd.epoch
			else:
				psd = self.reference_psd[instrument]

			stability = float(message.src.get_property("n-samples")) / message.src.get_property("average-samples")

			# the logic should be neater here, but this is hoped to be temporary until we wipe out everything when finishing transition to offline way
			if stability > 0.3 or self.reference_psd is not None:
				if self.update_psd[instrument] != 0:
					# do nothing, just decrease the counter
					self.update_psd[instrument] -= 1
				else:
					# PSD counter reached zero
					if options.verbose:
						print >> sys.stderr, "setting whitened templates for", instrument
					# if you don't give the reference psd, how often psd is updated is decided by the integer given here. Larger number, less often.
					# if you give the reference psd, you need to make the template banks only once, so make the counter negative
					if self.reference_psd is None:
						self.update_psd[instrument] = 10
					else:
						self.update_psd[instrument] = -1
					self.update_templates(instrument, psd)
			else:
				# Burn-in period. Use templates with all zeros so that we won't get any triggers.
				if options.verbose:
					print >> sys.stderr, "At GPS time", timestamp, "burn in period"
				template = [None] * len(self.template_bank[instrument])
				autocorr = [None] * len(self.template_bank[instrument])
				for i, row in enumerate(self.template_bank[instrument]):
					template[i] = numpy.zeros(int(32*options.sample_rate+1))
					# The length of autocorr is set to be similar to that for non-zero templates, but probably the length doesn't matter
					autocorr[i] = numpy.zeros(403)
				self.firbank[instrument].set_property("latency", (len(template[0]) - 1) // 2)
				self.firbank[instrument].set_property("fir_matrix", template)
				self.triggergen[instrument].set_property("autocorrelation_matrix", autocorr)
			return True
		return False


#
# =============================================================================
#
#                          Input and output files
#
# =============================================================================
#


#
# from the given channels make a dict like {"H1":"H1:channelname", "L1":"L1:channelname", ...}
# so that we can easily obtain channel names valid for demuxer etc., and there is easy mapping with the psd for each IFO
#

channel_dict = dict((channel.split(':')[0], channel) for channel in options.channel)
all_ifos = channel_dict.keys()


#
# load reference psds (if there are files given), and sort by instruments
# this gives a dictionary similar to one above like {"H1":"freq series", "L1":"freq series", ...}
#

if options.reference_psd is not None:
	psd = series.read_psd_xmldoc(ligolw_utils.load_filename(options.reference_psd, verbose = options.verbose, contenthandler = series.PSDContentHandler))
	# check for detector mismatch with channels
	assert psd.keys() == all_ifos, 'ifo masmatch between psds and channels'
else:
	psd = None


@lsctables.use_in
class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
	pass


#
# load the segment file with specific segment name (if there is one) for gating
#

if options.segments_file is not None:
	seglists = ligolw_segments.segmenttable_get_by_name(ligolw_utils.load_filename(options.segments_file, contenthandler = ligolw_segments.LIGOLWContentHandler, verbose = options.verbose), options.segments_name).coalesce()
	assert seglists.keys() == all_ifos, 'ifo masmatch between segments and channels'
	for ifo in all_ifos:
		seglists[ifo] &= segments.segmentlist([segments.segment(LIGOTimeGPS(options.gps_start_time), LIGOTimeGPS(options.gps_end_time))])


#
# load the vetoes file too (if there is one)
#

if options.vetoes_file is not None:
	vetolists = ligolw_segments.segmenttable_get_by_name(ligolw_utils.load_filename(options.vetoes_file, contenthandler = ligolw_segments.LIGOLWContentHandler, verbose = options.verbose), options.vetoes_name).coalesce()
	assert vetolists.keys() == all_ifos, 'ifo masmatch between segments and channels'
	for ifo in all_ifos:
		vetolists[ifo] &= segments.segmentlist([segments.segment(LIGOTimeGPS(options.gps_start_time), LIGOTimeGPS(options.gps_end_time))])


#
# load template bank file and find the template bank table
# Mapping is done from instrument to sngl_burst table & xml file
#

template_file = dict.fromkeys(all_ifos, None)
template_bank_table = dict.fromkeys(all_ifos, None)

for filename in options.template_bank:
	xmldoc = ligolw_utils.load_filename(filename, contenthandler = LIGOLWContentHandler, verbose = options.verbose)
	table = lsctables.SnglBurstTable.get_table(xmldoc)
	template_bank_table[table[0].ifo] = table
	template_file[table[0].ifo] = filename


#
# format output xml file for putting triggers
#

xmldoc = ligolw.Document()
xmldoc.appendChild(ligolw.LIGO_LW())
process = ligolw_process.register_to_xmldoc(xmldoc, "StringSearch", options.__dict__)


#
# append search_summary table
#

search_summary_table = lsctables.New(lsctables.SearchSummaryTable, ["process:process_id", "comment", "ifos", "in_start_time", "in_start_time_ns", "in_end_time", "in_end_time_ns", "out_start_time", "out_start_time_ns", "out_end_time", "out_end_time_ns", "nevents", "nnodes"])
xmldoc.childNodes[-1].appendChild(search_summary_table)
search_summary = lsctables.SearchSummary()
search_summary.process_id = process.process_id
if options.user_tag:
	search_summary.comment = options.user_tag
search_summary.ifos = ",".join(all_ifos)
search_summary.out_start = search_summary.in_start = LIGOTimeGPS(options.gps_start_time)
search_summary.out_end = search_summary.in_end = LIGOTimeGPS(options.gps_end_time)
search_summary.nnodes = 1
search_summary.nevents = 0
search_summary_table.append(search_summary)


#
# append the injection file and time slide file (ligolw_add job in previous pipeline)
# the injection file already has a time slide table in it.
# FIXME we can require NOT to have time-slide file as argument when injection-file is given.
#

if options.injection_file is not None:
	xmldoc = ligolw_add.ligolw_add(xmldoc, [options.injection_file], contenthandler = LIGOLWContentHandler, verbose = options.verbose)
else:
	xmldoc = ligolw_add.ligolw_add(xmldoc, [options.time_slide_file], contenthandler = LIGOLWContentHandler, verbose = options.verbose)

time_slide_table = lsctables.TimeSlideTable.get_table(xmldoc)


#
# table for single-detector triggers
#

sngl_burst_table = lsctables.New(lsctables.SnglBurstTable, ["process:process_id", "event_id","ifo","search","channel","start_time","start_time_ns","peak_time","peak_time_ns","duration","central_freq","bandwidth","amplitude","snr","confidence","chisq","chisq_dof"])
xmldoc.childNodes[-1].appendChild(sngl_burst_table)


#
# construct dictionary of segment lists that were analyzed
# (i.e. which contributes to the live time)
#

analyzed_seglistdict = segments.segmentlistdict()


#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#


mainloop = GObject.MainLoop()
pipeline = Gst.Pipeline(name="pipeline")

firbank = dict.fromkeys(all_ifos, None)
triggergen = dict.fromkeys(all_ifos, None)


for ifo in all_ifos:
	head = pipeparts.mklalcachesrc(pipeline, options.frame_cache, cache_src_regex = ifo[0], cache_dsc_regex = ifo)
	head = pipeparts.mkframecppchanneldemux(pipeline, head, channel_list = [channel_dict[ifo]])
	pipeparts.framecpp_channeldemux_set_units(head, {channel_dict[ifo]:"strain"})
	elem = pipeparts.mkaudioconvert(pipeline, None)
	pipeparts.src_deferred_link(head, channel_dict[ifo], elem.get_static_pad("sink"))
	head = elem
	# put gate for the segments and vetoes
	# currently with leaky option on to avoid step function-like disconts in the data affect the PSD.
	if options.segments_file is not None:
		head = datasource.mksegmentsrcgate(pipeline, head, seglists[ifo], invert_output = False, leaky = True)
	if options.vetoes_file is not None:
		head = datasource.mksegmentsrcgate(pipeline, head, vetolists[ifo], invert_output = True, leaky = True)
	# limit the maximum buffer duration.  keeps RAM use under control
	# in the even that we are loading gigantic frame files
	# FIXME currently needs to be >= fft_length (= 32s) for mkwhiten to work. (I think)
	# when the reference_psd can be used for mkwhiten, change the block duration to shorter time.
	head = pipeparts.mkreblock(pipeline, head, block_duration = 64 * 1000000000)


	#
	# injections
	#

	if options.injection_file is not None:
		head = pipeparts.mkinjections(pipeline, head, options.injection_file)


	#
	# whitener, resampler and caps filter
	#

	# FIXME if reference psd is available use that for whitening data
	# the below code doesn't work...
	#if options.reference_psd is not None:
		#head = pipeparts.mkwhiten(pipeline, head, fft_length = 32, name = "lal_whiten_%s" % ifo, psd_mode = 1, mean_psd = psd[ifo].data.data)
	#else:
		#head = pipeparts.mkwhiten(pipeline, head, fft_length = 32, name = "lal_whiten_%s" % ifo)
	head = pipeparts.mkwhiten(pipeline, head, fft_length = 32, name = "lal_whiten_%s" % ifo)
	head = pipeparts.mkaudioconvert(pipeline, head)
	head = pipeparts.mkresample(pipeline, head)
	# FIXME NO hardcoding original sample rate!
	head = pipeparts.mkaudioamplify(pipeline, head, math.sqrt(16384./options.sample_rate))
	head = pipeparts.mkcapsfilter(pipeline, head, "audio/x-raw, format=F32LE, rate=%d" % options.sample_rate)
	head = pipeparts.mkqueue(pipeline, head, max_size_buffers = 8)


	#
	# filter bank
	#

	head = firbank[ifo] = pipeparts.mkfirbank(pipeline, head, fir_matrix = numpy.zeros((len(template_bank_table[ifo]),int(32*options.sample_rate)+1),dtype=numpy.float64), block_stride = 4 * options.sample_rate, latency = int(16*options.sample_rate))

	#
	# trigger generator
	#

	triggergen[ifo] = pipeparts.mkgeneric(pipeline, head, "lal_string_triggergen", threshold = options.threshold, cluster = options.cluster_events, bank_filename = template_file[ifo], autocorrelation_matrix = numpy.zeros((len(template_bank_table[ifo]), 403),dtype=numpy.float64))


#
# handler
#

handler = PipelineHandler(mainloop, pipeline, xmldoc, template_bank_table, sngl_burst_table, analyzed_seglistdict, psd, firbank, triggergen)


#
# appsync
#

appsync = pipeparts.AppSync(appsink_new_buffer = handler.appsink_new_buffer)
appsinks = set(appsync.add_sink(pipeline, triggergen[ifo], caps = Gst.Caps.from_string("application/x-lal-snglburst"), name = ifo) for ifo in all_ifos)


#
# seek
#

if pipeline.set_state(Gst.State.READY) != Gst.StateChangeReturn.SUCCESS:
	raise RuntimeError("pipeline did not enter ready state")
options.gps_start_time = LIGOTimeGPS(options.gps_start_time)
options.gps_end_time = LIGOTimeGPS(options.gps_end_time)
datasource.pipeline_seek_for_gps(pipeline, options.gps_start_time, options.gps_end_time);


#
# run
#

if pipeline.set_state(Gst.State.PLAYING) != Gst.StateChangeReturn.SUCCESS:
	raise RuntimeError("pipeline did not enter playing state")
if options.verbose:
	print >>sys.stderr, "running pipeline ..."
mainloop.run()

handler.flush()


#
# obtain nevents from the coinc event table, and write output to disk
# FIXME vetoes table also needs to be dumped here with ligolw_add
#

search_summary.nevents = len(lsctables.CoincTable.get_table(xmldoc))
ligolw_utils.write_filename(xmldoc, options.output, gz = (options.output or "stdout").endswith(".gz"), verbose = options.verbose)