Skip to content

Commit

Permalink
Add ingestion progress callback
Browse files Browse the repository at this point in the history
  • Loading branch information
stduhpf committed May 2, 2023
1 parent 6124751 commit 1ebd5ba
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 11 deletions.
2 changes: 1 addition & 1 deletion include/bridge.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ namespace fastllama {
FastLlama& operator=(FastLlama &&) noexcept = default;
~FastLlama() { m_model.unload(); }

bool ingest(std::string prompt, bool is_system_prompt = false);
bool ingest(std::string prompt, std::function<void(size_t const&, size_t const&)> fn, bool is_system_prompt = false);
bool generate(
std::function<void(std::string const&)> fn,
std::size_t num_tokens,
Expand Down
6 changes: 4 additions & 2 deletions interfaces/c/fastllama.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,23 @@ bool llama_set_stop_words(struct llama_model_context* model_context, char const*
*
* @param model_context is the context that is constructed using `llama_create_context`.
* @param prompt is user string that will be processed and produce output.
* @param progress_fn callback used to keep track of the advancement of the ingestion
* @return true if it successfully ingests the prompt.
* @return false if it is unable to ingest the prompt.
*/
bool llama_ingest_system_prompt(struct llama_model_context* model_context, char const* prompt);
bool llama_ingest_system_prompt(struct llama_model_context* model_context, char const* prompt, LLAMA_LOGGER_PROGRESS_FUNC progress_fn);

/**
* @brief Ingests the prompt that will not be preserved across memory reset to save memory.
* It is used for having conversation with the model.
*
* @param model_context is the context that is constructed using `llama_create_context`.
* @param prompt is user string that will be processed and produce output.
* @param progress_fn callback used to keep track of the advancement of the ingestion
* @return true if it successfully ingests the prompt.
* @return false if it is unable to ingest the prompt.
*/
bool llama_ingest(struct llama_model_context* model_context, char const* prompt);
bool llama_ingest(struct llama_model_context* model_context, char const* prompt, LLAMA_LOGGER_PROGRESS_FUNC progress_fn);

/**
* @brief Generate the model output from pervious ingested prompt or past conversation.
Expand Down
8 changes: 4 additions & 4 deletions interfaces/c/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,16 @@ extern "C" {
return true;
}

bool llama_ingest(struct llama_model_context* model_context, char const* prompt) {
bool llama_ingest(struct llama_model_context* model_context, char const* prompt, LLAMA_LOGGER_PROGRESS_FUNC progress_fn) {
if (!is_model_valid(model_context)) return false;

return model_context->inner->ingest(std::string(prompt), false);
return model_context->inner->ingest(std::string(prompt),progress_fn, false);
}

bool llama_ingest_system_prompt(struct llama_model_context* model_context, char const* prompt) {
bool llama_ingest_system_prompt(struct llama_model_context* model_context, char const* prompt, LLAMA_LOGGER_PROGRESS_FUNC progress_fn) {
if (!is_model_valid(model_context)) return false;

return model_context->inner->ingest(std::string(prompt), true);
return model_context->inner->ingest(std::string(prompt), progress_fn, true);
}

bool llama_generate(
Expand Down
9 changes: 7 additions & 2 deletions interfaces/python/fastllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,22 +321,27 @@ def load_state(self, filepath: str) -> bool:
fn.restype = ctypes.c_bool
return bool(fn(self.ctx, bytes(filepath, 'utf-8')))

def ingest(self, prompt: str, is_system_prompt: bool = False) -> bool:
def ingest(self, prompt: str, progress_fn=Callable[[int,int], None], is_system_prompt: bool = False) -> bool:
"""
Ingests a prompt into the model.
:param prompt: The prompt to be ingested.
:param is_system_prompt: Flag to indicate if the prompt is a system prompt. Default is False.
:return: True if successful, False otherwise.
"""

def callback_fn(s: ctypes.c_size_t, t: ctypes.c_size_t):
progress_fn(int(s),int(t))
ctype_callback_fn = ctypes.CFUNCTYPE(None, ctypes.c_size_t, ctypes.c_size_t)

if is_system_prompt:
ingest_fn = self.lib.llama_ingest_system_prompt
else:
ingest_fn = self.lib.llama_ingest

ingest_fn.argtypes = [c_llama_model_context_ptr, ctypes.c_char_p]
ingest_fn.restype = ctypes.c_bool
return bool(ingest_fn(self.ctx, bytes(prompt, 'utf-8')))
return bool(ingest_fn(self.ctx, bytes(prompt, 'utf-8'), ctype_callback_fn(callback_fn)))

def generate(
self,
Expand Down
2 changes: 1 addition & 1 deletion lib/bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ namespace fastllama {
return m_model.dump_vocab(filepath);
}

bool FastLlama::ingest(std::string prompt, bool is_system_prompt) {
bool FastLlama::ingest(std::string prompt,std::function<void(size_t const&, size_t const&)> fn, bool is_system_prompt) {
m_model.logger.reset();
if (!m_model.is_valid) {
m_model.logger.log_err("FastLlama::ingest", "tried to ingest using invalid model");
Expand Down
2 changes: 1 addition & 1 deletion src/alpaca.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ int main() {

prompt = "### Instruction:\n\n" + prompt + "\n\n ### Response:\n\n";

if (!bridge.ingest(prompt)) return 2;
if (!bridge.ingest(prompt,[](size_t const& s,size_t const& t){})) return 2;

auto gen_res = bridge.generate([](std::string const& s) {
std::cout<<s;
Expand Down

0 comments on commit 1ebd5ba

Please sign in to comment.