Skip to content
Snippets Groups Projects
Commit 7f049938 authored by Charlie Hoy's avatar Charlie Hoy
Browse files

Allow for the Jensen Shannon method to accept different kde classes

parent f695ce22
No related branches found
No related tags found
No related merge requests found
......@@ -71,9 +71,9 @@ class Bounded_1d_kde(kde):
out_of_bounds = np.zeros(pts.shape[0], dtype='bool')
if self.xlow is not None:
out_of_bounds[pts[:, 0] < self.xlow] = True
out_of_bounds[pts < self.xlow] = True
if self.xhigh is not None:
out_of_bounds[pts[:, 0] > self.xhigh] = True
out_of_bounds[pts > self.xhigh] = True
results = self.evaluate(pts)
results[out_of_bounds] = 0.
......
......@@ -742,12 +742,29 @@ def kolmogorov_smirnov_test(samples, decimal=5):
return np.round(stats.ks_2samp(*samples)[1], decimal)
def jension_shannon_divergence(samples, decimal=5):
def jension_shannon_divergence(
samples, kde=stats.gaussian_kde, decimal=5, base=np.e, **kwargs
):
"""Calculate the JS divergence between two sets of samples
Parameters
----------
samples: list
2d list containing the samples drawn from two pdfs
kde: func
function to use when calculating the kde of the samples
decimal: int, float
number of decimal places to round the JS divergence to
base: float, optional
optional base to use for the scipy.stats.entropy function. Default
np.e
kwargs: dict
all kwargs are passed to the kde function
"""
try:
kernel = [stats.gaussian_kde(i) for i in samples]
kernel = [kde(i, **kwargs) for i in samples]
except np.linalg.LinAlgError:
return float("nan")
a, b = kernel
x = np.linspace(
np.min([np.min(i) for i in samples]),
np.max([np.max(i) for i in samples]),
......@@ -759,8 +776,8 @@ def jension_shannon_divergence(samples, decimal=5):
a /= a.sum()
b /= b.sum()
m = 1. / 2 * (a + b)
kl_forward = stats.entropy(a, qk=m)
kl_backward = stats.entropy(b, qk=m)
kl_forward = stats.entropy(a, qk=m, base=base)
kl_backward = stats.entropy(b, qk=m, base=base)
return np.round(kl_forward / 2. + kl_backward / 2., decimal)
......
# Copyright (C) 2018 Charlie Hoy <charlie.hoy@ligo.org>
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from pesummary.core.plots.bounded_1d_kde import Bounded_1d_kde
from pesummary.gw.plots.bounded_2d_kde import Bounded_2d_kde
from scipy.stats import gaussian_kde
import numpy as np
class TestBounded_kde(object):
"""Test the Bounded_1d_kde function
"""
def test_bounded_1d_kde(self):
samples = np.array(np.random.uniform(10, 5, 1000))
x_low = 9.5
x_high = 10.5
scipy = gaussian_kde(samples)
bounded = Bounded_1d_kde(samples, xlow=x_low, xhigh=x_high)
assert scipy(9.45) != 0.
assert bounded(9.45) == 0.
assert scipy(10.55) != 0.
assert bounded(10.55) == 0.
def test_bounded_2d_kde(self):
samples = np.array([
np.random.uniform(10, 5, 1000),
np.random.uniform(5, 2, 1000)
])
x_low = 9.5
x_high = 10.5
y_low = 4.5
y_high = 5.5
scipy = gaussian_kde(samples)
bounded = Bounded_2d_kde(
samples.T, xlow=x_low, xhigh=x_high, ylow=y_low, yhigh=y_high
)
assert scipy([9.45, 4.45]) != 0.
assert bounded([9.45, 4.45]) == 0.
assert scipy([9.45, 5.55]) != 0.
assert bounded([9.45, 5.55]) == 0.
assert scipy([10.55, 4.45]) != 0.
assert bounded([10.55, 4.45]) == 0.
assert scipy([10.55, 5.55]) != 0.
assert bounded([10.55, 5.55]) == 0.
......@@ -563,6 +563,31 @@ class TestArray(object):
)
def test_jensen_shannon_divergence():
"""Test that the `jension_shannon_divergence` method returns the same
values as the scipy function
"""
from scipy.spatial.distance import jensenshannon
from scipy import stats
samples = [
np.random.uniform(5, 4, 100),
np.random.uniform(5, 4, 100)
]
x = np.linspace(np.min(samples), np.max(samples), 100)
kde = [stats.gaussian_kde(i)(x) for i in samples]
_scipy = jensenshannon(*kde)**2
_pesummary = utils.jension_shannon_divergence(samples, decimal=9)
np.testing.assert_almost_equal(_scipy, _pesummary)
from pesummary.core.plots.bounded_1d_kde import Bounded_1d_kde
_pesummary = utils.jension_shannon_divergence(
samples, decimal=9, kde=Bounded_1d_kde, xlow=4.5, xhigh=5.5
)
def test_make_cache_style_file():
"""Test that the `make_cache_style_file` works as expected
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment