diff --git a/llm/__init__.py b/llm/__init__.py index 5a083e02..190b9d81 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -14,7 +14,7 @@ from .templates import Template from .plugins import pm import click -from typing import Dict, List +from typing import Dict, List, Optional import json import os import pathlib @@ -94,15 +94,28 @@ def get_model(name): raise UnknownModelError("Unknown model: " + name) -def get_key(key_arg, default_key, env_var=None): - keys = load_keys() - if key_arg in keys: - return keys[key_arg] - if key_arg: - return key_arg +def get_key( + explicit_key: Optional[str], key_alias: str, env_var: Optional[str] = None +) -> Optional[str]: + """ + Return an API key based on a hierarchy of potential sources. + + :param provided_key: A key provided by the user. This may be the key, or an alias of a key in keys.json. + :param key_alias: The alias used to retrieve the key from the keys.json file. + :param env_var: Name of the environment variable to check for the key. + """ + stored_keys = load_keys() + # If user specified an alias, use the key stored for that alias + if explicit_key in stored_keys: + return stored_keys[explicit_key] + if explicit_key: + # User specified a key that's not an alias, use that + return explicit_key + # Environment variables over-ride the default key if env_var and os.environ.get(env_var): return os.environ[env_var] - return keys.get(default_key) + # Return the key stored for the default alias + return stored_keys.get(key_alias) def load_keys():