-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
57 lines (48 loc) · 2.14 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from llama_index.readers import BeautifulSoupWebReader
import torch
from transformers import BitsAndBytesConfig
from llama_index.prompts import PromptTemplate
from llama_index.llms import HuggingFaceLLM
from llama_index import ServiceContext
from llama_index import VectorStoreIndex
from llama_index import VectorStoreIndex, SimpleDirectoryReader
documents = SimpleDirectoryReader("./data").load_data() // replace the ./data with your document's directory path
#Without quantization, the program consumes lot of memory, you can skip this if you have enough memory.
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
def messages_to_prompt(messages):
prompt = ""
for message in messages:
if message.role == 'system':
prompt += f"<|system|>\n{message.content}</s>\n"
elif message.role == 'user':
prompt += f"<|user|>\n{message.content}</s>\n"
elif message.role == 'assistant':
prompt += f"<|assistant|>\n{message.content}</s>\n"
# ensure we start with a system prompt, insert blank if needed
if not prompt.startswith("<|system|>\n"):
prompt = "<|system|>\n</s>\n" + prompt
# add final assistant prompt
prompt = prompt + "<|assistant|>\n"
return prompt
llm = HuggingFaceLLM(
model_name="meta-llama/Llama-2-7b-chat-hf",
tokenizer_name="meta-llama/Llama-2-7b-chat-hf",
query_wrapper_prompt=PromptTemplate("<|system|>\n</s>\n<|user|>\n{query_str}</s>\n<|assistant|>\n"),
context_window=3900,
max_new_tokens=256,
model_kwargs={"quantization_config": quantization_config},
# tokenizer_kwargs={},
generate_kwargs={"temperature": 0.3, "top_k": 50, "top_p": 0.95},
messages_to_prompt=messages_to_prompt,
device_map="auto",
)
service_context = ServiceContext.from_defaults(llm=llm, embed_model="local:BAAI/bge-small-en-v1.5")
vector_index = VectorStoreIndex.from_documents(documents, service_context=service_context)
query_engine = vector_index.as_query_engine(response_mode="compact")
response = query_engine.query("Can you please Summarize document?")
display_response(response)