Skip to content

Commit

Permalink
add cerebrum example (#319)
Browse files Browse the repository at this point in the history
* wip

* endpoint
  • Loading branch information
BRama10 authored Nov 17, 2024
1 parent 3bec8b6 commit 467f138
Show file tree
Hide file tree
Showing 12 changed files with 597 additions and 0 deletions.
83 changes: 83 additions & 0 deletions aios/llm_cores/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@

from aios.llm_cores.base import BaseLLM
from aios.llm_cores.registry import API_MODEL_REGISTRY
# from .llm_classes.hf_native_llm import HfNativeLLM

# standard implementation of LLM methods
# from .llm_classes.ollama_llm import OllamaLLM
# from .llm_classes.vllm import vLLM

class LLMAdapter:
"""Parameters for LLMs
Args:
llm_name (str): Name of the LLMs
max_gpu_memory (dict, optional): Maximum GPU resources that can be allocated to the LLM. Defaults to None.
eval_device (str, optional): Evaluation device of binding LLM to designated devices for inference. Defaults to None.
max_new_tokens (int, optional): Maximum token length generated by the LLM. Defaults to 256.
log_mode (str, optional): Mode of logging the LLM processing status. Defaults to "console".
use_backend (str, optional): Backend to use for speeding up open-source LLMs. Defaults to None. Choices are ["vllm", "ollama"]
"""

def __init__(self,
llm_name: str,
max_gpu_memory: dict = None,
eval_device: str = None,
max_new_tokens: int = 256,
use_backend: str = None
):

self.model: BaseLLM | None = None

# For API-based LLM
if llm_name in API_MODEL_REGISTRY.keys():
self.model = API_MODEL_REGISTRY[llm_name](
llm_name = llm_name,
)
# For locally-deployed LLM
else:
if use_backend == "ollama" or llm_name.startswith("ollama"):
# self.model = OllamaLLM(
# llm_name=llm_name,
# max_gpu_memory=max_gpu_memory,
# eval_device=eval_device,
# max_new_tokens=max_new_tokens,
# log_mode=log_mode
# )
pass

elif use_backend == "vllm":
# self.model = vLLM(
# llm_name=llm_name,
# max_gpu_memory=max_gpu_memory,
# eval_device=eval_device,
# max_new_tokens=max_new_tokens,
# log_mode=log_mode
# )
pass
else: # use huggingface LLM without backend
# self.model = HfNativeLLM(
# llm_name=llm_name,
# max_gpu_memory=max_gpu_memory,
# eval_device=eval_device,
# max_new_tokens=max_new_tokens,
# log_mode=log_mode
# )
pass

# def execute(self,
# agent_process,
# temperature=0.0) -> None:
# """Address request sent from the agent

# Args:
# agent_process: AgentProcess object that contains request sent from the agent
# temperature (float, optional): Parameter to control the randomness of LLM output. Defaults to 0.0.
# """
# self.model.execute(agent_process,temperature)

def get_model(self) -> BaseLLM | None:
return self.model



119 changes: 119 additions & 0 deletions aios/llm_cores/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import json
import re

# abc allows to make abstract classes
from abc import ABC, abstractmethod

from aios.utils.id_generator import generator_tool_call_id
from pyopenagi.utils.chat_template import LLMQuery as Query


class BaseLLM(ABC):
def __init__(self,
llm_name: str,
max_gpu_memory: dict = None,
eval_device: str = None,
max_new_tokens: int = 256,
):
self.max_gpu_memory = max_gpu_memory
self.eval_device = eval_device
self.max_new_tokens = max_new_tokens

self.model_name = llm_name

self.load_llm_and_tokenizer()



def convert_map(self, map: dict) -> dict:
""" helper utility to convert the keys of a map to int """
new_map = {}
for k, v in map.items():
new_map[int(k)] = v
return new_map

def check_model_type(self, model_name):
# TODO add more model types
return "causal_lm"


@abstractmethod
def load_llm_and_tokenizer(self) -> None: # load model from config
# raise NotImplementedError
"""Load model and tokenizers for each type of LLMs
"""
return

# only use for open-sourced LLM
def tool_calling_input_format(self, messages: list, tools: list) -> list:
"""Integrate tool information into the messages for open-sourced LLMs
Args:
messages (list): messages with different roles
tools (list): tool information
"""
prefix_prompt = "In and only in current step, you need to call tools. Available tools are: "
tool_prompt = json.dumps(tools)
suffix_prompt = "".join(
[
'Must call functions that are available. To call a function, respond '
'immediately and only with a list of JSON object of the following format:'
'{[{"name":"function_name_value","parameters":{"parameter_name1":"parameter_value1",'
'"parameter_name2":"parameter_value2"}}]}'
]
)

# translate tool call message for models don't support tool call
for message in messages:
if "tool_calls" in message:
message["content"] = json.dumps(message.pop("tool_calls"))
elif message["role"] == "tool":
message["role"] = "user"
tool_call_id = message.pop("tool_call_id")
content = message.pop("content")
message["content"] = f"The result of the execution of function(id :{tool_call_id}) is: {content}. "

messages[-1]["content"] += (prefix_prompt + tool_prompt + suffix_prompt)
return messages

def parse_json_format(self, message: str) -> str:
json_array_pattern = r'\[\s*\{.*?\}\s*\]'
json_object_pattern = r'\{\s*.*?\s*\}'

match_array = re.search(json_array_pattern, message)

if match_array:
json_array_substring = match_array.group(0)

try:
json_array_data = json.loads(json_array_substring)
return json.dumps(json_array_data)
except json.JSONDecodeError:
pass

match_object = re.search(json_object_pattern, message)

if match_object:
json_object_substring = match_object.group(0)

try:
json_object_data = json.loads(json_object_substring)
return json.dumps(json_object_data)
except json.JSONDecodeError:
pass
return '[]'

def parse_tool_calls(self, message):
# add tool call id and type for models don't support tool call
tool_calls = json.loads(self.parse_json_format(message))
for tool_call in tool_calls:
tool_call["id"] = generator_tool_call_id()
tool_call["type"] = "function"
return tool_calls

def execute(self, query: Query):
return self.process(query)

@abstractmethod
def process(self, query: Query):
raise NotImplementedError
141 changes: 141 additions & 0 deletions aios/llm_cores/providers/api/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import re
import json
import anthropic
from typing import List, Dict, Any

from cerebrum.llm.base import BaseLLM
from cerebrum.utils.chat import Query, Response

class ClaudeLLM(BaseLLM):
"""
ClaudeLLM class for interacting with Anthropic's Claude models.
This class provides methods for processing queries using Claude models,
including handling of tool calls and message formatting.
Attributes:
model (anthropic.Anthropic): The Anthropic client for API calls.
tokenizer (None): Placeholder for tokenizer, not used in this implementation.
"""

def __init__(self, llm_name: str,
max_gpu_memory: Dict[int, str] = None,
eval_device: str = None,
max_new_tokens: int = 256):
"""
Initialize the ClaudeLLM instance.
Args:
llm_name (str): Name of the Claude model to use.
max_gpu_memory (Dict[int, str], optional): GPU memory configuration.
eval_device (str, optional): Device for evaluation.
max_new_tokens (int, optional): Maximum number of new tokens to generate.
log_mode (str, optional): Logging mode, defaults to "console".
"""
super().__init__(llm_name,
max_gpu_memory=max_gpu_memory,
eval_device=eval_device,
max_new_tokens=max_new_tokens,)

def load_llm_and_tokenizer(self) -> None:
"""
Load the Anthropic client for API calls.
"""
self.model = anthropic.Anthropic()
self.tokenizer = None

def process(self, query: Query):
"""
Process a query using the Claude model.
Args:
agent_process (Any): The agent process containing the query and tools.
temperature (float, optional): Sampling temperature for generation.
Raises:
AssertionError: If the model name doesn't contain 'claude'.
anthropic.APIError: If there's an error with the Anthropic API call.
Exception: For any other unexpected errors.
"""
assert re.search(r'claude', self.model_name, re.IGNORECASE), "Model name must contain 'claude'"
messages = query.messages
tools = query.tools

print(f"{messages}", level="info")

if tools:
messages = self.tool_calling_input_format(messages, tools)

anthropic_messages = self._convert_to_anthropic_messages(messages)

try:
response = self.model.messages.create(
model=self.model_name,
messages=anthropic_messages,
max_tokens=self.max_new_tokens,
temperature=0.0
)

response_message = response.content[0].text
# self.logger.log(f"API Response: {response_message}", level="info")
tool_calls = self.parse_tool_calls(response_message) if tools else None

return Response(
response_message=response_message,
tool_calls=tool_calls
)
except anthropic.APIError as e:
error_message = f"Anthropic API error: {str(e)}"
self.logger.log(error_message, level="warning")
return Response(
response_message=f"Error: {str(e)}",
tool_calls=None
)
except Exception as e:
error_message = f"Unexpected error: {str(e)}"
self.logger.log(error_message, level="warning")
return Response(
response_message=f"Unexpected error: {str(e)}",
tool_calls=None
)


def _convert_to_anthropic_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""
Convert messages to the format expected by the Anthropic API.
Args:
messages (List[Dict[str, str]]): Original messages.
Returns:
List[Dict[str, str]]: Converted messages for Anthropic API.
"""
anthropic_messages = []
for message in messages:
if message['role'] == 'system':
anthropic_messages.append({"role": "user", "content": f"System: {message['content']}"})
anthropic_messages.append({"role": "assistant", "content": "Understood. I will follow these instructions."})
else:
anthropic_messages.append({
"role": "user" if message['role'] == 'user' else "assistant",
"content": message['content']
})
return anthropic_messages

def tool_calling_output_format(self, tool_calling_messages: str) -> List[Dict[str, Any]]:
"""
Parse the tool calling output from the model's response.
Args:
tool_calling_messages (str): The model's response containing tool calls.
Returns:
List[Dict[str, Any]]: Parsed tool calls, or None if parsing fails.
"""
try:
json_content = json.loads(tool_calling_messages)
if isinstance(json_content, list) and len(json_content) > 0 and 'name' in json_content[0]:
return json_content
except json.JSONDecodeError:
pass
return super().tool_calling_output_format(tool_calling_messages)
Empty file.
Loading

0 comments on commit 467f138

Please sign in to comment.