From 9868f2de1c800de5cc403b1cc696d56242613188 Mon Sep 17 00:00:00 2001 From: Aaron Viets <aaron.viets@ligo.org> Date: Mon, 4 Feb 2019 13:57:18 -0800 Subject: [PATCH] lal_matrixsolver: bug fixes - now it runs. --- .../gst/lal/gstlal_matrixsolver.c | 270 +++++++++++++----- .../gst/lal/gstlal_matrixsolver.h | 15 + 2 files changed, 216 insertions(+), 69 deletions(-) diff --git a/gstlal-calibration/gst/lal/gstlal_matrixsolver.c b/gstlal-calibration/gst/lal/gstlal_matrixsolver.c index 110e2adf42..d77dc11f6a 100644 --- a/gstlal-calibration/gst/lal/gstlal_matrixsolver.c +++ b/gstlal-calibration/gst/lal/gstlal_matrixsolver.c @@ -78,27 +78,28 @@ */ -#define GST_CAT_DEFAULT gstlal_matrixsolver_debug -GST_DEBUG_CATEGORY_STATIC(GST_CAT_DEFAULT); - - -G_DEFINE_TYPE_WITH_CODE( - GSTLALMatrixSolver, - gstlal_matrixsolver, - GST_TYPE_BASE_TRANSFORM, - GST_DEBUG_CATEGORY_INIT(GST_CAT_DEFAULT, "lal_matrixsolver", 0, "lal_matrixsolver element") -); +#define INCAPS \ + "audio/x-raw, " \ + "format = (string) {"GST_AUDIO_NE(F32)", "GST_AUDIO_NE(F64)", "GST_AUDIO_NE(Z64)", "GST_AUDIO_NE(Z128)"}, " \ + "rate = " GST_AUDIO_RATE_RANGE ", " \ + "channels = (int) [1, MAX], " \ + "layout = (string) interleaved, " \ + "channel-mask = (bitmask) 0" + +#define OUTCAPS \ + "audio/x-raw, " \ + "format = (string) {"GST_AUDIO_NE(F32)", "GST_AUDIO_NE(F64)", "GST_AUDIO_NE(Z64)", "GST_AUDIO_NE(Z128)"}, " \ + "rate = " GST_AUDIO_RATE_RANGE ", " \ + "channels = (int) [1, MAX], " \ + "layout = (string) interleaved, " \ + "channel-mask = (bitmask) 0" static GstStaticPadTemplate sink_factory = GST_STATIC_PAD_TEMPLATE( GST_BASE_TRANSFORM_SINK_NAME, GST_PAD_SINK, GST_PAD_ALWAYS, - GST_STATIC_CAPS( - GST_AUDIO_CAPS_MAKE("{" GST_AUDIO_NE(F32) ", " GST_AUDIO_NE(F64) ", " GST_AUDIO_NE(Z64) ", " GST_AUDIO_NE(Z128) "}") ", " \ - "layout = (string) interleaved, " \ - "channel-mask = (bitmask) 0" - ) + GST_STATIC_CAPS(INCAPS) ); @@ -106,11 +107,14 @@ static GstStaticPadTemplate src_factory = GST_STATIC_PAD_TEMPLATE( GST_BASE_TRANSFORM_SRC_NAME, GST_PAD_SRC, GST_PAD_ALWAYS, - GST_STATIC_CAPS( - GST_AUDIO_CAPS_MAKE("{" GST_AUDIO_NE(F32) ", " GST_AUDIO_NE(F64) ", " GST_AUDIO_NE(Z64) ", " GST_AUDIO_NE(Z128) "}") ", " \ - "layout = (string) interleaved, " \ - "channel-mask = (bitmask) 0" - ) + GST_STATIC_CAPS(OUTCAPS) +); + + +G_DEFINE_TYPE( + GSTLALMatrixSolver, + gstlal_matrixsolver, + GST_TYPE_BASE_TRANSFORM ); @@ -179,14 +183,10 @@ static complex double get_complexdouble_from_gsl_vector(gsl_vector_complex *vec, #define DEFINE_SOLVE_SYSTEM(COMPLEX, DTYPE, UNDERSCORE) \ -static void solve_system_ ## COMPLEX ## DTYPE(const COMPLEX DTYPE *src, COMPLEX DTYPE *dst, guint64 dst_size, int channels_in, int channels_out) { \ +static void solve_system_ ## COMPLEX ## DTYPE(const COMPLEX DTYPE *src, COMPLEX DTYPE *dst, guint64 dst_size, int channels_in, int channels_out, gsl_vector ## UNDERSCORE ## COMPLEX *invec, gsl_vector ## UNDERSCORE ## COMPLEX *outvec, gsl_matrix ## UNDERSCORE ## COMPLEX *matrix, gsl_permutation *permutation) { \ \ guint64 i; \ int j, signum; \ - gsl_vector ## UNDERSCORE ## COMPLEX *invec = gsl_vector ## UNDERSCORE ## COMPLEX ## _alloc(channels_out); \ - gsl_vector ## UNDERSCORE ## COMPLEX *outvec = gsl_vector ## UNDERSCORE ## COMPLEX ## _alloc(channels_out); \ - gsl_matrix ## UNDERSCORE ## COMPLEX *matrix = gsl_matrix ## UNDERSCORE ## COMPLEX ## _alloc(channels_out, channels_out); \ - gsl_permutation *permutation = gsl_permutation_alloc(channels_out); \ \ for(i = 0; i < dst_size; i++) { \ /* Set the elements of the GSL vector invec using the first N channels of input data */ \ @@ -194,8 +194,8 @@ static void solve_system_ ## COMPLEX ## DTYPE(const COMPLEX DTYPE *src, COMPLEX gsl_vector_ ## COMPLEX ## UNDERSCORE ## set(invec, j, make_gsl_input ## COMPLEX((COMPLEX double) src[channels_in * i + j])); \ \ /* Set the elements of the GSL matrix using the remaining channels of input data */ \ - for(j = channels_out; j < channels_in; j++) \ - gsl_matrix_ ## COMPLEX ## UNDERSCORE ## set(matrix, j / channels_out, j % channels_out, make_gsl_input ## COMPLEX((COMPLEX double) src[channels_in * i + j])); \ + for(j = 0; j < channels_in - channels_out; j++) \ + gsl_matrix_ ## COMPLEX ## UNDERSCORE ## set(matrix, j / channels_out, j % channels_out, make_gsl_input ## COMPLEX((COMPLEX double) src[channels_in * i + channels_out + j])); \ \ /* Now solve [matrix] [outvec] = [invec] for [outvec] using gsl */ \ gsl_linalg_ ## COMPLEX ## UNDERSCORE ## LU_decomp(matrix, permutation, &signum); \ @@ -252,7 +252,6 @@ static gboolean get_unit_size(GstBaseTransform *trans, GstCaps *caps, gsize *siz static GstCaps *transform_caps(GstBaseTransform *trans, GstPadDirection direction, GstCaps *caps, GstCaps *filter) { guint n; - int channels; caps = gst_caps_normalize(gst_caps_copy(caps)); switch(direction) { @@ -264,10 +263,24 @@ static GstCaps *transform_caps(GstBaseTransform *trans, GstPadDirection directio */ for(n = 0; n < gst_caps_get_size(caps); n++) { GstStructure *str = gst_caps_get_structure(caps, n); - if(!gst_structure_get_int(str, "channels", &channels)) - GST_DEBUG_OBJECT(trans, "unable to get number of channels from caps %" GST_PTR_FORMAT, caps); - channels = channels * (channels + 1); - gst_structure_set(str, "channels", G_TYPE_INT, channels, NULL); + const GValue *v = gst_structure_get_value(gst_caps_get_structure(caps, 0), "channels"); + if(GST_VALUE_HOLDS_INT_RANGE(v)) { + gint channels_in_min, channels_out_min, channels_out_max; + guint64 channels_in_max; + channels_out_min = gst_value_get_int_range_min(v); + channels_out_max = gst_value_get_int_range_max(v); + channels_in_min = channels_out_min * (channels_out_min + 1); + channels_in_max = channels_out_max * ((guint64) channels_out_max + 1); + /* In case channels_in_max is greater than G_MAXINT */ + channels_in_max = channels_in_max < (guint64) G_MAXINT ? channels_in_max : (guint64) G_MAXINT; + gst_structure_set(str, "channels", GST_TYPE_INT_RANGE, channels_in_min, (gint) channels_in_max, NULL); + } else if(G_VALUE_HOLDS_INT(v)) { + gint channels_in, channels_out; + channels_out = g_value_get_int(v); + channels_in = channels_out * (channels_out + 1); + gst_structure_set(str, "channels", G_TYPE_INT, channels_in, NULL); + } else + GST_ELEMENT_ERROR(trans, CORE, NEGOTIATION, (NULL), ("invalid type for channels in caps")); } break; @@ -279,10 +292,21 @@ static GstCaps *transform_caps(GstBaseTransform *trans, GstPadDirection directio */ for(n = 0; n < gst_caps_get_size(caps); n++) { GstStructure *str = gst_caps_get_structure(caps, n); - if(!gst_structure_get_int(str, "channels", &channels)) - GST_DEBUG_OBJECT(trans, "unable to get number of channels from caps %" GST_PTR_FORMAT, caps); - channels = (int) pow((double) channels, 0.5); - gst_structure_set(str, "channels", G_TYPE_INT, channels, NULL); + const GValue *v = gst_structure_get_value(gst_caps_get_structure(caps, 0), "channels"); + if(GST_VALUE_HOLDS_INT_RANGE(v)) { + int channels_in_min, channels_in_max, channels_out_min, channels_out_max; + channels_in_min = gst_value_get_int_range_min(v); + channels_in_max = gst_value_get_int_range_max(v); + channels_out_min = (int) pow((double) channels_in_min, 0.5); + channels_out_max = (int) pow((double) channels_in_max, 0.5); + gst_structure_set(str, "channels", GST_TYPE_INT_RANGE, channels_out_min, channels_out_max, NULL); + } else if(G_VALUE_HOLDS_INT(v)) { + int channels_in, channels_out; + channels_in = g_value_get_int(v); + channels_out = (int) pow((double) channels_in, 0.5); + gst_structure_set(str, "channels", G_TYPE_INT, channels_out, NULL); + } else + GST_ELEMENT_ERROR(trans, CORE, NEGOTIATION, (NULL), ("invalid type for channels in caps")); } break; @@ -311,31 +335,48 @@ static GstCaps *transform_caps(GstBaseTransform *trans, GstPadDirection directio static gboolean set_caps(GstBaseTransform *trans, GstCaps *incaps, GstCaps *outcaps) { GSTLALMatrixSolver *element = GSTLAL_MATRIXSOLVER(trans); - gint rate, channels; - gsize unit_size; + gint rate, channels_in, channels_out; + gsize unit_size_in, unit_size_out; /* * parse the caps */ - GstStructure *str = gst_caps_get_structure(outcaps, 0); - const gchar *name = gst_structure_get_string(str, "format"); + GstStructure *outstr = gst_caps_get_structure(outcaps, 0); + GstStructure *instr = gst_caps_get_structure(incaps, 0); + const gchar *name = gst_structure_get_string(outstr, "format"); if(!name) { GST_DEBUG_OBJECT(element, "unable to parse format from %" GST_PTR_FORMAT, outcaps); return FALSE; } - if(!get_unit_size(trans, outcaps, &unit_size)) { + if(!get_unit_size(trans, outcaps, &unit_size_out)) { GST_DEBUG_OBJECT(element, "function 'get_unit_size' failed"); return FALSE; } - if(!gst_structure_get_int(str, "channels", &channels)) { - GST_DEBUG_OBJECT(element, "unable to parse channels from %" GST_PTR_FORMAT, outcaps); + if(!get_unit_size(trans, incaps, &unit_size_in)) { + GST_DEBUG_OBJECT(element, "function 'get_unit_size' failed"); return FALSE; } - if(!gst_structure_get_int(str, "rate", &rate)) { + if(!gst_structure_get_int(outstr, "rate", &rate)) { GST_DEBUG_OBJECT(element, "unable to parse rate from %" GST_PTR_FORMAT, outcaps); return FALSE; } + if(!gst_structure_get_int(outstr, "channels", &channels_out)) { + GST_DEBUG_OBJECT(element, "unable to parse channels from %" GST_PTR_FORMAT, outcaps); + return FALSE; + } + if(!gst_structure_get_int(instr, "channels", &channels_in)) { + GST_DEBUG_OBJECT(element, "unable to parse channels from %" GST_PTR_FORMAT, incaps); + return FALSE; + } + + /* Require that there are N(N+1) input channels for N output channels */ + if(channels_in != channels_out * (channels_out + 1)) + GST_ERROR_OBJECT(element, "For N output channels, there must be N(N+1) input channels. input caps = %" GST_PTR_FORMAT " output caps = %" GST_PTR_FORMAT, incaps, outcaps); + + /* Require the input unit size to be larger than the output unit size by a factor N+1 */ + if(unit_size_in != unit_size_out * (channels_out + 1)) + GST_ERROR_OBJECT(element, "Input unit size must be N+1 times as large as output unit size, where N is the number of output channels. input caps = %" GST_PTR_FORMAT " output caps = %" GST_PTR_FORMAT, incaps, outcaps); /* * record stream parameters @@ -343,23 +384,65 @@ static gboolean set_caps(GstBaseTransform *trans, GstCaps *incaps, GstCaps *outc if(!strcmp(name, GST_AUDIO_NE(F32))) { element->data_type = GSTLAL_MATRIXSOLVER_F32; - g_assert_cmpuint(unit_size, ==, 4 * (guint) channels); + g_assert_cmpuint(unit_size_out, ==, 4 * (guint) channels_out); } else if(!strcmp(name, GST_AUDIO_NE(F64))) { element->data_type = GSTLAL_MATRIXSOLVER_F64; - g_assert_cmpuint(unit_size, ==, 8 * (guint) channels); + g_assert_cmpuint(unit_size_out, ==, 8 * (guint) channels_out); } else if(!strcmp(name, GST_AUDIO_NE(Z64))) { element->data_type = GSTLAL_MATRIXSOLVER_Z64; - g_assert_cmpuint(unit_size, ==, 8 * (guint) channels); + g_assert_cmpuint(unit_size_out, ==, 8 * (guint) channels_out); } else if(!strcmp(name, GST_AUDIO_NE(Z128))) { element->data_type = GSTLAL_MATRIXSOLVER_Z128; - g_assert_cmpuint(unit_size, ==, 16 * (guint) channels); + g_assert_cmpuint(unit_size_out, ==, 16 * (guint) channels_out); } else g_assert_not_reached(); element->rate = rate; - element->channels_out = channels; - element->channels_in = channels * (channels + 1); - element->unit_size_out = unit_size; + element->channels_out = channels_out; + element->channels_in = channels_in; + element->unit_size_out = unit_size_out; + + /* Allocate memory for gsl vectors, matrices, and permutations */ + if(element->data_type == GSTLAL_MATRIXSOLVER_F32 || element->data_type == GSTLAL_MATRIXSOLVER_F64) { + if(element->workspace.real.invec) { + gsl_vector_free(element->workspace.real.invec); + element->workspace.real.invec = NULL; + } + element->workspace.real.invec = gsl_vector_alloc(channels_out); + if(element->workspace.real.outvec) { + gsl_vector_free(element->workspace.real.outvec); + element->workspace.real.outvec = NULL; + } + element->workspace.real.outvec = gsl_vector_alloc(channels_out); + if(element->workspace.real.matrix) { + gsl_matrix_free(element->workspace.real.matrix); + element->workspace.real.matrix = NULL; + } + element->workspace.real.matrix = gsl_matrix_alloc(channels_out, channels_out); + } else if (element->data_type == GSTLAL_MATRIXSOLVER_Z64 || element->data_type == GSTLAL_MATRIXSOLVER_Z128) { + if(element->workspace.cplx.invec) { + gsl_vector_complex_free(element->workspace.cplx.invec); + element->workspace.cplx.invec = NULL; + } + element->workspace.cplx.invec = gsl_vector_complex_alloc(channels_out); + if(element->workspace.cplx.outvec) { + gsl_vector_complex_free(element->workspace.cplx.outvec); + element->workspace.cplx.outvec = NULL; + } + element->workspace.cplx.outvec = gsl_vector_complex_alloc(channels_out); + if(element->workspace.cplx.matrix) { + gsl_matrix_complex_free(element->workspace.cplx.matrix); + element->workspace.cplx.matrix = NULL; + } + element->workspace.cplx.matrix = gsl_matrix_complex_alloc(channels_out, channels_out); + } else + g_assert_not_reached(); + + if(element->permutation) { + gsl_permutation_free(element->permutation); + element->permutation = NULL; + } + element->permutation = gsl_permutation_alloc(channels_out); return TRUE; } @@ -370,8 +453,8 @@ static gboolean set_caps(GstBaseTransform *trans, GstCaps *incaps, GstCaps *outc */ -static gboolean transform_size(GstBaseTransform *trans, GstPadDirection direction, GstCaps *caps, gsize size, GstCaps *othercaps, gsize *othersize) -{ +static gboolean transform_size(GstBaseTransform *trans, GstPadDirection direction, GstCaps *caps, gsize size, GstCaps *othercaps, gsize *othersize) { + GSTLALMatrixSolver *element = GSTLAL_MATRIXSOLVER(trans); gsize unit_size; @@ -381,16 +464,6 @@ static gboolean transform_size(GstBaseTransform *trans, GstPadDirection directio return FALSE; } - /* - * convert byte count to samples - */ - - if(G_UNLIKELY(size % unit_size)) { - GST_DEBUG_OBJECT(element, "buffer size %" G_GSIZE_FORMAT " is not a multiple of %" G_GSIZE_FORMAT, size, unit_size); - return FALSE; - } - size /= unit_size; - /* * The data types of inputs and outputs are the same, but the number of channels differs. * For N output channels, there are N(N+1) input channels. @@ -418,11 +491,13 @@ static gboolean transform_size(GstBaseTransform *trans, GstPadDirection directio * The size of the output buffer should be a multiple of unit_size * (N+1). */ - if(G_UNLIKELY(size % (unit_size * element->channels_out + 1))) { - GST_DEBUG_OBJECT(element, "buffer size %" G_GSIZE_FORMAT " is not a multiple of %" G_GSIZE_FORMAT, size, unit_size * (element->channels_out + 1)); + if(G_UNLIKELY(size % unit_size)) { + GST_ERROR_OBJECT(element, "buffer size %" G_GSIZE_FORMAT " is not a multiple of %" G_GSIZE_FORMAT, size, unit_size); return FALSE; } + g_assert_cmpint(unit_size, ==, element->unit_size_out * (element->channels_out + 1)); + *othersize = size / (element->channels_out + 1); break; @@ -492,16 +567,16 @@ static GstFlowReturn transform(GstBaseTransform *trans, GstBuffer *inbuf, GstBuf gst_buffer_map(outbuf, &outmap, GST_MAP_WRITE); switch(element->data_type) { case GSTLAL_MATRIXSOLVER_F32: - solve_system_float((const void *) inmap.data, (void *) outmap.data, outmap.size / element->unit_size_out, element->channels_in, element->channels_out); + solve_system_float((const void *) inmap.data, (void *) outmap.data, outmap.size / element->unit_size_out, element->channels_in, element->channels_out, element->workspace.real.invec, element->workspace.real.outvec, element->workspace.real.matrix, element->permutation); break; case GSTLAL_MATRIXSOLVER_F64: - solve_system_double((const void *) inmap.data, (void *) outmap.data, outmap.size / element->unit_size_out, element->channels_in, element->channels_out); + solve_system_double((const void *) inmap.data, (void *) outmap.data, outmap.size / element->unit_size_out, element->channels_in, element->channels_out, element->workspace.real.invec, element->workspace.real.outvec, element->workspace.real.matrix, element->permutation); break; case GSTLAL_MATRIXSOLVER_Z64: - solve_system_complexfloat((const void *) inmap.data, (void *) outmap.data, outmap.size / element->unit_size_out, element->channels_in, element->channels_out); + solve_system_complexfloat((const void *) inmap.data, (void *) outmap.data, outmap.size / element->unit_size_out, element->channels_in, element->channels_out, element->workspace.cplx.invec, element->workspace.cplx.outvec, element->workspace.cplx.matrix, element->permutation); break; case GSTLAL_MATRIXSOLVER_Z128: - solve_system_complexdouble((const void *) inmap.data, (void *) outmap.data, outmap.size / element->unit_size_out, element->channels_in, element->channels_out); + solve_system_complexdouble((const void *) inmap.data, (void *) outmap.data, outmap.size / element->unit_size_out, element->channels_in, element->channels_out, element->workspace.cplx.invec, element->workspace.cplx.outvec, element->workspace.cplx.matrix, element->permutation); break; default: g_assert_not_reached(); @@ -539,6 +614,60 @@ static GstFlowReturn transform(GstBaseTransform *trans, GstBuffer *inbuf, GstBuf */ +/* + * finalize() + */ + + +static void finalize(GObject *object) { + + GSTLALMatrixSolver *element = GSTLAL_MATRIXSOLVER(object); + + /* + * free resources + */ + + if(element->data_type == GSTLAL_MATRIXSOLVER_F32 || element->data_type == GSTLAL_MATRIXSOLVER_F64) { + if(element->workspace.real.invec) { + gsl_vector_free(element->workspace.real.invec); + element->workspace.real.invec = NULL; + } + if(element->workspace.real.outvec) { + gsl_vector_free(element->workspace.real.outvec); + element->workspace.real.outvec = NULL; + } + if(element->workspace.real.matrix) { + gsl_matrix_free(element->workspace.real.matrix); + element->workspace.real.matrix = NULL; + } + } else if(element->data_type == GSTLAL_MATRIXSOLVER_Z64 || element->data_type == GSTLAL_MATRIXSOLVER_Z128) { + if(element->workspace.cplx.invec) { + gsl_vector_complex_free(element->workspace.cplx.invec); + element->workspace.cplx.invec = NULL; + } + if(element->workspace.cplx.outvec) { + gsl_vector_complex_free(element->workspace.cplx.outvec); + element->workspace.cplx.outvec = NULL; + } + if(element->workspace.cplx.matrix) { + gsl_matrix_complex_free(element->workspace.cplx.matrix); + element->workspace.cplx.matrix = NULL; + } + } + + if(element->permutation) { + gsl_permutation_free(element->permutation); + element->permutation = NULL; + } + + /* + * chain to parent class' finalize() method + */ + + G_OBJECT_CLASS(gstlal_matrixsolver_parent_class)->finalize(object); +} + + /* * class_init() */ @@ -548,6 +677,7 @@ static void gstlal_matrixsolver_class_init(GSTLALMatrixSolverClass *klass) { GstBaseTransformClass *transform_class = GST_BASE_TRANSFORM_CLASS(klass); GstElementClass *element_class = GST_ELEMENT_CLASS(klass); + GObjectClass *gobject_class = G_OBJECT_CLASS(klass); gst_element_class_set_details_simple( element_class, @@ -577,6 +707,7 @@ static void gstlal_matrixsolver_class_init(GSTLALMatrixSolverClass *klass) transform_class->transform_size = GST_DEBUG_FUNCPTR(transform_size); transform_class->start = GST_DEBUG_FUNCPTR(start); transform_class->transform = GST_DEBUG_FUNCPTR(transform); + gobject_class->finalize = GST_DEBUG_FUNCPTR(finalize); } @@ -592,4 +723,5 @@ static void gstlal_matrixsolver_init(GSTLALMatrixSolver *element) element->channels_in = 0; element->channels_out = 0; element->unit_size_out = 0; + element->permutation = NULL; } diff --git a/gstlal-calibration/gst/lal/gstlal_matrixsolver.h b/gstlal-calibration/gst/lal/gstlal_matrixsolver.h index 7e50dc461f..88b3516aa7 100644 --- a/gstlal-calibration/gst/lal/gstlal_matrixsolver.h +++ b/gstlal-calibration/gst/lal/gstlal_matrixsolver.h @@ -70,6 +70,21 @@ struct _GSTLALMatrixSolver { guint64 next_in_offset; guint64 next_out_offset; gboolean need_discont; + + /* gsl stuff for solving systens of linear equations */ + union { + struct { + gsl_vector *invec; + gsl_vector *outvec; + gsl_matrix *matrix; + } real; + struct { + gsl_vector_complex *invec; + gsl_vector_complex *outvec; + gsl_matrix_complex *matrix; + } cplx; + } workspace; + gsl_permutation *permutation; }; -- GitLab