Skip to content
Snippets Groups Projects
Commit b212ec65 authored by kipp's avatar kipp
Browse files

implement frequency-domain convolution

parent 74b2e220
No related branches found
No related tags found
No related merge requests found
......@@ -73,18 +73,11 @@
/*
* stuff from fftw
* stuff from FFTW and GSL
*/
#include <fftw3.h>
/*
* stuff from GSL
*/
#include <gsl/gsl_vector.h>
#include <gsl/gsl_matrix.h>
#include <gsl/gsl_blas.h>
......@@ -121,6 +114,17 @@ static int fir_length(const GSTLALFIRBank *element)
}
/*
* return the number of time-domain samples in each FFT
*/
static int fft_block_length(const GSTLALFIRBank *element)
{
return fir_length(element) * element->block_length_factor;
}
/*
* construct a buffer of zeros and push into adapter
*/
......@@ -172,8 +176,79 @@ static guint64 get_available_samples(GSTLALFIRBank *element)
/*
* transform input samples to output samples using a purely time-domain
* algorithm
* create/free FFT workspace. call with fir_matrix_lock held
*/
static int create_fft_workspace(GSTLALFIRBank *element)
{
int i;
int length_fd = fft_block_length(element) / 2 + 1;
g_mutex_lock(gstlal_fftw_lock);
/*
* frequency-domain input
*/
element->input_fd = (complex double *) fftw_malloc(length_fd * sizeof(*element->input_fd));
element->in_plan = fftw_plan_dft_r2c_1d(fft_block_length(element), (double *) element->input_fd, element->input_fd, FFTW_MEASURE);
/*
* frequency-domain workspace
*/
element->workspace_fd = (complex double *) fftw_malloc(length_fd * sizeof(*element->workspace_fd));
element->out_plan = fftw_plan_dft_c2r_1d(fft_block_length(element), element->workspace_fd, (double *) element->workspace_fd, FFTW_MEASURE);
/*
* loop over filters. copy each time-domain filter to input_fd,
* zero-pad, transform to frequency domain, and save. the
* frequency-domain filters are pre-scaled by 1/n and conjugated to
* save those operations inside the fitlering loop.
*/
element->fir_matrix_fd = (complex double *) fftw_malloc(fir_channels(element) * length_fd * sizeof(*element->fir_matrix_fd));
for(i = 0; i < fir_channels(element); i++) {
int j;
memset(element->input_fd, 0, length_fd * sizeof(*element->input_fd));
for(j = 0; j < fir_length(element); j++)
((double *) element->input_fd)[j] = gsl_matrix_get(element->fir_matrix, i, j) / fft_block_length(element);
fftw_execute(element->in_plan);
for(j = 0; j < length_fd; j++)
element->fir_matrix_fd[i * length_fd + j] = conj(element->input_fd[j]);
}
/*
* done
*/
g_mutex_unlock(gstlal_fftw_lock);
return 0;
}
static void free_fft_workspace(GSTLALFIRBank *element)
{
g_mutex_lock(gstlal_fftw_lock);
fftw_free(element->fir_matrix_fd);
element->fir_matrix_fd = NULL;
fftw_free(element->input_fd);
element->input_fd = NULL;
fftw_destroy_plan(element->in_plan);
element->in_plan = NULL;
fftw_free(element->workspace_fd);
element->workspace_fd = NULL;
fftw_destroy_plan(element->out_plan);
element->out_plan = NULL;
g_mutex_unlock(gstlal_fftw_lock);
}
/*
* transform input samples to output samples using a time-domain algorithm
*/
......@@ -263,6 +338,151 @@ static GstFlowReturn tdfilter(GSTLALFIRBank *element, GstBuffer *outbuf)
}
/*
* transform input samples to output samples using a frequency-domain
* algorithm
*/
static GstFlowReturn fdfilter(GSTLALFIRBank *element, GstBuffer *outbuf)
{
int i;
int fft_block_stride;
int fft_blocks;
int input_length;
int output_length;
double *input;
gsl_vector_view workspace;
/*
* how many FFT blocks can we construct from the contents of the
* adapter?
*/
input_length = get_available_samples(element);
if(input_length < fft_block_length(element))
return GST_BASE_TRANSFORM_FLOW_DROPPED;
fft_block_stride = fft_block_length(element) - fir_length(element) + 1;
fft_blocks = (input_length - fft_block_length(element)) / fft_block_stride + 1;
input_length = (fft_blocks - 1) * fft_block_stride + fft_block_length(element);
output_length = input_length - fir_length(element) + 1;
/*
* retrieve input samples
*/
input = (double *) gst_adapter_peek(element->adapter, input_length * sizeof(double));
/*
* wrap workspace (as real numbers) in a GSL vector view. note
* that vector is fft_block_stride in length to affect the
* requisite transient clipping
*/
workspace = gsl_vector_view_array((double *) element->workspace_fd, fft_block_stride);
/*
* loop over FFT blocks
*/
for(i = 0; i < fft_blocks; i++) {
gsl_matrix_view output;
complex double *filter;
int j;
/*
* wrap output buffer in a GSL matrix view.
*/
output = gsl_matrix_view_array(((double *) GST_BUFFER_DATA(outbuf)) + i * fft_block_stride * fir_channels(element), fft_block_stride, fir_channels(element));
/*
* copy a block-length of data to input workspace and
* transform to frequency-domain
*/
memcpy(element->input_fd, input, fft_block_length(element) * sizeof(*input));
fftw_execute(element->in_plan);
/*
* loop over filters
*/
filter = element->fir_matrix_fd;
for(j = 0; j < fir_channels(element); j++) {
int k;
/*
* multiply input by filter, transform to
* time-domain
*/
for(k = 0; k < fft_block_length(element) / 2 + 1; k++)
element->workspace_fd[k] = element->input_fd[k] * *(filter++);
fftw_execute(element->out_plan);
/*
* copy to output
*/
gsl_matrix_set_col(&output.matrix, j, &workspace.vector);
}
/*
* advance to next FFT block
*/
input += fft_block_stride;
}
/*
* flush the data from the adapter
*/
gst_adapter_flush(element->adapter, output_length * sizeof(double));
if(output_length > input_length - element->zeros_in_adapter)
/*
* some trailing zeros have been flushed from the adapter
*/
element->zeros_in_adapter -= output_length - (input_length - element->zeros_in_adapter);
/*
* set buffer metadata
*/
set_metadata(element, outbuf, output_length);
/*
* done
*/
return GST_FLOW_OK;
}
/*
* select a filtering algorithm
*/
static GstFlowReturn filter(GSTLALFIRBank *element, GstBuffer *outbuf)
{
#if 0
return tdfilter(element, outbuf);
#else
/*
* get frequency-domain filters if needed
*/
if(!element->fir_matrix_fd)
create_fft_workspace(element);
return fdfilter(element, outbuf);
#endif
}
/*
* ============================================================================
*
......@@ -276,12 +496,11 @@ static GstStaticPadTemplate sink_factory = GST_STATIC_PAD_TEMPLATE(
"sink",
GST_PAD_SINK,
GST_PAD_ALWAYS,
/* FIXME: BYTEORDER */
GST_STATIC_CAPS(
"audio/x-raw-float, " \
"rate = (int) [1, MAX], " \
"channels = (int) 1, " \
"endianness = (int) 1234, " \
"endianness = (int) BYTE_ORDER, " \
"width = (int) 64"
)
);
......@@ -291,12 +510,11 @@ static GstStaticPadTemplate src_factory = GST_STATIC_PAD_TEMPLATE(
"src",
GST_PAD_SRC,
GST_PAD_ALWAYS,
/* FIXME: BYTEORDER */
GST_STATIC_CAPS(
"audio/x-raw-float, " \
"rate = (int) [1, MAX], " \
"channels = (int) [1, MAX], " \
"endianness = (int) 1234, " \
"endianness = (int) BYTE_ORDER, " \
"width = (int) 64"
)
);
......@@ -536,7 +754,7 @@ static GstFlowReturn transform(GstBaseTransform *trans, GstBuffer *inbuf, GstBuf
gst_buffer_ref(inbuf); /* don't let the adapter free it */
gst_adapter_push(element->adapter, inbuf);
element->zeros_in_adapter = 0;
result = tdfilter(element, outbuf);
result = filter(element, outbuf);
} else if(element->zeros_in_adapter >= fir_length(element) - 1) {
/*
* input is 0s and we are past the tail of the impulse
......@@ -559,7 +777,7 @@ static GstFlowReturn transform(GstBaseTransform *trans, GstBuffer *inbuf, GstBuf
*/
push_zeros(element, length);
result = tdfilter(element, outbuf);
result = filter(element, outbuf);
} else {
/*
* input is 0s, we are not yet past the tail of the impulse
......@@ -586,7 +804,7 @@ static GstFlowReturn transform(GstBaseTransform *trans, GstBuffer *inbuf, GstBuf
result = gst_pad_alloc_buffer(srcpad, element->next_out_offset, available_samples * fir_channels(element) * sizeof(double), GST_PAD_CAPS(srcpad), &buf);
if(result != GST_FLOW_OK)
goto done;
result = tdfilter(element, buf);
result = filter(element, buf);
g_assert(result == GST_FLOW_OK);
result = gst_pad_push(srcpad, buf);
if(result != GST_FLOW_OK)
......@@ -635,7 +853,15 @@ static void set_property(GObject *object, enum property prop_id, const GValue *v
switch (prop_id) {
case ARG_BLOCK_LENGTH_FACTOR:
g_mutex_lock(element->fir_matrix_lock);
element->block_length_factor = g_value_get_int(value);
/*
* invalidate frequency-domain filters
*/
free_fft_workspace(element);
g_mutex_unlock(element->fir_matrix_lock);
break;
case ARG_FIR_MATRIX: {
......@@ -653,6 +879,17 @@ static void set_property(GObject *object, enum property prop_id, const GValue *v
* renegotiation
*/
gst_pad_set_caps(GST_BASE_TRANSFORM_SRC_PAD(GST_BASE_TRANSFORM(object)), NULL);
/*
* invalidate frequency-domain filters
*/
free_fft_workspace(element);
/*
* signal availability of new time-domain filters
*/
g_cond_signal(element->fir_matrix_available);
g_mutex_unlock(element->fir_matrix_lock);
break;
......@@ -725,10 +962,8 @@ static void finalize(GObject *object)
gsl_matrix_free(element->fir_matrix);
element->fir_matrix = NULL;
}
if(element->fir_matrix_fd) {
gsl_matrix_complex_free(element->fir_matrix_fd);
element->fir_matrix_fd = NULL;
}
free_fft_workspace(element);
G_OBJECT_CLASS(parent_class)->finalize(object);
}
......@@ -780,9 +1015,9 @@ static void gstlal_firbank_class_init(GSTLALFIRBankClass *klass)
ARG_BLOCK_LENGTH_FACTOR,
g_param_spec_int(
"block-length-factor",
"Convolution block size in multiples of the FIR size",
"Convolution block size in multiples of the FIR length",
"When using FFT convolutions, use this many times the number of samples in each FIR vector for the convolution block size.",
1, G_MAXINT, DEFAULT_BLOCK_LENGTH_FACTOR,
2, G_MAXINT, DEFAULT_BLOCK_LENGTH_FACTOR,
G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS
)
);
......@@ -837,5 +1072,9 @@ static void gstlal_firbank_init(GSTLALFIRBank *filter, GSTLALFIRBankClass *kclas
filter->fir_matrix_available = g_cond_new();
filter->fir_matrix = NULL;
filter->fir_matrix_fd = NULL;
filter->input_fd = NULL;
filter->workspace_fd = NULL;
filter->in_plan = NULL;
filter->out_plan = NULL;
gst_base_transform_set_gap_aware(GST_BASE_TRANSFORM(filter), TRUE);
}
......@@ -45,12 +45,16 @@
#define __GST_LAL_FIRBANK_H__
#include <complex.h>
#include <glib.h>
#include <gst/gst.h>
#include <gst/base/gstadapter.h>
#include <gst/base/gstbasetransform.h>
#include <fftw3.h>
#include <gsl/gsl_matrix.h>
......@@ -75,18 +79,37 @@ typedef struct {
typedef struct {
GstBaseTransform element;
gint block_length_factor;
gint64 latency;
/*
* input stream
*/
gint rate;
GstAdapter *adapter;
gint zeros_in_adapter;
/*
* filter info
*/
GMutex *fir_matrix_lock;
GCond *fir_matrix_available;
gsl_matrix *fir_matrix;
gsl_matrix_complex *fir_matrix_fd;
gint64 latency;
/*
* FFT work space
*/
gint block_length_factor;
complex double *fir_matrix_fd;
complex double *input_fd;
complex double *workspace_fd;
fftw_plan in_plan;
fftw_plan out_plan;
/*
* timestamp book-keeping
*/
GstClockTime t0;
guint64 offset0;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment