diff --git a/bilby_pipe/data_generation.py b/bilby_pipe/data_generation.py index 93096dcc799f755384b1fcf381d594a74d849b4c..78971e3e1c03de7ac7240a689cd574aa3ad57b6d 100644 --- a/bilby_pipe/data_generation.py +++ b/bilby_pipe/data_generation.py @@ -1227,122 +1227,29 @@ class DataGenerationInput(Input): data_dump.to_pickle() def save_roq_weights(self): - waveform_arguments = self.get_default_waveform_arguments() - - if self.roq_folder is not None: - logger.info( - f"Using ROQ likelihood with roq-folder={self.roq_folder} and " - f"roq-scale-factor={self.roq_scale_factor}" - ) - params = np.genfromtxt(self.roq_folder + "/params.dat", names=True) - - freq_nodes_linear = np.load(self.roq_folder + "/fnodes_linear.npy") - freq_nodes_quadratic = np.load(self.roq_folder + "/fnodes_quadratic.npy") - freq_nodes_linear *= self.roq_scale_factor - freq_nodes_quadratic *= self.roq_scale_factor - - basis_matrix_linear = np.load(self.roq_folder + "/B_linear.npy").T - basis_matrix_quadratic = np.load(self.roq_folder + "/B_quadratic.npy").T - - waveform_arguments["frequency_nodes_linear"] = freq_nodes_linear - waveform_arguments["frequency_nodes_quadratic"] = freq_nodes_quadratic - - if self.roq_weight_format is None: - weight_format = "npz" - else: - weight_format = self.roq_weight_format - elif ( - self.roq_linear_matrix is not None and self.roq_quadratic_matrix is not None - ): - logger.info( - f"Using ROQ likelihood with linear-matrix={self.roq_linear_matrix}, " - f"quadratic-matrix={self.roq_quadratic_matrix}, and roq-scale-factor={self.roq_scale_factor}" - ) - params = None - basis_matrix_linear = self.roq_linear_matrix - basis_matrix_quadratic = self.roq_quadratic_matrix - if self.roq_weight_format is None: - weight_format = "hdf5" - else: - weight_format = self.roq_weight_format - else: - raise AttributeError( - "For the use of ROQ likelihood, roq folder or both linear and " - "quadratic matrices are required." + if self.likelihood_type != "ROQGravitationalWaveTransient": + raise ValueError( + "ROQ weights can only be saved for ROQGravitationalWaveTransient" ) - - waveform_generator = self.waveform_generator_class( - sampling_frequency=self.interferometers.sampling_frequency, - duration=self.interferometers.duration, - frequency_domain_source_model=self.bilby_roq_frequency_domain_source_model, - parameter_conversion=self.parameter_conversion, - start_time=self.interferometers.start_time, - waveform_arguments=waveform_arguments, - ) - - likelihood = bilby.gw.likelihood.ROQGravitationalWaveTransient( - interferometers=self.interferometers, - priors=self.priors, - roq_params=params, - roq_scale_factor=self.roq_scale_factor, - waveform_generator=waveform_generator, - linear_matrix=basis_matrix_linear, - quadratic_matrix=basis_matrix_quadratic, - reference_frame=self.reference_frame, - time_reference=self.time_reference, - weights=self.roq_weights, - ) - - del basis_matrix_linear, basis_matrix_quadratic - + if getattr(self, "likelihood_roq_weights", None) is not None: + self.setup_roq_weights() weight_file = os.path.join( - self.data_directory, f"{self.label}_roq_weights.{weight_format}" + self.data_directory, f"{self.label}_roq_weights.{self.roq_weight_format}" ) self.meta_data["weight_file"] = weight_file - likelihood.save_weights(weight_file, format=weight_format) + self.likelihood.save_weights(weight_file, format=self.roq_weight_format) def save_multiband_weights(self): - waveform_arguments = self.get_default_waveform_arguments() - - waveform_generator = self.waveform_generator_class( - sampling_frequency=self.interferometers.sampling_frequency, - duration=self.interferometers.duration, - frequency_domain_source_model=self.bilby_multiband_frequency_domain_source_model, - parameter_conversion=self.parameter_conversion, - start_time=self.interferometers.start_time, - waveform_arguments=waveform_arguments, - ) - - likelihood = bilby.gw.likelihood.MBGravitationalWaveTransient( - interferometers=self.interferometers, - priors=self.priors, - waveform_generator=waveform_generator, - reference_frame=self.reference_frame, - time_reference=self.time_reference, - reference_chirp_mass=self.extra_likelihood_kwargs.get( - "reference_chirp_mass", None - ), - highest_mode=self.extra_likelihood_kwargs.get("highest_mode", 2), - linear_interpolation=self.extra_likelihood_kwargs.get( - "linear_interpolation", True - ), - accuracy_factor=self.extra_likelihood_kwargs.get("accuracy_factor", 5), - time_offset=self.extra_likelihood_kwargs.get("time_offset", None), - delta_f_end=self.extra_likelihood_kwargs.get("delta_f_end", None), - maximum_banding_frequency=self.extra_likelihood_kwargs.get( - "maximum_banding_frequency", None - ), - minimum_banding_duration=self.extra_likelihood_kwargs.get( - "minimum_banding_duration", 0 - ), - weights=self.extra_likelihood_kwargs.get("weights", None), - ) - + if not self.is_likelihood_multiband: + raise ValueError( + "Multiband weights can only be saved for " + "MultibandGravitationalWaveTransient" + ) weight_file = os.path.join( self.data_directory, f"{self.label}_multiband_weights.hdf5" ) self.meta_data["weight_file"] = weight_file - likelihood.save_weights(weight_file) + self.likelihood.save_weights(weight_file) def create_generation_parser(): @@ -1355,9 +1262,5 @@ def main(): args, unknown_args = parse_args(sys.argv[1:], create_generation_parser()) log_version_information() data = DataGenerationInput(args, unknown_args) - if args.likelihood_type == "ROQGravitationalWaveTransient": - data.save_roq_weights() - if data.is_likelihood_multiband: - data.save_multiband_weights() data.save_data_dump() logger.info("Completed data generation") diff --git a/bilby_pipe/input.py b/bilby_pipe/input.py index 2dba9b2b4bc48715c0fefc7f21537c43b1250930..8d608ced0fb70f794a2067ab4237dee1a70e11d2 100644 --- a/bilby_pipe/input.py +++ b/bilby_pipe/input.py @@ -1254,30 +1254,37 @@ class Input(object): @property def roq_likelihood_kwargs(self): + kwargs = dict(roq_scale_factor=self.roq_scale_factor) if hasattr(self, "likelihood_roq_params"): - params = self.likelihood_roq_params + kwargs["roq_params"] = self.likelihood_roq_params elif self.roq_folder is not None: - params = np.genfromtxt(self.roq_folder + "/params.dat", names=True) - else: - params = None + kwargs["roq_params"] = np.genfromtxt( + f"{self.roq_folder}/params.dat", names=True + ) if hasattr(self, "likelihood_roq_weights"): - weights = self.likelihood_roq_weights + kwargs["weights"] = self.likelihood_roq_weights + elif "weight_file" in self.meta_data: + kwargs["weights"] = self.meta_data["weight_file"] + logger.debug(f"Loading ROQ weights from {kwargs['weights']}") + elif self.roq_folder is not None: + kwargs["linear_matrix"] = np.load(f"{self.roq_folder}/B_linear.npy").T + kwargs["quadratic_matrix"] = np.load(f"{self.roq_folder}/B_quadratic.npy").T else: - weights = self.meta_data["weight_file"] - logger.debug(f"Loading ROQ weights from {weights}") - - return dict( - weights=weights, roq_params=params, roq_scale_factor=self.roq_scale_factor - ) + kwargs["linear_matrix"] = self.roq_linear_matrix + kwargs["quadratic_matrix"] = self.roq_quadratic_matrix + return kwargs @property def multiband_likelihood_kwargs(self): if hasattr(self, "likelihood_multiband_weights"): weights = self.likelihood_multiband_weights - else: + elif "weight_file" in self.meta_data: weights = self.meta_data["weight_file"] logger.info(f"Loading multiband weights from {weights}") + else: + weights = None + logger.info("No multiband weights found, these will be calculated now") return dict(weights=weights) @property diff --git a/bilby_pipe/parser.py b/bilby_pipe/parser.py index b3293f81a05aa486fdd043ab64db4bc81345baef..21af55a35e2377d02ee42733b3b5f520e3c6c9ac 100644 --- a/bilby_pipe/parser.py +++ b/bilby_pipe/parser.py @@ -749,13 +749,11 @@ def create_parser(top_level=True): ) likelihood_parser.add( "--roq-weight-format", - type=nonestr, - default=None, + type=str, + default="hdf5", help=( "File format of roq weights. This should be npz, hdf5, or json. " - "If not specified, it is set to npz if basis file is specified " - "through roq-folder, and hdf5 if through roq-linear-matrix and " - "roq-quadratic-matrix" + "If not specified, it is set to hdf5." ), ) likelihood_parser.add(