Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorporate cuDNN, add conv2d CPU/GPU version (based on Eigen and cuDNN) #388

Merged
merged 14 commits into from
Apr 17, 2017
Merged

Incorporate cuDNN, add conv2d CPU/GPU version (based on Eigen and cuDNN) #388

merged 14 commits into from
Apr 17, 2017

Conversation

zhisbug
Copy link
Collaborator

@zhisbug zhisbug commented Mar 20, 2017

#229 This is the CPU implementation based on Eigen SpatialConvolution. It is reported as the current fastest (available) CPU version of conv2d.
For GPU support, I think implementing a new version using cublas kernels (by hand) is worthless, so I am currently incorporating cudnn into DyNet and will provide a cudnn-based (standard) implementation.

@neubig
Copy link
Contributor

neubig commented Mar 20, 2017

This is excellent, thank you! I think we should probably create a third_party directory like the one used for TensorFlow though, just to make sure that it's clear that the code was copied&pasted from somewhere else. Other than that, this looks great. We will also need doc for the convolutions eventually, but this can be in a follow-up commit.

@zhisbug
Copy link
Collaborator Author

zhisbug commented Mar 21, 2017

Agree. Sure, let me mark this one as WIP. I will re-organize the structure a little bit by creating a separate folder to hold third party files.
My next commits will cover the following:

  1. Support for cuDNN, and revised installation guide (if use GPU, cuDNN will be a must).
  2. GPU conv2d based on cuDNN.
  3. CPU/GPU 2D maxpooling
  4. code restructuring
  5. Interface documents on how to use 2d conv/maxpooling.

Let me know if there's any concern on the above modifications.

@neubig
Copy link
Contributor

neubig commented Mar 21, 2017

Thanks! Two things:

  1. I think it's worth finishing this commit with only CPU support. Some people have been wanting conv2d, and even having it just on CPU will help people prototype while the other things are in progress.
  2. I'd like GPU support to be there even without cuDNN, but when the user uses a function that is only supported through cuDNN, it can throw a "please install cuDNN" error.

@zhisbug
Copy link
Collaborator Author

zhisbug commented Mar 21, 2017

OK, cool.
For 1, let me prioritize the code-reorganizing and try to merge this CPU version tomorrow.
For 2, I think a possible solution is to add another compiler FLAG (e.g. USE_CUDNN).
My concerns:
For some functions such as Conv2D, Conv3D, MaxPooling on GPU etc., though it is possible to implement them w/o using cuDNN, the best implementation by hand would still be 1-2 orders of magnitute slower than w/ cuDNN (on GPU). Implementing these functions would be time-consuming but the outcome is usually useless (because finally the implemention will converge to an implementation using cuDNN unless there exist some other computing devices in place of GPU). See TensorFlow's strategy as an example: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/conv_ops.cc#L448
Let me know how you think about it!

@neubig
Copy link
Contributor

neubig commented Mar 21, 2017

For 2., yes, the TensorFlow strategy is fine!

@sunalbert
Copy link
Contributor

good job. A cpu version is useful enough to help me prototype. Look forward to dynet with convolution operation.

@pmichel31415
Copy link
Collaborator

That is awesome I've been needing that for a while.

Do you have an estimation of when you'll be pushing the GPU version?

@zhisbug zhisbug changed the title Add conv2d CPU version (based on Eigen) [WIP] Add conv2d CPU version (based on Eigen) Mar 22, 2017
@zhisbug
Copy link
Collaborator Author

zhisbug commented Mar 22, 2017

I probably can push an 80%-complete but usable version by next Wednesday (with GPU enabled).

@zhisbug zhisbug changed the title [WIP] Add conv2d CPU version (based on Eigen) [WIP] Incorporate cuDNN, add conv2d CPU/GPU version (based on Eigen and cuDNN) Apr 3, 2017
@zhisbug zhisbug changed the title [WIP] Incorporate cuDNN, add conv2d CPU/GPU version (based on Eigen and cuDNN) Incorporate cuDNN, add conv2d CPU/GPU version (based on Eigen and cuDNN) Apr 3, 2017
@zhisbug
Copy link
Collaborator Author

