jax Leaked Tracer when converting to beta parameters
When compute_rate_posterior is called, jax is raising a UnexpectedTracerError
. I think I traced it back to here, which ultimately converts the mu and var parameters to the beta distribution parameters in this function which modifies the parameters in-place, which I think is the cause of the leaked tracer object.
When I replace this line with
parameters_original = dict(posterior.iloc[ii])
parameters = convert_to_beta_parameters(parameters_original)[0]
compute_rate_posterior runs without error.