-
Notifications
You must be signed in to change notification settings - Fork 218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add support for half-precision gemm #1080
Conversation
lib/cublas/libcublas.jl
Outdated
@@ -1314,6 +1314,16 @@ end | |||
handle, uplo, n, alpha, x, incx, y, incy, AP) | |||
end | |||
|
|||
@checked function cublasHgemm_v2(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is deprecated; see deprecated.jl
.
Float16 gemm already works, I assume you intend to add support for the batched APIs?
Moving those test out to a self-contained testset that covers more types seems like the easiest solution. And the CUBLAS tests desperately need to be refactored like that anyway. |
i've removed |
Thanks, looks good! |
the failing windows build lists these:
the passing 11.4 build is identical except the NVIDIA driver is 460.73, as is the NVML +. on what should i gate? |
You left out the important part, listing the actual device and its capability :-) Float16 generally requires capability 5.3 or higher, while ths GTX 970 is sm_52. |
Codecov Report
@@ Coverage Diff @@
## master #1080 +/- ##
=======================================
Coverage 80.00% 80.00%
=======================================
Files 118 118
Lines 7656 7656
=======================================
Hits 6125 6125
Misses 1531 1531
Continue to review full report at Codecov.
|
all tests pass now |
Great, thanks! |
is
Float16
the right type to use in theccall
?also, what do you want to do about tests? cublas only supports half precision for gemm, not any of the other functions (e.g. symm, trsm, hemm, etc.). i'd have to split that big for-loop apart and duplicate the array construction.
for size(A)==(64,1,1000), size(B)==(1,64,1000), size(C)==(64,64,1000) the execution time of
CUBLAS.gemm_strided_batched!
is 3/4 as long compared to float32 on an RTX8000, and several fold slower on a 1080Ti.fixes #1076