Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ChatBedrockConverse#stream not streaming response for model ids with cross region inference when bind_tools is used #239

Closed
renjiexu-amzn opened this issue Oct 14, 2024 · 3 comments · Fixed by #242

Comments

@renjiexu-amzn
Copy link
Contributor

To reproduce, the following code will result in the same behavior as invoke; if comment out .bind_tools line, the response would be properly streamed.

from langchain_aws import ChatBedrockConverse
from langchain_core.tools import tool

@tool(response_format="content_and_artifact")
def simple_calculator(a: int, b: int):
    """Use this tool to calcuate the sum of two integers.

    Args:
        a (int): The first integer.
        b (int): The second integer.

    Returns:
        int: The sum of the two integers.
    """
    return a + b

llm = ChatBedrockConverse(
    model="us.anthropic.claude-3-sonnet-20240229-v1:0",
    temperature=0,
    top_p=1,
    max_tokens=4096,
    region_name="us-west-2"
).bind_tools(tools=[simple_calculator])

a = llm.stream(
    input=[
        ("human", "Hello"),
    ],
)

full = next(a)

for x in a:
    print(x)
    full += x

print(full)
@langcarl langcarl bot added the investigate label Oct 14, 2024
@renjiexu-amzn
Copy link
Contributor Author

renjiexu-amzn commented Oct 14, 2024

Root cause is the logic to infer the provider from model/model ID doesn't support the cross-region inference profile ID properly, where the provider would be the second element after the split.

The workaround is to explicitly provide the provider value during the setup of the ChatBedrockConverse

from langchain_aws import ChatBedrockConverse
from langchain_core.tools import tool

@tool(response_format="content_and_artifact")
def simple_calculator(a: int, b: int):
    """Use this tool to calcuate the sum of two integers.

    Args:
        a (int): The first integer.
        b (int): The second integer.

    Returns:
        int: The sum of the two integers.
    """
    return a + b

llm = ChatBedrockConverse(
    model="us.anthropic.claude-3-sonnet-20240229-v1:0",
    temperature=0,
    top_p=1,
    max_tokens=4096,
    region_name="us-west-2",
    provider="anthropic"
).bind_tools(tools=[simple_calculator])

a = llm.stream(
    input=[
        ("human", "Hello"),
    ],
)

full = next(a)

for x in a:
    print(x)
    full += x

print(full)

@3coins 3coins changed the title ChatBedrockConverse#stream not streaming the response if has bind_tools ChatBedrockConverse#stream not streaming response for model ids with cross region inference when bind_tools is used Oct 15, 2024
@3coins
Copy link
Collaborator

3coins commented Oct 15, 2024

@renjiexu-amzn
Thanks for reporting this issue. The converse API has many different ways to specify a model id with a mix of arns, foundation model, inference profiles and model ids. While we can look at a long term solution to support and identify each of these formats, a short-term fix to support inference profile ids (without hard-coding regions) will be to look at how many parts the model id has. Here is a quick attempt at this formula.

def get_provider(model_id: str) -> str:
    parts = model_id.split(".")
    return parts[1] if len(parts) == 3 else parts[0]

assert "meta" == get_provider("meta.llama3-2-3b-instruct-v1:0") # mode id
assert "meta" == get_provider("us.meta.llama3-2-3b-instruct-v1:0") # inference profile id

Let me know if the above works for you, and if you want to open a PR to make the change.

An alternate solution could be to use the Bedrock API to get more info about the model, I am not sure if the Bedrock API returns the provider info for all models, so we have to verify that. This solution will also need some more consideration at calling the API only once during initialization of the chat class.

@sidatcd
Copy link

sidatcd commented Nov 18, 2024

This inference profile model ids are validated by its regions.
There is one available fro apac now.Can we include that in the validation list??

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants