Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Heyi fused grad accumulation #138

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

eliotwang
Copy link

hipblaslt implementation in fused_weight_gradient_dense

//cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
//if(g_hipblas_handle == nullptr)
// CHECK_HIPBLAS_ERROR(hipblasCreate(&g_hipblas_handle));
hipblasLtHandle_t handle = at::cuda::getCurrentCUDABlasHandle();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is hipblasLt, I think we need to use getCurrentCUDABlasLtHandle?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct. cuda handles can be used interchangeably, but hipblas and hipblaslt handles cannot.

hipblaslt_ext::Gemm gemm(
handle, transa, transb, HIP_R_16BF, HIP_R_16BF, HIP_R_32F, HIP_R_32F, HIPBLAS_COMPUTE_32F);

hipblaslt_ext::GemmEpilogue

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is using hipblasLtExt API. We haven't used that a lot. We mostly use hipblasLt API. I will need to check with hipblaslt team regarding the difference (in performance implication). And as hipblasLt API is more commonly used, I expect it would be more stable and have fewer bugs.

@eliotwang
Copy link
Author

update summary:

  1. replace hipblasltext API with hipblaslt API
  2. create and destroy handle per call
  3. add test_weight_grad.py to tests/L0/run_transformer/ following utils test rules

hipblasLtMatmulDesc_t matmul;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(int32_t)));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if transa is int32_t or not. But I guess we can just use sizeof(transa).

const int request_solutions = 1;
hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions];
int returnedAlgoCount = 0;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeffdaily Do we do any autotuning on the hipblasLt path? I am wondering if we need to do autotuning here.

import math, pdb
from torch.testing._internal import common_utils

torch.backends.cuda.matmul.allow_tf32 = False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeffdaily Do we have this default to be False on rocm?

@eliotwang
Copy link
Author

update summary:

  1. use getCurrentCUDABlasLtHandle to get hipblaslt handle;
  2. Update test files to accommodate to the unittest infrastructure and style used by apex; https://github.com/eliotwang/apex/blob/heyi_fused_grad_accumulation/tests/L0/run_transformer/test_weight_grad.py
    run test with: python tests/L0/run_test.py --include run_transformer
  3. Replace custom cosin_similarity with torch.all_close, set tolerance threshold for each test case
  4. Set TENSILE_DB=0x8000 to tell different func call between hipblas and hipblaslt;
    hipblas:
    Cijk_Ailk_Bjlk_BBS_BH_MT128x64x16_MI16x16x16x1_SN_1LDSB0_APM1_ABV0_ACED0_AF0EM1_AF1EM1_AMAS3_ASE_ASGT_ASLT_ASM_ASAE01_ASCE01_ASEM1_AAC0_BL1_BS1_CLR1_DTLA0_DTLB0_DTVA0_DTVB0_DVO0_ETSP_EPS1_ELFLR0_EMLL0_FSSC10_FL0_GLVWA4_GLVWB2_GRCGA1_GRCGB1_GRPM1_GRVWn1_GSU1_GSUASB_GLS0_ISA942_IU1_K1_KLA_LBSPPA0_LBSPPB512_LPA0_LPB16_LDL1_LRVW4_LWPMn1_LDW0_FMA_MIAV0_MDA2_MO40_MMFSC_MKFGSU256_NTA0_NTB0_NTC0_NTD0_NEPBS0_NLCA2_NLCB2_ONLL1_OPLV0_PK0_PAP0_PGR2_PLR1_PKA0_SIA3_SLW1_SS1_SU32_SUM0_SUS256_SCIUI1_SPO0_SRVW0_SSO0_SVW4_SNLL0_TSGRA0_TSGRB0_TT8_16_TLDS0_UMLDSA0_UMLDSB0_USFGROn1_VAW1_VSn1_VW4_VWB1_VFLRP1_WSGRA0_WSGRB0_WS64_WG16_16_1_WGMn16
    hipblaslt:
    Cijk_Ailk_Bjlk_BBS_BH_Bias_AS_SAV_UserArgs_MT64x32x32_MI16x16x1_SN_LDSB0_AFC1_AFEM1_AFEM1_ASEM1_CLR1_CADS0_EPS0_GRVWA8_GRVWB4_GSUAMB_ISA942_IU1_K1_LBSPPA512_LBSPPB256_LBSPPM0_LPA32_LPB16_LPM0_LRVW4_LWPMn1_MIAV0_MIWT2_1_MO40_NTn1_NTA0_NTB0_NTC0_NTD0_NTM0_NEPBS16_NLCA1_NLCB1_ONLL1_PGR2_PLR1_PKA1_SIA3_SS1_SPO0_SRVW0_SSO0_SVW2_TLDS0_USFGROn1_VSn1_VWA2_VWB1_WSGRA0_WSGRB0_WS64_WG32_8_1
  5. Remove dc_tensor, use d_weight as both the input and output;
  6. Remove batch_count>1 case;
  7. Replace sizeof(int32_t) sizeof(transa);
  8. Rename vars as gradient of output and gradient of weight to avoid confusion;

else:
print("========FAIL======")

grad_weight = grad_weight.view(-1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to reshape the tensor now as torch.allclose can take 2d tensors.

@wenchenvincent
Copy link

@pruthvistony @jithunnair-amd Do we have any CI setup for Apex?

@pragupta
Copy link

@pruthvistony @jithunnair-amd Do we have any CI setup for Apex?

@wenchenvincent -- there's no CI for apex.

@pragupta
Copy link

Regarding the tolerance for UTs, I looked at matmul UTs in PyTorch and it seems there are some toleranceOverride decorators which define the tolerance level for various operators and dtypes. Pleasea see lines like these in this file: https://github.com/pytorch/pytorch/blob/68272ab5967f448ed6d2986039a0ef0ddf0e1b37/test/test_matmul_cuda.py#L119

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants