Skip to content

Commit

Permalink
Fix llama_cpp and Llama type signatures. Closes ggerganov#221
Browse files Browse the repository at this point in the history
  • Loading branch information
abetlen authored and Don Mahurin committed May 31, 2023
1 parent 18eca89 commit f2e4aff
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions examples/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def llama_free(ctx: llama_context_p):
# nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
def llama_model_quantize(
fname_inp: bytes, fname_out: bytes, ftype: c_int, nthread: c_int
) -> c_int:
) -> int:
return _lib.llama_model_quantize(fname_inp, fname_out, ftype, nthread)


Expand All @@ -225,7 +225,7 @@ def llama_apply_lora_from_file(
path_lora: c_char_p,
path_base_model: c_char_p,
n_threads: c_int,
) -> c_int:
) -> int:
return _lib.llama_apply_lora_from_file(ctx, path_lora, path_base_model, n_threads)


Expand All @@ -234,7 +234,7 @@ def llama_apply_lora_from_file(


# Returns the number of tokens in the KV cache
def llama_get_kv_cache_token_count(ctx: llama_context_p) -> c_int:
def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int:
return _lib.llama_get_kv_cache_token_count(ctx)


Expand All @@ -253,7 +253,7 @@ def llama_set_rng_seed(ctx: llama_context_p, seed: c_int):

# Returns the maximum size in bytes of the state (rng, logits, embedding
# and kv_cache) - will often be smaller after compacting tokens
def llama_get_state_size(ctx: llama_context_p) -> c_size_t:
def llama_get_state_size(ctx: llama_context_p) -> int:
return _lib.llama_get_state_size(ctx)


Expand Down Expand Up @@ -293,7 +293,7 @@ def llama_load_session_file(
tokens_out, # type: Array[llama_token]
n_token_capacity: c_size_t,
n_token_count_out, # type: _Pointer[c_size_t]
) -> c_size_t:
) -> int:
return _lib.llama_load_session_file(
ctx, path_session, tokens_out, n_token_capacity, n_token_count_out
)
Expand All @@ -314,7 +314,7 @@ def llama_save_session_file(
path_session: bytes,
tokens, # type: Array[llama_token]
n_token_count: c_size_t,
) -> c_size_t:
) -> int:
return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count)


Expand All @@ -337,7 +337,7 @@ def llama_eval(
n_tokens: c_int,
n_past: c_int,
n_threads: c_int,
) -> c_int:
) -> int:
return _lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads)


Expand All @@ -364,23 +364,23 @@ def llama_tokenize(
_lib.llama_tokenize.restype = c_int


def llama_n_vocab(ctx: llama_context_p) -> c_int:
def llama_n_vocab(ctx: llama_context_p) -> int:
return _lib.llama_n_vocab(ctx)


_lib.llama_n_vocab.argtypes = [llama_context_p]
_lib.llama_n_vocab.restype = c_int


def llama_n_ctx(ctx: llama_context_p) -> c_int:
def llama_n_ctx(ctx: llama_context_p) -> int:
return _lib.llama_n_ctx(ctx)


_lib.llama_n_ctx.argtypes = [llama_context_p]
_lib.llama_n_ctx.restype = c_int


def llama_n_embd(ctx: llama_context_p) -> c_int:
def llama_n_embd(ctx: llama_context_p) -> int:
return _lib.llama_n_embd(ctx)


Expand Down Expand Up @@ -426,23 +426,23 @@ def llama_token_to_str(ctx: llama_context_p, token: llama_token) -> bytes:
# Special tokens


def llama_token_bos() -> llama_token:
def llama_token_bos() -> int:
return _lib.llama_token_bos()


_lib.llama_token_bos.argtypes = []
_lib.llama_token_bos.restype = llama_token


def llama_token_eos() -> llama_token:
def llama_token_eos() -> int:
return _lib.llama_token_eos()


_lib.llama_token_eos.argtypes = []
_lib.llama_token_eos.restype = llama_token


def llama_token_nl() -> llama_token:
def llama_token_nl() -> int:
return _lib.llama_token_nl()


Expand Down Expand Up @@ -625,7 +625,7 @@ def llama_sample_token_mirostat(
eta: c_float,
m: c_int,
mu, # type: _Pointer[c_float]
) -> llama_token:
) -> int:
return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)


Expand All @@ -651,7 +651,7 @@ def llama_sample_token_mirostat_v2(
tau: c_float,
eta: c_float,
mu, # type: _Pointer[c_float]
) -> llama_token:
) -> int:
return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)


Expand All @@ -669,7 +669,7 @@ def llama_sample_token_mirostat_v2(
def llama_sample_token_greedy(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
) -> llama_token:
) -> int:
return _lib.llama_sample_token_greedy(ctx, candidates)


Expand All @@ -684,7 +684,7 @@ def llama_sample_token_greedy(
def llama_sample_token(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
) -> llama_token:
) -> int:
return _lib.llama_sample_token(ctx, candidates)


Expand Down

0 comments on commit f2e4aff

Please sign in to comment.