Skip to content

Commit

Permalink
Adapt to PerClusterPools APIs of gemma.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ufownl committed Aug 6, 2024
1 parent e1fd9a1 commit 525f6b5
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 34 deletions.
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,24 @@ Available options:

#### cgemma.scheduler

**syntax:** `<cgemma.scheduler>sched, <string>err = cgemma.scheduler([<number>num_threads])`
**syntax:** `<cgemma.scheduler>sched, <string>err = cgemma.scheduler([<number>max_threads, <number>max_clusters])`

Create a scheduler instance.

A successful call returns a scheduler instance. Otherwise, it returns `nil` and a string describing the error.

The only parameter `num_threads` indicates the number of threads in the internal thread pool. If not provided or `num_threads <= 0`, it will create a default scheduler with the number of threads depending on the concurrent threads supported by the implementation.
Available parameters:

#### cgemma.scheduler.pin_threads
| Parameter | Description |
| --------- | ----------- |
| max_threads | Maximum number of threads to use. (default: `0` means unlimited) |
| max_clusters | Maximum number of sockets/CCXs to use. (default: `0` means unlimited) |

**syntax:** `sched:pin_threads()`
#### cgemma.scheduler.cpu_topology

Pin the scheduler's threads to logical processors.
**syntax:** `<table>clusters = sched:cpu_topology()`

Query CPU topology.

#### cgemma.instance.disabled_tokens

Expand Down
11 changes: 5 additions & 6 deletions src/instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,16 @@ int disabled_tokens(lua_State* L) {

namespace cgemma {

instance::instance(int argc, char* argv[], scheduler* s)
: args_(argc, argv)
, sched_(s) {
instance::instance(int argc, char* argv[], scheduler* sched)
: args_(argc, argv) {
if (auto err = args_.Validate()) {
throw std::invalid_argument(err);
}
if (!s) {
if (!sched) {
default_sched_ = std::make_unique<scheduler>();
sched_ = default_sched_.get();
sched = default_sched_.get();
}
model_ = std::make_unique<gcpp::Gemma>(args_.tokenizer, args_.weights, args_.Info(), sched_->pool());
model_ = std::make_unique<gcpp::Gemma>(args_.tokenizer, args_.weights, args_.Info(), sched->pools());
}

void instance::declare(lua_State* L) {
Expand Down
2 changes: 0 additions & 2 deletions src/instance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class instance {
explicit instance(int argc, char* argv[], scheduler* s);

const gcpp::LoaderArgs& args() const { return args_; }
scheduler& sched() const { return *sched_; }
gcpp::Gemma& model() const { return *model_; }
const std::unordered_set<int>& disabled_tokens() const { return disabled_tokens_; }

Expand All @@ -27,7 +26,6 @@ class instance {

private:
gcpp::LoaderArgs args_;
scheduler* sched_;
std::unique_ptr<scheduler> default_sched_;
std::unique_ptr<gcpp::Gemma> model_;
std::unordered_set<int> disabled_tokens_;
Expand Down
38 changes: 23 additions & 15 deletions src/scheduler.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
#include "scheduler.hpp"
#include <util/app.h>
#include <algorithm>
#include <thread>

namespace {

constexpr const char name[] = "cgemma.scheduler";

int pin_threads(lua_State* L) {
gcpp::PinWorkersToCores(cgemma::scheduler::check(L, 1)->pool());
return 0;
int cpu_topology(lua_State* L) {
auto sched = cgemma::scheduler::check(L, 1);
lua_newtable(L);
size_t i = 0;
for (auto& cluster: sched->pools().CoresPerCluster()) {
lua_pushinteger(L, ++i);
lua_newtable(L);
size_t j = 0;
cluster.Foreach([&](size_t cpu) {
lua_pushinteger(L, ++j);
lua_pushinteger(L, cpu);
lua_settable(L, -3);
});
lua_settable(L, -3);
}
return 1;
}

int destroy(lua_State* L) {
Expand All @@ -22,12 +33,12 @@ int destroy(lua_State* L) {
namespace cgemma {

scheduler::scheduler()
: pool_(std::min(static_cast<size_t>(std::thread::hardware_concurrency()), gcpp::kMaxThreads)) {
: pools_(0, 0) {
// nop
}

scheduler::scheduler(size_t num_threads)
: pool_(std::min(num_threads, gcpp::kMaxThreads)) {
scheduler::scheduler(size_t max_threads, size_t max_clusters)
: pools_(max_clusters, max_threads) {
// nop
}

Expand All @@ -37,7 +48,7 @@ void scheduler::declare(lua_State* L) {
{nullptr, nullptr}
};
constexpr const luaL_Reg methods[] = {
{"pin_threads", pin_threads},
{"cpu_topology", cpu_topology},
{nullptr, nullptr}
};
luaL_newmetatable(L, name);
Expand All @@ -62,14 +73,11 @@ scheduler* scheduler::check(lua_State* L, int index) {
}

int scheduler::create(lua_State* L) {
auto num_threads = lua_tointeger(L, 1);
auto max_threads = lua_tointeger(L, 1);
auto max_clusters = lua_tointeger(L, 2);
auto ud = lua_newuserdata(L, sizeof(scheduler));
try {
if (num_threads > 0) {
new(ud) scheduler(num_threads);
} else {
new(ud) scheduler();
}
new(ud) scheduler(max_threads, max_clusters);
luaL_getmetatable(L, name);
lua_setmetatable(L, -2);
return 1;
Expand Down
8 changes: 4 additions & 4 deletions src/scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@
#define CGEMMA_SCHEDULER_HPP

#include <lua.hpp>
#include <hwy/contrib/thread_pool/thread_pool.h>
#include <util/threading.h>

namespace cgemma {

class scheduler {
public:
scheduler();
explicit scheduler(size_t num_threads);
explicit scheduler(size_t max_threads, size_t max_clusters);

hwy::ThreadPool& pool() { return pool_; }
gcpp::PerClusterPools& pools() { return pools_; }

static void declare(lua_State* L);
static scheduler* to(lua_State* L, int index);
static scheduler* check(lua_State* L, int index);
static int create(lua_State* L);

private:
hwy::ThreadPool pool_;
gcpp::PerClusterPools pools_;
};

}
Expand Down
5 changes: 3 additions & 2 deletions tools/cache_prompt.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ if args.help then
print("Usage: cat /path/to/prompt.txt | resty cache_prompt.lua [options]")
print()
print("Available options:")
print(" --num_threads: Number of threads in scheduler. (default: hardware concurrency)")
print(" --max_threads: Maximum number of threads to use. (default: 0 means unlimited)")
print(" --max_clusters: Maximum number of sockets/CCXs to use. (default: 0 means unlimited)")
print(" --tokenizer: Path of tokenizer model file. (default: tokenizer.spm)")
print(" --model: Model type (default: gemma2-2b-pt)")
print(" 2b-it = Gemma 2B parameters, instruction-tuned")
Expand All @@ -42,7 +43,7 @@ if args.help then
end

-- Create a scheduler instance
local sched, err = require("cgemma").scheduler(tonumber(args.num_threads))
local sched, err = require("cgemma").scheduler(tonumber(args.max_threads), tonumber(args.max_clusters))
if not sched then
print("Opoos! ", err)
return
Expand Down

0 comments on commit 525f6b5

Please sign in to comment.