Commit 992d2b7b authored by Daniel Williams's avatar Daniel Williams 🤖
Browse files

Merge branch 'science-tests' into 'master'

Science tests

See merge request !10
parents ecdc689c 3f223ecd
Pipeline #130376 failed with stages
in 4 minutes and 41 seconds
......@@ -44,6 +44,15 @@ It's also fairly easy to use the existing framework to implement a new model, us
george
new-model
Testing and verifying models
============================
.. toctree::
:maxdepth: 2
:caption: Verification
verification
Indices and tables
==================
......
==================
Model verification
==================
The ``heron`` package includes tools for verifying the output of a model against either other models or waveforms in a numerical relativity catalogue.
These allow the easy calculation of both in-sample and out-of-sample tests.
.. autoclass:: heron.testing
......@@ -32,6 +32,48 @@ def match(a, b, psd=None):
return pycbc.filter.match(data_a, data_b)
def sample_match(generator, times, p, comparison, psd=None):
"""
Calculate the match between the output of a model and a canonical waveform.
Parameters
----------
generator : `heron.model`
The heron model to be tested.
times : ndarray
An array of times at which the model should be evaluated.
p : dict
A dictionary of parameters for the waveform.
comparison : `elk.waveform`
The waveform which should be compared to the model output
psd : `pycbc.psd`, optional
The PSD which should be used to evaluate the waveform match.
"""
ts_data = generator.mean(p=p.copy(), times=times.copy())[0]
return match(ts_data, comparison, psd)[0]
def nrcat_match(generator, catalogue):
"""
Calculate the matches between each waveform in a given waveform catalogue
and the generator model.
"""
matches = {}
for waveform in catalogue.waveforms:
spins = ["spin 1x", "spin 1y", "spin 1z", "spin 2x", "spin 2y", "spin 2z"]
pars = dict(zip(spins, waveform.spins))
pars['mass ratio'] = waveform.mass_ratio
nr_data = waveform.timeseries(total_mass=60, f_low=70, t_max=0.02, t_min=-0.015)
matches[waveform.tag] = heron.testing.sample_match(generator,
nr_data[0].times,
pars,
nr_data[0])
return matches
def outsample_retrain(generator, catalogue = NRCatalogue('GeorgiaTech')):
"""
Calculate the out-sample matches between a given heron model
......@@ -49,18 +91,12 @@ def outsample_retrain(generator, catalogue = NRCatalogue('GeorgiaTech')):
for waveform in catalogue.waveforms:
try:
print(waveform.tag)
if waveform.tag in ["GT063{}".format(i) for i in range(10)] + ["GT0701", "GT0881", "GT0833"]:
continue
if waveform.tag in ["GT063{}".format(i) for i in range(10)] + ["GT0701"]:
continue
if waveform.tag in ["GT0548", "GT0639"]: continue
spins = ["spin 1x", "spin 1y", "spin 1z", "spin 2x", "spin 2y", "spin 2z"]
pars = dict(zip(spins, waveform.spins))
pars['mass ratio'] = waveform.mass_ratio
waveform_nr = waveform.timeseries(total_mass=60, f_low=70, t_max=0.02, t_min=-0.015)
waveform_nr = waveform.timeseries(total_mass=60, f_low=70, t_max=0.02, t_min=-0.015)
times = waveform_nr[0].times
new_catalogue = copy.copy(catalogue)
......
"""
Tests for the science testing code.
"""
import unittest
import numpy as np
from heron import testing
from elk.waveform import NRWaveform, Timeseries
from elk.catalogue import NRCatalogue
from heron.models.georgebased import HeronHodlr
class MockWaveform(NRWaveform):
def timeseries(self,
total_mass,
sample_rate=4096,
f_low=None,
distance=1,
coa_phase=0,
ma=None,
t_min=None,
t_max=None,
f_ref=None,
t_align=True):
return (Timeseries(data=np.random.randn(1000)*1e-19, times=np.linspace(t_min, t_max, 1000)),
Timeseries(data=np.random.randn(1000)*1e-19, times=np.linspace(t_min, t_max, 1000)))
class TestTests(unittest.TestCase):
"""
Test the science testing code.
"""
def setUp(self):
self.model = HeronHodlr()
self.samples_catalogue = NRCatalogue("GeorgiaTech")
mock_waveforms = [
MockWaveform("spam", {"q": 1.0,
"tag": "test",
"mass_ratio": 1.0,
"spin_1x": 0, "spin_1y": 0, "spin_1z": 0,
"spin_2x": 0, "spin_2y": 0, "spin_2z": 0,
"s1x": 0, "s1y": 0, "s1z": 0,
"s2x": 0, "s2y": 0, "s2z": 0
}),
MockWaveform("eggs", {"q": 0.8,
"tag": "test2",
"mass_ratio": 1.0,
"spin_1x": 0, "spin_1y": 0, "spin_1z": 0,
"spin_2x": 0, "spin_2y": 0, "spin_2z": 0,
"s1x": 0, "s1y": 0, "s1z": 0,
"s2x": 0, "s2y": 0, "s2z": 0
})
]
self.samples_catalogue.waveforms = mock_waveforms
def test_nrcat_match(self):
"""Test the NR catalogue matcher."""
matches = testing.nrcat_match(self.model, self.samples_catalogue)
self.assertEqual(len(matches.values()), 2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment