Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate cachetools for in-memory LM caching, including unhashable types & pydantic #1896

Merged
merged 7 commits into from
Dec 7, 2024

Conversation

dbczumar
Copy link
Collaborator

@dbczumar dbczumar commented Dec 6, 2024

Integrate cachetools for in-memory LM caching, including unhashable types & pydantic

Fixes #1759

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
key=lambda request, *args, **kwargs: cache_key(request),
# Use a lock to ensure thread safety for the cache when DSPy LMs are queried
# concurrently, e.g. during optimization and evaluation
lock=threading.Lock(),
Copy link
Collaborator Author

@dbczumar dbczumar Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cachetools provides thread safety natively. alternatively, we could try to implement our own cache with the required thread safety functionality, but I suspect there might be bugs (best to reuse something that is known to work)

Copy link
Collaborator

@okhat okhat Dec 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a blocker for merge, but I'm slightly uneasy about Python-level locking (compared to whatever functools normally does?). Maybe it's required for thread safety, but since it's happening for every single LM call it's a bit worrisome.

Copy link
Collaborator Author

@dbczumar dbczumar Dec 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @okhat ! Functools uses a Python lock as well (Rlock). Ill follow up with a small PR to use Rlock instead of Lock.

@cached(
# NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead
cache=LRUCache(maxsize=maxsize or float("inf")),
key=lambda request, *args, **kwargs: cache_key(request),
Copy link
Collaborator Author

@dbczumar dbczumar Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the key advantage of cachetools. Unlike lru_cache, it allows us to define a cache key by applying a custom function to one or more arguments, rather than forcing all arguments to be hashed / JSON-encoded, passed to the function, and then decoded afterwards. Encoding / decoding is infeasible for callables.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do a global dspy.settings.request_cache default = LRUCache(maxsize=10_000_000) and then have this function pull from that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this naming, it could be confused with the disk cache though, right? It seems like we'd want some unified way to refer to both caches, or more distinctive naming. Thoughts?

return litellm_completion(
request,
cache={"no-cache": False, "no-store": False},
num_retries=num_retries,
)


def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer have to serialize / deserialize request within the litellm_completion and litellm_text_completion calls

Comment on lines 114 to 124
def test_lm_calls_support_unhashable_types(litellm_test_server, temporary_blank_cache_dir):
api_base, server_log_file_path = litellm_test_server

lm_with_unhashable_callable = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
# Define a callable kwarg for the LM to use during inference
azure_ad_token_provider=lambda *args, **kwargs: None,
)
lm_with_unhashable_callable("Query")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fails on main with:

        )
E       TypeError: <function test_lm_calls_support_unhashable_types.<locals>.<lambda> at 0x31204d5a0> is not JSON serializable

Comment on lines 127 to 139
def test_lm_calls_support_pydantic_models(litellm_test_server, temporary_blank_cache_dir):
api_base, server_log_file_path = litellm_test_server

class ResponseFormat(pydantic.BaseModel):
response: str

lm = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
response_format=ResponseFormat,
)
lm("Query")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fails on main with:

TypeError: <class 'tests.caching.test_caching.test_lm_calls_support_pydantic_models.<locals>.ResponseFormat'> is not JSON serializable

@@ -212,47 +219,82 @@ def copy(self, **kwargs):
return new_instance


@functools.lru_cache(maxsize=None)
def cached_litellm_completion(request, num_retries: int):
def request_cache(maxsize: Optional[int] = None):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@okhat @bahtman @CyrusNuevoDia Thoughts on this approach? See inline comments discussing advantages below

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks cool!

Could set default maxsize = float("inf") here

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
assert azure_openai_lm("azure openai query") == expected_response


def test_text_lms_can_be_queried(litellm_test_server):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're making changes to litellm_text_completion as well, we should have some coverage for LM queries with model_type="text"

Comment on lines +52 to +62
def test_lm_calls_support_unhashable_types(litellm_test_server):
api_base, server_log_file_path = litellm_test_server

lm_with_unhashable_callable = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
# Define a callable kwarg for the LM to use during inference
azure_ad_token_provider=lambda *args, **kwargs: None,
)
lm_with_unhashable_callable("Query")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fails on main with:

E       TypeError: <function test_lm_calls_support_unhashable_types.<locals>.<lambda> at 0x31204d5a0> is not JSON serializable

Comment on lines +65 to +77
def test_lm_calls_support_pydantic_models(litellm_test_server):
api_base, server_log_file_path = litellm_test_server

class ResponseFormat(pydantic.BaseModel):
response: str

lm = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
response_format=ResponseFormat,
)
lm("Query")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fails on main with:

TypeError: <class 'tests.caching.test_caching.test_lm_calls_support_pydantic_models.<locals>.ResponseFormat'> is not JSON serializable

@CyrusNuevoDia
Copy link
Collaborator

Looks awesome! Is there a way to have a global cache that we can dump/load?

@dbczumar
Copy link
Collaborator Author

dbczumar commented Dec 6, 2024

Looks awesome! Is there a way to have a global cache that we can dump/load?

Totally! We can add that if / when we need it by leveraging cachetools LRUCache.items() method

@CyrusNuevoDia
Copy link
Collaborator

Awesome, lgtm! Appreciate you 🙏

Copy link
Collaborator

@CyrusNuevoDia CyrusNuevoDia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One quick fix and lgtm

@cached(
# NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead
cache=LRUCache(maxsize=maxsize or float("inf")),
key=lambda request, *args, **kwargs: cache_key(request),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do a global dspy.settings.request_cache default = LRUCache(maxsize=10_000_000) and then have this function pull from that?

@okhat okhat merged commit a8d1107 into stanfordnlp:main Dec 7, 2024
4 checks passed
isaacbmiller pushed a commit that referenced this pull request Dec 11, 2024
…ypes & pydantic (#1896)

* Impl

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Cachetools add

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Inline

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* tweak

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* fix

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* fix

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Update lm.py

---------

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Using Azure token provider in dspy.LM
3 participants