Skip to content

Commit

Permalink
Metal support for all models (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus authored Oct 9, 2023
1 parent 9be06f0 commit b9a2388
Show file tree
Hide file tree
Showing 13 changed files with 495 additions and 362 deletions.
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ add_subdirectory(third_party/sentencepiece)

if (GGML_CUBLAS)
add_compile_definitions(GGML_USE_CUBLAS)
set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES "52;61;70;75;80;86")
set(CUDA_ARCHITECTURES "52;61;70;75;80;86" CACHE STRING "chatglm: cuda architectures to compile")
set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES ${CUDA_ARCHITECTURES})
endif ()

if (GGML_METAL)
Expand Down Expand Up @@ -72,7 +73,7 @@ if (CHATGLM_ENABLE_TESTING)
gtest_discover_tests(chatglm_test)
endif ()

option(CHATGLM_ENABLE_PYBIND, "chatglm: enable python binding" OFF)
option(CHATGLM_ENABLE_PYBIND "chatglm: enable python binding" OFF)
if (CHATGLM_ENABLE_PYBIND)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
set_target_properties(chatglm ggml sentencepiece-static PROPERTIES POSITION_INDEPENDENT_CODE TRUE)
Expand Down
23 changes: 15 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ C++ implementation of [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) and [Cha
## Features

Highlights:
* [x] Pure C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp).
* [x] Accelerated memory-efficient CPU inference with int4/int8 quantization, optimized KV cache and parallel computing.
* [x] Streaming generation with typewriter effect.
* [x] Python binding, web demo, api servers and more possibilities.
* Pure C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp).
* Accelerated memory-efficient CPU inference with int4/int8 quantization, optimized KV cache and parallel computing.
* Streaming generation with typewriter effect.
* Python binding, web demo, api servers and more possibilities.

Support Matrix:
* Hardwares: x86/arm CPU, NVIDIA GPU, Apple Silicon GPU
Expand All @@ -42,7 +42,7 @@ git submodule update --init --recursive
Install necessary packages for loading and quantizing Hugging Face models:
```sh
python3 -m pip install -U pip
python3 -m pip install torch tabulate tqdm transformers sentencepiece
python3 -m pip install torch tabulate tqdm transformers accelerate sentencepiece
```

