Commit 8ec9ed5c authored by Matthew David Pitkin's avatar Matthew David Pitkin

LALInferenceGenerateROQ.c: change to using iterated Gram-Schmidt algorithm from greedycpp

 - in addition to changing the to the use the iterated Gram-Schmidt
   algorithm from the greedycpp code, this also changes the validation
   and enrichment method, making it more efficient
 - the code in LALapps's ppe_roq.c has be modified to reflect these
   changes
parent 01b510fb
......@@ -279,7 +279,8 @@ void generate_interpolant( LALInferenceRunState *runState ){
ts = generate_training_set( tmpRS, ntraining ); /* generate the training set */
COMPLEX16Array *RB = NULL; /* the reduced basis */
maxprojerr = LALInferenceGenerateCOMPLEX16OrthonormalBasis(&RB, deltas, tolerance, ts); /* generate reduced basis */
UINT4Vector *greedypts = NULL; /* the points in the training set used to produce the reduced basis */
maxprojerr = LALInferenceGenerateCOMPLEX16OrthonormalBasis(&RB, deltas, tolerance, &ts, &greedypts); /* generate reduced basis */
XLAL_CHECK_VOID( RB != NULL, XLAL_EFUNC, "Could not produce linear basis set" );
nbases->data[i] = RB->dimLength->data[0];
......@@ -292,9 +293,11 @@ void generate_interpolant( LALInferenceRunState *runState ){
do {
/* create a new training set to try and "enrich" the original one */
COMPLEX16Array *tsenrich = NULL;
REAL8 newmaxprojerr = maxprojerr;
tsenrich = generate_training_set( tmpRS, ntraining );
maxprojerr = LALInferenceEnrichCOMPLEX16Basis(deltas, tolerance, &RB, &ts, tsenrich); /* regenerate reduced basis */
newmaxprojerr = LALInferenceEnrichCOMPLEX16Basis(deltas, tolerance, &RB, &greedypts, ts, &tsenrich); /* regenerate reduced basis */
if ( newmaxprojerr != 0. ) { maxprojerr = newmaxprojerr; } // update only if a new reduced basis has been generated
/* check if no new bases have been added */
if ( nbases->data[i] == RB->dimLength->data[0] ){
......@@ -313,12 +316,26 @@ void generate_interpolant( LALInferenceRunState *runState ){
}
if ( verbose ){ fprintf(stderr, "...%d", nbases->data[i]); }
// copy tsenrich into ts (as this is the new enriched training set from which the reduced basis was calculated)
if ( nbases->data[i] != RB->dimLength->data[0] ){
XLALDestroyCOMPLEX16Array( ts );
UINT4Vector *dimstmp = NULL;
dimstmp = XLALCreateUINT4Vector( 2 );
dimstmp->data[0] = tsenrich->dimLength->data[0];
dimstmp->data[1] = tsenrich->dimLength->data[1];
ts = XLALCreateCOMPLEX16Array( dimstmp );
for ( UINT4 didx = 0; didx < (dimstmp->data[0]*dimstmp->data[1]); didx++ ){ ts->data[didx] = tsenrich->data[didx]; }
XLALDestroyUINT4Vector( dimstmp );
}
XLALDestroyCOMPLEX16Array( tsenrich );
enrichcounts++;
} while( enrichcounts < nenrichmax && conseccount < nconsec );
if ( verbose ){ fprintf(stderr, "\n"); }
}
XLALDestroyUINT4Vector( greedypts );
XLALDestroyCOMPLEX16Array( ts );
if ( verbose ){
......@@ -346,7 +363,8 @@ void generate_interpolant( LALInferenceRunState *runState ){
tsquad = generate_training_set_quad( tmpRS, ntraining ); /* generate the training set */
REAL8Array *RBquad = NULL; /* the reduced basis */
maxprojerr = LALInferenceGenerateREAL8OrthonormalBasis(&RBquad, deltas, tolerance, tsquad); /* generate reduced basis */
greedypts = NULL;
maxprojerr = LALInferenceGenerateREAL8OrthonormalBasis(&RBquad, deltas, tolerance, &tsquad, &greedypts); /* generate reduced basis */
XLAL_CHECK_VOID( RBquad != NULL, XLAL_EFUNC, "Could not produce quadratic basis set" );
nbasesquad->data[i] = RBquad->dimLength->data[0];
......@@ -359,9 +377,11 @@ void generate_interpolant( LALInferenceRunState *runState ){
do {
/* create a new training set to try and "enrich" the original one */
REAL8Array *tsenrich = NULL;
REAL8 newmaxprojerr = maxprojerr;
tsenrich = generate_training_set_quad( tmpRS, ntraining );
maxprojerr = LALInferenceEnrichREAL8Basis(deltas, tolerance, &RBquad, &tsquad, tsenrich); /* regenerate reduced basis */
newmaxprojerr = LALInferenceEnrichREAL8Basis(deltas, tolerance, &RBquad, &greedypts, tsquad, &tsenrich); /* regenerate reduced basis */
if ( newmaxprojerr != 0. ) { maxprojerr = newmaxprojerr; } // update only if a new reduced basis has been generated
/* check if no new bases have been added */
if ( nbasesquad->data[i] == RBquad->dimLength->data[0] ){
......@@ -379,6 +399,18 @@ void generate_interpolant( LALInferenceRunState *runState ){
break;
}
// copy tsenrich into tsquad (as this is the new enriched training set from which the reduced basis was calculated)
if ( nbasesquad->data[i] != RBquad->dimLength->data[0] ){
XLALDestroyREAL8Array( tsquad );
UINT4Vector *dimstmp = NULL;
dimstmp = XLALCreateUINT4Vector( 2 );
dimstmp->data[0] = tsenrich->dimLength->data[0];
dimstmp->data[1] = tsenrich->dimLength->data[1];
tsquad = XLALCreateREAL8Array( dimstmp );
for ( UINT4 didx = 0; didx < (dimstmp->data[0]*dimstmp->data[1]); didx++ ){ tsquad->data[didx] = tsenrich->data[didx]; }
XLALDestroyUINT4Vector( dimstmp );
}
if ( verbose ){ fprintf(stderr, "...%d", nbasesquad->data[i]); }
XLALDestroyREAL8Array( tsenrich );
enrichcounts++;
......@@ -386,6 +418,7 @@ void generate_interpolant( LALInferenceRunState *runState ){
if ( verbose ){ fprintf(stderr, "\n"); }
}
XLALDestroyUINT4Vector(greedypts);
XLALDestroyREAL8Array(tsquad);
if ( verbose ){ fprintf(stderr, "Number of quadratic reduced bases for ROQ generation is %d, with a maximum projection error of %le\n", nbasesquad->data[i], maxprojerr);}
......
......@@ -22,3 +22,16 @@
approximation and analysis},
primaryclass = {gr-qc}
}
@ARTICLE{Hoffmann1989,
author = {{Hoffmann}, W.},
title = {{Iterative algorithms for Gram-Schmidt orthogonalization}},
journal = {Computing},
year = 1989,
month = dec,
volume = 41,
number = 4,
pages = 335--348,
doi = 10.1007/BF02241222,
url = {https://doi.org/10.1007/BF02241222}
}
This diff is collapsed.
......@@ -54,39 +54,52 @@ typedef struct tagLALInferenceCOMPLEXROQInterpolant{
/* function to create or enrich a real orthonormal basis set from a training set of models */
REAL8 LALInferenceGenerateREAL8OrthonormalBasis(REAL8Array **RB,
REAL8Vector *delta,
const REAL8Vector *delta,
REAL8 tolerance,
REAL8Array *TS);
REAL8Array **TS,
UINT4Vector **greedypoints);
REAL8 LALInferenceGenerateCOMPLEX16OrthonormalBasis(COMPLEX16Array **RB,
REAL8Vector *delta,
const REAL8Vector *delta,
REAL8 tolerance,
COMPLEX16Array *TS);
COMPLEX16Array **TS,
UINT4Vector **greedypoints);
/* functions to test the basis */
INT4 LALInferenceTestREAL8OrthonormalBasis(REAL8Vector *delta,
void LALInferenceValidateREAL8OrthonormalBasis(REAL8Vector **projerr,
const REAL8Vector *delta,
const REAL8Array *RB,
REAL8Array **testmodels);
void LALInferenceValidateCOMPLEX16OrthonormalBasis(REAL8Vector **projerr,
const REAL8Vector *delta,
const COMPLEX16Array *RB,
COMPLEX16Array **testmodels);
INT4 LALInferenceTestREAL8OrthonormalBasis(const REAL8Vector *delta,
REAL8 tolerance,
REAL8Array *RB,
REAL8Array *testmodels);
const REAL8Array *RB,
REAL8Array **testmodels);
INT4 LALInferenceTestCOMPLEX16OrthonormalBasis(REAL8Vector *delta,
INT4 LALInferenceTestCOMPLEX16OrthonormalBasis(const REAL8Vector *delta,
REAL8 tolerance,
COMPLEX16Array *RB,
COMPLEX16Array *testmodels);
const COMPLEX16Array *RB,
COMPLEX16Array **testmodels);
/* functions to enrich the training model set and basis set */
REAL8 LALInferenceEnrichREAL8Basis(REAL8Vector *delta,
REAL8 tolerance,
REAL8 LALInferenceEnrichREAL8Basis(const REAL8Vector *delta,
const REAL8 tolerance,
REAL8Array **RB,
REAL8Array **testmodels,
REAL8Array *testmodelsnew);
UINT4Vector **greedypoints,
const REAL8Array *testmodels,
REAL8Array **testmodelsnew);
REAL8 LALInferenceEnrichCOMPLEX16Basis(REAL8Vector *delta,
REAL8 tolerance,
REAL8 LALInferenceEnrichCOMPLEX16Basis(const REAL8Vector *delta,
const REAL8 tolerance,
COMPLEX16Array **RB,
COMPLEX16Array **testmodels,
COMPLEX16Array *testmodelsnew);
UINT4Vector **greedypoints,
const COMPLEX16Array *testmodels,
COMPLEX16Array **testmodelsnew);
/* function to create the empirical interpolant */
LALInferenceREALROQInterpolant *LALInferenceGenerateREALROQInterpolant(REAL8Array *RB);
......
......@@ -45,6 +45,7 @@ COMPLEX16 imag_model(double frequency, double Mchirp, double modperiod){
int main(void) {
REAL8Array *TS = NULL, *TSquad = NULL, *cTSquad = NULL; /* the training set of real waveforms (and quadratic model) */
COMPLEX16Array *cTS = NULL; /* the training set of complex waveforms */
UINT4Vector *gdpts = NULL; /* the greedy points used for the reduced basis generation */
size_t TSsize; /* the size of the training set (number of waveforms) */
size_t wl; /* the length of each waveform */
......@@ -120,13 +121,17 @@ int main(void) {
/* create reduced orthonormal basis from training set for linear part */
REAL8 maxprojerr = 0.;
maxprojerr = LALInferenceGenerateREAL8OrthonormalBasis(&RBlinear, fweights, tolerance, TS);
maxprojerr = LALInferenceGenerateREAL8OrthonormalBasis(&RBlinear, fweights, tolerance, &TS, &gdpts);
XLALDestroyUINT4Vector( gdpts );
fprintf(stderr, "No. linear nodes (real) = %d, %d x %d; Maximum projection err. = %le\n", RBlinear->dimLength->data[0], RBlinear->dimLength->data[0], RBlinear->dimLength->data[1], maxprojerr);
maxprojerr = LALInferenceGenerateCOMPLEX16OrthonormalBasis(&cRBlinear, fweights, tolerance, cTS);
maxprojerr = LALInferenceGenerateCOMPLEX16OrthonormalBasis(&cRBlinear, fweights, tolerance, &cTS, &gdpts);
XLALDestroyUINT4Vector( gdpts );
fprintf(stderr, "No. linear nodes (complex) = %d, %d x %d; Maximum projection err. = %le\n", cRBlinear->dimLength->data[0], cRBlinear->dimLength->data[0], cRBlinear->dimLength->data[1], maxprojerr);
maxprojerr = LALInferenceGenerateREAL8OrthonormalBasis(&RBquad, fweights, tolerance, TSquad);
maxprojerr = LALInferenceGenerateREAL8OrthonormalBasis(&RBquad, fweights, tolerance, &TSquad, &gdpts);
XLALDestroyUINT4Vector( gdpts );
fprintf(stderr, "No. quadratic nodes (real) = %d, %d x %d; Maximum projection err. = %le\n", RBquad->dimLength->data[0], RBquad->dimLength->data[0], RBquad->dimLength->data[1], maxprojerr);
maxprojerr = LALInferenceGenerateREAL8OrthonormalBasis(&cRBquad, fweights, tolerance, cTSquad);
maxprojerr = LALInferenceGenerateREAL8OrthonormalBasis(&cRBquad, fweights, tolerance, &cTSquad, &gdpts);
XLALDestroyUINT4Vector( gdpts );
fprintf(stderr, "No. quadratic nodes (complex) = %d, %d x %d; Maximum projection err. = %le\n", cRBquad->dimLength->data[0], cRBquad->dimLength->data[0], cRBquad->dimLength->data[1], maxprojerr);
/* free the training set */
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment