Skip to content

Commit

Permalink
[Core] implement redis cache mode (#1222)
Browse files Browse the repository at this point in the history
* implement redis cache mode, if redis_url is set in the llm_config then
it will try to use this.  also adds a test to validate both the existing
and the redis cache behavior.

* PR feedback, add unit tests

* more PR feedback, move the new style cache to a context manager

* Update agent_chat.md

* more PR feedback, remove tests from contrib and have them run with the normal jobs

* doc

* updated

* Update website/docs/Use-Cases/agent_chat.md

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* update docs

* update docs; let openaiwrapper to use cache object

* typo

* Update website/docs/Use-Cases/enhanced_inference.md

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* save previous client cache and reset it after send/a_send

* a_run_chat

---------

Co-authored-by: Vijay Ramesh <vijay@regrello.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
Co-authored-by: Chi Wang <wang.chi@microsoft.com>
  • Loading branch information
4 people authored Jan 20, 2024
1 parent e97b639 commit ee6ad8d
Show file tree
Hide file tree
Showing 21 changed files with 1,149 additions and 17 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ jobs:
if: matrix.python-version == '3.10'
run: |
pip install -e .[test]
pip install -e .[redis]
coverage run -a -m pytest test --ignore=test/agentchat/contrib --skip-openai
coverage xml
- name: Upload coverage to Codecov
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/openai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ jobs:
python-version: ["3.9", "3.10", "3.11", "3.12"]
runs-on: ${{ matrix.os }}
environment: openai1
services:
redis:
image: redis
ports:
- 6379:6379
options: --entrypoint redis-server
steps:
# checkout to pr branch
- name: Checkout
Expand All @@ -42,6 +48,7 @@ jobs:
if: matrix.python-version == '3.9'
run: |
pip install docker
pip install -e .[redis]
- name: Coverage
if: matrix.python-version == '3.9'
env:
Expand Down
23 changes: 22 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union

from .. import OpenAIWrapper
from ..cache.cache import Cache
from ..code_utils import (
DEFAULT_MODEL,
UNKNOWN,
Expand Down Expand Up @@ -135,6 +136,9 @@ def __init__(
self.llm_config.update(llm_config)
self.client = OpenAIWrapper(**self.llm_config)

# Initialize standalone client cache object.
self.client_cache = None

self._code_execution_config: Union[Dict, Literal[False]] = (
{} if code_execution_config is None else code_execution_config
)
Expand Down Expand Up @@ -665,6 +669,7 @@ def initiate_chat(
recipient: "ConversableAgent",
clear_history: Optional[bool] = True,
silent: Optional[bool] = False,
cache: Optional[Cache] = None,
**context,
):
"""Initiate a chat with the recipient agent.
Expand All @@ -677,6 +682,7 @@ def initiate_chat(
recipient: the recipient agent.
clear_history (bool): whether to clear the chat history with the agent.
silent (bool or None): (Experimental) whether to print the messages for this conversation.
cache (Cache or None): the cache client to be used for this conversation.
**context: any context information.
"message" needs to be provided if the `generate_init_message` method is not overridden.
Otherwise, input() will be called to get the initial message.
Expand All @@ -686,14 +692,20 @@ def initiate_chat(
"""
for agent in [self, recipient]:
agent._raise_exception_on_async_reply_functions()
agent.previous_cache = agent.client_cache
agent.client_cache = cache
self._prepare_chat(recipient, clear_history)
self.send(self.generate_init_message(**context), recipient, silent=silent)
for agent in [self, recipient]:
agent.client_cache = agent.previous_cache
agent.previous_cache = None

async def a_initiate_chat(
self,
recipient: "ConversableAgent",
clear_history: Optional[bool] = True,
silent: Optional[bool] = False,
cache: Optional[Cache] = None,
**context,
):
"""(async) Initiate a chat with the recipient agent.
Expand All @@ -706,12 +718,19 @@ async def a_initiate_chat(
recipient: the recipient agent.
clear_history (bool): whether to clear the chat history with the agent.
silent (bool or None): (Experimental) whether to print the messages for this conversation.
cache (Cache or None): the cache client to be used for this conversation.
**context: any context information.
"message" needs to be provided if the `generate_init_message` method is not overridden.
Otherwise, input() will be called to get the initial message.
"""
self._prepare_chat(recipient, clear_history)
for agent in [self, recipient]:
agent.previous_cache = agent.client_cache
agent.client_cache = cache
await self.a_send(await self.a_generate_init_message(**context), recipient, silent=silent)
for agent in [self, recipient]:
agent.client_cache = agent.previous_cache
agent.previous_cache = None

def reset(self):
"""Reset the agent."""
Expand Down Expand Up @@ -778,7 +797,9 @@ def generate_oai_reply(

# TODO: #1143 handle token limit exceeded error
response = client.create(
context=messages[-1].pop("context", None), messages=self._oai_system_message + all_messages
context=messages[-1].pop("context", None),
messages=self._oai_system_message + all_messages,
cache=self.client_cache,
)

extracted_response = client.extract_text_or_completion_object(response)[0]
Expand Down
18 changes: 17 additions & 1 deletion autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,17 @@ def run_chat(
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[GroupChat] = None,
) -> Union[str, Dict, None]:
) -> Tuple[bool, Optional[str]]:
"""Run a group chat."""
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
speaker = sender
groupchat = config
if self.client_cache is not None:
for a in groupchat.agents:
a.previous_cache = a.client_cache
a.client_cache = self.client_cache
for i in range(groupchat.max_round):
groupchat.append(message, speaker)
if self._is_termination_msg(message):
Expand Down Expand Up @@ -389,6 +393,10 @@ def run_chat(
message = self.last_message(speaker)
if i == groupchat.max_round - 1:
groupchat.append(message, speaker)
if self.client_cache is not None:
for a in groupchat.agents:
a.client_cache = a.previous_cache
a.previous_cache = None
return True, None

async def a_run_chat(
Expand All @@ -403,6 +411,10 @@ async def a_run_chat(
message = messages[-1]
speaker = sender
groupchat = config
if self.client_cache is not None:
for a in groupchat.agents:
a.previous_cache = a.client_cache
a.client_cache = self.client_cache
for i in range(groupchat.max_round):
groupchat.append(message, speaker)

Expand Down Expand Up @@ -436,6 +448,10 @@ async def a_run_chat(
# The speaker sends the message without requesting a reply
await speaker.a_send(reply, self, request_reply=False)
message = self.last_message(speaker)
if self.client_cache is not None:
for a in groupchat.agents:
a.client_cache = a.previous_cache
a.previous_cache = None
return True, None

def _raise_exception_on_async_reply_functions(self) -> None:
Expand Down
Empty file added autogen/cache/__init__.py
Empty file.
90 changes: 90 additions & 0 deletions autogen/cache/abstract_cache_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from abc import ABC, abstractmethod


class AbstractCache(ABC):
"""
Abstract base class for cache implementations.
This class defines the basic interface for cache operations.
Implementing classes should provide concrete implementations for
these methods to handle caching mechanisms.
"""

@abstractmethod
def get(self, key, default=None):
"""
Retrieve an item from the cache.
Abstract method that must be implemented by subclasses to
retrieve an item from the cache.
Args:
key (str): The key identifying the item in the cache.
default (optional): The default value to return if the key is not found.
Defaults to None.
Returns:
The value associated with the key if found, else the default value.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""

@abstractmethod
def set(self, key, value):
"""
Set an item in the cache.
Abstract method that must be implemented by subclasses to
store an item in the cache.
Args:
key (str): The key under which the item is to be stored.
value: The value to be stored in the cache.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""

@abstractmethod
def close(self):
"""
Close the cache.
Abstract method that should be implemented by subclasses to
perform any necessary cleanup, such as closing network connections or
releasing resources.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""

@abstractmethod
def __enter__(self):
"""
Enter the runtime context related to this object.
The with statement will bind this method’s return value to the target(s)
specified in the as clause of the statement, if any.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""

@abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
"""
Exit the runtime context and close the cache.
Abstract method that should be implemented by subclasses to handle
the exit from a with statement. It is responsible for resource
release and cleanup.
Args:
exc_type: The exception type if an exception was raised in the context.
exc_value: The exception value if an exception was raised in the context.
traceback: The traceback if an exception was raised in the context.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""
137 changes: 137 additions & 0 deletions autogen/cache/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
from typing import Dict, Any

from autogen.cache.cache_factory import CacheFactory


class Cache:
"""
A wrapper class for managing cache configuration and instances.
This class provides a unified interface for creating and interacting with
different types of cache (e.g., Redis, Disk). It abstracts the underlying
cache implementation details, providing methods for cache operations.
Attributes:
config (Dict[str, Any]): A dictionary containing cache configuration.
cache: The cache instance created based on the provided configuration.
Methods:
redis(cache_seed=42, redis_url="redis://localhost:6379/0"): Static method to create a Redis cache instance.
disk(cache_seed=42, cache_path_root=".cache"): Static method to create a Disk cache instance.
__init__(self, config): Initializes the Cache with the given configuration.
__enter__(self): Context management entry, returning the cache instance.
__exit__(self, exc_type, exc_value, traceback): Context management exit.
get(self, key, default=None): Retrieves an item from the cache.
set(self, key, value): Sets an item in the cache.
close(self): Closes the cache.
"""

ALLOWED_CONFIG_KEYS = ["cache_seed", "redis_url", "cache_path_root"]

@staticmethod
def redis(cache_seed=42, redis_url="redis://localhost:6379/0"):
"""
Create a Redis cache instance.
Args:
cache_seed (int, optional): A seed for the cache. Defaults to 42.
redis_url (str, optional): The URL for the Redis server. Defaults to "redis://localhost:6379/0".
Returns:
Cache: A Cache instance configured for Redis.
"""
return Cache({"cache_seed": cache_seed, "redis_url": redis_url})

@staticmethod
def disk(cache_seed=42, cache_path_root=".cache"):
"""
Create a Disk cache instance.
Args:
cache_seed (int, optional): A seed for the cache. Defaults to 42.
cache_path_root (str, optional): The root path for the disk cache. Defaults to ".cache".
Returns:
Cache: A Cache instance configured for Disk caching.
"""
return Cache({"cache_seed": cache_seed, "cache_path_root": cache_path_root})

def __init__(self, config: Dict[str, Any]):
"""
Initialize the Cache with the given configuration.
Validates the configuration keys and creates the cache instance.
Args:
config (Dict[str, Any]): A dictionary containing the cache configuration.
Raises:
ValueError: If an invalid configuration key is provided.
"""
self.config = config
# validate config
for key in self.config.keys():
if key not in self.ALLOWED_CONFIG_KEYS:
raise ValueError(f"Invalid config key: {key}")
# create cache instance
self.cache = CacheFactory.cache_factory(
self.config.get("cache_seed", "42"),
self.config.get("redis_url", None),
self.config.get("cache_path_root", None),
)

def __enter__(self):
"""
Enter the runtime context related to the cache object.
Returns:
The cache instance for use within a context block.
"""
return self.cache.__enter__()

def __exit__(self, exc_type, exc_value, traceback):
"""
Exit the runtime context related to the cache object.
Cleans up the cache instance and handles any exceptions that occurred
within the context.
Args:
exc_type: The exception type if an exception was raised in the context.
exc_value: The exception value if an exception was raised in the context.
traceback: The traceback if an exception was raised in the context.
"""
return self.cache.__exit__(exc_type, exc_value, traceback)

def get(self, key, default=None):
"""
Retrieve an item from the cache.
Args:
key (str): The key identifying the item in the cache.
default (optional): The default value to return if the key is not found.
Defaults to None.
Returns:
The value associated with the key if found, else the default value.
"""
return self.cache.get(key, default)

def set(self, key, value):
"""
Set an item in the cache.
Args:
key (str): The key under which the item is to be stored.
value: The value to be stored in the cache.
"""
self.cache.set(key, value)

def close(self):
"""
Close the cache.
Perform any necessary cleanup, such as closing connections or releasing resources.
"""
self.cache.close()
Loading

0 comments on commit ee6ad8d

Please sign in to comment.