From 756a578b678ad1f143eb5542868392fe01558501 Mon Sep 17 00:00:00 2001
From: Patrick Godwin <patrick.godwin@ligo.org>
Date: Sun, 7 Feb 2021 15:04:26 -0800
Subject: [PATCH] reference_psd.measure_psd(): convert to stream API

---
 gstlal/python/psd.py | 63 +++++++++++++++++++++++---------------------
 1 file changed, 33 insertions(+), 30 deletions(-)

diff --git a/gstlal/python/psd.py b/gstlal/python/psd.py
index 7ed1167ae4..88c8d125f8 100644
--- a/gstlal/python/psd.py
+++ b/gstlal/python/psd.py
@@ -55,6 +55,7 @@ from gstlal import datasource
 from gstlal import pipeparts
 from gstlal import pipeio
 from gstlal import simplehandler
+from gstlal.stream import MessageType, Stream
 
 
 __doc__ = """
@@ -95,6 +96,15 @@ class PSDHandler(simplehandler.Handler):
 		return False
 
 
+class PSDTracker:
+	def __init__(self):
+		self.psd = None
+
+	def on_spectrum_message(self, message):
+		self.psd = pipeio.parse_spectrum_message(message)
+		return True
+
+
 #
 # measure_psd()
 #
@@ -138,53 +148,46 @@ def measure_psd(gw_data_source_info, instrument, rate, psd_fft_length = 8, verbo
 		raise ValueError("segment %s too short" % str(gw_data_source_info.seg))
 
 	#
-	# build pipeline
+	# calculate number of samples to average over
 	#
 
-	if verbose:
-		print("measuring PSD in segment %s" % str(gw_data_source_info.seg), file=sys.stderr)
-		print("building pipeline ...", file=sys.stderr)
-	mainloop = GObject.MainLoop()
-	pipeline = Gst.Pipeline(name="psd")
-	handler = PSDHandler(mainloop, pipeline)
-
-	head, _, _ = datasource.mkbasicsrc(pipeline, gw_data_source_info, instrument, verbose = verbose)
-	head = pipeparts.mkcapsfilter(pipeline, head, "audio/x-raw, rate=[%d,MAX]" % rate)	# disallow upsampling
-	head = pipeparts.mkresample(pipeline, head, quality = 9)
-	head = pipeparts.mkcapsfilter(pipeline, head, "audio/x-raw, rate=%d" % rate)
-	head = pipeparts.mkqueue(pipeline, head, max_size_buffers = 8)
 	if gw_data_source_info.seg is not None:
 		average_samples = int(round(float(abs(gw_data_source_info.seg)) / (psd_fft_length / 2.) - 1.))
 	else:
 		#FIXME maybe let the user specify this
 		average_samples = 64
-	head = pipeparts.mkwhiten(pipeline, head, psd_mode = 0, zero_pad = 0, fft_length = psd_fft_length, average_samples = average_samples, median_samples = 7)
-	pipeparts.mkfakesink(pipeline, head)
 
 	#
-	# setup signal handler to shutdown pipeline for live data
+	# initialize PSD tracker
 	#
 
-	if gw_data_source_info.data_source in ("lvshm", "framexmit"):# FIXME what about nds online?
-		simplehandler.OneTimeSignalHandler(pipeline)
+	tracker = PSDTracker()
 
 	#
-	# process segment
+	# build pipeline
 	#
 
 	if verbose:
-		print("putting pipeline into READY state ...", file=sys.stderr)
-	if pipeline.set_state(Gst.State.READY) == Gst.StateChangeReturn.FAILURE:
-		raise RuntimeError("pipeline failed to enter READY state")
-	if gw_data_source_info.data_source not in ("lvshm", "framexmit"):# FIXME what about nds online?
-		datasource.pipeline_seek_for_gps(pipeline, *gw_data_source_info.seg)
-	if verbose:
-		print("putting pipeline into PLAYING state ...", file=sys.stderr)
-	if pipeline.set_state(Gst.State.PLAYING) == Gst.StateChangeReturn.FAILURE:
-		raise RuntimeError("pipeline failed to enter PLAYING state")
+		print("measuring PSD in segment %s" % str(gw_data_source_info.seg), file=sys.stderr)
+		print("building pipeline ...", file=sys.stderr)
+
+	stream = Stream.from_datasource(gw_data_source_info, instrument, verbose=verbose)
+	stream.add_callback(MessageType.ELEMENT, "spectrum", tracker.on_spectrum_message)
+
+	stream = stream.capsfilter(f"audio/x-raw, rate=[{rate:d},MAX]")  # disallow upsampling
+	stream.resample(quality=9) \
+		.capsfilter(f"audio/x-raw, rate={rate:d}") \
+		.queue(max_size_buffers=8) \
+		.whiten(psd_mode=0, zero_pad=0, fft_length=psd_fft_length, average_samples=average_samples, median_samples=7) \
+		.fakesink()
+
+	#
+	# process segment
+	#
+
 	if verbose:
 		print("running pipeline ...", file=sys.stderr)
-	mainloop.run()
+	stream.start()
 
 	#
 	# done
@@ -192,7 +195,7 @@ def measure_psd(gw_data_source_info, instrument, rate, psd_fft_length = 8, verbo
 
 	if verbose:
 		print("PSD measurement complete", file=sys.stderr)
-	return handler.psd
+	return tracker.psd
 
 
 def read_psd(filename: str, verbose: Optional[bool] = False) -> Dict[str, lal.REAL8FrequencySeries]:
-- 
GitLab