Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Heyi fused grad accumulation #138

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README_heyi_fused_grad_accumulation.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ cd apex
python setup.py install --cpp_ext --cuda_ext

python tests/L0/run_transformer/test_weight_grad.py
or
python tests/L0/run_test.py --include run_transformer
248 changes: 19 additions & 229 deletions csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,36 +64,12 @@ void gemmex_wrapper_fp16(
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, HIP_R_16BF, m, n, m));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, HIP_R_16BF, m, n, m));

if(batch_count > 1)
{
int64_t stride_a = m * k;
int64_t stride_b = k * n;
int64_t stride_c = m * n;
int64_t stride_d = m * n;
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
matA, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
matA, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
matB, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
matB, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
matC, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
matC, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
matD, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
matD, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_d, sizeof(stride_d)));
}

hipblasLtMatmulDesc_t matmul;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(int32_t)));
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(int32_t)));
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb)));

hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
Expand Down Expand Up @@ -131,9 +107,6 @@ void gemmex_wrapper_fp16(
uint64_t workspace_size = 0;
for(int i = 0; i < returnedAlgoCount; i++)
workspace_size = max(workspace_size, heuristicResult[i].workspaceSize);
// In this sample, the workspace is already allocated with max_workspace_size
// If not, allocate d_workspace here
// CHECK_HIP_ERRORhipMalloc(&d_workspace, workspace_size));

CHECK_HIPBLASLT_ERROR(hipblasLtMatmul(handle,
matmul,
Expand Down Expand Up @@ -185,36 +158,12 @@ void gemmex_wrapper_fp16(
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, HIP_R_16F, m, n, m));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, HIP_R_16F, m, n, m));

if(batch_count > 1)
{
int64_t stride_a = m * k;
int64_t stride_b = k * n;
int64_t stride_c = m * n;
int64_t stride_d = m * n;
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
matA, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
matA, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
matB, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
matB, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
matC, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
matC, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
matD, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute(
matD, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_d, sizeof(stride_d)));
}

hipblasLtMatmulDesc_t matmul;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(int32_t)));
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(int32_t)));
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb)));

hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
Expand Down Expand Up @@ -252,9 +201,6 @@ void gemmex_wrapper_fp16(
uint64_t workspace_size = 0;
for(int i = 0; i < returnedAlgoCount; i++)
workspace_size = max(workspace_size, heuristicResult[i].workspaceSize);
// In this sample, the workspace is already allocated with max_workspace_size
// If not, allocate d_workspace here
// CHECK_HIP_ERRORhipMalloc(&d_workspace, workspace_size));

CHECK_HIPBLASLT_ERROR(hipblasLtMatmul(handle,
matmul,
Expand All @@ -281,19 +227,11 @@ void gemmex_wrapper_fp16(
return;
}


//hipblasLtHandle_t g_hipblas_handle = nullptr;

template <typename T>
void wgrad_gemm_accum_fp16_cuda(T *input, T *d_output, T *dc_tensor, T *d_weight,int in_dim, int hidden_dim, int out_dim) {
//hipblasLtHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
//hipStream_t stream;
//hipblasGetStream(handle, &stream);
hipblasLtHandle_t handle;
hipStream_t stream;
CHECK_HIP_ERROR(hipStreamCreate(&stream));
CHECK_HIPBLASLT_ERROR(hipblasLtCreate(&handle));
void wgrad_gemm_accum_fp16_cuda(T *input, T *d_output, T *d_weight,int in_dim, int hidden_dim, int out_dim) {

hipblasLtHandle_t handle = at::cuda::getCurrentCUDABlasLtHandle();
hipStream_t stream = at::cuda::getCurrentCUDAStream();
float alpha = 1.0;
float beta = 1.0;
const int batch_count = 1;
Expand All @@ -305,26 +243,24 @@ void wgrad_gemm_accum_fp16_cuda(T *input, T *d_output, T *dc_tensor, T *d_weigh
handle,
HIPBLAS_OP_N,
HIPBLAS_OP_T,
in_dim, //m
out_dim, //n
hidden_dim, //k
in_dim, //m
out_dim, //n
hidden_dim, //k
batch_count,
alpha,
beta,
input, //da
d_output, //db
dc_tensor, //dc
d_weight, //dd
input, //da
d_output, //db
d_weight, //dc
d_weight, //dd
d_workspace,
max_workspace_size,
stream);

CHECK_HIPBLASLT_ERROR(hipblasLtDestroy(handle));
CHECK_HIP_ERROR(hipStreamDestroy(stream));
}

template void wgrad_gemm_accum_fp16_cuda<at::Half>(at::Half *input, at::Half *d_output, at::Half *dc_tensor, at::Half *d_weight, int in_dim, int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp16_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, at::BFloat16 *dc_tensor, at::BFloat16 *d_weight, int in_dim, int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp16_cuda<at::Half>(at::Half *input, at::Half *d_output, at::Half *d_weight, int in_dim, int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp16_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, at::BFloat16 *d_weight, int in_dim, int hidden_dim, int out_dim);

void wgrad_gemm_accum_fp16_cuda_stub(
at::Tensor &input,
Expand All @@ -347,163 +283,17 @@ void wgrad_gemm_accum_fp16_cuda_stub(
d_output_2d = d_output;
}

at::Tensor dc_tensor = at::empty_like(d_weight);
dc_tensor.copy_(d_weight);
//at::Tensor dst_tensor = at::zeros_like(d_weight);

const int hidden_dim = input_2d.size(0); //k
const int in_dim = input_2d.size(1); //m
const int out_dim = d_weight.size(0); //n
const int in_dim = input_2d.size(1); //m
const int out_dim = d_weight.size(0); //n

DISPATCH_HALF_AND_BFLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp16",
wgrad_gemm_accum_fp16_cuda<scalar_t>(
input_2d.data_ptr<scalar_t>(),
d_output_2d.data_ptr<scalar_t>(),
dc_tensor.data_ptr<scalar_t>(),
d_weight.data_ptr<scalar_t>(),
in_dim,
hidden_dim,
out_dim);
);
}
/*
// BF16 inputs and BF16 accumulation
void gemmex_wrapper_fp16(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::BFloat16* A,
int lda,
at::BFloat16* B,
int ldb,
const float* beta,
at::BFloat16* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16BF,
lda,
B,
CUDA_R_16BF,
ldb,
beta,
C,
CUDA_R_16BF,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}

// FP16 inputs and FP16 accumulation
void gemmex_wrapper_fp16(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float* beta,
at::Half* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_16F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}

template <typename T>
void wgrad_gemm_accum_fp16_cuda(T *input, T *d_output, T *d_weight, int in_dim, int hidden_dim, int out_dim) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta = 1.0;

gemmex_wrapper_fp16(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_dim,
out_dim,
hidden_dim,
&alpha,
input,
in_dim,
d_output,
out_dim,
&beta,
d_weight,
in_dim);
}

template void wgrad_gemm_accum_fp16_cuda<at::Half>(at::Half *input, at::Half *d_output, at::Half *d_weight, int in_dim, int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp16_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, at::BFloat16 *d_weight, int in_dim, int hidden_dim, int out_dim);

void wgrad_gemm_accum_fp16_cuda_stub(
at::Tensor &input,
at::Tensor &d_output,
at::Tensor &d_weight
) {
at::Tensor input_2d, d_output_2d;
// input tensor: collapse to the first dim
auto in_sizes = input.sizes();
if (input.dim() > 2) {
input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
} else {
input_2d = input;
}
// d_output tensor: collapse to the first dim
auto d_out_sizes = d_output.sizes();
if (d_output.dim() > 2) {
d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
} else {
d_output_2d = d_output;
}

const int hidden_dim = input_2d.size(0); //k
const int in_dim = input_2d.size(1); //m
const int out_dim = d_weight.size(0); //n

DISPATCH_HALF_AND_BFLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp16",
wgrad_gemm_accum_fp16_cuda<scalar_t>(
input_2d.data_ptr<scalar_t>(),
d_output_2d.data_ptr<scalar_t>(),
d_weight.data_ptr<scalar_t>(),
in_dim,
hidden_dim,
out_dim);
);
}
*/
}
Loading