Skip to content

WIP: Jax backend

Reed Essick requested to merge jax-backend into master

Implements a different backend in gwdistributions.backends.numerics so that users can rely on jax when using gw-distribution's objects. This should allow users to smoothly support HMC sampling with numpyro.

fixes #23 (closed)

To Do

  • implement backend for jax
  • remove backend for theano.tensor
  • change all calls to be.vector and be.matrix to just be.array
  • remove all instances where we set specific indecies of arrays (not supported with trivial syntax in jax)
    • we just have to live with array assignment during sampling? or use lists and then cast?
    • we can eat the little bit of extra cost when instantiating interpolators?
  • run full test suite with each backend
    • numpy
    • jax.numpy
    • cupy (NOT TESTED)

Note: it appears that the following numpy opperations are ordered in this way from fastest to slowest

  • x += y
  • x = x + y
  • x[:] = x + y

not sure about the memory cost, though.

Edited by Reed Essick

Merge request reports

Loading