Skip to content

Commit

Permalink
Merge pull request #11 from laelhalawani/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
laelhalawani authored Jan 11, 2024
2 parents daf7df7 + 56946b9 commit 0f30e90
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
6 changes: 0 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@ Generate text by calling infer():
text = ai.infer("Once upon a time")
print(text)"
```
Adjust model tokens to fit longer prompts:
```python
"big_prompt = "..." # prompt longer than max input tokens

text = ai.infer(big_prompt, max_tokens_if_needed=2000)"
```
## Installation

```python
Expand Down
25 changes: 22 additions & 3 deletions gguf_llama/gguf_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class LlamaAI:
_llama_kwrgs (dict): Additional kwargs to pass when loading Llama model.
"""

def __init__(self, model_gguf_path:str, max_tokens:int, **llama_kwrgs:Any) -> None:
def __init__(self, model_gguf_path:str, max_tokens:int, **llama_kwargs:Any) -> None:
"""
Initialize the LlamaAI instance.
Expand All @@ -32,7 +32,8 @@ def __init__(self, model_gguf_path:str, max_tokens:int, **llama_kwrgs:Any) -> No
self.llm = None
self.tokenizer = None
self._loaded = False
self._llama_kwrgs = llama_kwrgs
self._llama_kwrgs = llama_kwargs
self._embeddings_mode = True
self.load()


Expand All @@ -44,10 +45,28 @@ def load(self) -> None:
Sets _loaded to True once complete.
"""
print(f"Loading model from {self.model_path}...")
self.llm = Llama(model_path=self.model_path, verbose=False, n_ctx=self.max_tokens, kwargs=self._llama_kwrgs)
llama_kwargs = {"embedding": self._embeddings_mode}
for k, v in self._llama_kwrgs.items():
llama_kwargs[k] = v
self.llm = Llama(model_path=self.model_path, verbose=False, n_ctx=self.max_tokens, **llama_kwargs)
self.tokenizer = LlamaTokenizer(self.llm)
self._loaded = True

def create_embeddings(self, text:str) -> list[float]:
"""
Create embeddings for the input text.
Args:
text (str): The text to create embeddings for.
"""
self._check_loaded()
if not self._embeddings_mode:
print("Switching to embeddings mode...")
self._embeddings_mode = True
self.load()
embs = self.llm.embed(text)
return embs

def _try_fixing_format(self, text: str, only_letters: bool = False, rem_list_formatting: bool = False) -> str:
"""
Attempt to fix formatting issues in the input text.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
long_description = fh.read()
setup(
name="gguf_llama",
version="0.0.15",
version="0.0.16",
packages=find_packages(),
install_requires=[
'util_helper>=0.0.3',
Expand Down

0 comments on commit 0f30e90

Please sign in to comment.