From 0ff80973ae878592f46c24f22df463deb3b8775b Mon Sep 17 00:00:00 2001 From: RangerUFO <ufownl@gmail.com> Date: Sat, 8 Jun 2024 19:46:40 +0800 Subject: [PATCH] Adapt to changes in gemma.cpp Add optional `weight_type` option to `cgemma.new` method. --- README.md | 4 ++++ src/cgemma.cpp | 2 -- src/instance.cpp | 19 ++++++++++++++++--- src/session.cpp | 22 ++++++++-------------- 4 files changed, 28 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 2f62601..2c0e53a 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,10 @@ Available options: -- gr2b-pt (griffin 2B parameters, pretrained). -- (required) weights = "/path/to/2b-it-sfp.sbs", -- Path of model weights file. (required) + weight_type = "sfp", -- Weight type: + -- sfp (8-bit FP, default) + -- f32 (float) + -- bf16 (bfloat16) scheduler = sched_inst, -- Instance of scheduler, if not provided a default -- scheduler will be attached. } diff --git a/src/cgemma.cpp b/src/cgemma.cpp index 2d9f578..5d6954d 100644 --- a/src/cgemma.cpp +++ b/src/cgemma.cpp @@ -26,8 +26,6 @@ int info(lua_State* L) { << "Date & Time : " << std::put_time(std::localtime(&now), "%F %T") << std::endl << "Max Sequence Length : " << gcpp::kSeqLen << std::endl << "Top-K : " << gcpp::kTopK << std::endl - << "Weight Type : " << gcpp::TypeName(gcpp::GemmaWeightT()) << std::endl - << "Embedder Input Type : " << gcpp::TypeName(gcpp::EmbedderInputT()) << std::endl << "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize << std::endl << "Hardware Concurrency : " << std::thread::hardware_concurrency() << std::endl << "Instruction Set : " << hwy::TargetName(hwy::DispatchedTarget()) << " (" << hwy::VectorBytes() * 8 << " bits)" << std::endl diff --git a/src/instance.cpp b/src/instance.cpp index ba630de..9058919 100644 --- a/src/instance.cpp +++ b/src/instance.cpp @@ -26,7 +26,7 @@ instance::instance(int argc, char* argv[], scheduler* s) default_sched_ = std::make_unique<scheduler>(); sched_ = default_sched_.get(); } - model_ = std::make_unique<gcpp::Gemma>(args_.tokenizer, args_.weights, args_.ModelType(), sched_->pool()); + model_ = std::make_unique<gcpp::Gemma>(args_.tokenizer, args_.weights, args_.ModelType(), args_.WeightType(), sched_->pool()); } void instance::declare(lua_State* L) { @@ -61,7 +61,9 @@ int instance::create(lua_State* L) { lua_pop(L, 1); constexpr const char* required_options[] = {"--tokenizer", "--model", "--weights"}; constexpr const int n = sizeof(required_options) / sizeof(required_options[0]); - char* argv[n * 2 + 1] = {const_cast<char*>("lua-cgemma")}; + constexpr const char* optional_options[] = {"--weight_type"}; + constexpr const int m = sizeof(optional_options) / sizeof(optional_options[0]); + char* argv[(n + m) * 2 + 1] = {const_cast<char*>("lua-cgemma")}; for (int i = 0; i < n; ++i) { auto k = required_options[i] + 2; lua_getfield(L, 1, k); @@ -73,9 +75,20 @@ int instance::create(lua_State* L) { argv[i * 2 + 2] = const_cast<char*>(v); lua_pop(L, 1); } + auto argc = n * 2 + 1; + for (auto opt: optional_options) { + auto k = opt + 2; + lua_getfield(L, 1, k); + auto v = lua_tostring(L, -1); + if (v) { + argv[argc++] = const_cast<char*>(opt); + argv[argc++] = const_cast<char*>(v); + } + lua_pop(L, 1); + } auto ud = lua_newuserdata(L, sizeof(instance)); try { - new(ud) instance(n * 2 + 1, argv, s); + new(ud) instance(argc, argv, s); luaL_getmetatable(L, name); lua_setmetatable(L, -2); return 1; diff --git a/src/session.cpp b/src/session.cpp index d573e34..127a185 100644 --- a/src/session.cpp +++ b/src/session.cpp @@ -19,7 +19,7 @@ std::vector<int> text2prompt(cgemma::session* sess, const char* text) { constexpr const char model_sot[] = "<start_of_turn>model\n"; constexpr const char eot[] = "<end_of_turn>\n"; std::string s; - if (sess->inst()->args().ModelTraining() == gcpp::ModelTraining::GEMMA_IT) { + if (sess->inst()->args().ModelTrainingType() == gcpp::ModelTraining::GEMMA_IT) { s.reserve(sizeof(eot) - 1 + sizeof(user_sot) - 1 + std::strlen(text) @@ -169,22 +169,16 @@ class kv_cache_size_store { std::array<size_t, static_cast<size_t>(kv_cache_field::end)> store_; }; -kv_cache_size_store kv_cache_size(gcpp::Model type, size_t pos) { - switch (type) { - case gcpp::Model::GEMMA_2B: - return kv_cache_size_store(gcpp::ConfigGemma2B(), pos); - case gcpp::Model::GEMMA_7B: - return kv_cache_size_store(gcpp::ConfigGemma7B(), pos); - case gcpp::Model::GRIFFIN_2B: - return kv_cache_size_store(gcpp::ConfigGriffin2B(), pos); - default: - throw std::invalid_argument("Invalid model type."); +template <class Config> +struct kv_cache_size { + kv_cache_size_store operator()(size_t pos) { + return kv_cache_size_store(Config(), pos); } -} +}; size_t dump_impl(char* buf, const cgemma::session* sess) { auto type = sess->inst()->args().ModelType(); - auto size = kv_cache_size(type, sess->pos()); + auto size = gcpp::CallForModel<void, kv_cache_size>(type, sess->pos()); uint16_t pos = sess->pos(); if (buf) { std::memcpy(buf, name, sizeof(name) - 1); @@ -223,7 +217,7 @@ void load_impl(cgemma::session* sess, const char* buf, size_t n) { buf += sizeof(name); size_t pos = *reinterpret_cast<const uint16_t*>(buf); buf += sizeof(uint16_t); - auto size = kv_cache_size(type, pos); + auto size = gcpp::CallForModel<void, kv_cache_size>(type, pos); if (n != sizeof(name) + sizeof(uint16_t) + size.total()) { throw std::invalid_argument("Invalid dump format: KVCache length mismatch"); }