Skip to content
Snippets Groups Projects
Commit 756e9897 authored by Patrick Godwin's avatar Patrick Godwin
Browse files

Merge branch 'schema_metadata' into 'main'

Embed per-channel sample rate metadata in stream schema

Closes #16

See merge request ngdd/arrakis-server!40
parents 1894cddc 949e9307
Branches main
No related tags found
No related merge requests found
Pipeline #686570 passed
......@@ -9,19 +9,17 @@
from collections.abc import Iterable
import numpy
import pyarrow
from arrakis import Channel
def stream(channels: Iterable[str], dtypes: Iterable[str]) -> pyarrow.Schema:
def stream(channels: Iterable[Channel]) -> pyarrow.Schema:
"""Create an Arrow Flight schema for `stream`.
Parameters
----------
channels : Iterable[str]
The list of channels for the fetch request.
dtypes : Iterable[str]
The list of data types, one per channel.
channels : Iterable[Channel]
The list of channels for the stream request.
Returns
-------
......@@ -30,9 +28,12 @@ def stream(channels: Iterable[str], dtypes: Iterable[str]) -> pyarrow.Schema:
"""
columns = [pyarrow.field("time", pyarrow.int64(), nullable=False)]
pyarrow_dtypes = [pyarrow.from_numpy_dtype(numpy.dtype(dtype)) for dtype in dtypes]
for channel, dtype in zip(channels, pyarrow_dtypes):
columns.append(pyarrow.field(channel, pyarrow.list_(dtype), nullable=False))
for channel in channels:
dtype = pyarrow.from_numpy_dtype(channel.data_type)
field = pyarrow.field(channel.name, pyarrow.list_(dtype)).with_metadata(
{"rate": str(channel.sample_rate)}
)
columns.append(field)
return pyarrow.schema(columns)
......
......@@ -317,12 +317,9 @@ class ArrakisFlightServer(flight.FlightServerBase):
match request:
case RequestType.Stream:
# FIXME: this should query the backend to map channel
# metadata to their corresponding data types. does
# this imply that a fetch request requires a
# find_channels request?
dtypes = ["float32" for _ in args["channels"]]
schema = schemas.stream(args["channels"], dtypes)
assert isinstance(self._backend, traits.ServerBackend)
metadata = self._backend.describe(channels=args["channels"])
schema = schemas.stream(metadata)
case RequestType.Describe:
schema = schemas.describe()
case RequestType.Find:
......@@ -356,8 +353,7 @@ class ArrakisFlightServer(flight.FlightServerBase):
# extract first block to determine schema
block = next(blocks)
metadata = [block.channels[name] for name in channels]
dtypes = [channel.data_type.name for channel in metadata]
schema = schemas.stream(channels, dtypes)
schema = schemas.stream(metadata)
# stitch back together and convert to record batches
blocks = itertools.chain([block], blocks)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment