From f043fd0f7dec3d8625b680939607f89d6ec325e4 Mon Sep 17 00:00:00 2001 From: James Kennington <james.kennington@ligo.org> Date: Mon, 19 Jul 2021 16:27:27 +0000 Subject: [PATCH] Migrate some tests to pytest --- .gitignore | 1 + .gitlab-ci.yml | 2 +- gstlal/pytest.ini | 14 ++ gstlal/tests/tests_pytest/test_drop_01.py | 32 ++++ gstlal/tests/tests_pytest/test_firbank_01.py | 83 +++++++++ .../tests/tests_pytest/utils/cmp_nxydumps.py | 168 +++++++++++++++++ gstlal/tests/tests_pytest/utils/common.py | 171 ++++++++++++++++++ 7 files changed, 470 insertions(+), 1 deletion(-) create mode 100644 gstlal/pytest.ini create mode 100644 gstlal/tests/tests_pytest/test_drop_01.py create mode 100755 gstlal/tests/tests_pytest/test_firbank_01.py create mode 100755 gstlal/tests/tests_pytest/utils/cmp_nxydumps.py create mode 100644 gstlal/tests/tests_pytest/utils/common.py diff --git a/.gitignore b/.gitignore index 4f3e09f93f..fc6484fbb4 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,4 @@ configure */tests/*.trs libtool .vscode +.idea diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 66a6f721d2..04d9f8fefb 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -171,7 +171,7 @@ test:gstlal: # Run doctests - cd gstlal - - python3 -m pytest -v --doctest-modules --ignore gst/python --ignore port-tools --ignore tests --ignore share --ignore python/misc.py --ignore python/pipeparts/__init__.py --ignore python/matplotlibhelper.py --ignore python/dagfile.py --ignore python/httpinterface.py --ignore python/pipeline.py + - python3 -m pytest -c pytest.ini only: - schedules - pushes diff --git a/gstlal/pytest.ini b/gstlal/pytest.ini new file mode 100644 index 0000000000..975a5af8c5 --- /dev/null +++ b/gstlal/pytest.ini @@ -0,0 +1,14 @@ +# Configuration file for pytest within gstlal +[pytest] +norecursedirs = gst/python port-tools tests share tests_pytest/utils +testpaths = tests/tests_pytest python +addopts = + -v + --doctest-modules + # Ignore doctests in specific modules + --ignore python/dagfile.py + --ignore python/httpinterface.py + --ignore python/matplotlibhelper.py + --ignore python/misc.py + --ignore python/pipeline.py + --ignore python/pipeparts/__init__.py diff --git a/gstlal/tests/tests_pytest/test_drop_01.py b/gstlal/tests/tests_pytest/test_drop_01.py new file mode 100644 index 0000000000..857fab0e0f --- /dev/null +++ b/gstlal/tests/tests_pytest/test_drop_01.py @@ -0,0 +1,32 @@ +"""Unit tests for drop 01 + +""" + +import numpy + +from gstlal import pipeparts +from utils import cmp_nxydumps, common + + +class TestDrop01: + """Test class wrapper for drop 01""" + + def test_float_32(self): + """Test 32 bit float drop""" + drop_test_02("drop_test_02a", "float64", length=13147, drop_samples=1337, sample_fuzz=cmp_nxydumps.default_sample_fuzz) + + def test_float_64(self): + """Test 64 bit float drop""" + drop_test_02("drop_test_02a", "float64", length=13147, drop_samples=1337, sample_fuzz=cmp_nxydumps.default_sample_fuzz) + + +def drop_test_02(name, dtype, length, drop_samples, sample_fuzz=cmp_nxydumps.default_sample_fuzz): + channels_in = 1 + numpy.random.seed(0) + # check that the first array is dropped + input_array = numpy.random.random((length, channels_in)).astype(dtype) + output_reference = input_array[drop_samples:] + output_array = numpy.array(common.transform_arrays([input_array], pipeparts.mkdrop, name, drop_samples=drop_samples)) + residual = abs((output_array - output_reference)) + if residual[residual > sample_fuzz].any(): + raise ValueError("incorrect output: expected %s, got %s\ndifference = %s" % (output_reference, output_array, residual)) diff --git a/gstlal/tests/tests_pytest/test_firbank_01.py b/gstlal/tests/tests_pytest/test_firbank_01.py new file mode 100755 index 0000000000..29784dede5 --- /dev/null +++ b/gstlal/tests/tests_pytest/test_firbank_01.py @@ -0,0 +1,83 @@ +"""Unit tests for firbank + +""" + +import numpy +import pytest + +from gstlal import pipeparts +from utils import cmp_nxydumps, common + + +def firbank_test_01(pipeline, name, width, time_domain, gap_frequency): + # + # try changing these. test should still work! + # + + rate = 2048 # Hz + gap_frequency = gap_frequency # Hz + gap_threshold = 0.8 # of 1 + buffer_length = 1.0 # seconds + test_duration = 10.0 # seconds + fir_length = 21 # samples + latency = (fir_length - 1) // 2 # samples, in [0, fir_length) + + # + # build pipeline + # + + head = common.gapped_test_src(pipeline, buffer_length=buffer_length, rate=rate, width=width, test_duration=test_duration, gap_frequency=gap_frequency, + gap_threshold=gap_threshold, control_dump_filename="%s_control.dump" % name) + head = tee = pipeparts.mktee(pipeline, head) + + fir_matrix = numpy.zeros((1, fir_length), dtype="double") + fir_matrix[0, (fir_matrix.shape[1] - 1) - latency] = 1.0 + + head = pipeparts.mkfirbank(pipeline, head, fir_matrix=fir_matrix, latency=latency, time_domain=time_domain) + head = pipeparts.mkchecktimestamps(pipeline, head) + pipeparts.mknxydumpsink(pipeline, pipeparts.mkqueue(pipeline, head), "%s_out.dump" % name) + pipeparts.mknxydumpsink(pipeline, pipeparts.mkqueue(pipeline, tee), "%s_in.dump" % name) + + # + # done + # + + return pipeline + + +def firbank_test_02(pipeline, name, width, time_domain): + # 1 channel goes into firbank + head = common.test_src(pipeline, buffer_length=10.0, rate=16384, width=width, channels=1, test_duration=200.0, wave=5, verbose=True) + # 200 channels come out + head = pipeparts.mkfirbank(pipeline, head, fir_matrix=numpy.ones((200, 1)), time_domain=time_domain) + pipeparts.mkfakesink(pipeline, head) + + # + # done + # + + return pipeline + + +class TestFirbank01: + """Test class wrapper for Firbank 01 + is the firbank element an identity transform when given a unit impulse? + in and out timeseries should be identical modulo start/stop transients + """ + + FLAGS = cmp_nxydumps.COMPARE_FLAGS_EXACT_GAPS | cmp_nxydumps.COMPARE_FLAGS_ZERO_IS_GAP | cmp_nxydumps.COMPARE_FLAGS_ALLOW_STARTSTOP_MISALIGN + + @pytest.mark.parametrize('gap_frequency', (153.0, 13.0, 0.13)) + @pytest.mark.parametrize('width', (32, 64)) + @pytest.mark.parametrize('time_domain', (True, False)) + def test_firbank_01(self, gap_frequency, width, time_domain): + """Test for firbank 01""" + name = "firbank_test_01_%d%s_%.2f" % (width, ("TD" if time_domain else "FD"), gap_frequency) + common.build_and_run(firbank_test_01, name, width=width, time_domain=time_domain, gap_frequency=gap_frequency) + if width == 64: + cmp_nxydumps.compare("%s_in.dump" % name, "%s_out.dump" % name, flags=self.FLAGS) + else: + cmp_nxydumps.compare("%s_in.dump" % name, "%s_out.dump" % name, flags=self.FLAGS, sample_fuzz=1e-6) + + def test_firbank_02(self): + common.build_and_run(firbank_test_02, "firbank_test_02a", width=64, time_domain=True) diff --git a/gstlal/tests/tests_pytest/utils/cmp_nxydumps.py b/gstlal/tests/tests_pytest/utils/cmp_nxydumps.py new file mode 100755 index 0000000000..51ff82e773 --- /dev/null +++ b/gstlal/tests/tests_pytest/utils/cmp_nxydumps.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2013--2015 Kipp Cannon +# +# This program is free software; you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the +# Free Software Foundation; either version 2 of the License, or (at your +# option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General +# Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + + +import itertools + + +from ligo import segments +from lal import iterutils +from lal import LIGOTimeGPS + + +default_timestamp_fuzz = 1e-9 # seconds +default_sample_fuzz = 1e-15 # relative + + +# +# flags +# + + +# when comparing time series, require gap intervals to be identical +COMPARE_FLAGS_EXACT_GAPS = 1 +# consider samples that are all 0 also to be gaps +COMPARE_FLAGS_ZERO_IS_GAP = 2 +# don't require the two time series to start and stop at the same time +COMPARE_FLAGS_ALLOW_STARTSTOP_MISALIGN = 4 + +# the default flags for comparing time series +COMPARE_FLAGS_DEFAULT = 0 + + +# +# tools +# + + +def load_file(fobj, transients = (0.0, 0.0)): + stream = (line.strip() for line in fobj) + stream = (line.split() for line in stream if line and not line.startswith("#")) + lines = [(LIGOTimeGPS(line[0]),) + tuple(map(float, line[1:])) for line in stream] + assert lines, "no data" + channel_count_plus_1 = len(lines[0]) + assert all(len(line) == channel_count_plus_1 for line in lines), "not all lines have the same channel count" + for t1, t2 in zip((line[0] for line in lines), (line[0] for line in lines[1:])): + assert t2 > t1, "timestamps not in order @ t = %s s" % str(t2) + start = lines[0][0] + transients[0] + stop = lines[-1][0] - transients[-1] + iterutils.inplace_filter(lambda line: start <= line[0] <= stop, lines) + assert lines, "transients remove all data" + return lines + + +def max_abs_sample(lines): + # return the largest of the absolute values of the samples + return max(max(abs(x) for x in line[1:]) for line in lines) + + +def identify_gaps(lines, timestamp_fuzz = default_timestamp_fuzz, sample_fuzz = default_sample_fuzz, flags = COMPARE_FLAGS_DEFAULT): + # assume the smallest interval bewteen samples indicates the true + # sample rate, and correct for possible round-off by assuming true + # sample rate is an integer number of Hertz + dt = min(float(line1[0] - line0[0]) for line0, line1 in zip(lines, lines[1:])) + dt = 1.0 / round(1.0 / dt) + + # convert to absolute fuzz (but don't waste time with this if we + # don't need it) + if flags & COMPARE_FLAGS_ZERO_IS_GAP: + sample_fuzz *= max_abs_sample(lines) + + gaps = segments.segmentlist() + for i, line in enumerate(lines): + if i and (line[0] - lines[i - 1][0]) - dt > timestamp_fuzz * 2: + # clock skip. interpret missing timestamps as a + # gap + gaps.append(segments.segment((lines[i - 1][0] + dt, line[0]))) + if flags & COMPARE_FLAGS_ZERO_IS_GAP and all(abs(x) <= sample_fuzz for x in line[1:]): + # all samples are "0". the current sample is a gap + gaps.append(segments.segment((line[0], lines[i + 1][0] if i + 1 < len(lines) else line[0] + dt))) + return gaps.protract(timestamp_fuzz).coalesce() + + +def compare_fobjs(fobj1, fobj2, transients = (0.0, 0.0), timestamp_fuzz = default_timestamp_fuzz, sample_fuzz = default_sample_fuzz, flags = COMPARE_FLAGS_DEFAULT): + timestamp_fuzz = LIGOTimeGPS(timestamp_fuzz) + + # load dump files with transients removed + lines1 = load_file(fobj1, transients = transients) + lines2 = load_file(fobj2, transients = transients) + assert len(lines1[0]) == len(lines2[0]), "files do not have same channel count" + + # trim lead-in and lead-out if requested + if flags & COMPARE_FLAGS_ALLOW_STARTSTOP_MISALIGN: + lines1 = [line for line in lines1 if lines2[0][0] <= line[0] <= lines2[-1][0]] + assert lines1, "time intervals do not overlap" + lines2 = [line for line in lines2 if lines1[0][0] <= line[0] <= lines1[-1][0]] + assert lines2, "time intervals do not overlap" + + # construct segment lists indicating gap intervals + gaps1 = identify_gaps(lines1, timestamp_fuzz = timestamp_fuzz, sample_fuzz = sample_fuzz, flags = flags) + gaps2 = identify_gaps(lines2, timestamp_fuzz = timestamp_fuzz, sample_fuzz = sample_fuzz, flags = flags) + if flags & COMPARE_FLAGS_EXACT_GAPS: + difference = gaps1 ^ gaps2 + iterutils.inplace_filter(lambda seg: abs(seg) > timestamp_fuzz, difference) + assert not difference, "gap discrepancy: 1 ^ 2 = %s" % str(difference) + + # convert relative sample fuzz to absolute + sample_fuzz *= max_abs_sample(itertools.chain(lines1, lines2)) + + lines1 = iter(lines1) + lines2 = iter(lines2) + # guaranteeed to be at least 1 line in both lists + line1 = next(lines1) + line2 = next(lines2) + while True: + try: + if abs(line1[0] - line2[0]) <= timestamp_fuzz: + for val1, val2 in zip(line1[1:], line2[1:]): + assert abs(val1 - val2) <= sample_fuzz, "values disagree @ t = %s s" % str(line1[0]) + line1 = next(lines1) + line2 = next(lines2) + elif line1[0] < line2[0] and line1[0] in gaps2: + line1 = next(lines1) + elif line2[0] < line1[0] and line2[0] in gaps1: + line2 = next(lines2) + else: + raise AssertionError("timestamp misalignment @ %s s and %s s" % (str(line1[0]), str(line2[0]))) + except StopIteration: + break + # FIXME: should check that we're at the end of both series + + +def compare(filename1, filename2, *args, **kwargs): + try: + compare_fobjs(open(filename1), open(filename2), *args, **kwargs) + except AssertionError as e: + raise AssertionError("%s <--> %s: %s" % (filename1, filename2, str(e))) + + +# +# main() +# + + +if __name__ == "__main__": + from optparse import OptionParser + parser = OptionParser() + parser.add_option("--compare-exact-gaps", action = "store_const", const = COMPARE_FLAGS_EXACT_GAPS, default = 0) + parser.add_option("--compare-zero-is-gap", action = "store_const", const = COMPARE_FLAGS_ZERO_IS_GAP, default = 0) + parser.add_option("--compare-allow-startstop-misalign", action = "store_const", const = COMPARE_FLAGS_ALLOW_STARTSTOP_MISALIGN, default = 0) + parser.add_option("--timestamp-fuzz", metavar = "seconds", type = "float", default = default_timestamp_fuzz) + parser.add_option("--sample-fuzz", metavar = "fraction", type = "float", default = default_sample_fuzz) + options, (filename1, filename2) = parser.parse_args() + compare(filename1, filename2, timestamp_fuzz = options.timestamp_fuzz, sample_fuzz = options.sample_fuzz, flags = options.compare_exact_gaps | options.compare_zero_is_gap | options.compare_allow_startstop_misalign) diff --git a/gstlal/tests/tests_pytest/utils/common.py b/gstlal/tests/tests_pytest/utils/common.py new file mode 100644 index 0000000000..5842024b1c --- /dev/null +++ b/gstlal/tests/tests_pytest/utils/common.py @@ -0,0 +1,171 @@ +# Copyright (C) 2009--2011,2013 Kipp Cannon +# +# This program is free software; you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the +# Free Software Foundation; either version 2 of the License, or (at your +# option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General +# Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +# +# ============================================================================= +# +# Preamble +# +# ============================================================================= +# + + +import numpy +import sys + + +import gi +gi.require_version('Gst', '1.0') +from gi.repository import GObject +from gi.repository import Gst + + +from gstlal import pipeparts +from gstlal import pipeio +from gstlal import simplehandler + + +GObject.threads_init() +Gst.init(None) + + +if sys.byteorder == "little": + BYTE_ORDER = "LE" +else: + BYTE_ORDER = "BE" + + +# +# ============================================================================= +# +# Utilities +# +# ============================================================================= +# + + +def complex_test_src(pipeline, buffer_length = 1.0, rate = 2048, width = 64, test_duration = 10.0, wave = 5, freq = 0, is_live = False, verbose = True): + assert not width % 8 + samplesperbuffer = int(round(buffer_length * rate)) + head = pipeparts.mkaudiotestsrc(pipeline, wave = wave, freq = freq, volume = 1, blocksize = (width / 8 * 2) * samplesperbuffer, samplesperbuffer = samplesperbuffer, num_buffers = int(round(test_duration / buffer_length)), is_live = is_live) + head = pipeparts.mkcapsfilter(pipeline, head, "audio/x-raw, format=Z%d%s, rate=%d, channels=2" % (width, BYTE_ORDER, rate)) + head = pipeparts.mktogglecomplex(pipeline, head) + if verbose: + head = pipeparts.mkprogressreport(pipeline, head, "src") + return head + + +def test_src(pipeline, buffer_length = 1.0, rate = 2048, width = 64, channels = 1, test_duration = 10.0, wave = 5, freq = 0, is_live = False, verbose = True): + assert not width % 8 + if wave == "ligo": + head = pipeparts.mkfakeLIGOsrc(pipeline, instrument = "H1", channel_name = "LSC-STRAIN") + else: + samplesperbuffer = int(round(buffer_length * rate)) + head = pipeparts.mkaudiotestsrc(pipeline, wave = wave, freq = freq, volume = 1, blocksize = (width / 8 * channels) * samplesperbuffer, samplesperbuffer = samplesperbuffer, num_buffers = int(round(test_duration / buffer_length)), is_live = is_live) + head = pipeparts.mkcapsfilter(pipeline, head, "audio/x-raw, format=F%d%s, rate=%d, channels=%d" % (width, BYTE_ORDER, rate, channels)) + if verbose: + head = pipeparts.mkprogressreport(pipeline, head, "src") + return head + + +def add_gaps(pipeline, head, buffer_length, rate, test_duration, gap_frequency = None, gap_threshold = None, control_dump_filename = None): + if gap_frequency is None: + return head + samplesperbuffer = int(round(buffer_length * rate)) + control = pipeparts.mkcapsfilter(pipeline, pipeparts.mkaudiotestsrc(pipeline, wave = 0, freq = gap_frequency, volume = 1, blocksize = 4 * samplesperbuffer, samplesperbuffer = samplesperbuffer, num_buffers = int(round(test_duration / buffer_length))), "audio/x-raw, format=F32%s, rate=%d, channels=1" % (BYTE_ORDER, rate)) + if control_dump_filename is not None: + control = pipeparts.mknxydumpsinktee(pipeline, pipeparts.mkqueue(pipeline, control), control_dump_filename) + control = pipeparts.mkqueue(pipeline, control) + return pipeparts.mkgate(pipeline, head, control = control, threshold = gap_threshold) + + +def gapped_test_src(pipeline, buffer_length = 1.0, rate = 2048, width = 64, channels = 1, test_duration = 10.0, wave = 5, freq = 0, gap_frequency = None, gap_threshold = None, control_dump_filename = None, tags = None, is_live = False, verbose = True): + src = test_src(pipeline, buffer_length = buffer_length, rate = rate, width = width, channels = channels, test_duration = test_duration, wave = wave, freq = freq, is_live = is_live, verbose = verbose) + if tags is not None: + src = pipeparts.mktaginject(pipeline, src, tags) + return add_gaps(pipeline, src, buffer_length = buffer_length, rate = rate, test_duration = test_duration, gap_frequency = gap_frequency, gap_threshold = gap_threshold, control_dump_filename = control_dump_filename) + + +def gapped_complex_test_src(pipeline, buffer_length = 1.0, rate = 2048, width = 64, test_duration = 10.0, wave = 5, freq = 0, gap_frequency = None, gap_threshold = None, control_dump_filename = None, tags = None, is_live = False, verbose = True): + src = complex_test_src(pipeline, buffer_length = buffer_length, rate = rate, width = width, test_duration = test_duration, wave = wave, freq = freq, is_live = is_live, verbose = verbose) + if tags is not None: + src = pipeparts.mktaginject(pipeline, src, tags) + return pipeparts.mktogglecomplex(pipeline, add_gaps(pipeline, pipeparts.mktogglecomplex(pipeline, src), buffer_length = buffer_length, rate = rate, test_duration = test_duration, gap_frequency = gap_frequency, gap_threshold = gap_threshold, control_dump_filename = control_dump_filename)) + + +# +# ============================================================================= +# +# Pipeline Builder +# +# ============================================================================= +# + + +def build_and_run(pipelinefunc, name, segment = None, **pipelinefunc_kwargs): + print("=== Running Test %s ===" % name, file=sys.stderr) + mainloop = GObject.MainLoop() + pipeline = pipelinefunc(Gst.Pipeline(name = name), name, **pipelinefunc_kwargs) + handler = simplehandler.Handler(mainloop, pipeline) + if segment is not None: + if pipeline.set_state(Gst.State.PAUSED) == Gst.StateChangeReturn.FAILURE: + raise RuntimeError("pipeline failed to enter PLAYING state") + pipeline.seek(1.0, Gst.Format(Gst.Format.TIME), Gst.SeekFlags.FLUSH, Gst.SeekType.SET, segment[0].ns(), Gst.SeekType.SET, segment[1].ns()) + if pipeline.set_state(Gst.State.PLAYING) == Gst.StateChangeReturn.FAILURE: + raise RuntimeError("pipeline failed to enter PLAYING state") + mainloop.run() + + +# +# ============================================================================= +# +# Push Arrays Through an Element +# +# ============================================================================= +# + + +def transform_arrays(input_arrays, elemfunc, name, rate = 1, **elemfunc_kwargs): + input_arrays = list(input_arrays) # so we can modify it + output_arrays = [] + + pipeline = Gst.Pipeline(name = name) + + head = pipeparts.mkgeneric(pipeline, None, "appsrc", caps = pipeio.caps_from_array(input_arrays[0], rate = rate)) + def need_data(elem, arg, input_array_rate_pair): + input_arrays, rate = input_array_rate_pair + if input_arrays: + arr = input_arrays.pop(0) + elem.set_property("caps", pipeio.caps_from_array(arr, rate)) + buf = pipeio.audio_buffer_from_array(arr, 0, 0, rate) + elem.emit("push-buffer", pipeio.audio_buffer_from_array(arr, 0, 0, rate)) + return Gst.FlowReturn.OK + else: + elem.emit("end-of-stream") + return Gst.FlowReturn.EOS + head.connect("need-data", need_data, (input_arrays, rate)) + + head = elemfunc(pipeline, head, **elemfunc_kwargs) + + head = pipeparts.mkappsink(pipeline, head) + def appsink_get_array(elem, output_arrays): + output_arrays.append(pipeio.array_from_audio_sample(elem.emit("pull-sample"))) + return Gst.FlowReturn.OK + + head.connect("new-sample", appsink_get_array, output_arrays) + build_and_run((lambda *args, **kwargs: pipeline), name) + + return output_arrays -- GitLab