Skip to content

PyTorch autograd

QI JUN edited this page Jul 28, 2020 · 12 revisions

PyTorch 代码跟读

Forward 计算完成后,会产生一系列的 Backward Function。每个 Backward Function 包含一个反向计算函数。这些 Backward Function 组成了一张 Graph。

Backward Function 函数签名在 derivatives.yaml 中定义,Backward Function Class 定义由 gen_autograd_functions.py 编译产生。

Backward Function Class继承自 Node,有一个方法 release_variables 用来释放资源。

Engine 中有一个线程池来执行 Graph 中的若干 Backward Function,在 evaluate_function 方法中调用每个 Backward Function 的 release_variables 方法。

release_variables 会释放计算的中间结果,也就是 Variable。 一个 Variable 中包含一个 at::Tensorat::Tensor 中包含一个 intrusive_ptr(性能更好的 shared_ptr) 的 TensorImpl。

release_variables 方法内部会调用 reset_data 方法,reset_data 最终调用 intrusive_ptr 的 _reset 方法,会使引用计数减一。

libtorch C++

动态图的意思是每次 forward 计算都会创建一个新图。在 forward 中,会把计算中间结果放在 saved variable中(tape 的思路),并且产生对应的一组 backward function。 在 backward 中,会逐个执行 backward function。同时每执行完一个 backward function,就会调用 release_variables 方法,把该 backward function 引用的中间结果给释放掉。

libtorch中,C++ 在做 forward 的时候,因为在 tape 中做记录,引用计数 +1; C++ 在做 backward 的时候,每调用一个 backward function,因为调用了release_variable 方法,引用计数 -1。可以做到一边 backward,一边尽早释放资源。

GoTorch Go

我们在 libtorch C++ 之上,做了 Go wrapper,给每个中间结果都增加了一个额外的引用,那么这个引用应当由 Go 负责释放。并且,如果要达到“ 在不再需要的时候尽快释放”目标,这里我们就需要做到一边 backward,一边释放资源。

换句话说,直接调用 libtorch 的 backward 就不够了,libtorch只管 C++,管不了 Go,我们可能需要在 Go 中重新实现 backward,在每个 backward function 结束之后,插入 tensor.Close函数。

一种妥协方案是,在 forward 和 backward都结束之后,统一释放 Go 中的引用。这样做的好处是,直接复用 libtorch 中的 backward 函数,实现起来简单清晰;坏处是对内存/显存的占用比较大。

这里贴一个之前给 Paddle 显存 eager delete策略 的实验,该策略会在计算图的 backward 过程中插入一些 delete_tensor 算子。

Model no optimize release memory forward memory
Resnet 170590208 78004224(reduce 54.3%) 77488128

我们可以看出 Resnet 在没有优化的情况下,显存占用是 170590208,也就是说 forward + backward 总共要占用这么多显存。在 eager delete策略下,一边 backward,一边释放资源,显存占用是 78004224,基本上跟只做 forward 计算所需要的显存持平。

Clone this wiki locally