diff --git a/bilby/gw/result.py b/bilby/gw/result.py index 52af660fcf7cce6ad9f6158cdbb9f9799f3d6a1b..14ee2a066446580ecad0f755f7cfc988bc917b62 100644 --- a/bilby/gw/result.py +++ b/bilby/gw/result.py @@ -295,6 +295,10 @@ class CompactBinaryCoalescenceResult(CoreResult): resolution. """ + DATA_COLOR = "#ff7f0e" + WAVEFORM_COLOR = "#1f77b4" + INJECTION_COLOR = "#000000" + if format == "html": try: import plotly.graph_objects as go @@ -305,15 +309,25 @@ class CompactBinaryCoalescenceResult(CoreResult): "HTML plotting requested, but plotly cannot be imported, " "falling back to png format for waveform plot.") format = "png" + else: + _old_tex = rcParams["text.usetex"] + _old_serif = rcParams["font.serif"] + _old_family = rcParams["font.family"] + rcParams["text.usetex"] = True + rcParams["font.serif"] = "Computer Modern Roman" + rcParams["font.family"] = "Serif" if isinstance(interferometer, str): interferometer = get_empty_interferometer(interferometer) interferometer.set_strain_data_from_zero_noise( sampling_frequency=self.sampling_frequency, duration=self.duration, start_time=self.start_time) + PLOT_DATA = False elif not isinstance(interferometer, Interferometer): raise TypeError( 'interferometer must be either str or Interferometer') + else: + PLOT_DATA = True logger.info("Generating waveform figure for {}".format( interferometer.name)) @@ -349,7 +363,11 @@ class CompactBinaryCoalescenceResult(CoreResult): logger.debug("Downsampling frequency mask to {} values".format( len(frequency_idxs)) ) - plot_times = interferometer.time_array[time_idxs] - interferometer.strain_data.start_time + plot_times = interferometer.time_array[time_idxs] + # if format == "html": + plot_times -= interferometer.strain_data.start_time + start_time -= interferometer.strain_data.start_time + end_time -= interferometer.strain_data.start_time plot_frequencies = interferometer.frequency_array[frequency_idxs] waveform_generator = self.waveform_generator_class( @@ -373,61 +391,70 @@ class CompactBinaryCoalescenceResult(CoreResult): else: fig, axs = plt.subplots(2, 1) - if self.injection_parameters is not None: - try: - hf_inj = waveform_generator.frequency_domain_strain( - self.injection_parameters) - hf_inj_det = interferometer.get_detector_response( - hf_inj, self.injection_parameters) - ht_inj_det = infft( - hf_inj_det / - interferometer.amplitude_spectral_density_array, - self.sampling_frequency)[time_idxs], - if format == "html": - fig.add_trace( - go.Scatter( - x=plot_frequencies, - y=asd_from_freq_series( - hf_inj_det[frequency_idxs], - 1 / interferometer.strain_data.duration), - fill=None, - mode='lines', - line_color='black', - name="Injection", - legendgroup='injection', - ), - row=1, - col=1, - ) - fig.add_trace( - go.Scatter( - x=plot_times, y=ht_inj_det, - fill=None, - mode='lines', - line_color='black', - name="Injection", - legendgroup='injection', - showlegend=False, + if PLOT_DATA: + if format == "html": + fig.add_trace( + go.Scatter( + x=plot_frequencies, + y=asd_from_freq_series( + interferometer.frequency_domain_strain[frequency_idxs], + 1 / interferometer.strain_data.duration ), - row=2, - col=1, - ) - else: - axs[0].loglog( - plot_frequencies, - asd_from_freq_series( - hf_inj_det[frequency_idxs], - 1 / interferometer.strain_data.duration), - color='k', label='injected', linestyle='--') - axs[1].plot( - plot_times, - infft(hf_inj_det / - interferometer.amplitude_spectral_density_array, - self.sampling_frequency)[time_idxs], - color='k', linestyle='--') - logger.debug('Plotted injection.') - except IndexError: - logger.info('Failed to plot injection.') + fill=None, + mode='lines', line_color=DATA_COLOR, + opacity=0.5, + name="Data", + legendgroup='data', + ), + row=1, + col=1, + ) + fig.add_trace( + go.Scatter( + x=plot_frequencies, + y=interferometer.amplitude_spectral_density_array[frequency_idxs], + fill=None, + mode='lines', line_color=DATA_COLOR, + opacity=0.8, + name="ASD", + legendgroup='asd', + ), + row=1, + col=1, + ) + fig.add_trace( + go.Scatter( + x=plot_times, + y=infft( + interferometer.whitened_frequency_domain_strain, + sampling_frequency=interferometer.strain_data.sampling_frequency)[time_idxs], + fill=None, + mode='lines', line_color=DATA_COLOR, + opacity=0.5, + name="Data", + legendgroup='data', + showlegend=False, + ), + row=2, + col=1, + ) + else: + axs[0].loglog( + plot_frequencies, + asd_from_freq_series( + interferometer.frequency_domain_strain[frequency_idxs], + 1 / interferometer.strain_data.duration), + color=DATA_COLOR, label='Data', alpha=0.3) + axs[0].loglog( + plot_frequencies, + interferometer.amplitude_spectral_density_array[frequency_idxs], + color=DATA_COLOR, label='ASD') + axs[1].plot( + plot_times, infft( + interferometer.whitened_frequency_domain_strain, + sampling_frequency=interferometer.strain_data.sampling_frequency)[time_idxs], + color=DATA_COLOR, alpha=0.3) + logger.debug('Plotted interferometer data.') fd_waveforms = list() td_waveforms = list() @@ -459,7 +486,7 @@ class CompactBinaryCoalescenceResult(CoreResult): go.Scatter( x=plot_frequencies, y=np.median(fd_waveforms, axis=0), fill=None, - mode='lines', line_color='crimson', + mode='lines', line_color=WAVEFORM_COLOR, opacity=1, name="Median reconstructed", legendgroup='median', @@ -472,7 +499,7 @@ class CompactBinaryCoalescenceResult(CoreResult): x=plot_frequencies, y=np.percentile(fd_waveforms, lower_percentile, axis=0), fill=None, mode='lines', - line_color='crimson', + line_color=WAVEFORM_COLOR, opacity=0.1, name="{:.2f}% credible interval".format(upper_percentile - lower_percentile), legendgroup='uncertainty', @@ -485,7 +512,7 @@ class CompactBinaryCoalescenceResult(CoreResult): x=plot_frequencies, y=np.percentile(fd_waveforms, upper_percentile, axis=0), fill='tonexty', mode='lines', - line_color='crimson', + line_color=WAVEFORM_COLOR, opacity=0.1, name="{:.2f}% credible interval".format(upper_percentile - lower_percentile), legendgroup='uncertainty', @@ -498,7 +525,7 @@ class CompactBinaryCoalescenceResult(CoreResult): go.Scatter( x=plot_times, y=np.median(td_waveforms, axis=0), fill=None, - mode='lines', line_color='crimson', + mode='lines', line_color=WAVEFORM_COLOR, opacity=1, name="Median reconstructed", legendgroup='median', @@ -512,7 +539,7 @@ class CompactBinaryCoalescenceResult(CoreResult): x=plot_times, y=np.percentile(td_waveforms, lower_percentile, axis=0), fill=None, mode='lines', - line_color='crimson', + line_color=WAVEFORM_COLOR, opacity=0.1, name="{:.2f}% credible interval".format(upper_percentile - lower_percentile), legendgroup='uncertainty', @@ -526,7 +553,7 @@ class CompactBinaryCoalescenceResult(CoreResult): x=plot_times, y=np.percentile(td_waveforms, upper_percentile, axis=0), fill='tonexty', mode='lines', - line_color='crimson', + line_color=WAVEFORM_COLOR, opacity=0.1, name="{:.2f}% credible interval".format(upper_percentile - lower_percentile), legendgroup='uncertainty', @@ -538,106 +565,95 @@ class CompactBinaryCoalescenceResult(CoreResult): else: axs[0].loglog( plot_frequencies, - np.median(fd_waveforms, axis=0), color='r', label='Median') + np.median(fd_waveforms, axis=0), color=WAVEFORM_COLOR, label='Median reconstructed') axs[0].fill_between( plot_frequencies, np.percentile(fd_waveforms, lower_percentile, axis=0), np.percentile(fd_waveforms, upper_percentile, axis=0), - color='r', label='{} % Interval'.format( + color=WAVEFORM_COLOR, label='{}\% credible interval'.format( int(upper_percentile - lower_percentile)), alpha=0.3) axs[1].plot( plot_times, np.median(td_waveforms, axis=0), - color='r') + color=WAVEFORM_COLOR) axs[1].fill_between( plot_times, np.percentile( td_waveforms, lower_percentile, axis=0), - np.percentile(td_waveforms, upper_percentile, axis=0), color='r', + np.percentile(td_waveforms, upper_percentile, axis=0), + color=WAVEFORM_COLOR, alpha=0.3) - try: - if format == "html": - fig.add_trace( - go.Scatter( - x=plot_frequencies, - y=asd_from_freq_series( - interferometer.frequency_domain_strain[frequency_idxs], - 1 / interferometer.strain_data.duration + if self.injection_parameters is not None: + try: + hf_inj = waveform_generator.frequency_domain_strain( + self.injection_parameters) + hf_inj_det = interferometer.get_detector_response( + hf_inj, self.injection_parameters) + ht_inj_det = infft( + hf_inj_det / + interferometer.amplitude_spectral_density_array, + self.sampling_frequency)[time_idxs] + if format == "html": + fig.add_trace( + go.Scatter( + x=plot_frequencies, + y=asd_from_freq_series( + hf_inj_det[frequency_idxs], + 1 / interferometer.strain_data.duration), + fill=None, + mode='lines', + line=dict(color=INJECTION_COLOR, dash='dot'), + name="Injection", + legendgroup='injection', ), - fill=None, - mode='lines', line_color='darkblue', - opacity=0.5, - name="Data", - legendgroup='data', - ), - row=1, - col=1, - ) - fig.add_trace( - go.Scatter( - x=plot_frequencies, - y=interferometer.amplitude_spectral_density_array[frequency_idxs], - fill=None, - mode='lines', line_color='darkblue', - opacity=0.8, - name="ASD", - legendgroup='asd', - ), - row=1, - col=1, - ) - fig.add_trace( - go.Scatter( - x=plot_times, - y=infft( - interferometer.whitened_frequency_domain_strain, - sampling_frequency=interferometer.strain_data.sampling_frequency)[time_idxs], - fill=None, - mode='lines', line_color='darkblue', - opacity=0.5, - name="Data", - legendgroup='data', - showlegend=False, - ), - row=2, - col=1, - ) - else: - axs[0].loglog( - plot_frequencies, - asd_from_freq_series( - interferometer.frequency_domain_strain[frequency_idxs], - 1 / interferometer.strain_data.duration), - color='b', label='Data', alpha=0.3) - axs[0].loglog( - plot_frequencies, - interferometer.amplitude_spectral_density_array[frequency_idxs], - color='b', label='PSD') - axs[1].plot( - plot_times, infft( - interferometer.whitened_frequency_domain_strain, - sampling_frequency=interferometer.strain_data.sampling_frequency)[time_idxs], - color='b', alpha=0.3) - logger.debug('Plotted interferometer data.') - except AttributeError: - pass + row=1, + col=1, + ) + fig.add_trace( + go.Scatter( + x=plot_times, y=ht_inj_det, + fill=None, + mode='lines', + line=dict(color=INJECTION_COLOR, dash='dot'), + name="Injection", + legendgroup='injection', + showlegend=False, + ), + row=2, + col=1, + ) + else: + axs[0].loglog( + plot_frequencies, + asd_from_freq_series( + hf_inj_det[frequency_idxs], + 1 / interferometer.strain_data.duration), + color=INJECTION_COLOR, label='Injection', linestyle=':') + axs[1].plot( + plot_times, ht_inj_det, + color=INJECTION_COLOR, linestyle=':') + logger.debug('Plotted injection.') + except IndexError as e: + logger.info('Failed to plot injection with message {}.'.format(e)) + f_domain_x_label = "$f [\\mathrm{Hz}]$" + f_domain_y_label = "$\\mathrm{ASD} \\left[\\mathrm{Hz}^{-1/2}\\right]$" + t_domain_x_label = "$t - {} [s]$".format(interferometer.strain_data.start_time) + t_domain_y_label = "Whitened Strain" if format == "html": - fig.update_xaxes(title_text="Frequency [Hz]", type="log", row=1) - fig.update_xaxes(title_text="Time [s] - {}".format( - interferometer.strain_data.start_time), type="linear", row=2) - fig.update_yaxes( - title_text="$\mathrm{ASD}\,\left[\mathrm{Hz}^{-1/2}\\right]$", type="log", row=1) - fig.update_yaxes(title_text="Whitened Strain", type="linear", row=2) + fig.update_xaxes(title_text=f_domain_x_label, type="log", row=1) + fig.update_yaxes(title_text=f_domain_y_label, type="log", row=1) + fig.update_xaxes(title_text=t_domain_x_label, type="linear", row=2) + fig.update_yaxes(title_text=t_domain_x_label, type="linear", row=2) else: axs[0].set_xlim(interferometer.minimum_frequency, interferometer.maximum_frequency) axs[1].set_xlim(start_time, end_time) - axs[0].set_xlabel('$f$ [$Hz$]') - axs[1].set_xlabel('$t$ [$s$]') - axs[0].set_ylabel('$\\tilde{h}(f)$ [Hz$^{-\\frac{1}{2}}$]') - axs[1].set_ylabel('Whitened strain') + axs[0].set_xlabel(f_domain_x_label) + axs[0].set_ylabel(f_domain_y_label) + axs[1].set_xlabel(t_domain_x_label) + axs[1].set_ylabel(t_domain_y_label) axs[0].legend(loc='lower left') if save: @@ -648,8 +664,19 @@ class CompactBinaryCoalescenceResult(CoreResult): if format == 'html': plot(fig, filename=filename, include_mathjax='cdn', auto_open=False) else: - plt.savefig(filename, format=format, dpi=600) + plt.tight_layout() + try: + plt.savefig(filename, format=format, dpi=600) + except RuntimeError: + logger.debug( + "Failed to save waveform with tex labels turning off tex." + ) + rcParams["text.usetex"] = False + plt.savefig(filename, format=format, dpi=600) plt.close() + rcParams["text.usetex"] = _old_tex + rcParams["font.serif"] = _old_serif + rcParams["font.family"] = _old_family logger.debug("Figure saved to {}".format(filename)) else: return fig