diff --git a/bilby_pipe/data_generation.py b/bilby_pipe/data_generation.py index 1b90aada86634625c5b317eef1dba150ff66d398..159735f3a590a0624c6becba739f516b4255863d 100644 --- a/bilby_pipe/data_generation.py +++ b/bilby_pipe/data_generation.py @@ -580,11 +580,16 @@ class DataGenerationInput(Input): self._set_psd_from_file(ifo) else: logger.info(f"Setting PSD for {det} from data") - psd_data = self.__get_psd_data(det) - psd = self.__generate_psd(psd_data, roll_off) + before_psd_data = self.__get_psd_data(det, before=True) + before_psd = self.__generate_psd(before_psd_data, roll_off) + after_psd_data = self.__get_psd_data(det, before=False) + after_psd = self.__generate_psd(after_psd_data, roll_off) + psd = (before_psd.value + after_psd.value) / 2 ifo.power_spectral_density = PowerSpectralDensity( - frequency_array=psd.frequencies.value, psd_array=psd.value + frequency_array=before_psd.frequencies.value, + psd_array=psd, ) + psd_data = [before_psd_data, after_psd_data] logger.info(f"Getting analysis-segment data for {det}") data = self._get_data( @@ -603,12 +608,17 @@ class DataGenerationInput(Input): self.interferometers = bilby.gw.detector.InterferometerList(ifo_list) - def __get_psd_data(self, det): + def __get_psd_data(self, det, before=True): # psd_start_time is given relative to the segment start time # so here we calculate the actual start time - actual_psd_start_time = self.start_time + self.psd_start_time + if before: + actual_psd_start_time = self.start_time + self.psd_start_time + label = "before" + else: + actual_psd_start_time = self.start_time + self.duration + label = "after" actual_psd_end_time = actual_psd_start_time + self.psd_duration - logger.info(f"Getting psd-segment data for {det}") + logger.info(f"Getting {label} analysis segment psd data for {det}") psd_data = self._get_data( det, self.get_channel_type(det), actual_psd_start_time, actual_psd_end_time ) @@ -655,6 +665,11 @@ class DataGenerationInput(Input): else: plot_psd = True + if isinstance(psd_strain_data, list): + psd_strain_data, after_psd_data = psd_strain_data + else: + after_psd_data = None + plot_kwargs = dict( det=det, data_directory=self.data_directory, @@ -680,6 +695,8 @@ class DataGenerationInput(Input): # plot psd_strain_data+strain_data and zoom into strain_data segment data_with_psd = psd_strain_data.append(strain_data, inplace=False) + if after_psd_data is not None: + data_with_psd.append(after_psd_data, inplace=False) strain_spectrogram_plot( data=data_with_psd, extra_label=f"D{int(time[1] - time[0])}", **plot_kwargs )