Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve redis semantic cache implementation #5412

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ tiktoken
importlib_metadata
cohere
redis
redisvl==0.3.2
anthropic
orjson==3.9.15
pydantic==2.7.1
Expand Down
3 changes: 0 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ RUN pip install dist/*.whl
# install dependencies as wheels
RUN pip wheel --no-cache-dir --wheel-dir=/wheels/ -r requirements.txt

# install semantic-cache [Experimental]- we need this here and not in requirements.txt because redisvl pins to pydantic 1.0
RUN pip install redisvl==0.0.7 --no-deps

# ensure pyjwt is used, not jwt
RUN pip uninstall jwt -y
RUN pip uninstall PyJWT -y
Expand Down
3 changes: 0 additions & 3 deletions Dockerfile.database
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ COPY --from=builder /wheels/ /wheels/
# Install the built wheel using pip; again using a wildcard if it's the only file
RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels

# install semantic-cache [Experimental]- we need this here and not in requirements.txt because redisvl pins to pydantic 1.0
RUN pip install redisvl==0.0.7 --no-deps

# ensure pyjwt is used, not jwt
RUN pip uninstall jwt -y
RUN pip uninstall PyJWT -y
Expand Down
3 changes: 0 additions & 3 deletions Dockerfile.non_root
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ COPY --from=builder /wheels/ /wheels/
# Install the built wheel using pip; again using a wildcard if it's the only file
RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels

# install semantic-cache [Experimental]- we need this here and not in requirements.txt because redisvl pins to pydantic 1.0
RUN pip install redisvl==0.0.7 --no-deps

# ensure pyjwt is used, not jwt
RUN pip uninstall jwt -y
RUN pip uninstall PyJWT -y
Expand Down
2 changes: 1 addition & 1 deletion docs/my-website/docs/caching/all_caches.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ response2 = completion(

Install redis
```shell
pip install redisvl==0.0.7
pip install redisvl==0.3.2
```

For the hosted version you can setup your own Redis DB here: https://app.redislabs.com/
Expand Down
223 changes: 82 additions & 141 deletions litellm/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,8 @@ def delete_cache(self, key):


class RedisSemanticCache(BaseCache):
DEFAULT_REDIS_INDEX_NAME = "litellm_semantic_cache_index"

def __init__(
self,
host=None,
Expand All @@ -922,38 +924,26 @@ def __init__(
similarity_threshold=None,
use_async=False,
embedding_model="text-embedding-ada-002",
index_name=None,
**kwargs,
):
from redisvl.index import SearchIndex
from redisvl.query import VectorQuery
from redisvl.extensions.llmcache import SemanticCache
from redisvl.utils.vectorize import CustomTextVectorizer

if index_name is None:
index_name = self.DEFAULT_REDIS_INDEX_NAME

print_verbose(
"redis semantic-cache initializing INDEX - litellm_semantic_cache_index"
f"redis semantic-cache initializing INDEX - {index_name}"
)

if similarity_threshold is None:
raise Exception("similarity_threshold must be provided, passed None")

self.similarity_threshold = similarity_threshold
self.distance_threshold = 1 - similarity_threshold
self.embedding_model = embedding_model
schema = {
"index": {
"name": "litellm_semantic_cache_index",
"prefix": "litellm",
"storage_type": "hash",
},
"fields": {
"text": [{"name": "response"}],
"text": [{"name": "prompt"}],
"vector": [
{
"name": "litellm_embedding",
"dims": 1536,
"distance_metric": "cosine",
"algorithm": "flat",
"datatype": "float32",
}
],
},
}

if redis_url is None:
# if no url passed, check if host, port and password are passed, if not raise an Exception
if host is None or port is None or password is None:
Expand All @@ -967,20 +957,29 @@ def __init__(
raise Exception("Redis host, port, and password must be provided")

redis_url = "redis://:" + password + "@" + host + ":" + port

print_verbose(f"redis semantic-cache redis_url: {redis_url}")
if use_async == False:
self.index = SearchIndex.from_dict(schema)
self.index.connect(redis_url=redis_url)
try:
self.index.create(overwrite=False) # don't overwrite existing index
except Exception as e:
print_verbose(f"Got exception creating semantic cache index: {str(e)}")
elif use_async == True:
schema["index"]["name"] = "litellm_semantic_cache_index_async"
self.index = SearchIndex.from_dict(schema)
self.index.connect(redis_url=redis_url, use_async=True)

#
def generate_cache_embeddings(prompt: str):
# create an embedding from prompt
embedding_response = litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
embedding = embedding_response["data"][0]["embedding"]
return embedding

cache_vectorizer = CustomTextVectorizer(generate_cache_embeddings)

self.llmcache = SemanticCache(
name=index_name,
redis_url=redis_url,
vectorizer=cache_vectorizer,
distance_threshold=self.distance_threshold,
overwrite=False
)

def _get_cache_logic(self, cached_response: Any):
"""
Common 'get_cache_logic' across sync + async redis client implementations
Expand All @@ -998,111 +997,70 @@ def _get_cache_logic(self, cached_response: Any):
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)

return cached_response

def set_cache(self, key, value, **kwargs):
import numpy as np

print_verbose(f"redis semantic-cache set_cache, kwargs: {kwargs}")

# get the prompt
# get the prompt and value
messages = kwargs["messages"]
prompt = "".join(message["content"] for message in messages)

# create an embedding for prompt
embedding_response = litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)

# get the embedding
embedding = embedding_response["data"][0]["embedding"]

# make the embedding a numpy array, convert to bytes
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
value = str(value)
assert isinstance(value, str)

new_data = [
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
]

# Add more data
keys = self.index.load(new_data)
# store in redis semantic cache
self.llmcache.store(
prompt=prompt,
response=value
)

return

def get_cache(self, key, **kwargs):
print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
import numpy as np
from redisvl.query import VectorQuery

# query
# get the messages
messages = kwargs["messages"]
prompt = "".join(message["content"] for message in messages)

# convert to embedding
embedding_response = litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)

# get the embedding
embedding = embedding_response["data"][0]["embedding"]

query = VectorQuery(
vector=embedding,
vector_field_name="litellm_embedding",
return_fields=["response", "prompt", "vector_distance"],
num_results=1,
)
# check the cache
results = self.llmcache.check(prompt=prompt)

results = self.index.query(query)
if results == None:
# handle results / cache hit
if not results:
return None
if isinstance(results, list):
if len(results) == 0:
return None

vector_distance = results[0]["vector_distance"]
vector_distance = float(vector_distance)
cache_hit = results[0]
vector_distance = float(cache_hit["vector_distance"])
similarity = 1 - vector_distance
cached_prompt = results[0]["prompt"]
cached_prompt = cache_hit["prompt"]
cached_response = cache_hit["response"]

# check similarity, if more than self.similarity_threshold, return results
print_verbose(
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
f"got a cache hit: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, current_prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
)
if similarity > self.similarity_threshold:
# cache hit !
cached_value = results[0]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_value)
else:
# cache miss !
return None
return self._get_cache_logic(cached_response=cached_response)

pass

async def async_set_cache(self, key, value, **kwargs):
import numpy as np
# TODO - patch async support in redisvl for SemanticCache
tylerhutcherson marked this conversation as resolved.
Show resolved Hide resolved

from litellm.proxy.proxy_server import llm_model_list, llm_router

try:
await self.index.acreate(overwrite=False) # don't overwrite existing index
except Exception as e:
print_verbose(f"Got exception creating semantic cache index: {str(e)}")

print_verbose(f"async redis semantic-cache set_cache, kwargs: {kwargs}")

# get the prompt
# get the prompt and value
messages = kwargs["messages"]
prompt = "".join(message["content"] for message in messages)
value = str(value)
assert isinstance(value, str)

# create an embedding for prompt
router_model_names = (
[m["model_name"] for m in llm_model_list]
Expand Down Expand Up @@ -1132,23 +1090,19 @@ async def async_set_cache(self, key, value, **kwargs):
# get the embedding
embedding = embedding_response["data"][0]["embedding"]

# make the embedding a numpy array, convert to bytes
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
value = str(value)
assert isinstance(value, str)

new_data = [
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
]
# store in redis semantic cache
self.llmcache.store(
prompt=prompt,
response=value,
vector=embedding # pass through custom embedding here
)

# Add more data
keys = await self.index.aload(new_data)
return

async def async_get_cache(self, key, **kwargs):
# TODO - patch async support in redisvl for SemanticCache

print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
import numpy as np
from redisvl.query import VectorQuery

from litellm.proxy.proxy_server import llm_model_list, llm_router

Expand Down Expand Up @@ -1185,47 +1139,34 @@ async def async_get_cache(self, key, **kwargs):
# get the embedding
embedding = embedding_response["data"][0]["embedding"]

query = VectorQuery(
vector=embedding,
vector_field_name="litellm_embedding",
return_fields=["response", "prompt", "vector_distance"],
# check the cache
results = self.llmcache.check(
prompt=prompt, vector=embedding
tylerhutcherson marked this conversation as resolved.
Show resolved Hide resolved
)
results = await self.index.aquery(query)
if results == None:

# handle results / cache hit
if not results:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
if isinstance(results, list):
if len(results) == 0:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None

vector_distance = results[0]["vector_distance"]
vector_distance = float(vector_distance)
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity

cache_hit = results[0]
vector_distance = float(cache_hit["vector_distance"])
similarity = 1 - vector_distance
cached_prompt = results[0]["prompt"]
cached_prompt = cache_hit["prompt"]
cached_response = cache_hit["response"]

# check similarity, if more than self.similarity_threshold, return results
print_verbose(
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
f"got a cache hit: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, current_prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_response)

# update kwargs["metadata"] with similarity, don't rewrite the original metadata
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity

if similarity > self.similarity_threshold:
# cache hit !
cached_value = results[0]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_value)
else:
# cache miss !
return None
pass

async def _index_info(self):
return await self.index.ainfo()
# TODO - patch async support in redisvl for SemanticCache
return self.llmcache.index.info()


class QdrantSemanticCache(BaseCache):
Expand Down
Loading