WebUI 后端类型注解补全,使用全 typing 库类型注解

This commit is contained in:
DrSmoothl
2026-03-16 13:09:12 +08:00
parent df088205dd
commit e7ac064a80
47 changed files with 572 additions and 365 deletions

View File

@@ -8,11 +8,10 @@
3. 详情按需加载
"""
import json
from pathlib import Path
from typing import Dict, List, Optional
import json
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
@@ -272,8 +271,7 @@ async def get_all_logs(page: int = Query(1, ge=1), page_size: int = Query(20, ge
all_files = []
for chat_dir in PLAN_LOG_DIR.iterdir():
if chat_dir.is_dir():
for log_file in chat_dir.glob("*.json"):
all_files.append((chat_dir.name, log_file))
all_files.extend((chat_dir.name, log_file) for log_file in chat_dir.glob("*.json"))
# 按时间戳排序
all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True)

View File

@@ -8,11 +8,10 @@
3. 详情按需加载
"""
import json
from pathlib import Path
from typing import Dict, List, Optional
import json
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel

View File

@@ -1,11 +1,12 @@
"""FastAPI 应用工厂 - 创建和配置 WebUI 应用实例"""
import mimetypes
import shutil
from dataclasses import dataclass
from importlib import import_module
from pathlib import Path
from subprocess import CompletedProcess, TimeoutExpired, run
import mimetypes
import shutil
from typing import Any, Dict, List, Tuple
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
@@ -61,12 +62,12 @@ def _get_dashboard_root() -> Path:
return _get_project_root() / "dashboard"
def _format_dashboard_shell_commands(*commands: list[str]) -> str:
def _format_dashboard_shell_commands(*commands: List[str]) -> str:
formatted_commands = " && ".join(" ".join(command) for command in commands)
return f"cd dashboard && {formatted_commands}"
def _validate_static_path(static_path: Path | None) -> tuple[str, dict[str, object]] | None:
def _validate_static_path(static_path: Path | None) -> Tuple[str, Dict[str, Any]] | None:
if static_path is None:
return "startup.webui_static_dir_missing", {}
@@ -81,7 +82,7 @@ def _validate_static_path(static_path: Path | None) -> tuple[str, dict[str, obje
def _summarize_command_output(command_result: CompletedProcess[str] | TimeoutExpired) -> str:
output_chunks: list[str] = []
output_chunks: List[str] = []
stdout = command_result.stdout
stderr = command_result.stderr
@@ -111,14 +112,18 @@ def _get_preferred_dashboard_package_manager(dashboard_root: Path) -> str:
return "npm"
def _get_dashboard_build_command(dashboard_root: Path) -> list[str] | None:
def _get_dashboard_build_command(dashboard_root: Path) -> List[str] | None:
if not (dashboard_root / "package.json").exists():
return None
preferred_package_manager = _get_preferred_dashboard_package_manager(dashboard_root)
package_managers = [
preferred_package_manager,
*[package_manager for package_manager in _DASHBOARD_BUILD_COMMANDS if package_manager != preferred_package_manager],
*[
package_manager
for package_manager in _DASHBOARD_BUILD_COMMANDS
if package_manager != preferred_package_manager
],
]
for package_manager in package_managers:
@@ -128,8 +133,10 @@ def _get_dashboard_build_command(dashboard_root: Path) -> list[str] | None:
return None
def _get_dashboard_manual_recovery_command(dashboard_root: Path, build_command: list[str] | None = None) -> str:
package_manager = build_command[0] if build_command is not None else _get_preferred_dashboard_package_manager(dashboard_root)
def _get_dashboard_manual_recovery_command(dashboard_root: Path, build_command: List[str] | None = None) -> str:
package_manager = (
build_command[0] if build_command is not None else _get_preferred_dashboard_package_manager(dashboard_root)
)
install_command = _DASHBOARD_INSTALL_COMMANDS.get(package_manager)
selected_build_command = _DASHBOARD_BUILD_COMMANDS.get(package_manager)
@@ -308,8 +315,8 @@ def _setup_cors(app: FastAPI, port: int):
def _setup_anti_crawler(app: FastAPI):
try:
from src.webui.middleware import AntiCrawlerMiddleware
from src.config.config import global_config
from src.webui.middleware import AntiCrawlerMiddleware
anti_crawler_mode = global_config.webui.anti_crawler_mode
app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)

View File

@@ -1,5 +1,5 @@
import inspect
from typing import Any, get_args, get_origin
from typing import Any, Dict, List, get_args, get_origin
from pydantic_core import PydanticUndefined
@@ -8,13 +8,13 @@ from src.config.config_base import ConfigBase
class ConfigSchemaGenerator:
@classmethod
def generate_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> dict[str, Any]:
def generate_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> Dict[str, Any]:
return cls.generate_config_schema(config_class, include_nested=include_nested)
@classmethod
def generate_config_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> dict[str, Any]:
fields: list[dict[str, Any]] = []
nested: dict[str, dict[str, Any]] = {}
def generate_config_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> Dict[str, Any]:
fields: List[Dict[str, Any]] = []
nested: Dict[str, Dict[str, Any]] = {}
for field_name, field_info in config_class.model_fields.items():
if field_name in {"field_docs", "_validate_any", "suppress_any_warning"}:
@@ -28,7 +28,7 @@ class ConfigSchemaGenerator:
if nested_schema is not None:
nested[field_name] = nested_schema
schema: dict[str, Any] = {
schema: Dict[str, Any] = {
"className": config_class.__name__,
"classDoc": (config_class.__doc__ or "").strip(),
"fields": fields,
@@ -49,7 +49,7 @@ class ConfigSchemaGenerator:
return schema
@classmethod
def _build_nested_schema(cls, annotation: Any) -> dict[str, Any] | None:
def _build_nested_schema(cls, annotation: Any) -> Dict[str, Any] | None:
origin = get_origin(annotation)
args = get_args(annotation)
@@ -66,10 +66,10 @@ class ConfigSchemaGenerator:
@classmethod
def _build_field_schema(
cls, config_class: type[ConfigBase], field_name: str, annotation: Any, field_info: Any
) -> dict[str, Any]:
) -> Dict[str, Any]:
field_docs = config_class.get_class_field_docs()
field_type = cls._map_field_type(annotation)
schema: dict[str, Any] = {
schema: Dict[str, Any] = {
"name": field_name,
"type": field_type,
"label": field_name,
@@ -86,8 +86,7 @@ class ConfigSchemaGenerator:
if origin is list and args:
schema["items"] = {"type": cls._map_field_type(args[0])}
options = cls._extract_options(annotation)
if options:
if options := cls._extract_options(annotation):
schema["options"] = options
# Task 1c: Merge json_schema_extra (x-widget, x-icon, step, etc.)
@@ -105,7 +104,7 @@ class ConfigSchemaGenerator:
return schema
@staticmethod
def _extract_options(annotation: Any) -> list[str] | None:
def _extract_options(annotation: Any) -> List[str] | None:
origin = get_origin(annotation)
if origin is None:
return None

View File

@@ -1,19 +1,19 @@
from .security import TokenManager, get_token_manager
from .rate_limiter import (
RateLimiter,
get_rate_limiter,
check_auth_rate_limit,
check_api_rate_limit,
)
from .auth import (
COOKIE_NAME,
COOKIE_MAX_AGE,
COOKIE_NAME,
clear_auth_cookie,
get_current_token,
is_token_valid,
set_auth_cookie,
clear_auth_cookie,
verify_auth_token_from_cookie_or_header,
)
from .rate_limiter import (
RateLimiter,
check_api_rate_limit,
check_auth_rate_limit,
get_rate_limiter,
)
from .security import TokenManager, get_token_manager
__all__ = [
"TokenManager",

View File

@@ -3,8 +3,10 @@
from typing import Optional
from fastapi import Cookie, HTTPException, Request, Response
from src.common.logger import get_logger
from src.config.config import global_config
from .security import get_token_manager
logger = get_logger("webui.auth")
@@ -54,6 +56,7 @@ def get_current_token(
if not is_token_valid(maibot_session):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
assert maibot_session is not None
return maibot_session

View File

@@ -5,8 +5,10 @@ WebUI 请求频率限制模块
import time
from collections import defaultdict
from typing import Dict, Tuple, Optional
from fastapi import Request, HTTPException
from typing import Dict, List, Optional, Tuple
from fastapi import HTTPException, Request
from src.common.logger import get_logger
logger = get_logger("webui.rate_limiter")
@@ -21,7 +23,7 @@ class RateLimiter:
def __init__(self):
# 存储格式: {key: [(timestamp, count), ...]}
self._requests: Dict[str, list] = defaultdict(list)
self._requests: Dict[str, List] = defaultdict(list)
# 被封禁的 IP: {ip: unblock_timestamp}
self._blocked: Dict[str, float] = {}

View File

@@ -6,7 +6,7 @@ WebUI Token 管理模块
import json
import secrets
from pathlib import Path
from typing import Optional
from typing import Dict, Optional, Tuple
from src.common.logger import get_logger
@@ -52,7 +52,7 @@ class TokenManager:
logger.error(f"读取 WebUI 配置文件失败: {e},正在重新创建")
self._create_new_token()
def _load_config(self) -> dict:
def _load_config(self) -> Dict:
"""加载配置文件"""
try:
with open(self.config_path, "r", encoding="utf-8") as f:
@@ -61,7 +61,7 @@ class TokenManager:
logger.error(f"加载 WebUI 配置失败: {e}")
return {}
def _save_config(self, config: dict):
def _save_config(self, config: Dict):
"""保存配置文件"""
try:
with open(self.config_path, "w", encoding="utf-8") as f:
@@ -127,7 +127,7 @@ class TokenManager:
return is_valid
def update_token(self, new_token: str) -> tuple[bool, str]:
def update_token(self, new_token: str) -> Tuple[bool, str]:
"""
更新 token
@@ -208,7 +208,7 @@ class TokenManager:
except ValueError:
return False
def _validate_custom_token(self, token: str) -> tuple[bool, str]:
def _validate_custom_token(self, token: str) -> Tuple[bool, str]:
"""
验证自定义 token 格式

View File

@@ -1,6 +1,7 @@
from typing import Optional
from fastapi import Cookie, Depends, Request
from .core import check_auth_rate_limit, get_current_token, is_token_valid

View File

@@ -1,9 +1,11 @@
"""WebSocket 日志推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Optional
import json
from pathlib import Path
from typing import Dict, List, Optional, Set
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
@@ -15,7 +17,7 @@ router = APIRouter()
active_connections: Set[WebSocket] = set()
def load_recent_logs(limit: int = 100) -> list[dict]:
def load_recent_logs(limit: int = 100) -> List[Dict]:
"""从日志文件中加载最近的日志
Args:
@@ -140,7 +142,7 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None
active_connections.discard(websocket)
async def broadcast_log(log_data: dict):
async def broadcast_log(log_data: Dict):
"""广播日志到所有连接的 WebSocket 客户端
Args:

View File

@@ -1,10 +1,10 @@
from .anti_crawler import (
ALLOWED_IPS,
ANTI_CRAWLER_MODE,
TRUST_XFF,
TRUSTED_PROXIES,
AntiCrawlerMiddleware,
create_robots_txt_response,
ANTI_CRAWLER_MODE,
ALLOWED_IPS,
TRUSTED_PROXIES,
TRUST_XFF,
)
__all__ = [

View File

@@ -3,11 +3,12 @@ WebUI 防爬虫模块
提供爬虫检测和阻止功能,保护 WebUI 不被搜索引擎和恶意爬虫访问
"""
import time
import ipaddress
import re
import time
from collections import deque
from typing import Optional
from typing import Dict, List, Optional, Tuple
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse
@@ -131,7 +132,7 @@ SCANNER_SPECIFIC_HEADERS = {
# - CIDR格式192.168.1.0/24, 172.17.0.0/16 (适用于Docker网络)
# - 通配符192.168.*.*, 10.*.*.*, *.*.*.* (匹配所有)
# - IPv6::1, 2001:db8::/32
def _parse_allowed_ips(ip_string: str) -> list:
def _parse_allowed_ips(ip_string: str) -> List:
"""
解析IP白名单字符串支持精确IP、CIDR格式和通配符
@@ -255,7 +256,7 @@ TRUSTED_PROXIES = _config["trusted_proxies"]
TRUST_XFF = _config["trust_xff"]
def _get_mode_config(mode: str) -> dict:
def _get_mode_config(mode: str) -> Dict:
"""
根据模式获取配置参数
@@ -338,7 +339,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
# 用于存储每个IP的请求时间戳使用deque提高性能
self.request_times: dict[str, deque] = {}
self.request_times: Dict[str, deque] = {}
# 上次清理时间
self.last_cleanup = time.time()
# 将关键词列表转换为集合以提高查找性能
@@ -407,7 +408,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
return False
def _detect_asset_scanner(self, request: Request) -> tuple[bool, Optional[str]]:
def _detect_asset_scanner(self, request: Request) -> Tuple[bool, Optional[str]]:
"""
检测资产测绘工具

View File

@@ -1,5 +1,7 @@
"""WebUI 路由聚合模块 - 提供统一的路由注册接口"""
from typing import List
from fastapi import APIRouter
@@ -10,14 +12,14 @@ def get_api_router() -> APIRouter:
return main_router
def get_all_routers() -> list[APIRouter]:
def get_all_routers() -> List[APIRouter]:
"""获取所有需要独立注册的路由器列表"""
from src.webui.routes import router as main_router
from src.webui.routers.websocket.logs import router as logs_router
from src.webui.routers.knowledge import router as knowledge_router
from src.webui.routers.chat import router as chat_router
from src.webui.api.planner import router as planner_router
from src.webui.api.replier import router as replier_router
from src.webui.routers.chat import router as chat_router
from src.webui.routers.knowledge import router as knowledge_router
from src.webui.routers.websocket.logs import router as logs_router
from src.webui.routes import router as main_router
return [
main_router,

View File

@@ -1,10 +1,10 @@
from fastapi import APIRouter
from typing import Tuple
from .routes import router
from .support import ChatConnectionManager, WEBUI_CHAT_PLATFORM, chat_manager
from .support import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
def get_webui_chat_broadcaster() -> tuple[ChatConnectionManager, str]:
def get_webui_chat_broadcaster() -> Tuple[ChatConnectionManager, str]:
"""获取 WebUI 聊天广播器,供外部模块使用。"""
return chat_manager, WEBUI_CHAT_PLATFORM
@@ -15,4 +15,4 @@ __all__ = [
"chat_manager",
"get_webui_chat_broadcaster",
"router",
]
]

View File

@@ -1,8 +1,7 @@
"""本地聊天室路由 - WebUI 与麦麦直接对话。"""
import uuid
from typing import Optional
from typing import Dict, Optional
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
from sqlalchemy import case, func
@@ -36,7 +35,7 @@ async def get_chat_history(
limit: int = Query(default=50, ge=1, le=200),
user_id: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None),
) -> dict[str, object]:
) -> Dict[str, object]:
"""获取聊天历史记录。"""
del user_id
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
@@ -45,7 +44,7 @@ async def get_chat_history(
@router.get("/platforms")
async def get_available_platforms() -> dict[str, object]:
async def get_available_platforms() -> Dict[str, object]:
"""获取可用平台列表。"""
try:
with get_db_session() as session:
@@ -68,7 +67,7 @@ async def get_persons_by_platform(
platform: str = Query(..., description="平台名称"),
search: Optional[str] = Query(default=None, description="搜索关键词"),
limit: int = Query(default=50, ge=1, le=200),
) -> dict[str, object]:
) -> Dict[str, object]:
"""获取指定平台的用户列表。"""
try:
statement = select(PersonInfo).where(col(PersonInfo.platform) == platform)
@@ -108,7 +107,7 @@ async def get_persons_by_platform(
@router.delete("/history")
async def clear_chat_history(
group_id: Optional[str] = Query(default=None),
) -> dict[str, object]:
) -> Dict[str, object]:
"""清空聊天历史记录。"""
deleted = chat_history.clear_history(group_id)
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
@@ -164,11 +163,11 @@ async def websocket_chat(
@router.get("/info")
async def get_chat_info() -> dict[str, object]:
async def get_chat_info() -> Dict[str, object]:
"""获取聊天室信息。"""
return {
"bot_name": global_config.bot.nickname,
"platform": WEBUI_CHAT_PLATFORM,
"group_id": WEBUI_CHAT_GROUP_ID,
"active_sessions": len(chat_manager.active_connections),
}
}

View File

@@ -1,9 +1,8 @@
"""WebUI 聊天路由支持逻辑。"""
from typing import Any, Optional, cast
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, cast
from fastapi import WebSocket
from pydantic import BaseModel
@@ -59,7 +58,7 @@ class ChatHistoryManager:
def __init__(self, max_messages: int = 200) -> None:
self.max_messages = max_messages
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> dict[str, Any]:
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> Dict[str, Any]:
user_info = msg.message_info.user_info
user_id = user_info.user_id or ""
is_bot = is_bot_self(msg.platform, user_id)
@@ -78,7 +77,7 @@ class ChatHistoryManager:
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> list[dict[str, Any]]:
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
session_id = self._resolve_session_id(target_group_id)
try:
@@ -114,8 +113,8 @@ class ChatConnectionManager:
"""聊天连接管理器。"""
def __init__(self) -> None:
self.active_connections: dict[str, WebSocket] = {}
self.user_sessions: dict[str, str] = {}
self.active_connections: Dict[str, WebSocket] = {}
self.user_sessions: Dict[str, str] = {}
async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None:
await websocket.accept()
@@ -130,14 +129,14 @@ class ChatConnectionManager:
del self.user_sessions[user_id]
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
async def send_message(self, session_id: str, message: dict[str, Any]) -> None:
async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
if session_id in self.active_connections:
try:
await self.active_connections[session_id].send_json(message)
except Exception as e:
logger.error(f"发送消息失败: {e}")
async def broadcast(self, message: dict[str, Any]) -> None:
async def broadcast(self, message: Dict[str, Any]) -> None:
for session_id in list(self.active_connections.keys()):
await self.send_message(session_id, message)
@@ -224,8 +223,8 @@ def build_session_info_message(
user_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> dict[str, Any]:
session_info_data: dict[str, Any] = {
) -> Dict[str, Any]:
session_info_data: Dict[str, Any] = {
"type": "session_info",
"session_id": session_id,
"user_id": user_id,
@@ -314,7 +313,7 @@ def resolve_sender_identity(
current_user_name: str,
normalized_user_id: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> tuple[str, str]:
) -> Tuple[str, str]:
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id
@@ -328,7 +327,7 @@ def create_message_data(
message_id: Optional[str] = None,
is_at_bot: bool = True,
virtual_config: Optional[VirtualIdentityConfig] = None,
) -> dict[str, Any]:
) -> Dict[str, Any]:
if message_id is None:
message_id = str(uuid.uuid4())
@@ -385,7 +384,7 @@ def create_message_data(
async def handle_chat_message(
session_id: str,
data: dict[str, Any],
data: Dict[str, Any],
current_user_name: str,
normalized_user_id: str,
current_virtual_config: Optional[VirtualIdentityConfig],
@@ -443,7 +442,7 @@ async def handle_chat_ping(session_id: str) -> None:
await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()})
async def handle_nickname_update(session_id: str, data: dict[str, Any], current_user_name: str) -> str:
async def handle_nickname_update(session_id: str, data: Dict[str, Any], current_user_name: str) -> str:
new_name = str(data.get("user_name", "")).strip()
if not new_name:
return current_user_name
@@ -462,7 +461,7 @@ async def handle_nickname_update(session_id: str, data: dict[str, Any], current_
async def enable_virtual_identity(
session_id: str,
session_prefix: str,
virtual_data: dict[str, Any],
virtual_data: Dict[str, Any],
) -> Optional[VirtualIdentityConfig]:
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id")
@@ -558,7 +557,7 @@ async def disable_virtual_identity(session_id: str) -> None:
async def handle_virtual_identity_update(
session_id: str,
session_id_prefix: str,
data: dict[str, Any],
data: Dict[str, Any],
current_virtual_config: Optional[VirtualIdentityConfig],
) -> Optional[VirtualIdentityConfig]:
virtual_data = cast(dict[str, Any], data.get("config", {}))
@@ -573,11 +572,11 @@ async def handle_virtual_identity_update(
async def dispatch_chat_event(
session_id: str,
session_id_prefix: str,
data: dict[str, Any],
data: Dict[str, Any],
current_user_name: str,
normalized_user_id: str,
current_virtual_config: Optional[VirtualIdentityConfig],
) -> tuple[str, Optional[VirtualIdentityConfig]]:
) -> Tuple[str, Optional[VirtualIdentityConfig]]:
event_type = data.get("type")
if event_type == "message":
next_user_name = await handle_chat_message(

View File

@@ -5,51 +5,51 @@
import copy
import os
from pathlib import Path
from typing import Any, Annotated, Optional
from typing import Annotated, Any, Dict, List, Tuple
import tomlkit
from fastapi import APIRouter, Body, Depends, HTTPException
from src.common.logger import get_logger
from src.webui.dependencies import require_auth
from src.webui.utils.toml_utils import save_toml_with_format, _update_toml_doc
from src.config.config import Config, ModelConfig, CONFIG_DIR, PROJECT_ROOT
from src.config.config import CONFIG_DIR, PROJECT_ROOT, Config, ModelConfig
from src.config.config_base import AttributeData
from src.config.model_configs import (
APIProvider,
ModelInfo,
ModelTaskConfig,
)
from src.config.official_configs import (
BotConfig,
PersonalityConfig,
RelationshipConfig,
ChatConfig,
MessageReceiveConfig,
ChineseTypoConfig,
DebugConfig,
EmojiConfig,
ExperimentalConfig,
ExpressionConfig,
KeywordReactionConfig,
ChineseTypoConfig,
LPMMKnowledgeConfig,
MaimMessageConfig,
MemoryConfig,
MessageReceiveConfig,
PersonalityConfig,
RelationshipConfig,
ResponsePostProcessConfig,
ResponseSplitterConfig,
TelemetryConfig,
ExperimentalConfig,
MaimMessageConfig,
LPMMKnowledgeConfig,
ToolConfig,
MemoryConfig,
DebugConfig,
VoiceConfig,
)
from src.config.model_configs import (
ModelTaskConfig,
ModelInfo,
APIProvider,
)
from src.webui.config_schema import ConfigSchemaGenerator
from src.webui.dependencies import require_auth
from src.webui.utils.toml_utils import _update_toml_doc, save_toml_with_format
logger = get_logger("webui")
# 模块级别的类型别名(解决 B008 ruff 错误)
ConfigBody = Annotated[dict[str, Any], Body()]
ConfigBody = Annotated[Dict[str, Any], Body()]
SectionBody = Annotated[Any, Body()]
RawContentBody = Annotated[str, Body(embed=True)]
PathBody = Annotated[dict[str, str], Body()]
PathBody = Annotated[Dict[str, str], Body()]
router = APIRouter(prefix="/config", tags=["config"], dependencies=[Depends(require_auth)])
@@ -61,6 +61,8 @@ def _toml_to_plain_dict(obj: Any) -> Any:
if isinstance(obj, list):
return [_toml_to_plain_dict(v) for v in obj]
return obj
# ===== 架构获取接口 =====
@@ -385,8 +387,12 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
if section_name == "api_providers" and "api_provider" in str(e):
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
models = config_data.get("models", [])
orphaned_models = [
m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
orphaned_models: List[str] = [
str(model_name)
for m in models
if isinstance(m, dict)
and m.get("api_provider") not in provider_names
and (model_name := m.get("name")) is not None
]
if orphaned_models:
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
@@ -421,7 +427,7 @@ def _normalize_adapter_path(path: str) -> str:
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
def _get_allowed_adapter_config_roots() -> tuple[Path, ...]:
def _get_allowed_adapter_config_roots() -> Tuple[Path, ...]:
project_root = Path(PROJECT_ROOT).resolve()
return (
project_root,

View File

@@ -1,3 +1,3 @@
from .routes import router
__all__ = ["router"]
__all__ = ["router"]

View File

@@ -1,13 +1,12 @@
"""表情包管理 API 路由"""
from datetime import datetime
from pathlib import Path
from typing import Any, Optional
import asyncio
import hashlib
import io
import os
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Cookie, HTTPException, Query
from fastapi.responses import FileResponse, JSONResponse
@@ -17,7 +16,8 @@ from sqlmodel import col, select
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header as verify_auth_token
from src.webui.core import get_token_manager
from src.webui.core import verify_auth_token_from_cookie_or_header as verify_auth_token
from .schemas import (
BatchDeleteRequest,
@@ -219,7 +219,7 @@ async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(Non
@router.get("/stats/summary")
async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
"""获取表情包统计数据。"""
try:
verify_auth_token(maibot_session)
@@ -247,7 +247,7 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None)) -> dict[
registered = session.exec(registered_statement).one()
banned = session.exec(banned_statement).one()
formats: dict[str, int] = {}
formats: Dict[str, int] = {}
format_statement = select(Images.full_path).where(col(Images.image_type) == ImageType.EMOJI)
for full_path in session.exec(format_statement).all():
suffix = Path(full_path).suffix.lower().lstrip(".")
@@ -443,7 +443,7 @@ async def batch_delete_emojis(
deleted_count = 0
failed_count = 0
failed_ids: list[int] = []
failed_ids: List[int] = []
for emoji_id in request.emoji_ids:
try:
@@ -582,12 +582,12 @@ async def batch_upload_emoji(
emotion: EmotionForm = "",
is_registered: IsRegisteredForm = True,
maibot_session: Optional[str] = Cookie(None),
) -> dict[str, Any]:
) -> Dict[str, Any]:
"""批量上传表情包。"""
try:
verify_auth_token(maibot_session)
results: dict[str, Any] = {
results: Dict[str, Any] = {
"success": True,
"total": len(files),
"uploaded": 0,
@@ -614,9 +614,7 @@ async def batch_upload_emoji(
file_content = await file.read()
if not file_content:
results["failed"] += 1
results["details"].append(
{"filename": file.filename, "success": False, "error": "文件内容为空"}
)
results["details"].append({"filename": file.filename, "success": False, "error": "文件内容为空"})
continue
try:
@@ -677,14 +675,10 @@ async def batch_upload_emoji(
session.flush()
results["uploaded"] += 1
results["details"].append(
{"filename": file.filename, "success": True, "id": emoji.id}
)
results["details"].append({"filename": file.filename, "success": True, "id": emoji.id})
except Exception as e:
results["failed"] += 1
results["details"].append(
{"filename": file.filename, "success": False, "error": str(e)}
)
results["details"].append({"filename": file.filename, "success": False, "error": str(e)})
results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']}"
return results
@@ -787,7 +781,9 @@ async def preheat_thumbnail_cache(
try:
loop = asyncio.get_event_loop()
await loop.run_in_executor(get_thumbnail_executor(), generate_thumbnail, emoji.full_path, emoji.image_hash)
await loop.run_in_executor(
get_thumbnail_executor(), generate_thumbnail, emoji.full_path, emoji.image_hash
)
generated += 1
except Exception as e:
logger.warning(f"预热缩略图失败 {emoji.image_hash}: {e}")
@@ -840,4 +836,4 @@ async def clear_all_thumbnail_cache(maibot_session: Optional[str] = Cookie(None)
raise
except Exception as e:
logger.exception(f"清空缩略图缓存失败: {e}")
raise HTTPException(status_code=500, detail=f"清空失败: {str(e)}") from e
raise HTTPException(status_code=500, detail=f"清空失败: {str(e)}") from e

View File

@@ -137,4 +137,4 @@ def emoji_to_response(image: Images) -> EmojiResponse:
record_time=image.record_time.timestamp() if image.record_time else 0.0,
register_time=image.register_time.timestamp() if image.register_time else None,
last_used_time=image.last_used_time.timestamp() if image.last_used_time else None,
)
)

View File

@@ -1,8 +1,8 @@
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import os
import threading
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict, Set, Tuple
from PIL import Image
from sqlmodel import col, select
@@ -18,10 +18,10 @@ THUMBNAIL_SIZE = (200, 200)
THUMBNAIL_QUALITY = 80
EMOJI_REGISTERED_DIR = os.path.join("data", "emoji_registed")
_thumbnail_locks: dict[str, threading.Lock] = {}
_thumbnail_locks: Dict[str, threading.Lock] = {}
_locks_lock = threading.Lock()
_thumbnail_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="thumbnail")
_generating_thumbnails: set[str] = set()
_generating_thumbnails: Set[str] = set()
_generating_lock = threading.Lock()
@@ -35,7 +35,7 @@ def get_generating_lock() -> threading.Lock:
return _generating_lock
def get_generating_thumbnails() -> set[str]:
def get_generating_thumbnails() -> Set[str]:
"""获取正在生成的缩略图哈希集合。"""
return _generating_thumbnails
@@ -112,7 +112,7 @@ def background_generate_thumbnail(source_path: str, file_hash: str) -> None:
_background_generate_thumbnail(source_path, file_hash)
def cleanup_orphaned_thumbnails() -> tuple[int, int]:
def cleanup_orphaned_thumbnails() -> Tuple[int, int]:
"""清理孤立的缩略图缓存。"""
if not THUMBNAIL_CACHE_DIR.exists():
return 0, 0
@@ -139,4 +139,4 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]:
if cleaned > 0:
logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept}")
return cleaned, kept
return cleaned, kept

View File

@@ -5,14 +5,13 @@ from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy import case, func
from sqlmodel import col, select, delete
from sqlmodel import col, delete, select
from src.common.logger import get_logger
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.logger import get_logger
from src.webui.dependencies import require_auth
logger = get_logger("webui.expression")

View File

@@ -1,8 +1,7 @@
"""黑话(俚语)管理路由"""
import json
from typing import Annotated, Any, List, Optional
from typing import Annotated, Any, Dict, List, Optional, Set
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
@@ -90,7 +89,7 @@ class JargonListResponse(BaseModel):
total: int
page: int
page_size: int
data: List[dict[str, Any]]
data: List[Dict[str, Any]]
class JargonDetailResponse(BaseModel):
@@ -153,7 +152,7 @@ class JargonStatsResponse(BaseModel):
"""黑话统计响应"""
success: bool = True
data: dict[str, Any]
data: Dict[str, Any]
class ChatInfoResponse(BaseModel):
@@ -175,7 +174,7 @@ class ChatListResponse(BaseModel):
# ==================== 工具函数 ====================
def parse_session_id_dict(session_id_dict_str: Optional[str]) -> dict[str, int]:
def parse_session_id_dict(session_id_dict_str: Optional[str]) -> Dict[str, int]:
"""解析会话计数字典。"""
if not session_id_dict_str:
return {}
@@ -188,7 +187,7 @@ def parse_session_id_dict(session_id_dict_str: Optional[str]) -> dict[str, int]:
if not isinstance(parsed, dict):
return {}
session_counts: dict[str, int] = {}
session_counts: Dict[str, int] = {}
for session_id, count in parsed.items():
if not isinstance(session_id, str):
continue
@@ -202,7 +201,7 @@ def parse_session_id_dict(session_id_dict_str: Optional[str]) -> dict[str, int]:
return session_counts
def dump_session_id_dict(session_counts: dict[str, int]) -> str:
def dump_session_id_dict(session_counts: Dict[str, int]) -> str:
"""序列化会话计数字典。"""
return json.dumps(session_counts, ensure_ascii=False)
@@ -225,7 +224,7 @@ def build_session_id_dict_for_chat(chat_id: str, count: int = 1) -> str:
return dump_session_id_dict({chat_id: count})
def jargon_to_dict(jargon: Jargon, session: Session) -> dict[str, Any]:
def jargon_to_dict(jargon: Jargon, session: Session) -> Dict[str, Any]:
"""将 Jargon ORM 对象转换为字典"""
chat_id = get_primary_chat_id(jargon.session_id_dict)
chat_name = get_display_name_for_chat_id(chat_id, session) if chat_id else None
@@ -311,14 +310,16 @@ async def get_chat_list():
with get_db_session() as session:
jargons = session.exec(select(Jargon)).all()
seen_stream_ids: set[str] = set()
seen_stream_ids: Set[str] = set()
for jargon in jargons:
seen_stream_ids.update(parse_session_id_dict(jargon.session_id_dict).keys())
result = []
with get_db_session() as session:
for stream_id in seen_stream_ids:
if chat_session := session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first():
if chat_session := session.exec(
select(ChatSession).where(col(ChatSession.session_id) == stream_id)
).first():
chat_name = str(chat_session.group_id) if chat_session.group_id else stream_id[:20]
result.append(
ChatInfoResponse(
@@ -358,7 +359,7 @@ async def get_jargon_stats():
pending = sum(jargon.is_jargon is None for jargon in jargons)
complete_count = sum(jargon.is_complete for jargon in jargons)
top_chats_counter: dict[str, int] = {}
top_chats_counter: Dict[str, int] = {}
for jargon in jargons:
for session_id in parse_session_id_dict(jargon.session_id_dict):
top_chats_counter[session_id] = top_chats_counter.get(session_id, 0) + 1

View File

@@ -1,10 +1,11 @@
"""知识库图谱可视化 API 路由"""
from typing import Any, List, Optional
import logging
from typing import Any, List, Optional, Tuple
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
import logging
from src.config.config import global_config
from src.webui.dependencies import require_auth
@@ -59,7 +60,7 @@ def _get_paragraph_store():
return None
def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
def _get_paragraph_content(node_id: str) -> Tuple[Optional[str], bool]:
"""从 embedding store 获取段落完整内容
Args:

View File

@@ -5,9 +5,9 @@
"""
import os
import httpx
from typing import Optional
from typing import Dict, List, Optional
import httpx
import tomlkit
from fastapi import APIRouter, Depends, HTTPException, Query
@@ -39,7 +39,7 @@ def _normalize_url(url: str) -> str:
return url.rstrip("/") if url else ""
def _parse_openai_response(data: dict) -> list[dict]:
def _parse_openai_response(data: Dict) -> List[Dict]:
"""
解析 OpenAI 格式的模型列表响应
@@ -59,7 +59,7 @@ def _parse_openai_response(data: dict) -> list[dict]:
]
def _parse_gemini_response(data: dict) -> list[dict]:
def _parse_gemini_response(data: Dict) -> List[Dict]:
"""
解析 Gemini 格式的模型列表响应
@@ -89,7 +89,7 @@ async def _fetch_models_from_provider(
endpoint: str,
parser: str,
client_type: str = "openai",
) -> list[dict]:
) -> List[Dict]:
"""
从提供商 API 获取模型列表
@@ -154,7 +154,7 @@ async def _fetch_models_from_provider(
raise HTTPException(status_code=400, detail=f"不支持的解析器类型: {parser}")
def _get_provider_config(provider_name: str) -> Optional[dict]:
def _get_provider_config(provider_name: str) -> Optional[Dict]:
"""
从 model_config.toml 获取指定提供商的配置

View File

@@ -1,14 +1,13 @@
"""人物信息管理 API 路由"""
from datetime import datetime
import json
from datetime import datetime
from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy import case
from sqlmodel import col, select, delete
from sqlmodel import col, delete, select
from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo

View File

@@ -14,4 +14,4 @@ router.include_router(config_router)
set_update_progress_callback(update_progress)
__all__ = ["get_progress_router", "router"]
__all__ = ["get_progress_router", "router"]

View File

@@ -1,6 +1,5 @@
from typing import Any, Optional
import json
from typing import Any, Dict, Optional
from fastapi import APIRouter, Cookie, HTTPException
@@ -28,7 +27,7 @@ logger = get_logger("webui.plugin_routes")
router = APIRouter()
def _mirror_to_response(mirror: dict[str, Any]) -> MirrorConfigResponse:
def _mirror_to_response(mirror: Dict[str, Any]) -> MirrorConfigResponse:
return MirrorConfigResponse(
id=mirror["id"],
name=mirror["name"],
@@ -116,7 +115,7 @@ async def update_mirror(
@router.delete("/mirrors/{mirror_id}")
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
require_plugin_token(maibot_session)
service = get_git_mirror_service()
@@ -172,12 +171,16 @@ async def fetch_raw_file(
loaded_plugins=total,
)
except Exception:
await update_progress(stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0)
await update_progress(
stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0
)
return FetchRawFileResponse(**result)
except Exception as e:
logger.error(f"获取 Raw 文件失败: {e}")
await update_progress(stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0)
await update_progress(
stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0
)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@@ -204,4 +207,4 @@ async def clone_repository(
return CloneRepositoryResponse(**result)
except Exception as e:
logger.error(f"克隆仓库失败: {e}")
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e

View File

@@ -1,8 +1,7 @@
from typing import Any, Optional, cast
import json
import tomlkit
from typing import Any, Dict, Optional, cast
import tomlkit
from fastapi import APIRouter, Cookie, HTTPException
from src.common.logger import get_logger
@@ -24,8 +23,8 @@ logger = get_logger("webui.plugin_routes")
router = APIRouter()
def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> dict[str, Any]:
schema: dict[str, Any] = {
def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> Dict[str, Any]:
schema: Dict[str, Any] = {
"plugin_id": plugin_id,
"plugin_info": {
"name": plugin_id,
@@ -41,7 +40,7 @@ def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> di
for section_name, section_data in current_config.items():
if not isinstance(section_data, dict):
continue
section_fields: dict[str, Any] = {}
section_fields: Dict[str, Any] = {}
for field_name, field_value in section_data.items():
field_type = type(field_value).__name__
ui_type = "text"
@@ -121,7 +120,7 @@ def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> di
@router.get("/config/{plugin_id}/schema")
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"获取插件配置 Schema: {plugin_id}")
@@ -157,7 +156,7 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
@router.get("/config/{plugin_id}/raw")
async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"获取插件原始配置: {plugin_id}")
@@ -184,7 +183,7 @@ async def update_plugin_config_raw(
plugin_id: str,
request: UpdatePluginRawConfigRequest,
maibot_session: Optional[str] = Cookie(None),
) -> dict[str, Any]:
) -> Dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"更新插件原始配置: {plugin_id}")
@@ -216,7 +215,7 @@ async def update_plugin_config_raw(
@router.get("/config/{plugin_id}")
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"获取插件配置: {plugin_id}")
@@ -244,7 +243,7 @@ async def update_plugin_config(
plugin_id: str,
request: UpdatePluginConfigRequest,
maibot_session: Optional[str] = Cookie(None),
) -> dict[str, Any]:
) -> Dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"更新插件配置: {plugin_id}")
@@ -276,7 +275,7 @@ async def update_plugin_config(
@router.post("/config/{plugin_id}/reset")
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"重置插件配置: {plugin_id}")
@@ -300,7 +299,7 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co
@router.post("/config/{plugin_id}/toggle")
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"切换插件状态: {plugin_id}")
@@ -326,9 +325,14 @@ async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(N
status = "启用" if new_enabled else "禁用"
logger.info(f"{status}插件: {plugin_id}")
return {"success": True, "enabled": new_enabled, "message": f"插件已{status}", "note": "状态更改将在下次加载插件时生效"}
return {
"success": True,
"enabled": new_enabled,
"message": f"插件已{status}",
"note": "状态更改将在下次加载插件时生效",
}
except HTTPException:
raise
except Exception as e:
logger.error(f"切换插件状态失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e

View File

@@ -1,7 +1,6 @@
from pathlib import Path
from typing import Any, Optional
import json
from pathlib import Path
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Cookie, HTTPException
@@ -18,8 +17,8 @@ from .support import (
parse_repository_url,
remove_tree,
require_plugin_token,
resolve_plugin_file_path,
resolve_installed_plugin_path,
resolve_plugin_file_path,
validate_plugin_id,
)
@@ -28,7 +27,7 @@ logger = get_logger("webui.plugin_routes")
router = APIRouter()
def _infer_plugin_id(folder_name: str, manifest: dict[str, Any], manifest_path: Path) -> str:
def _infer_plugin_id(folder_name: str, manifest: Dict[str, Any], manifest_path: Path) -> str:
if "id" in manifest:
return str(manifest["id"])
@@ -66,43 +65,87 @@ def _infer_plugin_id(folder_name: str, manifest: dict[str, Any], manifest_path:
@router.post("/install")
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"收到安装插件请求: {request.plugin_id}")
plugin_id = request.plugin_id
try:
plugin_id = validate_plugin_id(request.plugin_id)
await update_progress(stage="loading", progress=5, message=f"开始安装插件: {plugin_id}", operation="install", plugin_id=plugin_id)
await update_progress(
stage="loading", progress=5, message=f"开始安装插件: {plugin_id}", operation="install", plugin_id=plugin_id
)
repo_url, owner, repo = parse_repository_url(request.repository_url)
await update_progress(stage="loading", progress=10, message=f"解析仓库信息: {owner}/{repo}", operation="install", plugin_id=plugin_id)
await update_progress(
stage="loading",
progress=10,
message=f"解析仓库信息: {owner}/{repo}",
operation="install",
plugin_id=plugin_id,
)
target_path, old_format_path = get_plugin_candidate_paths(plugin_id)
if target_path.exists() or old_format_path.exists():
await update_progress(stage="error", progress=0, message="插件已存在", operation="install", plugin_id=plugin_id, error="插件已安装,请先卸载")
await update_progress(
stage="error",
progress=0,
message="插件已存在",
operation="install",
plugin_id=plugin_id,
error="插件已安装,请先卸载",
)
raise HTTPException(status_code=400, detail="插件已安装")
await update_progress(stage="loading", progress=15, message=f"准备克隆到: {target_path}", operation="install", plugin_id=plugin_id)
await update_progress(
stage="loading", progress=15, message=f"准备克隆到: {target_path}", operation="install", plugin_id=plugin_id
)
service = get_git_mirror_service()
if "github.com" in repo_url:
result = await service.clone_repository(owner=owner, repo=repo, target_path=target_path, branch=request.branch, mirror_id=request.mirror_id, depth=1)
result = await service.clone_repository(
owner=owner,
repo=repo,
target_path=target_path,
branch=request.branch,
mirror_id=request.mirror_id,
depth=1,
)
else:
result = await service.clone_repository(owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1)
result = await service.clone_repository(
owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1
)
if not result.get("success"):
error_msg = str(result.get("error", "克隆失败"))
await update_progress(stage="error", progress=0, message="克隆仓库失败", operation="install", plugin_id=plugin_id, error=error_msg)
await update_progress(
stage="error",
progress=0,
message="克隆仓库失败",
operation="install",
plugin_id=plugin_id,
error=error_msg,
)
raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg)
await update_progress(stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id)
await update_progress(
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id
)
manifest_path = resolve_plugin_file_path(target_path, "_manifest.json")
if not manifest_path.exists():
remove_tree(target_path)
await update_progress(stage="error", progress=0, message="插件缺少 _manifest.json", operation="install", plugin_id=plugin_id, error="无效的插件格式")
await update_progress(
stage="error",
progress=0,
message="插件缺少 _manifest.json",
operation="install",
plugin_id=plugin_id,
error="无效的插件格式",
)
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
await update_progress(stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id)
await update_progress(
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id
)
try:
with open(manifest_path, "r", encoding="utf-8") as file_obj:
manifest = json.load(file_obj)
@@ -114,91 +157,199 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
json.dump(manifest, file_obj, ensure_ascii=False, indent=2)
except Exception as e:
remove_tree(target_path)
await update_progress(stage="error", progress=0, message="_manifest.json 无效", operation="install", plugin_id=plugin_id, error=str(e))
await update_progress(
stage="error",
progress=0,
message="_manifest.json 无效",
operation="install",
plugin_id=plugin_id,
error=str(e),
)
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
await update_progress(stage="success", progress=100, message=f"成功安装插件: {manifest['name']} v{manifest['version']}", operation="install", plugin_id=plugin_id)
return {"success": True, "message": "插件安装成功", "plugin_id": plugin_id, "plugin_name": manifest["name"], "version": manifest["version"], "path": str(target_path)}
await update_progress(
stage="success",
progress=100,
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
operation="install",
plugin_id=plugin_id,
)
return {
"success": True,
"message": "插件安装成功",
"plugin_id": plugin_id,
"plugin_name": manifest["name"],
"version": manifest["version"],
"path": str(target_path),
}
except HTTPException:
raise
except Exception as e:
logger.error(f"安装插件失败: {e}", exc_info=True)
await update_progress(stage="error", progress=0, message="安装失败", operation="install", plugin_id=plugin_id, error=str(e))
await update_progress(
stage="error", progress=0, message="安装失败", operation="install", plugin_id=plugin_id, error=str(e)
)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.post("/uninstall")
async def uninstall_plugin(request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
async def uninstall_plugin(
request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None)
) -> Dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"收到卸载插件请求: {request.plugin_id}")
plugin_id = request.plugin_id
try:
plugin_id = validate_plugin_id(request.plugin_id)
await update_progress(stage="loading", progress=10, message=f"开始卸载插件: {plugin_id}", operation="uninstall", plugin_id=plugin_id)
await update_progress(
stage="loading",
progress=10,
message=f"开始卸载插件: {plugin_id}",
operation="uninstall",
plugin_id=plugin_id,
)
plugin_path = resolve_installed_plugin_path(plugin_id)
if plugin_path is None:
await update_progress(stage="error", progress=0, message="插件不存在", operation="uninstall", plugin_id=plugin_id, error="插件未安装或已被删除")
await update_progress(
stage="error",
progress=0,
message="插件不存在",
operation="uninstall",
plugin_id=plugin_id,
error="插件未安装或已被删除",
)
raise HTTPException(status_code=404, detail="插件未安装")
await update_progress(stage="loading", progress=30, message=f"正在删除插件文件: {plugin_path}", operation="uninstall", plugin_id=plugin_id)
await update_progress(
stage="loading",
progress=30,
message=f"正在删除插件文件: {plugin_path}",
operation="uninstall",
plugin_id=plugin_id,
)
manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json"))
plugin_name = str(manifest.get("name", plugin_id)) if manifest is not None else plugin_id
await update_progress(stage="loading", progress=50, message=f"正在删除 {plugin_name}...", operation="uninstall", plugin_id=plugin_id)
await update_progress(
stage="loading",
progress=50,
message=f"正在删除 {plugin_name}...",
operation="uninstall",
plugin_id=plugin_id,
)
remove_tree(plugin_path)
logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})")
await update_progress(stage="success", progress=100, message=f"成功卸载插件: {plugin_name}", operation="uninstall", plugin_id=plugin_id)
await update_progress(
stage="success",
progress=100,
message=f"成功卸载插件: {plugin_name}",
operation="uninstall",
plugin_id=plugin_id,
)
return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name}
except HTTPException:
raise
except PermissionError as e:
logger.error(f"卸载插件失败(权限错误): {e}")
await update_progress(stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error="权限不足,无法删除插件文件")
await update_progress(
stage="error",
progress=0,
message="卸载失败",
operation="uninstall",
plugin_id=plugin_id,
error="权限不足,无法删除插件文件",
)
raise HTTPException(status_code=500, detail="权限不足,无法删除插件文件") from e
except Exception as e:
logger.error(f"卸载插件失败: {e}", exc_info=True)
await update_progress(stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error=str(e))
await update_progress(
stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error=str(e)
)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.post("/update")
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"收到更新插件请求: {request.plugin_id}")
plugin_id = request.plugin_id
try:
plugin_id = validate_plugin_id(request.plugin_id)
await update_progress(stage="loading", progress=5, message=f"开始更新插件: {plugin_id}", operation="update", plugin_id=plugin_id)
await update_progress(
stage="loading", progress=5, message=f"开始更新插件: {plugin_id}", operation="update", plugin_id=plugin_id
)
plugin_path = resolve_installed_plugin_path(plugin_id)
if plugin_path is None:
await update_progress(stage="error", progress=0, message="插件不存在", operation="update", plugin_id=plugin_id, error="插件未安装,请先安装")
await update_progress(
stage="error",
progress=0,
message="插件不存在",
operation="update",
plugin_id=plugin_id,
error="插件未安装,请先安装",
)
raise HTTPException(status_code=404, detail="插件未安装")
manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json"))
old_version = str(manifest.get("version", "unknown")) if manifest is not None else "unknown"
await update_progress(stage="loading", progress=10, message=f"当前版本: {old_version},准备更新...", operation="update", plugin_id=plugin_id)
await update_progress(stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id)
await update_progress(
stage="loading",
progress=10,
message=f"当前版本: {old_version},准备更新...",
operation="update",
plugin_id=plugin_id,
)
await update_progress(
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id
)
remove_tree(plugin_path)
await update_progress(stage="loading", progress=30, message="正在准备下载新版本...", operation="update", plugin_id=plugin_id)
await update_progress(
stage="loading", progress=30, message="正在准备下载新版本...", operation="update", plugin_id=plugin_id
)
repo_url, owner, repo = parse_repository_url(request.repository_url)
service = get_git_mirror_service()
if "github.com" in repo_url:
result = await service.clone_repository(owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, mirror_id=request.mirror_id, depth=1)
result = await service.clone_repository(
owner=owner,
repo=repo,
target_path=plugin_path,
branch=request.branch,
mirror_id=request.mirror_id,
depth=1,
)
else:
result = await service.clone_repository(owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1)
result = await service.clone_repository(
owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1
)
if not result.get("success"):
error_msg = str(result.get("error", "克隆失败"))
await update_progress(stage="error", progress=0, message="下载新版本失败", operation="update", plugin_id=plugin_id, error=error_msg)
await update_progress(
stage="error",
progress=0,
message="下载新版本失败",
operation="update",
plugin_id=plugin_id,
error=error_msg,
)
raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg)
await update_progress(stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id)
await update_progress(
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id
)
new_manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json")
if not new_manifest_path.exists():
remove_tree(plugin_path)
await update_progress(stage="error", progress=0, message="新版本缺少 _manifest.json", operation="update", plugin_id=plugin_id, error="无效的插件格式")
await update_progress(
stage="error",
progress=0,
message="新版本缺少 _manifest.json",
operation="update",
plugin_id=plugin_id,
error="无效的插件格式",
)
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
try:
@@ -207,27 +358,49 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
new_version = str(new_manifest.get("version", "unknown"))
new_name = str(new_manifest.get("name", plugin_id))
logger.info(f"成功更新插件: {plugin_id} {old_version}{new_version}")
await update_progress(stage="success", progress=100, message=f"成功更新 {new_name}: {old_version}{new_version}", operation="update", plugin_id=plugin_id)
return {"success": True, "message": "插件更新成功", "plugin_id": plugin_id, "plugin_name": new_name, "old_version": old_version, "new_version": new_version}
await update_progress(
stage="success",
progress=100,
message=f"成功更新 {new_name}: {old_version}{new_version}",
operation="update",
plugin_id=plugin_id,
)
return {
"success": True,
"message": "插件更新成功",
"plugin_id": plugin_id,
"plugin_name": new_name,
"old_version": old_version,
"new_version": new_version,
}
except Exception as e:
remove_tree(plugin_path)
await update_progress(stage="error", progress=0, message="_manifest.json 无效", operation="update", plugin_id=plugin_id, error=str(e))
await update_progress(
stage="error",
progress=0,
message="_manifest.json 无效",
operation="update",
plugin_id=plugin_id,
error=str(e),
)
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
except HTTPException:
raise
except Exception as e:
logger.error(f"更新插件失败: {e}", exc_info=True)
await update_progress(stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e))
await update_progress(
stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e)
)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.get("/installed")
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
require_plugin_token(maibot_session)
logger.info("收到获取已安装插件列表请求")
try:
installed_plugins: list[dict[str, Any]] = []
installed_plugins: List[Dict[str, Any]] = []
for plugin_path in iter_plugin_directories():
folder_name = plugin_path.name
if folder_name.startswith(".") or folder_name.startswith("__"):
@@ -253,9 +426,9 @@ async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) ->
except Exception as e:
logger.error(f"读取插件 {folder_name} 信息时出错: {e}")
seen_ids: dict[str, str] = {}
unique_plugins: list[dict[str, Any]] = []
duplicates: list[dict[str, Any]] = []
seen_ids: Dict[str, str] = {}
unique_plugins: List[Dict[str, Any]] = []
duplicates: List[Dict[str, Any]] = []
for plugin in installed_plugins:
plugin_id = str(plugin["id"])
plugin_path = str(plugin["path"])
@@ -277,7 +450,7 @@ async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) ->
@router.get("/local-readme/{plugin_id}")
async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
require_plugin_token(maibot_session)
logger.info(f"获取本地插件 README: {plugin_id}")
@@ -300,4 +473,4 @@ async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str]
return {"success": False, "error": "本地未找到 README 文件"}
except Exception as e:
logger.error(f"获取本地 README 失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
return {"success": False, "error": str(e)}

View File

@@ -1,7 +1,6 @@
import asyncio
import json
from typing import Any, Optional, Set
from typing import Any, Dict, Optional, Set
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
@@ -14,7 +13,7 @@ logger = get_logger("webui.plugin_progress")
router = APIRouter()
active_connections: Set[WebSocket] = set()
current_progress: dict[str, Any] = {
current_progress: Dict[str, Any] = {
"operation": "idle",
"stage": "idle",
"progress": 0,
@@ -26,7 +25,7 @@ current_progress: dict[str, Any] = {
}
async def broadcast_progress(progress_data: dict[str, Any]) -> None:
async def broadcast_progress(progress_data: Dict[str, Any]) -> None:
global current_progress
current_progress = progress_data.copy()
@@ -34,7 +33,7 @@ async def broadcast_progress(progress_data: dict[str, Any]) -> None:
return
message = json.dumps(progress_data, ensure_ascii=False)
disconnected: set[WebSocket] = set()
disconnected: Set[WebSocket] = set()
for websocket in active_connections:
try:
@@ -119,4 +118,4 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
def get_progress_router() -> APIRouter:
return router
return router

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
@@ -51,8 +51,8 @@ class MirrorConfigResponse(BaseModel):
class AvailableMirrorsResponse(BaseModel):
mirrors: list[MirrorConfigResponse] = Field(..., description="镜像源列表")
default_priority: list[str] = Field(..., description="默认优先级顺序ID 列表)")
mirrors: List[MirrorConfigResponse] = Field(..., description="镜像源列表")
default_priority: List[str] = Field(..., description="默认优先级顺序ID 列表)")
class AddMirrorRequest(BaseModel):
@@ -106,8 +106,8 @@ class UpdatePluginRequest(BaseModel):
class UpdatePluginConfigRequest(BaseModel):
enabled: Optional[bool] = None
config: Optional[dict[str, Any]] = None
config: Optional[Dict[str, Any]] = None
class UpdatePluginRawConfigRequest(BaseModel):
config: str = Field(..., description="原始 TOML 配置内容")
config: str = Field(..., description="原始 TOML 配置内容")

View File

@@ -1,12 +1,11 @@
from datetime import datetime
from pathlib import Path
from typing import Any, Optional, cast, get_origin
import json
import os
import re
import shutil
import stat
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, cast, get_origin
from fastapi import HTTPException
@@ -53,10 +52,7 @@ def _resolve_safe_plugin_directory(plugin_path: Path, plugins_dir: Path, strict:
resolved_plugin_path = plugin_path.resolve()
resolved_plugin_path.relative_to(resolved_plugins_dir)
if not resolved_plugin_path.is_dir():
return None
return resolved_plugin_path
return resolved_plugin_path if resolved_plugin_path.is_dir() else None
except HTTPException:
if strict:
raise
@@ -64,7 +60,7 @@ def _resolve_safe_plugin_directory(plugin_path: Path, plugins_dir: Path, strict:
return None
except (OSError, RuntimeError, ValueError):
if strict:
raise HTTPException(status_code=400, detail="插件目录超出允许范围")
raise HTTPException(status_code=400, detail="插件目录超出允许范围") from None
logger.warning(f"已跳过越界的插件目录: {plugin_path}")
return None
@@ -109,7 +105,7 @@ def validate_plugin_id(plugin_id: str) -> str:
return plugin_id
def parse_version(version_str: str) -> tuple[int, int, int]:
def parse_version(version_str: str) -> Tuple[int, int, int]:
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
parts = base_version.split(".")
if len(parts) < 3:
@@ -122,7 +118,7 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
return 0, 0, 0
def deep_merge(dst: dict[str, Any], src: dict[str, Any]) -> None:
def deep_merge(dst: Dict[str, Any], src: Dict[str, Any]) -> None:
for key, value in src.items():
if key in dst and isinstance(dst[key], dict) and isinstance(value, dict):
deep_merge(dst[key], value)
@@ -130,9 +126,9 @@ def deep_merge(dst: dict[str, Any], src: dict[str, Any]) -> None:
dst[key] = value
def normalize_dotted_keys(obj: dict[str, Any]) -> dict[str, Any]:
result: dict[str, Any] = {}
dotted_items: list[tuple[str, Any]] = []
def normalize_dotted_keys(obj: Dict[str, Any]) -> Dict[str, Any]:
result: Dict[str, Any] = {}
dotted_items: List[Tuple[str, Any]] = []
for key, value in obj.items():
if "." in key:
@@ -167,7 +163,7 @@ def normalize_dotted_keys(obj: dict[str, Any]) -> dict[str, Any]:
return result
def coerce_types(schema_part: dict[str, Any], config_part: dict[str, Any]) -> None:
def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> None:
def is_list_type(tp: Any) -> bool:
origin = get_origin(tp)
return tp is list or origin is list
@@ -200,7 +196,7 @@ def get_plugins_dir() -> Path:
return plugins_dir
def get_plugin_candidate_paths(plugin_id: str) -> tuple[Path, Path]:
def get_plugin_candidate_paths(plugin_id: str) -> Tuple[Path, Path]:
plugins_dir = get_plugins_dir()
folder_name = plugin_id.replace(".", "_")
return validate_safe_path(folder_name, plugins_dir), validate_safe_path(plugin_id, plugins_dir)
@@ -217,7 +213,7 @@ def resolve_installed_plugin_path(plugin_id: str) -> Optional[Path]:
return None
def parse_repository_url(repository_url: str) -> tuple[str, str, str]:
def parse_repository_url(repository_url: str) -> Tuple[str, str, str]:
repo_url = repository_url.rstrip("/").removesuffix(".git")
parts = repo_url.split("/")
if len(parts) < 2:
@@ -225,7 +221,7 @@ def parse_repository_url(repository_url: str) -> tuple[str, str, str]:
return repo_url, parts[-2], parts[-1]
def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
def load_manifest_json(manifest_path: Path) -> Optional[Dict[str, Any]]:
if not manifest_path.exists():
return None
@@ -246,9 +242,9 @@ def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
return None
def iter_plugin_directories() -> list[Path]:
def iter_plugin_directories() -> List[Path]:
plugins_dir = get_plugins_dir()
plugin_directories: list[Path] = []
plugin_directories: List[Path] = []
for path in plugins_dir.iterdir():
safe_path = _resolve_safe_plugin_directory(path, plugins_dir, strict=False)
if safe_path is not None:
@@ -286,4 +282,4 @@ def remove_tree(path: Path) -> None:
os.chmod(target_path, stat.S_IWRITE)
func(target_path)
shutil.rmtree(path, onerror=remove_readonly)
shutil.rmtree(path, onerror=remove_readonly)

View File

@@ -1,7 +1,7 @@
"""统计数据 API 路由"""
from datetime import datetime, timedelta
from typing import Any
from typing import Any, Dict, List
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
@@ -56,10 +56,10 @@ class DashboardData(BaseModel):
"""仪表盘数据"""
summary: StatisticsSummary
model_stats: list[ModelStatistics]
hourly_data: list[TimeSeriesData]
daily_data: list[TimeSeriesData]
recent_activity: list[dict[str, Any]]
model_stats: List[ModelStatistics]
hourly_data: List[TimeSeriesData]
daily_data: List[TimeSeriesData]
recent_activity: List[Dict[str, Any]]
@router.get("/dashboard", response_model=DashboardData)
@@ -168,7 +168,7 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
return summary
async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]:
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
# 使用GROUP BY聚合避免全量加载
statement = (
@@ -181,7 +181,7 @@ async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]:
with get_db_session() as session:
rows = session.exec(statement).all()
aggregates: dict[str, dict[str, float | int]] = {}
aggregates: Dict[str, Dict[str, float | int]] = {}
for record in rows:
model_name = record.model_assign_name or record.model_name or "unknown"
if model_name not in aggregates:
@@ -200,7 +200,7 @@ async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]:
bucket["total_time_cost"] = float(bucket["total_time_cost"]) + float(record.time_cost)
bucket["time_cost_count"] = int(bucket["time_cost_count"]) + 1
result: list[ModelStatistics] = []
result: List[ModelStatistics] = []
for model_name, bucket in sorted(
aggregates.items(),
key=lambda item: float(item[1]["request_count"]),
@@ -221,7 +221,7 @@ async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]:
return result
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> list[TimeSeriesData]:
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
"""获取小时级统计数据(优化:使用数据库聚合)"""
# SQLite的日期时间函数进行小时分组
# 使用strftime将timestamp格式化为小时级别
@@ -265,7 +265,7 @@ async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> li
return result
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> list[TimeSeriesData]:
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
"""获取日级统计数据(优化:使用数据库聚合)"""
# 使用strftime按日期分组
day_expr = func.strftime("%Y-%m-%dT00:00:00", col(ModelUsage.timestamp))
@@ -308,7 +308,7 @@ async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> lis
return result
async def _get_recent_activity(limit: int = 10) -> list[dict[str, Any]]:
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
"""获取最近活动"""
with get_db_session() as session:
statement = select(ModelUsage).order_by(desc(col(ModelUsage.timestamp))).limit(limit)

View File

@@ -10,8 +10,9 @@ from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from src.config.config import MMC_VERSION
from src.common.logger import get_logger
from src.config.config import MMC_VERSION
from src.webui.dependencies import require_auth
router = APIRouter(prefix="/system", tags=["system"], dependencies=[Depends(require_auth)])

View File

@@ -1,5 +1,5 @@
from .logs import router as logs_router
from .auth import router as ws_auth_router
from .logs import router as logs_router
__all__ = [
"logs_router",

View File

@@ -1,10 +1,11 @@
"""WebSocket 认证模块。"""
from typing import Optional
from fastapi import APIRouter, Cookie
import secrets
import time
from typing import Dict, Optional, Tuple
from fastapi import APIRouter, Cookie
from src.common.logger import get_logger
from src.webui.core import get_token_manager
@@ -13,7 +14,7 @@ router = APIRouter()
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
_ws_temp_tokens: Dict[str, Tuple[float, str]] = {}
_WS_TOKEN_EXPIRE_SECONDS = 60

View File

@@ -2,24 +2,26 @@
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from pydantic import BaseModel, Field
from src.common.logger import get_logger
from src.webui.core import (
clear_auth_cookie,
check_auth_rate_limit,
clear_auth_cookie,
get_rate_limiter,
get_token_manager,
set_auth_cookie,
)
from src.webui.dependencies import require_auth, verify_token_optional
from src.webui.routers.config import router as config_router
from src.webui.routers.statistics import router as statistics_router
from src.webui.routers.person import router as person_router
from src.webui.routers.emoji import router as emoji_router
from src.webui.routers.expression import router as expression_router
from src.webui.routers.jargon import router as jargon_router
from src.webui.routers.emoji import router as emoji_router
from src.webui.routers.plugin import get_progress_router, router as plugin_router
from src.webui.routers.system import router as system_router
from src.webui.routers.model import router as model_router
from src.webui.routers.person import router as person_router
from src.webui.routers.plugin import get_progress_router
from src.webui.routers.plugin import router as plugin_router
from src.webui.routers.statistics import router as statistics_router
from src.webui.routers.system import router as system_router
from src.webui.routers.websocket.auth import router as ws_auth_router
logger = get_logger("webui.api")

View File

@@ -2,62 +2,62 @@
# Auth schemas
from .auth import (
TokenVerifyRequest,
TokenVerifyResponse,
CompleteSetupResponse,
FirstSetupStatusResponse,
ResetSetupResponse,
TokenRegenerateResponse,
TokenUpdateRequest,
TokenUpdateResponse,
TokenRegenerateResponse,
FirstSetupStatusResponse,
CompleteSetupResponse,
ResetSetupResponse,
TokenVerifyRequest,
TokenVerifyResponse,
)
# Statistics schemas
from .statistics import (
StatisticsSummary,
ModelStatistics,
TimeSeriesData,
DashboardData,
# Chat schemas
from .chat import (
ChatHistoryMessage,
VirtualIdentityConfig,
)
# Emoji schemas
from .emoji import (
EmojiResponse,
EmojiListResponse,
EmojiDetailResponse,
EmojiUpdateRequest,
EmojiUpdateResponse,
EmojiDeleteResponse,
BatchDeleteRequest,
BatchDeleteResponse,
EmojiDeleteResponse,
EmojiDetailResponse,
EmojiListResponse,
EmojiResponse,
EmojiUpdateRequest,
EmojiUpdateResponse,
EmojiUploadResponse,
ThumbnailCacheStatsResponse,
ThumbnailCleanupResponse,
ThumbnailPreheatResponse,
)
# Chat schemas
from .chat import (
VirtualIdentityConfig,
ChatHistoryMessage,
)
# Plugin schemas
from .plugin import (
FetchRawFileRequest,
FetchRawFileResponse,
AddMirrorRequest,
AvailableMirrorsResponse,
CloneRepositoryRequest,
CloneRepositoryResponse,
MirrorConfigResponse,
AvailableMirrorsResponse,
AddMirrorRequest,
UpdateMirrorRequest,
FetchRawFileRequest,
FetchRawFileResponse,
GitStatusResponse,
InstallPluginRequest,
VersionResponse,
MirrorConfigResponse,
UninstallPluginRequest,
UpdatePluginRequest,
UpdateMirrorRequest,
UpdatePluginConfigRequest,
UpdatePluginRequest,
VersionResponse,
)
# Statistics schemas
from .statistics import (
DashboardData,
ModelStatistics,
StatisticsSummary,
TimeSeriesData,
)
__all__ = [

View File

@@ -1,6 +1,7 @@
from pydantic import BaseModel
from typing import Optional
from pydantic import BaseModel
class VirtualIdentityConfig(BaseModel):
"""虚拟身份配置"""

View File

@@ -1,5 +1,6 @@
from typing import List, Optional
from pydantic import BaseModel
from typing import Optional, List
class EmojiResponse(BaseModel):

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any
class FetchRawFileRequest(BaseModel):

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, List
from pydantic import BaseModel, Field
from typing import Dict, Any, List
class StatisticsSummary(BaseModel):

View File

@@ -1,14 +1,16 @@
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
from typing import Optional, List, Dict, Any
from enum import Enum
import httpx
import json
import asyncio
import subprocess
import json
import shutil
from pathlib import Path
import subprocess
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional
import httpx
from src.common.logger import get_logger
from src.webui.utils.network_security import validate_public_url
@@ -188,10 +190,8 @@ class GitMirrorConfig:
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
"""根据 ID 获取镜像源"""
for mirror in self.mirrors:
if mirror.get("id") == mirror_id:
return mirror.copy()
return None
matched_mirror = next((mirror for mirror in self.mirrors if mirror.get("id") == mirror_id), None)
return matched_mirror.copy() if matched_mirror is not None else None
def add_mirror(
self,
@@ -332,8 +332,8 @@ class GitMirrorService:
- path: str - Git 可执行文件路径(如果已安装)
- error: str - 错误信息(如果未安装或检测失败)
"""
import subprocess
import shutil
import subprocess
try:
# 查找 git 可执行文件路径
@@ -409,8 +409,7 @@ class GitMirrorService:
# 确定要使用的镜像源列表
if mirror_id:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
if (mirror := self.config.get_mirror_by_id(mirror_id)) is None:
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
mirrors_to_try = [mirror]
else:
@@ -477,7 +476,13 @@ class GitMirrorService:
try:
raw_prefix = _validate_mirror_prefix(mirror["raw_prefix"], "镜像 Raw 前缀")
except ValueError as e:
return {"success": False, "error": str(e), "mirror_used": mirror.get("id"), "attempts": 0, "status_code": 400}
return {
"success": False,
"error": str(e),
"mirror_used": mirror.get("id"),
"attempts": 0,
"status_code": 400,
}
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
@@ -566,8 +571,7 @@ class GitMirrorService:
# 确定要使用的镜像源列表
if mirror_id:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
if (mirror := self.config.get_mirror_by_id(mirror_id)) is None:
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
mirrors_to_try = [mirror]
else:
@@ -597,7 +601,13 @@ class GitMirrorService:
try:
clone_prefix = _validate_mirror_prefix(mirror["clone_prefix"], "镜像克隆前缀")
except ValueError as e:
return {"success": False, "error": str(e), "mirror_used": mirror.get("id"), "attempts": 0, "status_code": 400}
return {
"success": False,
"error": str(e),
"mirror_used": mirror.get("id"),
"attempts": 0,
"status_code": 400,
}
url = f"{clone_prefix}/{owner}/{repo}.git"

View File

@@ -1,18 +1,16 @@
from typing import Iterable
import ipaddress
import socket
from typing import Iterable, Set
from urllib.parse import urlparse
def _resolve_ip_addresses(hostname: str, port: int) -> set[ipaddress.IPv4Address | ipaddress.IPv6Address]:
def _resolve_ip_addresses(hostname: str, port: int) -> Set[ipaddress.IPv4Address | ipaddress.IPv6Address]:
try:
address_infos = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM)
except socket.gaierror as exc:
raise ValueError(f"无法解析主机名: {hostname}") from exc
resolved_addresses: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
resolved_addresses: Set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
for _, _, _, _, sockaddr in address_infos:
host_address = sockaddr[0]
if not isinstance(host_address, str):
@@ -73,4 +71,4 @@ def validate_public_url(url: str, allowed_schemes: Iterable[str] = ("https",)) -
if _is_forbidden_ip_address(address):
raise ValueError(f"禁止访问非公网地址: {address}")
return normalized_url
return normalized_url

View File

@@ -4,8 +4,9 @@ TOML 工具函数
提供 TOML 文件的格式化保存功能,确保数组等元素以美观的多行格式输出。
"""
from typing import Any
import re
from typing import Any
import tomlkit
from tomlkit.items import AoT, Array, Table

View File

@@ -1,9 +1,10 @@
"""独立的 WebUI 服务器 - 运行在 0.0.0.0:8001"""
from uvicorn import Config, Server as UvicornServer
import asyncio
from uvicorn import Config
from uvicorn import Server as UvicornServer
from src.common.logger import get_logger
from src.common.utils.port_checker import assert_port_available, is_port_conflict_error, log_port_conflict
from src.config.config import config_manager