Commit 76f4e5b0 authored by Gregory Ashton's avatar Gregory Ashton

Merge branch 'waveform-plot' into 'master'

add function to make waveform plot

See merge request !484
parents d678a13e b537889d
Pipeline #64120 passed with stages
in 6 minutes and 5 seconds
......@@ -9,9 +9,12 @@ from matplotlib import rcParams
import numpy as np
from ..core.result import Result as CoreResult
from ..core.utils import logger, check_directory_exists_and_if_not_mkdir
from .utils import (plot_spline_pos,
spline_angle_xform)
from ..core.utils import infft, logger, check_directory_exists_and_if_not_mkdir
from .utils import plot_spline_pos, spline_angle_xform
from .waveform_generator import WaveformGenerator
from .detector import get_empty_interferometer, Interferometer
from .source import lal_binary_black_hole
from .conversion import convert_to_lal_binary_black_hole_parameters
class CompactBinaryCoalescenceResult(CoreResult):
......@@ -66,12 +69,24 @@ class CompactBinaryCoalescenceResult(CoreResult):
return self.__get_from_nested_meta_data(
'likelihood', 'distance_marginalization')
@property
def interferometers(self):
""" List of interferometer names """
return [name for name in self.__get_from_nested_meta_data(
'likelihood', 'interferometers')]
@property
def waveform_approximant(self):
""" String of the waveform approximant """
return self.__get_from_nested_meta_data(
'likelihood', 'waveform_arguments', 'waveform_approximant')
@property
def waveform_arguments(self):
""" Dict of waveform arguments """
return self.__get_from_nested_meta_data(
'likelihood', 'waveform_arguments')
@property
def reference_frequency(self):
""" Float of the reference frequency """
......@@ -184,6 +199,161 @@ class CompactBinaryCoalescenceResult(CoreResult):
fig.savefig(filename, bbox_inches='tight')
plt.close(fig)
def plot_waveform_posterior(
self, interferometers=None, level=0.9, n_samples=None):
"""
Plot the posterior for the waveform in the frequency domain and
whitened time domain for all detectors.
If the strain data is passed that will be plotted.
If injection parameters can be found, the injection will be plotted.
Parameters
----------
interferometers: (list, bilby.gw.detector.InterferometerList, optional)
level: float, optional
symmetric confidence interval to show, default is 90%
n_samples: int, optional
number of samples to use to calculate the median/interval
default is all
Returns
-------
fig: figure-handle, only is save=False
"""
if interferometers is None:
interferometers = self.interferometers
elif not isinstance(interferometers, list):
raise TypeError(
'interferometers must be a list of InterferometerList')
for ifo in interferometers:
self.plot_interferometer_waveform_posterior(
interferometer=ifo, level=level, n_samples=n_samples, save=True)
def plot_interferometer_waveform_posterior(
self, interferometer, level=0.9, n_samples=None, save=True):
"""
Plot the posterior for the waveform in the frequency domain and
whitened time domain.
If the strain data is passed that will be plotted.
If injection parameters can be found, the injection will be plotted.
Parameters
----------
interferometer: (str, bilby.gw.detector.interferometer.Interferometer)
detector to use, if an Interferometer object is passed the data
will be overlaid on the posterior
level: float, optional
symmetric confidence interval to show, default is 90%
n_samples: int, optional
number of samples to use to calculate the median/interval
default is all
save: bool, optional
whether to save the image, default=True
if False, figure handle is returned
Returns
-------
fig: figure-handle, only is save=False
"""
if isinstance(interferometer, str):
interferometer = get_empty_interferometer(interferometer)
interferometer.set_strain_data_from_zero_noise(
sampling_frequency=self.sampling_frequency,
duration=self.duration, start_time=self.start_time)
elif not isinstance(interferometer, Interferometer):
raise TypeError(
'interferometer must be either str or Interferometer')
if n_samples is None:
n_samples = len(self.posterior)
waveform_generator = WaveformGenerator(
duration=self.duration, sampling_frequency=self.sampling_frequency,
start_time=self.start_time,
frequency_domain_source_model=lal_binary_black_hole,
parameter_conversion=convert_to_lal_binary_black_hole_parameters,
waveform_arguments=self.waveform_arguments)
fig, axs = plt.subplots(2, 1)
if self.injection_parameters is not None:
try:
hf_inj = waveform_generator.frequency_domain_strain(
self.injection_parameters)
hf_inj_det = interferometer.get_detector_response(
hf_inj, self.injection_parameters)
axs[0].loglog(interferometer.frequency_array, abs(hf_inj_det),
color='k', label='injected', linestyle='--')
axs[1].plot(
interferometer.time_array,
infft(hf_inj_det / interferometer.amplitude_spectral_density_array,
self.sampling_frequency),
color='k', linestyle='--')
except IndexError:
logger.info('Failed to plot injection.')
fd_waveforms = list()
for ii in range(n_samples):
params = dict(self.posterior.iloc[ii])
wf_pols = waveform_generator.frequency_domain_strain(params)
fd_waveforms.append(
interferometer.get_detector_response(wf_pols, params))
fd_waveforms = np.array(fd_waveforms)
td_waveforms = infft(
fd_waveforms / interferometer.amplitude_spectral_density_array,
self.sampling_frequency)
lower_percentile = level * 100 / 2
upper_percentile = 100 - lower_percentile
axs[0].loglog(
interferometer.frequency_array,
np.median(abs(fd_waveforms), axis=0), color='r', label='Median')
axs[0].fill_between(
interferometer.frequency_array, np.percentile(
abs(fd_waveforms), lower_percentile, axis=0),
np.percentile(abs(fd_waveforms), upper_percentile, axis=0),
color='r', label='{} % Interval'.format(int(level * 100)),
alpha=0.5)
axs[1].plot(
interferometer.time_array, np.median(td_waveforms, axis=0),
color='r')
axs[1].fill_between(
interferometer.time_array, np.percentile(
td_waveforms, lower_percentile, axis=0),
np.percentile(td_waveforms, upper_percentile, axis=0), color='r',
alpha=0.5)
try:
axs[0].loglog(
interferometer.frequency_array,
abs(interferometer.frequency_domain_strain),
color='b', label='Data', alpha=0.5)
axs[1].plot(
interferometer.time_array, infft(
interferometer.whitened_frequency_domain_strain,
sampling_frequency=interferometer.strain_data.sampling_frequency),
color='b', alpha=0.5)
except AttributeError:
pass
axs[0].legend()
axs[0].set_xlim(interferometer.minimum_frequency,
interferometer.maximum_frequency)
axs[1].set_xlim(
np.mean(self.posterior.geocent_time) - 0.5,
np.mean(self.posterior.geocent_time) + 0.5)
plt.tight_layout()
if save:
plt.savefig(os.path.join(
self.outdir,
self.label + '_{}_waveform.png'.format(interferometer.name)))
plt.close()
else:
return fig
def plot_skymap(
self, maxpts=None, trials=5, jobs=1, enable_multiresolution=True,
objid=None, instruments=None, geo=False, dpi=600,
......
......@@ -92,6 +92,11 @@ class TestCBCResult(unittest.TestCase):
with self.assertRaises(AttributeError):
self.result.waveform_approximant
def test_waveform_arguments(self):
self.assertEqual(
self.result.waveform_arguments,
self.meta_data['likelihood']['waveform_arguments'])
def test_frequency_domain_source_model(self):
self.assertEqual(
self.result.frequency_domain_source_model,
......@@ -102,6 +107,11 @@ class TestCBCResult(unittest.TestCase):
with self.assertRaises(AttributeError):
self.result.frequency_domain_source_model
def test_interferometer_names(self):
self.assertEqual(
self.result.interferometers,
[name for name in self.meta_data['likelihood']['interferometers']])
def test_detector_injection_properties(self):
self.assertEqual(
self.result.detector_injection_properties('H1'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment