289 lines
11 KiB
Python
289 lines
11 KiB
Python
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Optional, cast, get_origin
|
|
|
|
import json
|
|
import os
|
|
import re
|
|
import shutil
|
|
import stat
|
|
|
|
from fastapi import HTTPException
|
|
|
|
from src.common.logger import get_logger
|
|
from src.core.config_types import ConfigField
|
|
from src.webui.core import get_token_manager
|
|
|
|
logger = get_logger("webui.plugin_routes")
|
|
|
|
|
|
def require_plugin_token(maibot_session: Optional[str]) -> str:
|
|
token_manager = get_token_manager()
|
|
if not maibot_session or not token_manager.verify_token(maibot_session):
|
|
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
|
return maibot_session
|
|
|
|
|
|
def validate_safe_path(user_path: str, base_path: Path) -> Path:
|
|
base_resolved = base_path.resolve()
|
|
if any(pattern in user_path for pattern in ["..", "\x00"]):
|
|
logger.warning(f"检测到可疑路径: {user_path}")
|
|
raise HTTPException(status_code=400, detail="路径包含非法字符")
|
|
|
|
if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"):
|
|
logger.warning(f"检测到绝对路径: {user_path}")
|
|
raise HTTPException(status_code=400, detail="不允许使用绝对路径")
|
|
|
|
target_path = (base_path / user_path).resolve()
|
|
try:
|
|
target_path.relative_to(base_resolved)
|
|
except ValueError as e:
|
|
logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}")
|
|
raise HTTPException(status_code=400, detail="路径超出允许范围") from e
|
|
|
|
return target_path
|
|
|
|
|
|
def _resolve_safe_plugin_directory(plugin_path: Path, plugins_dir: Path, strict: bool) -> Optional[Path]:
|
|
try:
|
|
if plugin_path.is_symlink():
|
|
raise HTTPException(status_code=400, detail="插件目录不能是符号链接")
|
|
|
|
resolved_plugins_dir = plugins_dir.resolve()
|
|
resolved_plugin_path = plugin_path.resolve()
|
|
resolved_plugin_path.relative_to(resolved_plugins_dir)
|
|
|
|
if not resolved_plugin_path.is_dir():
|
|
return None
|
|
|
|
return resolved_plugin_path
|
|
except HTTPException:
|
|
if strict:
|
|
raise
|
|
logger.warning(f"已跳过不安全的插件目录: {plugin_path}")
|
|
return None
|
|
except (OSError, RuntimeError, ValueError):
|
|
if strict:
|
|
raise HTTPException(status_code=400, detail="插件目录超出允许范围")
|
|
logger.warning(f"已跳过越界的插件目录: {plugin_path}")
|
|
return None
|
|
|
|
|
|
def resolve_plugin_file_path(plugin_path: Path, relative_path: str, allow_missing: bool = True) -> Path:
|
|
plugin_root = plugin_path.resolve()
|
|
target_path = plugin_root / relative_path
|
|
|
|
if target_path.exists() and target_path.is_symlink():
|
|
raise HTTPException(status_code=400, detail=f"插件文件不能是符号链接: {relative_path}")
|
|
|
|
try:
|
|
resolved_target_path = target_path.resolve()
|
|
resolved_target_path.relative_to(plugin_root)
|
|
except (OSError, RuntimeError, ValueError) as e:
|
|
raise HTTPException(status_code=400, detail=f"插件文件超出允许范围: {relative_path}") from e
|
|
|
|
if not allow_missing and not resolved_target_path.exists():
|
|
raise HTTPException(status_code=404, detail=f"插件文件不存在: {relative_path}")
|
|
|
|
return resolved_target_path
|
|
|
|
|
|
def validate_plugin_id(plugin_id: str) -> str:
|
|
if not plugin_id or not plugin_id.strip():
|
|
logger.warning("非法插件 ID: 空字符串")
|
|
raise HTTPException(status_code=400, detail="插件 ID 不能为空")
|
|
|
|
for pattern in ["/", "\\", "\x00", "..", "\n", "\r", "\t"]:
|
|
if pattern in plugin_id:
|
|
logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)")
|
|
raise HTTPException(status_code=400, detail="插件 ID 包含非法字符")
|
|
|
|
if plugin_id.startswith(".") or plugin_id.endswith("."):
|
|
logger.warning(f"非法插件 ID: {plugin_id}")
|
|
raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾")
|
|
|
|
if plugin_id in {".", ".."}:
|
|
logger.warning(f"非法插件 ID: {plugin_id}")
|
|
raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名")
|
|
|
|
return plugin_id
|
|
|
|
|
|
def parse_version(version_str: str) -> tuple[int, int, int]:
|
|
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
|
|
parts = base_version.split(".")
|
|
if len(parts) < 3:
|
|
parts.extend(["0"] * (3 - len(parts)))
|
|
|
|
try:
|
|
return int(parts[0]), int(parts[1]), int(parts[2])
|
|
except (ValueError, IndexError):
|
|
logger.warning(f"无法解析版本号: {version_str},返回默认值 (0, 0, 0)")
|
|
return 0, 0, 0
|
|
|
|
|
|
def deep_merge(dst: dict[str, Any], src: dict[str, Any]) -> None:
|
|
for key, value in src.items():
|
|
if key in dst and isinstance(dst[key], dict) and isinstance(value, dict):
|
|
deep_merge(dst[key], value)
|
|
else:
|
|
dst[key] = value
|
|
|
|
|
|
def normalize_dotted_keys(obj: dict[str, Any]) -> dict[str, Any]:
|
|
result: dict[str, Any] = {}
|
|
dotted_items: list[tuple[str, Any]] = []
|
|
|
|
for key, value in obj.items():
|
|
if "." in key:
|
|
dotted_items.append((key, value))
|
|
else:
|
|
result[key] = normalize_dotted_keys(value) if isinstance(value, dict) else value
|
|
|
|
for dotted_key, value in dotted_items:
|
|
normalized_value = normalize_dotted_keys(value) if isinstance(value, dict) else value
|
|
parts = dotted_key.split(".")
|
|
if "" in parts:
|
|
logger.warning(f"键路径包含空段: '{dotted_key}'")
|
|
parts = [part for part in parts if part]
|
|
if not parts:
|
|
logger.warning(f"忽略空键路径: '{dotted_key}'")
|
|
continue
|
|
|
|
current = result
|
|
for index, part in enumerate(parts[:-1]):
|
|
if part in current and not isinstance(current[part], dict):
|
|
path_ctx = ".".join(parts[: index + 1])
|
|
logger.warning(f"键冲突:{part} 已存在且非字典,覆盖为字典以展开 {dotted_key} (路径 {path_ctx})")
|
|
current[part] = {}
|
|
current = current.setdefault(part, {})
|
|
|
|
last_part = parts[-1]
|
|
if last_part in current and isinstance(current[last_part], dict) and isinstance(normalized_value, dict):
|
|
deep_merge(current[last_part], normalized_value)
|
|
else:
|
|
current[last_part] = normalized_value
|
|
|
|
return result
|
|
|
|
|
|
def coerce_types(schema_part: dict[str, Any], config_part: dict[str, Any]) -> None:
|
|
def is_list_type(tp: Any) -> bool:
|
|
origin = get_origin(tp)
|
|
return tp is list or origin is list
|
|
|
|
for key, schema_val in schema_part.items():
|
|
if key not in config_part:
|
|
continue
|
|
value = config_part[key]
|
|
if isinstance(schema_val, ConfigField):
|
|
if is_list_type(schema_val.type) and isinstance(value, str):
|
|
config_part[key] = [item.strip() for item in value.split(",") if item.strip()]
|
|
elif isinstance(schema_val, dict) and isinstance(value, dict):
|
|
coerce_types(schema_val, value)
|
|
|
|
|
|
def find_plugin_instance(plugin_id: str) -> Optional[Any]:
|
|
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
|
|
|
manager = get_plugin_runtime_manager()
|
|
for supervisor in manager.supervisors:
|
|
registered = supervisor._registered_plugins.get(plugin_id)
|
|
if registered is not None:
|
|
return registered
|
|
return None
|
|
|
|
|
|
def get_plugins_dir() -> Path:
|
|
plugins_dir = Path("plugins").resolve()
|
|
plugins_dir.mkdir(exist_ok=True)
|
|
return plugins_dir
|
|
|
|
|
|
def get_plugin_candidate_paths(plugin_id: str) -> tuple[Path, Path]:
|
|
plugins_dir = get_plugins_dir()
|
|
folder_name = plugin_id.replace(".", "_")
|
|
return validate_safe_path(folder_name, plugins_dir), validate_safe_path(plugin_id, plugins_dir)
|
|
|
|
|
|
def resolve_installed_plugin_path(plugin_id: str) -> Optional[Path]:
|
|
new_format_path, old_format_path = get_plugin_candidate_paths(plugin_id)
|
|
plugins_dir = get_plugins_dir()
|
|
|
|
if new_format_path.exists():
|
|
return _resolve_safe_plugin_directory(new_format_path, plugins_dir, strict=True)
|
|
if old_format_path.exists():
|
|
return _resolve_safe_plugin_directory(old_format_path, plugins_dir, strict=True)
|
|
return None
|
|
|
|
|
|
def parse_repository_url(repository_url: str) -> tuple[str, str, str]:
|
|
repo_url = repository_url.rstrip("/").removesuffix(".git")
|
|
parts = repo_url.split("/")
|
|
if len(parts) < 2:
|
|
raise HTTPException(status_code=400, detail="无效的仓库 URL")
|
|
return repo_url, parts[-2], parts[-1]
|
|
|
|
|
|
def load_manifest_json(manifest_path: Path) -> Optional[dict[str, Any]]:
|
|
if not manifest_path.exists():
|
|
return None
|
|
|
|
if manifest_path.is_symlink():
|
|
logger.warning(f"已拒绝读取符号链接 manifest: {manifest_path}")
|
|
return None
|
|
|
|
try:
|
|
manifest_path.resolve().relative_to(manifest_path.parent.resolve())
|
|
except (OSError, RuntimeError, ValueError):
|
|
logger.warning(f"已拒绝读取越界 manifest: {manifest_path}")
|
|
return None
|
|
|
|
try:
|
|
with open(manifest_path, "r", encoding="utf-8") as file_obj:
|
|
return cast(dict[str, Any], json.load(file_obj))
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def iter_plugin_directories() -> list[Path]:
|
|
plugins_dir = get_plugins_dir()
|
|
plugin_directories: list[Path] = []
|
|
for path in plugins_dir.iterdir():
|
|
safe_path = _resolve_safe_plugin_directory(path, plugins_dir, strict=False)
|
|
if safe_path is not None:
|
|
plugin_directories.append(safe_path)
|
|
return plugin_directories
|
|
|
|
|
|
def find_plugin_path_by_id(plugin_id: str) -> Optional[Path]:
|
|
for plugin_path in iter_plugin_directories():
|
|
manifest_path = resolve_plugin_file_path(plugin_path, "_manifest.json")
|
|
manifest = load_manifest_json(manifest_path)
|
|
if manifest is not None and (manifest.get("id") == plugin_id or plugin_path.name == plugin_id):
|
|
return plugin_path
|
|
return None
|
|
|
|
|
|
def backup_file(file_path: Path, action: str, move_file: bool = False) -> Optional[Path]:
|
|
if not file_path.exists():
|
|
return None
|
|
|
|
backup_name = f"{file_path.name}.{action}.{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
|
backup_path = file_path.parent / backup_name
|
|
if move_file:
|
|
shutil.move(file_path, backup_path)
|
|
else:
|
|
shutil.copy(file_path, backup_path)
|
|
return backup_path
|
|
|
|
|
|
def remove_tree(path: Path) -> None:
|
|
if path.is_symlink():
|
|
raise ValueError(f"拒绝删除符号链接路径: {path}")
|
|
|
|
def remove_readonly(func: Any, target_path: str, _: Any) -> None:
|
|
os.chmod(target_path, stat.S_IWRITE)
|
|
func(target_path)
|
|
|
|
shutil.rmtree(path, onerror=remove_readonly) |