Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/autodiff/checkpoint ops #1358

Merged
merged 64 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
a4159d2
made it to ownership stage
louisfd Jan 29, 2024
478023f
working test!
louisfd Jan 30, 2024
1878ed9
refactor tests
louisfd Jan 31, 2024
2b1e832
tests for computed
louisfd Jan 31, 2024
d078587
retroforward works well (without n_required)
louisfd Feb 1, 2024
7daef33
refactor inner states
louisfd Feb 1, 2024
9f6f218
refactor into files
louisfd Feb 1, 2024
a91bb2b
refactor into files
louisfd Feb 1, 2024
ee3acd8
subtract overflow
louisfd Feb 1, 2024
cf195be
subtract overflow
louisfd Feb 1, 2024
6734590
tests pass
louisfd Feb 2, 2024
e40a035
cleanup and doc
louisfd Feb 2, 2024
85f2a79
merge main
louisfd Feb 2, 2024
b72cc20
code review
louisfd Feb 5, 2024
ab79012
remove retro forward
louisfd Feb 5, 2024
dfb7c97
some refactoring
louisfd Feb 5, 2024
368ea3f
switch to primitives
louisfd Feb 6, 2024
73a5f77
some cleanup
louisfd Feb 6, 2024
5d4e04d
clippy
louisfd Feb 6, 2024
7cade3c
wip integrate to ops
louisfd Feb 6, 2024
e321dab
compiles
louisfd Feb 7, 2024
99fda7b
some use cases work
louisfd Feb 7, 2024
4fc4648
notes
louisfd Feb 7, 2024
f9a2e74
fix bugs
louisfd Feb 8, 2024
d1b5c08
dirty - post pair prog
louisfd Feb 8, 2024
c1fae22
passes tests
louisfd Feb 8, 2024
2a8a18a
wip
louisfd Feb 10, 2024
1e520fb
wip
louisfd Feb 14, 2024
8174b4a
it works
louisfd Feb 15, 2024
94a8373
automatic tests
louisfd Feb 15, 2024
cb03b98
refactor and doc
louisfd Feb 16, 2024
612919b
wip bugfixes
louisfd Feb 16, 2024
0ea7ee2
confident everything works
louisfd Feb 16, 2024
b445187
fmt
louisfd Feb 16, 2024
75ee760
clippy
louisfd Feb 16, 2024
12cd9a3
refactor memory bound api
louisfd Feb 19, 2024
9925645
Merge branch 'main' into feat/autodiff/checkpoint
louisfd Feb 20, 2024
a219b34
ops
louisfd Feb 20, 2024
c61a51a
WIP fix bugs
nathanielsimard Feb 20, 2024
9ad240c
C
nathanielsimard Feb 20, 2024
8e2ad7c
checkpointer builder
louisfd Feb 21, 2024
bb0ce0e
finish builder refactor
louisfd Feb 21, 2024
47b44de
Merge branch 'feat/autodiff/checkpoint_ops' of github.com:tracel-ai/b…
louisfd Feb 21, 2024
dc5859e
wip retro forwards
louisfd Feb 21, 2024
e9aacdd
wip tensor memory bound
louisfd Feb 21, 2024
47efc4a
tensor ops compute or memory bound
louisfd Feb 21, 2024
1506764
all ops bound identified
louisfd Feb 21, 2024
88984ac
Merge branch 'main' into feat/autodiff/checkpoint_ops
louisfd Feb 21, 2024
6244787
merge main
louisfd Feb 21, 2024
fb38576
fix untracked bug
louisfd Feb 21, 2024
dfe82de
add checkpointing everywhere
louisfd Feb 22, 2024
085c28f
fmt
louisfd Feb 22, 2024
d56f6df
clippy
louisfd Feb 22, 2024
9580ba1
remove print
louisfd Feb 22, 2024
9edc8b3
Merge branch 'main' into feat/autodiff/checkpoint_ops
nathanielsimard Feb 22, 2024
bcfc649
Add autodiff
nathanielsimard Feb 22, 2024
1c9a2e7
Merge branch 'main' into feat/autodiff/checkpoint_ops
louisfd Feb 23, 2024
9a05b84
configurable checkpointing
louisfd Feb 23, 2024
0aeaa3f
alias ops for autodiff tests
louisfd Feb 23, 2024
eec4b12
tests in backends
louisfd Feb 23, 2024
4b69f01
clippy
louisfd Feb 23, 2024
5f0afe3
Merge branch 'main' into feat/autodiff/checkpoint_ops
louisfd Feb 23, 2024
851c4d6
Merge branch 'main' into feat/autodiff/checkpoint_ops
louisfd Feb 26, 2024
79eab9d
code review
louisfd Feb 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repository = "https://github.com/tracel-ai/burn/tree/main/backend-comparison"
version.workspace = true

[features]
default = ["burn/std"]
default = ["burn/std", "burn/autodiff"]
candle-cpu = ["burn/candle"]
candle-cuda = ["burn/candle", "burn/cuda"]
candle-metal = ["burn/candle", "burn/metal"]
Expand Down
84 changes: 61 additions & 23 deletions backend-comparison/benches/custom_gelu.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use backend_comparison::persistence::save;
use burn::backend::Autodiff;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};
use core::f64::consts::SQRT_2;
Expand All @@ -18,13 +19,18 @@ struct CustomGeluBenchmark<B: Backend, const D: usize> {
shape: Shape<D>,
device: B::Device,
kind: GeluKind,
autodiff: bool,
}

impl<B: Backend, const D: usize> Benchmark for CustomGeluBenchmark<B, D> {
type Args = Tensor<B, D>;

fn name(&self) -> String {
"gelu".into()
match self.autodiff {
true => "gelu_autodiff",
false => "gelu",
}
.into()
}

fn options(&self) -> Option<String> {
Expand All @@ -35,11 +41,26 @@ impl<B: Backend, const D: usize> Benchmark for CustomGeluBenchmark<B, D> {
vec![self.shape.dims.into()]
}

fn execute(&self, args: Self::Args) {
match self.kind {
GeluKind::Reference => burn::tensor::activation::gelu(args),
GeluKind::WithReferenceErf => gelu_custom(args, Tensor::erf),
GeluKind::WithCustomErf => gelu_custom(args, erf_custom),
fn execute(&self, tensor: Self::Args) {
match self.autodiff {
true => {
let tensor: Tensor<Autodiff<B>, D> = Tensor::from_inner(tensor).require_grad();
let output = match self.kind {
GeluKind::Reference => burn::tensor::activation::gelu(tensor.clone()),
GeluKind::WithReferenceErf => gelu_custom(tensor.clone(), Tensor::erf),
GeluKind::WithCustomErf => gelu_custom(tensor.clone(), erf_custom),
};
let mut gradients = output.sum().backward();
let _tmp = tensor.grad_remove(&mut gradients).unwrap();
}

false => {
match self.kind {
GeluKind::Reference => burn::tensor::activation::gelu(tensor),
GeluKind::WithReferenceErf => gelu_custom(tensor, Tensor::erf),
GeluKind::WithCustomErf => gelu_custom(tensor, erf_custom),
};
}
};
}

Expand All @@ -52,7 +73,7 @@ impl<B: Backend, const D: usize> Benchmark for CustomGeluBenchmark<B, D> {
}

fn num_samples(&self) -> usize {
50
10
}
}

Expand Down Expand Up @@ -97,22 +118,39 @@ fn bench<B: Backend>(device: &B::Device) {
const D: usize = 3;
let shape: Shape<D> = [32, 512, 2048].into();

let reference_gelu =
CustomGeluBenchmark::<B, D>::new(shape.clone(), device.clone(), GeluKind::Reference);
let reference_erf_gelu =
CustomGeluBenchmark::<B, D>::new(shape.clone(), device.clone(), GeluKind::WithReferenceErf);
let custom_erf_gelu =
CustomGeluBenchmark::<B, D>::new(shape, device.clone(), GeluKind::WithCustomErf);

save::<B>(
vec![
run_benchmark(reference_gelu),
run_benchmark(reference_erf_gelu),
run_benchmark(custom_erf_gelu),
],
device,
)
.unwrap();
let run = |autodiff: bool| {
let reference_gelu = CustomGeluBenchmark::<B, D>::new(
shape.clone(),
device.clone(),
GeluKind::Reference,
autodiff,
);
let reference_erf_gelu = CustomGeluBenchmark::<B, D>::new(
shape.clone(),
device.clone(),
GeluKind::WithReferenceErf,
autodiff,
);
let custom_erf_gelu = CustomGeluBenchmark::<B, D>::new(
shape.clone(),
device.clone(),
GeluKind::WithCustomErf,
autodiff,
);

save::<B>(
vec![
run_benchmark(reference_gelu),
run_benchmark(reference_erf_gelu),
run_benchmark(custom_erf_gelu),
],
device,
)
.unwrap();
};

run(false);
run(true);
}

fn main() {
Expand Down
77 changes: 46 additions & 31 deletions burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ pub trait Backend: burn::tensor::backend::Backend {
pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {}
```

In our project, we can use these traits instead of the `burn::tensor::backend::{Backend, AutodiffBackend}`
traits provided by Burn. Burn's user APIs typically make use of the `Tensor` struct rather than
dealing directly with primitive tensor types. Therefore, we can encapsulate our newly defined
backend traits with functions that expose new operations while maintaining a consistent API.
In our project, we can use these traits instead of the
`burn::tensor::backend::{Backend, AutodiffBackend}` traits provided by Burn. Burn's user APIs
typically make use of the `Tensor` struct rather than dealing directly with primitive tensor types.
Therefore, we can encapsulate our newly defined backend traits with functions that expose new
operations while maintaining a consistent API.

```rust, ignore
/// We define our custom implementation using the added function on our custom backend.
Expand Down Expand Up @@ -193,9 +194,9 @@ impl<E: FloatElement> DynamicKernel for FusedMatmulAddRelu<E> {
}
```

Subsequently, we'll go into implementing our custom backend trait for the WGPU backend.
Note that we won't go into supporting the `fusion` feature flag in this tutorial, so
we implement the trait for the raw `WgpuBackend` type.
Subsequently, we'll go into implementing our custom backend trait for the WGPU backend. Note that we
won't go into supporting the `fusion` feature flag in this tutorial, so we implement the trait for
the raw `WgpuBackend` type.

```rust, ignore
/// Implement our custom backend trait for the existing backend `WgpuBackend`.
Expand Down Expand Up @@ -296,7 +297,7 @@ operations.
// Note that we could implement the backend trait only for the Wgpu backend instead of any backend that
// also implements our own API. This would allow us to call any function only implemented for Wgpu
// and potentially call a custom kernel crafted only for this task.
impl<B: Backend> Backend for Autodiff<B> {
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
fn fused_matmul_add_relu<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
Expand All @@ -309,30 +310,32 @@ impl<B: Backend> Backend for Autodiff<B> {
// Implement the backward trait for the given backend B, the node gradient being of rank D
// with three other gradients to calculate (lhs, rhs, and bias).
impl<B: Backend, const D: usize> Backward<B, D, 3> for FusedMatmulAddReluBackward<D> {
// The state that must be built during the forward pass to compute the backward pass.
// Our state that we must build during the forward pass to compute the backward pass.
//
// Note that we could improve the performance further by only keeping the state of
// tensors that are tracked, improving memory management, but for simplicity, we avoid
// that part.
type State = (
FloatTensor<B, D>,
FloatTensor<B, D>,
FloatTensor<B, D>,
Shape<D>,
);

fn backward(self, ops: Ops<Self::State, 3>, grads: &mut Gradients) {
type State = (NodeID, NodeID, FloatTensor<B, D>, Shape<D>);

fn backward(
self,
ops: Ops<Self::State, 3>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
// Get the nodes of each variable.
let [node_lhs, node_rhs, node_bias] = ops.parents;
// Fetch the gradient for the current node.
let grad = grads.consume::<B, D>(&ops.node);

// Set the state.
let (lhs, rhs, output, shape_bias) = ops.state;
// Set our state.
let (lhs_state, rhs_state, output, shape_bias) = ops.state;
let lhs = checkpointer.retrieve_node_output(lhs_state);
let rhs = checkpointer.retrieve_node_output(rhs_state);

// Fetch shapes of the tensors to support broadcasting.
let shape_lhs = B::shape(&lhs);
let shape_rhs = B::shape(&rhs);
// Fetch shapes of our tensor to support broadcasting.
let shape_lhs = B::float_shape(&lhs);
let shape_rhs = B::float_shape(&rhs);

// Compute the gradient of the output using the already existing `relu_backward`
// function in the basic Burn backend trait.
Expand All @@ -341,13 +344,13 @@ impl<B: Backend> Backend for Autodiff<B> {
// Compute the lhs gradient, which is the derivative of matmul with support for
// broadcasting.
let grad_lhs = broadcast_shape::<B, D>(
B::matmul(grad_output.clone(), B::transpose(rhs)),
B::float_matmul(grad_output.clone(), B::float_transpose(rhs)),
&shape_lhs,
);
// Compute the rhs gradient, which is the derivative of matmul with support for
// broadcasting.
let grad_rhs = broadcast_shape::<B, D>(
B::matmul(B::transpose(lhs), grad_output.clone()),
B::float_matmul(B::float_transpose(lhs), grad_output.clone()),
&shape_rhs,
);
// The add derivative is only 1, so we just need to support broadcasting to
Expand All @@ -372,23 +375,35 @@ impl<B: Backend> Backend for Autodiff<B> {
//
// Each node can be fetched with `ops.parents` in the same order as defined here.
match FusedMatmulAddReluBackward
.prepare(
[lhs.node, rhs.node, bias.node],
[lhs.graph, rhs.graph, bias.graph],
.prepare::<C>(
[lhs.node.clone(), rhs.node.clone(), bias.node.clone()],
[lhs.graph.clone(), rhs.graph.clone(), bias.graph.clone()],
)
// Marks the operation as compute bound, meaning it will save its
// state instead of recomputing itself during checkpointing
.compute_bound()
.stateful()
{
OpsKind::Tracked(prep) => {
OpsKind::Tracked(mut prep) => {
// When at least one node is tracked, we should register our backward step.
// We compute the output and the state before finishing the preparation.
let bias_shape = B::shape(&bias.primitive);

// The state consists of what will be needed for this operation's backward pass.
// Since we need the parents' outputs, we must checkpoint their ids to retrieve their node
// output at the beginning of the backward. We can also save utilitary data such as the bias shape
// If we also need this operation's output, we can either save it in the state or recompute it
// during the backward pass. Here we choose to save it in the state because it's a compute bound operation.
let lhs_state = prep.checkpoint(&lhs);
let rhs_state = prep.checkpoint(&rhs);
let bias_shape = B::float_shape(&bias.primitive);

let output = B::fused_matmul_add_relu(
lhs.primitive.clone(),
rhs.primitive.clone(),
bias.primitive,
);

let state = (lhs.primitive, rhs.primitive, output.clone(), bias_shape);
let state = (lhs_state, rhs_state, output.clone(), bias_shape);

prep.finish(state, output)
}
OpsKind::UnTracked(prep) => {
Expand Down
2 changes: 0 additions & 2 deletions crates/burn-autodiff/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,3 @@ spin = { workspace = true }
burn-tensor = { path = "../burn-tensor", version = "0.13.0", default-features = false, features = [
"export_tests",
] }


14 changes: 10 additions & 4 deletions crates/burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use crate::{grads::Gradients, graph::backward::backward, tensor::AutodiffTensor};
use crate::{
checkpoint::strategy::{CheckpointStrategy, NoCheckpointing},
grads::Gradients,
graph::backward::backward,
tensor::AutodiffTensor,
};
use burn_tensor::backend::{AutodiffBackend, Backend};
use core::marker::PhantomData;

Expand All @@ -7,11 +12,12 @@ use core::marker::PhantomData;
/// This works as a backend decorator, extending the functionality of any backend with
/// backpropagation.
#[derive(Clone, Copy, Debug, Default)]
pub struct Autodiff<B> {
pub struct Autodiff<B, C = NoCheckpointing> {
_b: PhantomData<B>,
_checkpoint_strategy: PhantomData<C>,
}

impl<B: Backend> Backend for Autodiff<B> {
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
type Device = B::Device;

type FullPrecisionElem = B::FullPrecisionElem;
Expand Down Expand Up @@ -42,7 +48,7 @@ impl<B: Backend> Backend for Autodiff<B> {
}
}

impl<B: Backend> AutodiffBackend for Autodiff<B> {
impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
type InnerBackend = B;
type Gradients = Gradients;

Expand Down
Loading
Loading