Skip to content

Commit

Permalink
Litellm dev 11 20 2024 (#6831)
Browse files Browse the repository at this point in the history
* feat(customer_endpoints.py): support passing budget duration via `/customer/new` endpoint

Closes #5651

* docs: add missing params to swagger + api documentation test

* docs: add documentation for all key endpoints

documents all params on swagger

* docs(internal_user_endpoints.py): document all /user/new params

Ensures all params are documented

* docs(team_endpoints.py): add missing documentation for team endpoints

Ensures 100% param documentation on swagger

* docs(organization_endpoints.py): document all org params

Adds documentation for all params in org endpoint

* docs(customer_endpoints.py): add coverage for all params on /customer endpoints

ensures all /customer/* params are documented

* ci(config.yml): add endpoint doc testing to ci/cd

* fix: fix internal_user_endpoints.py

* fix(internal_user_endpoints.py): support 'duration' param

* fix(partner_models/main.py): fix anthropic re-raise exception on vertex

* fix: fix pydantic obj
  • Loading branch information
krrishdholakia authored Nov 20, 2024
1 parent a1f06de commit 689cd67
Show file tree
Hide file tree
Showing 11 changed files with 480 additions and 139 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,7 @@ jobs:
- run: python ./tests/code_coverage_tests/test_router_strategy_async.py
- run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py
- run: python ./tests/documentation_tests/test_env_keys.py
- run: python ./tests/documentation_tests/test_api_docs.py
- run: helm lint ./deploy/charts/litellm-helm

db_migration_disable_update_check:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,4 +236,6 @@ def completion(
)

except Exception as e:
if hasattr(e, "status_code"):
raise e
raise VertexAIError(status_code=500, message=str(e))
138 changes: 95 additions & 43 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,8 @@ class GenerateRequestBase(LiteLLMBase):
Overlapping schema between key and user generate/update requests
"""

key_alias: Optional[str] = None
duration: Optional[str] = None
models: Optional[list] = []
spend: Optional[float] = 0
max_budget: Optional[float] = None
Expand All @@ -635,13 +637,6 @@ class GenerateRequestBase(LiteLLMBase):
budget_duration: Optional[str] = None
allowed_cache_controls: Optional[list] = []
soft_budget: Optional[float] = None


class _GenerateKeyRequest(GenerateRequestBase):
key_alias: Optional[str] = None
key: Optional[str] = None
duration: Optional[str] = None
aliases: Optional[dict] = {}
config: Optional[dict] = {}
permissions: Optional[dict] = {}
model_max_budget: Optional[dict] = (
Expand All @@ -654,6 +649,11 @@ class _GenerateKeyRequest(GenerateRequestBase):
model_tpm_limit: Optional[dict] = None
guardrails: Optional[List[str]] = None
blocked: Optional[bool] = None
aliases: Optional[dict] = {}


class _GenerateKeyRequest(GenerateRequestBase):
key: Optional[str] = None


class GenerateKeyRequest(_GenerateKeyRequest):
Expand Down Expand Up @@ -719,7 +719,7 @@ class LiteLLM_ModelTable(LiteLLMBase):
model_config = ConfigDict(protected_namespaces=())


class NewUserRequest(_GenerateKeyRequest):
class NewUserRequest(GenerateRequestBase):
max_budget: Optional[float] = None
user_email: Optional[str] = None
user_alias: Optional[str] = None
Expand Down Expand Up @@ -786,15 +786,58 @@ class DeleteUserRequest(LiteLLMBase):
AllowedModelRegion = Literal["eu", "us"]


class NewCustomerRequest(LiteLLMBase):
class BudgetNew(LiteLLMBase):
budget_id: Optional[str] = Field(default=None, description="The unique budget id.")
max_budget: Optional[float] = Field(
default=None,
description="Requests will fail if this budget (in USD) is exceeded.",
)
soft_budget: Optional[float] = Field(
default=None,
description="Requests will NOT fail if this is exceeded. Will fire alerting though.",
)
max_parallel_requests: Optional[int] = Field(
default=None, description="Max concurrent requests allowed for this budget id."
)
tpm_limit: Optional[int] = Field(
default=None, description="Max tokens per minute, allowed for this budget id."
)
rpm_limit: Optional[int] = Field(
default=None, description="Max requests per minute, allowed for this budget id."
)
budget_duration: Optional[str] = Field(
default=None,
description="Max duration budget should be set for (e.g. '1hr', '1d', '28d')",
)


class BudgetRequest(LiteLLMBase):
budgets: List[str]


class BudgetDeleteRequest(LiteLLMBase):
id: str


class CustomerBase(LiteLLMBase):
user_id: str
alias: Optional[str] = None
spend: float = 0.0
allowed_model_region: Optional[AllowedModelRegion] = None
default_model: Optional[str] = None
budget_id: Optional[str] = None
litellm_budget_table: Optional[BudgetNew] = None
blocked: bool = False


class NewCustomerRequest(BudgetNew):
"""
Create a new customer, allocate a budget to them
"""

user_id: str
alias: Optional[str] = None # human-friendly alias
blocked: bool = False # allow/disallow requests for this end-user
max_budget: Optional[float] = None
budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[AllowedModelRegion] = (
None # require all user requests to use models in this specific region
Expand Down Expand Up @@ -1083,39 +1126,6 @@ class OrganizationRequest(LiteLLMBase):
organizations: List[str]


class BudgetNew(LiteLLMBase):
budget_id: str = Field(default=None, description="The unique budget id.")
max_budget: Optional[float] = Field(
default=None,
description="Requests will fail if this budget (in USD) is exceeded.",
)
soft_budget: Optional[float] = Field(
default=None,
description="Requests will NOT fail if this is exceeded. Will fire alerting though.",
)
max_parallel_requests: Optional[int] = Field(
default=None, description="Max concurrent requests allowed for this budget id."
)
tpm_limit: Optional[int] = Field(
default=None, description="Max tokens per minute, allowed for this budget id."
)
rpm_limit: Optional[int] = Field(
default=None, description="Max requests per minute, allowed for this budget id."
)
budget_duration: Optional[str] = Field(
default=None,
description="Max duration budget should be set for (e.g. '1hr', '1d', '28d')",
)


class BudgetRequest(LiteLLMBase):
budgets: List[str]


class BudgetDeleteRequest(LiteLLMBase):
id: str


class KeyManagementSystem(enum.Enum):
GOOGLE_KMS = "google_kms"
AZURE_KEY_VAULT = "azure_key_vault"
Expand Down Expand Up @@ -2081,3 +2091,45 @@ class JWTKeyItem(TypedDict, total=False):

class JWKUrlResponse(TypedDict, total=False):
keys: JWKKeyValue


class UserManagementEndpointParamDocStringEnums(str, enum.Enum):
user_id_doc_str = (
"Optional[str] - Specify a user id. If not set, a unique id will be generated."
)
user_alias_doc_str = (
"Optional[str] - A descriptive name for you to know who this user id refers to."
)
teams_doc_str = "Optional[list] - specify a list of team id's a user belongs to."
user_email_doc_str = "Optional[str] - Specify a user email."
send_invite_email_doc_str = (
"Optional[bool] - Specify if an invite email should be sent."
)
user_role_doc_str = """Optional[str] - Specify a user role - "proxy_admin", "proxy_admin_viewer", "internal_user", "internal_user_viewer", "team", "customer". Info about each role here: `https://github.com/BerriAI/litellm/litellm/proxy/_types.py#L20`"""
max_budget_doc_str = """Optional[float] - Specify max budget for a given user."""
budget_duration_doc_str = """Optional[str] - Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"), months ("1mo")."""
models_doc_str = """Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models)"""
tpm_limit_doc_str = (
"""Optional[int] - Specify tpm limit for a given user (Tokens per minute)"""
)
rpm_limit_doc_str = (
"""Optional[int] - Specify rpm limit for a given user (Requests per minute)"""
)
auto_create_key_doc_str = """bool - Default=True. Flag used for returning a key as part of the /user/new response"""
aliases_doc_str = """Optional[dict] - Model aliases for the user - [Docs](https://litellm.vercel.app/docs/proxy/virtual_keys#model-aliases)"""
config_doc_str = """Optional[dict] - [DEPRECATED PARAM] User-specific config."""
allowed_cache_controls_doc_str = """Optional[list] - List of allowed cache control values. Example - ["no-cache", "no-store"]. See all values - https://docs.litellm.ai/docs/proxy/caching#turn-on--off-caching-per-request-"""
blocked_doc_str = (
"""Optional[bool] - [Not Implemented Yet] Whether the user is blocked."""
)
guardrails_doc_str = """Optional[List[str]] - [Not Implemented Yet] List of active guardrails for the user"""
permissions_doc_str = """Optional[dict] - [Not Implemented Yet] User-specific permissions, eg. turning off pii masking."""
metadata_doc_str = """Optional[dict] - Metadata for user, store information for user. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }"""
max_parallel_requests_doc_str = """Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x."""
soft_budget_doc_str = """Optional[float] - Get alerts when user crosses given budget, doesn't block requests."""
model_max_budget_doc_str = """Optional[dict] - Model-specific max budget for user. [Docs](https://docs.litellm.ai/docs/proxy/users#add-model-specific-budgets-to-keys)"""
model_rpm_limit_doc_str = """Optional[float] - Model-specific rpm limit for user. [Docs](https://docs.litellm.ai/docs/proxy/users#add-model-specific-limits-to-keys)"""
model_tpm_limit_doc_str = """Optional[float] - Model-specific tpm limit for user. [Docs](https://docs.litellm.ai/docs/proxy/users#add-model-specific-limits-to-keys)"""
spend_doc_str = """Optional[float] - Amount spent by user. Default is 0. Will be updated by proxy whenever user is used."""
team_id_doc_str = """Optional[str] - [DEPRECATED PARAM] The team id of the user. Default is None."""
duration_doc_str = """Optional[str] - Duration for the key auto-created on `/user/new`. Default is None."""
67 changes: 57 additions & 10 deletions litellm/proxy/management_endpoints/customer_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
"""
CUSTOMER MANAGEMENT
All /customer management endpoints
/customer/new
/customer/info
/customer/update
/customer/delete
"""

#### END-USER/CUSTOMER MANAGEMENT ####
import asyncio
import copy
Expand Down Expand Up @@ -129,6 +140,26 @@ async def unblock_user(data: BlockUsers):
return {"blocked_users": litellm.blocked_user_list}


def new_budget_request(data: NewCustomerRequest) -> Optional[BudgetNew]:
"""
Return a new budget object if new budget params are passed.
"""
budget_params = BudgetNew.model_fields.keys()
budget_kv_pairs = {}

# Get the actual values from the data object using getattr
for field_name in budget_params:
if field_name == "budget_id":
continue
value = getattr(data, field_name, None)
if value is not None:
budget_kv_pairs[field_name] = value

if budget_kv_pairs:
return BudgetNew(**budget_kv_pairs)
return None


@router.post(
"/end_user/new",
tags=["Customer Management"],
Expand Down Expand Up @@ -157,6 +188,11 @@ async def new_end_user(
- allowed_model_region: Optional[Union[Literal["eu"], Literal["us"]]] - Require all user requests to use models in this specific region.
- default_model: Optional[str] - If no equivalent model in the allowed region, default all requests to this model.
- metadata: Optional[dict] = Metadata for customer, store information for customer. Example metadata = {"data_training_opt_out": True}
- budget_duration: Optional[str] - Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
- tpm_limit: Optional[int] - [Not Implemented Yet] Specify tpm limit for a given customer (Tokens per minute)
- rpm_limit: Optional[int] - [Not Implemented Yet] Specify rpm limit for a given customer (Requests per minute)
- max_parallel_requests: Optional[int] - [Not Implemented Yet] Specify max parallel requests for a given customer.
- soft_budget: Optional[float] - [Not Implemented Yet] Get alerts when customer crosses given budget, doesn't block requests.
- Allow specifying allowed regions
Expand Down Expand Up @@ -223,14 +259,19 @@ async def new_end_user(
new_end_user_obj: Dict = {}

## CREATE BUDGET ## if set
if data.max_budget is not None:
budget_record = await prisma_client.db.litellm_budgettable.create(
data={
"max_budget": data.max_budget,
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, # type: ignore
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}
)
_new_budget = new_budget_request(data)
if _new_budget is not None:
try:
budget_record = await prisma_client.db.litellm_budgettable.create(
data={
**_new_budget.model_dump(exclude_unset=True),
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, # type: ignore
"updated_by": user_api_key_dict.user_id
or litellm_proxy_admin_name,
}
)
except Exception as e:
raise HTTPException(status_code=422, detail={"error": str(e)})

new_end_user_obj["budget_id"] = budget_record.budget_id
elif data.budget_id is not None:
Expand All @@ -239,16 +280,22 @@ async def new_end_user(
_user_data = data.dict(exclude_none=True)

for k, v in _user_data.items():
if k != "max_budget" and k != "budget_id":
if k not in BudgetNew.model_fields.keys():
new_end_user_obj[k] = v

## WRITE TO DB ##
end_user_record = await prisma_client.db.litellm_endusertable.create(
data=new_end_user_obj # type: ignore
data=new_end_user_obj, # type: ignore
include={"litellm_budget_table": True},
)

return end_user_record
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.management_endpoints.customer_endpoints.new_end_user(): Exception occured - {}".format(
str(e)
)
)
if "Unique constraint failed on the fields: (`user_id`)" in str(e):
raise ProxyException(
message=f"Customer already exists, passed user_id={data.user_id}. Please pass a new user_id.",
Expand Down
Loading

0 comments on commit 689cd67

Please sign in to comment.