-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Heyi fused grad accumulation #138
base: master
Are you sure you want to change the base?
Conversation
//cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); | ||
//if(g_hipblas_handle == nullptr) | ||
// CHECK_HIPBLAS_ERROR(hipblasCreate(&g_hipblas_handle)); | ||
hipblasLtHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is hipblasLt, I think we need to use getCurrentCUDABlasLtHandle
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct. cuda handles can be used interchangeably, but hipblas and hipblaslt handles cannot.
hipblaslt_ext::Gemm gemm( | ||
handle, transa, transb, HIP_R_16BF, HIP_R_16BF, HIP_R_32F, HIP_R_32F, HIPBLAS_COMPUTE_32F); | ||
|
||
hipblaslt_ext::GemmEpilogue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is using hipblasLtExt API. We haven't used that a lot. We mostly use hipblasLt API. I will need to check with hipblaslt team regarding the difference (in performance implication). And as hipblasLt API is more commonly used, I expect it would be more stable and have fewer bugs.
…rge to utils_test
update summary:
|
hipblasLtMatmulDesc_t matmul; | ||
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); | ||
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute( | ||
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(int32_t))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if transa
is int32_t
or not. But I guess we can just use sizeof(transa)
.
const int request_solutions = 1; | ||
hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; | ||
int returnedAlgoCount = 0; | ||
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jeffdaily Do we do any autotuning on the hipblasLt path? I am wondering if we need to do autotuning here.
import math, pdb | ||
from torch.testing._internal import common_utils | ||
|
||
torch.backends.cuda.matmul.allow_tf32 = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jeffdaily Do we have this default to be False on rocm?
update summary:
|
else: | ||
print("========FAIL======") | ||
|
||
grad_weight = grad_weight.view(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to reshape the tensor now as torch.allclose
can take 2d tensors.
@pruthvistony @jithunnair-amd Do we have any CI setup for Apex? |
@wenchenvincent -- there's no CI for apex. |
Regarding the tolerance for UTs, I looked at matmul UTs in PyTorch and it seems there are some |
hipblaslt implementation in fused_weight_gradient_dense