Skip to content

Commit

Permalink
feat: app rate limit (langgenius#5844)
Browse files Browse the repository at this point in the history
Co-authored-by: liuzhenghua-jk <liuzhenghua-jk@360shuke.com>
Co-authored-by: takatost <takatost@gmail.com>
  • Loading branch information
3 people authored Jul 10, 2024
1 parent cc8dc6d commit 9622fbb
Show file tree
Hide file tree
Showing 19 changed files with 277 additions and 56 deletions.
2 changes: 1 addition & 1 deletion api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,4 @@ WORKFLOW_CALL_MAX_DEPTH=5

# App configuration
APP_MAX_EXECUTION_TIME=1200

APP_MAX_ACTIVE_REQUESTS=0
4 changes: 4 additions & 0 deletions api/configs/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class AppExecutionConfig(BaseSettings):
description='execution timeout in seconds for app execution',
default=1200,
)
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description='max active request per app, 0 means unlimited',
default=0,
)


class CodeExecutionSandboxConfig(BaseSettings):
Expand Down
1 change: 1 addition & 0 deletions api/controllers/console/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def put(self, app_model):
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
parser.add_argument('max_active_requests', type=int, location='json')
args = parser.parse_args()

app_service = AppService()
Expand Down
11 changes: 8 additions & 3 deletions api/controllers/console/app/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from controllers.console.wraps import account_initialization_required
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
Expand Down Expand Up @@ -75,7 +80,7 @@ def post(self, app_model):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")
Expand Down Expand Up @@ -141,7 +146,7 @@ def post(self, app_model):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")
Expand Down
3 changes: 2 additions & 1 deletion api/controllers/console/app/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from controllers.console.wraps import account_initialization_required
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import AppInvokeQuotaExceededError
from fields.workflow_fields import workflow_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
Expand Down Expand Up @@ -279,7 +280,7 @@ def post(self, app_model: App):
)

return helper.compact_generate_response(response)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")
Expand Down
11 changes: 8 additions & 3 deletions api/controllers/service_api/app/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
Expand Down Expand Up @@ -69,7 +74,7 @@ def post(self, app_model: App, end_user: EndUser):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")
Expand Down Expand Up @@ -132,7 +137,7 @@ def post(self, app_model: App, end_user: EndUser):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")
Expand Down
9 changes: 7 additions & 2 deletions api/controllers/service_api/app/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from models.model import App, AppMode, EndUser
Expand Down Expand Up @@ -59,7 +64,7 @@ def post(self, app_model: App, end_user: EndUser):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")
Expand Down
1 change: 1 addition & 0 deletions api/core/app/features/rate_limiting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .rate_limit import RateLimit
120 changes: 120 additions & 0 deletions api/core/app/features/rate_limiting/rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import logging
import time
import uuid
from collections.abc import Generator
from datetime import timedelta
from typing import Optional, Union

from core.errors.error import AppInvokeQuotaExceededError
from extensions.ext_redis import redis_client

logger = logging.getLogger(__name__)


class RateLimit:
_MAX_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:max_active_requests"
_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:active_requests"
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict = {}

def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
instance = super().__new__(cls)
cls._instance_dict[client_id] = instance
return cls._instance_dict[client_id]

def __init__(self, client_id: str, max_active_requests: int):
self.max_active_requests = max_active_requests
if hasattr(self, 'initialized'):
return
self.initialized = True
self.client_id = client_id
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
self.last_recalculate_time = float('-inf')
self.flush_cache(use_local_value=True)

def flush_cache(self, use_local_value=False):
self.last_recalculate_time = time.time()
# flush max active requests
if use_local_value or not redis_client.exists(self.max_active_requests_key):
with redis_client.pipeline() as pipe:
pipe.set(self.max_active_requests_key, self.max_active_requests)
pipe.expire(self.max_active_requests_key, timedelta(days=1))
pipe.execute()
else:
with redis_client.pipeline() as pipe:
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
redis_client.expire(self.max_active_requests_key, timedelta(days=1))

# flush max active requests (in-transit request list)
if not redis_client.exists(self.active_requests_key):
return
request_details = redis_client.hgetall(self.active_requests_key)
redis_client.expire(self.active_requests_key, timedelta(days=1))
timeout_requests = [k for k, v in request_details.items() if
time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
if timeout_requests:
redis_client.hdel(self.active_requests_key, *timeout_requests)

def enter(self, request_id: Optional[str] = None) -> str:
if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
self.flush_cache()
if self.max_active_requests <= 0:
return RateLimit._UNLIMITED_REQUEST_ID
if not request_id:
request_id = RateLimit.gen_request_key()

active_requests_count = redis_client.hlen(self.active_requests_key)
if active_requests_count >= self.max_active_requests:
raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
"concurrent requests allowed is {}.".format(self.max_active_requests))
redis_client.hset(self.active_requests_key, request_id, str(time.time()))
return request_id

def exit(self, request_id: str):
if request_id == RateLimit._UNLIMITED_REQUEST_ID:
return
redis_client.hdel(self.active_requests_key, request_id)

@staticmethod
def gen_request_key() -> str:
return str(uuid.uuid4())

def generate(self, generator: Union[Generator, callable, dict], request_id: str):
if isinstance(generator, dict):
return generator
else:
return RateLimitGenerator(self, generator, request_id)


class RateLimitGenerator:
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
self.rate_limit = rate_limit
if callable(generator):
self.generator = generator()
else:
self.generator = generator
self.request_id = request_id
self.closed = False

def __iter__(self):
return self

def __next__(self):
if self.closed:
raise StopIteration
try:
return next(self.generator)
except StopIteration:
self.close()
raise

def close(self):
if not self.closed:
self.closed = True
self.rate_limit.exit(self.request_id)
if self.generator is not None and hasattr(self.generator, 'close'):
self.generator.close()
7 changes: 7 additions & 0 deletions api/core/errors/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ class QuotaExceededError(Exception):
description = "Quota Exceeded"


class AppInvokeQuotaExceededError(Exception):
"""
Custom exception raised when the quota for an app has been exceeded.
"""
description = "App Invoke Quota Exceeded"


class ModelCurrentlyNotSupportError(Exception):
"""
Custom exception raised when the model not support
Expand Down
1 change: 1 addition & 0 deletions api/fields/app_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
app_partial_fields = {
'id': fields.String,
'name': fields.String,
'max_active_requests': fields.Raw(),
'description': fields.String(attribute='desc_or_prompt'),
'mode': fields.String(attribute='mode_compatible_with_agent'),
'icon': fields.String,
Expand Down
9 changes: 9 additions & 0 deletions api/libs/external_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from werkzeug.datastructures import Headers
from werkzeug.exceptions import HTTPException

from core.errors.error import AppInvokeQuotaExceededError


class ExternalApi(Api):

Expand Down Expand Up @@ -43,6 +45,13 @@ def handle_error(self, e):
'message': str(e),
'status': status_code
}
elif isinstance(e, AppInvokeQuotaExceededError):
status_code = 429
default_data = {
'code': 'too_many_requests',
'message': str(e),
'status': status_code
}
else:
status_code = 500
default_data = {
Expand Down
33 changes: 33 additions & 0 deletions api/migrations/versions/408176b91ad3_add_max_active_requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""'add_max_active_requests'
Revision ID: 408176b91ad3
Revises: 7e6a8693e07a
Create Date: 2024-07-04 09:25:14.029023
"""
import sqlalchemy as sa
from alembic import op

import models as models

# revision identifiers, used by Alembic.
revision = '408176b91ad3'
down_revision = '161cadc1af8d'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('apps', schema=None) as batch_op:
batch_op.add_column(sa.Column('max_active_requests', sa.Integer(), nullable=True))

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('apps', schema=None) as batch_op:
batch_op.drop_column('max_active_requests')

# ### end Alembic commands ###
1 change: 1 addition & 0 deletions api/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class App(db.Model):
is_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
tracing = db.Column(db.Text, nullable=True)
max_active_requests = db.Column(db.Integer, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

Expand Down
Loading

0 comments on commit 9622fbb

Please sign in to comment.