Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/autodiff/checkpoint #1239

Closed
wants to merge 35 commits into from
Closed

Feat/autodiff/checkpoint #1239

wants to merge 35 commits into from

Conversation

louisfd
Copy link
Member

@louisfd louisfd commented Feb 2, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

WIP for #936

Changes

Most of the logic for the autodiff checkpointing strategy.
At the moment, the code is all independant from the rest of Burn, I suspect it won't pass the CI because everything is unused; thus why I put my PR as a draft.
It has already undergone several refactorings, so I'm pretty satisfied with the cleanness.
Next I will of course plug it in all autodiff operations.

Testing

Heavily tested in the tests.rs file

burn-autodiff/src/checkpoint/base.rs Outdated Show resolved Hide resolved
burn-autodiff/src/checkpoint/base.rs Outdated Show resolved Hide resolved
burn-autodiff/src/checkpoint/base.rs Outdated Show resolved Hide resolved
Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the tests quite confusing, they normaly reflect how the code is going to be used, but it seems like the "scenario" is defined in many methods and assumed to be static.

What I think should improve the code:

  1. The "normal" forward pass that generates the Checkpoint struct with maybe some assertions on the Computed and Recompute tensors.
  2. Simulate the backward pass that calculates the "retro forwards" and asserts the results.

The more the forward pass is similar to normal tensor operations, the better.

burn-autodiff/src/checkpoint/tests.rs Outdated Show resolved Hide resolved
burn-autodiff/src/checkpoint/tests.rs Outdated Show resolved Hide resolved
Comment on lines 106 to 116
fn make_ids() -> [NodeID; 7] {
[
NodeID::new(),
NodeID::new(),
NodeID::new(),
NodeID::new(),
NodeID::new(),
NodeID::new(),
NodeID::new(),
]
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That method isn't helpful, we can inline each NodeId where they are needed.

}

/// Make the leaves for a div tree
fn make_leaves<B: Backend>(device: &B::Device, ids: [NodeID; 4]) -> (InnerStates, NodeTree) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, the NodeIds should be inlined.

burn-autodiff/src/checkpoint/tests.rs Outdated Show resolved Hide resolved
Comment on lines 45 to 47
#[cfg(test)]
mod tests {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are already in a test module this is unecessary.

burn-autodiff/src/checkpoint/mod.rs Show resolved Hide resolved
@louisfd
Copy link
Member Author

louisfd commented Feb 6, 2024

You should be happier with the new state of the production code. Tests are still convoluted but they were never really meant to outlive the development phase. I will migrate them to clean ones once I can use the autodiff api

@louisfd louisfd closed this Feb 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants