Skip to content
Snippets Groups Projects
Forked from lscsoft / GstLAL
2155 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
svd_bank_snr.py 12.69 KiB
"""
Short cutting gstlal inspiral pipeline to produce SNR for template(s)
"""

import sys
import numpy
import threading

import gi
gi.require_version('Gst', '1.0')
gi.require_version('GstAudio', '1.0')
from gi.repository import GObject, Gst, GstAudio
GObject.threads_init()
Gst.init(None)

from gstlal import datasource
from gstlal import lloidparts
from gstlal import multirate_datasource
from gstlal import pipeparts
from gstlal import pipeio
from gstlal import reference_psd
from gstlal import simplehandler

import lal
from lal import LIGOTimeGPS
import lal.series

from ligo.lw import array as ligolw_array
from ligo.lw import ligolw
from ligo.lw import param as ligolw_param
from ligo.lw import utils as ligolw_utils

@ligolw_array.use_in
@ligolw_param.use_in
class SNRContentHandler(ligolw.LIGOLWContentHandler):
	pass

class LLOID_SNR(object):
	"""
	 The options for SNR calculation, please refer to multirate_datasource.mkwhitened_src()
	 and llloidparts.mkLLOIDhoftToSnrSlices() for more information. Defaults are:

		"psd_fft_length": 32,
		"ht_gate_threshold": None,
		"veto_segments": None,
		"track_psd": False,
		"width": 32,
		"verbose": False
	"""

	def __init__(self, psd_fft_length = 32, ht_gate_threshold = None, veto_segments = None, track_psd = False, width = 32, verbose = False):
		self.psd_fft_length = psd_fft_length
		self.ht_gate_threshold = ht_gate_threshold
		self.veto_segments = veto_segments
		self.track_psd = track_psd
		self.width = width
		self.verbose = verbose

		self.lock = threading.Lock()
		self.snr_info = {
			"timestamps": [],
			"instrument": None,
			"deltaT": None,
			"data": [],
		}

	def __call__(self, gw_data_source_info, bank, instrument, psd = None):
		pipeline = Gst.Pipeline(name = "gstlal_inspiral_LLOID_SNR")
		mainloop = GObject.MainLoop()
		handler = simplehandler.Handler(mainloop, pipeline)
		self.snr_info["instrument"] = instrument

		# sanity check
		if psd is not None:
			assert instrument in set(psd)
		assert instrument in set(gw_data_source_info.channel_dict)

		if self.verbose:
			sys.stderr.write("Building pipeline to calculate SNR...\n")

		src, statevector, dqvector = datasource.mkbasicsrc(pipeline, gw_data_source_info, instrument, self.verbose)

		hoftdict = multirate_datasource.mkwhitened_multirate_src(
			pipeline,
			src = src,
			rates = set(rate for rate in bank.get_rates()),
			instrument = instrument,
			psd = psd[instrument],
			psd_fft_length = self.psd_fft_length,
			ht_gate_threshold = self.ht_gate_threshold,
			veto_segments = self.veto_segments,
			track_psd = self.track_psd,
			width = self.width,
			statevector = statevector,
			dqvector = dqvector,
			fir_whiten_reference_psd = bank.processed_psd
			)

		snr = lloidparts.mkLLOIDhoftToSnrSlices(
			pipeline,
			hoftdict = hoftdict,
			bank = bank,
			control_snksrc = (None, None),
			block_duration = 8 * Gst.SECOND,
			fir_stride = 16,
			verbose = self.verbose,
			logname = instrument
			)

		appsink = pipeparts.mkappsink(pipeline, snr, drop = False)
		handler_id = appsink.connect("new-preroll", self.new_preroll_handler)
		assert handler_id > 0
		handler_id = appsink.connect("new-sample", self.pull_buffer)
		assert handler_id > 0
		handler_id = appsink.connect("eos", self.pull_buffer)
		assert handler_id > 0

		if self.verbose:
			sys.stderr.write("Setting pipeline state to READY...\n")
		if pipeline.set_state(Gst.State.READY) != Gst.StateChangeReturn.SUCCESS:
			raise RuntimeError("pipeline cannot enter ready state.")

		datasource.pipeline_seek_for_gps(pipeline, *gw_data_source_info.seg)

		if self.verbose:
			sys.stderr.write("Seting pipeline state to PLAYING...\n")
		if pipeline.set_state(Gst.State.PLAYING) != Gst.StateChangeReturn.SUCCESS:
			raise RuntimeError("pipeline cannot enter playing state.")
		if self.verbose:
			sys.stderr.write("Calculating SNR...\n")

		mainloop.run()

		if self.verbose:
			sys.stderr.write("Calculation done.\n")
		if pipeline.set_state(Gst.State.NULL) != Gst.StateChangeReturn.SUCCESS:
			raise RuntimeError("pipeline could not be set to NULL.")

		del pipeline, mainloop, handler

		# return snr_info containing information to construct snr for all template in the template bank
		# see make_SNR_series() to make SNR LAL Series from snr_info
		self.snr_info["data"] = numpy.concatenate(numpy.array(self.snr_info["data"]), axis = 0)
		return self.snr_info

	#===============================================================================================
	#
	#									internal functions
	#
	#===============================================================================================
	def new_preroll_handler(self, elem):
		with self.lock:
			# ignore preroll buffers
			elem.emit("pull-preroll")
			return Gst.FlowReturn.OK

	def pull_buffer(self, elem):
		with self.lock:
			sample = elem.emit("pull-sample")
			if sample is None:
				return Gst.FlowReturn.OK
			else:
				success, rate = sample.get_caps().get_structure(0).get_int("rate")

				assert success == True
				# make sure the sampling rate is the same for all data
				if self.snr_info["deltaT"] is None:
					self.snr_info["deltaT"] = 1. / rate
				else:
					assert self.snr_info["deltaT"] == 1. / rate, "data have different sampling rate."

				buf = sample.get_buffer()
				if buf.mini_object.flags & Gst.BufferFlags.GAP or buf.n_memory() == 0:
					return Gst.FlowReturn.OK
				# FIXME: check timestamps
				data = pipeio.array_from_audio_sample(sample)
				if data is not None:
					self.snr_info["data"].append(data)
					self.snr_info["timestamps"].append(LIGOTimeGPS(0, sample.get_buffer().pts))
				return Gst.FlowReturn.OK

class FIR_SNR(object):
	"""
	Required arguments:
	-gw_data_source_info:
	-template:
	-psd:
	-instrument:

	Optional arguments:
	-psd_fft_length:
	-zero_pad: Hanning Window's zero padding (seconds)
	-average_samples:
	-median_samples:
	-rate:
	-verbose:
	"""

	def __init__(self, rate, psd_fft_length = 32, zero_pad = 0, average_samples = 64, median_samples = 7, width = 32, track_psd = False, verbose = False):
		self.average_samples = average_samples
		self.lock = threading.Lock()
		self.median_samples = median_samples
		self.psd_fft_length = psd_fft_length
		self.rate = rate
		self.track_psd = track_psd
		self.verbose = verbose
		self.width = width
		self.zero_pad = zero_pad

		self.snr_info = {
			"timestamps": [],
			"instrument": None,
			"deltaT": 1./rate,
			"data": [],
		}

	def __call__(self, gw_data_source_info, template, instrument, latency, psd = None):
		self.snr_info["instrument"] = instrument

		pipeline = Gst.Pipeline("gstlal_inspiral_simple_SNR")
		mainloop = GObject.MainLoop()
		handler = simplehandler.Handler(mainloop, pipeline)

		if self.verbose:
			sys.stderr.write("Building pipeline to calculate SNR\n")

		src, statevector, dqvector = datasource.mkbasicsrc(pipeline, gw_data_source_info, instrument, verbose = self.verbose)

		hoftdict = multirate_datasource.mkwhitened_multirate_src(
					pipeline,
					src = src,
					rates = [self.rate],
					instrument = instrument,
					psd = psd[instrument],
					psd_fft_length = self.psd_fft_length,
					track_psd = self.track_psd,
					width = self.width,
					statevector = statevector,
					dqvector = dqvector
					)

		#FIXME: how to set latency
		head = pipeparts.mkfirbank(pipeline, hoftdict[self.rate], latency = latency, fir_matrix = [template], block_stride = 16 * self.rate, time_domain = False)

		appsink = pipeparts.mkappsink(pipeline, head, drop = False)
		handler_id = appsink.connect("new-preroll", self.new_preroll_handler)
		assert handler_id > 0
		handler_id = appsink.connect("new-sample", self.pull_buffer)
		assert handler_id > 0
		handler_id = appsink.connect("eos", self.pull_buffer)
		assert handler_id > 0

		if self.verbose:
			sys.stderr.write("Setting pipeline state to READY...\n")
		if pipeline.set_state(Gst.State.READY) != Gst.StateChangeReturn.SUCCESS:
			raise RuntimeError("pipeline cannot enter ready state.")

		datasource.pipeline_seek_for_gps(pipeline, *gw_data_source_info.seg)

		if self.verbose:
			sys.stderr.write("Setting pipeline state to PLAYING...\n")
		if pipeline.set_state(Gst.State.PLAYING) != Gst.StateChangeReturn.SUCCESS:
			raise RuntimeError("pipeline cannot enter playing state.")
		if self.verbose:
			sys.stderr.write("Calculating SNR...\n")

		mainloop.run()

		if self.verbose:
			sys.stderr.write("Calculation done.\n")
		if pipeline.set_state(Gst.State.NULL) != Gst.StateChangeReturn.SUCCESS:
			raise RuntimeError("pipeline could not be set to NULL.")

		del pipeline, mainloop, handler

		self.snr_info["data"] = numpy.concatenate(numpy.array(self.snr_info["data"]), axis = 0)
		return self.snr_info

	#===============================================================================================
	#
	#									internal functions
	#
	#===============================================================================================

	def new_preroll_handler(self, elem):
		with self.lock:
			# ignore preroll buffers
			elem.emit("pull-preroll")
			return Gst.FlowReturn.OK

	def pull_buffer(self, elem):
		with self.lock:
			sample = elem.emit("pull-sample")
			if sample is None:
				return Gst.FlowReturn.OK
			else:
				success, rate = sample.get_caps().get_structure(0).get_int("rate")

				assert success == True
				# make sure the sampling rate is the same for all data
				if self.snr_info["deltaT"] is not None:
					assert self.snr_info["deltaT"] == 1. / rate, "data have different sampling rate."
				self.snr_info["deltaT"] = 1. / rate

				buf = sample.get_buffer()
				if buf.mini_object.flags & Gst.BufferFlags.GAP or buf.n_memory() == 0:
					return Gst.FlowReturn.OK

				data = pipeio.array_from_audio_sample(sample)
				if data is not None:
					self.snr_info["data"].append(data)
					self.snr_info["timestamps"].append(LIGOTimeGPS(0, sample.get_buffer().pts))
				return Gst.FlowReturn.OK

