WebUI 前端 & 后端超级大重构
This commit is contained in:
221
src/webui/routers/plugin/support.py
Normal file
221
src/webui/routers/plugin/support.py
Normal file
@@ -0,0 +1,221 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, cast, get_origin
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import stat
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.core.config_types import ConfigField
|
||||
from src.webui.core import get_token_manager
|
||||
|
||||
logger = get_logger("webui.plugin_routes")
|
||||
|
||||
|
||||
def require_plugin_token(maibot_session: Optional[str]) -> str:
|
||||
token_manager = get_token_manager()
|
||||
if not maibot_session or not token_manager.verify_token(maibot_session):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
return maibot_session
|
||||
|
||||
|
||||
def validate_safe_path(user_path: str, base_path: Path) -> Path:
|
||||
base_resolved = base_path.resolve()
|
||||
if any(pattern in user_path for pattern in ["..", "\x00"]):
|
||||
logger.warning(f"检测到可疑路径: {user_path}")
|
||||
raise HTTPException(status_code=400, detail="路径包含非法字符")
|
||||
|
||||
if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"):
|
||||
logger.warning(f"检测到绝对路径: {user_path}")
|
||||
raise HTTPException(status_code=400, detail="不允许使用绝对路径")
|
||||
|
||||
target_path = (base_path / user_path).resolve()
|
||||
try:
|
||||
target_path.relative_to(base_resolved)
|
||||
except ValueError as e:
|
||||
logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}")
|
||||
raise HTTPException(status_code=400, detail="路径超出允许范围") from e
|
||||
|
||||
return target_path
|
||||
|
||||
|
||||
def validate_plugin_id(plugin_id: str) -> str:
|
||||
if not plugin_id or not plugin_id.strip():
|
||||
logger.warning("非法插件 ID: 空字符串")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能为空")
|
||||
|
||||
for pattern in ["/", "\\", "\x00", "..", "\n", "\r", "\t"]:
|
||||
if pattern in plugin_id:
|
||||
logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 包含非法字符")
|
||||
|
||||
if plugin_id.startswith(".") or plugin_id.endswith("."):
|
||||
logger.warning(f"非法插件 ID: {plugin_id}")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾")
|
||||
|
||||
if plugin_id in {".", ".."}:
|
||||
logger.warning(f"非法插件 ID: {plugin_id}")
|
||||
raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名")
|
||||
|
||||
return plugin_id
|
||||
|
||||
|
||||
def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
|
||||
parts = base_version.split(".")
|
||||
if len(parts) < 3:
|
||||
parts.extend(["0"] * (3 - len(parts)))
|
||||
|
||||
try:
|
||||
return int(parts[0]), int(parts[1]), int(parts[2])
|
||||
except (ValueError, IndexError):
|
||||
logger.warning(f"无法解析版本号: {version_str},返回默认值 (0, 0, 0)")
|
||||
return 0, 0, 0
|
||||
|
||||
|
||||
def deep_merge(dst: dict[str, Any], src: dict[str, Any]) -> None:
|
||||
for key, value in src.items():
|
||||
if key in dst and isinstance(dst[key], dict) and isinstance(value, dict):
|
||||
deep_merge(dst[key], value)
|
||||
else:
|
||||
dst[key] = value
|
||||
|
||||
|
||||
def normalize_dotted_keys(obj: dict[str, Any]) -> dict[str, Any]:
|
||||
result: dict[str, Any] = {}
|
||||
dotted_items: list[tuple[str, Any]] = []
|
||||
|
||||
for key, value in obj.items():
|
||||
if "." in key:
|
||||
dotted_items.append((key, value))
|
||||
else:
|
||||
result[key] = normalize_dotted_keys(value) if isinstance(value, dict) else value
|
||||
|
||||
for dotted_key, value in dotted_items:
|
||||
normalized_value = normalize_dotted_keys(value) if isinstance(value, dict) else value
|
||||
parts = dotted_key.split(".")
|
||||
if "" in parts:
|
||||
logger.warning(f"键路径包含空段: '{dotted_key}'")
|
||||
parts = [part for part in parts if part]
|
||||
if not parts:
|
||||
logger.warning(f"忽略空键路径: '{dotted_key}'")
|
||||
continue
|
||||
|
||||
current = result
|
||||
for index, part in enumerate(parts[:-1]):
|
||||
if part in current and not isinstance(current[part], dict):
|
||||
path_ctx = ".".join(parts[: index + 1])
|
||||
logger.warning(f"键冲突:{part} 已存在且非字典,覆盖为字典以展开 {dotted_key} (路径 {path_ctx})")
|
||||
current[part] = {}
|
||||
current = current.setdefault(part, {})
|
||||
|
||||
last_part = parts[-1]
|
||||
if last_part in current and isinstance(current[last_part], dict) and isinstance(normalized_value, dict):
|
||||
deep_merge(current[last_part], normalized_value)
|
||||
else:
|
||||
current[last_part] = normalized_value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def coerce_types(schema_part: dict[str, Any], config_part: dict[str, Any]) -> None:
|
||||
def is_list_type(tp: Any) -> bool:
|
||||
origin = get_origin(tp)
|
||||
return tp is list or origin is list
|
||||
|
||||
for key, schema_val in schema_part.items():
|
||||
if key not in config_part:
|
||||
continue
|
||||
value = config_part[key]
|
||||
if isinstance(schema_val, ConfigField):
|
||||
if is_list_type(schema_val.type) and isinstance(value, str):
|
||||
config_part[key] = [item.strip() for item in value.split(",") if item.strip()]
|
||||
elif isinstance(schema_val, dict) and isinstance(value, dict):
|
||||
coerce_types(schema_val, value)
|
||||
|
||||
|
||||
def find_plugin_instance(plugin_id: str) -> Optional[Any]:
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
manager = get_plugin_runtime_manager()
|
||||
for supervisor in manager.supervisors:
|
||||
registered = supervisor._registered_plugins.get(plugin_id)
|
||||
if registered is not None:
|
||||
return registered
|
||||
return None
|
||||
|
||||
|
||||
def get_plugins_dir() -> Path:
|
||||
plugins_dir = Path("plugins").resolve()
|
||||
plugins_dir.mkdir(exist_ok=True)
|
||||
return plugins_dir
|
||||
|
||||
|
||||
def get_plugin_candidate_paths(plugin_id: str) -> tuple[Path, Path]:
|
||||
plugins_dir = get_plugins_dir()
|
||||
folder_name = plugin_id.replace(".", "_")
|
||||
return validate_safe_path(folder_name, plugins_dir), validate_safe_path(plugin_id, plugins_dir)
|
||||
|
||||
|
||||
def resolve_installed_plugin_path(plugin_id: str) -> Optional[Path]:
|
||||
new_format_path, old_format_path = get_plugin_candidate_paths(plugin_id)
|
||||
if new_format_path.exists():
|
||||
return new_format_path
|
||||
return old_format_path if old_format_path.exists() else None
|
||||
|
||||
|
||||
def parse_repository_url(repository_url: str) -> tuple[str, str, str]:
|
||||
repo_url = repository_url.rstrip("/").removesuffix(".git")
|
||||
parts = repo_url.split("/")
|
||||
if len(parts) < 2:
|
||||
raise HTTPException(status_code=400, detail="无效的仓库 URL")
|
||||
return repo_url, parts[-2], parts[-1]
|
||||
|
||||
|
||||
def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
|
||||
if not manifest_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as file_obj:
|
||||
return cast(dict[str, Any], json.load(file_obj))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def iter_plugin_directories() -> list[Path]:
|
||||
return [path for path in get_plugins_dir().iterdir() if path.is_dir()]
|
||||
|
||||
|
||||
def find_plugin_path_by_id(plugin_id: str) -> Optional[Path]:
|
||||
for plugin_path in iter_plugin_directories():
|
||||
manifest = load_manifest_json(plugin_path / "_manifest.json")
|
||||
if manifest is not None and (manifest.get("id") == plugin_id or plugin_path.name == plugin_id):
|
||||
return plugin_path
|
||||
return None
|
||||
|
||||
|
||||
def backup_file(file_path: Path, action: str, move_file: bool = False) -> Optional[Path]:
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
backup_name = f"{file_path.name}.{action}.{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
backup_path = file_path.parent / backup_name
|
||||
if move_file:
|
||||
shutil.move(file_path, backup_path)
|
||||
else:
|
||||
shutil.copy(file_path, backup_path)
|
||||
return backup_path
|
||||
|
||||
|
||||
def remove_tree(path: Path) -> None:
|
||||
def remove_readonly(func: Any, target_path: str, _: Any) -> None:
|
||||
os.chmod(target_path, stat.S_IWRITE)
|
||||
func(target_path)
|
||||
|
||||
shutil.rmtree(path, onerror=remove_readonly)
|
||||
Reference in New Issue
Block a user