[math] support jax.disable_jit()
for debugging
#389
Merged
jax.disable_jit()
for debugging
#389