Commit 9e4e480a authored by Benjamin Farr's avatar Benjamin Farr
Browse files

lalinference_mcmc: switched tempMax setting to SNR-based criteria

Original: a141eba661ef047bbdde0e9532545aaf4ca5d89e
parent ba843a9d
......@@ -147,6 +147,7 @@ void initializeMCMC(LALInferenceRunState *runState)
char help[]="\
(--Niter N) Number of iterations(2*10^6)\n\
(--Nskip n) Number of iterations between disk save(100)\n\
(--trigSNR SNR) Network SNR from trigger, used to calculate tempMax\n\
(--tempMin T) Lowest temperature for parallel tempering(1.0)\n\
(--tempMax T) Highest temperature for parallel tempering(50.0)\n\
(--randomseed seed) Random seed of sampling distribution(random)\n\
......@@ -165,6 +166,7 @@ void initializeMCMC(LALInferenceRunState *runState)
INT4 verbose=0,tmpi=0;
unsigned int randomseed=0;
REAL8 trigSNR = 0.0;
REAL8 tempMin = 1.0;
REAL8 tempMax = 50.0;
ProcessParamsTable *commandLine=runState->commandLine;
......@@ -269,6 +271,14 @@ void initializeMCMC(LALInferenceRunState *runState)
}
LALInferenceAddVariable(runState->algorithmParams,"Nskip",&tmpi, LALINFERENCE_UINT4_t,LALINFERENCE_PARAM_FIXED);
printf("set trigger SNR.\n");
/* Network SNR of trigger */
ppt=LALInferenceGetProcParamVal(commandLine,"--trigSNR");
if(ppt){
trigSNR=strtod(ppt->value,(char **)NULL);
}
LALInferenceAddVariable(runState->algorithmParams,"trigSNR",&trigSNR,LALINFERENCE_REAL8_t,LALINFERENCE_PARAM_FIXED);
printf("set lowest temperature.\n");
/* Minimum temperature of the temperature ladder */
ppt=LALInferenceGetProcParamVal(commandLine,"--tempMin");
......
......@@ -192,6 +192,7 @@ void PTMCMCAlgorithm(struct tagLALInferenceRunState *runState)
REAL8 tempCurrentPrior = 0.0;
REAL8 tempCurrentLikelihood = 0.0;
REAL8 priorMin, priorMax, dprior;
REAL8 tempDelta = 0.0;
REAL8Vector * parameters = NULL;
LALInferenceVariables tempCurrentParams;
LALInferenceVariables flatPriorTestParams;
......@@ -240,11 +241,18 @@ void PTMCMCAlgorithm(struct tagLALInferenceRunState *runState)
ptr=ptr->next;
}
/* Determine network SNR if injection was done */
REAL8 networkSNRsqrd = 0.0;
LALInferenceIFOData *IFO = runState->data;
while (IFO != NULL) {
networkSNRsqrd += IFO->SNR * IFO->SNR;
IFO = IFO->next;
}
/* Adaptation settings */
REAL8 s_gamma = 1.0;
INT4 adaptStart = 0;
INT4 adaptTau = *((INT4 *)LALInferenceGetVariable(runState->proposalArgs, "adaptTau"));
REAL8 s_gamma = 1.0; // Sets the size of changes to jump size during adaptation
INT4 adaptStart = 0; // Keeps track of last iteration adaptation was restarted
INT4 adaptTau = *((INT4 *)LALInferenceGetVariable(runState->proposalArgs, "adaptTau")); // Sets the length and slope of adaption function
INT4 adaptResetBuffer = 100; // Number of iterations before adapting after a restart
INT4 adaptationLength = pow(10,adaptTau); // Number of iterations to adapt before turning off
ppt=LALInferenceGetProcParamVal(runState->commandLine, "--acceptanceRatio");
......@@ -253,23 +261,38 @@ void PTMCMCAlgorithm(struct tagLALInferenceRunState *runState)
}
ppt=LALInferenceGetProcParamVal(runState->commandLine, "--adapt");
if (ppt) {
adaptationOn = 1;
adapting = 1;
adaptationOn = 1; // Flag to indicate adaptation is being used during the run
adapting = 1; // Flag to indicate the current steps are being adapted
}
/* Temperature ladder settings */
INT4 checksOfPrior = 10000; // Number of random values to check for each parameter to test prior function for flatness
INT4 hotThreshold = 2; // If MPIrank > hotThreshold, use different "hot" jump proposals
INT4 TmaxSearchLen = 10000; // Lentgh of mini-MCMC testing for ideal maximum temperature
REAL8 tempMin = *(REAL8*) LALInferenceGetVariable(runState->algorithmParams, "tempMin"); //min temperature in the temperature ladder
REAL8 tempMax = *(REAL8*) LALInferenceGetVariable(runState->algorithmParams, "tempMax"); //max temperature in the temperature ladder
REAL8 tempSearchLow = tempMin; // Lower bound for maximum temperature search
REAL8 tempSearchHigh = tempMax; // Upper bound for maximum temperature search
INT4 nFlatPriorBins = 10; // Number of bins for params w/ flat priors for max temp test
REAL8 flatPriorTolerance = .3; // Percentage of tolerance allowed in each bin when testing for flatness
REAL8 tempDelta = (tempSearchHigh-tempSearchLow)/(REAL8)(nChain-1);
INT4 flatBinLow = (INT4) (TmaxSearchLen / nFlatPriorBins * (1.0 - flatPriorTolerance));
INT4 flatBinHigh = (INT4) (TmaxSearchLen / nFlatPriorBins * (1.0 + flatPriorTolerance));
REAL8 tempMin = *(REAL8*) LALInferenceGetVariable(runState->algorithmParams, "tempMin"); // Min temperature in ladder
REAL8 tempMax = 0.0;
REAL8 trigSNR = 0.0;
REAL8 targetHotLike = 25; // Targeted max 'experienced' log(likelihood) of hottest chain
INT4 hotThreshold = nChain/2; // If MPIrank > hotThreshold, use proposals with higher acceptance rates for hot chains
/* Set maximum temperature (command line value take precidence) */
if (LALInferenceGetProcParamVal(runState->commandLine,"--tempMax")) {
tempMax = *(REAL8*) LALInferenceGetVariable(runState->algorithmParams, "tempMax");
} else if (LALInferenceGetProcParamVal(runState->commandLine,"--trigSNR")) {
trigSNR = *(REAL8*) LALInferenceGetVariable(runState->algorithmParams, "trigSNR");
networkSNRsqrd = trigSNR * trigSNR;
tempMax = networkSNRsqrd/(2*targetHotLike); // If trigSNR specified, choose max temp so targetHotLike is achieved
if(MPIrank==0)
fprintf(stdout,"Trigger SNR of %f specified, setting tempMax to %f.\n", trigSNR, tempMax);
} else if (networkSNRsqrd > 0.0) {
tempMax = networkSNRsqrd/(2*targetHotLike); // If injection, choose max temp so targetHotLike is achieved
if(MPIrank==0)
fprintf(stdout,"Injecting SNR of %f, setting tempMax to %f.\n", sqrt(networkSNRsqrd), tempMax);
} else {
tempMax = *(REAL8*) LALInferenceGetVariable(runState->algorithmParams, "tempMax"); // Otherwise use default
if(MPIrank==0)
fprintf(stdout,"No --trigSNR or --tempMax specified, and not injecting a signal. Setting tempMax to default of %f.\n", tempMax);
}
LALInferenceSetVariable(runState->algorithmParams, "tempMax", &tempMax);
/* Annealing settings */
INT4 startAnnealing = 500000; // Iteration where annealing starts
......@@ -277,13 +300,13 @@ void PTMCMCAlgorithm(struct tagLALInferenceRunState *runState)
ppt=LALInferenceGetProcParamVal(runState->commandLine, "--noAnneal");
if (ppt) {
annealingOn = 0;
annealingOn = 0; // Flag to indicate annealing is being used during the run
}
/* Parallel tempering settings */
INT4 nSwaps = (nChain-1)*nChain/2; // Number of proposed swaps between temperatures
INT4 Tskip = 100; // Number of iterations between temperature swaps proposals
INT4 Tkill = startAnnealing+annealLength; // Iterature where parallel tempering ends
INT4 nSwaps = (nChain-1)*nChain/2; // Number of proposed swaps between temperatures in one swap iteration
INT4 Tskip = 100; // Number of iterations between proposed temperature swaps
INT4 Tkill = startAnnealing+annealLength; // Iterature where parallel tempering ends
if (LALInferenceGetProcParamVal(runState->commandLine,"--tempSkip"))
Tskip = atoi(LALInferenceGetProcParamVal(runState->commandLine,"--tempSkip")->value);
......@@ -342,177 +365,6 @@ void PTMCMCAlgorithm(struct tagLALInferenceRunState *runState)
LALInferenceAddVariable(runState->proposalArgs, "hotChain", &hotChain, LALINFERENCE_UINT4_t, LALINFERENCE_PARAM_OUTPUT);
REAL8 logLAtAdaptStart = runState->currentLikelihood;
if (!LALInferenceGetProcParamVal(runState->commandLine, "--noTempSearch") && nChain > 1) {
/*
* Determine how high of a temperature is needed to recover the prior.
*
* A linear temperature ladder is constructed, and the parameters with
* flat priors are binned and checked for flatness.
*/
tMaxSearch = 1;
LALInferenceSetVariable(runState->proposalArgs, "tMaxSearch", &(tMaxSearch));
/* Save values for after temperature testing */
tempCurrentParams.head = NULL;
tempCurrentParams.dimension = 0;
LALInferenceCopyVariables(runState->currentParams, &tempCurrentParams);
tempCurrentPrior = runState->currentPrior;
tempCurrentLikelihood = runState->currentLikelihood;
tempPrior = runState->prior;
tempProposal = runState->proposal;
/* Find parameters with flat prior */
INT4 nFlatPar = nPar;
flatPriorTestParams.head = NULL;
flatPriorTestParams.dimension = 0;
flatPriorParams.head = NULL;
flatPriorParams.dimension = 0;
runState->prior = &LALInferenceInspiralPriorNormalised;
runState->currentPrior = runState->prior(runState, runState->currentParams);
flatPriorTestVal = runState->currentPrior;
LALInferenceCopyVariables(runState->currentParams, &flatPriorParams);
for(p=0;p<nPar;++p){
LALInferenceCopyVariables(runState->currentParams, &flatPriorTestParams);
name = LALInferenceGetVariableName(runState->currentParams, (p+1));
sprintf(nameMin, "%s_min", name);
sprintf(nameMax, "%s_max", name);
priorMin = *((REAL8 *)LALInferenceGetVariable(runState->priorArgs, nameMin));
priorMax = *((REAL8 *)LALInferenceGetVariable(runState->priorArgs, nameMax));
dprior = priorMax - priorMin;
for(x=0;x<checksOfPrior;++x){
randVal = gsl_rng_uniform(runState->GSLrandom);
paramVal = priorMin + randVal * (priorMax - priorMin);
LALInferenceSetVariable(&flatPriorTestParams, name, &paramVal);
flatPriorTestVal = runState->prior(runState, &flatPriorTestParams);
if(flatPriorTestVal != runState->currentPrior){
LALInferenceRemoveVariable(&flatPriorParams, name);
nFlatPar -= 1;
break;
}
}
}
/* Construct temporary linear temperature ladder to probe for best max temp */
for(t=0; t<nChain; ++t){
tempLadder[t]=tempSearchLow+t*tempDelta;
}
LALInferenceSetVariable(runState->proposalArgs, "temperature", &(tempLadder[MPIrank]));
/* Use specialized jump proposal and run a short MCMC */
if(MPIrank==0)
fprintf(stdout,"Running exploratory MCMC to determine best temperature ladder.\n");
runState->proposal = &LALInferencePTTempTestProposal;
runState->prior = &LALInferenceInspiralPriorNormalised;
pdf=(INT4**)calloc(nPar,sizeof(INT4 *));
while (tMaxSearch == 1) {
for(p=0;p<nPar;++p){
pdf[p]=calloc(10,sizeof(INT4));
for(x=0;x<10;++x){
pdf[p][x]=0;
}
}
for(i=0;i<TmaxSearchLen;++i){
PTMCMCOneStep(runState);
ptr=runState->currentParams->head;
p=0;
while(ptr!=NULL) {
if (ptr->vary != LALINFERENCE_PARAM_FIXED) {
parameters->data[p]=*(REAL8 *)ptr->value;
p++;
}
ptr=ptr->next;
}
/* Bin parameterm values */
for (p=0;p<nPar;++p){
name = LALInferenceGetVariableName(runState->currentParams, (p+1));
sprintf(nameMin, "%s_min", name);
sprintf(nameMax, "%s_max", name);
priorMin = *((REAL8 *)LALInferenceGetVariable(runState->priorArgs, nameMin));
priorMax = *((REAL8 *)LALInferenceGetVariable(runState->priorArgs, nameMax));
dprior = priorMax - priorMin;
x=(int)(((parameters->data[p] - priorMin)/dprior)*nFlatPriorBins);
if(x<0) x=0;
if(x>nFlatPriorBins-1) x=nFlatPriorBins-1;
pdf[p][x]++;
}
}//for(i=0;i<TmaxSearchLen;++i)
/* Check for flat PDFs in parameters w/ flat priors */
param_count=0;
for (p=0;p<nPar;++p){
name = LALInferenceGetVariableName(runState->currentParams, (p+1));
if(LALInferenceCheckVariable(&flatPriorParams, name)){
pdf_count=0;
for(x=0;x<nFlatPriorBins;++x){
if(pdf[p][x]<flatBinLow || pdf[p][x]>flatBinHigh) pdf_count++;
}
if(pdf_count==0) param_count++;
}
}
if (param_count == nFlatPar) {
acceptanceCount = 1;
} else {
acceptanceCount = 0;
}
MPI_Allgather(&acceptanceCount, 1, MPI_INT, acceptanceCountLadder, 1, MPI_INT, MPI_COMM_WORLD);
UINT4 recoveredPrior = 0;
tempMax = tempLadder[nChain-1];
for (i=0;i<nChain;++i) {
if (acceptanceCountLadder[i]) {
if (!recoveredPrior) {
recoveredPrior = 1;
tempMax = tempLadder[i];
tMaxSearch = 0;
}
} else {
if (recoveredPrior) {
if(MPIrank==0)
fprintf(stdout,"Inconsistent temperature performance, possibly due to stuck chain. Re-running exploritory MCMC");
recoveredPrior = 0;
tempMax = tempLadder[nChain-1];
tMaxSearch = 1;
break;
}
}
}
MPI_Barrier(MPI_COMM_WORLD);
} //while (tMaxSearch == 1)
if (tempMax == tempLadder[nChain-1])
fprintf(stdout,"WARNING: The search set max temperature to the maximum allowed temperature (%f). \
This may be insufficient. Recommend allowing higher temperatures using --Tmax=<Tmax>.\n",tempMax);
LALInferenceSetVariable(runState->algorithmParams, "tempMax", &(tempMax));
LALInferenceSetVariable(runState->proposalArgs, "tMaxSearch", &(tMaxSearch));
/* Reset values to those before temperature search */
LALInferenceCopyVariables(&tempCurrentParams, runState->currentParams);
runState->currentLikelihood = tempCurrentLikelihood;
runState->prior = tempPrior;
runState->currentPrior = tempCurrentPrior;
runState->proposal = tempProposal;
acceptanceCount = 0;
LALInferenceSetVariable(runState->proposalArgs, "acceptanceCount", &(acceptanceCount));
LALInferenceDeleteProposalCycle(runState);
for (t=0; t<nChain; ++t) {
acceptanceCountLadder[t] = 0;
}
}//if(!LALInferenceGetProcParamVal(runState->commandLine, "--noTempSearch") && nChain>1)
/* Construct temperature ladder */
if(nChain > 1){
if(LALInferenceGetProcParamVal(runState->commandLine, "--inverseLadder")){ //temperature spacing uniform in 1/T
......@@ -523,11 +375,9 @@ void PTMCMCAlgorithm(struct tagLALInferenceRunState *runState)
}
else if(LALInferenceGetProcParamVal(runState->commandLine, "--geomLadder")){ //Geometric spacing (most efficient so far. Should become default?
tempDelta=pow(tempMax,1.0/(REAL8)(nChain-1));
//tempDelta=pow(tempMax,1.0/(REAL8)(nChain+4-1));
for (t=0;t<nChain; ++t) {
tempLadder[t]=pow(tempDelta,t);
annealDecay[t] = (tempLadder[t]-1)/(REAL8)annealLength;
//annealDecay[t] = t*(REAL8)annealLength / (((t-1)/(nChain-1))*log(tempMax - 1)+1);
}
}
else{ //epxonential spacing
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment