From 8778bfd3805609642797186e53bd4b4a3095daae Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Wed, 20 Nov 2019 21:21:00 -0600
Subject: [PATCH] Resolve "Add checking to the ROQ usage"

---
 bilby/gw/likelihood.py                        |  89 +++++++++-
 bilby/gw/prior.py                             |  62 ++++++-
 .../injection_examples/roq_example.py         |  11 +-
 test/gw_likelihood_test.py                    | 166 +++++++++++++++++-
 4 files changed, 312 insertions(+), 16 deletions(-)

diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py
index 2af5d6201..804dab655 100644
--- a/bilby/gw/likelihood.py
+++ b/bilby/gw/likelihood.py
@@ -22,7 +22,7 @@ from ..core.utils import (
     speed_of_light, radius_of_earth)
 from ..core.prior import Interped, Prior, Uniform
 from .detector import InterferometerList
-from .prior import BBHPriorDict
+from .prior import BBHPriorDict, CBCPriorDict
 from .source import lal_binary_black_hole
 from .utils import noise_weighted_inner_product, build_roq_weights, blockwise_dot_product
 from .waveform_generator import WaveformGenerator
@@ -824,6 +824,12 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
         quadratic_matrix array, or the array itself.
     roq_params: str, array_like
         Parameters describing the domain of validity of the ROQ basis.
+    roq_params_check: bool
+        If true, run tests using the roq_params to check the prior and data are
+        valid for the ROQ
+    roq_scale_factor: float
+        The ROQ scale factor used. WARNING: this does not apply the scaling,
+        but is only used for checking that the ROQ basis is appropriate.
     priors: dict, bilby.prior.PriorDict
         A dictionary of priors containing at least the geocent_time prior
     distance_marginalization_lookup_table: (dict, str), optional
@@ -838,7 +844,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
     """
     def __init__(self, interferometers, waveform_generator, priors,
                  weights=None, linear_matrix=None, quadratic_matrix=None,
-                 roq_params=None,
+                 roq_params=None, roq_params_check=True, roq_scale_factor=1,
                  distance_marginalization=False, phase_marginalization=False,
                  distance_marginalization_lookup_table=None):
         super(ROQGravitationalWaveTransient, self).__init__(
@@ -850,9 +856,12 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
             distance_marginalization_lookup_table=distance_marginalization_lookup_table,
             jitter_time=False)
 
+        self.roq_params_check = roq_params_check
+        self.roq_scale_factor = roq_scale_factor
         if isinstance(roq_params, np.ndarray) or roq_params is None:
             self.roq_params = roq_params
         elif isinstance(roq_params, str):
+            self.roq_params_file = roq_params
             self.roq_params = np.genfromtxt(roq_params, names=True)
         else:
             raise TypeError("roq_params should be array or str")
@@ -968,6 +977,75 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
         in_bounds = (indices[0] >= 0) & (indices[-1] < samples.size)
         return indices, in_bounds
 
+    def perform_roq_params_check(self, ifo=None):
+        """ Perform checking that the prior and data are valid for the ROQ
+
+        Parameters
+        ----------
+        ifo: bilby.gw.detector.Interferometer
+            The interferometer
+        """
+        if self.roq_params_check is False:
+            logger.warning("No ROQ params checking performed")
+            return
+        else:
+            if getattr(self, "roq_params_file", None) is not None:
+                msg = ("Check ROQ params {} with roq_scale_factor={}"
+                       .format(self.roq_params_file, self.roq_scale_factor))
+            else:
+                msg = ("Check ROQ params with roq_scale_factor={}"
+                       .format(self.roq_scale_factor))
+            logger.info(msg)
+
+        roq_params = self.roq_params
+        roq_params['flow'] *= self.roq_scale_factor
+        roq_params['fhigh'] *= self.roq_scale_factor
+        roq_params['seglen'] /= self.roq_scale_factor
+        roq_params['chirpmassmin'] /= self.roq_scale_factor
+        roq_params['chirpmassmax'] /= self.roq_scale_factor
+        roq_params['compmin'] /= self.roq_scale_factor
+
+        if ifo.maximum_frequency > roq_params['fhigh']:
+            raise BilbyROQParamsRangeError(
+                "Requested maximum frequency {} larger than ROQ basis fhigh {}"
+                .format(ifo.maximum_frequency, roq_params['fhigh']))
+        if ifo.minimum_frequency < roq_params['flow']:
+            raise BilbyROQParamsRangeError(
+                "Requested minimum frequency {} lower than ROQ basis flow {}"
+                .format(ifo.minimum_frequency, roq_params['flow']))
+        if ifo.strain_data.duration != roq_params['seglen']:
+            raise BilbyROQParamsRangeError(
+                "Requested duration differs from ROQ basis seglen")
+
+        priors = self.priors
+        if isinstance(priors, CBCPriorDict) is False:
+            logger.warning("Unable to check ROQ parameter bounds: priors not understood")
+            return
+
+        if priors.minimum_chirp_mass is None:
+            logger.warning("Unable to check minimum chirp mass ROQ bounds")
+        elif priors.minimum_chirp_mass < roq_params["chirpmassmin"]:
+            raise BilbyROQParamsRangeError(
+                "Prior minimum chirp mass {} less than ROQ basis bound {}"
+                .format(priors.minimum_chirp_mass,
+                        roq_params["chirpmassmin"]))
+
+        if priors.maximum_chirp_mass is None:
+            logger.warning("Unable to check maximum_chirp mass ROQ bounds")
+        elif priors.maximum_chirp_mass > roq_params["chirpmassmax"]:
+            raise BilbyROQParamsRangeError(
+                "Prior maximum chirp mass {} greater than ROQ basis bound {}"
+                .format(priors.maximum_chirp_mass,
+                        roq_params["chirpmassmax"]))
+
+        if priors.minimum_component_mass is None:
+            logger.warning("Unable to check minimum component mass ROQ bounds")
+        elif priors.minimum_component_mass < roq_params["compmin"]:
+            raise BilbyROQParamsRangeError(
+                "Prior minimum component mass {} less than ROQ basis bound {}"
+                .format(priors.minimum_component_mass,
+                        roq_params["compmin"]))
+
     def _set_weights(self, linear_matrix, quadratic_matrix):
         """ Setup the time-dependent ROQ weights.
 
@@ -991,8 +1069,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
 
         for ifo in self.interferometers:
             if self.roq_params is not None:
-                if ifo.maximum_frequency > self.roq_params['fhigh']:
-                    raise ValueError("Requested maximum frequency larger than ROQ basis fhigh")
+                self.perform_roq_params_check(ifo)
                 # Generate frequencies for the ROQ
                 roq_frequencies = create_frequency_series(
                     sampling_frequency=self.roq_params['fhigh'] * 2,
@@ -1167,3 +1244,7 @@ def get_binary_black_hole_likelihood(interferometers):
         waveform_arguments={'waveform_approximant': 'IMRPhenomPv2',
                             'reference_frequency': 50})
     return GravitationalWaveTransient(interferometers, waveform_generator)
+
+
+class BilbyROQParamsRangeError(Exception):
+    pass
diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py
index ae8be2a11..5789a76b5 100644
--- a/bilby/gw/prior.py
+++ b/bilby/gw/prior.py
@@ -10,7 +10,9 @@ from ..core.utils import infer_args_from_method, logger
 from .conversion import (
     convert_to_lal_binary_black_hole_parameters,
     convert_to_lal_binary_neutron_star_parameters, generate_mass_parameters,
-    generate_tidal_parameters, fill_from_fixed_priors)
+    generate_tidal_parameters, fill_from_fixed_priors,
+    chirp_mass_and_mass_ratio_to_total_mass,
+    total_mass_and_mass_ratio_to_component_masses)
 from .cosmology import get_cosmology
 
 try:
@@ -305,7 +307,61 @@ class AlignedSpin(Interped):
                                           boundary=boundary)
 
 
-class BBHPriorDict(PriorDict):
+class CBCPriorDict(PriorDict):
+    @property
+    def minimum_chirp_mass(self):
+        mass_1 = None
+        mass_2 = None
+        if "chirp_mass" in self:
+            return self["chirp_mass"].minimum
+        elif "mass_1" in self:
+            mass_1 = self['mass_1'].minimum
+            if "mass_2" in self:
+                mass_2 = self['mass_2'].minimum
+            elif "mass_ratio" in self:
+                mass_2 = mass_1 * self["mass_ratio"].minimum
+        if mass_1 is not None and mass_2 is not None:
+            s = generate_mass_parameters(dict(mass_1=mass_1, mass_2=mass_2))
+            return s["chirp_mass"]
+        else:
+            logger.warning("Unable to determine minimum chirp mass")
+            return None
+
+    @property
+    def maximum_chirp_mass(self):
+        mass_1 = None
+        mass_2 = None
+        if "chirp_mass" in self:
+            return self["chirp_mass"].maximum
+        elif "mass_1" in self:
+            mass_1 = self['mass_1'].maximum
+            if "mass_2" in self:
+                mass_2 = self['mass_2'].maximum
+            elif "mass_ratio" in self:
+                mass_2 = mass_1 * self["mass_ratio"].maximum
+        if mass_1 is not None and mass_2 is not None:
+            s = generate_mass_parameters(dict(mass_1=mass_1, mass_2=mass_2))
+            return s["chirp_mass"]
+        else:
+            logger.warning("Unable to determine maximum chirp mass")
+            return None
+
+    @property
+    def minimum_component_mass(self):
+        if "mass_2" in self:
+            return self["mass_2"].minimum
+        if "chirp_mass" in self and "mass_ratio" in self:
+            total_mass = chirp_mass_and_mass_ratio_to_total_mass(
+                self["chirp_mass"].minimum, self["mass_ratio"].minimum)
+            _, mass_2 = total_mass_and_mass_ratio_to_component_masses(
+                self["mass_ratio"].minimum, total_mass)
+            return mass_2
+        else:
+            logger.warning("Unable to determine minimum component mass")
+            return None
+
+
+class BBHPriorDict(CBCPriorDict):
     def __init__(self, dictionary=None, filename=None, aligned_spin=False,
                  conversion_function=None):
         """ Initialises a Prior set for Binary Black holes
@@ -411,7 +467,7 @@ class BBHPriorDict(PriorDict):
         return False
 
 
-class BNSPriorDict(PriorDict):
+class BNSPriorDict(CBCPriorDict):
 
     def __init__(self, dictionary=None, filename=None, aligned_spin=True,
                  conversion_function=None):
diff --git a/examples/gw_examples/injection_examples/roq_example.py b/examples/gw_examples/injection_examples/roq_example.py
index 947d85a52..da639805e 100644
--- a/examples/gw_examples/injection_examples/roq_example.py
+++ b/examples/gw_examples/injection_examples/roq_example.py
@@ -50,8 +50,7 @@ injection_parameters = dict(
     phase=1.3, geocent_time=1126259642.413, ra=1.375, dec=-1.2108)
 
 waveform_arguments = dict(waveform_approximant='IMRPhenomPv2',
-                          reference_frequency=20. * scale_factor,
-                          minimum_frequency=20. * scale_factor)
+                          reference_frequency=20. * scale_factor)
 
 waveform_generator = bilby.gw.WaveformGenerator(
     duration=duration, sampling_frequency=sampling_frequency,
@@ -65,6 +64,8 @@ ifos.set_strain_data_from_power_spectral_densities(
     start_time=injection_parameters['geocent_time'] - 3)
 ifos.inject_signal(waveform_generator=waveform_generator,
                    parameters=injection_parameters)
+for ifo in ifos:
+    ifo.minimum_frequency = 20 * scale_factor
 
 # make ROQ waveform generator
 search_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
@@ -98,7 +99,7 @@ likelihood = bilby.gw.likelihood.ROQGravitationalWaveTransient(
     priors=priors, roq_params=params)
 
 # write the weights to file so they can be loaded multiple times
-likelihood.save_weights('weights.json')
+likelihood.save_weights('weights.npz')
 
 # remove the basis matrices as these are big for longer bases
 del basis_matrix_linear, basis_matrix_quadratic
@@ -106,10 +107,10 @@ del basis_matrix_linear, basis_matrix_quadratic
 # load the weights from the file
 likelihood = bilby.gw.likelihood.ROQGravitationalWaveTransient(
     interferometers=ifos, waveform_generator=search_waveform_generator,
-    weights='weights.json', priors=priors)
+    weights='weights.npz', priors=priors)
 
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='pymultinest', npoints=500,
+    likelihood=likelihood, priors=priors, sampler='dynesty', npoints=500,
     injection_parameters=injection_parameters, outdir=outdir, label=label)
 
 # Make a corner plot.
diff --git a/test/gw_likelihood_test.py b/test/gw_likelihood_test.py
index d9961cf11..91966dc45 100644
--- a/test/gw_likelihood_test.py
+++ b/test/gw_likelihood_test.py
@@ -1,7 +1,10 @@
 from __future__ import division, absolute_import
 import unittest
-import bilby
+import os
+
 import numpy as np
+import bilby
+from bilby.gw.likelihood import BilbyROQParamsRangeError
 
 
 class TestBasicGWTransient(unittest.TestCase):
@@ -533,7 +536,18 @@ class TestROQLikelihood(unittest.TestCase):
         self.duration = 4
         self.sampling_frequency = 2048
 
-        roq_dir = '/roq_basis'
+        # Possible locations for the ROQ: in the docker image, local, or on CIT
+        trial_roq_paths = [
+            "/roq_basis",
+            os.path.join(os.path.expanduser("~"), 'ROQ_data/IMRPhenomPv2/4s'),
+            "/home/cbc/ROQ_data/IMRPhenomPv2/4s"]
+        roq_dir = None
+        for path in trial_roq_paths:
+            if os.path.isdir(path):
+                roq_dir = path
+                break
+        if roq_dir is None:
+            raise Exception("Unable to load ROQ basis: cannot proceed with tests")
 
         linear_matrix_file = "{}/B_linear.npy".format(roq_dir)
         quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir)
@@ -556,6 +570,11 @@ class TestROQLikelihood(unittest.TestCase):
             sampling_frequency=self.sampling_frequency, duration=self.duration)
 
         self.priors = bilby.gw.prior.BBHPriorDict()
+        self.priors.pop("mass_1")
+        self.priors.pop("mass_2")
+        # Testing is done with the 4s IMRPhenomPV2 ROQ basis
+        self.priors["chirp_mass"] = bilby.core.prior.Uniform(12.299703, 45)
+        self.priors["mass_ratio"] = bilby.core.prior.Uniform(0.125, 1)
         self.priors['geocent_time'] = bilby.core.prior.Uniform(1.19, 1.21)
 
         non_roq_wfg = bilby.gw.WaveformGenerator(
@@ -637,8 +656,9 @@ class TestROQLikelihood(unittest.TestCase):
             roq.log_likelihood_ratio(), self.roq.log_likelihood_ratio())
 
     def test_create_roq_weights_frequency_mismatch_works_with_params(self):
+
         self.ifos[0].maximum_frequency = self.ifos[0].maximum_frequency / 2
-        _ = bilby.gw.likelihood.ROQGravitationalWaveTransient(
+        bilby.gw.likelihood.ROQGravitationalWaveTransient(
             interferometers=self.ifos, waveform_generator=self.roq_wfg,
             linear_matrix=self.linear_matrix_file, roq_params=self.params_file,
             quadratic_matrix=self.quadratic_matrix_file, priors=self.priors)
@@ -646,11 +666,149 @@ class TestROQLikelihood(unittest.TestCase):
     def test_create_roq_weights_frequency_mismatch_fails_without_params(self):
         self.ifos[0].maximum_frequency = self.ifos[0].maximum_frequency / 2
         with self.assertRaises(ValueError):
-            _ = bilby.gw.likelihood.ROQGravitationalWaveTransient(
+            bilby.gw.likelihood.ROQGravitationalWaveTransient(
                 interferometers=self.ifos, waveform_generator=self.roq_wfg,
                 linear_matrix=self.linear_matrix_file,
                 quadratic_matrix=self.quadratic_matrix_file, priors=self.priors)
 
+    def test_create_roq_weights_fails_with_min_chirp_mass_outside_bounds(self):
+        self.ifos[0].maximum_frequency = self.ifos[0].maximum_frequency / 2
+        self.priors["chirp_mass"] = bilby.core.prior.Uniform(10, 45)
+        with self.assertRaises(BilbyROQParamsRangeError):
+            bilby.gw.likelihood.ROQGravitationalWaveTransient(
+                interferometers=self.ifos, waveform_generator=self.roq_wfg,
+                linear_matrix=self.linear_matrix_file,
+                roq_params=self.params_file,
+                quadratic_matrix=self.quadratic_matrix_file,
+                priors=self.priors)
+
+    def test_create_roq_weights_fails_with_max_chirp_mass_outside_bounds(self):
+        self.ifos[0].maximum_frequency = self.ifos[0].maximum_frequency / 2
+        self.priors["chirp_mass"] = bilby.core.prior.Uniform(12.299703, 50)
+        with self.assertRaises(BilbyROQParamsRangeError):
+            bilby.gw.likelihood.ROQGravitationalWaveTransient(
+                interferometers=self.ifos, waveform_generator=self.roq_wfg,
+                linear_matrix=self.linear_matrix_file,
+                roq_params=self.params_file,
+                quadratic_matrix=self.quadratic_matrix_file,
+                priors=self.priors)
+
+    def test_create_roq_weights_fails_with_min_component_mass_outside_bounds(self):
+        self.ifos[0].maximum_frequency = self.ifos[0].maximum_frequency / 2
+        self.priors["chirp_mass"] = bilby.core.prior.Uniform(12.299703, 45)
+        self.priors["mass_ratio"] = bilby.core.prior.Uniform(1e-5, 1)
+        with self.assertRaises(BilbyROQParamsRangeError):
+            bilby.gw.likelihood.ROQGravitationalWaveTransient(
+                interferometers=self.ifos, waveform_generator=self.roq_wfg,
+                linear_matrix=self.linear_matrix_file,
+                roq_params=self.params_file,
+                quadratic_matrix=self.quadratic_matrix_file,
+                priors=self.priors)
+
+    def test_create_roq_weights_fails_with_max_frequency(self):
+        ifos = bilby.gw.detector.InterferometerList(['H1'])
+        ifos.set_strain_data_from_power_spectral_densities(
+            sampling_frequency=2**14, duration=4)
+        ifos[0].maximum_frequency = 2**13
+        with self.assertRaises(BilbyROQParamsRangeError):
+            bilby.gw.likelihood.ROQGravitationalWaveTransient(
+                interferometers=ifos, waveform_generator=self.roq_wfg,
+                linear_matrix=self.linear_matrix_file,
+                roq_params=self.params_file,
+                quadratic_matrix=self.quadratic_matrix_file,
+                priors=self.priors)
+
+    def test_create_roq_weights_fails_due_to_min_frequency(self):
+        self.ifos[0].minimum_frequency = 15
+        with self.assertRaises(BilbyROQParamsRangeError):
+            bilby.gw.likelihood.ROQGravitationalWaveTransient(
+                interferometers=self.ifos, waveform_generator=self.roq_wfg,
+                linear_matrix=self.linear_matrix_file,
+                roq_params=self.params_file,
+                quadratic_matrix=self.quadratic_matrix_file,
+                priors=self.priors)
+
+    def test_create_roq_weights_fails_due_to_duration(self):
+        ifos = bilby.gw.detector.InterferometerList(['H1'])
+        ifos.set_strain_data_from_power_spectral_densities(
+            sampling_frequency=self.sampling_frequency, duration=16)
+        with self.assertRaises(BilbyROQParamsRangeError):
+            bilby.gw.likelihood.ROQGravitationalWaveTransient(
+                interferometers=ifos, waveform_generator=self.roq_wfg,
+                linear_matrix=self.linear_matrix_file,
+                roq_params=self.params_file,
+                quadratic_matrix=self.quadratic_matrix_file,
+                priors=self.priors)
+
+
+class TestRescaledROQLikelihood(unittest.TestCase):
+
+    def test_rescaling(self):
+
+        # Possible locations for the ROQ: in the docker image, local, or on CIT
+        trial_roq_paths = [
+            "/roq_basis",
+            os.path.join(os.path.expanduser("~"), 'ROQ_data/IMRPhenomPv2/4s'),
+            "/home/cbc/ROQ_data/IMRPhenomPv2/4s"]
+        roq_dir = None
+        for path in trial_roq_paths:
+            if os.path.isdir(path):
+                roq_dir = path
+                break
+        if roq_dir is None:
+            raise Exception("Unable to load ROQ basis: cannot proceed with tests")
+
+        linear_matrix_file = "{}/B_linear.npy".format(roq_dir)
+        quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir)
+
+        fnodes_linear_file = "{}/fnodes_linear.npy".format(roq_dir)
+        fnodes_linear = np.load(fnodes_linear_file).T
+        fnodes_quadratic_file = "{}/fnodes_quadratic.npy".format(roq_dir)
+        fnodes_quadratic = np.load(fnodes_quadratic_file).T
+        self.linear_matrix_file = "{}/B_linear.npy".format(roq_dir)
+        self.quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir)
+        self.params_file = "{}/params.dat".format(roq_dir)
+
+        scale_factor = 0.5
+        params = np.genfromtxt(self.params_file, names=True)
+        params['flow'] *= scale_factor
+        params['fhigh'] *= scale_factor
+        params['seglen'] /= scale_factor
+        params['chirpmassmin'] /= scale_factor
+        params['chirpmassmax'] /= scale_factor
+        params['compmin'] /= scale_factor
+
+        self.duration = 4 / scale_factor
+        self.sampling_frequency = 2048 * scale_factor
+
+        ifos = bilby.gw.detector.InterferometerList(['H1'])
+        ifos.set_strain_data_from_power_spectral_densities(
+            sampling_frequency=self.sampling_frequency, duration=self.duration)
+        self.ifos = ifos
+
+        self.priors = bilby.gw.prior.BBHPriorDict()
+        self.priors.pop("mass_1")
+        self.priors.pop("mass_2")
+        # Testing is done with the 4s IMRPhenomPV2 ROQ basis
+        self.priors["chirp_mass"] = bilby.core.prior.Uniform(
+            12.299703 / scale_factor, 45 / scale_factor)
+        self.priors["mass_ratio"] = bilby.core.prior.Uniform(0.125, 1)
+        self.priors['geocent_time'] = bilby.core.prior.Uniform(1.19, 1.21)
+
+        self.roq_wfg = bilby.gw.waveform_generator.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.roq,
+            waveform_arguments=dict(
+                frequency_nodes_linear=fnodes_linear,
+                frequency_nodes_quadratic=fnodes_quadratic,
+                reference_frequency=20., minimum_frequency=20.,
+                approximant='IMRPhenomPv2'))
+
+        self.roq = bilby.gw.likelihood.ROQGravitationalWaveTransient(
+            interferometers=ifos, waveform_generator=self.roq_wfg,
+            linear_matrix=linear_matrix_file, roq_params=params,
+            quadratic_matrix=quadratic_matrix_file, priors=self.priors)
+
 
 class TestBBHLikelihoodSetUp(unittest.TestCase):
 
-- 
GitLab