Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python API support for passing an API key to a model #744

Open
simonw opened this issue Feb 11, 2025 · 14 comments
Open

Python API support for passing an API key to a model #744

simonw opened this issue Feb 11, 2025 · 14 comments

Comments

@simonw
Copy link
Owner

simonw commented Feb 11, 2025

For some applications it may be useful to provide API keys for models at runtime when a prompt is executed. That's not very easy right now - you can set model.key = "x" but that's then shared across all uses of that model instance. This is bad for multi-user environments like web applications, plus there's no easy way to create new model instances - llm.get_model(model_id) returns a single shared object.

I don't want to break existing code here, so I'm going to make this a new optional argument to model.prompt("prompt", key=...).

@simonw
Copy link
Owner Author

simonw commented Feb 11, 2025

This is going to be a fiddle. Here's how .prompt() works right now:

llm/llm/models.py

Lines 639 to 659 in f67c215

def prompt(
self,
prompt: str,
*,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
stream: bool = True,
**options,
) -> Response:
self._validate_attachments(attachments)
return Response(
Prompt(
prompt,
attachments=attachments,
system=system,
model=self,
options=self.Options(**options),
),
self,
stream,
)

Also relevant:

llm/llm/models.py

Lines 386 to 398 in f67c215

def __iter__(self) -> Iterator[str]:
self._start = time.monotonic()
self._start_utcnow = datetime.datetime.now(datetime.timezone.utc)
if self._done:
yield from self._chunks
return
for chunk in self.model.execute(
self.prompt,
stream=self.stream,
response=self,
conversation=self.conversation,
):

So I have two options here - I could stash the optional key on the Prompt instance, or I could stash it on that Response.

Either way, the .execute() method provided by the underlying implementation needs to know where to look for it. Currently the common pattern is for those plugins to call self.get_key() at some point, which resolves the key like this:

llm/llm/models.py

Lines 567 to 591 in f67c215

def get_key(self):
from llm import get_key
if self.needs_key is None:
# This model doesn't use an API key
return None
if self.key is not None:
# Someone already set model.key='...'
return self.key
# Attempt to load a key using llm.get_key()
key = get_key(
explicit_key=None, key_alias=self.needs_key, env_var=self.key_env_var
)
if key:
return key
# Show a useful error message
message = "No key found - add one using 'llm keys set {}'".format(
self.needs_key
)
if self.key_env_var:
message += " or set the {} environment variable".format(self.key_env_var)
raise NeedsKeyException(message)

But now we need that mechanism to have visibility into either the Prompt or arguments that were passed to the updated .execute() method.

@simonw
Copy link
Owner Author

simonw commented Feb 11, 2025

New idea! What if I use dependency injection here?

If your plugin provides a .execute() method with a key= parameter the framework passes an API key for that plugin to use.

That way I can upgrade plugins for the new mechanism.

@simonw simonw pinned this issue Feb 11, 2025
@simonw
Copy link
Owner Author

simonw commented Feb 11, 2025

... in which case the underlying code implementation will be for that Response class to get a ._api_key property which stashes the value that was passed to model.prompt() - then it will pass it to self.model.execute() later if that method has the parameter.

@simonw
Copy link
Owner Author

simonw commented Feb 12, 2025

I have a working prototype for this now. Need to document it, test it and also decide what to do about conversations as opposed to just straight up model.prompt() calls.

@simonw
Copy link
Owner Author

simonw commented Feb 12, 2025

Here's that prototype sketch, it's making mypy unhappy though:

diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py
index 0a9dab2..177660f 100644
--- a/llm/default_plugins/openai_models.py
+++ b/llm/default_plugins/openai_models.py
@@ -474,7 +474,7 @@ class _Shared:
             input=input_tokens, output=output_tokens, details=simplify_usage_dict(usage)
         )
 
-    def get_client(self, async_=False):
+    def get_client(self, *, async_=False, key=None):
         kwargs = {}
         if self.api_base:
             kwargs["base_url"] = self.api_base
@@ -485,7 +485,7 @@ class _Shared:
         if self.api_engine:
             kwargs["engine"] = self.api_engine
         if self.needs_key:
-            kwargs["api_key"] = self.get_key()
+            kwargs["api_key"] = self.get_key(key)
         else:
             # OpenAI-compatible models don't need a key, but the
             # openai client library requires one
@@ -522,12 +522,12 @@ class Chat(_Shared, Model):
             default=None,
         )
 
-    def execute(self, prompt, stream, response, conversation=None):
+    def execute(self, prompt, stream, response, conversation=None, key=None):
         if prompt.system and not self.allows_system_prompt:
             raise NotImplementedError("Model does not support system prompts")
         messages = self.build_messages(prompt, conversation)
         kwargs = self.build_kwargs(prompt, stream)
-        client = self.get_client()
+        client = self.get_client(key=key)
         usage = None
         if stream:
             completion = client.chat.completions.create(
@@ -574,13 +574,13 @@ class AsyncChat(_Shared, AsyncModel):
         )
 
     async def execute(
-        self, prompt, stream, response, conversation=None
+        self, prompt, stream, response, conversation=None, key=None
     ) -> AsyncGenerator[str, None]:
         if prompt.system and not self.allows_system_prompt:
             raise NotImplementedError("Model does not support system prompts")
         messages = self.build_messages(prompt, conversation)
         kwargs = self.build_kwargs(prompt, stream)
-        client = self.get_client(async_=True)
+        client = self.get_client(async_=True, key=key)
         usage = None
         if stream:
             completion = await client.chat.completions.create(
diff --git a/llm/models.py b/llm/models.py
index 2bed85e..25b23a9 100644
--- a/llm/models.py
+++ b/llm/models.py
@@ -5,6 +5,7 @@ import datetime
 from .errors import NeedsKeyException
 import hashlib
 import httpx
+import inspect
 from itertools import islice
 import re
 import time
@@ -204,11 +205,13 @@ class _BaseResponse:
         model: "_BaseModel",
         stream: bool,
         conversation: Optional[_BaseConversation] = None,
+        key: Optional[str] = None,
     ):
         self.prompt = prompt
         self._prompt_json = None
         self.model = model
         self.stream = stream
+        self._key = key
         self._chunks: List[str] = []
         self._done = False
         self.response_json = None
@@ -390,12 +393,15 @@ class Response(_BaseResponse):
             yield from self._chunks
             return
 
-        for chunk in self.model.execute(
-            self.prompt,
-            stream=self.stream,
-            response=self,
-            conversation=self.conversation,
-        ):
+        kwargs = {
+            "stream": self.stream,
+            "response": self,
+            "conversation": self.conversation,
+        }
+        if _accepts_parameter(self.model.execute, "key"):
+            kwargs["key"] = self.model.get_key(self._key)
+
+        for chunk in self.model.execute(self.prompt, **kwargs):
             yield chunk
             self._chunks.append(chunk)
 
@@ -447,12 +453,14 @@ class AsyncResponse(_BaseResponse):
             return chunk
 
         if not hasattr(self, "_generator"):
-            self._generator = self.model.execute(
-                self.prompt,
-                stream=self.stream,
-                response=self,
-                conversation=self.conversation,
-            )
+            kwargs = {
+                "stream": self.stream,
+                "response": self,
+                "conversation": self.conversation,
+            }
+            if _accepts_parameter(self.model.execute, "key"):
+                kwargs["key"] = self.model.get_key(self._key)
+            self._generator = self.model.execute(self.prompt, **kwargs)
 
         try:
             chunk = await self._generator.__anext__()
@@ -564,7 +572,7 @@ _Options = Options
 
 
 class _get_key_mixin:
-    def get_key(self):
+    def get_key(self, explicit_key: Optional[str] = None) -> Optional[str]:
         from llm import get_key
 
         if self.needs_key is None:
@@ -577,7 +585,9 @@ class _get_key_mixin:
 
         # Attempt to load a key using llm.get_key()
         key = get_key(
-            explicit_key=None, key_alias=self.needs_key, env_var=self.key_env_var
+            explicit_key=explicit_key,
+            key_alias=self.needs_key,
+            env_var=self.key_env_var,
         )
         if key:
             return key
@@ -633,6 +643,7 @@ class Model(_BaseModel):
         stream: bool,
         response: Response,
         conversation: Optional[Conversation],
+        key: Optional[str] = None,
     ) -> Iterator[str]:
         pass
 
@@ -643,6 +654,7 @@ class Model(_BaseModel):
         attachments: Optional[List[Attachment]] = None,
         system: Optional[str] = None,
         stream: bool = True,
+        key: Optional[str] = None,
         **options,
     ) -> Response:
         self._validate_attachments(attachments)
@@ -656,6 +668,7 @@ class Model(_BaseModel):
             ),
             self,
             stream,
+            key=key,
         )
 
 
