Skip to content
Snippets Groups Projects

Plot horizon distance from ranking statistics

Merged ChiWai Chan requested to merge plot_psd_horizon into master
1 unresolved thread
1 file
+ 223
59
Compare changes
  • Side-by-side
  • Inline
  • * fold in state/DQ vector support in Stream.from_datasource
    * expose setting pipeline state via stream.set_state(state)
    * support operations on groups of keyed streamed via StreamMap
    * allow callbacks via EOS messages where message types are not relevant
    * support x-lal caps when pulling buffers from appsink
    * support AppSync functionality for synchronization via StreamMap.bufsink
+ 223
59
@@ -33,6 +33,7 @@ import functools
import io
import os
import sys
import threading
import uuid
import numpy
@@ -45,10 +46,13 @@ from gi.repository import GObject
from gi.repository import Gst
from gi.repository import GstAudio
from lal import LIGOTimeGPS
from ligo import segments
from gstlal import datasource
from gstlal import pipeparts
from gstlal import pipeio
from gstlal import plugins
from gstlal import simplehandler
@@ -62,18 +66,19 @@ from gstlal import simplehandler
#
SourceElem = namedtuple("SourceElem", "datasource is_live gps_range")
Buffer = namedtuple("Buffer", "t0 data")
SourceElem = namedtuple("SourceElem", "datasource is_live gps_range state_vector dq_vector")
Buffer = namedtuple("Buffer", "name t0 duration data")
MessageType = Gst.MessageType
class Stream(object):
class Stream:
"""Class for building a GStreamer-based pipeline.
"""
_thread_init = False
_has_elements = False
_caps_buffer_map = None
def __init__(self, name=None, mainloop=None, pipeline=None, handler=None, source=None, head=None):
# initialize threads if not set
@@ -87,6 +92,10 @@ class Stream(object):
for elem_name, elem in self._get_registered_elements().items():
self.register_element(elem_name)(elem)
# register caps to buffer mapping
if self._caps_buffer_map is None:
self._load_caps_buffer_map()
# set up gstreamer pipeline
self.name = name if name else str(uuid.uuid1())
self.mainloop = mainloop if mainloop else GObject.MainLoop()
@@ -102,10 +111,10 @@ class Stream(object):
"""
if self.source.is_live:
simplehandler.OneTimeSignalHandler(self.pipeline)
self._set_state(Gst.State.READY)
self.set_state(Gst.State.READY)
if not self.source.is_live:
self._seek_gps()
self._set_state(Gst.State.PLAYING)
self.set_state(Gst.State.PLAYING)
## Debugging output
if os.environ.get("GST_DEBUG_DUMP_DOT_DIR", False):
@@ -128,63 +137,87 @@ class Stream(object):
def register(func):
def wrapped(self, *srcs, **kwargs):
head = func(self.pipeline, self.head, *srcs, **kwargs)
if isinstance(head, Mapping):
new_head = head.__class__()
for key, elem in head.items():
new_head = {
key: cls(
name=self.name,
mainloop=self.mainloop,
pipeline=self.pipeline,
handler=self.handler,
source=self.source,
head=elem,
)
}
return new_head
else:
return cls(
name=self.name,
mainloop=self.mainloop,
pipeline=self.pipeline,
handler=self.handler,
source=self.source,
head=head,
)
return cls(
name=self.name,
mainloop=self.mainloop,
pipeline=self.pipeline,
handler=self.handler,
source=self.source,
head=head,
)
setattr(cls, elem_name, wrapped)
return register
@classmethod
def from_datasource(cls, data_source_info, ifo, verbose=False):
stream = cls()
stream.head, _, _ = datasource.mkbasicsrc(stream.pipeline, data_source_info, ifo, verbose=verbose)
def from_datasource(cls, data_source_info, ifos, verbose=False, state_vector=False, dq_vector=False):
is_live = data_source_info.data_source in data_source_info.live_sources
stream.source = SourceElem(
datasource=data_source_info.data_source,
is_live=is_live,
gps_range=data_source_info.seg,
)
return stream
ref_stream = cls()
if isinstance(ifos, str):
stream = ref_stream
stream.head, state_vector, dq_vector = datasource.mkbasicsrc(
stream.pipeline,
data_source_info,
ifos,
verbose=verbose
)
stream.source = SourceElem(
datasource=data_source_info.data_source,
is_live=is_live,
gps_range=data_source_info.seg,
state_vector=state_vector,
dq_vector=dq_vector
)
return stream
else:
stream_map = {}
state_vectors = {}
dq_vectors = {}
for ifo in ifos:
stream = cls(
name=ref_stream.name,
mainloop=ref_stream.mainloop,
pipeline=ref_stream.pipeline,
handler=ref_stream.handler,
)
stream.head, state_vectors[ifo], dq_vectors[ifo] = datasource.mkbasicsrc(
stream.pipeline,
data_source_info,
ifo,
verbose=verbose
)
stream_map[ifo] = stream
stream = StreamMap.from_dict(stream_map)
stream.source = SourceElem(
datasource=data_source_info.data_source,
is_live=is_live,
gps_range=data_source_info.seg,
state_vector=state_vectors,
dq_vector=dq_vectors
)
return stream
def connect(self, *args, **kwargs):
self.head.connect(*args, **kwargs)
def sink(self, func):
def bufsink(self, func, caps=None):
def sample_handler(elem):
buf = self._pull_buffer(elem)
func(buf)
buf = self._pull_buffer(elem, caps=caps)
if buf:
func(buf)
return Gst.FlowReturn.OK
sink = pipeparts.mkappsink(self.pipeline, self.head, max_buffers=1, sync=False)
sink.connect("new-sample", sample_handler)
sink.connect("new-preroll", self._preroll_handler)
def add_callback(self, msg_type, msg_name, callback):
def add_callback(self, msg_type, *args):
"""
"""
self.handler.add_callback(msg_type, msg_name, callback)
self.handler.add_callback(msg_type, *args)
def _set_state(self, state):
def set_state(self, state):
"""Set pipeline state, checking for errors.
"""
if self.pipeline.set_state(state) == Gst.StateChangeReturn.FAILURE:
@@ -196,19 +229,72 @@ class Stream(object):
start, end = self.source.gps_range
datasource.pipeline_seek_for_gps(self.pipeline, start, end)
@staticmethod
def _pull_buffer(elem):
buf = elem.emit("pull-sample").get_buffer()
buftime = buf.pts // 1e9
result, mapinfo = buf.map(Gst.MapFlags.READ)
if mapinfo.data:
with io.BytesIO(mapinfo.data) as s:
newbuf = Buffer(t0=buftime, data=numpy.loadtxt(s))
@classmethod
def _pull_buffer(cls, elem, caps=None):
# get buffer
sample = elem.emit("pull-sample")
buf = sample.get_buffer()
buftime = LIGOTimeGPS(0, buf.pts)
# check if valid buffer is available
if buf.mini_object.flags & Gst.BufferFlags.GAP or buf.n_memory() == 0:
return
# read from buffer
if caps:
data = []
for i in range(buf.n_memory()):
memory = buf.peek_memory(i)
success, mapinfo = memory.map(Gst.MapFlags.READ)
assert success
if mapinfo.data:
rows = cls._caps_buffer_map[caps.to_string()](mapinfo.data)
data.extend(rows)
memory.unmap(mapinfo)
return Buffer(
t0=buftime,
duration=buf.duration,
data=data,
name=elem.name,
)
else:
newbuf = Buffer(t0=buftime, data=None)
buf.unmap(mapinfo)
del buf
return newbuf
return Buffer(
t0=buftime,
duration=buf.duration,
data=pipeio.array_from_audio_sample(sample),
name=elem.name,
)
@classmethod
def _load_caps_buffer_map(cls):
bufmap = {}
# load table definitions if available
# FIXME: this is really ugly, revisit this with importlib or similar
try:
from gstlal.snglinspiraltable import GSTLALSnglInspiral
except ImportError:
pass
else:
bufmap["application/x-lal-snglinspiral"] = GSTLALSnglInspiral.from_buffer
try:
from gstlal.snglbursttable import GSTLALSnglBurst
except ImportError:
pass
else:
bufmap["application/x-lal-snglburst"] = GSTLALSnglBurst.from_buffer
try:
from gstlal.snax.sngltriggertable import GSTLALSnglTrigger
except ImportError:
pass
else:
bufmap["application/gstlal-sngltrigger"] = GSTLALSnglTrigger.from_buffer
cls._caps_buffer_map = bufmap
@staticmethod
def _preroll_handler(elem):
@@ -230,6 +316,13 @@ class Stream(object):
from gstlal.pipeparts import condition
manager.register(condition)
try:
from gstlal import lloidparts
except ImportError:
pass
else:
manager.register(lloidparts)
# add all registered plugins to registry
registered = {}
for plugin_name in manager.hook.elements():
@@ -239,6 +332,64 @@ class Stream(object):
return registered
class StreamMap(Stream):
def __getitem__(self, key):
return self.__class__(
name=self.name,
mainloop=self.mainloop,
pipeline=self.pipeline,
handler=self.handler,
source=self.source,
head=self.head[key],
)
def __setitem__(self, key, value):
self.head[key] = value
def keys(self):
yield from self.head.keys()
def values(self):
for key in self.keys():
yield self[key]
def items(self):
for key in self.keys():
yield key, self[key]
def bufsink(self, func, caps=None):
def sample_handler(elem):
buf = self._pull_buffer(elem, caps=caps)
if buf:
func(buf)
return Gst.FlowReturn.OK
self._appsync = pipeparts.AppSync(appsink_new_buffer=sample_handler)
for key in self.keys():
self._appsync.add_sink(self.pipeline, self.head[key], name=key)
@classmethod
def from_dict(cls, stream_dict):
# check that stream properties are consistent
ref_key = next(iter(stream_dict.keys()))
ref_stream = stream_dict[ref_key]
for stream in stream_dict.values():
assert stream.name == ref_stream.name
assert stream.mainloop is ref_stream.mainloop
assert stream.pipeline is ref_stream.pipeline
assert stream.handler is ref_stream.handler
assert stream.source is ref_stream.source
return cls(
name=ref_stream.name,
mainloop=ref_stream.mainloop,
pipeline=ref_stream.pipeline,
handler=ref_stream.handler,
source=ref_stream.source,
head={key: stream.head for key, stream in stream_dict.items()}
)
class StreamHandler(simplehandler.Handler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -250,9 +401,16 @@ class StreamHandler(simplehandler.Handler):
Gst.MessageType.EOS: {},
}
def add_callback(self, msg_type, msg_name, callback):
def add_callback(self, msg_type, *args):
"""
"""
# FIXME: would be better to rearrange the method signature so
# this extra step to determine args doesn't need to be done
if len(args) == 1:
msg_name = None
callback = args[0]
else:
msg_name, callback = args
if msg_name in self.callbacks[msg_type]:
raise ValueError("callback already registered for message type/name")
self.callbacks[msg_type][msg_name] = callback
@@ -261,8 +419,14 @@ class StreamHandler(simplehandler.Handler):
"""
"""
if message.type in self.callbacks:
if message.get_structure():
if message.type == Gst.MessageType.EOS:
# EOS messages don't have specific subtypes so we don't
# parse the message's structure to determine how to proceed
message_name = None
elif message.get_structure():
message_name = message.get_structure().get_name()
if message_name in self.callbacks[message.type]:
self.callbacks[message.type][message_name](message)
else:
return False
if message_name in self.callbacks[message.type]:
self.callbacks[message.type][message_name](message)
return False
Loading