Skip to content

jax Leaked Tracer when converting to beta parameters

When compute_rate_posterior is called, jax is raising a UnexpectedTracerError. I think I traced it back to here, which ultimately converts the mu and var parameters to the beta distribution parameters in this function which modifies the parameters in-place, which I think is the cause of the leaked tracer object. When I replace this line with

parameters_original = dict(posterior.iloc[ii])
parameters = convert_to_beta_parameters(parameters_original)[0]

compute_rate_posterior runs without error.