#=============================================================================================
#
# 										Output Utilities
#
#=============================================================================================
def make_snr_series(snr_info, row_number = 0, drop_first = 0, drop_last = 0):
	assert drop_first >= 0, "must drop positive number of data"
	assert drop_last >= 0, "must drop positive number of data"
	bps = drop_first * int(round(1 / snr_info["deltaT"]))
	bpe = -drop_last * int(round(1 / snr_info["deltaT"])) if drop_last != 0 else None

	data = numpy.abs(snr_info["data"])[:,row_number][bps:bpe]

	if data.dtype == numpy.float32:
		tseries = lal.CreateREAL4TimeSeries(
				name = snr_info["instrument"],
				epoch = snr_info["timestamps"][0] + drop_first,
				deltaT = snr_info["deltaT"],
				f0 = 0,
				sampleUnits = lal.DimensionlessUnit,
				length = len(data)
				)
		tseries.data.data = data
	elif data.dtype == numpy.float64:
		tseries = lal.CreateREAL8TimeSeries(
				name = snr_info["instrument"],
				epoch = snr_info["timestamps"][0] + drop_first,
				deltaT = snr_info["deltaT"],
				f0 = 0,
				sampleUnits = lal.DimensionlessUnit,
				length = len(data)
				)
		tseries.data.data = data
	else:
		raise ValueError("unsupported type : %s " % data.dtype)

	return tseries

def make_xmldoc(snrdict, xmldoc = None, root_name = u"gstlal_inspiral_snr"):
	if xmldoc is None:
		xmldoc = ligolw.Document()

	root = xmldoc.appendChild(ligolw.LIGO_LW())
	root.Name = root_name
	for instrument, snr in snrdict.items():
		if snr.data.data.dtype == numpy.float32:
			tseries = root.appendChild(lal.series.build_REAL4TimeSeries(snr))
		elif snr.data.data.dtype == numpy.float64:
			tseries = root.appendChild(lal.series.build_REAL8TimeSeries(snr))
		else:
			raise ValueError("unsupported type : %s" % snr.data.data.dtype)
		if instrument is not None:
			tseries.appendChild(ligolw_param.Param.from_pyvalue(u"instrument", instrument))

	return xmldoc

def read_xmldoc(xmldoc, root_name = u"gstlal_inspiral_snr"):
	if root_name is not None:
		xmldoc, = (elem for elem in xmldoc.getElementsByTagName(ligolw.LIGO_LW.tagName) if elem.hasAttribute(u"Name") if elem.Name == root_name)

	result = []
	for elem in (elem for elem in xmldoc.getElementsByTagName(ligolw.LIGO_LW.tagName) if elem.hasAttribute(u"Name")):
		if elem.Name == u"REAL4TimeSeries":
			result.append((ligolw_param.get_pyvalue(elem, u"instrument"), lal.series.parse_REAL4TimeSeries(elem)))
		elif elem.Name == u"REAL8TimeSeries":
			result.append((ligolw_param.get_pyvalue(elem, u"instrument"), lal.series.parse_REAL8TimeSeries(elem)))

	assert result is not None, "xmldoc contains no LAL Series or LAL Series is unsupported"

	return dict(result)

# wrapper for writing snr series to URL
def write_url(xmldoc, filename, verbose = False):
	ligolw_utils.write_filename(xmldoc, filename, gz = filename.endswith(".gz"), verbose = verbose)

# wrapper for reading snr series from URL
def read_url(filename, contenthandler, verbose = False):
	return ligolw_utils.load_url(filename, verbose = verbose, contenthandler = contenthandler)