From f2e4affb6a7a6496d5e463c6367d1d5611756d79 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 19 May 2023 11:59:33 -0400 Subject: [PATCH] Fix llama_cpp and Llama type signatures. Closes #221 --- examples/llama_cpp.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/llama_cpp.py b/examples/llama_cpp.py index a8f90f861c605..6bddadff3c5e4 100644 --- a/examples/llama_cpp.py +++ b/examples/llama_cpp.py @@ -206,7 +206,7 @@ def llama_free(ctx: llama_context_p): # nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given def llama_model_quantize( fname_inp: bytes, fname_out: bytes, ftype: c_int, nthread: c_int -) -> c_int: +) -> int: return _lib.llama_model_quantize(fname_inp, fname_out, ftype, nthread) @@ -225,7 +225,7 @@ def llama_apply_lora_from_file( path_lora: c_char_p, path_base_model: c_char_p, n_threads: c_int, -) -> c_int: +) -> int: return _lib.llama_apply_lora_from_file(ctx, path_lora, path_base_model, n_threads) @@ -234,7 +234,7 @@ def llama_apply_lora_from_file( # Returns the number of tokens in the KV cache -def llama_get_kv_cache_token_count(ctx: llama_context_p) -> c_int: +def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int: return _lib.llama_get_kv_cache_token_count(ctx) @@ -253,7 +253,7 @@ def llama_set_rng_seed(ctx: llama_context_p, seed: c_int): # Returns the maximum size in bytes of the state (rng, logits, embedding # and kv_cache) - will often be smaller after compacting tokens -def llama_get_state_size(ctx: llama_context_p) -> c_size_t: +def llama_get_state_size(ctx: llama_context_p) -> int: return _lib.llama_get_state_size(ctx) @@ -293,7 +293,7 @@ def llama_load_session_file( tokens_out, # type: Array[llama_token] n_token_capacity: c_size_t, n_token_count_out, # type: _Pointer[c_size_t] -) -> c_size_t: +) -> int: return _lib.llama_load_session_file( ctx, path_session, tokens_out, n_token_capacity, n_token_count_out ) @@ -314,7 +314,7 @@ def llama_save_session_file( path_session: bytes, tokens, # type: Array[llama_token] n_token_count: c_size_t, -) -> c_size_t: +) -> int: return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count) @@ -337,7 +337,7 @@ def llama_eval( n_tokens: c_int, n_past: c_int, n_threads: c_int, -) -> c_int: +) -> int: return _lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads) @@ -364,7 +364,7 @@ def llama_tokenize( _lib.llama_tokenize.restype = c_int -def llama_n_vocab(ctx: llama_context_p) -> c_int: +def llama_n_vocab(ctx: llama_context_p) -> int: return _lib.llama_n_vocab(ctx) @@ -372,7 +372,7 @@ def llama_n_vocab(ctx: llama_context_p) -> c_int: _lib.llama_n_vocab.restype = c_int -def llama_n_ctx(ctx: llama_context_p) -> c_int: +def llama_n_ctx(ctx: llama_context_p) -> int: return _lib.llama_n_ctx(ctx) @@ -380,7 +380,7 @@ def llama_n_ctx(ctx: llama_context_p) -> c_int: _lib.llama_n_ctx.restype = c_int -def llama_n_embd(ctx: llama_context_p) -> c_int: +def llama_n_embd(ctx: llama_context_p) -> int: return _lib.llama_n_embd(ctx) @@ -426,7 +426,7 @@ def llama_token_to_str(ctx: llama_context_p, token: llama_token) -> bytes: # Special tokens -def llama_token_bos() -> llama_token: +def llama_token_bos() -> int: return _lib.llama_token_bos() @@ -434,7 +434,7 @@ def llama_token_bos() -> llama_token: _lib.llama_token_bos.restype = llama_token -def llama_token_eos() -> llama_token: +def llama_token_eos() -> int: return _lib.llama_token_eos() @@ -442,7 +442,7 @@ def llama_token_eos() -> llama_token: _lib.llama_token_eos.restype = llama_token -def llama_token_nl() -> llama_token: +def llama_token_nl() -> int: return _lib.llama_token_nl() @@ -625,7 +625,7 @@ def llama_sample_token_mirostat( eta: c_float, m: c_int, mu, # type: _Pointer[c_float] -) -> llama_token: +) -> int: return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu) @@ -651,7 +651,7 @@ def llama_sample_token_mirostat_v2( tau: c_float, eta: c_float, mu, # type: _Pointer[c_float] -) -> llama_token: +) -> int: return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu) @@ -669,7 +669,7 @@ def llama_sample_token_mirostat_v2( def llama_sample_token_greedy( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] -) -> llama_token: +) -> int: return _lib.llama_sample_token_greedy(ctx, candidates) @@ -684,7 +684,7 @@ def llama_sample_token_greedy( def llama_sample_token( ctx: llama_context_p, candidates, # type: _Pointer[llama_token_data_array] -) -> llama_token: +) -> int: return _lib.llama_sample_token(ctx, candidates)