Skip to content
Snippets Groups Projects
Commit ad4c243e authored by Colm Talbot's avatar Colm Talbot Committed by Gregory Ashton
Browse files

allow writing/reading ROQ weights to/from npz

parent 3f25d4a7
No related branches found
No related tags found
No related merge requests found
......@@ -1000,14 +1000,28 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
logger.info("Finished building weights for {}".format(ifo.name))
def save_weights(self, filename):
with open(filename, 'w') as file:
json.dump(self.weights, file, indent=2, cls=BilbyJsonEncoder)
def save_weights(self, filename, format='npz'):
if format not in filename:
filename += "." + format
logger.info("Saving ROQ weights to {}".format(filename))
if format == 'json':
with open(filename, 'w') as file:
json.dump(self.weights, file, indent=2, cls=BilbyJsonEncoder)
elif format == 'npz':
np.savez(filename, **self.weights)
@staticmethod
def load_weights(filename):
with open(filename, 'r') as file:
weights = json.load(file, object_hook=decode_bilby_json)
def load_weights(filename, format=None):
if format is None:
format = filename.split(".")[-1]
if format not in ["json", "npz"]:
raise IOError("Format {} not recongized.".format(format))
logger.info("Loading ROQ weights from {}".format(filename))
if format == "json":
with open(filename, 'r') as file:
weights = json.load(file, object_hook=decode_bilby_json)
elif format == "npz":
weights = np.load(filename)
return weights
def _get_time_resolution(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment