Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cookbook for selectins llms based on context length #12486

Merged
merged 1 commit into from
Oct 29, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions cookbook/selecting_llms_based_on_context_length.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "e93283d1",
"metadata": {},
"source": [
"# Selecting LLMs based on Context Length\n",
"\n",
"Different LLMs have different context lengths. As a very immediate an practical example, OpenAI has two versions of GPT-3.5-Turbo: one with 4k context, another with 16k context. This notebook shows how to route between them based on input."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "cc453450",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"from langchain.schema.prompt import PromptValue\n",
"from langchain.schema.messages import BaseMessage\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.schema.output_parser import StrOutputParser\n",
"from typing import Union, Sequence"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1cec6a10",
"metadata": {},
"outputs": [],
"source": [
"short_context_model = ChatOpenAI(model=\"gpt-3.5-turbo\")\n",
"long_context_model = ChatOpenAI(model=\"gpt-3.5-turbo-16k\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "772da153",
"metadata": {},
"outputs": [],
"source": [
"def get_context_length(prompt: PromptValue):\n",
" messages = prompt.to_messages()\n",
" tokens = short_context_model.get_num_tokens_from_messages(messages)\n",
" return tokens"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "db771e20",
"metadata": {},
"outputs": [],
"source": [
"prompt = PromptTemplate.from_template(\"Summarize this passage: {context}\")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "af057e2f",
"metadata": {},
"outputs": [],
"source": [
"def choose_model(prompt: PromptValue):\n",
" context_len = get_context_length(prompt)\n",
" if context_len < 30:\n",
" print(\"short model\")\n",
" return short_context_model\n",
" else:\n",
" print(\"long model\")\n",
" return long_context_model"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "84f3e07d",
"metadata": {},
"outputs": [],
"source": [
"chain = prompt | choose_model | StrOutputParser()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "d8b14f8f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"short model\n"
]
},
{
"data": {
"text/plain": [
"'The passage mentions that a frog visited a pond.'"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"context\": \"a frog went to a pond\"})"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "70ebd3dd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"long model\n"
]
},
{
"data": {
"text/plain": [
"'The passage describes a frog that moved from one pond to another and perched on a log.'"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"context\": \"a frog went to a pond and sat on a log and went to a different pond\"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a7e29fef",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading