-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add design of asynchronous techniques on heterogeneous devices #7814
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Design Doc: Asynchronous and Non-blocking techniques on heterogeneous devices. | ||
|
||
|
||
## Problem | ||
|
||
We often use heterogeneous devices, like GPU, FPGA, in a deep learning platform. The computation on heterogeneous devices is usually asynchronous and non-blocking. The computation tasks are offloaded to heterogeneous devices and programmers must wait for the computation finished and then fetch results. Heterogeneous devices are usually able to execute simultaneously. For example, Fermi architecture (CUDA compute capability 2.0+) can simultaneously support | ||
|
||
* execute CUDA kernels | ||
* host-to-device memory copy | ||
* device-to-host memory copy. | ||
|
||
However, the simultaneous execution is not transparent to programmers. | ||
|
||
Let's use CUDA as an example. There is a building block named `stream` in CUDA. Streams introduce task-based parallelism to CUDA codes. The sequence of operations will be executed in issue-order on the GPU if they are in the same stream. | ||
|
||
The operators in different streams are able to run concurrently as long as they are in multiple streams and hardware supports it. CUDA hardware has no notion of streams. The hardware has separate queues (engines) to perform memory copies and to execute kernels. | ||
|
||
If we want to take advantage of CUDA devices, we must use at least N streams, where N equals the number of hardware queues, and separate operators into these streams. The N equals to three since CUDA can simultaneously execute CUDA kernels, H2D memcpy, D2H memcpy by the CUDA hardware. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that the stream of CUDA does not have a limit. As long as the resources (memory and computaion) of GPU are not occupied, in theory, you can create a new stream. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, it does not. The CUDA can create many streams without any limits. However, jobs in CUDA can simultaneously execute, in two conditions.
Since the CUDA hardware supports to simultaneously execute THREE kinds of operators, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Threads in a block will be launched in a SM(streaming multiprocessors). If the former CUDA kernel occupies few SMs and there is more SM left, another CUDA kernel can be executed in parallel. Please refer to https://devblogs.nvidia.com/gpu-pro-tip-cuda-7-streams-simplify-concurrency/. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Please consider the following case: We have 7 kernels of 3 types (A B C). There dependency relationship is as follows: A0 B0 C0 If we just have 3 streams, and it's possible that they are put into 3 stream in the following order: A0 B0 C0 C1 C2 In this case C1 C2 don't depend on C0 but still have to wait for C0's completion before being able to run. I think the number of streams we use should equal to the concurrency expressed in our program, not the hardware. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why |
||
|
||
Considering the execution of CUDA devices is asynchronous, there should be a wait operator when switching streams. For example, we want to read the computation result from GPU; then we must wait for the computation complete and issue a device-to-host memory copy. | ||
|
||
## Solution | ||
|
||
The solution is straightforward based on the hardware properties we described in the problem section. We should: | ||
|
||
* Create N device contexts on one device. The N should be corresponding to the hardware property. For example, the CUDA devices should have three device contexts. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't it be that the number of device contexts only depends on the "concurrency" requirement of PaddlePaddle program, rather than the hardware? Related question: #7814 (comment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is true that one "device contexts" = "one stream"? It's a little confusing that in the "Problem" section we are only talking about stream, but in the "Solution" section we are mainly talking about context. |
||
|
||
* Every tensor should hold the one device context, where the current operator of the tensor is performed on. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wondered whether it is appropriate that every tensor holds one device context.
But for tensors, only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only device context can be Another reason we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In |
||
|
||
* Wait for the execution complete on the previous device context, when switching the current device context of tensors. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the stream on the given tensor is as follows:
You mentioned "Wait for the execution complete on the previous device context", do we have to wait until |
||
|
||
|
||
The sample C++ program is | ||
|
||
```cpp | ||
|
||
enum CUDAHardwareStream { | ||
kCOMPUTATION, | ||
kD2HMEMCPY, | ||
kH2DMEMCPY | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a |
||
}; | ||
|
||
std::map<CUDAHardwareStream, DeviceContext* > gDevCtxs; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does here need a device_id for multi-devices in this global gDevCtxs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well. This code is used to demonstrate the basic idea of the solution. I do not go so deeply in details. |
||
|
||
class Tensor { | ||
public: | ||
... | ||
|
||
void SwitchDevCtx(DeviceContext* new_ctx) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am hesitant to add
So maybe we should put the explicit wait in operator run? void ReduceOp::Run(scope, place) {
gDevCtxs[place, kCOMPUTATION].Wait();
my_ctx = gDevCtx[place, kD2DMEMCPY];
...
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only data in different CUDA streams are independent, can these operations execute There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tonyyang-svail The previous operator of the tensor needs to wait for. So we need
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@QiJune
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @reyoung Maybe we can write some experimental codes. Following is a pseudocode:
We expect that after op1 running, update_op can run in parallel with op2 and op3. But I am not sure if the behavior will be update_op running after op1/op2/op3 finishing, because the dev_ctx1 stream has three operations on it. That's not we wanted. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @QiJune thanks for the example. I am sure these four operators will be executed sequentially. As fas as parallel_do_grad is concerned, I think the following program is good enough, even without an explicit scheduler
|
||
if (dev_ctx_ != new_ctx) { | ||
dev_ctx->Wait(); | ||
} | ||
dev_ctx_ = new_ctx; | ||
} | ||
|
||
private: | ||
... | ||
DeviceContext* dev_ctx_; | ||
|
||
}; | ||
|
||
|
||
void SomeTensorComputationFunction(Tensor* t) { | ||
t->SwitchDevCtx(gDevCtxs[kCOMPUTATION]); | ||
... | ||
} | ||
|
||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
memory copies
==>data transfers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I want to talk another issue. Please refer to #7814 (comment)