WebUI 后端类型注解补全,使用全 typing 库类型注解
This commit is contained in:
@@ -8,11 +8,10 @@
|
|||||||
3. 详情按需加载
|
3. 详情按需加载
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -272,8 +271,7 @@ async def get_all_logs(page: int = Query(1, ge=1), page_size: int = Query(20, ge
|
|||||||
all_files = []
|
all_files = []
|
||||||
for chat_dir in PLAN_LOG_DIR.iterdir():
|
for chat_dir in PLAN_LOG_DIR.iterdir():
|
||||||
if chat_dir.is_dir():
|
if chat_dir.is_dir():
|
||||||
for log_file in chat_dir.glob("*.json"):
|
all_files.extend((chat_dir.name, log_file) for log_file in chat_dir.glob("*.json"))
|
||||||
all_files.append((chat_dir.name, log_file))
|
|
||||||
|
|
||||||
# 按时间戳排序
|
# 按时间戳排序
|
||||||
all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True)
|
all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True)
|
||||||
|
|||||||
@@ -8,11 +8,10 @@
|
|||||||
3. 详情按需加载
|
3. 详情按需加载
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
"""FastAPI 应用工厂 - 创建和配置 WebUI 应用实例"""
|
"""FastAPI 应用工厂 - 创建和配置 WebUI 应用实例"""
|
||||||
|
|
||||||
|
import mimetypes
|
||||||
|
import shutil
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from subprocess import CompletedProcess, TimeoutExpired, run
|
from subprocess import CompletedProcess, TimeoutExpired, run
|
||||||
import mimetypes
|
from typing import Any, Dict, List, Tuple
|
||||||
import shutil
|
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@@ -61,12 +62,12 @@ def _get_dashboard_root() -> Path:
|
|||||||
return _get_project_root() / "dashboard"
|
return _get_project_root() / "dashboard"
|
||||||
|
|
||||||
|
|
||||||
def _format_dashboard_shell_commands(*commands: list[str]) -> str:
|
def _format_dashboard_shell_commands(*commands: List[str]) -> str:
|
||||||
formatted_commands = " && ".join(" ".join(command) for command in commands)
|
formatted_commands = " && ".join(" ".join(command) for command in commands)
|
||||||
return f"cd dashboard && {formatted_commands}"
|
return f"cd dashboard && {formatted_commands}"
|
||||||
|
|
||||||
|
|
||||||
def _validate_static_path(static_path: Path | None) -> tuple[str, dict[str, object]] | None:
|
def _validate_static_path(static_path: Path | None) -> Tuple[str, Dict[str, Any]] | None:
|
||||||
if static_path is None:
|
if static_path is None:
|
||||||
return "startup.webui_static_dir_missing", {}
|
return "startup.webui_static_dir_missing", {}
|
||||||
|
|
||||||
@@ -81,7 +82,7 @@ def _validate_static_path(static_path: Path | None) -> tuple[str, dict[str, obje
|
|||||||
|
|
||||||
|
|
||||||
def _summarize_command_output(command_result: CompletedProcess[str] | TimeoutExpired) -> str:
|
def _summarize_command_output(command_result: CompletedProcess[str] | TimeoutExpired) -> str:
|
||||||
output_chunks: list[str] = []
|
output_chunks: List[str] = []
|
||||||
stdout = command_result.stdout
|
stdout = command_result.stdout
|
||||||
stderr = command_result.stderr
|
stderr = command_result.stderr
|
||||||
|
|
||||||
@@ -111,14 +112,18 @@ def _get_preferred_dashboard_package_manager(dashboard_root: Path) -> str:
|
|||||||
return "npm"
|
return "npm"
|
||||||
|
|
||||||
|
|
||||||
def _get_dashboard_build_command(dashboard_root: Path) -> list[str] | None:
|
def _get_dashboard_build_command(dashboard_root: Path) -> List[str] | None:
|
||||||
if not (dashboard_root / "package.json").exists():
|
if not (dashboard_root / "package.json").exists():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
preferred_package_manager = _get_preferred_dashboard_package_manager(dashboard_root)
|
preferred_package_manager = _get_preferred_dashboard_package_manager(dashboard_root)
|
||||||
package_managers = [
|
package_managers = [
|
||||||
preferred_package_manager,
|
preferred_package_manager,
|
||||||
*[package_manager for package_manager in _DASHBOARD_BUILD_COMMANDS if package_manager != preferred_package_manager],
|
*[
|
||||||
|
package_manager
|
||||||
|
for package_manager in _DASHBOARD_BUILD_COMMANDS
|
||||||
|
if package_manager != preferred_package_manager
|
||||||
|
],
|
||||||
]
|
]
|
||||||
|
|
||||||
for package_manager in package_managers:
|
for package_manager in package_managers:
|
||||||
@@ -128,8 +133,10 @@ def _get_dashboard_build_command(dashboard_root: Path) -> list[str] | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _get_dashboard_manual_recovery_command(dashboard_root: Path, build_command: list[str] | None = None) -> str:
|
def _get_dashboard_manual_recovery_command(dashboard_root: Path, build_command: List[str] | None = None) -> str:
|
||||||
package_manager = build_command[0] if build_command is not None else _get_preferred_dashboard_package_manager(dashboard_root)
|
package_manager = (
|
||||||
|
build_command[0] if build_command is not None else _get_preferred_dashboard_package_manager(dashboard_root)
|
||||||
|
)
|
||||||
install_command = _DASHBOARD_INSTALL_COMMANDS.get(package_manager)
|
install_command = _DASHBOARD_INSTALL_COMMANDS.get(package_manager)
|
||||||
selected_build_command = _DASHBOARD_BUILD_COMMANDS.get(package_manager)
|
selected_build_command = _DASHBOARD_BUILD_COMMANDS.get(package_manager)
|
||||||
|
|
||||||
@@ -308,8 +315,8 @@ def _setup_cors(app: FastAPI, port: int):
|
|||||||
|
|
||||||
def _setup_anti_crawler(app: FastAPI):
|
def _setup_anti_crawler(app: FastAPI):
|
||||||
try:
|
try:
|
||||||
from src.webui.middleware import AntiCrawlerMiddleware
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
from src.webui.middleware import AntiCrawlerMiddleware
|
||||||
|
|
||||||
anti_crawler_mode = global_config.webui.anti_crawler_mode
|
anti_crawler_mode = global_config.webui.anti_crawler_mode
|
||||||
app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
|
app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from typing import Any, get_args, get_origin
|
from typing import Any, Dict, List, get_args, get_origin
|
||||||
|
|
||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
@@ -8,13 +8,13 @@ from src.config.config_base import ConfigBase
|
|||||||
|
|
||||||
class ConfigSchemaGenerator:
|
class ConfigSchemaGenerator:
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> dict[str, Any]:
|
def generate_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> Dict[str, Any]:
|
||||||
return cls.generate_config_schema(config_class, include_nested=include_nested)
|
return cls.generate_config_schema(config_class, include_nested=include_nested)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_config_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> dict[str, Any]:
|
def generate_config_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> Dict[str, Any]:
|
||||||
fields: list[dict[str, Any]] = []
|
fields: List[Dict[str, Any]] = []
|
||||||
nested: dict[str, dict[str, Any]] = {}
|
nested: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
for field_name, field_info in config_class.model_fields.items():
|
for field_name, field_info in config_class.model_fields.items():
|
||||||
if field_name in {"field_docs", "_validate_any", "suppress_any_warning"}:
|
if field_name in {"field_docs", "_validate_any", "suppress_any_warning"}:
|
||||||
@@ -28,7 +28,7 @@ class ConfigSchemaGenerator:
|
|||||||
if nested_schema is not None:
|
if nested_schema is not None:
|
||||||
nested[field_name] = nested_schema
|
nested[field_name] = nested_schema
|
||||||
|
|
||||||
schema: dict[str, Any] = {
|
schema: Dict[str, Any] = {
|
||||||
"className": config_class.__name__,
|
"className": config_class.__name__,
|
||||||
"classDoc": (config_class.__doc__ or "").strip(),
|
"classDoc": (config_class.__doc__ or "").strip(),
|
||||||
"fields": fields,
|
"fields": fields,
|
||||||
@@ -49,7 +49,7 @@ class ConfigSchemaGenerator:
|
|||||||
return schema
|
return schema
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _build_nested_schema(cls, annotation: Any) -> dict[str, Any] | None:
|
def _build_nested_schema(cls, annotation: Any) -> Dict[str, Any] | None:
|
||||||
origin = get_origin(annotation)
|
origin = get_origin(annotation)
|
||||||
args = get_args(annotation)
|
args = get_args(annotation)
|
||||||
|
|
||||||
@@ -66,10 +66,10 @@ class ConfigSchemaGenerator:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _build_field_schema(
|
def _build_field_schema(
|
||||||
cls, config_class: type[ConfigBase], field_name: str, annotation: Any, field_info: Any
|
cls, config_class: type[ConfigBase], field_name: str, annotation: Any, field_info: Any
|
||||||
) -> dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
field_docs = config_class.get_class_field_docs()
|
field_docs = config_class.get_class_field_docs()
|
||||||
field_type = cls._map_field_type(annotation)
|
field_type = cls._map_field_type(annotation)
|
||||||
schema: dict[str, Any] = {
|
schema: Dict[str, Any] = {
|
||||||
"name": field_name,
|
"name": field_name,
|
||||||
"type": field_type,
|
"type": field_type,
|
||||||
"label": field_name,
|
"label": field_name,
|
||||||
@@ -86,8 +86,7 @@ class ConfigSchemaGenerator:
|
|||||||
if origin is list and args:
|
if origin is list and args:
|
||||||
schema["items"] = {"type": cls._map_field_type(args[0])}
|
schema["items"] = {"type": cls._map_field_type(args[0])}
|
||||||
|
|
||||||
options = cls._extract_options(annotation)
|
if options := cls._extract_options(annotation):
|
||||||
if options:
|
|
||||||
schema["options"] = options
|
schema["options"] = options
|
||||||
|
|
||||||
# Task 1c: Merge json_schema_extra (x-widget, x-icon, step, etc.)
|
# Task 1c: Merge json_schema_extra (x-widget, x-icon, step, etc.)
|
||||||
@@ -105,7 +104,7 @@ class ConfigSchemaGenerator:
|
|||||||
return schema
|
return schema
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_options(annotation: Any) -> list[str] | None:
|
def _extract_options(annotation: Any) -> List[str] | None:
|
||||||
origin = get_origin(annotation)
|
origin = get_origin(annotation)
|
||||||
if origin is None:
|
if origin is None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -1,19 +1,19 @@
|
|||||||
from .security import TokenManager, get_token_manager
|
|
||||||
from .rate_limiter import (
|
|
||||||
RateLimiter,
|
|
||||||
get_rate_limiter,
|
|
||||||
check_auth_rate_limit,
|
|
||||||
check_api_rate_limit,
|
|
||||||
)
|
|
||||||
from .auth import (
|
from .auth import (
|
||||||
COOKIE_NAME,
|
|
||||||
COOKIE_MAX_AGE,
|
COOKIE_MAX_AGE,
|
||||||
|
COOKIE_NAME,
|
||||||
|
clear_auth_cookie,
|
||||||
get_current_token,
|
get_current_token,
|
||||||
is_token_valid,
|
is_token_valid,
|
||||||
set_auth_cookie,
|
set_auth_cookie,
|
||||||
clear_auth_cookie,
|
|
||||||
verify_auth_token_from_cookie_or_header,
|
verify_auth_token_from_cookie_or_header,
|
||||||
)
|
)
|
||||||
|
from .rate_limiter import (
|
||||||
|
RateLimiter,
|
||||||
|
check_api_rate_limit,
|
||||||
|
check_auth_rate_limit,
|
||||||
|
get_rate_limiter,
|
||||||
|
)
|
||||||
|
from .security import TokenManager, get_token_manager
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TokenManager",
|
"TokenManager",
|
||||||
|
|||||||
@@ -3,8 +3,10 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import Cookie, HTTPException, Request, Response
|
from fastapi import Cookie, HTTPException, Request, Response
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
from .security import get_token_manager
|
from .security import get_token_manager
|
||||||
|
|
||||||
logger = get_logger("webui.auth")
|
logger = get_logger("webui.auth")
|
||||||
@@ -54,6 +56,7 @@ def get_current_token(
|
|||||||
if not is_token_valid(maibot_session):
|
if not is_token_valid(maibot_session):
|
||||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||||
|
|
||||||
|
assert maibot_session is not None
|
||||||
return maibot_session
|
return maibot_session
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,10 @@ WebUI 请求频率限制模块
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, Tuple, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
from fastapi import Request, HTTPException
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("webui.rate_limiter")
|
logger = get_logger("webui.rate_limiter")
|
||||||
@@ -21,7 +23,7 @@ class RateLimiter:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 存储格式: {key: [(timestamp, count), ...]}
|
# 存储格式: {key: [(timestamp, count), ...]}
|
||||||
self._requests: Dict[str, list] = defaultdict(list)
|
self._requests: Dict[str, List] = defaultdict(list)
|
||||||
# 被封禁的 IP: {ip: unblock_timestamp}
|
# 被封禁的 IP: {ip: unblock_timestamp}
|
||||||
self._blocked: Dict[str, float] = {}
|
self._blocked: Dict[str, float] = {}
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ WebUI Token 管理模块
|
|||||||
import json
|
import json
|
||||||
import secrets
|
import secrets
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ class TokenManager:
|
|||||||
logger.error(f"读取 WebUI 配置文件失败: {e},正在重新创建")
|
logger.error(f"读取 WebUI 配置文件失败: {e},正在重新创建")
|
||||||
self._create_new_token()
|
self._create_new_token()
|
||||||
|
|
||||||
def _load_config(self) -> dict:
|
def _load_config(self) -> Dict:
|
||||||
"""加载配置文件"""
|
"""加载配置文件"""
|
||||||
try:
|
try:
|
||||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||||
@@ -61,7 +61,7 @@ class TokenManager:
|
|||||||
logger.error(f"加载 WebUI 配置失败: {e}")
|
logger.error(f"加载 WebUI 配置失败: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _save_config(self, config: dict):
|
def _save_config(self, config: Dict):
|
||||||
"""保存配置文件"""
|
"""保存配置文件"""
|
||||||
try:
|
try:
|
||||||
with open(self.config_path, "w", encoding="utf-8") as f:
|
with open(self.config_path, "w", encoding="utf-8") as f:
|
||||||
@@ -127,7 +127,7 @@ class TokenManager:
|
|||||||
|
|
||||||
return is_valid
|
return is_valid
|
||||||
|
|
||||||
def update_token(self, new_token: str) -> tuple[bool, str]:
|
def update_token(self, new_token: str) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
更新 token
|
更新 token
|
||||||
|
|
||||||
@@ -208,7 +208,7 @@ class TokenManager:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _validate_custom_token(self, token: str) -> tuple[bool, str]:
|
def _validate_custom_token(self, token: str) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
验证自定义 token 格式
|
验证自定义 token 格式
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import Cookie, Depends, Request
|
from fastapi import Cookie, Depends, Request
|
||||||
|
|
||||||
from .core import check_auth_rate_limit, get_current_token, is_token_valid
|
from .core import check_auth_rate_limit, get_current_token, is_token_valid
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
"""WebSocket 日志推送模块"""
|
"""WebSocket 日志推送模块"""
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
|
||||||
from typing import Set, Optional
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Set
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.webui.core import get_token_manager
|
from src.webui.core import get_token_manager
|
||||||
from src.webui.routers.websocket.auth import verify_ws_token
|
from src.webui.routers.websocket.auth import verify_ws_token
|
||||||
@@ -15,7 +17,7 @@ router = APIRouter()
|
|||||||
active_connections: Set[WebSocket] = set()
|
active_connections: Set[WebSocket] = set()
|
||||||
|
|
||||||
|
|
||||||
def load_recent_logs(limit: int = 100) -> list[dict]:
|
def load_recent_logs(limit: int = 100) -> List[Dict]:
|
||||||
"""从日志文件中加载最近的日志
|
"""从日志文件中加载最近的日志
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -140,7 +142,7 @@ async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None
|
|||||||
active_connections.discard(websocket)
|
active_connections.discard(websocket)
|
||||||
|
|
||||||
|
|
||||||
async def broadcast_log(log_data: dict):
|
async def broadcast_log(log_data: Dict):
|
||||||
"""广播日志到所有连接的 WebSocket 客户端
|
"""广播日志到所有连接的 WebSocket 客户端
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
from .anti_crawler import (
|
from .anti_crawler import (
|
||||||
|
ALLOWED_IPS,
|
||||||
|
ANTI_CRAWLER_MODE,
|
||||||
|
TRUST_XFF,
|
||||||
|
TRUSTED_PROXIES,
|
||||||
AntiCrawlerMiddleware,
|
AntiCrawlerMiddleware,
|
||||||
create_robots_txt_response,
|
create_robots_txt_response,
|
||||||
ANTI_CRAWLER_MODE,
|
|
||||||
ALLOWED_IPS,
|
|
||||||
TRUSTED_PROXIES,
|
|
||||||
TRUST_XFF,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@@ -3,11 +3,12 @@ WebUI 防爬虫模块
|
|||||||
提供爬虫检测和阻止功能,保护 WebUI 不被搜索引擎和恶意爬虫访问
|
提供爬虫检测和阻止功能,保护 WebUI 不被搜索引擎和恶意爬虫访问
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import PlainTextResponse
|
from starlette.responses import PlainTextResponse
|
||||||
@@ -131,7 +132,7 @@ SCANNER_SPECIFIC_HEADERS = {
|
|||||||
# - CIDR格式:192.168.1.0/24, 172.17.0.0/16 (适用于Docker网络)
|
# - CIDR格式:192.168.1.0/24, 172.17.0.0/16 (适用于Docker网络)
|
||||||
# - 通配符:192.168.*.*, 10.*.*.*, *.*.*.* (匹配所有)
|
# - 通配符:192.168.*.*, 10.*.*.*, *.*.*.* (匹配所有)
|
||||||
# - IPv6:::1, 2001:db8::/32
|
# - IPv6:::1, 2001:db8::/32
|
||||||
def _parse_allowed_ips(ip_string: str) -> list:
|
def _parse_allowed_ips(ip_string: str) -> List:
|
||||||
"""
|
"""
|
||||||
解析IP白名单字符串,支持精确IP、CIDR格式和通配符
|
解析IP白名单字符串,支持精确IP、CIDR格式和通配符
|
||||||
|
|
||||||
@@ -255,7 +256,7 @@ TRUSTED_PROXIES = _config["trusted_proxies"]
|
|||||||
TRUST_XFF = _config["trust_xff"]
|
TRUST_XFF = _config["trust_xff"]
|
||||||
|
|
||||||
|
|
||||||
def _get_mode_config(mode: str) -> dict:
|
def _get_mode_config(mode: str) -> Dict:
|
||||||
"""
|
"""
|
||||||
根据模式获取配置参数
|
根据模式获取配置参数
|
||||||
|
|
||||||
@@ -338,7 +339,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||||||
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
|
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
|
||||||
|
|
||||||
# 用于存储每个IP的请求时间戳(使用deque提高性能)
|
# 用于存储每个IP的请求时间戳(使用deque提高性能)
|
||||||
self.request_times: dict[str, deque] = {}
|
self.request_times: Dict[str, deque] = {}
|
||||||
# 上次清理时间
|
# 上次清理时间
|
||||||
self.last_cleanup = time.time()
|
self.last_cleanup = time.time()
|
||||||
# 将关键词列表转换为集合以提高查找性能
|
# 将关键词列表转换为集合以提高查找性能
|
||||||
@@ -407,7 +408,7 @@ class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _detect_asset_scanner(self, request: Request) -> tuple[bool, Optional[str]]:
|
def _detect_asset_scanner(self, request: Request) -> Tuple[bool, Optional[str]]:
|
||||||
"""
|
"""
|
||||||
检测资产测绘工具
|
检测资产测绘工具
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""WebUI 路由聚合模块 - 提供统一的路由注册接口"""
|
"""WebUI 路由聚合模块 - 提供统一的路由注册接口"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
|
||||||
@@ -10,14 +12,14 @@ def get_api_router() -> APIRouter:
|
|||||||
return main_router
|
return main_router
|
||||||
|
|
||||||
|
|
||||||
def get_all_routers() -> list[APIRouter]:
|
def get_all_routers() -> List[APIRouter]:
|
||||||
"""获取所有需要独立注册的路由器列表"""
|
"""获取所有需要独立注册的路由器列表"""
|
||||||
from src.webui.routes import router as main_router
|
|
||||||
from src.webui.routers.websocket.logs import router as logs_router
|
|
||||||
from src.webui.routers.knowledge import router as knowledge_router
|
|
||||||
from src.webui.routers.chat import router as chat_router
|
|
||||||
from src.webui.api.planner import router as planner_router
|
from src.webui.api.planner import router as planner_router
|
||||||
from src.webui.api.replier import router as replier_router
|
from src.webui.api.replier import router as replier_router
|
||||||
|
from src.webui.routers.chat import router as chat_router
|
||||||
|
from src.webui.routers.knowledge import router as knowledge_router
|
||||||
|
from src.webui.routers.websocket.logs import router as logs_router
|
||||||
|
from src.webui.routes import router as main_router
|
||||||
|
|
||||||
return [
|
return [
|
||||||
main_router,
|
main_router,
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
from fastapi import APIRouter
|
from typing import Tuple
|
||||||
|
|
||||||
from .routes import router
|
from .routes import router
|
||||||
from .support import ChatConnectionManager, WEBUI_CHAT_PLATFORM, chat_manager
|
from .support import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
|
||||||
|
|
||||||
|
|
||||||
def get_webui_chat_broadcaster() -> tuple[ChatConnectionManager, str]:
|
def get_webui_chat_broadcaster() -> Tuple[ChatConnectionManager, str]:
|
||||||
"""获取 WebUI 聊天广播器,供外部模块使用。"""
|
"""获取 WebUI 聊天广播器,供外部模块使用。"""
|
||||||
return chat_manager, WEBUI_CHAT_PLATFORM
|
return chat_manager, WEBUI_CHAT_PLATFORM
|
||||||
|
|
||||||
@@ -15,4 +15,4 @@ __all__ = [
|
|||||||
"chat_manager",
|
"chat_manager",
|
||||||
"get_webui_chat_broadcaster",
|
"get_webui_chat_broadcaster",
|
||||||
"router",
|
"router",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
"""本地聊天室路由 - WebUI 与麦麦直接对话。"""
|
"""本地聊天室路由 - WebUI 与麦麦直接对话。"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Dict, Optional
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
|
||||||
from sqlalchemy import case, func
|
from sqlalchemy import case, func
|
||||||
@@ -36,7 +35,7 @@ async def get_chat_history(
|
|||||||
limit: int = Query(default=50, ge=1, le=200),
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
user_id: Optional[str] = Query(default=None),
|
user_id: Optional[str] = Query(default=None),
|
||||||
group_id: Optional[str] = Query(default=None),
|
group_id: Optional[str] = Query(default=None),
|
||||||
) -> dict[str, object]:
|
) -> Dict[str, object]:
|
||||||
"""获取聊天历史记录。"""
|
"""获取聊天历史记录。"""
|
||||||
del user_id
|
del user_id
|
||||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||||
@@ -45,7 +44,7 @@ async def get_chat_history(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/platforms")
|
@router.get("/platforms")
|
||||||
async def get_available_platforms() -> dict[str, object]:
|
async def get_available_platforms() -> Dict[str, object]:
|
||||||
"""获取可用平台列表。"""
|
"""获取可用平台列表。"""
|
||||||
try:
|
try:
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
@@ -68,7 +67,7 @@ async def get_persons_by_platform(
|
|||||||
platform: str = Query(..., description="平台名称"),
|
platform: str = Query(..., description="平台名称"),
|
||||||
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
||||||
limit: int = Query(default=50, ge=1, le=200),
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
) -> dict[str, object]:
|
) -> Dict[str, object]:
|
||||||
"""获取指定平台的用户列表。"""
|
"""获取指定平台的用户列表。"""
|
||||||
try:
|
try:
|
||||||
statement = select(PersonInfo).where(col(PersonInfo.platform) == platform)
|
statement = select(PersonInfo).where(col(PersonInfo.platform) == platform)
|
||||||
@@ -108,7 +107,7 @@ async def get_persons_by_platform(
|
|||||||
@router.delete("/history")
|
@router.delete("/history")
|
||||||
async def clear_chat_history(
|
async def clear_chat_history(
|
||||||
group_id: Optional[str] = Query(default=None),
|
group_id: Optional[str] = Query(default=None),
|
||||||
) -> dict[str, object]:
|
) -> Dict[str, object]:
|
||||||
"""清空聊天历史记录。"""
|
"""清空聊天历史记录。"""
|
||||||
deleted = chat_history.clear_history(group_id)
|
deleted = chat_history.clear_history(group_id)
|
||||||
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
|
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
|
||||||
@@ -164,11 +163,11 @@ async def websocket_chat(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/info")
|
@router.get("/info")
|
||||||
async def get_chat_info() -> dict[str, object]:
|
async def get_chat_info() -> Dict[str, object]:
|
||||||
"""获取聊天室信息。"""
|
"""获取聊天室信息。"""
|
||||||
return {
|
return {
|
||||||
"bot_name": global_config.bot.nickname,
|
"bot_name": global_config.bot.nickname,
|
||||||
"platform": WEBUI_CHAT_PLATFORM,
|
"platform": WEBUI_CHAT_PLATFORM,
|
||||||
"group_id": WEBUI_CHAT_GROUP_ID,
|
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||||
"active_sessions": len(chat_manager.active_connections),
|
"active_sessions": len(chat_manager.active_connections),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
"""WebUI 聊天路由支持逻辑。"""
|
"""WebUI 聊天路由支持逻辑。"""
|
||||||
|
|
||||||
from typing import Any, Optional, cast
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||||
|
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -59,7 +58,7 @@ class ChatHistoryManager:
|
|||||||
def __init__(self, max_messages: int = 200) -> None:
|
def __init__(self, max_messages: int = 200) -> None:
|
||||||
self.max_messages = max_messages
|
self.max_messages = max_messages
|
||||||
|
|
||||||
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> dict[str, Any]:
|
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> Dict[str, Any]:
|
||||||
user_info = msg.message_info.user_info
|
user_info = msg.message_info.user_info
|
||||||
user_id = user_info.user_id or ""
|
user_id = user_info.user_id or ""
|
||||||
is_bot = is_bot_self(msg.platform, user_id)
|
is_bot = is_bot_self(msg.platform, user_id)
|
||||||
@@ -78,7 +77,7 @@ class ChatHistoryManager:
|
|||||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||||
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
|
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
|
||||||
|
|
||||||
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> list[dict[str, Any]]:
|
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||||
session_id = self._resolve_session_id(target_group_id)
|
session_id = self._resolve_session_id(target_group_id)
|
||||||
try:
|
try:
|
||||||
@@ -114,8 +113,8 @@ class ChatConnectionManager:
|
|||||||
"""聊天连接管理器。"""
|
"""聊天连接管理器。"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.active_connections: dict[str, WebSocket] = {}
|
self.active_connections: Dict[str, WebSocket] = {}
|
||||||
self.user_sessions: dict[str, str] = {}
|
self.user_sessions: Dict[str, str] = {}
|
||||||
|
|
||||||
async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None:
|
async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None:
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
@@ -130,14 +129,14 @@ class ChatConnectionManager:
|
|||||||
del self.user_sessions[user_id]
|
del self.user_sessions[user_id]
|
||||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||||
|
|
||||||
async def send_message(self, session_id: str, message: dict[str, Any]) -> None:
|
async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
|
||||||
if session_id in self.active_connections:
|
if session_id in self.active_connections:
|
||||||
try:
|
try:
|
||||||
await self.active_connections[session_id].send_json(message)
|
await self.active_connections[session_id].send_json(message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"发送消息失败: {e}")
|
logger.error(f"发送消息失败: {e}")
|
||||||
|
|
||||||
async def broadcast(self, message: dict[str, Any]) -> None:
|
async def broadcast(self, message: Dict[str, Any]) -> None:
|
||||||
for session_id in list(self.active_connections.keys()):
|
for session_id in list(self.active_connections.keys()):
|
||||||
await self.send_message(session_id, message)
|
await self.send_message(session_id, message)
|
||||||
|
|
||||||
@@ -224,8 +223,8 @@ def build_session_info_message(
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
user_name: str,
|
user_name: str,
|
||||||
virtual_config: Optional[VirtualIdentityConfig],
|
virtual_config: Optional[VirtualIdentityConfig],
|
||||||
) -> dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
session_info_data: dict[str, Any] = {
|
session_info_data: Dict[str, Any] = {
|
||||||
"type": "session_info",
|
"type": "session_info",
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
@@ -314,7 +313,7 @@ def resolve_sender_identity(
|
|||||||
current_user_name: str,
|
current_user_name: str,
|
||||||
normalized_user_id: str,
|
normalized_user_id: str,
|
||||||
virtual_config: Optional[VirtualIdentityConfig],
|
virtual_config: Optional[VirtualIdentityConfig],
|
||||||
) -> tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
if is_virtual_mode_enabled(virtual_config):
|
if is_virtual_mode_enabled(virtual_config):
|
||||||
assert virtual_config is not None
|
assert virtual_config is not None
|
||||||
return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id
|
return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id
|
||||||
@@ -328,7 +327,7 @@ def create_message_data(
|
|||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
is_at_bot: bool = True,
|
is_at_bot: bool = True,
|
||||||
virtual_config: Optional[VirtualIdentityConfig] = None,
|
virtual_config: Optional[VirtualIdentityConfig] = None,
|
||||||
) -> dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if message_id is None:
|
if message_id is None:
|
||||||
message_id = str(uuid.uuid4())
|
message_id = str(uuid.uuid4())
|
||||||
|
|
||||||
@@ -385,7 +384,7 @@ def create_message_data(
|
|||||||
|
|
||||||
async def handle_chat_message(
|
async def handle_chat_message(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
data: dict[str, Any],
|
data: Dict[str, Any],
|
||||||
current_user_name: str,
|
current_user_name: str,
|
||||||
normalized_user_id: str,
|
normalized_user_id: str,
|
||||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||||
@@ -443,7 +442,7 @@ async def handle_chat_ping(session_id: str) -> None:
|
|||||||
await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()})
|
await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()})
|
||||||
|
|
||||||
|
|
||||||
async def handle_nickname_update(session_id: str, data: dict[str, Any], current_user_name: str) -> str:
|
async def handle_nickname_update(session_id: str, data: Dict[str, Any], current_user_name: str) -> str:
|
||||||
new_name = str(data.get("user_name", "")).strip()
|
new_name = str(data.get("user_name", "")).strip()
|
||||||
if not new_name:
|
if not new_name:
|
||||||
return current_user_name
|
return current_user_name
|
||||||
@@ -462,7 +461,7 @@ async def handle_nickname_update(session_id: str, data: dict[str, Any], current_
|
|||||||
async def enable_virtual_identity(
|
async def enable_virtual_identity(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
session_prefix: str,
|
session_prefix: str,
|
||||||
virtual_data: dict[str, Any],
|
virtual_data: Dict[str, Any],
|
||||||
) -> Optional[VirtualIdentityConfig]:
|
) -> Optional[VirtualIdentityConfig]:
|
||||||
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
|
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
|
||||||
await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id")
|
await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id")
|
||||||
@@ -558,7 +557,7 @@ async def disable_virtual_identity(session_id: str) -> None:
|
|||||||
async def handle_virtual_identity_update(
|
async def handle_virtual_identity_update(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
session_id_prefix: str,
|
session_id_prefix: str,
|
||||||
data: dict[str, Any],
|
data: Dict[str, Any],
|
||||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||||
) -> Optional[VirtualIdentityConfig]:
|
) -> Optional[VirtualIdentityConfig]:
|
||||||
virtual_data = cast(dict[str, Any], data.get("config", {}))
|
virtual_data = cast(dict[str, Any], data.get("config", {}))
|
||||||
@@ -573,11 +572,11 @@ async def handle_virtual_identity_update(
|
|||||||
async def dispatch_chat_event(
|
async def dispatch_chat_event(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
session_id_prefix: str,
|
session_id_prefix: str,
|
||||||
data: dict[str, Any],
|
data: Dict[str, Any],
|
||||||
current_user_name: str,
|
current_user_name: str,
|
||||||
normalized_user_id: str,
|
normalized_user_id: str,
|
||||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||||
) -> tuple[str, Optional[VirtualIdentityConfig]]:
|
) -> Tuple[str, Optional[VirtualIdentityConfig]]:
|
||||||
event_type = data.get("type")
|
event_type = data.get("type")
|
||||||
if event_type == "message":
|
if event_type == "message":
|
||||||
next_user_name = await handle_chat_message(
|
next_user_name = await handle_chat_message(
|
||||||
|
|||||||
@@ -5,51 +5,51 @@
|
|||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Annotated, Optional
|
from typing import Annotated, Any, Dict, List, Tuple
|
||||||
|
|
||||||
import tomlkit
|
import tomlkit
|
||||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.webui.dependencies import require_auth
|
from src.config.config import CONFIG_DIR, PROJECT_ROOT, Config, ModelConfig
|
||||||
from src.webui.utils.toml_utils import save_toml_with_format, _update_toml_doc
|
|
||||||
from src.config.config import Config, ModelConfig, CONFIG_DIR, PROJECT_ROOT
|
|
||||||
from src.config.config_base import AttributeData
|
from src.config.config_base import AttributeData
|
||||||
|
from src.config.model_configs import (
|
||||||
|
APIProvider,
|
||||||
|
ModelInfo,
|
||||||
|
ModelTaskConfig,
|
||||||
|
)
|
||||||
from src.config.official_configs import (
|
from src.config.official_configs import (
|
||||||
BotConfig,
|
BotConfig,
|
||||||
PersonalityConfig,
|
|
||||||
RelationshipConfig,
|
|
||||||
ChatConfig,
|
ChatConfig,
|
||||||
MessageReceiveConfig,
|
ChineseTypoConfig,
|
||||||
|
DebugConfig,
|
||||||
EmojiConfig,
|
EmojiConfig,
|
||||||
|
ExperimentalConfig,
|
||||||
ExpressionConfig,
|
ExpressionConfig,
|
||||||
KeywordReactionConfig,
|
KeywordReactionConfig,
|
||||||
ChineseTypoConfig,
|
LPMMKnowledgeConfig,
|
||||||
|
MaimMessageConfig,
|
||||||
|
MemoryConfig,
|
||||||
|
MessageReceiveConfig,
|
||||||
|
PersonalityConfig,
|
||||||
|
RelationshipConfig,
|
||||||
ResponsePostProcessConfig,
|
ResponsePostProcessConfig,
|
||||||
ResponseSplitterConfig,
|
ResponseSplitterConfig,
|
||||||
TelemetryConfig,
|
TelemetryConfig,
|
||||||
ExperimentalConfig,
|
|
||||||
MaimMessageConfig,
|
|
||||||
LPMMKnowledgeConfig,
|
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
MemoryConfig,
|
|
||||||
DebugConfig,
|
|
||||||
VoiceConfig,
|
VoiceConfig,
|
||||||
)
|
)
|
||||||
from src.config.model_configs import (
|
|
||||||
ModelTaskConfig,
|
|
||||||
ModelInfo,
|
|
||||||
APIProvider,
|
|
||||||
)
|
|
||||||
from src.webui.config_schema import ConfigSchemaGenerator
|
from src.webui.config_schema import ConfigSchemaGenerator
|
||||||
|
from src.webui.dependencies import require_auth
|
||||||
|
from src.webui.utils.toml_utils import _update_toml_doc, save_toml_with_format
|
||||||
|
|
||||||
logger = get_logger("webui")
|
logger = get_logger("webui")
|
||||||
|
|
||||||
# 模块级别的类型别名(解决 B008 ruff 错误)
|
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||||
ConfigBody = Annotated[dict[str, Any], Body()]
|
ConfigBody = Annotated[Dict[str, Any], Body()]
|
||||||
SectionBody = Annotated[Any, Body()]
|
SectionBody = Annotated[Any, Body()]
|
||||||
RawContentBody = Annotated[str, Body(embed=True)]
|
RawContentBody = Annotated[str, Body(embed=True)]
|
||||||
PathBody = Annotated[dict[str, str], Body()]
|
PathBody = Annotated[Dict[str, str], Body()]
|
||||||
|
|
||||||
router = APIRouter(prefix="/config", tags=["config"], dependencies=[Depends(require_auth)])
|
router = APIRouter(prefix="/config", tags=["config"], dependencies=[Depends(require_auth)])
|
||||||
|
|
||||||
@@ -61,6 +61,8 @@ def _toml_to_plain_dict(obj: Any) -> Any:
|
|||||||
if isinstance(obj, list):
|
if isinstance(obj, list):
|
||||||
return [_toml_to_plain_dict(v) for v in obj]
|
return [_toml_to_plain_dict(v) for v in obj]
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
# ===== 架构获取接口 =====
|
# ===== 架构获取接口 =====
|
||||||
|
|
||||||
|
|
||||||
@@ -385,8 +387,12 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
|
|||||||
if section_name == "api_providers" and "api_provider" in str(e):
|
if section_name == "api_providers" and "api_provider" in str(e):
|
||||||
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
|
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
|
||||||
models = config_data.get("models", [])
|
models = config_data.get("models", [])
|
||||||
orphaned_models = [
|
orphaned_models: List[str] = [
|
||||||
m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
|
str(model_name)
|
||||||
|
for m in models
|
||||||
|
if isinstance(m, dict)
|
||||||
|
and m.get("api_provider") not in provider_names
|
||||||
|
and (model_name := m.get("name")) is not None
|
||||||
]
|
]
|
||||||
if orphaned_models:
|
if orphaned_models:
|
||||||
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
|
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
|
||||||
@@ -421,7 +427,7 @@ def _normalize_adapter_path(path: str) -> str:
|
|||||||
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
|
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
|
||||||
|
|
||||||
|
|
||||||
def _get_allowed_adapter_config_roots() -> tuple[Path, ...]:
|
def _get_allowed_adapter_config_roots() -> Tuple[Path, ...]:
|
||||||
project_root = Path(PROJECT_ROOT).resolve()
|
project_root = Path(PROJECT_ROOT).resolve()
|
||||||
return (
|
return (
|
||||||
project_root,
|
project_root,
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
from .routes import router
|
from .routes import router
|
||||||
|
|
||||||
__all__ = ["router"]
|
__all__ = ["router"]
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
"""表情包管理 API 路由"""
|
"""表情包管理 API 路由"""
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Cookie, HTTPException, Query
|
from fastapi import APIRouter, Cookie, HTTPException, Query
|
||||||
from fastapi.responses import FileResponse, JSONResponse
|
from fastapi.responses import FileResponse, JSONResponse
|
||||||
@@ -17,7 +16,8 @@ from sqlmodel import col, select
|
|||||||
|
|
||||||
from src.common.database.database import get_db_session
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import Images, ImageType
|
from src.common.database.database_model import Images, ImageType
|
||||||
from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header as verify_auth_token
|
from src.webui.core import get_token_manager
|
||||||
|
from src.webui.core import verify_auth_token_from_cookie_or_header as verify_auth_token
|
||||||
|
|
||||||
from .schemas import (
|
from .schemas import (
|
||||||
BatchDeleteRequest,
|
BatchDeleteRequest,
|
||||||
@@ -219,7 +219,7 @@ async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(Non
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/stats/summary")
|
@router.get("/stats/summary")
|
||||||
async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||||
"""获取表情包统计数据。"""
|
"""获取表情包统计数据。"""
|
||||||
try:
|
try:
|
||||||
verify_auth_token(maibot_session)
|
verify_auth_token(maibot_session)
|
||||||
@@ -247,7 +247,7 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None)) -> dict[
|
|||||||
registered = session.exec(registered_statement).one()
|
registered = session.exec(registered_statement).one()
|
||||||
banned = session.exec(banned_statement).one()
|
banned = session.exec(banned_statement).one()
|
||||||
|
|
||||||
formats: dict[str, int] = {}
|
formats: Dict[str, int] = {}
|
||||||
format_statement = select(Images.full_path).where(col(Images.image_type) == ImageType.EMOJI)
|
format_statement = select(Images.full_path).where(col(Images.image_type) == ImageType.EMOJI)
|
||||||
for full_path in session.exec(format_statement).all():
|
for full_path in session.exec(format_statement).all():
|
||||||
suffix = Path(full_path).suffix.lower().lstrip(".")
|
suffix = Path(full_path).suffix.lower().lstrip(".")
|
||||||
@@ -443,7 +443,7 @@ async def batch_delete_emojis(
|
|||||||
|
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
failed_count = 0
|
failed_count = 0
|
||||||
failed_ids: list[int] = []
|
failed_ids: List[int] = []
|
||||||
|
|
||||||
for emoji_id in request.emoji_ids:
|
for emoji_id in request.emoji_ids:
|
||||||
try:
|
try:
|
||||||
@@ -582,12 +582,12 @@ async def batch_upload_emoji(
|
|||||||
emotion: EmotionForm = "",
|
emotion: EmotionForm = "",
|
||||||
is_registered: IsRegisteredForm = True,
|
is_registered: IsRegisteredForm = True,
|
||||||
maibot_session: Optional[str] = Cookie(None),
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
) -> dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""批量上传表情包。"""
|
"""批量上传表情包。"""
|
||||||
try:
|
try:
|
||||||
verify_auth_token(maibot_session)
|
verify_auth_token(maibot_session)
|
||||||
|
|
||||||
results: dict[str, Any] = {
|
results: Dict[str, Any] = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"total": len(files),
|
"total": len(files),
|
||||||
"uploaded": 0,
|
"uploaded": 0,
|
||||||
@@ -614,9 +614,7 @@ async def batch_upload_emoji(
|
|||||||
file_content = await file.read()
|
file_content = await file.read()
|
||||||
if not file_content:
|
if not file_content:
|
||||||
results["failed"] += 1
|
results["failed"] += 1
|
||||||
results["details"].append(
|
results["details"].append({"filename": file.filename, "success": False, "error": "文件内容为空"})
|
||||||
{"filename": file.filename, "success": False, "error": "文件内容为空"}
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -677,14 +675,10 @@ async def batch_upload_emoji(
|
|||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
results["uploaded"] += 1
|
results["uploaded"] += 1
|
||||||
results["details"].append(
|
results["details"].append({"filename": file.filename, "success": True, "id": emoji.id})
|
||||||
{"filename": file.filename, "success": True, "id": emoji.id}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
results["failed"] += 1
|
results["failed"] += 1
|
||||||
results["details"].append(
|
results["details"].append({"filename": file.filename, "success": False, "error": str(e)})
|
||||||
{"filename": file.filename, "success": False, "error": str(e)}
|
|
||||||
)
|
|
||||||
|
|
||||||
results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']} 个"
|
results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']} 个"
|
||||||
return results
|
return results
|
||||||
@@ -787,7 +781,9 @@ async def preheat_thumbnail_cache(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
await loop.run_in_executor(get_thumbnail_executor(), generate_thumbnail, emoji.full_path, emoji.image_hash)
|
await loop.run_in_executor(
|
||||||
|
get_thumbnail_executor(), generate_thumbnail, emoji.full_path, emoji.image_hash
|
||||||
|
)
|
||||||
generated += 1
|
generated += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"预热缩略图失败 {emoji.image_hash}: {e}")
|
logger.warning(f"预热缩略图失败 {emoji.image_hash}: {e}")
|
||||||
@@ -840,4 +836,4 @@ async def clear_all_thumbnail_cache(maibot_session: Optional[str] = Cookie(None)
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"清空缩略图缓存失败: {e}")
|
logger.exception(f"清空缩略图缓存失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"清空失败: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"清空失败: {str(e)}") from e
|
||||||
|
|||||||
@@ -137,4 +137,4 @@ def emoji_to_response(image: Images) -> EmojiResponse:
|
|||||||
record_time=image.record_time.timestamp() if image.record_time else 0.0,
|
record_time=image.record_time.timestamp() if image.record_time else 0.0,
|
||||||
register_time=image.register_time.timestamp() if image.register_time else None,
|
register_time=image.register_time.timestamp() if image.register_time else None,
|
||||||
last_used_time=image.last_used_time.timestamp() if image.last_used_time else None,
|
last_used_time=image.last_used_time.timestamp() if image.last_used_time else None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Set, Tuple
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
@@ -18,10 +18,10 @@ THUMBNAIL_SIZE = (200, 200)
|
|||||||
THUMBNAIL_QUALITY = 80
|
THUMBNAIL_QUALITY = 80
|
||||||
EMOJI_REGISTERED_DIR = os.path.join("data", "emoji_registed")
|
EMOJI_REGISTERED_DIR = os.path.join("data", "emoji_registed")
|
||||||
|
|
||||||
_thumbnail_locks: dict[str, threading.Lock] = {}
|
_thumbnail_locks: Dict[str, threading.Lock] = {}
|
||||||
_locks_lock = threading.Lock()
|
_locks_lock = threading.Lock()
|
||||||
_thumbnail_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="thumbnail")
|
_thumbnail_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="thumbnail")
|
||||||
_generating_thumbnails: set[str] = set()
|
_generating_thumbnails: Set[str] = set()
|
||||||
_generating_lock = threading.Lock()
|
_generating_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ def get_generating_lock() -> threading.Lock:
|
|||||||
return _generating_lock
|
return _generating_lock
|
||||||
|
|
||||||
|
|
||||||
def get_generating_thumbnails() -> set[str]:
|
def get_generating_thumbnails() -> Set[str]:
|
||||||
"""获取正在生成的缩略图哈希集合。"""
|
"""获取正在生成的缩略图哈希集合。"""
|
||||||
return _generating_thumbnails
|
return _generating_thumbnails
|
||||||
|
|
||||||
@@ -112,7 +112,7 @@ def background_generate_thumbnail(source_path: str, file_hash: str) -> None:
|
|||||||
_background_generate_thumbnail(source_path, file_hash)
|
_background_generate_thumbnail(source_path, file_hash)
|
||||||
|
|
||||||
|
|
||||||
def cleanup_orphaned_thumbnails() -> tuple[int, int]:
|
def cleanup_orphaned_thumbnails() -> Tuple[int, int]:
|
||||||
"""清理孤立的缩略图缓存。"""
|
"""清理孤立的缩略图缓存。"""
|
||||||
if not THUMBNAIL_CACHE_DIR.exists():
|
if not THUMBNAIL_CACHE_DIR.exists():
|
||||||
return 0, 0
|
return 0, 0
|
||||||
@@ -139,4 +139,4 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]:
|
|||||||
if cleaned > 0:
|
if cleaned > 0:
|
||||||
logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept} 个")
|
logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept} 个")
|
||||||
|
|
||||||
return cleaned, kept
|
return cleaned, kept
|
||||||
|
|||||||
@@ -5,14 +5,13 @@ from typing import Dict, List, Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from sqlalchemy import case, func
|
from sqlalchemy import case, func
|
||||||
from sqlmodel import col, select, delete
|
from sqlmodel import col, delete, select
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||||
from src.common.database.database import get_db_session
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import Expression
|
from src.common.database.database_model import Expression
|
||||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
from src.common.logger import get_logger
|
||||||
from src.webui.dependencies import require_auth
|
from src.webui.dependencies import require_auth
|
||||||
|
|
||||||
logger = get_logger("webui.expression")
|
logger = get_logger("webui.expression")
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
"""黑话(俚语)管理路由"""
|
"""黑话(俚语)管理路由"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from typing import Annotated, Any, Dict, List, Optional, Set
|
||||||
from typing import Annotated, Any, List, Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -90,7 +89,7 @@ class JargonListResponse(BaseModel):
|
|||||||
total: int
|
total: int
|
||||||
page: int
|
page: int
|
||||||
page_size: int
|
page_size: int
|
||||||
data: List[dict[str, Any]]
|
data: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class JargonDetailResponse(BaseModel):
|
class JargonDetailResponse(BaseModel):
|
||||||
@@ -153,7 +152,7 @@ class JargonStatsResponse(BaseModel):
|
|||||||
"""黑话统计响应"""
|
"""黑话统计响应"""
|
||||||
|
|
||||||
success: bool = True
|
success: bool = True
|
||||||
data: dict[str, Any]
|
data: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class ChatInfoResponse(BaseModel):
|
class ChatInfoResponse(BaseModel):
|
||||||
@@ -175,7 +174,7 @@ class ChatListResponse(BaseModel):
|
|||||||
# ==================== 工具函数 ====================
|
# ==================== 工具函数 ====================
|
||||||
|
|
||||||
|
|
||||||
def parse_session_id_dict(session_id_dict_str: Optional[str]) -> dict[str, int]:
|
def parse_session_id_dict(session_id_dict_str: Optional[str]) -> Dict[str, int]:
|
||||||
"""解析会话计数字典。"""
|
"""解析会话计数字典。"""
|
||||||
if not session_id_dict_str:
|
if not session_id_dict_str:
|
||||||
return {}
|
return {}
|
||||||
@@ -188,7 +187,7 @@ def parse_session_id_dict(session_id_dict_str: Optional[str]) -> dict[str, int]:
|
|||||||
if not isinstance(parsed, dict):
|
if not isinstance(parsed, dict):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
session_counts: dict[str, int] = {}
|
session_counts: Dict[str, int] = {}
|
||||||
for session_id, count in parsed.items():
|
for session_id, count in parsed.items():
|
||||||
if not isinstance(session_id, str):
|
if not isinstance(session_id, str):
|
||||||
continue
|
continue
|
||||||
@@ -202,7 +201,7 @@ def parse_session_id_dict(session_id_dict_str: Optional[str]) -> dict[str, int]:
|
|||||||
return session_counts
|
return session_counts
|
||||||
|
|
||||||
|
|
||||||
def dump_session_id_dict(session_counts: dict[str, int]) -> str:
|
def dump_session_id_dict(session_counts: Dict[str, int]) -> str:
|
||||||
"""序列化会话计数字典。"""
|
"""序列化会话计数字典。"""
|
||||||
return json.dumps(session_counts, ensure_ascii=False)
|
return json.dumps(session_counts, ensure_ascii=False)
|
||||||
|
|
||||||
@@ -225,7 +224,7 @@ def build_session_id_dict_for_chat(chat_id: str, count: int = 1) -> str:
|
|||||||
return dump_session_id_dict({chat_id: count})
|
return dump_session_id_dict({chat_id: count})
|
||||||
|
|
||||||
|
|
||||||
def jargon_to_dict(jargon: Jargon, session: Session) -> dict[str, Any]:
|
def jargon_to_dict(jargon: Jargon, session: Session) -> Dict[str, Any]:
|
||||||
"""将 Jargon ORM 对象转换为字典"""
|
"""将 Jargon ORM 对象转换为字典"""
|
||||||
chat_id = get_primary_chat_id(jargon.session_id_dict)
|
chat_id = get_primary_chat_id(jargon.session_id_dict)
|
||||||
chat_name = get_display_name_for_chat_id(chat_id, session) if chat_id else None
|
chat_name = get_display_name_for_chat_id(chat_id, session) if chat_id else None
|
||||||
@@ -311,14 +310,16 @@ async def get_chat_list():
|
|||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
jargons = session.exec(select(Jargon)).all()
|
jargons = session.exec(select(Jargon)).all()
|
||||||
|
|
||||||
seen_stream_ids: set[str] = set()
|
seen_stream_ids: Set[str] = set()
|
||||||
for jargon in jargons:
|
for jargon in jargons:
|
||||||
seen_stream_ids.update(parse_session_id_dict(jargon.session_id_dict).keys())
|
seen_stream_ids.update(parse_session_id_dict(jargon.session_id_dict).keys())
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
for stream_id in seen_stream_ids:
|
for stream_id in seen_stream_ids:
|
||||||
if chat_session := session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first():
|
if chat_session := session.exec(
|
||||||
|
select(ChatSession).where(col(ChatSession.session_id) == stream_id)
|
||||||
|
).first():
|
||||||
chat_name = str(chat_session.group_id) if chat_session.group_id else stream_id[:20]
|
chat_name = str(chat_session.group_id) if chat_session.group_id else stream_id[:20]
|
||||||
result.append(
|
result.append(
|
||||||
ChatInfoResponse(
|
ChatInfoResponse(
|
||||||
@@ -358,7 +359,7 @@ async def get_jargon_stats():
|
|||||||
pending = sum(jargon.is_jargon is None for jargon in jargons)
|
pending = sum(jargon.is_jargon is None for jargon in jargons)
|
||||||
complete_count = sum(jargon.is_complete for jargon in jargons)
|
complete_count = sum(jargon.is_complete for jargon in jargons)
|
||||||
|
|
||||||
top_chats_counter: dict[str, int] = {}
|
top_chats_counter: Dict[str, int] = {}
|
||||||
for jargon in jargons:
|
for jargon in jargons:
|
||||||
for session_id in parse_session_id_dict(jargon.session_id_dict):
|
for session_id in parse_session_id_dict(jargon.session_id_dict):
|
||||||
top_chats_counter[session_id] = top_chats_counter.get(session_id, 0) + 1
|
top_chats_counter[session_id] = top_chats_counter.get(session_id, 0) + 1
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
"""知识库图谱可视化 API 路由"""
|
"""知识库图谱可视化 API 路由"""
|
||||||
|
|
||||||
from typing import Any, List, Optional
|
import logging
|
||||||
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import logging
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.webui.dependencies import require_auth
|
from src.webui.dependencies import require_auth
|
||||||
|
|
||||||
@@ -59,7 +60,7 @@ def _get_paragraph_store():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
|
def _get_paragraph_content(node_id: str) -> Tuple[Optional[str], bool]:
|
||||||
"""从 embedding store 获取段落完整内容
|
"""从 embedding store 获取段落完整内容
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -5,9 +5,9 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import httpx
|
from typing import Dict, List, Optional
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
import httpx
|
||||||
import tomlkit
|
import tomlkit
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ def _normalize_url(url: str) -> str:
|
|||||||
return url.rstrip("/") if url else ""
|
return url.rstrip("/") if url else ""
|
||||||
|
|
||||||
|
|
||||||
def _parse_openai_response(data: dict) -> list[dict]:
|
def _parse_openai_response(data: Dict) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
解析 OpenAI 格式的模型列表响应
|
解析 OpenAI 格式的模型列表响应
|
||||||
|
|
||||||
@@ -59,7 +59,7 @@ def _parse_openai_response(data: dict) -> list[dict]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _parse_gemini_response(data: dict) -> list[dict]:
|
def _parse_gemini_response(data: Dict) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
解析 Gemini 格式的模型列表响应
|
解析 Gemini 格式的模型列表响应
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ async def _fetch_models_from_provider(
|
|||||||
endpoint: str,
|
endpoint: str,
|
||||||
parser: str,
|
parser: str,
|
||||||
client_type: str = "openai",
|
client_type: str = "openai",
|
||||||
) -> list[dict]:
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
从提供商 API 获取模型列表
|
从提供商 API 获取模型列表
|
||||||
|
|
||||||
@@ -154,7 +154,7 @@ async def _fetch_models_from_provider(
|
|||||||
raise HTTPException(status_code=400, detail=f"不支持的解析器类型: {parser}")
|
raise HTTPException(status_code=400, detail=f"不支持的解析器类型: {parser}")
|
||||||
|
|
||||||
|
|
||||||
def _get_provider_config(provider_name: str) -> Optional[dict]:
|
def _get_provider_config(provider_name: str) -> Optional[Dict]:
|
||||||
"""
|
"""
|
||||||
从 model_config.toml 获取指定提供商的配置
|
从 model_config.toml 获取指定提供商的配置
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
"""人物信息管理 API 路由"""
|
"""人物信息管理 API 路由"""
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
import json
|
import json
|
||||||
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from sqlalchemy import case
|
from sqlalchemy import case
|
||||||
from sqlmodel import col, select, delete
|
from sqlmodel import col, delete, select
|
||||||
|
|
||||||
from src.common.database.database import get_db_session
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import PersonInfo
|
from src.common.database.database_model import PersonInfo
|
||||||
|
|||||||
@@ -14,4 +14,4 @@ router.include_router(config_router)
|
|||||||
|
|
||||||
set_update_progress_callback(update_progress)
|
set_update_progress_callback(update_progress)
|
||||||
|
|
||||||
__all__ = ["get_progress_router", "router"]
|
__all__ = ["get_progress_router", "router"]
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Cookie, HTTPException
|
from fastapi import APIRouter, Cookie, HTTPException
|
||||||
|
|
||||||
@@ -28,7 +27,7 @@ logger = get_logger("webui.plugin_routes")
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def _mirror_to_response(mirror: dict[str, Any]) -> MirrorConfigResponse:
|
def _mirror_to_response(mirror: Dict[str, Any]) -> MirrorConfigResponse:
|
||||||
return MirrorConfigResponse(
|
return MirrorConfigResponse(
|
||||||
id=mirror["id"],
|
id=mirror["id"],
|
||||||
name=mirror["name"],
|
name=mirror["name"],
|
||||||
@@ -116,7 +115,7 @@ async def update_mirror(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/mirrors/{mirror_id}")
|
@router.delete("/mirrors/{mirror_id}")
|
||||||
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||||
require_plugin_token(maibot_session)
|
require_plugin_token(maibot_session)
|
||||||
|
|
||||||
service = get_git_mirror_service()
|
service = get_git_mirror_service()
|
||||||
@@ -172,12 +171,16 @@ async def fetch_raw_file(
|
|||||||
loaded_plugins=total,
|
loaded_plugins=total,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
await update_progress(stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0)
|
await update_progress(
|
||||||
|
stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0
|
||||||
|
)
|
||||||
|
|
||||||
return FetchRawFileResponse(**result)
|
return FetchRawFileResponse(**result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取 Raw 文件失败: {e}")
|
logger.error(f"获取 Raw 文件失败: {e}")
|
||||||
await update_progress(stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0)
|
await update_progress(
|
||||||
|
stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0
|
||||||
|
)
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@@ -204,4 +207,4 @@ async def clone_repository(
|
|||||||
return CloneRepositoryResponse(**result)
|
return CloneRepositoryResponse(**result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"克隆仓库失败: {e}")
|
logger.error(f"克隆仓库失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
from typing import Any, Optional, cast
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import tomlkit
|
from typing import Any, Dict, Optional, cast
|
||||||
|
|
||||||
|
import tomlkit
|
||||||
from fastapi import APIRouter, Cookie, HTTPException
|
from fastapi import APIRouter, Cookie, HTTPException
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -24,8 +23,8 @@ logger = get_logger("webui.plugin_routes")
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> dict[str, Any]:
|
def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> Dict[str, Any]:
|
||||||
schema: dict[str, Any] = {
|
schema: Dict[str, Any] = {
|
||||||
"plugin_id": plugin_id,
|
"plugin_id": plugin_id,
|
||||||
"plugin_info": {
|
"plugin_info": {
|
||||||
"name": plugin_id,
|
"name": plugin_id,
|
||||||
@@ -41,7 +40,7 @@ def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> di
|
|||||||
for section_name, section_data in current_config.items():
|
for section_name, section_data in current_config.items():
|
||||||
if not isinstance(section_data, dict):
|
if not isinstance(section_data, dict):
|
||||||
continue
|
continue
|
||||||
section_fields: dict[str, Any] = {}
|
section_fields: Dict[str, Any] = {}
|
||||||
for field_name, field_value in section_data.items():
|
for field_name, field_value in section_data.items():
|
||||||
field_type = type(field_value).__name__
|
field_type = type(field_value).__name__
|
||||||
ui_type = "text"
|
ui_type = "text"
|
||||||
@@ -121,7 +120,7 @@ def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> di
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/config/{plugin_id}/schema")
|
@router.get("/config/{plugin_id}/schema")
|
||||||
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||||
require_plugin_token(maibot_session)
|
require_plugin_token(maibot_session)
|
||||||
logger.info(f"获取插件配置 Schema: {plugin_id}")
|
logger.info(f"获取插件配置 Schema: {plugin_id}")
|
||||||
|
|
||||||
@@ -157,7 +156,7 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/config/{plugin_id}/raw")
|
@router.get("/config/{plugin_id}/raw")
|
||||||
async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||||
require_plugin_token(maibot_session)
|
require_plugin_token(maibot_session)
|
||||||
logger.info(f"获取插件原始配置: {plugin_id}")
|
logger.info(f"获取插件原始配置: {plugin_id}")
|
||||||
|
|
||||||
@@ -184,7 +183,7 @@ async def update_plugin_config_raw(
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
request: UpdatePluginRawConfigRequest,
|
request: UpdatePluginRawConfigRequest,
|
||||||
maibot_session: Optional[str] = Cookie(None),
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
) -> dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
require_plugin_token(maibot_session)
|
require_plugin_token(maibot_session)
|
||||||
logger.info(f"更新插件原始配置: {plugin_id}")
|
logger.info(f"更新插件原始配置: {plugin_id}")
|
||||||
|
|
||||||
@@ -216,7 +215,7 @@ async def update_plugin_config_raw(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/config/{plugin_id}")
|
@router.get("/config/{plugin_id}")
|
||||||
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||||
require_plugin_token(maibot_session)
|
require_plugin_token(maibot_session)
|
||||||
logger.info(f"获取插件配置: {plugin_id}")
|
logger.info(f"获取插件配置: {plugin_id}")
|
||||||
|
|
||||||
@@ -244,7 +243,7 @@ async def update_plugin_config(
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
request: UpdatePluginConfigRequest,
|
request: UpdatePluginConfigRequest,
|
||||||
maibot_session: Optional[str] = Cookie(None),
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
) -> dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
require_plugin_token(maibot_session)
|
require_plugin_token(maibot_session)
|
||||||
logger.info(f"更新插件配置: {plugin_id}")
|
logger.info(f"更新插件配置: {plugin_id}")
|
||||||
|
|
||||||
@@ -276,7 +275,7 @@ async def update_plugin_config(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/config/{plugin_id}/reset")
|
@router.post("/config/{plugin_id}/reset")
|
||||||
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||||
require_plugin_token(maibot_session)
|
require_plugin_token(maibot_session)
|
||||||
logger.info(f"重置插件配置: {plugin_id}")
|
logger.info(f"重置插件配置: {plugin_id}")
|
||||||
|
|
||||||
@@ -300,7 +299,7 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/config/{plugin_id}/toggle")
|
@router.post("/config/{plugin_id}/toggle")
|
||||||
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||||
require_plugin_token(maibot_session)
|
require_plugin_token(maibot_session)
|
||||||
logger.info(f"切换插件状态: {plugin_id}")
|
logger.info(f"切换插件状态: {plugin_id}")
|
||||||
|
|
||||||
@@ -326,9 +325,14 @@ async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(N
|
|||||||
|
|
||||||
status = "启用" if new_enabled else "禁用"
|
status = "启用" if new_enabled else "禁用"
|
||||||
logger.info(f"已{status}插件: {plugin_id}")
|
logger.info(f"已{status}插件: {plugin_id}")
|
||||||
return {"success": True, "enabled": new_enabled, "message": f"插件已{status}", "note": "状态更改将在下次加载插件时生效"}
|
return {
|
||||||
|
"success": True,
|
||||||
|
"enabled": new_enabled,
|
||||||
|
"message": f"插件已{status}",
|
||||||
|
"note": "状态更改将在下次加载插件时生效",
|
||||||
|
}
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"切换插件状态失败: {e}", exc_info=True)
|
logger.error(f"切换插件状态失败: {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Cookie, HTTPException
|
from fastapi import APIRouter, Cookie, HTTPException
|
||||||
|
|
||||||
@@ -18,8 +17,8 @@ from .support import (
|
|||||||
parse_repository_url,
|
parse_repository_url,
|
||||||
remove_tree,
|
remove_tree,
|
||||||
require_plugin_token,
|
require_plugin_token,
|
||||||
resolve_plugin_file_path,
|
|
||||||
resolve_installed_plugin_path,
|
resolve_installed_plugin_path,
|
||||||
|
resolve_plugin_file_path,
|
||||||
validate_plugin_id,
|
validate_plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,7 +27,7 @@ logger = get_logger("webui.plugin_routes")
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def _infer_plugin_id(folder_name: str, manifest: dict[str, Any], manifest_path: Path) -> str:
|
def _infer_plugin_id(folder_name: str, manifest: Dict[str, Any], manifest_path: Path) -> str:
|
||||||
if "id" in manifest:
|
if "id" in manifest:
|
||||||
return str(manifest["id"])
|
return str(manifest["id"])
|
||||||
|
|
||||||
@@ -66,43 +65,87 @@ def _infer_plugin_id(folder_name: str, manifest: dict[str, Any], manifest_path:
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/install")
|
@router.post("/install")
|
||||||
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||||
require_plugin_token(maibot_session)
|
require_plugin_token(maibot_session)
|
||||||
logger.info(f"收到安装插件请求: {request.plugin_id}")
|
logger.info(f"收到安装插件请求: {request.plugin_id}")
|
||||||
plugin_id = request.plugin_id
|
plugin_id = request.plugin_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
plugin_id = validate_plugin_id(request.plugin_id)
|
plugin_id = validate_plugin_id(request.plugin_id)
|
||||||
await update_progress(stage="loading", progress=5, message=f"开始安装插件: {plugin_id}", operation="install", plugin_id=plugin_id)
|
await update_progress(
|
||||||
|
stage="loading", progress=5, message=f"开始安装插件: {plugin_id}", operation="install", plugin_id=plugin_id
|
||||||
|
)
|
||||||
|
|
||||||
repo_url, owner, repo = parse_repository_url(request.repository_url)
|
repo_url, owner, repo = parse_repository_url(request.repository_url)
|
||||||
await update_progress(stage="loading", progress=10, message=f"解析仓库信息: {owner}/{repo}", operation="install", plugin_id=plugin_id)
|
await update_progress(
|
||||||
|
stage="loading",
|
||||||
|
progress=10,
|
||||||
|
message=f"解析仓库信息: {owner}/{repo}",
|
||||||
|
operation="install",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
)
|
||||||
|
|
||||||
target_path, old_format_path = get_plugin_candidate_paths(plugin_id)
|
target_path, old_format_path = get_plugin_candidate_paths(plugin_id)
|
||||||
if target_path.exists() or old_format_path.exists():
|
if target_path.exists() or old_format_path.exists():
|
||||||
await update_progress(stage="error", progress=0, message="插件已存在", operation="install", plugin_id=plugin_id, error="插件已安装,请先卸载")
|
await update_progress(
|
||||||
|
stage="error",
|
||||||
|
progress=0,
|
||||||
|
message="插件已存在",
|
||||||
|
operation="install",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
error="插件已安装,请先卸载",
|
||||||
|
)
|
||||||
raise HTTPException(status_code=400, detail="插件已安装")
|
raise HTTPException(status_code=400, detail="插件已安装")
|
||||||
|
|
||||||
await update_progress(stage="loading", progress=15, message=f"准备克隆到: {target_path}", operation="install", plugin_id=plugin_id)
|
await update_progress(
|
||||||
|
stage="loading", progress=15, message=f"准备克隆到: {target_path}", operation="install", plugin_id=plugin_id
|
||||||
|
)
|
||||||
service = get_git_mirror_service()
|
service = get_git_mirror_service()
|
||||||
if "github.com" in repo_url:
|
if "github.com" in repo_url:
|
||||||
result = await service.clone_repository(owner=owner, repo=repo, target_path=target_path, branch=request.branch, mirror_id=request.mirror_id, depth=1)
|
result = await service.clone_repository(
|
||||||
|
owner=owner,
|
||||||
|
repo=repo,
|
||||||
|
target_path=target_path,
|
||||||
|
branch=request.branch,
|
||||||
|
mirror_id=request.mirror_id,
|
||||||
|
depth=1,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
result = await service.clone_repository(owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1)
|
result = await service.clone_repository(
|
||||||
|
owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1
|
||||||
|
)
|
||||||
|
|
||||||
if not result.get("success"):
|
if not result.get("success"):
|
||||||
error_msg = str(result.get("error", "克隆失败"))
|
error_msg = str(result.get("error", "克隆失败"))
|
||||||
await update_progress(stage="error", progress=0, message="克隆仓库失败", operation="install", plugin_id=plugin_id, error=error_msg)
|
await update_progress(
|
||||||
|
stage="error",
|
||||||
|
progress=0,
|
||||||
|
message="克隆仓库失败",
|
||||||
|
operation="install",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
error=error_msg,
|
||||||
|
)
|
||||||
raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg)
|
raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg)
|
||||||
|
|
||||||
await update_progress(stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id)
|
await update_progress(
|
||||||
|
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id
|
||||||
|
)
|
||||||
manifest_path = resolve_plugin_file_path(target_path, "_manifest.json")
|
manifest_path = resolve_plugin_file_path(target_path, "_manifest.json")
|
||||||
if not manifest_path.exists():
|
if not manifest_path.exists():
|
||||||
remove_tree(target_path)
|
remove_tree(target_path)
|
||||||
await update_progress(stage="error", progress=0, message="插件缺少 _manifest.json", operation="install", plugin_id=plugin_id, error="无效的插件格式")
|
await update_progress(
|
||||||
|
stage="error",
|
||||||
|
progress=0,
|
||||||
|
message="插件缺少 _manifest.json",
|
||||||
|
operation="install",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
error="无效的插件格式",
|
||||||
|
)
|
||||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||||
|
|
||||||
await update_progress(stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id)
|
await update_progress(
|
||||||
|
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
with open(manifest_path, "r", encoding="utf-8") as file_obj:
|
with open(manifest_path, "r", encoding="utf-8") as file_obj:
|
||||||
manifest = json.load(file_obj)
|
manifest = json.load(file_obj)
|
||||||
@@ -114,91 +157,199 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
|
|||||||
json.dump(manifest, file_obj, ensure_ascii=False, indent=2)
|
json.dump(manifest, file_obj, ensure_ascii=False, indent=2)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
remove_tree(target_path)
|
remove_tree(target_path)
|
||||||
await update_progress(stage="error", progress=0, message="_manifest.json 无效", operation="install", plugin_id=plugin_id, error=str(e))
|
await update_progress(
|
||||||
|
stage="error",
|
||||||
|
progress=0,
|
||||||
|
message="_manifest.json 无效",
|
||||||
|
operation="install",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||||
|
|
||||||
await update_progress(stage="success", progress=100, message=f"成功安装插件: {manifest['name']} v{manifest['version']}", operation="install", plugin_id=plugin_id)
|
await update_progress(
|
||||||
return {"success": True, "message": "插件安装成功", "plugin_id": plugin_id, "plugin_name": manifest["name"], "version": manifest["version"], "path": str(target_path)}
|
stage="success",
|
||||||
|
progress=100,
|
||||||
|
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
|
||||||
|
operation="install",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "插件安装成功",
|
||||||
|
"plugin_id": plugin_id,
|
||||||
|
"plugin_name": manifest["name"],
|
||||||
|
"version": manifest["version"],
|
||||||
|
"path": str(target_path),
|
||||||
|
}
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"安装插件失败: {e}", exc_info=True)
|
logger.error(f"安装插件失败: {e}", exc_info=True)
|
||||||
await update_progress(stage="error", progress=0, message="安装失败", operation="install", plugin_id=plugin_id, error=str(e))
|
await update_progress(
|
||||||
|
stage="error", progress=0, message="安装失败", operation="install", plugin_id=plugin_id, error=str(e)
|
||||||
|
)
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/uninstall")
|
@router.post("/uninstall")
|
||||||
async def uninstall_plugin(request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
async def uninstall_plugin(
|
||||||
|
request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
require_plugin_token(maibot_session)
|
require_plugin_token(maibot_session)
|
||||||
logger.info(f"收到卸载插件请求: {request.plugin_id}")
|
logger.info(f"收到卸载插件请求: {request.plugin_id}")
|
||||||
plugin_id = request.plugin_id
|
plugin_id = request.plugin_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
plugin_id = validate_plugin_id(request.plugin_id)
|
plugin_id = validate_plugin_id(request.plugin_id)
|
||||||
await update_progress(stage="loading", progress=10, message=f"开始卸载插件: {plugin_id}", operation="uninstall", plugin_id=plugin_id)
|
await update_progress(
|
||||||
|
stage="loading",
|
||||||
|
progress=10,
|
||||||
|
message=f"开始卸载插件: {plugin_id}",
|
||||||
|
operation="uninstall",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
)
|
||||||
plugin_path = resolve_installed_plugin_path(plugin_id)
|
plugin_path = resolve_installed_plugin_path(plugin_id)
|
||||||
if plugin_path is None:
|
if plugin_path is None:
|
||||||
await update_progress(stage="error", progress=0, message="插件不存在", operation="uninstall", plugin_id=plugin_id, error="插件未安装或已被删除")
|
await update_progress(
|
||||||
|
stage="error",
|
||||||
|
progress=0,
|
||||||
|
message="插件不存在",
|
||||||
|
operation="uninstall",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
error="插件未安装或已被删除",
|
||||||
|
)
|
||||||
raise HTTPException(status_code=404, detail="插件未安装")
|
raise HTTPException(status_code=404, detail="插件未安装")
|
||||||
|
|
||||||
await update_progress(stage="loading", progress=30, message=f"正在删除插件文件: {plugin_path}", operation="uninstall", plugin_id=plugin_id)
|
await update_progress(
|
||||||
|
stage="loading",
|
||||||
|
progress=30,
|
||||||
|
message=f"正在删除插件文件: {plugin_path}",
|
||||||
|
operation="uninstall",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
)
|
||||||
manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json"))
|
manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json"))
|
||||||
plugin_name = str(manifest.get("name", plugin_id)) if manifest is not None else plugin_id
|
plugin_name = str(manifest.get("name", plugin_id)) if manifest is not None else plugin_id
|
||||||
await update_progress(stage="loading", progress=50, message=f"正在删除 {plugin_name}...", operation="uninstall", plugin_id=plugin_id)
|
await update_progress(
|
||||||
|
stage="loading",
|
||||||
|
progress=50,
|
||||||
|
message=f"正在删除 {plugin_name}...",
|
||||||
|
operation="uninstall",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
)
|
||||||
remove_tree(plugin_path)
|
remove_tree(plugin_path)
|
||||||
logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})")
|
logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})")
|
||||||
await update_progress(stage="success", progress=100, message=f"成功卸载插件: {plugin_name}", operation="uninstall", plugin_id=plugin_id)
|
await update_progress(
|
||||||
|
stage="success",
|
||||||
|
progress=100,
|
||||||
|
message=f"成功卸载插件: {plugin_name}",
|
||||||
|
operation="uninstall",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
)
|
||||||
return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name}
|
return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name}
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except PermissionError as e:
|
except PermissionError as e:
|
||||||
logger.error(f"卸载插件失败(权限错误): {e}")
|
logger.error(f"卸载插件失败(权限错误): {e}")
|
||||||
await update_progress(stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error="权限不足,无法删除插件文件")
|
await update_progress(
|
||||||
|
stage="error",
|
||||||
|
progress=0,
|
||||||
|
message="卸载失败",
|
||||||
|
operation="uninstall",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
error="权限不足,无法删除插件文件",
|
||||||
|
)
|
||||||
raise HTTPException(status_code=500, detail="权限不足,无法删除插件文件") from e
|
raise HTTPException(status_code=500, detail="权限不足,无法删除插件文件") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"卸载插件失败: {e}", exc_info=True)
|
logger.error(f"卸载插件失败: {e}", exc_info=True)
|
||||||
await update_progress(stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error=str(e))
|
await update_progress(
|
||||||
|
stage="error", progress=0, message="卸载失败", operation="uninstall", plugin_id=plugin_id, error=str(e)
|
||||||
|
)
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/update")
|
@router.post("/update")
|
||||||
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||||
require_plugin_token(maibot_session)
|
require_plugin_token(maibot_session)
|
||||||
logger.info(f"收到更新插件请求: {request.plugin_id}")
|
logger.info(f"收到更新插件请求: {request.plugin_id}")
|
||||||
plugin_id = request.plugin_id
|
plugin_id = request.plugin_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
plugin_id = validate_plugin_id(request.plugin_id)
|
plugin_id = validate_plugin_id(request.plugin_id)
|
||||||
await update_progress(stage="loading", progress=5, message=f"开始更新插件: {plugin_id}", operation="update", plugin_id=plugin_id)
|
await update_progress(
|
||||||
|
stage="loading", progress=5, message=f"开始更新插件: {plugin_id}", operation="update", plugin_id=plugin_id
|
||||||
|
)
|
||||||
plugin_path = resolve_installed_plugin_path(plugin_id)
|
plugin_path = resolve_installed_plugin_path(plugin_id)
|
||||||
if plugin_path is None:
|
if plugin_path is None:
|
||||||
await update_progress(stage="error", progress=0, message="插件不存在", operation="update", plugin_id=plugin_id, error="插件未安装,请先安装")
|
await update_progress(
|
||||||
|
stage="error",
|
||||||
|
progress=0,
|
||||||
|
message="插件不存在",
|
||||||
|
operation="update",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
error="插件未安装,请先安装",
|
||||||
|
)
|
||||||
raise HTTPException(status_code=404, detail="插件未安装")
|
raise HTTPException(status_code=404, detail="插件未安装")
|
||||||
|
|
||||||
manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json"))
|
manifest = load_manifest_json(resolve_plugin_file_path(plugin_path, "_manifest.json"))
|
||||||
old_version = str(manifest.get("version", "unknown")) if manifest is not None else "unknown"
|
old_version = str(manifest.get("version", "unknown")) if manifest is not None else "unknown"
|
||||||
await update_progress(stage="loading", progress=10, message=f"当前版本: {old_version},准备更新...", operation="update", plugin_id=plugin_id)
|
await update_progress(
|
||||||
await update_progress(stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id)
|
stage="loading",
|
||||||
|
progress=10,
|
||||||
|
message=f"当前版本: {old_version},准备更新...",
|
||||||
|
operation="update",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
)
|
||||||
|
await update_progress(
|
||||||
|
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id
|
||||||
|
)
|
||||||
remove_tree(plugin_path)
|
remove_tree(plugin_path)
|
||||||
|
|
||||||
await update_progress(stage="loading", progress=30, message="正在准备下载新版本...", operation="update", plugin_id=plugin_id)
|
await update_progress(
|
||||||
|
stage="loading", progress=30, message="正在准备下载新版本...", operation="update", plugin_id=plugin_id
|
||||||
|
)
|
||||||
repo_url, owner, repo = parse_repository_url(request.repository_url)
|
repo_url, owner, repo = parse_repository_url(request.repository_url)
|
||||||
service = get_git_mirror_service()
|
service = get_git_mirror_service()
|
||||||
if "github.com" in repo_url:
|
if "github.com" in repo_url:
|
||||||
result = await service.clone_repository(owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, mirror_id=request.mirror_id, depth=1)
|
result = await service.clone_repository(
|
||||||
|
owner=owner,
|
||||||
|
repo=repo,
|
||||||
|
target_path=plugin_path,
|
||||||
|
branch=request.branch,
|
||||||
|
mirror_id=request.mirror_id,
|
||||||
|
depth=1,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
result = await service.clone_repository(owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1)
|
result = await service.clone_repository(
|
||||||
|
owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1
|
||||||
|
)
|
||||||
|
|
||||||
if not result.get("success"):
|
if not result.get("success"):
|
||||||
error_msg = str(result.get("error", "克隆失败"))
|
error_msg = str(result.get("error", "克隆失败"))
|
||||||
await update_progress(stage="error", progress=0, message="下载新版本失败", operation="update", plugin_id=plugin_id, error=error_msg)
|
await update_progress(
|
||||||
|
stage="error",
|
||||||
|
progress=0,
|
||||||
|
message="下载新版本失败",
|
||||||
|
operation="update",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
error=error_msg,
|
||||||
|
)
|
||||||
raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg)
|
raise HTTPException(status_code=int(result.get("status_code", 500)), detail=error_msg)
|
||||||
|
|
||||||
await update_progress(stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id)
|
await update_progress(
|
||||||
|
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id
|
||||||
|
)
|
||||||
new_manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json")
|
new_manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json")
|
||||||
if not new_manifest_path.exists():
|
if not new_manifest_path.exists():
|
||||||
remove_tree(plugin_path)
|
remove_tree(plugin_path)
|
||||||
await update_progress(stage="error", progress=0, message="新版本缺少 _manifest.json", operation="update", plugin_id=plugin_id, error="无效的插件格式")
|
await update_progress(
|
||||||
|
stage="error",
|
||||||
|
progress=0,
|
||||||
|
message="新版本缺少 _manifest.json",
|
||||||
|
operation="update",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
error="无效的插件格式",
|
||||||
|
)
|
||||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -207,27 +358,49 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
|||||||
new_version = str(new_manifest.get("version", "unknown"))
|
new_version = str(new_manifest.get("version", "unknown"))
|
||||||
new_name = str(new_manifest.get("name", plugin_id))
|
new_name = str(new_manifest.get("name", plugin_id))
|
||||||
logger.info(f"成功更新插件: {plugin_id} {old_version} → {new_version}")
|
logger.info(f"成功更新插件: {plugin_id} {old_version} → {new_version}")
|
||||||
await update_progress(stage="success", progress=100, message=f"成功更新 {new_name}: {old_version} → {new_version}", operation="update", plugin_id=plugin_id)
|
await update_progress(
|
||||||
return {"success": True, "message": "插件更新成功", "plugin_id": plugin_id, "plugin_name": new_name, "old_version": old_version, "new_version": new_version}
|
stage="success",
|
||||||
|
progress=100,
|
||||||
|
message=f"成功更新 {new_name}: {old_version} → {new_version}",
|
||||||
|
operation="update",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "插件更新成功",
|
||||||
|
"plugin_id": plugin_id,
|
||||||
|
"plugin_name": new_name,
|
||||||
|
"old_version": old_version,
|
||||||
|
"new_version": new_version,
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
remove_tree(plugin_path)
|
remove_tree(plugin_path)
|
||||||
await update_progress(stage="error", progress=0, message="_manifest.json 无效", operation="update", plugin_id=plugin_id, error=str(e))
|
await update_progress(
|
||||||
|
stage="error",
|
||||||
|
progress=0,
|
||||||
|
message="_manifest.json 无效",
|
||||||
|
operation="update",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新插件失败: {e}", exc_info=True)
|
logger.error(f"更新插件失败: {e}", exc_info=True)
|
||||||
await update_progress(stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e))
|
await update_progress(
|
||||||
|
stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e)
|
||||||
|
)
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.get("/installed")
|
@router.get("/installed")
|
||||||
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||||
require_plugin_token(maibot_session)
|
require_plugin_token(maibot_session)
|
||||||
logger.info("收到获取已安装插件列表请求")
|
logger.info("收到获取已安装插件列表请求")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
installed_plugins: list[dict[str, Any]] = []
|
installed_plugins: List[Dict[str, Any]] = []
|
||||||
for plugin_path in iter_plugin_directories():
|
for plugin_path in iter_plugin_directories():
|
||||||
folder_name = plugin_path.name
|
folder_name = plugin_path.name
|
||||||
if folder_name.startswith(".") or folder_name.startswith("__"):
|
if folder_name.startswith(".") or folder_name.startswith("__"):
|
||||||
@@ -253,9 +426,9 @@ async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) ->
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"读取插件 {folder_name} 信息时出错: {e}")
|
logger.error(f"读取插件 {folder_name} 信息时出错: {e}")
|
||||||
|
|
||||||
seen_ids: dict[str, str] = {}
|
seen_ids: Dict[str, str] = {}
|
||||||
unique_plugins: list[dict[str, Any]] = []
|
unique_plugins: List[Dict[str, Any]] = []
|
||||||
duplicates: list[dict[str, Any]] = []
|
duplicates: List[Dict[str, Any]] = []
|
||||||
for plugin in installed_plugins:
|
for plugin in installed_plugins:
|
||||||
plugin_id = str(plugin["id"])
|
plugin_id = str(plugin["id"])
|
||||||
plugin_path = str(plugin["path"])
|
plugin_path = str(plugin["path"])
|
||||||
@@ -277,7 +450,7 @@ async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) ->
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/local-readme/{plugin_id}")
|
@router.get("/local-readme/{plugin_id}")
|
||||||
async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> dict[str, Any]:
|
async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||||
require_plugin_token(maibot_session)
|
require_plugin_token(maibot_session)
|
||||||
logger.info(f"获取本地插件 README: {plugin_id}")
|
logger.info(f"获取本地插件 README: {plugin_id}")
|
||||||
|
|
||||||
@@ -300,4 +473,4 @@ async def get_local_plugin_readme(plugin_id: str, maibot_session: Optional[str]
|
|||||||
return {"success": False, "error": "本地未找到 README 文件"}
|
return {"success": False, "error": "本地未找到 README 文件"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取本地 README 失败: {e}", exc_info=True)
|
logger.error(f"获取本地 README 失败: {e}", exc_info=True)
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from typing import Any, Dict, Optional, Set
|
||||||
from typing import Any, Optional, Set
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||||
|
|
||||||
@@ -14,7 +13,7 @@ logger = get_logger("webui.plugin_progress")
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
active_connections: Set[WebSocket] = set()
|
active_connections: Set[WebSocket] = set()
|
||||||
current_progress: dict[str, Any] = {
|
current_progress: Dict[str, Any] = {
|
||||||
"operation": "idle",
|
"operation": "idle",
|
||||||
"stage": "idle",
|
"stage": "idle",
|
||||||
"progress": 0,
|
"progress": 0,
|
||||||
@@ -26,7 +25,7 @@ current_progress: dict[str, Any] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def broadcast_progress(progress_data: dict[str, Any]) -> None:
|
async def broadcast_progress(progress_data: Dict[str, Any]) -> None:
|
||||||
global current_progress
|
global current_progress
|
||||||
current_progress = progress_data.copy()
|
current_progress = progress_data.copy()
|
||||||
|
|
||||||
@@ -34,7 +33,7 @@ async def broadcast_progress(progress_data: dict[str, Any]) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
message = json.dumps(progress_data, ensure_ascii=False)
|
message = json.dumps(progress_data, ensure_ascii=False)
|
||||||
disconnected: set[WebSocket] = set()
|
disconnected: Set[WebSocket] = set()
|
||||||
|
|
||||||
for websocket in active_connections:
|
for websocket in active_connections:
|
||||||
try:
|
try:
|
||||||
@@ -119,4 +118,4 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
|
|||||||
|
|
||||||
|
|
||||||
def get_progress_router() -> APIRouter:
|
def get_progress_router() -> APIRouter:
|
||||||
return router
|
return router
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -51,8 +51,8 @@ class MirrorConfigResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class AvailableMirrorsResponse(BaseModel):
|
class AvailableMirrorsResponse(BaseModel):
|
||||||
mirrors: list[MirrorConfigResponse] = Field(..., description="镜像源列表")
|
mirrors: List[MirrorConfigResponse] = Field(..., description="镜像源列表")
|
||||||
default_priority: list[str] = Field(..., description="默认优先级顺序(ID 列表)")
|
default_priority: List[str] = Field(..., description="默认优先级顺序(ID 列表)")
|
||||||
|
|
||||||
|
|
||||||
class AddMirrorRequest(BaseModel):
|
class AddMirrorRequest(BaseModel):
|
||||||
@@ -106,8 +106,8 @@ class UpdatePluginRequest(BaseModel):
|
|||||||
|
|
||||||
class UpdatePluginConfigRequest(BaseModel):
|
class UpdatePluginConfigRequest(BaseModel):
|
||||||
enabled: Optional[bool] = None
|
enabled: Optional[bool] = None
|
||||||
config: Optional[dict[str, Any]] = None
|
config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class UpdatePluginRawConfigRequest(BaseModel):
|
class UpdatePluginRawConfigRequest(BaseModel):
|
||||||
config: str = Field(..., description="原始 TOML 配置内容")
|
config: str = Field(..., description="原始 TOML 配置内容")
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Optional, cast, get_origin
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import stat
|
import stat
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, cast, get_origin
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
@@ -53,10 +52,7 @@ def _resolve_safe_plugin_directory(plugin_path: Path, plugins_dir: Path, strict:
|
|||||||
resolved_plugin_path = plugin_path.resolve()
|
resolved_plugin_path = plugin_path.resolve()
|
||||||
resolved_plugin_path.relative_to(resolved_plugins_dir)
|
resolved_plugin_path.relative_to(resolved_plugins_dir)
|
||||||
|
|
||||||
if not resolved_plugin_path.is_dir():
|
return resolved_plugin_path if resolved_plugin_path.is_dir() else None
|
||||||
return None
|
|
||||||
|
|
||||||
return resolved_plugin_path
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
if strict:
|
if strict:
|
||||||
raise
|
raise
|
||||||
@@ -64,7 +60,7 @@ def _resolve_safe_plugin_directory(plugin_path: Path, plugins_dir: Path, strict:
|
|||||||
return None
|
return None
|
||||||
except (OSError, RuntimeError, ValueError):
|
except (OSError, RuntimeError, ValueError):
|
||||||
if strict:
|
if strict:
|
||||||
raise HTTPException(status_code=400, detail="插件目录超出允许范围")
|
raise HTTPException(status_code=400, detail="插件目录超出允许范围") from None
|
||||||
logger.warning(f"已跳过越界的插件目录: {plugin_path}")
|
logger.warning(f"已跳过越界的插件目录: {plugin_path}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -109,7 +105,7 @@ def validate_plugin_id(plugin_id: str) -> str:
|
|||||||
return plugin_id
|
return plugin_id
|
||||||
|
|
||||||
|
|
||||||
def parse_version(version_str: str) -> tuple[int, int, int]:
|
def parse_version(version_str: str) -> Tuple[int, int, int]:
|
||||||
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
|
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
|
||||||
parts = base_version.split(".")
|
parts = base_version.split(".")
|
||||||
if len(parts) < 3:
|
if len(parts) < 3:
|
||||||
@@ -122,7 +118,7 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
|
|||||||
return 0, 0, 0
|
return 0, 0, 0
|
||||||
|
|
||||||
|
|
||||||
def deep_merge(dst: dict[str, Any], src: dict[str, Any]) -> None:
|
def deep_merge(dst: Dict[str, Any], src: Dict[str, Any]) -> None:
|
||||||
for key, value in src.items():
|
for key, value in src.items():
|
||||||
if key in dst and isinstance(dst[key], dict) and isinstance(value, dict):
|
if key in dst and isinstance(dst[key], dict) and isinstance(value, dict):
|
||||||
deep_merge(dst[key], value)
|
deep_merge(dst[key], value)
|
||||||
@@ -130,9 +126,9 @@ def deep_merge(dst: dict[str, Any], src: dict[str, Any]) -> None:
|
|||||||
dst[key] = value
|
dst[key] = value
|
||||||
|
|
||||||
|
|
||||||
def normalize_dotted_keys(obj: dict[str, Any]) -> dict[str, Any]:
|
def normalize_dotted_keys(obj: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
result: dict[str, Any] = {}
|
result: Dict[str, Any] = {}
|
||||||
dotted_items: list[tuple[str, Any]] = []
|
dotted_items: List[Tuple[str, Any]] = []
|
||||||
|
|
||||||
for key, value in obj.items():
|
for key, value in obj.items():
|
||||||
if "." in key:
|
if "." in key:
|
||||||
@@ -167,7 +163,7 @@ def normalize_dotted_keys(obj: dict[str, Any]) -> dict[str, Any]:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def coerce_types(schema_part: dict[str, Any], config_part: dict[str, Any]) -> None:
|
def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> None:
|
||||||
def is_list_type(tp: Any) -> bool:
|
def is_list_type(tp: Any) -> bool:
|
||||||
origin = get_origin(tp)
|
origin = get_origin(tp)
|
||||||
return tp is list or origin is list
|
return tp is list or origin is list
|
||||||
@@ -200,7 +196,7 @@ def get_plugins_dir() -> Path:
|
|||||||
return plugins_dir
|
return plugins_dir
|
||||||
|
|
||||||
|
|
||||||
def get_plugin_candidate_paths(plugin_id: str) -> tuple[Path, Path]:
|
def get_plugin_candidate_paths(plugin_id: str) -> Tuple[Path, Path]:
|
||||||
plugins_dir = get_plugins_dir()
|
plugins_dir = get_plugins_dir()
|
||||||
folder_name = plugin_id.replace(".", "_")
|
folder_name = plugin_id.replace(".", "_")
|
||||||
return validate_safe_path(folder_name, plugins_dir), validate_safe_path(plugin_id, plugins_dir)
|
return validate_safe_path(folder_name, plugins_dir), validate_safe_path(plugin_id, plugins_dir)
|
||||||
@@ -217,7 +213,7 @@ def resolve_installed_plugin_path(plugin_id: str) -> Optional[Path]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_repository_url(repository_url: str) -> tuple[str, str, str]:
|
def parse_repository_url(repository_url: str) -> Tuple[str, str, str]:
|
||||||
repo_url = repository_url.rstrip("/").removesuffix(".git")
|
repo_url = repository_url.rstrip("/").removesuffix(".git")
|
||||||
parts = repo_url.split("/")
|
parts = repo_url.split("/")
|
||||||
if len(parts) < 2:
|
if len(parts) < 2:
|
||||||
@@ -225,7 +221,7 @@ def parse_repository_url(repository_url: str) -> tuple[str, str, str]:
|
|||||||
return repo_url, parts[-2], parts[-1]
|
return repo_url, parts[-2], parts[-1]
|
||||||
|
|
||||||
|
|
||||||
def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
|
def load_manifest_json(manifest_path: Path) -> Optional[Dict[str, Any]]:
|
||||||
if not manifest_path.exists():
|
if not manifest_path.exists():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -246,9 +242,9 @@ def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def iter_plugin_directories() -> list[Path]:
|
def iter_plugin_directories() -> List[Path]:
|
||||||
plugins_dir = get_plugins_dir()
|
plugins_dir = get_plugins_dir()
|
||||||
plugin_directories: list[Path] = []
|
plugin_directories: List[Path] = []
|
||||||
for path in plugins_dir.iterdir():
|
for path in plugins_dir.iterdir():
|
||||||
safe_path = _resolve_safe_plugin_directory(path, plugins_dir, strict=False)
|
safe_path = _resolve_safe_plugin_directory(path, plugins_dir, strict=False)
|
||||||
if safe_path is not None:
|
if safe_path is not None:
|
||||||
@@ -286,4 +282,4 @@ def remove_tree(path: Path) -> None:
|
|||||||
os.chmod(target_path, stat.S_IWRITE)
|
os.chmod(target_path, stat.S_IWRITE)
|
||||||
func(target_path)
|
func(target_path)
|
||||||
|
|
||||||
shutil.rmtree(path, onerror=remove_readonly)
|
shutil.rmtree(path, onerror=remove_readonly)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""统计数据 API 路由"""
|
"""统计数据 API 路由"""
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -56,10 +56,10 @@ class DashboardData(BaseModel):
|
|||||||
"""仪表盘数据"""
|
"""仪表盘数据"""
|
||||||
|
|
||||||
summary: StatisticsSummary
|
summary: StatisticsSummary
|
||||||
model_stats: list[ModelStatistics]
|
model_stats: List[ModelStatistics]
|
||||||
hourly_data: list[TimeSeriesData]
|
hourly_data: List[TimeSeriesData]
|
||||||
daily_data: list[TimeSeriesData]
|
daily_data: List[TimeSeriesData]
|
||||||
recent_activity: list[dict[str, Any]]
|
recent_activity: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
@router.get("/dashboard", response_model=DashboardData)
|
@router.get("/dashboard", response_model=DashboardData)
|
||||||
@@ -168,7 +168,7 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
|
|||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]:
|
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
||||||
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
|
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
|
||||||
# 使用GROUP BY聚合,避免全量加载
|
# 使用GROUP BY聚合,避免全量加载
|
||||||
statement = (
|
statement = (
|
||||||
@@ -181,7 +181,7 @@ async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]:
|
|||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
rows = session.exec(statement).all()
|
rows = session.exec(statement).all()
|
||||||
|
|
||||||
aggregates: dict[str, dict[str, float | int]] = {}
|
aggregates: Dict[str, Dict[str, float | int]] = {}
|
||||||
for record in rows:
|
for record in rows:
|
||||||
model_name = record.model_assign_name or record.model_name or "unknown"
|
model_name = record.model_assign_name or record.model_name or "unknown"
|
||||||
if model_name not in aggregates:
|
if model_name not in aggregates:
|
||||||
@@ -200,7 +200,7 @@ async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]:
|
|||||||
bucket["total_time_cost"] = float(bucket["total_time_cost"]) + float(record.time_cost)
|
bucket["total_time_cost"] = float(bucket["total_time_cost"]) + float(record.time_cost)
|
||||||
bucket["time_cost_count"] = int(bucket["time_cost_count"]) + 1
|
bucket["time_cost_count"] = int(bucket["time_cost_count"]) + 1
|
||||||
|
|
||||||
result: list[ModelStatistics] = []
|
result: List[ModelStatistics] = []
|
||||||
for model_name, bucket in sorted(
|
for model_name, bucket in sorted(
|
||||||
aggregates.items(),
|
aggregates.items(),
|
||||||
key=lambda item: float(item[1]["request_count"]),
|
key=lambda item: float(item[1]["request_count"]),
|
||||||
@@ -221,7 +221,7 @@ async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> list[TimeSeriesData]:
|
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||||
"""获取小时级统计数据(优化:使用数据库聚合)"""
|
"""获取小时级统计数据(优化:使用数据库聚合)"""
|
||||||
# SQLite的日期时间函数进行小时分组
|
# SQLite的日期时间函数进行小时分组
|
||||||
# 使用strftime将timestamp格式化为小时级别
|
# 使用strftime将timestamp格式化为小时级别
|
||||||
@@ -265,7 +265,7 @@ async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> li
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> list[TimeSeriesData]:
|
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||||
"""获取日级统计数据(优化:使用数据库聚合)"""
|
"""获取日级统计数据(优化:使用数据库聚合)"""
|
||||||
# 使用strftime按日期分组
|
# 使用strftime按日期分组
|
||||||
day_expr = func.strftime("%Y-%m-%dT00:00:00", col(ModelUsage.timestamp))
|
day_expr = func.strftime("%Y-%m-%dT00:00:00", col(ModelUsage.timestamp))
|
||||||
@@ -308,7 +308,7 @@ async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> lis
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def _get_recent_activity(limit: int = 10) -> list[dict[str, Any]]:
|
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||||
"""获取最近活动"""
|
"""获取最近活动"""
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = select(ModelUsage).order_by(desc(col(ModelUsage.timestamp))).limit(limit)
|
statement = select(ModelUsage).order_by(desc(col(ModelUsage.timestamp))).limit(limit)
|
||||||
|
|||||||
@@ -10,8 +10,9 @@ from datetime import datetime
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from src.config.config import MMC_VERSION
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import MMC_VERSION
|
||||||
from src.webui.dependencies import require_auth
|
from src.webui.dependencies import require_auth
|
||||||
|
|
||||||
router = APIRouter(prefix="/system", tags=["system"], dependencies=[Depends(require_auth)])
|
router = APIRouter(prefix="/system", tags=["system"], dependencies=[Depends(require_auth)])
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from .logs import router as logs_router
|
|
||||||
from .auth import router as ws_auth_router
|
from .auth import router as ws_auth_router
|
||||||
|
from .logs import router as logs_router
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"logs_router",
|
"logs_router",
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
"""WebSocket 认证模块。"""
|
"""WebSocket 认证模块。"""
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Cookie
|
|
||||||
import secrets
|
import secrets
|
||||||
import time
|
import time
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Cookie
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.webui.core import get_token_manager
|
from src.webui.core import get_token_manager
|
||||||
|
|
||||||
@@ -13,7 +14,7 @@ router = APIRouter()
|
|||||||
|
|
||||||
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
|
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
|
||||||
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
|
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
|
||||||
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
|
_ws_temp_tokens: Dict[str, Tuple[float, str]] = {}
|
||||||
_WS_TOKEN_EXPIRE_SECONDS = 60
|
_WS_TOKEN_EXPIRE_SECONDS = 60
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,24 +2,26 @@
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.webui.core import (
|
from src.webui.core import (
|
||||||
clear_auth_cookie,
|
|
||||||
check_auth_rate_limit,
|
check_auth_rate_limit,
|
||||||
|
clear_auth_cookie,
|
||||||
get_rate_limiter,
|
get_rate_limiter,
|
||||||
get_token_manager,
|
get_token_manager,
|
||||||
set_auth_cookie,
|
set_auth_cookie,
|
||||||
)
|
)
|
||||||
from src.webui.dependencies import require_auth, verify_token_optional
|
from src.webui.dependencies import require_auth, verify_token_optional
|
||||||
from src.webui.routers.config import router as config_router
|
from src.webui.routers.config import router as config_router
|
||||||
from src.webui.routers.statistics import router as statistics_router
|
from src.webui.routers.emoji import router as emoji_router
|
||||||
from src.webui.routers.person import router as person_router
|
|
||||||
from src.webui.routers.expression import router as expression_router
|
from src.webui.routers.expression import router as expression_router
|
||||||
from src.webui.routers.jargon import router as jargon_router
|
from src.webui.routers.jargon import router as jargon_router
|
||||||
from src.webui.routers.emoji import router as emoji_router
|
|
||||||
from src.webui.routers.plugin import get_progress_router, router as plugin_router
|
|
||||||
from src.webui.routers.system import router as system_router
|
|
||||||
from src.webui.routers.model import router as model_router
|
from src.webui.routers.model import router as model_router
|
||||||
|
from src.webui.routers.person import router as person_router
|
||||||
|
from src.webui.routers.plugin import get_progress_router
|
||||||
|
from src.webui.routers.plugin import router as plugin_router
|
||||||
|
from src.webui.routers.statistics import router as statistics_router
|
||||||
|
from src.webui.routers.system import router as system_router
|
||||||
from src.webui.routers.websocket.auth import router as ws_auth_router
|
from src.webui.routers.websocket.auth import router as ws_auth_router
|
||||||
|
|
||||||
logger = get_logger("webui.api")
|
logger = get_logger("webui.api")
|
||||||
|
|||||||
@@ -2,62 +2,62 @@
|
|||||||
|
|
||||||
# Auth schemas
|
# Auth schemas
|
||||||
from .auth import (
|
from .auth import (
|
||||||
TokenVerifyRequest,
|
CompleteSetupResponse,
|
||||||
TokenVerifyResponse,
|
FirstSetupStatusResponse,
|
||||||
|
ResetSetupResponse,
|
||||||
|
TokenRegenerateResponse,
|
||||||
TokenUpdateRequest,
|
TokenUpdateRequest,
|
||||||
TokenUpdateResponse,
|
TokenUpdateResponse,
|
||||||
TokenRegenerateResponse,
|
TokenVerifyRequest,
|
||||||
FirstSetupStatusResponse,
|
TokenVerifyResponse,
|
||||||
CompleteSetupResponse,
|
|
||||||
ResetSetupResponse,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Statistics schemas
|
# Chat schemas
|
||||||
from .statistics import (
|
from .chat import (
|
||||||
StatisticsSummary,
|
ChatHistoryMessage,
|
||||||
ModelStatistics,
|
VirtualIdentityConfig,
|
||||||
TimeSeriesData,
|
|
||||||
DashboardData,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Emoji schemas
|
# Emoji schemas
|
||||||
from .emoji import (
|
from .emoji import (
|
||||||
EmojiResponse,
|
|
||||||
EmojiListResponse,
|
|
||||||
EmojiDetailResponse,
|
|
||||||
EmojiUpdateRequest,
|
|
||||||
EmojiUpdateResponse,
|
|
||||||
EmojiDeleteResponse,
|
|
||||||
BatchDeleteRequest,
|
BatchDeleteRequest,
|
||||||
BatchDeleteResponse,
|
BatchDeleteResponse,
|
||||||
|
EmojiDeleteResponse,
|
||||||
|
EmojiDetailResponse,
|
||||||
|
EmojiListResponse,
|
||||||
|
EmojiResponse,
|
||||||
|
EmojiUpdateRequest,
|
||||||
|
EmojiUpdateResponse,
|
||||||
EmojiUploadResponse,
|
EmojiUploadResponse,
|
||||||
ThumbnailCacheStatsResponse,
|
ThumbnailCacheStatsResponse,
|
||||||
ThumbnailCleanupResponse,
|
ThumbnailCleanupResponse,
|
||||||
ThumbnailPreheatResponse,
|
ThumbnailPreheatResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Chat schemas
|
|
||||||
from .chat import (
|
|
||||||
VirtualIdentityConfig,
|
|
||||||
ChatHistoryMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Plugin schemas
|
# Plugin schemas
|
||||||
from .plugin import (
|
from .plugin import (
|
||||||
FetchRawFileRequest,
|
AddMirrorRequest,
|
||||||
FetchRawFileResponse,
|
AvailableMirrorsResponse,
|
||||||
CloneRepositoryRequest,
|
CloneRepositoryRequest,
|
||||||
CloneRepositoryResponse,
|
CloneRepositoryResponse,
|
||||||
MirrorConfigResponse,
|
FetchRawFileRequest,
|
||||||
AvailableMirrorsResponse,
|
FetchRawFileResponse,
|
||||||
AddMirrorRequest,
|
|
||||||
UpdateMirrorRequest,
|
|
||||||
GitStatusResponse,
|
GitStatusResponse,
|
||||||
InstallPluginRequest,
|
InstallPluginRequest,
|
||||||
VersionResponse,
|
MirrorConfigResponse,
|
||||||
UninstallPluginRequest,
|
UninstallPluginRequest,
|
||||||
UpdatePluginRequest,
|
UpdateMirrorRequest,
|
||||||
UpdatePluginConfigRequest,
|
UpdatePluginConfigRequest,
|
||||||
|
UpdatePluginRequest,
|
||||||
|
VersionResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Statistics schemas
|
||||||
|
from .statistics import (
|
||||||
|
DashboardData,
|
||||||
|
ModelStatistics,
|
||||||
|
StatisticsSummary,
|
||||||
|
TimeSeriesData,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class VirtualIdentityConfig(BaseModel):
|
class VirtualIdentityConfig(BaseModel):
|
||||||
"""虚拟身份配置"""
|
"""虚拟身份配置"""
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
|
|
||||||
class EmojiResponse(BaseModel):
|
class EmojiResponse(BaseModel):
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Optional, List, Dict, Any
|
|
||||||
|
|
||||||
|
|
||||||
class FetchRawFileRequest(BaseModel):
|
class FetchRawFileRequest(BaseModel):
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Dict, Any, List
|
|
||||||
|
|
||||||
|
|
||||||
class StatisticsSummary(BaseModel):
|
class StatisticsSummary(BaseModel):
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
|
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
|
||||||
|
|
||||||
from typing import Optional, List, Dict, Any
|
|
||||||
from enum import Enum
|
|
||||||
import httpx
|
|
||||||
import json
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import subprocess
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
import subprocess
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.webui.utils.network_security import validate_public_url
|
from src.webui.utils.network_security import validate_public_url
|
||||||
|
|
||||||
@@ -188,10 +190,8 @@ class GitMirrorConfig:
|
|||||||
|
|
||||||
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
|
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""根据 ID 获取镜像源"""
|
"""根据 ID 获取镜像源"""
|
||||||
for mirror in self.mirrors:
|
matched_mirror = next((mirror for mirror in self.mirrors if mirror.get("id") == mirror_id), None)
|
||||||
if mirror.get("id") == mirror_id:
|
return matched_mirror.copy() if matched_mirror is not None else None
|
||||||
return mirror.copy()
|
|
||||||
return None
|
|
||||||
|
|
||||||
def add_mirror(
|
def add_mirror(
|
||||||
self,
|
self,
|
||||||
@@ -332,8 +332,8 @@ class GitMirrorService:
|
|||||||
- path: str - Git 可执行文件路径(如果已安装)
|
- path: str - Git 可执行文件路径(如果已安装)
|
||||||
- error: str - 错误信息(如果未安装或检测失败)
|
- error: str - 错误信息(如果未安装或检测失败)
|
||||||
"""
|
"""
|
||||||
import subprocess
|
|
||||||
import shutil
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 查找 git 可执行文件路径
|
# 查找 git 可执行文件路径
|
||||||
@@ -409,8 +409,7 @@ class GitMirrorService:
|
|||||||
# 确定要使用的镜像源列表
|
# 确定要使用的镜像源列表
|
||||||
if mirror_id:
|
if mirror_id:
|
||||||
# 使用指定的镜像源
|
# 使用指定的镜像源
|
||||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
if (mirror := self.config.get_mirror_by_id(mirror_id)) is None:
|
||||||
if not mirror:
|
|
||||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||||
mirrors_to_try = [mirror]
|
mirrors_to_try = [mirror]
|
||||||
else:
|
else:
|
||||||
@@ -477,7 +476,13 @@ class GitMirrorService:
|
|||||||
try:
|
try:
|
||||||
raw_prefix = _validate_mirror_prefix(mirror["raw_prefix"], "镜像 Raw 前缀")
|
raw_prefix = _validate_mirror_prefix(mirror["raw_prefix"], "镜像 Raw 前缀")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return {"success": False, "error": str(e), "mirror_used": mirror.get("id"), "attempts": 0, "status_code": 400}
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"mirror_used": mirror.get("id"),
|
||||||
|
"attempts": 0,
|
||||||
|
"status_code": 400,
|
||||||
|
}
|
||||||
|
|
||||||
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
|
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
|
||||||
|
|
||||||
@@ -566,8 +571,7 @@ class GitMirrorService:
|
|||||||
# 确定要使用的镜像源列表
|
# 确定要使用的镜像源列表
|
||||||
if mirror_id:
|
if mirror_id:
|
||||||
# 使用指定的镜像源
|
# 使用指定的镜像源
|
||||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
if (mirror := self.config.get_mirror_by_id(mirror_id)) is None:
|
||||||
if not mirror:
|
|
||||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||||
mirrors_to_try = [mirror]
|
mirrors_to_try = [mirror]
|
||||||
else:
|
else:
|
||||||
@@ -597,7 +601,13 @@ class GitMirrorService:
|
|||||||
try:
|
try:
|
||||||
clone_prefix = _validate_mirror_prefix(mirror["clone_prefix"], "镜像克隆前缀")
|
clone_prefix = _validate_mirror_prefix(mirror["clone_prefix"], "镜像克隆前缀")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return {"success": False, "error": str(e), "mirror_used": mirror.get("id"), "attempts": 0, "status_code": 400}
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"mirror_used": mirror.get("id"),
|
||||||
|
"attempts": 0,
|
||||||
|
"status_code": 400,
|
||||||
|
}
|
||||||
|
|
||||||
url = f"{clone_prefix}/{owner}/{repo}.git"
|
url = f"{clone_prefix}/{owner}/{repo}.git"
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,16 @@
|
|||||||
from typing import Iterable
|
|
||||||
|
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import socket
|
import socket
|
||||||
|
from typing import Iterable, Set
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
def _resolve_ip_addresses(hostname: str, port: int) -> set[ipaddress.IPv4Address | ipaddress.IPv6Address]:
|
def _resolve_ip_addresses(hostname: str, port: int) -> Set[ipaddress.IPv4Address | ipaddress.IPv6Address]:
|
||||||
try:
|
try:
|
||||||
address_infos = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM)
|
address_infos = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM)
|
||||||
except socket.gaierror as exc:
|
except socket.gaierror as exc:
|
||||||
raise ValueError(f"无法解析主机名: {hostname}") from exc
|
raise ValueError(f"无法解析主机名: {hostname}") from exc
|
||||||
|
|
||||||
resolved_addresses: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
|
resolved_addresses: Set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
|
||||||
for _, _, _, _, sockaddr in address_infos:
|
for _, _, _, _, sockaddr in address_infos:
|
||||||
host_address = sockaddr[0]
|
host_address = sockaddr[0]
|
||||||
if not isinstance(host_address, str):
|
if not isinstance(host_address, str):
|
||||||
@@ -73,4 +71,4 @@ def validate_public_url(url: str, allowed_schemes: Iterable[str] = ("https",)) -
|
|||||||
if _is_forbidden_ip_address(address):
|
if _is_forbidden_ip_address(address):
|
||||||
raise ValueError(f"禁止访问非公网地址: {address}")
|
raise ValueError(f"禁止访问非公网地址: {address}")
|
||||||
|
|
||||||
return normalized_url
|
return normalized_url
|
||||||
|
|||||||
@@ -4,8 +4,9 @@ TOML 工具函数
|
|||||||
提供 TOML 文件的格式化保存功能,确保数组等元素以美观的多行格式输出。
|
提供 TOML 文件的格式化保存功能,确保数组等元素以美观的多行格式输出。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
import re
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import tomlkit
|
import tomlkit
|
||||||
from tomlkit.items import AoT, Array, Table
|
from tomlkit.items import AoT, Array, Table
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""独立的 WebUI 服务器 - 运行在 0.0.0.0:8001"""
|
"""独立的 WebUI 服务器 - 运行在 0.0.0.0:8001"""
|
||||||
|
|
||||||
from uvicorn import Config, Server as UvicornServer
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from uvicorn import Config
|
||||||
|
from uvicorn import Server as UvicornServer
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.utils.port_checker import assert_port_available, is_port_conflict_error, log_port_conflict
|
from src.common.utils.port_checker import assert_port_available, is_port_conflict_error, log_port_conflict
|
||||||
from src.config.config import config_manager
|
from src.config.config import config_manager
|
||||||
|
|||||||
Reference in New Issue
Block a user