Skip to content
Snippets Groups Projects
Commit 01094aef authored by Patrick Godwin's avatar Patrick Godwin
Browse files

stream.py: add type annotations, minor cleanup

parent 35417eab
No related branches found
No related tags found
1 merge request!55Add high-level Stream API to build GStreamer pipelines
......@@ -14,18 +14,9 @@
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
## @file
## @package stream
#
# =============================================================================
#
# Preamble
#
# =============================================================================
#
"""High-level tools to build GStreamer pipelines.
"""
from collections import namedtuple
from collections.abc import Mapping
......@@ -34,6 +25,8 @@ import io
import os
import sys
import threading
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Mapping as MappingType
import uuid
import numpy
......@@ -56,15 +49,6 @@ from gstlal import plugins
from gstlal import simplehandler
#
# =============================================================================
#
# Stream
#
# =============================================================================
#
SourceElem = namedtuple("SourceElem", "datasource is_live gps_range state_vector dq_vector")
Buffer = namedtuple("Buffer", "name t0 duration data caps")
......@@ -80,13 +64,13 @@ class Stream:
def __init__(
self,
name=None,
mainloop=None,
pipeline=None,
handler=None,
source=None,
head=None,
):
name: Optional[str] = None,
mainloop: Optional[GLib.MainLoop] = None,
pipeline: Optional[Gst.Pipeline] = None,
handler: Optional["StreamHandler"] = None,
source: Optional["SourceElem"] = None,
head: Union[MappingType[str, Gst.Element], Gst.Element, None] = None,
) -> None:
# initialize GStreamer if needed
if not self._gst_init:
Gst.init(None)
......@@ -111,7 +95,7 @@ class Stream:
# set up source elem properties
self.source = source if source else None
def start(self):
def start(self) -> None:
"""Start up the pipeline.
"""
if self.source.is_live:
......@@ -136,11 +120,11 @@ class Stream:
self.mainloop.run()
@classmethod
def register_element(cls, elem_name):
def register_element(cls, elem_name: str) -> Callable[[Gst.Element], None]:
"""Register an element to the stream, making it callable.
"""
def register(func):
def wrapped(self, *srcs, **kwargs):
def register(func: Callable[..., Gst.Element]) -> None:
def attach_element(self, *srcs, **kwargs) -> "Stream":
head = func(self.pipeline, self.head, *srcs, **kwargs)
return cls(
name=self.name,
......@@ -150,19 +134,19 @@ class Stream:
source=self.source,
head=head,
)
setattr(cls, elem_name, wrapped)
setattr(cls, elem_name, attach_element)
return register
@classmethod
def from_datasource(
cls,
data_source_info,
ifos,
verbose=False,
state_vector=False,
dq_vector=False
):
is_live = data_source_info.data_source in data_source_info.live_sources
data_source_info: datasource.DataSourceInfo,
ifos: Union[str, Iterable[str]],
verbose: bool = False,
state_vector: bool = False,
dq_vector: bool = False
) -> "Stream":
is_live = data_source_info.data_source in datasource.KNOWN_LIVE_DATASOURCES
if isinstance(ifos, str):
ifos = [ifos]
keyed = False
......@@ -200,11 +184,15 @@ class Stream:
else:
return stream[ifos[0]]
def connect(self, *args, **kwargs):
def connect(self, *args, **kwargs) -> None:
self.head.connect(*args, **kwargs)
def bufsink(self, func, caps=None):
def sample_handler(elem):
def bufsink(
self,
func: Callable[[Buffer], None],
caps: Optional[Gst.Caps] = None
) -> None:
def sample_handler(elem: Gst.Element):
buf = self._pull_buffer(elem, caps=caps)
if buf:
func(buf)
......@@ -219,28 +207,28 @@ class Stream:
sink.connect("new-sample", sample_handler)
sink.connect("new-preroll", self._preroll_handler)
def add_callback(self, msg_type, *args):
def add_callback(self, msg_type: Gst.MessageType, *args) -> None:
"""
"""
self.handler.add_callback(msg_type, *args)
def set_state(self, state):
def set_state(self, state: Gst.State) -> None:
"""Set pipeline state, checking for errors.
"""
if self.pipeline.set_state(state) == Gst.StateChangeReturn.FAILURE:
raise RuntimeError(f"pipeline failed to enter {state.value_name}")
def get_element_by_name(self, name):
def get_element_by_name(self, name: str) -> Gst.Element:
return self.pipeline.get_by_name(name)
def post_message(self, msg_name, timestamp=None):
def post_message(self, msg_name: None, timestamp: Optional[int] = None) -> None:
s = Gst.Structure.new_empty(msg_name)
message = Gst.Message.new_application(self.pipeline, s)
if timestamp:
message.timestamp = timestamp
self.pipeline.get_bus().post(message)
def __getitem__(self, key):
def __getitem__(self, key: str) -> "Stream":
return self.__class__(
name=self.name,
mainloop=self.mainloop,
......@@ -250,7 +238,7 @@ class Stream:
head=self.head.setdefault(key, {}),
)
def __setitem__(self, key, value):
def __setitem__(self, key: str, value: "Stream") -> None:
if self.pipeline:
assert self.name == value.name
assert self.mainloop is value.mainloop
......@@ -266,18 +254,18 @@ class Stream:
self.head[key] = value.head
def keys(self):
def keys(self) -> Iterable[str]:
yield from self.head.keys()
def values(self):
def values(self) -> Iterable["Stream"]:
for key in self.keys():
yield self[key]
def items(self):
def items(self) -> Iterable[Tuple[str, "Stream"]]:
for key in self.keys():
yield key, self[key]
def remap(self):
def remap(self) -> "Stream":
return self.__class__(
name=self.name,
mainloop=self.mainloop,
......@@ -287,14 +275,14 @@ class Stream:
head={},
)
def _seek_gps(self):
def _seek_gps(self) -> None:
"""Seek pipeline to the given GPS start/end times.
"""
start, end = self.source.gps_range
datasource.pipeline_seek_for_gps(self.pipeline, start, end)
@classmethod
def _pull_buffer(cls, elem, caps=None):
def _pull_buffer(cls, elem: Gst.Element, caps: Optional[Gst.Caps] = None):
# get buffer
sample = elem.emit("pull-sample")
buf = sample.get_buffer()
......@@ -335,7 +323,7 @@ class Stream:
)
@classmethod
def _load_caps_buffer_map(cls):
def _load_caps_buffer_map(cls) -> None:
bufmap = {}
# load table definitions if available
# FIXME: this is really ugly, revisit this with importlib or similar
......@@ -363,13 +351,13 @@ class Stream:
cls._caps_buffer_map = bufmap
@staticmethod
def _preroll_handler(elem):
def _preroll_handler(elem: Gst.Element) -> Gst.FlowReturn:
buf = elem.emit("pull-preroll")
del buf
return Gst.FlowReturn.OK
@classmethod
def _get_registered_elements(cls):
def _get_registered_elements(cls) -> MappingType[str, Callable[..., Gst.Element]]:
"""Get all registered GStreamer elements.
"""
# set up plugin manager
......@@ -409,7 +397,7 @@ class StreamHandler(simplehandler.Handler):
Gst.MessageType.EOS: {},
}
def add_callback(self, msg_type, *args):
def add_callback(self, msg_type: Gst.MessageType, *args):
"""
"""
# FIXME: would be better to rearrange the method signature so
......@@ -423,7 +411,7 @@ class StreamHandler(simplehandler.Handler):
raise ValueError("callback already registered for message type/name")
self.callbacks[msg_type][msg_name] = callback
def do_on_message(self, bus, message):
def do_on_message(self, bus: Gst.Bus, message: Gst.Message):
"""
"""
if message.type in self.callbacks:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment