Skip to content
Snippets Groups Projects
Commit 3d257816 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'html-waveform-plot' into 'master'

allow html writing of waveform plots

See merge request !641
parents 13fadd20 68bc433f
No related branches found
No related tags found
1 merge request!641allow html writing of waveform plots
Pipeline #89168 passed
......@@ -289,6 +289,18 @@ class CompactBinaryCoalescenceResult(CoreResult):
waveforms to have ~4000 entries. This should be sufficient for decent
resolution.
"""
if format == "html":
try:
import plotly.graph_objects as go
from plotly.offline import plot
from plotly.subplots import make_subplots
except ImportError:
logger.warning(
"HTML plotting requested, but plotly cannot be imported, "
"falling back to png format for waveform plot.")
format = "png"
if isinstance(interferometer, str):
interferometer = get_empty_interferometer(interferometer)
interferometer.set_strain_data_from_zero_noise(
......@@ -317,6 +329,9 @@ class CompactBinaryCoalescenceResult(CoreResult):
if end_time is None:
end_time = 0.2
end_time = np.mean(self.posterior.geocent_time) + end_time
if format == "html":
start_time = - np.inf
end_time = np.inf
time_idxs = (
(interferometer.time_array >= start_time) &
(interferometer.time_array <= end_time)
......@@ -329,7 +344,7 @@ class CompactBinaryCoalescenceResult(CoreResult):
logger.debug("Downsampling frequency mask to {} values".format(
len(frequency_idxs))
)
plot_times = interferometer.time_array[time_idxs]
plot_times = interferometer.time_array[time_idxs] - interferometer.strain_data.start_time
plot_frequencies = interferometer.frequency_array[frequency_idxs]
waveform_generator = WaveformGenerator(
......@@ -338,25 +353,73 @@ class CompactBinaryCoalescenceResult(CoreResult):
frequency_domain_source_model=self.frequency_domain_source_model,
parameter_conversion=self.parameter_conversion,
waveform_arguments=self.waveform_arguments)
fig, axs = plt.subplots(2, 1)
if format == "html":
fig = make_subplots(
rows=2, cols=1,
row_heights=[0.5, 0.5],
)
fig.update_layout(
template='plotly_white',
font=dict(
family="Computer Modern",
)
)
else:
fig, axs = plt.subplots(2, 1)
if self.injection_parameters is not None:
try:
hf_inj = waveform_generator.frequency_domain_strain(
self.injection_parameters)
hf_inj_det = interferometer.get_detector_response(
hf_inj, self.injection_parameters)
axs[0].loglog(
plot_frequencies,
asd_from_freq_series(
hf_inj_det[frequency_idxs],
1 / interferometer.strain_data.duration),
color='k', label='injected', linestyle='--')
axs[1].plot(
plot_times,
infft(hf_inj_det /
interferometer.amplitude_spectral_density_array,
self.sampling_frequency)[time_idxs],
color='k', linestyle='--')
ht_inj_det = infft(
hf_inj_det /
interferometer.amplitude_spectral_density_array,
self.sampling_frequency)[time_idxs],
if format == "html":
fig.add_trace(
go.Scatter(
x=plot_frequencies,
y=asd_from_freq_series(
hf_inj_det[frequency_idxs],
1 / interferometer.strain_data.duration),
fill=None,
mode='lines',
line_color='black',
name="Injection",
legendgroup='injection',
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=plot_times, y=ht_inj_det,
fill=None,
mode='lines',
line_color='black',
name="Injection",
legendgroup='injection',
showlegend=False,
),
row=2,
col=1,
)
else:
axs[0].loglog(
plot_frequencies,
asd_from_freq_series(
hf_inj_det[frequency_idxs],
1 / interferometer.strain_data.duration),
color='k', label='injected', linestyle='--')
axs[1].plot(
plot_times,
infft(hf_inj_det /
interferometer.amplitude_spectral_density_array,
self.sampling_frequency)[time_idxs],
color='k', linestyle='--')
logger.debug('Plotted injection.')
except IndexError:
logger.info('Failed to plot injection.')
......@@ -386,64 +449,203 @@ class CompactBinaryCoalescenceResult(CoreResult):
)
)
axs[0].loglog(
plot_frequencies,
np.median(fd_waveforms, axis=0), color='r', label='Median')
axs[0].fill_between(
plot_frequencies,
np.percentile(fd_waveforms, lower_percentile, axis=0),
np.percentile(fd_waveforms, upper_percentile, axis=0),
color='r', label='{} % Interval'.format(
int(upper_percentile - lower_percentile)),
alpha=0.3)
axs[1].plot(
plot_times, np.median(td_waveforms, axis=0),
color='r')
axs[1].fill_between(
plot_times, np.percentile(
td_waveforms, lower_percentile, axis=0),
np.percentile(td_waveforms, upper_percentile, axis=0), color='r',
alpha=0.3)
try:
if format == "html":
fig.add_trace(
go.Scatter(
x=plot_frequencies, y=np.median(fd_waveforms, axis=0),
fill=None,
mode='lines', line_color='crimson',
opacity=1,
name="Median reconstructed",
legendgroup='median',
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=plot_frequencies, y=np.percentile(fd_waveforms, lower_percentile, axis=0),
fill=None,
mode='lines',
line_color='crimson',
opacity=0.1,
name="{:.2f}% credible interval".format(upper_percentile - lower_percentile),
legendgroup='uncertainty',
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=plot_frequencies, y=np.percentile(fd_waveforms, upper_percentile, axis=0),
fill='tonexty',
mode='lines',
line_color='crimson',
opacity=0.1,
name="{:.2f}% credible interval".format(upper_percentile - lower_percentile),
legendgroup='uncertainty',
showlegend=False,
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=plot_times, y=np.median(td_waveforms, axis=0),
fill=None,
mode='lines', line_color='crimson',
opacity=1,
name="Median reconstructed",
legendgroup='median',
showlegend=False,
),
row=2,
col=1,
)
fig.add_trace(
go.Scatter(
x=plot_times, y=np.percentile(td_waveforms, lower_percentile, axis=0),
fill=None,
mode='lines',
line_color='crimson',
opacity=0.1,
name="{:.2f}% credible interval".format(upper_percentile - lower_percentile),
legendgroup='uncertainty',
showlegend=False,
),
row=2,
col=1,
)
fig.add_trace(
go.Scatter(
x=plot_times, y=np.percentile(td_waveforms, upper_percentile, axis=0),
fill='tonexty',
mode='lines',
line_color='crimson',
opacity=0.1,
name="{:.2f}% credible interval".format(upper_percentile - lower_percentile),
legendgroup='uncertainty',
showlegend=False,
),
row=2,
col=1,
)
else:
axs[0].loglog(
plot_frequencies,
asd_from_freq_series(
interferometer.frequency_domain_strain[frequency_idxs],
1 / interferometer.strain_data.duration),
color='b', label='Data', alpha=0.3)
axs[0].loglog(
np.median(fd_waveforms, axis=0), color='r', label='Median')
axs[0].fill_between(
plot_frequencies,
interferometer.amplitude_spectral_density_array[frequency_idxs],
color='b', label='PSD')
np.percentile(fd_waveforms, lower_percentile, axis=0),
np.percentile(fd_waveforms, upper_percentile, axis=0),
color='r', label='{} % Interval'.format(
int(upper_percentile - lower_percentile)),
alpha=0.3)
axs[1].plot(
plot_times, infft(
interferometer.whitened_frequency_domain_strain,
sampling_frequency=interferometer.strain_data.sampling_frequency)[time_idxs],
color='b', alpha=0.3)
plot_times, np.median(td_waveforms, axis=0),
color='r')
axs[1].fill_between(
plot_times, np.percentile(
td_waveforms, lower_percentile, axis=0),
np.percentile(td_waveforms, upper_percentile, axis=0), color='r',
alpha=0.3)
try:
if format == "html":
fig.add_trace(
go.Scatter(
x=plot_frequencies,
y=asd_from_freq_series(
interferometer.frequency_domain_strain[frequency_idxs],
1 / interferometer.strain_data.duration
),
fill=None,
mode='lines', line_color='darkblue',
opacity=0.5,
name="Data",
legendgroup='data',
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=plot_frequencies,
y=interferometer.amplitude_spectral_density_array[frequency_idxs],
fill=None,
mode='lines', line_color='darkblue',
opacity=0.8,
name="ASD",
legendgroup='asd',
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=plot_times,
y=infft(
interferometer.whitened_frequency_domain_strain,
sampling_frequency=interferometer.strain_data.sampling_frequency)[time_idxs],
fill=None,
mode='lines', line_color='darkblue',
opacity=0.5,
name="Data",
legendgroup='data',
showlegend=False,
),
row=2,
col=1,
)
else:
axs[0].loglog(
plot_frequencies,
asd_from_freq_series(
interferometer.frequency_domain_strain[frequency_idxs],
1 / interferometer.strain_data.duration),
color='b', label='Data', alpha=0.3)
axs[0].loglog(
plot_frequencies,
interferometer.amplitude_spectral_density_array[frequency_idxs],
color='b', label='PSD')
axs[1].plot(
plot_times, infft(
interferometer.whitened_frequency_domain_strain,
sampling_frequency=interferometer.strain_data.sampling_frequency)[time_idxs],
color='b', alpha=0.3)
logger.debug('Plotted interferometer data.')
except AttributeError:
pass
axs[0].set_xlim(interferometer.minimum_frequency,
interferometer.maximum_frequency)
axs[1].set_xlim(start_time, end_time)
if format == "html":
fig.update_xaxes(title_text="Frequency [Hz]", type="log", row=1)
fig.update_xaxes(title_text="Time [s] - {}".format(
interferometer.strain_data.start_time), type="linear", row=2)
fig.update_yaxes(
title_text="$\mathrm{ASD}\,\left[\mathrm{Hz}^{-1/2}\\right]$", type="log", row=1)
fig.update_yaxes(title_text="Whitened Strain", type="linear", row=2)
else:
axs[0].set_xlim(interferometer.minimum_frequency,
interferometer.maximum_frequency)
axs[1].set_xlim(start_time, end_time)
axs[0].set_xlabel('$f$ [$Hz$]')
axs[1].set_xlabel('$t$ [$s$]')
axs[0].set_ylabel('$\\tilde{h}(f)$ [Hz$^{-\\frac{1}{2}}$]')
axs[1].set_ylabel('Whitened strain')
axs[0].legend(loc='lower left')
axs[0].set_xlabel('$f$ [$Hz$]')
axs[1].set_xlabel('$t$ [$s$]')
axs[0].set_ylabel('$\\tilde{h}(f)$ [Hz$^{-\\frac{1}{2}}$]')
axs[1].set_ylabel('Whitened strain')
axs[0].legend(loc='lower left')
plt.tight_layout()
if save:
filename = os.path.join(
self.outdir,
self.label + '_{}_waveform.{}'.format(
interferometer.name, format))
plt.savefig(filename, format=format, dpi=600)
if format == 'html':
plot(fig, filename=filename, include_mathjax='cdn', auto_open=False)
else:
plt.savefig(filename, format=format, dpi=600)
plt.close()
logger.debug("Figure saved to {}".format(filename))
plt.close()
else:
return fig
......
......@@ -2,3 +2,4 @@ astropy
lalsuite
gwpy
theano
plotly
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment