From cccf85f68a25d5d322a0cb1f8fd154eb476302b7 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Tue, 20 Feb 2024 20:52:41 +0800 Subject: [PATCH] [kernel] Add GPU compiler. --- CMakeLists.txt | 26 +++++++++++++++++++++----- cmake/onednn.cmake | 9 +++++++-- examples/cpp/CMakeLists.txt | 3 +++ include/dtype.h | 4 ++-- src/common/my_types.h | 2 +- src/models/chatglm2.cpp | 36 ++++++++++++++++++------------------ src/models/chatglm2.h | 8 ++++---- src/models/chatglm3.h | 6 +++--- src/models/common_decoder.h | 2 +- src/models/models.cpp | 24 ++++++++++++------------ src/pytorch/CMakeLists.txt | 5 +---- src/utils/matmul_helper.h | 4 ++-- 12 files changed, 75 insertions(+), 54 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 21783d16..09b2d7e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,11 +16,24 @@ cmake_minimum_required(VERSION 3.15.1) project(xfastertransformer LANGUAGES C CXX) -# Get gcc version -execute_process(COMMAND ${CMAKE_CXX_COMPILER} -dumpfullversion - OUTPUT_VARIABLE GCC_VERSION - OUTPUT_STRIP_TRAILING_WHITESPACE) -message(STATUS "Notice: GCC version: ${GCC_VERSION}") +# Enable GPU +option(WITH_GPU "Build with GPU" OFF) +if(WITH_GPU) + message(STATUS "Notice: Building with GPU.") + add_definitions(-DGPU=true) + # Get compiler version + execute_process(COMMAND ${CMAKE_CXX_COMPILER} --version + OUTPUT_VARIABLE ICPX_VERSION + OUTPUT_STRIP_TRAILING_WHITESPACE) + message(STATUS "Notice: ICPX version: ${ICPX_VERSION}") +else() + message(STATUS "Notice: Building with CPU.") + # Get compiler version + execute_process(COMMAND ${CMAKE_CXX_COMPILER} -dumpfullversion + OUTPUT_VARIABLE GCC_VERSION + OUTPUT_STRIP_TRAILING_WHITESPACE) + message(STATUS "Notice: GCC version: ${GCC_VERSION}") +endif() if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) @@ -29,6 +42,9 @@ endif() set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp -mavx512f -mavx512bw -mavx512vl -fPIC -D_GLIBCXX_USE_CXX11_ABI=0") +if(WITH_GPU) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -fsycl-device-code-split=per_kernel -lOpenCL") +endif() # GCC>=10.1 should support avx512bf16, but need to double check as some versions have issues if(GCC_VERSION VERSION_GREATER_EQUAL "10.1") diff --git a/cmake/onednn.cmake b/cmake/onednn.cmake index 688b217a..399fd1ca 100644 --- a/cmake/onednn.cmake +++ b/cmake/onednn.cmake @@ -24,13 +24,18 @@ project(dependency NONE) include(ExternalProject) +set(ONEDNN_BUILD_OPTIONS -DONEDNN_LIBRARY_TYPE=STATIC -DONEDNN_BUILD_TESTS=OFF -DONEDNN_BUILD_EXAMPLES=OFF) +if(WITH_GPU) + set(ONEDNN_BUILD_OPTIONS "${ONEDNN_BUILD_OPTIONS} -DONEDNN_GPU_RUNTIME=SYCL") +endif() + # cmake-format: off ExternalProject_Add(onednn GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.2 + GIT_TAG v3.3.3 SOURCE_DIR ${CMAKE_SOURCE_DIR}/3rdparty/onednn BINARY_DIR ${CMAKE_SOURCE_DIR}/3rdparty/onednn - CONFIGURE_COMMAND ${CMAKE_COMMAND} -E make_directory "build" && ${CMAKE_COMMAND} -E chdir "build" ${CMAKE_COMMAND} -DONEDNN_LIBRARY_TYPE=STATIC -DONEDNN_BUILD_TESTS=OFF -DONEDNN_BUILD_EXAMPLES=OFF .. + CONFIGURE_COMMAND ${CMAKE_COMMAND} -E make_directory "build" && ${CMAKE_COMMAND} -E chdir "build" ${CMAKE_COMMAND} ${ONEDNN_BUILD_OPTIONS} .. BUILD_COMMAND ${CMAKE_COMMAND} -E chdir "build" make -j all INSTALL_COMMAND "" TEST_COMMAND "" diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index 4d875737..86f778fc 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -32,5 +32,8 @@ else() target_link_libraries(example PRIVATE xfastertransformer_static) endif() target_link_libraries(example PRIVATE sentencepiece -lstdc++fs) +if(WITH_GPU) + target_link_libraries(example PRIVATE -fsycl -fsycl-device-code-split=per_kernel -lOpenCL) +endif() add_dependencies(example cmdline sentencepiece_lib) diff --git a/include/dtype.h b/include/dtype.h index 542306f9..67a0b221 100644 --- a/include/dtype.h +++ b/include/dtype.h @@ -34,7 +34,7 @@ enum DataType { }; enum DeviceKind { - CPU = 0, - GPU, + iCPU = 0, + iGPU, }; } // namespace xft diff --git a/src/common/my_types.h b/src/common/my_types.h index b356800b..74b37c1a 100644 --- a/src/common/my_types.h +++ b/src/common/my_types.h @@ -22,7 +22,7 @@ typedef int8_t s8; typedef uint8_t u8; -typedef struct { +typedef struct w8a8 { int8_t s8; operator int8_t() { return s8; } } w8a8_t; diff --git a/src/models/chatglm2.cpp b/src/models/chatglm2.cpp index 0633ced9..8ed867be 100644 --- a/src/models/chatglm2.cpp +++ b/src/models/chatglm2.cpp @@ -18,10 +18,10 @@ #include "INIReader.h" #include "chatglm2.h" -template -ChatGLM2::ChatGLM2(const std::string &modelPath, const std::string &modelType) - : CommonDecoder, - ChatGLM2MLP>(modelPath, modelType) { +template +ChatGLM2::ChatGLM2(const std::string &modelPath, const std::string &modelType) + : CommonDecoder, + ChatGLM2MLP>(modelPath, modelType) { this->positionIds = nullptr; this->posBufSize = 0; @@ -36,15 +36,15 @@ ChatGLM2::ChatGLM2(const std::string &modelPath, const std::string setFinalLnWeight(modelPath); } -template -ChatGLM2::~ChatGLM2() { +template +ChatGLM2::~ChatGLM2() { delete embedding; if (positionIds) { free(positionIds); } } -template -void ChatGLM2::setEmbeddingWeights(const std::string &modelPath) { +template +void ChatGLM2::setEmbeddingWeights(const std::string &modelPath) { int vocabSize = embedding->getVocabSize(); int hiddenSize = embedding->getHiddenSize(); @@ -57,8 +57,8 @@ void ChatGLM2::setEmbeddingWeights(const std::string &modelPath) { free(tokenEmb); } -template -void ChatGLM2::setFinalLnWeight(const std::string &modelPath) { +template +void ChatGLM2::setFinalLnWeight(const std::string &modelPath) { int hiddenSize = embedding->getHiddenSize(); float *gamma = (float *)malloc(hiddenSize * sizeof(float)); @@ -85,8 +85,8 @@ void ChatGLM2::setFinalLnWeight(const std::string &modelPath) { // attention_mask = (attention_mask < 0.5).bool() // // return attention_mask -template -void ChatGLM2::prepareAttnMask(int *ids, int step) { +template +void ChatGLM2::prepareAttnMask(int *ids, int step) { DecoderContext *ctx = this->getContext(); int seqLen = ctx->inputSeqLen; int sizeRequired = ctx->batchSize * seqLen * seqLen; @@ -127,13 +127,13 @@ void ChatGLM2::prepareAttnMask(int *ids, int step) { } } -template -void ChatGLM2::embeddingForward(int *ids, float *output, int batchSize, int seqLen) { +template +void ChatGLM2::embeddingForward(int *ids, float *output, int batchSize, int seqLen) { embedding->forward(ids, output, batchSize, seqLen); } -template -void ChatGLM2::lastLayerNormForward(float *input, float *output, int rows) { +template +void ChatGLM2::lastLayerNormForward(float *input, float *output, int rows) { finalLN.forward(input, output, rows); } @@ -147,8 +147,8 @@ void ChatGLM2::lastLayerNormForward(float *input, float *output, in // batch_size, seq_length = input_ids.shape // position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) // return position_ids -template -int *ChatGLM2::getPositionIds(int *ids, int batchSize, int seqLen, int step) { +template +int *ChatGLM2::getPositionIds(int *ids, int batchSize, int seqLen, int step) { // Prepare buffer int sizeNeeded = (batchSize * seqLen + 63) / 64 * 64; // position_ids + block_position_ids if (posBufSize < sizeNeeded) { diff --git a/src/models/chatglm2.h b/src/models/chatglm2.h index 20ec4daf..e0f63104 100644 --- a/src/models/chatglm2.h +++ b/src/models/chatglm2.h @@ -22,9 +22,9 @@ #include "rotary_embedding_chatglm2.h" #include "token_embedding.h" -template -class ChatGLM2 : public CommonDecoder, - ChatGLM2MLP> { +template +class ChatGLM2 : public CommonDecoder, + ChatGLM2MLP> { public: ChatGLM2(const std::string &modelPath, const std::string &modelType = "chatglm2"); ~ChatGLM2(); @@ -40,7 +40,7 @@ class ChatGLM2 : public CommonDecoder *embedding; - NormT finalLN; + RmsNorm finalLN; // Record last block positions std::vector lastBlockPositions; diff --git a/src/models/chatglm3.h b/src/models/chatglm3.h index 615e7dc3..6b7ffe6e 100644 --- a/src/models/chatglm3.h +++ b/src/models/chatglm3.h @@ -17,10 +17,10 @@ #include "chatglm2.h" // ChatGLM3 and ChatGLM2 have the same structure, so ChatGLM3 utilizes the implementation of ChatGLM2. -template -class ChatGLM3 : public ChatGLM2 { +template +class ChatGLM3 : public ChatGLM2 { public: - ChatGLM3(const std::string &modelPath) : ChatGLM2(modelPath, "chatglm3") {} + ChatGLM3(const std::string &modelPath) : ChatGLM2(modelPath, "chatglm3") {} }; template class ChatGLM3; diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index d67dc836..45e1fed7 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -614,7 +614,7 @@ class CommonDecoder : public AbstractDecoder { this->context.reset(new DecoderContext(layers, hiddenSize, attHeadNum, kvHeadNum, imSize, act, epsilon, vocabSize, embeddingSize, maxPositions, maxPosEmbed, maxSeqLength, tpRank, tpSize, ppSize, ppRank, ropeParamsPtr)); - this->context->mmHelper = new MMHelper(xft::DeviceKind::CPU, 0); + this->context->mmHelper = new MMHelper(xft::DeviceKind::iCPU, 0); } return this->context.get(); diff --git a/src/models/models.cpp b/src/models/models.cpp index fcddb647..8bec9dde 100644 --- a/src/models/models.cpp +++ b/src/models/models.cpp @@ -377,12 +377,12 @@ AutoModel::AutoModel(std::string modelPath, xft::DataType datatype) : Model() { } } else if (modeltype == "chatglm2") { switch (datatype) { - case xft::DataType::fp16: setDecoder(new ChatGLM2(modelPath)); break; - case xft::DataType::bf16: setDecoder(new ChatGLM2(modelPath)); break; - case xft::DataType::int8: setDecoder(new ChatGLM2(modelPath)); break; - case xft::DataType::w8a8: setDecoder(new ChatGLM2(modelPath)); break; - case xft::DataType::int4: setDecoder(new ChatGLM2(modelPath)); break; - case xft::DataType::nf4: setDecoder(new ChatGLM2(modelPath)); break; + case xft::DataType::fp16: setDecoder(new ChatGLM2(modelPath)); break; + case xft::DataType::bf16: setDecoder(new ChatGLM2(modelPath)); break; + case xft::DataType::int8: setDecoder(new ChatGLM2(modelPath)); break; + case xft::DataType::w8a8: setDecoder(new ChatGLM2(modelPath)); break; + case xft::DataType::int4: setDecoder(new ChatGLM2(modelPath)); break; + case xft::DataType::nf4: setDecoder(new ChatGLM2(modelPath)); break; case xft::DataType::bf16_fp16: setDecoder(new HybridModel(modelPath)); break; @@ -399,12 +399,12 @@ AutoModel::AutoModel(std::string modelPath, xft::DataType datatype) : Model() { } } else if (modeltype == "chatglm3") { switch (datatype) { - case xft::DataType::fp16: setDecoder(new ChatGLM3(modelPath)); break; - case xft::DataType::bf16: setDecoder(new ChatGLM3(modelPath)); break; - case xft::DataType::int8: setDecoder(new ChatGLM3(modelPath)); break; - case xft::DataType::w8a8: setDecoder(new ChatGLM3(modelPath)); break; - case xft::DataType::int4: setDecoder(new ChatGLM3(modelPath)); break; - case xft::DataType::nf4: setDecoder(new ChatGLM3(modelPath)); break; + case xft::DataType::fp16: setDecoder(new ChatGLM3(modelPath)); break; + case xft::DataType::bf16: setDecoder(new ChatGLM3(modelPath)); break; + case xft::DataType::int8: setDecoder(new ChatGLM3(modelPath)); break; + case xft::DataType::w8a8: setDecoder(new ChatGLM3(modelPath)); break; + case xft::DataType::int4: setDecoder(new ChatGLM3(modelPath)); break; + case xft::DataType::nf4: setDecoder(new ChatGLM3(modelPath)); break; case xft::DataType::bf16_fp16: setDecoder(new HybridModel(modelPath)); break; diff --git a/src/pytorch/CMakeLists.txt b/src/pytorch/CMakeLists.txt index 6d7519f0..d6c9b10e 100644 --- a/src/pytorch/CMakeLists.txt +++ b/src/pytorch/CMakeLists.txt @@ -14,8 +14,6 @@ # ============================================================================ cmake_minimum_required(VERSION 3.15.1) -find_package(OpenMP REQUIRED) - execute_process(COMMAND ${Python_EXECUTABLE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)" OUTPUT_VARIABLE TORCH_CMAKE_PREFIX_PATH OUTPUT_STRIP_TRAILING_WHITESPACE) @@ -47,8 +45,7 @@ target_include_directories(xfastertransformer_pt PUBLIC ${PyTorch_INCLUDE_DIR}) # Link against LibTorch and others target_link_libraries(xfastertransformer_pt - PRIVATE OpenMP::OpenMP_CXX - "${TORCH_LIBS}" + PRIVATE "${TORCH_LIBS}" xfastertransformer_static stdc++fs) diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index f212c33c..10187eaa 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -38,11 +38,11 @@ class MMHelper { public: MMHelper(xft::DeviceKind device_kind, int idx) { - if (device_kind == xft::DeviceKind::CPU) { + if (device_kind == xft::DeviceKind::iCPU) { kind = dnnl::engine::kind::cpu; engine = new dnnl::engine(kind, idx); stream = new dnnl::stream(*engine); - } else if (device_kind == xft::DeviceKind::GPU) { + } else if (device_kind == xft::DeviceKind::iGPU) { kind = dnnl::engine::kind::gpu; engine = new dnnl::engine(kind, idx); stream = new dnnl::stream(*engine);