Skip to content

Commit

Permalink
feat: llama-3.2 on bedrock
Browse files Browse the repository at this point in the history
This is initial support for Llama 3.2 90B vision instruct model!

For such a big model, it's very hard to make it work locally with all
Alumnium requirements (tool calling, structured output, multimodal). For
the time being, AWS Bedrock is a provider that proves to work fine in
this initial implementation.

There are few things to keep in mind in this initial implementation:
1. tool calling types are less strict (e.g. it's common for the model to
   return str instead of int/bool). Pydantic coercion helps with this.
2. vision is disabled for now - when the model is used both with image
   and structured output, the latter does not work. This can probably be
   worked around with custom response parsing, but this is left for the
   future (maybe AWS will fix it eventually).
3. images needs to be resized to max of 1120x1120, but this is not
   implemented yet due to the previous point.

It would be great to use Ollama or Llama.cpp to support true local
inference. This commit however proves that Alumnium can be used with
open models!
  • Loading branch information
p0deje authored and sh3pik committed Nov 22, 2024
1 parent aa16d26 commit ce07ce8
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 14 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
matrix:
model:
- aws_anthropic
- aws_meta
- azure_openai
steps:
- uses: actions/checkout@v4
Expand Down
20 changes: 10 additions & 10 deletions alumnium/agents/actor_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ def invoke(self, goal: str):
logger.info(f" <- Usage: {message.usage_metadata}")

# Move to tool itself to avoid hardcoding it's parameters.
for tool in message.tool_calls:
args = tool.get("args", {}).copy()
if "id" in args:
args["id"] = aria.cached_ids[args["id"]]
if "from_id" in args:
args["from_id"] = aria.cached_ids[args["from_id"]]
if "to_id" in args:
args["to_id"] = aria.cached_ids[args["to_id"]]

ALL_TOOLS[tool["name"]](**args).invoke(self.driver)
for tool_call in message.tool_calls:
tool = ALL_TOOLS[tool_call["name"]](**tool_call["args"])
if "id" in tool.model_fields_set:
tool.id = aria.cached_ids[tool.id]
if "from_id" in tool.model_fields_set:
tool.from_id = aria.cached_ids[tool.from_id]
if "to_id" in tool.model_fields_set:
tool.to_id = aria.cached_ids[tool.to_id]

tool.invoke(self.driver)

@lru_cache()
def __prompt(self, goal: str, aria: str):
Expand Down
8 changes: 4 additions & 4 deletions alumnium/alumni.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from os import getenv

from langchain_anthropic import ChatAnthropic
from langchain_aws import ChatBedrock
from langchain_aws import ChatBedrockConverse
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI

Expand Down Expand Up @@ -35,10 +35,10 @@ def __init__(
)
elif model == Model.ANTHROPIC:
llm = ChatAnthropic(model=model.value, temperature=0, max_retries=2)
elif model == Model.AWS_ANTHROPIC:
llm = ChatBedrock(
elif model == Model.AWS_ANTHROPIC or model == Model.AWS_META:
llm = ChatBedrockConverse(
model_id=model.value,
model_kwargs={"temperature": 0},
temperature=0,
aws_access_key_id=getenv("AWS_ACCESS_KEY", ""),
aws_secret_access_key=getenv("AWS_SECRET_KEY", ""),
region_name=getenv("AWS_REGION_NAME", "us-east-1"),
Expand Down
1 change: 1 addition & 0 deletions alumnium/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class Model(Enum):
AZURE_OPENAI = "gpt-4o-mini" # 2024-07-18
ANTHROPIC = "claude-3-haiku-20240307"
AWS_ANTHROPIC = "anthropic.claude-3-haiku-20240307-v1:0"
AWS_META = "us.meta.llama3-2-90b-instruct-v1:0"
GOOGLE = "gemini-1.5-flash-002"
OPENAI = "gpt-4o-mini-2024-07-18"

Expand Down
1 change: 1 addition & 0 deletions examples/pytest/calculator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


@mark.xfail(Model.load() == Model.AWS_ANTHROPIC, reason="Bedrock version of Haiku is subpar")
@mark.xfail(Model.load() == Model.AWS_META, reason="It is too hard for Llama 3.2")
def test_addition(al, driver):
driver.get("https://seleniumbase.io/apps/calculator")
al.do("1 + 1 =")
Expand Down
4 changes: 4 additions & 0 deletions examples/pytest/drag_and_drop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@


@mark.xfail(Model.load() == Model.AWS_ANTHROPIC, reason="Bedrock version of Haiku is subpar")
@mark.xfail(
Model.load() == Model.AWS_META,
reason="Bedrock Llama 3.2 doesn't support vision and structured output at the same time",
)
def test_drag_and_drop(al, driver):
driver.get("https://the-internet.herokuapp.com/drag_and_drop")
al.check("square A is positioned to the left of square B", vision=True)
Expand Down

0 comments on commit ce07ce8

Please sign in to comment.