Memory issues with grad
and lax.scan
#25181
Unanswered
vboussange
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hey there,
I am looking to implement a differentiable least cost path algorithm, working with large graphs. I have been considering the Bellman-Ford algorithm, as I am especially interested in having one to all vertex distances. Here is a custom implementation, that seems to work beautifully:
However, when trying to differentiate
bellman_ford
with respect to edge weights, I encounter aRESOURCE_EXHAUSTED: Out of memory
error.It seems that the
@jax.checkpoint
helps reduce memory usage, but it’s not sufficient to prevent memory build-up. I am actually suspecting a memory leak.Am I doing something wrong here? Would there be a better implementation to avoid memory exhaustion?
Side note: I’m considering creating a JAX-based package for graph utilities (all JIT-compatible and differentiable). If anyone is interested, let me know!
Thanks in advance for your help!
Beta Was this translation helpful? Give feedback.
All reactions