Skip to content
Snippets Groups Projects
Commit f043fd0f authored by James Kennington's avatar James Kennington
Browse files

Migrate some tests to pytest

parent 308fd61a
No related branches found
No related tags found
1 merge request!44Migrate some tests to pytest
Pipeline #259040 passed
...@@ -32,3 +32,4 @@ configure ...@@ -32,3 +32,4 @@ configure
*/tests/*.trs */tests/*.trs
libtool libtool
.vscode .vscode
.idea
...@@ -171,7 +171,7 @@ test:gstlal: ...@@ -171,7 +171,7 @@ test:gstlal:
# Run doctests # Run doctests
- cd gstlal - cd gstlal
- python3 -m pytest -v --doctest-modules --ignore gst/python --ignore port-tools --ignore tests --ignore share --ignore python/misc.py --ignore python/pipeparts/__init__.py --ignore python/matplotlibhelper.py --ignore python/dagfile.py --ignore python/httpinterface.py --ignore python/pipeline.py - python3 -m pytest -c pytest.ini
only: only:
- schedules - schedules
- pushes - pushes
......
# Configuration file for pytest within gstlal
[pytest]
norecursedirs = gst/python port-tools tests share tests_pytest/utils
testpaths = tests/tests_pytest python
addopts =
-v
--doctest-modules
# Ignore doctests in specific modules
--ignore python/dagfile.py
--ignore python/httpinterface.py
--ignore python/matplotlibhelper.py
--ignore python/misc.py
--ignore python/pipeline.py
--ignore python/pipeparts/__init__.py
"""Unit tests for drop 01
"""
import numpy
from gstlal import pipeparts
from utils import cmp_nxydumps, common
class TestDrop01:
"""Test class wrapper for drop 01"""
def test_float_32(self):
"""Test 32 bit float drop"""
drop_test_02("drop_test_02a", "float64", length=13147, drop_samples=1337, sample_fuzz=cmp_nxydumps.default_sample_fuzz)
def test_float_64(self):
"""Test 64 bit float drop"""
drop_test_02("drop_test_02a", "float64", length=13147, drop_samples=1337, sample_fuzz=cmp_nxydumps.default_sample_fuzz)
def drop_test_02(name, dtype, length, drop_samples, sample_fuzz=cmp_nxydumps.default_sample_fuzz):
channels_in = 1
numpy.random.seed(0)
# check that the first array is dropped
input_array = numpy.random.random((length, channels_in)).astype(dtype)
output_reference = input_array[drop_samples:]
output_array = numpy.array(common.transform_arrays([input_array], pipeparts.mkdrop, name, drop_samples=drop_samples))
residual = abs((output_array - output_reference))
if residual[residual > sample_fuzz].any():
raise ValueError("incorrect output: expected %s, got %s\ndifference = %s" % (output_reference, output_array, residual))
"""Unit tests for firbank
"""
import numpy
import pytest
from gstlal import pipeparts
from utils import cmp_nxydumps, common
def firbank_test_01(pipeline, name, width, time_domain, gap_frequency):
#
# try changing these. test should still work!
#
rate = 2048 # Hz
gap_frequency = gap_frequency # Hz
gap_threshold = 0.8 # of 1
buffer_length = 1.0 # seconds
test_duration = 10.0 # seconds
fir_length = 21 # samples
latency = (fir_length - 1) // 2 # samples, in [0, fir_length)
#
# build pipeline
#
head = common.gapped_test_src(pipeline, buffer_length=buffer_length, rate=rate, width=width, test_duration=test_duration, gap_frequency=gap_frequency,
gap_threshold=gap_threshold, control_dump_filename="%s_control.dump" % name)
head = tee = pipeparts.mktee(pipeline, head)
fir_matrix = numpy.zeros((1, fir_length), dtype="double")
fir_matrix[0, (fir_matrix.shape[1] - 1) - latency] = 1.0
head = pipeparts.mkfirbank(pipeline, head, fir_matrix=fir_matrix, latency=latency, time_domain=time_domain)
head = pipeparts.mkchecktimestamps(pipeline, head)
pipeparts.mknxydumpsink(pipeline, pipeparts.mkqueue(pipeline, head), "%s_out.dump" % name)
pipeparts.mknxydumpsink(pipeline, pipeparts.mkqueue(pipeline, tee), "%s_in.dump" % name)
#
# done
#
return pipeline
def firbank_test_02(pipeline, name, width, time_domain):
# 1 channel goes into firbank
head = common.test_src(pipeline, buffer_length=10.0, rate=16384, width=width, channels=1, test_duration=200.0, wave=5, verbose=True)
# 200 channels come out
head = pipeparts.mkfirbank(pipeline, head, fir_matrix=numpy.ones((200, 1)), time_domain=time_domain)
pipeparts.mkfakesink(pipeline, head)
#
# done
#
return pipeline
class TestFirbank01:
"""Test class wrapper for Firbank 01
is the firbank element an identity transform when given a unit impulse?
in and out timeseries should be identical modulo start/stop transients
"""
FLAGS = cmp_nxydumps.COMPARE_FLAGS_EXACT_GAPS | cmp_nxydumps.COMPARE_FLAGS_ZERO_IS_GAP | cmp_nxydumps.COMPARE_FLAGS_ALLOW_STARTSTOP_MISALIGN
@pytest.mark.parametrize('gap_frequency', (153.0, 13.0, 0.13))
@pytest.mark.parametrize('width', (32, 64))
@pytest.mark.parametrize('time_domain', (True, False))
def test_firbank_01(self, gap_frequency, width, time_domain):
"""Test for firbank 01"""
name = "firbank_test_01_%d%s_%.2f" % (width, ("TD" if time_domain else "FD"), gap_frequency)
common.build_and_run(firbank_test_01, name, width=width, time_domain=time_domain, gap_frequency=gap_frequency)
if width == 64:
cmp_nxydumps.compare("%s_in.dump" % name, "%s_out.dump" % name, flags=self.FLAGS)
else:
cmp_nxydumps.compare("%s_in.dump" % name, "%s_out.dump" % name, flags=self.FLAGS, sample_fuzz=1e-6)
def test_firbank_02(self):
common.build_and_run(firbank_test_02, "firbank_test_02a", width=64, time_domain=True)
#!/usr/bin/env python3
#
# Copyright (C) 2013--2015 Kipp Cannon
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import itertools
from ligo import segments
from lal import iterutils
from lal import LIGOTimeGPS
default_timestamp_fuzz = 1e-9 # seconds
default_sample_fuzz = 1e-15 # relative
#
# flags
#
# when comparing time series, require gap intervals to be identical
COMPARE_FLAGS_EXACT_GAPS = 1
# consider samples that are all 0 also to be gaps
COMPARE_FLAGS_ZERO_IS_GAP = 2
# don't require the two time series to start and stop at the same time
COMPARE_FLAGS_ALLOW_STARTSTOP_MISALIGN = 4
# the default flags for comparing time series
COMPARE_FLAGS_DEFAULT = 0
#
# tools
#
def load_file(fobj, transients = (0.0, 0.0)):
stream = (line.strip() for line in fobj)
stream = (line.split() for line in stream if line and not line.startswith("#"))
lines = [(LIGOTimeGPS(line[0]),) + tuple(map(float, line[1:])) for line in stream]
assert lines, "no data"
channel_count_plus_1 = len(lines[0])
assert all(len(line) == channel_count_plus_1 for line in lines), "not all lines have the same channel count"
for t1, t2 in zip((line[0] for line in lines), (line[0] for line in lines[1:])):
assert t2 > t1, "timestamps not in order @ t = %s s" % str(t2)
start = lines[0][0] + transients[0]
stop = lines[-1][0] - transients[-1]
iterutils.inplace_filter(lambda line: start <= line[0] <= stop, lines)
assert lines, "transients remove all data"
return lines
def max_abs_sample(lines):
# return the largest of the absolute values of the samples
return max(max(abs(x) for x in line[1:]) for line in lines)
def identify_gaps(lines, timestamp_fuzz = default_timestamp_fuzz, sample_fuzz = default_sample_fuzz, flags = COMPARE_FLAGS_DEFAULT):
# assume the smallest interval bewteen samples indicates the true
# sample rate, and correct for possible round-off by assuming true
# sample rate is an integer number of Hertz
dt = min(float(line1[0] - line0[0]) for line0, line1 in zip(lines, lines[1:]))
dt = 1.0 / round(1.0 / dt)
# convert to absolute fuzz (but don't waste time with this if we
# don't need it)
if flags & COMPARE_FLAGS_ZERO_IS_GAP:
sample_fuzz *= max_abs_sample(lines)
gaps = segments.segmentlist()
for i, line in enumerate(lines):
if i and (line[0] - lines[i - 1][0]) - dt > timestamp_fuzz * 2:
# clock skip. interpret missing timestamps as a
# gap
gaps.append(segments.segment((lines[i - 1][0] + dt, line[0])))
if flags & COMPARE_FLAGS_ZERO_IS_GAP and all(abs(x) <= sample_fuzz for x in line[1:]):
# all samples are "0". the current sample is a gap
gaps.append(segments.segment((line[0], lines[i + 1][0] if i + 1 < len(lines) else line[0] + dt)))
return gaps.protract(timestamp_fuzz).coalesce()
def compare_fobjs(fobj1, fobj2, transients = (0.0, 0.0), timestamp_fuzz = default_timestamp_fuzz, sample_fuzz = default_sample_fuzz, flags = COMPARE_FLAGS_DEFAULT):
timestamp_fuzz = LIGOTimeGPS(timestamp_fuzz)
# load dump files with transients removed
lines1 = load_file(fobj1, transients = transients)
lines2 = load_file(fobj2, transients = transients)
assert len(lines1[0]) == len(lines2[0]), "files do not have same channel count"
# trim lead-in and lead-out if requested
if flags & COMPARE_FLAGS_ALLOW_STARTSTOP_MISALIGN:
lines1 = [line for line in lines1 if lines2[0][0] <= line[0] <= lines2[-1][0]]
assert lines1, "time intervals do not overlap"
lines2 = [line for line in lines2 if lines1[0][0] <= line[0] <= lines1[-1][0]]
assert lines2, "time intervals do not overlap"
# construct segment lists indicating gap intervals
gaps1 = identify_gaps(lines1, timestamp_fuzz = timestamp_fuzz, sample_fuzz = sample_fuzz, flags = flags)
gaps2 = identify_gaps(lines2, timestamp_fuzz = timestamp_fuzz, sample_fuzz = sample_fuzz, flags = flags)
if flags & COMPARE_FLAGS_EXACT_GAPS:
difference = gaps1 ^ gaps2
iterutils.inplace_filter(lambda seg: abs(seg) > timestamp_fuzz, difference)
assert not difference, "gap discrepancy: 1 ^ 2 = %s" % str(difference)
# convert relative sample fuzz to absolute
sample_fuzz *= max_abs_sample(itertools.chain(lines1, lines2))
lines1 = iter(lines1)
lines2 = iter(lines2)
# guaranteeed to be at least 1 line in both lists
line1 = next(lines1)
line2 = next(lines2)
while True:
try:
if abs(line1[0] - line2[0]) <= timestamp_fuzz:
for val1, val2 in zip(line1[1:], line2[1:]):
assert abs(val1 - val2) <= sample_fuzz, "values disagree @ t = %s s" % str(line1[0])
line1 = next(lines1)
line2 = next(lines2)
elif line1[0] < line2[0] and line1[0] in gaps2:
line1 = next(lines1)
elif line2[0] < line1[0] and line2[0] in gaps1:
line2 = next(lines2)
else:
raise AssertionError("timestamp misalignment @ %s s and %s s" % (str(line1[0]), str(line2[0])))
except StopIteration:
break
# FIXME: should check that we're at the end of both series
def compare(filename1, filename2, *args, **kwargs):
try:
compare_fobjs(open(filename1), open(filename2), *args, **kwargs)
except AssertionError as e:
raise AssertionError("%s <--> %s: %s" % (filename1, filename2, str(e)))
#
# main()
#
if __name__ == "__main__":
from optparse import OptionParser
parser = OptionParser()
parser.add_option("--compare-exact-gaps", action = "store_const", const = COMPARE_FLAGS_EXACT_GAPS, default = 0)
parser.add_option("--compare-zero-is-gap", action = "store_const", const = COMPARE_FLAGS_ZERO_IS_GAP, default = 0)
parser.add_option("--compare-allow-startstop-misalign", action = "store_const", const = COMPARE_FLAGS_ALLOW_STARTSTOP_MISALIGN, default = 0)
parser.add_option("--timestamp-fuzz", metavar = "seconds", type = "float", default = default_timestamp_fuzz)
parser.add_option("--sample-fuzz", metavar = "fraction", type = "float", default = default_sample_fuzz)
options, (filename1, filename2) = parser.parse_args()
compare(filename1, filename2, timestamp_fuzz = options.timestamp_fuzz, sample_fuzz = options.sample_fuzz, flags = options.compare_exact_gaps | options.compare_zero_is_gap | options.compare_allow_startstop_misalign)
# Copyright (C) 2009--2011,2013 Kipp Cannon
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# =============================================================================
#
# Preamble
#
# =============================================================================
#
import numpy
import sys
import gi
gi.require_version('Gst', '1.0')
from gi.repository import GObject
from gi.repository import Gst
from gstlal import pipeparts
from gstlal import pipeio
from gstlal import simplehandler
GObject.threads_init()
Gst.init(None)
if sys.byteorder == "little":
BYTE_ORDER = "LE"
else:
BYTE_ORDER = "BE"
#
# =============================================================================
#
# Utilities
#
# =============================================================================
#
def complex_test_src(pipeline, buffer_length = 1.0, rate = 2048, width = 64, test_duration = 10.0, wave = 5, freq = 0, is_live = False, verbose = True):
assert not width % 8
samplesperbuffer = int(round(buffer_length * rate))
head = pipeparts.mkaudiotestsrc(pipeline, wave = wave, freq = freq, volume = 1, blocksize = (width / 8 * 2) * samplesperbuffer, samplesperbuffer = samplesperbuffer, num_buffers = int(round(test_duration / buffer_length)), is_live = is_live)
head = pipeparts.mkcapsfilter(pipeline, head, "audio/x-raw, format=Z%d%s, rate=%d, channels=2" % (width, BYTE_ORDER, rate))
head = pipeparts.mktogglecomplex(pipeline, head)
if verbose:
head = pipeparts.mkprogressreport(pipeline, head, "src")
return head
def test_src(pipeline, buffer_length = 1.0, rate = 2048, width = 64, channels = 1, test_duration = 10.0, wave = 5, freq = 0, is_live = False, verbose = True):
assert not width % 8
if wave == "ligo":
head = pipeparts.mkfakeLIGOsrc(pipeline, instrument = "H1", channel_name = "LSC-STRAIN")
else:
samplesperbuffer = int(round(buffer_length * rate))
head = pipeparts.mkaudiotestsrc(pipeline, wave = wave, freq = freq, volume = 1, blocksize = (width / 8 * channels) * samplesperbuffer, samplesperbuffer = samplesperbuffer, num_buffers = int(round(test_duration / buffer_length)), is_live = is_live)
head = pipeparts.mkcapsfilter(pipeline, head, "audio/x-raw, format=F%d%s, rate=%d, channels=%d" % (width, BYTE_ORDER, rate, channels))
if verbose:
head = pipeparts.mkprogressreport(pipeline, head, "src")
return head
def add_gaps(pipeline, head, buffer_length, rate, test_duration, gap_frequency = None, gap_threshold = None, control_dump_filename = None):
if gap_frequency is None:
return head
samplesperbuffer = int(round(buffer_length * rate))
control = pipeparts.mkcapsfilter(pipeline, pipeparts.mkaudiotestsrc(pipeline, wave = 0, freq = gap_frequency, volume = 1, blocksize = 4 * samplesperbuffer, samplesperbuffer = samplesperbuffer, num_buffers = int(round(test_duration / buffer_length))), "audio/x-raw, format=F32%s, rate=%d, channels=1" % (BYTE_ORDER, rate))
if control_dump_filename is not None:
control = pipeparts.mknxydumpsinktee(pipeline, pipeparts.mkqueue(pipeline, control), control_dump_filename)
control = pipeparts.mkqueue(pipeline, control)
return pipeparts.mkgate(pipeline, head, control = control, threshold = gap_threshold)
def gapped_test_src(pipeline, buffer_length = 1.0, rate = 2048, width = 64, channels = 1, test_duration = 10.0, wave = 5, freq = 0, gap_frequency = None, gap_threshold = None, control_dump_filename = None, tags = None, is_live = False, verbose = True):
src = test_src(pipeline, buffer_length = buffer_length, rate = rate, width = width, channels = channels, test_duration = test_duration, wave = wave, freq = freq, is_live = is_live, verbose = verbose)
if tags is not None:
src = pipeparts.mktaginject(pipeline, src, tags)
return add_gaps(pipeline, src, buffer_length = buffer_length, rate = rate, test_duration = test_duration, gap_frequency = gap_frequency, gap_threshold = gap_threshold, control_dump_filename = control_dump_filename)
def gapped_complex_test_src(pipeline, buffer_length = 1.0, rate = 2048, width = 64, test_duration = 10.0, wave = 5, freq = 0, gap_frequency = None, gap_threshold = None, control_dump_filename = None, tags = None, is_live = False, verbose = True):
src = complex_test_src(pipeline, buffer_length = buffer_length, rate = rate, width = width, test_duration = test_duration, wave = wave, freq = freq, is_live = is_live, verbose = verbose)
if tags is not None:
src = pipeparts.mktaginject(pipeline, src, tags)
return pipeparts.mktogglecomplex(pipeline, add_gaps(pipeline, pipeparts.mktogglecomplex(pipeline, src), buffer_length = buffer_length, rate = rate, test_duration = test_duration, gap_frequency = gap_frequency, gap_threshold = gap_threshold, control_dump_filename = control_dump_filename))
#
# =============================================================================
#
# Pipeline Builder
#
# =============================================================================
#
def build_and_run(pipelinefunc, name, segment = None, **pipelinefunc_kwargs):
print("=== Running Test %s ===" % name, file=sys.stderr)
mainloop = GObject.MainLoop()
pipeline = pipelinefunc(Gst.Pipeline(name = name), name, **pipelinefunc_kwargs)
handler = simplehandler.Handler(mainloop, pipeline)
if segment is not None:
if pipeline.set_state(Gst.State.PAUSED) == Gst.StateChangeReturn.FAILURE:
raise RuntimeError("pipeline failed to enter PLAYING state")
pipeline.seek(1.0, Gst.Format(Gst.Format.TIME), Gst.SeekFlags.FLUSH, Gst.SeekType.SET, segment[0].ns(), Gst.SeekType.SET, segment[1].ns())
if pipeline.set_state(Gst.State.PLAYING) == Gst.StateChangeReturn.FAILURE:
raise RuntimeError("pipeline failed to enter PLAYING state")
mainloop.run()
#
# =============================================================================
#
# Push Arrays Through an Element
#
# =============================================================================
#
def transform_arrays(input_arrays, elemfunc, name, rate = 1, **elemfunc_kwargs):
input_arrays = list(input_arrays) # so we can modify it
output_arrays = []
pipeline = Gst.Pipeline(name = name)
head = pipeparts.mkgeneric(pipeline, None, "appsrc", caps = pipeio.caps_from_array(input_arrays[0], rate = rate))
def need_data(elem, arg, input_array_rate_pair):
input_arrays, rate = input_array_rate_pair
if input_arrays:
arr = input_arrays.pop(0)
elem.set_property("caps", pipeio.caps_from_array(arr, rate))
buf = pipeio.audio_buffer_from_array(arr, 0, 0, rate)
elem.emit("push-buffer", pipeio.audio_buffer_from_array(arr, 0, 0, rate))
return Gst.FlowReturn.OK
else:
elem.emit("end-of-stream")
return Gst.FlowReturn.EOS
head.connect("need-data", need_data, (input_arrays, rate))
head = elemfunc(pipeline, head, **elemfunc_kwargs)
head = pipeparts.mkappsink(pipeline, head)
def appsink_get_array(elem, output_arrays):
output_arrays.append(pipeio.array_from_audio_sample(elem.emit("pull-sample")))
return Gst.FlowReturn.OK
head.connect("new-sample", appsink_get_array, output_arrays)
build_and_run((lambda *args, **kwargs: pipeline), name)
return output_arrays
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment