Skip to content

Commit

Permalink
Add support for new single-file weight format
Browse files Browse the repository at this point in the history
  • Loading branch information
ufownl committed Jan 17, 2025
1 parent bd60906 commit 2e9f9e2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ Available options:

```lua
{
tokenizer = "/path/to/tokenizer.spm", -- Path of tokenizer model file. (required)
tokenizer = "/path/to/tokenizer.spm", -- Path of tokenizer model file.
model = "gemma2-2b-it", -- Model type:
-- 2b-it (Gemma 2B parameters, instruction-tuned),
-- 2b-pt (Gemma 2B parameters, pretrained),
Expand All @@ -193,7 +193,6 @@ Available options:
-- paligemma2-3b-448 (PaliGemma2 3B 448*448),
-- paligemma2-10b-224 (PaliGemma2 10B 224*224),
-- paligemma2-10b-448 (PaliGemma2 10B 448*448),
-- (required)
weights = "/path/to/2.0-2b-it-sfp.sbs", -- Path of model weights file. (required)
weight_type = "sfp", -- Weight type:
-- sfp (8-bit FP, default)
Expand Down
10 changes: 7 additions & 3 deletions src/instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ instance::instance(int argc, char* argv[], unsigned int seed, scheduler* sched)
default_sched_ = std::make_unique<scheduler>();
sched = default_sched_.get();
}
model_ = std::make_unique<gcpp::Gemma>(args_.tokenizer, args_.weights, args_.Info(), sched->pools());
if (args_.Info().weight == gcpp::Type::kUnknown || args_.Info().model == gcpp::Model::UNKNOWN || args_.tokenizer.path.empty()) {
model_ = std::make_unique<gcpp::Gemma>(args_.weights, sched->pools());
} else {
model_ = std::make_unique<gcpp::Gemma>(args_.tokenizer, args_.weights, args_.Info(), sched->pools());
}
}

void instance::declare(lua_State* L) {
Expand Down Expand Up @@ -73,9 +77,9 @@ instance* instance::check(lua_State* L, int index) {

int instance::create(lua_State* L) {
luaL_checktype(L, 1, LUA_TTABLE);
constexpr const char* required_options[] = {"--tokenizer", "--model", "--weights"};
constexpr const char* required_options[] = {"--weights"};
constexpr const int n = sizeof(required_options) / sizeof(required_options[0]);
constexpr const char* optional_options[] = {"--weight_type"};
constexpr const char* optional_options[] = {"--tokenizer", "--model", "--weight_type"};
constexpr const int m = sizeof(optional_options) / sizeof(optional_options[0]);
char* argv[(n + m) * 2 + 1] = {const_cast<char*>("lua-cgemma")};
for (int i = 0; i < n; ++i) {
Expand Down

0 comments on commit 2e9f9e2

Please sign in to comment.