Commit d94adeb8 authored by Richard O'Shaughnessy's avatar Richard O'Shaughnessy

convergence_test_samples.py: fix 1d case with 'lame' option

parent 67864704
......@@ -50,6 +50,13 @@ def calc_kl(mu_1, mu_2, sigma_1, sigma_2, sigma_1_inv, sigma_2_inv):
"""
return 0.5*(np.trace(np.dot(sigma_2_inv,sigma_1))+np.dot(np.dot((mu_2-mu_1).T, sigma_2_inv), (mu_2-mu_1))-len(mu_1)+np.log(la.det(sigma_2)/la.det(sigma_1)))
def calc_kl_scalar(mu_1, mu_2, sigma_1, sigma_2):
"""
calc_kl : KL divergence for two gaussians. sigma_1, and sigma_2 are the covariance matricies.
"""
return np.log(sigma_2/sigma_1) -0.5 +( (mu_1-mu_2)**2 + sigma_1**2)/(2*sigma_2**2)
def test_lame(dat1,dat2):
"""
Compute a multivariate gaussian estimate (sample mean and variance), and then use KL divergence between them !
......@@ -58,8 +65,11 @@ def test_lame(dat1,dat2):
mu_2 = np.mean(dat2,axis=0)
sigma_1 = np.cov(dat1.T)
sigma_2 = np.cov(dat2.T)
sigma_1_inv = np.linalg.inv(sigma_1)
sigma_2_inv = np.linalg.inv(sigma_2)
if np.isscalar(mu_1) or len(mu_1)==1:
return np.asscalar(calc_kl_scalar(mu_1, mu_2, sigma_1, sigma_2))
else:
sigma_1_inv = np.linalg.inv(sigma_1)
sigma_2_inv = np.linalg.inv(sigma_2)
return calc_kl(mu_1,mu_2, sigma_1, sigma_2, sigma_1_inv, sigma_2_inv)
def test_ks1d(dat1_1d, dat2_1d):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment