Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • john-veitch/bilby
  • duncanmmacleod/bilby
  • colm.talbot/bilby
  • lscsoft/bilby
  • matthew-pitkin/bilby
  • salvatore-vitale/tupak
  • charlie.hoy/bilby
  • bfarr/bilby
  • virginia.demilio/bilby
  • vivien/bilby
  • eric-howell/bilby
  • sebastian-khan/bilby
  • rhys.green/bilby
  • moritz.huebner/bilby
  • joseph.mills/bilby
  • scott.coughlin/bilby
  • matthew.carney/bilby
  • hyungwon.lee/bilby
  • monica.rizzo/bilby
  • christopher-berry/bilby
  • lindsay.demarchi/bilby
  • kaushik.rao/bilby
  • charles.kimball/bilby
  • andrew.matas/bilby
  • juan.calderonbustillo/bilby
  • patrick-meyers/bilby
  • hannah.middleton/bilby
  • eve.chase/bilby
  • grant.meadors/bilby
  • khun.phukon/bilby
  • sumeet.kulkarni/bilby
  • daniel.reardon/bilby
  • cjhaster/bilby
  • sylvia.biscoveanu/bilby
  • james-clark/bilby
  • meg.millhouse/bilby
  • joshua.willis/bilby
  • nikhil.sarin/bilby
  • paul.easter/bilby
  • youngmin/bilby
  • daniel-williams/bilby
  • shanika.galaudage/bilby
  • bruce.edelman/bilby
  • avi.vajpeyi/bilby
  • isobel.romero-shaw/bilby
  • andrew.kim/bilby
  • dominika.zieba/bilby
  • jonathan.davies/bilby
  • marc.arene/bilby
  • srishti.tiwari/bilby-tidal-heating-eccentric
  • aditya.vijaykumar/bilby
  • michael.williams/bilby
  • cecilio.garcia-quiros/bilby
  • rory-smith/bilby
  • maite.mateu-lucena/bilby
  • wushichao/bilby
  • kaylee.desoto/bilby
  • brandon.piotrzkowski/bilby
  • rossella.gamba/bilby
  • hunter.gabbard/bilby
  • deep.chatterjee/bilby
  • tathagata.ghosh/bilby
  • arunava.mukherjee/bilby
  • philip.relton/bilby
  • reed.essick/bilby
  • pawan.gupta/bilby
  • francisco.hernandez/bilby
  • rhiannon.udall/bilby
  • leo.tsukada/bilby
  • will-farr/bilby
  • vijay.varma/bilby
  • jeremy.baier/bilby
  • joshua.brandt/bilby
  • ethan.payne/bilby
  • ka-lok.lo/bilby
  • antoni.ramos-buades/bilby
  • oliviastephany.wilk/bilby
  • jack.heinzel/bilby
  • samson.leong/bilby-psi4
  • viviana.caceres/bilby
  • nadia.qutob/bilby
  • michael-coughlin/bilby
  • hemantakumar.phurailatpam/bilby
  • boris.goncharov/bilby
  • sama.al-shammari/bilby
  • siqi.zhong/bilby
  • jocelyn-read/bilby
  • marc.penuliar/bilby
  • stephanie.letourneau/bilby
  • alexandresebastien.goettel/bilby
  • alec.gunny/bilby
  • serguei.ossokine/bilby
  • pratyusava.baral/bilby
  • sophie.hourihane/bilby
  • eunsub/bilby
  • james.hart/bilby
  • pratyusava.baral/bilby-tg
  • zhaozc/bilby
  • pratyusava.baral/bilby_SoG
  • tomasz.baka/bilby
  • nicogerardo.bers/bilby
  • soumen.roy/bilby
  • isaac.mcmahon/healpix-redundancy
  • asamakai.baker/bilby-frequency-dependent-antenna-pattern-functions
  • anna.puecher/bilby
  • pratyusava.baral/bilby-x-g
  • thibeau.wouters/bilby
  • christian.adamcewicz/bilby
  • raffi.enficiaud/bilby
