Skip to content

Commit

Permalink
Update the tool and demo according to the changes in the scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
ufownl committed Oct 20, 2024
1 parent f3812a2 commit 6f028c0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion demo/cgemma_demo.conf
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ http {
include mime.types;

init_worker_by_lua_block {
local sched, err = require("cgemma").scheduler(1, 1, 0)
local sched, err = require("cgemma").scheduler({num_threads = 2})
if not sched then
ngx.log(ngx.ERR, "cgemma error: ", err)
end
Expand Down
22 changes: 18 additions & 4 deletions tools/cache_prompt.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ if args.help then
print("Usage: cat /path/to/prompt.txt | resty cache_prompt.lua [options]")
print()
print("Available options:")
print(" --max_threads: Maximum number of threads to use, 0 = unlimited. (default: 0)")
print(" --max_clusters: Maximum number of sockets/CCXs to use, 0 = unlimited. (default: 0)")
print(" --pin_threads: Pin threads? -1 = auto, 0 = no, 1 = yes. (default: -1)")
print(" --tokenizer: Path of tokenizer model file. (default: tokenizer.spm)")
print(" --model: Model type (default: gemma2-2b-pt)")
print(" 2b-it = Gemma 2B parameters, instruction-tuned")
Expand All @@ -38,12 +35,29 @@ if args.help then
print(" --weight_type: Weight type (default: sfp)")
print(" --prefill_tbatch: Maximum batch size during prefill phase (default: 64)")
print(" --output: Path of output file. (default: dump.bin)")
print(" --num_threads: Maximum number of threads to use, 0 = unlimited. (default: 0)")
print(" --pin: Pin threads? -1 = auto, 0 = no, 1 = yes. (default: -1)")
print(" --skip_packages: Index of the first socket to use, 0 = unlimited. (default: 0)")
print(" --max_packages: Maximum number of sockets to use, 0 = unlimited. (default: 0)")
print(" --skip_clusters: Index of the first CCX to use, 0 = unlimited. (default: 0)")
print(" --max_clusters: Maximum number of CCXs to use, 0 = unlimited. (default: 0)")
print(" --skip_lps: Index of the first LP to use, 0 = unlimited. (default: 0)")
print(" --max_lps: Maximum number of LPs to use, 0 = unlimited. (default: 0)")
print(" --stats: Print statistics at end.")
return
end

-- Create a scheduler instance
local sched, err = require("cgemma").scheduler(tonumber(args.max_threads), tonumber(args.max_clusters), tonumber(args.pin_threads))
local sched, err = require("cgemma").scheduler({
num_threads = args.num_threads,
pin = args.pin,
skip_packages = args.skip_packages,
max_packages = args.max_packages,
skip_clusters = args.skip_clusters,
max_clusters = args.max_clusters,
skip_lps = args.skip_lps,
max_lps = args.max_lps
})
if not sched then
print("Opoos! ", err)
return
Expand Down

0 comments on commit 6f028c0

Please sign in to comment.