diff --git a/bilby/gw/result.py b/bilby/gw/result.py index 8f405cc99ec7623a98e3f193706061c59728e6fe..ed9125a94840b6f48518d0761ba18c3dc675f08c 100644 --- a/bilby/gw/result.py +++ b/bilby/gw/result.py @@ -289,6 +289,18 @@ class CompactBinaryCoalescenceResult(CoreResult): waveforms to have ~4000 entries. This should be sufficient for decent resolution. """ + + if format == "html": + try: + import plotly.graph_objects as go + from plotly.offline import plot + from plotly.subplots import make_subplots + except ImportError: + logger.warning( + "HTML plotting requested, but plotly cannot be imported, " + "falling back to png format for waveform plot.") + format = "png" + if isinstance(interferometer, str): interferometer = get_empty_interferometer(interferometer) interferometer.set_strain_data_from_zero_noise( @@ -317,6 +329,9 @@ class CompactBinaryCoalescenceResult(CoreResult): if end_time is None: end_time = 0.2 end_time = np.mean(self.posterior.geocent_time) + end_time + if format == "html": + start_time = - np.inf + end_time = np.inf time_idxs = ( (interferometer.time_array >= start_time) & (interferometer.time_array <= end_time) @@ -329,7 +344,7 @@ class CompactBinaryCoalescenceResult(CoreResult): logger.debug("Downsampling frequency mask to {} values".format( len(frequency_idxs)) ) - plot_times = interferometer.time_array[time_idxs] + plot_times = interferometer.time_array[time_idxs] - interferometer.strain_data.start_time plot_frequencies = interferometer.frequency_array[frequency_idxs] waveform_generator = WaveformGenerator( @@ -338,25 +353,73 @@ class CompactBinaryCoalescenceResult(CoreResult): frequency_domain_source_model=self.frequency_domain_source_model, parameter_conversion=self.parameter_conversion, waveform_arguments=self.waveform_arguments) - fig, axs = plt.subplots(2, 1) + + if format == "html": + fig = make_subplots( + rows=2, cols=1, + row_heights=[0.5, 0.5], + ) + fig.update_layout( + template='plotly_white', + font=dict( + family="Computer Modern", + ) + ) + else: + fig, axs = plt.subplots(2, 1) + if self.injection_parameters is not None: try: hf_inj = waveform_generator.frequency_domain_strain( self.injection_parameters) hf_inj_det = interferometer.get_detector_response( hf_inj, self.injection_parameters) - axs[0].loglog( - plot_frequencies, - asd_from_freq_series( - hf_inj_det[frequency_idxs], - 1 / interferometer.strain_data.duration), - color='k', label='injected', linestyle='--') - axs[1].plot( - plot_times, - infft(hf_inj_det / - interferometer.amplitude_spectral_density_array, - self.sampling_frequency)[time_idxs], - color='k', linestyle='--') + ht_inj_det = infft( + hf_inj_det / + interferometer.amplitude_spectral_density_array, + self.sampling_frequency)[time_idxs], + if format == "html": + fig.add_trace( + go.Scatter( + x=plot_frequencies, + y=asd_from_freq_series( + hf_inj_det[frequency_idxs], + 1 / interferometer.strain_data.duration), + fill=None, + mode='lines', + line_color='black', + name="Injection", + legendgroup='injection', + ), + row=1, + col=1, + ) + fig.add_trace( + go.Scatter( + x=plot_times, y=ht_inj_det, + fill=None, + mode='lines', + line_color='black', + name="Injection", + legendgroup='injection', + showlegend=False, + ), + row=2, + col=1, + ) + else: + axs[0].loglog( + plot_frequencies, + asd_from_freq_series( + hf_inj_det[frequency_idxs], + 1 / interferometer.strain_data.duration), + color='k', label='injected', linestyle='--') + axs[1].plot( + plot_times, + infft(hf_inj_det / + interferometer.amplitude_spectral_density_array, + self.sampling_frequency)[time_idxs], + color='k', linestyle='--') logger.debug('Plotted injection.') except IndexError: logger.info('Failed to plot injection.') @@ -386,64 +449,203 @@ class CompactBinaryCoalescenceResult(CoreResult): ) ) - axs[0].loglog( - plot_frequencies, - np.median(fd_waveforms, axis=0), color='r', label='Median') - axs[0].fill_between( - plot_frequencies, - np.percentile(fd_waveforms, lower_percentile, axis=0), - np.percentile(fd_waveforms, upper_percentile, axis=0), - color='r', label='{} % Interval'.format( - int(upper_percentile - lower_percentile)), - alpha=0.3) - axs[1].plot( - plot_times, np.median(td_waveforms, axis=0), - color='r') - axs[1].fill_between( - plot_times, np.percentile( - td_waveforms, lower_percentile, axis=0), - np.percentile(td_waveforms, upper_percentile, axis=0), color='r', - alpha=0.3) - - try: + if format == "html": + fig.add_trace( + go.Scatter( + x=plot_frequencies, y=np.median(fd_waveforms, axis=0), + fill=None, + mode='lines', line_color='crimson', + opacity=1, + name="Median reconstructed", + legendgroup='median', + ), + row=1, + col=1, + ) + fig.add_trace( + go.Scatter( + x=plot_frequencies, y=np.percentile(fd_waveforms, lower_percentile, axis=0), + fill=None, + mode='lines', + line_color='crimson', + opacity=0.1, + name="{:.2f}% credible interval".format(upper_percentile - lower_percentile), + legendgroup='uncertainty', + ), + row=1, + col=1, + ) + fig.add_trace( + go.Scatter( + x=plot_frequencies, y=np.percentile(fd_waveforms, upper_percentile, axis=0), + fill='tonexty', + mode='lines', + line_color='crimson', + opacity=0.1, + name="{:.2f}% credible interval".format(upper_percentile - lower_percentile), + legendgroup='uncertainty', + showlegend=False, + ), + row=1, + col=1, + ) + fig.add_trace( + go.Scatter( + x=plot_times, y=np.median(td_waveforms, axis=0), + fill=None, + mode='lines', line_color='crimson', + opacity=1, + name="Median reconstructed", + legendgroup='median', + showlegend=False, + ), + row=2, + col=1, + ) + fig.add_trace( + go.Scatter( + x=plot_times, y=np.percentile(td_waveforms, lower_percentile, axis=0), + fill=None, + mode='lines', + line_color='crimson', + opacity=0.1, + name="{:.2f}% credible interval".format(upper_percentile - lower_percentile), + legendgroup='uncertainty', + showlegend=False, + ), + row=2, + col=1, + ) + fig.add_trace( + go.Scatter( + x=plot_times, y=np.percentile(td_waveforms, upper_percentile, axis=0), + fill='tonexty', + mode='lines', + line_color='crimson', + opacity=0.1, + name="{:.2f}% credible interval".format(upper_percentile - lower_percentile), + legendgroup='uncertainty', + showlegend=False, + ), + row=2, + col=1, + ) + else: axs[0].loglog( plot_frequencies, - asd_from_freq_series( - interferometer.frequency_domain_strain[frequency_idxs], - 1 / interferometer.strain_data.duration), - color='b', label='Data', alpha=0.3) - axs[0].loglog( + np.median(fd_waveforms, axis=0), color='r', label='Median') + axs[0].fill_between( plot_frequencies, - interferometer.amplitude_spectral_density_array[frequency_idxs], - color='b', label='PSD') + np.percentile(fd_waveforms, lower_percentile, axis=0), + np.percentile(fd_waveforms, upper_percentile, axis=0), + color='r', label='{} % Interval'.format( + int(upper_percentile - lower_percentile)), + alpha=0.3) axs[1].plot( - plot_times, infft( - interferometer.whitened_frequency_domain_strain, - sampling_frequency=interferometer.strain_data.sampling_frequency)[time_idxs], - color='b', alpha=0.3) + plot_times, np.median(td_waveforms, axis=0), + color='r') + axs[1].fill_between( + plot_times, np.percentile( + td_waveforms, lower_percentile, axis=0), + np.percentile(td_waveforms, upper_percentile, axis=0), color='r', + alpha=0.3) + + try: + if format == "html": + fig.add_trace( + go.Scatter( + x=plot_frequencies, + y=asd_from_freq_series( + interferometer.frequency_domain_strain[frequency_idxs], + 1 / interferometer.strain_data.duration + ), + fill=None, + mode='lines', line_color='darkblue', + opacity=0.5, + name="Data", + legendgroup='data', + ), + row=1, + col=1, + ) + fig.add_trace( + go.Scatter( + x=plot_frequencies, + y=interferometer.amplitude_spectral_density_array[frequency_idxs], + fill=None, + mode='lines', line_color='darkblue', + opacity=0.8, + name="ASD", + legendgroup='asd', + ), + row=1, + col=1, + ) + fig.add_trace( + go.Scatter( + x=plot_times, + y=infft( + interferometer.whitened_frequency_domain_strain, + sampling_frequency=interferometer.strain_data.sampling_frequency)[time_idxs], + fill=None, + mode='lines', line_color='darkblue', + opacity=0.5, + name="Data", + legendgroup='data', + showlegend=False, + ), + row=2, + col=1, + ) + else: + axs[0].loglog( + plot_frequencies, + asd_from_freq_series( + interferometer.frequency_domain_strain[frequency_idxs], + 1 / interferometer.strain_data.duration), + color='b', label='Data', alpha=0.3) + axs[0].loglog( + plot_frequencies, + interferometer.amplitude_spectral_density_array[frequency_idxs], + color='b', label='PSD') + axs[1].plot( + plot_times, infft( + interferometer.whitened_frequency_domain_strain, + sampling_frequency=interferometer.strain_data.sampling_frequency)[time_idxs], + color='b', alpha=0.3) logger.debug('Plotted interferometer data.') except AttributeError: pass - axs[0].set_xlim(interferometer.minimum_frequency, - interferometer.maximum_frequency) - axs[1].set_xlim(start_time, end_time) + if format == "html": + fig.update_xaxes(title_text="Frequency [Hz]", type="log", row=1) + fig.update_xaxes(title_text="Time [s] - {}".format( + interferometer.strain_data.start_time), type="linear", row=2) + fig.update_yaxes( + title_text="$\mathrm{ASD}\,\left[\mathrm{Hz}^{-1/2}\\right]$", type="log", row=1) + fig.update_yaxes(title_text="Whitened Strain", type="linear", row=2) + else: + axs[0].set_xlim(interferometer.minimum_frequency, + interferometer.maximum_frequency) + axs[1].set_xlim(start_time, end_time) - axs[0].set_xlabel('$f$ [$Hz$]') - axs[1].set_xlabel('$t$ [$s$]') - axs[0].set_ylabel('$\\tilde{h}(f)$ [Hz$^{-\\frac{1}{2}}$]') - axs[1].set_ylabel('Whitened strain') - axs[0].legend(loc='lower left') + axs[0].set_xlabel('$f$ [$Hz$]') + axs[1].set_xlabel('$t$ [$s$]') + axs[0].set_ylabel('$\\tilde{h}(f)$ [Hz$^{-\\frac{1}{2}}$]') + axs[1].set_ylabel('Whitened strain') + axs[0].legend(loc='lower left') - plt.tight_layout() if save: filename = os.path.join( self.outdir, self.label + '_{}_waveform.{}'.format( interferometer.name, format)) - plt.savefig(filename, format=format, dpi=600) + if format == 'html': + plot(fig, filename=filename, include_mathjax='cdn', auto_open=False) + else: + plt.savefig(filename, format=format, dpi=600) + plt.close() logger.debug("Figure saved to {}".format(filename)) - plt.close() else: return fig diff --git a/optional_requirements.txt b/optional_requirements.txt index 243f6563c97a0fa6c40b01fd0282340fc8bc65f7..cd16a99f5eef1d8839dffb60db9623b4771b04e3 100644 --- a/optional_requirements.txt +++ b/optional_requirements.txt @@ -2,3 +2,4 @@ astropy lalsuite gwpy theano +plotly