109 results
Show changes
Commits on Source (21)
Showing
with 493 additions and 331 deletions
......@@ -35,7 +35,7 @@ authors:
# Test containers scripts are up to date
containers:
stage: initial
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python310
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python311
script:
- cd containers
- python write_dockerfiles.py #HACK
......@@ -43,7 +43,7 @@ containers:
# write_dockerfiles.py and commit the changes.
- git diff --exit-code
- cp env-template.yml env.yml
- echo " - python=3.10" >> env.yml
- echo " - python=3.11" >> env.yml
- mamba env create -f env.yml -n test --dry-run
.test-python: &test-python
......@@ -70,10 +70,6 @@ containers:
${script} --help;
done
basic-3.9:
<<: *test-python
image: python:3.9
basic-3.10:
<<: *test-python
image: python:3.10
......@@ -82,6 +78,10 @@ basic-3.11:
<<: *test-python
image: python:3.11
basic-3.12:
<<: *test-python
image: python:3.12
.test-samplers-import: &test-samplers-import
stage: initial
script:
......@@ -89,10 +89,6 @@ basic-3.11:
- *list-env
- pytest test/test_samplers_import.py -v
import-samplers-3.9:
<<: *test-samplers-import
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python39
import-samplers-3.10:
<<: *test-samplers-import
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python310
......@@ -101,6 +97,10 @@ import-samplers-3.11:
<<: *test-samplers-import
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python311
import-samplers-3.12:
<<: *test-samplers-import
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python312
.precommits: &precommits
stage: initial
script:
......@@ -112,19 +112,19 @@ import-samplers-3.11:
# Run precommits (flake8, spellcheck, isort, no merge conflicts, etc)
- pre-commit run --all-files --verbose --show-diff-on-failure
precommits-py3.10:
precommits-py3.11:
<<: *precommits
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python310
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python311
variables:
CACHE_DIR: ".pip310"
PYVERSION: "python310"
CACHE_DIR: ".pip311"
PYVERSION: "python311"
install:
stage: initial
parallel:
matrix:
- EXTRA: [gw, mcmc, all]
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python310
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python311
script:
- pip install .[$EXTRA]
......@@ -138,15 +138,15 @@ install:
- pytest --cov=bilby --durations 10
python-3.9:
<<: *unit-test
needs: ["basic-3.9"]
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python39
python-3.10:
<<: *unit-test
needs: ["basic-3.10", "precommits-py3.10"]
needs: ["basic-3.10"]
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python310
python-3.11:
<<: *unit-test
needs: ["basic-3.11", "precommits-py3.11"]
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python311
after_script:
- coverage html
- coverage xml
......@@ -160,10 +160,10 @@ python-3.10:
- htmlcov/
expire_in: 30 days
python-3.11:
python-3.12:
<<: *unit-test
needs: ["basic-3.11"]
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python311
needs: ["basic-3.12"]
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python312
.test-sampler: &test-sampler
stage: test
......@@ -172,25 +172,25 @@ python-3.11:
- *list-env
- pytest test/integration/sampler_run_test.py --durations 10 -v
python-3.9-samplers:
<<: *test-sampler
needs: ["basic-3.9"]
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python39
python-3.10-samplers:
<<: *test-sampler
needs: ["basic-3.10", "precommits-py3.10"]
needs: ["basic-3.10"]
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python310
python-3.11-samplers:
<<: *test-sampler
needs: ["basic-3.11"]
needs: ["basic-3.11", "precommits-py3.11"]
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python311
integration-tests-python-3.10:
python-3.12-samplers:
<<: *test-sampler
needs: ["basic-3.12"]
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python312
integration-tests-python-3.11:
stage: test
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python310
needs: ["basic-3.10", "precommits-py3.10"]
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python311
needs: ["basic-3.11", "precommits-py3.11"]
only:
- schedules
script:
......@@ -208,26 +208,26 @@ integration-tests-python-3.10:
- *list-env
- pytest test/gw/plot_test.py
plotting-python-3.9:
<<: *plotting
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python39
needs: ["basic-3.9"]
plotting-python-3.10:
<<: *plotting
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python310
needs: ["basic-3.10", "precommits-py3.10"]
needs: ["basic-3.10"]
plotting-python-3.11:
<<: *plotting
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python311
needs: ["basic-3.11"]
needs: ["basic-3.11", "precommits-py3.11"]
plotting-python-3.12:
<<: *plotting
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python312
needs: ["basic-3.12"]
# ------------------- Docs stage -------------------------------------------
docs:
stage: docs
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python310
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python311
before_script:
- python -m ipykernel install
script:
......@@ -248,7 +248,7 @@ docs:
pages:
stage: deploy
needs: ["docs", "python-3.10"]
needs: ["docs", "python-3.11"]
script:
- mkdir public/
- mv htmlcov/ public/
......@@ -280,11 +280,6 @@ pages:
- docker image tag v3-bilby-$PYVERSION containers.ligo.org/lscsoft/bilby/v2-bilby-$PYVERSION:latest
- docker image push containers.ligo.org/lscsoft/bilby/v2-bilby-$PYVERSION:latest
build-python39-container:
<<: *build-container
variables:
PYVERSION: "python39"
build-python310-container:
<<: *build-container
variables:
......@@ -295,6 +290,11 @@ build-python311-container:
variables:
PYVERSION: "python311"
build-python312-container:
<<: *build-container
variables:
PYVERSION: "python312"
pypi-release:
stage: deploy
image: containers.ligo.org/lscsoft/bilby/v2-bilby-python310
......
......@@ -16,6 +16,7 @@ Bruce Edelman
Carl-Johan Haster
Cecilio Garcia-Quiros
Charlie Hoy
Chentao Yang
Christopher Philip Luke Berry
Christos Karathanasis
Colm Talbot
......@@ -29,6 +30,7 @@ Gregory Ashton
Hank Hua
Hector Estelles
Ignacio Magaña Hernandez
Isaac McMahon
Isobel Marguarethe Romero-Shaw
Jack Heinzel
Jacob Golomb
......
|pipeline status| |coverage report| |pypi| |conda| |version|
=====
Bilby
Bilby development has moved to `GitHub <https://github.com/bilby-dev/bilby>`__!
=====
Please open any new issues or pull requests there. A full migration guide will be provided soon. Links below here may no longer be active.
====
A user-friendly Bayesian inference library.
Fulfilling all your Bayesian dreams.
......@@ -30,47 +32,7 @@ us directly. For advice on contributing, see `the contributing guide <https://gi
Citation guide
--------------
If you use :code:`bilby` in a scientific publication, please cite
* `Bilby: A user-friendly Bayesian inference library for gravitational-wave
astronomy
<https://ui.adsabs.harvard.edu/#abs/2018arXiv181102042A/abstract>`__
* `Bayesian inference for compact binary coalescences with BILBY: validation and application to the first LIGO-Virgo gravitational-wave transient catalogue <https://ui.adsabs.harvard.edu/abs/2020MNRAS.499.3295R/abstract>`__
The first of these papers introduces the software, while the second introduces advances in the sampling approaches and validation of the software.
If you use the :code:`bilby_mcmc` sampler, please additionally cite
* `BILBY-MCMC: an MCMC sampler for gravitational-wave inference <https://ui.adsabs.harvard.edu/abs/2021MNRAS.507.2037A/abstract>`__
Additionally, :code:`bilby` builds on a number of open-source packages. If you
make use of this functionality in your publications, we recommend you cite them
as requested in their associated documentation.
**Samplers**
* `dynesty <https://github.com/joshspeagle/dynesty>`__
* `nestle <https://github.com/kbarbary/nestle>`__
* `pymultinest <https://github.com/JohannesBuchner/PyMultiNest>`__
* `cpnest <https://github.com/johnveitch/cpnest>`__
* `emcee <https://github.com/dfm/emcee>`__
* `nessai <https://github.com/mj-will/nessai>`_
* `ptemcee <https://github.com/willvousden/ptemcee>`__
* `ptmcmcsampler <https://github.com/jellis18/PTMCMCSampler>`__
* `pypolychord <https://github.com/PolyChord/PolyChordLite>`__
* `PyMC3 <https://github.com/pymc-devs/pymc3>`_
**Gravitational-wave tools**
* `gwpy <https://github.com/gwpy/gwpy>`__
* `lalsuite <https://git.ligo.org/lscsoft/lalsuite>`__
* `astropy <https://github.com/astropy/astropy>`__
**Plotting**
* `corner <https://github.com/dfm/corner.py>`__ for generating corner plot
* `matplotlib <https://github.com/matplotlib/matplotlib>`__ for general plotting routines
Please refer to the `Acknowledging/citing bilby guide <https://lscsoft.docs.ligo.org/bilby/citing-bilby.html>`__.
.. |pipeline status| image:: https://git.ligo.org/lscsoft/bilby/badges/master/pipeline.svg
:target: https://git.ligo.org/lscsoft/bilby/commits/master
......
......@@ -127,6 +127,9 @@ class Bilby_MCMC(MCMCSampler):
initial_sample_dict: dict
A dictionary of the initial sample value. If incomplete, will overwrite
the initial_sample drawn using initial_sample_method.
normalize_prior: bool
When False, disables calculation of constraint normalization factor
during prior probability computation. Default value is True.
verbose: bool
Whether to print diagnostic output during the run.
......@@ -175,6 +178,7 @@ class Bilby_MCMC(MCMCSampler):
resume=True,
exit_code=130,
verbose=True,
normalize_prior=True,
**kwargs,
):
......@@ -194,6 +198,7 @@ class Bilby_MCMC(MCMCSampler):
self.kwargs["target_nsamples"] = self.kwargs["nsamples"]
self.L1steps = self.kwargs["L1steps"]
self.L2steps = self.kwargs["L2steps"]
self.normalize_prior = normalize_prior
self.pt_inputs = ParallelTemperingInputs(
**{key: self.kwargs[key] for key in ParallelTemperingInputs._fields}
)
......@@ -309,6 +314,7 @@ class Bilby_MCMC(MCMCSampler):
evidence_method=self.evidence_method,
initial_sample_method=self.initial_sample_method,
initial_sample_dict=self.initial_sample_dict,
normalize_prior=self.normalize_prior,
)
def get_setup_string(self):
......@@ -382,7 +388,9 @@ class Bilby_MCMC(MCMCSampler):
If true, resume file was successfully loaded, otherwise false
"""
if os.path.isfile(self.resume_file) is False:
if os.path.isfile(self.resume_file) is False or not os.path.getsize(
self.resume_file
):
return False
import dill
......@@ -583,11 +591,13 @@ class BilbyPTMCMCSampler(object):
evidence_method,
initial_sample_method,
initial_sample_dict,
normalize_prior=True,
):
self.set_pt_inputs(pt_inputs)
self.use_ratio = use_ratio
self.initial_sample_method = initial_sample_method
self.initial_sample_dict = initial_sample_dict
self.normalize_prior = normalize_prior
self.setup_sampler_dictionary(convergence_inputs, proposal_cycle)
self.set_convergence_inputs(convergence_inputs)
self.pt_rejection_sample = pt_rejection_sample
......@@ -658,6 +668,7 @@ class BilbyPTMCMCSampler(object):
use_ratio=self.use_ratio,
initial_sample_method=self.initial_sample_method,
initial_sample_dict=self.initial_sample_dict,
normalize_prior=self.normalize_prior,
)
for Eindex in range(n)
]
......@@ -1152,12 +1163,13 @@ class BilbyMCMCSampler(object):
use_ratio=False,
initial_sample_method="prior",
initial_sample_dict=None,
normalize_prior=True,
):
self.beta = beta
self.Tindex = Tindex
self.Eindex = Eindex
self.use_ratio = use_ratio
self.normalize_prior = normalize_prior
self.parameters = _sampling_convenience_dump.priors.non_fixed_keys
self.ndim = len(self.parameters)
......@@ -1232,7 +1244,10 @@ class BilbyMCMCSampler(object):
return logl
def log_prior(self, sample):
return _sampling_convenience_dump.priors.ln_prob(sample.parameter_only_dict)
return _sampling_convenience_dump.priors.ln_prob(
sample.parameter_only_dict,
normalized=self.normalize_prior,
)
def accept_proposal(self, prop, proposal):
self.chain.append(prop)
......
......@@ -537,7 +537,7 @@ class PriorDict(dict):
constrained_prob[keep] = prob[keep] * ratio
return constrained_prob
def ln_prob(self, sample, axis=None):
def ln_prob(self, sample, axis=None, normalized=True):
"""
Parameters
......@@ -546,6 +546,9 @@ class PriorDict(dict):
Dictionary of the samples of which to calculate the log probability
axis: None or int
Axis along which the summation is performed
normalized: bool
When False, disables calculation of constraint normalization factor
during prior probability computation. Default value is True.
Returns
=======
......@@ -554,10 +557,14 @@ class PriorDict(dict):
"""
ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], axis=axis)
return self.check_ln_prob(sample, ln_prob)
return self.check_ln_prob(sample, ln_prob,
normalized=normalized)
def check_ln_prob(self, sample, ln_prob):
ratio = self.normalize_constraint_factor(tuple(sample.keys()))
def check_ln_prob(self, sample, ln_prob, normalized=True):
if normalized:
ratio = self.normalize_constraint_factor(tuple(sample.keys()))
else:
ratio = 1
if np.all(np.isinf(ln_prob)):
return ln_prob
else:
......@@ -785,7 +792,7 @@ class ConditionalPriorDict(PriorDict):
prob = np.prod(res, **kwargs)
return self.check_prob(sample, prob)
def ln_prob(self, sample, axis=None):
def ln_prob(self, sample, axis=None, normalized=True):
"""
Parameters
......@@ -794,6 +801,9 @@ class ConditionalPriorDict(PriorDict):
Dictionary of the samples of which we want to have the log probability of
axis: Union[None, int]
Axis along which the summation is performed
normalized: bool
When False, disables calculation of constraint normalization factor
during prior probability computation. Default value is True.
Returns
=======
......@@ -806,7 +816,8 @@ class ConditionalPriorDict(PriorDict):
for key in sample
]
ln_prob = np.sum(res, axis=axis)
return self.check_ln_prob(sample, ln_prob)
return self.check_ln_prob(sample, ln_prob,
normalized=normalized)
def cdf(self, sample):
self._prepare_evaluation(*zip(*sample.items()))
......
......@@ -463,7 +463,7 @@ class Sampler(object):
logger.info("Unable to measure single likelihood time")
else:
logger.info(
f"Single likelihood evaluation took {self._log_likelihood_eval_time:.3e} s"
f"Single likelihood evaluation took {log_likelihood_eval_time:.3e} s"
)
return log_likelihood_eval_time
......
......@@ -3,6 +3,7 @@ import inspect
import os
import sys
import time
import warnings
import numpy as np
from pandas import DataFrame
......@@ -677,12 +678,21 @@ class Dynesty(NestedSampler):
chain of nested samples within dynesty and have to be removed before
restarting the sampler.
"""
logger.debug("Running sampler with checkpointing")
old_ncall = self.sampler.ncall
sampler_kwargs = self.sampler_function_kwargs.copy()
warnings.filterwarnings(
"ignore",
message="The sampling was stopped short due to maxiter/maxcall limit*",
category=UserWarning,
module="dynesty.sampler",
)
while True:
self.finalize_sampler_kwargs(sampler_kwargs)
if getattr(self.sampler, "added_live", False):
self.sampler._remove_live_points()
self.sampler.run_nested(**sampler_kwargs)
if self.sampler.ncall == old_ncall:
break
......@@ -697,8 +707,8 @@ class Dynesty(NestedSampler):
if last_checkpoint_s > self.check_point_delta_t:
self.write_current_state()
self.plot_current_state()
if getattr(self.sampler, "added_live", False):
self.sampler._remove_live_points()
if getattr(self.sampler, "added_live", False):
self.sampler._remove_live_points()
self.sampler.run_nested(**sampler_kwargs)
self.write_current_state()
......@@ -736,7 +746,10 @@ class Dynesty(NestedSampler):
if os.path.isfile(self.resume_file):
logger.info(f"Reading resume file {self.resume_file}")
with open(self.resume_file, "rb") as file:
sampler = dill.load(file)
try:
sampler = dill.load(file)
except EOFError:
sampler = None
if not hasattr(sampler, "versions"):
logger.warning(
......
import os
import shutil
from collections import namedtuple
from shutil import copyfile
import numpy as np
from packaging import version
......@@ -311,7 +310,11 @@ class Emcee(MCMCSampler):
"""
if hasattr(self, "_sampler"):
pass
elif self.resume and os.path.isfile(self.checkpoint_info.sampler_file):
elif (
self.resume
and os.path.isfile(self.checkpoint_info.sampler_file)
and os.path.getsize(self.checkpoint_info.sampler_file)
):
import dill
logger.info(
......@@ -329,16 +332,19 @@ class Emcee(MCMCSampler):
def write_chains_to_file(self, sample):
chain_file = self.checkpoint_info.chain_file
temp_chain_file = chain_file + ".temp"
if os.path.isfile(chain_file):
copyfile(chain_file, temp_chain_file)
if self.prerelease:
points = np.hstack([sample.coords, sample.blobs])
else:
points = np.hstack([sample[0], np.array(sample[3])])
with open(temp_chain_file, "a") as ff:
for ii, point in enumerate(points):
ff.write(self.checkpoint_info.chain_template.format(ii, *point))
shutil.move(temp_chain_file, chain_file)
data_to_write = "\n".join(
self.checkpoint_info.chain_template.format(ii, *point)
for ii, point in enumerate(points)
)
with open(temp_chain_file, "w") as ff:
ff.write(data_to_write)
with open(temp_chain_file, "rb") as ftemp, open(chain_file, "ab") as fchain:
shutil.copyfileobj(ftemp, fchain)
os.remove(temp_chain_file)
@property
def _previous_iterations(self):
......
......@@ -166,7 +166,11 @@ class Kombine(Emcee):
return self.sampler.chain[:nsteps, :, :]
def check_resume(self):
return self.resume and os.path.isfile(self.checkpoint_info.sampler_file)
return (
self.resume
and os.path.isfile(self.checkpoint_info.sampler_file)
and os.path.getsize(self.checkpoint_info.sampler_file) > 0
)
@signal_wrapper
def run_sampler(self):
......
......@@ -415,7 +415,11 @@ class Ptemcee(MCMCSampler):
# This is a very ugly hack to support numpy>=1.24
ptemcee.sampler.np.float = float
if os.path.isfile(self.resume_file) and self.resume is True:
if (
os.path.isfile(self.resume_file)
and os.path.getsize(self.resume_file)
and self.resume is True
):
import dill
logger.info(f"Resume data {self.resume_file} found")
......@@ -513,7 +517,7 @@ class Ptemcee(MCMCSampler):
logger.info("Starting to sample")
while True:
for (pos0, log_posterior, log_likelihood) in sampler.sample(
for pos0, log_posterior, log_likelihood in sampler.sample(
self.pos0,
storechain=False,
iterations=self.convergence_inputs.niterations_per_check,
......
import math
from numbers import Number
import numpy as np
from scipy.interpolate import interp2d
from scipy.interpolate import RectBivariateSpline
from scipy.special import logsumexp
from .log import logger
......@@ -189,79 +189,34 @@ def logtrapzexp(lnf, dx):
return C + logsumexp([logsumexp(lnfdx1), logsumexp(lnfdx2)])
class UnsortedInterp2d(interp2d):
def __call__(self, x, y, dx=0, dy=0, assume_sorted=False):
"""Modified version of the interp2d call method.
This avoids the outer product that is done when two numpy
arrays are passed.
Parameters
==========
x: See superclass
y: See superclass
dx: See superclass
dy: See superclass
assume_sorted: bool, optional
This is just a place holder to prevent a warning.
Overwriting this will not do anything
class BoundedRectBivariateSpline(RectBivariateSpline):
Returns
=======
array_like: See superclass
def __init__(self, x, y, z, bbox=[None] * 4, kx=3, ky=3, s=0, fill_value=None):
self.x_min, self.x_max, self.y_min, self.y_max = bbox
if self.x_min is None:
self.x_min = min(x)
if self.x_max is None:
self.x_max = max(x)
if self.y_min is None:
self.y_min = min(y)
if self.y_max is None:
self.y_max = max(y)
self.fill_value = fill_value
super().__init__(x=x, y=y, z=z, bbox=bbox, kx=kx, ky=ky, s=s)
"""
from scipy.interpolate.dfitpack import bispeu
x, y = self._sanitize_inputs(x, y)
def __call__(self, x, y, dx=0, dy=0, grid=False):
result = super().__call__(x=x, y=y, dx=dx, dy=dy, grid=grid)
out_of_bounds_x = (x < self.x_min) | (x > self.x_max)
out_of_bounds_y = (y < self.y_min) | (y > self.y_max)
bad = out_of_bounds_x | out_of_bounds_y
if isinstance(x, Number) and isinstance(y, Number):
result[bad] = self.fill_value
if result.size == 1:
if bad:
output = self.fill_value
ier = 0
return self.fill_value
else:
output, ier = bispeu(*self.tck, x, y)
output = float(output)
return result.item()
else:
output = np.empty_like(x)
output[bad] = self.fill_value
if np.any(~bad):
output[~bad], ier = bispeu(*self.tck, x[~bad], y[~bad])
else:
ier = 0
if ier == 10:
raise ValueError("Invalid input data")
elif ier:
raise TypeError("An error occurred")
return output
@staticmethod
def _sanitize_inputs(x, y):
if isinstance(x, np.ndarray) and x.size == 1:
x = float(x)
if isinstance(y, np.ndarray) and y.size == 1:
y = float(y)
if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
original_shapes = (x.shape, y.shape)
if x.shape != y.shape:
while x.ndim > y.ndim:
y = np.expand_dims(y, -1)
while y.ndim > x.ndim:
x = np.expand_dims(x, -1)
try:
x = x * np.ones(y.shape)
y = y * np.ones(x.shape)
except ValueError:
raise ValueError(
f"UnsortedInterp2d received incompatibly shaped arrays: {original_shapes}"
)
elif isinstance(x, np.ndarray) and not isinstance(y, np.ndarray):
y = y * np.ones_like(x)
elif not isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
x = x * np.ones_like(y)
return x, y
return result
def round_up_to_power_of_two(x):
......
......@@ -2556,3 +2556,59 @@ def fill_sample(args):
likelihood.parameters.update(dict(sample).copy())
new_sample = likelihood.generate_posterior_sample_from_marginalized_likelihood()
return tuple((new_sample[key] for key in marginalized_parameters))
def identity_map_conversion(parameters):
"""An identity map conversion function that makes no changes to the parameters,
but returns the correct signature expected by other conversion functions
(e.g. convert_to_lal_binary_black_hole_parameters)"""
return parameters, []
def identity_map_generation(sample, likelihood=None, priors=None, npool=1):
"""An identity map generation function that handles marginalizations, SNRs, etc. correctly,
but does not attempt e.g. conversions in mass or spins
Parameters
==========
sample: dict or pandas.DataFrame
Samples to fill in with extra parameters, this may be either an
injection or posterior samples.
likelihood: bilby.gw.likelihood.GravitationalWaveTransient, optional
GravitationalWaveTransient used for sampling, used for waveform and
likelihood.interferometers.
priors: dict, optional
Dictionary of prior objects, used to fill in non-sampled parameters.
Returns
=======
"""
output_sample = sample.copy()
output_sample = fill_from_fixed_priors(output_sample, priors)
if likelihood is not None:
compute_per_detector_log_likelihoods(
samples=output_sample, likelihood=likelihood, npool=npool)
marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list())
if len(marginalized_parameters) > 0:
try:
generate_posterior_samples_from_marginalized_likelihood(
samples=output_sample, likelihood=likelihood, npool=npool)
except MarginalizedLikelihoodReconstructionError as e:
logger.warning(
"Marginalised parameter reconstruction failed with message "
"{}. Some parameters may not have the intended "
"interpretation.".format(e)
)
if ("ra" in output_sample.keys() and "dec" in output_sample.keys() and "psi" in output_sample.keys()):
compute_snrs(output_sample, likelihood, npool=npool)
else:
logger.info(
"Skipping SNR computation since samples have insufficient sky location information"
)
return output_sample
......@@ -597,6 +597,26 @@ class Interferometer(object):
power_spectral_density=self.power_spectral_density_array[self.strain_data.frequency_mask],
duration=self.strain_data.duration)
def template_template_inner_product(self, signal_1, signal_2):
"""A noise weighted inner product between two templates, using this ifo's PSD.
Parameters
==========
signal_1 : array_like
An array containing the first signal
signal_2 : array_like
an array containing the second signal
Returns
=======
float: The noise weighted inner product of the two templates
"""
return gwutils.noise_weighted_inner_product(
aa=signal_1[self.strain_data.frequency_mask],
bb=signal_2[self.strain_data.frequency_mask],
power_spectral_density=self.power_spectral_density_array[self.strain_data.frequency_mask],
duration=self.strain_data.duration)
def matched_filter_snr(self, signal):
"""
......@@ -616,19 +636,95 @@ class Interferometer(object):
power_spectral_density=self.power_spectral_density_array[self.strain_data.frequency_mask],
duration=self.strain_data.duration)
def whiten_frequency_series(self, frequency_series : np.array) -> np.array:
"""Whitens a frequency series with the noise properties of the detector
.. math::
\\tilde{a}_w(f) = \\tilde{a}(f) \\sqrt{\\frac{4}{T S_n(f)}}
Such that
.. math::
Var(n) = \\frac{1}{N} \\sum_{k=0}^N n_W(f_k)n_W^*(f_k) = 2
Where the factor of two is due to the independent real and imaginary
components.
Parameters
==========
frequency_series : np.array
The frequency series, whitened by the ASD
"""
return frequency_series / (self.amplitude_spectral_density_array * np.sqrt(self.duration / 4))
def get_whitened_time_series_from_whitened_frequency_series(
self,
whitened_frequency_series : np.array
) -> np.array:
"""Gets the whitened time series from a whitened frequency series.
This ifft's and also applies a windowing factor,
since when f_min and f_max are set bilby applies a mask to the series.
Per 6.2a-b in https://arxiv.org/pdf/gr-qc/0509116 since our window
is just a band pass,
this coefficient is :math:`w/W` where
.. math::
W = \\frac{1}{N} \\sum_{k=0}^N w^2[j]
Since our window :math:`w` is simply 1 or 0, depending on the mask, we get
.. math::
W = \\frac{1}{N} \\sum_{k=0}^N \\Theta(f_{max} - f_k)\\Theta(f_k - f_{min})
and accordingly the termwise window factor is
.. math::
w = \\sqrt{N W} = \\sqrt{\\sum_{k=0}^N \\Theta(f_{max} - f_k)\\Theta(f_k - f_{min})}
"""
frequency_window_factor = (
np.sum(self.frequency_mask)
/ len(self.frequency_mask)
)
whitened_time_series = (
np.fft.irfft(whitened_frequency_series)
* np.sqrt(np.sum(self.frequency_mask)) / frequency_window_factor
)
return whitened_time_series
@property
def whitened_frequency_domain_strain(self):
""" Calculates the whitened data by dividing the frequency domain data by
((amplitude spectral density) * (duration / 4) ** 0.5). The resulting
data will have unit variance.
r"""Whitens the frequency domain data by dividing through by ASD,
with appropriate normalization.
See `whiten_frequency_series()` for details.
Returns
=======
array_like: The whitened data
"""
return self.strain_data.frequency_domain_strain / (
self.amplitude_spectral_density_array * np.sqrt(self.duration / 4)
)
return self.whiten_frequency_series(self.strain_data.frequency_domain_strain)
@property
def whitened_time_domain_strain(self) -> np.array:
"""Calculates the whitened time domain strain
by iffting the whitened frequency domain strain,
with the appropriate normalization.
See `get_whitened_time_series_from_whitened_frequency_series()` for details
Returns
=======
array_like
The whitened data in the time domain
"""
return self.get_whitened_time_series_from_whitened_frequency_series(self.whitened_frequency_domain_strain)
def save_data(self, outdir, label=None):
""" Creates save files for interferometer data in plain text format.
......
......@@ -7,7 +7,7 @@ import numpy as np
from scipy.special import logsumexp
from ...core.likelihood import Likelihood
from ...core.utils import logger, UnsortedInterp2d, create_time_series
from ...core.utils import logger, BoundedRectBivariateSpline, create_time_series
from ...core.prior import Interped, Prior, Uniform, DeltaFunction
from ..detector import InterferometerList, get_empty_interferometer, calibration
from ..prior import BBHPriorDict, Cosmological
......@@ -752,7 +752,7 @@ class GravitationalWaveTransient(Likelihood):
d_inner_h_ref = np.real(d_inner_h_ref)
return self._interp_dist_margd_loglikelihood(
d_inner_h_ref, h_inner_h_ref)
d_inner_h_ref, h_inner_h_ref, grid=False)
def phase_marginalized_likelihood(self, d_inner_h, h_inner_h):
d_inner_h = ln_i0(abs(d_inner_h))
......@@ -891,9 +891,9 @@ class GravitationalWaveTransient(Likelihood):
self._create_lookup_table()
else:
self._create_lookup_table()
self._interp_dist_margd_loglikelihood = UnsortedInterp2d(
self._interp_dist_margd_loglikelihood = BoundedRectBivariateSpline(
self._d_inner_h_ref_array, self._optimal_snr_squared_ref_array,
self._dist_margd_loglikelihood_array, kind='cubic', fill_value=-np.inf)
self._dist_margd_loglikelihood_array.T, fill_value=-np.inf)
@property
def cached_lookup_table_filename(self):
......
......@@ -136,13 +136,28 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
linear_matrix['duration_s'][()])],
dtype=[('flow', float), ('fhigh', float), ('seglen', float)]
)
elif is_hdf5_quadratic:
self.roq_params = np.array(
[(quadratic_matrix['minimum_frequency_hz'][()],
quadratic_matrix['maximum_frequency_hz'][()],
quadratic_matrix['duration_s'][()])],
dtype=[('flow', float), ('fhigh', float), ('seglen', float)]
)
if is_hdf5_quadratic:
if self.roq_params is None:
self.roq_params = np.array(
[(quadratic_matrix['minimum_frequency_hz'][()],
quadratic_matrix['maximum_frequency_hz'][()],
quadratic_matrix['duration_s'][()])],
dtype=[('flow', float), ('fhigh', float), ('seglen', float)]
)
else:
self.roq_params['flow'] = max(
self.roq_params['flow'], quadratic_matrix['minimum_frequency_hz'][()]
)
self.roq_params['fhigh'] = min(
self.roq_params['fhigh'], quadratic_matrix['maximum_frequency_hz'][()]
)
self.roq_params['seglen'] = min(
self.roq_params['seglen'], quadratic_matrix['duration_s'][()]
)
if self.roq_params is not None:
for ifo in self.interferometers:
self.perform_roq_params_check(ifo)
self.weights = dict()
self._set_weights(linear_matrix=linear_matrix, quadratic_matrix=quadratic_matrix)
if is_hdf5_linear:
......@@ -158,9 +173,10 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
for basis_type in ['linear', 'quadratic']:
number_of_bases = getattr(self, f'number_of_bases_{basis_type}')
if number_of_bases > 1:
self._verify_prior_ranges_and_frequency_nodes(basis_type)
self._verify_numbers_of_prior_ranges_and_frequency_nodes(basis_type)
else:
self._check_frequency_nodes_exist_for_single_basis(basis_type)
self._verify_prior_ranges(basis_type)
self._set_unique_frequency_nodes_and_inverse()
# need to fill waveform_arguments here if single basis is used, as they will never be updated.
......@@ -171,7 +187,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
self._waveform_generator.waveform_arguments['linear_indices'] = linear_indices
self._waveform_generator.waveform_arguments['quadratic_indices'] = quadratic_indices
def _verify_prior_ranges_and_frequency_nodes(self, basis_type):
def _verify_numbers_of_prior_ranges_and_frequency_nodes(self, basis_type):
"""
Check if self.weights contains lists of prior ranges and frequency nodes, and their sizes are equal to the
number of bases.
......@@ -205,6 +221,35 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
raise ValueError(
f'The number of arrays of frequency nodes does not match the number of {basis_type} bases')
def _verify_prior_ranges(self, basis_type):
"""Check if the union of prior ranges is within the ROQ basis bounds.
Parameters
==========
basis_type: str
"""
key = f'prior_range_{basis_type}'
if key not in self.weights:
return
prior_ranges = self.weights[key]
for param_name, prior_ranges_of_this_param in prior_ranges.items():
prior_minimum = self.priors[param_name].minimum
basis_minimum = np.min(prior_ranges_of_this_param[:, 0])
if prior_minimum < basis_minimum:
raise BilbyROQParamsRangeError(
f"Prior minimum of {param_name} {prior_minimum} less "
f"than ROQ basis bound {basis_minimum}"
)
prior_maximum = self.priors[param_name].maximum
basis_maximum = np.max(prior_ranges_of_this_param[:, 1])
if prior_maximum > basis_maximum:
raise BilbyROQParamsRangeError(
f"Prior maximum of {param_name} {prior_maximum} greater "
f"than ROQ basis bound {basis_maximum}"
)
def _check_frequency_nodes_exist_for_single_basis(self, basis_type):
"""
For a single-basis case, frequency nodes should be contained in self._waveform_generator.waveform_arguments or
......@@ -701,6 +746,8 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
roq_scale_factor = 1.
prior_ranges[param_name] = matrix[key][param_name][()] * roq_scale_factor
selected_idxs, selected_prior_ranges = self._select_prior_ranges(prior_ranges)
if len(selected_idxs) == 0:
raise BilbyROQParamsRangeError(f"There are no {basis_type} ROQ bases within the prior range.")
self.weights[key] = selected_prior_ranges
idxs_in_prior_range[basis_type] = selected_idxs
else:
......@@ -725,7 +772,6 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
ifo_idxs = {}
for ifo in self.interferometers:
if self.roq_params is not None:
self.perform_roq_params_check(ifo)
# Get scaled ROQ quantities
roq_scaled_minimum_frequency = self.roq_params['flow'] * self.roq_scale_factor
roq_scaled_maximum_frequency = self.roq_params['fhigh'] * self.roq_scale_factor
......@@ -1110,11 +1156,11 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
f_high: float
The maximum frequency which must be considered
"""
from scipy.integrate import simps
from scipy.integrate import simpson
integrand1 = np.power(freq, -7. / 3) / psd
integral1 = simps(integrand1, freq)
integral1 = simpson(y=integrand1, x=freq)
integrand3 = np.power(freq, 2. / 3.) / (psd * integral1)
f_3_bar = simps(integrand3, freq)
f_3_bar = simpson(y=integrand3, x=freq)
f_high = scaling * f_3_bar**(1 / 3)
......
......@@ -1362,6 +1362,7 @@ class HealPixMapPriorDist(BaseJointPriorDist):
else:
self.distance = False
self.prob = self.hp.read_map(hp_file)
self.prob = self._check_norm(self.prob)
super(HealPixMapPriorDist, self).__init__(names=names, bounds=bounds)
self.distname = "hpmap"
......@@ -1458,7 +1459,7 @@ class HealPixMapPriorDist(BaseJointPriorDist):
self.distance_pdf = lambda r: self.distnorm[pix_idx] * norm(
loc=self.distmu[pix_idx], scale=self.distsigma[pix_idx]
).pdf(r)
pdfs = self.rs ** 2 * norm(loc=self.distmu[pix_idx], scale=self.distsigma[pix_idx]).pdf(self.rs)
pdfs = self.rs ** 2 * self.distance_pdf(self.rs)
cdfs = np.cumsum(pdfs) / np.sum(pdfs)
self.distance_icdf = interp1d(cdfs, self.rs)
......@@ -1501,9 +1502,7 @@ class HealPixMapPriorDist(BaseJointPriorDist):
sample : array_like
sample of ra, and dec (and distance if 3D=True)
"""
pixel_choices = np.arange(self.npix)
pixel_probs = self._check_norm(self.prob)
sample_pix = random.rng.choice(pixel_choices, size=size, p=pixel_probs, replace=True)
sample_pix = random.rng.choice(self.npix, size=size, p=self.prob, replace=True)
sample = np.empty((size, self.num_vars))
for samp in range(size):
theta, ra = self.hp.pix2ang(self.nside, sample_pix[samp])
......
......@@ -377,10 +377,6 @@ class CompactBinaryCoalescenceResult(CoreResult):
logger.debug("Downsampling frequency mask to {} values".format(
len(frequency_idxs))
)
frequency_window_factor = (
np.sum(interferometer.frequency_mask)
/ len(interferometer.frequency_mask)
)
plot_times = interferometer.time_array[time_idxs]
plot_times -= interferometer.strain_data.start_time
start_time -= interferometer.strain_data.start_time
......@@ -451,11 +447,7 @@ class CompactBinaryCoalescenceResult(CoreResult):
fig.add_trace(
go.Scatter(
x=plot_times,
y=np.fft.irfft(
interferometer.whitened_frequency_domain_strain
* np.sqrt(np.sum(interferometer.frequency_mask))
/ frequency_window_factor
)[time_idxs],
y=interferometer.whitened_time_domain_strain[time_idxs],
fill=None,
mode='lines', line_color=DATA_COLOR,
opacity=0.5,
......@@ -478,11 +470,7 @@ class CompactBinaryCoalescenceResult(CoreResult):
interferometer.amplitude_spectral_density_array[frequency_idxs],
color=DATA_COLOR, label='ASD')
axs[1].plot(
plot_times, np.fft.irfft(
interferometer.whitened_frequency_domain_strain
* np.sqrt(np.sum(interferometer.frequency_mask))
/ frequency_window_factor
)[time_idxs],
plot_times, interferometer.whitened_time_domain_strain[time_idxs],
color=DATA_COLOR, alpha=0.3)
logger.debug('Plotted interferometer data.')
......@@ -493,10 +481,10 @@ class CompactBinaryCoalescenceResult(CoreResult):
wf_pols = waveform_generator.frequency_domain_strain(params)
fd_waveform = interferometer.get_detector_response(wf_pols, params)
fd_waveforms.append(fd_waveform[frequency_idxs])
td_waveform = infft(
fd_waveform * np.sqrt(2. / interferometer.sampling_frequency) /
interferometer.amplitude_spectral_density_array,
self.sampling_frequency)[time_idxs]
whitened_fd_waveform = interferometer.whiten_frequency_series(fd_waveform)
td_waveform = interferometer.get_whitened_time_series_from_whitened_frequency_series(
whitened_fd_waveform
)[time_idxs]
td_waveforms.append(td_waveform)
fd_waveforms = asd_from_freq_series(
fd_waveforms,
......
......@@ -6,10 +6,15 @@ from .conversion import bilby_to_lalsimulation_spins
from .utils import (lalsim_GetApproximantFromString,
lalsim_SimInspiralFD,
lalsim_SimInspiralChooseFDWaveform,
lalsim_SimInspiralWaveformParamsInsertTidalLambda1,
lalsim_SimInspiralWaveformParamsInsertTidalLambda2,
lalsim_SimInspiralChooseFDWaveformSequence)
UNUSED_KWARGS_MESSAGE = """There are unused waveform kwargs. This is deprecated behavior and will
result in an error in future releases. Make sure all of the waveform kwargs are correctly
spelled.
Unused waveform_kwargs: {waveform_kwargs}
"""
def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_distance, a_1, tilt_1,
phi_12, a_2, tilt_2, phi_jl, theta_jn, phase, **kwargs):
......@@ -480,6 +485,54 @@ def lal_eccentric_binary_black_hole_no_spins(
eccentricity=eccentricity, **waveform_kwargs)
def set_waveform_dictionary(waveform_kwargs, lambda_1=0, lambda_2=0):
"""
Add keyword arguments to the :code:`LALDict` object.
Parameters
==========
waveform_kwargs: dict
A dictionary of waveform kwargs. This is modified in place to remove used arguments.
lambda_1: float
Dimensionless tidal deformability of the primary object.
lambda_2: float
Dimensionless tidal deformability of the primary object.
Returns
=======
waveform_dictionary: lal.LALDict
The lal waveform dictionary. This is either taken from the waveform_kwargs or created
internally.
"""
import lalsimulation as lalsim
from lal import CreateDict
waveform_dictionary = waveform_kwargs.pop('lal_waveform_dictionary', CreateDict())
waveform_kwargs["TidalLambda1"] = lambda_1
waveform_kwargs["TidalLambda2"] = lambda_2
waveform_kwargs["NumRelData"] = waveform_kwargs.pop("numerical_relativity_data", None)
for key in [
"pn_spin_order", "pn_tidal_order", "pn_phase_order", "pn_amplitude_order"
]:
waveform_kwargs[key[:2].upper() + key[3:].title().replace('_', '')] = waveform_kwargs.pop(key)
for key in list(waveform_kwargs.keys()).copy():
func = getattr(lalsim, f"SimInspiralWaveformParamsInsert{key}", None)
if func is None:
continue
value = waveform_kwargs.pop(key)
if func is not None and value is not None:
func(waveform_dictionary, value)
mode_array = waveform_kwargs.pop("mode_array", None)
if mode_array is not None:
mode_array_lal = lalsim.SimInspiralCreateModeArray()
for mode in mode_array:
lalsim.SimInspiralModeArrayActivateMode(mode_array_lal, mode[0], mode[1])
lalsim.SimInspiralWaveformParamsInsertModeArray(waveform_dictionary, mode_array_lal)
return waveform_dictionary
def _base_lal_cbc_fd_waveform(
frequency_array, mass_1, mass_2, luminosity_distance, theta_jn, phase,
a_1=0.0, a_2=0.0, tilt_1=0.0, tilt_2=0.0, phi_12=0.0, phi_jl=0.0,
......@@ -525,22 +578,16 @@ def _base_lal_cbc_fd_waveform(
=======
dict: A dictionary with the plus and cross polarisation strain modes
"""
import lal
import lalsimulation as lalsim
waveform_approximant = waveform_kwargs['waveform_approximant']
reference_frequency = waveform_kwargs['reference_frequency']
minimum_frequency = waveform_kwargs['minimum_frequency']
maximum_frequency = waveform_kwargs['maximum_frequency']
catch_waveform_errors = waveform_kwargs['catch_waveform_errors']
pn_spin_order = waveform_kwargs['pn_spin_order']
pn_tidal_order = waveform_kwargs['pn_tidal_order']
pn_phase_order = waveform_kwargs['pn_phase_order']
waveform_approximant = waveform_kwargs.pop('waveform_approximant')
reference_frequency = waveform_kwargs.pop('reference_frequency')
minimum_frequency = waveform_kwargs.pop('minimum_frequency')
maximum_frequency = waveform_kwargs.pop('maximum_frequency')
catch_waveform_errors = waveform_kwargs.pop('catch_waveform_errors')
pn_amplitude_order = waveform_kwargs['pn_amplitude_order']
waveform_dictionary = waveform_kwargs.get(
'lal_waveform_dictionary', lal.CreateDict()
)
waveform_dictionary = set_waveform_dictionary(waveform_kwargs, lambda_1, lambda_2)
approximant = lalsim_GetApproximantFromString(waveform_approximant)
if pn_amplitude_order != 0:
......@@ -567,35 +614,6 @@ def _base_lal_cbc_fd_waveform(
longitude_ascending_nodes = 0.0
mean_per_ano = 0.0
lalsim.SimInspiralWaveformParamsInsertPNSpinOrder(
waveform_dictionary, int(pn_spin_order))
lalsim.SimInspiralWaveformParamsInsertPNTidalOrder(
waveform_dictionary, int(pn_tidal_order))
lalsim.SimInspiralWaveformParamsInsertPNPhaseOrder(
waveform_dictionary, int(pn_phase_order))
lalsim.SimInspiralWaveformParamsInsertPNAmplitudeOrder(
waveform_dictionary, int(pn_amplitude_order))
lalsim_SimInspiralWaveformParamsInsertTidalLambda1(
waveform_dictionary, float(lambda_1))
lalsim_SimInspiralWaveformParamsInsertTidalLambda2(
waveform_dictionary, float(lambda_2))
for key, value in waveform_kwargs.items():
func = getattr(lalsim, "SimInspiralWaveformParamsInsert" + key, None)
if func is not None:
func(waveform_dictionary, value)
if waveform_kwargs.get('numerical_relativity_file', None) is not None:
lalsim.SimInspiralWaveformParamsInsertNumRelData(
waveform_dictionary, waveform_kwargs['numerical_relativity_file'])
if ('mode_array' in waveform_kwargs) and waveform_kwargs['mode_array'] is not None:
mode_array = waveform_kwargs['mode_array']
mode_array_lal = lalsim.SimInspiralCreateModeArray()
for mode in mode_array:
lalsim.SimInspiralModeArrayActivateMode(mode_array_lal, mode[0], mode[1])
lalsim.SimInspiralWaveformParamsInsertModeArray(waveform_dictionary, mode_array_lal)
if lalsim.SimInspiralImplementedFDApproximants(approximant):
wf_func = lalsim_SimInspiralChooseFDWaveform
else:
......@@ -650,6 +668,9 @@ def _base_lal_cbc_fd_waveform(
h_plus[frequency_bounds] *= time_shift
h_cross[frequency_bounds] *= time_shift
if len(waveform_kwargs) > 0:
logger.warning(UNUSED_KWARGS_MESSAGE.format(waveform_kwargs))
return dict(plus=h_plus, cross=h_cross)
......@@ -705,6 +726,7 @@ def lal_binary_black_hole_relative_binning(
waveform_kwargs.update(kwargs)
if fiducial == 1:
_ = waveform_kwargs.pop("frequency_bin_edges", None)
return _base_lal_cbc_fd_waveform(
frequency_array=frequency_array, mass_1=mass_1, mass_2=mass_2,
luminosity_distance=luminosity_distance, theta_jn=theta_jn, phase=phase,
......@@ -712,6 +734,8 @@ def lal_binary_black_hole_relative_binning(
phi_12=phi_12, lambda_1=0.0, lambda_2=0.0, **waveform_kwargs)
else:
_ = waveform_kwargs.pop("minimum_frequency", None)
_ = waveform_kwargs.pop("maximum_frequency", None)
waveform_kwargs["frequencies"] = waveform_kwargs.pop("frequency_bin_edges")
return _base_waveform_frequency_sequence(
frequency_array=frequency_array, mass_1=mass_1, mass_2=mass_2,
......@@ -748,6 +772,8 @@ def lal_binary_neutron_star_relative_binning(
a_1=a_1, a_2=a_2, tilt_1=tilt_1, tilt_2=tilt_2, phi_12=phi_12,
phi_jl=phi_jl, lambda_1=lambda_1, lambda_2=lambda_2, **waveform_kwargs)
else:
_ = waveform_kwargs.pop("minimum_frequency", None)
_ = waveform_kwargs.pop("maximum_frequency", None)
waveform_kwargs["frequencies"] = waveform_kwargs.pop("frequency_bin_edges")
return _base_waveform_frequency_sequence(
frequency_array=frequency_array, mass_1=mass_1, mass_2=mass_2,
......@@ -822,19 +848,21 @@ def _base_roq_waveform(
if 'frequency_nodes' not in waveform_arguments:
size_linear = len(waveform_arguments['frequency_nodes_linear'])
frequency_nodes_combined = np.hstack(
(waveform_arguments['frequency_nodes_linear'],
waveform_arguments['frequency_nodes_quadratic'])
(waveform_arguments.pop('frequency_nodes_linear'),
waveform_arguments.pop('frequency_nodes_quadratic'))
)
frequency_nodes_unique, original_indices = np.unique(
frequency_nodes_combined, return_inverse=True
)
linear_indices = original_indices[:size_linear]
quadratic_indices = original_indices[size_linear:]
waveform_arguments['frequency_nodes'] = frequency_nodes_unique
waveform_arguments['linear_indices'] = linear_indices
waveform_arguments['quadratic_indices'] = quadratic_indices
waveform_arguments['frequencies'] = waveform_arguments['frequency_nodes']
waveform_arguments['frequencies'] = frequency_nodes_unique
else:
linear_indices = waveform_arguments.pop("linear_indices")
quadratic_indices = waveform_arguments.pop("quadratic_indices")
for key in ["frequency_nodes_linear", "frequency_nodes_quadratic"]:
_ = waveform_arguments.pop(key, None)
waveform_arguments['frequencies'] = waveform_arguments.pop('frequency_nodes')
waveform_polarizations = _base_waveform_frequency_sequence(
frequency_array=frequency_array, mass_1=mass_1, mass_2=mass_2,
luminosity_distance=luminosity_distance, theta_jn=theta_jn, phase=phase,
......@@ -843,12 +871,12 @@ def _base_roq_waveform(
return {
'linear': {
'plus': waveform_polarizations['plus'][waveform_arguments['linear_indices']],
'cross': waveform_polarizations['cross'][waveform_arguments['linear_indices']]
'plus': waveform_polarizations['plus'][linear_indices],
'cross': waveform_polarizations['cross'][linear_indices]
},
'quadratic': {
'plus': waveform_polarizations['plus'][waveform_arguments['quadratic_indices']],
'cross': waveform_polarizations['cross'][waveform_arguments['quadratic_indices']]
'plus': waveform_polarizations['plus'][quadratic_indices],
'cross': waveform_polarizations['cross'][quadratic_indices]
}
}
......@@ -1059,49 +1087,13 @@ def _base_waveform_frequency_sequence(
Dict containing plus and cross modes evaluated at the linear and
quadratic frequency nodes.
"""
from lal import CreateDict
import lalsimulation as lalsim
frequencies = waveform_kwargs['frequencies']
reference_frequency = waveform_kwargs['reference_frequency']
approximant = lalsim_GetApproximantFromString(waveform_kwargs['waveform_approximant'])
catch_waveform_errors = waveform_kwargs['catch_waveform_errors']
pn_spin_order = waveform_kwargs['pn_spin_order']
pn_tidal_order = waveform_kwargs['pn_tidal_order']
pn_phase_order = waveform_kwargs['pn_phase_order']
pn_amplitude_order = waveform_kwargs['pn_amplitude_order']
waveform_dictionary = waveform_kwargs.get(
'lal_waveform_dictionary', CreateDict()
)
lalsim.SimInspiralWaveformParamsInsertPNSpinOrder(
waveform_dictionary, int(pn_spin_order))
lalsim.SimInspiralWaveformParamsInsertPNTidalOrder(
waveform_dictionary, int(pn_tidal_order))
lalsim.SimInspiralWaveformParamsInsertPNPhaseOrder(
waveform_dictionary, int(pn_phase_order))
lalsim.SimInspiralWaveformParamsInsertPNAmplitudeOrder(
waveform_dictionary, int(pn_amplitude_order))
lalsim_SimInspiralWaveformParamsInsertTidalLambda1(
waveform_dictionary, float(lambda_1))
lalsim_SimInspiralWaveformParamsInsertTidalLambda2(
waveform_dictionary, float(lambda_2))
for key, value in waveform_kwargs.items():
func = getattr(lalsim, "SimInspiralWaveformParamsInsert" + key, None)
if func is not None:
func(waveform_dictionary, value)
frequencies = waveform_kwargs.pop('frequencies')
reference_frequency = waveform_kwargs.pop('reference_frequency')
approximant = waveform_kwargs.pop('waveform_approximant')
catch_waveform_errors = waveform_kwargs.pop('catch_waveform_errors')
if waveform_kwargs.get('numerical_relativity_file', None) is not None:
lalsim.SimInspiralWaveformParamsInsertNumRelData(
waveform_dictionary, waveform_kwargs['numerical_relativity_file'])
if ('mode_array' in waveform_kwargs) and waveform_kwargs['mode_array'] is not None:
mode_array = waveform_kwargs['mode_array']
mode_array_lal = lalsim.SimInspiralCreateModeArray()
for mode in mode_array:
lalsim.SimInspiralModeArrayActivateMode(mode_array_lal, mode[0], mode[1])
lalsim.SimInspiralWaveformParamsInsertModeArray(waveform_dictionary, mode_array_lal)
waveform_dictionary = set_waveform_dictionary(waveform_kwargs, lambda_1, lambda_2)
approximant = lalsim_GetApproximantFromString(approximant)
luminosity_distance = luminosity_distance * 1e6 * utils.parsec
mass_1 = mass_1 * utils.solar_mass
......@@ -1135,6 +1127,9 @@ def _base_waveform_frequency_sequence(
else:
raise
if len(waveform_kwargs) > 0:
logger.warning(UNUSED_KWARGS_MESSAGE.format(waveform_kwargs))
return dict(plus=h_plus.data.data, cross=h_cross.data.data)
......
from ..core.utils import infer_args_from_function_except_n_args
class Model(object):
class Model:
r"""
Population model that combines a set of factorizable models.
......@@ -12,18 +12,24 @@ class Model(object):
p(\theta | \Lambda) = \prod_{i} p_{i}(\theta | \Lambda)
"""
def __init__(self, model_functions=None):
def __init__(self, model_functions=None, cache=True):
"""
Parameters
==========
model_functions: list
List of callables to compute the probability.
If this includes classes, the `__call__` method should return the
probability.
If this includes classes, the :code:`__call__`: method
should return the probability.
The requires variables are chosen at run time based on either
inspection or querying a :code:`variable_names` attribute.
cache: bool
Whether to cache the value returned by the model functions,
default=:code:`True`. The caching only looks at the parameters
not the data, so should be used with caution. The caching also
breaks :code:`jax` JIT compilation.
"""
self.models = model_functions
self.cache = cache
self._cached_parameters = {model: None for model in self.models}
self._cached_probability = {model: None for model in self.models}
......@@ -48,14 +54,18 @@ class Model(object):
probability = 1.0
for ii, function in enumerate(self.models):
function_parameters = self._get_function_parameters(function)
if self._cached_parameters[function] == function_parameters:
if (
self.cache
and self._cached_parameters[function] == function_parameters
):
new_probability = self._cached_probability[function]
else:
new_probability = function(
data, **self._get_function_parameters(function)
)
self._cached_parameters[function] = function_parameters
self._cached_probability[function] = new_probability
if self.cache:
self._cached_parameters[function] = function_parameters
self._cached_probability[function] = new_probability
probability *= new_probability
return probability
......
......@@ -5,8 +5,8 @@ LABEL name="bilby CI testing" \
maintainer="Gregory Ashton <gregory.ashton@ligo.org>, Colm Talbot <colm.talbot@ligo.org>"
COPY env-template.yml env.yml
RUN echo " - python=3.9" >> env.yml
ENV conda_env python39
RUN echo " - python=3.12" >> env.yml
ENV conda_env python312
RUN mamba env create -f env.yml -n ${conda_env}
RUN echo "source activate ${conda_env}" > ~/.bashrc
......