Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Heyi fused grad accumulation #138

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README_heyi_fused_grad_accumulation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
docker pull rocm/pytorch:latest

docker run -it --network=host -v /home/yigex/huggingface/heyi:/workspace --device=/dev/kfd --device=/dev/dri --group-add video --security-opt seccomp=unconfined --ipc=host --name heyi_apex_wx rocm/pytorch:latest

mkdir heyi && cd heyi

git clone -b heyi_fused_grad_accumulation https://github.com/eliotwang/apex

cd apex

python setup.py install --cpp_ext --cuda_ext

python tests/L0/run_transformer/test_weight_grad.py
or
python tests/L0/run_test.py --include run_transformer
308 changes: 226 additions & 82 deletions csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,113 +10,257 @@
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>

#include "type_shim.h"

/* Includes, HIP */
#include <hipblaslt/hipblaslt.h>
#include <hipblaslt/hipblaslt-ext.hpp>

#ifndef CHECK_HIP_ERROR
#define CHECK_HIP_ERROR(error) \
if(error != hipSuccess) \
{ \
fprintf(stderr, \
"Hip error: '%s'(%d) at %s:%d\n", \
hipGetErrorString(error), \
error, \
__FILE__, \
__LINE__); \
exit(EXIT_FAILURE); \
}
#endif
#ifndef CHECK_HIPBLASLT_ERROR
#define CHECK_HIPBLASLT_ERROR(error) \
if(error != HIPBLAS_STATUS_SUCCESS) \
{ \
fprintf(stderr, "hipBLASLt error(Err=%d) at %s:%d\n", error, __FILE__, __LINE__); \
fprintf(stderr, "\n"); \
exit(EXIT_FAILURE); \
}
#endif

// BF16 inputs and BF16 accumulation
void gemmex_wrapper_fp16(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
hipblasHandle_t handle,
hipblasOperation_t transa,
hipblasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
int batch_count,
float& alpha,
float& beta,
at::BFloat16* A,
int lda,
at::BFloat16* B,
int ldb,
const float* beta,
at::BFloat16* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16BF,
lda,
B,
CUDA_R_16BF,
ldb,
beta,
C,
CUDA_R_16BF,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
at::BFloat16* D,
void* d_workspace,
int64_t max_workspace_size,
hipStream_t stream)
{
hipblasLtMatrixLayout_t matA, matB, matC, matD;
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matA, HIP_R_16BF, m, k, m));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matB, HIP_R_16BF, n, k, n));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, HIP_R_16BF, m, n, m));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, HIP_R_16BF, m, n, m));

hipblasLtMatmulDesc_t matmul;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb)));

hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));

// Set User Preference attributes
hipblasLtMatmulPreference_t pref;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulPreferenceCreate(&pref));
CHECK_HIPBLASLT_ERROR(
hipblasLtMatmulPreferenceSetAttribute(pref,
HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&max_workspace_size,
sizeof(max_workspace_size)));

const int request_solutions = 1;
hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions];
int returnedAlgoCount = 0;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle,
matmul,
matA,
matB,
matC,
matD,
pref,
request_solutions,
heuristicResult,
&returnedAlgoCount));

if(returnedAlgoCount == 0)
{
std::cerr << "No valid solution found!" << std::endl;
return;
}

uint64_t workspace_size = 0;
for(int i = 0; i < returnedAlgoCount; i++)
workspace_size = max(workspace_size, heuristicResult[i].workspaceSize);

CHECK_HIPBLASLT_ERROR(hipblasLtMatmul(handle,
matmul,
&alpha,
A,
matA,
B,
matB,
&beta,
C,
matC,
D,
matD,
&heuristicResult[0].algo,
d_workspace,
workspace_size,
stream));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matA));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matB));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matC));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matD));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescDestroy(matmul));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulPreferenceDestroy(pref));
return;
}

// FP16 inputs and FP16 accumulation
void gemmex_wrapper_fp16(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
hipblasLtHandle_t handle,
hipblasOperation_t transa,
hipblasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
int batch_count,
float& alpha,
float& beta,
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float* beta,
at::Half* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_16F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
at::Half* D,
void* d_workspace,
int64_t max_workspace_size,
hipStream_t stream)
{
hipblasLtMatrixLayout_t matA, matB, matC, matD;
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matA, HIP_R_16F, m, k, m));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matB, HIP_R_16F, n, k, n));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, HIP_R_16F, m, n, m));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, HIP_R_16F, m, n, m));

hipblasLtMatmulDesc_t matmul;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb)));

hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));

// Set User Preference attributes
hipblasLtMatmulPreference_t pref;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulPreferenceCreate(&pref));
CHECK_HIPBLASLT_ERROR(
hipblasLtMatmulPreferenceSetAttribute(pref,
HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&max_workspace_size,
sizeof(max_workspace_size)));

const int request_solutions = 1;
hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions];
int returnedAlgoCount = 0;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle,
matmul,
matA,
matB,
matC,
matD,
pref,
request_solutions,
heuristicResult,
&returnedAlgoCount));

if(returnedAlgoCount == 0)
{
std::cerr << "No valid solution found!" << std::endl;
return;
}

uint64_t workspace_size = 0;
for(int i = 0; i < returnedAlgoCount; i++)
workspace_size = max(workspace_size, heuristicResult[i].workspaceSize);

CHECK_HIPBLASLT_ERROR(hipblasLtMatmul(handle,
matmul,
&alpha,
A,
matA,
B,
matB,
&beta,
C,
matC,
D,
matD,
&heuristicResult[0].algo,
d_workspace,
workspace_size,
stream));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matA));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matB));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matC));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutDestroy(matD));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescDestroy(matmul));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulPreferenceDestroy(pref));
return;
}

template <typename T>
void wgrad_gemm_accum_fp16_cuda(T *input, T *d_output, T *d_weight, int in_dim, int hidden_dim, int out_dim) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta = 1.0;
void wgrad_gemm_accum_fp16_cuda(T *input, T *d_output, T *d_weight,int in_dim, int hidden_dim, int out_dim) {

hipblasLtHandle_t handle = at::cuda::getCurrentCUDABlasLtHandle();
hipStream_t stream = at::cuda::getCurrentCUDAStream();
float alpha = 1.0;
float beta = 1.0;
const int batch_count = 1;
void* d_workspace;
int64_t max_workspace_size = 32*1024*1024;
if(max_workspace_size > 0)
CHECK_HIP_ERROR(hipMalloc(&d_workspace, max_workspace_size));
gemmex_wrapper_fp16(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_dim,
out_dim,
hidden_dim,
&alpha,
input,
in_dim,
d_output,
out_dim,
&beta,
d_weight,
in_dim);
HIPBLAS_OP_N,
HIPBLAS_OP_T,
in_dim, //m
out_dim, //n
hidden_dim, //k
batch_count,
alpha,
beta,
input, //da
d_output, //db
d_weight, //dc
d_weight, //dd
d_workspace,
max_workspace_size,
stream);

}

template void wgrad_gemm_accum_fp16_cuda<at::Half>(at::Half *input, at::Half *d_output, at::Half *d_weight, int in_dim, int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp16_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, at::BFloat16 *d_weight, int in_dim, int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp16_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, at::BFloat16 *d_weight, int in_dim, int hidden_dim, int out_dim);

void wgrad_gemm_accum_fp16_cuda_stub(
at::Tensor &input,
Expand All @@ -139,17 +283,17 @@ void wgrad_gemm_accum_fp16_cuda_stub(
d_output_2d = d_output;
}

const int hidden_dim = input_2d.size(0);
const int in_dim = input_2d.size(1);
const int out_dim = d_weight.size(0);
const int hidden_dim = input_2d.size(0); //k
const int in_dim = input_2d.size(1); //m
const int out_dim = d_weight.size(0); //n

DISPATCH_HALF_AND_BFLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp16",
wgrad_gemm_accum_fp16_cuda<scalar_t>(
input_2d.data_ptr<scalar_t>(),
d_output_2d.data_ptr<scalar_t>(),
d_weight.data_ptr<scalar_t>(),
d_weight.data_ptr<scalar_t>(),
in_dim,
hidden_dim,
out_dim);
);
}
}
Loading