Skip to content

DEV: add variable backend and jax

Colm Talbot requested to merge add-jax-backend into main

This MR is a fairly large refactoring to enable us to switch between different optimized backends. It is a sufficiently large change to possibly warrant renaming the project.

I used jax as an example as that enables JIT compilation, autodiff, and automatic vectorization.

The tests were completely refactored to make it easier for them all to be parameterized.

The default backend is determined by looking for an environment variable BILBY_ARRAY_DEFAULT_BACKEND with a fallback to the current default (cython).

The backend can then be set using bilby_cython.set_backend. I noticed that some functions that are explicitly imported in Bilby don't play very well with the changes, but being able to specify the default via env var should provide enough flexibility.

The base-level geometry and time modules are now pass-throughs to the specific versions which can also be explicitly imported to avoid the overhead, but I suspect that is minimal.

As an aside, I experimented with mlx, but the lack of float64 support makes the gmst calculation too imprecise.

Merge request reports