WebUI 后端类型注解补全,使用全 typing 库类型注解
This commit is contained in:
@@ -8,11 +8,10 @@
|
||||
3. 详情按需加载
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -272,8 +271,7 @@ async def get_all_logs(page: int = Query(1, ge=1), page_size: int = Query(20, ge
|
||||
all_files = []
|
||||
for chat_dir in PLAN_LOG_DIR.iterdir():
|
||||
if chat_dir.is_dir():
|
||||
for log_file in chat_dir.glob("*.json"):
|
||||
all_files.append((chat_dir.name, log_file))
|
||||
all_files.extend((chat_dir.name, log_file) for log_file in chat_dir.glob("*.json"))
|
||||
|
||||
# 按时间戳排序
|
||||
all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True)
|
||||
|
||||
@@ -8,11 +8,10 @@
|
||||
3. 详情按需加载
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""FastAPI 应用工厂 - 创建和配置 WebUI 应用实例"""
|
||||
|
||||
import mimetypes
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from subprocess import CompletedProcess, TimeoutExpired, run
|
||||
import mimetypes
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -61,12 +62,12 @@ def _get_dashboard_root() -> Path:
|
||||
return _get_project_root() / "dashboard"
|
||||
|
||||
|
||||
def _format_dashboard_shell_commands(*commands: list[str]) -> str:
|
||||
def _format_dashboard_shell_commands(*commands: List[str]) -> str:
|
||||
formatted_commands = " && ".join(" ".join(command) for command in commands)
|
||||
return f"cd dashboard && {formatted_commands}"
|
||||
|
||||
|
||||
def _validate_static_path(static_path: Path | None) -> tuple[str, dict[str, object]] | None:
|
||||
def _validate_static_path(static_path: Path | None) -> Tuple[str, Dict[str, Any]] | None:
|
||||
if static_path is None:
|
||||
return "startup.webui_static_dir_missing", {}
|
||||
|
||||
@@ -81,7 +82,7 @@ def _validate_static_path(static_path: Path | None) -> tuple[str, dict[str, obje
|
||||
|
||||
|
||||
def _summarize_command_output(command_result: CompletedProcess[str] | TimeoutExpired) -> str:
|
||||
output_chunks: list[str] = []
|
||||
output_chunks: List[str] = []
|
||||
stdout = command_result.stdout
|
||||
stderr = command_result.stderr
|
||||
|
||||
@@ -111,14 +112,18 @@ def _get_preferred_dashboard_package_manager(dashboard_root: Path) -> str:
|
||||
return "npm"
|
||||
|
||||
|
||||
def _get_dashboard_build_command(dashboard_root: Path) -> list[str] | None:
|
||||
def _get_dashboard_build_command(dashboard_root: Path) -> List[str] | None:
|
||||
if not (dashboard_root / "package.json").exists():
|
||||
return None
|
||||
|
||||
preferred_package_manager = _get_preferred_dashboard_package_manager(dashboard_root)
|
||||
package_managers = [
|
||||
preferred_package_manager,
|
||||
*[package_manager for package_manager in _DASHBOARD_BUILD_COMMANDS if package_manager != preferred_package_manager],
|
||||
*[
|
||||
package_manager
|
||||
for package_manager in _DASHBOARD_BUILD_COMMANDS
|
||||
if package_manager != preferred_package_manager
|
||||
],
|
||||
]
|
||||
|
||||
for package_manager in package_managers:
|
||||
@@ -128,8 +133,10 @@ def _get_dashboard_build_command(dashboard_root: Path) -> list[str] | None:
|
||||
return None
|
||||
|
||||
|
||||
def _get_dashboard_manual_recovery_command(dashboard_root: Path, build_command: list[str] | None = None) -> str:
|
||||
package_manager = build_command[0] if build_command is not None else _get_preferred_dashboard_package_manager(dashboard_root)
|
||||
def _get_dashboard_manual_recovery_command(dashboard_root: Path, build_command: List[str] | None = None) -> str:
|
||||
package_manager = (
|
||||
build_command[0] if build_command is not None else _get_preferred_dashboard_package_manager(dashboard_root)
|
||||
)
|
||||
install_command = _DASHBOARD_INSTALL_COMMANDS.get(package_manager)
|
||||
selected_build_command = _DASHBOARD_BUILD_COMMANDS.get(package_manager)
|
||||
|
||||
@@ -308,8 +315,8 @@ def _setup_cors(app: FastAPI, port: int):
|
||||
|
||||
def _setup_anti_crawler(app: FastAPI):
|
||||
try:
|
||||
from src.webui.middleware import AntiCrawlerMiddleware
|
||||
from src.config.config import global_config
|
||||
from src.webui.middleware import AntiCrawlerMiddleware
|
||||
|
||||
anti_crawler_mode = global_config.webui.anti_crawler_mode
|
||||
app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import inspect
|
||||
from typing import Any, get_args, get_origin
|
||||
from typing import Any, Dict, List, get_args, get_origin
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
@@ -8,13 +8,13 @@ from src.config.config_base import ConfigBase
|
||||
|
||||
class ConfigSchemaGenerator:
|
||||
@classmethod
|
||||
def generate_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> dict[str, Any]:
|
||||
def generate_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> Dict[str, Any]:
|
||||
return cls.generate_config_schema(config_class, include_nested=include_nested)
|
||||
|
||||
@classmethod
|
||||
def generate_config_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> dict[str, Any]:
|
||||
fields: list[dict[str, Any]] = []
|
||||
nested: dict[str, dict[str, Any]] = {}
|
||||
def generate_config_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> Dict[str, Any]:
|
||||
fields: List[Dict[str, Any]] = []
|
||||
nested: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for field_name, field_info in config_class.model_fields.items():
|
||||
if field_name in {"field_docs", "_validate_any", "suppress_any_warning"}:
|
||||
@@ -28,7 +28,7 @@ class ConfigSchemaGenerator:
|
||||
if nested_schema is not None:
|
||||
nested[field_name] = nested_schema
|
||||
|
||||
schema: dict[str, Any] = {
|
||||
schema: Dict[str, Any] = {
|
||||
"className": config_class.__name__,
|
||||
"classDoc": (config_class.__doc__ or "").strip(),
|
||||
"fields": fields,
|
||||
@@ -49,7 +49,7 @@ class ConfigSchemaGenerator:
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def _build_nested_schema(cls, annotation: Any) -> dict[str, Any] | None:
|
||||
def _build_nested_schema(cls, annotation: Any) -> Dict[str, Any] | None:
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
@@ -66,10 +66,10 @@ class ConfigSchemaGenerator:
|
||||
@classmethod
|
||||
def _build_field_schema(
|
||||
cls, config_class: type[ConfigBase], field_name: str, annotation: Any, field_info: Any
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
field_docs = config_class.get_class_field_docs()
|
||||
field_type = cls._map_field_type(annotation)
|
||||
schema: dict[str, Any] = {
|
||||
schema: Dict[str, Any] = {
|
||||
"name": field_name,
|
||||
"type": field_type,
|
||||
"label": field_name,
|
||||
@@ -86,8 +86,7 @@ class ConfigSchemaGenerator:
|
||||
if origin is list and args:
|
||||
schema["items"] = {"type": cls._map_field_type(args[0])}
|
||||
|
||||
options = cls._extract_options(annotation)
|
||||
if options:
|
||||
if options := cls._extract_options(annotation):
|
||||
schema["options"] = options
|
||||
|
||||
# Task 1c: Merge json_schema_extra (x-widget, x-icon, step, etc.)
|
||||
@@ -105,7 +104,7 @@ class ConfigSchemaGenerator:
|
||||
return schema
|
||||
|
||||
@staticmethod
|
||||
def _extract_options(annotation: Any) -> list[str] | None:
|
||||
def _extract_options(annotation: Any) -> List[str] | None:
|
||||
origin = get_origin(annotation)
|
||||
if origin is None:
|
||||
return None
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
from .security import TokenManager, get_token_manager
|
||||
from .rate_limiter import (
|
||||
RateLimiter,
|
||||
get_rate_limiter,
|
||||
check_auth_rate_limit,
|
||||
check_api_rate_limit,
|
||||
)
|
||||
from .auth import (
|
||||
COOKIE_NAME,
|
||||
COOKIE_MAX_AGE,
|
||||
COOKIE_NAME,
|
||||
clear_auth_cookie,
|
||||
get_current_token,
|
||||
is_token_valid,
|
||||
set_auth_cookie,
|
||||
clear_auth_cookie,
|
||||
verify_auth_token_from_cookie_or_header,
|
||||
)
|
||||
from .rate_limiter import (
|
||||
RateLimiter,
|
||||
check_api_rate_limit,
|
||||
check_auth_rate_limit,
|
||||
get_rate_limiter,
|
||||
)
|
||||
from .security import TokenManager, get_token_manager
|
||||
|
||||
__all__ = [
|
||||
"TokenManager",
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Cookie, HTTPException, Request, Response
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
from .security import get_token_manager
|
||||
|
||||
logger = get_logger("webui.auth")
|
||||
@@ -54,6 +56,7 @@ def get_current_token(
|
||||
if not is_token_valid(maibot_session):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
assert maibot_session is not None
|
||||
return maibot_session
|
||||
|
||||
|
||||
|
||||
@@ -5,8 +5,10 @@ WebUI 请求频率限制模块
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Tuple, Optional
|
||||
from fastapi import Request, HTTPException
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.rate_limiter")
|
||||
@@ -21,7 +23,7 @@ class RateLimiter:
|
||||
|
||||
def __init__(self):
|
||||
# 存储格式: {key: [(timestamp, count), ...]}
|
||||
self._requests: Dict[str, list] = defaultdict(list)
|
||||
self._requests: Dict[str, List] = defaultdict(list)
|
||||
# 被封禁的 IP: {ip: unblock_timestamp}
|
||||
self._blocked: Dict[str, float] = {}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ WebUI Token 管理模块
|
||||
import json
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -52,7 +52,7 @@ class TokenManager:
|
||||
logger.error(f"读取 WebUI 配置文件失败: {e},正在重新创建")
|
||||
self._create_new_token()
|
||||
|
||||
def _load_config(self) -> dict:
|
||||
def _load_config(self) -> Dict:
|
||||
"""加载配置文件"""
|
||||
try:
|
||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||
@@ -61,7 +61,7 @@ class TokenManager:
|
||||
logger.error(f"加载 WebUI 配置失败: {e}")
|
||||
return {}
|
||||
|
||||
def _save_config(self, config: dict):
|
||||
def _save_config(self, config: Dict):
|
||||
"""保存配置文件"""
|
||||
try:
|
||||
with open(self.config_path, "w", encoding="utf-8") as f:
|
||||
@@ -127,7 +127,7 @@ class TokenManager:
|
||||
|
||||
return is_valid
|
||||
|
||||
def update_token(self, new_token: str) -> tuple[bool, str]:
|
||||
def update_token(self, new_token: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
更新 token
|
||||
|
||||
@@ -208,7 +208,7 @@ class TokenManager:
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _validate_custom_token(self, token: str) -> tuple[bool, str]:
|
||||
def _validate_custom_token(self, token: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
验证自定义 token 格式
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Cookie, Depends, Request
|
||||
|
||||
from .core import check_auth_rate_limit, get_current_token, is_token_valid
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""WebSocket 日志推送模块"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from typing import Set, Optional
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
@@ -15,7 +17,7 @@ router = APIRouter()
|
||||
active_connections: Set[WebSocket] = set()
|
||||
|
||||
|
||||
def load_recent_logs(limit: int = 100) -> list[dict]:
|
||||
def load_recent_logs(limit: int = 100) -> List[Dict]:
|
||||
"""从日志文件中加载最近的日志
|
||||
|
||||
Args:
|
||||
@@ -140,7 +142,7 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None
|
||||
active_connections.discard(websocket)
|
||||
|
||||
|
||||
async def broadcast_log(log_data: dict):
|
||||
async def broadcast_log(log_data: Dict):
|
||||
"""广播日志到所有连接的 WebSocket 客户端
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from .anti_crawler import (
|
||||
ALLOWED_IPS,
|
||||
ANTI_CRAWLER_MODE,
|
||||
TRUST_XFF,
|
||||
TRUSTED_PROXIES,
|
||||
AntiCrawlerMiddleware,
|
||||
create_robots_txt_response,
|
||||
ANTI_CRAWLER_MODE,
|
||||
ALLOWED_IPS,
|
||||
TRUSTED_PROXIES,
|
||||
TRUST_XFF,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -3,11 +3,12 @@ WebUI 防爬虫模块
|
||||
提供爬虫检测和阻止功能,保护 WebUI 不被搜索引擎和恶意爬虫访问
|
||||
"""
|
||||
|
||||
import time
|
||||
import ipaddress
|
||||
import re
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse
|
||||
@@ -131,7 +132,7 @@ SCANNER_SPECIFIC_HEADERS = {
|
||||
# - CIDR格式:192.168.1.0/24, 172.17.0.0/16 (适用于Docker网络)
|
||||
# - 通配符:192.168.*.*, 10.*.*.*, *.*.*.* (匹配所有)
|
||||
# - IPv6:::1, 2001:db8::/32
|
||||
def _parse_allowed_ips(ip_string: str) -> list:
|
||||
def _parse_allowed_ips(ip_string: str) -> List:
|
||||
"""
|
||||
解析IP白名单字符串,支持精确IP、CIDR格式和通配符
|
||||
|
||||
@@ -255,7 +256,7 @@ TRUSTED_PROXIES = _config["trusted_proxies"]
|
||||
TRUST_XFF = _config["trust_xff"]
|
||||
|
||||
|
||||
def _get_mode_config(mode: str) -> dict:
|
||||
def _get_mode_config(mode: str) -> Dict:
|
||||
"""
|
||||
根据模式获取配置参数
|
||||
|
||||
@@ -338,7 +339,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
||||
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
|
||||
|
||||
# 用于存储每个IP的请求时间戳(使用deque提高性能)
|
||||
self.request_times: dict[str, deque] = {}
|
||||
self.request_times: Dict[str, deque] = {}
|
||||
# 上次清理时间
|
||||
self.last_cleanup = time.time()
|
||||
# 将关键词列表转换为集合以提高查找性能
|
||||
@@ -407,7 +408,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
return False
|
||||
|
||||
def _detect_asset_scanner(self, request: Request) -> tuple[bool, Optional[str]]:
|
||||
def _detect_asset_scanner(self, request: Request) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
检测资产测绘工具
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""WebUI 路由聚合模块 - 提供统一的路由注册接口"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
|
||||
@@ -10,14 +12,14 @@ def get_api_router() -> APIRouter:
|
||||
return main_router
|
||||
|
||||
|
||||
def get_all_routers() -> list[APIRouter]:
|
||||
def get_all_routers() -> List[APIRouter]:
|
||||
"""获取所有需要独立注册的路由器列表"""
|
||||
from src.webui.routes import router as main_router
|
||||
from src.webui.routers.websocket.logs import router as logs_router
|
||||
from src.webui.routers.knowledge import router as knowledge_router
|
||||
from src.webui.routers.chat import router as chat_router
|
||||
from src.webui.api.planner import router as planner_router
|
||||
from src.webui.api.replier import router as replier_router
|
||||
from src.webui.routers.chat import router as chat_router
|
||||
from src.webui.routers.knowledge import router as knowledge_router
|
||||
from src.webui.routers.websocket.logs import router as logs_router
|
||||
from src.webui.routes import router as main_router
|
||||
|
||||
return [
|
||||
main_router,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from fastapi import APIRouter
|
||||
from typing import Tuple
|
||||
|
||||
from .routes import router
|
||||
from .support import ChatConnectionManager, WEBUI_CHAT_PLATFORM, chat_manager
|
||||
from .support import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
|
||||
|
||||
|
||||
def get_webui_chat_broadcaster() -> tuple[ChatConnectionManager, str]:
|
||||
def get_webui_chat_broadcaster() -> Tuple[ChatConnectionManager, str]:
|
||||
"""获取 WebUI 聊天广播器,供外部模块使用。"""
|
||||
return chat_manager, WEBUI_CHAT_PLATFORM
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"""本地聊天室路由 - WebUI 与麦麦直接对话。"""
|
||||
|
||||
import uuid
|
||||
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
|
||||
from sqlalchemy import case, func
|
||||
@@ -36,7 +35,7 @@ async def get_chat_history(
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
user_id: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None),
|
||||
) -> dict[str, object]:
|
||||
) -> Dict[str, object]:
|
||||
"""获取聊天历史记录。"""
|
||||
del user_id
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
@@ -45,7 +44,7 @@ async def get_chat_history(
|
||||
|
||||
|
||||
@router.get("/platforms")
|
||||
async def get_available_platforms() -> dict[str, object]:
|
||||
async def get_available_platforms() -> Dict[str, object]:
|
||||
"""获取可用平台列表。"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
@@ -68,7 +67,7 @@ async def get_persons_by_platform(
|
||||
platform: str = Query(..., description="平台名称"),
|
||||
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
) -> dict[str, object]:
|
||||
) -> Dict[str, object]:
|
||||
"""获取指定平台的用户列表。"""
|
||||
try:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.platform) == platform)
|
||||
@@ -108,7 +107,7 @@ async def get_persons_by_platform(
|
||||
@router.delete("/history")
|
||||
async def clear_chat_history(
|
||||
group_id: Optional[str] = Query(default=None),
|
||||
) -> dict[str, object]:
|
||||
) -> Dict[str, object]:
|
||||
"""清空聊天历史记录。"""
|
||||
deleted = chat_history.clear_history(group_id)
|
||||
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
|
||||
@@ -164,7 +163,7 @@ async def websocket_chat(
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_chat_info() -> dict[str, object]:
|
||||
async def get_chat_info() -> Dict[str, object]:
|
||||
"""获取聊天室信息。"""
|
||||
return {
|
||||
"bot_name": global_config.bot.nickname,
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""WebUI 聊天路由支持逻辑。"""
|
||||
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
from fastapi import WebSocket
|
||||
from pydantic import BaseModel
|
||||
@@ -59,7 +58,7 @@ class ChatHistoryManager:
|
||||
def __init__(self, max_messages: int = 200) -> None:
|
||||
self.max_messages = max_messages
|
||||
|
||||
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> dict[str, Any]:
|
||||
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
user_info = msg.message_info.user_info
|
||||
user_id = user_info.user_id or ""
|
||||
is_bot = is_bot_self(msg.platform, user_id)
|
||||
@@ -78,7 +77,7 @@ class ChatHistoryManager:
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
|
||||
|
||||
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> list[dict[str, Any]]:
|
||||
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
session_id = self._resolve_session_id(target_group_id)
|
||||
try:
|
||||
@@ -114,8 +113,8 @@ class ChatConnectionManager:
|
||||
"""聊天连接管理器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.active_connections: dict[str, WebSocket] = {}
|
||||
self.user_sessions: dict[str, str] = {}
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.user_sessions: Dict[str, str] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None:
|
||||
await websocket.accept()
|
||||
@@ -130,14 +129,14 @@ class ChatConnectionManager:
|
||||
del self.user_sessions[user_id]
|
||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||
|
||||
async def send_message(self, session_id: str, message: dict[str, Any]) -> None:
|
||||
async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
|
||||
if session_id in self.active_connections:
|
||||
try:
|
||||
await self.active_connections[session_id].send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
async def broadcast(self, message: dict[str, Any]) -> None:
|
||||
async def broadcast(self, message: Dict[str, Any]) -> None:
|
||||
for session_id in list(self.active_connections.keys()):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
@@ -224,8 +223,8 @@ def build_session_info_message(
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> dict[str, Any]:
|
||||
session_info_data: dict[str, Any] = {
|
||||
) -> Dict[str, Any]:
|
||||
session_info_data: Dict[str, Any] = {
|
||||
"type": "session_info",
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
@@ -314,7 +313,7 @@ def resolve_sender_identity(
|
||||
current_user_name: str,
|
||||
normalized_user_id: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> tuple[str, str]:
|
||||
) -> Tuple[str, str]:
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id
|
||||
@@ -328,7 +327,7 @@ def create_message_data(
|
||||
message_id: Optional[str] = None,
|
||||
is_at_bot: bool = True,
|
||||
virtual_config: Optional[VirtualIdentityConfig] = None,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
@@ -385,7 +384,7 @@ def create_message_data(
|
||||
|
||||
async def handle_chat_message(
|
||||
session_id: str,
|
||||
data: dict[str, Any],
|
||||
data: Dict[str, Any],
|
||||
current_user_name: str,
|
||||
normalized_user_id: str,
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
@@ -443,7 +442,7 @@ async def handle_chat_ping(session_id: str) -> None:
|
||||
await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()})
|
||||
|
||||
|
||||
async def handle_nickname_update(session_id: str, data: dict[str, Any], current_user_name: str) -> str:
|
||||
async def handle_nickname_update(session_id: str, data: Dict[str, Any], current_user_name: str) -> str:
|
||||
new_name = str(data.get("user_name", "")).strip()
|
||||
if not new_name:
|
||||
return current_user_name
|
||||
@@ -462,7 +461,7 @@ async def handle_nickname_update(session_id: str, data: dict[str, Any], current_
|
||||
async def enable_virtual_identity(
|
||||
session_id: str,
|
||||
session_prefix: str,
|
||||
virtual_data: dict[str, Any],
|
||||
virtual_data: Dict[str, Any],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
|
||||
await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id")
|
||||
@@ -558,7 +557,7 @@ async def disable_virtual_identity(session_id: str) -> None:
|
||||
async def handle_virtual_identity_update(
|
||||
session_id: str,
|
||||
session_id_prefix: str,
|
||||
data: dict[str, Any],
|
||||
data: Dict[str, Any],
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
virtual_data = cast(dict[str, Any], data.get("config", {}))
|
||||
@@ -573,11 +572,11 @@ async def handle_virtual_identity_update(
|
||||
async def dispatch_chat_event(
|
||||
session_id: str,
|
||||
session_id_prefix: str,
|
||||
data: dict[str, Any],
|
||||
data: Dict[str, Any],
|
||||
current_user_name: str,
|
||||
normalized_user_id: str,
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> tuple[str, Optional[VirtualIdentityConfig]]:
|
||||
) -> Tuple[str, Optional[VirtualIdentityConfig]]:
|
||||
event_type = data.get("type")
|
||||
if event_type == "message":
|
||||
next_user_name = await handle_chat_message(
|
||||
|
||||
@@ -5,51 +5,51 @@
|
||||
import copy
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Annotated, Optional
|
||||
from typing import Annotated, Any, Dict, List, Tuple
|
||||
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.dependencies import require_auth
|
||||
from src.webui.utils.toml_utils import save_toml_with_format, _update_toml_doc
|
||||
from src.config.config import Config, ModelConfig, CONFIG_DIR, PROJECT_ROOT
|
||||
from src.config.config import CONFIG_DIR, PROJECT_ROOT, Config, ModelConfig
|
||||
from src.config.config_base import AttributeData
|
||||
from src.config.model_configs import (
|
||||
APIProvider,
|
||||
ModelInfo,
|
||||
ModelTaskConfig,
|
||||
)
|
||||
from src.config.official_configs import (
|
||||
BotConfig,
|
||||
PersonalityConfig,
|
||||
RelationshipConfig,
|
||||
ChatConfig,
|
||||
MessageReceiveConfig,
|
||||
ChineseTypoConfig,
|
||||
DebugConfig,
|
||||
EmojiConfig,
|
||||
ExperimentalConfig,
|
||||
ExpressionConfig,
|
||||
KeywordReactionConfig,
|
||||
ChineseTypoConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
MaimMessageConfig,
|
||||
MemoryConfig,
|
||||
MessageReceiveConfig,
|
||||
PersonalityConfig,
|
||||
RelationshipConfig,
|
||||
ResponsePostProcessConfig,
|
||||
ResponseSplitterConfig,
|
||||
TelemetryConfig,
|
||||
ExperimentalConfig,
|
||||
MaimMessageConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
ToolConfig,
|
||||
MemoryConfig,
|
||||
DebugConfig,
|
||||
VoiceConfig,
|
||||
)
|
||||
from src.config.model_configs import (
|
||||
ModelTaskConfig,
|
||||
ModelInfo,
|
||||
APIProvider,
|
||||
)
|
||||
from src.webui.config_schema import ConfigSchemaGenerator
|
||||
from src.webui.dependencies import require_auth
|
||||
from src.webui.utils.toml_utils import _update_toml_doc, save_toml_with_format
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||
ConfigBody = Annotated[dict[str, Any], Body()]
|
||||
ConfigBody = Annotated[Dict[str, Any], Body()]
|
||||
SectionBody = Annotated[Any, Body()]
|
||||
RawContentBody = Annotated[str, Body(embed=True)]
|
||||
PathBody = Annotated[dict[str, str], Body()]
|
||||
PathBody = Annotated[Dict[str, str], Body()]
|
||||
|
||||
router = APIRouter(prefix="/config", tags=["config"], dependencies=[Depends(require_auth)])
|
||||
|
||||
@@ -61,6 +61,8 @@ def _toml_to_plain_dict(obj: Any) -> Any:
|
||||
if isinstance(obj, list):
|
||||
return [_toml_to_plain_dict(v) for v in obj]
|
||||
return obj
|
||||
|
||||
|
||||
# ===== 架构获取接口 =====
|
||||
|
||||
|
||||
@@ -385,8 +387,12 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
|
||||
if section_name == "api_providers" and "api_provider" in str(e):
|
||||
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
|
||||
models = config_data.get("models", [])
|
||||
orphaned_models = [
|
||||
m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
|
||||
orphaned_models: List[str] = [
|
||||
str(model_name)
|
||||
for m in models
|
||||
if isinstance(m, dict)
|
||||
and m.get("api_provider") not in provider_names
|
||||
and (model_name := m.get("name")) is not None
|
||||
]
|
||||
if orphaned_models:
|
||||
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
|
||||
@@ -421,7 +427,7 @@ def _normalize_adapter_path(path: str) -> str:
|
||||
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
|
||||
|
||||
|
||||
def _get_allowed_adapter_config_roots() -> tuple[Path, ...]:
|
||||
def _get_allowed_adapter_config_roots() -> Tuple[Path, ...]:
|
||||
project_root = Path(PROJECT_ROOT).resolve()
|
||||
return (
|
||||
project_root,
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
"""表情包管理 API 路由"""
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Cookie, HTTPException, Query
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
@@ -17,7 +16,8 @@ from sqlmodel import col, select
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header as verify_auth_token
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.core import verify_auth_token_from_cookie_or_header as verify_auth_token
|
||||
|
||||
from .schemas import (
|
||||
BatchDeleteRequest,
|
||||
@@ -219,7 +219,7 @@ async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(Non
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
"""获取表情包统计数据。"""
|
||||
try:
|
||||
verify_auth_token(maibot_session)
|
||||
@@ -247,7 +247,7 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None)) -> dict[
|
||||
registered = session.exec(registered_statement).one()
|
||||
banned = session.exec(banned_statement).one()
|
||||
|
||||
formats: dict[str, int] = {}
|
||||
formats: Dict[str, int] = {}
|
||||
format_statement = select(Images.full_path).where(col(Images.image_type) == ImageType.EMOJI)
|
||||
for full_path in session.exec(format_statement).all():
|
||||
suffix = Path(full_path).suffix.lower().lstrip(".")
|
||||
@@ -443,7 +443,7 @@ async def batch_delete_emojis(
|
||||
|
||||
deleted_count = 0
|
||||
failed_count = 0
|
||||
failed_ids: list[int] = []
|
||||
failed_ids: List[int] = []
|
||||
|
||||
for emoji_id in request.emoji_ids:
|
||||
try:
|
||||
@@ -582,12 +582,12 @@ async def batch_upload_emoji(
|
||||
emotion: EmotionForm = "",
|
||||
is_registered: IsRegisteredForm = True,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
"""批量上传表情包。"""
|
||||
try:
|
||||
verify_auth_token(maibot_session)
|
||||
|
||||
results: dict[str, Any] = {
|
||||
results: Dict[str, Any] = {
|
||||
"success": True,
|
||||
"total": len(files),
|
||||
"uploaded": 0,
|
||||
@@ -614,9 +614,7 @@ async def batch_upload_emoji(
|
||||
file_content = await file.read()
|
||||
if not file_content:
|
||||
results["failed"] += 1
|
||||
results["details"].append(
|
||||
{"filename": file.filename, "success": False, "error": "文件内容为空"}
|
||||
)
|
||||
results["details"].append({"filename": file.filename, "success": False, "error": "文件内容为空"})
|
||||
continue
|
||||
|
||||
try:
|
||||
@@ -677,14 +675,10 @@ async def batch_upload_emoji(
|
||||
session.flush()
|
||||
|
||||
results["uploaded"] += 1
|
||||
results["details"].append(
|
||||
{"filename": file.filename, "success": True, "id": emoji.id}
|
||||
)
|
||||
results["details"].append({"filename": file.filename, "success": True, "id": emoji.id})
|
||||
except Exception as e:
|
||||
results["failed"] += 1
|
||||
results["details"].append(
|
||||
{"filename": file.filename, "success": False, "error": str(e)}
|
||||
)
|
||||
results["details"].append({"filename": file.filename, "success": False, "error": str(e)})
|
||||
|
||||
results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']} 个"
|
||||
return results
|
||||
@@ -787,7 +781,9 @@ async def preheat_thumbnail_cache(
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(get_thumbnail_executor(), generate_thumbnail, emoji.full_path, emoji.image_hash)
|
||||
await loop.run_in_executor(
|
||||
get_thumbnail_executor(), generate_thumbnail, emoji.full_path, emoji.image_hash
|
||||
)
|
||||
generated += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"预热缩略图失败 {emoji.image_hash}: {e}")
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Dict, Set, Tuple
|
||||
|
||||
from PIL import Image
|
||||
from sqlmodel import col, select
|
||||
@@ -18,10 +18,10 @@ THUMBNAIL_SIZE = (200, 200)
|
||||
THUMBNAIL_QUALITY = 80
|
||||
EMOJI_REGISTERED_DIR = os.path.join("data", "emoji_registed")
|
||||
|
||||
_thumbnail_locks: dict[str, threading.Lock] = {}
|
||||
_thumbnail_locks: Dict[str, threading.Lock] = {}
|
||||
_locks_lock = threading.Lock()
|
||||
_thumbnail_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="thumbnail")
|
||||
_generating_thumbnails: set[str] = set()
|
||||
_generating_thumbnails: Set[str] = set()
|
||||
_generating_lock = threading.Lock()
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ def get_generating_lock() -> threading.Lock:
|
||||
return _generating_lock
|
||||
|
||||
|
||||
def get_generating_thumbnails() -> set[str]:
|
||||
def get_generating_thumbnails() -> Set[str]:
|
||||
"""获取正在生成的缩略图哈希集合。"""
|
||||
return _generating_thumbnails
|
||||
|
||||
@@ -112,7 +112,7 @@ def background_generate_thumbnail(source_path: str, file_hash: str) -> None:
|
||||
_background_generate_thumbnail(source_path, file_hash)
|
||||
|
||||
|
||||
def cleanup_orphaned_thumbnails() -> tuple[int, int]:
|
||||
def cleanup_orphaned_thumbnails() -> Tuple[int, int]:
|
||||
"""清理孤立的缩略图缓存。"""
|
||||
if not THUMBNAIL_CACHE_DIR.exists():
|
||||
return 0, 0
|
||||
|
||||
@@ -5,14 +5,13 @@ from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sqlalchemy import case, func
|
||||
from sqlmodel import col, select, delete
|
||||
from sqlmodel import col, delete, select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
logger = get_logger("webui.expression")
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"""黑话(俚语)管理路由"""
|
||||
|
||||
import json
|
||||
|
||||
from typing import Annotated, Any, List, Optional
|
||||
from typing import Annotated, Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -90,7 +89,7 @@ class JargonListResponse(BaseModel):
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
data: List[dict[str, Any]]
|
||||
data: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class JargonDetailResponse(BaseModel):
|
||||
@@ -153,7 +152,7 @@ class JargonStatsResponse(BaseModel):
|
||||
"""黑话统计响应"""
|
||||
|
||||
success: bool = True
|
||||
data: dict[str, Any]
|
||||
data: Dict[str, Any]
|
||||
|
||||
|
||||
class ChatInfoResponse(BaseModel):
|
||||
@@ -175,7 +174,7 @@ class ChatListResponse(BaseModel):
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
|
||||
def parse_session_id_dict(session_id_dict_str: Optional[str]) -> dict[str, int]:
|
||||
def parse_session_id_dict(session_id_dict_str: Optional[str]) -> Dict[str, int]:
|
||||
"""解析会话计数字典。"""
|
||||
if not session_id_dict_str:
|
||||
return {}
|
||||
@@ -188,7 +187,7 @@ def parse_session_id_dict(session_id_dict_str: Optional[str]) -> dict[str, int]:
|
||||
if not isinstance(parsed, dict):
|
||||
return {}
|
||||
|
||||
session_counts: dict[str, int] = {}
|
||||
session_counts: Dict[str, int] = {}
|
||||
for session_id, count in parsed.items():
|
||||
if not isinstance(session_id, str):
|
||||
continue
|
||||
@@ -202,7 +201,7 @@ def parse_session_id_dict(session_id_dict_str: Optional[str]) -> dict[str, int]:
|
||||
return session_counts
|
||||
|
||||
|
||||
def dump_session_id_dict(session_counts: dict[str, int]) -> str:
|
||||
def dump_session_id_dict(session_counts: Dict[str, int]) -> str:
|
||||
"""序列化会话计数字典。"""
|
||||
return json.dumps(session_counts, ensure_ascii=False)
|
||||
|
||||
@@ -225,7 +224,7 @@ def build_session_id_dict_for_chat(chat_id: str, count: int = 1) -> str:
|
||||
return dump_session_id_dict({chat_id: count})
|
||||
|
||||
|
||||
def jargon_to_dict(jargon: Jargon, session: Session) -> dict[str, Any]:
|
||||
def jargon_to_dict(jargon: Jargon, session: Session) -> Dict[str, Any]:
|
||||
"""将 Jargon ORM 对象转换为字典"""
|
||||
chat_id = get_primary_chat_id(jargon.session_id_dict)
|
||||
chat_name = get_display_name_for_chat_id(chat_id, session) if chat_id else None
|
||||
@@ -311,14 +310,16 @@ async def get_chat_list():
|
||||
with get_db_session() as session:
|
||||
jargons = session.exec(select(Jargon)).all()
|
||||
|
||||
seen_stream_ids: set[str] = set()
|
||||
seen_stream_ids: Set[str] = set()
|
||||
for jargon in jargons:
|
||||
seen_stream_ids.update(parse_session_id_dict(jargon.session_id_dict).keys())
|
||||
|
||||
result = []
|
||||
with get_db_session() as session:
|
||||
for stream_id in seen_stream_ids:
|
||||
if chat_session := session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first():
|
||||
if chat_session := session.exec(
|
||||
select(ChatSession).where(col(ChatSession.session_id) == stream_id)
|
||||
).first():
|
||||
chat_name = str(chat_session.group_id) if chat_session.group_id else stream_id[:20]
|
||||
result.append(
|
||||
ChatInfoResponse(
|
||||
@@ -358,7 +359,7 @@ async def get_jargon_stats():
|
||||
pending = sum(jargon.is_jargon is None for jargon in jargons)
|
||||
complete_count = sum(jargon.is_complete for jargon in jargons)
|
||||
|
||||
top_chats_counter: dict[str, int] = {}
|
||||
top_chats_counter: Dict[str, int] = {}
|
||||
for jargon in jargons:
|
||||
for session_id in parse_session_id_dict(jargon.session_id_dict):
|
||||
top_chats_counter[session_id] = top_chats_counter.get(session_id, 0) + 1
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""知识库图谱可视化 API 路由"""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
import logging
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
@@ -59,7 +60,7 @@ def _get_paragraph_store():
|
||||
return None
|
||||
|
||||
|
||||
def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
|
||||
def _get_paragraph_content(node_id: str) -> Tuple[Optional[str], bool]:
|
||||
"""从 embedding store 获取段落完整内容
|
||||
|
||||
Args:
|
||||
|
||||
@@ -5,9 +5,9 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
import httpx
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
@@ -39,7 +39,7 @@ def _normalize_url(url: str) -> str:
|
||||
return url.rstrip("/") if url else ""
|
||||
|
||||
|
||||
def _parse_openai_response(data: dict) -> list[dict]:
|
||||
def _parse_openai_response(data: Dict) -> List[Dict]:
|
||||
"""
|
||||
解析 OpenAI 格式的模型列表响应
|
||||
|
||||
@@ -59,7 +59,7 @@ def _parse_openai_response(data: dict) -> list[dict]:
|
||||
]
|
||||
|
||||
|
||||
def _parse_gemini_response(data: dict) -> list[dict]:
|
||||
def _parse_gemini_response(data: Dict) -> List[Dict]:
|
||||
"""
|
||||
解析 Gemini 格式的模型列表响应
|
||||
|
||||
@@ -89,7 +89,7 @@ async def _fetch_models_from_provider(
|
||||
endpoint: str,
|
||||
parser: str,
|
||||
client_type: str = "openai",
|
||||
) -> list[dict]:
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
从提供商 API 获取模型列表
|
||||
|
||||
@@ -154,7 +154,7 @@ async def _fetch_models_from_provider(
|
||||
raise HTTPException(status_code=400, detail=f"不支持的解析器类型: {parser}")
|
||||
|
||||
|
||||
def _get_provider_config(provider_name: str) -> Optional[dict]:
|
||||
def _get_provider_config(provider_name: str) -> Optional[Dict]:
|
||||
"""
|
||||
从 model_config.toml 获取指定提供商的配置
|
||||
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
"""人物信息管理 API 路由"""
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sqlalchemy import case
|
||||
from sqlmodel import col, select, delete
|
||||
from sqlmodel import col, delete, select
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Cookie, HTTPException
|
||||
|
||||
@@ -28,7 +27,7 @@ logger = get_logger("webui.plugin_routes")
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _mirror_to_response(mirror: dict[str, Any]) -> MirrorConfigResponse:
|
||||
def _mirror_to_response(mirror: Dict[str, Any]) -> MirrorConfigResponse:
|
||||
return MirrorConfigResponse(
|
||||
id=mirror["id"],
|
||||
name=mirror["name"],
|
||||
@@ -116,7 +115,7 @@ async def update_mirror(
|
||||
|
||||
|
||||
@router.delete("/mirrors/{mirror_id}")
|
||||
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
|
||||
service = get_git_mirror_service()
|
||||
@@ -172,12 +171,16 @@ async def fetch_raw_file(
|
||||
loaded_plugins=total,
|
||||
)
|
||||
except Exception:
|
||||
await update_progress(stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0)
|
||||
await update_progress(
|
||||
stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0
|
||||
)
|
||||
|
||||
return FetchRawFileResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"获取 Raw 文件失败: {e}")
|
||||
await update_progress(stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0)
|
||||
await update_progress(
|
||||
stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import json
|
||||
import tomlkit
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, Cookie, HTTPException
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -24,8 +23,8 @@ logger = get_logger("webui.plugin_routes")
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> dict[str, Any]:
|
||||
schema: dict[str, Any] = {
|
||||
def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> Dict[str, Any]:
|
||||
schema: Dict[str, Any] = {
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_info": {
|
||||
"name": plugin_id,
|
||||
@@ -41,7 +40,7 @@ def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> di
|
||||
for section_name, section_data in current_config.items():
|
||||
if not isinstance(section_data, dict):
|
||||
continue
|
||||
section_fields: dict[str, Any] = {}
|
||||
section_fields: Dict[str, Any] = {}
|
||||
for field_name, field_value in section_data.items():
|
||||
field_type = type(field_value).__name__
|
||||
ui_type = "text"
|
||||
@@ -121,7 +120,7 @@ def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> di
|
||||
|
||||
|
||||
@router.get("/config/{plugin_id}/schema")
|
||||
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"获取插件配置 Schema: {plugin_id}")
|
||||
|
||||
@@ -157,7 +156,7 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
|
||||
|
||||
|
||||
@router.get("/config/{plugin_id}/raw")
|
||||
async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"获取插件原始配置: {plugin_id}")
|
||||
|
||||
@@ -184,7 +183,7 @@ async def update_plugin_config_raw(
|
||||
plugin_id: str,
|
||||
request: UpdatePluginRawConfigRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"更新插件原始配置: {plugin_id}")
|
||||
|
||||
@@ -216,7 +215,7 @@ async def update_plugin_config_raw(
|
||||
|
||||
|
||||
@router.get("/config/{plugin_id}")
|
||||
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"获取插件配置: {plugin_id}")
|
||||
|
||||
@@ -244,7 +243,7 @@ async def update_plugin_config(
|
||||
plugin_id: str,
|
||||
request: UpdatePluginConfigRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"更新插件配置: {plugin_id}")
|
||||
|
||||
@@ -276,7 +275,7 @@ async def update_plugin_config(
|
||||
|
||||
|
||||
@router.post("/config/{plugin_id}/reset")
|
||||
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"重置插件配置: {plugin_id}")
|
||||
|
||||
@@ -300,7 +299,7 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co
|
||||
|
||||
|
||||
@router.post("/config/{plugin_id}/toggle")
|
||||
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"切换插件状态: {plugin_id}")
|
||||
|
||||
@@ -326,7 +325,12 @@ async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(N
|
||||
|
||||
status = "启用" if new_enabled else "禁用"
|
||||
logger.info(f"已{status}插件: {plugin_id}")
|
||||
return {"success": True, "enabled": new_enabled, "message": f"插件已{status}", "note": "状态更改将在下次加载插件时生效"}
|
||||
return {
|
||||
"success": True,
|
||||
"enabled": new_enabled,
|
||||
"message": f"插件已{status}",
|
||||
"note": "状态更改将在下次加载插件时生效",
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Cookie, HTTPException
|
||||
|
||||
@@ -18,8 +17,8 @@ from .support import (
|
||||
parse_repository_url,
|
||||
remove_tree,
|
||||
require_plugin_token,
|
||||
resolve_plugin_file_path,
|
||||
resolve_installed_plugin_path,
|
||||
resolve_plugin_file_path,
|
||||
validate_plugin_id,
|
||||
)
|
||||
|
||||
@@ -28,7 +27,7 @@ logger = get_logger("webui.plugin_routes")
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _infer_plugin_id(folder_name: str, manifest: dict[str, Any], manifest_path: Path) -> str:
|
||||
def _infer_plugin_id(folder_name: str, manifest: Dict[str, Any], manifest_path: Path) -> str:
|
||||
if "id" in manifest:
|
||||
return str(manifest["id"])
|
||||
|
||||
@@ -66,43 +65,87 @@ def _infer_plugin_id(folder_name: str, manifest: dict[str, Any], manifest_path:
|
||||
|
||||
|
||||
@router.post("/install")
|
||||
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"收到安装插件请求: {request.plugin_id}")
|
||||
plugin_id = request.plugin_id
|
||||
|
||||
try:
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
await update_progress(stage="loading", progress=5, message=f"开始安装插件: {plugin_id}", operation="install", plugin_id=plugin_id)
|
||||
await update_progress(
|
||||
stage="loading", progress=5, message=f"开始安装插件: {plugin_id}", operation="install", plugin_id=plugin_id
|
||||
)
|
||||
|
||||
repo_url, owner, repo = parse_repository_url(request.repository_url)
|
||||
await update_progress(stage="loading", progress=10, message=f"解析仓库信息: {owner}/{repo}", operation="install", plugin_id=plugin_id)
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=10,
|
||||
message=f"解析仓库信息: {owner}/{repo}",
|
||||
operation="install",
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
target_path, old_format_path = get_plugin_candidate_paths(plugin_id)
|
||||
if target_path.exists() or old_format_path.exists():
|
||||
await update_progress(stage="error", progress=0, message="插件已存在", operation="install", plugin_id=plugin_id, error="插件已安装,请先卸载")
|
||||
await update_progress(
|
||||
stage="error",
|
||||
progress=0,
|
||||
message="插件已存在",
|
||||
operation="install",
|
||||
plugin_id=plugin_id,
|
||||
error="插件已安装,请先卸载",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="插件已安装")
|
||||
|
||||
await update_progress(stage="loading", progress=15, message=f"准备克隆到: {target_path}", operation="install", plugin_id=plugin_id)
|
||||
await update_progress(
|
||||
stage="loading", progress=15, message=f"准备克隆到: {target_path}", operation="install", plugin_id=plugin_id
|
||||
)
|
||||
service = get_git_mirror_service()
|
||||
if "github.com" in repo_url:
|
||||
result = await service.clone_repository(owner=owner, repo=repo, target_path=target_path, branch=request.branch, mirror_id=request.mirror_id, depth=1)
|
||||
result = await service.clone_repository(
|
||||
owner=owner,
|
||||
repo=repo,
|
||||
target_path=target_path,
|
||||
branch=request.branch,
|
||||
mirror_id=request.mirror_id,
|
||||
depth=1,
|
||||
)
|
||||
else:
|
||||
result = await service.clone_repository(owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1)
|
||||
result = await service.clone_repository(
|
||||
owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1
|
||||
)
|
||||
|
||||
if not result.get("success"):
|
||||
error_msg = str(result.get("error", "克隆失败"))
|
||||
await update_progress(stage="error", progress=0, message="克隆仓库失败", operation="install", plugin_id=plugin_id, error=error_msg)
|
||||
await update_progress(
|
||||
stage="error",
|
||||
progress=0,
|
||||
message="克隆仓库失败",
|
||||
operation="install",
|
||||
plugin_id=plugin_id,
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg)
|
||||
|
||||
await update_progress(stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id)
|
||||
await update_progress(
|
||||
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id
|
||||
)
|
||||
manifest_path = resolve_plugin_file_path(target_path, "_manifest.json")
|
||||
if not manifest_path.exists():
|
||||
remove_tree(target_path)
|
||||
await update_progress(stage="error", progress=0, message="插件缺少 _manifest.json", operation="install", plugin_id=plugin_id, error="无效的插件格式")
|
||||
await update_progress(
|
||||
stage="error",
|
||||
progress=0,
|
||||
message="插件缺少 _manifest.json",
|
||||
operation="install",
|
||||
plugin_id=plugin_id,
|
||||
error="无效的插件格式",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||
|
||||
await update_progress(stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id)
|
||||
await update_progress(
|
||||
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id
|
||||
)
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as file_obj:
|
||||
manifest = json.load(file_obj)
|
||||
@@ -114,91 +157,199 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
|
||||
json.dump(manifest, file_obj, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
remove_tree(target_path)
|
||||
await update_progress(stage="error", progress=0, message="_manifest.json 无效", operation="install", plugin_id=plugin_id, error=str(e))
|
||||
await update_progress(
|
||||
stage="error",
|
||||
progress=0,
|
||||
message="_manifest.json 无效",
|
||||
operation="install",
|
||||
plugin_id=plugin_id,
|
||||
error=str(e),
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||
|
||||
await update_progress(stage="success", progress=100, message=f"成功安装插件: {manifest['name']} v{manifest['version']}", operation="install", plugin_id=plugin_id)
|
||||
return {"success": True, "message": "插件安装成功", "plugin_id": plugin_id, "plugin_name": manifest["name"], "version": manifest["version"], "path": str(target_path)}
|
||||
await update_progress(
|
||||
stage="success",
|
||||
progress=100,
|
||||
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
|
||||
operation="install",
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "插件安装成功",
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": manifest["name"],
|
||||
"version": manifest["version"],
|
||||
"path": str(target_path),
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"安装插件失败: {e}", exc_info=True)
|
||||
await update_progress(stage="error", progress=0, message="安装失败", operation="install", plugin_id=plugin_id, error=str(e))
|
||||
await update_progress(
|
||||
stage="error", progress=0, message="安装失败", operation="install", plugin_id=plugin_id, error=str(e)
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/uninstall")
|
||||
async def uninstall_plugin(request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
async def uninstall_plugin(
|
||||
request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None)
|
||||
) -> Dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"收到卸载插件请求: {request.plugin_id}")
|
||||
plugin_id = request.plugin_id
|
||||
|
||||
try:
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
await update_progress(stage="loading", progress=10, message=f"开始卸载插件: {plugin_id}", operation="uninstall", plugin_id=plugin_id)
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=10,
|
||||
message=f"开始卸载插件: {plugin_id}",
|
||||
operation="uninstall",
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
plugin_path = resolve_installed_plugin_path(plugin_id)
|
||||
if plugin_path is None:
|
||||
await update_progress(stage="error", progress=0, message="插件不存在", operation="uninstall", plugin_id=plugin_id, error="插件未安装或已被删除")
|
||||
await update_progress(
|
||||
stage="error",
|
||||
progress=0,
|
||||
message="插件不存在",
|
||||
operation="uninstall",
|
||||
plugin_id=plugin_id,
|
||||
error="插件未安装或已被删除",
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="插件未安装")
|
||||
|
||||
await update_progress(stage="loading", progress=30, message=f"正在删除插件文件: {plugin_path}", operation="uninstall", plugin_id=plugin_id)
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=30,
|
||||
message=f"正在删除插件文件: {plugin_path}",
|
||||
operation="uninstall",
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json"))
|
||||
plugin_name = str(manifest.get("name", plugin_id)) if manifest is not None else plugin_id
|
||||
await update_progress(stage="loading", progress=50, message=f"正在删除 {plugin_name}...", operation="uninstall", plugin_id=plugin_id)
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=50,
|
||||
message=f"正在删除 {plugin_name}...",
|
||||
operation="uninstall",
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
remove_tree(plugin_path)
|
||||
logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})")
|
||||
await update_progress(stage="success", progress=100, message=f"成功卸载插件: {plugin_name}", operation="uninstall", plugin_id=plugin_id)
|
||||
await update_progress(
|
||||
stage="success",
|
||||
progress=100,
|
||||
message=f"成功卸载插件: {plugin_name}",
|
||||
operation="uninstall",
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name}
|
||||
except HTTPException:
|
||||
raise
|
||||
except PermissionError as e:
|
||||
logger.error(f"卸载插件失败(权限错误): {e}")
|
||||
await update_progress(stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error="权限不足,无法删除插件文件")
|
||||
await update_progress(
|
||||
stage="error",
|
||||
progress=0,
|
||||
message="卸载失败",
|
||||
operation="uninstall",
|
||||
plugin_id=plugin_id,
|
||||
error="权限不足,无法删除插件文件",
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="权限不足,无法删除插件文件") from e
|
||||
except Exception as e:
|
||||
logger.error(f"卸载插件失败: {e}", exc_info=True)
|
||||
await update_progress(stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error=str(e))
|
||||
await update_progress(
|
||||
stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error=str(e)
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/update")
|
||||
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"收到更新插件请求: {request.plugin_id}")
|
||||
plugin_id = request.plugin_id
|
||||
|
||||
try:
|
||||
plugin_id = validate_plugin_id(request.plugin_id)
|
||||
await update_progress(stage="loading", progress=5, message=f"开始更新插件: {plugin_id}", operation="update", plugin_id=plugin_id)
|
||||
await update_progress(
|
||||
stage="loading", progress=5, message=f"开始更新插件: {plugin_id}", operation="update", plugin_id=plugin_id
|
||||
)
|
||||
plugin_path = resolve_installed_plugin_path(plugin_id)
|
||||
if plugin_path is None:
|
||||
await update_progress(stage="error", progress=0, message="插件不存在", operation="update", plugin_id=plugin_id, error="插件未安装,请先安装")
|
||||
await update_progress(
|
||||
stage="error",
|
||||
progress=0,
|
||||
message="插件不存在",
|
||||
operation="update",
|
||||
plugin_id=plugin_id,
|
||||
error="插件未安装,请先安装",
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="插件未安装")
|
||||
|
||||
manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json"))
|
||||
old_version = str(manifest.get("version", "unknown")) if manifest is not None else "unknown"
|
||||
await update_progress(stage="loading", progress=10, message=f"当前版本: {old_version},准备更新...", operation="update", plugin_id=plugin_id)
|
||||
await update_progress(stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id)
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=10,
|
||||
message=f"当前版本: {old_version},准备更新...",
|
||||
operation="update",
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
await update_progress(
|
||||
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id
|
||||
)
|
||||
remove_tree(plugin_path)
|
||||
|
||||
await update_progress(stage="loading", progress=30, message="正在准备下载新版本...", operation="update", plugin_id=plugin_id)
|
||||
await update_progress(
|
||||
stage="loading", progress=30, message="正在准备下载新版本...", operation="update", plugin_id=plugin_id
|
||||
)
|
||||
repo_url, owner, repo = parse_repository_url(request.repository_url)
|
||||
service = get_git_mirror_service()
|
||||
if "github.com" in repo_url:
|
||||
result = await service.clone_repository(owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, mirror_id=request.mirror_id, depth=1)
|
||||
result = await service.clone_repository(
|
||||
owner=owner,
|
||||
repo=repo,
|
||||
target_path=plugin_path,
|
||||
branch=request.branch,
|
||||
mirror_id=request.mirror_id,
|
||||
depth=1,
|
||||
)
|
||||
else:
|
||||
result = await service.clone_repository(owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1)
|
||||
result = await service.clone_repository(
|
||||
owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1
|
||||
)
|
||||
|
||||
if not result.get("success"):
|
||||
error_msg = str(result.get("error", "克隆失败"))
|
||||
await update_progress(stage="error", progress=0, message="下载新版本失败", operation="update", plugin_id=plugin_id, error=error_msg)
|
||||
await update_progress(
|
||||
stage="error",
|
||||
progress=0,
|
||||
message="下载新版本失败",
|
||||
operation="update",
|
||||
plugin_id=plugin_id,
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg)
|
||||
|
||||
await update_progress(stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id)
|
||||
await update_progress(
|
||||
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id
|
||||
)
|
||||
new_manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json")
|
||||
if not new_manifest_path.exists():
|
||||
remove_tree(plugin_path)
|
||||
await update_progress(stage="error", progress=0, message="新版本缺少 _manifest.json", operation="update", plugin_id=plugin_id, error="无效的插件格式")
|
||||
await update_progress(
|
||||
stage="error",
|
||||
progress=0,
|
||||
message="新版本缺少 _manifest.json",
|
||||
operation="update",
|
||||
plugin_id=plugin_id,
|
||||
error="无效的插件格式",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||
|
||||
try:
|
||||
@@ -207,27 +358,49 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
||||
new_version = str(new_manifest.get("version", "unknown"))
|
||||
new_name = str(new_manifest.get("name", plugin_id))
|
||||
logger.info(f"成功更新插件: {plugin_id} {old_version} → {new_version}")
|
||||
await update_progress(stage="success", progress=100, message=f"成功更新 {new_name}: {old_version} → {new_version}", operation="update", plugin_id=plugin_id)
|
||||
return {"success": True, "message": "插件更新成功", "plugin_id": plugin_id, "plugin_name": new_name, "old_version": old_version, "new_version": new_version}
|
||||
await update_progress(
|
||||
stage="success",
|
||||
progress=100,
|
||||
message=f"成功更新 {new_name}: {old_version} → {new_version}",
|
||||
operation="update",
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "插件更新成功",
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_name": new_name,
|
||||
"old_version": old_version,
|
||||
"new_version": new_version,
|
||||
}
|
||||
except Exception as e:
|
||||
remove_tree(plugin_path)
|
||||
await update_progress(stage="error", progress=0, message="_manifest.json 无效", operation="update", plugin_id=plugin_id, error=str(e))
|
||||
await update_progress(
|
||||
stage="error",
|
||||
progress=0,
|
||||
message="_manifest.json 无效",
|
||||
operation="update",
|
||||
plugin_id=plugin_id,
|
||||
error=str(e),
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件失败: {e}", exc_info=True)
|
||||
await update_progress(stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e))
|
||||
await update_progress(
|
||||
stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e)
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/installed")
|
||||
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info("收到获取已安装插件列表请求")
|
||||
|
||||
try:
|
||||
installed_plugins: list[dict[str, Any]] = []
|
||||
installed_plugins: List[Dict[str, Any]] = []
|
||||
for plugin_path in iter_plugin_directories():
|
||||
folder_name = plugin_path.name
|
||||
if folder_name.startswith(".") or folder_name.startswith("__"):
|
||||
@@ -253,9 +426,9 @@ async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) ->
|
||||
except Exception as e:
|
||||
logger.error(f"读取插件 {folder_name} 信息时出错: {e}")
|
||||
|
||||
seen_ids: dict[str, str] = {}
|
||||
unique_plugins: list[dict[str, Any]] = []
|
||||
duplicates: list[dict[str, Any]] = []
|
||||
seen_ids: Dict[str, str] = {}
|
||||
unique_plugins: List[Dict[str, Any]] = []
|
||||
duplicates: List[Dict[str, Any]] = []
|
||||
for plugin in installed_plugins:
|
||||
plugin_id = str(plugin["id"])
|
||||
plugin_path = str(plugin["path"])
|
||||
@@ -277,7 +450,7 @@ async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) ->
|
||||
|
||||
|
||||
@router.get("/local-readme/{plugin_id}")
|
||||
async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
||||
async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"获取本地插件 README: {plugin_id}")
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from typing import Any, Optional, Set
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||
|
||||
@@ -14,7 +13,7 @@ logger = get_logger("webui.plugin_progress")
|
||||
router = APIRouter()
|
||||
|
||||
active_connections: Set[WebSocket] = set()
|
||||
current_progress: dict[str, Any] = {
|
||||
current_progress: Dict[str, Any] = {
|
||||
"operation": "idle",
|
||||
"stage": "idle",
|
||||
"progress": 0,
|
||||
@@ -26,7 +25,7 @@ current_progress: dict[str, Any] = {
|
||||
}
|
||||
|
||||
|
||||
async def broadcast_progress(progress_data: dict[str, Any]) -> None:
|
||||
async def broadcast_progress(progress_data: Dict[str, Any]) -> None:
|
||||
global current_progress
|
||||
current_progress = progress_data.copy()
|
||||
|
||||
@@ -34,7 +33,7 @@ async def broadcast_progress(progress_data: dict[str, Any]) -> None:
|
||||
return
|
||||
|
||||
message = json.dumps(progress_data, ensure_ascii=False)
|
||||
disconnected: set[WebSocket] = set()
|
||||
disconnected: Set[WebSocket] = set()
|
||||
|
||||
for websocket in active_connections:
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -51,8 +51,8 @@ class MirrorConfigResponse(BaseModel):
|
||||
|
||||
|
||||
class AvailableMirrorsResponse(BaseModel):
|
||||
mirrors: list[MirrorConfigResponse] = Field(..., description="镜像源列表")
|
||||
default_priority: list[str] = Field(..., description="默认优先级顺序(ID 列表)")
|
||||
mirrors: List[MirrorConfigResponse] = Field(..., description="镜像源列表")
|
||||
default_priority: List[str] = Field(..., description="默认优先级顺序(ID 列表)")
|
||||
|
||||
|
||||
class AddMirrorRequest(BaseModel):
|
||||
@@ -106,7 +106,7 @@ class UpdatePluginRequest(BaseModel):
|
||||
|
||||
class UpdatePluginConfigRequest(BaseModel):
|
||||
enabled: Optional[bool] = None
|
||||
config: Optional[dict[str, Any]] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class UpdatePluginRawConfigRequest(BaseModel):
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, cast, get_origin
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import stat
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast, get_origin
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
@@ -53,10 +52,7 @@ def _resolve_safe_plugin_directory(plugin_path: Path, plugins_dir: Path, strict:
|
||||
resolved_plugin_path = plugin_path.resolve()
|
||||
resolved_plugin_path.relative_to(resolved_plugins_dir)
|
||||
|
||||
if not resolved_plugin_path.is_dir():
|
||||
return None
|
||||
|
||||
return resolved_plugin_path
|
||||
return resolved_plugin_path if resolved_plugin_path.is_dir() else None
|
||||
except HTTPException:
|
||||
if strict:
|
||||
raise
|
||||
@@ -64,7 +60,7 @@ def _resolve_safe_plugin_directory(plugin_path: Path, plugins_dir: Path, strict:
|
||||
return None
|
||||
except (OSError, RuntimeError, ValueError):
|
||||
if strict:
|
||||
raise HTTPException(status_code=400, detail="插件目录超出允许范围")
|
||||
raise HTTPException(status_code=400, detail="插件目录超出允许范围") from None
|
||||
logger.warning(f"已跳过越界的插件目录: {plugin_path}")
|
||||
return None
|
||||
|
||||
@@ -109,7 +105,7 @@ def validate_plugin_id(plugin_id: str) -> str:
|
||||
return plugin_id
|
||||
|
||||
|
||||
def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||
def parse_version(version_str: str) -> Tuple[int, int, int]:
|
||||
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
|
||||
parts = base_version.split(".")
|
||||
if len(parts) < 3:
|
||||
@@ -122,7 +118,7 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||
return 0, 0, 0
|
||||
|
||||
|
||||
def deep_merge(dst: dict[str, Any], src: dict[str, Any]) -> None:
|
||||
def deep_merge(dst: Dict[str, Any], src: Dict[str, Any]) -> None:
|
||||
for key, value in src.items():
|
||||
if key in dst and isinstance(dst[key], dict) and isinstance(value, dict):
|
||||
deep_merge(dst[key], value)
|
||||
@@ -130,9 +126,9 @@ def deep_merge(dst: dict[str, Any], src: dict[str, Any]) -> None:
|
||||
dst[key] = value
|
||||
|
||||
|
||||
def normalize_dotted_keys(obj: dict[str, Any]) -> dict[str, Any]:
|
||||
result: dict[str, Any] = {}
|
||||
dotted_items: list[tuple[str, Any]] = []
|
||||
def normalize_dotted_keys(obj: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result: Dict[str, Any] = {}
|
||||
dotted_items: List[Tuple[str, Any]] = []
|
||||
|
||||
for key, value in obj.items():
|
||||
if "." in key:
|
||||
@@ -167,7 +163,7 @@ def normalize_dotted_keys(obj: dict[str, Any]) -> dict[str, Any]:
|
||||
return result
|
||||
|
||||
|
||||
def coerce_types(schema_part: dict[str, Any], config_part: dict[str, Any]) -> None:
|
||||
def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> None:
|
||||
def is_list_type(tp: Any) -> bool:
|
||||
origin = get_origin(tp)
|
||||
return tp is list or origin is list
|
||||
@@ -200,7 +196,7 @@ def get_plugins_dir() -> Path:
|
||||
return plugins_dir
|
||||
|
||||
|
||||
def get_plugin_candidate_paths(plugin_id: str) -> tuple[Path, Path]:
|
||||
def get_plugin_candidate_paths(plugin_id: str) -> Tuple[Path, Path]:
|
||||
plugins_dir = get_plugins_dir()
|
||||
folder_name = plugin_id.replace(".", "_")
|
||||
return validate_safe_path(folder_name, plugins_dir), validate_safe_path(plugin_id, plugins_dir)
|
||||
@@ -217,7 +213,7 @@ def resolve_installed_plugin_path(plugin_id: str) -> Optional[Path]:
|
||||
return None
|
||||
|
||||
|
||||
def parse_repository_url(repository_url: str) -> tuple[str, str, str]:
|
||||
def parse_repository_url(repository_url: str) -> Tuple[str, str, str]:
|
||||
repo_url = repository_url.rstrip("/").removesuffix(".git")
|
||||
parts = repo_url.split("/")
|
||||
if len(parts) < 2:
|
||||
@@ -225,7 +221,7 @@ def parse_repository_url(repository_url: str) -> tuple[str, str, str]:
|
||||
return repo_url, parts[-2], parts[-1]
|
||||
|
||||
|
||||
def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
|
||||
def load_manifest_json(manifest_path: Path) -> Optional[Dict[str, Any]]:
|
||||
if not manifest_path.exists():
|
||||
return None
|
||||
|
||||
@@ -246,9 +242,9 @@ def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
|
||||
return None
|
||||
|
||||
|
||||
def iter_plugin_directories() -> list[Path]:
|
||||
def iter_plugin_directories() -> List[Path]:
|
||||
plugins_dir = get_plugins_dir()
|
||||
plugin_directories: list[Path] = []
|
||||
plugin_directories: List[Path] = []
|
||||
for path in plugins_dir.iterdir():
|
||||
safe_path = _resolve_safe_plugin_directory(path, plugins_dir, strict=False)
|
||||
if safe_path is not None:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""统计数据 API 路由"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -56,10 +56,10 @@ class DashboardData(BaseModel):
|
||||
"""仪表盘数据"""
|
||||
|
||||
summary: StatisticsSummary
|
||||
model_stats: list[ModelStatistics]
|
||||
hourly_data: list[TimeSeriesData]
|
||||
daily_data: list[TimeSeriesData]
|
||||
recent_activity: list[dict[str, Any]]
|
||||
model_stats: List[ModelStatistics]
|
||||
hourly_data: List[TimeSeriesData]
|
||||
daily_data: List[TimeSeriesData]
|
||||
recent_activity: List[Dict[str, Any]]
|
||||
|
||||
|
||||
@router.get("/dashboard", response_model=DashboardData)
|
||||
@@ -168,7 +168,7 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
|
||||
return summary
|
||||
|
||||
|
||||
async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]:
|
||||
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
||||
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
|
||||
# 使用GROUP BY聚合,避免全量加载
|
||||
statement = (
|
||||
@@ -181,7 +181,7 @@ async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]:
|
||||
with get_db_session() as session:
|
||||
rows = session.exec(statement).all()
|
||||
|
||||
aggregates: dict[str, dict[str, float | int]] = {}
|
||||
aggregates: Dict[str, Dict[str, float | int]] = {}
|
||||
for record in rows:
|
||||
model_name = record.model_assign_name or record.model_name or "unknown"
|
||||
if model_name not in aggregates:
|
||||
@@ -200,7 +200,7 @@ async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]:
|
||||
bucket["total_time_cost"] = float(bucket["total_time_cost"]) + float(record.time_cost)
|
||||
bucket["time_cost_count"] = int(bucket["time_cost_count"]) + 1
|
||||
|
||||
result: list[ModelStatistics] = []
|
||||
result: List[ModelStatistics] = []
|
||||
for model_name, bucket in sorted(
|
||||
aggregates.items(),
|
||||
key=lambda item: float(item[1]["request_count"]),
|
||||
@@ -221,7 +221,7 @@ async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]:
|
||||
return result
|
||||
|
||||
|
||||
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> list[TimeSeriesData]:
|
||||
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""获取小时级统计数据(优化:使用数据库聚合)"""
|
||||
# SQLite的日期时间函数进行小时分组
|
||||
# 使用strftime将timestamp格式化为小时级别
|
||||
@@ -265,7 +265,7 @@ async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> li
|
||||
return result
|
||||
|
||||
|
||||
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> list[TimeSeriesData]:
|
||||
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""获取日级统计数据(优化:使用数据库聚合)"""
|
||||
# 使用strftime按日期分组
|
||||
day_expr = func.strftime("%Y-%m-%dT00:00:00", col(ModelUsage.timestamp))
|
||||
@@ -308,7 +308,7 @@ async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> lis
|
||||
return result
|
||||
|
||||
|
||||
async def _get_recent_activity(limit: int = 10) -> list[dict[str, Any]]:
|
||||
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取最近活动"""
|
||||
with get_db_session() as session:
|
||||
statement = select(ModelUsage).order_by(desc(col(ModelUsage.timestamp))).limit(limit)
|
||||
|
||||
@@ -10,8 +10,9 @@ from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from src.config.config import MMC_VERSION
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import MMC_VERSION
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
router = APIRouter(prefix="/system", tags=["system"], dependencies=[Depends(require_auth)])
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from .logs import router as logs_router
|
||||
from .auth import router as ws_auth_router
|
||||
from .logs import router as logs_router
|
||||
|
||||
__all__ = [
|
||||
"logs_router",
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""WebSocket 认证模块。"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Cookie
|
||||
import secrets
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter, Cookie
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import get_token_manager
|
||||
|
||||
@@ -13,7 +14,7 @@ router = APIRouter()
|
||||
|
||||
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
|
||||
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
|
||||
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
|
||||
_ws_temp_tokens: Dict[str, Tuple[float, str]] = {}
|
||||
_WS_TOKEN_EXPIRE_SECONDS = 60
|
||||
|
||||
|
||||
|
||||
@@ -2,24 +2,26 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import (
|
||||
clear_auth_cookie,
|
||||
check_auth_rate_limit,
|
||||
clear_auth_cookie,
|
||||
get_rate_limiter,
|
||||
get_token_manager,
|
||||
set_auth_cookie,
|
||||
)
|
||||
from src.webui.dependencies import require_auth, verify_token_optional
|
||||
from src.webui.routers.config import router as config_router
|
||||
from src.webui.routers.statistics import router as statistics_router
|
||||
from src.webui.routers.person import router as person_router
|
||||
from src.webui.routers.emoji import router as emoji_router
|
||||
from src.webui.routers.expression import router as expression_router
|
||||
from src.webui.routers.jargon import router as jargon_router
|
||||
from src.webui.routers.emoji import router as emoji_router
|
||||
from src.webui.routers.plugin import get_progress_router, router as plugin_router
|
||||
from src.webui.routers.system import router as system_router
|
||||
from src.webui.routers.model import router as model_router
|
||||
from src.webui.routers.person import router as person_router
|
||||
from src.webui.routers.plugin import get_progress_router
|
||||
from src.webui.routers.plugin import router as plugin_router
|
||||
from src.webui.routers.statistics import router as statistics_router
|
||||
from src.webui.routers.system import router as system_router
|
||||
from src.webui.routers.websocket.auth import router as ws_auth_router
|
||||
|
||||
logger = get_logger("webui.api")
|
||||
|
||||
@@ -2,62 +2,62 @@
|
||||
|
||||
# Auth schemas
|
||||
from .auth import (
|
||||
TokenVerifyRequest,
|
||||
TokenVerifyResponse,
|
||||
CompleteSetupResponse,
|
||||
FirstSetupStatusResponse,
|
||||
ResetSetupResponse,
|
||||
TokenRegenerateResponse,
|
||||
TokenUpdateRequest,
|
||||
TokenUpdateResponse,
|
||||
TokenRegenerateResponse,
|
||||
FirstSetupStatusResponse,
|
||||
CompleteSetupResponse,
|
||||
ResetSetupResponse,
|
||||
TokenVerifyRequest,
|
||||
TokenVerifyResponse,
|
||||
)
|
||||
|
||||
# Statistics schemas
|
||||
from .statistics import (
|
||||
StatisticsSummary,
|
||||
ModelStatistics,
|
||||
TimeSeriesData,
|
||||
DashboardData,
|
||||
# Chat schemas
|
||||
from .chat import (
|
||||
ChatHistoryMessage,
|
||||
VirtualIdentityConfig,
|
||||
)
|
||||
|
||||
# Emoji schemas
|
||||
from .emoji import (
|
||||
EmojiResponse,
|
||||
EmojiListResponse,
|
||||
EmojiDetailResponse,
|
||||
EmojiUpdateRequest,
|
||||
EmojiUpdateResponse,
|
||||
EmojiDeleteResponse,
|
||||
BatchDeleteRequest,
|
||||
BatchDeleteResponse,
|
||||
EmojiDeleteResponse,
|
||||
EmojiDetailResponse,
|
||||
EmojiListResponse,
|
||||
EmojiResponse,
|
||||
EmojiUpdateRequest,
|
||||
EmojiUpdateResponse,
|
||||
EmojiUploadResponse,
|
||||
ThumbnailCacheStatsResponse,
|
||||
ThumbnailCleanupResponse,
|
||||
ThumbnailPreheatResponse,
|
||||
)
|
||||
|
||||
# Chat schemas
|
||||
from .chat import (
|
||||
VirtualIdentityConfig,
|
||||
ChatHistoryMessage,
|
||||
)
|
||||
|
||||
# Plugin schemas
|
||||
from .plugin import (
|
||||
FetchRawFileRequest,
|
||||
FetchRawFileResponse,
|
||||
AddMirrorRequest,
|
||||
AvailableMirrorsResponse,
|
||||
CloneRepositoryRequest,
|
||||
CloneRepositoryResponse,
|
||||
MirrorConfigResponse,
|
||||
AvailableMirrorsResponse,
|
||||
AddMirrorRequest,
|
||||
UpdateMirrorRequest,
|
||||
FetchRawFileRequest,
|
||||
FetchRawFileResponse,
|
||||
GitStatusResponse,
|
||||
InstallPluginRequest,
|
||||
VersionResponse,
|
||||
MirrorConfigResponse,
|
||||
UninstallPluginRequest,
|
||||
UpdatePluginRequest,
|
||||
UpdateMirrorRequest,
|
||||
UpdatePluginConfigRequest,
|
||||
UpdatePluginRequest,
|
||||
VersionResponse,
|
||||
)
|
||||
|
||||
# Statistics schemas
|
||||
from .statistics import (
|
||||
DashboardData,
|
||||
ModelStatistics,
|
||||
StatisticsSummary,
|
||||
TimeSeriesData,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class VirtualIdentityConfig(BaseModel):
|
||||
"""虚拟身份配置"""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
class EmojiResponse(BaseModel):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
|
||||
class FetchRawFileRequest(BaseModel):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, List
|
||||
|
||||
|
||||
class StatisticsSummary(BaseModel):
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
import httpx
|
||||
import json
|
||||
import asyncio
|
||||
import subprocess
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.utils.network_security import validate_public_url
|
||||
|
||||
@@ -188,10 +190,8 @@ class GitMirrorConfig:
|
||||
|
||||
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""根据 ID 获取镜像源"""
|
||||
for mirror in self.mirrors:
|
||||
if mirror.get("id") == mirror_id:
|
||||
return mirror.copy()
|
||||
return None
|
||||
matched_mirror = next((mirror for mirror in self.mirrors if mirror.get("id") == mirror_id), None)
|
||||
return matched_mirror.copy() if matched_mirror is not None else None
|
||||
|
||||
def add_mirror(
|
||||
self,
|
||||
@@ -332,8 +332,8 @@ class GitMirrorService:
|
||||
- path: str - Git 可执行文件路径(如果已安装)
|
||||
- error: str - 错误信息(如果未安装或检测失败)
|
||||
"""
|
||||
import subprocess
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
# 查找 git 可执行文件路径
|
||||
@@ -409,8 +409,7 @@ class GitMirrorService:
|
||||
# 确定要使用的镜像源列表
|
||||
if mirror_id:
|
||||
# 使用指定的镜像源
|
||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||
if not mirror:
|
||||
if (mirror := self.config.get_mirror_by_id(mirror_id)) is None:
|
||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||
mirrors_to_try = [mirror]
|
||||
else:
|
||||
@@ -477,7 +476,13 @@ class GitMirrorService:
|
||||
try:
|
||||
raw_prefix = _validate_mirror_prefix(mirror["raw_prefix"], "镜像 Raw 前缀")
|
||||
except ValueError as e:
|
||||
return {"success": False, "error": str(e), "mirror_used": mirror.get("id"), "attempts": 0, "status_code": 400}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"mirror_used": mirror.get("id"),
|
||||
"attempts": 0,
|
||||
"status_code": 400,
|
||||
}
|
||||
|
||||
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
|
||||
|
||||
@@ -566,8 +571,7 @@ class GitMirrorService:
|
||||
# 确定要使用的镜像源列表
|
||||
if mirror_id:
|
||||
# 使用指定的镜像源
|
||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||
if not mirror:
|
||||
if (mirror := self.config.get_mirror_by_id(mirror_id)) is None:
|
||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||
mirrors_to_try = [mirror]
|
||||
else:
|
||||
@@ -597,7 +601,13 @@ class GitMirrorService:
|
||||
try:
|
||||
clone_prefix = _validate_mirror_prefix(mirror["clone_prefix"], "镜像克隆前缀")
|
||||
except ValueError as e:
|
||||
return {"success": False, "error": str(e), "mirror_used": mirror.get("id"), "attempts": 0, "status_code": 400}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"mirror_used": mirror.get("id"),
|
||||
"attempts": 0,
|
||||
"status_code": 400,
|
||||
}
|
||||
|
||||
url = f"{clone_prefix}/{owner}/{repo}.git"
|
||||
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
from typing import Iterable
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
|
||||
from typing import Iterable, Set
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def _resolve_ip_addresses(hostname: str, port: int) -> set[ipaddress.IPv4Address | ipaddress.IPv6Address]:
|
||||
def _resolve_ip_addresses(hostname: str, port: int) -> Set[ipaddress.IPv4Address | ipaddress.IPv6Address]:
|
||||
try:
|
||||
address_infos = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM)
|
||||
except socket.gaierror as exc:
|
||||
raise ValueError(f"无法解析主机名: {hostname}") from exc
|
||||
|
||||
resolved_addresses: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
|
||||
resolved_addresses: Set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
|
||||
for _, _, _, _, sockaddr in address_infos:
|
||||
host_address = sockaddr[0]
|
||||
if not isinstance(host_address, str):
|
||||
|
||||
@@ -4,8 +4,9 @@ TOML 工具函数
|
||||
提供 TOML 文件的格式化保存功能,确保数组等元素以美观的多行格式输出。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import tomlkit
|
||||
from tomlkit.items import AoT, Array, Table
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""独立的 WebUI 服务器 - 运行在 0.0.0.0:8001"""
|
||||
|
||||
from uvicorn import Config, Server as UvicornServer
|
||||
|
||||
import asyncio
|
||||
|
||||
from uvicorn import Config
|
||||
from uvicorn import Server as UvicornServer
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.port_checker import assert_port_available, is_port_conflict_error, log_port_conflict
|
||||
from src.config.config import config_manager
|
||||
|
||||
Reference in New Issue
Block a user