diff --git a/bilby/core/result.py b/bilby/core/result.py index e2c5894f87146880f6f7b9d49b7ab52750b11847..9fb59d2adcc58c664bcf2ec0e09953ac91fef890 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -16,7 +16,8 @@ from matplotlib import lines as mpllines from . import utils from .utils import (logger, infer_parameters_from_function, - check_directory_exists_and_if_not_mkdir) + check_directory_exists_and_if_not_mkdir, + BilbyJsonEncoder, decode_bilby_json) from .prior import Prior, PriorDict, DeltaFunction @@ -230,7 +231,7 @@ class Result(object): if os.path.isfile(filename): with open(filename, 'r') as file: - dictionary = json.load(file, object_hook=decode_bilby_json_result) + dictionary = json.load(file, object_hook=decode_bilby_json) for key in dictionary.keys(): # Convert the loaded priors to bilby prior type if key == 'priors': @@ -421,7 +422,7 @@ class Result(object): try: if extension == 'json': with open(file_name, 'w') as file: - json.dump(dictionary, file, indent=2, cls=BilbyResultJsonEncoder) + json.dump(dictionary, file, indent=2, cls=BilbyJsonEncoder) elif extension == 'hdf5': deepdish.io.save(file_name, dictionary) else: @@ -1122,27 +1123,6 @@ class Result(object): return outdir -class BilbyResultJsonEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, np.ndarray): - return {'__array__': True, 'content': obj.tolist()} - if isinstance(obj, complex): - return {'__complex__': True, 'real': obj.real, 'imag': obj.imag} - if isinstance(obj, pd.core.frame.DataFrame): - return {'__dataframe__': True, 'content': obj.to_dict(orient='list')} - return json.JSONEncoder.default(self, obj) - - -def decode_bilby_json_result(dct): - if dct.get("__array__", False): - return np.asarray(dct["content"]) - if dct.get("__complex__", False): - return complex(dct["real"], dct["imag"]) - if dct.get("__dataframe__", False): - return pd.DataFrame(dct['content']) - return dct - - def plot_multiple(results, filename=None, labels=None, colours=None, save=True, evidences=False, **kwargs): """ Generate a corner plot overlaying two sets of results diff --git a/bilby/core/utils.py b/bilby/core/utils.py index 1b3b1cfab96b45bd15b597b8b1339321767b6d18..dfb2117f84689cc06618ae6e393432a644aac701 100644 --- a/bilby/core/utils.py +++ b/bilby/core/utils.py @@ -7,9 +7,11 @@ import argparse import traceback import inspect import subprocess +import json import numpy as np from scipy.interpolate import interp2d +import pandas as pd logger = logging.getLogger('bilby') @@ -751,5 +753,26 @@ else: print(traceback.format_exc()) +class BilbyJsonEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.ndarray): + return {'__array__': True, 'content': obj.tolist()} + if isinstance(obj, complex): + return {'__complex__': True, 'real': obj.real, 'imag': obj.imag} + if isinstance(obj, pd.core.frame.DataFrame): + return {'__dataframe__': True, 'content': obj.to_dict(orient='list')} + return json.JSONEncoder.default(self, obj) + + +def decode_bilby_json(dct): + if dct.get("__array__", False): + return np.asarray(dct["content"]) + if dct.get("__complex__", False): + return complex(dct["real"], dct["imag"]) + if dct.get("__dataframe__", False): + return pd.DataFrame(dct['content']) + return dct + + class IllegalDurationAndSamplingFrequencyException(Exception): pass diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py index d6e2f127f8d568e929e7e2635adc4553c568b977..e3ff0733879bfe57ee863ca2e3a6750e67e323cb 100644 --- a/bilby/gw/likelihood.py +++ b/bilby/gw/likelihood.py @@ -1,4 +1,7 @@ from __future__ import division + +import json + import numpy as np import scipy.integrate as integrate from scipy.interpolate import interp1d @@ -10,7 +13,8 @@ except ImportError: from scipy.special import i0e from ..core import likelihood -from ..core.utils import logger, UnsortedInterp2d +from ..core.utils import ( + logger, UnsortedInterp2d, BilbyJsonEncoder, decode_bilby_json) from ..core.prior import Prior, Uniform from .detector import InterferometerList from .prior import BBHPriorDict @@ -411,8 +415,8 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): A dictionary of priors containing at least the geocent_time prior """ - def __init__(self, interferometers, waveform_generator, - linear_matrix, quadratic_matrix, priors, + def __init__(self, interferometers, waveform_generator, priors, + weights=None, linear_matrix=None, quadratic_matrix=None, distance_marginalization=False, phase_marginalization=False): GravitationalWaveTransient.__init__( self, interferometers=interferometers, @@ -420,18 +424,27 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): distance_marginalization=distance_marginalization, phase_marginalization=phase_marginalization) - if isinstance(linear_matrix, str): - logger.info("Loading linear matrix from {}".format(linear_matrix)) - linear_matrix = np.load(linear_matrix).T - if isinstance(quadratic_matrix, str): - logger.info("Loading quadratic_matrix from {}".format(quadratic_matrix)) - quadratic_matrix = np.load(quadratic_matrix).T - - self.linear_matrix = linear_matrix - self.quadratic_matrix = quadratic_matrix - self.time_samples = None - self.weights = dict() - self._set_weights() + self.time_samples = np.arange( + self.priors['geocent_time'].minimum - 0.045, + self.priors['geocent_time'].maximum + 0.045, + self._get_time_resolution()) - self.interferometers.start_time + + if isinstance(weights, dict): + self.weights = weights + elif isinstance(weights, str): + self.weights = self.load_weights(weights) + else: + self.weights = dict() + if isinstance(linear_matrix, str): + logger.info( + "Loading linear matrix from {}".format(linear_matrix)) + linear_matrix = np.load(linear_matrix).T + if isinstance(quadratic_matrix, str): + logger.info( + "Loading quadratic_matrix from {}".format(quadratic_matrix)) + quadratic_matrix = np.load(quadratic_matrix).T + self._set_weights(linear_matrix=linear_matrix, + quadratic_matrix=quadratic_matrix) self.frequency_nodes_linear =\ waveform_generator.waveform_arguments['frequency_nodes_linear'] @@ -465,7 +478,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): h_cross_quadratic = f_cross * waveform['quadratic']['cross'] indices, in_bounds = self._closest_time_indices( - ifo_time, self.time_samples[ifo.name]) + ifo_time, self.time_samples) if not in_bounds: return np.nan_to_num(-np.inf) @@ -474,7 +487,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): self.weights[ifo.name + '_linear'][indices]) d_inner_h += interp1d( - self.time_samples[ifo.name][indices], + self.time_samples[indices], d_inner_h_tc_array, kind='cubic')(ifo_time) optimal_snr_squared += \ @@ -517,7 +530,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): in_bounds = (indices[0] >= 0) & (indices[-1] < samples.size) return indices, in_bounds - def _set_weights(self): + def _set_weights(self, linear_matrix, quadratic_matrix): """ Setup the time-dependent ROQ weights. This follows FIXME: Smith et al. @@ -525,38 +538,29 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): The times are chosen to allow all the merger times allows in the time prior. """ - self.time_samples = dict() for ifo in self.interferometers: # only get frequency components up to maximum_frequency - self.linear_matrix = \ - self.linear_matrix[:, :sum(ifo.frequency_mask)] - self.quadratic_matrix = \ - self.quadratic_matrix[:, :sum(ifo.frequency_mask)] + linear_matrix = linear_matrix[:, :sum(ifo.frequency_mask)] + quadratic_matrix = quadratic_matrix[:, :sum(ifo.frequency_mask)] # array of relative time shifts to be applied to the data # 0.045s comes from time for GW to traverse the Earth - self.time_samples[ifo.name] = np.arange( - self.priors['geocent_time'].minimum - 0.045, - self.priors['geocent_time'].maximum + 0.045, - self._get_time_resolution(ifo)) - self.time_samples[ifo.name] -= ifo.strain_data.start_time - time_space = (self.time_samples[ifo.name][1] - - self.time_samples[ifo.name][0]) + time_space = (self.time_samples[1] - + self.time_samples[0]) # array to be filled with data, shifted by discrete time_samples tc_shifted_data = np.zeros([ - len(self.time_samples[ifo.name]), - len(ifo.frequency_array[ifo.frequency_mask])], dtype=complex) + len(self.time_samples), sum(ifo.frequency_mask)], dtype=complex) # shift data to beginning of the prior increment by the time step shifted_data =\ ifo.frequency_domain_strain[ifo.frequency_mask] * \ np.exp(2j * np.pi * ifo.frequency_array[ifo.frequency_mask] * - self.time_samples[ifo.name][0]) + self.time_samples[0]) single_time_shift = np.exp( 2j * np.pi * ifo.frequency_array[ifo.frequency_mask] * time_space) - for j in range(len(self.time_samples[ifo.name])): + for j in range(len(self.time_samples)): tc_shifted_data[j] = shifted_data shifted_data *= single_time_shift @@ -568,16 +572,25 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): self.weights[ifo.name + '_linear'] = blockwise_dot_product( tc_shifted_data / ifo.power_spectral_density_array[ifo.frequency_mask], - self.linear_matrix, max_elements) * 4 / ifo.strain_data.duration + linear_matrix, max_elements) * 4 / ifo.strain_data.duration del tc_shifted_data self.weights[ifo.name + '_quadratic'] = build_roq_weights( 1 / ifo.power_spectral_density_array[ifo.frequency_mask], - self.quadratic_matrix.real, 1 / ifo.strain_data.duration) + quadratic_matrix.real, 1 / ifo.strain_data.duration) + + def save_weights(self, filename): + with open(filename, 'w') as file: + json.dump(self.weights, file, indent=2, cls=BilbyJsonEncoder) @staticmethod - def _get_time_resolution(ifo): + def load_weights(filename): + with open(filename, 'r') as file: + weights = json.load(file, object_hook=decode_bilby_json) + return weights + + def _get_time_resolution(self): """ This method estimates the time resolution given the optimal SNR of the signal in the detector. This is then used when constructing the weights @@ -622,12 +635,13 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): def c_f_scaling(snr): return (np.pi**2 * snr**2 / 6)**(1 / 3) - psd = ifo.power_spectral_density_array[ifo.frequency_mask] - - freq = ifo.frequency_array[ifo.frequency_mask] + inj_snr = 0 + for ifo in self.interferometers: - inj_snr = getattr(ifo.meta_data, 'optimal_SNR', 30) + inj_snr += getattr(ifo.meta_data, 'optimal_SNR', 30) + psd = ifo.power_spectral_density_array[ifo.frequency_mask] + freq = ifo.frequency_array[ifo.frequency_mask] fhigh = calc_fhigh(freq, psd, scaling=c_f_scaling(inj_snr)) delta_t = fhigh**-1 diff --git a/examples/injection_examples/roq_example.py b/examples/injection_examples/roq_example.py index 2ad8c99d4d020289e1e148f63829c1dbcc0c6560..319251dccd89e975c76d503d88e8c753a1b51300 100644 --- a/examples/injection_examples/roq_example.py +++ b/examples/injection_examples/roq_example.py @@ -23,7 +23,7 @@ basis_matrix_linear = np.load("B_linear.npy").T freq_nodes_linear = np.load("fnodes_linear.npy") # Load in the pieces for the quadratic part of the ROQ -basic_matrix_quadratic = np.load("B_quadratic.npy").T +basis_matrix_quadratic = np.load("B_quadratic.npy").T freq_nodes_quadratic = np.load("fnodes_quadratic.npy") np.random.seed(170808) @@ -77,11 +77,22 @@ priors['geocent_time'] = bilby.core.prior.Uniform( likelihood = bilby.gw.likelihood.ROQGravitationalWaveTransient( interferometers=ifos, waveform_generator=search_waveform_generator, - linear_matrix=basis_matrix_linear, quadratic_matrix=basic_matrix_quadratic, - prior=priors) + linear_matrix=basis_matrix_linear, quadratic_matrix=basis_matrix_quadratic, + priors=priors) + +# write the weights to file so they can be loaded multiple times +likelihood.save_weights('weights.json') + +# remove the basis matrices as these are big for longer bases +del basis_matrix_linear, basis_matrix_quadratic + +# load the weights from the file +likelihood = bilby.gw.likelihood.ROQGravitationalWaveTransient( + interferometers=ifos, waveform_generator=search_waveform_generator, + weights='weights.json', priors=priors) result = bilby.run_sampler( - likelihood=likelihood, priors=priors, sampler='dynesty', npoints=500, + likelihood=likelihood, priors=priors, sampler='pymultinest', npoints=500, injection_parameters=injection_parameters, outdir=outdir, label=label) # Make a corner plot. diff --git a/test/result_test.py b/test/result_test.py index 00b6c7b67eeb8951f787e07bcd0f5626cb9478f5..297d22b4e6f4360deffe6e670e93466f8c10f894 100644 --- a/test/result_test.py +++ b/test/result_test.py @@ -11,38 +11,43 @@ import bilby class TestJson(unittest.TestCase): + + def setUp(self): + self.encoder = bilby.core.utils.BilbyJsonEncoder + self.decoder = bilby.core.utils.decode_bilby_json + def test_list_encoding(self): data = dict(x=[1, 2, 3.4]) - encoded = json.dumps(data, cls=bilby.core.result.BilbyResultJsonEncoder) - decoded = json.loads(encoded, object_hook=bilby.core.result.decode_bilby_json_result) + encoded = json.dumps(data, cls=self.encoder) + decoded = json.loads(encoded, object_hook=self.decoder) self.assertEqual(data.keys(), decoded.keys()) self.assertEqual(type(data['x']), type(decoded['x'])) - self.assertTrue(np.all(data['x']==decoded['x'])) + self.assertTrue(np.all(data['x'] == decoded['x'])) def test_array_encoding(self): data = dict(x=np.array([1, 2, 3.4])) - encoded = json.dumps(data, cls=bilby.core.result.BilbyResultJsonEncoder) - decoded = json.loads(encoded, object_hook=bilby.core.result.decode_bilby_json_result) + encoded = json.dumps(data, cls=self.encoder) + decoded = json.loads(encoded, object_hook=self.decoder) self.assertEqual(data.keys(), decoded.keys()) self.assertEqual(type(data['x']), type(decoded['x'])) - self.assertTrue(np.all(data['x']==decoded['x'])) + self.assertTrue(np.all(data['x'] == decoded['x'])) def test_complex_encoding(self): data = dict(x=1 + 3j) - encoded = json.dumps(data, cls=bilby.core.result.BilbyResultJsonEncoder) - decoded = json.loads(encoded, object_hook=bilby.core.result.decode_bilby_json_result) + encoded = json.dumps(data, cls=self.encoder) + decoded = json.loads(encoded, object_hook=self.decoder) self.assertEqual(data.keys(), decoded.keys()) self.assertEqual(type(data['x']), type(decoded['x'])) - self.assertTrue(np.all(data['x']==decoded['x'])) + self.assertTrue(np.all(data['x'] == decoded['x'])) def test_dataframe_encoding(self): data = dict(data=pd.DataFrame(dict(x=[3, 4, 5], y=[5, 6, 7]))) - encoded = json.dumps(data, cls=bilby.core.result.BilbyResultJsonEncoder) - decoded = json.loads(encoded, object_hook=bilby.core.result.decode_bilby_json_result) + encoded = json.dumps(data, cls=self.encoder) + decoded = json.loads(encoded, object_hook=self.decoder) self.assertEqual(data.keys(), decoded.keys()) self.assertEqual(type(data['data']), type(decoded['data'])) - self.assertTrue(np.all(data['data']['x']==decoded['data']['x'])) - self.assertTrue(np.all(data['data']['y']==decoded['data']['y'])) + self.assertTrue(np.all(data['data']['x'] == decoded['data']['x'])) + self.assertTrue(np.all(data['data']['y'] == decoded['data']['y'])) class TestResult(unittest.TestCase):