Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed multi input prompt for MapReduceChain #4979

Merged
merged 8 commits into from
Jun 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 137 additions & 2 deletions docs/modules/chains/index_examples/summarize.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "e9db25f3",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -318,6 +318,141 @@
"chain({\"input_documents\": docs}, return_only_outputs=True)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b882e209",
"metadata": {},
"source": [
"## The custom `MapReduceChain`\n",
"\n",
"**Multi input prompt**\n",
"\n",
"You can also use prompt with multi input. In this example, we will use a MapReduce chain to answer specifc question about our code."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f7ad9ee2",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain\n",
"from langchain.chains.combine_documents.stuff import StuffDocumentsChain\n",
"\n",
"map_template_string = \"\"\"Give the following python code information, generate a description that explains what the code does and also mention the time complexity.\n",
"Code:\n",
"{code}\n",
"\n",
"Return the the description in the following format:\n",
"name of the function: description of the function\n",
"\"\"\"\n",
"\n",
"\n",
"reduce_template_string = \"\"\"Give the following following python fuctions name and their descritpion, answer the following question\n",
"{code_description}\n",
"Question: {question}\n",
"Answer:\n",
"\"\"\"\n",
"\n",
"MAP_PROMPT = PromptTemplate(input_variables=[\"code\"], template=map_template_string)\n",
"REDUCE_PROMPT = PromptTemplate(input_variables=[\"code_description\", \"question\"], template=reduce_template_string)\n",
"\n",
"llm = OpenAI()\n",
"\n",
"map_llm_chain = LLMChain(llm=llm, prompt=MAP_PROMPT)\n",
"reduce_llm_chain = LLMChain(llm=llm, prompt=REDUCE_PROMPT)\n",
"\n",
"generative_result_reduce_chain = StuffDocumentsChain(\n",
" llm_chain=reduce_llm_chain,\n",
" document_variable_name=\"code_description\",\n",
")\n",
"\n",
"combine_documents = MapReduceDocumentsChain(\n",
" llm_chain=map_llm_chain,\n",
" combine_document_chain=generative_result_reduce_chain,\n",
" document_variable_name=\"code\",\n",
")\n",
"\n",
"map_reduce = MapReduceChain(\n",
" combine_documents_chain=combine_documents,\n",
" text_splitter=CharacterTextSplitter(separator=\"\\n##\\n\", chunk_size=100, chunk_overlap=0),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0d4caccb",
"metadata": {},
"outputs": [],
"source": [
"code = \"\"\"\n",
"def bubblesort(list):\n",
" for iter_num in range(len(list)-1,0,-1):\n",
" for idx in range(iter_num):\n",
" if list[idx]>list[idx+1]:\n",
" temp = list[idx]\n",
" list[idx] = list[idx+1]\n",
" list[idx+1] = temp\n",
" return list\n",
"##\n",
"def insertion_sort(InputList):\n",
" for i in range(1, len(InputList)):\n",
" j = i-1\n",
" nxt_element = InputList[i]\n",
" while (InputList[j] > nxt_element) and (j >= 0):\n",
" InputList[j+1] = InputList[j]\n",
" j=j-1\n",
" InputList[j+1] = nxt_element\n",
" return InputList\n",
"##\n",
"def shellSort(input_list):\n",
" gap = len(input_list) // 2\n",
" while gap > 0:\n",
" for i in range(gap, len(input_list)):\n",
" temp = input_list[i]\n",
" j = i\n",
" while j >= gap and input_list[j - gap] > temp:\n",
" input_list[j] = input_list[j - gap]\n",
" j = j-gap\n",
" input_list[j] = temp\n",
" gap = gap//2\n",
" return input_list\n",
"\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d5a9a35b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Created a chunk of size 247, which is longer than the specified 100\n",
"Created a chunk of size 267, which is longer than the specified 100\n"
]
},
{
"data": {
"text/plain": [
"'shellSort has a better time complexity than both bubblesort and insertion_sort, as it has a time complexity of O(n^2), while the other two have a time complexity of O(n^2).'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"map_reduce.run(input_text=code, question=\"Which function has a better time complexity?\")"
]
},
{
"cell_type": "markdown",
"id": "f61350f9",
Expand Down Expand Up @@ -470,7 +605,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.8.16"
},
"vscode": {
"interpreter": {
Expand Down
20 changes: 16 additions & 4 deletions langchain/chains/mapreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
from __future__ import annotations

from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Mapping, Optional

from pydantic import Extra

Expand Down Expand Up @@ -38,15 +38,22 @@ def from_params(
prompt: BasePromptTemplate,
text_splitter: TextSplitter,
callbacks: Callbacks = None,
combine_chain_kwargs: Optional[Mapping[str, Any]] = None,
reduce_chain_kwargs: Optional[Mapping[str, Any]] = None,
**kwargs: Any,
) -> MapReduceChain:
"""Construct a map-reduce chain that uses the chain for map and reduce."""
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks)
reduce_chain = StuffDocumentsChain(llm_chain=llm_chain, callbacks=callbacks)
reduce_chain = StuffDocumentsChain(
llm_chain=llm_chain,
callbacks=callbacks,
**(reduce_chain_kwargs if reduce_chain_kwargs else {}),
)
combine_documents_chain = MapReduceDocumentsChain(
llm_chain=llm_chain,
combine_document_chain=reduce_chain,
callbacks=callbacks,
**(combine_chain_kwargs if combine_chain_kwargs else {}),
)
return cls(
combine_documents_chain=combine_documents_chain,
Expand Down Expand Up @@ -84,9 +91,14 @@ def _call(
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
# Split the larger text into smaller chunks.
texts = self.text_splitter.split_text(inputs[self.input_key])
doc_text = inputs.pop(self.input_key)
texts = self.text_splitter.split_text(doc_text)
docs = [Document(page_content=text) for text in texts]
_inputs: Dict[str, Any] = {
**inputs,
self.combine_documents_chain.input_key: docs,
}
outputs = self.combine_documents_chain.run(
input_documents=docs, callbacks=_run_manager.get_child()
_inputs, callbacks=_run_manager.get_child()
)
return {self.output_key: outputs}