diff --git a/src/webui/api/planner.py b/src/webui/api/planner.py index cea892dc..c1d36955 100644 --- a/src/webui/api/planner.py +++ b/src/webui/api/planner.py @@ -8,11 +8,10 @@ 3. 详情按需加载 """ +import json from pathlib import Path from typing import Dict, List, Optional -import json - from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel @@ -272,8 +271,7 @@ async def get_all_logs(page: int = Query(1, ge=1), page_size: int = Query(20, ge all_files = [] for chat_dir in PLAN_LOG_DIR.iterdir(): if chat_dir.is_dir(): - for log_file in chat_dir.glob("*.json"): - all_files.append((chat_dir.name, log_file)) + all_files.extend((chat_dir.name, log_file) for log_file in chat_dir.glob("*.json")) # 按时间戳排序 all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True) diff --git a/src/webui/api/replier.py b/src/webui/api/replier.py index fe25459f..fff82267 100644 --- a/src/webui/api/replier.py +++ b/src/webui/api/replier.py @@ -8,11 +8,10 @@ 3. 详情按需加载 """ +import json from pathlib import Path from typing import Dict, List, Optional -import json - from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel diff --git a/src/webui/app.py b/src/webui/app.py index 934694c6..d0b98e26 100644 --- a/src/webui/app.py +++ b/src/webui/app.py @@ -1,11 +1,12 @@ """FastAPI 应用工厂 - 创建和配置 WebUI 应用实例""" +import mimetypes +import shutil from dataclasses import dataclass from importlib import import_module from pathlib import Path from subprocess import CompletedProcess, TimeoutExpired, run -import mimetypes -import shutil +from typing import Any, Dict, List, Tuple from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware @@ -61,12 +62,12 @@ def _get_dashboard_root() -> Path: return _get_project_root() / "dashboard" -def _format_dashboard_shell_commands(*commands: list[str]) -> str: +def _format_dashboard_shell_commands(*commands: List[str]) -> str: formatted_commands = " && ".join(" ".join(command) for command in commands) return f"cd dashboard && {formatted_commands}" -def _validate_static_path(static_path: Path | None) -> tuple[str, dict[str, object]] | None: +def _validate_static_path(static_path: Path | None) -> Tuple[str, Dict[str, Any]] | None: if static_path is None: return "startup.webui_static_dir_missing", {} @@ -81,7 +82,7 @@ def _validate_static_path(static_path: Path | None) -> tuple[str, dict[str, obje def _summarize_command_output(command_result: CompletedProcess[str] | TimeoutExpired) -> str: - output_chunks: list[str] = [] + output_chunks: List[str] = [] stdout = command_result.stdout stderr = command_result.stderr @@ -111,14 +112,18 @@ def _get_preferred_dashboard_package_manager(dashboard_root: Path) -> str: return "npm" -def _get_dashboard_build_command(dashboard_root: Path) -> list[str] | None: +def _get_dashboard_build_command(dashboard_root: Path) -> List[str] | None: if not (dashboard_root / "package.json").exists(): return None preferred_package_manager = _get_preferred_dashboard_package_manager(dashboard_root) package_managers = [ preferred_package_manager, - *[package_manager for package_manager in _DASHBOARD_BUILD_COMMANDS if package_manager != preferred_package_manager], + *[ + package_manager + for package_manager in _DASHBOARD_BUILD_COMMANDS + if package_manager != preferred_package_manager + ], ] for package_manager in package_managers: @@ -128,8 +133,10 @@ def _get_dashboard_build_command(dashboard_root: Path) -> list[str] | None: return None -def _get_dashboard_manual_recovery_command(dashboard_root: Path, build_command: list[str] | None = None) -> str: - package_manager = build_command[0] if build_command is not None else _get_preferred_dashboard_package_manager(dashboard_root) +def _get_dashboard_manual_recovery_command(dashboard_root: Path, build_command: List[str] | None = None) -> str: + package_manager = ( + build_command[0] if build_command is not None else _get_preferred_dashboard_package_manager(dashboard_root) + ) install_command = _DASHBOARD_INSTALL_COMMANDS.get(package_manager) selected_build_command = _DASHBOARD_BUILD_COMMANDS.get(package_manager) @@ -308,8 +315,8 @@ def _setup_cors(app: FastAPI, port: int): def _setup_anti_crawler(app: FastAPI): try: - from src.webui.middleware import AntiCrawlerMiddleware from src.config.config import global_config + from src.webui.middleware import AntiCrawlerMiddleware anti_crawler_mode = global_config.webui.anti_crawler_mode app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode) diff --git a/src/webui/config_schema.py b/src/webui/config_schema.py index 5abd4b1d..6dfe2381 100644 --- a/src/webui/config_schema.py +++ b/src/webui/config_schema.py @@ -1,5 +1,5 @@ import inspect -from typing import Any, get_args, get_origin +from typing import Any, Dict, List, get_args, get_origin from pydantic_core import PydanticUndefined @@ -8,13 +8,13 @@ from src.config.config_base import ConfigBase class ConfigSchemaGenerator: @classmethod - def generate_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> dict[str, Any]: + def generate_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> Dict[str, Any]: return cls.generate_config_schema(config_class, include_nested=include_nested) @classmethod - def generate_config_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> dict[str, Any]: - fields: list[dict[str, Any]] = [] - nested: dict[str, dict[str, Any]] = {} + def generate_config_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> Dict[str, Any]: + fields: List[Dict[str, Any]] = [] + nested: Dict[str, Dict[str, Any]] = {} for field_name, field_info in config_class.model_fields.items(): if field_name in {"field_docs", "_validate_any", "suppress_any_warning"}: @@ -28,7 +28,7 @@ class ConfigSchemaGenerator: if nested_schema is not None: nested[field_name] = nested_schema - schema: dict[str, Any] = { + schema: Dict[str, Any] = { "className": config_class.__name__, "classDoc": (config_class.__doc__ or "").strip(), "fields": fields, @@ -49,7 +49,7 @@ class ConfigSchemaGenerator: return schema @classmethod - def _build_nested_schema(cls, annotation: Any) -> dict[str, Any] | None: + def _build_nested_schema(cls, annotation: Any) -> Dict[str, Any] | None: origin = get_origin(annotation) args = get_args(annotation) @@ -66,10 +66,10 @@ class ConfigSchemaGenerator: @classmethod def _build_field_schema( cls, config_class: type[ConfigBase], field_name: str, annotation: Any, field_info: Any - ) -> dict[str, Any]: + ) -> Dict[str, Any]: field_docs = config_class.get_class_field_docs() field_type = cls._map_field_type(annotation) - schema: dict[str, Any] = { + schema: Dict[str, Any] = { "name": field_name, "type": field_type, "label": field_name, @@ -86,8 +86,7 @@ class ConfigSchemaGenerator: if origin is list and args: schema["items"] = {"type": cls._map_field_type(args[0])} - options = cls._extract_options(annotation) - if options: + if options := cls._extract_options(annotation): schema["options"] = options # Task 1c: Merge json_schema_extra (x-widget, x-icon, step, etc.) @@ -105,7 +104,7 @@ class ConfigSchemaGenerator: return schema @staticmethod - def _extract_options(annotation: Any) -> list[str] | None: + def _extract_options(annotation: Any) -> List[str] | None: origin = get_origin(annotation) if origin is None: return None diff --git a/src/webui/core/__init__.py b/src/webui/core/__init__.py index d0b7c146..37078485 100644 --- a/src/webui/core/__init__.py +++ b/src/webui/core/__init__.py @@ -1,19 +1,19 @@ -from .security import TokenManager, get_token_manager -from .rate_limiter import ( - RateLimiter, - get_rate_limiter, - check_auth_rate_limit, - check_api_rate_limit, -) from .auth import ( - COOKIE_NAME, COOKIE_MAX_AGE, + COOKIE_NAME, + clear_auth_cookie, get_current_token, is_token_valid, set_auth_cookie, - clear_auth_cookie, verify_auth_token_from_cookie_or_header, ) +from .rate_limiter import ( + RateLimiter, + check_api_rate_limit, + check_auth_rate_limit, + get_rate_limiter, +) +from .security import TokenManager, get_token_manager __all__ = [ "TokenManager", diff --git a/src/webui/core/auth.py b/src/webui/core/auth.py index 73693355..99f7ed94 100644 --- a/src/webui/core/auth.py +++ b/src/webui/core/auth.py @@ -3,8 +3,10 @@ from typing import Optional from fastapi import Cookie, HTTPException, Request, Response + from src.common.logger import get_logger from src.config.config import global_config + from .security import get_token_manager logger = get_logger("webui.auth") @@ -54,6 +56,7 @@ def get_current_token( if not is_token_valid(maibot_session): raise HTTPException(status_code=401, detail="Token 无效或已过期") + assert maibot_session is not None return maibot_session diff --git a/src/webui/core/rate_limiter.py b/src/webui/core/rate_limiter.py index 23cfc0f0..303c0420 100644 --- a/src/webui/core/rate_limiter.py +++ b/src/webui/core/rate_limiter.py @@ -5,8 +5,10 @@ WebUI 请求频率限制模块 import time from collections import defaultdict -from typing import Dict, Tuple, Optional -from fastapi import Request, HTTPException +from typing import Dict, List, Optional, Tuple + +from fastapi import HTTPException, Request + from src.common.logger import get_logger logger = get_logger("webui.rate_limiter") @@ -21,7 +23,7 @@ class RateLimiter: def __init__(self): # 存储格式: {key: [(timestamp, count), ...]} - self._requests: Dict[str, list] = defaultdict(list) + self._requests: Dict[str, List] = defaultdict(list) # 被封禁的 IP: {ip: unblock_timestamp} self._blocked: Dict[str, float] = {} diff --git a/src/webui/core/security.py b/src/webui/core/security.py index 7e5e6891..11dd37f4 100644 --- a/src/webui/core/security.py +++ b/src/webui/core/security.py @@ -6,7 +6,7 @@ WebUI Token 管理模块 import json import secrets from pathlib import Path -from typing import Optional +from typing import Dict, Optional, Tuple from src.common.logger import get_logger @@ -52,7 +52,7 @@ class TokenManager: logger.error(f"读取 WebUI 配置文件失败: {e},正在重新创建") self._create_new_token() - def _load_config(self) -> dict: + def _load_config(self) -> Dict: """加载配置文件""" try: with open(self.config_path, "r", encoding="utf-8") as f: @@ -61,7 +61,7 @@ class TokenManager: logger.error(f"加载 WebUI 配置失败: {e}") return {} - def _save_config(self, config: dict): + def _save_config(self, config: Dict): """保存配置文件""" try: with open(self.config_path, "w", encoding="utf-8") as f: @@ -127,7 +127,7 @@ class TokenManager: return is_valid - def update_token(self, new_token: str) -> tuple[bool, str]: + def update_token(self, new_token: str) -> Tuple[bool, str]: """ 更新 token @@ -208,7 +208,7 @@ class TokenManager: except ValueError: return False - def _validate_custom_token(self, token: str) -> tuple[bool, str]: + def _validate_custom_token(self, token: str) -> Tuple[bool, str]: """ 验证自定义 token 格式 diff --git a/src/webui/dependencies.py b/src/webui/dependencies.py index d29663fd..41584a69 100644 --- a/src/webui/dependencies.py +++ b/src/webui/dependencies.py @@ -1,6 +1,7 @@ from typing import Optional from fastapi import Cookie, Depends, Request + from .core import check_auth_rate_limit, get_current_token, is_token_valid diff --git a/src/webui/logs_ws.py b/src/webui/logs_ws.py index 0e707803..1f5d2263 100644 --- a/src/webui/logs_ws.py +++ b/src/webui/logs_ws.py @@ -1,9 +1,11 @@ """WebSocket 日志推送模块""" -from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query -from typing import Set, Optional import json from pathlib import Path +from typing import Dict, List, Optional, Set + +from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect + from src.common.logger import get_logger from src.webui.core import get_token_manager from src.webui.routers.websocket.auth import verify_ws_token @@ -15,7 +17,7 @@ router = APIRouter() active_connections: Set[WebSocket] = set() -def load_recent_logs(limit: int = 100) -> list[dict]: +def load_recent_logs(limit: int = 100) -> List[Dict]: """从日志文件中加载最近的日志 Args: @@ -140,7 +142,7 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None active_connections.discard(websocket) -async def broadcast_log(log_data: dict): +async def broadcast_log(log_data: Dict): """广播日志到所有连接的 WebSocket 客户端 Args: diff --git a/src/webui/middleware/__init__.py b/src/webui/middleware/__init__.py index 275b1daa..c271cc35 100644 --- a/src/webui/middleware/__init__.py +++ b/src/webui/middleware/__init__.py @@ -1,10 +1,10 @@ from .anti_crawler import ( + ALLOWED_IPS, + ANTI_CRAWLER_MODE, + TRUST_XFF, + TRUSTED_PROXIES, AntiCrawlerMiddleware, create_robots_txt_response, - ANTI_CRAWLER_MODE, - ALLOWED_IPS, - TRUSTED_PROXIES, - TRUST_XFF, ) __all__ = [ diff --git a/src/webui/middleware/anti_crawler.py b/src/webui/middleware/anti_crawler.py index 8cb0335f..59bf1599 100644 --- a/src/webui/middleware/anti_crawler.py +++ b/src/webui/middleware/anti_crawler.py @@ -3,11 +3,12 @@ WebUI 防爬虫模块 提供爬虫检测和阻止功能,保护 WebUI 不被搜索引擎和恶意爬虫访问 """ -import time import ipaddress import re +import time from collections import deque -from typing import Optional +from typing import Dict, List, Optional, Tuple + from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import PlainTextResponse @@ -131,7 +132,7 @@ SCANNER_SPECIFIC_HEADERS = { # - CIDR格式:192.168.1.0/24, 172.17.0.0/16 (适用于Docker网络) # - 通配符:192.168.*.*, 10.*.*.*, *.*.*.* (匹配所有) # - IPv6:::1, 2001:db8::/32 -def _parse_allowed_ips(ip_string: str) -> list: +def _parse_allowed_ips(ip_string: str) -> List: """ 解析IP白名单字符串,支持精确IP、CIDR格式和通配符 @@ -255,7 +256,7 @@ TRUSTED_PROXIES = _config["trusted_proxies"] TRUST_XFF = _config["trust_xff"] -def _get_mode_config(mode: str) -> dict: +def _get_mode_config(mode: str) -> Dict: """ 根据模式获取配置参数 @@ -338,7 +339,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问 # 用于存储每个IP的请求时间戳(使用deque提高性能) - self.request_times: dict[str, deque] = {} + self.request_times: Dict[str, deque] = {} # 上次清理时间 self.last_cleanup = time.time() # 将关键词列表转换为集合以提高查找性能 @@ -407,7 +408,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware): return False - def _detect_asset_scanner(self, request: Request) -> tuple[bool, Optional[str]]: + def _detect_asset_scanner(self, request: Request) -> Tuple[bool, Optional[str]]: """ 检测资产测绘工具 diff --git a/src/webui/routers/__init__.py b/src/webui/routers/__init__.py index a306bbe2..65d63d02 100644 --- a/src/webui/routers/__init__.py +++ b/src/webui/routers/__init__.py @@ -1,5 +1,7 @@ """WebUI 路由聚合模块 - 提供统一的路由注册接口""" +from typing import List + from fastapi import APIRouter @@ -10,14 +12,14 @@ def get_api_router() -> APIRouter: return main_router -def get_all_routers() -> list[APIRouter]: +def get_all_routers() -> List[APIRouter]: """获取所有需要独立注册的路由器列表""" - from src.webui.routes import router as main_router - from src.webui.routers.websocket.logs import router as logs_router - from src.webui.routers.knowledge import router as knowledge_router - from src.webui.routers.chat import router as chat_router from src.webui.api.planner import router as planner_router from src.webui.api.replier import router as replier_router + from src.webui.routers.chat import router as chat_router + from src.webui.routers.knowledge import router as knowledge_router + from src.webui.routers.websocket.logs import router as logs_router + from src.webui.routes import router as main_router return [ main_router, diff --git a/src/webui/routers/chat/__init__.py b/src/webui/routers/chat/__init__.py index 1ee04fcf..8033c172 100644 --- a/src/webui/routers/chat/__init__.py +++ b/src/webui/routers/chat/__init__.py @@ -1,10 +1,10 @@ -from fastapi import APIRouter +from typing import Tuple from .routes import router -from .support import ChatConnectionManager, WEBUI_CHAT_PLATFORM, chat_manager +from .support import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager -def get_webui_chat_broadcaster() -> tuple[ChatConnectionManager, str]: +def get_webui_chat_broadcaster() -> Tuple[ChatConnectionManager, str]: """获取 WebUI 聊天广播器,供外部模块使用。""" return chat_manager, WEBUI_CHAT_PLATFORM @@ -15,4 +15,4 @@ __all__ = [ "chat_manager", "get_webui_chat_broadcaster", "router", -] \ No newline at end of file +] diff --git a/src/webui/routers/chat/routes.py b/src/webui/routers/chat/routes.py index 805ab3b3..4cf7e2b7 100644 --- a/src/webui/routers/chat/routes.py +++ b/src/webui/routers/chat/routes.py @@ -1,8 +1,7 @@ """本地聊天室路由 - WebUI 与麦麦直接对话。""" import uuid - -from typing import Optional +from typing import Dict, Optional from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect from sqlalchemy import case, func @@ -36,7 +35,7 @@ async def get_chat_history( limit: int = Query(default=50, ge=1, le=200), user_id: Optional[str] = Query(default=None), group_id: Optional[str] = Query(default=None), -) -> dict[str, object]: +) -> Dict[str, object]: """获取聊天历史记录。""" del user_id target_group_id = group_id or WEBUI_CHAT_GROUP_ID @@ -45,7 +44,7 @@ async def get_chat_history( @router.get("/platforms") -async def get_available_platforms() -> dict[str, object]: +async def get_available_platforms() -> Dict[str, object]: """获取可用平台列表。""" try: with get_db_session() as session: @@ -68,7 +67,7 @@ async def get_persons_by_platform( platform: str = Query(..., description="平台名称"), search: Optional[str] = Query(default=None, description="搜索关键词"), limit: int = Query(default=50, ge=1, le=200), -) -> dict[str, object]: +) -> Dict[str, object]: """获取指定平台的用户列表。""" try: statement = select(PersonInfo).where(col(PersonInfo.platform) == platform) @@ -108,7 +107,7 @@ async def get_persons_by_platform( @router.delete("/history") async def clear_chat_history( group_id: Optional[str] = Query(default=None), -) -> dict[str, object]: +) -> Dict[str, object]: """清空聊天历史记录。""" deleted = chat_history.clear_history(group_id) return {"success": True, "message": f"已清空 {deleted} 条聊天记录"} @@ -164,11 +163,11 @@ async def websocket_chat( @router.get("/info") -async def get_chat_info() -> dict[str, object]: +async def get_chat_info() -> Dict[str, object]: """获取聊天室信息。""" return { "bot_name": global_config.bot.nickname, "platform": WEBUI_CHAT_PLATFORM, "group_id": WEBUI_CHAT_GROUP_ID, "active_sessions": len(chat_manager.active_connections), - } \ No newline at end of file + } diff --git a/src/webui/routers/chat/support.py b/src/webui/routers/chat/support.py index d9afa474..94d5723c 100644 --- a/src/webui/routers/chat/support.py +++ b/src/webui/routers/chat/support.py @@ -1,9 +1,8 @@ """WebUI 聊天路由支持逻辑。""" -from typing import Any, Optional, cast - import time import uuid +from typing import Any, Dict, List, Optional, Tuple, cast from fastapi import WebSocket from pydantic import BaseModel @@ -59,7 +58,7 @@ class ChatHistoryManager: def __init__(self, max_messages: int = 200) -> None: self.max_messages = max_messages - def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> dict[str, Any]: + def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> Dict[str, Any]: user_info = msg.message_info.user_info user_id = user_info.user_id or "" is_bot = is_bot_self(msg.platform, user_id) @@ -78,7 +77,7 @@ class ChatHistoryManager: target_group_id = group_id or WEBUI_CHAT_GROUP_ID return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id) - def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> list[dict[str, Any]]: + def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]: target_group_id = group_id or WEBUI_CHAT_GROUP_ID session_id = self._resolve_session_id(target_group_id) try: @@ -114,8 +113,8 @@ class ChatConnectionManager: """聊天连接管理器。""" def __init__(self) -> None: - self.active_connections: dict[str, WebSocket] = {} - self.user_sessions: dict[str, str] = {} + self.active_connections: Dict[str, WebSocket] = {} + self.user_sessions: Dict[str, str] = {} async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None: await websocket.accept() @@ -130,14 +129,14 @@ class ChatConnectionManager: del self.user_sessions[user_id] logger.info(f"WebUI 聊天会话已断开: session={session_id}") - async def send_message(self, session_id: str, message: dict[str, Any]) -> None: + async def send_message(self, session_id: str, message: Dict[str, Any]) -> None: if session_id in self.active_connections: try: await self.active_connections[session_id].send_json(message) except Exception as e: logger.error(f"发送消息失败: {e}") - async def broadcast(self, message: dict[str, Any]) -> None: + async def broadcast(self, message: Dict[str, Any]) -> None: for session_id in list(self.active_connections.keys()): await self.send_message(session_id, message) @@ -224,8 +223,8 @@ def build_session_info_message( user_id: str, user_name: str, virtual_config: Optional[VirtualIdentityConfig], -) -> dict[str, Any]: - session_info_data: dict[str, Any] = { +) -> Dict[str, Any]: + session_info_data: Dict[str, Any] = { "type": "session_info", "session_id": session_id, "user_id": user_id, @@ -314,7 +313,7 @@ def resolve_sender_identity( current_user_name: str, normalized_user_id: str, virtual_config: Optional[VirtualIdentityConfig], -) -> tuple[str, str]: +) -> Tuple[str, str]: if is_virtual_mode_enabled(virtual_config): assert virtual_config is not None return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id @@ -328,7 +327,7 @@ def create_message_data( message_id: Optional[str] = None, is_at_bot: bool = True, virtual_config: Optional[VirtualIdentityConfig] = None, -) -> dict[str, Any]: +) -> Dict[str, Any]: if message_id is None: message_id = str(uuid.uuid4()) @@ -385,7 +384,7 @@ def create_message_data( async def handle_chat_message( session_id: str, - data: dict[str, Any], + data: Dict[str, Any], current_user_name: str, normalized_user_id: str, current_virtual_config: Optional[VirtualIdentityConfig], @@ -443,7 +442,7 @@ async def handle_chat_ping(session_id: str) -> None: await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()}) -async def handle_nickname_update(session_id: str, data: dict[str, Any], current_user_name: str) -> str: +async def handle_nickname_update(session_id: str, data: Dict[str, Any], current_user_name: str) -> str: new_name = str(data.get("user_name", "")).strip() if not new_name: return current_user_name @@ -462,7 +461,7 @@ async def handle_nickname_update(session_id: str, data: dict[str, Any], current_ async def enable_virtual_identity( session_id: str, session_prefix: str, - virtual_data: dict[str, Any], + virtual_data: Dict[str, Any], ) -> Optional[VirtualIdentityConfig]: if not virtual_data.get("platform") or not virtual_data.get("person_id"): await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id") @@ -558,7 +557,7 @@ async def disable_virtual_identity(session_id: str) -> None: async def handle_virtual_identity_update( session_id: str, session_id_prefix: str, - data: dict[str, Any], + data: Dict[str, Any], current_virtual_config: Optional[VirtualIdentityConfig], ) -> Optional[VirtualIdentityConfig]: virtual_data = cast(dict[str, Any], data.get("config", {})) @@ -573,11 +572,11 @@ async def handle_virtual_identity_update( async def dispatch_chat_event( session_id: str, session_id_prefix: str, - data: dict[str, Any], + data: Dict[str, Any], current_user_name: str, normalized_user_id: str, current_virtual_config: Optional[VirtualIdentityConfig], -) -> tuple[str, Optional[VirtualIdentityConfig]]: +) -> Tuple[str, Optional[VirtualIdentityConfig]]: event_type = data.get("type") if event_type == "message": next_user_name = await handle_chat_message( diff --git a/src/webui/routers/config.py b/src/webui/routers/config.py index 78d2f1d2..8d22dc19 100644 --- a/src/webui/routers/config.py +++ b/src/webui/routers/config.py @@ -5,51 +5,51 @@ import copy import os from pathlib import Path -from typing import Any, Annotated, Optional +from typing import Annotated, Any, Dict, List, Tuple import tomlkit from fastapi import APIRouter, Body, Depends, HTTPException from src.common.logger import get_logger -from src.webui.dependencies import require_auth -from src.webui.utils.toml_utils import save_toml_with_format, _update_toml_doc -from src.config.config import Config, ModelConfig, CONFIG_DIR, PROJECT_ROOT +from src.config.config import CONFIG_DIR, PROJECT_ROOT, Config, ModelConfig from src.config.config_base import AttributeData +from src.config.model_configs import ( + APIProvider, + ModelInfo, + ModelTaskConfig, +) from src.config.official_configs import ( BotConfig, - PersonalityConfig, - RelationshipConfig, ChatConfig, - MessageReceiveConfig, + ChineseTypoConfig, + DebugConfig, EmojiConfig, + ExperimentalConfig, ExpressionConfig, KeywordReactionConfig, - ChineseTypoConfig, + LPMMKnowledgeConfig, + MaimMessageConfig, + MemoryConfig, + MessageReceiveConfig, + PersonalityConfig, + RelationshipConfig, ResponsePostProcessConfig, ResponseSplitterConfig, TelemetryConfig, - ExperimentalConfig, - MaimMessageConfig, - LPMMKnowledgeConfig, ToolConfig, - MemoryConfig, - DebugConfig, VoiceConfig, ) -from src.config.model_configs import ( - ModelTaskConfig, - ModelInfo, - APIProvider, -) from src.webui.config_schema import ConfigSchemaGenerator +from src.webui.dependencies import require_auth +from src.webui.utils.toml_utils import _update_toml_doc, save_toml_with_format logger = get_logger("webui") # 模块级别的类型别名(解决 B008 ruff 错误) -ConfigBody = Annotated[dict[str, Any], Body()] +ConfigBody = Annotated[Dict[str, Any], Body()] SectionBody = Annotated[Any, Body()] RawContentBody = Annotated[str, Body(embed=True)] -PathBody = Annotated[dict[str, str], Body()] +PathBody = Annotated[Dict[str, str], Body()] router = APIRouter(prefix="/config", tags=["config"], dependencies=[Depends(require_auth)]) @@ -61,6 +61,8 @@ def _toml_to_plain_dict(obj: Any) -> Any: if isinstance(obj, list): return [_toml_to_plain_dict(v) for v in obj] return obj + + # ===== 架构获取接口 ===== @@ -385,8 +387,12 @@ async def update_model_config_section(section_name: str, section_data: SectionBo if section_name == "api_providers" and "api_provider" in str(e): provider_names = {p.get("name") for p in section_data if isinstance(p, dict)} models = config_data.get("models", []) - orphaned_models = [ - m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names + orphaned_models: List[str] = [ + str(model_name) + for m in models + if isinstance(m, dict) + and m.get("api_provider") not in provider_names + and (model_name := m.get("name")) is not None ] if orphaned_models: error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。" @@ -421,7 +427,7 @@ def _normalize_adapter_path(path: str) -> str: return os.path.normpath(os.path.join(PROJECT_ROOT, path)) -def _get_allowed_adapter_config_roots() -> tuple[Path, ...]: +def _get_allowed_adapter_config_roots() -> Tuple[Path, ...]: project_root = Path(PROJECT_ROOT).resolve() return ( project_root, diff --git a/src/webui/routers/emoji/__init__.py b/src/webui/routers/emoji/__init__.py index f8cdb40b..1b2eccc4 100644 --- a/src/webui/routers/emoji/__init__.py +++ b/src/webui/routers/emoji/__init__.py @@ -1,3 +1,3 @@ from .routes import router -__all__ = ["router"] \ No newline at end of file +__all__ = ["router"] diff --git a/src/webui/routers/emoji/routes.py b/src/webui/routers/emoji/routes.py index 6081dd14..4216802e 100644 --- a/src/webui/routers/emoji/routes.py +++ b/src/webui/routers/emoji/routes.py @@ -1,13 +1,12 @@ """表情包管理 API 路由""" -from datetime import datetime -from pathlib import Path -from typing import Any, Optional - import asyncio import hashlib import io import os +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional from fastapi import APIRouter, Cookie, HTTPException, Query from fastapi.responses import FileResponse, JSONResponse @@ -17,7 +16,8 @@ from sqlmodel import col, select from src.common.database.database import get_db_session from src.common.database.database_model import Images, ImageType -from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header as verify_auth_token +from src.webui.core import get_token_manager +from src.webui.core import verify_auth_token_from_cookie_or_header as verify_auth_token from .schemas import ( BatchDeleteRequest, @@ -219,7 +219,7 @@ async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(Non @router.get("/stats/summary") -async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: +async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]: """获取表情包统计数据。""" try: verify_auth_token(maibot_session) @@ -247,7 +247,7 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None)) -> dict[ registered = session.exec(registered_statement).one() banned = session.exec(banned_statement).one() - formats: dict[str, int] = {} + formats: Dict[str, int] = {} format_statement = select(Images.full_path).where(col(Images.image_type) == ImageType.EMOJI) for full_path in session.exec(format_statement).all(): suffix = Path(full_path).suffix.lower().lstrip(".") @@ -443,7 +443,7 @@ async def batch_delete_emojis( deleted_count = 0 failed_count = 0 - failed_ids: list[int] = [] + failed_ids: List[int] = [] for emoji_id in request.emoji_ids: try: @@ -582,12 +582,12 @@ async def batch_upload_emoji( emotion: EmotionForm = "", is_registered: IsRegisteredForm = True, maibot_session: Optional[str] = Cookie(None), -) -> dict[str, Any]: +) -> Dict[str, Any]: """批量上传表情包。""" try: verify_auth_token(maibot_session) - results: dict[str, Any] = { + results: Dict[str, Any] = { "success": True, "total": len(files), "uploaded": 0, @@ -614,9 +614,7 @@ async def batch_upload_emoji( file_content = await file.read() if not file_content: results["failed"] += 1 - results["details"].append( - {"filename": file.filename, "success": False, "error": "文件内容为空"} - ) + results["details"].append({"filename": file.filename, "success": False, "error": "文件内容为空"}) continue try: @@ -677,14 +675,10 @@ async def batch_upload_emoji( session.flush() results["uploaded"] += 1 - results["details"].append( - {"filename": file.filename, "success": True, "id": emoji.id} - ) + results["details"].append({"filename": file.filename, "success": True, "id": emoji.id}) except Exception as e: results["failed"] += 1 - results["details"].append( - {"filename": file.filename, "success": False, "error": str(e)} - ) + results["details"].append({"filename": file.filename, "success": False, "error": str(e)}) results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']} 个" return results @@ -787,7 +781,9 @@ async def preheat_thumbnail_cache( try: loop = asyncio.get_event_loop() - await loop.run_in_executor(get_thumbnail_executor(), generate_thumbnail, emoji.full_path, emoji.image_hash) + await loop.run_in_executor( + get_thumbnail_executor(), generate_thumbnail, emoji.full_path, emoji.image_hash + ) generated += 1 except Exception as e: logger.warning(f"预热缩略图失败 {emoji.image_hash}: {e}") @@ -840,4 +836,4 @@ async def clear_all_thumbnail_cache(maibot_session: Optional[str] = Cookie(None) raise except Exception as e: logger.exception(f"清空缩略图缓存失败: {e}") - raise HTTPException(status_code=500, detail=f"清空失败: {str(e)}") from e \ No newline at end of file + raise HTTPException(status_code=500, detail=f"清空失败: {str(e)}") from e diff --git a/src/webui/routers/emoji/schemas.py b/src/webui/routers/emoji/schemas.py index 4eea67c6..fd874841 100644 --- a/src/webui/routers/emoji/schemas.py +++ b/src/webui/routers/emoji/schemas.py @@ -137,4 +137,4 @@ def emoji_to_response(image: Images) -> EmojiResponse: record_time=image.record_time.timestamp() if image.record_time else 0.0, register_time=image.register_time.timestamp() if image.register_time else None, last_used_time=image.last_used_time.timestamp() if image.last_used_time else None, - ) \ No newline at end of file + ) diff --git a/src/webui/routers/emoji/support.py b/src/webui/routers/emoji/support.py index 51790cd4..41b0be97 100644 --- a/src/webui/routers/emoji/support.py +++ b/src/webui/routers/emoji/support.py @@ -1,8 +1,8 @@ -from concurrent.futures import ThreadPoolExecutor -from pathlib import Path - import os import threading +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Dict, Set, Tuple from PIL import Image from sqlmodel import col, select @@ -18,10 +18,10 @@ THUMBNAIL_SIZE = (200, 200) THUMBNAIL_QUALITY = 80 EMOJI_REGISTERED_DIR = os.path.join("data", "emoji_registed") -_thumbnail_locks: dict[str, threading.Lock] = {} +_thumbnail_locks: Dict[str, threading.Lock] = {} _locks_lock = threading.Lock() _thumbnail_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="thumbnail") -_generating_thumbnails: set[str] = set() +_generating_thumbnails: Set[str] = set() _generating_lock = threading.Lock() @@ -35,7 +35,7 @@ def get_generating_lock() -> threading.Lock: return _generating_lock -def get_generating_thumbnails() -> set[str]: +def get_generating_thumbnails() -> Set[str]: """获取正在生成的缩略图哈希集合。""" return _generating_thumbnails @@ -112,7 +112,7 @@ def background_generate_thumbnail(source_path: str, file_hash: str) -> None: _background_generate_thumbnail(source_path, file_hash) -def cleanup_orphaned_thumbnails() -> tuple[int, int]: +def cleanup_orphaned_thumbnails() -> Tuple[int, int]: """清理孤立的缩略图缓存。""" if not THUMBNAIL_CACHE_DIR.exists(): return 0, 0 @@ -139,4 +139,4 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]: if cleaned > 0: logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept} 个") - return cleaned, kept \ No newline at end of file + return cleaned, kept diff --git a/src/webui/routers/expression.py b/src/webui/routers/expression.py index f814751c..94930d7d 100644 --- a/src/webui/routers/expression.py +++ b/src/webui/routers/expression.py @@ -5,14 +5,13 @@ from typing import Dict, List, Optional from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel - from sqlalchemy import case, func -from sqlmodel import col, select, delete +from sqlmodel import col, delete, select -from src.common.logger import get_logger +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.common.database.database import get_db_session from src.common.database.database_model import Expression -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager +from src.common.logger import get_logger from src.webui.dependencies import require_auth logger = get_logger("webui.expression") diff --git a/src/webui/routers/jargon.py b/src/webui/routers/jargon.py index 95568d48..ce0ddf4c 100644 --- a/src/webui/routers/jargon.py +++ b/src/webui/routers/jargon.py @@ -1,8 +1,7 @@ """黑话(俚语)管理路由""" import json - -from typing import Annotated, Any, List, Optional +from typing import Annotated, Any, Dict, List, Optional, Set from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel, Field @@ -90,7 +89,7 @@ class JargonListResponse(BaseModel): total: int page: int page_size: int - data: List[dict[str, Any]] + data: List[Dict[str, Any]] class JargonDetailResponse(BaseModel): @@ -153,7 +152,7 @@ class JargonStatsResponse(BaseModel): """黑话统计响应""" success: bool = True - data: dict[str, Any] + data: Dict[str, Any] class ChatInfoResponse(BaseModel): @@ -175,7 +174,7 @@ class ChatListResponse(BaseModel): # ==================== 工具函数 ==================== -def parse_session_id_dict(session_id_dict_str: Optional[str]) -> dict[str, int]: +def parse_session_id_dict(session_id_dict_str: Optional[str]) -> Dict[str, int]: """解析会话计数字典。""" if not session_id_dict_str: return {} @@ -188,7 +187,7 @@ def parse_session_id_dict(session_id_dict_str: Optional[str]) -> dict[str, int]: if not isinstance(parsed, dict): return {} - session_counts: dict[str, int] = {} + session_counts: Dict[str, int] = {} for session_id, count in parsed.items(): if not isinstance(session_id, str): continue @@ -202,7 +201,7 @@ def parse_session_id_dict(session_id_dict_str: Optional[str]) -> dict[str, int]: return session_counts -def dump_session_id_dict(session_counts: dict[str, int]) -> str: +def dump_session_id_dict(session_counts: Dict[str, int]) -> str: """序列化会话计数字典。""" return json.dumps(session_counts, ensure_ascii=False) @@ -225,7 +224,7 @@ def build_session_id_dict_for_chat(chat_id: str, count: int = 1) -> str: return dump_session_id_dict({chat_id: count}) -def jargon_to_dict(jargon: Jargon, session: Session) -> dict[str, Any]: +def jargon_to_dict(jargon: Jargon, session: Session) -> Dict[str, Any]: """将 Jargon ORM 对象转换为字典""" chat_id = get_primary_chat_id(jargon.session_id_dict) chat_name = get_display_name_for_chat_id(chat_id, session) if chat_id else None @@ -311,14 +310,16 @@ async def get_chat_list(): with get_db_session() as session: jargons = session.exec(select(Jargon)).all() - seen_stream_ids: set[str] = set() + seen_stream_ids: Set[str] = set() for jargon in jargons: seen_stream_ids.update(parse_session_id_dict(jargon.session_id_dict).keys()) result = [] with get_db_session() as session: for stream_id in seen_stream_ids: - if chat_session := session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first(): + if chat_session := session.exec( + select(ChatSession).where(col(ChatSession.session_id) == stream_id) + ).first(): chat_name = str(chat_session.group_id) if chat_session.group_id else stream_id[:20] result.append( ChatInfoResponse( @@ -358,7 +359,7 @@ async def get_jargon_stats(): pending = sum(jargon.is_jargon is None for jargon in jargons) complete_count = sum(jargon.is_complete for jargon in jargons) - top_chats_counter: dict[str, int] = {} + top_chats_counter: Dict[str, int] = {} for jargon in jargons: for session_id in parse_session_id_dict(jargon.session_id_dict): top_chats_counter[session_id] = top_chats_counter.get(session_id, 0) + 1 diff --git a/src/webui/routers/knowledge.py b/src/webui/routers/knowledge.py index cb7f58d9..b6ab500f 100644 --- a/src/webui/routers/knowledge.py +++ b/src/webui/routers/knowledge.py @@ -1,10 +1,11 @@ """知识库图谱可视化 API 路由""" -from typing import Any, List, Optional +import logging +from typing import Any, List, Optional, Tuple from fastapi import APIRouter, Depends, Query from pydantic import BaseModel -import logging + from src.config.config import global_config from src.webui.dependencies import require_auth @@ -59,7 +60,7 @@ def _get_paragraph_store(): return None -def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]: +def _get_paragraph_content(node_id: str) -> Tuple[Optional[str], bool]: """从 embedding store 获取段落完整内容 Args: diff --git a/src/webui/routers/model.py b/src/webui/routers/model.py index 1e16ac4d..2f67aca5 100644 --- a/src/webui/routers/model.py +++ b/src/webui/routers/model.py @@ -5,9 +5,9 @@ """ import os -import httpx -from typing import Optional +from typing import Dict, List, Optional +import httpx import tomlkit from fastapi import APIRouter, Depends, HTTPException, Query @@ -39,7 +39,7 @@ def _normalize_url(url: str) -> str: return url.rstrip("/") if url else "" -def _parse_openai_response(data: dict) -> list[dict]: +def _parse_openai_response(data: Dict) -> List[Dict]: """ 解析 OpenAI 格式的模型列表响应 @@ -59,7 +59,7 @@ def _parse_openai_response(data: dict) -> list[dict]: ] -def _parse_gemini_response(data: dict) -> list[dict]: +def _parse_gemini_response(data: Dict) -> List[Dict]: """ 解析 Gemini 格式的模型列表响应 @@ -89,7 +89,7 @@ async def _fetch_models_from_provider( endpoint: str, parser: str, client_type: str = "openai", -) -> list[dict]: +) -> List[Dict]: """ 从提供商 API 获取模型列表 @@ -154,7 +154,7 @@ async def _fetch_models_from_provider( raise HTTPException(status_code=400, detail=f"不支持的解析器类型: {parser}") -def _get_provider_config(provider_name: str) -> Optional[dict]: +def _get_provider_config(provider_name: str) -> Optional[Dict]: """ 从 model_config.toml 获取指定提供商的配置 diff --git a/src/webui/routers/person.py b/src/webui/routers/person.py index 86927b0e..d797602b 100644 --- a/src/webui/routers/person.py +++ b/src/webui/routers/person.py @@ -1,14 +1,13 @@ """人物信息管理 API 路由""" -from datetime import datetime import json +from datetime import datetime from typing import Dict, List, Optional from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel - from sqlalchemy import case -from sqlmodel import col, select, delete +from sqlmodel import col, delete, select from src.common.database.database import get_db_session from src.common.database.database_model import PersonInfo diff --git a/src/webui/routers/plugin/__init__.py b/src/webui/routers/plugin/__init__.py index 1841f690..1be61841 100644 --- a/src/webui/routers/plugin/__init__.py +++ b/src/webui/routers/plugin/__init__.py @@ -14,4 +14,4 @@ router.include_router(config_router) set_update_progress_callback(update_progress) -__all__ = ["get_progress_router", "router"] \ No newline at end of file +__all__ = ["get_progress_router", "router"] diff --git a/src/webui/routers/plugin/catalog.py b/src/webui/routers/plugin/catalog.py index 4a49d048..80077bad 100644 --- a/src/webui/routers/plugin/catalog.py +++ b/src/webui/routers/plugin/catalog.py @@ -1,6 +1,5 @@ -from typing import Any, Optional - import json +from typing import Any, Dict, Optional from fastapi import APIRouter, Cookie, HTTPException @@ -28,7 +27,7 @@ logger = get_logger("webui.plugin_routes") router = APIRouter() -def _mirror_to_response(mirror: dict[str, Any]) -> MirrorConfigResponse: +def _mirror_to_response(mirror: Dict[str, Any]) -> MirrorConfigResponse: return MirrorConfigResponse( id=mirror["id"], name=mirror["name"], @@ -116,7 +115,7 @@ async def update_mirror( @router.delete("/mirrors/{mirror_id}") -async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: +async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]: require_plugin_token(maibot_session) service = get_git_mirror_service() @@ -172,12 +171,16 @@ async def fetch_raw_file( loaded_plugins=total, ) except Exception: - await update_progress(stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0) + await update_progress( + stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0 + ) return FetchRawFileResponse(**result) except Exception as e: logger.error(f"获取 Raw 文件失败: {e}") - await update_progress(stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0) + await update_progress( + stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0 + ) raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e @@ -204,4 +207,4 @@ async def clone_repository( return CloneRepositoryResponse(**result) except Exception as e: logger.error(f"克隆仓库失败: {e}") - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e \ No newline at end of file + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e diff --git a/src/webui/routers/plugin/config_routes.py b/src/webui/routers/plugin/config_routes.py index b9c32f2e..a84362bf 100644 --- a/src/webui/routers/plugin/config_routes.py +++ b/src/webui/routers/plugin/config_routes.py @@ -1,8 +1,7 @@ -from typing import Any, Optional, cast - import json -import tomlkit +from typing import Any, Dict, Optional, cast +import tomlkit from fastapi import APIRouter, Cookie, HTTPException from src.common.logger import get_logger @@ -24,8 +23,8 @@ logger = get_logger("webui.plugin_routes") router = APIRouter() -def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> dict[str, Any]: - schema: dict[str, Any] = { +def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> Dict[str, Any]: + schema: Dict[str, Any] = { "plugin_id": plugin_id, "plugin_info": { "name": plugin_id, @@ -41,7 +40,7 @@ def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> di for section_name, section_data in current_config.items(): if not isinstance(section_data, dict): continue - section_fields: dict[str, Any] = {} + section_fields: Dict[str, Any] = {} for field_name, field_value in section_data.items(): field_type = type(field_value).__name__ ui_type = "text" @@ -121,7 +120,7 @@ def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> di @router.get("/config/{plugin_id}/schema") -async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: +async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]: require_plugin_token(maibot_session) logger.info(f"获取插件配置 Schema: {plugin_id}") @@ -157,7 +156,7 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] @router.get("/config/{plugin_id}/raw") -async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: +async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]: require_plugin_token(maibot_session) logger.info(f"获取插件原始配置: {plugin_id}") @@ -184,7 +183,7 @@ async def update_plugin_config_raw( plugin_id: str, request: UpdatePluginRawConfigRequest, maibot_session: Optional[str] = Cookie(None), -) -> dict[str, Any]: +) -> Dict[str, Any]: require_plugin_token(maibot_session) logger.info(f"更新插件原始配置: {plugin_id}") @@ -216,7 +215,7 @@ async def update_plugin_config_raw( @router.get("/config/{plugin_id}") -async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: +async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]: require_plugin_token(maibot_session) logger.info(f"获取插件配置: {plugin_id}") @@ -244,7 +243,7 @@ async def update_plugin_config( plugin_id: str, request: UpdatePluginConfigRequest, maibot_session: Optional[str] = Cookie(None), -) -> dict[str, Any]: +) -> Dict[str, Any]: require_plugin_token(maibot_session) logger.info(f"更新插件配置: {plugin_id}") @@ -276,7 +275,7 @@ async def update_plugin_config( @router.post("/config/{plugin_id}/reset") -async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: +async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]: require_plugin_token(maibot_session) logger.info(f"重置插件配置: {plugin_id}") @@ -300,7 +299,7 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co @router.post("/config/{plugin_id}/toggle") -async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: +async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]: require_plugin_token(maibot_session) logger.info(f"切换插件状态: {plugin_id}") @@ -326,9 +325,14 @@ async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(N status = "启用" if new_enabled else "禁用" logger.info(f"已{status}插件: {plugin_id}") - return {"success": True, "enabled": new_enabled, "message": f"插件已{status}", "note": "状态更改将在下次加载插件时生效"} + return { + "success": True, + "enabled": new_enabled, + "message": f"插件已{status}", + "note": "状态更改将在下次加载插件时生效", + } except HTTPException: raise except Exception as e: logger.error(f"切换插件状态失败: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e \ No newline at end of file + raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e diff --git a/src/webui/routers/plugin/management.py b/src/webui/routers/plugin/management.py index 7a725c94..2c509215 100644 --- a/src/webui/routers/plugin/management.py +++ b/src/webui/routers/plugin/management.py @@ -1,7 +1,6 @@ -from pathlib import Path -from typing import Any, Optional - import json +from pathlib import Path +from typing import Any, Dict, List, Optional from fastapi import APIRouter, Cookie, HTTPException @@ -18,8 +17,8 @@ from .support import ( parse_repository_url, remove_tree, require_plugin_token, - resolve_plugin_file_path, resolve_installed_plugin_path, + resolve_plugin_file_path, validate_plugin_id, ) @@ -28,7 +27,7 @@ logger = get_logger("webui.plugin_routes") router = APIRouter() -def _infer_plugin_id(folder_name: str, manifest: dict[str, Any], manifest_path: Path) -> str: +def _infer_plugin_id(folder_name: str, manifest: Dict[str, Any], manifest_path: Path) -> str: if "id" in manifest: return str(manifest["id"]) @@ -66,43 +65,87 @@ def _infer_plugin_id(folder_name: str, manifest: dict[str, Any], manifest_path: @router.post("/install") -async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: +async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]: require_plugin_token(maibot_session) logger.info(f"收到安装插件请求: {request.plugin_id}") plugin_id = request.plugin_id try: plugin_id = validate_plugin_id(request.plugin_id) - await update_progress(stage="loading", progress=5, message=f"开始安装插件: {plugin_id}", operation="install", plugin_id=plugin_id) + await update_progress( + stage="loading", progress=5, message=f"开始安装插件: {plugin_id}", operation="install", plugin_id=plugin_id + ) repo_url, owner, repo = parse_repository_url(request.repository_url) - await update_progress(stage="loading", progress=10, message=f"解析仓库信息: {owner}/{repo}", operation="install", plugin_id=plugin_id) + await update_progress( + stage="loading", + progress=10, + message=f"解析仓库信息: {owner}/{repo}", + operation="install", + plugin_id=plugin_id, + ) target_path, old_format_path = get_plugin_candidate_paths(plugin_id) if target_path.exists() or old_format_path.exists(): - await update_progress(stage="error", progress=0, message="插件已存在", operation="install", plugin_id=plugin_id, error="插件已安装,请先卸载") + await update_progress( + stage="error", + progress=0, + message="插件已存在", + operation="install", + plugin_id=plugin_id, + error="插件已安装,请先卸载", + ) raise HTTPException(status_code=400, detail="插件已安装") - await update_progress(stage="loading", progress=15, message=f"准备克隆到: {target_path}", operation="install", plugin_id=plugin_id) + await update_progress( + stage="loading", progress=15, message=f"准备克隆到: {target_path}", operation="install", plugin_id=plugin_id + ) service = get_git_mirror_service() if "github.com" in repo_url: - result = await service.clone_repository(owner=owner, repo=repo, target_path=target_path, branch=request.branch, mirror_id=request.mirror_id, depth=1) + result = await service.clone_repository( + owner=owner, + repo=repo, + target_path=target_path, + branch=request.branch, + mirror_id=request.mirror_id, + depth=1, + ) else: - result = await service.clone_repository(owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1) + result = await service.clone_repository( + owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1 + ) if not result.get("success"): error_msg = str(result.get("error", "克隆失败")) - await update_progress(stage="error", progress=0, message="克隆仓库失败", operation="install", plugin_id=plugin_id, error=error_msg) + await update_progress( + stage="error", + progress=0, + message="克隆仓库失败", + operation="install", + plugin_id=plugin_id, + error=error_msg, + ) raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg) - await update_progress(stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id) + await update_progress( + stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id + ) manifest_path = resolve_plugin_file_path(target_path, "_manifest.json") if not manifest_path.exists(): remove_tree(target_path) - await update_progress(stage="error", progress=0, message="插件缺少 _manifest.json", operation="install", plugin_id=plugin_id, error="无效的插件格式") + await update_progress( + stage="error", + progress=0, + message="插件缺少 _manifest.json", + operation="install", + plugin_id=plugin_id, + error="无效的插件格式", + ) raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json") - await update_progress(stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id) + await update_progress( + stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id + ) try: with open(manifest_path, "r", encoding="utf-8") as file_obj: manifest = json.load(file_obj) @@ -114,91 +157,199 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional json.dump(manifest, file_obj, ensure_ascii=False, indent=2) except Exception as e: remove_tree(target_path) - await update_progress(stage="error", progress=0, message="_manifest.json 无效", operation="install", plugin_id=plugin_id, error=str(e)) + await update_progress( + stage="error", + progress=0, + message="_manifest.json 无效", + operation="install", + plugin_id=plugin_id, + error=str(e), + ) raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e - await update_progress(stage="success", progress=100, message=f"成功安装插件: {manifest['name']} v{manifest['version']}", operation="install", plugin_id=plugin_id) - return {"success": True, "message": "插件安装成功", "plugin_id": plugin_id, "plugin_name": manifest["name"], "version": manifest["version"], "path": str(target_path)} + await update_progress( + stage="success", + progress=100, + message=f"成功安装插件: {manifest['name']} v{manifest['version']}", + operation="install", + plugin_id=plugin_id, + ) + return { + "success": True, + "message": "插件安装成功", + "plugin_id": plugin_id, + "plugin_name": manifest["name"], + "version": manifest["version"], + "path": str(target_path), + } except HTTPException: raise except Exception as e: logger.error(f"安装插件失败: {e}", exc_info=True) - await update_progress(stage="error", progress=0, message="安装失败", operation="install", plugin_id=plugin_id, error=str(e)) + await update_progress( + stage="error", progress=0, message="安装失败", operation="install", plugin_id=plugin_id, error=str(e) + ) raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e @router.post("/uninstall") -async def uninstall_plugin(request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: +async def uninstall_plugin( + request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None) +) -> Dict[str, Any]: require_plugin_token(maibot_session) logger.info(f"收到卸载插件请求: {request.plugin_id}") plugin_id = request.plugin_id try: plugin_id = validate_plugin_id(request.plugin_id) - await update_progress(stage="loading", progress=10, message=f"开始卸载插件: {plugin_id}", operation="uninstall", plugin_id=plugin_id) + await update_progress( + stage="loading", + progress=10, + message=f"开始卸载插件: {plugin_id}", + operation="uninstall", + plugin_id=plugin_id, + ) plugin_path = resolve_installed_plugin_path(plugin_id) if plugin_path is None: - await update_progress(stage="error", progress=0, message="插件不存在", operation="uninstall", plugin_id=plugin_id, error="插件未安装或已被删除") + await update_progress( + stage="error", + progress=0, + message="插件不存在", + operation="uninstall", + plugin_id=plugin_id, + error="插件未安装或已被删除", + ) raise HTTPException(status_code=404, detail="插件未安装") - await update_progress(stage="loading", progress=30, message=f"正在删除插件文件: {plugin_path}", operation="uninstall", plugin_id=plugin_id) + await update_progress( + stage="loading", + progress=30, + message=f"正在删除插件文件: {plugin_path}", + operation="uninstall", + plugin_id=plugin_id, + ) manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json")) plugin_name = str(manifest.get("name", plugin_id)) if manifest is not None else plugin_id - await update_progress(stage="loading", progress=50, message=f"正在删除 {plugin_name}...", operation="uninstall", plugin_id=plugin_id) + await update_progress( + stage="loading", + progress=50, + message=f"正在删除 {plugin_name}...", + operation="uninstall", + plugin_id=plugin_id, + ) remove_tree(plugin_path) logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})") - await update_progress(stage="success", progress=100, message=f"成功卸载插件: {plugin_name}", operation="uninstall", plugin_id=plugin_id) + await update_progress( + stage="success", + progress=100, + message=f"成功卸载插件: {plugin_name}", + operation="uninstall", + plugin_id=plugin_id, + ) return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name} except HTTPException: raise except PermissionError as e: logger.error(f"卸载插件失败(权限错误): {e}") - await update_progress(stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error="权限不足,无法删除插件文件") + await update_progress( + stage="error", + progress=0, + message="卸载失败", + operation="uninstall", + plugin_id=plugin_id, + error="权限不足,无法删除插件文件", + ) raise HTTPException(status_code=500, detail="权限不足,无法删除插件文件") from e except Exception as e: logger.error(f"卸载插件失败: {e}", exc_info=True) - await update_progress(stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error=str(e)) + await update_progress( + stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error=str(e) + ) raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e @router.post("/update") -async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: +async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]: require_plugin_token(maibot_session) logger.info(f"收到更新插件请求: {request.plugin_id}") plugin_id = request.plugin_id try: plugin_id = validate_plugin_id(request.plugin_id) - await update_progress(stage="loading", progress=5, message=f"开始更新插件: {plugin_id}", operation="update", plugin_id=plugin_id) + await update_progress( + stage="loading", progress=5, message=f"开始更新插件: {plugin_id}", operation="update", plugin_id=plugin_id + ) plugin_path = resolve_installed_plugin_path(plugin_id) if plugin_path is None: - await update_progress(stage="error", progress=0, message="插件不存在", operation="update", plugin_id=plugin_id, error="插件未安装,请先安装") + await update_progress( + stage="error", + progress=0, + message="插件不存在", + operation="update", + plugin_id=plugin_id, + error="插件未安装,请先安装", + ) raise HTTPException(status_code=404, detail="插件未安装") manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json")) old_version = str(manifest.get("version", "unknown")) if manifest is not None else "unknown" - await update_progress(stage="loading", progress=10, message=f"当前版本: {old_version},准备更新...", operation="update", plugin_id=plugin_id) - await update_progress(stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id) + await update_progress( + stage="loading", + progress=10, + message=f"当前版本: {old_version},准备更新...", + operation="update", + plugin_id=plugin_id, + ) + await update_progress( + stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id + ) remove_tree(plugin_path) - await update_progress(stage="loading", progress=30, message="正在准备下载新版本...", operation="update", plugin_id=plugin_id) + await update_progress( + stage="loading", progress=30, message="正在准备下载新版本...", operation="update", plugin_id=plugin_id + ) repo_url, owner, repo = parse_repository_url(request.repository_url) service = get_git_mirror_service() if "github.com" in repo_url: - result = await service.clone_repository(owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, mirror_id=request.mirror_id, depth=1) + result = await service.clone_repository( + owner=owner, + repo=repo, + target_path=plugin_path, + branch=request.branch, + mirror_id=request.mirror_id, + depth=1, + ) else: - result = await service.clone_repository(owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1) + result = await service.clone_repository( + owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1 + ) if not result.get("success"): error_msg = str(result.get("error", "克隆失败")) - await update_progress(stage="error", progress=0, message="下载新版本失败", operation="update", plugin_id=plugin_id, error=error_msg) + await update_progress( + stage="error", + progress=0, + message="下载新版本失败", + operation="update", + plugin_id=plugin_id, + error=error_msg, + ) raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg) - await update_progress(stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id) + await update_progress( + stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id + ) new_manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json") if not new_manifest_path.exists(): remove_tree(plugin_path) - await update_progress(stage="error", progress=0, message="新版本缺少 _manifest.json", operation="update", plugin_id=plugin_id, error="无效的插件格式") + await update_progress( + stage="error", + progress=0, + message="新版本缺少 _manifest.json", + operation="update", + plugin_id=plugin_id, + error="无效的插件格式", + ) raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json") try: @@ -207,27 +358,49 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s new_version = str(new_manifest.get("version", "unknown")) new_name = str(new_manifest.get("name", plugin_id)) logger.info(f"成功更新插件: {plugin_id} {old_version} → {new_version}") - await update_progress(stage="success", progress=100, message=f"成功更新 {new_name}: {old_version} → {new_version}", operation="update", plugin_id=plugin_id) - return {"success": True, "message": "插件更新成功", "plugin_id": plugin_id, "plugin_name": new_name, "old_version": old_version, "new_version": new_version} + await update_progress( + stage="success", + progress=100, + message=f"成功更新 {new_name}: {old_version} → {new_version}", + operation="update", + plugin_id=plugin_id, + ) + return { + "success": True, + "message": "插件更新成功", + "plugin_id": plugin_id, + "plugin_name": new_name, + "old_version": old_version, + "new_version": new_version, + } except Exception as e: remove_tree(plugin_path) - await update_progress(stage="error", progress=0, message="_manifest.json 无效", operation="update", plugin_id=plugin_id, error=str(e)) + await update_progress( + stage="error", + progress=0, + message="_manifest.json 无效", + operation="update", + plugin_id=plugin_id, + error=str(e), + ) raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e except HTTPException: raise except Exception as e: logger.error(f"更新插件失败: {e}", exc_info=True) - await update_progress(stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e)) + await update_progress( + stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e) + ) raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e @router.get("/installed") -async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: +async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]: require_plugin_token(maibot_session) logger.info("收到获取已安装插件列表请求") try: - installed_plugins: list[dict[str, Any]] = [] + installed_plugins: List[Dict[str, Any]] = [] for plugin_path in iter_plugin_directories(): folder_name = plugin_path.name if folder_name.startswith(".") or folder_name.startswith("__"): @@ -253,9 +426,9 @@ async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) -> except Exception as e: logger.error(f"读取插件 {folder_name} 信息时出错: {e}") - seen_ids: dict[str, str] = {} - unique_plugins: list[dict[str, Any]] = [] - duplicates: list[dict[str, Any]] = [] + seen_ids: Dict[str, str] = {} + unique_plugins: List[Dict[str, Any]] = [] + duplicates: List[Dict[str, Any]] = [] for plugin in installed_plugins: plugin_id = str(plugin["id"]) plugin_path = str(plugin["path"]) @@ -277,7 +450,7 @@ async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) -> @router.get("/local-readme/{plugin_id}") -async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]: +async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]: require_plugin_token(maibot_session) logger.info(f"获取本地插件 README: {plugin_id}") @@ -300,4 +473,4 @@ async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str] return {"success": False, "error": "本地未找到 README 文件"} except Exception as e: logger.error(f"获取本地 README 失败: {e}", exc_info=True) - return {"success": False, "error": str(e)} \ No newline at end of file + return {"success": False, "error": str(e)} diff --git a/src/webui/routers/plugin/progress.py b/src/webui/routers/plugin/progress.py index f12406b7..9e24ace7 100644 --- a/src/webui/routers/plugin/progress.py +++ b/src/webui/routers/plugin/progress.py @@ -1,7 +1,6 @@ import asyncio import json - -from typing import Any, Optional, Set +from typing import Any, Dict, Optional, Set from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect @@ -14,7 +13,7 @@ logger = get_logger("webui.plugin_progress") router = APIRouter() active_connections: Set[WebSocket] = set() -current_progress: dict[str, Any] = { +current_progress: Dict[str, Any] = { "operation": "idle", "stage": "idle", "progress": 0, @@ -26,7 +25,7 @@ current_progress: dict[str, Any] = { } -async def broadcast_progress(progress_data: dict[str, Any]) -> None: +async def broadcast_progress(progress_data: Dict[str, Any]) -> None: global current_progress current_progress = progress_data.copy() @@ -34,7 +33,7 @@ async def broadcast_progress(progress_data: dict[str, Any]) -> None: return message = json.dumps(progress_data, ensure_ascii=False) - disconnected: set[WebSocket] = set() + disconnected: Set[WebSocket] = set() for websocket in active_connections: try: @@ -119,4 +118,4 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = def get_progress_router() -> APIRouter: - return router \ No newline at end of file + return router diff --git a/src/webui/routers/plugin/schemas.py b/src/webui/routers/plugin/schemas.py index c7f6c252..0d431dc0 100644 --- a/src/webui/routers/plugin/schemas.py +++ b/src/webui/routers/plugin/schemas.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field @@ -51,8 +51,8 @@ class MirrorConfigResponse(BaseModel): class AvailableMirrorsResponse(BaseModel): - mirrors: list[MirrorConfigResponse] = Field(..., description="镜像源列表") - default_priority: list[str] = Field(..., description="默认优先级顺序(ID 列表)") + mirrors: List[MirrorConfigResponse] = Field(..., description="镜像源列表") + default_priority: List[str] = Field(..., description="默认优先级顺序(ID 列表)") class AddMirrorRequest(BaseModel): @@ -106,8 +106,8 @@ class UpdatePluginRequest(BaseModel): class UpdatePluginConfigRequest(BaseModel): enabled: Optional[bool] = None - config: Optional[dict[str, Any]] = None + config: Optional[Dict[str, Any]] = None class UpdatePluginRawConfigRequest(BaseModel): - config: str = Field(..., description="原始 TOML 配置内容") \ No newline at end of file + config: str = Field(..., description="原始 TOML 配置内容") diff --git a/src/webui/routers/plugin/support.py b/src/webui/routers/plugin/support.py index 46814709..9001c7ee 100644 --- a/src/webui/routers/plugin/support.py +++ b/src/webui/routers/plugin/support.py @@ -1,12 +1,11 @@ -from datetime import datetime -from pathlib import Path -from typing import Any, Optional, cast, get_origin - import json import os import re import shutil import stat +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, cast, get_origin from fastapi import HTTPException @@ -53,10 +52,7 @@ def _resolve_safe_plugin_directory(plugin_path: Path, plugins_dir: Path, strict: resolved_plugin_path = plugin_path.resolve() resolved_plugin_path.relative_to(resolved_plugins_dir) - if not resolved_plugin_path.is_dir(): - return None - - return resolved_plugin_path + return resolved_plugin_path if resolved_plugin_path.is_dir() else None except HTTPException: if strict: raise @@ -64,7 +60,7 @@ def _resolve_safe_plugin_directory(plugin_path: Path, plugins_dir: Path, strict: return None except (OSError, RuntimeError, ValueError): if strict: - raise HTTPException(status_code=400, detail="插件目录超出允许范围") + raise HTTPException(status_code=400, detail="插件目录超出允许范围") from None logger.warning(f"已跳过越界的插件目录: {plugin_path}") return None @@ -109,7 +105,7 @@ def validate_plugin_id(plugin_id: str) -> str: return plugin_id -def parse_version(version_str: str) -> tuple[int, int, int]: +def parse_version(version_str: str) -> Tuple[int, int, int]: base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0] parts = base_version.split(".") if len(parts) < 3: @@ -122,7 +118,7 @@ def parse_version(version_str: str) -> tuple[int, int, int]: return 0, 0, 0 -def deep_merge(dst: dict[str, Any], src: dict[str, Any]) -> None: +def deep_merge(dst: Dict[str, Any], src: Dict[str, Any]) -> None: for key, value in src.items(): if key in dst and isinstance(dst[key], dict) and isinstance(value, dict): deep_merge(dst[key], value) @@ -130,9 +126,9 @@ def deep_merge(dst: dict[str, Any], src: dict[str, Any]) -> None: dst[key] = value -def normalize_dotted_keys(obj: dict[str, Any]) -> dict[str, Any]: - result: dict[str, Any] = {} - dotted_items: list[tuple[str, Any]] = [] +def normalize_dotted_keys(obj: Dict[str, Any]) -> Dict[str, Any]: + result: Dict[str, Any] = {} + dotted_items: List[Tuple[str, Any]] = [] for key, value in obj.items(): if "." in key: @@ -167,7 +163,7 @@ def normalize_dotted_keys(obj: dict[str, Any]) -> dict[str, Any]: return result -def coerce_types(schema_part: dict[str, Any], config_part: dict[str, Any]) -> None: +def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> None: def is_list_type(tp: Any) -> bool: origin = get_origin(tp) return tp is list or origin is list @@ -200,7 +196,7 @@ def get_plugins_dir() -> Path: return plugins_dir -def get_plugin_candidate_paths(plugin_id: str) -> tuple[Path, Path]: +def get_plugin_candidate_paths(plugin_id: str) -> Tuple[Path, Path]: plugins_dir = get_plugins_dir() folder_name = plugin_id.replace(".", "_") return validate_safe_path(folder_name, plugins_dir), validate_safe_path(plugin_id, plugins_dir) @@ -217,7 +213,7 @@ def resolve_installed_plugin_path(plugin_id: str) -> Optional[Path]: return None -def parse_repository_url(repository_url: str) -> tuple[str, str, str]: +def parse_repository_url(repository_url: str) -> Tuple[str, str, str]: repo_url = repository_url.rstrip("/").removesuffix(".git") parts = repo_url.split("/") if len(parts) < 2: @@ -225,7 +221,7 @@ def parse_repository_url(repository_url: str) -> tuple[str, str, str]: return repo_url, parts[-2], parts[-1] -def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]: +def load_manifest_json(manifest_path: Path) -> Optional[Dict[str, Any]]: if not manifest_path.exists(): return None @@ -246,9 +242,9 @@ def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]: return None -def iter_plugin_directories() -> list[Path]: +def iter_plugin_directories() -> List[Path]: plugins_dir = get_plugins_dir() - plugin_directories: list[Path] = [] + plugin_directories: List[Path] = [] for path in plugins_dir.iterdir(): safe_path = _resolve_safe_plugin_directory(path, plugins_dir, strict=False) if safe_path is not None: @@ -286,4 +282,4 @@ def remove_tree(path: Path) -> None: os.chmod(target_path, stat.S_IWRITE) func(target_path) - shutil.rmtree(path, onerror=remove_readonly) \ No newline at end of file + shutil.rmtree(path, onerror=remove_readonly) diff --git a/src/webui/routers/statistics.py b/src/webui/routers/statistics.py index b91e3244..c101eba3 100644 --- a/src/webui/routers/statistics.py +++ b/src/webui/routers/statistics.py @@ -1,7 +1,7 @@ """统计数据 API 路由""" from datetime import datetime, timedelta -from typing import Any +from typing import Any, Dict, List from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, Field @@ -56,10 +56,10 @@ class DashboardData(BaseModel): """仪表盘数据""" summary: StatisticsSummary - model_stats: list[ModelStatistics] - hourly_data: list[TimeSeriesData] - daily_data: list[TimeSeriesData] - recent_activity: list[dict[str, Any]] + model_stats: List[ModelStatistics] + hourly_data: List[TimeSeriesData] + daily_data: List[TimeSeriesData] + recent_activity: List[Dict[str, Any]] @router.get("/dashboard", response_model=DashboardData) @@ -168,7 +168,7 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S return summary -async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]: +async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]: """获取模型统计数据(优化:使用数据库聚合和分组)""" # 使用GROUP BY聚合,避免全量加载 statement = ( @@ -181,7 +181,7 @@ async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]: with get_db_session() as session: rows = session.exec(statement).all() - aggregates: dict[str, dict[str, float | int]] = {} + aggregates: Dict[str, Dict[str, float | int]] = {} for record in rows: model_name = record.model_assign_name or record.model_name or "unknown" if model_name not in aggregates: @@ -200,7 +200,7 @@ async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]: bucket["total_time_cost"] = float(bucket["total_time_cost"]) + float(record.time_cost) bucket["time_cost_count"] = int(bucket["time_cost_count"]) + 1 - result: list[ModelStatistics] = [] + result: List[ModelStatistics] = [] for model_name, bucket in sorted( aggregates.items(), key=lambda item: float(item[1]["request_count"]), @@ -221,7 +221,7 @@ async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]: return result -async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> list[TimeSeriesData]: +async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]: """获取小时级统计数据(优化:使用数据库聚合)""" # SQLite的日期时间函数进行小时分组 # 使用strftime将timestamp格式化为小时级别 @@ -265,7 +265,7 @@ async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> li return result -async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> list[TimeSeriesData]: +async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]: """获取日级统计数据(优化:使用数据库聚合)""" # 使用strftime按日期分组 day_expr = func.strftime("%Y-%m-%dT00:00:00", col(ModelUsage.timestamp)) @@ -308,7 +308,7 @@ async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> lis return result -async def _get_recent_activity(limit: int = 10) -> list[dict[str, Any]]: +async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]: """获取最近活动""" with get_db_session() as session: statement = select(ModelUsage).order_by(desc(col(ModelUsage.timestamp))).limit(limit) diff --git a/src/webui/routers/system.py b/src/webui/routers/system.py index 7bd27347..837939c3 100644 --- a/src/webui/routers/system.py +++ b/src/webui/routers/system.py @@ -10,8 +10,9 @@ from datetime import datetime from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel -from src.config.config import MMC_VERSION + from src.common.logger import get_logger +from src.config.config import MMC_VERSION from src.webui.dependencies import require_auth router = APIRouter(prefix="/system", tags=["system"], dependencies=[Depends(require_auth)]) diff --git a/src/webui/routers/websocket/__init__.py b/src/webui/routers/websocket/__init__.py index 9dc576f2..dced3cf1 100644 --- a/src/webui/routers/websocket/__init__.py +++ b/src/webui/routers/websocket/__init__.py @@ -1,5 +1,5 @@ -from .logs import router as logs_router from .auth import router as ws_auth_router +from .logs import router as logs_router __all__ = [ "logs_router", diff --git a/src/webui/routers/websocket/auth.py b/src/webui/routers/websocket/auth.py index a8e33e7f..1cb24f6c 100644 --- a/src/webui/routers/websocket/auth.py +++ b/src/webui/routers/websocket/auth.py @@ -1,10 +1,11 @@ """WebSocket 认证模块。""" -from typing import Optional - -from fastapi import APIRouter, Cookie import secrets import time +from typing import Dict, Optional, Tuple + +from fastapi import APIRouter, Cookie + from src.common.logger import get_logger from src.webui.core import get_token_manager @@ -13,7 +14,7 @@ router = APIRouter() # WebSocket 临时 token 存储 {token: (expire_time, session_token)} # 临时 token 有效期 60 秒,仅用于 WebSocket 握手 -_ws_temp_tokens: dict[str, tuple[float, str]] = {} +_ws_temp_tokens: Dict[str, Tuple[float, str]] = {} _WS_TOKEN_EXPIRE_SECONDS = 60 diff --git a/src/webui/routes.py b/src/webui/routes.py index c3e4c1fc..c1a7e446 100644 --- a/src/webui/routes.py +++ b/src/webui/routes.py @@ -2,24 +2,26 @@ from fastapi import APIRouter, Depends, HTTPException, Request, Response from pydantic import BaseModel, Field + from src.common.logger import get_logger from src.webui.core import ( - clear_auth_cookie, check_auth_rate_limit, + clear_auth_cookie, get_rate_limiter, get_token_manager, set_auth_cookie, ) from src.webui.dependencies import require_auth, verify_token_optional from src.webui.routers.config import router as config_router -from src.webui.routers.statistics import router as statistics_router -from src.webui.routers.person import router as person_router +from src.webui.routers.emoji import router as emoji_router from src.webui.routers.expression import router as expression_router from src.webui.routers.jargon import router as jargon_router -from src.webui.routers.emoji import router as emoji_router -from src.webui.routers.plugin import get_progress_router, router as plugin_router -from src.webui.routers.system import router as system_router from src.webui.routers.model import router as model_router +from src.webui.routers.person import router as person_router +from src.webui.routers.plugin import get_progress_router +from src.webui.routers.plugin import router as plugin_router +from src.webui.routers.statistics import router as statistics_router +from src.webui.routers.system import router as system_router from src.webui.routers.websocket.auth import router as ws_auth_router logger = get_logger("webui.api") diff --git a/src/webui/schemas/__init__.py b/src/webui/schemas/__init__.py index c9d12a41..8eb337a1 100644 --- a/src/webui/schemas/__init__.py +++ b/src/webui/schemas/__init__.py @@ -2,62 +2,62 @@ # Auth schemas from .auth import ( - TokenVerifyRequest, - TokenVerifyResponse, + CompleteSetupResponse, + FirstSetupStatusResponse, + ResetSetupResponse, + TokenRegenerateResponse, TokenUpdateRequest, TokenUpdateResponse, - TokenRegenerateResponse, - FirstSetupStatusResponse, - CompleteSetupResponse, - ResetSetupResponse, + TokenVerifyRequest, + TokenVerifyResponse, ) -# Statistics schemas -from .statistics import ( - StatisticsSummary, - ModelStatistics, - TimeSeriesData, - DashboardData, +# Chat schemas +from .chat import ( + ChatHistoryMessage, + VirtualIdentityConfig, ) # Emoji schemas from .emoji import ( - EmojiResponse, - EmojiListResponse, - EmojiDetailResponse, - EmojiUpdateRequest, - EmojiUpdateResponse, - EmojiDeleteResponse, BatchDeleteRequest, BatchDeleteResponse, + EmojiDeleteResponse, + EmojiDetailResponse, + EmojiListResponse, + EmojiResponse, + EmojiUpdateRequest, + EmojiUpdateResponse, EmojiUploadResponse, ThumbnailCacheStatsResponse, ThumbnailCleanupResponse, ThumbnailPreheatResponse, ) -# Chat schemas -from .chat import ( - VirtualIdentityConfig, - ChatHistoryMessage, -) - # Plugin schemas from .plugin import ( - FetchRawFileRequest, - FetchRawFileResponse, + AddMirrorRequest, + AvailableMirrorsResponse, CloneRepositoryRequest, CloneRepositoryResponse, - MirrorConfigResponse, - AvailableMirrorsResponse, - AddMirrorRequest, - UpdateMirrorRequest, + FetchRawFileRequest, + FetchRawFileResponse, GitStatusResponse, InstallPluginRequest, - VersionResponse, + MirrorConfigResponse, UninstallPluginRequest, - UpdatePluginRequest, + UpdateMirrorRequest, UpdatePluginConfigRequest, + UpdatePluginRequest, + VersionResponse, +) + +# Statistics schemas +from .statistics import ( + DashboardData, + ModelStatistics, + StatisticsSummary, + TimeSeriesData, ) __all__ = [ diff --git a/src/webui/schemas/chat.py b/src/webui/schemas/chat.py index 786792be..899dfe22 100644 --- a/src/webui/schemas/chat.py +++ b/src/webui/schemas/chat.py @@ -1,6 +1,7 @@ -from pydantic import BaseModel from typing import Optional +from pydantic import BaseModel + class VirtualIdentityConfig(BaseModel): """虚拟身份配置""" diff --git a/src/webui/schemas/emoji.py b/src/webui/schemas/emoji.py index 70787975..27937f09 100644 --- a/src/webui/schemas/emoji.py +++ b/src/webui/schemas/emoji.py @@ -1,5 +1,6 @@ +from typing import List, Optional + from pydantic import BaseModel -from typing import Optional, List class EmojiResponse(BaseModel): diff --git a/src/webui/schemas/plugin.py b/src/webui/schemas/plugin.py index 1a75a38e..e2fea32d 100644 --- a/src/webui/schemas/plugin.py +++ b/src/webui/schemas/plugin.py @@ -1,5 +1,6 @@ +from typing import Any, Dict, List, Optional + from pydantic import BaseModel, Field -from typing import Optional, List, Dict, Any class FetchRawFileRequest(BaseModel): diff --git a/src/webui/schemas/statistics.py b/src/webui/schemas/statistics.py index 278d7251..36ce36b6 100644 --- a/src/webui/schemas/statistics.py +++ b/src/webui/schemas/statistics.py @@ -1,5 +1,6 @@ +from typing import Any, Dict, List + from pydantic import BaseModel, Field -from typing import Dict, Any, List class StatisticsSummary(BaseModel): diff --git a/src/webui/services/git_mirror_service.py b/src/webui/services/git_mirror_service.py index 5052869d..b40e1c3f 100644 --- a/src/webui/services/git_mirror_service.py +++ b/src/webui/services/git_mirror_service.py @@ -1,14 +1,16 @@ """Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取""" -from typing import Optional, List, Dict, Any -from enum import Enum -import httpx -import json import asyncio -import subprocess +import json import shutil -from pathlib import Path +import subprocess from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional + +import httpx + from src.common.logger import get_logger from src.webui.utils.network_security import validate_public_url @@ -188,10 +190,8 @@ class GitMirrorConfig: def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]: """根据 ID 获取镜像源""" - for mirror in self.mirrors: - if mirror.get("id") == mirror_id: - return mirror.copy() - return None + matched_mirror = next((mirror for mirror in self.mirrors if mirror.get("id") == mirror_id), None) + return matched_mirror.copy() if matched_mirror is not None else None def add_mirror( self, @@ -332,8 +332,8 @@ class GitMirrorService: - path: str - Git 可执行文件路径(如果已安装) - error: str - 错误信息(如果未安装或检测失败) """ - import subprocess import shutil + import subprocess try: # 查找 git 可执行文件路径 @@ -409,8 +409,7 @@ class GitMirrorService: # 确定要使用的镜像源列表 if mirror_id: # 使用指定的镜像源 - mirror = self.config.get_mirror_by_id(mirror_id) - if not mirror: + if (mirror := self.config.get_mirror_by_id(mirror_id)) is None: return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0} mirrors_to_try = [mirror] else: @@ -477,7 +476,13 @@ class GitMirrorService: try: raw_prefix = _validate_mirror_prefix(mirror["raw_prefix"], "镜像 Raw 前缀") except ValueError as e: - return {"success": False, "error": str(e), "mirror_used": mirror.get("id"), "attempts": 0, "status_code": 400} + return { + "success": False, + "error": str(e), + "mirror_used": mirror.get("id"), + "attempts": 0, + "status_code": 400, + } url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}" @@ -566,8 +571,7 @@ class GitMirrorService: # 确定要使用的镜像源列表 if mirror_id: # 使用指定的镜像源 - mirror = self.config.get_mirror_by_id(mirror_id) - if not mirror: + if (mirror := self.config.get_mirror_by_id(mirror_id)) is None: return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0} mirrors_to_try = [mirror] else: @@ -597,7 +601,13 @@ class GitMirrorService: try: clone_prefix = _validate_mirror_prefix(mirror["clone_prefix"], "镜像克隆前缀") except ValueError as e: - return {"success": False, "error": str(e), "mirror_used": mirror.get("id"), "attempts": 0, "status_code": 400} + return { + "success": False, + "error": str(e), + "mirror_used": mirror.get("id"), + "attempts": 0, + "status_code": 400, + } url = f"{clone_prefix}/{owner}/{repo}.git" diff --git a/src/webui/utils/network_security.py b/src/webui/utils/network_security.py index 1cb7e4ea..f8cd580f 100644 --- a/src/webui/utils/network_security.py +++ b/src/webui/utils/network_security.py @@ -1,18 +1,16 @@ -from typing import Iterable - import ipaddress import socket - +from typing import Iterable, Set from urllib.parse import urlparse -def _resolve_ip_addresses(hostname: str, port: int) -> set[ipaddress.IPv4Address | ipaddress.IPv6Address]: +def _resolve_ip_addresses(hostname: str, port: int) -> Set[ipaddress.IPv4Address | ipaddress.IPv6Address]: try: address_infos = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM) except socket.gaierror as exc: raise ValueError(f"无法解析主机名: {hostname}") from exc - resolved_addresses: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set() + resolved_addresses: Set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set() for _, _, _, _, sockaddr in address_infos: host_address = sockaddr[0] if not isinstance(host_address, str): @@ -73,4 +71,4 @@ def validate_public_url(url: str, allowed_schemes: Iterable[str] = ("https",)) - if _is_forbidden_ip_address(address): raise ValueError(f"禁止访问非公网地址: {address}") - return normalized_url \ No newline at end of file + return normalized_url diff --git a/src/webui/utils/toml_utils.py b/src/webui/utils/toml_utils.py index 7543979a..ecabdf0b 100644 --- a/src/webui/utils/toml_utils.py +++ b/src/webui/utils/toml_utils.py @@ -4,8 +4,9 @@ TOML 工具函数 提供 TOML 文件的格式化保存功能,确保数组等元素以美观的多行格式输出。 """ -from typing import Any import re +from typing import Any + import tomlkit from tomlkit.items import AoT, Array, Table diff --git a/src/webui/webui_server.py b/src/webui/webui_server.py index fbc30cbf..436f941e 100644 --- a/src/webui/webui_server.py +++ b/src/webui/webui_server.py @@ -1,9 +1,10 @@ """独立的 WebUI 服务器 - 运行在 0.0.0.0:8001""" -from uvicorn import Config, Server as UvicornServer - import asyncio +from uvicorn import Config +from uvicorn import Server as UvicornServer + from src.common.logger import get_logger from src.common.utils.port_checker import assert_port_available, is_port_conflict_error, log_port_conflict from src.config.config import config_manager