Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[kernel] Add ICX compiler. #228

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,24 @@
cmake_minimum_required(VERSION 3.15.1)
project(xfastertransformer LANGUAGES C CXX)

# Get gcc version
execute_process(COMMAND ${CMAKE_CXX_COMPILER} -dumpfullversion
OUTPUT_VARIABLE GCC_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE)
message(STATUS "Notice: GCC version: ${GCC_VERSION}")
# Enable GPU
option(WITH_GPU "Build with GPU" OFF)
if(WITH_GPU)
message(STATUS "Notice: Building with GPU.")
add_definitions(-DGPU=true)
# Get compiler version
execute_process(COMMAND ${CMAKE_CXX_COMPILER} --version
OUTPUT_VARIABLE ICPX_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE)
message(STATUS "Notice: ICPX version: ${ICPX_VERSION}")
else()
message(STATUS "Notice: Building with CPU.")
# Get compiler version
execute_process(COMMAND ${CMAKE_CXX_COMPILER} -dumpfullversion
OUTPUT_VARIABLE GCC_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE)
message(STATUS "Notice: GCC version: ${GCC_VERSION}")
endif()

if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
Expand All @@ -29,6 +42,9 @@ endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_FLAGS
"${CMAKE_CXX_FLAGS} -fopenmp -mavx512f -mavx512bw -mavx512vl -fPIC -D_GLIBCXX_USE_CXX11_ABI=0")
if(WITH_GPU)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -fsycl-device-code-split=per_kernel -lOpenCL")
endif()

# GCC>=10.1 should support avx512bf16, but need to double check as some versions have issues
if(GCC_VERSION VERSION_GREATER_EQUAL "10.1")
Expand Down
9 changes: 7 additions & 2 deletions cmake/onednn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@ project(dependency NONE)

include(ExternalProject)

set(ONEDNN_BUILD_OPTIONS -DONEDNN_LIBRARY_TYPE=STATIC -DONEDNN_BUILD_TESTS=OFF -DONEDNN_BUILD_EXAMPLES=OFF)
if(WITH_GPU)
set(ONEDNN_BUILD_OPTIONS "${ONEDNN_BUILD_OPTIONS} -DONEDNN_GPU_RUNTIME=SYCL")
endif()

# cmake-format: off
ExternalProject_Add(onednn
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.2
GIT_TAG v3.3.3
SOURCE_DIR ${CMAKE_SOURCE_DIR}/3rdparty/onednn
BINARY_DIR ${CMAKE_SOURCE_DIR}/3rdparty/onednn
CONFIGURE_COMMAND ${CMAKE_COMMAND} -E make_directory "build" && ${CMAKE_COMMAND} -E chdir "build" ${CMAKE_COMMAND} -DONEDNN_LIBRARY_TYPE=STATIC -DONEDNN_BUILD_TESTS=OFF -DONEDNN_BUILD_EXAMPLES=OFF ..
CONFIGURE_COMMAND ${CMAKE_COMMAND} -E make_directory "build" && ${CMAKE_COMMAND} -E chdir "build" ${CMAKE_COMMAND} ${ONEDNN_BUILD_OPTIONS} ..
BUILD_COMMAND ${CMAKE_COMMAND} -E chdir "build" make -j all
INSTALL_COMMAND ""
TEST_COMMAND ""
Expand Down
3 changes: 3 additions & 0 deletions examples/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,8 @@ else()
target_link_libraries(example PRIVATE xfastertransformer_static)
endif()
target_link_libraries(example PRIVATE sentencepiece -lstdc++fs)
if(WITH_GPU)
target_link_libraries(example PRIVATE -fsycl -fsycl-device-code-split=per_kernel -lOpenCL)
endif()

add_dependencies(example cmdline sentencepiece_lib)
4 changes: 2 additions & 2 deletions include/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ enum DataType {
};

enum DeviceKind {
CPU = 0,
GPU,
iCPU = 0,
iGPU,
};
} // namespace xft
2 changes: 1 addition & 1 deletion src/common/my_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
typedef int8_t s8;
typedef uint8_t u8;

typedef struct {
typedef struct w8a8 {
int8_t s8;
operator int8_t() { return s8; }
} w8a8_t;
Expand Down
36 changes: 18 additions & 18 deletions src/models/chatglm2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
#include "INIReader.h"
#include "chatglm2.h"

template <typename WeiT, typename NormT>
ChatGLM2<WeiT, NormT>::ChatGLM2(const std::string &modelPath, const std::string &modelType)
: CommonDecoder<Attention<WeiT, ChatGLM2RotaryEmbedding, NormT, float, float, float, true>,
ChatGLM2MLP<WeiT, float, float, float, NormT, true>>(modelPath, modelType) {
template <typename WeiT>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previous code contains 2 template parameters:
"template <typename WeiT, typename NormT>"
So, here the normalization operator will always be RmsNorm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, synced w/ Gui Sheng. He answered me it is rms_norm currently in chatglm2/3.
By the way, ICX could compile template template default param.

ChatGLM2<WeiT>::ChatGLM2(const std::string &modelPath, const std::string &modelType)
: CommonDecoder<Attention<WeiT, ChatGLM2RotaryEmbedding, RmsNorm, float, float, float, true>,
ChatGLM2MLP<WeiT, float, float, float, RmsNorm, true>>(modelPath, modelType) {
this->positionIds = nullptr;
this->posBufSize = 0;

Expand All @@ -36,15 +36,15 @@ ChatGLM2<WeiT, NormT>::ChatGLM2(const std::string &modelPath, const std::string
setFinalLnWeight(modelPath);
}

template <typename WeiT, typename NormT>
ChatGLM2<WeiT, NormT>::~ChatGLM2() {
template <typename WeiT>
ChatGLM2<WeiT>::~ChatGLM2() {
delete embedding;

if (positionIds) { free(positionIds); }
}

template <typename WeiT, typename NormT>
void ChatGLM2<WeiT, NormT>::setEmbeddingWeights(const std::string &modelPath) {
template <typename WeiT>
void ChatGLM2<WeiT>::setEmbeddingWeights(const std::string &modelPath) {
int vocabSize = embedding->getVocabSize();
int hiddenSize = embedding->getHiddenSize();

Expand All @@ -57,8 +57,8 @@ void ChatGLM2<WeiT, NormT>::setEmbeddingWeights(const std::string &modelPath) {
free(tokenEmb);
}

template <typename WeiT, typename NormT>
void ChatGLM2<WeiT, NormT>::setFinalLnWeight(const std::string &modelPath) {
template <typename WeiT>
void ChatGLM2<WeiT>::setFinalLnWeight(const std::string &modelPath) {
int hiddenSize = embedding->getHiddenSize();

float *gamma = (float *)malloc(hiddenSize * sizeof(float));
Expand All @@ -85,8 +85,8 @@ void ChatGLM2<WeiT, NormT>::setFinalLnWeight(const std::string &modelPath) {
// attention_mask = (attention_mask < 0.5).bool()
//
// return attention_mask
template <typename WeiT, typename NormT>
void ChatGLM2<WeiT, NormT>::prepareAttnMask(int *ids, int step) {
template <typename WeiT>
void ChatGLM2<WeiT>::prepareAttnMask(int *ids, int step) {
DecoderContext *ctx = this->getContext();
int seqLen = ctx->inputSeqLen;
int sizeRequired = ctx->batchSize * seqLen * seqLen;
Expand Down Expand Up @@ -127,13 +127,13 @@ void ChatGLM2<WeiT, NormT>::prepareAttnMask(int *ids, int step) {
}
}

template <typename WeiT, typename NormT>
void ChatGLM2<WeiT, NormT>::embeddingForward(int *ids, float *output, int batchSize, int seqLen) {
template <typename WeiT>
void ChatGLM2<WeiT>::embeddingForward(int *ids, float *output, int batchSize, int seqLen) {
embedding->forward(ids, output, batchSize, seqLen);
}

template <typename WeiT, typename NormT>
void ChatGLM2<WeiT, NormT>::lastLayerNormForward(float *input, float *output, int rows) {
template <typename WeiT>
void ChatGLM2<WeiT>::lastLayerNormForward(float *input, float *output, int rows) {
finalLN.forward(input, output, rows);
}

Expand All @@ -147,8 +147,8 @@ void ChatGLM2<WeiT, NormT>::lastLayerNormForward(float *input, float *output, in
// batch_size, seq_length = input_ids.shape
// position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
// return position_ids
template <typename WeiT, typename NormT>
int *ChatGLM2<WeiT, NormT>::getPositionIds(int *ids, int batchSize, int seqLen, int step) {
template <typename WeiT>
int *ChatGLM2<WeiT>::getPositionIds(int *ids, int batchSize, int seqLen, int step) {
// Prepare buffer
int sizeNeeded = (batchSize * seqLen + 63) / 64 * 64; // position_ids + block_position_ids
if (posBufSize < sizeNeeded) {
Expand Down
8 changes: 4 additions & 4 deletions src/models/chatglm2.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
#include "rotary_embedding_chatglm2.h"
#include "token_embedding.h"

template <typename WeiT, typename NormT = RmsNorm>
class ChatGLM2 : public CommonDecoder<Attention<WeiT, ChatGLM2RotaryEmbedding, NormT, float, float, float, true>,
ChatGLM2MLP<WeiT, float, float, float, NormT, true>> {
template <typename WeiT>
class ChatGLM2 : public CommonDecoder<Attention<WeiT, ChatGLM2RotaryEmbedding, RmsNorm, float, float, float, true>,
ChatGLM2MLP<WeiT, float, float, float, RmsNorm, true>> {
public:
ChatGLM2(const std::string &modelPath, const std::string &modelType = "chatglm2");
~ChatGLM2();
Expand All @@ -40,7 +40,7 @@ class ChatGLM2 : public CommonDecoder<Attention<WeiT, ChatGLM2RotaryEmbedding, N

private:
TokenEmbedding<float16_t> *embedding;
NormT finalLN;
RmsNorm finalLN;

// Record last block positions
std::vector<int> lastBlockPositions;
Expand Down
6 changes: 3 additions & 3 deletions src/models/chatglm3.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
#include "chatglm2.h"

// ChatGLM3 and ChatGLM2 have the same structure, so ChatGLM3 utilizes the implementation of ChatGLM2.
template <typename WeiT, typename NormT = RmsNorm>
class ChatGLM3 : public ChatGLM2<WeiT, NormT> {
template <typename WeiT>
class ChatGLM3 : public ChatGLM2<WeiT> {
public:
ChatGLM3(const std::string &modelPath) : ChatGLM2<WeiT, NormT>(modelPath, "chatglm3") {}
ChatGLM3(const std::string &modelPath) : ChatGLM2<WeiT>(modelPath, "chatglm3") {}
};

template class ChatGLM3<float>;
Expand Down
2 changes: 1 addition & 1 deletion src/models/common_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ class CommonDecoder : public AbstractDecoder {
this->context.reset(new DecoderContext(layers, hiddenSize, attHeadNum, kvHeadNum, imSize, act, epsilon,
vocabSize, embeddingSize, maxPositions, maxPosEmbed, maxSeqLength, tpRank, tpSize, ppSize, ppRank,
ropeParamsPtr));
this->context->mmHelper = new MMHelper(xft::DeviceKind::CPU, 0);
this->context->mmHelper = new MMHelper(xft::DeviceKind::iCPU, 0);
}

return this->context.get();
Expand Down
24 changes: 12 additions & 12 deletions src/models/models.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,12 @@ AutoModel::AutoModel(std::string modelPath, xft::DataType datatype) : Model() {
}
} else if (modeltype == "chatglm2") {
switch (datatype) {
case xft::DataType::fp16: setDecoder(new ChatGLM2<float16_t, RmsNorm>(modelPath)); break;
case xft::DataType::bf16: setDecoder(new ChatGLM2<bfloat16_t, RmsNorm>(modelPath)); break;
case xft::DataType::int8: setDecoder(new ChatGLM2<int8_t, RmsNorm>(modelPath)); break;
case xft::DataType::w8a8: setDecoder(new ChatGLM2<w8a8_t, RmsNorm>(modelPath)); break;
case xft::DataType::int4: setDecoder(new ChatGLM2<uint4x2_t, RmsNorm>(modelPath)); break;
case xft::DataType::nf4: setDecoder(new ChatGLM2<nf4x2_t, RmsNorm>(modelPath)); break;
case xft::DataType::fp16: setDecoder(new ChatGLM2<float16_t>(modelPath)); break;
case xft::DataType::bf16: setDecoder(new ChatGLM2<bfloat16_t>(modelPath)); break;
case xft::DataType::int8: setDecoder(new ChatGLM2<int8_t>(modelPath)); break;
case xft::DataType::w8a8: setDecoder(new ChatGLM2<w8a8_t>(modelPath)); break;
case xft::DataType::int4: setDecoder(new ChatGLM2<uint4x2_t>(modelPath)); break;
case xft::DataType::nf4: setDecoder(new ChatGLM2<nf4x2_t>(modelPath)); break;
case xft::DataType::bf16_fp16:
setDecoder(new HybridModel<ChatGLM2, bfloat16_t, float16_t>(modelPath));
break;
Expand All @@ -399,12 +399,12 @@ AutoModel::AutoModel(std::string modelPath, xft::DataType datatype) : Model() {
}
} else if (modeltype == "chatglm3") {
switch (datatype) {
case xft::DataType::fp16: setDecoder(new ChatGLM3<float16_t, RmsNorm>(modelPath)); break;
case xft::DataType::bf16: setDecoder(new ChatGLM3<bfloat16_t, RmsNorm>(modelPath)); break;
case xft::DataType::int8: setDecoder(new ChatGLM3<int8_t, RmsNorm>(modelPath)); break;
case xft::DataType::w8a8: setDecoder(new ChatGLM3<w8a8_t, RmsNorm>(modelPath)); break;
case xft::DataType::int4: setDecoder(new ChatGLM3<uint4x2_t, RmsNorm>(modelPath)); break;
case xft::DataType::nf4: setDecoder(new ChatGLM3<nf4x2_t, RmsNorm>(modelPath)); break;
case xft::DataType::fp16: setDecoder(new ChatGLM3<float16_t>(modelPath)); break;
case xft::DataType::bf16: setDecoder(new ChatGLM3<bfloat16_t>(modelPath)); break;
case xft::DataType::int8: setDecoder(new ChatGLM3<int8_t>(modelPath)); break;
case xft::DataType::w8a8: setDecoder(new ChatGLM3<w8a8_t>(modelPath)); break;
case xft::DataType::int4: setDecoder(new ChatGLM3<uint4x2_t>(modelPath)); break;
case xft::DataType::nf4: setDecoder(new ChatGLM3<nf4x2_t>(modelPath)); break;
case xft::DataType::bf16_fp16:
setDecoder(new HybridModel<ChatGLM3, bfloat16_t, float16_t>(modelPath));
break;
Expand Down
5 changes: 1 addition & 4 deletions src/pytorch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# ============================================================================
cmake_minimum_required(VERSION 3.15.1)

find_package(OpenMP REQUIRED)

execute_process(COMMAND ${Python_EXECUTABLE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
OUTPUT_VARIABLE TORCH_CMAKE_PREFIX_PATH
OUTPUT_STRIP_TRAILING_WHITESPACE)
Expand Down Expand Up @@ -47,8 +45,7 @@ target_include_directories(xfastertransformer_pt PUBLIC ${PyTorch_INCLUDE_DIR})

# Link against LibTorch and others
target_link_libraries(xfastertransformer_pt
PRIVATE OpenMP::OpenMP_CXX
"${TORCH_LIBS}"
PRIVATE "${TORCH_LIBS}"
xfastertransformer_static
stdc++fs)

Expand Down
4 changes: 2 additions & 2 deletions src/utils/matmul_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@
class MMHelper {
public:
MMHelper(xft::DeviceKind device_kind, int idx) {
if (device_kind == xft::DeviceKind::CPU) {
if (device_kind == xft::DeviceKind::iCPU) {
kind = dnnl::engine::kind::cpu;
engine = new dnnl::engine(kind, idx);
stream = new dnnl::stream(*engine);
} else if (device_kind == xft::DeviceKind::GPU) {
} else if (device_kind == xft::DeviceKind::iGPU) {
kind = dnnl::engine::kind::gpu;
engine = new dnnl::engine(kind, idx);
stream = new dnnl::stream(*engine);
Expand Down