Skip to content

Commit

Permalink
Adapt to changes in gemma.cpp
Browse files Browse the repository at this point in the history
Add optional `weight_type` option to `cgemma.new` method.
  • Loading branch information
ufownl committed Jun 8, 2024
1 parent a220ea8 commit 0ff8097
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 19 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ Available options:
-- gr2b-pt (griffin 2B parameters, pretrained).
-- (required)
weights = "/path/to/2b-it-sfp.sbs", -- Path of model weights file. (required)
weight_type = "sfp", -- Weight type:
-- sfp (8-bit FP, default)
-- f32 (float)
-- bf16 (bfloat16)
scheduler = sched_inst, -- Instance of scheduler, if not provided a default
-- scheduler will be attached.
}
Expand Down
2 changes: 0 additions & 2 deletions src/cgemma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ int info(lua_State* L) {
<< "Date & Time : " << std::put_time(std::localtime(&now), "%F %T") << std::endl
<< "Max Sequence Length : " << gcpp::kSeqLen << std::endl
<< "Top-K : " << gcpp::kTopK << std::endl
<< "Weight Type : " << gcpp::TypeName(gcpp::GemmaWeightT()) << std::endl
<< "Embedder Input Type : " << gcpp::TypeName(gcpp::EmbedderInputT()) << std::endl
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize << std::endl
<< "Hardware Concurrency : " << std::thread::hardware_concurrency() << std::endl
<< "Instruction Set : " << hwy::TargetName(hwy::DispatchedTarget()) << " (" << hwy::VectorBytes() * 8 << " bits)" << std::endl
Expand Down
19 changes: 16 additions & 3 deletions src/instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ instance::instance(int argc, char* argv[], scheduler* s)
default_sched_ = std::make_unique<scheduler>();
sched_ = default_sched_.get();
}
model_ = std::make_unique<gcpp::Gemma>(args_.tokenizer, args_.weights, args_.ModelType(), sched_->pool());
model_ = std::make_unique<gcpp::Gemma>(args_.tokenizer, args_.weights, args_.ModelType(), args_.WeightType(), sched_->pool());
}

void instance::declare(lua_State* L) {
Expand Down Expand Up @@ -61,7 +61,9 @@ int instance::create(lua_State* L) {
lua_pop(L, 1);
constexpr const char* required_options[] = {"--tokenizer", "--model", "--weights"};
constexpr const int n = sizeof(required_options) / sizeof(required_options[0]);
char* argv[n * 2 + 1] = {const_cast<char*>("lua-cgemma")};
constexpr const char* optional_options[] = {"--weight_type"};
constexpr const int m = sizeof(optional_options) / sizeof(optional_options[0]);
char* argv[(n + m) * 2 + 1] = {const_cast<char*>("lua-cgemma")};
for (int i = 0; i < n; ++i) {
auto k = required_options[i] + 2;
lua_getfield(L, 1, k);
Expand All @@ -73,9 +75,20 @@ int instance::create(lua_State* L) {
argv[i * 2 + 2] = const_cast<char*>(v);
lua_pop(L, 1);
}
auto argc = n * 2 + 1;
for (auto opt: optional_options) {
auto k = opt + 2;
lua_getfield(L, 1, k);
auto v = lua_tostring(L, -1);
if (v) {
argv[argc++] = const_cast<char*>(opt);
argv[argc++] = const_cast<char*>(v);
}
lua_pop(L, 1);
}
auto ud = lua_newuserdata(L, sizeof(instance));
try {
new(ud) instance(n * 2 + 1, argv, s);
new(ud) instance(argc, argv, s);
luaL_getmetatable(L, name);
lua_setmetatable(L, -2);
return 1;
Expand Down
22 changes: 8 additions & 14 deletions src/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ std::vector<int> text2prompt(cgemma::session* sess, const char* text) {
constexpr const char model_sot[] = "<start_of_turn>model\n";
constexpr const char eot[] = "<end_of_turn>\n";
std::string s;
if (sess->inst()->args().ModelTraining() == gcpp::ModelTraining::GEMMA_IT) {
if (sess->inst()->args().ModelTrainingType() == gcpp::ModelTraining::GEMMA_IT) {
s.reserve(sizeof(eot) - 1
+ sizeof(user_sot) - 1
+ std::strlen(text)
Expand Down Expand Up @@ -169,22 +169,16 @@ class kv_cache_size_store {
std::array<size_t, static_cast<size_t>(kv_cache_field::end)> store_;
};

kv_cache_size_store kv_cache_size(gcpp::Model type, size_t pos) {
switch (type) {
case gcpp::Model::GEMMA_2B:
return kv_cache_size_store(gcpp::ConfigGemma2B(), pos);
case gcpp::Model::GEMMA_7B:
return kv_cache_size_store(gcpp::ConfigGemma7B(), pos);
case gcpp::Model::GRIFFIN_2B:
return kv_cache_size_store(gcpp::ConfigGriffin2B(), pos);
default:
throw std::invalid_argument("Invalid model type.");
template <class Config>
struct kv_cache_size {
kv_cache_size_store operator()(size_t pos) {
return kv_cache_size_store(Config(), pos);
}
}
};

size_t dump_impl(char* buf, const cgemma::session* sess) {
auto type = sess->inst()->args().ModelType();
auto size = kv_cache_size(type, sess->pos());
auto size = gcpp::CallForModel<void, kv_cache_size>(type, sess->pos());
uint16_t pos = sess->pos();
if (buf) {
std::memcpy(buf, name, sizeof(name) - 1);
Expand Down Expand Up @@ -223,7 +217,7 @@ void load_impl(cgemma::session* sess, const char* buf, size_t n) {
buf += sizeof(name);
size_t pos = *reinterpret_cast<const uint16_t*>(buf);
buf += sizeof(uint16_t);
auto size = kv_cache_size(type, pos);
auto size = gcpp::CallForModel<void, kv_cache_size>(type, pos);
if (n != sizeof(name) + sizeof(uint16_t) + size.total()) {
throw std::invalid_argument("Invalid dump format: KVCache length mismatch");
}
Expand Down

0 comments on commit 0ff8097

Please sign in to comment.