From 3f3c329f76293a657558fc7c81ba3088b11c9a74 Mon Sep 17 00:00:00 2001 From: Harrison Date: Wed, 10 Jan 2024 03:05:01 +1100 Subject: [PATCH] Smaller thread group sizes --- godot-rwkv.h | 63 ++++++++++++++++++++++++++-------------------------- rwkv.cuh | 2 +- 2 files changed, 32 insertions(+), 33 deletions(-) diff --git a/godot-rwkv.h b/godot-rwkv.h index 513aa24..86025f0 100644 --- a/godot-rwkv.h +++ b/godot-rwkv.h @@ -107,14 +107,13 @@ class Agent : public Resource { protected: static void _bind_methods() { - ClassDB::bind_method(D_METHOD("add_context"), &Agent::add_context); - ClassDB::bind_method(D_METHOD("generate"), &Agent::generate); - ClassDB::bind_method(D_METHOD("set_temperature"), &Agent::set_temperature); - ClassDB::bind_method(D_METHOD("set_tau"), &Agent::set_tau); - ClassDB::bind_method(D_METHOD("set_stop_sequences"), &Agent::set_stop_sequences); + ClassDB::bind_method(D_METHOD("add_context", "Context"), &Agent::add_context); + ClassDB::bind_method(D_METHOD("generate", "Tokens"), &Agent::generate); + ClassDB::bind_method(D_METHOD("set_temperature", "Temp"), &Agent::set_temperature); + ClassDB::bind_method(D_METHOD("set_tau", "Tau"), &Agent::set_tau); + ClassDB::bind_method(D_METHOD("set_stop_sequences", "Sequences"), &Agent::set_stop_sequences); ClassDB::bind_method(D_METHOD("set_last_token"), &Agent::set_last_token); ClassDB::bind_method(D_METHOD("get_context"), &Agent::get_context); - ClassDB::bind_method(D_METHOD("get_last_token"), &Agent::get_last_token); ClassDB::bind_method(D_METHOD("get_max_queued_tokens"), &Agent::get_max_queued_tokens); } @@ -130,16 +129,19 @@ class GodotRWKV : public Resource { RWKV* model = nullptr; RWKVTokenizer* tokenizer = nullptr; size_t lastToken = 187; - size_t max_agents = 50; std::vector agents = {}; GodotRWKV() { } - void loadModel(String path, int max_batch = 50) { - max_agents = (size_t)max_batch; - // model.loadFile(std::string(path.utf8().get_data())); - model = new RWKV(std::string(path.utf8().get_data())); + void loadModel(String path, size_t NumThreads = 8) { + + if (NumThreads != 1 && NumThreads != 2 && NumThreads != 4 && NumThreads != 8) { + ERR_PRINT("NumThreads must be 8, 4, 2 or 1"); + return; + } + + model = new RWKV(std::string(path.utf8().get_data()), NumThreads); }; void loadTokenizer(String path) { @@ -162,21 +164,19 @@ class GodotRWKV : public Resource { model->set_state(agents[i]->state, 0); std::cout << "state set" << std::endl; - auto maxBatchSeqSize = max_agents; + // process tokens in batches of maxBatchSeqSize - for (size_t oi = 0; oi < tokens.size()-1; oi += maxBatchSeqSize) { - auto tokensBatch = std::vector(); - tokensBatch.push_back(agents[i]->last_token); - for (size_t j = oi; j < MIN(oi + maxBatchSeqSize, tokens.size()-1); j++) { - tokensBatch.push_back(tokens[j]); - } - std::cout << "tokensBatch: " << oi << std::endl; - auto outputs = (*model)({tokensBatch}); - if (oi + maxBatchSeqSize >= tokens.size()-1) { - agents[i]->last_token = tokens[tokens.size()-1]; - } + auto tokensBatch = std::vector(); + tokensBatch.push_back(agents[i]->last_token); + for (size_t j = 0; j < tokens.size()-1; j++) { + tokensBatch.push_back(tokens[j]); } + auto outputs = (*model)({tokensBatch}); + + agents[i]->last_token = tokens[tokens.size()-1]; + + std::cout << "context processed" << std::endl; agents[i]->add_context_queue = ""; @@ -253,22 +253,21 @@ class GodotRWKV : public Resource { }; Variant createAgent() { - if (agents.size() < max_agents) { - Agent *agent = new Agent(model, tokenizer); - agents.push_back(agent); - return Variant(agent); - } - print_error("max agents reached"); - return Variant(); + Agent *agent = new Agent(model, tokenizer); + agents.push_back(agent); + return Variant(agent); + + + } protected: static void _bind_methods() { ClassDB::bind_method(D_METHOD("listen"), &GodotRWKV::listen); - ClassDB::bind_method(D_METHOD("loadModel"), &GodotRWKV::loadModel); - ClassDB::bind_method(D_METHOD("loadTokenizer"), &GodotRWKV::loadTokenizer); + ClassDB::bind_method(D_METHOD("loadModel", "Path", "Threads"), &GodotRWKV::loadModel, DEFVAL(4)); + ClassDB::bind_method(D_METHOD("loadTokenizer", "Path"), &GodotRWKV::loadTokenizer); ClassDB::bind_method(D_METHOD("createAgent"), &GodotRWKV::createAgent); } }; diff --git a/rwkv.cuh b/rwkv.cuh index 5f0fc49..b4e275b 160000 --- a/rwkv.cuh +++ b/rwkv.cuh @@ -1 +1 @@ -Subproject commit 5f0fc4987458af52fb7f783fe380c1cde4a8289f +Subproject commit b4e275b61d76db8d5c08fa7af6a2d279abd1d1be