Commit 9b940cf1 authored by Colm Talbot's avatar Colm Talbot Committed by Paul Lasky

Reduce roq memory usage

parent a2e87c80
......@@ -16,7 +16,8 @@ from matplotlib import lines as mpllines
from . import utils
from .utils import (logger, infer_parameters_from_function,
check_directory_exists_and_if_not_mkdir)
check_directory_exists_and_if_not_mkdir,
BilbyJsonEncoder, decode_bilby_json)
from .prior import Prior, PriorDict, DeltaFunction
......@@ -230,7 +231,7 @@ class Result(object):
if os.path.isfile(filename):
with open(filename, 'r') as file:
dictionary = json.load(file, object_hook=decode_bilby_json_result)
dictionary = json.load(file, object_hook=decode_bilby_json)
for key in dictionary.keys():
# Convert the loaded priors to bilby prior type
if key == 'priors':
......@@ -421,7 +422,7 @@ class Result(object):
try:
if extension == 'json':
with open(file_name, 'w') as file:
json.dump(dictionary, file, indent=2, cls=BilbyResultJsonEncoder)
json.dump(dictionary, file, indent=2, cls=BilbyJsonEncoder)
elif extension == 'hdf5':
deepdish.io.save(file_name, dictionary)
else:
......@@ -1122,27 +1123,6 @@ class Result(object):
return outdir
class BilbyResultJsonEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return {'__array__': True, 'content': obj.tolist()}
if isinstance(obj, complex):
return {'__complex__': True, 'real': obj.real, 'imag': obj.imag}
if isinstance(obj, pd.core.frame.DataFrame):
return {'__dataframe__': True, 'content': obj.to_dict(orient='list')}
return json.JSONEncoder.default(self, obj)
def decode_bilby_json_result(dct):
if dct.get("__array__", False):
return np.asarray(dct["content"])
if dct.get("__complex__", False):
return complex(dct["real"], dct["imag"])
if dct.get("__dataframe__", False):
return pd.DataFrame(dct['content'])
return dct
def plot_multiple(results, filename=None, labels=None, colours=None,
save=True, evidences=False, **kwargs):
""" Generate a corner plot overlaying two sets of results
......
......@@ -7,9 +7,11 @@ import argparse
import traceback
import inspect
import subprocess
import json
import numpy as np
from scipy.interpolate import interp2d
import pandas as pd
logger = logging.getLogger('bilby')
......@@ -751,5 +753,26 @@ else:
print(traceback.format_exc())
class BilbyJsonEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return {'__array__': True, 'content': obj.tolist()}
if isinstance(obj, complex):
return {'__complex__': True, 'real': obj.real, 'imag': obj.imag}
if isinstance(obj, pd.core.frame.DataFrame):
return {'__dataframe__': True, 'content': obj.to_dict(orient='list')}
return json.JSONEncoder.default(self, obj)
def decode_bilby_json(dct):
if dct.get("__array__", False):
return np.asarray(dct["content"])
if dct.get("__complex__", False):
return complex(dct["real"], dct["imag"])
if dct.get("__dataframe__", False):
return pd.DataFrame(dct['content'])
return dct
class IllegalDurationAndSamplingFrequencyException(Exception):
pass
from __future__ import division
import json
import numpy as np
import scipy.integrate as integrate
from scipy.interpolate import interp1d
......@@ -10,7 +13,8 @@ except ImportError:
from scipy.special import i0e
from ..core import likelihood
from ..core.utils import logger, UnsortedInterp2d
from ..core.utils import (
logger, UnsortedInterp2d, BilbyJsonEncoder, decode_bilby_json)
from ..core.prior import Prior, Uniform
from .detector import InterferometerList
from .prior import BBHPriorDict
......@@ -411,8 +415,8 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
A dictionary of priors containing at least the geocent_time prior
"""
def __init__(self, interferometers, waveform_generator,
linear_matrix, quadratic_matrix, priors,
def __init__(self, interferometers, waveform_generator, priors,
weights=None, linear_matrix=None, quadratic_matrix=None,
distance_marginalization=False, phase_marginalization=False):
GravitationalWaveTransient.__init__(
self, interferometers=interferometers,
......@@ -420,18 +424,27 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
distance_marginalization=distance_marginalization,
phase_marginalization=phase_marginalization)
if isinstance(linear_matrix, str):
logger.info("Loading linear matrix from {}".format(linear_matrix))
linear_matrix = np.load(linear_matrix).T
if isinstance(quadratic_matrix, str):
logger.info("Loading quadratic_matrix from {}".format(quadratic_matrix))
quadratic_matrix = np.load(quadratic_matrix).T
self.linear_matrix = linear_matrix
self.quadratic_matrix = quadratic_matrix
self.time_samples = None
self.weights = dict()
self._set_weights()
self.time_samples = np.arange(
self.priors['geocent_time'].minimum - 0.045,
self.priors['geocent_time'].maximum + 0.045,
self._get_time_resolution()) - self.interferometers.start_time
if isinstance(weights, dict):
self.weights = weights
elif isinstance(weights, str):
self.weights = self.load_weights(weights)
else:
self.weights = dict()
if isinstance(linear_matrix, str):
logger.info(
"Loading linear matrix from {}".format(linear_matrix))
linear_matrix = np.load(linear_matrix).T
if isinstance(quadratic_matrix, str):
logger.info(
"Loading quadratic_matrix from {}".format(quadratic_matrix))
quadratic_matrix = np.load(quadratic_matrix).T
self._set_weights(linear_matrix=linear_matrix,
quadratic_matrix=quadratic_matrix)
self.frequency_nodes_linear =\
waveform_generator.waveform_arguments['frequency_nodes_linear']
......@@ -465,7 +478,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
h_cross_quadratic = f_cross * waveform['quadratic']['cross']
indices, in_bounds = self._closest_time_indices(
ifo_time, self.time_samples[ifo.name])
ifo_time, self.time_samples)
if not in_bounds:
return np.nan_to_num(-np.inf)
......@@ -474,7 +487,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
self.weights[ifo.name + '_linear'][indices])
d_inner_h += interp1d(
self.time_samples[ifo.name][indices],
self.time_samples[indices],
d_inner_h_tc_array, kind='cubic')(ifo_time)
optimal_snr_squared += \
......@@ -517,7 +530,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
in_bounds = (indices[0] >= 0) & (indices[-1] < samples.size)
return indices, in_bounds
def _set_weights(self):
def _set_weights(self, linear_matrix, quadratic_matrix):
"""
Setup the time-dependent ROQ weights.
This follows FIXME: Smith et al.
......@@ -525,38 +538,29 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
The times are chosen to allow all the merger times allows in the time
prior.
"""
self.time_samples = dict()
for ifo in self.interferometers:
# only get frequency components up to maximum_frequency
self.linear_matrix = \
self.linear_matrix[:, :sum(ifo.frequency_mask)]
self.quadratic_matrix = \
self.quadratic_matrix[:, :sum(ifo.frequency_mask)]
linear_matrix = linear_matrix[:, :sum(ifo.frequency_mask)]
quadratic_matrix = quadratic_matrix[:, :sum(ifo.frequency_mask)]
# array of relative time shifts to be applied to the data
# 0.045s comes from time for GW to traverse the Earth
self.time_samples[ifo.name] = np.arange(
self.priors['geocent_time'].minimum - 0.045,
self.priors['geocent_time'].maximum + 0.045,
self._get_time_resolution(ifo))
self.time_samples[ifo.name] -= ifo.strain_data.start_time
time_space = (self.time_samples[ifo.name][1] -
self.time_samples[ifo.name][0])
time_space = (self.time_samples[1] -
self.time_samples[0])
# array to be filled with data, shifted by discrete time_samples
tc_shifted_data = np.zeros([
len(self.time_samples[ifo.name]),
len(ifo.frequency_array[ifo.frequency_mask])], dtype=complex)
len(self.time_samples), sum(ifo.frequency_mask)], dtype=complex)
# shift data to beginning of the prior increment by the time step
shifted_data =\
ifo.frequency_domain_strain[ifo.frequency_mask] * \
np.exp(2j * np.pi * ifo.frequency_array[ifo.frequency_mask] *
self.time_samples[ifo.name][0])
self.time_samples[0])
single_time_shift = np.exp(
2j * np.pi * ifo.frequency_array[ifo.frequency_mask] *
time_space)
for j in range(len(self.time_samples[ifo.name])):
for j in range(len(self.time_samples)):
tc_shifted_data[j] = shifted_data
shifted_data *= single_time_shift
......@@ -568,16 +572,25 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
self.weights[ifo.name + '_linear'] = blockwise_dot_product(
tc_shifted_data /
ifo.power_spectral_density_array[ifo.frequency_mask],
self.linear_matrix, max_elements) * 4 / ifo.strain_data.duration
linear_matrix, max_elements) * 4 / ifo.strain_data.duration
del tc_shifted_data
self.weights[ifo.name + '_quadratic'] = build_roq_weights(
1 / ifo.power_spectral_density_array[ifo.frequency_mask],
self.quadratic_matrix.real, 1 / ifo.strain_data.duration)
quadratic_matrix.real, 1 / ifo.strain_data.duration)
def save_weights(self, filename):
with open(filename, 'w') as file:
json.dump(self.weights, file, indent=2, cls=BilbyJsonEncoder)
@staticmethod
def _get_time_resolution(ifo):
def load_weights(filename):
with open(filename, 'r') as file:
weights = json.load(file, object_hook=decode_bilby_json)
return weights
def _get_time_resolution(self):
"""
This method estimates the time resolution given the optimal SNR of the
signal in the detector. This is then used when constructing the weights
......@@ -622,12 +635,13 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
def c_f_scaling(snr):
return (np.pi**2 * snr**2 / 6)**(1 / 3)
psd = ifo.power_spectral_density_array[ifo.frequency_mask]
freq = ifo.frequency_array[ifo.frequency_mask]
inj_snr = 0
for ifo in self.interferometers:
inj_snr = getattr(ifo.meta_data, 'optimal_SNR', 30)
inj_snr += getattr(ifo.meta_data, 'optimal_SNR', 30)
psd = ifo.power_spectral_density_array[ifo.frequency_mask]
freq = ifo.frequency_array[ifo.frequency_mask]
fhigh = calc_fhigh(freq, psd, scaling=c_f_scaling(inj_snr))
delta_t = fhigh**-1
......
......@@ -23,7 +23,7 @@ basis_matrix_linear = np.load("B_linear.npy").T
freq_nodes_linear = np.load("fnodes_linear.npy")
# Load in the pieces for the quadratic part of the ROQ
basic_matrix_quadratic = np.load("B_quadratic.npy").T
basis_matrix_quadratic = np.load("B_quadratic.npy").T
freq_nodes_quadratic = np.load("fnodes_quadratic.npy")
np.random.seed(170808)
......@@ -77,11 +77,22 @@ priors['geocent_time'] = bilby.core.prior.Uniform(
likelihood = bilby.gw.likelihood.ROQGravitationalWaveTransient(
interferometers=ifos, waveform_generator=search_waveform_generator,
linear_matrix=basis_matrix_linear, quadratic_matrix=basic_matrix_quadratic,
prior=priors)
linear_matrix=basis_matrix_linear, quadratic_matrix=basis_matrix_quadratic,
priors=priors)
# write the weights to file so they can be loaded multiple times
likelihood.save_weights('weights.json')
# remove the basis matrices as these are big for longer bases
del basis_matrix_linear, basis_matrix_quadratic
# load the weights from the file
likelihood = bilby.gw.likelihood.ROQGravitationalWaveTransient(
interferometers=ifos, waveform_generator=search_waveform_generator,
weights='weights.json', priors=priors)
result = bilby.run_sampler(
likelihood=likelihood, priors=priors, sampler='dynesty', npoints=500,
likelihood=likelihood, priors=priors, sampler='pymultinest', npoints=500,
injection_parameters=injection_parameters, outdir=outdir, label=label)
# Make a corner plot.
......
......@@ -11,38 +11,43 @@ import bilby
class TestJson(unittest.TestCase):
def setUp(self):
self.encoder = bilby.core.utils.BilbyJsonEncoder
self.decoder = bilby.core.utils.decode_bilby_json
def test_list_encoding(self):
data = dict(x=[1, 2, 3.4])
encoded = json.dumps(data, cls=bilby.core.result.BilbyResultJsonEncoder)
decoded = json.loads(encoded, object_hook=bilby.core.result.decode_bilby_json_result)
encoded = json.dumps(data, cls=self.encoder)
decoded = json.loads(encoded, object_hook=self.decoder)
self.assertEqual(data.keys(), decoded.keys())
self.assertEqual(type(data['x']), type(decoded['x']))
self.assertTrue(np.all(data['x']==decoded['x']))
self.assertTrue(np.all(data['x'] == decoded['x']))
def test_array_encoding(self):
data = dict(x=np.array([1, 2, 3.4]))
encoded = json.dumps(data, cls=bilby.core.result.BilbyResultJsonEncoder)
decoded = json.loads(encoded, object_hook=bilby.core.result.decode_bilby_json_result)
encoded = json.dumps(data, cls=self.encoder)
decoded = json.loads(encoded, object_hook=self.decoder)
self.assertEqual(data.keys(), decoded.keys())
self.assertEqual(type(data['x']), type(decoded['x']))
self.assertTrue(np.all(data['x']==decoded['x']))
self.assertTrue(np.all(data['x'] == decoded['x']))
def test_complex_encoding(self):
data = dict(x=1 + 3j)
encoded = json.dumps(data, cls=bilby.core.result.BilbyResultJsonEncoder)
decoded = json.loads(encoded, object_hook=bilby.core.result.decode_bilby_json_result)
encoded = json.dumps(data, cls=self.encoder)
decoded = json.loads(encoded, object_hook=self.decoder)
self.assertEqual(data.keys(), decoded.keys())
self.assertEqual(type(data['x']), type(decoded['x']))
self.assertTrue(np.all(data['x']==decoded['x']))
self.assertTrue(np.all(data['x'] == decoded['x']))
def test_dataframe_encoding(self):
data = dict(data=pd.DataFrame(dict(x=[3, 4, 5], y=[5, 6, 7])))
encoded = json.dumps(data, cls=bilby.core.result.BilbyResultJsonEncoder)
decoded = json.loads(encoded, object_hook=bilby.core.result.decode_bilby_json_result)
encoded = json.dumps(data, cls=self.encoder)
decoded = json.loads(encoded, object_hook=self.decoder)
self.assertEqual(data.keys(), decoded.keys())
self.assertEqual(type(data['data']), type(decoded['data']))
self.assertTrue(np.all(data['data']['x']==decoded['data']['x']))
self.assertTrue(np.all(data['data']['y']==decoded['data']['y']))
self.assertTrue(np.all(data['data']['x'] == decoded['data']['x']))
self.assertTrue(np.all(data['data']['y'] == decoded['data']['y']))
class TestResult(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment