Commit 130740ce authored by Gregory Ashton's avatar Gregory Ashton

Merge branch 'save-roq-weights-npz' into 'master'

allow writing/reading ROQ weights to/from npz

See merge request !536
parents 3f25d4a7 ad4c243e
Pipeline #68136 passed with stages
in 5 minutes and 42 seconds
......@@ -1000,14 +1000,28 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
logger.info("Finished building weights for {}".format(ifo.name))
def save_weights(self, filename):
with open(filename, 'w') as file:
json.dump(self.weights, file, indent=2, cls=BilbyJsonEncoder)
def save_weights(self, filename, format='npz'):
if format not in filename:
filename += "." + format
logger.info("Saving ROQ weights to {}".format(filename))
if format == 'json':
with open(filename, 'w') as file:
json.dump(self.weights, file, indent=2, cls=BilbyJsonEncoder)
elif format == 'npz':
np.savez(filename, **self.weights)
@staticmethod
def load_weights(filename):
with open(filename, 'r') as file:
weights = json.load(file, object_hook=decode_bilby_json)
def load_weights(filename, format=None):
if format is None:
format = filename.split(".")[-1]
if format not in ["json", "npz"]:
raise IOError("Format {} not recongized.".format(format))
logger.info("Loading ROQ weights from {}".format(filename))
if format == "json":
with open(filename, 'r') as file:
weights = json.load(file, object_hook=decode_bilby_json)
elif format == "npz":
weights = np.load(filename)
return weights
def _get_time_resolution(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment