From 0ff80973ae878592f46c24f22df463deb3b8775b Mon Sep 17 00:00:00 2001
From: RangerUFO <ufownl@gmail.com>
Date: Sat, 8 Jun 2024 19:46:40 +0800
Subject: [PATCH] Adapt to changes in gemma.cpp

Add optional `weight_type` option to `cgemma.new` method.
---
 README.md        |  4 ++++
 src/cgemma.cpp   |  2 --
 src/instance.cpp | 19 ++++++++++++++++---
 src/session.cpp  | 22 ++++++++--------------
 4 files changed, 28 insertions(+), 19 deletions(-)

diff --git a/README.md b/README.md
index 2f62601..2c0e53a 100644
--- a/README.md
+++ b/README.md
@@ -106,6 +106,10 @@ Available options:
                     -- gr2b-pt (griffin 2B parameters, pretrained).
                     -- (required)
   weights = "/path/to/2b-it-sfp.sbs",  -- Path of model weights file. (required)
+  weight_type = "sfp",  -- Weight type:
+                        -- sfp (8-bit FP, default)
+                        -- f32 (float)
+                        -- bf16 (bfloat16)
   scheduler = sched_inst,  -- Instance of scheduler, if not provided a default
                            -- scheduler will be attached.
 }
diff --git a/src/cgemma.cpp b/src/cgemma.cpp
index 2d9f578..5d6954d 100644
--- a/src/cgemma.cpp
+++ b/src/cgemma.cpp
@@ -26,8 +26,6 @@ int info(lua_State* L) {
     << "Date & Time                                : " << std::put_time(std::localtime(&now), "%F %T") << std::endl
     << "Max Sequence Length                        : " << gcpp::kSeqLen << std::endl
     << "Top-K                                      : " << gcpp::kTopK << std::endl
-    << "Weight Type                                : " << gcpp::TypeName(gcpp::GemmaWeightT()) << std::endl
-    << "Embedder Input Type                        : " << gcpp::TypeName(gcpp::EmbedderInputT()) << std::endl
     << "Prefill Token Batch Size                   : " << gcpp::kPrefillBatchSize << std::endl
     << "Hardware Concurrency                       : " << std::thread::hardware_concurrency() << std::endl
     << "Instruction Set                            : " << hwy::TargetName(hwy::DispatchedTarget()) << " (" << hwy::VectorBytes() * 8 << " bits)" << std::endl
diff --git a/src/instance.cpp b/src/instance.cpp
index ba630de..9058919 100644
--- a/src/instance.cpp
+++ b/src/instance.cpp
@@ -26,7 +26,7 @@ instance::instance(int argc, char* argv[], scheduler* s)
     default_sched_ = std::make_unique<scheduler>();
     sched_ = default_sched_.get();
   }
-  model_ = std::make_unique<gcpp::Gemma>(args_.tokenizer, args_.weights, args_.ModelType(), sched_->pool());
+  model_ = std::make_unique<gcpp::Gemma>(args_.tokenizer, args_.weights, args_.ModelType(), args_.WeightType(), sched_->pool());
 }
 
 void instance::declare(lua_State* L) {
@@ -61,7 +61,9 @@ int instance::create(lua_State* L) {
   lua_pop(L, 1);
   constexpr const char* required_options[] = {"--tokenizer", "--model", "--weights"};
   constexpr const int n = sizeof(required_options) / sizeof(required_options[0]);
-  char* argv[n * 2 + 1] = {const_cast<char*>("lua-cgemma")};
+  constexpr const char* optional_options[] = {"--weight_type"};
+  constexpr const int m = sizeof(optional_options) / sizeof(optional_options[0]);
+  char* argv[(n + m) * 2 + 1] = {const_cast<char*>("lua-cgemma")};
   for (int i = 0; i < n; ++i) {
     auto k = required_options[i] + 2;
     lua_getfield(L, 1, k);
@@ -73,9 +75,20 @@ int instance::create(lua_State* L) {
     argv[i * 2 + 2] = const_cast<char*>(v);
     lua_pop(L, 1);
   }
+  auto argc = n * 2 + 1;
+  for (auto opt: optional_options) {
+    auto k = opt + 2;
+    lua_getfield(L, 1, k);
+    auto v = lua_tostring(L, -1);
+    if (v) {
+      argv[argc++] = const_cast<char*>(opt);
+      argv[argc++] = const_cast<char*>(v);
+    }
+    lua_pop(L, 1);
+  }
   auto ud = lua_newuserdata(L, sizeof(instance));
   try {
-    new(ud) instance(n * 2 + 1, argv, s);
+    new(ud) instance(argc, argv, s);
     luaL_getmetatable(L, name);
     lua_setmetatable(L, -2);
     return 1;
diff --git a/src/session.cpp b/src/session.cpp
index d573e34..127a185 100644
--- a/src/session.cpp
+++ b/src/session.cpp
@@ -19,7 +19,7 @@ std::vector<int> text2prompt(cgemma::session* sess, const char* text) {
   constexpr const char model_sot[] = "<start_of_turn>model\n";
   constexpr const char eot[] = "<end_of_turn>\n";
   std::string s;
-  if (sess->inst()->args().ModelTraining() == gcpp::ModelTraining::GEMMA_IT) {
+  if (sess->inst()->args().ModelTrainingType() == gcpp::ModelTraining::GEMMA_IT) {
     s.reserve(sizeof(eot) - 1
             + sizeof(user_sot) - 1
             + std::strlen(text)
@@ -169,22 +169,16 @@ class kv_cache_size_store {
   std::array<size_t, static_cast<size_t>(kv_cache_field::end)> store_;
 };
 
-kv_cache_size_store kv_cache_size(gcpp::Model type, size_t pos) {
-  switch (type) {
-    case gcpp::Model::GEMMA_2B:
-      return kv_cache_size_store(gcpp::ConfigGemma2B(), pos);
-    case gcpp::Model::GEMMA_7B:
-      return kv_cache_size_store(gcpp::ConfigGemma7B(), pos);
-    case gcpp::Model::GRIFFIN_2B:
-      return kv_cache_size_store(gcpp::ConfigGriffin2B(), pos);
-    default:
-      throw std::invalid_argument("Invalid model type.");
+template <class Config>
+struct kv_cache_size {
+  kv_cache_size_store operator()(size_t pos) {
+    return kv_cache_size_store(Config(), pos);
   }
-}
+};
 
 size_t dump_impl(char* buf, const cgemma::session* sess) {
   auto type = sess->inst()->args().ModelType();
-  auto size = kv_cache_size(type, sess->pos());
+  auto size = gcpp::CallForModel<void, kv_cache_size>(type, sess->pos());
   uint16_t pos = sess->pos();
   if (buf) {
     std::memcpy(buf, name, sizeof(name) - 1);
@@ -223,7 +217,7 @@ void load_impl(cgemma::session* sess, const char* buf, size_t n) {
   buf += sizeof(name);
   size_t pos = *reinterpret_cast<const uint16_t*>(buf);
   buf += sizeof(uint16_t);
-  auto size = kv_cache_size(type, pos);
+  auto size = gcpp::CallForModel<void, kv_cache_size>(type, pos);
   if (n != sizeof(name) + sizeof(uint16_t) + size.total()) {
     throw std::invalid_argument("Invalid dump format: KVCache length mismatch");
   }