Skip to content

Commit

Permalink
handle < 3.7
Browse files Browse the repository at this point in the history
  • Loading branch information
vowelparrot committed May 22, 2023
1 parent 5563713 commit d1f4103
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 12 deletions.
27 changes: 17 additions & 10 deletions langchain/chat_models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@

import logging
import sys
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Union,
)

from pydantic import Extra, Field, root_validator
from tenacity import (
Expand All @@ -29,13 +39,14 @@
SystemMessage,
)
from langchain.utils import get_from_dict_or_env

if TYPE_CHECKING:
import tiktoken

logger = logging.getLogger(__name__)


def _import_tiktoken() -> tiktoken:
def _import_tiktoken() -> Any:
try:
import tiktoken
except ImportError:
Expand Down Expand Up @@ -369,7 +380,7 @@ def _llm_type(self) -> str:
return "openai-chat"

def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
tiktoken = _import_tiktoken()
tiktoken_ = _import_tiktoken()
model = self.model_name
if model == "gpt-3.5-turbo":
# gpt-3.5-turbo may change over time.
Expand All @@ -381,11 +392,11 @@ def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
model = "gpt-4-0314"
# Returns the number of tokens used by a list of messages.
try:
encoding = tiktoken.encoding_for_model(model)
encoding = tiktoken_.encoding_for_model(model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
encoding = tiktoken_.get_encoding(model)
return model, encoding

def get_token_ids(self, text: str) -> List[int]:
Expand All @@ -394,11 +405,7 @@ def get_token_ids(self, text: str) -> List[int]:
if sys.version_info[1] <= 7:
return super().get_token_ids(text)
_, encoding_model = self._get_encoding_model()
return encoding_model.encode(
text,
allowed_special=self.allowed_special,
disallowed_special=self.disallowed_special,
)
return encoding_model.encode(text)

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Expand Down
4 changes: 2 additions & 2 deletions langchain/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def get_sub_prompts(
)
params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
sub_prompts = [
prompts[i: i + self.batch_size]
prompts[i : i + self.batch_size]
for i in range(0, len(prompts), self.batch_size)
]
return sub_prompts
Expand All @@ -386,7 +386,7 @@ def create_llm_result(
"""Create the LLMResult from the choices and prompts."""
generations = []
for i, _ in enumerate(prompts):
sub_choices = choices[i * self.n: (i + 1) * self.n]
sub_choices = choices[i * self.n : (i + 1) * self.n]
generations.append(
[
Generation(
Expand Down
38 changes: 38 additions & 0 deletions tests/integration_tests/llms/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.openai import ChatOpenAI
from langchain.llms.loading import load_llm
from langchain.llms.openai import OpenAI, OpenAIChat
from langchain.schema import LLMResult
Expand Down Expand Up @@ -237,3 +238,40 @@ def test_openai_modelname_to_contextsize_invalid() -> None:
"""Test model name to context size on an invalid model."""
with pytest.raises(ValueError):
OpenAI().modelname_to_contextsize("foobar")


_EXPECTED_NUM_TOKENS = {
"ada": 17,
"babbage": 17,
"curie": 17,
"davinci": 17,
"gpt-4": 12,
"gpt-4-32k": 12,
"gpt-3.5-turbo": 12,
}

_MODELS = models = [
"ada",
"babbage",
"curie",
"davinci",
]
_CHAT_MODELS = [
"gpt-4",
"gpt-4-32k",
"gpt-3.5-turbo",
]


@pytest.mark.parametrize("model", _MODELS)
def test_openai_get_num_tokens(model: str) -> None:
"""Test get_tokens."""
llm = OpenAI(model=model)
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]


@pytest.mark.parametrize("model", _CHAT_MODELS)
def test_chat_openai_get_num_tokens(model: str) -> None:
"""Test get_tokens."""
llm = ChatOpenAI(model=model)
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]

0 comments on commit d1f4103

Please sign in to comment.