Use `convert.py` to transform ChatGLM-6B or ChatGLM2-6B into quantized GGML format. For example, to convert the fp16 original model to q4_0 (quantized int4) GGML model, run:
Expand Down Expand Up @@ -176,7 +176,11 @@ cuBLAS uses NVIDIA GPU to accelerate BLAS. Add the CMake flag `-DGGML_CUBLAS=ON`
cmake -B build -DGGML_CUBLAS=ON && cmake --build build -j
```
Note that the current GGML CUDA implementation is really slow. The community is making efforts to optimize it.
By default, all kernels will be compiled for all possible CUDA architectures and it takes some time. To run on a specific type of device, you may specify `CUDA_ARCHITECTURES` to speed up the nvcc compilation. For example:
```sh
cmake -B build -DGGML_CUBLAS=ON -DCUDA_ARCHITECTURES="80" # for A100
cmake -B build -DGGML_CUBLAS=ON -DCUDA_ARCHITECTURES="70;75" # compatible with both V100 and T4
```
**Metal**
Expand Down Expand Up @@ -419,14 +423,15 @@ Python demo and API servers are also supported in pre-built image. Use it in the
Environment:
* CPU backend performance is measured on a Linux server with Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz using 16 threads.
* CUDA backend is measured on a V100-SXM2-32GB GPU using 1 thread.
* MPS backend is measured on an Apple M2 Ultra device using 1 thread (currently only supports ChatGLM2).
* MPS backend is measured on an Apple M2 Ultra device using 1 thread.
ChatGLM-6B:
| | Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 | F16 |
|--------------------------------|-------|-------|-------|-------|-------|-------|
| ms/token (CPU @ Platinum 8260) | 74 | 77 | 86 | 89 | 114 | 189 |
| ms/token (CUDA @ V100 SXM2) | 8.1 | 8.7 | 9.4 | 9.5 | 12.0 | 19.1 |
| ms/token (MPS @ M2 Ultra) | 11.5 | 12.3 | N/A | N/A | 16.1 | 24.4 |
| file size | 3.3G | 3.7G | 4.0G | 4.4G | 6.2G | 12G |
| mem usage | 4.0G | 4.4G | 4.7G | 5.1G | 6.9G | 13G |
Expand All @@ -436,7 +441,7 @@ ChatGLM2-6B / CodeGeeX2:
|--------------------------------|-------|-------|-------|-------|-------|-------|
| ms/token (CPU @ Platinum 8260) | 64 | 71 | 79 | 83 | 106 | 189 |
| ms/token (CUDA @ V100 SXM2) | 7.9 | 8.3 | 9.2 | 9.2 | 11.7 | 18.5 |
| ms/token (MPS @ M2 Ultra) | 11.0 | 11.7 | N/A | N/A | N/A | 32.1 |
| ms/token (MPS @ M2 Ultra) | 10.0 | 10.8 | N/A | N/A | 14.5 | 22.2 |
| file size | 3.3G | 3.7G | 4.0G | 4.4G | 6.2G | 12G |
| mem usage | 3.4G | 3.8G | 4.1G | 4.5G | 6.2G | 12G |
Expand All @@ -446,6 +451,7 @@ Baichuan-7B / Baichuan2-7B:
|--------------------------------|-------|-------|-------|-------|-------|-------|
| ms/token (CPU @ Platinum 8260) | 85.3 | 94.8 | 103.4 | 109.6 | 136.8 | 248.5 |
| ms/token (CUDA @ V100 SXM2) | 8.7 | 9.2 | 10.2 | 10.3 | 13.2 | 21.0 |
| ms/token (MPS @ M2 Ultra) | 11.3 | 12.0 | N/A | N/A | 16.4 | 25.6 |
| file size | 4.0G | 4.4G | 4.9G | 5.3G | 7.5G | 14G |
| mem usage | 4.5G | 4.9G | 5.3G | 5.7G | 7.8G | 14G |
Expand All @@ -455,6 +461,7 @@ Baichuan-13B / Baichuan2-13B:
|--------------------------------|-------|-------|-------|-------|-------|-------|
| ms/token (CPU @ Platinum 8260) | 161.7 | 175.8 | 189.9 | 192.3 | 255.6 | 459.6 |
| ms/token (CUDA @ V100 SXM2) | 13.7 | 15.1 | 16.3 | 16.9 | 21.9 | 36.8 |
| ms/token (MPS @ M2 Ultra) | 18.2 | 18.8 | N/A | N/A | 27.2 | 44.4 |
| file size | 7.0G | 7.8G | 8.5G | 9.3G | 14G | 25G |
| mem usage | 7.8G | 8.8G | 9.5G | 10G | 14G | 25G |
Expand Down
23 changes: 11 additions & 12 deletions chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,9 @@ BaseModelForCausalLM::BaseModelForCausalLM(ModelType model_type, ModelConfig con
ctx_.compute_buffer.resize(mem_size);
ctx_.scratch_buffer.resize(scratch_size);
ctx_.scratch = {0, ctx_.scratch_buffer.size(), ctx_.scratch_buffer.data()};
#ifdef GGML_USE_CUBLAS
ggml_cuda_set_scratch_size(scratch_size);
#endif
}

int BaseModelForCausalLM::generate_next_token(const std::vector<int> &input_ids, const GenerationConfig &gen_config,
Expand Down Expand Up @@ -473,6 +476,11 @@ int BaseModelForCausalLM::generate_next_token(const std::vector<int> &input_ids,
int vocab_size = lm_logits->ne[0];
float *next_token_logits = (float *)lm_logits->data;

// check nan
for (int i = 0; i < vocab_size; i++) {
CHATGLM_CHECK(std::isfinite(next_token_logits[i])) << "nan/inf encountered at lm_logits[" << i << "]";
}

// logits pre-process
if (gen_config.repetition_penalty != 1.f) {
sampling_repetition_penalty(next_token_logits, next_token_logits + vocab_size, input_ids,
Expand Down Expand Up @@ -780,13 +788,14 @@ ggml_tensor *GLMContextMasker::operator()(ModelContext *ctx, ggml_tensor *attn_s
return attn_scores;
}

ggml_tensor *GLMBlock::forward(ModelContext *ctx, ggml_tensor *hidden_states, int n_past, int n_ctx) const {
ggml_tensor *GLMBlock::forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *position_ids, int n_past,
int n_ctx) const {
ggml_context *gctx = ctx->ctx_b.get();

ggml_tensor *alpha = ggml_new_f32(gctx, alpha_value);

ggml_tensor *attn_input = input_layernorm.forward(ctx, hidden_states);
ggml_tensor *attn_output = attention.forward(ctx, attn_input, n_past, n_ctx);
ggml_tensor *attn_output = attention.forward(ctx, attn_input, position_ids, n_past, n_ctx);
ggml_build_forward_expand(&ctx->gf, attn_output);
attn_input = tensor_assign_buffers(ggml_scale_inplace(gctx, attn_input, alpha));
hidden_states = tensor_assign_buffers(ggml_add_inplace(gctx, attn_input, attn_output));
Expand All @@ -800,16 +809,6 @@ ggml_tensor *GLMBlock::forward(ModelContext *ctx, ggml_tensor *hidden_states, in
return output;
}

std::vector<GLMBlock> ChatGLMModel::build_layers(ModelContext *ctx, const ModelConfig &config) {
std::vector<GLMBlock> layers;
layers.reserve(config.num_hidden_layers);
for (int layer_id = 0; layer_id < config.num_hidden_layers; layer_id++) {
layers.emplace_back(ctx, config.hidden_size, config.num_attention_heads, config.num_hidden_layers,
config.max_length);
}
return layers;
}

ChatGLMForCausalLM::ChatGLMForCausalLM(const ModelConfig &config)
: BasicModelForCausalLM(MODEL_TYPE_CHATGLM, config, MEM_SIZE, SCRATCH_SIZE) {
constexpr size_t tensor_ovhd = GGML_TENSOR_SIZE + GGML_OBJECT_SIZE;
Expand Down
135 changes: 105 additions & 30 deletions chatglm.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,21 +299,65 @@ enum RopeType {
};

struct NoopRoper {
ggml_tensor *operator()(ggml_context *ctx, ggml_tensor *a, int n_past, int n_ctx) const { return a; }
ggml_tensor *operator()(ModelContext *ctx, ggml_tensor *a, ggml_tensor *b, int n_ctx) const { return a; }
};

template <RopeType MODE, int DIM_SCALE>
struct BasicRoper {
ggml_tensor *operator()(ggml_context *ctx, ggml_tensor *a, int n_past, int n_ctx) const {
ggml_tensor *operator()(ModelContext *ctx, ggml_tensor *a, ggml_tensor *b, int n_ctx) const {
// tensor a (activation) is of shape [qlen, heads, head_size]
// tensor b (position_ids) is of shape [qlen]
ggml_context *gctx = ctx->ctx_b.get();
#ifdef GGML_USE_CUBLAS
if (!ggml_is_contiguous(a)) {
a = tensor_assign_buffers(ggml_cont(ctx, a));
a = tensor_assign_buffers(ggml_cont(gctx, a));
}
#endif
const int head_size = a->ne[0];
const int rope_dim = head_size / DIM_SCALE;
a = tensor_assign_buffers(ggml_rope_inplace(ctx, a, n_past, rope_dim, MODE, n_ctx)); // [qlen, heads, head_size]
a = tensor_assign_buffers(ggml_rope_inplace(gctx, a, b, rope_dim, MODE, n_ctx)); // [qlen, heads, head_size]

return a;
}
};

struct GLMRoper {
ggml_tensor *operator()(ModelContext *ctx, ggml_tensor *a, ggml_tensor *b, int n_ctx) const {
// tensor a (activation) is of shape [qlen, heads, head_size]
// tensor b (position_ids) is of shape [2 * qlen]
ggml_context *gctx = ctx->ctx_b.get();

const int head_size = a->ne[0];
const int num_heads = a->ne[1];
const int qlen = a->ne[2];
const int rope_dim = head_size / 2;

ggml_tensor *b1 = ggml_view_1d(gctx, b, qlen, 0);
ggml_tensor *b2 = ggml_view_1d(gctx, b, qlen, qlen * ggml_element_size(b));

ggml_tensor *a1 = ggml_view_3d(gctx, a, head_size / 2, num_heads, qlen, a->nb[1], a->nb[2], 0);
ggml_tensor *a2 = ggml_view_3d(gctx, a, head_size / 2, num_heads, qlen, a->nb[1], a->nb[2],
head_size / 2 * ggml_element_size(a));

ggml_tensor *a1_rope = a1;
ggml_tensor *a2_rope = a2;
#ifdef GGML_USE_CUBLAS
a1_rope = tensor_assign_buffers(ggml_cont(gctx, a1_rope));
a2_rope = tensor_assign_buffers(ggml_cont(gctx, a2_rope));
#endif

a1_rope = tensor_assign_buffers(
ggml_rope_inplace(gctx, a1_rope, b1, rope_dim, ROPE_TYPE_NEOX, n_ctx)); // [qlen, heads, head_size/2]
a2_rope = tensor_assign_buffers(
ggml_rope_inplace(gctx, a2_rope, b2, rope_dim, ROPE_TYPE_NEOX, n_ctx)); // [qlen, heads, head_size/2]

#ifdef GGML_USE_CUBLAS
a1_rope = ggml_cpy(gctx, a1_rope, a1);
a2_rope = ggml_cpy(gctx, a2_rope, a2);
#endif
ggml_build_forward_expand(&ctx->gf, a1_rope);
ggml_build_forward_expand(&ctx->gf, a2_rope);

return a;
}
};
Expand All @@ -333,7 +377,8 @@ class BasicAttention {
v_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, max_length, hidden_size / num_attention_heads,
num_kv_heads)) {}

ggml_tensor *forward(ModelContext *ctx, ggml_tensor *hidden_states, int n_past, int n_ctx) const {
ggml_tensor *forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *position_ids, int n_past,
int n_ctx) const {
ggml_context *gctx = ctx->ctx_b.get();

const int hidden_size = hidden_states->ne[0];
Expand Down Expand Up @@ -368,8 +413,8 @@ class BasicAttention {
qkv->nb[1], (hidden_size + head_size * num_kv_heads) * ggml_element_size(qkv));
}

query_layer = roper_(gctx, query_layer, n_past, n_ctx);
key_layer = roper_(gctx, key_layer, n_past, n_ctx);
query_layer = roper_(ctx, query_layer, position_ids, n_ctx);
key_layer = roper_(ctx, key_layer, position_ids, n_ctx);

query_layer = tensor_assign_buffers(
ggml_cont(gctx, ggml_permute(gctx, query_layer, 0, 2, 1, 3))); // [heads, qlen, head_size]
Expand Down Expand Up @@ -462,12 +507,13 @@ class BasicBlock {
attention(ctx, hidden_size, num_attention_heads, num_kv_heads, max_length),
post_attention_layernorm(ctx, hidden_size, false, norm_eps), mlp(ctx, hidden_size, intermediate_size) {}

ggml_tensor *forward(ModelContext *ctx, ggml_tensor *hidden_states, int n_past, int n_ctx) const {
ggml_tensor *forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *position_ids, int n_past,
int n_ctx) const {
ggml_context *gctx = ctx->ctx_b.get();

ggml_tensor *residual = hidden_states;
hidden_states = input_layernorm.forward(ctx, hidden_states);
hidden_states = attention.forward(ctx, hidden_states, n_past, n_ctx);
hidden_states = attention.forward(ctx, hidden_states, position_ids, n_past, n_ctx);
hidden_states = tensor_assign_buffers(ggml_add_inplace(gctx, hidden_states, residual));

residual = hidden_states;
Expand All @@ -490,7 +536,33 @@ class BasicBlock {
MLP mlp;
};

template <typename Block, typename Norm>
struct NoopPositionIdsGenerator {
ggml_tensor *operator()(ggml_context *ctx, int qlen, int n_past, int n_ctx) const { return nullptr; }
};

struct BasicPositionIdsGenerator {
ggml_tensor *operator()(ggml_context *ctx, int qlen, int n_past, int n_ctx) const {
ggml_tensor *position_ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, qlen);
for (int i = 0; i < qlen; i++) {
((int *)position_ids->data)[i] = n_past + i;
}
return position_ids;
}
};

struct GLMPositionIdsGenerator {
ggml_tensor *operator()(ggml_context *ctx, int qlen, int n_past, int n_ctx) const {
ggml_tensor *position_ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, qlen * 2);
for (int i = 0; i < qlen; i++) {
const int p = n_past + i;
((int *)position_ids->data)[i] = std::min(p, n_ctx - 2);
((int *)position_ids->data)[qlen + i] = std::max(p - (n_ctx - 2), 0);
}
return position_ids;
}
};

template <typename Block, typename Norm, typename PositionIdsGenerator>
class BasicModel {
public:
BasicModel() = default;
Expand All @@ -504,10 +576,17 @@ class BasicModel {

ggml_tensor *forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, int n_ctx) const {
ggml_context *gctx = ctx->ctx_b.get();
ggml_tensor *position_ids = pos_ids_gen_(gctx, input_ids->ne[0], n_past, n_ctx);
if (position_ids) {
tensor_to_device(position_ids);
}
ggml_tensor *hidden_states = word_embeddings.forward(ctx, input_ids);
for (const auto &layer : layers) {
ggml_set_scratch(gctx, ctx->scratch);
hidden_states = layer.forward(ctx, hidden_states, n_past, n_ctx);
hidden_states = layer.forward(ctx, hidden_states, position_ids, n_past, n_ctx);
}
if (position_ids) {
tensor_to_cpu(position_ids);
}
ggml_scratch empty_scratch = {0, 0, nullptr};
ggml_set_scratch(gctx, empty_scratch);
Expand All @@ -531,6 +610,9 @@ class BasicModel {
Embedding word_embeddings;
std::vector<Block> layers;
Norm final_layernorm;

private:
PositionIdsGenerator pos_ids_gen_;
};

class BaseStreamer {
Expand Down Expand Up @@ -800,35 +882,28 @@ struct GLMContextMasker {
ggml_tensor *operator()(ModelContext *ctx, ggml_tensor *attn_scores, int n_past) const;
};

using GLMAttention = BasicAttention<true, true, true, BasicRoper<ROPE_TYPE_CHATGLM, 2>, false, GLMContextMasker>;
using GLMAttention = BasicAttention<true, true, true, GLMRoper, false, GLMContextMasker>;

using GLMMLP = BasicMLP<ACT_TYPE_GELU>;

class GLMBlock : public BasicBlock<LayerNorm, GLMAttention, GLMMLP> {
public:
GLMBlock() = default;
GLMBlock(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_hidden_layers, int max_length)
: BasicBlock(LayerNorm(ctx, hidden_size),
GLMBlock(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size,
int max_length, float norm_eps)
: BasicBlock(LayerNorm(ctx, hidden_size, true, norm_eps),
GLMAttention(ctx, hidden_size, num_attention_heads, num_attention_heads, max_length),
LayerNorm(ctx, hidden_size), GLMMLP(ctx, hidden_size, 4 * hidden_size)),
alpha_value(std::sqrt(2.f * num_hidden_layers)) {}
LayerNorm(ctx, hidden_size, true, norm_eps), GLMMLP(ctx, hidden_size, intermediate_size)),
alpha_value(std::sqrt(2.f * 28)) {}

ggml_tensor *forward(ModelContext *ctx, ggml_tensor *hidden_states, int n_past, int n_ctx) const;
ggml_tensor *forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *position_ids, int n_past,
int n_ctx) const;

public:
float alpha_value;
};

class ChatGLMModel : public BasicModel<GLMBlock, LayerNorm> {
public:
ChatGLMModel() = default;
ChatGLMModel(ModelContext *ctx, const ModelConfig &config)
: BasicModel(Embedding(ctx, config.vocab_size, config.hidden_size), build_layers(ctx, config),
LayerNorm(ctx, config.hidden_size)) {}

private:
static std::vector<GLMBlock> build_layers(ModelContext *ctx, const ModelConfig &config);
};
using ChatGLMModel = BasicModel<GLMBlock, LayerNorm, GLMPositionIdsGenerator>;

class ChatGLMForCausalLM : public BasicModelForCausalLM<ChatGLMModel> {
public:
Expand Down Expand Up @@ -872,7 +947,7 @@ using GLM2MLP = BasicGLU<ACT_TYPE_SILU, false>;

using GLM2Block = BasicBlock<RMSNorm, GLM2Attention, GLM2MLP>;

using ChatGLM2Model = BasicModel<GLM2Block, RMSNorm>;
using ChatGLM2Model = BasicModel<GLM2Block, RMSNorm, BasicPositionIdsGenerator>;

class ChatGLM2ForCausalLM : public BasicModelForCausalLM<ChatGLM2Model> {
public:
Expand Down Expand Up @@ -921,7 +996,7 @@ using Baichuan7BMLP = BasicGLU<ACT_TYPE_SILU, false>;

using Baichuan7BBlock = BasicBlock<RMSNorm, Baichuan7BAttention, Baichuan7BMLP>;

using Baichuan7BModel = BasicModel<Baichuan7BBlock, RMSNorm>;
using Baichuan7BModel = BasicModel<Baichuan7BBlock, RMSNorm, BasicPositionIdsGenerator>;

class Baichuan7BForCausalLM : public BasicModelForCausalLM<Baichuan7BModel> {
public:
Expand All @@ -942,7 +1017,7 @@ using Baichuan13BMLP = BasicGLU<ACT_TYPE_SILU, false>;

using Baichuan13BBlock = BasicBlock<RMSNorm, Baichuan13BAttention, Baichuan13BMLP>;

using Baichuan13BModel = BasicModel<Baichuan13BBlock, RMSNorm>;
using Baichuan13BModel = BasicModel<Baichuan13BBlock, RMSNorm, NoopPositionIdsGenerator>;

class Baichuan13BForCausalLM : public BasicModelForCausalLM<Baichuan13BModel> {
public:
Expand Down
Loading

0 comments on commit b9a2388

Please sign in to comment.