Skip to content

Commit

Permalink
Smaller thread group sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
harrisonvanderbyl committed Jan 9, 2024
1 parent 7b65fc5 commit 3f3c329
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 33 deletions.
63 changes: 31 additions & 32 deletions godot-rwkv.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,13 @@ class Agent : public Resource {

protected:
static void _bind_methods() {
ClassDB::bind_method(D_METHOD("add_context"), &Agent::add_context);
ClassDB::bind_method(D_METHOD("generate"), &Agent::generate);
ClassDB::bind_method(D_METHOD("set_temperature"), &Agent::set_temperature);
ClassDB::bind_method(D_METHOD("set_tau"), &Agent::set_tau);
ClassDB::bind_method(D_METHOD("set_stop_sequences"), &Agent::set_stop_sequences);
ClassDB::bind_method(D_METHOD("add_context", "Context"), &Agent::add_context);
ClassDB::bind_method(D_METHOD("generate", "Tokens"), &Agent::generate);
ClassDB::bind_method(D_METHOD("set_temperature", "Temp"), &Agent::set_temperature);
ClassDB::bind_method(D_METHOD("set_tau", "Tau"), &Agent::set_tau);
ClassDB::bind_method(D_METHOD("set_stop_sequences", "Sequences"), &Agent::set_stop_sequences);
ClassDB::bind_method(D_METHOD("set_last_token"), &Agent::set_last_token);
ClassDB::bind_method(D_METHOD("get_context"), &Agent::get_context);
ClassDB::bind_method(D_METHOD("get_last_token"), &Agent::get_last_token);
ClassDB::bind_method(D_METHOD("get_max_queued_tokens"), &Agent::get_max_queued_tokens);
}

Expand All @@ -130,16 +129,19 @@ class GodotRWKV : public Resource {
RWKV* model = nullptr;
RWKVTokenizer* tokenizer = nullptr;
size_t lastToken = 187;
size_t max_agents = 50;
std::vector<Agent*> agents = {};
GodotRWKV() {

}

void loadModel(String path, int max_batch = 50) {
max_agents = (size_t)max_batch;
// model.loadFile(std::string(path.utf8().get_data()));
model = new RWKV(std::string(path.utf8().get_data()));
void loadModel(String path, size_t NumThreads = 8) {

if (NumThreads != 1 && NumThreads != 2 && NumThreads != 4 && NumThreads != 8) {
ERR_PRINT("NumThreads must be 8, 4, 2 or 1");
return;
}

model = new RWKV(std::string(path.utf8().get_data()), NumThreads);
};

void loadTokenizer(String path) {
Expand All @@ -162,21 +164,19 @@ class GodotRWKV : public Resource {
model->set_state(agents[i]->state, 0);
std::cout << "state set" << std::endl;

auto maxBatchSeqSize = max_agents;


// process tokens in batches of maxBatchSeqSize
for (size_t oi = 0; oi < tokens.size()-1; oi += maxBatchSeqSize) {
auto tokensBatch = std::vector<size_t>();
tokensBatch.push_back(agents[i]->last_token);
for (size_t j = oi; j < MIN(oi + maxBatchSeqSize, tokens.size()-1); j++) {
tokensBatch.push_back(tokens[j]);
}
std::cout << "tokensBatch: " << oi << std::endl;
auto outputs = (*model)({tokensBatch});
if (oi + maxBatchSeqSize >= tokens.size()-1) {
agents[i]->last_token = tokens[tokens.size()-1];
}
auto tokensBatch = std::vector<size_t>();
tokensBatch.push_back(agents[i]->last_token);
for (size_t j = 0; j < tokens.size()-1; j++) {
tokensBatch.push_back(tokens[j]);
}
auto outputs = (*model)({tokensBatch});

agents[i]->last_token = tokens[tokens.size()-1];


std::cout << "context processed" << std::endl;

agents[i]->add_context_queue = "";
Expand Down Expand Up @@ -253,22 +253,21 @@ class GodotRWKV : public Resource {
};

Variant createAgent() {
if (agents.size() < max_agents) {
Agent *agent = new Agent(model, tokenizer);
agents.push_back(agent);
return Variant(agent);
}

print_error("max agents reached");
return Variant();
Agent *agent = new Agent(model, tokenizer);
agents.push_back(agent);
return Variant(agent);




}

protected:
static void _bind_methods() {
ClassDB::bind_method(D_METHOD("listen"), &GodotRWKV::listen);
ClassDB::bind_method(D_METHOD("loadModel"), &GodotRWKV::loadModel);
ClassDB::bind_method(D_METHOD("loadTokenizer"), &GodotRWKV::loadTokenizer);
ClassDB::bind_method(D_METHOD("loadModel", "Path", "Threads"), &GodotRWKV::loadModel, DEFVAL(4));
ClassDB::bind_method(D_METHOD("loadTokenizer", "Path"), &GodotRWKV::loadTokenizer);
ClassDB::bind_method(D_METHOD("createAgent"), &GodotRWKV::createAgent);
}
};
Expand Down
2 changes: 1 addition & 1 deletion rwkv.cuh

0 comments on commit 3f3c329

Please sign in to comment.