Skip to content

Commit

Permalink
update LLM callings
Browse files Browse the repository at this point in the history
  • Loading branch information
dongyuanjushi committed Nov 25, 2024
1 parent 90b358e commit dae66a2
Show file tree
Hide file tree
Showing 6 changed files with 728 additions and 3 deletions.
8 changes: 6 additions & 2 deletions aios/llm_core/adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

from aios.llm_core.cores.base import BaseLLM
from aios.llm_core.registry import MODEL_REGISTRY

from aios.llm_core.cores.local.ollama import OllamaLLM

class LLMAdapter:
"""Parameters for LLMs
Expand Down Expand Up @@ -37,7 +37,11 @@ def __init__(self,
# For locally-deployed LLM
else:
if use_backend == "ollama" or llm_name.startswith("ollama"):
pass
self.model = OllamaLLM(
llm_name = llm_name,
log_mode = log_mode,
use_context_manager = use_context_manager
)
#ollama here
elif use_backend == "vllm":
# VLLM here
Expand Down
213 changes: 213 additions & 0 deletions aios/llm_core/cores/api/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import re
import json
import time
import anthropic
from typing import List, Dict, Any

from aios.llm_core.cores.base import BaseLLM

from cerebrum.llm.communication import Response


class ClaudeLLM(BaseLLM):
"""
ClaudeLLM class for interacting with Anthropic's Claude models.
This class provides methods for processing queries using Claude models,
including handling of tool calls and message formatting.
Attributes:
model (anthropic.Anthropic): The Anthropic client for API calls.
tokenizer (None): Placeholder for tokenizer, not used in this implementation.
"""

def __init__(
self,
llm_name: str,
max_gpu_memory: Dict[int, str] = None,
eval_device: str = None,
max_new_tokens: int = 256,
log_mode: str = "console",
use_context_manager: bool = False,
):
"""
Initialize the ClaudeLLM instance.
Args:
llm_name (str): Name of the Claude model to use.
max_gpu_memory (Dict[int, str], optional): GPU memory configuration.
eval_device (str, optional): Device for evaluation.
max_new_tokens (int, optional): Maximum number of new tokens to generate.
log_mode (str, optional): Logging mode, defaults to "console".
"""
super().__init__(
llm_name,
max_gpu_memory=max_gpu_memory,
eval_device=eval_device,
max_new_tokens=max_new_tokens,
log_mode=log_mode,
use_context_manager=use_context_manager,
)

def load_llm_and_tokenizer(self) -> None:
"""
Load the Anthropic client for API calls.
"""
self.model = anthropic.Anthropic()
self.tokenizer = None

def convert_tools(self, tools):
anthropic_tools = []
# print(tools)
for tool in tools:
anthropic_tool = tool["function"]
anthropic_tool["input_schema"] = anthropic_tool["parameters"]
anthropic_tool.pop("parameters")
anthropic_tools.append(anthropic_tool)
# print(anthropic_tools)
return anthropic_tools

def address_syscall(self, llm_syscall, temperature: float = 0.0) -> None:
"""
Process a request_data using the Claude model.
Args:
agent_request (Any): The agent process containing the request_data and tools.
temperature (float, optional): Sampling temperature for generation.
Raises:
AssertionError: If the model name doesn't contain 'claude'.
anthropic.APIError: If there's an error with the Anthropic API call.
Exception: For any other unexpected errors.
"""
assert re.search(
r"claude", self.model_name, re.IGNORECASE
), "Model name must contain 'claude'"
llm_syscall.set_status("executing")
llm_syscall.set_start_time(time.time())
messages = llm_syscall.query.messages
tools = llm_syscall.query.tools

self.logger.log(f"{messages}", level="info")
self.logger.log(
f"{llm_syscall.agent_name} is switched to executing.", level="executing"
)

if tools:
# messages = self.tool_calling_input_format(messages, tools)
tools = self.convert_tools(tools)

anthropic_messages = self._convert_to_anthropic_messages(messages)
self.logger.log(f"{anthropic_messages}", level="info")

try:
response = self.model.messages.create(
model=self.model_name,
messages=anthropic_messages,
max_tokens=self.max_new_tokens,
temperature=temperature,
# tools=tools,
)

print(response)

response_message = response.content[0].text
self.logger.log(f"API Response: {response_message}", level="info")
tool_calls = self.parse_tool_calls(response_message) if tools else None

response = Response(
response_message=response_message, tool_calls=tool_calls
)

# agent_request.set_response(
# Response(
# response_message=response_message,
# tool_calls=tool_calls
# )
# )
except anthropic.APIError as e:
error_message = f"Anthropic API error: {str(e)}"
self.logger.log(error_message, level="warning")

response = Response(response_message=f"Error: {str(e)}", tool_calls=None)

# agent_request.set_response(
# Response(
# response_message=f"Error: {str(e)}",
# tool_calls=None
# )
# )

except Exception as e:
error_message = f"Unexpected error: {str(e)}"
self.logger.log(error_message, level="warning")
# agent_request.set_response(
# Response(
# response_message=f"Unexpected error: {str(e)}",
# tool_calls=None
# )
# )
response = Response(
response_message=f"Unexpected error: {str(e)}", tool_calls=None
)

return response
# agent_request.set_status("done")
# agent_request.set_end_time(time.time())

def _convert_to_anthropic_messages(
self, messages: List[Dict[str, str]]
) -> List[Dict[str, str]]:
"""
Convert messages to the format expected by the Anthropic API.
Args:
messages (List[Dict[str, str]]): Original messages.
Returns:
List[Dict[str, str]]: Converted messages for Anthropic API.
"""
anthropic_messages = []
for message in messages:
if message["role"] == "system":
anthropic_messages.append(
{"role": "user", "content": f"System: {message['content']}"}
)
anthropic_messages.append(
{
"role": "assistant",
"content": "Understood. I will follow these instructions.",
}
)
else:
anthropic_messages.append(
{
"role": "user" if message["role"] == "user" else "assistant",
"content": message["content"],
}
)
return anthropic_messages

def tool_calling_output_format(
self, tool_calling_messages: str
) -> List[Dict[str, Any]]:
"""
Parse the tool calling output from the model's response.
Args:
tool_calling_messages (str): The model's response containing tool calls.
Returns:
List[Dict[str, Any]]: Parsed tool calls, or None if parsing fails.
"""
try:
json_content = json.loads(tool_calling_messages)
if (
isinstance(json_content, list)
and len(json_content) > 0
and "name" in json_content[0]
):
return json_content
except json.JSONDecodeError:
pass
return super().tool_calling_output_format(tool_calling_messages)
Loading

0 comments on commit dae66a2

Please sign in to comment.