import matplotlib as mpl
mpl.use('Agg')
import bilby
from pandas import read_csv
import corner
import matplotlib.pyplot as plt
import numpy as np
import json
import sys
import pdb
from scipy.stats import ks_2samp

def main(bilby_result, lalinference_result, bilby_prior_samples=None,
        lalinference_prior_samples=None):
    #load in the data
    bilby_data = bilby.gw.result.CBCResult.from_json(filename=bilby_result)
    lalinf_data = read_csv(lalinference_result, header=0,
            delim_whitespace=True)
    if bilby_prior_samples!=None:
        bilby_prior_data = json.load(open(bilby_prior_samples,'r'))
    if lalinference_prior_samples!=None:
        lalinf_prior_data = read_csv(lalinference_prior_samples,
                header=0, delim_whitespace=True)
    #check if the two runs have the same number of spline points
    n_spline_points_bilby = len([i for i in bilby_data.posterior.keys() if
        i.startswith('recalib_H1_frequency')])
    n_spline_points_lalinf = len([i for i in lalinf_data.keys() if
        i.startswith('h1_spcal_freq')])
    if n_spline_points_bilby != n_spline_points_lalinf:
        print('Bilby run used {} spline points while LALInference had {}.\n'.format(n_spline_points_bilby,n_spline_points_lalinf)+
                'Unable to continue with comparison.')
        return

    #set the default plotting settings
    plotting_kwargs = dict(bins=50, smooth=0.9, label_kwargs=dict(fontsize=16),
            title_kwargs=dict(fontsize=16),
            truth_color='tab:orange', quantiles=[0.16, 0.84],
            levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.)),
            plot_density=False, plot_datapoints=False, no_fill_contours=True,
            max_n_ticks=3, show_titles=True)

    #get the calibration posteriors by detector
    spline_points_bilby = dict()
    spline_points_lalinf = dict()
    ks_stats_amplitude = dict()
    ks_stats_phase = dict()
    ks_stats_amplitude_prior = dict()
    ks_stats_phase_prior = dict()
    for ifo in bilby_data.meta_data['command_line_args']['detectors']:
        spline_points_bilby[ifo] = []
        spline_points_lalinf[ifo] = []
        ks_stats_amplitude[ifo] = []
        ks_stats_phase[ifo] = []
        ks_stats_amplitude_prior[ifo] = []
        ks_stats_phase_prior[ifo] = []
        amplitude_post_bilby = []
        amplitude_post_lalinf = []
        phase_post_bilby = []
        phase_post_lalinf = []
        amplitude_prior_bilby = []
        phase_prior_bilby = []
        amplitude_prior_lalinf = []
        phase_prior_lalinf = []
        amplitude_labels = []
        phase_labels = []
        for i in range(n_spline_points_bilby):
            spline_points_bilby[ifo].append(bilby_data.posterior['recalib_'+ifo+'_frequency_'+str(i)][0])
            spline_points_lalinf[ifo].append(lalinf_data[ifo.lower()+'_spcal_freq_'+str(i)][0])
            amplitude_post_bilby.append(bilby_data.posterior['recalib_'+ifo+'_amplitude_'+str(i)])
            amplitude_post_lalinf.append(lalinf_data[ifo.lower()+'_spcal_amp_'+str(i)])
            phase_post_bilby.append(bilby_data.posterior['recalib_'+ifo+'_phase_'+str(i)])
            phase_post_lalinf.append(lalinf_data[ifo.lower()+'_spcal_phase_'+str(i)])
            if bilby_prior_samples!=None: 
                amplitude_prior_bilby.append(bilby_prior_data['recalib_'+ifo+'_amplitude_'+str(i)])
                phase_prior_bilby.append(bilby_prior_data['recalib_'+ifo+'_phase_'+str(i)])
            if lalinference_prior_samples!=None: 
                amplitude_prior_lalinf.append(lalinf_prior_data[ifo.lower()+'_spcal_amp_'+str(i)])
                phase_prior_lalinf.append(lalinf_prior_data[ifo.lower()+'_spcal_phase_'+str(i)])
            amplitude_labels.append(ifo+'_amplitude_'+str(i))
            phase_labels.append(ifo+'_phase_'+str(i))
            
            #perform ks test
            ks_stats = ks_2samp(amplitude_post_bilby[-1], amplitude_post_lalinf[-1])
            ks_stats_amplitude[ifo].append(ks_stats)
            ks_stats = ks_2samp(phase_post_bilby[-1], phase_post_lalinf[-1])
            ks_stats_phase[ifo].append(ks_stats)

            if (bilby_prior_samples!=None) and (lalinference_prior_samples!=None):
                ks_stats = ks_2samp(amplitude_prior_bilby[-1], amplitude_prior_lalinf[-1])
                ks_stats_amplitude_prior[ifo].append(ks_stats)
                ks_stats = ks_2samp(phase_prior_bilby[-1], phase_prior_lalinf[-1])
                ks_stats_phase_prior[ifo].append(ks_stats)
       

        #plot the two posteriors for amplitude and phase separately
        fig = corner.corner(np.transpose(amplitude_post_bilby),
                labels=amplitude_labels, color='r',
                hist_kwargs=dict(density=True, color='r'), **plotting_kwargs)
        corner.corner(np.transpose(amplitude_post_lalinf), fig=fig,
                color='b', hist_kwargs=dict(density=True, color='b'), **plotting_kwargs)
        if lalinference_prior_samples!=None:
            corner.corner(np.transpose(amplitude_prior_lalinf), fig=fig,
                    color='c', hist_kwargs=dict(density=True, color='c'), **plotting_kwargs)
        if bilby_prior_samples!=None:
            corner.corner(np.transpose(amplitude_prior_bilby), fig=fig,
                    color='g', hist_kwargs=dict(density=True, color='g'), **plotting_kwargs)
        plt.savefig(ifo+'_amplitude_post.png')
        fig2 = corner.corner(np.transpose(phase_post_bilby),
                labels=phase_labels, color='r', hist_kwargs=dict(density=True,
                    color='r'), **plotting_kwargs)
        corner.corner(np.transpose(phase_post_lalinf), fig=fig2,
                color='b', hist_kwargs=dict(density=True, color='b'), **plotting_kwargs)
        if lalinference_prior_samples!=None:
            corner.corner(np.transpose(phase_prior_lalinf), fig=fig2,
                    color='c', hist_kwargs=dict(density=True, color='c'), **plotting_kwargs)
        if bilby_prior_samples!=None:
            corner.corner(np.transpose(phase_prior_bilby), fig=fig2,
                    color='g', hist_kwargs=dict(density=True, color='g'), **plotting_kwargs)
        plt.savefig(ifo+'_phase_post.png')
   
    #save the spline points and stats files
    ks_stats = {'amplitude': ks_stats_amplitude, 'phase': ks_stats_phase}
    json.dump(ks_stats, open('ks_stats.json','w'))
    if (bilby_prior_samples!=None) and (lalinference_prior_samples!=None):
        ks_stats_prior = {'amplitude': ks_stats_amplitude_prior, 'phase': ks_stats_phase_prior}
        json.dump(ks_stats, open('ks_stats_prior.json','w'))
    spline_points = {'bilby': spline_points_bilby, 'lalinf': spline_points_lalinf}
    json.dump(spline_points, open('spline_points.json','w'))

if __name__ == "__main__":
    if len(sys.argv)==3:
        main(sys.argv[1], sys.argv[2])
    else:
        main(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4])