Commit b737add1 authored by Marcella Wijngaarden's avatar Marcella Wijngaarden
Browse files

CBC extrinsic burnin seperated

parent 941f7aa5
/*******************************************************************************************
Copyright (c) 2019 Neil Cornish
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
**********************************************************************************************/
#include <stdio.h>
#include <math.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <omp.h>
#include "BayesCBC.h"
#include "BayesCBC_internals.h"
#include "Constants.h"
#include "IMRPhenomD.h"
#include <lal/LALDetectors.h>
#include <lal/Date.h>
#include <lal/TimeDelay.h>
#include <lal/DetectorSite.h>
#include <lal/DetResponse.h>
// #include "BayesWave.h"
/*******************************************************************************************
This file contains the extrinsic routines for the BayesCBC code. Unlike main BayesCBC, these
routines do depend on BW/LAL because we want to use the BW/LAL timedelays/projections/etc.- routines
**********************************************************************************************/
void skymcmc(struct Net *net, int MCX, int *mxc, FILE *chain, double **paramx, double **skyx, double **pallx, int *who, double *heat, double dtx, int nt, double *DD, double **WW, double ***DHc, double ***DHs, double ***HH, double Tobs, gsl_rng * r, struct bayesCBC *rundata)
{
int i, j, k, q, ic, id1, id2;
int scount, sacc, hold, mcount;
int ac, rc, rca, clc, cla, PRcnt, POcnt;
int sdx, Ax, mc;
double alpha, beta;
double Mchirp, Mtot, eta, dm, m1, m2, chieff, ciota;
double qxy, qyx, Jack, phi;
int rflag, cflag;
double **sky, **skyy, **skyh;
double x, y, z, DL, scale, logLy;
double *logLx;
double pAy, logH, pAx;
double *param;
double *dtimes, *dtimes2;
double ***fishskyx, ***fishskyy;
double ***skyvecsx, ***skyvecsy;
double **skyevalsx, **skyevalsy;
double Fp, Fc, Fs, ps;
double *jump, *sqH;
double ldetx, ldety;
double scmax, scmin;
double DLx, DLy;
int fflag, fc, fac;
int uflag, uc, uac;
double Ap, Ac, Fcross, Fplus, lambda, lambda2, Fs2;
double sindelta, psi;
int NX = rundata->NX;
int NP = rundata->NP;
int NS = rundata->NS;
int NC = rundata->NC; // number of chains
int NCC = rundata->NCC; // number of cold chains
int NH = rundata->NH; // length of history
int NQ = rundata->NQ; // number of mass ratios in global proposal
int NM = rundata->NM; // number of chirp masses in global proposal
dtimes = (double*)malloc(sizeof(double)*5);
dtimes2 = (double*)malloc(sizeof(double)*5);
param = (double*)malloc(sizeof(double)*(NX+3*net->Nifo));
// sky parameter order
//[0] alpha, [1] sin(delta) [2] psi [3] ciota [4] scale [5] phi0 [6] dt
// max and min of rescaling parameter
scmin = 0.1;
scmax = 10.0;
sky = double_matrix(NC,NS);
skyh = double_matrix(NC,NS);
skyy = double_matrix(NC,NS);
logLx = double_vector(NC);
sqH = double_vector(NC);
for(k = 0; k < NC; k++)
{
for(i = 0; i < NS; i++) sky[k][i] = skyx[k][i];
for(i = 0; i < NS; i++) skyh[k][i] = skyx[k][i];
}
for(k = 0; k < NC; k++) sqH[k] = sqrt(heat[k]);
ic = who[0];
for(k = 0; k < NC; k++) logLx[k] = skylike(net, skyx[k], DD, WW[k], DHc[k], DHs[k], dtx, nt, 0, rundata);
fishskyx = double_tensor(NC,NS,NS);
fishskyy = double_tensor(NC,NS,NS);
skyvecsx = double_tensor(NC,NS,NS);
skyvecsy = double_tensor(NC,NS,NS);
skyevalsx = double_matrix(NC,NS);
skyevalsy = double_matrix(NC,NS);
jump = double_vector(NS);
for(k = 0; k < NC; k++)
{
fisher_matrix_fastsky(net, skyx[k], fishskyx[k], HH[k], NS, rundata->gmst);
FisherEvec(fishskyx[k], skyevalsx[k], skyvecsx[k], NS);
}
printf("Extrinsic MCMC\n");
ac = 0;
rc = 1;
rca = 0;
fc = 1;
uc = 1;
clc = 0;
cla = 0;
fac = 0;
uac = 0;
PRcnt = 0;
POcnt = 0;
sdx = 0.0;
Ax = 0.0;
scount = 0;
sacc = 0;
mcount = 0;
for(mc = 0; mc < MCX; mc++)
{
if(mc > 1 && mc%1000==0)
{
// update the Fisher matrices
for(k = 0; k < NC; k++)
{
fisher_matrix_fastsky(net, skyx[k], fishskyx[k], HH[k], NS, rundata->gmst);
FisherEvec(fishskyx[k], skyevalsx[k], skyvecsx[k], NS);
}
}
alpha = gsl_rng_uniform(r);
if((NC > 1) && (alpha < 0.2)) // decide if we are doing a MCMC update of all the chains or a PT swap
{
// chain swap
scount++;
alpha = (double)(NC-1)*gsl_rng_uniform(r);
j = (int)(alpha);
beta = exp((logLx[who[j]]-logLx[who[j+1]])/heat[j+1] - (logLx[who[j]]-logLx[who[j+1]])/heat[j]);
alpha = gsl_rng_uniform(r);
if(beta > alpha)
{
hold = who[j];
who[j] = who[j+1];
who[j+1] = hold;
sacc++;
}
}
else // MCMC update
{
mcount++;
for(k = 0; k < NC; k++)
{
for(i = 0; i < NS; i++) skyy[k][i] = skyx[k][i];
}
for(k=0; k < NC; k++)
{
q = who[k];
qxy = 0.0;
qyx = 0.0;
Jack = 0.0;
rflag = 0;
cflag = 0;
fflag = 0;
uflag = 0;
alpha = gsl_rng_uniform(r);
// if(alpha > 0.80 && net->Nifo > 1) // ring
if(alpha > 1.0 && net->Nifo > 1) // ring
{
id1 = 0;
id2 = 1;
// Pick a pair of interferometers to define sky ring
if(net->Nifo > 2)
{
id1 = (int)((double)(net->Nifo)*gsl_rng_uniform(r));
do
{
id2 = (int)((double)(net->Nifo)*gsl_rng_uniform(r));
}while(id1==id2);
}
// map these labels to actual detectors
id1 = net->labels[id1];
id2 = net->labels[id2];
// get new sky locations with the same time delays
Ring(net, skyx[q], skyy[q], id1, id2, r, rundata->gmst);
// for(i = 0; i < NS; i++) skyy[q][i] = skyx[q][i];
qyx = 1.0;
qxy = 1.0;
skymap(net, skyx[q], skyy[q], id1, id2, net->labels[0], rundata->gmst);
Jack = log(skydensity(net, skyx[q], skyy[q], id1, id2, net->labels[0], NS, rundata->gmst));
// The mapping only overs half the phi, psi space. Can cover it all by radomly shifting both by a half period
x = gsl_rng_uniform(r);
if(x > 0.5)
{
skyy[q][5] += PI;
skyy[q][2] += PI/2.0;
}
if(k==0) rc++;
rflag = 1;
}
else if (alpha > 0.2) // Fisher matrix
{
if(k==0) fc++;
fflag = 1;
fisher_skyproposal(r, skyvecsx[q], skyevalsx[q], jump, rundata);
for(i=0; i< NS; i++) skyy[q][i] = skyx[q][i]+sqH[k]*jump[i];
// If the Fisher matrix was updated after each Fisher jump we would
// need these proposal densities. Since Fisher held fixed for blocks
// of iterations, we don't need the densities
// pfishxy = fisher_density(fishskyx, ldetx, skyx, skyy);
// fisher_matrix_fastsky(net, skyy, fishskyy, HH);
// fisher_skyvectors(fishskyy, skyvecsy, skyevalsy, &ldety);
// pfishyx = fisher_density(fishskyy, ldety, skyy, skyx);
}
else // jiggle (most useful early when Fisher not effective)
{
uflag = 1;
if(k==0) uc++;
beta = 0.01*pow(10.0, -floor(3.0*gsl_rng_uniform(r)))*sqH[k];
for(i = 0; i < NS-1; i++) skyy[q][i] = skyx[q][i]+beta*gsl_ran_gaussian(r,1.0);
skyy[q][6] = skyx[q][6]+0.01*beta*gsl_ran_gaussian(r,1.0);
}
if(skyy[q][0] > TPI) skyy[q][0] -= TPI;
if(skyy[q][0] < 0.0) skyy[q][0] += TPI;
if(skyy[q][2] > PI) skyy[q][2] -= PI;
if(skyy[q][2] < 0.0) skyy[q][2] += PI;
if(skyy[q][5] > TPI) skyy[q][5] -= TPI;
if(skyy[q][5] < 0.0) skyy[q][5] += TPI;
//[0] alpha, [1] sin(delta) [2] psi [3] cos(iota) [4] scale [5] phi0 [6] dt
DLy = exp(pallx[q][6])/(skyy[q][4]*PC_SI);
if(DLy < rundata->DLmin || DLy > rundata->DLmax || fabs(skyy[q][1]) > 1.0 || fabs(skyy[q][3]) > 1.0 || fabs(skyy[q][6]) > dtmax)
{
logLy = -1.0e60;
pAy = 0.0;
pAx = 0.0;
}
else
{
logLy = skylike(net, skyy[q], DD, WW[q], DHc[q], DHs[q], dtx, nt, 0, rundata);
// Need a Jacobian a factor here since we sample uniformly in amplitude.
// Jacobian between D cos(theta) phi cos(iota) psi and A cos(theta) phi cos(iota) psi
// Since D = D_0/A, boils down to just D^2 |dD/dA| = D_0^3/A^4 = D^3/A
DLx = exp(pallx[q][6])/(skyx[q][4]);
pAx = 3.0*log(DLx)-log(skyx[q][4]);
DLy = exp(pallx[q][6])/(skyy[q][4]);
pAy = 3.0*log(DLy)- log(skyy[q][4]);
}
logH = Jack + (logLy-logLx[q])/heat[k] +pAy-qyx-pAx+qxy;
alpha = log(gsl_rng_uniform(r));
if(logH > alpha)
{
for(i=0; i< NS; i++) skyx[q][i] = skyy[q][i];
logLx[q] = logLy;
if(k==0)
{
ac++;
if(rflag == 1) rca++;
if(fflag == 1) fac++;
if(uflag == 1) uac++;
}
}
} // ends loop over chains
} // ends choice of update
/*
if(mc%100 == 0)
{
ic = who[1];
phi = skyx[ic][0];
if(phi > TPI) phi -= TPI;
if(phi < 0.0) phi += TPI;
skyx[ic][0] = phi;
Mchirp = exp(paramx[ic][0])/MSUN_SI;
Mtot = exp(paramx[ic][1])/MSUN_SI;
eta = pow((Mchirp/Mtot), (5.0/3.0));
dm = sqrt(1.0-4.0*eta);
m1 = Mtot*(1.0+dm)/2.0;
m2 = Mtot*(1.0-dm)/2.0;
chieff = (m1*paramx[ic][2]+m2*paramx[ic][3])/Mtot;
// counter, log likelihood, chirp mass, total mass, effective spin, phase shift , time shift, distance, RA , sine of DEC,
// polarization angle, cos inclination
DL = exp(pallx[ic][6])/(1.0e6*PC_SI*skyx[ic][4]);
z = z_DL(DL);
// Note that skyx[ic][5], skyx[ic][6], hold different quantities than what is printed by the other MCMCs
fprintf(chain,"%d %e %e %e %e %e %e %e %e %e %e %e %e %e %e %e %e %e %e\n", mxc[1], logLx[ic], Mchirp, Mtot, chieff, skyx[ic][5], \
skyx[ic][6], DL, skyx[ic][0], skyx[ic][1], \
skyx[ic][2], skyx[ic][3], z, Mchirp/(1.0+z), Mtot/(1.0+z), m1/(1.0+z), m2/(1.0+z), m1, m2);
mxc[1] += 1;
} */
if(mc%10000 == 0)
{
ic = who[0];
DL = exp(pallx[ic][6])/(1.0e6*PC_SI*skyx[ic][4]);
printf("%d %i %f %f %f %f %f %f\n", mc, ic, logLx[ic], DL, skyx[ic][0], skyx[ic][1], skyx[ic][2], skyx[ic][3]);
}
}
// update the amplitude, time and phase shifts between detectors in preparation for extrinsic updates
for(k=0; k < NC; k++) dshifts(net, skyx[k], paramx[k], NX, rundata);
// sky [0] alpha, [1] sin(delta) [2] psi [3] ciota [4] scale [5] dphi [6] dt
// param [0] log(Mc) [1] log(Mt) [2] chi1 [3] chi2 [4] phi0 [5] tp0 [6] log(DL0) then relative amplitudes, time, phases
// update the extrinsic parameters
for(k = 0; k < NC; k++)
{
// move reference point from geocenter to ref detector
// Note that sky[4],sky[5], sky[6] hold shifts relative to the reference geocenter waveform
// To map back to the reference detector
ciota = skyh[k][3];
Ap = (1.0+ciota*ciota)/2.0;
Ac = ciota;
alpha = skyh[k][0];
sindelta = skyh[k][1];
psi = skyh[k][2];
ComputeDetFant(net, psi, alpha, sindelta, &Fplus, &Fcross, net->labels[0], rundata->gmst);
Fs = sqrt(Ap*Ap*Fplus*Fplus+Ac*Ac*Fcross*Fcross);
lambda = atan2(Ac*Fcross,Ap*Fplus);
if(lambda < 0.0) lambda += TPI;
TimeDelays(net, alpha, sindelta, dtimes, rundata->gmst);
ciota = skyx[k][3];
Ap = (1.0+ciota*ciota)/2.0;
Ac = ciota;
alpha = skyx[k][0];
sindelta = skyx[k][1];
psi = skyx[k][2];
ComputeDetFant(net, psi, alpha, sindelta, &Fplus, &Fcross, net->labels[0], rundata->gmst);
Fs2 = sqrt(Ap*Ap*Fplus*Fplus+Ac*Ac*Fcross*Fcross);
lambda2 = atan2(Ac*Fcross,Ap*Fplus);
if(lambda2 < 0.0) lambda2 += TPI;
TimeDelays(net, alpha, sindelta, dtimes2, rundata->gmst);
paramx[k][4] += 0.5*(skyx[k][5]+lambda2-lambda);
while(paramx[k][4] > PI) paramx[k][4] -= PI;
while(paramx[k][4] < 0.0) paramx[k][4] += PI;
paramx[k][5] += skyx[k][6]+dtimes2[net->labels[0]]-dtimes[net->labels[0]];
paramx[k][6] -= log(skyx[k][4]*Fs2/Fs);
// sky will be re-aligned with geocenter so reset
skyx[k][4] = 1.0;
skyx[k][5] = 0.0;
skyx[k][6] = 0.0;
}
printf("Swap Acceptance = %f\n", (double)sacc/(double)(scount));
printf("MCMC Acceptance = %f\n", (double)ac/(double)(mcount));
printf("Ring Acceptance = %f\n", (double)rca/(double)(rc));
printf("Fisher Acceptance = %f\n", (double)fac/(double)(fc));
printf("Jiggle Acceptance = %f\n", (double)uac/(double)(uc));
free_double_matrix(skyy,NC);
free_double_matrix(skyh,NC);
free_double_matrix(sky,NC);
free_double_vector(logLx);
free_double_vector(sqH);
free_double_tensor(fishskyx,NC,NS);
free_double_tensor(fishskyy,NC,NS);
free_double_tensor(skyvecsx,NC,NS);
free_double_tensor(skyvecsy,NC,NS);
free_double_matrix(skyevalsx,NC);
free_double_matrix(skyevalsy,NC);
free_double_vector(jump);
free(dtimes);
free(dtimes2);
free(param);
return;
}
double skylike(struct Net *net, double *params, double *D, double *H, double **DHc, double **DHs, double dt, int nt, int flag, struct bayesCBC *rundata)
{
int i, j, k, l, id;
double alpha, sindelta, dphi, t0, A, FA, ecc, ciota, psi;
double *dtimes, *F, *lambda;
double tdelay, tx, toff, dc, ds, DH;
double dcc, dss;
double Fplus, Fcross;
double Ap, Ac;
double clam, slam, x;
double cphi, sphi;
double logL;
logL = 0.0;
if(rundata->constantLogLFlag == 0)
{
dtimes = (double*)malloc(sizeof(double)* 5);
F = (double*)malloc(sizeof(double)* 5);
lambda = (double*)malloc(sizeof(double)* 5);
k = (nt-1)/2;
alpha = params[0];
sindelta = params[1];
psi = params[2];
ciota = params[3];
Ap = (1.0+ciota*ciota)/2.0;
Ac = ciota;
A = params[4];
dphi = params[5];
toff = params[6];
cphi = cos(dphi);
sphi = sin(dphi);
TimeDelays(net, alpha, sindelta, dtimes, rundata->gmst);
for(id = 0; id<net->Nifo; id++)
{
ComputeDetFant(net, psi, alpha, sindelta, &Fplus, &Fcross, net->labels[id], rundata->gmst);
// XLALComputeDetAMResponse(&Fplus, &Fcross, (const REAL4(*)[3]) net->response[id], alpha, asin(sindelta), psi, rundata->gmst);
F[id] = sqrt(Ap*Ap*Fplus*Fplus+Ac*Ac*Fcross*Fcross); // magnitude of response
lambda[id] = atan2(Ac*Fcross,Ap*Fplus);
if(lambda[id] < 0.0) lambda[id] += TPI;
}
logL = 0.0;
// Note that F[] and lambda[] have already used the label array to get the correct detector ordering while dtimes uses a fixed ordering
// so needs the label array to get the correct delays
for(id = 0; id<net->Nifo; id++)
{
// everything is reference to geocenter
tdelay = toff+dtimes[net->labels[id]];
i = (int)(floor(tdelay/dt));
tx = (tdelay/dt - (double)(i));
// if (flag == 1) printf("%d %d %d %f %f\n", i, i+k, nt, tx, tdelay/dt);
// printf("toff %f, dtimes[%i] = %f (label= %i)\n", toff, id, dtimes[net->labels[id]], net->labels[id]);
// printf("%d %d %d %f %f\n", i, i+k, nt, tx, tdelay/dt);
// linear interpolation
i += k;
slam = sin(lambda[id]);
clam = cos(lambda[id]);
FA = A*F[id];
if(i >= 0 && i < nt-2)
{
dc = DHc[id][i]*(1.0-tx)+DHc[id][i+1]*tx;
ds = DHs[id][i]*(1.0-tx)+DHs[id][i+1]*tx;
// put in phase rotation
dcc = cphi*dc+sphi*ds;
dss = -sphi*dc+cphi*ds;
DH = clam*dcc+slam*dss;
// relative likelihood
x = -(FA*FA*H[id]-2.0*FA*DH)/2.0;
// if (flag == 1) printf("%d %f %f\n", id, tdelay, x);
// printf("%d %f %f\n", id, tdelay, x);
logL += x;
}
else
{
logL -= 1.0e10;
}
}
free(dtimes);
free(F);
free(lambda);
}
return(logL);