Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-GPU support #42

Closed
soumith opened this issue Sep 23, 2014 · 41 comments
Closed

Multi-GPU support #42

soumith opened this issue Sep 23, 2014 · 41 comments

Comments

@soumith
Copy link
Member

soumith commented Sep 23, 2014

MultiGPU support has been implemented in cutorch (and by extension all torch cuda libraries like cunn, cudnn etc.).

  • Switch the device on the fly with cutorch.setDevice(devID)
  • All cuda calls are asynchronous, and can be synchronized with cutorch.synchronize()

Example usage for tensors:

-- Let us do matrix addition for matrices sitting on two different GPUs
cutorch.setDevice(1)
matrix1 = torch.CudaTensor(10):fill(1)
print(matrix1) -- printing is a synchronous call, so you dont have to explicitly call cutorch.synchronize()
cutorch.setDevice(2)
matrix2 = torch.CudaTensor(10):fill(2)
print(matrix2) 
matrix2:add(matrix1) -- matrix1 is seamlessly copied onto GPU2 and added to matrix2
print(matrix2)

if you want to do data-parallel training of neural nets (including convnets), your training loop can run like this:

For each mini-batch:

1. load data (preferably using multiple threads, for example using [threads-ffi](https://github.com/torch/threads-ffi))
2. loop over GPUs (the loop below will be completely anynchronous, so will run parallely)
  2.1. model[gpuX]:forward
  2.2. criterion[gpuX]:forward
  2.3. criterion[gpuX]:backward
  2.4. model[gpuX]:backward
3. cutorch.synchronize()
4. accumulate GPUx's gradParameters to GPU1's gradParameters
5. do SGD on GPU1
6. copy back GPU1's parameters to GPUx
7. cutorch.synchronize() and print accuracy etc.

Loop back to 1 for next mini-batch

Also, to train ConvNets using multiple GPUs, I recommend using CuDNN for the convolution layers, as I've tested that they are completely asynchronous (meaning that the processing runs parallely on multiple GPUs)

Comments below describe the technical details of changes made. If you just want to use Multi-GPU, you can stop reading now.

@soumith
Copy link
Member Author

soumith commented Sep 23, 2014

Comments below describe the technical details of changes made. If you just want to use Multi-GPU, you can stop reading now.
What is missing?

cutorch.setDevice right now resets the random seed as well. This needs to be separated so that you can use multiple GPUs as needed
https://github.com/torch/cutorch/blob/master/init.c#L42

GPU-to-GPU copy now has to use a host-bridge. This needs to be changed to P2P GPU copy. This is really trivial to implement, in the cutorch initialization function, we just have to enable p2p for each GPU detected with this function:
http://developer.download.nvidia.com/compute/cuda/4_1/rel/toolkit/docs/online/group__CUDART__PEER_g9e5ea65a18938c2b8715a5602105c306.html

After that UVA takes care of everything else, copying tensors from one GPU to another is as simple as this:
cutorch.setDevice(1)
t1 = torch.randn(100):cuda()
cutorch.setDevice(2)
t2 = torch.randn(100):cuda()
-- UVA copy
t2:copy(t1)

Internally, Clement and us have multi-GPU support, and we will get the changes back to cutorch slowly (it will take time to isolate the commits and get approval etc), but if you are really adventurous, this is a couple of hours of work.

@soumith
Copy link
Member Author

soumith commented Sep 23, 2014

@nicholas-leonard
Copy link
Member

And I am guessing we should use this for our D2D memory copies: http://developer.download.nvidia.com/compute/cuda/4_1/rel/toolkit/docs/online/group__CUDART__MEMORY_g046702971bc5a66d9bc6000682a6d844.html#g046702971bc5a66d9bc6000682a6d844

This means that if I have kernel sequence A->B->C (one device 1), followed by device2device memcopy D , followed by kernels (one device 2) E->F->G, then eventually A->B->C of iteration t should run in parallel to E->F->G of the previous iteration (t-1), right? Otherwise, I don't see how this can be useful, other than allowing the use of more GPU memory. I mean ideally, you want those GPUs to work on different kernels concurrently. Say A->B->C are the first 3 modules of nn.Sequential, and E->F->G are the last 3.

@soumith
Copy link
Member Author

soumith commented Sep 23, 2014

@nicholas-leonard you dont need to do cudaMemcpyPeer explicitly anymore, UVA takes care of it.

@soumith
Copy link
Member Author

soumith commented Sep 23, 2014

@nicholas-leonard
Copy link
Member

Wow. So you are right, it would be super easy to implement. We check for device UVA flag, then call cudaDeviceEnablePeerAccess for every combination of such devices. Easy.

@nicholas-leonard
Copy link
Member

Still, would the two sequences in the above example be able to run concurrently?

@soumith
Copy link
Member Author

soumith commented Sep 23, 2014

yes they would run concurrently starting the next itreration as long as you have no blocking calls anywhere.
clement already removed all the blocking calls from the convnet pipeline in a few commits earlier this year (dont remember the exact hashes).

@szagoruyko
Copy link
Member

@soumith I did the UVA init for GPUs and got rid of "cuda runtime error : an illegal memory access was encountered" errors while copying directly from one tensor to another, I'm not sure, however, that network calls are not blocking everywhere. How do we test it?

@soumith
Copy link
Member Author

soumith commented Sep 24, 2014

@szagoruyko profile it (with sys.tic()/sys.toc() or just with os.clock().
if model:forward() or model:backward() takes more then say 5-10ms, there is blocking calls.
you should only see a big timing number once you do cutorch.synchronize(), and until then all the cuda calls should not take any time at all.

@soumith
Copy link
Member Author

soumith commented Sep 24, 2014

@szagoruyko SpatialConvolutionMM seems to be randomly blocking sometimes, due to for-looped cublas calls. Use CuDNN, there is no blocking at all.

@clementfarabet
Copy link
Member

Yeah, this is really annoying, I could never figure out why this gemm calls
are blocking sometimes, there's got to be a reason...

On Wed, Sep 24, 2014 at 12:26 PM, Soumith Chintala <notifications@github.com

wrote:

@szagoruyko https://github.com/szagoruyko SpatialConvolutionMM seems to
be randomly blocking sometimes, due to for-looped cublas calls. Use CuDNN,
there is no blocking at all.


Reply to this email directly or view it on GitHub
#42 (comment).

@soumith
Copy link
Member Author

soumith commented Sep 24, 2014

@clementfarabet I've even tried moving to cublasv2, and that didnt help. Maybe CuBLAS has a queue that gets filled? It's only conjecture, as we dont have source code.

@clementfarabet
Copy link
Member

It probably does, in which case we would need to use streams

Clément

On Sep 24, 2014, at 12:30 PM, Soumith Chintala notifications@github.com wrote:

@clementfarabet I've even tried moving to cublasv2, and that didnt help. Maybe CuBLAS has a queue that gets filled? It's only conjecture, as we dont have source code.


Reply to this email directly or view it on GitHub.

@szagoruyko
Copy link
Member

@soumith forward cudnn itself is not blocking, but when I add nn.Reshape and nn.Linear it blocks. backward is not blocking at all though

@soumith
Copy link
Member Author

soumith commented Sep 24, 2014

@szagoruyko use nn.View
nn.Linear shouldn't be blocking I think. If it still is, let me know, I'll take a look.

@nicholas-leonard
Copy link
Member

I think the call to new():fill() blocks here: https://github.com/torch/nn/blob/master/Linear.lua#L48

@soumith
Copy link
Member Author

soumith commented Sep 24, 2014

i dont know anymore what the public cutorch is like. it is possible that that line is blocking, that line is not needed there, it can be a temporary buffer that is reused.

@soumith
Copy link
Member Author

soumith commented Sep 24, 2014

if not self.addBuffer or (self.addBuffer:size(1) ~= nframe) then
self.addBuffer = input.new(nframe):fill(1)
end
self.output:zero():addr(1, self.addBuffer, self.bias)

shall i patch it, or does someone else want to do the honours? Same with line 89 and 92, same addBuffer can be reused.

@szagoruyko
Copy link
Member

@soumith @nicholas-leonard it doesn't go there actually, nunit is 1 in my case

@soumith
Copy link
Member Author

soumith commented Sep 24, 2014

ah, for nunit=1, this line is blocking
https://github.com/torch/nn/blob/master/Linear.lua#L45
because it gets the bias back to host memory. I'll patch it sometime today.

@szagoruyko
Copy link
Member

@soumith cool! ccn2 is blocking by the way. MM and ccn2 are blocking, and ccn2 in backward is not fully blocked. Should I share the test script somewhere?

@nicholas-leonard
Copy link
Member

@soumith thanks

@soumith
Copy link
Member Author

soumith commented Sep 24, 2014

@szagoruyko yes that would be helpful for all

@szagoruyko
Copy link
Member

@szagoruyko
Copy link
Member

and pull request is here #44

@szagoruyko
Copy link
Member

by the way, is it possible to have shared modules on different GPUs?

@soumith
Copy link
Member Author

soumith commented Sep 24, 2014

@szagoruyko just added p2p access, anyone else wants to take the task of not resetting the random seed every time setDevice is called? All you have to do is move the randomseed initialization to cuda initialization (per device)

@soumith
Copy link
Member Author

soumith commented Sep 24, 2014

@szagoruyko not directly, but if you want to do data-parallel training, your training loop can run like this:

  • loop over GPUs
    • run model[gpuX]:forward + criterion[gpuX]:forward + criterion[gpuX]:backward + model[gpuX]:backward
  • accumulate GPUx's gradParameters to GPU1's gradParameters
  • do SGD on GPU1
  • copy back GPU1's parameters to GPUx

@szagoruyko
Copy link
Member

@soumith cool, looks like we can do it efficiently now. Thanks!

@dominikgrewe
Copy link
Member

I've created a pull request for moving the random seed initialization: #45

@soumith
Copy link
Member Author

soumith commented Sep 25, 2014

Awesome! now that this is done, basic Multi-GPU support is essentially done.
Anything on top of this is going to be a feature.
Helper methods like getting the device associated with a Tensor/Storage, removing any remaining blocking calls in the entire cutorch/cunn, etc.

So, @jonathantompson to answer your earlier question, torch has Multi-GPU support ;)

@shugaoma
Copy link

Error happens when I run the example: it seems the math operations are not "seamless" as declared:

th> cutorch.setDevice(1)
th> matrix1 = torch.CudaTensor(10):fill(1)
th> print(matrix1)
cutorch.synchronize()
1
1
1
1
1
1
1
1
1
1
[torch.CudaTensor of size 10]
th> cutorch.setDevice(2)
th> matrix2 = torch.CudaTensor(10):fill(2)
th> print(matrix2)
2
2
2
2
2
2
2
2
2
2
[torch.CudaTensor of size 10]
th> matrix2:add(matrix1)
[string "matrix2:add(matrix1) -- matrix1 is seamlessly..."]:1: Assertion `THCudaTensor_checkGPU(state, 3, self_, src1, src2)' failed. at /tmp/luarocks_cutorch-scm-1-6658/cutorch/lib/THC/THCTensorMathPointwise.cu:82
stack traceback:
[C]: in function 'add'
[string "matrix2:add(matrix1) -- matrix1 is seamlessly..."]:1: in main chunk
[C]: in function 'xpcall'
/home/shugao/torch/install/share/lua/5.1/trepl/init.lua:648: in function 'repl'
...ugao/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:185: in main chunk
[C]: at 0x00406670

@soumith
Copy link
Member Author

soumith commented Aug 13, 2015

@Algred the only operation that is allowed cross-GPU is the copy operation. All other operations are checked with assertions to make sure what you are doing is not possible. To get good performance, you'll have to copy matrix1 onto matrix2's GPU and then do the mathematical operation.

If you dont like this setting, you can simply disable these assertions by adding the define DISABLE_CHECK_GPU and reinstalling cutorch.

https://github.com/torch/cutorch/blob/master/lib/THC/THCTensor.c#L761

@shugaoma
Copy link

@soumith Thank you very much for replying!

@darksigma
Copy link

Does this PR support GPU Direct RDMA? Or are additional lower-level modifications necessary to run on a multi GPU Mellanox/GTX Titan X cluster?

@eriche2016
Copy link

If i do data parallel training, how does the minibatch data are split over the multi GPUs, do they split evenly by the scheduler, or just split the minibatch manually

@soumith
Copy link
Member Author

soumith commented Jan 26, 2016

evenly

@soumith
Copy link
Member Author

soumith commented Jan 26, 2016

@darksigma try looking at nccl.torch for that.

@byronwwang
Copy link

@soumith For mini-batch splitting, is there any shuffle before it? Or just split the batch evenly according to the original order of samples in the batch.

@soumith
Copy link
Member Author

soumith commented Jun 4, 2016

No shuffle, split evenly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants