Skip to content

Commit

Permalink
feat/cudnnv4: passive support for cuDNNv4
Browse files Browse the repository at this point in the history
  • Loading branch information
hobofan committed Mar 3, 2016
1 parent b685ae8 commit 0dc4630
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ license = "MIT OR Apache-2.0"

[dependencies]
collenchyma = { version = "0.0.8", default-features = false }
cudnn = { version = "1.3.0", optional = true }
cudnn = { version = "1.3.1", optional = true }
libc = "0.2"
lazy_static = "0.1"
log = "0.3.2"
Expand Down
23 changes: 11 additions & 12 deletions src/frameworks/cuda/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ impl ConvForwardAlgo {
ConvForwardAlgo::ImplicitGEMM => ::cudnn::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
ConvForwardAlgo::ImplicitPrecompiledGEMM => ::cudnn::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
ConvForwardAlgo::FFT => ::cudnn::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
ConvForwardAlgo::FFTTiling => ::cudnn::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
ConvForwardAlgo::Direct => ::cudnn::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
})
}
Expand All @@ -130,6 +131,7 @@ impl ConvForwardAlgo {
::cudnn::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM => ConvForwardAlgo::ImplicitGEMM,
::cudnn::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM => ConvForwardAlgo::ImplicitPrecompiledGEMM,
::cudnn::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_FFT => ConvForwardAlgo::FFT,
::cudnn::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING => ConvForwardAlgo::FFTTiling,
::cudnn::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT => ConvForwardAlgo::Direct,
}
}
Expand Down Expand Up @@ -161,6 +163,7 @@ impl ConvForwardAlgo {
ConvForwardAlgo::ImplicitGEMM => false,
ConvForwardAlgo::ImplicitPrecompiledGEMM => true,
ConvForwardAlgo::FFT => true,
ConvForwardAlgo::FFTTiling => true,
ConvForwardAlgo::Direct => true,
})
}
Expand Down Expand Up @@ -212,7 +215,7 @@ impl ConvBackwardFilterAlgo {
Ok(match *self {
ConvBackwardFilterAlgo::Auto => return Err(::co::error::Error::Plugin(::co::plugin::Error::Plugin("Can't check necessary workspace size for ConvBackwardFilterAlgo::Auto. Use `find_cudnn_algo` to find an algorithm."))),
ConvBackwardFilterAlgo::ImplicitGEMM => false,
ConvBackwardFilterAlgo::ImplicitGEMMSum => false,
ConvBackwardFilterAlgo::ImplicitGEMMSum => true,
ConvBackwardFilterAlgo::ImplicitPrecompiledGEMMSum => true,
ConvBackwardFilterAlgo::FFT => true,
})
Expand All @@ -227,6 +230,7 @@ impl ConvBackwardDataAlgo {
ConvBackwardDataAlgo::ImplicitGEMM => ::cudnn::cudnnConvolutionBwdDataAlgo_t::CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
ConvBackwardDataAlgo::ImplicitGEMMSum => ::cudnn::cudnnConvolutionBwdDataAlgo_t::CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
ConvBackwardDataAlgo::FFT => ::cudnn::cudnnConvolutionBwdDataAlgo_t::CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
ConvBackwardDataAlgo::FFTTiling => ::cudnn::cudnnConvolutionBwdDataAlgo_t::CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
})
}

Expand All @@ -236,6 +240,7 @@ impl ConvBackwardDataAlgo {
::cudnn::cudnnConvolutionBwdDataAlgo_t::CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 => ConvBackwardDataAlgo::ImplicitGEMMSum,
::cudnn::cudnnConvolutionBwdDataAlgo_t::CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 => ConvBackwardDataAlgo::ImplicitGEMM,
::cudnn::cudnnConvolutionBwdDataAlgo_t::CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT => ConvBackwardDataAlgo::FFT,
::cudnn::cudnnConvolutionBwdDataAlgo_t::CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING => ConvBackwardDataAlgo::FFTTiling,
}
}

Expand Down Expand Up @@ -265,6 +270,7 @@ impl ConvBackwardDataAlgo {
ConvBackwardDataAlgo::ImplicitGEMM => false,
ConvBackwardDataAlgo::ImplicitGEMMSum => false,
ConvBackwardDataAlgo::FFT => true,
ConvBackwardDataAlgo::FFTTiling => true,
})
}
}
Expand Down Expand Up @@ -298,18 +304,11 @@ macro_rules! impl_convolution_for_cuda_backend {
let useable_algo_bwd_filter = try!(algo_bwd_filter.find_cudnn_algo(&filter_desc, &conv_desc, &src_desc, &dest_desc));
let useable_algo_bwd_data = try!(algo_bwd_data.find_cudnn_algo(&filter_desc, &conv_desc, &src_desc, &dest_desc));

let workspace_size_fwd = match try!(useable_algo_fwd.needs_cudnn_workspace()) {
false => 0,
true => API::get_convolution_forward_workspace_size(*CUDNN.id_c(), useable_algo_fwd.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap(),
};

let workspace_size_bwd_filter = match try!(useable_algo_bwd_filter.needs_cudnn_workspace()) {
false => 0,
true => API::get_convolution_backward_filter_workspace_size(*CUDNN.id_c(), useable_algo_bwd_filter.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap(),
};

let workspace_size_fwd = API::get_convolution_forward_workspace_size(*CUDNN.id_c(), useable_algo_fwd.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap();
let workspace_size_bwd_filter = API::get_convolution_backward_filter_workspace_size(*CUDNN.id_c(), useable_algo_bwd_filter.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap();
// let workspace_size_bwd_data = API::get_convolution_backward_data_workspace_size(*CUDNN.id_c(), useable_algo_bwd_data.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap();
let workspace_size_bwd_data = match try!(useable_algo_bwd_data.needs_cudnn_workspace()) {
false => 0,
false => 1,
true => API::get_convolution_backward_data_workspace_size(*CUDNN.id_c(), useable_algo_bwd_data.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap(),
};

Expand Down
10 changes: 10 additions & 0 deletions src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ pub enum ConvForwardAlgo {
///
/// Needs a significant memory workspace.
FFT,
/// Compute the convolution as Fast-Fourier Transform with 32x32 tiles.
///
/// Needs a significant memory workspace.
FFTTiling,
/// Compute the convolution without implicit or explicit matrix-multiplication. **Do not try to use this**.
///
/// Listed in cuDNN docs but cuDNN does not provide a implementation.
Expand Down Expand Up @@ -100,6 +104,12 @@ pub enum ConvBackwardDataAlgo {
///
/// The results are deterministic.
FFT,
/// Compute the convolution as Fast-Fourier Transform with 32x32 tiles.
///
/// Needs a significant memory workspace.
///
/// The results are deterministic.
FFTTiling,
}

impl ConvBackwardDataAlgo {
Expand Down

0 comments on commit 0dc4630

Please sign in to comment.