-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
72 lines (54 loc) · 2.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import logging
from transformers import pipeline
from langchain.llms import HuggingFacePipeline
from langchain.document_loaders import TextLoader
from langchain.document_loaders import UnstructuredPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
LOG = logging.getLogger(__name__)
def documents_loader(files, chunk_size, chunk_overlap):
documents = []
for name in files:
suffix = name.split(".")[-1]
if suffix == "txt":
loader = TextLoader(name)
elif suffix == "pdf":
loader = UnstructuredPDFLoader(name)
else:
LOG.warning(f"Currently document {name} is not supported")
documents += loader.load()
text_splitter = CharacterTextSplitter(
separator="\n",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len)
splited_docs = text_splitter.split_documents(documents)
return splited_docs
def make_pipeline(model, tokenizer):
LOG.info("creating transformer pipeline...")
model_pipeline = pipeline("text-generation",
model=model,
tokenizer=tokenizer,
device=0,
max_length=512,
temperature=0.7)
return model_pipeline
# you can specify RetrievalQA and use it to fetch docs along with answer
def make_chain(pipeline):
LOG.info("creating chain...")
llm = HuggingFacePipeline(pipeline=pipeline)
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Answer: """
prompt = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
chain_type = "stuff"
# chain_type = "map_rerank"
LOG.info("Loading Q&A chain...")
chain = load_qa_chain(llm, chain_type=chain_type, prompt=prompt)
# debug the prompt
LOG.debug(chain.llm_chain.prompt.template)
return chain