WIP: Jax backend
Implements a different backend in gwdistributions.backends.numerics
so that users can rely on jax when using gw-distribution's objects. This should allow users to smoothly support HMC sampling with numpyro.
fixes #23 (closed)
To Do
-
implement backend for jax -
remove backend for theano.tensor -
change all calls to be.vector
andbe.matrix
to justbe.array
-
remove all instances where we set specific indecies of arrays (not supported with trivial syntax in jax) -
we just have to live with array assignment during sampling? or use lists and then cast? -
we can eat the little bit of extra cost when instantiating interpolators?
-
-
run full test suite with each backend -
numpy
-
jax.numpy
-
cupy
(NOT TESTED)
-
Note: it appears that the following numpy opperations are ordered in this way from fastest to slowest
x += y
x = x + y
x[:] = x + y
not sure about the memory cost, though.
Edited by Reed Essick