-
Notifications
You must be signed in to change notification settings - Fork 1
/
CustomConversationalRetrievalChain.py
93 lines (77 loc) · 3.7 KB
/
CustomConversationalRetrievalChain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
'''
Duplicate partial of the code from LangChain for customization on top of LangChain
'''
from __future__ import annotations
import warnings
from abc import abstractmethod
# from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel, Extra, Field, root_validator
# from langchain.chains.base import Chain
# from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.conversational_retrieval.base import BaseConversationalRetrievalChain
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
#from langchain.schema import BaseLanguageModel, BaseRetriever, Document
from langchain.schema import BaseRetriever, Document
from langchain.base_language import BaseLanguageModel
# from langchain.vectorstores.base import VectorStore
from Custom_load_qa_chain import custom_load_qa_chain
from CommonHelper import *
class CustomConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel):
"""Chain for chatting with an index."""
retriever: BaseRetriever
"""Index to connect to."""
max_tokens_limit: Optional[int] = None
"""If set, restricts the docs to return from store based on tokens, enforced only
for StuffDocumentChain"""
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
num_docs = len(docs)
# if self.max_tokens_limit and isinstance(self.combine_docs_chain, StuffDocumentsChain):
if self.max_tokens_limit > 0 and (isinstance(self.combine_docs_chain, StuffDocumentsChain) or isinstance(self.combine_docs_chain, RefineDocumentsChain)):
tokens = [
# self.combine_docs_chain.llm_chain.llm.get_num_tokens(doc.page_content)
get_rough_token_len(doc.page_content)
for doc in docs
]
token_count = sum(tokens[:num_docs])
while token_count > self.max_tokens_limit:
num_docs -= 1
token_count -= tokens[num_docs]
return docs[:num_docs]
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
print(f"\n[New question] {question}\n")
docs = self.retriever.get_relevant_documents(question)
return self._reduce_tokens_below_limit(docs)
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
docs = await self.retriever.aget_relevant_documents(question)
return self._reduce_tokens_below_limit(docs)
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
retriever: BaseRetriever,
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
# qa_prompt: Optional[BasePromptTemplate] = None,
chain_type: str = "refine",
**kwargs: Any,
) -> BaseConversationalRetrievalChain:
"""Load chain from LLM."""
doc_chain = custom_load_qa_chain(
llm,
chain_type=chain_type,
# question_prompt=qa_prompt,
)
condense_question_chain = LLMChain(
llm=llm, prompt=condense_question_prompt)
return cls(
retriever=retriever,
combine_docs_chain=doc_chain,
question_generator=condense_question_chain,
**kwargs,
)