Skip to content

Added model.prob before jitting

I am currently getting an error from common_format, where jax says jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on traced array with shape float64[1000]. when constructing the self._q_interpolant. I'm not sure why, but it seems to have began when I made the changes in https://git.ligo.org/jaxen.godfrey/o4a-astro-dist-model-comparison-study/-/merge_requests/29, which for some reason makes self.m1s traced in the extract_choices function in a way it was not before...

I was able to fix this by making sure the model is called once before the first call of the jitted function, so the jitted function never encounters the caching step.

@colm.talbot

Merge request reports

Loading