zhisbug commented Apr 3, 2017

  • Support cuDNN 5.1
  • Add a USE_CUDNN compiler flag: CMake will try to find cuDNN if the user provides a -DCUDNN_ROOT argument
  • CPU conv2d based on Eigen
  • GPU conv2d based on cuDNN
  • code restructuring: create a thrid_party fold to hold external functionalities
  • Add interface documentation: how to use conv2d.
  • Modify setup tutorial: add an extra -DCUDNN_ROOT flag.

@zhisbug
Copy link
Collaborator Author

zhisbug commented Apr 3, 2017

OK, I think most of the engineering effort for this issue has been finished. It would be great if someone could review this commit and leave comments.
For MaxPooling2D, I would prefer to creating a separate issue for it.
Once the code is ready to be merged, I will add interface documentation, and modify the setup tutorial to reflect the cuDNN support.

@sunalbert
Copy link
Contributor

sunalbert commented Apr 4, 2017

Why the shape of x_1 in conv2d is [W H C N]?
x_1 \in R^{W x H x Ci x N} (input)
I am confused by the layout. In C++, the layout is row-major which means a 4D array has a shape of [N H W C]. According to following comment is tensor.h:

This provides a bridge between classic C++ types and Eigen tensors.

,I think the initial layout of the input(x_1) should be [C W H N] instead of [W H C N]
Look forward to your reply. Thanks in advance.

@zhisbug
Copy link
Collaborator Author

zhisbug commented Apr 4, 2017

Good question. This is actually a design choice (that needs some discussion). Here is the explanation (I'll assume RowMajor in the following discussion, and x_1 is the feature maps and x_2 the filters):
In Eigen, only the layout NHWC for x_1 and WHCN for x_2 are supported.
While, cuDNN supports NHWC or NCHW for x_1, and NHWC or NCHW for x_2, but only NCHW can get the maximal performance.

The above conditions imply that there will be at least one layout transformation happening, either for CPU version or GPU version.

Considering that layout transformation is usually equally fast on CPU and GPU, I design it to let it happen only f or CPU convolution. Reasons are as follows:

  1. CPU runs convolution must slower than GPU, so the relative slowdown caused by layout transformation will be less significant;
  2. The code is easier to implement on CPU :)

Let me know what you think about!

@neubig
Copy link
Contributor

neubig commented Apr 4, 2017

Hi, just to confirm:
W = width = columns
H = height = rows
C = channels
N = number of filters.

I think the most important thing here is clarity and consistency with the standard DyNet API, so H should come before W, and it makes sense that C comes after HW. N could come either first or last, as both make sense.

So in summary, I think NHWC is the most intuitive option here. Even if the performance is a little bit worse on CuDNN, I think we should make sure that we go with consistency and simplicity of the API, which is one of the core design principles behind DyNet.

@zhisbug
Copy link
Collaborator Author

zhisbug commented Apr 4, 2017

OK, then I will assume x_1 and x_2 both come in with the layout NHWC.
So for both CPU and GPU, layout transformation is necessary:

  1. CPU: transform x_2 to WHCN
  2. GPU, transform x_1 and x_2 to NHWC (this step could be optional because cuDNN also supports NHWC, but I am not sure about how much performance will be compromised.)

@neubig
Copy link
Contributor

neubig commented Apr 4, 2017

@zhisbug Thanks! Also, two more things:

  1. This is probably a stupid question, but why does x_1, the input need N? My understanding was that N is the number of filters, and there is no concept of number of filters for the input.
  2. We can also probably support input where we omit "C", as there are plenty of situations where we don't have multiple channels.

@neubig
Copy link
Contributor

neubig commented Apr 4, 2017

Sorry, for 1. I answered my own question. N is the number of batches. Currently DyNet only supports the number of batches being the last dimension, so we probably need it to be
HWCN or CHWN for the input, but this is in column-major format. Is cuDNN column major or row major? If it's row major, then the memory layout would be opposite and things might be confused here.

@zhisbug
Copy link
Collaborator Author

zhisbug commented Apr 4, 2017

@neubig To clarify, what I was talking about above is assuming RowMajor. so RowMajor NHWC/NCHW is ColMajor CWHN/WHCN, and cuDNN is RowMajor.

But I am a bit confused by your comments "so we probably need it to be HWCN or CHWN for the input, but this is in ColMajor format". Could you clarify which one of HWCN and CHWN in ColMajor is preferred? Currently, the implementation assumes input comes in NCHW in rowmajor, which is WHCN in colmajor (which is aligned with DyNet's design that N comes as the last dim).

@neubig
Copy link
Contributor

neubig commented Apr 4, 2017

OK, just to confirm, the canonical layout in DyNet would be:

  • rows X columns X channels X filters

re-written in tensorflow terminology, maybe this is the following (does W correspond to rows? because it's "width" I think it should correspond to the number of columns... Regardless, it doesn't matter, as the behavior of rows and columns is identical anyway.)

  • WHCN

If we assume that this is written in the opposite order of the actual memory layout (according to the cuDNN convention):

  • NCHW

So actually, I think things are OK. We will lay out things in the canonical order of "rows X columns X channels X filters" or "rows X columns X channels X minibatches" in DyNet Tensors, and then the memory order will correspond to the order that cuDNN expects. On CPU, this might mean that we have to do one transform, but I think that's OK.

@zhisbug
Copy link
Collaborator Author

zhisbug commented Apr 4, 2017

@neubig Thanks! This makes sense.
In my description, row corresponds to H (height) and column corresponds to W (width). So we reached the agreement that the input should be in (ColMajor):

  • rows X columns X channels X (filters or batches)
  • HWCN (in CUDA/TensorFlow language)

Let me do a final check whether my assumed input is the case. If it is not, I will need to do a minor modification before merging.

@zhisbug
Copy link
Collaborator Author

zhisbug commented Apr 5, 2017

Done. Now it assumes the input follows the HxWxCxN layout in ColMajor (thereby NxCxWxH in RowMajor), where

  • H: number of rows
  • W: number of columns
  • C: number of channels
  • N: number of batches (or filters)

Copy link
Contributor

@neubig neubig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Basically looks good, but there are a few places where we can remove the custom-implemented code and just rely on Eigen tensor operations, which will make things much simpler and cleaner. Also, please make sure that ordering of function arguments always follows the canonical column-major ordering, so rows come first, then columns, etc. I probably won't have time to test this before merging, so perhaps we can ask the DyNet user base to stress test it.

dynet/gpu-ops.h Outdated
@@ -34,6 +34,7 @@ void sparse_to_dense_assign(int n, const unsigned int* ids, float* src, float* t
void dense_to_sparse_subtract(int n, const unsigned int* ids, float* src, float* trg);
void sparse_to_dense_block_assign_and_multiply(int n, const unsigned *idx, int bsize, float mult, float *src, float *trg);
void dense_to_sparse_block_add(int n, const unsigned* ids, int bsize, float* src, float* trg);
void pad_input(float* output, const float* input, int N, int C, int H, int W, int pad_right, int pad_bottom);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there may be an Eigen tensor operation "tensor.pad()" for this. Can you try to use this instead of writing a custom kernel? You can find an example in eigen/unsupported/tests/cxx11_tensor_padding.cpp.

};

// template?
struct NCWHToNWHC {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, there is an Eigen operation for this. See eigen/unsupported/tests/cxx11_tensor_shuffling.cpp

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced this layout transformation functions with Eigen functions.

//Eigen::array<int, 4> offsets = {0, 0, 0, 0};
//Eigen::array<int, 4> extents = {static_cast<int>(XH), static_cast<int>(XW), static_cast<int>(XC), static_cast<int>(XN)};
//TODO this functio cannot be linked
//dxi->tb<3>().device(*dev.edevice) = padded_dx.tb<3>().slice(offsets, extents);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I tried to use Eigen slice() function to do negative padding in place of my customized functions but ended up with a linking error. Can someone familiar with Eigen GPU operations take a look?

//paddings[2] = std::make_pair(0, 0);
//paddings[3] = std::make_pair(0, 0);
//TODO this function cannot be linked
//padded_x.tb<3>().device(*dev.edevice) = xs[0]->tb<3>().pad(paddings);
Copy link
Collaborator Author

@zhisbug zhisbug Apr 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I tried to use Eigen pad() function to do positive padding on GPU tensors but ended up with a linking error. Can someone familiar with Eigen GPU operations take a look?

@neubig
Copy link
Contributor

neubig commented Apr 11, 2017

@zhisbug I think this is probably because you're compiling the cudnn-ops.cc file with regular gcc, not nvcc. In general, Eigen code that uses the GPU will need to be compiled with nvcc, which is why we have the dummy files gpu-XXX.cu that are compiled with nvcc. Could you try renaming the file to gpu-cudnn-ops.cu, and move it from the list of files to be compiled with gcc to the list of files to be compiled with nvcc in CMakeLists.txt?

@zhisbug
Copy link
Collaborator Author

zhisbug commented Apr 11, 2017

Cool! Now it works!

@neubig
Copy link
Contributor

neubig commented Apr 11, 2017

OK, great! I think the final step is now documentation removing the unnecessary code and documenting the functions.

I think we should probably also deprecate (comment out) the existing conv1d functions, as they don't match expected behavior from any paper that I'm aware of. see: #236

@zhisbug
Copy link
Collaborator Author

zhisbug commented Apr 12, 2017

@neubig Do you think the conv1d is still necessary? If it is, maybe I can take a look at it and reimplement it if needed.

@neubig
Copy link
Contributor

neubig commented Apr 12, 2017

@zhisbug I don't think they're necessary, so for this commit let's just comment them out.

@neubig
Copy link
Contributor

neubig commented Apr 14, 2017

@zhisbug Is this basically done except for doc? If so, let's do the doc and merge it in! (I'll be able to do this tomorrow if you're busy today)

@zhisbug
Copy link
Collaborator Author

zhisbug commented Apr 14, 2017

@neubig Yes, this is a complete version except for the docs. I'll create the doc tonight and let you know when it is done and mergeable!

@pmichel31415
Copy link
Collaborator

Can this be merged? I will add the python documentation directly on master

@neubig neubig merged commit cd7bc2e into clab:master Apr 17, 2017
@neubig
Copy link
Contributor

neubig commented Apr 17, 2017

@pmichel31415 Thanks, done!

@zhisbug
Copy link
Collaborator Author

zhisbug commented Apr 17, 2017

@neubig @pmichel31415 Cool! Thanks!

@zhisbug
Copy link
Collaborator Author

zhisbug commented Apr 17, 2017

Oh, btw, @neubig @pmichel31415 the installation guide should probably be updated to let the user be aware of the cuDNN support/requirement.

@zhisbug
Copy link
Collaborator Author

zhisbug commented Apr 17, 2017

After merging, the build does not pass.

  1. There is an API change in the latest version. Could you change TensorTools::Constant() to TensorTools::constant() in dynet/nodes-conv2d.cc. and TensorTools::SetElements() to TensorTools::set_elements() in test/test-nodes.cc
  2. @neubig I saw you remove cudnn-ops.cu from dynet/CMakeLists.txt, why? I think this will cause a linking error.

@neubig
Copy link
Contributor

neubig commented Apr 17, 2017

Sorry about that! There were a few bugs introduced by a manual merge of the CMakeLists that I didn't catch. These commits should fix them: ff2cc7d 04ca2b9

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants