From 2e5c379632b0159acd39a2e7627c688f1bd9c2eb Mon Sep 17 00:00:00 2001 From: ccurme Date: Tue, 27 Aug 2024 16:18:19 -0400 Subject: [PATCH] openai[patch]: fix get_num_tokens for function calls (#25785) Closes https://github.com/langchain-ai/langchain/issues/25784 See additional discussion [here](https://github.com/langchain-ai/langchain/commit/0a4ee864e9fbee7e9e328b21280df1f0fbd788e7#r145147380). --- libs/partners/openai/langchain_openai/chat_models/base.py | 2 +- .../openai/tests/unit_tests/chat_models/test_base.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index ca54ce32e6e6d..47c89929ea5db 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -947,7 +947,7 @@ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: else: # Cast str(value) in case the message value is not a string # This occurs with function messages - num_tokens += len(encoding.encode(value)) + num_tokens += len(encoding.encode(str(value))) if key == "name": num_tokens += tokens_per_name # every reply is primed with assistant diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 4041718368ca2..4e959f005990d 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -677,7 +677,10 @@ def test_get_num_tokens_from_messages() -> None: AIMessage( "", additional_kwargs={ - "function_call": json.dumps({"arguments": "old", "name": "fun"}) + "function_call": { + "arguments": json.dumps({"arg1": "arg1"}), + "name": "fun", + } }, ), AIMessage( @@ -688,6 +691,6 @@ def test_get_num_tokens_from_messages() -> None: ), ToolMessage("foobar", tool_call_id="foo"), ] - expected = 170 + expected = 176 actual = llm.get_num_tokens_from_messages(messages) assert expected == actual