Files
mai-bot/src/webui/routers/plugin/support.py

289 lines
11 KiB
Python

from datetime import datetime
from pathlib import Path
from typing import Any, Optional, cast, get_origin
import json
import os
import re
import shutil
import stat
from fastapi import HTTPException
from src.common.logger import get_logger
from src.core.config_types import ConfigField
from src.webui.core import get_token_manager
logger = get_logger("webui.plugin_routes")
def require_plugin_token(maibot_session: Optional[str]) -> str:
token_manager = get_token_manager()
if not maibot_session or not token_manager.verify_token(maibot_session):
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
return maibot_session
def validate_safe_path(user_path: str, base_path: Path) -> Path:
base_resolved = base_path.resolve()
if any(pattern in user_path for pattern in ["..", "\x00"]):
logger.warning(f"检测到可疑路径: {user_path}")
raise HTTPException(status_code=400, detail="路径包含非法字符")
if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"):
logger.warning(f"检测到绝对路径: {user_path}")
raise HTTPException(status_code=400, detail="不允许使用绝对路径")
target_path = (base_path / user_path).resolve()
try:
target_path.relative_to(base_resolved)
except ValueError as e:
logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}")
raise HTTPException(status_code=400, detail="路径超出允许范围") from e
return target_path
def _resolve_safe_plugin_directory(plugin_path: Path, plugins_dir: Path, strict: bool) -> Optional[Path]:
try:
if plugin_path.is_symlink():
raise HTTPException(status_code=400, detail="插件目录不能是符号链接")
resolved_plugins_dir = plugins_dir.resolve()
resolved_plugin_path = plugin_path.resolve()
resolved_plugin_path.relative_to(resolved_plugins_dir)
if not resolved_plugin_path.is_dir():
return None
return resolved_plugin_path
except HTTPException:
if strict:
raise
logger.warning(f"已跳过不安全的插件目录: {plugin_path}")
return None
except (OSError, RuntimeError, ValueError):
if strict:
raise HTTPException(status_code=400, detail="插件目录超出允许范围")
logger.warning(f"已跳过越界的插件目录: {plugin_path}")
return None
def resolve_plugin_file_path(plugin_path: Path, relative_path: str, allow_missing: bool = True) -> Path:
plugin_root = plugin_path.resolve()
target_path = plugin_root / relative_path
if target_path.exists() and target_path.is_symlink():
raise HTTPException(status_code=400, detail=f"插件文件不能是符号链接: {relative_path}")
try:
resolved_target_path = target_path.resolve()
resolved_target_path.relative_to(plugin_root)
except (OSError, RuntimeError, ValueError) as e:
raise HTTPException(status_code=400, detail=f"插件文件超出允许范围: {relative_path}") from e
if not allow_missing and not resolved_target_path.exists():
raise HTTPException(status_code=404, detail=f"插件文件不存在: {relative_path}")
return resolved_target_path
def validate_plugin_id(plugin_id: str) -> str:
if not plugin_id or not plugin_id.strip():
logger.warning("非法插件 ID: 空字符串")
raise HTTPException(status_code=400, detail="插件 ID 不能为空")
for pattern in ["/", "\\", "\x00", "..", "\n", "\r", "\t"]:
if pattern in plugin_id:
logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)")
raise HTTPException(status_code=400, detail="插件 ID 包含非法字符")
if plugin_id.startswith(".") or plugin_id.endswith("."):
logger.warning(f"非法插件 ID: {plugin_id}")
raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾")
if plugin_id in {".", ".."}:
logger.warning(f"非法插件 ID: {plugin_id}")
raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名")
return plugin_id
def parse_version(version_str: str) -> tuple[int, int, int]:
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
parts = base_version.split(".")
if len(parts) < 3:
parts.extend(["0"] * (3 - len(parts)))
try:
return int(parts[0]), int(parts[1]), int(parts[2])
except (ValueError, IndexError):
logger.warning(f"无法解析版本号: {version_str},返回默认值 (0, 0, 0)")
return 0, 0, 0
def deep_merge(dst: dict[str, Any], src: dict[str, Any]) -> None:
for key, value in src.items():
if key in dst and isinstance(dst[key], dict) and isinstance(value, dict):
deep_merge(dst[key], value)
else:
dst[key] = value
def normalize_dotted_keys(obj: dict[str, Any]) -> dict[str, Any]:
result: dict[str, Any] = {}
dotted_items: list[tuple[str, Any]] = []
for key, value in obj.items():
if "." in key:
dotted_items.append((key, value))
else:
result[key] = normalize_dotted_keys(value) if isinstance(value, dict) else value
for dotted_key, value in dotted_items:
normalized_value = normalize_dotted_keys(value) if isinstance(value, dict) else value
parts = dotted_key.split(".")
if "" in parts:
logger.warning(f"键路径包含空段: '{dotted_key}'")
parts = [part for part in parts if part]
if not parts:
logger.warning(f"忽略空键路径: '{dotted_key}'")
continue
current = result
for index, part in enumerate(parts[:-1]):
if part in current and not isinstance(current[part], dict):
path_ctx = ".".join(parts[: index + 1])
logger.warning(f"键冲突:{part} 已存在且非字典,覆盖为字典以展开 {dotted_key} (路径 {path_ctx})")
current[part] = {}
current = current.setdefault(part, {})
last_part = parts[-1]
if last_part in current and isinstance(current[last_part], dict) and isinstance(normalized_value, dict):
deep_merge(current[last_part], normalized_value)
else:
current[last_part] = normalized_value
return result
def coerce_types(schema_part: dict[str, Any], config_part: dict[str, Any]) -> None:
def is_list_type(tp: Any) -> bool:
origin = get_origin(tp)
return tp is list or origin is list
for key, schema_val in schema_part.items():
if key not in config_part:
continue
value = config_part[key]
if isinstance(schema_val, ConfigField):
if is_list_type(schema_val.type) and isinstance(value, str):
config_part[key] = [item.strip() for item in value.split(",") if item.strip()]
elif isinstance(schema_val, dict) and isinstance(value, dict):
coerce_types(schema_val, value)
def find_plugin_instance(plugin_id: str) -> Optional[Any]:
from src.plugin_runtime.integration import get_plugin_runtime_manager
manager = get_plugin_runtime_manager()
for supervisor in manager.supervisors:
registered = supervisor._registered_plugins.get(plugin_id)
if registered is not None:
return registered
return None
def get_plugins_dir() -> Path:
plugins_dir = Path("plugins").resolve()
plugins_dir.mkdir(exist_ok=True)
return plugins_dir
def get_plugin_candidate_paths(plugin_id: str) -> tuple[Path, Path]:
plugins_dir = get_plugins_dir()
folder_name = plugin_id.replace(".", "_")
return validate_safe_path(folder_name, plugins_dir), validate_safe_path(plugin_id, plugins_dir)
def resolve_installed_plugin_path(plugin_id: str) -> Optional[Path]:
new_format_path, old_format_path = get_plugin_candidate_paths(plugin_id)
plugins_dir = get_plugins_dir()
if new_format_path.exists():
return _resolve_safe_plugin_directory(new_format_path, plugins_dir, strict=True)
if old_format_path.exists():
return _resolve_safe_plugin_directory(old_format_path, plugins_dir, strict=True)
return None
def parse_repository_url(repository_url: str) -> tuple[str, str, str]:
repo_url = repository_url.rstrip("/").removesuffix(".git")
parts = repo_url.split("/")
if len(parts) < 2:
raise HTTPException(status_code=400, detail="无效的仓库 URL")
return repo_url, parts[-2], parts[-1]
def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
if not manifest_path.exists():
return None
if manifest_path.is_symlink():
logger.warning(f"已拒绝读取符号链接 manifest: {manifest_path}")
return None
try:
manifest_path.resolve().relative_to(manifest_path.parent.resolve())
except (OSError, RuntimeError, ValueError):
logger.warning(f"已拒绝读取越界 manifest: {manifest_path}")
return None
try:
with open(manifest_path, "r", encoding="utf-8") as file_obj:
return cast(dict[str, Any], json.load(file_obj))
except Exception:
return None
def iter_plugin_directories() -> list[Path]:
plugins_dir = get_plugins_dir()
plugin_directories: list[Path] = []
for path in plugins_dir.iterdir():
safe_path = _resolve_safe_plugin_directory(path, plugins_dir, strict=False)
if safe_path is not None:
plugin_directories.append(safe_path)
return plugin_directories
def find_plugin_path_by_id(plugin_id: str) -> Optional[Path]:
for plugin_path in iter_plugin_directories():
manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json")
manifest = load_manifest_json(manifest_path)
if manifest is not None and (manifest.get("id") == plugin_id or plugin_path.name == plugin_id):
return plugin_path
return None
def backup_file(file_path: Path, action: str, move_file: bool = False) -> Optional[Path]:
if not file_path.exists():
return None
backup_name = f"{file_path.name}.{action}.{datetime.now().strftime('%Y%m%d%H%M%S')}"
backup_path = file_path.parent / backup_name
if move_file:
shutil.move(file_path, backup_path)
else:
shutil.copy(file_path, backup_path)
return backup_path
def remove_tree(path: Path) -> None:
if path.is_symlink():
raise ValueError(f"拒绝删除符号链接路径: {path}")
def remove_readonly(func: Any, target_path: str, _: Any) -> None:
os.chmod(target_path, stat.S_IWRITE)
func(target_path)
shutil.rmtree(path, onerror=remove_readonly)