-
-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient Accumulation #400
Comments
I think if we change the underlying storage to Would need to add implementations of AddAssign for raw device storage |
So I was working on this a bit cause I thought I had a pretty clever solution. Turns out Notably you can't have a trait that takes/returns any trait AddSelf {
fn add_self(self, rhs: Self) -> Self;
} If we can't do this with Gradients object, we may need to add some separate gradient accumulator object that does lazy addition: struct GradientAccumulator {
gradients: Vec<Gradients>
} and then add some abstraction layer for the optimizers to use: trait HasGradients {
fn remove<T>(&mut self, t: &T) -> Option<T::Gradient>
where
T: HasUniqueId + AllocGrad;
} impl this for both fn update<G: HasGradients>(
&mut self,
module: &mut M,
gradients: G,
) -> Result<(), OptimizerUpdateError<D>>; |
@coreylowman Why can't there be a trait that accepts / returns Self? That first trait you had looks acceptable to me |
You can't use that kind of object with |
Another option for this: give an option to pass in an existing Gradients object to Pros:
|
Often times it is desired to train on larger batch sizes than can fit on the GPU / memory at once. Accumulating gradients across mini-batches is a solution, which can effectively simulate larger batch sizes, albeit without the parallelism advantage.
A straightforward way to do this would be to impl
Add<Gradients<D>>
toGradients<D>
such that gradients on the same device can be added.I'm not sure how this would be handled, since the grads seem to be stored in a Box, which I don't see how you can add to.
The text was updated successfully, but these errors were encountered: