Replies: 1 comment 1 reply
-
It's certainly possible to write code for which a JVP would materialize an array of that size. Without seeing what code you're running, it's hard to say any more than that. Can you include a minimal reproduction of the issue? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm trying to debug excessive memory use when training LMs. It seems that a 65M model with a batch size
> 32
seems to OOM on a 16GB card. The situation is worse for bigger batches and models.Profiling Graphs:
One
Array
I'm noticing seems to be this one, for thejvp
which apears to be of shape(batch_size * seqlen, vocab_size) = (32 * 512, 50304)
. This is the biggest allocation on the heap, according to the profiler.Surely this array shouldn't materialize, as for larger models it'd be severely memory inefficient? Is there something I'm missing here or is this expected?
Beta Was this translation helpful? Give feedback.
All reactions