Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compilation on metal #337

Merged
merged 3 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ if (CHATGLM_ENABLE_PYBIND)
endif ()

# third-party libraries

# ggml
if (GGML_CUDA)
add_compile_definitions(GGML_USE_CUDA)
enable_language(CUDA)
Expand All @@ -42,33 +44,37 @@ if (GGML_CUDA)
set(CMAKE_CUDA_ARCHITECTURES ${CUDA_ARCH_LIST} CACHE STRING "")
endif ()

if (GGML_METAL)
add_compile_definitions(GGML_USE_METAL)
set(GGML_METAL_EMBED_LIBRARY ON CACHE BOOL "" FORCE)
endif ()

if (GGML_PERF)
add_compile_definitions(GGML_PERF)
endif ()

include_directories(third_party/ggml/include/ggml third_party/ggml/src)
add_subdirectory(third_party/ggml)

# sentencepiece
set(SPM_ENABLE_SHARED OFF CACHE BOOL "chatglm: disable sentencepiece shared libraries by default")
set(SPM_ENABLE_TCMALLOC OFF CACHE BOOL "chatglm: disable tcmalloc by default")
include_directories(third_party/sentencepiece/src)
add_subdirectory(third_party/sentencepiece)

include_directories(third_party/sentencepiece/third_party/protobuf-lite)

# absl
set(ABSL_ENABLE_INSTALL ON CACHE BOOL "" FORCE)
set(ABSL_PROPAGATE_CXX_STD ON CACHE BOOL "" FORCE)
add_subdirectory(third_party/abseil-cpp)

# re2
add_subdirectory(third_party/re2)

# stb
include_directories(third_party/stb)

if (GGML_METAL)
add_compile_definitions(GGML_USE_METAL)
configure_file(third_party/ggml/src/ggml-metal.metal ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
endif ()

if (GGML_PERF)
add_compile_definitions(GGML_PERF)
endif ()

include_directories(${CMAKE_CURRENT_SOURCE_DIR})

file(GLOB CPP_SOURCES
Expand Down
25 changes: 22 additions & 3 deletions chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1118,8 +1118,27 @@ ggml_tensor *GLMBlock::forward(ModelContext *mctx, ggml_tensor *hidden_states, g
return output;
}

static void alloc_weight_context(ModelContext *mctx, const ggml_backend_buffer_t sd_buf) {
void *sd_buf_base = ggml_backend_buffer_get_base(sd_buf);
const size_t sd_buf_size = ggml_backend_buffer_get_size(sd_buf);
if (ggml_backend_is_cpu(mctx->backend.get())) {
mctx->buf_w = unique_ggml_backend_buffer_t(ggml_backend_cpu_buffer_from_ptr(sd_buf_base, sd_buf_size));
}
#ifdef GGML_USE_METAL
else if (ggml_backend_is_metal(mctx->backend.get())) {
const size_t max_size = ggml_get_max_tensor_size(mctx->ctx_w.get());
mctx->buf_w =
unique_ggml_backend_buffer_t(ggml_backend_metal_buffer_from_ptr(sd_buf_base, sd_buf_size, max_size));
}
#endif
else {
mctx->buf_w =
unique_ggml_backend_buffer_t(ggml_backend_alloc_ctx_tensors(mctx->ctx_w.get(), mctx->backend.get()));
}
}

void ChatGLMForCausalLM::load_state_dict(const StateDict &sd) {
alloc_weight_context(sd.buf.get());
alloc_weight_context(mctx_.get(), sd.buf.get());

StateDict self_sd = state_dict();
for (auto &item : self_sd.kv) {
Expand Down Expand Up @@ -1259,7 +1278,7 @@ bool ChatGLM2Tokenizer::is_special_id(int id) const {
}

void ChatGLM2ForCausalLM::load_state_dict(const StateDict &sd) {
alloc_weight_context(sd.buf.get());
alloc_weight_context(mctx_.get(), sd.buf.get());

if (config.num_virtual_tokens > 0) {
ggml_tensor *past_key_values = sd.kv.at("past_key_values");
Expand Down Expand Up @@ -1959,7 +1978,7 @@ int ChatGLM4VForCausalLM::count_tokens(const std::vector<int> &input_ids, const
}

void ChatGLM4VForCausalLM::load_state_dict(const StateDict &sd) {
alloc_weight_context(sd.buf.get());
alloc_weight_context(mctx_.get(), sd.buf.get());

auto self_sd = state_dict();
ChatGLM2ForCausalLM::load_state_dict(mctx_.get(), self_sd, sd);
Expand Down
20 changes: 0 additions & 20 deletions chatglm.h
Original file line number Diff line number Diff line change
Expand Up @@ -999,26 +999,6 @@ class BasicModelForCausalLM : public BaseModelForCausalLM {

void load_prefix_cache(ggml_tensor *past_key_values) { transformer.load_prefix_cache(config, past_key_values); }

protected:
void alloc_weight_context(const ggml_backend_buffer_t sd_buf) const {
void *sd_buf_base = ggml_backend_buffer_get_base(sd_buf);
const size_t sd_buf_size = ggml_backend_buffer_get_size(sd_buf);
if (ggml_backend_is_cpu(mctx_->backend.get())) {
mctx_->buf_w = unique_ggml_backend_buffer_t(ggml_backend_cpu_buffer_from_ptr(sd_buf_base, sd_buf_size));
}
#ifdef GGML_USE_METAL
else if (ggml_backend_is_metal(mctx_->backend.get())) {
const size_t max_size = ggml_get_max_tensor_size(mctx_->ctx_w.get());
mctx_->buf_w =
unique_ggml_backend_buffer_t(ggml_backend_metal_buffer_from_ptr(sd_buf_base, sd_buf_size, max_size));
}
#endif
else {
mctx_->buf_w =
unique_ggml_backend_buffer_t(ggml_backend_alloc_ctx_tensors(mctx_->ctx_w.get(), mctx_->backend.get()));
}
}

public:
Model transformer;
Linear lm_head;
Expand Down
2 changes: 1 addition & 1 deletion chatglm_cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import chatglm_cpp._C as _C
from chatglm_cpp._C import ChatMessage, Image

__version__ = "0.4.1"
__version__ = "0.4.2"


@dataclass
Expand Down
Loading