Skip to content

Commit

Permalink
feat: fix anthropic reasking (#560)
Browse files Browse the repository at this point in the history
  • Loading branch information
jxnl authored Apr 3, 2024
1 parent 1f1cb5e commit 1f7926d
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 35 deletions.
3 changes: 1 addition & 2 deletions examples/groq/groq_example2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from pydantic import BaseModel, Field
from typing import List
from pydantic import BaseModel
from groq import Groq
import instructor

Expand Down
17 changes: 7 additions & 10 deletions instructor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import importlib

from .mode import Mode
from .process_response import handle_response_model
from .distil import FinetuneFormat, Instructions
Expand Down Expand Up @@ -44,18 +46,13 @@
"handle_response_model",
]

try:
import anthropic

if importlib.util.find_spec("anthropic") is not None:
from .client_anthropic import from_anthropic

__all__.append("from_anthropic")
except ImportError:
pass
__all__ += ["from_anthropic"]

try:
import groq
if importlib.util.find_spec("groq") is not None:
from .client_groq import from_groq

__all__.append("from_groq")
except ImportError:
pass
__all__ += ["from_groq"]
11 changes: 7 additions & 4 deletions instructor/client_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ def from_anthropic(
mode: instructor.Mode = instructor.Mode.ANTHROPIC_JSON,
**kwargs,
) -> instructor.Instructor | instructor.AsyncInstructor:
assert mode in {
instructor.Mode.ANTHROPIC_JSON,
instructor.Mode.ANTHROPIC_TOOLS,
}, "Mode be one of {instructor.Mode.ANTHROPIC_JSON, instructor.Mode.ANTHROPIC_TOOLS}"
assert (
mode
in {
instructor.Mode.ANTHROPIC_JSON,
instructor.Mode.ANTHROPIC_TOOLS,
}
), "Mode be one of {instructor.Mode.ANTHROPIC_JSON, instructor.Mode.ANTHROPIC_TOOLS}"

assert isinstance(
client, (anthropic.Anthropic, anthropic.AsyncAnthropic)
Expand Down
9 changes: 6 additions & 3 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,15 @@ def parse_anthropic_tools(
@classmethod
def parse_anthropic_json(
cls: Type[BaseModel],
completion: ChatCompletion,
completion,
validation_context: Optional[Dict[str, Any]] = None,
strict: Optional[bool] = None,
) -> BaseModel:
assert hasattr(completion, "content")
text = completion.content[0].text # type: ignore
from anthropic.types import Message

assert isinstance(completion, Message)

text = completion.content[0].text
extra_text = extract_json_from_codeblock(text)
return cls.model_validate_json(
extra_text, context=validation_context, strict=strict
Expand Down
7 changes: 6 additions & 1 deletion instructor/process_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from instructor.dsl.partial import PartialBase
from instructor.dsl.simple_type import AdapterBase, ModelAdapter, is_simple_type
from instructor.function_calls import OpenAISchema, openai_schema

from instructor.utils import merge_consecutive_messages
from openai.types.chat import ChatCompletion
from pydantic import BaseModel

Expand Down Expand Up @@ -333,6 +333,11 @@ def handle_response_model(
for message in new_kwargs.get("messages", [])
if message["role"] != "system"
]

# the messages array must be alternating roles of user and assistant, we must merge
# consecutive user messages into a single message
new_kwargs["messages"] = merge_consecutive_messages(new_kwargs["messages"])

else:
raise ValueError(f"Invalid patch mode: {mode}")

Expand Down
18 changes: 16 additions & 2 deletions instructor/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from openai.types.chat import ChatCompletion
from instructor.mode import Mode
from instructor.process_response import process_response, process_response_async
from instructor.utils import dump_message, update_total_usage
from instructor.utils import (
dump_message,
update_total_usage,
merge_consecutive_messages,
)

from openai.types.completion_usage import CompletionUsage
from pydantic import ValidationError
Expand All @@ -26,9 +30,18 @@
def reask_messages(response: ChatCompletion, mode: Mode, exception: Exception):
if mode == Mode.ANTHROPIC_TOOLS:
# TODO: we need to include the original response
yield {
"role": "assistant",
"content": f"Validation Errors found:\n{exception}\nRecall the function correctly, fix the errors",
}
return
if mode == Mode.ANTHROPIC_JSON:
from anthropic.types import Message

assert isinstance(response, Message)
yield {
"role": "user",
"content": f"Validation Error found:\n{exception}\nRecall the function correctly, fix the errors",
"content": f"""Validation Errors found:\n{exception}\nRecall the function correctly, fix the errors found in the following attempt:\n{response.content[0].text}""",
}
return

Expand Down Expand Up @@ -94,6 +107,7 @@ def retry_sync(
except (ValidationError, JSONDecodeError) as e:
logger.debug(f"Error response: {response}")
kwargs["messages"].extend(reask_messages(response, mode, e))
kwargs["messages"] = merge_consecutive_messages(kwargs["messages"])
raise e
except RetryError as e:
logger.exception(f"Failed after retries: {e.last_attempt.exception}")
Expand Down
12 changes: 12 additions & 0 deletions instructor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,15 @@ def is_async(func: Callable) -> bool:
func = func.__wrapped__
is_coroutine = is_coroutine or inspect.iscoroutinefunction(func)
return is_coroutine


def merge_consecutive_messages(messages: list[dict]) -> list[dict]:
# merge all consecutive user messages into a single message
new_messages = []
for message in messages:
if len(new_messages) > 0 and message["role"] == new_messages[-1]["role"]:
new_messages[-1]["content"] += f"\n\n{message['content']}"
else:
new_messages.append(message)

return new_messages
31 changes: 18 additions & 13 deletions tests/llm/test_anthropic/evals/test_simple.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import anthropic
import instructor
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from typing import List, Literal
from enum import Enum

create = instructor.patch(
create=anthropic.Anthropic().messages.create, mode=instructor.Mode.ANTHROPIC_JSON
client = instructor.from_anthropic(
anthropic.Anthropic(), mode=instructor.Mode.ANTHROPIC_JSON
)


Expand All @@ -14,10 +14,15 @@ class User(BaseModel):
name: str
age: int

resp = create(
@field_validator("name")
def name_is_uppercase(cls, v: str):
assert v.isupper(), "Name must be uppercase"
return v

resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
max_retries=2,
messages=[
{
"role": "user",
Expand All @@ -28,7 +33,7 @@ class User(BaseModel):
) # type: ignore

assert isinstance(resp, User)
assert resp.name == "John"
assert resp.name == "JOHN" # due to validation
assert resp.age == 18


Expand All @@ -42,7 +47,7 @@ class User(BaseModel):
age: int
address: Address

resp = create(
resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
Expand Down Expand Up @@ -70,7 +75,7 @@ class User(BaseModel):
age: int
family: List[str]

resp = create(
resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
Expand Down Expand Up @@ -98,7 +103,7 @@ class User(BaseModel):
name: str
role: Role

resp = create(
resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
Expand All @@ -120,10 +125,10 @@ class User(BaseModel):
name: str
role: Literal["admin", "user"]

resp = create(
resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
max_retries=2,
messages=[
{
"role": "user",
Expand All @@ -147,7 +152,7 @@ class User(BaseModel):
age: int
properties: List[Properties]

resp = create(
resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
Expand All @@ -170,7 +175,7 @@ class User(BaseModel):
name: str
age: int

resp = create(
resp = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
max_retries=0,
Expand Down
33 changes: 33 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
extract_json_from_codeblock,
extract_json_from_stream,
extract_json_from_stream_async,
merge_consecutive_messages,
)


Expand Down Expand Up @@ -125,3 +126,35 @@ async def batch_strings_async(chunks, n=2):
"key": "value",
"another_key": [{"key": {"key": "value"}}, {"key": "value"}],
}


def test_merge_consecutive_messages():
messages = [
{"role": "user", "content": "Hello"},
{"role": "user", "content": "How are you"},
{"role": "assistant", "content": "Hello"},
{"role": "assistant", "content": "I am good"},
]
result = merge_consecutive_messages(messages)
assert result == [
{"role": "user", "content": "Hello\n\nHow are you"},
{"role": "assistant", "content": "Hello\n\nI am good"},
]


def test_merge_consecutive_messages_empty():
messages = []
result = merge_consecutive_messages(messages)
assert result == []


def test_merge_consecutive_messages_single():
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hello"},
]
result = merge_consecutive_messages(messages)
assert result == [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hello"},
]

0 comments on commit 1f7926d

Please sign in to comment.