Commit 12a394ce authored by Matthew David Pitkin's avatar Matthew David Pitkin

ppe_init.[c,h],ppe_likelihood.c: allow use of multi-variate GMM prior

Original: 1770da7a3f4e36e1e04eea4e57e34c536435d3ec
parent 2d1d5919
......@@ -623,8 +623,21 @@ void initialise_prior( LALInferenceRunState *runState )
strtoupper( tempPar ); /* convert tempPar to all uppercase letters */
tempPrior = XLALStringDuplicate( tline->tokens[1] );
/* check if there is more than one parameter in tempPar, separated by a ':', for us in GMM prior */
TokenList *parnames = NULL;
XLALCreateTokenList( &parnames, tempPar, ":" ); /* find number of parameters used by GMM (parameters should be ':' separated */
UINT4 npars = parnames->nTokens;
/* gaussian, uniform, loguniform and fermi-dirac priors should all have four values to a line */
if ( nvals == 4 ){
if ( !strcmp(tempPrior, "uniform") || !strcmp(tempPrior, "loguniform") || !strcmp(tempPrior, "gaussian") || !strcmp(tempPrior, "fermidirac") ){
if ( npars > 1 ){
XLAL_ERROR_VOID( XLAL_EINVAL, "Error... 'uniform', 'loguniform', 'gaussian', or 'fermidirac' priors must only be given for single parameters." );
}
if ( nvals != 4 ){
XLAL_ERROR_VOID( XLAL_EINVAL, "Error... 'uniform', 'loguniform', 'gaussian', or 'fermidirac' priors must specify four values." );
}
low = atof(tline->tokens[2]);
high = atof(tline->tokens[3]);
......@@ -652,79 +665,119 @@ void initialise_prior( LALInferenceRunState *runState )
LALInferenceAddFermiDiracPrior( runState->priorArgs, tempPar, &low, &high, LALINFERENCE_REAL8_t );
}
}
else if ( nvals > 4 ){
/* check if using a 1D Gaussian Mixture Model prior e.g.:
* iota gmm 2 0.5 0.01 1.0 1.5 0.01 1.0 0.0 3.14
* where the third value is the number of modes, followed by tuples of means, standard deviations and weights for each mode
* and finally (if required two values giving minimum and maximum limits for the prior).
else if ( !strcmp(tempPrior, "gmm") ){
/* check if using a 1D/multi-dimensional Gaussian Mixture Model prior e.g.:
* H0:COSIOTA gmm 2 [[1e-23,0.3],[2e-23,0.4]] [[[1e-45,2e-25],[2e-25,0.02]],[[4e-46,2e-24],[2e-24,0.1]]] [0.2,0.4] [0.,1e-22] [-1.,1.]
* where the third value is the number of modes followed by: a list of means for each parameter for each mode; the covariance
* matrices for each mode; the weights for each mode; and if required sets of pairs of minimum and maximum prior ranges for each
* parameter.
*/
if ( !strcmp(tempPrior, "gmm") ){
UINT4 nmodes = (UINT4)atoi( tline->tokens[2] );
if ( nvals < 6 ){
fprintf(stderr, "Warning... number of values ('%d') on line '%d' in prior file is different than expected:\n\t'%s'", nvals, k+1, tlist->tokens[k]);
XLALDestroyTokenList( tline );
continue;
}
if ( nvals - 2 < 3*nmodes ){
XLAL_ERROR_VOID( XLAL_EINVAL, "Error... GMM prior is not properly defined" );
}
UINT4 nmodes = (UINT4)atoi( tline->tokens[2] ); /* get the number of modes */
/* get Gaussian mode means, standard deviations and weights */
REAL8Vector *gmmsigmas = NULL, *gmmmus = NULL, *gmmweights = NULL;
gmmsigmas = XLALCreateREAL8Vector( nmodes );
gmmmus = XLALCreateREAL8Vector( nmodes );
gmmweights = XLALCreateREAL8Vector( nmodes );
for ( i = 0; i < nmodes; i++ ){
gmmmus->data[i] = atof( tline->tokens[3+3*i] );
gmmsigmas->data[i] = atof( tline->tokens[3+3*i+1] );
gmmweights->data[i] = atof( tline->tokens[3+3*i+2] );
}
/* get means of modes for each parameter */
REAL8Vector **gmmmus;
gmmmus = parse_gmm_means(tline->tokens[3], npars, nmodes);
if ( !gmmmus ){
XLAL_ERROR_VOID(XLAL_EINVAL, "Error... problem parsing GMM prior mean values for '%s'.", tempPar);
}
REAL8 minval = -INFINITY, maxval = INFINITY;
/* check if minimum and maximum bounds are specified */
if ( nvals > 3+3*nmodes ) { minval = atof(tline->tokens[3+3*nmodes]); }
if ( nvals > 3+3*nmodes+1 ) { maxval = atof(tline->tokens[3+3*nmodes+1]); }
if ( nvals > 3+3*nmodes+2 ){
fprintf(stderr, "Warning... additional unnecessary values given in GMM prior\n");
}
/* get the covariance matrices for the modes */
gsl_matrix **gmmcovs;
gmmcovs = parse_gmm_covs(tline->tokens[4], npars, nmodes);
if ( !gmmcovs ){
XLAL_ERROR_VOID(XLAL_EINVAL, "Error... problem parsing GMM prior covariance matrix values for '%s'.", tempPar);
}
/* get weights for the modes */
REAL8Vector *gmmweights = NULL;
gmmweights = XLALCreateREAL8Vector( nmodes );
LALInferenceAdd1DGMMPrior( runState->priorArgs, tempPar, &gmmmus, &gmmsigmas, &gmmweights, &minval, &maxval );
CHAR strpart[8192];
CHAR UNUSED *nextpart;
nextpart = get_bracketed_string(strpart, tline->tokens[5], '[', ']');
if ( !strpart[0] ){
XLAL_ERROR_VOID(XLAL_EINVAL, "Error... problem parsing GMM prior weights values for '%s'.", tempPar);
}
else{
XLAL_ERROR_VOID( XLAL_EINVAL, "Error... prior type '%s' not recognised", tempPrior );
/* parse comma separated weights */
TokenList *weightvals = NULL;
XLALCreateTokenList( &weightvals, strpart, "," );
if ( weightvals->nTokens != nmodes ){
XLAL_ERROR_VOID(XLAL_EINVAL, "Error... problem parsing GMM prior weights values for '%s'.", tempPar);
}
for ( UINT4 j=0; j < nmodes; j++ ){ gmmweights->data[j] = atof(weightvals->tokens[j]); }
XLALDestroyTokenList( weightvals );
REAL8 minval = -INFINITY, maxval = INFINITY;
REAL8Vector *minvals = XLALCreateREAL8Vector( npars ), *maxvals = XLALCreateREAL8Vector( npars );
/* check if minimum and maximum bounds are specified (otherwise set to +/- infinity) */
/* there are minimum and maximum values, e.g. [h0min,h0max] [cosiotamin,cosiotamax] */
for ( UINT4 j=0; j < npars; j++ ){
REAL8 thismin = minval, thismax = maxval;
if ( tline->nTokens > 6+j ){
nextpart = get_bracketed_string(strpart, tline->tokens[6+j], '[', ']');
if ( !strpart[0] ){
XLAL_ERROR_VOID(XLAL_EINVAL, "Error... problem parsing GMM prior limit values for '%s'.", tempPar);
}
TokenList *minmaxvals = NULL;
XLALCreateTokenList( &minmaxvals, strpart, "," );
if ( minmaxvals->nTokens == 2 ){
if ( isfinite(atof(minmaxvals->tokens[0])) && isfinite(atof(minmaxvals->tokens[1])) ){
thismin = atof(minmaxvals->tokens[0]);
thismax = atof(minmaxvals->tokens[1]);
}
}
XLALDestroyTokenList( minmaxvals );
}
minvals->data[j] = thismin;
maxvals->data[j] = thismax;
}
LALInferenceAddGMMPrior( runState->priorArgs, tempPar, &gmmmus, &gmmcovs, &gmmweights, &minvals, &maxvals );
}
else{
fprintf(stderr, "Warning... number of values ('%d') on line '%d' in prior file is different than expected:\n\t'%s'\n", nvals, k+1, tlist->tokens[k]);
XLALDestroyTokenList( tline );
continue;
XLAL_ERROR_VOID( XLAL_EINVAL, "Error... prior type '%s' not recognised", tempPrior );
}
/* set variable type to LINEAR (as they are initialised as FIXED) */
varyType = LALINFERENCE_PARAM_LINEAR;
LALInferenceSetParamVaryType( threadState->currentParams, tempPar, varyType );
/* if there is a phase parameter defined in the proposal then set varyphase to 1 */
for ( i = 0; i < NUMAMPPARS; i++ ){
if ( !strcmp(tempPar, amppars[i]) ){
isthere = 1;
break;
for ( UINT4 j = 0; j < npars; j++ ){
for ( i = 0; i < NUMAMPPARS; i++ ){
if ( !strcmp(parnames->tokens[j], amppars[i]) ){
isthere = 1;
break;
}
}
}
if ( !isthere ) { varyphase = 1; }
if ( !isthere ) { varyphase = 1; }
/* check if there are sky position parameters that will be searched over */
for ( i = 0; i < NUMSKYPARS; i++ ){
if ( !strcmp(tempPar, skypars[i]) ){
varyskypos = 1;
break;
/* check if there are sky position parameters that will be searched over */
for ( i = 0; i < NUMSKYPARS; i++ ){
if ( !strcmp(parnames->tokens[j], skypars[i]) ){
varyskypos = 1;
break;
}
}
}
/* check if there are any binary parameters that will be searched over */
for ( i = 0; i < NUMBINPARS; i++ ){
if ( !strcmp(tempPar, binpars[i]) ){
varybinary = 1;
break;
/* check if there are any binary parameters that will be searched over */
for ( i = 0; i < NUMBINPARS; i++ ){
if ( !strcmp(parnames->tokens[j], binpars[i]) ){
varybinary = 1;
break;
}
}
/* set variable type to LINEAR (as they are initialised as FIXED) */
varyType = LALINFERENCE_PARAM_LINEAR;
LALInferenceSetParamVaryType( threadState->currentParams, parnames->tokens[j], varyType );
}
XLALDestroyTokenList( parnames );
XLALDestroyTokenList( tline );
}
......@@ -1373,6 +1426,173 @@ void sum_data( LALInferenceRunState *runState ){
return;
}
/**
* \brief Parse data from a prior file containing Gaussian Mixture Model mean values
*
* If a Gaussian Mixture Model prior has been specified then this function will parse
* the means for each parameter for each mode given. E.g. if the GMM provides multivariate
* Gaussian modes for two parameters, x and y, then the means would be specified in a string
* of the form "[[mux1,muy1],[mux2,muy2],....]". The string should have no whitespace between
* values, and mean values for a given mode must be separated by a comma.
*
* These values are returned in an vector of REAL8Vectors. If an error occurred then NULL will be returned.
*
* \param meanstr [in] A string containing the mean values
* \param npars [in] The number of parameters
* \param nmodes [in] The number of modes
*/
REAL8Vector** parse_gmm_means(CHAR *meanstr, UINT4 npars, UINT4 nmodes){
UINT4 modecount = 0;
/* parse mean string */
CHAR *startloc = strchr(meanstr, '['); /* find location of first '[' */
if ( !startloc ){ return NULL; }
CHAR strpart[16384]; /* string to hold elements */
/* allocate memory for returned value */
REAL8Vector **meanmat;
meanmat = XLALCalloc(nmodes, sizeof(REAL8Vector *));
while( 1 ){
CHAR *closeloc = get_bracketed_string(strpart, startloc+1, '[', ']');
if ( !strpart[0] ){ break; } /* break when no more bracketed items are found */
/* get mean values (separated by commas) */
TokenList *meantoc = NULL;
XLALCreateTokenList( &meantoc, strpart, "," );
if ( meantoc->nTokens != npars ){
XLAL_PRINT_WARNING("Warning... number of means parameters specified for GMM is not consistent with number of parameters.\n");
for ( INT4 k=modecount-1; k > -1; k-- ){ XLALDestroyREAL8Vector(meanmat[k]); }
XLALFree(meanmat);
return NULL;
}
meanmat[modecount] = XLALCreateREAL8Vector( npars );
for( UINT4 j = 0; j < meantoc->nTokens; j++ ){ meanmat[modecount]->data[j] = atof(meantoc->tokens[j]); }
startloc = closeloc;
modecount++;
XLALDestroyTokenList( meantoc );
}
if ( modecount != nmodes ){
XLAL_PRINT_WARNING("Warning... number of means values specified for GMM is not consistent with number of modes.\n");
for ( INT4 k=modecount-1; k > -1; k-- ){ XLALDestroyREAL8Vector(meanmat[k]); }
XLALFree(meanmat);
meanmat = NULL;
}
return meanmat;
}
/**
* \brief Parse data from a prior file containing Gaussian Mixture Model covariance matrix values
*
* If a Gaussian Mixture Model prior has been specified then this function will parse
* the covariance matrices for each mode given. E.g. if the GMM provides multivariate
* Gaussian modes for two parameters, x and y, then the covariances for each mode would
* be specified in a string of the form "[[[covxx1,covxy1][covyx1,covyy1]],[[covxx2,covxy2][covyx2,covyy2]],...]".
* The string should have no whitespace between values, and covariance values for a given mode must be
* separated by a comma.
*
* These values are returned in an array of GSL matrices. If an error occurred then NULL will be returned.
*
* \param meanstr [in] A string containing the covariance matrix values
* \param npars [in] The number of parameters
* \param nmodes [in] The number of modes
*/
gsl_matrix** parse_gmm_covs(CHAR *covstr, UINT4 npars, UINT4 nmodes){
UINT4 modecount = 0;
/* parse covariance string */
CHAR *startloc = strchr(covstr, '['); /* find location of first '[' */
if ( !startloc ){ return NULL; }
CHAR strpart[16384]; /* string to hold elements */
/* allocate memory for returned value */
gsl_matrix **covmat;
covmat = XLALCalloc(nmodes, sizeof(gsl_matrix *));
while( 1 ){
CHAR *openloc = strstr(startloc+1, "[["); /* find next "[[" */
/* break if there are no more opening brackets */
if ( !openloc ){ break; }
CHAR *closeloc = strstr(openloc+1, "]]"); /* find next "]]" */
if ( !closeloc ){ break; }
strncpy(strpart, openloc+1, (closeloc-openloc)); /* copy string */
strpart[(closeloc-openloc)] = '\0'; /* add null terminating character */
CHAR *newstartloc = strpart;
UINT4 parcount = 0;
covmat[modecount] = gsl_matrix_alloc(npars, npars);
while ( 1 ){
CHAR newstrpart[8192];
CHAR *newcloseloc = get_bracketed_string(newstrpart, newstartloc, '[', ']');
if ( !newstrpart[0] ){ break; } /* read all of covariance matrix for this mode */
if ( parcount > npars ){
XLAL_PRINT_WARNING("Warning... number of covariance parameters specified for GMM is not consistent with number of parameters.\n");
for ( INT4 k=modecount; k > -1; k-- ){ gsl_matrix_free(covmat[k]); }
XLALFree(covmat);
return NULL;
}
newstartloc = newcloseloc;
/* get covariance values (separated by commas) */
TokenList *covtoc = NULL;
XLALCreateTokenList( &covtoc, newstrpart, "," );
if ( covtoc->nTokens != npars ){
XLAL_PRINT_WARNING("Warning... number of means parameters specified for GMM is not consistent with number of parameters.\n");
for ( INT4 k=modecount; k > -1; k-- ){ gsl_matrix_free(covmat[k]); }
XLALFree(covmat);
return NULL;
}
for( UINT4 j = 0; j < covtoc->nTokens; j++ ){ gsl_matrix_set(covmat[modecount], parcount, j, atof(covtoc->tokens[j])); }
XLALDestroyTokenList( covtoc );
parcount++;
}
startloc = closeloc;
modecount++;
}
if ( modecount != nmodes ){
XLAL_PRINT_WARNING("Warning... number of means values specified for GMM is not consistent with number of modes.\n");
for ( INT4 k=modecount; k > -1; k-- ){ gsl_matrix_free(covmat[k]); }
XLALFree(covmat);
covmat = NULL;
}
return covmat;
}
CHAR* get_bracketed_string(CHAR *dest, const CHAR *bstr, int openbracket, int closebracket){
/* get positions of opening and closing brackets */
CHAR *openpar = strchr(bstr, openbracket);
CHAR *closepar = strchr(bstr+1, closebracket);
if ( !openpar || !closepar ){
dest[0] = 0;
return NULL;
}
strncpy(dest, openpar+1, (closepar-openpar)-1);
dest[(closepar-openpar)-1] = '\0';
/* return pointer the the location after the closing brackets */
return closepar+1;
}
void initialise_threads(LALInferenceRunState *state, INT4 nthreads){
INT4 i,randomseed;
......
......@@ -51,6 +51,9 @@ void add_correlation_matrix( LALInferenceVariables *ini,
void sum_data( LALInferenceRunState *runState );
void LogSampleToFile(LALInferenceVariables *algorithmParams, LALInferenceVariables *vars);
void LogSampleToArray(LALInferenceVariables *algorithmParams, LALInferenceVariables *vars);
REAL8Vector** parse_gmm_means(CHAR *meanstr, UINT4 npars, UINT4 nmodes);
gsl_matrix** parse_gmm_covs(CHAR *covstr, UINT4 npars, UINT4 nmodes);
CHAR* get_bracketed_string(CHAR *dest, const CHAR *bstr, int openbracket, int closebracket);
void initialise_threads(LALInferenceRunState *state, INT4 nthreads);
#ifdef __cplusplus
......
......@@ -512,9 +512,9 @@ REAL8 priorFunction( LALInferenceRunState *runState, LALInferenceVariables *para
}
}
}
/* check if using a 1d Gaussian Mixture Model prior */
else if( LALInferenceCheck1DGMMPrior(runState->priorArgs, item->name) ){
prior += LALInference1DGMMPrior( runState->priorArgs, item->name, value );
/* check if using a Gaussian Mixture Model prior */
else if( LALInferenceCheckGMMPrior(runState->priorArgs, item->name) ){
prior += LALInferenceGMMPrior( runState->priorArgs, item->name, value );
}
/* check for log(uniform) prior */
else if( LALInferenceCheckLogUniformPrior(runState->priorArgs, item->name) ){
......@@ -551,7 +551,7 @@ REAL8 priorFunction( LALInferenceRunState *runState, LALInferenceVariables *para
LALInferenceGetCorrelatedPrior( runState->priorArgs, corlist->data[0], &cor, &invcor, &mu, &sigma, &idx );
/* get the log prior (this only works properly if the parameter values have been prescaled so as to be from a
* Gaussian of zero mean and unit variance, which happens on line 473) */
* Gaussian of zero mean and unit variance, which happens on line 510) */
vals = gsl_vector_view_array( corVals->data, corVals->length );
XLAL_CALLGSL( gsl_blas_dgemv(CblasNoTrans, 1., invcor, &vals.vector, 0., vm) );
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment