From aff6786f7ab44b42b8f246c944e84f5c9096fbe0 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Tue, 7 May 2024 18:07:28 +0000 Subject: [PATCH 01/34] [Kernel] Add GPU kernels. --- src/common/allocator.h | 26 ++- src/common/transformer_ctx.h | 9 +- src/layers/attention.h | 4 + src/layers/dist_linear.h | 15 +- src/layers/mlp_llama.h | 7 +- src/layers/rotary_embedding.cpp | 75 ++++++++ src/layers/rotary_embedding.h | 1 + src/models/common_decoder.h | 11 +- src/utils/matmul_helper.h | 300 +++++++++++++++++++++++++++++--- src/utils/simple_mem_pool.h | 13 +- 10 files changed, 423 insertions(+), 38 deletions(-) diff --git a/src/common/allocator.h b/src/common/allocator.h index 2e454772..afd8d1a7 100644 --- a/src/common/allocator.h +++ b/src/common/allocator.h @@ -18,6 +18,10 @@ #include #include "environment.h" +#ifdef GPU +#include +#endif + namespace xft { constexpr size_t g_thp_threshold = (size_t)2 * 1024 * 1024; @@ -26,11 +30,18 @@ static inline bool is_thp_alloc(size_t nbytes) { return (Env::getInstance().getTHPEnabled() && (nbytes >= g_thp_threshold)); } -static inline void *alloc(size_t nbytes, size_t alignment = 64) { +static inline void *alloc(size_t nbytes, size_t alignment = 64, void *device = nullptr) { if (nbytes == 0) { return nullptr; } void *data; +#ifdef GPU + if (device != nullptr) { + data = sycl::malloc_device(nbytes, *static_cast(device)); + return data; + } +#endif + int err = posix_memalign(&data, alignment, nbytes); if (err != 0) { printf("Unable to allocate buffer with size of %zu, err=%d\n", nbytes, err); @@ -47,4 +58,17 @@ static inline void *alloc(size_t nbytes, size_t alignment = 64) { return data; } + +static inline void dealloc(void *data, void *device = nullptr) { +#ifdef GPU + if (device != nullptr) { + sycl::free(data, *static_cast(device)); + return; + } +#endif + + free(data); + return; +} + } // namespace xft \ No newline at end of file diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index 3685baae..e955bf2f 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -111,6 +111,7 @@ struct DecoderContext { hpj::Matrix imOut; // intermediate output MMHelper *mmHelper; + void *device; std::string configPath; INIReader configReader; @@ -238,8 +239,12 @@ struct DecoderContext { bool cached(const std::string &name) { return SimpleMemPool::instance().cached(name); } template - T *getBuffer(const std::string &name, size_t size, size_t alignment = 64) { - return (T *)SimpleMemPool::instance().getBuffer(name, sizeof(T) * size, alignment); + T *getBuffer(const std::string &name, size_t size, void *device = nullptr, size_t alignment = 64) { + return (T *)SimpleMemPool::instance().getBuffer(name, sizeof(T) * size, device, alignment); + } + + void freeBuffer(const std::string &name, void *device = nullptr) { + SimpleMemPool::instance().freeBuffer(name, device); } void dump() { diff --git a/src/layers/attention.h b/src/layers/attention.h index 568cbcfc..d9aeb83a 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -294,6 +294,10 @@ class Attention { std::iota(posIds.begin(), posIds.end(), pastSeqLen); } qkpo.forward(query.Data(), key.Data(), query.Stride(), key.Stride(), qkShape, posIds.data()); +#ifdef GPU + sycl::queue *gpu_queue = static_cast(ctx->device); + gpu_queue->memcpy(qkvMatMul.Data(), query.Data(), ctx->batchSize * ctx->inputSeqLen * qkvCols * sizeof(float)).wait(); +#endif } t3.release(); diff --git a/src/layers/dist_linear.h b/src/layers/dist_linear.h index 6bd581a6..c5569571 100644 --- a/src/layers/dist_linear.h +++ b/src/layers/dist_linear.h @@ -59,14 +59,27 @@ class DistLinear { int K = inputSize; int N = this->splitSize; - weight.Resize(K, N); + scaleWeight.Resize(N); zeroWeight.Resize(N); hpj::Matrix quantizedWeight; ctx->mmHelper->convertWeight( true, K, N, w + splitOffset * K, nullptr, nullptr, quantizedWeight, scaleWeight, zeroWeight, sumWeight); +#ifdef GPU + hpj::Matrix tWeight; + tWeight.Resize(K, N); + ctx->mmHelper->transposeWeight(true, quantizedWeight, tWeight); + + sycl::queue *gpu_queue = static_cast(ctx->device); + WeiT *input_data = sycl::malloc_device(K * N, *gpu_queue); + weight.Assign(input_data, K, N, N); + gpu_queue->memcpy(weight.Data(), tWeight.Data(), tWeight.Rows() * tWeight.Cols() * sizeof(WeiT)) + .wait(); +#else + weight.Resize(K, N); ctx->mmHelper->packWeight(true, quantizedWeight, weight); +#endif // Copy Bias if (b) { diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index d3f102eb..42789085 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -275,8 +275,7 @@ class LlamaMLP : public SingletonBase> { } } - template - void catGateUpProj(DecoderContext *ctx, hpj::Matrix &input, hpj::Matrix &output, hpj::Matrix &siluBuf) { + void catGateUpProj(DecoderContext *ctx, hpj::Matrix &input, hpj::Matrix &output, hpj::Matrix &siluBuf) { TimeLine t("catGateUpProj"); assert(input.Rows() == output.Rows()); @@ -286,12 +285,12 @@ class LlamaMLP : public SingletonBase> { int M = input.Rows(), N = output.Cols(), K = input.Cols(); int lda = input.Stride(), ldc = output.Stride(); - const T1 *A = input.Data(); + const InT *A = input.Data(); const WeiT *B = catWeights.Data(); const float *scaleB = catWeightsScale.Data(); const float *zeroB = catWeightsZero.Data(); const float *sumB = catWeightsSum.Data(); - T2 *C = output.Data(); + ImT *C = output.Data(); ctx->mmHelper->compute(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc); diff --git a/src/layers/rotary_embedding.cpp b/src/layers/rotary_embedding.cpp index 6e495d28..e4745dbe 100644 --- a/src/layers/rotary_embedding.cpp +++ b/src/layers/rotary_embedding.cpp @@ -42,6 +42,19 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) { inv_freq[i] = 1.0 / pow(base, float(i * 2) / dim); } llamaCalEmb(inv_freq, max_position_embeddings); +#ifdef GPU + if (device != nullptr) { + sycl::queue *gpu_queue = static_cast(device); + float *emb_cos_bak = emb_cos; + float *emb_sin_bak = emb_sin; + emb_cos = ctx->getBuffer(emb_cos_str + "_gpu", max_position_embeddings * inv_freq_size, gpu_queue); + emb_sin = ctx->getBuffer(emb_sin_str + "_gpu", max_position_embeddings * inv_freq_size, gpu_queue); + gpu_queue->memcpy(emb_cos, emb_cos_bak, max_position_embeddings * inv_freq_size * sizeof(float)).wait(); + gpu_queue->memcpy(emb_sin, emb_sin_bak, max_position_embeddings * inv_freq_size * sizeof(float)).wait(); + ctx->freeBuffer(emb_cos_str); + ctx->freeBuffer(emb_sin_str); + } +#endif } else if (dim != inv_freq_size * 2) { printf("Incorrect dim=%d, inv_freq_size=%d\n", dim, inv_freq_size); exit(-1); @@ -112,6 +125,66 @@ void LlamaRotaryEmbedding::llamaCalEmb(const float *inv_freq, const int max_posi // |_____| |_____| // head_size/2 head_size/2 +#ifdef GPU + +void LlamaRotaryEmbedding::forward( + float *query, float *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { + const int batchSize = qkShape[0]; + const int seqLen = qkShape[1]; + const int qHeads = qkShape[2]; + const int kHeads = qkShape[4]; + const int head_num = std::max(qHeads, kHeads); + const int head_size = qkShape[3]; + const int half_head_size = (head_size + 1) / 2; + using namespace sycl; + + auto rope_kernel = [](sycl::nd_item<3> &item, const float *embCos, const float *embSin, const int qHeads, + const int kHeads, const int seq_size, const int head_size, const int half, float *query, float *key, + int qStride, int kStride, const sycl::accessor &positionIds) { + size_t idx_bs_seq = item.get_global_id(0); + size_t idx_head_num = item.get_global_id(1); + size_t idx_half_head_dim = item.get_global_id(2); + + size_t pos = positionIds[idx_bs_seq % seq_size]; + float cos = embCos[pos * half + idx_half_head_dim]; + float sin = embSin[pos * half + idx_half_head_dim]; + + float *q = query + idx_bs_seq * qStride + idx_head_num * head_size + idx_half_head_dim; + float *k = key + idx_bs_seq * kStride + idx_head_num * head_size + idx_half_head_dim; + + if (idx_head_num < qHeads) { + auto q1 = q[0]; + q[0] = q1 * cos - q[half] * sin; + q[half] = q[half] * cos + q1 * sin; + } + if (idx_head_num < kHeads) { + auto k1 = k[0]; + k[0] = k1 * cos - k[half] * sin; + k[half] = k[half] * cos + k1 * sin; + } + }; + + // Reorder input + sycl::queue *gpu_queue = static_cast(device); + float *embCos = emb_cos; + float *embSin = emb_sin; + + sycl::buffer positionIdsBuf(positionIds, sycl::range<1>(seqLen)); + gpu_queue->submit([&](sycl::handler &cgh) { + sycl::accessor position(positionIdsBuf, cgh, sycl::read_only); + sycl::range<3> globalSize(batchSize * seqLen, head_num, half_head_size); + sycl::range<3> workGroupSize(1, 1, 1); + + cgh.parallel_for( + sycl::nd_range(globalSize, workGroupSize), [=, this](sycl::nd_item<3> item) { + rope_kernel(item, embCos, embSin, qHeads, kHeads, seqLen, head_size, half_head_size, + query, key, qStride, kStride, position); + }); + }).wait(); +} + +#else + void LlamaRotaryEmbedding::forward( float *query, float *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { int dim = inv_freq_size * 2; @@ -214,3 +287,5 @@ void LlamaRotaryEmbedding::forward( } } } + +#endif // GPU \ No newline at end of file diff --git a/src/layers/rotary_embedding.h b/src/layers/rotary_embedding.h index b488a5f6..2c3746eb 100644 --- a/src/layers/rotary_embedding.h +++ b/src/layers/rotary_embedding.h @@ -58,4 +58,5 @@ class LlamaRotaryEmbedding { float *inv_freq = nullptr; float *emb_cos = nullptr; float *emb_sin = nullptr; + void *device = nullptr; }; diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index ab289027..6396c774 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -638,10 +638,17 @@ class CommonDecoder : public AbstractDecoder { epsilon, vocabSize, embeddingSize, maxPositions, maxPosEmbed, maxSeqLength, tpRank, tpSize, ppSize, ppRank, ropeParamsPtr, useLogN, useNTK)); + int engineIdx = 0; if (env.getEngineKind() == xft::DeviceKind::iGPU && env.getEngineIndex() < 0) // Sequential assignment - this->context->mmHelper = new MMHelper(env.getEngineKind(), ppRank * tpSize + tpRank); + engineIdx = ppRank * tpSize + tpRank; else // assignment through the user - this->context->mmHelper = new MMHelper(env.getEngineKind(), env.getEngineIndex()); + engineIdx = env.getEngineIndex(); + + this->context->mmHelper = new MMHelper(env.getEngineKind(), engineIdx); +#ifdef GPU + auto devices = sycl::device::get_devices(sycl::info::device_type::gpu); + this->context->device = new sycl::queue(devices[this->context->mmHelper->getEngineCount() + engineIdx]); +#endif } return this->context.get(); diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index 5ad5fbc9..927a24be 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -360,9 +360,11 @@ class MMHelper { weight.Resize(dims[0], dims[1]); weight.Resize(K, N); - dnnl::memory packedB_mem(desc, *engine, weight.Data()); - dnnl::reorder(B_mem, packedB_mem).execute(*stream, B_mem, packedB_mem); - stream->wait(); + dnnl::engine engine(dnnl::engine::kind::cpu, 0); + dnnl::stream stream(engine); + dnnl::memory packedB_mem(desc, engine, weight.Data()); + dnnl::reorder(B_mem, packedB_mem).execute(stream, B_mem, packedB_mem); + stream.wait(); } // INT4 @@ -396,6 +398,37 @@ class MMHelper { } } + template + void transposeWeight(bool trans, hpj::Matrix &src, hpj::Matrix &dst) { + using namespace dnnl; + using tag = memory::format_tag; + using dt = memory::data_type; + + dt weight_dt; + if constexpr (std::is_same_v) { + weight_dt = dt::f32; + } else if constexpr (std::is_same_v) { + weight_dt = dt::bf16; + } else if constexpr (std::is_same_v) { + weight_dt = dt::f16; + } else { + printf(">>> onednn_gemm_compute: input date type not supported."); + exit(-1); + } + + int K = trans ? src.Cols() : src.Rows(); + int N = trans ? src.Rows() : src.Cols(); + + dnnl::engine engine(dnnl::engine::kind::cpu, 0); + dnnl::stream stream(engine); + auto weight_md = memory::desc({K, N}, weight_dt, trans ? tag::ba : tag::ab); + auto weight_mem = memory(weight_md, engine, src.Data()); + auto transposed_weight_md = memory::desc({K, N}, weight_dt, get_onednn_weight_layout(weight_dt)); + auto transposed_weight_mem = memory(transposed_weight_md, engine, dst.Data()); + dnnl::reorder(weight_mem, transposed_weight_mem).execute(stream, weight_mem, transposed_weight_mem); + stream.wait(); + } + template void compute(bool transA, int M, int N, int K, float alpha, const InT *A, int lda, const WeiT *packedB, const float *scaleB, const float *zeroB, const float *sumB, float beta, OutT *C, int ldc) { @@ -408,9 +441,8 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute", - xdnn_sgemm_f32f16f32_compute( - transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc)); + GEMMVERBOSE("onednn_gemm_compute", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc)); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute", xdnn_hgemm_f32f16f32_compute( @@ -426,7 +458,7 @@ class MMHelper { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 GEMMVERBOSE("xdnn_sgemm_f32bf16f32_compute", xdnn_sgemm_f32bf16f32_compute( - transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, beta, C, ldc)); + transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) // TODO: xdnn impl? if constexpr (std::is_same_v) { @@ -533,7 +565,7 @@ class MMHelper { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 GEMMVERBOSE("xdnn_sgemm_f32bf16f32_compute_biasadd", xdnn_sgemm_f32bf16f32_compute_biasadd( - transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, beta, C, ldc, bias)); + transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, bias)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) // TODO: xdnn impl? if constexpr (std::is_same_v) { @@ -643,7 +675,7 @@ class MMHelper { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 GEMMVERBOSE("xdnn_sgemm_f32bf16f32_compute_biasadd_relu", xdnn_sgemm_f32bf16f32_compute_biasadd_relu( - transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, beta, C, ldc, bias)); + transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, bias)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if (M > AMXThresholdM) { GEMMVERBOSE("onednn_amx_sgemm_f32bf16f32_compute_biasadd_relu", @@ -746,7 +778,7 @@ class MMHelper { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 GEMMVERBOSE("xdnn_sgemm_f32bf16f32_compute_silu", xdnn_sgemm_f32bf16f32_compute_silu( - transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, beta, C, ldc)); + transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if constexpr (std::is_same_v) { GEMMVERBOSE("onednn_amx_sgemm_f32bf16f32_compute_silu", @@ -855,7 +887,7 @@ class MMHelper { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 GEMMVERBOSE("xdnn_sgemm_f32bf16f32_compute_gelu", xdnn_sgemm_f32bf16f32_compute_gelu( - transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, beta, C, ldc)); + transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if constexpr (std::is_same_v) { GEMMVERBOSE("onednn_amx_sgemm_f32bf16f32_compute_gelu", @@ -965,7 +997,7 @@ class MMHelper { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 GEMMVERBOSE("xdnn_sgemm_f32bf16f32_compute_resmul", xdnn_sgemm_f32bf16f32_compute_resmul( - transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, beta, C, ldc, res, ldres)); + transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, res, ldres)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if constexpr (std::is_same_v) { GEMMVERBOSE("onednn_amx_sgemm_f32bf16f32_compute_resmul", @@ -1076,7 +1108,7 @@ class MMHelper { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 GEMMVERBOSE("xdnn_sgemm_f32bf16f32_compute_residential", xdnn_sgemm_f32bf16f32_compute_residential(transA, M, N, K, alpha, A, lda, - (const XDNN_UINT4x2 *)packedB, beta, C, ldc, bias, res, ldres)); + (const XDNN_BF16 *)packedB, beta, C, ldc, bias, res, ldres)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) // TODO: xdnn impl? if constexpr (std::is_same_v) { @@ -1187,7 +1219,7 @@ class MMHelper { else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 GEMMVERBOSE("xdnn_sgemm_f32bf16f32_compute_resext", - xdnn_sgemm_f32bf16f32_compute_resext(transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, + xdnn_sgemm_f32bf16f32_compute_resext(transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, bias, gamma, res, ldres)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if constexpr (std::is_same_v) { @@ -1283,6 +1315,11 @@ class MMHelper { } } + int getEngineCount() { + int count = engine->get_count(kind); + return count; + } + private: dnnl::engine::kind kind; dnnl::engine *engine; @@ -1310,10 +1347,10 @@ class MMHelper { dnnl::memory::format_tag get_onednn_input_layout(dnnl::memory::data_type dt) { if (this->kind == dnnl::engine::kind::cpu) { - return dnnl::memory::format_tag::undef; + return dnnl::memory::format_tag::ab; } else if (this->kind == dnnl::engine::kind::gpu) { - return dnnl::memory::format_tag::AB32a16b; - // return dnnl::memory::format_tag::any; + return dnnl::memory::format_tag::ab; + // return dnnl::memory::format_tag::AB32a16b; } else { printf("[XFT][ERROR] Need a right engine kind in input layout."); std::exit(-1); @@ -1322,7 +1359,7 @@ class MMHelper { dnnl::memory::format_tag get_onednn_weight_layout(dnnl::memory::data_type dt) { if (this->kind == dnnl::engine::kind::cpu) { - if (dt == dnnl::memory::data_type::bf16) { + if (dt == dnnl::memory::data_type::bf16 || dt == dnnl::memory::data_type::f16) { return dnnl::memory::format_tag::BA16a64b2a; } else if (dt == dnnl::memory::data_type::s8) { return dnnl::memory::format_tag::BA16a64b4a; @@ -1331,26 +1368,237 @@ class MMHelper { std::exit(-1); } } else if (this->kind == dnnl::engine::kind::gpu) { - return dnnl::memory::format_tag::BA4b8a8b2a; - // return dnnl::memory::format_tag::any; + return dnnl::memory::format_tag::ba; + // return dnnl::memory::format_tag::BA4b8a8b2a; } else { printf("[XFT][ERROR] Need a right engine kind in weight layout."); std::exit(-1); } } + dnnl::memory::format_tag get_onednn_bias_layout(dnnl::memory::data_type dt) { + if (this->kind == dnnl::engine::kind::cpu) { + return dnnl::memory::format_tag::ab; + } else if (this->kind == dnnl::engine::kind::gpu) { + return dnnl::memory::format_tag::ab; + } else { + printf("[XFT][ERROR] Need a right engine kind in bias layout."); + std::exit(-1); + } + } + + dnnl::memory::format_tag get_onednn_shift_layout(dnnl::memory::data_type dt) { + if (this->kind == dnnl::engine::kind::cpu) { + return dnnl::memory::format_tag::ab; + } else if (this->kind == dnnl::engine::kind::gpu) { + return dnnl::memory::format_tag::ab; + } else { + printf("[XFT][ERROR] Need a right engine kind in shift layout."); + std::exit(-1); + } + } + dnnl::memory::format_tag get_onednn_output_layout(dnnl::memory::data_type dt) { if (this->kind == dnnl::engine::kind::cpu) { - return dnnl::memory::format_tag::undef; + return dnnl::memory::format_tag::ab; } else if (this->kind == dnnl::engine::kind::gpu) { - return dnnl::memory::format_tag::AB32a16b; - // return dnnl::memory::format_tag::any; + return dnnl::memory::format_tag::ab; + // return dnnl::memory::format_tag::AB32a16b; } else { printf("[XFT][ERROR] Need a right engine kind in output layout."); std::exit(-1); } } + template + void onednn_gemm_compute(bool transA, int M, int N, int K, float alpha, const Tin *A, int lda, + const Twei *packedB, float beta, Tout *C, int ldc, const Tbias *bias = nullptr, + const Tres *res = nullptr, int ldres = -1, const matmul_kinds postAlg = matmul_kinds::Basic) { + TimeLine t("onednn_gemm_compute"); + TimeLine t1("onednn_gemm_compute.create_primitive"); + using namespace dnnl; + using tag = memory::format_tag; + using dt = memory::data_type; + + dt input_dt; + if constexpr (std::is_same_v) { + input_dt = dt::f32; + } else if constexpr (std::is_same_v) { + input_dt = dt::bf16; + } else if constexpr (std::is_same_v) { + input_dt = dt::f16; + } else { + printf(">>> onednn_gemm_compute: input date type not supported."); + exit(-1); + } + + dt weight_dt; + if constexpr (std::is_same_v) { + weight_dt = dt::f32; + } else if constexpr (std::is_same_v) { + weight_dt = dt::bf16; + } else if constexpr (std::is_same_v) { + weight_dt = dt::f16; + } else { + printf(">>> onednn_gemm_compute: weight date type not supported."); + exit(-1); + } + + dt output_dt; + if constexpr (std::is_same_v) { + output_dt = dt::f32; + } else if constexpr (std::is_same_v) { + output_dt = dt::bf16; + } else if constexpr (std::is_same_v) { + output_dt = dt::f16; + } else { + printf(">>> onednn_gemm_compute: output date type not supported."); + exit(-1); + } + + dt bias_dt; + if constexpr (std::is_same_v) { + bias_dt = dt::f32; + } else if constexpr (std::is_same_v) { + bias_dt = dt::bf16; + } else if constexpr (std::is_same_v) { + bias_dt = dt::f16; + } else { + printf(">>> onednn_gemm_compute: bias date type not supported."); + exit(-1); + } + + dt shift_dt; + if constexpr (std::is_same_v) { + shift_dt = dt::f32; + } else if constexpr (std::is_same_v) { + shift_dt = dt::bf16; + } else if constexpr (std::is_same_v) { + shift_dt = dt::f16; + } else { + printf(">>> onednn_gemm_compute: res date type not supported."); + exit(-1); + } + + matmul::primitive_desc *matmul_pd; + matmul *matmul_prim; + std::string key = create_key(transA, M, N, K, postAlg); + auto it = matmul_hub.find(key); + if (it != matmul_hub.end()) { + matmul_pd = std::get<0>(it->second); + matmul_prim = std::get<1>(it->second); + } else { + // Source (A), weights (B) and destination (C) matrix dimensions. + memory::dims input_dims = {M, K}; + memory::dims weight_dims = {K, N}; + memory::dims output_dims = {M, N}; + memory::dims bias_dims = {1, N}; + memory::dims shift_dims = {M, N}; + + // Create memory descriptors and memory objects for src, weights, bias, and dst. + auto input_md = memory::desc(input_dims, input_dt, get_onednn_input_layout(input_dt)); + auto weight_md = memory::desc(weight_dims, weight_dt, get_onednn_weight_layout(weight_dt)); + auto output_md = memory::desc(output_dims, output_dt, get_onednn_output_layout(output_dt)); + auto bias_md = memory::desc(bias_dims, bias_dt, get_onednn_bias_layout(bias_dt)); + auto shift_md = memory::desc(shift_dims, shift_dt, get_onednn_shift_layout(shift_dt)); + + // Create primitive descriptor and primitive. + primitive_attr matmul_attr; + switch (postAlg) { + case matmul_kinds::Basic: { + break; + } + case matmul_kinds::Silu: { + const float post_alpha = 1.0f; + const float post_beta = 0.0f; + post_ops matmul_ops; + matmul_ops.append_eltwise(algorithm::eltwise_swish, post_alpha, post_beta); + matmul_attr.set_post_ops(matmul_ops); + break; + } + case matmul_kinds::Gelu: { + const float post_alpha = 1.0f; + const float post_beta = 0.0f; + post_ops matmul_ops; + matmul_ops.append_eltwise(algorithm::eltwise_gelu_tanh, post_alpha, post_beta); + matmul_attr.set_post_ops(matmul_ops); + break; + } + case matmul_kinds::Residential: { + if (res == nullptr) { + printf(">>> onednn_gemm_compute: Residential need be valuable."); + exit(-1); + } + + post_ops matmul_ops; + matmul_ops.append_binary(algorithm::binary_add, shift_md); + matmul_attr.set_post_ops(matmul_ops); + break; + } + default: { + printf(">>> onednn_gemm_compute: postAlg type %s not supported.", std::to_string(postAlg).c_str()); + exit(-1); + } + } + + if (postAlg == matmul_kinds::Basic) { + if (bias != nullptr) + matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, bias_md, output_md); + else + matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md); + } else { + if (bias != nullptr) + matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, bias_md, output_md, matmul_attr); + else + matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md, matmul_attr); + } + + matmul_prim = new matmul(*matmul_pd); + + // Cache primitive_desc and matmul + std::string key = create_key(transA, M, N, K, postAlg); + std::tuple value(matmul_pd, matmul_prim); + matmul_hub[key] = value; + } + + // Repack and convert input data. + memory input_mem; + if constexpr (std::is_same_v) { + input_mem = memory(matmul_pd->src_desc(), *engine); + } else { + input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + } + + auto weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(packedB)); + auto output_mem = memory(matmul_pd->dst_desc(), *engine, C); + memory bias_mem; + if (bias != nullptr) { bias_mem = memory(matmul_pd->bias_desc(), *engine, const_cast(bias)); } + auto shift_md = memory::desc({M, N}, shift_dt, get_onednn_shift_layout(shift_dt)); + auto shift_mem = memory(shift_md, *engine, const_cast(res)); + + // Create the primitive args. + std::unordered_map matmul_args; + matmul_args.insert({DNNL_ARG_SRC, input_mem}); + matmul_args.insert({DNNL_ARG_WEIGHTS, weight_mem}); + if (bias != nullptr) { matmul_args.insert({DNNL_ARG_BIAS, bias_mem}); } + if (res != nullptr) { matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, shift_mem}); } + matmul_args.insert({DNNL_ARG_DST, output_mem}); + t1.release(); + + // Executions. + TimeLine t2("onednn_gemm_compute.execute_primitive"); + // Reorder + if constexpr (std::is_same_v && std::is_same_v) { +#pragma omp parallel for + for (uint64_t i = 0; i < M; ++i) { + bfloat16_t::cvt_float_to_bfloat16(A + i * lda, (bfloat16_t *)input_mem.get_data_handle() + i * K, K); + } + } + + matmul_prim->execute(*stream, matmul_args); + stream->wait(); + } + template void onednn_amx_sgemm_f32bf16f32_compute(bool transA, int M, int N, int K, float alpha, const Tin *A, int lda, const bfloat16_t *packedB, float beta, Tout *C, int ldc, const matmul_kinds postAlg = matmul_kinds::Basic) { @@ -1374,13 +1622,13 @@ class MMHelper { memory::dims output_dims = {M, N}; // Create memory descriptors and memory objects for src, weights, bias, and dst. - auto input_md = memory::desc(input_dims, dt::bf16, tag::ab); + auto input_md = memory::desc(input_dims, dt::bf16, get_onednn_input_layout(dt::bf16)); auto weight_md = memory::desc(weight_dims, dt::bf16, get_onednn_weight_layout(dt::bf16)); memory::desc output_md; if constexpr (std::is_same_v) { - output_md = memory::desc(output_dims, dt::f32, tag::ab); + output_md = memory::desc(output_dims, dt::f32, get_onednn_output_layout(dt::f32)); } else if constexpr (std::is_same_v) { - output_md = memory::desc(output_dims, dt::bf16, tag::ab); + output_md = memory::desc(output_dims, dt::bf16, get_onednn_output_layout(dt::bf16)); } else { printf(">>> onednn amx output date type not supported."); exit(-1); diff --git a/src/utils/simple_mem_pool.h b/src/utils/simple_mem_pool.h index 6fe36633..6ab765e9 100644 --- a/src/utils/simple_mem_pool.h +++ b/src/utils/simple_mem_pool.h @@ -46,7 +46,7 @@ class SimpleMemPool { } // Allocate or reallocate memory buffer based on name and size - void *getBuffer(const std::string &name, size_t size, size_t alignment = 64) { + void *getBuffer(const std::string &name, size_t size, void *device = nullptr, size_t alignment = 64) { if (size == 0) { // std::cout << "[Warning] Try to allocate 0 bytes for buffer:" << name << std::endl; return nullptr; @@ -65,7 +65,7 @@ class SimpleMemPool { } // Allocate new aligned buffer - void *buffer = xft::alloc(size, alignment); + void *buffer = xft::alloc(size, alignment, device); if (buffer == nullptr) { // Allocation failed std::cerr << "Memory allocation failed for buffer:" << name << " size:" << size << std::endl; @@ -78,6 +78,15 @@ class SimpleMemPool { return buffer; } + // Free allocated memory based on name + void *freeBuffer(const std::string &name, void *device = nullptr) { + auto it = memoryMap.find(name); + + if (it != memoryMap.end()) { + xft::dealloc(it->second.first, device); + } + } + // Destructor to free all allocated memory on program termination ~SimpleMemPool() { for (auto &entry : memoryMap) { From 0b1be0ec2b9a69ee38601bb5ef6bd1744c584aa6 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Tue, 7 May 2024 18:19:55 +0000 Subject: [PATCH 02/34] format code --- src/layers/attention.h | 5 ++- src/layers/dist_linear.h | 3 +- src/layers/rotary_embedding.cpp | 61 ++++++++++++++-------------- src/utils/matmul_helper.h | 70 ++++++++++++++++----------------- 4 files changed, 70 insertions(+), 69 deletions(-) diff --git a/src/layers/attention.h b/src/layers/attention.h index d9aeb83a..0959de08 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -295,8 +295,9 @@ class Attention { } qkpo.forward(query.Data(), key.Data(), query.Stride(), key.Stride(), qkShape, posIds.data()); #ifdef GPU - sycl::queue *gpu_queue = static_cast(ctx->device); - gpu_queue->memcpy(qkvMatMul.Data(), query.Data(), ctx->batchSize * ctx->inputSeqLen * qkvCols * sizeof(float)).wait(); + sycl::queue *q = static_cast(ctx->device); + int64_t size = ctx->batchSize * ctx->inputSeqLen * qkvCols * sizeof(float); + q->memcpy(qkvMatMul.Data(), query.Data(), size).wait(); #endif } t3.release(); diff --git a/src/layers/dist_linear.h b/src/layers/dist_linear.h index c5569571..f95dd791 100644 --- a/src/layers/dist_linear.h +++ b/src/layers/dist_linear.h @@ -74,8 +74,7 @@ class DistLinear { sycl::queue *gpu_queue = static_cast(ctx->device); WeiT *input_data = sycl::malloc_device(K * N, *gpu_queue); weight.Assign(input_data, K, N, N); - gpu_queue->memcpy(weight.Data(), tWeight.Data(), tWeight.Rows() * tWeight.Cols() * sizeof(WeiT)) - .wait(); + gpu_queue->memcpy(weight.Data(), tWeight.Data(), tWeight.Rows() * tWeight.Cols() * sizeof(WeiT)).wait(); #else weight.Resize(K, N); ctx->mmHelper->packWeight(true, quantizedWeight, weight); diff --git a/src/layers/rotary_embedding.cpp b/src/layers/rotary_embedding.cpp index e4745dbe..ff8cbae5 100644 --- a/src/layers/rotary_embedding.cpp +++ b/src/layers/rotary_embedding.cpp @@ -28,8 +28,7 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) { ctx->GetAttr("rope_theta", &this->base, 10000); ctx->GetAttr("rope_type", &this->rope_type, std::to_string(-1)); - if (this->rope_type == "linear") - ctx->GetAttr("scaling_factor", &this->scaling_factor, 1.0f); + if (this->rope_type == "linear") ctx->GetAttr("scaling_factor", &this->scaling_factor, 1.0f); inv_freq_size = (dim + 1) / 2; @@ -138,31 +137,32 @@ void LlamaRotaryEmbedding::forward( const int half_head_size = (head_size + 1) / 2; using namespace sycl; - auto rope_kernel = [](sycl::nd_item<3> &item, const float *embCos, const float *embSin, const int qHeads, - const int kHeads, const int seq_size, const int head_size, const int half, float *query, float *key, - int qStride, int kStride, const sycl::accessor &positionIds) { - size_t idx_bs_seq = item.get_global_id(0); - size_t idx_head_num = item.get_global_id(1); - size_t idx_half_head_dim = item.get_global_id(2); - - size_t pos = positionIds[idx_bs_seq % seq_size]; - float cos = embCos[pos * half + idx_half_head_dim]; - float sin = embSin[pos * half + idx_half_head_dim]; - - float *q = query + idx_bs_seq * qStride + idx_head_num * head_size + idx_half_head_dim; - float *k = key + idx_bs_seq * kStride + idx_head_num * head_size + idx_half_head_dim; - - if (idx_head_num < qHeads) { - auto q1 = q[0]; - q[0] = q1 * cos - q[half] * sin; - q[half] = q[half] * cos + q1 * sin; - } - if (idx_head_num < kHeads) { - auto k1 = k[0]; - k[0] = k1 * cos - k[half] * sin; - k[half] = k[half] * cos + k1 * sin; - } - }; + auto rope_kernel + = [](sycl::nd_item<3> &item, const float *embCos, const float *embSin, const int qHeads, const int kHeads, + const int seq_size, const int head_size, const int half, float *query, float *key, int qStride, + int kStride, const sycl::accessor &positionIds) { + size_t idx_bs_seq = item.get_global_id(0); + size_t idx_head_num = item.get_global_id(1); + size_t idx_half_head_dim = item.get_global_id(2); + + size_t pos = positionIds[idx_bs_seq % seq_size]; + float cos = embCos[pos * half + idx_half_head_dim]; + float sin = embSin[pos * half + idx_half_head_dim]; + + float *q = query + idx_bs_seq * qStride + idx_head_num * head_size + idx_half_head_dim; + float *k = key + idx_bs_seq * kStride + idx_head_num * head_size + idx_half_head_dim; + + if (idx_head_num < qHeads) { + auto q1 = q[0]; + q[0] = q1 * cos - q[half] * sin; + q[half] = q[half] * cos + q1 * sin; + } + if (idx_head_num < kHeads) { + auto k1 = k[0]; + k[0] = k1 * cos - k[half] * sin; + k[half] = k[half] * cos + k1 * sin; + } + }; // Reorder input sycl::queue *gpu_queue = static_cast(device); @@ -177,10 +177,11 @@ void LlamaRotaryEmbedding::forward( cgh.parallel_for( sycl::nd_range(globalSize, workGroupSize), [=, this](sycl::nd_item<3> item) { - rope_kernel(item, embCos, embSin, qHeads, kHeads, seqLen, head_size, half_head_size, - query, key, qStride, kStride, position); + rope_kernel(item, embCos, embSin, qHeads, kHeads, seqLen, head_size, half_head_size, query, key, + qStride, kStride, position); }); - }).wait(); + }); + gpu_queue->wait(); } #else diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index 927a24be..469466df 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -441,8 +441,8 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("onednn_gemm_compute", - onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc)); + GEMMVERBOSE( + "onednn_gemm_compute", onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc)); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute", xdnn_hgemm_f32f16f32_compute( @@ -1411,9 +1411,9 @@ class MMHelper { } template - void onednn_gemm_compute(bool transA, int M, int N, int K, float alpha, const Tin *A, int lda, - const Twei *packedB, float beta, Tout *C, int ldc, const Tbias *bias = nullptr, - const Tres *res = nullptr, int ldres = -1, const matmul_kinds postAlg = matmul_kinds::Basic) { + void onednn_gemm_compute(bool transA, int M, int N, int K, float alpha, const Tin *A, int lda, const Twei *packedB, + float beta, Tout *C, int ldc, const Tbias *bias = nullptr, const Tres *res = nullptr, int ldres = -1, + const matmul_kinds postAlg = matmul_kinds::Basic) { TimeLine t("onednn_gemm_compute"); TimeLine t1("onednn_gemm_compute.create_primitive"); using namespace dnnl; @@ -1548,7 +1548,8 @@ class MMHelper { matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md); } else { if (bias != nullptr) - matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, bias_md, output_md, matmul_attr); + matmul_pd + = new matmul::primitive_desc(*engine, input_md, weight_md, bias_md, output_md, matmul_attr); else matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md, matmul_attr); } @@ -1635,35 +1636,34 @@ class MMHelper { } // Create primitive descriptor and primitive. - switch (postAlg) - { - case matmul_kinds::Basic: - matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md); - break; - case matmul_kinds::Silu:{ - const float post_alpha = 1.0f; - const float post_beta = 0.0f; - post_ops matmul_ops; - matmul_ops.append_eltwise(algorithm::eltwise_swish, post_alpha, post_beta); - primitive_attr matmul_attr; - matmul_attr.set_post_ops(matmul_ops); - matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md, matmul_attr); - break; - } - case matmul_kinds::Gelu:{ - const float post_alpha = 1.0f; - const float post_beta = 0.0f; - post_ops matmul_ops; - matmul_ops.append_eltwise(algorithm::eltwise_gelu_tanh, post_alpha, post_beta); - primitive_attr matmul_attr; - matmul_attr.set_post_ops(matmul_ops); - matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md, matmul_attr); - break; - } - default: - printf(">>> onednn amx postAlg type %s not supported.", std::to_string(postAlg).c_str()); - exit(-1); - break; + switch (postAlg) { + case matmul_kinds::Basic: + matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md); + break; + case matmul_kinds::Silu: { + const float post_alpha = 1.0f; + const float post_beta = 0.0f; + post_ops matmul_ops; + matmul_ops.append_eltwise(algorithm::eltwise_swish, post_alpha, post_beta); + primitive_attr matmul_attr; + matmul_attr.set_post_ops(matmul_ops); + matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md, matmul_attr); + break; + } + case matmul_kinds::Gelu: { + const float post_alpha = 1.0f; + const float post_beta = 0.0f; + post_ops matmul_ops; + matmul_ops.append_eltwise(algorithm::eltwise_gelu_tanh, post_alpha, post_beta); + primitive_attr matmul_attr; + matmul_attr.set_post_ops(matmul_ops); + matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md, matmul_attr); + break; + } + default: + printf(">>> onednn amx postAlg type %s not supported.", std::to_string(postAlg).c_str()); + exit(-1); + break; } matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul From 3ef6143f6c46349466cdf5d37c4517d47393f0e0 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Tue, 7 May 2024 20:51:35 +0000 Subject: [PATCH 03/34] fix running issue. --- src/models/model_factory.h | 2 +- src/utils/matmul_helper.h | 13 +++++++------ src/utils/simple_mem_pool.h | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/models/model_factory.h b/src/models/model_factory.h index d13a9310..2730347d 100644 --- a/src/models/model_factory.h +++ b/src/models/model_factory.h @@ -109,4 +109,4 @@ class DecoderRegister { MODEL(IMPLEMENT, CLASS, NAME) #define REGISTER_MODEL(CLASS, NAME) \ - MODEL(REGISTER, CLASS, NAME) \ No newline at end of file + MODEL(REGISTER, CLASS, NAME) diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index 469466df..ca922036 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -349,8 +349,11 @@ class MMHelper { // W8A8 else if constexpr (std::is_same_v) { using dt = dnnl::memory::data_type; + dnnl::engine eng(dnnl::engine::kind::cpu, 0); + dnnl::stream stm(eng); + auto tag = trans ? dnnl::memory::format_tag::ba : dnnl::memory::format_tag::ab; - dnnl::memory B_mem({{K, N}, dt::s8, tag}, *this->engine, src.Data()); + dnnl::memory B_mem({{K, N}, dt::s8, tag}, eng, src.Data()); dnnl::memory::desc desc({K, N}, dt::s8, get_onednn_weight_layout(dt::s8)); // When converting to oneDNN blocked memory format, padded dims can be larger than [K, N] @@ -360,11 +363,9 @@ class MMHelper { weight.Resize(dims[0], dims[1]); weight.Resize(K, N); - dnnl::engine engine(dnnl::engine::kind::cpu, 0); - dnnl::stream stream(engine); - dnnl::memory packedB_mem(desc, engine, weight.Data()); - dnnl::reorder(B_mem, packedB_mem).execute(stream, B_mem, packedB_mem); - stream.wait(); + dnnl::memory packedB_mem(desc, eng, weight.Data()); + dnnl::reorder(B_mem, packedB_mem).execute(stm, B_mem, packedB_mem); + stm.wait(); } // INT4 diff --git a/src/utils/simple_mem_pool.h b/src/utils/simple_mem_pool.h index 6ab765e9..63f7a2fa 100644 --- a/src/utils/simple_mem_pool.h +++ b/src/utils/simple_mem_pool.h @@ -79,7 +79,7 @@ class SimpleMemPool { } // Free allocated memory based on name - void *freeBuffer(const std::string &name, void *device = nullptr) { + void freeBuffer(const std::string &name, void *device = nullptr) { auto it = memoryMap.find(name); if (it != memoryMap.end()) { From 619e7883db05b90935621dfaf6b09c49221b087e Mon Sep 17 00:00:00 2001 From: changqi1 Date: Wed, 15 May 2024 20:57:19 +0000 Subject: [PATCH 04/34] Add RmsNorm kernel. --- src/layers/attention.h | 63 +++++++++++++++++++++++++-------- src/layers/layer_norm.cpp | 28 +++++++++++++-- src/layers/layer_norm.h | 3 ++ src/layers/mlp_chatglm2.h | 9 ++--- src/layers/mlp_llama.cpp | 4 +-- src/layers/mlp_llama.h | 52 +++++++++++++++++++-------- src/layers/rms_norm.cpp | 37 +++++++++++++++++-- src/layers/rms_norm.h | 3 ++ src/layers/rotary_embedding.cpp | 1 + src/models/llama.cpp | 7 ++-- src/models/llama.h | 2 +- 11 files changed, 165 insertions(+), 44 deletions(-) diff --git a/src/layers/attention.h b/src/layers/attention.h index 0959de08..7b23581e 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -25,12 +25,12 @@ #include "gemm_kernel_ext.h" #include "kvcache_tensor.h" #include "matmul_helper.h" +#include "rms_norm.h" +#include "rotary_embedding.h" #include "simple_mem_pool.h" #include "transformer_ctx.h" #include "transformer_util.h" -#include "rotary_embedding.h" - /** * WeiT: weight data type * InT: input data type @@ -50,6 +50,8 @@ class Attention { //todo(marvin): clear this code after all rotary_emb refactor if constexpr (std::is_same::value) { qkpo = LlamaRotaryEmbedding(ctx); } + norm = new NORM_CLS(ctx); + // Group attention or multi-head attention (multi-head attn is a special case of group attn) if (ctx->attHeadNum % ctx->kvHeadNum == 0) { // We are responsible for the range [startQHead, endQHead) @@ -85,7 +87,6 @@ class Attention { int qResponsibleCols = (this->endQHead - this->startQHead) * headSize; int kvResponsibleCols = (this->endKVHead - this->startKVHead) * headSize; int responsibleCols = qResponsibleCols + 2 * kvResponsibleCols; - qkvWeight.Resize(hiddenSize, responsibleCols); constexpr int sizeFactor = std::is_same_v ? 2 : 1; @@ -135,7 +136,21 @@ class Attention { hpj::Matrix convertedqkvWeight; ctx->mmHelper->convertWeight(trans, hiddenSize, responsibleCols, concatBuf, concatScale, concatZero, convertedqkvWeight, qkvWeightScale, qkvWeightZero, qkvWeightSum); + +#ifdef GPU + hpj::Matrix qkvWeightT; + qkvWeightT.Resize(hiddenSize, responsibleCols); + ctx->mmHelper->transposeWeight(true, convertedqkvWeight, qkvWeightT); + + sycl::queue *gpu_queue_1 = static_cast(ctx->device); + WeiT *qkvWeiData = sycl::malloc_device(hiddenSize * responsibleCols, *gpu_queue_1); + qkvWeight.Assign(qkvWeiData, hiddenSize, responsibleCols, responsibleCols); + gpu_queue_1->memcpy(qkvWeight.Data(), qkvWeightT.Data(), qkvWeightT.Rows() * qkvWeightT.Cols() * sizeof(WeiT)) + .wait(); +#else + qkvWeight.Resize(hiddenSize, responsibleCols); ctx->mmHelper->packWeight(trans, convertedqkvWeight, qkvWeight); +#endif free(concatBuf); free(concatScale); @@ -162,16 +177,30 @@ class Attention { // Weights for attention output // Horizontally split the weight, as the source (PyTorch weight) is transposed, thus looks like vertically - hpj::Matrix convertedWeight; + hpj::Matrix convertedOutWeight; ctx->mmHelper->convertWeight(trans, ctx->attHeadNum * ctx->attHeadSize, hiddenSize, attnOutWeight, attnOutScale, - attnOutZero, this->startQHead * headSize, qResponsibleCols, false, convertedWeight, + attnOutZero, this->startQHead * headSize, qResponsibleCols, false, convertedOutWeight, attnOutputWeightScale, attnOutputWeightZero, attnOutputWeightSum, true); - ctx->mmHelper->packWeight(trans, convertedWeight, attnOutputWeight); + +#ifdef GPU + hpj::Matrix outWeightT; + outWeightT.Resize(ctx->attHeadNum * ctx->attHeadSize, hiddenSize); + ctx->mmHelper->transposeWeight(true, convertedOutWeight, outWeightT); + + sycl::queue *gpu_queue_2 = static_cast(ctx->device); + WeiT *outWeiData = sycl::malloc_device(ctx->attHeadNum * ctx->attHeadSize * hiddenSize, *gpu_queue_2); + attnOutputWeight.Assign(outWeiData, ctx->attHeadNum * ctx->attHeadSize, hiddenSize, hiddenSize); + int outWeightTSize = outWeightT.Rows() * outWeightT.Cols() * sizeof(WeiT); + gpu_queue_2->memcpy(attnOutputWeight.Data(), outWeightT.Data(), outWeightTSize).wait(); +#else + attnOutputWeight.Resize(ctx->attHeadNum * ctx->attHeadSize, hiddenSize); + ctx->mmHelper->packWeight(trans, convertedOutWeight, attnOutputWeight); +#endif #ifdef DEBUG - dbg.debugPrint(">>> attention output weight: [%d, %d] (%d)\n", convertedWeight.Rows(), convertedWeight.Cols(), - convertedWeight.Stride()); - dbg.dumpMatrix(convertedWeight); + dbg.debugPrint(">>> attention output weight: [%d, %d] (%d)\n", convertedOutWeight.Rows(), + convertedOutWeight.Cols(), convertedOutWeight.Stride()); + dbg.dumpMatrix(convertedOutWeight); dbg.debugPrint("attention output packed weight: [%d, %d] (%d)\n", attnOutputWeight.Rows(), attnOutputWeight.Cols(), attnOutputWeight.Stride()); dbg.dumpMatrix(attnOutputWeight); @@ -188,7 +217,7 @@ class Attention { } // LayerNorm - this->norm.setWeight(gamma1, beta1, hiddenSize); + this->norm->setWeight(gamma1, beta1, hiddenSize); } #ifdef DEBUG @@ -242,7 +271,7 @@ class Attention { if (doLnBefore) { TimeLine t1("input.layer_norm"); - norm.forward(inputBuffer.Data(), imBuffer.Data(), inputBuffer.Rows(), inputBuffer.Stride(), + norm->forward(inputBuffer.Data(), imBuffer.Data(), inputBuffer.Rows(), inputBuffer.Stride(), imBuffer.Stride(), epsilon); } #ifdef DEBUG @@ -297,7 +326,7 @@ class Attention { #ifdef GPU sycl::queue *q = static_cast(ctx->device); int64_t size = ctx->batchSize * ctx->inputSeqLen * qkvCols * sizeof(float); - q->memcpy(qkvMatMul.Data(), query.Data(), size).wait(); + q->memcpy(qkvMatMul.Data(), query.Data(), size).wait(); // error: need CPU ptr and GPU ptr #endif } t3.release(); @@ -344,6 +373,12 @@ class Attention { dbg.dumpMatrix(attnSplit); #endif +#ifdef GPU + sycl::queue *q = static_cast(ctx->device); + int64_t size = ctx->batchSize * ctx->inputSeqLen * qkvCols * sizeof(float); + q->memcpy(qkvMatMul.Data(), attnSplit.Data(), size).wait(); // error: need CPU ptr and GPU ptr +#endif + TimeLine t5("Output"); // Output/projection in attention, only add the input in the first split if (ctx->splitIdx == 0) { @@ -389,7 +424,7 @@ class Attention { if (!doLnBefore) { TimeLine t6("result.layer_norm"); - norm.forward(outBuffer.Data(), outBuffer.Data(), outBuffer.Rows(), outBuffer.Stride(), outBuffer.Stride()); + norm->forward(outBuffer.Data(), outBuffer.Data(), outBuffer.Rows(), outBuffer.Stride(), outBuffer.Stride()); #ifdef DEBUG dbg.debugPrint("LayerNorm after attention: [%d, %d] (%d)\n", outBuffer.Rows(), outBuffer.Cols(), outBuffer.Stride()); @@ -906,7 +941,7 @@ class Attention { QKPO_CLS qkpo; // layerNorm param - NORM_CLS norm; + NORM_CLS *norm; int layerId; // The responsible head in the global view diff --git a/src/layers/layer_norm.cpp b/src/layers/layer_norm.cpp index 4ebfc6b7..c5428d37 100644 --- a/src/layers/layer_norm.cpp +++ b/src/layers/layer_norm.cpp @@ -31,17 +31,32 @@ LayerNorm::LayerNorm() { normSize = 0; } +LayerNorm::LayerNorm(DecoderContext *ctx) { + device = ctx->device; + gamma = nullptr; + beta = nullptr; + normSize = 0; +} + LayerNorm::~LayerNorm() { - if (gamma) { free(gamma); } - if (beta) { free(beta); } + if (gamma) { xft::dealloc(gamma); } + if (beta) { xft::dealloc(beta); } } void LayerNorm::setWeight(const float *gamma, const float *beta, int cols) { this->normSize = cols; +#ifdef GPU + sycl::queue *gpu_queue = static_cast(device); + this->gamma = sycl::malloc_device(cols, *gpu_queue); + this->beta = sycl::malloc_device(cols, *gpu_queue); + gpu_queue->memcpy(this->gamma, gamma, cols * sizeof(float)).wait(); + gpu_queue->memcpy(this->beta, beta, cols * sizeof(float)).wait(); +#else this->gamma = (float *)xft::alloc(cols * sizeof(float)); this->beta = (float *)xft::alloc(cols * sizeof(float)); memcpy(this->gamma, gamma, cols * sizeof(float)); memcpy(this->beta, beta, cols * sizeof(float)); +#endif } void LayerNorm::setWeight(const std::string &gammaPath, const std::string &betaPath, int cols) { @@ -52,11 +67,18 @@ void LayerNorm::setWeight(const std::string &gammaPath, const std::string &betaP // input and output are in shape of (rows, normSize) // TODO: column-wise parallel +#ifdef GPU +void LayerNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { + TimeLine t("LayerNorm.forward"); + const float *pgamma = gamma; + const float *pbeta = beta; +} +#else void LayerNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("LayerNorm.forward"); const float *pgamma = gamma; const float *pbeta = beta; invokeLayerNorm(output, input, pgamma, pbeta, rows, normSize, iStride, oStride); } - +#endif } // namespace xft \ No newline at end of file diff --git a/src/layers/layer_norm.h b/src/layers/layer_norm.h index 8b554648..75d9409b 100644 --- a/src/layers/layer_norm.h +++ b/src/layers/layer_norm.h @@ -16,6 +16,7 @@ #include #include "weight_util.h" +#include "transformer_ctx.h" namespace xft { @@ -23,6 +24,7 @@ namespace xft { class LayerNorm { public: LayerNorm(); + LayerNorm(DecoderContext *ctx); ~LayerNorm(); void setWeight(const float *gamma, const float *beta, int cols); @@ -37,6 +39,7 @@ class LayerNorm { float *gamma = nullptr; float *beta = nullptr; + void *device = nullptr; }; } // namespace xft \ No newline at end of file diff --git a/src/layers/mlp_chatglm2.h b/src/layers/mlp_chatglm2.h index dbc83cd8..2193725c 100644 --- a/src/layers/mlp_chatglm2.h +++ b/src/layers/mlp_chatglm2.h @@ -121,9 +121,10 @@ class ChatGLM2MLP : public LlamaMLP { this->dbg.dumpMatrix(this->downWeight); #endif // norm.setWeight(normW, NULL, hiddenSize); - if (normW) { - this->normWeight.Resize(hiddenSize); - memcpy(this->normWeight.Data(), normW, sizeof(float) * hiddenSize); - } + + if (normW) { norm->setWeight(normW, nullptr, hiddenSize); } } + +private: + using LlamaMLP::norm; }; diff --git a/src/layers/mlp_llama.cpp b/src/layers/mlp_llama.cpp index c457a344..433616ed 100644 --- a/src/layers/mlp_llama.cpp +++ b/src/layers/mlp_llama.cpp @@ -45,7 +45,7 @@ void invokeMLPLLaMA(DataType dt, int numTokens, int hiddenSize, int intermediate auto it_created = llama_mlp_hub.find(llama_mlp_key); if (it_created == llama_mlp_hub.end()) { // LlamaMLP &llama_mlp = LlamaMLP::getInstance(); - llama_mlp = new LlamaMLP; + llama_mlp = new LlamaMLP(ctx); llama_mlp->setWeights(ctx, (float *)gateWeight, nullptr, nullptr, nullptr, (float *)upWeight, nullptr, nullptr, nullptr, nullptr, nullptr, (float *)downWeight, nullptr, nullptr, false); llama_mlp_hub[llama_mlp_key] = llama_mlp; @@ -77,7 +77,7 @@ void invokeMLPLLaMA(DataType dt, int numTokens, int hiddenSize, int intermediate auto it_created = llama_mlp_hub.find(llama_mlp_key); if (it_created == llama_mlp_hub.end()) { // LlamaMLP &llama_mlp = LlamaMLP::getInstance(); - llama_mlp = new LlamaMLP; + llama_mlp = new LlamaMLP(ctx); llama_mlp->setWeights(ctx, (float *)gateWeight, nullptr, nullptr, nullptr, (float *)upWeight, nullptr, nullptr, nullptr, nullptr, nullptr, (float *)downWeight, nullptr, nullptr, false); llama_mlp_hub[llama_mlp_key] = llama_mlp; diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index 42789085..7b3ca792 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -23,7 +23,7 @@ #include "debugger.h" #include "decoder_util.h" #include "matmul_helper.h" -#include "rmsnorm_kernels.h" +#include "rms_norm.h" #include "simple_mem_pool.h" #include "singleton.h" #include "timeline.h" @@ -38,12 +38,10 @@ // def forward(self, x): // return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) // But please also be noted: we extended the MLP to include layer norm -template -class LlamaMLP : public SingletonBase> { +template +class LlamaMLP { public: - LlamaMLP() {} - - LlamaMLP(DecoderContext *ctx) {} + LlamaMLP(DecoderContext *ctx) { norm = new NORM_CLS(ctx); } // OriWeiT: float, int8_t or uint4x2_t template @@ -61,7 +59,6 @@ class LlamaMLP : public SingletonBase> { hpj::Matrix quantizedGateWeight, quantizedUpWeight, quantizedDownWeight; auto it = SplitUtil::getTaskRange(imSize, ctx->numSplit, ctx->splitIdx); - downWeight.Resize(it.second - it.first, hiddenSize); ctx->mmHelper->convertWeight(ctx, trans, hiddenSize, imSize, gateW, gateS, gateZ, true, quantizedGateWeight, gateWeightScale, gateWeightZero, gateWeightSum); @@ -80,13 +77,41 @@ class LlamaMLP : public SingletonBase> { catWeightsSum); quantizedGateWeight.Release(); quantizedUpWeight.Release(); + +#ifdef GPU + hpj::Matrix catWeightsT; + int catWeiRows = quantizedCatWeights.Rows(); + int catWeiCols = quantizedCatWeights.Cols(); + catWeightsT.Resize(catWeiRows, catWeiCols); + ctx->mmHelper->transposeWeight(true, quantizedCatWeights, catWeightsT); + + sycl::queue *gpu_queue = static_cast(ctx->device); + WeiT *catWeiData = sycl::malloc_device(catWeiRows * catWeiCols, *gpu_queue); + catWeights.Assign(catWeiData, catWeiRows, catWeiCols, catWeiCols); + gpu_queue->memcpy(catWeights.Data(), catWeightsT.Data(), catWeiRows * catWeiCols * sizeof(WeiT)).wait(); +#else catWeights.Resize(quantizedCatWeights.Rows(), quantizedCatWeights.Cols()); ctx->mmHelper->packWeight(trans, quantizedCatWeights, catWeights); +#endif } // Horizontally split the down weight ctx->mmHelper->convertWeight(ctx, trans, imSize, hiddenSize, downW, downS, downZ, false, quantizedDownWeight, downWeightScale, downWeightZero, downWeightSum); +#ifdef GPU + hpj::Matrix downWeightT; + int downWeiRows = it.second - it.first; + int downWeiCols = hiddenSize; + downWeightT.Resize(downWeiRows, downWeiCols); + ctx->mmHelper->transposeWeight(true, quantizedDownWeight, downWeightT); + + sycl::queue *gpu_queue = static_cast(ctx->device); + WeiT *downWeiData = sycl::malloc_device(downWeiRows * downWeiCols, *gpu_queue); + downWeight.Assign(downWeiData, downWeiRows, downWeiCols, downWeiCols); + gpu_queue->memcpy(downWeight.Data(), downWeightT.Data(), downWeiRows * downWeiCols * sizeof(WeiT)).wait(); +#else + downWeight.Resize(it.second - it.first, hiddenSize); ctx->mmHelper->packWeight(trans, quantizedDownWeight, downWeight); +#endif #ifdef DEBUG dbg.debugPrint("quantizedGateWeight:\n"); @@ -100,10 +125,7 @@ class LlamaMLP : public SingletonBase> { #endif // LlamaRMSNorm - if (normW) { - normWeight.Resize(hiddenSize); - memcpy(normWeight.Data(), normW, sizeof(float) * hiddenSize); - } + if (normW) { norm->setWeight(normW, nullptr, hiddenSize); } } #ifdef DEBUG @@ -125,8 +147,7 @@ class LlamaMLP : public SingletonBase> { (ImT *)ctx->normBuf.Data(), ctx->normBuf.Rows(), ctx->normBuf.Cols(), ctx->normBuf.Stride()); if (doLnBefore == true) { - xft::rmsNorm(normBuffer.Data(), inBuffer.Data(), normWeight.Data(), M, hiddenSize, inBuffer.Stride(), - normBuffer.Stride(), 1e-6); + norm->forward(inBuffer.Data(), normBuffer.Data(), M, inBuffer.Stride(), normBuffer.Stride(), 1e-6); } #ifdef DEBUG @@ -275,7 +296,8 @@ class LlamaMLP : public SingletonBase> { } } - void catGateUpProj(DecoderContext *ctx, hpj::Matrix &input, hpj::Matrix &output, hpj::Matrix &siluBuf) { + void catGateUpProj( + DecoderContext *ctx, hpj::Matrix &input, hpj::Matrix &output, hpj::Matrix &siluBuf) { TimeLine t("catGateUpProj"); assert(input.Rows() == output.Rows()); @@ -362,7 +384,7 @@ class LlamaMLP : public SingletonBase> { hpj::Vector downWeightSum; // For int8_t weight // LlamaRMSNorm param - hpj::Vector normWeight; + NORM_CLS *norm; #ifdef DEBUG Debugger dbg; diff --git a/src/layers/rms_norm.cpp b/src/layers/rms_norm.cpp index dc54cdd8..43727e54 100644 --- a/src/layers/rms_norm.cpp +++ b/src/layers/rms_norm.cpp @@ -21,6 +21,11 @@ #include "rms_norm.h" #include "rmsnorm_kernels.h" #include "timeline.h" +#include "transformer_ctx.h" + +#ifdef GPU +#include "gpudnn/gpu_layernorm_kernels.h" +#endif namespace xft { @@ -29,14 +34,26 @@ RmsNorm::RmsNorm() { normSize = 0; } +RmsNorm::RmsNorm(DecoderContext *ctx) { + device = ctx->device; + weight = nullptr; + normSize = 0; +} + RmsNorm::~RmsNorm() { - if (weight) { free(weight); } + if (weight) { xft::dealloc(weight); } } void RmsNorm::setWeight(const float *w, const float *, int cols) { this->normSize = cols; +#ifdef GPU + sycl::queue *gpu_queue = static_cast(device); + this->weight = sycl::malloc_device(cols, *gpu_queue); + gpu_queue->memcpy(this->weight, w, cols * sizeof(float)).wait(); +#else this->weight = (float *)xft::alloc(cols * sizeof(float)); memcpy(weight, w, cols * sizeof(float)); +#endif } void RmsNorm::setWeight(const std::string &modelPath, const std::string &, int cols) { @@ -44,6 +61,22 @@ void RmsNorm::setWeight(const std::string &modelPath, const std::string &, int c loadWeight(modelPath, weight, cols); } +#ifdef GPU +void RmsNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { + TimeLine t("RmsNorm.forward"); + sycl::queue *gpu_queue = static_cast(device); + fastertransformer::invokeGeneralT5LayerNorm( + output, input, weight, (const float *)nullptr, epsilon, rows, iStride, gpu_queue); +} + +void RmsNorm::forward(const float *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { + TimeLine t("RmsNorm.forward"); +} + +void RmsNorm::forward(const bfloat16_t *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { + TimeLine t("RmsNorm.forward"); +} +#else // input and output are in shape of (rows, normSize) void RmsNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); @@ -59,5 +92,5 @@ void RmsNorm::forward(const bfloat16_t *input, bfloat16_t *output, int rows, int TimeLine t("RmsNorm.forward"); rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); } - +#endif } // namespace xft \ No newline at end of file diff --git a/src/layers/rms_norm.h b/src/layers/rms_norm.h index 7babe853..420deb11 100644 --- a/src/layers/rms_norm.h +++ b/src/layers/rms_norm.h @@ -16,6 +16,7 @@ #include "bfloat16.h" #include "weight_util.h" +#include "transformer_ctx.h" namespace xft { @@ -23,6 +24,7 @@ namespace xft { class RmsNorm { public: RmsNorm(); + RmsNorm(DecoderContext *ctx); ~RmsNorm(); void setWeight(const float *w, const float *, int cols); @@ -44,6 +46,7 @@ class RmsNorm { // the scale weight float *weight = nullptr; + void *device = nullptr; }; } // namespace xft \ No newline at end of file diff --git a/src/layers/rotary_embedding.cpp b/src/layers/rotary_embedding.cpp index ff8cbae5..de5fd099 100644 --- a/src/layers/rotary_embedding.cpp +++ b/src/layers/rotary_embedding.cpp @@ -42,6 +42,7 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) { } llamaCalEmb(inv_freq, max_position_embeddings); #ifdef GPU + device = ctx->device; if (device != nullptr) { sycl::queue *gpu_queue = static_cast(device); float *emb_cos_bak = emb_cos; diff --git a/src/models/llama.cpp b/src/models/llama.cpp index 75c520cd..dfdbaa8c 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -31,6 +31,7 @@ LlamaLLM::LlamaLLM(const std::string &modelPath) setEmbeddingWeights(modelPath); // Final LN + finalLN = new RmsNorm(ctx); setFinalLnWeight(modelPath); } @@ -46,7 +47,7 @@ void LlamaLLM::setEmbeddingWeights(const std::string &modelPath) template void LlamaLLM::setFinalLnWeight(const std::string &modelPath) { - finalLN.setWeight(modelPath + "/model.final_layernorm.weight.bin", "", embedding->getHiddenSize()); + finalLN->setWeight(modelPath + "/model.final_layernorm.weight.bin", "", embedding->getHiddenSize()); } // Prepare attention_mask which is like: @@ -116,12 +117,12 @@ void LlamaLLM::embeddingForward(int *ids, bfloat16_t *output, in template void LlamaLLM::lastLayerNormForward(float *input, float *output, int rows) { - finalLN.forward(input, output, rows); + finalLN->forward(input, output, rows); } template void LlamaLLM::lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows) { - finalLN.forward(input, output, rows); + finalLN->forward(input, output, rows); } IMPLEMENT_MODEL(LlamaLLM, llama) \ No newline at end of file diff --git a/src/models/llama.h b/src/models/llama.h index 5fad1e24..fdf4782d 100644 --- a/src/models/llama.h +++ b/src/models/llama.h @@ -46,7 +46,7 @@ class LlamaLLM private: TokenEmbedding *embedding; - RmsNorm finalLN; + RmsNorm *finalLN; }; REGISTER_MODEL(LlamaLLM, llama) \ No newline at end of file From 39ec0b4687233744a3674a085deb3393631bf6e3 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Mon, 27 May 2024 11:28:20 +0800 Subject: [PATCH 05/34] Fix some issues. --- src/common/allocator.h | 8 ++++++-- src/layers/layer_norm.cpp | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/common/allocator.h b/src/common/allocator.h index afd8d1a7..cb9dd281 100644 --- a/src/common/allocator.h +++ b/src/common/allocator.h @@ -15,8 +15,8 @@ #pragma once #include #include -#include #include "environment.h" +#include #ifdef GPU #include @@ -33,11 +33,15 @@ static inline bool is_thp_alloc(size_t nbytes) { static inline void *alloc(size_t nbytes, size_t alignment = 64, void *device = nullptr) { if (nbytes == 0) { return nullptr; } - void *data; + void *data = nullptr; #ifdef GPU if (device != nullptr) { data = sycl::malloc_device(nbytes, *static_cast(device)); + if (data == nullptr) { + printf("Unable to allocate buffer with size of %zu in GPU.\n", nbytes); + exit(-1); + } return data; } #endif diff --git a/src/layers/layer_norm.cpp b/src/layers/layer_norm.cpp index c5428d37..332feb02 100644 --- a/src/layers/layer_norm.cpp +++ b/src/layers/layer_norm.cpp @@ -47,8 +47,8 @@ void LayerNorm::setWeight(const float *gamma, const float *beta, int cols) { this->normSize = cols; #ifdef GPU sycl::queue *gpu_queue = static_cast(device); - this->gamma = sycl::malloc_device(cols, *gpu_queue); - this->beta = sycl::malloc_device(cols, *gpu_queue); + this->gamma = (float *)xft::alloc(cols * sizeof(float), 64, *gpu_queue); + this->beta = (float *)xft::alloc(cols * sizeof(float), 64, *gpu_queue); gpu_queue->memcpy(this->gamma, gamma, cols * sizeof(float)).wait(); gpu_queue->memcpy(this->beta, beta, cols * sizeof(float)).wait(); #else From 2ef7f7aa7ff9aff178acb0dc5293027947894577 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Mon, 27 May 2024 12:31:25 +0800 Subject: [PATCH 06/34] Optimze alloc --- src/common/allocator.h | 19 ++++++++++++++++--- src/common/transformer_ctx.h | 4 ++-- src/layers/attention.h | 19 +++++++------------ src/layers/dist_linear.h | 5 ++--- src/layers/layer_norm.cpp | 16 ++++------------ src/layers/mlp_llama.h | 17 ++++++++--------- src/layers/rms_norm.cpp | 10 ++-------- src/layers/rotary_embedding.cpp | 9 ++++----- src/models/common_decoder.h | 2 +- src/utils/simple_mem_pool.h | 2 +- 10 files changed, 47 insertions(+), 56 deletions(-) diff --git a/src/common/allocator.h b/src/common/allocator.h index cb9dd281..99b31d3f 100644 --- a/src/common/allocator.h +++ b/src/common/allocator.h @@ -15,6 +15,7 @@ #pragma once #include #include +#include #include "environment.h" #include @@ -30,14 +31,15 @@ static inline bool is_thp_alloc(size_t nbytes) { return (Env::getInstance().getTHPEnabled() && (nbytes >= g_thp_threshold)); } -static inline void *alloc(size_t nbytes, size_t alignment = 64, void *device = nullptr) { +static inline void *alloc(size_t nbytes, void *device = nullptr, size_t alignment = 64) { if (nbytes == 0) { return nullptr; } void *data = nullptr; #ifdef GPU if (device != nullptr) { - data = sycl::malloc_device(nbytes, *static_cast(device)); + sycl::queue *gpu_queue = static_cast(device); + data = sycl::malloc_device(nbytes, *gpu_queue); if (data == nullptr) { printf("Unable to allocate buffer with size of %zu in GPU.\n", nbytes); exit(-1); @@ -72,7 +74,18 @@ static inline void dealloc(void *data, void *device = nullptr) { #endif free(data); - return; +} + +static inline void memcopy(void *dst, const void *src, size_t size, void *device = nullptr) { +#ifdef GPU + if (device != nullptr) { + sycl::queue *gpu_queue = static_cast(device); + gpu_queue->memcpy(dst, src, size).wait(); + return; + } +#endif + + memcpy(dst, src, size); } } // namespace xft \ No newline at end of file diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index e955bf2f..91041554 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -110,8 +110,8 @@ struct DecoderContext { hpj::Matrix qkvMatMul; // query, key, value hpj::Matrix imOut; // intermediate output - MMHelper *mmHelper; - void *device; + MMHelper *mmHelper = nullptr; + void *device = nullptr; std::string configPath; INIReader configReader; diff --git a/src/layers/attention.h b/src/layers/attention.h index 7b23581e..ab3400f1 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -142,11 +142,9 @@ class Attention { qkvWeightT.Resize(hiddenSize, responsibleCols); ctx->mmHelper->transposeWeight(true, convertedqkvWeight, qkvWeightT); - sycl::queue *gpu_queue_1 = static_cast(ctx->device); - WeiT *qkvWeiData = sycl::malloc_device(hiddenSize * responsibleCols, *gpu_queue_1); + WeiT *qkvWeiData = xft::alloc(hiddenSize * responsibleCols * sizeof(WeiT), ctx->device); qkvWeight.Assign(qkvWeiData, hiddenSize, responsibleCols, responsibleCols); - gpu_queue_1->memcpy(qkvWeight.Data(), qkvWeightT.Data(), qkvWeightT.Rows() * qkvWeightT.Cols() * sizeof(WeiT)) - .wait(); + xft::memcopy(qkvWeight.Data(), qkvWeightT.Data(), hiddenSize * responsibleCols * sizeof(WeiT), ctx->device); #else qkvWeight.Resize(hiddenSize, responsibleCols); ctx->mmHelper->packWeight(trans, convertedqkvWeight, qkvWeight); @@ -187,11 +185,10 @@ class Attention { outWeightT.Resize(ctx->attHeadNum * ctx->attHeadSize, hiddenSize); ctx->mmHelper->transposeWeight(true, convertedOutWeight, outWeightT); - sycl::queue *gpu_queue_2 = static_cast(ctx->device); - WeiT *outWeiData = sycl::malloc_device(ctx->attHeadNum * ctx->attHeadSize * hiddenSize, *gpu_queue_2); + WeiT *outWeiData = xft::alloc(ctx->attHeadNum * ctx->attHeadSize * hiddenSize * sizeof(WeiT), ctx->device); attnOutputWeight.Assign(outWeiData, ctx->attHeadNum * ctx->attHeadSize, hiddenSize, hiddenSize); - int outWeightTSize = outWeightT.Rows() * outWeightT.Cols() * sizeof(WeiT); - gpu_queue_2->memcpy(attnOutputWeight.Data(), outWeightT.Data(), outWeightTSize).wait(); + int outWeightTSize = ctx->attHeadNum * ctx->attHeadSize * hiddenSize * sizeof(WeiT); + xft::memcopy(attnOutputWeight.Data(), outWeightT.Data(), outWeightTSize, ctx->device); #else attnOutputWeight.Resize(ctx->attHeadNum * ctx->attHeadSize, hiddenSize); ctx->mmHelper->packWeight(trans, convertedOutWeight, attnOutputWeight); @@ -324,9 +321,8 @@ class Attention { } qkpo.forward(query.Data(), key.Data(), query.Stride(), key.Stride(), qkShape, posIds.data()); #ifdef GPU - sycl::queue *q = static_cast(ctx->device); int64_t size = ctx->batchSize * ctx->inputSeqLen * qkvCols * sizeof(float); - q->memcpy(qkvMatMul.Data(), query.Data(), size).wait(); // error: need CPU ptr and GPU ptr + xft::memcopy(qkvMatMul.Data(), query.Data(), size, ctx->device); // error: need CPU ptr and GPU ptr #endif } t3.release(); @@ -374,9 +370,8 @@ class Attention { #endif #ifdef GPU - sycl::queue *q = static_cast(ctx->device); int64_t size = ctx->batchSize * ctx->inputSeqLen * qkvCols * sizeof(float); - q->memcpy(qkvMatMul.Data(), attnSplit.Data(), size).wait(); // error: need CPU ptr and GPU ptr + xft::memcopy(qkvMatMul.Data(), attnSplit.Data(), size, ctx->device); // error: need CPU ptr and GPU ptr #endif TimeLine t5("Output"); diff --git a/src/layers/dist_linear.h b/src/layers/dist_linear.h index f95dd791..e0a9563d 100644 --- a/src/layers/dist_linear.h +++ b/src/layers/dist_linear.h @@ -71,10 +71,9 @@ class DistLinear { tWeight.Resize(K, N); ctx->mmHelper->transposeWeight(true, quantizedWeight, tWeight); - sycl::queue *gpu_queue = static_cast(ctx->device); - WeiT *input_data = sycl::malloc_device(K * N, *gpu_queue); + WeiT *input_data = xft::alloc(K * N * sizeof(WeiT), ctx->device); weight.Assign(input_data, K, N, N); - gpu_queue->memcpy(weight.Data(), tWeight.Data(), tWeight.Rows() * tWeight.Cols() * sizeof(WeiT)).wait(); + xft::memcopy(weight.Data(), tWeight.Data(), tWeight.Rows() * tWeight.Cols() * sizeof(WeiT), ctx->device); #else weight.Resize(K, N); ctx->mmHelper->packWeight(true, quantizedWeight, weight); diff --git a/src/layers/layer_norm.cpp b/src/layers/layer_norm.cpp index 332feb02..761a8fbb 100644 --- a/src/layers/layer_norm.cpp +++ b/src/layers/layer_norm.cpp @@ -45,18 +45,10 @@ LayerNorm::~LayerNorm() { void LayerNorm::setWeight(const float *gamma, const float *beta, int cols) { this->normSize = cols; -#ifdef GPU - sycl::queue *gpu_queue = static_cast(device); - this->gamma = (float *)xft::alloc(cols * sizeof(float), 64, *gpu_queue); - this->beta = (float *)xft::alloc(cols * sizeof(float), 64, *gpu_queue); - gpu_queue->memcpy(this->gamma, gamma, cols * sizeof(float)).wait(); - gpu_queue->memcpy(this->beta, beta, cols * sizeof(float)).wait(); -#else - this->gamma = (float *)xft::alloc(cols * sizeof(float)); - this->beta = (float *)xft::alloc(cols * sizeof(float)); - memcpy(this->gamma, gamma, cols * sizeof(float)); - memcpy(this->beta, beta, cols * sizeof(float)); -#endif + this->gamma = (float *)xft::alloc(cols * sizeof(float), device); + this->beta = (float *)xft::alloc(cols * sizeof(float), device); + xft::memcopy(this->gamma, gamma, cols * sizeof(float), device); + xft::memcopy(this->beta, beta, cols * sizeof(float), device); } void LayerNorm::setWeight(const std::string &gammaPath, const std::string &betaPath, int cols) { diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index 7b3ca792..5bc099a4 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -85,10 +85,9 @@ class LlamaMLP { catWeightsT.Resize(catWeiRows, catWeiCols); ctx->mmHelper->transposeWeight(true, quantizedCatWeights, catWeightsT); - sycl::queue *gpu_queue = static_cast(ctx->device); - WeiT *catWeiData = sycl::malloc_device(catWeiRows * catWeiCols, *gpu_queue); + WeiT *catWeiData = xft::alloc(catWeiRows * catWeiCols * sizeof(WeiT), ctx->device); catWeights.Assign(catWeiData, catWeiRows, catWeiCols, catWeiCols); - gpu_queue->memcpy(catWeights.Data(), catWeightsT.Data(), catWeiRows * catWeiCols * sizeof(WeiT)).wait(); + xft::memcopy(catWeights.Data(), catWeightsT.Data(), catWeiRows * catWeiCols * sizeof(WeiT), ctx->device); #else catWeights.Resize(quantizedCatWeights.Rows(), quantizedCatWeights.Cols()); ctx->mmHelper->packWeight(trans, quantizedCatWeights, catWeights); @@ -104,10 +103,9 @@ class LlamaMLP { downWeightT.Resize(downWeiRows, downWeiCols); ctx->mmHelper->transposeWeight(true, quantizedDownWeight, downWeightT); - sycl::queue *gpu_queue = static_cast(ctx->device); - WeiT *downWeiData = sycl::malloc_device(downWeiRows * downWeiCols, *gpu_queue); + WeiT *downWeiData = xft::alloc(downWeiRows * downWeiCols * sizeof(WeiT), ctx->device); downWeight.Assign(downWeiData, downWeiRows, downWeiCols, downWeiCols); - gpu_queue->memcpy(downWeight.Data(), downWeightT.Data(), downWeiRows * downWeiCols * sizeof(WeiT)).wait(); + xft::memcopy(downWeight.Data(), downWeightT.Data(), downWeiRows * downWeiCols * sizeof(WeiT), ctx->device); #else downWeight.Resize(it.second - it.first, hiddenSize); ctx->mmHelper->packWeight(trans, quantizedDownWeight, downWeight); @@ -296,8 +294,9 @@ class LlamaMLP { } } + template void catGateUpProj( - DecoderContext *ctx, hpj::Matrix &input, hpj::Matrix &output, hpj::Matrix &siluBuf) { + DecoderContext *ctx, hpj::Matrix &input, hpj::Matrix &output, hpj::Matrix &siluBuf) { TimeLine t("catGateUpProj"); assert(input.Rows() == output.Rows()); @@ -307,12 +306,12 @@ class LlamaMLP { int M = input.Rows(), N = output.Cols(), K = input.Cols(); int lda = input.Stride(), ldc = output.Stride(); - const InT *A = input.Data(); + const T1 *A = input.Data(); const WeiT *B = catWeights.Data(); const float *scaleB = catWeightsScale.Data(); const float *zeroB = catWeightsZero.Data(); const float *sumB = catWeightsSum.Data(); - ImT *C = output.Data(); + T2 *C = output.Data(); ctx->mmHelper->compute(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, sumB, 0.0f, C, ldc); diff --git a/src/layers/rms_norm.cpp b/src/layers/rms_norm.cpp index 43727e54..8feae8d2 100644 --- a/src/layers/rms_norm.cpp +++ b/src/layers/rms_norm.cpp @@ -46,14 +46,8 @@ RmsNorm::~RmsNorm() { void RmsNorm::setWeight(const float *w, const float *, int cols) { this->normSize = cols; -#ifdef GPU - sycl::queue *gpu_queue = static_cast(device); - this->weight = sycl::malloc_device(cols, *gpu_queue); - gpu_queue->memcpy(this->weight, w, cols * sizeof(float)).wait(); -#else - this->weight = (float *)xft::alloc(cols * sizeof(float)); - memcpy(weight, w, cols * sizeof(float)); -#endif + this->weight = (float *)xft::alloc(cols * sizeof(float), device); + xft::memcopy(this->weight, w, cols * sizeof(float), device); } void RmsNorm::setWeight(const std::string &modelPath, const std::string &, int cols) { diff --git a/src/layers/rotary_embedding.cpp b/src/layers/rotary_embedding.cpp index de5fd099..1ca8ffb8 100644 --- a/src/layers/rotary_embedding.cpp +++ b/src/layers/rotary_embedding.cpp @@ -44,13 +44,12 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) { #ifdef GPU device = ctx->device; if (device != nullptr) { - sycl::queue *gpu_queue = static_cast(device); float *emb_cos_bak = emb_cos; float *emb_sin_bak = emb_sin; - emb_cos = ctx->getBuffer(emb_cos_str + "_gpu", max_position_embeddings * inv_freq_size, gpu_queue); - emb_sin = ctx->getBuffer(emb_sin_str + "_gpu", max_position_embeddings * inv_freq_size, gpu_queue); - gpu_queue->memcpy(emb_cos, emb_cos_bak, max_position_embeddings * inv_freq_size * sizeof(float)).wait(); - gpu_queue->memcpy(emb_sin, emb_sin_bak, max_position_embeddings * inv_freq_size * sizeof(float)).wait(); + emb_cos = ctx->getBuffer(emb_cos_str + "_gpu", max_position_embeddings * inv_freq_size, device); + emb_sin = ctx->getBuffer(emb_sin_str + "_gpu", max_position_embeddings * inv_freq_size, device); + xft::memcopy(emb_cos, emb_cos_bak, max_position_embeddings * inv_freq_size * sizeof(float), device); + xft::memcopy(emb_sin, emb_sin_bak, max_position_embeddings * inv_freq_size * sizeof(float), device); ctx->freeBuffer(emb_cos_str); ctx->freeBuffer(emb_sin_str); } diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index 6396c774..c2909f15 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -667,7 +667,7 @@ class CommonDecoder : public AbstractDecoder { int kvSize = attHeadSize * kvHeadNum; int qkvSize = qSize + 2 * kvSize; -#define ALLOC(size, alignment) xft::alloc((size), (alignment)) +#define ALLOC(size, alignment) xft::alloc((size), nullptr, (alignment)) OriWeiT *qkvWeight = (OriWeiT *)ALLOC(hiddenSize * qkvSize * sizeof(OriWeiT), 64); float *qkvScales = nullptr; float *qkvZeros = nullptr; diff --git a/src/utils/simple_mem_pool.h b/src/utils/simple_mem_pool.h index 63f7a2fa..7ade4858 100644 --- a/src/utils/simple_mem_pool.h +++ b/src/utils/simple_mem_pool.h @@ -65,7 +65,7 @@ class SimpleMemPool { } // Allocate new aligned buffer - void *buffer = xft::alloc(size, alignment, device); + void *buffer = xft::alloc(size, device, alignment); if (buffer == nullptr) { // Allocation failed std::cerr << "Memory allocation failed for buffer:" << name << " size:" << size << std::endl; From 29f01c183735bf5cc2f2bae33c79513cfe70b2ee Mon Sep 17 00:00:00 2001 From: changqi1 Date: Mon, 27 May 2024 12:53:40 +0800 Subject: [PATCH 07/34] Use unified onednn engine --- src/models/common_decoder.h | 2 ++ src/utils/matmul_helper.h | 29 ++++++++++++++--------------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index c2909f15..8096a9ce 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -648,6 +648,8 @@ class CommonDecoder : public AbstractDecoder { #ifdef GPU auto devices = sycl::device::get_devices(sycl::info::device_type::gpu); this->context->device = new sycl::queue(devices[this->context->mmHelper->getEngineCount() + engineIdx]); +#else + this->context->device = nullptr; #endif } diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index ca922036..82bdd9c5 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -53,6 +53,8 @@ class MMHelper { } AMXThresholdM = Env::getInstance().getAMXThresholdM(); + cpu_engine = new dnnl::engine(dnnl::engine::kind::cpu, 0); + cpu_stream = new dnnl::stream(*cpu_engine); } ~MMHelper() { @@ -349,11 +351,9 @@ class MMHelper { // W8A8 else if constexpr (std::is_same_v) { using dt = dnnl::memory::data_type; - dnnl::engine eng(dnnl::engine::kind::cpu, 0); - dnnl::stream stm(eng); auto tag = trans ? dnnl::memory::format_tag::ba : dnnl::memory::format_tag::ab; - dnnl::memory B_mem({{K, N}, dt::s8, tag}, eng, src.Data()); + dnnl::memory B_mem({{K, N}, dt::s8, tag}, *cpu_engine, src.Data()); dnnl::memory::desc desc({K, N}, dt::s8, get_onednn_weight_layout(dt::s8)); // When converting to oneDNN blocked memory format, padded dims can be larger than [K, N] @@ -363,9 +363,9 @@ class MMHelper { weight.Resize(dims[0], dims[1]); weight.Resize(K, N); - dnnl::memory packedB_mem(desc, eng, weight.Data()); - dnnl::reorder(B_mem, packedB_mem).execute(stm, B_mem, packedB_mem); - stm.wait(); + dnnl::memory packedB_mem(desc, *cpu_engine, weight.Data()); + dnnl::reorder(B_mem, packedB_mem).execute(*cpu_stream, B_mem, packedB_mem); + cpu_stream->wait(); } // INT4 @@ -419,15 +419,12 @@ class MMHelper { int K = trans ? src.Cols() : src.Rows(); int N = trans ? src.Rows() : src.Cols(); - - dnnl::engine engine(dnnl::engine::kind::cpu, 0); - dnnl::stream stream(engine); auto weight_md = memory::desc({K, N}, weight_dt, trans ? tag::ba : tag::ab); - auto weight_mem = memory(weight_md, engine, src.Data()); + auto weight_mem = memory(weight_md, *cpu_engine, src.Data()); auto transposed_weight_md = memory::desc({K, N}, weight_dt, get_onednn_weight_layout(weight_dt)); - auto transposed_weight_mem = memory(transposed_weight_md, engine, dst.Data()); - dnnl::reorder(weight_mem, transposed_weight_mem).execute(stream, weight_mem, transposed_weight_mem); - stream.wait(); + auto transposed_weight_mem = memory(transposed_weight_md, *cpu_engine, dst.Data()); + dnnl::reorder(weight_mem, transposed_weight_mem).execute(*cpu_stream, weight_mem, transposed_weight_mem); + cpu_stream->wait(); } template @@ -1323,9 +1320,11 @@ class MMHelper { private: dnnl::engine::kind kind; - dnnl::engine *engine; - dnnl::stream *stream; + dnnl::engine *engine; // For runtime engine + dnnl::stream *stream; // For runtime stream std::unordered_map> matmul_hub; + dnnl::engine *cpu_engine; + dnnl::stream *cpu_stream; int AMXThresholdM; From 17f224e5ff1f446cb51b1b77179122807e569a26 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Mon, 27 May 2024 15:41:15 +0800 Subject: [PATCH 08/34] Fix compile --- src/layers/attention.h | 8 ++++---- src/layers/decoder_block.h | 2 +- src/layers/dist_linear.h | 2 +- src/layers/mlp_llama.cpp | 2 +- src/layers/mlp_llama.h | 4 ++-- src/models/llama.cpp | 2 +- src/utils/matmul_helper.h | 2 +- 7 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/layers/attention.h b/src/layers/attention.h index c56e3a5d..d731ae4f 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -141,7 +141,7 @@ class Attention { convertedqkvWeight, qkvWeightScale, qkvWeightZero, qkvWeightSum); #ifdef GPU - hpj::Matrix qkvWeightT; + xft::Matrix qkvWeightT; qkvWeightT.Resize(hiddenSize, responsibleCols); ctx->mmHelper->transposeWeight(true, convertedqkvWeight, qkvWeightT); @@ -184,7 +184,7 @@ class Attention { attnOutputWeightScale, attnOutputWeightZero, attnOutputWeightSum, true); #ifdef GPU - hpj::Matrix outWeightT; + xft::Matrix outWeightT; outWeightT.Resize(ctx->attHeadNum * ctx->attHeadSize, hiddenSize); ctx->mmHelper->transposeWeight(true, convertedOutWeight, outWeightT); @@ -423,7 +423,7 @@ class Attention { if (doLnAfter) { TimeLine t6("result.layer_norm"); - norm.forward(outBuffer.Data(), outBuffer.Data(), outBuffer.Rows(), outBuffer.Stride(), outBuffer.Stride()); + norm->forward(outBuffer.Data(), outBuffer.Data(), outBuffer.Rows(), outBuffer.Stride(), outBuffer.Stride()); #ifdef DEBUG dbg.debugPrint("LayerNorm after attention: [%d, %d] (%d)\n", outBuffer.Rows(), outBuffer.Cols(), outBuffer.Stride()); @@ -466,7 +466,7 @@ class Attention { if (doLnBefore) { TimeLine t1("input.layer_norm"); - norm.forward(inputBuffer.Data(), imBuffer.Data(), inputBuffer.Rows(), inputBuffer.Stride(), + norm->forward(inputBuffer.Data(), imBuffer.Data(), inputBuffer.Rows(), inputBuffer.Stride(), imBuffer.Stride(), epsilon); } #ifdef DEBUG diff --git a/src/layers/decoder_block.h b/src/layers/decoder_block.h index 4a18c13d..ba352066 100644 --- a/src/layers/decoder_block.h +++ b/src/layers/decoder_block.h @@ -147,7 +147,7 @@ class DecoderBlock { int kvSize = attHeadSize * kvHeadNum; int qkvSize = qSize + 2 * kvSize; -#define ALLOC(size, alignment) xft::alloc((size), (alignment)) +#define ALLOC(size, alignment) xft::alloc((size), nullptr, (alignment)) OriWeiT *qkvWeight = (OriWeiT *)ALLOC(hiddenSize * qkvSize * sizeof(OriWeiT), 64); float *qkvScales = nullptr; float *qkvZeros = nullptr; diff --git a/src/layers/dist_linear.h b/src/layers/dist_linear.h index a6cd2d91..a52f0bbe 100644 --- a/src/layers/dist_linear.h +++ b/src/layers/dist_linear.h @@ -67,7 +67,7 @@ class DistLinear { ctx->mmHelper->convertWeight( true, K, N, w + splitOffset * K, nullptr, nullptr, quantizedWeight, scaleWeight, zeroWeight, sumWeight); #ifdef GPU - hpj::Matrix tWeight; + xft::Matrix tWeight; tWeight.Resize(K, N); ctx->mmHelper->transposeWeight(true, quantizedWeight, tWeight); diff --git a/src/layers/mlp_llama.cpp b/src/layers/mlp_llama.cpp index 816c5d69..749b39e0 100644 --- a/src/layers/mlp_llama.cpp +++ b/src/layers/mlp_llama.cpp @@ -58,7 +58,7 @@ void MLPLLaMAImpl(DataType dt, ActivationType at, int numTokens, int hiddenSize, auto it_created = llama_mlp_hub.find(llama_mlp_key); if (it_created == llama_mlp_hub.end()) { // MLP &llama_mlp = MLP::getInstance(); - llama_mlp = new MLP(); + llama_mlp = new MLP(ctx); llama_mlp->setWeights(ctx, (float *)gateWeight, nullptr, nullptr, nullptr, (float *)upWeight, nullptr, nullptr, nullptr, nullptr, nullptr, (float *)downWeight, nullptr, nullptr, false); llama_mlp_hub[llama_mlp_key] = llama_mlp; diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index a93bd2c7..094edb5a 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -79,7 +79,7 @@ class LlamaMLP { quantizedUpWeight.Release(); #ifdef GPU - hpj::Matrix catWeightsT; + xft::Matrix catWeightsT; int catWeiRows = quantizedCatWeights.Rows(); int catWeiCols = quantizedCatWeights.Cols(); catWeightsT.Resize(catWeiRows, catWeiCols); @@ -97,7 +97,7 @@ class LlamaMLP { ctx->mmHelper->convertWeight(ctx, trans, imSize, hiddenSize, downW, downS, downZ, false, quantizedDownWeight, downWeightScale, downWeightZero, downWeightSum); #ifdef GPU - hpj::Matrix downWeightT; + xft::Matrix downWeightT; int downWeiRows = it.second - it.first; int downWeiCols = hiddenSize; downWeightT.Resize(downWeiRows, downWeiCols); diff --git a/src/models/llama.cpp b/src/models/llama.cpp index 336d5bb6..8c2ba400 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -132,7 +132,7 @@ void LlamaLLM::lastLayerNormForward(bfloat16_t *input, bfloat16_ template void LlamaLLM::lastLayerNormForward(float16_t *input, float16_t *output, int rows) { - finalLN.forward(input, output, rows); + finalLN->forward(input, output, rows); } IMPLEMENT_MODEL(LlamaLLM, llama) \ No newline at end of file diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index 01030c75..aa03b0dd 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -400,7 +400,7 @@ class MMHelper { } template - void transposeWeight(bool trans, hpj::Matrix &src, hpj::Matrix &dst) { + void transposeWeight(bool trans, xft::Matrix &src, xft::Matrix &dst) { using namespace dnnl; using tag = memory::format_tag; using dt = memory::data_type; From b61fe526e1eaedf5e2043212f65e723cf0c2e3cd Mon Sep 17 00:00:00 2001 From: changqi1 Date: Mon, 27 May 2024 18:20:41 +0800 Subject: [PATCH 09/34] Fix onednn gemm issue --- src/utils/matmul_helper.h | 131 +++++++++++++++++++++++--------------- 1 file changed, 80 insertions(+), 51 deletions(-) diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index aa03b0dd..3690b869 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -1339,10 +1339,11 @@ class MMHelper { Resext, }; - std::string create_key(bool transA, int M, int N, int K, int matmul_kind) { - std::string key = std::to_string(transA) + "_" + std::to_string(M) + "_" + std::to_string(N) + "_" - + std::to_string(K) + "_" + std::to_string(matmul_kind); - return key; + template + std::string create_key(bool transA, int M, int N, int K, int matmul_kind, const Twei *packedB) { + std::stringstream key; + key << transA << "_" << M << "_" << N << "_" << K << "_" << matmul_kind << "_" << packedB; + return key.str(); } dnnl::memory::format_tag get_onednn_input_layout(dnnl::memory::data_type dt) { @@ -1359,8 +1360,10 @@ class MMHelper { dnnl::memory::format_tag get_onednn_weight_layout(dnnl::memory::data_type dt) { if (this->kind == dnnl::engine::kind::cpu) { - if (dt == dnnl::memory::data_type::bf16 || dt == dnnl::memory::data_type::f16) { + if (dt == dnnl::memory::data_type::bf16) { return dnnl::memory::format_tag::BA16a64b2a; + } else if (dt == dnnl::memory::data_type::f16) { + return dnnl::memory::format_tag::BA16a64b; } else if (dt == dnnl::memory::data_type::s8) { return dnnl::memory::format_tag::BA16a64b4a; } else { @@ -1410,9 +1413,26 @@ class MMHelper { } } - template + // Tin | Twei | Tout | Tbias | matmul + // --- | ---- | ---- | ----- | ------ + // f32 | f32 | f32 | f32 | sgemm + // f32 | f32 | f16 | f32 | sgemm_f32f32f16 + // f32 | f32 | bf16 | f32 | sgemm_f32f32bf16 + // f16 | f32 | f32 | f32 | sgemm_f16f32f32 + // bf16| f32 | f32 | f32 | sgemm_bf16f32f32 + // f16 | f32 | f16 | f32 | sgemm_f16f32f16 + // bf16| f32 | bf16 | f32 | sgemm_bf16f32bf16 + // f32 | f16 | f32 | f32 | hgemm_f32f16f32 + // f32 | f16 | f16 | f32 | hgemm_f32f16f16 + // f16 | f16 | f32 | f32 | hgemm_f16f16f32 + // f16 | f16 | f16 | f32 | hgemm + // f32 | bf16 | f32 | f32 | bgemm_f32bf16f32 + // f32 | bf16 | bf16 | f32 | bgemm_f32bf16bf16 + // bf16| bf16 | f32 | f32 | bgemm_bf16bf16f32 + // bf16| bf16 | bf16 | f32 | bgemm + template void onednn_gemm_compute(bool transA, int M, int N, int K, float alpha, const Tin *A, int lda, const Twei *packedB, - float beta, Tout *C, int ldc, const Tbias *bias = nullptr, const Tres *res = nullptr, int ldres = -1, + float beta, Tout *C, int ldc, const Tbias *bias = nullptr, const Tin *res = nullptr, int ldres = -1, const matmul_kinds postAlg = matmul_kinds::Basic) { TimeLine t("onednn_gemm_compute"); TimeLine t1("onednn_gemm_compute.create_primitive"); @@ -1421,26 +1441,22 @@ class MMHelper { using dt = memory::data_type; dt input_dt; - if constexpr (std::is_same_v) { - input_dt = dt::f32; - } else if constexpr (std::is_same_v) { - input_dt = dt::bf16; - } else if constexpr (std::is_same_v) { - input_dt = dt::f16; - } else { - printf(">>> onednn_gemm_compute: input date type not supported."); - exit(-1); - } - dt weight_dt; + dt shift_dt; if constexpr (std::is_same_v) { + input_dt = dt::f32; weight_dt = dt::f32; + shift_dt = dt::f32; } else if constexpr (std::is_same_v) { + input_dt = dt::bf16; weight_dt = dt::bf16; + shift_dt = dt::bf16; } else if constexpr (std::is_same_v) { + input_dt = dt::f16; weight_dt = dt::f16; + shift_dt = dt::f16; } else { - printf(">>> onednn_gemm_compute: weight date type not supported."); + printf(">>> onednn_gemm_compute: input and weight date type not supported."); exit(-1); } @@ -1468,21 +1484,9 @@ class MMHelper { exit(-1); } - dt shift_dt; - if constexpr (std::is_same_v) { - shift_dt = dt::f32; - } else if constexpr (std::is_same_v) { - shift_dt = dt::bf16; - } else if constexpr (std::is_same_v) { - shift_dt = dt::f16; - } else { - printf(">>> onednn_gemm_compute: res date type not supported."); - exit(-1); - } - matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, postAlg); + std::string key = create_key(transA, M, N, K, postAlg, packedB); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -1557,7 +1561,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, postAlg); + std::string key = create_key(transA, M, N, K, postAlg, packedB); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -1570,12 +1574,21 @@ class MMHelper { input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); } - auto weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(packedB)); - auto output_mem = memory(matmul_pd->dst_desc(), *engine, C); + memory weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(packedB)); + memory output_mem = memory(matmul_pd->dst_desc(), *engine, C); memory bias_mem; if (bias != nullptr) { bias_mem = memory(matmul_pd->bias_desc(), *engine, const_cast(bias)); } - auto shift_md = memory::desc({M, N}, shift_dt, get_onednn_shift_layout(shift_dt)); - auto shift_mem = memory(shift_md, *engine, const_cast(res)); + + memory::desc shift_md; + memory shift_mem; + if (res != nullptr) { + shift_md = memory::desc({M, N}, shift_dt, get_onednn_shift_layout(shift_dt)); + if constexpr (std::is_same_v) { + shift_mem = memory(shift_md, *engine); + } else { + shift_mem = memory(shift_md, *engine, const_cast(res)); + } + } // Create the primitive args. std::unordered_map matmul_args; @@ -1589,10 +1602,26 @@ class MMHelper { // Executions. TimeLine t2("onednn_gemm_compute.execute_primitive"); // Reorder - if constexpr (std::is_same_v && std::is_same_v) { + if constexpr (std::is_same_v && !std::is_same_v) { #pragma omp parallel for for (uint64_t i = 0; i < M; ++i) { - bfloat16_t::cvt_float_to_bfloat16(A + i * lda, (bfloat16_t *)input_mem.get_data_handle() + i * K, K); + void *input_ptr = input_mem.get_data_handle(); + if constexpr (std::is_same_v) { + bfloat16_t::cvt_float_to_bfloat16(A + i * lda, (bfloat16_t *)input_ptr + i * K, K); + if (res != nullptr) { + void *shift_ptr = shift_mem.get_data_handle(); + bfloat16_t::cvt_float_to_bfloat16(res + i * lda, (bfloat16_t *)shift_ptr + i * K, K); + } + } else if constexpr (std::is_same_v) { + float16_t::cvt_float_to_float16(A + i * lda, (float16_t *)input_ptr + i * K, K); + if (res != nullptr) { + void *shift_ptr = shift_mem.get_data_handle(); + float16_t::cvt_float_to_float16(res + i * lda, (float16_t *)shift_ptr + i * K, K); + } + } else { + printf(">>> onednn_gemm_compute: input and res date type convert not supported."); + exit(-1); + } } } @@ -1611,7 +1640,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, postAlg); + std::string key = create_key(transA, M, N, K, postAlg, packedB); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -1667,7 +1696,7 @@ class MMHelper { } matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, postAlg); + std::string key = create_key(transA, M, N, K, postAlg, packedB); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -1717,7 +1746,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd); + std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd, packedB); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -1747,7 +1776,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd); + std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd, packedB); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -1799,7 +1828,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu); + std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu, packedB); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -1837,7 +1866,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu); + std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu, packedB); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -1889,7 +1918,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul); + std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul, packedB); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -1928,7 +1957,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul); + std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul, packedB); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -1996,7 +2025,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::Residential); + std::string key = create_key(transA, M, N, K, matmul_kinds::Residential, packedB); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -2042,7 +2071,7 @@ class MMHelper { } // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::Residential); + std::string key = create_key(transA, M, N, K, matmul_kinds::Residential, packedB); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -2101,7 +2130,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::Basic); + std::string key = create_key(transA, M, N, K, matmul_kinds::Basic, B); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -2123,7 +2152,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::Basic); + std::string key = create_key(transA, M, N, K, matmul_kinds::Basic, B); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } From c5b7ac711478cc3c25eec712ac9df7cd0a883587 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Mon, 3 Jun 2024 09:29:34 +0800 Subject: [PATCH 10/34] Fix build --- src/utils/matmul_helper.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index 3690b869..0e481c20 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -20,6 +20,7 @@ #include "dtype.h" #include "environment.h" #include "float16.h" +#include "intrinsics_util.h" #include "my_types.h" #include "normal_float4x2.h" #include "oneapi/dnnl/dnnl.hpp" From 0faa9deb06bfa2f9765f5c0db20b46029657ab56 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Tue, 4 Jun 2024 14:24:52 +0800 Subject: [PATCH 11/34] Add fp16 rope kernels --- src/kernels/rotary_embedding_kernels.cpp | 79 ++++++++++++++++++++++++ src/kernels/rotary_embedding_kernels.h | 14 ++++- src/layers/rms_norm.cpp | 3 + src/layers/rotary_embedding.cpp | 68 +++++--------------- src/layers/token_embedding.h | 2 + src/utils/verbose.h | 29 ++++++++- 6 files changed, 139 insertions(+), 56 deletions(-) diff --git a/src/kernels/rotary_embedding_kernels.cpp b/src/kernels/rotary_embedding_kernels.cpp index 99f52b08..91361a35 100644 --- a/src/kernels/rotary_embedding_kernels.cpp +++ b/src/kernels/rotary_embedding_kernels.cpp @@ -386,4 +386,83 @@ void qwenApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, i maxSupportedSeqLength, qkShape, positionIds); } +#ifdef GPU +// For LLaMA +template +static inline void llamaApplyRotaryPosEmbeding(void *device, T *query, T *key, int qStride, int kStride, float *emb_cos, + float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) { + int dim = inv_freq_size * 2; + REQUIRES(dim == qkShape[3], "Incorrect shape, this dimention is not the head size."); + + const int batchSize = qkShape[0]; + const int seqLen = qkShape[1]; + const int qHeads = qkShape[2]; + const int kHeads = qkShape[4]; + const int head_num = std::max(qHeads, kHeads); + const int head_size = qkShape[3]; + const int half_head_size = (head_size + 1) / 2; + using namespace sycl; + + auto rope_kernel + = [](sycl::nd_item<3> &item, const float *embCos, const float *embSin, const int qHeads, const int kHeads, + const int seq_size, const int head_size, const int half, T *query, T *key, int qStride, + int kStride, const sycl::accessor &positionIds) { + size_t idx_bs_seq = item.get_global_id(0); + size_t idx_head_num = item.get_global_id(1); + size_t idx_half_head_dim = item.get_global_id(2); + + size_t pos = positionIds[idx_bs_seq % seq_size]; + float cos = embCos[pos * half + idx_half_head_dim]; + float sin = embSin[pos * half + idx_half_head_dim]; + + T *q = query + idx_bs_seq * qStride + idx_head_num * head_size + idx_half_head_dim; + T *k = key + idx_bs_seq * kStride + idx_head_num * head_size + idx_half_head_dim; + + if (idx_head_num < qHeads) { + auto q1 = q[0]; + q[0] = q1 * cos - q[half] * sin; + q[half] = q[half] * cos + q1 * sin; + } + if (idx_head_num < kHeads) { + auto k1 = k[0]; + k[0] = k1 * cos - k[half] * sin; + k[half] = k[half] * cos + k1 * sin; + } + }; + + // Reorder input + sycl::queue *gpu_queue = static_cast(device); + sycl::buffer positionIdsBuf(positionIds, sycl::range<1>(seqLen)); + gpu_queue->submit([&](sycl::handler &cgh) { + sycl::accessor position(positionIdsBuf, cgh, sycl::read_only); + sycl::range<3> globalSize(batchSize * seqLen, head_num, half_head_size); + sycl::range<3> workGroupSize(1, 1, 1); + + cgh.parallel_for(sycl::nd_range(globalSize, workGroupSize), [=, this](sycl::nd_item<3> item) { + rope_kernel(item, emb_cos, emb_sin, qHeads, kHeads, seqLen, head_size, half_head_size, query, key, qStride, + kStride, position); + }); + }); + gpu_queue->wait(); +} + +void llamaApplyRotaryPosEmbeding(void *device, float *query, float *key, int qStride, int kStride, float *emb_cos, + float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) { + llamaApplyRotaryPosEmbeding( + device, query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); +} + +void llamaApplyRotaryPosEmbeding(void *device, bfloat16_t *query, bfloat16_t *key, int qStride, int kStride, + float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) { + llamaApplyRotaryPosEmbeding( + device, query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); +} + +void llamaApplyRotaryPosEmbeding(void *device, float16_t *query, float16_t *key, int qStride, int kStride, + float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) { + llamaApplyRotaryPosEmbeding( + device, query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); +} +#endif + } // namespace xft diff --git a/src/kernels/rotary_embedding_kernels.h b/src/kernels/rotary_embedding_kernels.h index bf150351..3348e968 100644 --- a/src/kernels/rotary_embedding_kernels.h +++ b/src/kernels/rotary_embedding_kernels.h @@ -32,7 +32,7 @@ void llamaApplyRotaryPosEmbeding(bfloat16_t *query, bfloat16_t *key, int qStride void llamaApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, int kStride, float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds); -// For continous batching +// For LLaMA continous batching void llamaApplyRotaryPosEmbed(float *query, float *key, float *embCos, float *embSin, int qStride, int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds); @@ -65,4 +65,16 @@ void qwenApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, i float *cur_emb_sin, int inv_freq_size, const float *logn, int maxSupportedSeqLength, const int *qkShape, const int *positionIds); +#ifdef GPU +// For LLaMA +void llamaApplyRotaryPosEmbeding(void *device, float *query, float *key, int qStride, int kStride, float *emb_cos, + float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds); + +void llamaApplyRotaryPosEmbeding(void *device, bfloat16_t *query, bfloat16_t *key, int qStride, int kStride, + float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds); + +void llamaApplyRotaryPosEmbeding(void *device, float16_t *query, float16_t *key, int qStride, int kStride, + float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds); +#endif + } // namespace xft diff --git a/src/layers/rms_norm.cpp b/src/layers/rms_norm.cpp index dd097b96..a15bc2ca 100644 --- a/src/layers/rms_norm.cpp +++ b/src/layers/rms_norm.cpp @@ -69,6 +69,9 @@ void RmsNorm::forward(const float *input, bfloat16_t *output, int rows, int iStr void RmsNorm::forward(const bfloat16_t *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); + sycl::queue *gpu_queue = static_cast(device); + fastertransformer::invokeGeneralT5LayerNorm( + output, input, weight, (const bfloat16_t *)nullptr, epsilon, rows, iStride, gpu_queue); } void RmsNorm::forward(const float *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { diff --git a/src/layers/rotary_embedding.cpp b/src/layers/rotary_embedding.cpp index 320e585f..47a9809d 100644 --- a/src/layers/rotary_embedding.cpp +++ b/src/layers/rotary_embedding.cpp @@ -107,60 +107,20 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(const int dim, const int max_position void LlamaRotaryEmbedding::forward( float *query, float *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { - const int batchSize = qkShape[0]; - const int seqLen = qkShape[1]; - const int qHeads = qkShape[2]; - const int kHeads = qkShape[4]; - const int head_num = std::max(qHeads, kHeads); - const int head_size = qkShape[3]; - const int half_head_size = (head_size + 1) / 2; - using namespace sycl; - - auto rope_kernel - = [](sycl::nd_item<3> &item, const float *embCos, const float *embSin, const int qHeads, const int kHeads, - const int seq_size, const int head_size, const int half, float *query, float *key, int qStride, - int kStride, const sycl::accessor &positionIds) { - size_t idx_bs_seq = item.get_global_id(0); - size_t idx_head_num = item.get_global_id(1); - size_t idx_half_head_dim = item.get_global_id(2); - - size_t pos = positionIds[idx_bs_seq % seq_size]; - float cos = embCos[pos * half + idx_half_head_dim]; - float sin = embSin[pos * half + idx_half_head_dim]; - - float *q = query + idx_bs_seq * qStride + idx_head_num * head_size + idx_half_head_dim; - float *k = key + idx_bs_seq * kStride + idx_head_num * head_size + idx_half_head_dim; - - if (idx_head_num < qHeads) { - auto q1 = q[0]; - q[0] = q1 * cos - q[half] * sin; - q[half] = q[half] * cos + q1 * sin; - } - if (idx_head_num < kHeads) { - auto k1 = k[0]; - k[0] = k1 * cos - k[half] * sin; - k[half] = k[half] * cos + k1 * sin; - } - }; - - // Reorder input - sycl::queue *gpu_queue = static_cast(device); - float *embCos = emb_cos; - float *embSin = emb_sin; - - sycl::buffer positionIdsBuf(positionIds, sycl::range<1>(seqLen)); - gpu_queue->submit([&](sycl::handler &cgh) { - sycl::accessor position(positionIdsBuf, cgh, sycl::read_only); - sycl::range<3> globalSize(batchSize * seqLen, head_num, half_head_size); - sycl::range<3> workGroupSize(1, 1, 1); - - cgh.parallel_for( - sycl::nd_range(globalSize, workGroupSize), [=, this](sycl::nd_item<3> item) { - rope_kernel(item, embCos, embSin, qHeads, kHeads, seqLen, head_size, half_head_size, query, key, - qStride, kStride, position); - }); - }); - gpu_queue->wait(); + xft::llamaApplyRotaryPosEmbeding(this->device, + query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); +} + +void LlamaRotaryEmbedding::forward( + bfloat16_t *query, bfloat16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { + xft::llamaApplyRotaryPosEmbeding(this->device, + query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); +} + +void LlamaRotaryEmbedding::forward( + float16_t *query, float16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { + xft::llamaApplyRotaryPosEmbeding(this->device, + query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); } #else diff --git a/src/layers/token_embedding.h b/src/layers/token_embedding.h index f49a135e..c49dd597 100644 --- a/src/layers/token_embedding.h +++ b/src/layers/token_embedding.h @@ -24,6 +24,7 @@ class TokenEmbedding { TokenEmbedding(DecoderContext *ctx) { this->vocabSize = ctx->vocabSize; this->hiddenSize = ctx->hiddenSize; + this->device = ctx->device; } void setWeights(float *tokenEmb) { @@ -59,4 +60,5 @@ class TokenEmbedding { int hiddenSize; T *embTable = nullptr; + void *device = nullptr; }; diff --git a/src/utils/verbose.h b/src/utils/verbose.h index 9c207e37..f47fa1fb 100644 --- a/src/utils/verbose.h +++ b/src/utils/verbose.h @@ -41,11 +41,38 @@ class Printer { printf("xft_verbose,exec,cpu,api,%s,m%dn%dk%d,%.6lf\n", api_func, M, N, K, ms); fflush(stdout); } + static void matrix(int rows, int cols, int stride, size_t totalmem) { printf("xft_verbose,matrix:rows%d_cols%d_stride%d,use:%zu bytes of memory\n", rows, cols, stride, totalmem); fflush(stdout); } -}; + + template + static void print(std::string buf_name, T *buf, int rows, int cols, int stride, bool printAll = false, void *device = nullptr) { + std::cout << buf_name.c_str() << ":" << std::endl; +#ifdef GPU + if (device != nullptr) { + sycl::queue *gpu_queue = static_cast(device); + gpu_queue + ->submit([&](sycl::handler &cgh) { + auto out = sycl::stream(1024, 768, cgh); + cgh.parallel_for(sycl::nd_range<1>(1, 1), [=](sycl::nd_item<1> item) { + int idx_col = item.get_global_id(0); + if (idx_col == 0) { + for (int row = 0; row < rows; ++row) { + for (int col = 0; col < cols; ++col) { + out << (float)buf[row * stride + col] << ", "; + } + out << sycl::endl; + } + out << sycl::endl; + } + }); + }) + .wait(); +#endif + } + }; #define GEMMVERBOSE(api_func, compute_func) \ if (Env::getInstance().getVerbose() >= 1) { \ From 3cfa6501b6232adc9d57b532ed8cb75a8a6b0e8f Mon Sep 17 00:00:00 2001 From: changqi1 Date: Tue, 4 Jun 2024 14:56:22 +0800 Subject: [PATCH 12/34] Fix attention UT issue. --- tests/ut/layers_attention_test.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/ut/layers_attention_test.cpp b/tests/ut/layers_attention_test.cpp index 942393c7..134020f3 100644 --- a/tests/ut/layers_attention_test.cpp +++ b/tests/ut/layers_attention_test.cpp @@ -87,14 +87,17 @@ void test_AttentionLLaMA(void) { int nextTokenNum = 1; compareAttentionLLaMA(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, - kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, qkvProj, qkvProj + qSize, qkvProj + kvSize, oProj); + kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, qkvProj, qkvProj + qSize, qkvProj + qSize + kvSize, + oProj); pastSeqLen += inputSeqLen; currentSeqLen = nextTokenNum; compareAttentionLLaMA(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, - kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, qkvProj, qkvProj + qSize, qkvProj + kvSize, oProj); + kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, qkvProj, qkvProj + qSize, qkvProj + qSize + kvSize, + oProj); pastSeqLen += nextTokenNum; compareAttentionLLaMA(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, - kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, qkvProj, qkvProj + qSize, qkvProj + kvSize, oProj); + kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, qkvProj, qkvProj + qSize, qkvProj + qSize + kvSize, + oProj); free(qkvProj); free(oProj); From 881bc78d0ceff8c42f94b6f2d2703607ee3bc85c Mon Sep 17 00:00:00 2001 From: changqi1 Date: Tue, 4 Jun 2024 17:51:32 +0000 Subject: [PATCH 13/34] Fix ICX build issue. --- src/common/sequence.h | 71 ++++++++++++++++++++++++ src/kernels/rotary_embedding_kernels.cpp | 2 +- src/layers/attention.h | 4 +- src/layers/dist_linear.h | 2 +- src/layers/mlp_llama.h | 4 +- src/layers/rms_norm.cpp | 8 +-- src/layers/rotary_embedding.cpp | 5 +- src/models/common_decoder.h | 18 +++--- src/utils/compile_util.h | 4 ++ src/utils/transpose_util.h | 2 +- src/utils/verbose.h | 6 +- tests/ut/attention_kernels_test.cpp | 18 +++--- tests/ut/kv_reorder_test.cpp | 2 +- tests/ut/rotary_embedding_test.cpp | 4 +- 14 files changed, 114 insertions(+), 36 deletions(-) diff --git a/src/common/sequence.h b/src/common/sequence.h index 211b69ed..26c99310 100644 --- a/src/common/sequence.h +++ b/src/common/sequence.h @@ -19,6 +19,7 @@ #include #include +#include "allocator.h" #include "environment.h" #include "sampling_params.h" @@ -81,6 +82,20 @@ class SequenceMeta { , promptTokens(_inputSeqLen, 0) , step(0) {} + SequenceMeta(int32_t _sequenceID, std::vector &_promptTokens) + : sequenceID(_sequenceID) + , inputSeqLen(_promptTokens.size()) + , pastSeqLen(0) + , promptTokens(_promptTokens) + , step(0) {} + + SequenceMeta(int32_t _sequenceID, int32_t _inputSeqLen) + : sequenceID(_sequenceID) + , inputSeqLen(_inputSeqLen) + , pastSeqLen(0) + , promptTokens(_inputSeqLen, 0) + , step(0) {} + ~SequenceMeta() {} int32_t getSequenceID() const { return sequenceID; } @@ -207,6 +222,38 @@ class SequenceGroupMeta { groupID = sequences[0].getSequenceID(); } + SequenceGroupMeta(int32_t _sequenceID, std::vector &_inputTokens, SamplingMeta &samplingMeta_) : samplingMeta(samplingMeta_) { + sequences.reserve(samplingMeta.config.numBeams); + for (int i = 0; i < samplingMeta.config.numBeams; ++i) { + sequences.emplace_back(SequenceMeta(_sequenceID, _inputTokens)); + } + groupID = sequences[0].getSequenceID(); + } + + SequenceGroupMeta(int32_t _sequenceID, int32_t _inputSeqLen, SamplingMeta &samplingMeta_) : samplingMeta(samplingMeta_) { + sequences.reserve(samplingMeta.config.numBeams); + for (int i = 0; i < samplingMeta.config.numBeams; ++i) { + sequences.emplace_back(SequenceMeta(_sequenceID, _inputSeqLen)); + } + groupID = sequences[0].getSequenceID(); + } + + SequenceGroupMeta(int32_t _sequenceID, std::vector &_inputTokens) { + sequences.reserve(samplingMeta.config.numBeams); + for (int i = 0; i < samplingMeta.config.numBeams; ++i) { + sequences.emplace_back(SequenceMeta(_sequenceID, _inputTokens)); + } + groupID = sequences[0].getSequenceID(); + } + + SequenceGroupMeta(int32_t _sequenceID, int32_t _inputSeqLen) { + sequences.reserve(samplingMeta.config.numBeams); + for (int i = 0; i < samplingMeta.config.numBeams; ++i) { + sequences.emplace_back(SequenceMeta(_sequenceID, _inputSeqLen)); + } + groupID = sequences[0].getSequenceID(); + } + int32_t getGroupID() { return groupID; } int32_t getGroupSize() { return samplingMeta.config.numBeams; } @@ -272,6 +319,30 @@ class SequencePool { return group; } + SequenceGroupMeta *newGroupMeta(int32_t sequenceID, std::vector &inputTokens, SamplingMeta &samplingMeta_) { + auto *group = new SequenceGroupMeta(sequenceID, inputTokens, samplingMeta_); + this->add(group); + return group; + } + + SequenceGroupMeta *newGroupMeta(int32_t sequenceID, int32_t inputSeqLen, SamplingMeta &samplingMeta_) { + auto *group = new SequenceGroupMeta(sequenceID, inputSeqLen, samplingMeta_); + this->add(group); + return group; + } + + SequenceGroupMeta *newGroupMeta(int32_t sequenceID, std::vector &inputTokens) { + auto *group = new SequenceGroupMeta(sequenceID, inputTokens); + this->add(group); + return group; + } + + SequenceGroupMeta *newGroupMeta(int32_t sequenceID, int32_t inputSeqLen) { + auto *group = new SequenceGroupMeta(sequenceID, inputSeqLen); + this->add(group); + return group; + } + bool add(SequenceGroupMeta *sequenceGroup, bool force = false) { int32_t groupID = sequenceGroup->getGroupID(); bool isSuccess = false; diff --git a/src/kernels/rotary_embedding_kernels.cpp b/src/kernels/rotary_embedding_kernels.cpp index 91361a35..983a3251 100644 --- a/src/kernels/rotary_embedding_kernels.cpp +++ b/src/kernels/rotary_embedding_kernels.cpp @@ -438,7 +438,7 @@ static inline void llamaApplyRotaryPosEmbeding(void *device, T *query, T *key, i sycl::range<3> globalSize(batchSize * seqLen, head_num, half_head_size); sycl::range<3> workGroupSize(1, 1, 1); - cgh.parallel_for(sycl::nd_range(globalSize, workGroupSize), [=, this](sycl::nd_item<3> item) { + cgh.parallel_for(sycl::nd_range(globalSize, workGroupSize), [=](sycl::nd_item<3> item) { rope_kernel(item, emb_cos, emb_sin, qHeads, kHeads, seqLen, head_size, half_head_size, query, key, qStride, kStride, position); }); diff --git a/src/layers/attention.h b/src/layers/attention.h index d731ae4f..a51efaf5 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -145,7 +145,7 @@ class Attention { qkvWeightT.Resize(hiddenSize, responsibleCols); ctx->mmHelper->transposeWeight(true, convertedqkvWeight, qkvWeightT); - WeiT *qkvWeiData = xft::alloc(hiddenSize * responsibleCols * sizeof(WeiT), ctx->device); + WeiT *qkvWeiData = (WeiT *)xft::alloc(hiddenSize * responsibleCols * sizeof(WeiT), ctx->device); qkvWeight.Assign(qkvWeiData, hiddenSize, responsibleCols, responsibleCols); xft::memcopy(qkvWeight.Data(), qkvWeightT.Data(), hiddenSize * responsibleCols * sizeof(WeiT), ctx->device); #else @@ -188,7 +188,7 @@ class Attention { outWeightT.Resize(ctx->attHeadNum * ctx->attHeadSize, hiddenSize); ctx->mmHelper->transposeWeight(true, convertedOutWeight, outWeightT); - WeiT *outWeiData = xft::alloc(ctx->attHeadNum * ctx->attHeadSize * hiddenSize * sizeof(WeiT), ctx->device); + WeiT *outWeiData = (WeiT *)xft::alloc(ctx->attHeadNum * ctx->attHeadSize * hiddenSize * sizeof(WeiT), ctx->device); attnOutputWeight.Assign(outWeiData, ctx->attHeadNum * ctx->attHeadSize, hiddenSize, hiddenSize); int outWeightTSize = ctx->attHeadNum * ctx->attHeadSize * hiddenSize * sizeof(WeiT); xft::memcopy(attnOutputWeight.Data(), outWeightT.Data(), outWeightTSize, ctx->device); diff --git a/src/layers/dist_linear.h b/src/layers/dist_linear.h index a52f0bbe..1eefb3af 100644 --- a/src/layers/dist_linear.h +++ b/src/layers/dist_linear.h @@ -71,7 +71,7 @@ class DistLinear { tWeight.Resize(K, N); ctx->mmHelper->transposeWeight(true, quantizedWeight, tWeight); - WeiT *input_data = xft::alloc(K * N * sizeof(WeiT), ctx->device); + WeiT *input_data = (WeiT *)xft::alloc(K * N * sizeof(WeiT), ctx->device); weight.Assign(input_data, K, N, N); xft::memcopy(weight.Data(), tWeight.Data(), tWeight.Rows() * tWeight.Cols() * sizeof(WeiT), ctx->device); #else diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index 094edb5a..2780479f 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -85,7 +85,7 @@ class LlamaMLP { catWeightsT.Resize(catWeiRows, catWeiCols); ctx->mmHelper->transposeWeight(true, quantizedCatWeights, catWeightsT); - WeiT *catWeiData = xft::alloc(catWeiRows * catWeiCols * sizeof(WeiT), ctx->device); + WeiT *catWeiData = (WeiT *)xft::alloc(catWeiRows * catWeiCols * sizeof(WeiT), ctx->device); catWeights.Assign(catWeiData, catWeiRows, catWeiCols, catWeiCols); xft::memcopy(catWeights.Data(), catWeightsT.Data(), catWeiRows * catWeiCols * sizeof(WeiT), ctx->device); #else @@ -103,7 +103,7 @@ class LlamaMLP { downWeightT.Resize(downWeiRows, downWeiCols); ctx->mmHelper->transposeWeight(true, quantizedDownWeight, downWeightT); - WeiT *downWeiData = xft::alloc(downWeiRows * downWeiCols * sizeof(WeiT), ctx->device); + WeiT *downWeiData = (WeiT *)xft::alloc(downWeiRows * downWeiCols * sizeof(WeiT), ctx->device); downWeight.Assign(downWeiData, downWeiRows, downWeiCols, downWeiCols); xft::memcopy(downWeight.Data(), downWeightT.Data(), downWeiRows * downWeiCols * sizeof(WeiT), ctx->device); #else diff --git a/src/layers/rms_norm.cpp b/src/layers/rms_norm.cpp index a15bc2ca..2df9bd0a 100644 --- a/src/layers/rms_norm.cpp +++ b/src/layers/rms_norm.cpp @@ -70,8 +70,8 @@ void RmsNorm::forward(const float *input, bfloat16_t *output, int rows, int iStr void RmsNorm::forward(const bfloat16_t *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); sycl::queue *gpu_queue = static_cast(device); - fastertransformer::invokeGeneralT5LayerNorm( - output, input, weight, (const bfloat16_t *)nullptr, epsilon, rows, iStride, gpu_queue); + // fastertransformer::invokeGeneralT5LayerNorm( + // output, input, weight, (const bfloat16_t *)nullptr, epsilon, rows, iStride, gpu_queue); } void RmsNorm::forward(const float *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { @@ -81,8 +81,8 @@ void RmsNorm::forward(const float *input, float16_t *output, int rows, int iStri void RmsNorm::forward(const float16_t *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); sycl::queue *gpu_queue = static_cast(device); - fastertransformer::invokeGeneralT5LayerNorm( - output, input, weight, (const float16_t *)nullptr, epsilon, rows, iStride, gpu_queue); + // fastertransformer::invokeGeneralT5LayerNorm( + // output, input, weight, (const float16_t *)nullptr, epsilon, rows, iStride, gpu_queue); } #else diff --git a/src/layers/rotary_embedding.cpp b/src/layers/rotary_embedding.cpp index 47a9809d..12207530 100644 --- a/src/layers/rotary_embedding.cpp +++ b/src/layers/rotary_embedding.cpp @@ -183,6 +183,8 @@ void LlamaRotaryEmbedding::forward( query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); } +#endif // GPU + // For continuous batching void LlamaRotaryEmbedding::forward( float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) { @@ -200,5 +202,4 @@ void LlamaRotaryEmbedding::forward(float16_t *query, float16_t *key, int totSeqL int qHeads, int kHeads, int *positionIds) { xft::llamaApplyRotaryPosEmbed( query, key, emb_cos, emb_sin, qStride, kStride, this->dim, totSeqLen, qHeads, kHeads, positionIds); -} -#endif // GPU \ No newline at end of file +} \ No newline at end of file diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index c0a5e6b9..e44f9b23 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -342,21 +342,21 @@ class CommonDecoder : public AbstractDecoder { // TODO: Error: different scope when dynamic loading so file // this->messenger.worldRecvFP32(embBuf, count, prev_world_rank, curr_world_rank); if (!SequencePool::getInstance().has(sequenceID)) { - auto *seqs = SequencePool::getInstance().newMeta(sequenceID, seqLen); - seqs->get(0)->setPastSeqLen(pastSeqLen); - seqs->get(0)->allocBuffer(hiddenSize, embBuf); - SequencePool::getInstance().add(seqs->get(0)->getSequenceID(), seqs); + auto *groupMeta = SequencePool::getInstance().newGroupMeta(sequenceID, seqLen); + groupMeta->get(0)->setPastSeqLen(pastSeqLen); + groupMeta->get(0)->allocBuffer(hiddenSize, embBuf); + SequencePool::getInstance().add(groupMeta); } TaskWaitingQueue::getInstance().push(SequencePool::getInstance().get(sequenceID)); } if (!InputQueue::getInstance().empty()) { if (!TaskWaitingQueue::getInstance().isFull()) { - auto *seqs = InputQueue::getInstance().pop(); - seqs->get(0)->setPastSeqLen(pastSeqLen); - seqs->get(0)->allocBuffer(hiddenSize, embBuf); - SequencePool::getInstance().add(seqs->get(0)->getSequenceID(), seqs); - TaskWaitingQueue::getInstance().push(SequencePool::getInstance().get(seqs->get(0)->getSequenceID())); + auto *groupMeta = InputQueue::getInstance().pop(); + groupMeta->get(0)->setPastSeqLen(pastSeqLen); + groupMeta->get(0)->allocBuffer(hiddenSize, embBuf); + SequencePool::getInstance().add(groupMeta); + TaskWaitingQueue::getInstance().push(SequencePool::getInstance().get(groupMeta->get(0)->getSequenceID())); } } diff --git a/src/utils/compile_util.h b/src/utils/compile_util.h index e11cf32c..3874d763 100644 --- a/src/utils/compile_util.h +++ b/src/utils/compile_util.h @@ -17,6 +17,10 @@ #include #include +#ifdef GPU +#include +#endif + #define likely(x) __builtin_expect((x), 1) #define unlikely(x) __builtin_expect((x), 0) diff --git a/src/utils/transpose_util.h b/src/utils/transpose_util.h index f2a07d43..0ae7ec00 100644 --- a/src/utils/transpose_util.h +++ b/src/utils/transpose_util.h @@ -14,7 +14,7 @@ // ============================================================================ #ifndef _TRANSPOSE_H #define _TRANSPOSE_H -#include + #include #include diff --git a/src/utils/verbose.h b/src/utils/verbose.h index f47fa1fb..5716127d 100644 --- a/src/utils/verbose.h +++ b/src/utils/verbose.h @@ -48,7 +48,8 @@ class Printer { } template - static void print(std::string buf_name, T *buf, int rows, int cols, int stride, bool printAll = false, void *device = nullptr) { + static void print(std::string buf_name, T *buf, int rows, int cols, int stride, bool printAll = false, + void *device = nullptr) { std::cout << buf_name.c_str() << ":" << std::endl; #ifdef GPU if (device != nullptr) { @@ -72,7 +73,8 @@ class Printer { .wait(); #endif } - }; + } +}; #define GEMMVERBOSE(api_func, compute_func) \ if (Env::getInstance().getVerbose() >= 1) { \ diff --git a/tests/ut/attention_kernels_test.cpp b/tests/ut/attention_kernels_test.cpp index 57ba121a..d9100323 100644 --- a/tests/ut/attention_kernels_test.cpp +++ b/tests/ut/attention_kernels_test.cpp @@ -86,7 +86,7 @@ static void selfAttentionRef(bfloat16_t *output, bfloat16_t *query, bfloat16_t * int kvHeadNum, int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *tokenSizes, const float scale) { - int rowOffsets[batchSize] = {0}; + int rowOffsets[batchSize]; for (int i = 1; i < batchSize; i++) { rowOffsets[i] = rowOffsets[i - 1] + tokenSizes[i - 1]; } @@ -178,49 +178,49 @@ void testSelfAttention( } TEST(AttentionKernelsTest, SeparateCopyTest1) { - int batchSize = 1; + const int batchSize = 1; int tokenSizes[batchSize] = {80}; testSelfAttention(128, 2, 2, tokenSizes, batchSize); } TEST(AttentionKernelsTest, SeparateCopyTest2) { - int batchSize = 1; + const int batchSize = 1; int tokenSizes[batchSize] = {100}; testSelfAttention(128, 6, 2, tokenSizes, batchSize); } TEST(AttentionKernelsTest, SeparateCopyTest3) { - int batchSize = 2; + const int batchSize = 2; int tokenSizes[batchSize] = {100, 200}; testSelfAttention(128, 8, 2, tokenSizes, batchSize); } TEST(AttentionKernelsTest, SeparateCopyTest4) { - int batchSize = 3; + const int batchSize = 3; int tokenSizes[batchSize] = {100, 101, 102}; testSelfAttention(128, 8, 2, tokenSizes, batchSize); } TEST(AttentionKernelsTest, SeparateCopyTest5) { - int batchSize = 4; + const int batchSize = 4; int tokenSizes[batchSize] = {100, 55, 111, 203}; testSelfAttention(128, 8, 2, tokenSizes, batchSize); } TEST(AttentionKernelsTest, FusedCopyTest1) { - int batchSize = 1; + const int batchSize = 1; int tokenSizes[batchSize] = {100}; testSelfAttention(128, 2, 2, tokenSizes, batchSize, false); } TEST(AttentionKernelsTest, FusedCopyTest2) { - int batchSize = 2; + const int batchSize = 2; int tokenSizes[batchSize] = {100, 101}; testSelfAttention(128, 4, 4, tokenSizes, batchSize, false); } TEST(AttentionKernelsTest, FusedCopyTest3) { - int batchSize = 4; + const int batchSize = 4; int tokenSizes[batchSize] = {100, 101, 102, 103}; testSelfAttention(128, 4, 4, tokenSizes, batchSize, false); } diff --git a/tests/ut/kv_reorder_test.cpp b/tests/ut/kv_reorder_test.cpp index f082bc7b..17890026 100644 --- a/tests/ut/kv_reorder_test.cpp +++ b/tests/ut/kv_reorder_test.cpp @@ -14,7 +14,7 @@ // ============================================================================ #include -#include "opt_decoder.h" +#include "kvcache_tensor.h" #include "gtest/gtest.h" template diff --git a/tests/ut/rotary_embedding_test.cpp b/tests/ut/rotary_embedding_test.cpp index c339e6ec..0008b320 100644 --- a/tests/ut/rotary_embedding_test.cpp +++ b/tests/ut/rotary_embedding_test.cpp @@ -26,7 +26,7 @@ static bool compare(const float *result, const float *ground_truth, const int si } TEST(RotrayEmbedding, RotrayEmbeddingTest) { - int bs = 2, seq = 2, headnum = 2, dim = 2; + const int bs = 2, seq = 2, headnum = 2, dim = 2; int max_len = 10; int stride = bs * seq, size = bs * seq * headnum * dim; @@ -57,7 +57,7 @@ TEST(RotrayEmbedding, RotrayEmbeddingTest) { } TEST(RotrayEmbedding, BF16Test) { - int bs = 2, seq = 2, headnum = 2, dim = 2; + const int bs = 2, seq = 2, headnum = 2, dim = 2; int max_len = 10; int stride = bs * seq, size = bs * seq * headnum * dim; From 9b2af7edff78bc024803406b916294a1f6877cf2 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Tue, 4 Jun 2024 18:55:08 +0000 Subject: [PATCH 14/34] Fix build. --- src/utils/verbose.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/verbose.h b/src/utils/verbose.h index 5716127d..0716b045 100644 --- a/src/utils/verbose.h +++ b/src/utils/verbose.h @@ -71,8 +71,8 @@ class Printer { }); }) .wait(); -#endif } +#endif } }; From c03484908e2e74e0652c8eb9626336efc065acfe Mon Sep 17 00:00:00 2001 From: changqi1 Date: Fri, 7 Jun 2024 15:09:55 +0000 Subject: [PATCH 15/34] Add rmsNorm impl and XFT_DEBUG --- CMakeLists.txt | 10 +-- src/kernels/attention_kernels.cpp | 6 +- src/kernels/attention_kernels.h | 16 ++-- src/layers/attention.h | 42 ++++----- src/layers/decoder_layer.h | 6 +- src/layers/mlp_chatglm2.h | 2 +- src/layers/mlp_llama.h | 22 ++--- src/layers/mlp_standard.h | 12 +-- src/layers/rms_norm.cpp | 136 ++++++++++++++++++++++++------ src/layers/rms_norm.h | 23 +++-- src/models/common_decoder.h | 18 ++-- tests/ut/cross_attention_test.cpp | 2 +- 12 files changed, 189 insertions(+), 106 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 81eeecc7..42832658 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,10 +35,6 @@ else() message(STATUS "Notice: GCC version: ${GCC_VERSION}") endif() -if(NOT CMAKE_BUILD_TYPE) - set(CMAKE_BUILD_TYPE Release) -endif() - set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mavx512bw -mavx512vl -fPIC") if(WITH_GPU) @@ -73,11 +69,15 @@ if(GCC_VERSION VERSION_GREATER_EQUAL "10.1") endif() endif() +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + if(CMAKE_BUILD_TYPE MATCHES "Debug") message(STATUS "Notice: Using Debug mode.") set(CMAKE_C_FLAGS "-O0 -g") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g") - add_definitions(-DDEBUG=true) + add_definitions(-DXFT_DEBUG=true) add_definitions(-DSTEP_BY_STEP_ATTN=true) else() message(STATUS "Notice: Using Release mode.") diff --git a/src/kernels/attention_kernels.cpp b/src/kernels/attention_kernels.cpp index 2ace33d7..497c99ed 100644 --- a/src/kernels/attention_kernels.cpp +++ b/src/kernels/attention_kernels.cpp @@ -66,7 +66,7 @@ void crossAttention(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bflo small_sgemm_bf16bf16f32_b(true, m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)baseB, ldb, C, ldc, blkIndices, cacheBlkStride, cacheBlkSize); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { printf("Q * K, first head:\n"); auto p = C; @@ -78,7 +78,7 @@ void crossAttention(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bflo // Softmax(Q * K) small_softmax_f32(C, scale, n); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { printf("Softmax(Q * K), first head:\n"); auto p = C; @@ -100,7 +100,7 @@ void crossAttention(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bflo small_sgemm_f32bf16bf16_b(false, m, n, k, C, lda, (XDNN_BF16 *)baseB, ldb, (XDNN_BF16 *)baseC, ldc, blkIndices, cacheBlkStride, cacheBlkSize); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { printf("Softmax(Q * K) * V, first head:\n"); auto p = C; diff --git a/src/kernels/attention_kernels.h b/src/kernels/attention_kernels.h index ca0dac8f..1a9d96bf 100644 --- a/src/kernels/attention_kernels.h +++ b/src/kernels/attention_kernels.h @@ -237,7 +237,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_ xdnn_small_amx_sgemm_bf16bf16bf16_compute( m, endSeq, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { auto B = key + offsets[b] * kvStride + kvHeadIdx * headSize; printf("mnk=%d,%d,%d, ldabc=%d,%d,%d, A[0]=%f, B[0]=%f, packedB[0]=%f\n", m, n, k, lda, ldb, ldc, @@ -260,7 +260,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_ memset(C + seq * ldc + elements, 0, (tokens - elements) * sizeof(bfloat16_t)); } -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { printf("Softmax(Q * Káµ€), first head:\n"); auto p = C; @@ -290,7 +290,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_ xdnn_small_amx_sgemm_bf16bf16bf16_compute( m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedV, (XDNN_BF16 *)C, ldc); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { printf("Softmax(Q * Káµ€) * V, first head:\n"); auto p = C; @@ -306,7 +306,7 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t * int kvHeadNum, int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *tokenSizes, const float scale, const float *alibiSlopes, int threadNum, const Lambda1 &getKCache, const Lambda2 &getVCache) { -#ifdef DEBUG +#ifdef XFT_DEBUG printf("Q[0]=%f, K[0]=%f, V[0]=%f\n", (float)query[0], (float)key[0], (float)value[0]); printf("kvHeadNum=%d, headSize=%d, qStride=%d, kvStride=%d, batchSize=%d\n", kvHeadNum, headSize, qStride, kvStride, batchSize); @@ -337,7 +337,7 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t * bfloat16_t *scores = (bfloat16_t *)SimpleMemPool::instance().getBuffer( "qkscore", threadNum * mBlockSize * maxScoreStride * sizeof(bfloat16_t)); -#ifdef DEBUG +#ifdef XFT_DEBUG printf("maxTokenSize=%d, tokenSizes[0]=%d, offsets[0]=%d, kvStride=%d\n", maxTokenSize, tokenSizes[0], offsets[0], kvStride); #endif @@ -389,7 +389,7 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t * xdnn_small_amx_sgemm_bf16bf16bf16_compute( m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { printf("mnk=%d,%d,%d, ldabc=%d,%d,%d, A[0]=%f, B[0]=%f, packedB[0]=%f\n", m, n, k, lda, ldb, ldc, (float)A[0], (float)B[0], (float)packedB[0]); @@ -411,7 +411,7 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t * memset(C + seq * ldc + elements, 0, (tokens - elements) * sizeof(bfloat16_t)); } -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { printf("Softmax(Q * Káµ€), first head:\n"); auto p = C; @@ -430,7 +430,7 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t * xdnn_small_amx_sgemm_bf16bf16bf16_compute( m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedV, (XDNN_BF16 *)C, ldc); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { printf("Softmax(Q * Káµ€) * V, first head:\n"); auto p = C; diff --git a/src/layers/attention.h b/src/layers/attention.h index a51efaf5..89be2bfd 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -157,7 +157,7 @@ class Attention { free(concatScale); free(concatZero); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("attention qkv weight: [%d, %d] (%d)\n", convertedqkvWeight.Rows(), convertedqkvWeight.Cols(), convertedqkvWeight.Stride()); dbg.dumpMatrix(convertedqkvWeight); @@ -197,7 +197,7 @@ class Attention { ctx->mmHelper->packWeight(trans, convertedOutWeight, attnOutputWeight); #endif -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint(">>> attention output weight: [%d, %d] (%d)\n", convertedOutWeight.Rows(), convertedOutWeight.Cols(), convertedOutWeight.Stride()); dbg.dumpMatrix(convertedOutWeight); @@ -220,7 +220,7 @@ class Attention { if (doLNorm) this->norm->setWeight(gamma1, beta1, hiddenSize); } -#ifdef DEBUG +#ifdef XFT_DEBUG void setDebugger(const Debugger &debugger) { this->dbg = debugger; } #endif @@ -264,7 +264,7 @@ class Attention { auto &qkvMatMul = ctx->qkvMatMul; xft::Matrix qkvGroupMatMul((ImT *)qkvMatMul.Data(), qkvRows, qkvCols, qkvStride); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("---- DecoderLayer.forward (useSelfAttn=%d) ----\n", useSelfAttn); dbg.debugPrint("input:\n"); dbg.dumpMatrix(inputBuffer); @@ -275,7 +275,7 @@ class Attention { norm->forward(inputBuffer.Data(), imBuffer.Data(), inputBuffer.Rows(), inputBuffer.Stride(), imBuffer.Stride(), epsilon); } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("layer norm:\n"); dbg.dumpMatrix(imBuffer); dbg.debugPrint("qkvWeight [%d, %d]:\n", this->qkvWeight.Rows(), this->qkvWeight.Cols()); @@ -299,7 +299,7 @@ class Attention { xft::Matrix key(qkvGroupMatMul, 0, inputBuffer.Rows(), qCols, kvCols); xft::Matrix value(qkvGroupMatMul, 0, inputBuffer.Rows(), qkCols, kvCols); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("Q[%d,%d](%d):\n", query.Rows(), query.Cols(), query.Stride()); dbg.dumpMatrix(query); dbg.debugPrint("K[%d,%d](%d):\n", key.Rows(), key.Cols(), key.Stride()); @@ -331,7 +331,7 @@ class Attention { } t3.release(); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("Q[%d,%d](%d) after post op:\n", query.Rows(), query.Cols(), query.Stride()); dbg.dumpMatrix(query); dbg.debugPrint("K[%d,%d](%d) after post op:\n", key.Rows(), key.Cols(), key.Stride()); @@ -367,7 +367,7 @@ class Attention { } t4.release(); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint(">>> attention_%d (softmax * value): [%d, %d] (%d)\n", ctx->splitIdx, attnSplit.Rows(), attnSplit.Cols(), attnSplit.Stride()); dbg.dumpMatrix(attnSplit); @@ -415,7 +415,7 @@ class Attention { } t5.release(); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint(">>> attention output/projection[%d, %d] (%d):\n", outBuffer.Rows(), outBuffer.Cols(), outBuffer.Stride()); dbg.dumpMatrix(outBuffer); @@ -424,7 +424,7 @@ class Attention { if (doLnAfter) { TimeLine t6("result.layer_norm"); norm->forward(outBuffer.Data(), outBuffer.Data(), outBuffer.Rows(), outBuffer.Stride(), outBuffer.Stride()); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("LayerNorm after attention: [%d, %d] (%d)\n", outBuffer.Rows(), outBuffer.Cols(), outBuffer.Stride()); dbg.dumpMatrix(outBuffer); @@ -458,7 +458,7 @@ class Attention { auto &qkvMatMul = ctx->qkvMatMul; xft::Matrix qkvGroupMatMul((ImT *)qkvMatMul.Data(), qkvRows, qkvCols, qkvStride); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("---- DecoderLayer.forward ----\n"); dbg.debugPrint("input:\n"); dbg.dumpMatrix(inputBuffer); @@ -469,7 +469,7 @@ class Attention { norm->forward(inputBuffer.Data(), imBuffer.Data(), inputBuffer.Rows(), inputBuffer.Stride(), imBuffer.Stride(), epsilon); } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("layer norm:\n"); dbg.dumpMatrix(imBuffer); dbg.debugPrint("qkvWeight [%d, %d]:\n", this->qkvWeight.Rows(), this->qkvWeight.Cols()); @@ -493,7 +493,7 @@ class Attention { xft::Matrix key(qkvGroupMatMul, 0, inputBuffer.Rows(), qCols, kvCols); xft::Matrix value(qkvGroupMatMul, 0, inputBuffer.Rows(), qkCols, kvCols); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("Q[%d,%d](%d):\n", query.Rows(), query.Cols(), query.Stride()); dbg.dumpMatrix(query); dbg.debugPrint("K[%d,%d](%d):\n", key.Rows(), key.Cols(), key.Stride()); @@ -523,7 +523,7 @@ class Attention { } t3.release(); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("Q[%d,%d](%d) after post op:\n", query.Rows(), query.Cols(), query.Stride()); dbg.dumpMatrix(query); dbg.debugPrint("K[%d,%d](%d) after post op:\n", key.Rows(), key.Cols(), key.Stride()); @@ -559,7 +559,7 @@ class Attention { } t4.release(); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint(">>> attention_%d (softmax * value): [%d, %d] (%d)\n", ctx->splitIdx, attnSplit.Rows(), attnSplit.Cols(), attnSplit.Stride()); dbg.dumpMatrix(attnSplit); @@ -602,7 +602,7 @@ class Attention { } t5.release(); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint(">>> attention output/projection[%d, %d] (%d):\n", outBuffer.Rows(), outBuffer.Cols(), outBuffer.Stride()); dbg.dumpMatrix(outBuffer); @@ -611,7 +611,7 @@ class Attention { if (!doLnBefore) { TimeLine t6("result.layer_norm"); norm->forward(outBuffer.Data(), outBuffer.Data(), outBuffer.Rows(), outBuffer.Stride(), outBuffer.Stride()); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("LayerNorm after attention: [%d, %d] (%d)\n", outBuffer.Rows(), outBuffer.Cols(), outBuffer.Stride()); dbg.dumpMatrix(outBuffer); @@ -929,7 +929,7 @@ class Attention { this->gemm1(A, keyMatInfo, C, m, n, headSize, lda, ldc); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { dbg.debugPrint("Q * K, first head:\n"); auto p = scoreBuf; @@ -942,7 +942,7 @@ class Attention { // Softmax(Q * K) this->softmax(ctx, C, getMask(attnMask, b, i, queryLen, keyLen), m, n, ldc, startSeq); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { dbg.debugPrint("Softmax(Q * K), first head:\n"); auto p = scoreBuf; @@ -960,7 +960,7 @@ class Attention { auto output = result.Row(b * ctx->inputSeqLen + startSeq) + i * ctx->attHeadSize; this->gemm2(C, valueMat, output, m, headSize, keyLen, scoreStride, result.Stride()); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { dbg.debugPrint("Softmax(Q * K) * V, first head:\n"); auto p = output; @@ -1201,7 +1201,7 @@ class Attention { int endQHead; int startKVHead; int endKVHead; -#ifdef DEBUG +#ifdef XFT_DEBUG Debugger dbg; #endif }; diff --git a/src/layers/decoder_layer.h b/src/layers/decoder_layer.h index 9a44b13f..3cb58736 100644 --- a/src/layers/decoder_layer.h +++ b/src/layers/decoder_layer.h @@ -59,11 +59,11 @@ class Decoder { : layerIdx(_layerIdx) , attn(_layerIdx, _ctx) , mlp(_ctx) -#ifdef DEBUG +#ifdef XFT_DEBUG , dbg(Debugger::formatStr("%d_%d.csv", _layerIdx, _ctx->splitIdx)) #endif { -#ifdef DEBUG +#ifdef XFT_DEBUG attn.setDebugger(dbg); mlp.setDebugger(dbg); #endif @@ -126,7 +126,7 @@ class Decoder { ATTN_CLS attn; MLP_CLS mlp; -#ifdef DEBUG +#ifdef XFT_DEBUG Debugger dbg; #endif }; diff --git a/src/layers/mlp_chatglm2.h b/src/layers/mlp_chatglm2.h index 8a6296d2..67bb40fa 100644 --- a/src/layers/mlp_chatglm2.h +++ b/src/layers/mlp_chatglm2.h @@ -94,7 +94,7 @@ class ChatGLM2MLP : public LlamaMLP { ctx->mmHelper->convertWeight(ctx, trans, intermediateSize, hiddenSize, downW, nullptr, nullptr, false, convertedDownWeight, this->downWeightScale, this->downWeightZero, this->downWeightSum); ctx->mmHelper->packWeight(trans, convertedDownWeight, this->downWeight); -#ifdef DEBUG +#ifdef XFT_DEBUG this->dbg.debugPrint("convertedGateWeight [%d, %d](%d):\n", convertedGateWeight.Rows(), convertedGateWeight.Cols(), convertedGateWeight.Stride()); this->dbg.dumpMatrix(convertedGateWeight); diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index 2780479f..80f2a8fc 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -14,10 +14,6 @@ // ============================================================================ #pragma once -#ifdef UNDEBUG -#undef NDEBUG -#endif - #include "bert_util.h" #include "copy_util.h" #include "debugger.h" @@ -111,7 +107,7 @@ class LlamaMLP { ctx->mmHelper->packWeight(trans, quantizedDownWeight, downWeight); #endif -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("quantizedGateWeight:\n"); dbg.dumpMatrix(quantizedGateWeight); @@ -126,7 +122,7 @@ class LlamaMLP { if (normW) { norm->setWeight(normW, nullptr, hiddenSize); } } -#ifdef DEBUG +#ifdef XFT_DEBUG void setDebugger(const Debugger &debugger) { this->dbg = debugger; } #endif @@ -149,7 +145,7 @@ class LlamaMLP { norm->forward(inBuffer.Data(), normBuffer.Data(), M, inBuffer.Stride(), normBuffer.Stride(), 1e-6); } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("LayerNorm before MLP:\n"); dbg.dumpMatrix(normBuffer); dbg.debugPrint(">>> residential: [%d, %d] (%d)\n", inBuffer.Rows(), inBuffer.Cols(), inBuffer.Stride()); @@ -161,7 +157,7 @@ class LlamaMLP { (ImT *)ctx->imOut.Data(), ctx->imOut.Rows(), ctx->imOut.Cols(), ctx->imOut.Stride()); gateProj(ctx, doLnBefore ? normBuffer : inBuffer, imBuffer); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint( ">>> gateWeight: [%d, %d] (%d)\n", gateWeight.Rows(), gateWeight.Cols(), gateWeight.Stride()); dbg.dumpMatrix(gateWeight); @@ -171,7 +167,7 @@ class LlamaMLP { upProj(ctx, doLnBefore ? normBuffer : inBuffer, imBuffer); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint(">>> upWeight: [%d, %d] (%d)\n", upWeight.Rows(), upWeight.Cols(), upWeight.Stride()); dbg.dumpMatrix(upWeight); dbg.debugPrint(">>> up output: [%d, %d] (%d)\n", imBuffer.Rows(), imBuffer.Cols(), imBuffer.Stride()); @@ -189,7 +185,7 @@ class LlamaMLP { auto bufSize = sizeof(ImT) * M * cols; ImT *t = (ImT *)SimpleMemPool::instance().getBuffer("mlp_silu", bufSize); xft::Matrix siluBuf(t, M, cols, cols); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint( ">>> enableCATMLP imBuffer: [%d, %d] (%d)\n", imBuffer.Rows(), imBuffer.Cols(), imBuffer.Stride()); dbg.dumpMatrix(imBuffer); @@ -197,7 +193,7 @@ class LlamaMLP { dbg.dumpMatrix(inBuffer); #endif catGateUpProj(ctx, doLnBefore ? normBuffer : inBuffer, imBuffer, siluBuf); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("catWeights:\n"); dbg.dumpMatrix(catWeights); dbg.debugPrint("gateUp output:\n"); @@ -208,7 +204,7 @@ class LlamaMLP { downProj(ctx, siluBuf, outBuffer, inBuffer, ctx->splitIdx == 0); } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint(">>> downWeight: [%d, %d] (%d)\n", downWeight.Rows(), downWeight.Cols(), downWeight.Stride()); dbg.dumpMatrix(downWeight); dbg.debugPrint(">>> residential: [%d, %d] (%d)\n", inBuffer.Rows(), inBuffer.Cols(), inBuffer.Stride()); @@ -385,7 +381,7 @@ class LlamaMLP { // LlamaRMSNorm param NORM_CLS *norm; -#ifdef DEBUG +#ifdef XFT_DEBUG Debugger dbg; #endif }; diff --git a/src/layers/mlp_standard.h b/src/layers/mlp_standard.h index 0e9c6123..a4d65170 100644 --- a/src/layers/mlp_standard.h +++ b/src/layers/mlp_standard.h @@ -71,7 +71,7 @@ class MLP { } } -#ifdef DEBUG +#ifdef XFT_DEBUG void setDebugger(const Debugger &debugger) { this->dbg = debugger; } #endif @@ -99,7 +99,7 @@ class MLP { auto &imInput = doLnBefore ? (INPUT_AS_RESID ? resultBuffer1 : resultBuffer2) : resultBuffer2; -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("layer norm after attention:\n"); dbg.dumpMatrix(imInput); #endif @@ -110,7 +110,7 @@ class MLP { case DecoderContext::GELU: intermediate_gelu(ctx, imInput, imBuffer); break; } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("intermediate:\n"); dbg.dumpMatrix(imBuffer); #endif @@ -149,7 +149,7 @@ class MLP { } } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("output:\n"); dbg.dumpMatrix(resultBuffer1); #endif @@ -157,7 +157,7 @@ class MLP { // layerNorm if (!doLnBefore) { DecoderUtil::layerNorm(resultBuffer1, resultBuffer1, gamma2, beta2); } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("final output:\n"); dbg.dumpMatrix(resultBuffer1); #endif @@ -239,7 +239,7 @@ class MLP { // layerNorm param xft::Vector gamma2, beta2; -#ifdef DEBUG +#ifdef XFT_DEBUG Debugger dbg; #endif }; diff --git a/src/layers/rms_norm.cpp b/src/layers/rms_norm.cpp index 2df9bd0a..7998eab8 100644 --- a/src/layers/rms_norm.cpp +++ b/src/layers/rms_norm.cpp @@ -25,92 +25,172 @@ #ifdef GPU #include "gpudnn/gpu_layernorm_kernels.h" +#include #endif namespace xft { -RmsNorm::RmsNorm() { +template +RmsNormImp::RmsNormImp() { weight = nullptr; normSize = 0; } -RmsNorm::RmsNorm(DecoderContext *ctx) { +template +RmsNormImp::RmsNormImp(DecoderContext *ctx) { device = ctx->device; weight = nullptr; normSize = 0; } -RmsNorm::~RmsNorm() { +template +RmsNormImp::~RmsNormImp() { if (weight) { xft::dealloc(weight); } } -void RmsNorm::setWeight(const float *w, const float *, int cols) { +template +void RmsNormImp::setWeight(const float *w, const float *, int cols) { + T weightBuf[cols]; + if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { + float16_t::cvt_float_to_float16(w, weightBuf, cols); + } else if constexpr (std::is_same_v) { + bfloat16_t::cvt_float_to_bfloat16(w, weightBuf, cols); + } else { + printf("%s:%d: Could not setWeight in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } + this->normSize = cols; - this->weight = (float *)xft::alloc(cols * sizeof(float), device); - xft::memcopy(this->weight, w, cols * sizeof(float), device); + this->weight = (T *)xft::alloc(cols * sizeof(T), device); + xft::memcopy(this->weight, weightBuf, cols * sizeof(T), device); } -void RmsNorm::setWeight(const std::string &modelPath, const std::string &, int cols) { +template +void RmsNormImp::setWeight(const std::string &modelPath, const std::string &, int cols) { this->normSize = cols; loadWeight(modelPath, weight, cols); } #ifdef GPU -void RmsNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { +template +void RmsNormImp::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); sycl::queue *gpu_queue = static_cast(device); - fastertransformer::invokeGeneralT5LayerNorm( - output, input, weight, (const float *)nullptr, epsilon, rows, iStride, gpu_queue); + if constexpr (std::is_same_v) { + fastertransformer::invokeGeneralT5LayerNorm( + output, input, weight, (const float *)nullptr, epsilon, rows, iStride, gpu_queue); + } else { + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } } -void RmsNorm::forward(const float *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { +template +void RmsNormImp::forward(const float *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); } -void RmsNorm::forward(const bfloat16_t *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { +template +void RmsNormImp::forward( + const bfloat16_t *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); sycl::queue *gpu_queue = static_cast(device); - // fastertransformer::invokeGeneralT5LayerNorm( - // output, input, weight, (const bfloat16_t *)nullptr, epsilon, rows, iStride, gpu_queue); + if constexpr (std::is_same_v) { + // TODO: Add BF16 RmsNorm Implemention. + // fastertransformer::invokeGeneralT5LayerNorm( + // output, input, weight, (const bfloat16_t *)nullptr, epsilon, rows, iStride, gpu_queue); + } else { + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } } -void RmsNorm::forward(const float *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { +template +void RmsNormImp::forward(const float *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); } -void RmsNorm::forward(const float16_t *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { +template +void RmsNormImp::forward( + const float16_t *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); sycl::queue *gpu_queue = static_cast(device); - // fastertransformer::invokeGeneralT5LayerNorm( - // output, input, weight, (const float16_t *)nullptr, epsilon, rows, iStride, gpu_queue); + if constexpr (std::is_same_v) { + fastertransformer::invokeGeneralT5LayerNorm((sycl::half *)output, (const sycl::half *)input, + (const sycl::half *)weight, (const sycl::half *)nullptr, epsilon, rows, iStride, gpu_queue); + } else { + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } } #else // input and output are in shape of (rows, normSize) -void RmsNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { +template +void RmsNormImp::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); - rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + if constexpr (std::is_same_v) { + rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + } else { + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } } -void RmsNorm::forward(const float *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { +template +void RmsNormImp::forward(const float *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); - rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + if constexpr (std::is_same_v) { + rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + } else { + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } } -void RmsNorm::forward(const bfloat16_t *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { +template +void RmsNormImp::forward( + const bfloat16_t *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); - rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + if constexpr (std::is_same_v) { + rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + } else { + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } } -void RmsNorm::forward(const float *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { +template +void RmsNormImp::forward(const float *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); - rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + if constexpr (std::is_same_v) { + rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + } else { + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } } -void RmsNorm::forward(const float16_t *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { +template +void RmsNormImp::forward( + const float16_t *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); - rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + if constexpr (std::is_same_v) { + rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + } else { + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } } #endif +template class RmsNormImp; +template class RmsNormImp; +template class RmsNormImp; + } // namespace xft \ No newline at end of file diff --git a/src/layers/rms_norm.h b/src/layers/rms_norm.h index b355963a..05cafd2b 100644 --- a/src/layers/rms_norm.h +++ b/src/layers/rms_norm.h @@ -15,17 +15,18 @@ #pragma once #include "bfloat16.h" -#include "weight_util.h" #include "transformer_ctx.h" +#include "weight_util.h" namespace xft { // RMS normalization: only support the norm along last dimension -class RmsNorm { +template +class RmsNormImp { public: - RmsNorm(); - RmsNorm(DecoderContext *ctx); - ~RmsNorm(); + RmsNormImp(); + RmsNormImp(DecoderContext *ctx); + ~RmsNormImp(); void setWeight(const float *w, const float *, int cols); void setWeight(const std::string &modelPath, const std::string &, int cols); @@ -41,8 +42,8 @@ class RmsNorm { void forward(const bfloat16_t *input, bfloat16_t *output, int rows, int iStride = -1, int oStride = -1, float epsilon = 1e-6); - void forward(const float *input, float16_t *output, int rows, int iStride = -1, int oStride = -1, - float epsilon = 1e-6); + void forward( + const float *input, float16_t *output, int rows, int iStride = -1, int oStride = -1, float epsilon = 1e-6); void forward(const float16_t *input, float16_t *output, int rows, int iStride = -1, int oStride = -1, float epsilon = 1e-6); @@ -51,8 +52,14 @@ class RmsNorm { int normSize; // the scale weight - float *weight = nullptr; + T *weight = nullptr; void *device = nullptr; }; +#ifdef GPU +using RmsNorm = RmsNormImp; +#else +using RmsNorm = RmsNormImp; +#endif + } // namespace xft \ No newline at end of file diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index e44f9b23..2f2724b9 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -155,7 +155,7 @@ class CommonDecoder : public AbstractDecoder { public: CommonDecoder(const std::string &modelPath, const std::string &modelType) : messenger(Messenger::getInstance()) -#ifdef DEBUG +#ifdef XFT_DEBUG , dbg("model_decoder.csv") #endif { @@ -313,7 +313,7 @@ class CommonDecoder : public AbstractDecoder { this->embeddingForward(ids, embBuf, batchSize * inputSeqLen); this->accSeqLen += seqLen; -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("---- embedding.forward ----\n"); dbg.debugPrint("ids:\n"); dbg.dumpMatrix(ids, batchSize, inputSeqLen, inputSeqLen); @@ -451,7 +451,7 @@ class CommonDecoder : public AbstractDecoder { } } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint(">>> DecoderLayer Output[%d, %d] (%d):\n", batchSize * inputSeqLen, hiddenSize, hiddenSize); dbg.dumpMatrix(embBuf, batchSize * inputSeqLen, hiddenSize, hiddenSize); dbg.debugPrint("LayerNorm In:\n"); @@ -469,7 +469,7 @@ class CommonDecoder : public AbstractDecoder { else lastLayerNormForward(lnIn, lnOut, batchSize * seqLen); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("LayerNorm Out:\n"); if (!logitsAll) dbg.dumpMatrix(lnOut, batchSize, hiddenSize, hiddenSize); @@ -484,7 +484,7 @@ class CommonDecoder : public AbstractDecoder { else this->predictor->forward(ctx, lnOut, finalOut, batchSize * seqLen); -#ifdef DEBUG +#ifdef XFT_DEBUG auto splitSize = this->predictor->getSplitSize(); dbg.debugPrint("finalOut:\n"); if (!logitsAll) @@ -562,7 +562,7 @@ class CommonDecoder : public AbstractDecoder { } } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint(">>> DecoderLayer Output[%d, %d] (%d):\n", logitRows, hiddenSize, hiddenSize); dbg.dumpMatrix(embBuf, logitRows, hiddenSize, hiddenSize); dbg.debugPrint("LayerNorm In:\n"); @@ -574,7 +574,7 @@ class CommonDecoder : public AbstractDecoder { MlpOutT *lnOut = embBuf; lastLayerNormForward(lnIn, lnOut, logitRows); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("LayerNorm Out:\n"); dbg.dumpMatrix(lnOut, logitRows, hiddenSize, hiddenSize); #endif @@ -583,7 +583,7 @@ class CommonDecoder : public AbstractDecoder { float *finalOut = (float *)outBuf; this->predictor->forward(ctx, lnOut, finalOut, logitRows); -#ifdef DEBUG +#ifdef XFT_DEBUG auto splitSize = this->predictor->getSplitSize(); dbg.debugPrint("finalOut:\n"); dbg.dumpMatrix(finalOut, logitRows, splitSize, splitSize); @@ -1114,7 +1114,7 @@ class CommonDecoder : public AbstractDecoder { int startId; int endId; -#ifdef DEBUG +#ifdef XFT_DEBUG Debugger dbg; #endif }; diff --git a/tests/ut/cross_attention_test.cpp b/tests/ut/cross_attention_test.cpp index 980b1307..c6694a77 100644 --- a/tests/ut/cross_attention_test.cpp +++ b/tests/ut/cross_attention_test.cpp @@ -97,7 +97,7 @@ static void crossAttentionRef(bfloat16_t *output, const bfloat16_t *query, const // Score = Softmax(Q * Káµ€) softmaxRef(pscore, presentSeqLen); -#ifdef DEBUG +#ifdef XFT_DEBUG printf("pscore: "); for (int i = 0; i < presentSeqLen; ++i) { printf("%.6f ", pscore[i]); From ec463a325c2ff09945678d523daf41ea2ed9a294 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Fri, 7 Jun 2024 16:24:50 +0000 Subject: [PATCH 16/34] Update. --- src/common/sequence.h | 26 +++++++++++++------------- src/layers/attention.h | 24 +++++++++++++----------- src/layers/layer_norm.cpp | 3 +++ src/layers/mlp_chatglm2.h | 2 +- src/layers/mlp_llama.h | 11 ++++++----- src/models/llama.cpp | 1 + src/utils/matmul_helper.h | 20 ++++++++++++++++---- tests/ut/attention_kernels_test.cpp | 1 + 8 files changed, 54 insertions(+), 34 deletions(-) diff --git a/src/common/sequence.h b/src/common/sequence.h index 26c99310..6376e4a6 100644 --- a/src/common/sequence.h +++ b/src/common/sequence.h @@ -68,7 +68,7 @@ class SequenceIDManager { // The SequenceMeta is one sequence of batch inputs and includes the generated tokens. class SequenceMeta { public: - SequenceMeta(std::vector &_promptTokens) + SequenceMeta(const std::vector &_promptTokens) : sequenceID(SequenceIDManager::getInstance().createSequenceID()) , inputSeqLen(_promptTokens.size()) , pastSeqLen(0) @@ -82,7 +82,7 @@ class SequenceMeta { , promptTokens(_inputSeqLen, 0) , step(0) {} - SequenceMeta(int32_t _sequenceID, std::vector &_promptTokens) + SequenceMeta(int32_t _sequenceID, const std::vector &_promptTokens) : sequenceID(_sequenceID) , inputSeqLen(_promptTokens.size()) , pastSeqLen(0) @@ -90,11 +90,7 @@ class SequenceMeta { , step(0) {} SequenceMeta(int32_t _sequenceID, int32_t _inputSeqLen) - : sequenceID(_sequenceID) - , inputSeqLen(_inputSeqLen) - , pastSeqLen(0) - , promptTokens(_inputSeqLen, 0) - , step(0) {} + : sequenceID(_sequenceID), inputSeqLen(_inputSeqLen), pastSeqLen(0), promptTokens(_inputSeqLen, 0), step(0) {} ~SequenceMeta() {} @@ -190,7 +186,8 @@ class SequenceGroupMeta { groupID = sequences[0].getSequenceID(); } - SequenceGroupMeta(std::vector &_inputTokens, SamplingMeta &samplingMeta_) : samplingMeta(samplingMeta_) { + SequenceGroupMeta(const std::vector &_inputTokens, SamplingMeta &samplingMeta_) + : samplingMeta(samplingMeta_) { sequences.reserve(samplingMeta.config.numBeams); for (int i = 0; i < samplingMeta.config.numBeams; ++i) { sequences.emplace_back(SequenceMeta(_inputTokens)); @@ -206,7 +203,7 @@ class SequenceGroupMeta { groupID = sequences[0].getSequenceID(); } - SequenceGroupMeta(std::vector &_inputTokens) { + SequenceGroupMeta(const std::vector &_inputTokens) { sequences.reserve(samplingMeta.config.numBeams); for (int i = 0; i < samplingMeta.config.numBeams; ++i) { sequences.emplace_back(SequenceMeta(_inputTokens)); @@ -222,7 +219,8 @@ class SequenceGroupMeta { groupID = sequences[0].getSequenceID(); } - SequenceGroupMeta(int32_t _sequenceID, std::vector &_inputTokens, SamplingMeta &samplingMeta_) : samplingMeta(samplingMeta_) { + SequenceGroupMeta(int32_t _sequenceID, const std::vector &_inputTokens, SamplingMeta &samplingMeta_) + : samplingMeta(samplingMeta_) { sequences.reserve(samplingMeta.config.numBeams); for (int i = 0; i < samplingMeta.config.numBeams; ++i) { sequences.emplace_back(SequenceMeta(_sequenceID, _inputTokens)); @@ -230,7 +228,8 @@ class SequenceGroupMeta { groupID = sequences[0].getSequenceID(); } - SequenceGroupMeta(int32_t _sequenceID, int32_t _inputSeqLen, SamplingMeta &samplingMeta_) : samplingMeta(samplingMeta_) { + SequenceGroupMeta(int32_t _sequenceID, int32_t _inputSeqLen, SamplingMeta &samplingMeta_) + : samplingMeta(samplingMeta_) { sequences.reserve(samplingMeta.config.numBeams); for (int i = 0; i < samplingMeta.config.numBeams; ++i) { sequences.emplace_back(SequenceMeta(_sequenceID, _inputSeqLen)); @@ -238,7 +237,7 @@ class SequenceGroupMeta { groupID = sequences[0].getSequenceID(); } - SequenceGroupMeta(int32_t _sequenceID, std::vector &_inputTokens) { + SequenceGroupMeta(int32_t _sequenceID, const std::vector &_inputTokens) { sequences.reserve(samplingMeta.config.numBeams); for (int i = 0; i < samplingMeta.config.numBeams; ++i) { sequences.emplace_back(SequenceMeta(_sequenceID, _inputTokens)); @@ -319,7 +318,8 @@ class SequencePool { return group; } - SequenceGroupMeta *newGroupMeta(int32_t sequenceID, std::vector &inputTokens, SamplingMeta &samplingMeta_) { + SequenceGroupMeta *newGroupMeta( + int32_t sequenceID, std::vector &inputTokens, SamplingMeta &samplingMeta_) { auto *group = new SequenceGroupMeta(sequenceID, inputTokens, samplingMeta_); this->add(group); return group; diff --git a/src/layers/attention.h b/src/layers/attention.h index 89be2bfd..adf3f809 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -46,12 +46,13 @@ template class Attention { public: - Attention(int layerId, DecoderContext *ctx) : layerId(layerId), qkpo(ctx->attHeadSize, ctx->maxPosEmbed) { + Attention(int layerId, DecoderContext *ctx) + : layerId(layerId), qkpo(ctx->attHeadSize, ctx->maxPosEmbed), norm(NORM_CLS(ctx)) { //todo(marvin): clear this code after all rotary_emb refactor - if constexpr (std::is_same::value) { qkpo = LlamaRotaryEmbedding(ctx); } - - norm = new NORM_CLS(ctx); + if constexpr (std::is_same::value) { + qkpo = LlamaRotaryEmbedding(ctx); + } // Group attention or multi-head attention (multi-head attn is a special case of group attn) if (ctx->attHeadNum % ctx->kvHeadNum == 0) { @@ -188,7 +189,8 @@ class Attention { outWeightT.Resize(ctx->attHeadNum * ctx->attHeadSize, hiddenSize); ctx->mmHelper->transposeWeight(true, convertedOutWeight, outWeightT); - WeiT *outWeiData = (WeiT *)xft::alloc(ctx->attHeadNum * ctx->attHeadSize * hiddenSize * sizeof(WeiT), ctx->device); + WeiT *outWeiData + = (WeiT *)xft::alloc(ctx->attHeadNum * ctx->attHeadSize * hiddenSize * sizeof(WeiT), ctx->device); attnOutputWeight.Assign(outWeiData, ctx->attHeadNum * ctx->attHeadSize, hiddenSize, hiddenSize); int outWeightTSize = ctx->attHeadNum * ctx->attHeadSize * hiddenSize * sizeof(WeiT); xft::memcopy(attnOutputWeight.Data(), outWeightT.Data(), outWeightTSize, ctx->device); @@ -217,7 +219,7 @@ class Attention { } // LayerNorm - if (doLNorm) this->norm->setWeight(gamma1, beta1, hiddenSize); + if (doLNorm) this->norm.setWeight(gamma1, beta1, hiddenSize); } #ifdef XFT_DEBUG @@ -272,7 +274,7 @@ class Attention { if (doLnBefore) { TimeLine t1("input.layer_norm"); - norm->forward(inputBuffer.Data(), imBuffer.Data(), inputBuffer.Rows(), inputBuffer.Stride(), + norm.forward(inputBuffer.Data(), imBuffer.Data(), inputBuffer.Rows(), inputBuffer.Stride(), imBuffer.Stride(), epsilon); } #ifdef XFT_DEBUG @@ -423,7 +425,7 @@ class Attention { if (doLnAfter) { TimeLine t6("result.layer_norm"); - norm->forward(outBuffer.Data(), outBuffer.Data(), outBuffer.Rows(), outBuffer.Stride(), outBuffer.Stride()); + norm.forward(outBuffer.Data(), outBuffer.Data(), outBuffer.Rows(), outBuffer.Stride(), outBuffer.Stride()); #ifdef XFT_DEBUG dbg.debugPrint("LayerNorm after attention: [%d, %d] (%d)\n", outBuffer.Rows(), outBuffer.Cols(), outBuffer.Stride()); @@ -466,7 +468,7 @@ class Attention { if (doLnBefore) { TimeLine t1("input.layer_norm"); - norm->forward(inputBuffer.Data(), imBuffer.Data(), inputBuffer.Rows(), inputBuffer.Stride(), + norm.forward(inputBuffer.Data(), imBuffer.Data(), inputBuffer.Rows(), inputBuffer.Stride(), imBuffer.Stride(), epsilon); } #ifdef XFT_DEBUG @@ -610,7 +612,7 @@ class Attention { if (!doLnBefore) { TimeLine t6("result.layer_norm"); - norm->forward(outBuffer.Data(), outBuffer.Data(), outBuffer.Rows(), outBuffer.Stride(), outBuffer.Stride()); + norm.forward(outBuffer.Data(), outBuffer.Data(), outBuffer.Rows(), outBuffer.Stride(), outBuffer.Stride()); #ifdef XFT_DEBUG dbg.debugPrint("LayerNorm after attention: [%d, %d] (%d)\n", outBuffer.Rows(), outBuffer.Cols(), outBuffer.Stride()); @@ -1189,7 +1191,7 @@ class Attention { QKPO_CLS qkpo; // layerNorm param - NORM_CLS *norm; + NORM_CLS norm; int layerId; // Alibi Slopes diff --git a/src/layers/layer_norm.cpp b/src/layers/layer_norm.cpp index 761a8fbb..ba958cb6 100644 --- a/src/layers/layer_norm.cpp +++ b/src/layers/layer_norm.cpp @@ -64,6 +64,9 @@ void LayerNorm::forward(const float *input, float *output, int rows, int iStride TimeLine t("LayerNorm.forward"); const float *pgamma = gamma; const float *pbeta = beta; + // TODO: Add LayerNorm Impl + printf("%s:%d: Could not forward in LayerNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); } #else void LayerNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { diff --git a/src/layers/mlp_chatglm2.h b/src/layers/mlp_chatglm2.h index 67bb40fa..885a6280 100644 --- a/src/layers/mlp_chatglm2.h +++ b/src/layers/mlp_chatglm2.h @@ -121,7 +121,7 @@ class ChatGLM2MLP : public LlamaMLP { #endif // norm.setWeight(normW, NULL, hiddenSize); - if (normW) { norm->setWeight(normW, nullptr, hiddenSize); } + if (normW) { norm.setWeight(normW, nullptr, hiddenSize); } } private: diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index 80f2a8fc..d3833247 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -34,10 +34,11 @@ // def forward(self, x): // return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) // But please also be noted: we extended the MLP to include layer norm -template +template class LlamaMLP { public: - LlamaMLP(DecoderContext *ctx) { norm = new NORM_CLS(ctx); } + LlamaMLP(DecoderContext *ctx) : norm(NORM_CLS(ctx)) {} // OriWeiT: float, int8_t or uint4x2_t template @@ -119,7 +120,7 @@ class LlamaMLP { #endif // LlamaRMSNorm - if (normW) { norm->setWeight(normW, nullptr, hiddenSize); } + if (normW) { norm.setWeight(normW, nullptr, hiddenSize); } } #ifdef XFT_DEBUG @@ -142,7 +143,7 @@ class LlamaMLP { (ImT *)ctx->normBuf.Data(), ctx->normBuf.Rows(), ctx->normBuf.Cols(), ctx->normBuf.Stride()); if (doLnBefore == true) { - norm->forward(inBuffer.Data(), normBuffer.Data(), M, inBuffer.Stride(), normBuffer.Stride(), 1e-6); + norm.forward(inBuffer.Data(), normBuffer.Data(), M, inBuffer.Stride(), normBuffer.Stride(), 1e-6); } #ifdef XFT_DEBUG @@ -379,7 +380,7 @@ class LlamaMLP { xft::Vector downWeightSum; // For int8_t weight // LlamaRMSNorm param - NORM_CLS *norm; + NORM_CLS norm; #ifdef XFT_DEBUG Debugger dbg; diff --git a/src/models/llama.cpp b/src/models/llama.cpp index 8c2ba400..70fb155e 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -38,6 +38,7 @@ LlamaLLM::LlamaLLM(const std::string &modelPath) template LlamaLLM::~LlamaLLM() { delete embedding; + delete finalLN; } template diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index a3fe1a8d..38104d10 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -61,6 +61,14 @@ class MMHelper { ~MMHelper() { if (engine) delete engine; if (stream) delete stream; + + for (auto &pair : matmul_hub) { + dnnl::matmul::primitive_desc *primitive_desc_ptr = std::get<0>(pair.second); + dnnl::matmul *matmul_ptr = std::get<1>(pair.second); + + delete primitive_desc_ptr; + delete matmul_ptr; + } } // Pack the MatMul weight from 'src(rows, cols)' to 'weight' @@ -1050,7 +1058,8 @@ class MMHelper { if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_compute_resmul", xdnn_hgemm_compute_resmul(transA, M, N, K, alpha, (const XDNN_FP16 *)A, lda, - (const XDNN_FP16 *)packedB, beta, (XDNN_FP16 *)C, ldc, (const XDNN_FP16 *)res, ldres)); + (const XDNN_FP16 *)packedB, beta, (XDNN_FP16 *)C, ldc, (const XDNN_FP16 *)res, + ldres)); } else if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f16f16f32_compute_resmul", xdnn_hgemm_f16f16f32_compute_resmul(transA, M, N, K, alpha, (const XDNN_FP16 *)A, lda, @@ -1173,7 +1182,8 @@ class MMHelper { if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_compute_residential", xdnn_hgemm_compute_residential(transA, M, N, K, alpha, (const XDNN_FP16 *)A, lda, - (const XDNN_FP16 *)packedB, beta, (XDNN_FP16 *)C, ldc, bias, (const XDNN_FP16 *)res, ldres)); + (const XDNN_FP16 *)packedB, beta, (XDNN_FP16 *)C, ldc, bias, (const XDNN_FP16 *)res, + ldres)); } else if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f16f16f32_compute_residential", xdnn_hgemm_f16f16f32_compute_residential(transA, M, N, K, alpha, (const XDNN_FP16 *)A, lda, @@ -1297,11 +1307,13 @@ class MMHelper { if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_compute_resext", xdnn_hgemm_compute_resext(transA, M, N, K, alpha, (const XDNN_FP16 *)A, lda, - (const XDNN_FP16 *)packedB, beta, (XDNN_FP16 *)C, ldc, bias, gamma, (const XDNN_FP16 *)res, ldres)); + (const XDNN_FP16 *)packedB, beta, (XDNN_FP16 *)C, ldc, bias, gamma, + (const XDNN_FP16 *)res, ldres)); } else if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f16f16f32_compute_resext", xdnn_hgemm_f16f16f32_compute_resext(transA, M, N, K, alpha, (const XDNN_FP16 *)A, lda, - (const XDNN_FP16 *)packedB, beta, C, ldc, bias, gamma, (const XDNN_FP16 *)res, ldres)); + (const XDNN_FP16 *)packedB, beta, C, ldc, bias, gamma, (const XDNN_FP16 *)res, + ldres)); } } #else diff --git a/tests/ut/attention_kernels_test.cpp b/tests/ut/attention_kernels_test.cpp index d9100323..09ed540d 100644 --- a/tests/ut/attention_kernels_test.cpp +++ b/tests/ut/attention_kernels_test.cpp @@ -87,6 +87,7 @@ static void selfAttentionRef(bfloat16_t *output, bfloat16_t *query, bfloat16_t * const float scale) { int rowOffsets[batchSize]; + memset(rowOffsets, 0 , batchSize * sizeof(int)); for (int i = 1; i < batchSize; i++) { rowOffsets[i] = rowOffsets[i - 1] + tokenSizes[i - 1]; } From 69dd33af7773521145ce69b54c428ca58f47a2e0 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Fri, 7 Jun 2024 16:33:34 +0000 Subject: [PATCH 17/34] update. --- src/layers/attention.h | 2 +- src/layers/mlp_llama.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/attention.h b/src/layers/attention.h index adf3f809..2ff9ed62 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -47,7 +47,7 @@ template attHeadSize, ctx->maxPosEmbed), norm(NORM_CLS(ctx)) { + : layerId(layerId), qkpo(ctx->attHeadSize, ctx->maxPosEmbed), norm(ctx) { //todo(marvin): clear this code after all rotary_emb refactor if constexpr (std::is_same::value) { diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index d3833247..c37f803b 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -38,7 +38,7 @@ template class LlamaMLP { public: - LlamaMLP(DecoderContext *ctx) : norm(NORM_CLS(ctx)) {} + LlamaMLP(DecoderContext *ctx) : norm(ctx) {} // OriWeiT: float, int8_t or uint4x2_t template From 5f438970f54efd96720aaa1c9390089281a8a62e Mon Sep 17 00:00:00 2001 From: changqi1 Date: Wed, 12 Jun 2024 23:11:20 +0000 Subject: [PATCH 18/34] Add GPU memory to run kernels. --- src/common/allocator.h | 12 +++++ src/common/transformer_ctx.h | 21 +++++---- src/kernels/rotary_embedding_kernels.cpp | 58 ++++++++++-------------- src/layers/attention.h | 34 +++++++++----- src/layers/layer_norm.cpp | 4 +- src/layers/mlp_llama.h | 8 ++-- src/layers/rms_norm.cpp | 7 ++- src/layers/rotary_embedding.cpp | 4 +- src/models/common_decoder.h | 18 ++++++++ src/utils/debugger.h | 11 ++++- src/utils/decoder_util.h | 36 +++++++++++++++ src/utils/matmul_helper.h | 56 +++++++++++++---------- src/utils/simple_mem_pool.h | 20 ++++---- 13 files changed, 194 insertions(+), 95 deletions(-) diff --git a/src/common/allocator.h b/src/common/allocator.h index 99b31d3f..87ad1190 100644 --- a/src/common/allocator.h +++ b/src/common/allocator.h @@ -88,4 +88,16 @@ static inline void memcopy(void *dst, const void *src, size_t size, void *device memcpy(dst, src, size); } +static inline void memsetv(void *dst, int ch, size_t size, void *device = nullptr) { +#ifdef GPU + if (device != nullptr) { + sycl::queue *gpu_queue = static_cast(device); + gpu_queue->memset(dst, ch, size).wait(); + return; + } +#endif + + memset(dst, ch, size); +} + } // namespace xft \ No newline at end of file diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index 6999b634..71d05f31 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -171,9 +171,8 @@ struct DecoderContext { } } - this->rawBufSize = 4 * 32 * intermediateSize + 4 * attHeadNum * 32 * 32; // assume bs=4, seq=32 - this->rawBuffer = (float *)xft::alloc(sizeof(float) * rawBufSize); - memset(this->rawBuffer, 0, sizeof(float) * rawBufSize); + this->rawBufSize = 0; + this->rawBuffer = nullptr; if (act == "relu") { this->actType = RELU; @@ -245,8 +244,8 @@ struct DecoderContext { return (T *)SimpleMemPool::instance().getBuffer(name, sizeof(T) * size, device, alignment); } - void freeBuffer(const std::string &name, void *device = nullptr) { - SimpleMemPool::instance().freeBuffer(name, device); + void freeBuffer(const std::string &name) { + SimpleMemPool::instance().freeBuffer(name); } void dump() { @@ -291,10 +290,10 @@ struct DecoderContext { uint64_t total = size1 + size2 + size3; if (total > this->rawBufSize) { this->rawBufSize = total; - free(this->rawBuffer); + if (this->rawBuffer) xft::dealloc(this->rawBuffer, this->device); - this->rawBuffer = (float *)xft::alloc(sizeof(float) * rawBufSize); - memset(this->rawBuffer, 0, sizeof(float) * rawBufSize); + this->rawBuffer = (float *)xft::alloc(sizeof(float) * rawBufSize, this->device); + xft::memsetv(this->rawBuffer, 0, sizeof(float) * rawBufSize, this->device); } // Assign the buffer @@ -317,5 +316,9 @@ struct DecoderContext { return rawBufSize - size1 - size2; } - ~DecoderContext() { free(this->rawBuffer); } + ~DecoderContext() { + xft::dealloc(this->rawBuffer, this->device); + if (this->mmHelper) delete this->mmHelper; + if (this->device) delete this->device; + } }; \ No newline at end of file diff --git a/src/kernels/rotary_embedding_kernels.cpp b/src/kernels/rotary_embedding_kernels.cpp index 983a3251..aa67b1d7 100644 --- a/src/kernels/rotary_embedding_kernels.cpp +++ b/src/kernels/rotary_embedding_kernels.cpp @@ -389,8 +389,8 @@ void qwenApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, i #ifdef GPU // For LLaMA template -static inline void llamaApplyRotaryPosEmbeding(void *device, T *query, T *key, int qStride, int kStride, float *emb_cos, - float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) { +static inline void llamaApplyRotaryPosEmbeding(void *device, T *query, T *key, int qStride, int kStride, + const float *emb_cos, const float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) { int dim = inv_freq_size * 2; REQUIRES(dim == qkShape[3], "Incorrect shape, this dimention is not the head size."); @@ -403,33 +403,6 @@ static inline void llamaApplyRotaryPosEmbeding(void *device, T *query, T *key, i const int half_head_size = (head_size + 1) / 2; using namespace sycl; - auto rope_kernel - = [](sycl::nd_item<3> &item, const float *embCos, const float *embSin, const int qHeads, const int kHeads, - const int seq_size, const int head_size, const int half, T *query, T *key, int qStride, - int kStride, const sycl::accessor &positionIds) { - size_t idx_bs_seq = item.get_global_id(0); - size_t idx_head_num = item.get_global_id(1); - size_t idx_half_head_dim = item.get_global_id(2); - - size_t pos = positionIds[idx_bs_seq % seq_size]; - float cos = embCos[pos * half + idx_half_head_dim]; - float sin = embSin[pos * half + idx_half_head_dim]; - - T *q = query + idx_bs_seq * qStride + idx_head_num * head_size + idx_half_head_dim; - T *k = key + idx_bs_seq * kStride + idx_head_num * head_size + idx_half_head_dim; - - if (idx_head_num < qHeads) { - auto q1 = q[0]; - q[0] = q1 * cos - q[half] * sin; - q[half] = q[half] * cos + q1 * sin; - } - if (idx_head_num < kHeads) { - auto k1 = k[0]; - k[0] = k1 * cos - k[half] * sin; - k[half] = k[half] * cos + k1 * sin; - } - }; - // Reorder input sycl::queue *gpu_queue = static_cast(device); sycl::buffer positionIdsBuf(positionIds, sycl::range<1>(seqLen)); @@ -439,8 +412,27 @@ static inline void llamaApplyRotaryPosEmbeding(void *device, T *query, T *key, i sycl::range<3> workGroupSize(1, 1, 1); cgh.parallel_for(sycl::nd_range(globalSize, workGroupSize), [=](sycl::nd_item<3> item) { - rope_kernel(item, emb_cos, emb_sin, qHeads, kHeads, seqLen, head_size, half_head_size, query, key, qStride, - kStride, position); + size_t idx_bs_seq = item.get_global_id(0); + size_t idx_head_num = item.get_global_id(1); + size_t idx_half_head_dim = item.get_global_id(2); + + size_t pos = position[idx_bs_seq % seqLen]; + const sycl::half cos = (sycl::half)emb_cos[pos * half_head_size + idx_half_head_dim]; + const sycl::half sin = (sycl::half)emb_sin[pos * half_head_size + idx_half_head_dim]; + + sycl::half *q = (sycl::half *)query + idx_bs_seq * qStride + idx_head_num * head_size + idx_half_head_dim; + sycl::half *k = (sycl::half *)key + idx_bs_seq * kStride + idx_head_num * head_size + idx_half_head_dim; + + if (idx_head_num < qHeads) { + auto q1 = q[0]; + q[0] = q1 * cos - q[half_head_size] * sin; + q[half_head_size] = q[half_head_size] * cos + q1 * sin; + } + if (idx_head_num < kHeads) { + auto k1 = k[0]; + k[0] = k1 * cos - k[half_head_size] * sin; + k[half_head_size] = k[half_head_size] * cos + k1 * sin; + } }); }); gpu_queue->wait(); @@ -460,8 +452,8 @@ void llamaApplyRotaryPosEmbeding(void *device, bfloat16_t *query, bfloat16_t *ke void llamaApplyRotaryPosEmbeding(void *device, float16_t *query, float16_t *key, int qStride, int kStride, float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) { - llamaApplyRotaryPosEmbeding( - device, query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); + llamaApplyRotaryPosEmbeding(device, (sycl::half *)query, (sycl::half *)key, qStride, kStride, emb_cos, + emb_sin, inv_freq_size, qkShape, positionIds); } #endif diff --git a/src/layers/attention.h b/src/layers/attention.h index 2ff9ed62..b6ddd484 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -144,7 +144,7 @@ class Attention { #ifdef GPU xft::Matrix qkvWeightT; qkvWeightT.Resize(hiddenSize, responsibleCols); - ctx->mmHelper->transposeWeight(true, convertedqkvWeight, qkvWeightT); + ctx->mmHelper->transposeWeight(trans, convertedqkvWeight, qkvWeightT); WeiT *qkvWeiData = (WeiT *)xft::alloc(hiddenSize * responsibleCols * sizeof(WeiT), ctx->device); qkvWeight.Assign(qkvWeiData, hiddenSize, responsibleCols, responsibleCols); @@ -187,7 +187,7 @@ class Attention { #ifdef GPU xft::Matrix outWeightT; outWeightT.Resize(ctx->attHeadNum * ctx->attHeadSize, hiddenSize); - ctx->mmHelper->transposeWeight(true, convertedOutWeight, outWeightT); + ctx->mmHelper->transposeWeight(trans, convertedOutWeight, outWeightT); WeiT *outWeiData = (WeiT *)xft::alloc(ctx->attHeadNum * ctx->attHeadSize * hiddenSize * sizeof(WeiT), ctx->device); @@ -326,13 +326,18 @@ class Attention { std::iota(posIds.begin(), posIds.end(), pastSeqLen); } qkpo.forward(query.Data(), key.Data(), query.Stride(), key.Stride(), qkShape, posIds.data()); -#ifdef GPU - int64_t size = ctx->batchSize * ctx->inputSeqLen * qkvCols * sizeof(float); - xft::memcopy(qkvMatMul.Data(), query.Data(), size, ctx->device); // error: need CPU ptr and GPU ptr -#endif } t3.release(); +#ifdef GPU + int64_t qkvSize = qkvRows * qkvStride * sizeof(ImT); + ImT *qkvTmp = (ImT *)xft::alloc(qkvSize); + xft::memcopy(qkvTmp, qkvGroupMatMul.Data(), qkvSize, ctx->device); // error: need CPU ptr and GPU ptr + query.Assign(qkvTmp, inputBuffer.Rows(), qCols, qkvCols); + key.Assign(qkvTmp + qCols, inputBuffer.Rows(), kvCols, qkvCols); + value.Assign(qkvTmp + qCols + kvCols, inputBuffer.Rows(), kvCols, qkvCols); +#endif + #ifdef XFT_DEBUG dbg.debugPrint("Q[%d,%d](%d) after post op:\n", query.Rows(), query.Cols(), query.Stride()); dbg.dumpMatrix(query); @@ -356,6 +361,12 @@ class Attention { // For multiple nodes inference, not the whole result buffer xft::Matrix attnSplit(imBuffer.Data(), imBuffer.Rows(), qCols, qCols); +#ifdef GPU + int64_t attnSplitSize = imBuffer.Rows() * qCols * sizeof(ImT); + ImT *attnSplitTmp = (ImT *)xft::alloc(attnSplitSize); + attnSplit.Assign(attnSplitTmp, imBuffer.Rows(), qCols, qCols); +#endif + if (pastSeqLen == 0) { if (ctx->inputSeqLen > getFlashThresh()) { flashAttention(ctx, query, key, value, attnSplit, presentKey, presentValue, attnMask, pastSeqLen); @@ -369,17 +380,18 @@ class Attention { } t4.release(); +#ifdef GPU + xft::memcopy(imBuffer.Data(), attnSplit.Data(), attnSplitSize, ctx->device); + attnSplit.Assign(imBuffer.Data(), imBuffer.Rows(), qCols, qCols); + xft::dealloc(qkvTmp); +#endif + #ifdef XFT_DEBUG dbg.debugPrint(">>> attention_%d (softmax * value): [%d, %d] (%d)\n", ctx->splitIdx, attnSplit.Rows(), attnSplit.Cols(), attnSplit.Stride()); dbg.dumpMatrix(attnSplit); #endif -#ifdef GPU - int64_t size = ctx->batchSize * ctx->inputSeqLen * qkvCols * sizeof(float); - xft::memcopy(qkvMatMul.Data(), attnSplit.Data(), size, ctx->device); // error: need CPU ptr and GPU ptr -#endif - TimeLine t5("Output"); // Output/projection in attention, only add the input in the first split if (ctx->splitIdx == 0) { diff --git a/src/layers/layer_norm.cpp b/src/layers/layer_norm.cpp index ba958cb6..44ad949f 100644 --- a/src/layers/layer_norm.cpp +++ b/src/layers/layer_norm.cpp @@ -39,8 +39,8 @@ LayerNorm::LayerNorm(DecoderContext *ctx) { } LayerNorm::~LayerNorm() { - if (gamma) { xft::dealloc(gamma); } - if (beta) { xft::dealloc(beta); } + if (gamma) { xft::dealloc(gamma, device); } + if (beta) { xft::dealloc(beta, device); } } void LayerNorm::setWeight(const float *gamma, const float *beta, int cols) { diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index c37f803b..189fcf7b 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -98,7 +98,7 @@ class LlamaMLP { int downWeiRows = it.second - it.first; int downWeiCols = hiddenSize; downWeightT.Resize(downWeiRows, downWeiCols); - ctx->mmHelper->transposeWeight(true, quantizedDownWeight, downWeightT); + ctx->mmHelper->transposeWeight(trans, quantizedDownWeight, downWeightT); WeiT *downWeiData = (WeiT *)xft::alloc(downWeiRows * downWeiCols * sizeof(WeiT), ctx->device); downWeight.Assign(downWeiData, downWeiRows, downWeiCols, downWeiCols); @@ -184,7 +184,7 @@ class LlamaMLP { // Need to allocate extra buffer as oneDNN does not support the case of stride > cols const int cols = N / 2; auto bufSize = sizeof(ImT) * M * cols; - ImT *t = (ImT *)SimpleMemPool::instance().getBuffer("mlp_silu", bufSize); + ImT *t = (ImT *)SimpleMemPool::instance().getBuffer("mlp_silu", bufSize, ctx->device); xft::Matrix siluBuf(t, M, cols, cols); #ifdef XFT_DEBUG dbg.debugPrint( @@ -314,9 +314,9 @@ class LlamaMLP { // Compute silu on the left half and then add it with the right half if (ctx->actType == DecoderContext::SILU) { - DecoderUtil::siluSum(output, siluBuf); + DecoderUtil::siluSum(output, siluBuf, ctx->device); } else if (ctx->actType == DecoderContext::SWIGLU) { // chatglm2/3 - DecoderUtil::siluSum(output, siluBuf); + DecoderUtil::siluSum(output, siluBuf, ctx->device); } else if (ctx->actType == DecoderContext::GELU) { // gemma DecoderUtil::geluSum(output, siluBuf); } else { diff --git a/src/layers/rms_norm.cpp b/src/layers/rms_norm.cpp index 7998eab8..8842a965 100644 --- a/src/layers/rms_norm.cpp +++ b/src/layers/rms_norm.cpp @@ -45,7 +45,7 @@ RmsNormImp::RmsNormImp(DecoderContext *ctx) { template RmsNormImp::~RmsNormImp() { - if (weight) { xft::dealloc(weight); } + if (weight) { xft::dealloc(weight, device); } } template @@ -69,7 +69,10 @@ void RmsNormImp::setWeight(const float *w, const float *, int cols) { template void RmsNormImp::setWeight(const std::string &modelPath, const std::string &, int cols) { this->normSize = cols; - loadWeight(modelPath, weight, cols); + float *weiBuf = (float *)xft::alloc(cols * sizeof(float)); + loadWeight(modelPath, weiBuf, cols, DataType::fp32); + this->setWeight(weiBuf, nullptr, cols); + xft::dealloc(weiBuf); } #ifdef GPU diff --git a/src/layers/rotary_embedding.cpp b/src/layers/rotary_embedding.cpp index 12207530..f8a473ff 100644 --- a/src/layers/rotary_embedding.cpp +++ b/src/layers/rotary_embedding.cpp @@ -22,6 +22,7 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) { const std::string emb_cos_str = "emb_cos"; const std::string emb_sin_str = "emb_sin"; + this->device = ctx->device; this->dim = ctx->attHeadSize; this->max_position_embeddings = ctx->maxPosEmbed; ctx->GetAttr("rope_theta", &this->base, 10000); @@ -43,8 +44,7 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) { } xft::llamaSetCosSinCache(inv_freq, emb_cos, emb_sin, inv_freq_size, max_position_embeddings); #ifdef GPU - device = ctx->device; - if (device != nullptr) { + if (this->device != nullptr) { float *emb_cos_bak = emb_cos; float *emb_sin_bak = emb_sin; emb_cos = ctx->getBuffer(emb_cos_str + "_gpu", max_position_embeddings * inv_freq_size, device); diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index 2f2724b9..ff2b308a 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -370,6 +370,15 @@ class CommonDecoder : public AbstractDecoder { TimeLine t("Decoder.Seq" + std::to_string(sequenceID) + ".Step"); #endif +#ifdef GPU + size_t embBufSize = batchSize * inputSeqLen * hiddenSize * sizeof(AttnInT); + AttnInT *embBufTmp = (AttnInT *)xft::alloc(embBufSize, ctx->device); + AttnInT *outBufTmp = (AttnInT *)xft::alloc(embBufSize, ctx->device); + xft::memcopy(embBufTmp, embBuf, embBufSize, ctx->device); + embBuf = embBufTmp; + outBuf = outBufTmp; +#endif + // Decoder: forward int layers_per_pp_stage = decoderBlock->size(); for (int i = 0; i < layers_per_pp_stage; ++i) { @@ -423,6 +432,15 @@ class CommonDecoder : public AbstractDecoder { } } +#ifdef GPU + embBufTmp = (AttnInT *)actBuffers->Data(); + xft::memcopy(embBufTmp, embBuf, embBufSize, ctx->device); + xft::dealloc(embBuf, ctx->device); + xft::dealloc(outBuf, ctx->device); + embBuf = embBufTmp; + outBuf = (MlpOutT *)(embBuf + batchSize * inputSeqLen * hiddenSize); +#endif + #ifdef PIPELINE_PARALLEL } diff --git a/src/utils/debugger.h b/src/utils/debugger.h index ca74d115..5dea56de 100644 --- a/src/utils/debugger.h +++ b/src/utils/debugger.h @@ -114,6 +114,15 @@ class Debugger { } } +#ifdef GPU + template + void dumpMatrix(xft::Matrix &m, bool print_all = false) { + } + + template + void dumpMatrix(T *data, uint64_t rows, uint64_t cols, uint64_t stride, bool print_all = false) { + } +#else template void dumpMatrix(xft::Matrix &m, bool print_all = false) { std::ostringstream oss; @@ -281,7 +290,7 @@ class Debugger { fflush(debugFile); } } - +#endif // Function to store float* data to a file template void storeMatrix(const std::string &filename, const T *data, uint64_t rows, uint64_t cols) { diff --git a/src/utils/decoder_util.h b/src/utils/decoder_util.h index 08289a16..343b36ff 100644 --- a/src/utils/decoder_util.h +++ b/src/utils/decoder_util.h @@ -28,6 +28,10 @@ #include "transformer_ctx.h" #include "xdnn.h" +#ifdef GPU +#include +#endif + extern int getFlashThresh(); extern bool enableCATMLP(); extern bool enableSkipMsk(); @@ -444,6 +448,37 @@ class DecoderUtil { return std::make_pair(maxVal, sum); } +#ifdef GPU + template + static void siluSum(xft::Matrix &src, xft::Matrix &dst, void *device = nullptr) { + int M = src.Rows(); + int lds = src.Stride(); + int N = lds / 2; + int ldd = dst.Stride(); + + if (device != nullptr) { + sycl::queue *gpu_queue = static_cast(device); + + if constexpr (std::is_same_v && std::is_same_v) { + const float16_t *src0 = src.Data(); + const float16_t *src1 = src.Data() + N; + sycl::half *dest = (sycl::half *)dst.Data(); + + gpu_queue + ->submit([&](sycl::handler &h) { + h.parallel_for(M * N, [=](auto i) { + int32_t row = i / N; + int32_t col = i % N; + dest[row * ldd + col] = ((sycl::half)src0[row * lds + col] + / ((sycl::half)1.0f + (sycl::half)sycl::native::exp(-src0[row * lds + col])) + * (sycl::half)src1[row * lds + col]); + }); + }) + .wait(); + } + } + } +#else // compute silu on the left half and then add it with the right half template static void siluSum(xft::Matrix &src, xft::Matrix &dst) { @@ -469,6 +504,7 @@ class DecoderUtil { } } } +#endif // compute gelu on the left half and then add it with the right half template diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index 38104d10..63ec8515 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -566,9 +566,9 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_biasadd", - xdnn_sgemm_f32f16f32_compute_biasadd( - transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, bias)); + GEMMVERBOSE("onednn_gemm_compute_bias", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, (const InT *)nullptr, -1, + matmul_kinds::BiasAdd)); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_biasadd", @@ -685,12 +685,13 @@ class MMHelper { GEMMVERBOSE("xdnn_sgemm_compute_biasadd_relu", xdnn_sgemm_compute_biasadd_relu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias)); } + // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_biasadd_relu", - xdnn_sgemm_f32f16f32_compute_biasadd_relu( - transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, bias)); + GEMMVERBOSE("onednn_gemm_compute_bias_relu", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, (const InT *)nullptr, -1, + matmul_kinds::BiasAdd_Relu)); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_biasadd_relu", @@ -803,9 +804,9 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_silu", - xdnn_sgemm_f32f16f32_compute_silu( - transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc)); + GEMMVERBOSE("onednn_gemm_compute_silu", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, (const float *)nullptr, + (const InT *)nullptr, -1, matmul_kinds::Silu)); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_silu", @@ -924,9 +925,9 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_gelu", - xdnn_sgemm_f32f16f32_compute_gelu( - transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc)); + GEMMVERBOSE("onednn_gemm_compute_gelu", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, (const float *)nullptr, + (const InT *)nullptr, -1, matmul_kinds::Gelu)); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_gelu", @@ -1046,9 +1047,9 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_resmul", - xdnn_sgemm_f32f16f32_compute_resmul( - transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, res, ldres)); + GEMMVERBOSE("onednn_gemm_compute_resmul", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, (const float *)nullptr, + res, ldres, matmul_kinds::Resmul)); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_resmul", @@ -1170,9 +1171,9 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_residential", - xdnn_sgemm_f32f16f32_compute_residential(transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, - beta, C, ldc, bias, res, ldres)); + GEMMVERBOSE("onednn_gemm_compute_residential", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, res, ldres, + matmul_kinds::Residential)); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_residential", @@ -1295,9 +1296,19 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_resext", - xdnn_sgemm_f32f16f32_compute_resext(transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, - beta, C, ldc, bias, gamma, res, ldres)); +#pragma omp parallel for collapse(2) + for (uint64_t i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + auto remain = N - j; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + auto v = xft::load_avx512(mask, &res[i * ldres + j]); + v = _mm512_mul_ps(_mm512_set1_ps(gamma), v); + xft::store_avx512(&res[i * ldres + j], mask, v); + } + } + GEMMVERBOSE("onednn_gemm_compute_resext", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, res, ldres, + matmul_kinds::Residential)); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_resext", @@ -1688,10 +1699,9 @@ class MMHelper { memory bias_mem; if (bias != nullptr) { bias_mem = memory(matmul_pd->bias_desc(), *engine, const_cast(bias)); } - memory::desc shift_md; memory shift_mem; if (res != nullptr) { - shift_md = memory::desc({M, N}, shift_dt, get_onednn_shift_layout(shift_dt)); + memory::desc shift_md = memory::desc({M, N}, shift_dt, get_onednn_shift_layout(shift_dt)); if constexpr (std::is_same_v) { shift_mem = memory(shift_md, *engine); } else { diff --git a/src/utils/simple_mem_pool.h b/src/utils/simple_mem_pool.h index 7ade4858..dabb1a7f 100644 --- a/src/utils/simple_mem_pool.h +++ b/src/utils/simple_mem_pool.h @@ -24,7 +24,7 @@ class SimpleMemPool { private: - std::unordered_map> memoryMap; + std::unordered_map> memoryMap; // Private constructor to enforce Singleton pattern SimpleMemPool() {} @@ -47,6 +47,8 @@ class SimpleMemPool { // Allocate or reallocate memory buffer based on name and size void *getBuffer(const std::string &name, size_t size, void *device = nullptr, size_t alignment = 64) { + if (name.empty()) return nullptr; + if (size == 0) { // std::cout << "[Warning] Try to allocate 0 bytes for buffer:" << name << std::endl; return nullptr; @@ -55,12 +57,12 @@ class SimpleMemPool { if (it != memoryMap.end()) { // Buffer with the given name found - if (it->second.second >= size) { + if (std::get<1>(it->second) >= size) { // Existing buffer size is sufficient, return it - return it->second.first; + return std::get<0>(it->second); } else { // Reallocate the buffer - free(it->second.first); + xft::dealloc(std::get<0>(it->second), std::get<2>(it->second)); } } @@ -73,24 +75,26 @@ class SimpleMemPool { } // Update or insert entry in the mapping - memoryMap[name] = std::make_pair(buffer, size); + memoryMap[name] = std::make_tuple(buffer, size, device); return buffer; } // Free allocated memory based on name - void freeBuffer(const std::string &name, void *device = nullptr) { + void freeBuffer(const std::string &name) { auto it = memoryMap.find(name); if (it != memoryMap.end()) { - xft::dealloc(it->second.first, device); + xft::dealloc(std::get<0>(it->second), std::get<2>(it->second)); + memoryMap.erase(it->first); } } // Destructor to free all allocated memory on program termination ~SimpleMemPool() { for (auto &entry : memoryMap) { - free(entry.second.first); + if (!entry.first.empty()) + freeBuffer(entry.first); } memoryMap.clear(); } From 23d2053d80655f0b43acf3aabe96f3eabe1c5b01 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Wed, 12 Jun 2024 23:33:07 +0000 Subject: [PATCH 19/34] Add gpu matmul kernels --- src/utils/matmul_helper.h | 110 +++++++++++++++++++++++++++----------- 1 file changed, 80 insertions(+), 30 deletions(-) diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index 63ec8515..2ae1c12f 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -448,8 +448,14 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE( - "onednn_gemm_compute", onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc)); + if constexpr (std::is_same_v && std::is_same_v) { + GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute", + xdnn_sgemm_f32f16f32_compute( + transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc)); + } else { + GEMMVERBOSE("onednn_gemm_compute", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc)); + } #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute", @@ -566,9 +572,15 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("onednn_gemm_compute_bias", - onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, (const InT *)nullptr, -1, - matmul_kinds::BiasAdd)); + if constexpr (std::is_same_v && std::is_same_v) { + GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_biasadd", + xdnn_sgemm_f32f16f32_compute_biasadd( + transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, bias)); + } else { + GEMMVERBOSE("onednn_gemm_compute_bias", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, + (const InT *)nullptr, -1, matmul_kinds::BiasAdd)); + } #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_biasadd", @@ -689,9 +701,15 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("onednn_gemm_compute_bias_relu", - onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, (const InT *)nullptr, -1, - matmul_kinds::BiasAdd_Relu)); + if constexpr (std::is_same_v && std::is_same_v) { + GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_biasadd_relu", + xdnn_sgemm_f32f16f32_compute_biasadd_relu( + transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, bias)); + } else { + GEMMVERBOSE("onednn_gemm_compute_bias_relu", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, + (const InT *)nullptr, -1, matmul_kinds::BiasAdd_Relu)); + } #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_biasadd_relu", @@ -804,9 +822,15 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("onednn_gemm_compute_silu", - onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, (const float *)nullptr, - (const InT *)nullptr, -1, matmul_kinds::Silu)); + if constexpr (std::is_same_v && std::is_same_v) { + GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_silu", + xdnn_sgemm_f32f16f32_compute_silu( + transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc)); + } else { + GEMMVERBOSE("onednn_gemm_compute_silu", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, + (const float *)nullptr, (const InT *)nullptr, -1, matmul_kinds::Silu)); + } #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_silu", @@ -925,9 +949,16 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("onednn_gemm_compute_gelu", - onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, (const float *)nullptr, - (const InT *)nullptr, -1, matmul_kinds::Gelu)); + if constexpr (std::is_same_v && std::is_same_v) { + GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_gelu", + xdnn_sgemm_f32f16f32_compute_gelu( + transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc)); + } else { + GEMMVERBOSE("onednn_gemm_compute_gelu", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, + (const float *)nullptr, (const InT *)nullptr, -1, matmul_kinds::Gelu)); + } + #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_gelu", @@ -1047,9 +1078,15 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("onednn_gemm_compute_resmul", - onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, (const float *)nullptr, - res, ldres, matmul_kinds::Resmul)); + if constexpr (std::is_same_v && std::is_same_v) { + GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_resmul", + xdnn_sgemm_f32f16f32_compute_resmul( + transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, res, ldres)); + } else { + GEMMVERBOSE("onednn_gemm_compute_resmul", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, + (const float *)nullptr, res, ldres, matmul_kinds::Resmul)); + } #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_resmul", @@ -1171,9 +1208,15 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("onednn_gemm_compute_residential", - onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, res, ldres, - matmul_kinds::Residential)); + if constexpr (std::is_same_v && std::is_same_v) { + GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_residential", + xdnn_sgemm_f32f16f32_compute_residential(transA, M, N, K, alpha, A, lda, + (const XDNN_FP16 *)packedB, beta, C, ldc, bias, res, ldres)); + } else { + GEMMVERBOSE("onednn_gemm_compute_residential", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, res, ldres, + matmul_kinds::Residential)); + } #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_residential", @@ -1296,19 +1339,26 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 + if constexpr (std::is_same_v && std::is_same_v) { + GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_resext", + xdnn_sgemm_f32f16f32_compute_resext(transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, + beta, C, ldc, bias, gamma, res, ldres)); + } else { #pragma omp parallel for collapse(2) - for (uint64_t i = 0; i < M; ++i) { - for (int j = 0; j < N; ++j) { - auto remain = N - j; - __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); - auto v = xft::load_avx512(mask, &res[i * ldres + j]); - v = _mm512_mul_ps(_mm512_set1_ps(gamma), v); - xft::store_avx512(&res[i * ldres + j], mask, v); + for (uint64_t i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + auto remain = N - j; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + auto v = xft::load_avx512(mask, &res[i * ldres + j]); + v = _mm512_mul_ps(_mm512_set1_ps(gamma), v); + xft::store_avx512(&res[i * ldres + j], mask, v); + } } + + GEMMVERBOSE("onednn_gemm_compute_resext", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, res, ldres, + matmul_kinds::Residential)); } - GEMMVERBOSE("onednn_gemm_compute_resext", - onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, res, ldres, - matmul_kinds::Residential)); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_resext", From b7dc9eb4e15986c46395eaca158c8ec0a251f27a Mon Sep 17 00:00:00 2001 From: changqi1 Date: Thu, 13 Jun 2024 10:14:07 +0000 Subject: [PATCH 20/34] Fix CPU build issue. --- src/common/transformer_ctx.h | 12 +++++++----- src/layers/attention.cpp | 5 +++-- src/layers/decoder_layer.cpp | 5 +++-- src/layers/mlp_llama.cpp | 5 +++-- src/layers/mlp_llama.h | 2 +- src/models/common_decoder.h | 20 +++++++++++--------- src/utils/decoder_util.h | 4 ++-- 7 files changed, 30 insertions(+), 23 deletions(-) diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index 71d05f31..398a789f 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -131,7 +131,7 @@ struct DecoderContext { public: DecoderContext(int _layers, int _hiddenSize, int _headSize, int _attHeadNum, int _kvHeadNum, int _imSize, const std::string &act, float epsilon, int _vocabSize, int _embeddingSize, int _maxPositions, int _maxPosEmbed, int _maxSeqLength, - int _splitIdx, int _splits, int _ppSize = 1, int _ppRank = 0, RopeParams *_ropeParamsPtr = nullptr, + int _splitIdx, int _splits, MMHelper *mmHelper, void *device = nullptr, int _ppSize = 1, int _ppRank = 0, RopeParams *_ropeParamsPtr = nullptr, bool _useLogN = true, bool _useNTK = true, int numThreads = 0) : layers(_layers) , hiddenSize(_hiddenSize) @@ -171,8 +171,12 @@ struct DecoderContext { } } - this->rawBufSize = 0; - this->rawBuffer = nullptr; + this->mmHelper = mmHelper; + this->device = device; + + this->rawBufSize = 4 * 32 * intermediateSize + 4 * attHeadNum * 32 * 32; // assume bs=4, seq=32 + this->rawBuffer = (float *)xft::alloc(sizeof(float) * rawBufSize, this->device); + xft::memsetv(this->rawBuffer, 0, sizeof(float) * rawBufSize, this->device); if (act == "relu") { this->actType = RELU; @@ -318,7 +322,5 @@ struct DecoderContext { ~DecoderContext() { xft::dealloc(this->rawBuffer, this->device); - if (this->mmHelper) delete this->mmHelper; - if (this->device) delete this->device; } }; \ No newline at end of file diff --git a/src/layers/attention.cpp b/src/layers/attention.cpp index 0e9d0669..bd15b0af 100644 --- a/src/layers/attention.cpp +++ b/src/layers/attention.cpp @@ -77,15 +77,16 @@ void AttentionLLaMAImpl(DataType dt, int batchSize, int inputSeqLen, int attHead using ATTENTION = Attention; static std::unordered_map llama_attention_hub; + static MMHelper *mmHelper; static DecoderContext *ctx; static KVCacheManager *kvCacheMgr; if (ctx == nullptr || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->attHeadSize != attHeadDim))) { if (ctx != nullptr) delete ctx; printf(">> create context: %d %d\n", hiddenSize, attHeadDim); + mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, 1, "silu", 1e-6, 0, 0, maxPositions, - maxPosEmbed, -1, 0, 1); - ctx->mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); + maxPosEmbed, -1, 0, 1, mmHelper); if (kvCacheMgr != nullptr) delete kvCacheMgr; kvCacheMgr = new KVCacheManager(1); } diff --git a/src/layers/decoder_layer.cpp b/src/layers/decoder_layer.cpp index d1017648..02f13cbf 100644 --- a/src/layers/decoder_layer.cpp +++ b/src/layers/decoder_layer.cpp @@ -85,6 +85,7 @@ void LayerLLaMAImpl(DataType dt, ActivationType at, NormType nt, int batchSize, using DECODER = Decoder, LlamaMLP>; static std::unordered_map llama_layer_hub; + static MMHelper *mmHelper; static DecoderContext *ctx; static KVCacheManager *kvCacheMgr; @@ -104,9 +105,9 @@ void LayerLLaMAImpl(DataType dt, ActivationType at, NormType nt, int batchSize, || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->intermediateSize != intermediateSize))) { if (ctx != nullptr) delete ctx; printf(">> create context: %d %d\n", hiddenSize, intermediateSize); + mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, intermediateSize, actType, 1e-6, 0, - 0, maxPositions, maxPosEmbed, -1, 0, 1); - ctx->mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); + 0, maxPositions, maxPosEmbed, -1, 0, 1, mmHelper); if (kvCacheMgr != nullptr) delete kvCacheMgr; kvCacheMgr = new KVCacheManager(1); } diff --git a/src/layers/mlp_llama.cpp b/src/layers/mlp_llama.cpp index 749b39e0..50d5e79e 100644 --- a/src/layers/mlp_llama.cpp +++ b/src/layers/mlp_llama.cpp @@ -26,6 +26,7 @@ void MLPLLaMAImpl(DataType dt, ActivationType at, int numTokens, int hiddenSize, using MLP = LlamaMLP; static std::unordered_map llama_mlp_hub; + static MMHelper *mmHelper; static DecoderContext *ctx; std::string actType; @@ -44,8 +45,8 @@ void MLPLLaMAImpl(DataType dt, ActivationType at, int numTokens, int hiddenSize, || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->intermediateSize != intermediateSize))) { if (ctx != nullptr) delete ctx; printf(">> create context: %d %d\n", hiddenSize, intermediateSize); - ctx = new DecoderContext(1, hiddenSize, 1, 1, 1, intermediateSize, actType, 1e-6, 0, 0, 0, 0, 0, 0, 1); - ctx->mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); + mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); + ctx = new DecoderContext(1, hiddenSize, 1, 1, 1, intermediateSize, actType, 1e-6, 0, 0, 0, 0, 0, 0, 1, mmHelper); } // create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed. diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index 189fcf7b..6e835aaf 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -318,7 +318,7 @@ class LlamaMLP { } else if (ctx->actType == DecoderContext::SWIGLU) { // chatglm2/3 DecoderUtil::siluSum(output, siluBuf, ctx->device); } else if (ctx->actType == DecoderContext::GELU) { // gemma - DecoderUtil::geluSum(output, siluBuf); + DecoderUtil::geluSum(output, siluBuf, ctx->device); } else { printf("ERROR: unsupported activation in MLP.\n"); exit(-1); diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index ff2b308a..e32c1634 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -356,7 +356,8 @@ class CommonDecoder : public AbstractDecoder { groupMeta->get(0)->setPastSeqLen(pastSeqLen); groupMeta->get(0)->allocBuffer(hiddenSize, embBuf); SequencePool::getInstance().add(groupMeta); - TaskWaitingQueue::getInstance().push(SequencePool::getInstance().get(groupMeta->get(0)->getSequenceID())); + TaskWaitingQueue::getInstance().push( + SequencePool::getInstance().get(groupMeta->get(0)->getSequenceID())); } } @@ -533,7 +534,7 @@ class CommonDecoder : public AbstractDecoder { } std::tuple forward(std::vector &seqs, bool logitsAll = false) { - // Assume all sequences are all prompts(step==0) or all decodes(step>0) + // Assume all sequences are all prompts(step==0) or all decodes(step>0) // Assume input has been synced with master in higher level. TimeLine t("Decoder.forward"); TimeLine t1("Decoder.embedding"); @@ -757,23 +758,22 @@ class CommonDecoder : public AbstractDecoder { exit(-1); } } else { - this->context.reset(new DecoderContext(layers, hiddenSize, headSize, attHeadNum, kvHeadNum, imSize, act, - epsilon, vocabSize, embeddingSize, maxPositions, maxPosEmbed, maxSeqLength, tpRank, tpSize, ppSize, - ppRank, ropeParamsPtr, useLogN, useNTK)); - int engineIdx = 0; if (env.getEngineKind() == xft::DeviceKind::iGPU && env.getEngineIndex() < 0) // Sequential assignment engineIdx = ppRank * tpSize + tpRank; else // assignment through the user engineIdx = env.getEngineIndex(); - this->context->mmHelper = new MMHelper(env.getEngineKind(), engineIdx); + this->mmHelper.reset(new MMHelper(env.getEngineKind(), engineIdx)); #ifdef GPU auto devices = sycl::device::get_devices(sycl::info::device_type::gpu); - this->context->device = new sycl::queue(devices[this->context->mmHelper->getEngineCount() + engineIdx]); + this->device.reset(new sycl::queue(devices[this->mmHelper->getEngineCount() + engineIdx])); #else - this->context->device = nullptr; + this->device.reset(nullptr); #endif + this->context.reset(new DecoderContext(layers, hiddenSize, headSize, attHeadNum, kvHeadNum, imSize, act, + epsilon, vocabSize, embeddingSize, maxPositions, maxPosEmbed, maxSeqLength, tpRank, tpSize, + this->mmHelper.get(), this->device.get(), ppSize, ppRank, ropeParamsPtr, useLogN, useNTK)); } return this->context.get(); @@ -1094,6 +1094,8 @@ class CommonDecoder : public AbstractDecoder { // Execution context std::shared_ptr context; + std::shared_ptr mmHelper; + std::shared_ptr device; // The initial input sequence length, which is the prompt token size int initSeqLen; diff --git a/src/utils/decoder_util.h b/src/utils/decoder_util.h index 343b36ff..0d465118 100644 --- a/src/utils/decoder_util.h +++ b/src/utils/decoder_util.h @@ -481,7 +481,7 @@ class DecoderUtil { #else // compute silu on the left half and then add it with the right half template - static void siluSum(xft::Matrix &src, xft::Matrix &dst) { + static void siluSum(xft::Matrix &src, xft::Matrix &dst, void *device = nullptr) { __m512 one = _mm512_set1_ps(1.f); __m512 negOne = _mm512_set1_ps(-1.f); int M = src.Rows(); @@ -508,7 +508,7 @@ class DecoderUtil { // compute gelu on the left half and then add it with the right half template - static void geluSum(xft::Matrix &src, xft::Matrix &dst) { + static void geluSum(xft::Matrix &src, xft::Matrix &dst, void *device = nullptr) { const __m512 c1 = _mm512_set1_ps(0.044715f); const __m512 c2 = _mm512_set1_ps(0.7978845608f); const __m512 vone = _mm512_set1_ps(1.0f); From daec9dd566e4dc2fa350a096c5a3835b1455d0af Mon Sep 17 00:00:00 2001 From: changqi1 Date: Thu, 13 Jun 2024 10:21:16 +0000 Subject: [PATCH 21/34] fix --- src/models/common_decoder.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index e32c1634..e3edcaa9 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -769,7 +769,7 @@ class CommonDecoder : public AbstractDecoder { auto devices = sycl::device::get_devices(sycl::info::device_type::gpu); this->device.reset(new sycl::queue(devices[this->mmHelper->getEngineCount() + engineIdx])); #else - this->device.reset(nullptr); + this->device.reset(&nullptr); #endif this->context.reset(new DecoderContext(layers, hiddenSize, headSize, attHeadNum, kvHeadNum, imSize, act, epsilon, vocabSize, embeddingSize, maxPositions, maxPosEmbed, maxSeqLength, tpRank, tpSize, From c3e83f22f88060b24aa4b214e648156d4893037d Mon Sep 17 00:00:00 2001 From: changqi1 Date: Thu, 13 Jun 2024 10:29:55 +0000 Subject: [PATCH 22/34] fix --- src/models/common_decoder.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index e3edcaa9..a8ae1bbf 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -768,8 +768,6 @@ class CommonDecoder : public AbstractDecoder { #ifdef GPU auto devices = sycl::device::get_devices(sycl::info::device_type::gpu); this->device.reset(new sycl::queue(devices[this->mmHelper->getEngineCount() + engineIdx])); -#else - this->device.reset(&nullptr); #endif this->context.reset(new DecoderContext(layers, hiddenSize, headSize, attHeadNum, kvHeadNum, imSize, act, epsilon, vocabSize, embeddingSize, maxPositions, maxPosEmbed, maxSeqLength, tpRank, tpSize, @@ -1095,7 +1093,7 @@ class CommonDecoder : public AbstractDecoder { // Execution context std::shared_ptr context; std::shared_ptr mmHelper; - std::shared_ptr device; + std::shared_ptr device(nullptr); // The initial input sequence length, which is the prompt token size int initSeqLen; From 277de9b0dd35b3b8c7de303400db648cdd41c996 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Thu, 13 Jun 2024 10:45:10 +0000 Subject: [PATCH 23/34] fix --- src/models/common_decoder.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index a8ae1bbf..cb4a531e 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -1093,7 +1093,7 @@ class CommonDecoder : public AbstractDecoder { // Execution context std::shared_ptr context; std::shared_ptr mmHelper; - std::shared_ptr device(nullptr); + std::shared_ptr device; // The initial input sequence length, which is the prompt token size int initSeqLen; From dd1d3fb254b8896a5fe8372e000bf0b472983e52 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Thu, 13 Jun 2024 11:26:27 +0000 Subject: [PATCH 24/34] fix --- src/utils/simple_mem_pool.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/utils/simple_mem_pool.h b/src/utils/simple_mem_pool.h index dabb1a7f..f4a65db2 100644 --- a/src/utils/simple_mem_pool.h +++ b/src/utils/simple_mem_pool.h @@ -93,8 +93,7 @@ class SimpleMemPool { // Destructor to free all allocated memory on program termination ~SimpleMemPool() { for (auto &entry : memoryMap) { - if (!entry.first.empty()) - freeBuffer(entry.first); + xft::dealloc(std::get<0>(entry.second), std::get<2>(entry.second)); } memoryMap.clear(); } From 15fc202038f1dce0f088363cb0ef9dcbfca1cc8a Mon Sep 17 00:00:00 2001 From: changqi1 Date: Thu, 13 Jun 2024 12:21:37 +0000 Subject: [PATCH 25/34] Fix build issue. --- src/common/transformer_ctx.h | 2 +- src/models/common_decoder.h | 2 +- src/utils/matmul_helper.h | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index 398a789f..9d79a1e5 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -321,6 +321,6 @@ struct DecoderContext { } ~DecoderContext() { - xft::dealloc(this->rawBuffer, this->device); + if (this->rawBuffer) xft::dealloc(this->rawBuffer, this->device); } }; \ No newline at end of file diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index cb4a531e..3b26547d 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -256,7 +256,7 @@ class CommonDecoder : public AbstractDecoder { virtual ~CommonDecoder() { if (this->inputTokens) free(this->inputTokens); - if (this->attnMask) free(this->attnMask); + if (this->attnMask) xft::alloc(this->attnMask); delete this->decoderBlock; delete this->predictor; diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index 2ae1c12f..8b9a4092 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -226,8 +226,8 @@ class MMHelper { int offset = trans ? rowOffset : colOffset; scaleWeight.Resize(size); zeroWeight.Resize(size); - memcpy(scaleWeight.Data(), scales + offset, size * sizeof(float)); - memcpy(zeroWeight.Data(), zeros + offset, size * sizeof(float)); + if (scales) memcpy(scaleWeight.Data(), scales + offset, size * sizeof(float)); + if (zeros) memcpy(zeroWeight.Data(), zeros + offset, size * sizeof(float)); #pragma omp parallel for for (uint64_t i = 0; i < rowSize; i++) { WeiT *dst = convertedWeight.Data() + i * convertedWeight.Stride(); @@ -242,8 +242,8 @@ class MMHelper { int offset = trans ? rowOffset : colOffset; scaleWeight.Resize(size); zeroWeight.Resize(size); - memcpy(scaleWeight.Data(), scales + offset, size * sizeof(float)); - memcpy(zeroWeight.Data(), zeros + offset, size * sizeof(float)); + if (scales) memcpy(scaleWeight.Data(), scales + offset, size * sizeof(float)); + if (zeros) memcpy(zeroWeight.Data(), zeros + offset, size * sizeof(float)); #pragma omp parallel for for (uint64_t i = 0; i < rowSize; i++) { WeiT *dst = convertedWeight.Data() + i * convertedWeight.Stride() / 2; From 726d35675953bc6d633c480c4500c34553883798 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Thu, 13 Jun 2024 12:23:50 +0000 Subject: [PATCH 26/34] Fix build issue. --- src/models/common_decoder.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index 3b26547d..d6f574ee 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -256,7 +256,7 @@ class CommonDecoder : public AbstractDecoder { virtual ~CommonDecoder() { if (this->inputTokens) free(this->inputTokens); - if (this->attnMask) xft::alloc(this->attnMask); + if (this->attnMask) xft::dealloc(this->attnMask); delete this->decoderBlock; delete this->predictor; From 003c46b53db86f2e166f210db6be728380c5e8a0 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Thu, 13 Jun 2024 15:19:27 +0000 Subject: [PATCH 27/34] Fix LN bug --- src/layers/rms_norm.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/layers/rms_norm.cpp b/src/layers/rms_norm.cpp index 8842a965..f86e3b5e 100644 --- a/src/layers/rms_norm.cpp +++ b/src/layers/rms_norm.cpp @@ -52,6 +52,7 @@ template void RmsNormImp::setWeight(const float *w, const float *, int cols) { T weightBuf[cols]; if constexpr (std::is_same_v) { + xft::memcopy(weightBuf, w, cols * sizeof(float)); } else if constexpr (std::is_same_v) { float16_t::cvt_float_to_float16(w, weightBuf, cols); } else if constexpr (std::is_same_v) { From 4cb98cfb24440e3f545768663aee4a29d09b1d0a Mon Sep 17 00:00:00 2001 From: changqi1 Date: Thu, 13 Jun 2024 15:50:49 +0000 Subject: [PATCH 28/34] Fix final LN --- src/layers/rms_norm.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/layers/rms_norm.cpp b/src/layers/rms_norm.cpp index f86e3b5e..4f8a6c25 100644 --- a/src/layers/rms_norm.cpp +++ b/src/layers/rms_norm.cpp @@ -69,11 +69,10 @@ void RmsNormImp::setWeight(const float *w, const float *, int cols) { template void RmsNormImp::setWeight(const std::string &modelPath, const std::string &, int cols) { + float weightBuf[cols]; + loadWeight(modelPath, weightBuf, cols); this->normSize = cols; - float *weiBuf = (float *)xft::alloc(cols * sizeof(float)); - loadWeight(modelPath, weiBuf, cols, DataType::fp32); - this->setWeight(weiBuf, nullptr, cols); - xft::dealloc(weiBuf); + this->setWeight(weightBuf, nullptr, cols); } #ifdef GPU From 6a85769a0f21d028fc38f7a0f9a0fcfd9d8106fa Mon Sep 17 00:00:00 2001 From: changqi1 Date: Thu, 13 Jun 2024 15:57:33 +0000 Subject: [PATCH 29/34] Fix 2 --- src/utils/weight_util.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/weight_util.h b/src/utils/weight_util.h index 0ac594c1..b8b7ae6d 100644 --- a/src/utils/weight_util.h +++ b/src/utils/weight_util.h @@ -153,7 +153,7 @@ int loadWeightWithConvert(T *ptr, int size, const std::string &filename, bool re } template -int loadWeight(std::string filename, T *&ptr, int size, DataType w_type = DataType::unknown, bool required = true) { +int loadWeight(std::string filename, T *ptr, int size, DataType w_type = DataType::unknown, bool required = true) { // By default, read the config.ini configuration file // in the same directory as the model file to determine the data type of the file. From f6e6e6477e0f2b0b06305bc0efe8c831f946e5a6 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Thu, 13 Jun 2024 16:25:00 +0000 Subject: [PATCH 30/34] Fix 3 --- src/layers/rms_norm.cpp | 3 ++- src/utils/weight_util.h | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/layers/rms_norm.cpp b/src/layers/rms_norm.cpp index 4f8a6c25..05e21d69 100644 --- a/src/layers/rms_norm.cpp +++ b/src/layers/rms_norm.cpp @@ -70,7 +70,8 @@ void RmsNormImp::setWeight(const float *w, const float *, int cols) { template void RmsNormImp::setWeight(const std::string &modelPath, const std::string &, int cols) { float weightBuf[cols]; - loadWeight(modelPath, weightBuf, cols); + float *weiBuf = &weightBuf[0]; + loadWeight(modelPath, weiBuf, cols); this->normSize = cols; this->setWeight(weightBuf, nullptr, cols); } diff --git a/src/utils/weight_util.h b/src/utils/weight_util.h index b8b7ae6d..0ac594c1 100644 --- a/src/utils/weight_util.h +++ b/src/utils/weight_util.h @@ -153,7 +153,7 @@ int loadWeightWithConvert(T *ptr, int size, const std::string &filename, bool re } template -int loadWeight(std::string filename, T *ptr, int size, DataType w_type = DataType::unknown, bool required = true) { +int loadWeight(std::string filename, T *&ptr, int size, DataType w_type = DataType::unknown, bool required = true) { // By default, read the config.ini configuration file // in the same directory as the model file to determine the data type of the file. From 5f93020b4574c0c8334f84566abafe5da3117c0c Mon Sep 17 00:00:00 2001 From: changqi1 Date: Thu, 13 Jun 2024 23:23:47 +0000 Subject: [PATCH 31/34] Done --- src/common/transformer_ctx.h | 2 +- src/layers/attention.h | 14 ++++----- src/layers/mlp_llama.h | 2 +- src/layers/rms_norm.cpp | 6 ++-- src/layers/rotary_embedding.cpp | 24 ++++++++------- src/models/common_decoder.h | 41 ++++++++++++++----------- src/utils/decoder_util.h | 11 +++---- src/utils/simple_mem_pool.h | 6 ++-- src/utils/type_selector.h | 14 +++++---- src/utils/verbose.h | 53 +++++++++++++++++++++++++++++---- 10 files changed, 113 insertions(+), 60 deletions(-) diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index 9d79a1e5..34081369 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -321,6 +321,6 @@ struct DecoderContext { } ~DecoderContext() { - if (this->rawBuffer) xft::dealloc(this->rawBuffer, this->device); + // if (this->rawBuffer) xft::dealloc(this->rawBuffer, this->device); } }; \ No newline at end of file diff --git a/src/layers/attention.h b/src/layers/attention.h index b6ddd484..215a8539 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -329,6 +329,13 @@ class Attention { } t3.release(); +#ifdef XFT_DEBUG + dbg.debugPrint("Q[%d,%d](%d) after post op:\n", query.Rows(), query.Cols(), query.Stride()); + dbg.dumpMatrix(query); + dbg.debugPrint("K[%d,%d](%d) after post op:\n", key.Rows(), key.Cols(), key.Stride()); + dbg.dumpMatrix(key); +#endif + #ifdef GPU int64_t qkvSize = qkvRows * qkvStride * sizeof(ImT); ImT *qkvTmp = (ImT *)xft::alloc(qkvSize); @@ -338,13 +345,6 @@ class Attention { value.Assign(qkvTmp + qCols + kvCols, inputBuffer.Rows(), kvCols, qkvCols); #endif -#ifdef XFT_DEBUG - dbg.debugPrint("Q[%d,%d](%d) after post op:\n", query.Rows(), query.Cols(), query.Stride()); - dbg.dumpMatrix(query); - dbg.debugPrint("K[%d,%d](%d) after post op:\n", key.Rows(), key.Cols(), key.Stride()); - dbg.dumpMatrix(key); -#endif - // Revise attnFactor before softmax (for some models, attnFactor may be not the default value) // We initially introduced the code for ChatGLM, but eventually found it has no difference and was unnecessary. // However, we have chosen to keep it in the codebase in case it becomes useful for future models. diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index 6e835aaf..6d63fd1e 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -80,7 +80,7 @@ class LlamaMLP { int catWeiRows = quantizedCatWeights.Rows(); int catWeiCols = quantizedCatWeights.Cols(); catWeightsT.Resize(catWeiRows, catWeiCols); - ctx->mmHelper->transposeWeight(true, quantizedCatWeights, catWeightsT); + ctx->mmHelper->transposeWeight(trans, quantizedCatWeights, catWeightsT); WeiT *catWeiData = (WeiT *)xft::alloc(catWeiRows * catWeiCols * sizeof(WeiT), ctx->device); catWeights.Assign(catWeiData, catWeiRows, catWeiCols, catWeiCols); diff --git a/src/layers/rms_norm.cpp b/src/layers/rms_norm.cpp index 05e21d69..7aa5cd68 100644 --- a/src/layers/rms_norm.cpp +++ b/src/layers/rms_norm.cpp @@ -83,7 +83,7 @@ void RmsNormImp::forward(const float *input, float *output, int rows, int iSt sycl::queue *gpu_queue = static_cast(device); if constexpr (std::is_same_v) { fastertransformer::invokeGeneralT5LayerNorm( - output, input, weight, (const float *)nullptr, epsilon, rows, iStride, gpu_queue); + output, input, weight, (const float *)nullptr, epsilon, rows, normSize, gpu_queue); } else { printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); exit(-1); @@ -105,7 +105,7 @@ void RmsNormImp::forward( if constexpr (std::is_same_v) { // TODO: Add BF16 RmsNorm Implemention. // fastertransformer::invokeGeneralT5LayerNorm( - // output, input, weight, (const bfloat16_t *)nullptr, epsilon, rows, iStride, gpu_queue); + // output, input, weight, (const bfloat16_t *)nullptr, epsilon, rows, normSize, gpu_queue); } else { printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); exit(-1); @@ -126,7 +126,7 @@ void RmsNormImp::forward( sycl::queue *gpu_queue = static_cast(device); if constexpr (std::is_same_v) { fastertransformer::invokeGeneralT5LayerNorm((sycl::half *)output, (const sycl::half *)input, - (const sycl::half *)weight, (const sycl::half *)nullptr, epsilon, rows, iStride, gpu_queue); + (const sycl::half *)weight, (const sycl::half *)nullptr, epsilon, rows, normSize, gpu_queue); } else { printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); exit(-1); diff --git a/src/layers/rotary_embedding.cpp b/src/layers/rotary_embedding.cpp index f8a473ff..f5e36659 100644 --- a/src/layers/rotary_embedding.cpp +++ b/src/layers/rotary_embedding.cpp @@ -43,22 +43,24 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) { inv_freq[i] /= this->scaling_factor; } xft::llamaSetCosSinCache(inv_freq, emb_cos, emb_sin, inv_freq_size, max_position_embeddings); + } else if (dim != inv_freq_size * 2) { + printf("Incorrect dim=%d, inv_freq_size=%d\n", dim, inv_freq_size); + exit(-1); + } + #ifdef GPU - if (this->device != nullptr) { - float *emb_cos_bak = emb_cos; - float *emb_sin_bak = emb_sin; - emb_cos = ctx->getBuffer(emb_cos_str + "_gpu", max_position_embeddings * inv_freq_size, device); - emb_sin = ctx->getBuffer(emb_sin_str + "_gpu", max_position_embeddings * inv_freq_size, device); + if (this->device != nullptr) { + float *emb_cos_bak = emb_cos; + float *emb_sin_bak = emb_sin; + emb_cos = ctx->getBuffer(emb_cos_str + "_gpu", max_position_embeddings * inv_freq_size, device); + emb_sin = ctx->getBuffer(emb_sin_str + "_gpu", max_position_embeddings * inv_freq_size, device); + if (!ctx->cached(inv_freq_str + "_gpu")) { + inv_freq = ctx->getBuffer(inv_freq_str + "_gpu", inv_freq_size); xft::memcopy(emb_cos, emb_cos_bak, max_position_embeddings * inv_freq_size * sizeof(float), device); xft::memcopy(emb_sin, emb_sin_bak, max_position_embeddings * inv_freq_size * sizeof(float), device); - ctx->freeBuffer(emb_cos_str); - ctx->freeBuffer(emb_sin_str); } -#endif - } else if (dim != inv_freq_size * 2) { - printf("Incorrect dim=%d, inv_freq_size=%d\n", dim, inv_freq_size); - exit(-1); } +#endif } // This API is deprecated, will delete after all rotary embed code refactor. diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index d6f574ee..35fc1a1a 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -374,7 +374,8 @@ class CommonDecoder : public AbstractDecoder { #ifdef GPU size_t embBufSize = batchSize * inputSeqLen * hiddenSize * sizeof(AttnInT); AttnInT *embBufTmp = (AttnInT *)xft::alloc(embBufSize, ctx->device); - AttnInT *outBufTmp = (AttnInT *)xft::alloc(embBufSize, ctx->device); + AttnInT *outBufTmp = (AttnInT *)xft::alloc( + actBuffers->Rows() * actBuffers->Cols() * sizeof(float) - embBufSize, ctx->device); xft::memcopy(embBufTmp, embBuf, embBufSize, ctx->device); embBuf = embBufTmp; outBuf = outBufTmp; @@ -433,15 +434,6 @@ class CommonDecoder : public AbstractDecoder { } } -#ifdef GPU - embBufTmp = (AttnInT *)actBuffers->Data(); - xft::memcopy(embBufTmp, embBuf, embBufSize, ctx->device); - xft::dealloc(embBuf, ctx->device); - xft::dealloc(outBuf, ctx->device); - embBuf = embBufTmp; - outBuf = (MlpOutT *)(embBuf + batchSize * inputSeqLen * hiddenSize); -#endif - #ifdef PIPELINE_PARALLEL } @@ -465,8 +457,8 @@ class CommonDecoder : public AbstractDecoder { lnIn = outBuf; #pragma omp parallel for for (int b = 0; b < batchSize; ++b) { - memcpy(lnIn + b * hiddenSize, embBuf + ((b + 1) * inputSeqLen - 1) * hiddenSize, - hiddenSize * sizeof(MlpOutT)); + xft::memcopy(lnIn + b * hiddenSize, embBuf + ((b + 1) * inputSeqLen - 1) * hiddenSize, + hiddenSize * sizeof(MlpOutT), ctx->device ? ctx->device : nullptr); } } @@ -475,10 +467,11 @@ class CommonDecoder : public AbstractDecoder { dbg.dumpMatrix(embBuf, batchSize * inputSeqLen, hiddenSize, hiddenSize); dbg.debugPrint("LayerNorm In:\n"); - if (!logitsAll) + if (!logitsAll) { dbg.dumpMatrix(lnIn, batchSize, hiddenSize, hiddenSize); - else + } else { dbg.dumpMatrix(lnIn, batchSize * inputSeqLen, hiddenSize, hiddenSize); + } #endif // LN, as it supports inplace computing, input and output can be the same @@ -490,10 +483,11 @@ class CommonDecoder : public AbstractDecoder { #ifdef XFT_DEBUG dbg.debugPrint("LayerNorm Out:\n"); - if (!logitsAll) + if (!logitsAll) { dbg.dumpMatrix(lnOut, batchSize, hiddenSize, hiddenSize); - else + } else { dbg.dumpMatrix(lnOut, batchSize * inputSeqLen, hiddenSize, hiddenSize); + } #endif // Predictor @@ -506,10 +500,21 @@ class CommonDecoder : public AbstractDecoder { #ifdef XFT_DEBUG auto splitSize = this->predictor->getSplitSize(); dbg.debugPrint("finalOut:\n"); - if (!logitsAll) + if (!logitsAll) { dbg.dumpMatrix(finalOut, batchSize, splitSize, splitSize); - else + } else { dbg.dumpMatrix(finalOut, batchSize * inputSeqLen, splitSize, splitSize); + } +#endif + +#ifdef GPU + xft::dealloc(embBuf, ctx->device); + embBuf = (AttnInT *)actBuffers->Data(); + + float *finalOutTmp = (float *)(embBuf + batchSize * inputSeqLen * hiddenSize); + xft::memcopy(finalOutTmp, finalOut, batchSize * splitSize * sizeof(float), ctx->device); + xft::dealloc(outBuf, ctx->device); + finalOut = finalOutTmp; #endif // Expand the result to make it cover multiple beams diff --git a/src/utils/decoder_util.h b/src/utils/decoder_util.h index 0d465118..8038c288 100644 --- a/src/utils/decoder_util.h +++ b/src/utils/decoder_util.h @@ -460,8 +460,8 @@ class DecoderUtil { sycl::queue *gpu_queue = static_cast(device); if constexpr (std::is_same_v && std::is_same_v) { - const float16_t *src0 = src.Data(); - const float16_t *src1 = src.Data() + N; + sycl::half *src0 = (sycl::half *)src.Data(); + sycl::half *src1 = (sycl::half *)(src.Data() + N); sycl::half *dest = (sycl::half *)dst.Data(); gpu_queue @@ -469,9 +469,10 @@ class DecoderUtil { h.parallel_for(M * N, [=](auto i) { int32_t row = i / N; int32_t col = i % N; - dest[row * ldd + col] = ((sycl::half)src0[row * lds + col] - / ((sycl::half)1.0f + (sycl::half)sycl::native::exp(-src0[row * lds + col])) - * (sycl::half)src1[row * lds + col]); + sycl::half tmp0 = src0[row * lds + col]; + sycl::half tmp1 = src1[row * lds + col]; + dest[row * ldd + col] = tmp0 * tmp1 + / ((sycl::half)1.0f + (sycl::half)sycl::native::exp(tmp0 * -1.0f)); }); }) .wait(); diff --git a/src/utils/simple_mem_pool.h b/src/utils/simple_mem_pool.h index f4a65db2..590c6f1c 100644 --- a/src/utils/simple_mem_pool.h +++ b/src/utils/simple_mem_pool.h @@ -92,9 +92,9 @@ class SimpleMemPool { // Destructor to free all allocated memory on program termination ~SimpleMemPool() { - for (auto &entry : memoryMap) { - xft::dealloc(std::get<0>(entry.second), std::get<2>(entry.second)); - } + // for (auto &entry : memoryMap) { + // xft::dealloc(std::get<0>(entry.second), std::get<2>(entry.second)); + // } memoryMap.clear(); } }; \ No newline at end of file diff --git a/src/utils/type_selector.h b/src/utils/type_selector.h index 7708ccad..60e47154 100644 --- a/src/utils/type_selector.h +++ b/src/utils/type_selector.h @@ -31,9 +31,11 @@ struct TypeSelector { using OutType = bfloat16_t; }; -// template <> -// struct TypeSelector { -// using InType = float16_t; -// using ImType = float16_t; -// using OutType = float16_t; -// }; \ No newline at end of file +#ifdef GPU +template <> +struct TypeSelector { + using InType = float16_t; + using ImType = float16_t; + using OutType = float16_t; +}; +#endif \ No newline at end of file diff --git a/src/utils/verbose.h b/src/utils/verbose.h index 0716b045..cd838df8 100644 --- a/src/utils/verbose.h +++ b/src/utils/verbose.h @@ -56,21 +56,64 @@ class Printer { sycl::queue *gpu_queue = static_cast(device); gpu_queue ->submit([&](sycl::handler &cgh) { - auto out = sycl::stream(1024, 768, cgh); + auto out = sycl::stream(10240, 7680, cgh); cgh.parallel_for(sycl::nd_range<1>(1, 1), [=](sycl::nd_item<1> item) { int idx_col = item.get_global_id(0); if (idx_col == 0) { - for (int row = 0; row < rows; ++row) { - for (int col = 0; col < cols; ++col) { - out << (float)buf[row * stride + col] << ", "; + if (printAll == false) { + for (int row = 0; row < 6; ++row) { + for (int col = 0; col < 6; ++col) { + out << (float)buf[row * stride + col] << ", "; + } + out << " ... "; + for (int col = cols - 6; col < cols; ++col) { + out << (float)buf[row * stride + col] << ", "; + } + out << sycl::endl; + } + out << "..." << sycl::endl; + for (int row = rows - 6; row < rows; ++row) { + for (int col = 0; col < 6; ++col) { + out << (float)buf[row * stride + col] << ", "; + } + out << " ... "; + for (int col = cols - 6; col < cols; ++col) { + out << (float)buf[row * stride + col] << ", "; + } + out << sycl::endl; } out << sycl::endl; + } else { + for (int row = 0; row < rows; ++row) { + for (int col = 0; col < 6; ++col) { + out << (float)buf[row * stride + col] << ", "; + } + out << " ... "; + for (int col = cols - 6; col < cols; ++col) { + out << (float)buf[row * stride + col] << ", "; + } + out << sycl::endl; + } } - out << sycl::endl; } }); }) .wait(); + } else { + for (int row = 0; row < 6; ++row) { + for (int col = 0; col < 6; ++col) { + std::cout << (float)buf[row * stride + col] << ", "; + } + std::cout << std::endl; + } + std::cout << "..." << std::endl; + for (int row = rows - 6; row < rows; ++row) { + for (int col = cols - 6; col < cols; ++col) { + std::cout << (float)buf[row * stride + col] << ", "; + } + std::cout << std::endl; + } + std::cout << std::endl; } #endif } From 175c4dc492f56cfdd38de0e1fecab1304b74c15b Mon Sep 17 00:00:00 2001 From: changqi1 Date: Fri, 14 Jun 2024 09:08:04 +0000 Subject: [PATCH 32/34] Finish --- src/common/transformer_ctx.h | 4 +++- src/utils/simple_mem_pool.h | 8 +++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index 34081369..356a2532 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -321,6 +321,8 @@ struct DecoderContext { } ~DecoderContext() { - // if (this->rawBuffer) xft::dealloc(this->rawBuffer, this->device); +#ifndef GPU + if (this->rawBuffer) xft::dealloc(this->rawBuffer, this->device); +#endif } }; \ No newline at end of file diff --git a/src/utils/simple_mem_pool.h b/src/utils/simple_mem_pool.h index 590c6f1c..7338dbfe 100644 --- a/src/utils/simple_mem_pool.h +++ b/src/utils/simple_mem_pool.h @@ -92,9 +92,11 @@ class SimpleMemPool { // Destructor to free all allocated memory on program termination ~SimpleMemPool() { - // for (auto &entry : memoryMap) { - // xft::dealloc(std::get<0>(entry.second), std::get<2>(entry.second)); - // } +#ifndef GPU + for (auto &entry : memoryMap) { + xft::dealloc(std::get<0>(entry.second), std::get<2>(entry.second)); + } +#endif memoryMap.clear(); } }; \ No newline at end of file From 8d35cfc31640acc2696752894e14c32d39ec6c6b Mon Sep 17 00:00:00 2001 From: changqi1 Date: Fri, 14 Jun 2024 09:23:22 +0000 Subject: [PATCH 33/34] change macro GPU to XFT_GPU --- CMakeLists.txt | 2 +- src/common/allocator.h | 10 +++++----- src/common/transformer_ctx.h | 2 +- src/kernels/rotary_embedding_kernels.cpp | 2 +- src/kernels/rotary_embedding_kernels.h | 2 +- src/layers/attention.h | 10 +++++----- src/layers/dist_linear.h | 2 +- src/layers/layer_norm.cpp | 2 +- src/layers/mlp_llama.h | 4 ++-- src/layers/rms_norm.cpp | 4 ++-- src/layers/rms_norm.h | 2 +- src/layers/rotary_embedding.cpp | 4 ++-- src/models/common_decoder.h | 9 ++++----- src/utils/compile_util.h | 2 +- src/utils/debugger.h | 2 +- src/utils/decoder_util.h | 4 ++-- src/utils/simple_mem_pool.h | 2 +- src/utils/type_selector.h | 2 +- src/utils/verbose.h | 2 +- 19 files changed, 34 insertions(+), 35 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 42832658..1e868b83 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,7 +20,7 @@ project(xfastertransformer LANGUAGES C CXX) option(WITH_GPU "Build with GPU" OFF) if(WITH_GPU) message(STATUS "Notice: Building with GPU.") - add_definitions(-DGPU=true) + add_definitions(-DXFT_GPU=true) # Get compiler version execute_process(COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE ICPX_VERSION diff --git a/src/common/allocator.h b/src/common/allocator.h index 87ad1190..6fc80f80 100644 --- a/src/common/allocator.h +++ b/src/common/allocator.h @@ -19,7 +19,7 @@ #include "environment.h" #include -#ifdef GPU +#ifdef XFT_GPU #include #endif @@ -36,7 +36,7 @@ static inline void *alloc(size_t nbytes, void *device = nullptr, size_t alignmen void *data = nullptr; -#ifdef GPU +#ifdef XFT_GPU if (device != nullptr) { sycl::queue *gpu_queue = static_cast(device); data = sycl::malloc_device(nbytes, *gpu_queue); @@ -66,7 +66,7 @@ static inline void *alloc(size_t nbytes, void *device = nullptr, size_t alignmen } static inline void dealloc(void *data, void *device = nullptr) { -#ifdef GPU +#ifdef XFT_GPU if (device != nullptr) { sycl::free(data, *static_cast(device)); return; @@ -77,7 +77,7 @@ static inline void dealloc(void *data, void *device = nullptr) { } static inline void memcopy(void *dst, const void *src, size_t size, void *device = nullptr) { -#ifdef GPU +#ifdef XFT_GPU if (device != nullptr) { sycl::queue *gpu_queue = static_cast(device); gpu_queue->memcpy(dst, src, size).wait(); @@ -89,7 +89,7 @@ static inline void memcopy(void *dst, const void *src, size_t size, void *device } static inline void memsetv(void *dst, int ch, size_t size, void *device = nullptr) { -#ifdef GPU +#ifdef XFT_GPU if (device != nullptr) { sycl::queue *gpu_queue = static_cast(device); gpu_queue->memset(dst, ch, size).wait(); diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index 356a2532..27b777bc 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -321,7 +321,7 @@ struct DecoderContext { } ~DecoderContext() { -#ifndef GPU +#ifndef XFT_GPU if (this->rawBuffer) xft::dealloc(this->rawBuffer, this->device); #endif } diff --git a/src/kernels/rotary_embedding_kernels.cpp b/src/kernels/rotary_embedding_kernels.cpp index aa67b1d7..dd3e8298 100644 --- a/src/kernels/rotary_embedding_kernels.cpp +++ b/src/kernels/rotary_embedding_kernels.cpp @@ -386,7 +386,7 @@ void qwenApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, i maxSupportedSeqLength, qkShape, positionIds); } -#ifdef GPU +#ifdef XFT_GPU // For LLaMA template static inline void llamaApplyRotaryPosEmbeding(void *device, T *query, T *key, int qStride, int kStride, diff --git a/src/kernels/rotary_embedding_kernels.h b/src/kernels/rotary_embedding_kernels.h index 3348e968..b1e1f388 100644 --- a/src/kernels/rotary_embedding_kernels.h +++ b/src/kernels/rotary_embedding_kernels.h @@ -65,7 +65,7 @@ void qwenApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, i float *cur_emb_sin, int inv_freq_size, const float *logn, int maxSupportedSeqLength, const int *qkShape, const int *positionIds); -#ifdef GPU +#ifdef XFT_GPU // For LLaMA void llamaApplyRotaryPosEmbeding(void *device, float *query, float *key, int qStride, int kStride, float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds); diff --git a/src/layers/attention.h b/src/layers/attention.h index 215a8539..135d9bd4 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -141,7 +141,7 @@ class Attention { ctx->mmHelper->convertWeight(trans, hiddenSize, responsibleCols, concatBuf, concatScale, concatZero, convertedqkvWeight, qkvWeightScale, qkvWeightZero, qkvWeightSum); -#ifdef GPU +#ifdef XFT_GPU xft::Matrix qkvWeightT; qkvWeightT.Resize(hiddenSize, responsibleCols); ctx->mmHelper->transposeWeight(trans, convertedqkvWeight, qkvWeightT); @@ -184,7 +184,7 @@ class Attention { attnOutZero, this->startQHead * headSize, qResponsibleCols, false, convertedOutWeight, attnOutputWeightScale, attnOutputWeightZero, attnOutputWeightSum, true); -#ifdef GPU +#ifdef XFT_GPU xft::Matrix outWeightT; outWeightT.Resize(ctx->attHeadNum * ctx->attHeadSize, hiddenSize); ctx->mmHelper->transposeWeight(trans, convertedOutWeight, outWeightT); @@ -336,7 +336,7 @@ class Attention { dbg.dumpMatrix(key); #endif -#ifdef GPU +#ifdef XFT_GPU int64_t qkvSize = qkvRows * qkvStride * sizeof(ImT); ImT *qkvTmp = (ImT *)xft::alloc(qkvSize); xft::memcopy(qkvTmp, qkvGroupMatMul.Data(), qkvSize, ctx->device); // error: need CPU ptr and GPU ptr @@ -361,7 +361,7 @@ class Attention { // For multiple nodes inference, not the whole result buffer xft::Matrix attnSplit(imBuffer.Data(), imBuffer.Rows(), qCols, qCols); -#ifdef GPU +#ifdef XFT_GPU int64_t attnSplitSize = imBuffer.Rows() * qCols * sizeof(ImT); ImT *attnSplitTmp = (ImT *)xft::alloc(attnSplitSize); attnSplit.Assign(attnSplitTmp, imBuffer.Rows(), qCols, qCols); @@ -380,7 +380,7 @@ class Attention { } t4.release(); -#ifdef GPU +#ifdef XFT_GPU xft::memcopy(imBuffer.Data(), attnSplit.Data(), attnSplitSize, ctx->device); attnSplit.Assign(imBuffer.Data(), imBuffer.Rows(), qCols, qCols); xft::dealloc(qkvTmp); diff --git a/src/layers/dist_linear.h b/src/layers/dist_linear.h index 1eefb3af..b118b5fb 100644 --- a/src/layers/dist_linear.h +++ b/src/layers/dist_linear.h @@ -66,7 +66,7 @@ class DistLinear { xft::Matrix quantizedWeight; ctx->mmHelper->convertWeight( true, K, N, w + splitOffset * K, nullptr, nullptr, quantizedWeight, scaleWeight, zeroWeight, sumWeight); -#ifdef GPU +#ifdef XFT_GPU xft::Matrix tWeight; tWeight.Resize(K, N); ctx->mmHelper->transposeWeight(true, quantizedWeight, tWeight); diff --git a/src/layers/layer_norm.cpp b/src/layers/layer_norm.cpp index 44ad949f..012f8460 100644 --- a/src/layers/layer_norm.cpp +++ b/src/layers/layer_norm.cpp @@ -59,7 +59,7 @@ void LayerNorm::setWeight(const std::string &gammaPath, const std::string &betaP // input and output are in shape of (rows, normSize) // TODO: column-wise parallel -#ifdef GPU +#ifdef XFT_GPU void LayerNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("LayerNorm.forward"); const float *pgamma = gamma; diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index 6d63fd1e..334644bb 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -75,7 +75,7 @@ class LlamaMLP { quantizedGateWeight.Release(); quantizedUpWeight.Release(); -#ifdef GPU +#ifdef XFT_GPU xft::Matrix catWeightsT; int catWeiRows = quantizedCatWeights.Rows(); int catWeiCols = quantizedCatWeights.Cols(); @@ -93,7 +93,7 @@ class LlamaMLP { // Horizontally split the down weight ctx->mmHelper->convertWeight(ctx, trans, imSize, hiddenSize, downW, downS, downZ, false, quantizedDownWeight, downWeightScale, downWeightZero, downWeightSum); -#ifdef GPU +#ifdef XFT_GPU xft::Matrix downWeightT; int downWeiRows = it.second - it.first; int downWeiCols = hiddenSize; diff --git a/src/layers/rms_norm.cpp b/src/layers/rms_norm.cpp index 7aa5cd68..0fb0fe01 100644 --- a/src/layers/rms_norm.cpp +++ b/src/layers/rms_norm.cpp @@ -23,7 +23,7 @@ #include "timeline.h" #include "transformer_ctx.h" -#ifdef GPU +#ifdef XFT_GPU #include "gpudnn/gpu_layernorm_kernels.h" #include #endif @@ -76,7 +76,7 @@ void RmsNormImp::setWeight(const std::string &modelPath, const std::string &, this->setWeight(weightBuf, nullptr, cols); } -#ifdef GPU +#ifdef XFT_GPU template void RmsNormImp::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); diff --git a/src/layers/rms_norm.h b/src/layers/rms_norm.h index 05cafd2b..b78a3ae0 100644 --- a/src/layers/rms_norm.h +++ b/src/layers/rms_norm.h @@ -56,7 +56,7 @@ class RmsNormImp { void *device = nullptr; }; -#ifdef GPU +#ifdef XFT_GPU using RmsNorm = RmsNormImp; #else using RmsNorm = RmsNormImp; diff --git a/src/layers/rotary_embedding.cpp b/src/layers/rotary_embedding.cpp index f5e36659..5efc6020 100644 --- a/src/layers/rotary_embedding.cpp +++ b/src/layers/rotary_embedding.cpp @@ -48,7 +48,7 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) { exit(-1); } -#ifdef GPU +#ifdef XFT_GPU if (this->device != nullptr) { float *emb_cos_bak = emb_cos; float *emb_sin_bak = emb_sin; @@ -105,7 +105,7 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(const int dim, const int max_position // |_____| |_____| // head_size/2 head_size/2 -#ifdef GPU +#ifdef XFT_GPU void LlamaRotaryEmbedding::forward( float *query, float *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index 35fc1a1a..d463a899 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -371,7 +371,7 @@ class CommonDecoder : public AbstractDecoder { TimeLine t("Decoder.Seq" + std::to_string(sequenceID) + ".Step"); #endif -#ifdef GPU +#ifdef XFT_GPU size_t embBufSize = batchSize * inputSeqLen * hiddenSize * sizeof(AttnInT); AttnInT *embBufTmp = (AttnInT *)xft::alloc(embBufSize, ctx->device); AttnInT *outBufTmp = (AttnInT *)xft::alloc( @@ -491,6 +491,7 @@ class CommonDecoder : public AbstractDecoder { #endif // Predictor + const int splitSize = this->predictor->getSplitSize(); float *finalOut = (float *)outBuf; if (!logitsAll) this->predictor->forward(ctx, lnOut, finalOut, batchSize); @@ -498,7 +499,6 @@ class CommonDecoder : public AbstractDecoder { this->predictor->forward(ctx, lnOut, finalOut, batchSize * seqLen); #ifdef XFT_DEBUG - auto splitSize = this->predictor->getSplitSize(); dbg.debugPrint("finalOut:\n"); if (!logitsAll) { dbg.dumpMatrix(finalOut, batchSize, splitSize, splitSize); @@ -507,7 +507,7 @@ class CommonDecoder : public AbstractDecoder { } #endif -#ifdef GPU +#ifdef XFT_GPU xft::dealloc(embBuf, ctx->device); embBuf = (AttnInT *)actBuffers->Data(); @@ -519,7 +519,6 @@ class CommonDecoder : public AbstractDecoder { // Expand the result to make it cover multiple beams if (step == 0 && beamSize > 1) { - const int splitSize = this->predictor->getSplitSize(); for (int b = userSideBS - 1; b >= 0; --b) { float *src = finalOut + b * splitSize; #pragma omp parallel for @@ -770,7 +769,7 @@ class CommonDecoder : public AbstractDecoder { engineIdx = env.getEngineIndex(); this->mmHelper.reset(new MMHelper(env.getEngineKind(), engineIdx)); -#ifdef GPU +#ifdef XFT_GPU auto devices = sycl::device::get_devices(sycl::info::device_type::gpu); this->device.reset(new sycl::queue(devices[this->mmHelper->getEngineCount() + engineIdx])); #endif diff --git a/src/utils/compile_util.h b/src/utils/compile_util.h index 3874d763..e77e8f6b 100644 --- a/src/utils/compile_util.h +++ b/src/utils/compile_util.h @@ -17,7 +17,7 @@ #include #include -#ifdef GPU +#ifdef XFT_GPU #include #endif diff --git a/src/utils/debugger.h b/src/utils/debugger.h index 5dea56de..ce130393 100644 --- a/src/utils/debugger.h +++ b/src/utils/debugger.h @@ -114,7 +114,7 @@ class Debugger { } } -#ifdef GPU +#ifdef XFT_GPU template void dumpMatrix(xft::Matrix &m, bool print_all = false) { } diff --git a/src/utils/decoder_util.h b/src/utils/decoder_util.h index 8038c288..4c8dba1c 100644 --- a/src/utils/decoder_util.h +++ b/src/utils/decoder_util.h @@ -28,7 +28,7 @@ #include "transformer_ctx.h" #include "xdnn.h" -#ifdef GPU +#ifdef XFT_GPU #include #endif @@ -448,7 +448,7 @@ class DecoderUtil { return std::make_pair(maxVal, sum); } -#ifdef GPU +#ifdef XFT_GPU template static void siluSum(xft::Matrix &src, xft::Matrix &dst, void *device = nullptr) { int M = src.Rows(); diff --git a/src/utils/simple_mem_pool.h b/src/utils/simple_mem_pool.h index 7338dbfe..9ea074f9 100644 --- a/src/utils/simple_mem_pool.h +++ b/src/utils/simple_mem_pool.h @@ -92,7 +92,7 @@ class SimpleMemPool { // Destructor to free all allocated memory on program termination ~SimpleMemPool() { -#ifndef GPU +#ifndef XFT_GPU for (auto &entry : memoryMap) { xft::dealloc(std::get<0>(entry.second), std::get<2>(entry.second)); } diff --git a/src/utils/type_selector.h b/src/utils/type_selector.h index 60e47154..252e83ed 100644 --- a/src/utils/type_selector.h +++ b/src/utils/type_selector.h @@ -31,7 +31,7 @@ struct TypeSelector { using OutType = bfloat16_t; }; -#ifdef GPU +#ifdef XFT_GPU template <> struct TypeSelector { using InType = float16_t; diff --git a/src/utils/verbose.h b/src/utils/verbose.h index cd838df8..542e83c5 100644 --- a/src/utils/verbose.h +++ b/src/utils/verbose.h @@ -51,7 +51,7 @@ class Printer { static void print(std::string buf_name, T *buf, int rows, int cols, int stride, bool printAll = false, void *device = nullptr) { std::cout << buf_name.c_str() << ":" << std::endl; -#ifdef GPU +#ifdef XFT_GPU if (device != nullptr) { sycl::queue *gpu_queue = static_cast(device); gpu_queue From ea1679d0f1227ca9a061c5bbd57666592c6b1433 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Fri, 14 Jun 2024 11:28:37 +0000 Subject: [PATCH 34/34] Add requirements-gpu.txt --- requirements-gpu.txt | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 requirements-gpu.txt diff --git a/requirements-gpu.txt b/requirements-gpu.txt new file mode 100644 index 00000000..f699e100 --- /dev/null +++ b/requirements-gpu.txt @@ -0,0 +1,8 @@ +-f https://download.pytorch.org/whl/torch_stable.html +cmake==3.26.1 +sentencepiece==0.1.99 +torch==2.3.0+cpu.cxx11.abi +transformers==4.40.0 +accelerate==0.23.0 +protobuf +tiktoken