Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for half-precision gemm #1080

Merged
merged 1 commit into from
Aug 5, 2021
Merged

Conversation

bjarthur
Copy link
Contributor

@bjarthur bjarthur commented Aug 4, 2021

is Float16 the right type to use in the ccall?

also, what do you want to do about tests? cublas only supports half precision for gemm, not any of the other functions (e.g. symm, trsm, hemm, etc.). i'd have to split that big for-loop apart and duplicate the array construction.

for size(A)==(64,1,1000), size(B)==(1,64,1000), size(C)==(64,64,1000) the execution time of CUBLAS.gemm_strided_batched! is 3/4 as long compared to float32 on an RTX8000, and several fold slower on a 1080Ti.

fixes #1076

@@ -1314,6 +1314,16 @@ end
handle, uplo, n, alpha, x, incx, y, incy, AP)
end

@checked function cublasHgemm_v2(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is deprecated; see deprecated.jl.

@maleadt
Copy link
Member

maleadt commented Aug 4, 2021

Float16 gemm already works, I assume you intend to add support for the batched APIs?

also, what do you want to do about tests? cublas only supports half precision for gemm, not any of the other functions (e.g. symm, trsm, hemm, etc.). i'd have to split that big for-loop apart and duplicate the array construction.

Moving those test out to a self-contained testset that covers more types seems like the easiest solution. And the CUBLAS tests desperately need to be refactored like that anyway.

@bjarthur
Copy link
Contributor Author

bjarthur commented Aug 4, 2021

i've removed Hgemm_v2 and added tests.

@maleadt
Copy link
Member

maleadt commented Aug 4, 2021

Thanks, looks good!
The Windows CI failure seems related though, Float16 tests will need to be gated by the device capability.

@bjarthur
Copy link
Contributor Author

bjarthur commented Aug 4, 2021

the failing windows build lists these:

  │ │ CUDA toolkit 11.4.0, artifact installation
  | │ CUDA driver 11.3.0
  | │ NVIDIA driver 466.27.0
  | │
  | │ Libraries:
  │ │ - CUBLAS: 11.5.2
  | │ - CURAND: 10.2.5
  | │ - CUFFT: 10.5.0
  | │ - CUSOLVER: 11.2.0
  | │ - CUSPARSE: 11.6.0
  | │ - CUPTI: 14.0.0
  | │ - NVML: 11.0.0+466.27
  | │ - CUDNN: 8.20.0 (for CUDA 11.3.0)
  | │ - CUTENSOR: 1.3.0 (for CUDA 11.2.0)
  | │
  | │ Toolchain:
  | │ - Julia: 1.6.2
  | │ - LLVM: 11.0.1
  | │ - PTX ISA support: 3.2, 4.0, 4.1, 4.2, 4.3, 5.0, 6.0, 6.1, 6.3, 6.4, 6.5, 7.0
  | │ - Device capability support: sm_35, sm_37, sm_50, sm_52, sm_53, sm_60, sm_61, sm_62, sm_70, sm_72, sm_75, sm_80

the passing 11.4 build is identical except the NVIDIA driver is 460.73, as is the NVML +. on what should i gate?

@maleadt
Copy link
Member

maleadt commented Aug 4, 2021

You left out the important part, listing the actual device and its capability :-) Float16 generally requires capability 5.3 or higher, while ths GTX 970 is sm_52.

@codecov
Copy link

codecov bot commented Aug 4, 2021

Codecov Report

Merging #1080 (d807b5b) into master (5710e25) will not change coverage.
The diff coverage is n/a.

❗ Current head d807b5b differs from pull request most recent head 381ac17. Consider uploading reports for the commit 381ac17 to get more accurate results
Impacted file tree graph

@@           Coverage Diff           @@
##           master    #1080   +/-   ##
=======================================
  Coverage   80.00%   80.00%           
=======================================
  Files         118      118           
  Lines        7656     7656           
=======================================
  Hits         6125     6125           
  Misses       1531     1531           
Impacted Files Coverage Δ
lib/cublas/wrappers.jl 92.37% <ø> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5710e25...381ac17. Read the comment docs.

@bjarthur
Copy link
Contributor Author

bjarthur commented Aug 5, 2021

all tests pass now

@maleadt
Copy link
Member

maleadt commented Aug 5, 2021

Great, thanks!

@maleadt maleadt merged commit 9f2ca66 into JuliaGPU:master Aug 5, 2021
@bjarthur bjarthur deleted the bja/Hgemm branch August 5, 2021 11:52
@maleadt maleadt added cuda libraries Stuff about CUDA library wrappers. enhancement New feature or request labels Aug 13, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cuda libraries Stuff about CUDA library wrappers. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

cublasHgemmStridedBatched
2 participants