@@ -670,6 +683,7 @@ class AsyncModel(_BaseModel):
         stream: bool,
         response: AsyncResponse,
         conversation: Optional[AsyncConversation],
+        key: Optional[str] = None,
     ) -> AsyncGenerator[str, None]:
         yield ""
 
@@ -680,6 +694,7 @@ class AsyncModel(_BaseModel):
         attachments: Optional[List[Attachment]] = None,
         system: Optional[str] = None,
         stream: bool = True,
+        key: Optional[str] = None,
         **options,
     ) -> AsyncResponse:
         self._validate_attachments(attachments)
@@ -693,6 +708,7 @@ class AsyncModel(_BaseModel):
             ),
             self,
             stream,
+            key=key,
         )
 
 
@@ -780,3 +796,7 @@ def _conversation_name(text):
     if len(text) <= CONVERSATION_NAME_LENGTH:
         return text
     return text[: CONVERSATION_NAME_LENGTH - 1] + "…"
+
+
+def _accepts_parameter(callable: Callable, parameter: str) -> bool:
+    return parameter in inspect.signature(callable).parameters

@simonw
Copy link
Owner Author

simonw commented Feb 12, 2025

Also need to consider if this design should apply to embedding methods too. It probably should.

@simonw
Copy link
Owner Author

simonw commented Feb 12, 2025

Documentation for this will go in advanced model plugins: https://llm.datasette.io/en/stable/plugins/advanced-model-plugins.html

@simonw
Copy link
Owner Author

simonw commented Feb 12, 2025

After much digging I don't think it's possible to define an ABC with a @abstractmethod for execute() that specifies that the key= keyword argument is optional. That messes up my design somewhat.

Options:

  1. Come up with a completely different design. My desire to NOT break existing plugins makes that hard.
  2. Drop the mypy support
  3. Maybe something involving protocols instead of ABCs? Not sure about that yet. https://chatgpt.com/share/67ac15b3-63e4-8006-a618-249f3e9f29d1

@simonw
Copy link
Owner Author

simonw commented Feb 12, 2025

This seems to work though:

from abc import ABC, abstractmethod


class Base(ABC):
    @abstractmethod
    def execute(self, **kwargs):
        print(kwargs)


class One(Base):
    def execute(self, **kwargs):
        print("One", kwargs)


class Two(Base):
    def execute(self, a, b):
        print("Two", a, b)


if __name__ == "__main__":
    one = One()
    one.execute(a=1, b=2)

    two = Two()
    two.execute(a=1, b=2)

mypy against that says no errors.

@simonw
Copy link
Owner Author

simonw commented Feb 12, 2025

I don't understand why this file passes mypy:

from abc import ABC, abstractmethod


class Base(ABC):
    @abstractmethod
    def execute(self, a: int, b: int):
        print(a, b)


class Two(Base):
    def execute(self, a, b, d):
        print("Two", a, b, d)


if __name__ == "__main__":
    two = Two()
    two.execute(a=1, b=2, d=4)
mypy base.py
Success: no issues found in 1 source file

@simonw
Copy link
Owner Author

simonw commented Feb 12, 2025

... figured that out. If I explicitly tell it that two is of type Base I get failures:

from abc import ABC, abstractmethod


class Base(ABC):
    @abstractmethod
    def execute(self, a: int, b: int):
        print(a, b)


class Two(Base):
    def execute(self, a, b, d):
        print("Two", a, b, d)


if __name__ == "__main__":
    two: Base = Two()
    two.execute(a=1, b=2, d=4)
base.py:6: note: "execute" of "Base" defined here
base.py:17: error: Unexpected keyword argument "d" for "execute" of "Base"  [call-arg]
Found 1 error in 1 file (checked 1 source file)

@simonw
Copy link
Owner Author

simonw commented Feb 12, 2025

I tried to solve this with TypeGuard but it got very messy, partly due to the whole Response v.s. AsyncResponse thing: https://gist.github.com/simonw/0bcbdedd734562ccdbb78862194869d0

@simonw
Copy link
Owner Author

simonw commented Feb 12, 2025

The best option at this point may be to abandon ABC and @abstractmethod for the model subclasses entirely. Does that gain me enough that it's worth discarding the optional key= parameter design?

@simonw
Copy link
Owner Author

simonw commented Feb 12, 2025

Or... how about if I have a llm.Model base class and a separate llm.ModelWithKey base class, and tell plugins to subclass the appropriate one?

Then I could have the code that calls .execute() change how it sends parameters based on an isinstance() check.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant