From a04542e80b73a0209e50716df468c0358631072c Mon Sep 17 00:00:00 2001 From: yym68686 Date: Sun, 13 Oct 2024 04:20:06 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Feature:=20Add=20support=20for=20fr?= =?UTF-8?q?ontend=20page=20operation=20configuration=20files.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 514 ++++++++++++++++++++++++++++++++++++++++-- test/xue/test_home.py | 20 +- 2 files changed, 504 insertions(+), 30 deletions(-) diff --git a/main.py b/main.py index 5ec8f1f..3545931 100644 --- a/main.py +++ b/main.py @@ -3,12 +3,12 @@ import re import httpx import secrets -import time as time_module +from time import time from contextlib import asynccontextmanager from starlette.middleware.base import BaseHTTPMiddleware from fastapi.middleware.cors import CORSMiddleware -from fastapi import FastAPI, HTTPException, Depends, Request +from fastapi import FastAPI, HTTPException, Depends, Request, APIRouter from fastapi.responses import JSONResponse from fastapi.responses import StreamingResponse as FastAPIStreamingResponse from starlette.responses import StreamingResponse as StarletteStreamingResponse @@ -77,6 +77,13 @@ def _get_default_sql(default): @asynccontextmanager async def lifespan(app: FastAPI): + # print("Main app routes:") + # for route in app.routes: + # print(f"Route: {route.path}, methods: {route.methods}") + + # print("\nFrontend router routes:") + # for route in frontend_router.routes: + # print(f"Route: {route.path}, methods: {route.methods}") # 启动时的代码 await create_tables() @@ -95,6 +102,16 @@ async def lifespan(app: FastAPI): ) # app.state.client = httpx.AsyncClient(timeout=timeout) app.state.config, app.state.api_keys_db, app.state.api_list = await load_config(app) + + for item in app.state.api_keys_db: + if item.get("role") == "admin": + app.state.admin_api_key = item.get("api") + if not hasattr(app.state, "admin_api_key"): + if len(app.state.api_keys_db) >= 1: + app.state.admin_api_key = app.state.api_keys_db[0].get("api") + else: + raise Exception("No admin API key found") + yield # 关闭时的代码 await app.state.client.aclose() @@ -113,7 +130,6 @@ async def http_exception_handler(request: Request, exc: HTTPException): import uuid import json import asyncio -from time import time import contextvars request_info = contextvars.ContextVar('request_info', default={}) @@ -391,18 +407,19 @@ async def dispatch(self, request: Request, call_next): try: response = await call_next(request) - if isinstance(response, (FastAPIStreamingResponse, StarletteStreamingResponse)) or type(response).__name__ == '_StreamingResponse': - response = LoggingStreamingResponse( - content=response.body_iterator, - status_code=response.status_code, - media_type=response.media_type, - headers=response.headers, - current_info=current_info, - ) - elif hasattr(response, 'json'): - logger.info(f"Response: {await response.json()}") - else: - logger.info(f"Response: type={type(response).__name__}, status_code={response.status_code}, headers={response.headers}") + if request.url.path.startswith("/v1"): + if isinstance(response, (FastAPIStreamingResponse, StarletteStreamingResponse)) or type(response).__name__ == '_StreamingResponse': + response = LoggingStreamingResponse( + content=response.body_iterator, + status_code=response.status_code, + media_type=response.media_type, + headers=response.headers, + current_info=current_info, + ) + elif hasattr(response, 'json'): + logger.info(f"Response: {await response.json()}") + else: + logger.info(f"Response: type={type(response).__name__}, status_code={response.status_code}, headers={response.headers}") return response finally: @@ -793,7 +810,7 @@ def __init__(self): self.requests = defaultdict(list) async def is_rate_limited(self, key: str, limit: int, period: int) -> bool: - now = time_module.time() + now = time() self.requests[key] = [req for req in self.requests[key] if req > now - period] if len(self.requests[key]) >= limit: return True @@ -910,7 +927,7 @@ async def audio_transcriptions( traceback.print_exc() raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}") -@app.get("/generate-api-key", dependencies=[Depends(rate_limit_dependency)]) +@app.get("/v1/generate-api-key", dependencies=[Depends(rate_limit_dependency)]) def generate_api_key(): # Define the character set (only alphanumeric) chars = string.ascii_letters + string.digits @@ -924,7 +941,7 @@ def generate_api_key(): from sqlalchemy import func, desc, case from fastapi import Query -@app.get("/stats", dependencies=[Depends(rate_limit_dependency)]) +@app.get("/v1/stats", dependencies=[Depends(rate_limit_dependency)]) async def get_stats( request: Request, token: str = Depends(verify_admin_api_key), @@ -1026,6 +1043,467 @@ async def get_stats( return JSONResponse(content=stats) + + +from fastapi import FastAPI, Request +from fastapi import Form as FastapiForm, HTTPException, Depends +from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse +from fastapi.security import APIKeyHeader +from typing import Optional, List + +from xue import HTML, Head, Body, Div, xue_initialize, Script +from xue.components.menubar import ( + Menubar, MenubarMenu, MenubarTrigger, MenubarContent, + MenubarItem, MenubarSeparator +) +from xue.components import input +from xue.components import dropdown, sheet, form, button, checkbox +from xue.components.model_config_row import model_config_row +# import sys +# import os +# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +from components.provider_table import data_table + +from ruamel.yaml import YAML +yaml = YAML() +yaml.preserve_quotes = True +yaml.indent(mapping=2, sequence=4, offset=2) + + +frontend_router = APIRouter() + +API_KEY_NAME = "X-API-Key" +api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) +async def get_api_key(request: Request, x_api_key: Optional[str] = Depends(api_key_header)): + if not x_api_key: + x_api_key = request.cookies.get("x_api_key") or request.query_params.get("x_api_key") + # print(f"Cookie x_api_key: {request.cookies.get('x_api_key')}") # 添加此行 + # print(f"Query param x_api_key: {request.query_params.get('x_api_key')}") # 添加此行 + # print(f"Header x_api_key: {x_api_key}") # 添加此行 + # logger.info(f"x_api_key: {x_api_key} {x_api_key == 'your_admin_api_key'}") + + if x_api_key == app.state.admin_api_key: # 替换为实际的管理员API密钥 + return x_api_key + else: + return None + +async def frontend_rate_limit_dependency(request: Request, x_api_key: str = Depends(get_api_key)): + token = x_api_key if x_api_key else None + limit, period = 100, 60 + + # 使用 IP 地址和 token(如果有)作为限制键 + client_ip = request.client.host + rate_limit_key = f"{client_ip}:{token}" if token else client_ip + + if await rate_limiter.is_rate_limited(rate_limit_key, limit, period): + raise HTTPException(status_code=429, detail="Too many requests") + +# def get_backend_router_api_list(): +# api_list = [] +# for route in frontend_router.routes: +# api_list.append({ +# "path": f"/api{route.path}", # 加上前缀 +# "method": route.methods, +# "name": route.name, +# "summary": route.summary +# }) +# return api_list + +# @app.get("/backend-router-api-list") +# async def backend_router_api_list(): +# return get_backend_router_api_list() + +xue_initialize(tailwind=True) + +API_YAML_PATH = "./api.yaml" + +data_table_columns = [ + # {"label": "Status", "value": "status", "sortable": True}, + {"label": "Provider", "value": "provider", "sortable": True}, + {"label": "Base url", "value": "base_url", "sortable": True}, + # {"label": "Engine", "value": "engine", "sortable": True}, + {"label": "Tools", "value": "tools", "sortable": True}, +] + +@frontend_router.get("/login", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) +async def login_page(): + return HTML( + Head(title="登录"), + Body( + Div( + form.Form( + form.FormField("API Key", "x_api_key", type="password", placeholder="输入API密钥", required=True), + Div(id="error-message", class_="text-red-500 mt-2"), + Div( + button.button("提交", variant="primary", type="submit"), + class_="flex justify-end mt-4" + ), + hx_post="/verify-api-key", + hx_target="#error-message", + hx_swap="innerHTML", + class_="space-y-4" + ), + class_="container mx-auto p-4 max-w-md" + ) + ) + ).render() + + +@frontend_router.post("/verify-api-key", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) +async def verify_api_key(x_api_key: str = FastapiForm(...)): + if x_api_key == app.state.admin_api_key: # 替换为实际的管理员API密钥 + response = JSONResponse(content={"success": True}) + response.headers["HX-Redirect"] = "/" # 添加这一行 + response.set_cookie( + key="x_api_key", + value=x_api_key, + httponly=True, + max_age=1800, # 30分钟 + secure=False, # 在开发环境中设置为False,生产环境中使用HTTPS时设置为True + samesite="lax" # 改为"lax"以允许重定向时携带cookie + ) + return response + else: + return Div("无效的API密钥", class_="text-red-500").render() + +@frontend_router.get("/", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) +async def root(x_api_key: str = Depends(get_api_key)): + if not x_api_key: + return RedirectResponse(url="/login", status_code=303) + + result = HTML( + Head( + Script(""" + document.addEventListener('DOMContentLoaded', function() { + const filterInput = document.getElementById('users-table-filter'); + filterInput.addEventListener('input', function() { + const filterValue = this.value; + htmx.ajax('GET', `/filter-table?filter=${filterValue}`, '#users-table'); + }); + }); + """), + title="Menubar Example" + ), + Body( + Div( + Menubar( + MenubarMenu( + MenubarTrigger("File", "file-menu"), + MenubarContent( + MenubarItem("New Tab", shortcut="⌘T"), + MenubarItem("New Window", shortcut="⌘N"), + MenubarItem("New Incognito Window", disabled=True), + MenubarSeparator(), + MenubarItem("Print...", shortcut="⌘P"), + ), + id="file-menu" + ), + MenubarMenu( + MenubarTrigger("Edit", "edit-menu"), + MenubarContent( + MenubarItem("Undo", shortcut="⌘Z"), + MenubarItem("Redo", shortcut="⇧⌘Z"), + MenubarSeparator(), + MenubarItem("Cut"), + MenubarItem("Copy"), + MenubarItem("Paste"), + ), + id="edit-menu" + ), + MenubarMenu( + MenubarTrigger("View", "view-menu"), + MenubarContent( + MenubarItem("Always Show Bookmarks Bar"), + MenubarItem("Always Show Full URLs"), + MenubarSeparator(), + MenubarItem("Reload", shortcut="⌘R"), + MenubarItem("Force Reload", shortcut="⇧⌘R", disabled=True), + MenubarSeparator(), + MenubarItem("Toggle Fullscreen"), + MenubarItem("Hide Sidebar"), + ), + id="view-menu" + ), + ), + class_="p-4" + ), + Div( + data_table(data_table_columns, app.state.config["providers"], "users-table"), + class_="p-4" + ), + Div(id="sheet-container"), # 这里是 sheet 将被加载的地方 + class_="container mx-auto", + id="body" + ) + ).render() + # print(result) + return result + +@frontend_router.get("/dropdown-menu/{menu_id}/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) +async def get_columns_menu(menu_id: str, row_id: str): + columns = [ + { + "label": "Edit", + "value": "edit", + "hx-get": f"/edit-sheet/{row_id}", + "hx-target": "#sheet-container", + "hx-swap": "innerHTML" + }, + { + "label": "Duplicate", + "value": "duplicate", + "hx-post": f"/duplicate/{row_id}", + "hx-target": "body", + "hx-swap": "outerHTML" + }, + { + "label": "Delete", + "value": "delete", + "hx-delete": f"/delete/{row_id}", + "hx-target": "body", + "hx-swap": "outerHTML", + "hx-confirm": "确定要删除这个配置吗?" + }, + ] + result = dropdown.dropdown_menu_content(menu_id, columns).render() + print(result) + return result + +@frontend_router.get("/dropdown-menu/{menu_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) +async def get_columns_menu(menu_id: str): + result = dropdown.dropdown_menu_content(menu_id, data_table_columns).render() + print(result) + return result + +@frontend_router.get("/filter-table", response_class=HTMLResponse) +async def filter_table(filter: str = ""): + filtered_data = [ + provider for provider in app.state.config["providers"] + if filter.lower() in str(provider["provider"]).lower() or + filter.lower() in str(provider["base_url"]).lower() or + filter.lower() in str(provider["tools"]).lower() + ] + return data_table(data_table_columns, filtered_data, "users-table", with_filter=False).render() + +@frontend_router.post("/add-model", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) +async def add_model(): + new_model_id = f"model{hash(str(time()))}" # 生成一个唯一的ID + new_model = model_config_row(new_model_id).render() + return new_model + +@frontend_router.get("/edit-sheet/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) +async def get_edit_sheet(row_id: str, x_api_key: str = Depends(get_api_key)): + row_data = get_row_data(row_id) + print("row_data", row_data) + + model_list = [] + for index, model in enumerate(row_data["model"]): + if isinstance(model, str): + model_list.append(model_config_row(f"model{index}", model, "", True)) + if isinstance(model, dict): + # print("model", model, list(model.items())[0]) + key, value = list(model.items())[0] + model_list.append(model_config_row(f"model{index}", key, value, True)) + + sheet_id = "edit-sheet" + edit_sheet_content = sheet.SheetContent( + sheet.SheetHeader( + sheet.SheetTitle("Edit Item"), + sheet.SheetDescription("Make changes to your item here.") + ), + sheet.SheetBody( + Div( + form.Form( + form.FormField("Provider", "provider", value=row_data["provider"], placeholder="Enter provider name", required=True), + form.FormField("Base URL", "base_url", value=row_data["base_url"], placeholder="Enter base URL", required=True), + form.FormField("API Key", "api_key", value=row_data["api"], type="text", placeholder="Enter API key"), + Div( + Div("Models", class_="text-lg font-semibold mb-2"), + Div( + *model_list, + id="models-container" + ), + button.button( + "Add Model", + class_="mt-2", + hx_post="/add-model", + hx_target="#models-container", + hx_swap="beforeend" + ), + class_="mb-4" + ), + Div( + checkbox.checkbox("tools", "Enable Tools", checked=row_data["tools"], name="tools"), + class_="mb-4" + ), + form.FormField("Notes", "notes", value=row_data.get("notes", ""), placeholder="Enter any additional notes"), + Div( + button.button("Submit", variant="primary", type="submit"), + button.button("Cancel", variant="outline", type="button", class_="ml-2", onclick=f"toggleSheet('{sheet_id}')"), + class_="flex justify-end mt-4" + ), + hx_post=f"/submit/{row_id}", + hx_swap="outerHTML", + hx_target="body", + class_="space-y-4" + ), + class_="container mx-auto p-4 max-w-2xl" + ) + ) + ) + + result = sheet.Sheet( + sheet_id, + Div(), + edit_sheet_content, + width="80%", + max_width="800px" + ).render() + return result + +@frontend_router.get("/add-provider-sheet", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) +async def get_add_provider_sheet(): + edit_sheet_content = sheet.SheetContent( + sheet.SheetHeader( + sheet.SheetTitle("Add New Provider"), + sheet.SheetDescription("Enter details for the new provider.") + ), + sheet.SheetBody( + Div( + form.Form( + form.FormField("Provider", "provider", placeholder="Enter provider name", required=True), + form.FormField("Base URL", "base_url", placeholder="Enter base URL", required=True), + form.FormField("API Key", "api_key", type="text", placeholder="Enter API key"), + Div( + Div("Models", class_="text-lg font-semibold mb-2"), + Div(id="models-container"), + button.button( + "Add Model", + class_="mt-2", + hx_post="/add-model", + hx_target="#models-container", + hx_swap="beforeend" + ), + class_="mb-4" + ), + Div( + checkbox.checkbox("tools", "Enable Tools", name="tools"), + class_="mb-4" + ), + form.FormField("Notes", "notes", placeholder="Enter any additional notes"), + Div( + button.button("Submit", variant="primary", type="submit"), + button.button("Cancel", variant="outline", class_="ml-2"), + class_="flex justify-end mt-4" + ), + hx_post="/submit/new", + hx_swap="outerHTML", + hx_target="body", + class_="space-y-4" + ), + class_="container mx-auto p-4 max-w-2xl" + ) + ) + ) + + result = sheet.Sheet( + "add-provider-sheet", + Div(), + edit_sheet_content, + width="80%", + max_width="800px" + ).render() + return result + +def get_row_data(row_id): + index = int(row_id) + # print(app.state.config["providers"]) + return app.state.config["providers"][index] + +def update_row_data(row_id, updated_data): + print(row_id, updated_data) + index = int(row_id) + app.state.config["providers"][index] = updated_data + save_api_yaml() + +def save_api_yaml(): + with open(API_YAML_PATH, "w", encoding="utf-8") as f: + yaml.dump(app.state.config, f) + +@frontend_router.post("/submit/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) +async def submit_form( + row_id: str, + request: Request, + provider: str = FastapiForm(...), + base_url: str = FastapiForm(...), + api_key: Optional[str] = FastapiForm(None), + tools: Optional[str] = FastapiForm(None), + notes: Optional[str] = FastapiForm(None), + x_api_key: str = Depends(get_api_key) +): + form_data = await request.form() + + # 收集模型数据 + models = [] + for key, value in form_data.items(): + if key.startswith("model_name_"): + model_id = key.split("_")[-1] + enabled = form_data.get(f"model_enabled_{model_id}") == "on" + rename = form_data.get(f"model_rename_{model_id}") + if value: + if rename: + models.append({value: rename}) + else: + models.append(value) + + updated_data = { + "provider": provider, + "base_url": base_url, + "api": api_key, + "model": models, + "tools": tools == "on", + "notes": notes, + } + + print("updated_data", updated_data) + + if row_id == "new": + # 添加新提供者 + app.state.config["providers"].append(updated_data) + else: + # 更新现有提供者 + update_row_data(row_id, updated_data) + + # 保存更新后的配置 + save_api_yaml() + + return await root() + +@frontend_router.post("/duplicate/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) +async def duplicate_row(row_id: str): + index = int(row_id) + original_data = app.state.config["providers"][index] + new_data = original_data.copy() + new_data["provider"] += "-copy" + app.state.config["providers"].insert(index + 1, new_data) + + # 保存更新后的配置 + save_api_yaml() + + return await root() + +@frontend_router.delete("/delete/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)]) +async def delete_row(row_id: str): + index = int(row_id) + del app.state.config["providers"][index] + + # 保存更新后的配置 + save_api_yaml() + + return await root() + +app.include_router(frontend_router, tags=["frontend"]) + # async def on_fetch(request, env): # import asgi # return await asgi.fetch(app, request, env) diff --git a/test/xue/test_home.py b/test/xue/test_home.py index 59250eb..231280a 100644 --- a/test/xue/test_home.py +++ b/test/xue/test_home.py @@ -33,6 +33,7 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +API_YAML_PATH = "./api.yaml" class RequestBodyLoggerMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): if request.method == "POST" and request.url.path.startswith("/submit/"): @@ -47,7 +48,6 @@ async def dispatch(self, request: Request, call_next): from contextlib import asynccontextmanager @asynccontextmanager async def lifespan(app: FastAPI): - # app.state.client = httpx.AsyncClient(timeout=timeout) app.state.config, app.state.api_keys_db, app.state.api_list = await load_config() for item in app.state.api_keys_db: if item.get("role") == "admin": @@ -58,10 +58,6 @@ async def lifespan(app: FastAPI): else: raise Exception("No admin API key found") - global data - # providers_data = app.state.config["providers"] - - # print("data", data) yield # 关闭时的代码 await app.state.client.aclose() @@ -393,7 +389,10 @@ def update_row_data(row_id, updated_data): print(row_id, updated_data) index = int(row_id) app.state.config["providers"][index] = updated_data - with open("./api1.yaml", "w", encoding="utf-8") as f: + save_api_yaml() + +def save_api_yaml(): + with open(API_YAML_PATH, "w", encoding="utf-8") as f: yaml.dump(app.state.config, f) @app.post("/submit/{row_id}", response_class=HTMLResponse) @@ -441,8 +440,7 @@ async def submit_form( update_row_data(row_id, updated_data) # 保存更新后的配置 - with open("./api1.yaml", "w", encoding="utf-8") as f: - yaml.dump(app.state.config, f) + save_api_yaml() return await root() @@ -455,8 +453,7 @@ async def duplicate_row(row_id: str): app.state.config["providers"].insert(index + 1, new_data) # 保存更新后的配置 - with open("./api1.yaml", "w", encoding="utf-8") as f: - yaml.dump(app.state.config, f) + save_api_yaml() return await root() @@ -466,8 +463,7 @@ async def delete_row(row_id: str): del app.state.config["providers"][index] # 保存更新后的配置 - with open("./api1.yaml", "w", encoding="utf-8") as f: - yaml.dump(app.state.config, f) + save_api_yaml() return await root()