Merge branch 'dev'
This commit is contained in:
796
src/webui/anti_crawler.py
Normal file
796
src/webui/anti_crawler.py
Normal file
@@ -0,0 +1,796 @@
|
||||
"""
|
||||
WebUI 防爬虫模块
|
||||
提供爬虫检测和阻止功能,保护 WebUI 不被搜索引擎和恶意爬虫访问
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import ipaddress
|
||||
import re
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.anti_crawler")
|
||||
|
||||
# 常见爬虫 User-Agent 列表(使用更精确的关键词,避免误报)
|
||||
CRAWLER_USER_AGENTS = {
|
||||
# 搜索引擎爬虫(精确匹配)
|
||||
"googlebot",
|
||||
"bingbot",
|
||||
"baiduspider",
|
||||
"yandexbot",
|
||||
"slurp", # Yahoo
|
||||
"duckduckbot",
|
||||
"sogou",
|
||||
"exabot",
|
||||
"facebot",
|
||||
"ia_archiver", # Internet Archive
|
||||
# 通用爬虫(移除过于宽泛的关键词)
|
||||
"crawler",
|
||||
"spider",
|
||||
"scraper",
|
||||
"wget", # 保留wget,因为通常用于自动化脚本
|
||||
"scrapy", # 保留scrapy,因为这是爬虫框架
|
||||
# 安全扫描工具(这些是明确的扫描工具)
|
||||
"masscan",
|
||||
"nmap",
|
||||
"nikto",
|
||||
"sqlmap",
|
||||
# 注意:移除了以下过于宽泛的关键词以避免误报:
|
||||
# - "bot" (会误匹配GitHub-Robot等)
|
||||
# - "curl" (正常工具)
|
||||
# - "python-requests" (正常库)
|
||||
# - "httpx" (正常库)
|
||||
# - "aiohttp" (正常库)
|
||||
}
|
||||
|
||||
# 资产测绘工具 User-Agent 标识
|
||||
ASSET_SCANNER_USER_AGENTS = {
|
||||
# 知名资产测绘平台
|
||||
"shodan",
|
||||
"censys",
|
||||
"zoomeye",
|
||||
"fofa",
|
||||
"quake",
|
||||
"hunter",
|
||||
"binaryedge",
|
||||
"onyphe",
|
||||
"securitytrails",
|
||||
"virustotal",
|
||||
"passivetotal",
|
||||
# 安全扫描工具
|
||||
"acunetix",
|
||||
"appscan",
|
||||
"burpsuite",
|
||||
"nessus",
|
||||
"openvas",
|
||||
"qualys",
|
||||
"rapid7",
|
||||
"tenable",
|
||||
"veracode",
|
||||
"zap",
|
||||
"awvs", # Acunetix Web Vulnerability Scanner
|
||||
"netsparker",
|
||||
"skipfish",
|
||||
"w3af",
|
||||
"arachni",
|
||||
# 其他扫描工具
|
||||
"masscan",
|
||||
"zmap",
|
||||
"nmap",
|
||||
"whatweb",
|
||||
"wpscan",
|
||||
"joomscan",
|
||||
"dnsenum",
|
||||
"subfinder",
|
||||
"amass",
|
||||
"sublist3r",
|
||||
"theharvester",
|
||||
}
|
||||
|
||||
# 资产测绘工具常用的HTTP头标识
|
||||
ASSET_SCANNER_HEADERS = {
|
||||
# 常见的扫描工具自定义头
|
||||
"x-scan": {"shodan", "censys", "zoomeye", "fofa"},
|
||||
"x-scanner": {"nmap", "masscan", "zmap"},
|
||||
"x-probe": {"masscan", "zmap"},
|
||||
# 其他可疑头(移除反向代理标准头)
|
||||
"x-originating-ip": set(),
|
||||
"x-remote-ip": set(),
|
||||
"x-remote-addr": set(),
|
||||
# 注意:移除了以下反向代理标准头以避免误报:
|
||||
# - "x-forwarded-proto" (反向代理标准头)
|
||||
# - "x-real-ip" (反向代理标准头,已在_get_client_ip中使用)
|
||||
}
|
||||
|
||||
# 仅检查特定HTTP头中的可疑模式(收紧匹配范围)
|
||||
# 只检查这些特定头,不检查所有头
|
||||
SCANNER_SPECIFIC_HEADERS = {
|
||||
"x-scan",
|
||||
"x-scanner",
|
||||
"x-probe",
|
||||
"x-originating-ip",
|
||||
"x-remote-ip",
|
||||
"x-remote-addr",
|
||||
}
|
||||
|
||||
# 防爬虫模式配置
|
||||
# false: 禁用
|
||||
# strict: 严格模式(更严格的检测,更低的频率限制)
|
||||
# loose: 宽松模式(较宽松的检测,较高的频率限制)
|
||||
# basic: 基础模式(只记录恶意访问,不阻止,不限制请求数,不跟踪IP)
|
||||
|
||||
# IP白名单配置(从配置文件读取,逗号分隔)
|
||||
# 支持格式:
|
||||
# - 精确IP:127.0.0.1, 192.168.1.100
|
||||
# - CIDR格式:192.168.1.0/24, 172.17.0.0/16 (适用于Docker网络)
|
||||
# - 通配符:192.168.*.*, 10.*.*.*, *.*.*.* (匹配所有)
|
||||
# - IPv6:::1, 2001:db8::/32
|
||||
def _parse_allowed_ips(ip_string: str) -> list:
|
||||
"""
|
||||
解析IP白名单字符串,支持精确IP、CIDR格式和通配符
|
||||
|
||||
Args:
|
||||
ip_string: 逗号分隔的IP字符串
|
||||
|
||||
Returns:
|
||||
IP白名单列表,每个元素可能是:
|
||||
- ipaddress.IPv4Network/IPv6Network对象(CIDR格式)
|
||||
- ipaddress.IPv4Address/IPv6Address对象(精确IP)
|
||||
- str(通配符模式,已转换为正则表达式)
|
||||
"""
|
||||
allowed = []
|
||||
if not ip_string:
|
||||
return allowed
|
||||
|
||||
for ip_entry in ip_string.split(","):
|
||||
ip_entry = ip_entry.strip() # 去除空格
|
||||
if not ip_entry:
|
||||
continue
|
||||
|
||||
# 跳过注释行(以#开头)
|
||||
if ip_entry.startswith("#"):
|
||||
continue
|
||||
|
||||
# 检查通配符格式(包含*)
|
||||
if "*" in ip_entry:
|
||||
# 处理通配符
|
||||
pattern = _convert_wildcard_to_regex(ip_entry)
|
||||
if pattern:
|
||||
allowed.append(pattern)
|
||||
else:
|
||||
logger.warning(f"无效的通配符IP格式,已忽略: {ip_entry}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# 尝试解析为CIDR格式(包含/)
|
||||
if "/" in ip_entry:
|
||||
allowed.append(ipaddress.ip_network(ip_entry, strict=False))
|
||||
else:
|
||||
# 精确IP地址
|
||||
allowed.append(ipaddress.ip_address(ip_entry))
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"无效的IP白名单条目,已忽略: {ip_entry} ({e})")
|
||||
|
||||
return allowed
|
||||
|
||||
|
||||
def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
|
||||
"""
|
||||
将通配符IP模式转换为正则表达式
|
||||
|
||||
支持的格式:
|
||||
- 192.168.*.* 或 192.168.*
|
||||
- 10.*.*.* 或 10.*
|
||||
- *.*.*.* 或 *
|
||||
|
||||
Args:
|
||||
wildcard_pattern: 通配符模式字符串
|
||||
|
||||
Returns:
|
||||
正则表达式字符串,如果格式无效则返回None
|
||||
"""
|
||||
# 去除空格
|
||||
pattern = wildcard_pattern.strip()
|
||||
|
||||
# 处理单个*(匹配所有)
|
||||
if pattern == "*":
|
||||
return r".*"
|
||||
|
||||
# 处理IPv4通配符格式
|
||||
# 支持:192.168.*.*, 192.168.*, 10.*.*.*, 10.* 等
|
||||
parts = pattern.split(".")
|
||||
|
||||
if len(parts) > 4:
|
||||
return None # IPv4最多4段
|
||||
|
||||
# 构建正则表达式
|
||||
regex_parts = []
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if part == "*":
|
||||
regex_parts.append(r"\d+") # 匹配任意数字
|
||||
elif part.isdigit():
|
||||
# 验证数字范围(0-255)
|
||||
num = int(part)
|
||||
if 0 <= num <= 255:
|
||||
regex_parts.append(re.escape(part))
|
||||
else:
|
||||
return None # 无效的数字
|
||||
else:
|
||||
return None # 无效的格式
|
||||
|
||||
# 如果部分少于4段,补充.*
|
||||
while len(regex_parts) < 4:
|
||||
regex_parts.append(r"\d+")
|
||||
|
||||
# 组合成正则表达式
|
||||
regex = r"^" + r"\.".join(regex_parts) + r"$"
|
||||
return regex
|
||||
|
||||
|
||||
# 从配置读取防爬虫设置(延迟导入避免循环依赖)
|
||||
def _get_anti_crawler_config():
|
||||
"""获取防爬虫配置"""
|
||||
from src.config.config import global_config
|
||||
return {
|
||||
'mode': global_config.webui.anti_crawler_mode,
|
||||
'allowed_ips': _parse_allowed_ips(global_config.webui.allowed_ips),
|
||||
'trusted_proxies': _parse_allowed_ips(global_config.webui.trusted_proxies),
|
||||
'trust_xff': global_config.webui.trust_xff
|
||||
}
|
||||
|
||||
# 初始化配置(将在模块加载时执行)
|
||||
_config = _get_anti_crawler_config()
|
||||
ANTI_CRAWLER_MODE = _config['mode']
|
||||
ALLOWED_IPS = _config['allowed_ips']
|
||||
TRUSTED_PROXIES = _config['trusted_proxies']
|
||||
TRUST_XFF = _config['trust_xff']
|
||||
|
||||
|
||||
def _get_mode_config(mode: str) -> dict:
|
||||
"""
|
||||
根据模式获取配置参数
|
||||
|
||||
Args:
|
||||
mode: 防爬虫模式 (false/strict/loose/basic)
|
||||
|
||||
Returns:
|
||||
配置字典,包含所有相关参数
|
||||
"""
|
||||
mode = mode.lower()
|
||||
|
||||
if mode == "false":
|
||||
return {
|
||||
"enabled": False,
|
||||
"rate_limit_window": 60,
|
||||
"rate_limit_max_requests": 1000, # 禁用时设置很高的值
|
||||
"max_tracked_ips": 0,
|
||||
"check_user_agent": False,
|
||||
"check_asset_scanner": False,
|
||||
"check_rate_limit": False,
|
||||
"block_on_detect": False, # 不阻止
|
||||
}
|
||||
elif mode == "strict":
|
||||
return {
|
||||
"enabled": True,
|
||||
"rate_limit_window": 60,
|
||||
"rate_limit_max_requests": 15, # 严格模式:更低的请求数
|
||||
"max_tracked_ips": 20000,
|
||||
"check_user_agent": True,
|
||||
"check_asset_scanner": True,
|
||||
"check_rate_limit": True,
|
||||
"block_on_detect": True, # 阻止恶意访问
|
||||
}
|
||||
elif mode == "loose":
|
||||
return {
|
||||
"enabled": True,
|
||||
"rate_limit_window": 60,
|
||||
"rate_limit_max_requests": 60, # 宽松模式:更高的请求数
|
||||
"max_tracked_ips": 5000,
|
||||
"check_user_agent": True,
|
||||
"check_asset_scanner": True,
|
||||
"check_rate_limit": True,
|
||||
"block_on_detect": True, # 阻止恶意访问
|
||||
}
|
||||
else: # basic (默认模式)
|
||||
return {
|
||||
"enabled": True,
|
||||
"rate_limit_window": 60,
|
||||
"rate_limit_max_requests": 1000, # 不限制请求数
|
||||
"max_tracked_ips": 0, # 不跟踪IP
|
||||
"check_user_agent": True, # 检测但不阻止
|
||||
"check_asset_scanner": True, # 检测但不阻止
|
||||
"check_rate_limit": False, # 不限制请求频率
|
||||
"block_on_detect": False, # 只记录,不阻止
|
||||
}
|
||||
|
||||
|
||||
class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
||||
"""防爬虫中间件"""
|
||||
|
||||
def __init__(self, app, mode: str = "standard"):
|
||||
"""
|
||||
初始化防爬虫中间件
|
||||
|
||||
Args:
|
||||
app: FastAPI 应用实例
|
||||
mode: 防爬虫模式 (false/strict/loose/standard)
|
||||
"""
|
||||
super().__init__(app)
|
||||
self.mode = mode.lower()
|
||||
# 根据模式获取配置
|
||||
config = _get_mode_config(self.mode)
|
||||
self.enabled = config["enabled"]
|
||||
self.rate_limit_window = config["rate_limit_window"]
|
||||
self.rate_limit_max_requests = config["rate_limit_max_requests"]
|
||||
self.max_tracked_ips = config["max_tracked_ips"]
|
||||
self.check_user_agent = config["check_user_agent"]
|
||||
self.check_asset_scanner = config["check_asset_scanner"]
|
||||
self.check_rate_limit = config["check_rate_limit"]
|
||||
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
|
||||
|
||||
# 用于存储每个IP的请求时间戳(使用deque提高性能)
|
||||
self.request_times: dict[str, deque] = {}
|
||||
# 上次清理时间
|
||||
self.last_cleanup = time.time()
|
||||
# 将关键词列表转换为集合以提高查找性能
|
||||
self.crawler_keywords_set = set(CRAWLER_USER_AGENTS)
|
||||
self.scanner_keywords_set = set(ASSET_SCANNER_USER_AGENTS)
|
||||
|
||||
def _is_crawler_user_agent(self, user_agent: Optional[str]) -> bool:
|
||||
"""
|
||||
检测是否为爬虫 User-Agent
|
||||
|
||||
Args:
|
||||
user_agent: User-Agent 字符串
|
||||
|
||||
Returns:
|
||||
如果是爬虫则返回 True
|
||||
"""
|
||||
if not user_agent:
|
||||
# 没有 User-Agent 的请求记录日志但不直接阻止
|
||||
# 改为只记录,让频率限制来处理
|
||||
logger.debug("请求缺少User-Agent")
|
||||
return False # 不再直接阻止无User-Agent的请求
|
||||
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# 使用集合查找提高性能(检查是否包含爬虫关键词)
|
||||
for crawler_keyword in self.crawler_keywords_set:
|
||||
if crawler_keyword in user_agent_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_asset_scanner_header(self, request: Request) -> bool:
|
||||
"""
|
||||
检测是否为资产测绘工具的HTTP头(只检查特定头,收紧匹配)
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
|
||||
Returns:
|
||||
如果检测到资产测绘工具头则返回 True
|
||||
"""
|
||||
# 只检查特定的扫描工具头,不检查所有头
|
||||
for header_name, header_value in request.headers.items():
|
||||
header_name_lower = header_name.lower()
|
||||
header_value_lower = header_value.lower() if header_value else ""
|
||||
|
||||
# 检查已知的扫描工具头
|
||||
if header_name_lower in ASSET_SCANNER_HEADERS:
|
||||
# 如果该头有特定的工具集合,检查值是否匹配
|
||||
expected_tools = ASSET_SCANNER_HEADERS[header_name_lower]
|
||||
if expected_tools:
|
||||
for tool in expected_tools:
|
||||
if tool in header_value_lower:
|
||||
return True
|
||||
else:
|
||||
# 如果没有特定工具集合,只要存在该头就视为可疑
|
||||
if header_value_lower:
|
||||
return True
|
||||
|
||||
# 只检查特定头中的可疑模式(收紧匹配)
|
||||
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
|
||||
# 检查头值中是否包含已知扫描工具名称
|
||||
for tool in self.scanner_keywords_set:
|
||||
if tool in header_value_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _detect_asset_scanner(self, request: Request) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
检测资产测绘工具
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
|
||||
Returns:
|
||||
(是否检测到, 检测到的工具名称)
|
||||
"""
|
||||
user_agent = request.headers.get("User-Agent")
|
||||
|
||||
# 检查 User-Agent(使用集合查找提高性能)
|
||||
if user_agent:
|
||||
user_agent_lower = user_agent.lower()
|
||||
for scanner_keyword in self.scanner_keywords_set:
|
||||
if scanner_keyword in user_agent_lower:
|
||||
return True, scanner_keyword
|
||||
|
||||
# 检查HTTP头
|
||||
if self._is_asset_scanner_header(request):
|
||||
# 尝试从User-Agent或头中提取工具名称
|
||||
detected_tool = None
|
||||
if user_agent:
|
||||
user_agent_lower = user_agent.lower()
|
||||
for tool in self.scanner_keywords_set:
|
||||
if tool in user_agent_lower:
|
||||
detected_tool = tool
|
||||
break
|
||||
|
||||
# 检查HTTP头中的工具标识(只检查特定头)
|
||||
if not detected_tool:
|
||||
for header_name, header_value in request.headers.items():
|
||||
header_name_lower = header_name.lower()
|
||||
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
|
||||
header_value_lower = (header_value or "").lower()
|
||||
for tool in self.scanner_keywords_set:
|
||||
if tool in header_value_lower:
|
||||
detected_tool = tool
|
||||
break
|
||||
if detected_tool:
|
||||
break
|
||||
|
||||
return True, detected_tool or "unknown_scanner"
|
||||
|
||||
return False, None
|
||||
|
||||
def _check_rate_limit(self, client_ip: str) -> bool:
|
||||
"""
|
||||
检查请求频率限制
|
||||
|
||||
Args:
|
||||
client_ip: 客户端IP地址
|
||||
|
||||
Returns:
|
||||
如果超过限制则返回 True(需要阻止)
|
||||
"""
|
||||
# 检查IP白名单
|
||||
if self._is_ip_allowed(client_ip):
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 定期清理过期的请求记录(每5分钟清理一次)
|
||||
if current_time - self.last_cleanup > 300:
|
||||
self._cleanup_old_requests(current_time)
|
||||
self.last_cleanup = current_time
|
||||
|
||||
# 限制跟踪的IP数量,防止内存泄漏
|
||||
if self.max_tracked_ips > 0 and len(self.request_times) > self.max_tracked_ips:
|
||||
# 清理最旧的记录(删除最久未访问的IP)
|
||||
self._cleanup_oldest_ips()
|
||||
|
||||
# 获取或创建该IP的请求时间deque(不使用maxlen,避免限流变松)
|
||||
if client_ip not in self.request_times:
|
||||
self.request_times[client_ip] = deque()
|
||||
|
||||
request_times = self.request_times[client_ip]
|
||||
|
||||
# 移除时间窗口外的请求记录(从左侧弹出过期记录)
|
||||
while request_times and current_time - request_times[0] >= self.rate_limit_window:
|
||||
request_times.popleft()
|
||||
|
||||
# 检查是否超过限制
|
||||
if len(request_times) >= self.rate_limit_max_requests:
|
||||
return True
|
||||
|
||||
# 记录当前请求时间
|
||||
request_times.append(current_time)
|
||||
return False
|
||||
|
||||
def _cleanup_old_requests(self, current_time: float):
|
||||
"""清理过期的请求记录(只清理当前需要检查的IP,不全量遍历)"""
|
||||
# 这个方法现在主要用于定期清理,实际清理在_check_rate_limit中按需进行
|
||||
# 清理最久未访问的IP记录
|
||||
if len(self.request_times) > self.max_tracked_ips * 0.8:
|
||||
self._cleanup_oldest_ips()
|
||||
|
||||
def _cleanup_oldest_ips(self):
|
||||
"""清理最久未访问的IP记录(全量遍历找真正的oldest)"""
|
||||
if not self.request_times:
|
||||
return
|
||||
|
||||
# 先收集空deque的IP(优先删除)
|
||||
empty_ips = []
|
||||
# 找到最久未访问的IP(最旧时间戳)
|
||||
oldest_ip = None
|
||||
oldest_time = float("inf")
|
||||
|
||||
# 全量遍历找真正的oldest(超限时性能可接受)
|
||||
for ip, times in self.request_times.items():
|
||||
if not times:
|
||||
# 空deque,记录待删除
|
||||
empty_ips.append(ip)
|
||||
else:
|
||||
# 找到最旧的时间戳
|
||||
if times[0] < oldest_time:
|
||||
oldest_time = times[0]
|
||||
oldest_ip = ip
|
||||
|
||||
# 先删除空deque的IP
|
||||
for ip in empty_ips:
|
||||
del self.request_times[ip]
|
||||
|
||||
# 如果没有空deque可删除,且仍需要清理,删除最旧的一个IP
|
||||
if not empty_ips and oldest_ip:
|
||||
del self.request_times[oldest_ip]
|
||||
|
||||
def _is_trusted_proxy(self, ip: str) -> bool:
|
||||
"""
|
||||
检查IP是否在信任的代理列表中
|
||||
|
||||
Args:
|
||||
ip: IP地址字符串
|
||||
|
||||
Returns:
|
||||
如果是信任的代理则返回 True
|
||||
"""
|
||||
if not TRUSTED_PROXIES or ip == "unknown":
|
||||
return False
|
||||
|
||||
# 检查代理列表中的每个条目
|
||||
for trusted_entry in TRUSTED_PROXIES:
|
||||
# 通配符模式(字符串,正则表达式)
|
||||
if isinstance(trusted_entry, str):
|
||||
try:
|
||||
if re.match(trusted_entry, ip):
|
||||
return True
|
||||
except re.error:
|
||||
continue
|
||||
# CIDR格式(网络对象)
|
||||
elif isinstance(trusted_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj in trusted_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
# 精确IP(地址对象)
|
||||
elif isinstance(trusted_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj == trusted_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""
|
||||
获取客户端真实IP地址(带基本验证和代理信任检查)
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
|
||||
Returns:
|
||||
客户端IP地址
|
||||
"""
|
||||
# 获取直接连接的客户端IP(用于验证代理)
|
||||
direct_client_ip = None
|
||||
if request.client:
|
||||
direct_client_ip = request.client.host
|
||||
|
||||
# 检查是否信任X-Forwarded-For头
|
||||
# TRUST_XFF 只表示"启用代理解析能力",但仍要求直连 IP 在 TRUSTED_PROXIES 中
|
||||
use_xff = False
|
||||
if TRUST_XFF and TRUSTED_PROXIES and direct_client_ip:
|
||||
# 只有在启用 TRUST_XFF 且直连 IP 在信任列表中时,才信任 XFF
|
||||
use_xff = self._is_trusted_proxy(direct_client_ip)
|
||||
|
||||
# 如果信任代理,优先从 X-Forwarded-For 获取
|
||||
if use_xff:
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 可能包含多个IP,取第一个
|
||||
ip = forwarded_for.split(",")[0].strip()
|
||||
# 基本验证IP格式
|
||||
if self._validate_ip(ip):
|
||||
return ip
|
||||
|
||||
# 从 X-Real-IP 获取(如果信任代理)
|
||||
if use_xff:
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
ip = real_ip.strip()
|
||||
if self._validate_ip(ip):
|
||||
return ip
|
||||
|
||||
# 使用直接连接的客户端IP
|
||||
if direct_client_ip and self._validate_ip(direct_client_ip):
|
||||
return direct_client_ip
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _validate_ip(self, ip: str) -> bool:
|
||||
"""
|
||||
验证IP地址格式
|
||||
|
||||
Args:
|
||||
ip: IP地址字符串
|
||||
|
||||
Returns:
|
||||
如果格式有效则返回 True
|
||||
"""
|
||||
try:
|
||||
ipaddress.ip_address(ip)
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
def _is_ip_allowed(self, ip: str) -> bool:
|
||||
"""
|
||||
检查IP是否在白名单中(支持精确IP、CIDR格式和通配符)
|
||||
|
||||
Args:
|
||||
ip: 客户端IP地址
|
||||
|
||||
Returns:
|
||||
如果IP在白名单中则返回 True
|
||||
"""
|
||||
if not ALLOWED_IPS or ip == "unknown":
|
||||
return False
|
||||
|
||||
# 检查白名单中的每个条目
|
||||
for allowed_entry in ALLOWED_IPS:
|
||||
# 通配符模式(字符串,正则表达式)
|
||||
if isinstance(allowed_entry, str):
|
||||
try:
|
||||
if re.match(allowed_entry, ip):
|
||||
return True
|
||||
except re.error:
|
||||
# 正则表达式错误,跳过
|
||||
continue
|
||||
# CIDR格式(网络对象)
|
||||
elif isinstance(allowed_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj in allowed_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
# IP格式无效,跳过
|
||||
continue
|
||||
# 精确IP(地址对象)
|
||||
elif isinstance(allowed_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj == allowed_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
# IP格式无效,跳过
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
处理请求
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
call_next: 下一个中间件或路由处理函数
|
||||
|
||||
Returns:
|
||||
响应对象
|
||||
"""
|
||||
# 如果未启用,直接通过
|
||||
if not self.enabled:
|
||||
return await call_next(request)
|
||||
|
||||
# 允许访问 robots.txt(由专门的路由处理)
|
||||
if request.url.path == "/robots.txt":
|
||||
return await call_next(request)
|
||||
|
||||
# 允许访问静态资源(CSS、JS、图片等)
|
||||
# 注意:.json 已移除,避免 API 路径绕过防护
|
||||
# 静态资源只在特定前缀下放行(/static/、/assets/、/dist/)
|
||||
static_extensions = {
|
||||
".css",
|
||||
".js",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".svg",
|
||||
".ico",
|
||||
".woff",
|
||||
".woff2",
|
||||
".ttf",
|
||||
".eot",
|
||||
}
|
||||
static_prefixes = {"/static/", "/assets/", "/dist/"}
|
||||
|
||||
# 检查是否是静态资源路径(特定前缀下的静态文件)
|
||||
path = request.url.path
|
||||
is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(
|
||||
path.endswith(ext) for ext in static_extensions
|
||||
)
|
||||
|
||||
# 也允许根路径下的静态文件(如 /favicon.ico)
|
||||
is_root_static = path.count("/") == 1 and any(path.endswith(ext) for ext in static_extensions)
|
||||
|
||||
if is_static_path or is_root_static:
|
||||
return await call_next(request)
|
||||
|
||||
# 获取客户端IP(只获取一次,避免重复调用)
|
||||
client_ip = self._get_client_ip(request)
|
||||
|
||||
# 检查IP白名单(优先检查,白名单IP直接通过)
|
||||
if self._is_ip_allowed(client_ip):
|
||||
return await call_next(request)
|
||||
|
||||
# 获取 User-Agent
|
||||
user_agent = request.headers.get("User-Agent")
|
||||
|
||||
# 检测资产测绘工具(优先检测,因为更危险)
|
||||
if self.check_asset_scanner:
|
||||
is_scanner, scanner_name = self._detect_asset_scanner(request)
|
||||
if is_scanner:
|
||||
logger.warning(
|
||||
f"🚫 检测到资产测绘工具请求 - IP: {client_ip}, 工具: {scanner_name}, "
|
||||
f"User-Agent: {user_agent}, Path: {request.url.path}"
|
||||
)
|
||||
# 根据配置决定是否阻止
|
||||
if self.block_on_detect:
|
||||
return PlainTextResponse(
|
||||
"Access Denied: Asset scanning tools are not allowed",
|
||||
status_code=403,
|
||||
)
|
||||
|
||||
# 检测爬虫 User-Agent
|
||||
if self.check_user_agent and self._is_crawler_user_agent(user_agent):
|
||||
logger.warning(f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
|
||||
# 根据配置决定是否阻止
|
||||
if self.block_on_detect:
|
||||
return PlainTextResponse(
|
||||
"Access Denied: Crawlers are not allowed",
|
||||
status_code=403,
|
||||
)
|
||||
|
||||
# 检查请求频率限制
|
||||
if self.check_rate_limit and self._check_rate_limit(client_ip):
|
||||
logger.warning(f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
|
||||
return PlainTextResponse(
|
||||
"Too Many Requests: Rate limit exceeded",
|
||||
status_code=429,
|
||||
)
|
||||
|
||||
# 正常请求,继续处理
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def create_robots_txt_response() -> PlainTextResponse:
|
||||
"""
|
||||
创建 robots.txt 响应
|
||||
|
||||
Returns:
|
||||
robots.txt 响应对象
|
||||
"""
|
||||
robots_content = """User-agent: *
|
||||
Disallow: /
|
||||
|
||||
# 禁止所有爬虫访问
|
||||
"""
|
||||
return PlainTextResponse(
|
||||
content=robots_content,
|
||||
media_type="text/plain",
|
||||
headers={"Cache-Control": "public, max-age=86400"}, # 缓存24小时
|
||||
)
|
||||
@@ -6,6 +6,7 @@ WebUI 认证模块
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, Cookie, Header, Response, Request
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from .token_manager import get_token_manager
|
||||
|
||||
logger = get_logger("webui.auth")
|
||||
@@ -15,6 +16,28 @@ COOKIE_NAME = "maibot_session"
|
||||
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
|
||||
|
||||
|
||||
def _is_secure_environment() -> bool:
|
||||
"""
|
||||
检测是否应该启用安全 Cookie(HTTPS)
|
||||
|
||||
Returns:
|
||||
bool: 如果应该使用 secure cookie 则返回 True
|
||||
"""
|
||||
# 从配置读取
|
||||
if global_config.webui.secure_cookie:
|
||||
logger.info("配置中启用了 secure_cookie")
|
||||
return True
|
||||
|
||||
# 检查是否是生产环境
|
||||
if global_config.webui.mode == "production":
|
||||
logger.info("WebUI运行在生产模式,启用 secure cookie")
|
||||
return True
|
||||
|
||||
# 默认:开发环境不启用(因为通常是 HTTP)
|
||||
logger.debug("WebUI运行在开发模式,禁用 secure cookie")
|
||||
return False
|
||||
|
||||
|
||||
def get_current_token(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
@@ -22,69 +45,102 @@ def get_current_token(
|
||||
) -> str:
|
||||
"""
|
||||
获取当前请求的 token,优先从 Cookie 获取,其次从 Header 获取
|
||||
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization Header (Bearer token)
|
||||
|
||||
|
||||
Returns:
|
||||
验证通过的 token
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def set_auth_cookie(response: Response, token: str) -> None:
|
||||
def set_auth_cookie(response: Response, token: str, request: Optional[Request] = None) -> None:
|
||||
"""
|
||||
设置认证 Cookie
|
||||
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
token: 要设置的 token
|
||||
request: FastAPI Request 对象(可选,用于检测协议)
|
||||
"""
|
||||
# 根据环境和实际请求协议决定安全设置
|
||||
is_secure = _is_secure_environment()
|
||||
|
||||
# 如果提供了 request,检测实际使用的协议
|
||||
if request:
|
||||
# 检查 X-Forwarded-Proto header(代理/负载均衡器)
|
||||
forwarded_proto = request.headers.get("x-forwarded-proto", "").lower()
|
||||
if forwarded_proto:
|
||||
is_https = forwarded_proto == "https"
|
||||
logger.debug(f"检测到 X-Forwarded-Proto: {forwarded_proto}, is_https={is_https}")
|
||||
else:
|
||||
# 检查 request.url.scheme
|
||||
is_https = request.url.scheme == "https"
|
||||
logger.debug(f"检测到 scheme: {request.url.scheme}, is_https={is_https}")
|
||||
|
||||
# 如果是 HTTP 连接,强制禁用 secure 标志
|
||||
if not is_https and is_secure:
|
||||
logger.warning("=" * 80)
|
||||
logger.warning("检测到 HTTP 连接但环境配置要求 HTTPS (secure cookie)")
|
||||
logger.warning("已自动禁用 secure 标志以允许登录,但建议修改配置:")
|
||||
logger.warning("1. 在配置文件中设置: webui.secure_cookie = false")
|
||||
logger.warning("2. 如果使用反向代理,请确保正确配置 X-Forwarded-Proto 头")
|
||||
logger.warning("=" * 80)
|
||||
is_secure = False
|
||||
|
||||
# 设置 Cookie
|
||||
response.set_cookie(
|
||||
key=COOKIE_NAME,
|
||||
value=token,
|
||||
max_age=COOKIE_MAX_AGE,
|
||||
httponly=True, # 防止 JS 读取
|
||||
samesite="lax", # 允许同站导航时发送 Cookie(兼容开发环境代理)
|
||||
secure=False, # 本地开发不强制 HTTPS,生产环境建议设为 True
|
||||
httponly=True, # 防止 JS 读取,阻止 XSS 窃取
|
||||
samesite="lax", # 使用 lax 以兼容更多场景(开发和生产)
|
||||
secure=is_secure, # 根据实际协议决定
|
||||
path="/", # 确保 Cookie 在所有路径下可用
|
||||
)
|
||||
logger.debug(f"已设置认证 Cookie: {token[:8]}...")
|
||||
|
||||
logger.info(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure}, samesite=lax, httponly=True, path=/, max_age={COOKIE_MAX_AGE})")
|
||||
logger.debug(f"完整 token 前缀: {token[:20]}...")
|
||||
|
||||
|
||||
def clear_auth_cookie(response: Response) -> None:
|
||||
"""
|
||||
清除认证 Cookie
|
||||
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
"""
|
||||
# 保持与 set_auth_cookie 相同的安全设置
|
||||
is_secure = _is_secure_environment()
|
||||
|
||||
response.delete_cookie(
|
||||
key=COOKIE_NAME,
|
||||
httponly=True,
|
||||
samesite="lax",
|
||||
samesite="strict" if is_secure else "lax",
|
||||
secure=is_secure,
|
||||
path="/",
|
||||
)
|
||||
logger.debug("已清除认证 Cookie")
|
||||
@@ -96,32 +152,32 @@ def verify_auth_token_from_cookie_or_header(
|
||||
) -> bool:
|
||||
"""
|
||||
验证认证 Token,支持从 Cookie 或 Header 获取
|
||||
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
|
||||
Returns:
|
||||
验证成功返回 True
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
|
||||
return True
|
||||
|
||||
@@ -8,18 +8,30 @@
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, List
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Messages, PersonInfo
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
from src.webui.token_manager import get_token_manager
|
||||
from src.webui.ws_auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
|
||||
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# WebUI 聊天的虚拟群组 ID
|
||||
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
|
||||
WEBUI_CHAT_PLATFORM = "webui"
|
||||
@@ -63,14 +75,14 @@ class ChatHistoryManager:
|
||||
|
||||
def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""将数据库消息转换为前端格式
|
||||
|
||||
|
||||
Args:
|
||||
msg: 数据库消息对象
|
||||
group_id: 群 ID,用于判断是否是虚拟群
|
||||
"""
|
||||
# 判断是否是机器人消息
|
||||
user_id = msg.user_id or ""
|
||||
|
||||
|
||||
# 对于虚拟群,通过比较机器人 QQ 账号来判断
|
||||
# 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头
|
||||
if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
|
||||
@@ -256,6 +268,7 @@ async def get_chat_history(
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
|
||||
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取聊天历史记录
|
||||
|
||||
@@ -272,7 +285,7 @@ async def get_chat_history(
|
||||
|
||||
|
||||
@router.get("/platforms")
|
||||
async def get_available_platforms():
|
||||
async def get_available_platforms(_auth: bool = Depends(require_auth)):
|
||||
"""获取可用平台列表
|
||||
|
||||
从 PersonInfo 表中获取所有已知的平台
|
||||
@@ -303,6 +316,7 @@ async def get_persons_by_platform(
|
||||
platform: str = Query(..., description="平台名称"),
|
||||
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取指定平台的用户列表
|
||||
|
||||
@@ -350,7 +364,7 @@ async def get_persons_by_platform(
|
||||
|
||||
|
||||
@router.delete("/history")
|
||||
async def clear_chat_history(group_id: Optional[str] = Query(default=None)):
|
||||
async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)):
|
||||
"""清空聊天历史记录
|
||||
|
||||
Args:
|
||||
@@ -372,6 +386,7 @@ async def websocket_chat(
|
||||
person_id: Optional[str] = Query(default=None),
|
||||
group_name: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
|
||||
token: Optional[str] = Query(default=None), # 认证 token
|
||||
):
|
||||
"""WebSocket 聊天端点
|
||||
|
||||
@@ -382,9 +397,45 @@ async def websocket_chat(
|
||||
person_id: 虚拟身份模式的用户 person_id(可选)
|
||||
group_name: 虚拟身份模式的群名(可选)
|
||||
group_id: 虚拟身份模式的群 ID(可选,由前端生成并持久化)
|
||||
token: 认证 token(可选,也可从 Cookie 获取)
|
||||
|
||||
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
|
||||
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/api/chat/ws?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
# 生成会话 ID(每次连接都是新的)
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
@@ -414,7 +465,9 @@ async def websocket_chat(
|
||||
group_id=virtual_group_id,
|
||||
group_name=group_name or "WebUI虚拟群聊",
|
||||
)
|
||||
logger.info(f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}")
|
||||
logger.info(
|
||||
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
||||
|
||||
@@ -710,7 +763,7 @@ async def websocket_chat(
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_chat_info():
|
||||
async def get_chat_info(_auth: bool = Depends(require_auth)):
|
||||
"""获取聊天室信息"""
|
||||
return {
|
||||
"bot_name": global_config.bot.nickname,
|
||||
|
||||
@@ -4,11 +4,12 @@
|
||||
|
||||
import os
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, HTTPException, Body
|
||||
from typing import Any, Annotated
|
||||
from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
|
||||
from typing import Any, Annotated, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.toml_utils import save_toml_with_format
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
|
||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
||||
from src.config.official_configs import (
|
||||
BotConfig,
|
||||
@@ -29,9 +30,7 @@ from src.config.official_configs import (
|
||||
ToolConfig,
|
||||
MemoryConfig,
|
||||
DebugConfig,
|
||||
MoodConfig,
|
||||
VoiceConfig,
|
||||
JargonConfig,
|
||||
)
|
||||
from src.config.api_ada_configs import (
|
||||
ModelTaskConfig,
|
||||
@@ -51,45 +50,19 @@ PathBody = Annotated[dict[str, str], Body()]
|
||||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
|
||||
|
||||
# ===== 辅助函数 =====
|
||||
|
||||
|
||||
def _update_dict_preserve_comments(target: Any, source: Any) -> None:
|
||||
"""
|
||||
递归合并字典,保留 target 中的注释和格式
|
||||
将 source 的值更新到 target 中(仅更新已存在的键)
|
||||
|
||||
Args:
|
||||
target: 目标字典(tomlkit 对象,包含注释)
|
||||
source: 源字典(普通 dict 或 list)
|
||||
"""
|
||||
# 如果 source 是列表,直接替换(数组表没有注释保留的意义)
|
||||
if isinstance(source, list):
|
||||
return # 调用者需要直接赋值
|
||||
|
||||
# 如果都是字典,递归合并
|
||||
if isinstance(source, dict) and isinstance(target, dict):
|
||||
for key, value in source.items():
|
||||
if key == "version":
|
||||
continue # 跳过版本号
|
||||
if key in target:
|
||||
target_value = target[key]
|
||||
# 递归处理嵌套字典
|
||||
if isinstance(value, dict) and isinstance(target_value, dict):
|
||||
_update_dict_preserve_comments(target_value, value)
|
||||
else:
|
||||
# 使用 tomlkit.item 保持类型
|
||||
try:
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
target[key] = value
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# ===== 架构获取接口 =====
|
||||
|
||||
|
||||
@router.get("/schema/bot")
|
||||
async def get_bot_config_schema():
|
||||
async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
|
||||
"""获取麦麦主程序配置架构"""
|
||||
try:
|
||||
# Config 类包含所有子配置
|
||||
@@ -101,7 +74,7 @@ async def get_bot_config_schema():
|
||||
|
||||
|
||||
@router.get("/schema/model")
|
||||
async def get_model_config_schema():
|
||||
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
|
||||
"""获取模型配置架构(包含提供商和模型任务配置)"""
|
||||
try:
|
||||
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
|
||||
@@ -115,7 +88,7 @@ async def get_model_config_schema():
|
||||
|
||||
|
||||
@router.get("/schema/section/{section_name}")
|
||||
async def get_config_section_schema(section_name: str):
|
||||
async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取指定配置节的架构
|
||||
|
||||
@@ -138,7 +111,6 @@ async def get_config_section_schema(section_name: str):
|
||||
- tool: ToolConfig
|
||||
- memory: MemoryConfig
|
||||
- debug: DebugConfig
|
||||
- mood: MoodConfig
|
||||
- voice: VoiceConfig
|
||||
- jargon: JargonConfig
|
||||
- model_task_config: ModelTaskConfig
|
||||
@@ -164,9 +136,7 @@ async def get_config_section_schema(section_name: str):
|
||||
"tool": ToolConfig,
|
||||
"memory": MemoryConfig,
|
||||
"debug": DebugConfig,
|
||||
"mood": MoodConfig,
|
||||
"voice": VoiceConfig,
|
||||
"jargon": JargonConfig,
|
||||
"model_task_config": ModelTaskConfig,
|
||||
"api_provider": APIProvider,
|
||||
"model_info": ModelInfo,
|
||||
@@ -188,7 +158,7 @@ async def get_config_section_schema(section_name: str):
|
||||
|
||||
|
||||
@router.get("/bot")
|
||||
async def get_bot_config():
|
||||
async def get_bot_config(_auth: bool = Depends(require_auth)):
|
||||
"""获取麦麦主程序配置"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
@@ -207,7 +177,7 @@ async def get_bot_config():
|
||||
|
||||
|
||||
@router.get("/model")
|
||||
async def get_model_config():
|
||||
async def get_model_config(_auth: bool = Depends(require_auth)):
|
||||
"""获取模型配置(包含提供商和模型任务配置)"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
@@ -229,7 +199,7 @@ async def get_model_config():
|
||||
|
||||
|
||||
@router.post("/bot")
|
||||
async def update_bot_config(config_data: ConfigBody):
|
||||
async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新麦麦主程序配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
@@ -238,7 +208,7 @@ async def update_bot_config(config_data: ConfigBody):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件(格式化数组为多行)
|
||||
# 保存配置文件(自动保留注释和格式)
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
@@ -252,7 +222,7 @@ async def update_bot_config(config_data: ConfigBody):
|
||||
|
||||
|
||||
@router.post("/model")
|
||||
async def update_model_config(config_data: ConfigBody):
|
||||
async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新模型配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
@@ -261,7 +231,7 @@ async def update_model_config(config_data: ConfigBody):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件(格式化数组为多行)
|
||||
# 保存配置文件(自动保留注释和格式)
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
@@ -278,7 +248,7 @@ async def update_model_config(config_data: ConfigBody):
|
||||
|
||||
|
||||
@router.post("/bot/section/{section_name}")
|
||||
async def update_bot_config_section(section_name: str, section_data: SectionBody):
|
||||
async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
@@ -300,7 +270,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
|
||||
config_data[section_name] = section_data
|
||||
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
|
||||
# 字典递归合并
|
||||
_update_dict_preserve_comments(config_data[section_name], section_data)
|
||||
_update_toml_doc(config_data[section_name], section_data)
|
||||
else:
|
||||
# 其他类型直接替换
|
||||
config_data[section_name] = section_data
|
||||
@@ -327,7 +297,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
|
||||
|
||||
|
||||
@router.get("/bot/raw")
|
||||
async def get_bot_config_raw():
|
||||
async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
|
||||
"""获取麦麦主程序配置的原始 TOML 内容"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
@@ -346,7 +316,7 @@ async def get_bot_config_raw():
|
||||
|
||||
|
||||
@router.post("/bot/raw")
|
||||
async def update_bot_config_raw(raw_content: RawContentBody):
|
||||
async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
||||
try:
|
||||
# 验证 TOML 格式
|
||||
@@ -376,7 +346,9 @@ async def update_bot_config_raw(raw_content: RawContentBody):
|
||||
|
||||
|
||||
@router.post("/model/section/{section_name}")
|
||||
async def update_model_config_section(section_name: str, section_data: SectionBody):
|
||||
async def update_model_config_section(
|
||||
section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)
|
||||
):
|
||||
"""更新模型配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
@@ -398,7 +370,7 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
|
||||
config_data[section_name] = section_data
|
||||
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
|
||||
# 字典递归合并
|
||||
_update_dict_preserve_comments(config_data[section_name], section_data)
|
||||
_update_toml_doc(config_data[section_name], section_data)
|
||||
else:
|
||||
# 其他类型直接替换
|
||||
config_data[section_name] = section_data
|
||||
@@ -407,6 +379,17 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
|
||||
# 特殊处理:如果是更新 api_providers,检查是否有模型引用了已删除的provider
|
||||
if section_name == "api_providers" and "api_provider" in str(e):
|
||||
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
|
||||
models = config_data.get("models", [])
|
||||
orphaned_models = [
|
||||
m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
|
||||
]
|
||||
if orphaned_models:
|
||||
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
|
||||
raise HTTPException(status_code=400, detail=error_msg) from e
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置(格式化数组为多行,保留注释)
|
||||
@@ -457,7 +440,7 @@ def _to_relative_path(path: str) -> str:
|
||||
|
||||
|
||||
@router.get("/adapter-config/path")
|
||||
async def get_adapter_config_path():
|
||||
async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
|
||||
"""获取保存的适配器配置文件路径"""
|
||||
try:
|
||||
# 从 data/webui.json 读取路径偏好
|
||||
@@ -496,7 +479,7 @@ async def get_adapter_config_path():
|
||||
|
||||
|
||||
@router.post("/adapter-config/path")
|
||||
async def save_adapter_config_path(data: PathBody):
|
||||
async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)):
|
||||
"""保存适配器配置文件路径偏好"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
@@ -539,7 +522,7 @@ async def save_adapter_config_path(data: PathBody):
|
||||
|
||||
|
||||
@router.get("/adapter-config")
|
||||
async def get_adapter_config(path: str):
|
||||
async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
|
||||
"""从指定路径读取适配器配置文件"""
|
||||
try:
|
||||
if not path:
|
||||
@@ -571,7 +554,7 @@ async def get_adapter_config(path: str):
|
||||
|
||||
|
||||
@router.post("/adapter-config")
|
||||
async def save_adapter_config(data: PathBody):
|
||||
async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)):
|
||||
"""保存适配器配置到指定路径"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
|
||||
@@ -296,7 +296,6 @@ class ConfigSchemaGenerator:
|
||||
"plan_style",
|
||||
"visual_style",
|
||||
"private_plan_style",
|
||||
"emotion_style",
|
||||
"reaction",
|
||||
"filtration_prompt",
|
||||
]:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
""" 表情包管理 API 路由"""
|
||||
"""表情包管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
@@ -48,7 +48,7 @@ def _get_thumbnail_lock(file_hash: str) -> threading.Lock:
|
||||
def _background_generate_thumbnail(source_path: str, file_hash: str) -> None:
|
||||
"""
|
||||
后台生成缩略图(在线程池中执行)
|
||||
|
||||
|
||||
生成完成后自动从 generating 集合中移除
|
||||
"""
|
||||
try:
|
||||
@@ -74,14 +74,14 @@ def _get_thumbnail_cache_path(file_hash: str) -> Path:
|
||||
def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
|
||||
"""
|
||||
生成缩略图并保存到缓存目录
|
||||
|
||||
|
||||
Args:
|
||||
source_path: 原图路径
|
||||
file_hash: 文件哈希值,用作缓存文件名
|
||||
|
||||
|
||||
Returns:
|
||||
缩略图路径
|
||||
|
||||
|
||||
Features:
|
||||
- GIF: 提取第一帧作为缩略图
|
||||
- 所有格式统一转为 WebP
|
||||
@@ -89,63 +89,63 @@ def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
|
||||
"""
|
||||
_ensure_thumbnail_cache_dir()
|
||||
cache_path = _get_thumbnail_cache_path(file_hash)
|
||||
|
||||
|
||||
# 使用锁防止并发生成同一缩略图
|
||||
lock = _get_thumbnail_lock(file_hash)
|
||||
with lock:
|
||||
# 双重检查,可能在等待锁时已被其他线程生成
|
||||
if cache_path.exists():
|
||||
return cache_path
|
||||
|
||||
|
||||
try:
|
||||
with Image.open(source_path) as img:
|
||||
# GIF 处理:提取第一帧
|
||||
if hasattr(img, 'n_frames') and img.n_frames > 1:
|
||||
if hasattr(img, "n_frames") and img.n_frames > 1:
|
||||
img.seek(0) # 确保在第一帧
|
||||
|
||||
|
||||
# 转换为 RGB/RGBA(WebP 支持透明度)
|
||||
if img.mode in ('P', 'PA'):
|
||||
if img.mode in ("P", "PA"):
|
||||
# 调色板模式转换为 RGBA 以保留透明度
|
||||
img = img.convert('RGBA')
|
||||
elif img.mode == 'LA':
|
||||
img = img.convert('RGBA')
|
||||
elif img.mode not in ('RGB', 'RGBA'):
|
||||
img = img.convert('RGB')
|
||||
|
||||
img = img.convert("RGBA")
|
||||
elif img.mode == "LA":
|
||||
img = img.convert("RGBA")
|
||||
elif img.mode not in ("RGB", "RGBA"):
|
||||
img = img.convert("RGB")
|
||||
|
||||
# 创建缩略图(保持宽高比)
|
||||
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
|
||||
|
||||
|
||||
# 保存为 WebP 格式
|
||||
img.save(cache_path, 'WEBP', quality=THUMBNAIL_QUALITY, method=6)
|
||||
|
||||
img.save(cache_path, "WEBP", quality=THUMBNAIL_QUALITY, method=6)
|
||||
|
||||
logger.debug(f"生成缩略图: {file_hash} -> {cache_path}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"生成缩略图失败 {file_hash}: {e},将返回原图")
|
||||
# 生成失败时不创建缓存文件,下次会重试
|
||||
raise
|
||||
|
||||
|
||||
return cache_path
|
||||
|
||||
|
||||
def cleanup_orphaned_thumbnails() -> tuple[int, int]:
|
||||
"""
|
||||
清理孤立的缩略图缓存(原图已不存在的缩略图)
|
||||
|
||||
|
||||
Returns:
|
||||
(清理数量, 保留数量)
|
||||
"""
|
||||
if not THUMBNAIL_CACHE_DIR.exists():
|
||||
return 0, 0
|
||||
|
||||
|
||||
# 获取所有表情包的哈希值
|
||||
valid_hashes = set()
|
||||
for emoji in Emoji.select(Emoji.emoji_hash):
|
||||
valid_hashes.add(emoji.emoji_hash)
|
||||
|
||||
|
||||
cleaned = 0
|
||||
kept = 0
|
||||
|
||||
|
||||
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
|
||||
file_hash = cache_file.stem
|
||||
if file_hash not in valid_hashes:
|
||||
@@ -157,12 +157,13 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]:
|
||||
logger.warning(f"清理缩略图失败 {cache_file.name}: {e}")
|
||||
else:
|
||||
kept += 1
|
||||
|
||||
|
||||
if cleaned > 0:
|
||||
logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept} 个")
|
||||
|
||||
|
||||
return cleaned, kept
|
||||
|
||||
|
||||
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
|
||||
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
|
||||
@@ -365,7 +366,9 @@ async def get_emoji_list(
|
||||
|
||||
|
||||
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
||||
async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def get_emoji_detail(
|
||||
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取表情包详细信息
|
||||
|
||||
@@ -394,7 +397,12 @@ async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie
|
||||
|
||||
|
||||
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
||||
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def update_emoji(
|
||||
emoji_id: int,
|
||||
request: EmojiUpdateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
增量更新表情包(只更新提供的字段)
|
||||
|
||||
@@ -446,7 +454,9 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_sessio
|
||||
|
||||
|
||||
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
||||
async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def delete_emoji(
|
||||
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
删除表情包
|
||||
|
||||
@@ -538,7 +548,9 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
|
||||
|
||||
|
||||
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
||||
async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def register_emoji(
|
||||
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
注册表情包(快捷操作)
|
||||
|
||||
@@ -578,7 +590,9 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
|
||||
|
||||
|
||||
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
||||
async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def ban_emoji(
|
||||
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
禁用表情包(快捷操作)
|
||||
|
||||
@@ -633,7 +647,7 @@ async def get_emoji_thumbnail(
|
||||
|
||||
Returns:
|
||||
表情包缩略图(WebP 格式)或原图
|
||||
|
||||
|
||||
Features:
|
||||
- 懒加载:首次请求时生成缩略图
|
||||
- 缓存:后续请求直接返回缓存
|
||||
@@ -643,7 +657,7 @@ async def get_emoji_thumbnail(
|
||||
try:
|
||||
token_manager = get_token_manager()
|
||||
is_valid = False
|
||||
|
||||
|
||||
# 1. 优先使用 Cookie
|
||||
if maibot_session and token_manager.verify_token(maibot_session):
|
||||
is_valid = True
|
||||
@@ -655,7 +669,7 @@ async def get_emoji_thumbnail(
|
||||
auth_token = authorization.replace("Bearer ", "")
|
||||
if token_manager.verify_token(auth_token):
|
||||
is_valid = True
|
||||
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
@@ -680,35 +694,27 @@ async def get_emoji_thumbnail(
|
||||
}
|
||||
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
||||
return FileResponse(
|
||||
path=emoji.full_path,
|
||||
media_type=media_type,
|
||||
filename=f"{emoji.emoji_hash}.{emoji.format}"
|
||||
path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}"
|
||||
)
|
||||
|
||||
# 尝试获取或生成缩略图
|
||||
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
|
||||
|
||||
|
||||
# 检查缓存是否存在
|
||||
if cache_path.exists():
|
||||
# 缓存命中,直接返回
|
||||
return FileResponse(
|
||||
path=str(cache_path),
|
||||
media_type="image/webp",
|
||||
filename=f"{emoji.emoji_hash}_thumb.webp"
|
||||
path=str(cache_path), media_type="image/webp", filename=f"{emoji.emoji_hash}_thumb.webp"
|
||||
)
|
||||
|
||||
|
||||
# 缓存未命中,触发后台生成并返回 202
|
||||
with _generating_lock:
|
||||
if emoji.emoji_hash not in _generating_thumbnails:
|
||||
# 标记为正在生成
|
||||
_generating_thumbnails.add(emoji.emoji_hash)
|
||||
# 提交到线程池后台生成
|
||||
_thumbnail_executor.submit(
|
||||
_background_generate_thumbnail,
|
||||
emoji.full_path,
|
||||
emoji.emoji_hash
|
||||
)
|
||||
|
||||
_thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.emoji_hash)
|
||||
|
||||
# 返回 202 Accepted,告诉前端缩略图正在生成中
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
@@ -719,7 +725,7 @@ async def get_emoji_thumbnail(
|
||||
},
|
||||
headers={
|
||||
"Retry-After": "1", # 建议 1 秒后重试
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
@@ -730,7 +736,11 @@ async def get_emoji_thumbnail(
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||
async def batch_delete_emojis(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def batch_delete_emojis(
|
||||
request: BatchDeleteRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量删除表情包
|
||||
|
||||
@@ -1079,7 +1089,7 @@ async def batch_upload_emoji(
|
||||
|
||||
class ThumbnailCacheStatsResponse(BaseModel):
|
||||
"""缩略图缓存统计响应"""
|
||||
|
||||
|
||||
success: bool
|
||||
cache_dir: str
|
||||
total_count: int
|
||||
@@ -1090,7 +1100,7 @@ class ThumbnailCacheStatsResponse(BaseModel):
|
||||
|
||||
class ThumbnailCleanupResponse(BaseModel):
|
||||
"""缩略图清理响应"""
|
||||
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
cleaned_count: int
|
||||
@@ -1099,7 +1109,7 @@ class ThumbnailCleanupResponse(BaseModel):
|
||||
|
||||
class ThumbnailPreheatResponse(BaseModel):
|
||||
"""缩略图预热响应"""
|
||||
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
generated_count: int
|
||||
@@ -1114,27 +1124,27 @@ async def get_thumbnail_cache_stats(
|
||||
):
|
||||
"""
|
||||
获取缩略图缓存统计信息
|
||||
|
||||
|
||||
Returns:
|
||||
缓存目录、缓存数量、总大小、覆盖率等统计信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
|
||||
_ensure_thumbnail_cache_dir()
|
||||
|
||||
|
||||
# 统计缓存文件
|
||||
cache_files = list(THUMBNAIL_CACHE_DIR.glob("*.webp"))
|
||||
total_count = len(cache_files)
|
||||
total_size = sum(f.stat().st_size for f in cache_files)
|
||||
total_size_mb = round(total_size / (1024 * 1024), 2)
|
||||
|
||||
|
||||
# 统计表情包总数
|
||||
emoji_count = Emoji.select().count()
|
||||
|
||||
|
||||
# 计算覆盖率
|
||||
coverage_percent = round((total_count / emoji_count * 100) if emoji_count > 0 else 0, 1)
|
||||
|
||||
|
||||
return ThumbnailCacheStatsResponse(
|
||||
success=True,
|
||||
cache_dir=str(THUMBNAIL_CACHE_DIR.absolute()),
|
||||
@@ -1143,7 +1153,7 @@ async def get_thumbnail_cache_stats(
|
||||
emoji_count=emoji_count,
|
||||
coverage_percent=coverage_percent,
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -1158,22 +1168,22 @@ async def cleanup_thumbnail_cache(
|
||||
):
|
||||
"""
|
||||
清理孤立的缩略图缓存(原图已删除的表情包对应的缩略图)
|
||||
|
||||
|
||||
Returns:
|
||||
清理结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
|
||||
cleaned, kept = cleanup_orphaned_thumbnails()
|
||||
|
||||
|
||||
return ThumbnailCleanupResponse(
|
||||
success=True,
|
||||
message=f"清理完成:删除 {cleaned} 个孤立缓存,保留 {kept} 个有效缓存",
|
||||
cleaned_count=cleaned,
|
||||
kept_count=kept,
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -1189,20 +1199,20 @@ async def preheat_thumbnail_cache(
|
||||
):
|
||||
"""
|
||||
预热缩略图缓存(提前生成未缓存的缩略图)
|
||||
|
||||
|
||||
优先处理使用次数高的表情包
|
||||
|
||||
|
||||
Args:
|
||||
limit: 最多预热数量 (1-1000)
|
||||
|
||||
|
||||
Returns:
|
||||
预热结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
|
||||
_ensure_thumbnail_cache_dir()
|
||||
|
||||
|
||||
# 获取使用次数最高的表情包(未缓存的优先)
|
||||
emojis = (
|
||||
Emoji.select()
|
||||
@@ -1210,41 +1220,36 @@ async def preheat_thumbnail_cache(
|
||||
.order_by(Emoji.usage_count.desc())
|
||||
.limit(limit * 2) # 多查一些,因为有些可能已缓存
|
||||
)
|
||||
|
||||
|
||||
generated = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
|
||||
|
||||
for emoji in emojis:
|
||||
if generated >= limit:
|
||||
break
|
||||
|
||||
|
||||
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
|
||||
|
||||
|
||||
# 已缓存,跳过
|
||||
if cache_path.exists():
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
|
||||
# 原文件不存在,跳过
|
||||
if not os.path.exists(emoji.full_path):
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
# 使用线程池异步生成缩略图,避免阻塞事件循环
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
_thumbnail_executor,
|
||||
_generate_thumbnail,
|
||||
emoji.full_path,
|
||||
emoji.emoji_hash
|
||||
)
|
||||
await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.emoji_hash)
|
||||
generated += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}")
|
||||
failed += 1
|
||||
|
||||
|
||||
return ThumbnailPreheatResponse(
|
||||
success=True,
|
||||
message=f"预热完成:生成 {generated} 个,跳过 {skipped} 个已缓存,失败 {failed} 个",
|
||||
@@ -1252,7 +1257,7 @@ async def preheat_thumbnail_cache(
|
||||
skipped_count=skipped,
|
||||
failed_count=failed,
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -1267,13 +1272,13 @@ async def clear_all_thumbnail_cache(
|
||||
):
|
||||
"""
|
||||
清空所有缩略图缓存(下次访问时会重新生成)
|
||||
|
||||
|
||||
Returns:
|
||||
清理结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
|
||||
if not THUMBNAIL_CACHE_DIR.exists():
|
||||
return ThumbnailCleanupResponse(
|
||||
success=True,
|
||||
@@ -1281,7 +1286,7 @@ async def clear_all_thumbnail_cache(
|
||||
cleaned_count=0,
|
||||
kept_count=0,
|
||||
)
|
||||
|
||||
|
||||
cleaned = 0
|
||||
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
|
||||
try:
|
||||
@@ -1289,16 +1294,16 @@ async def clear_all_thumbnail_cache(
|
||||
cleaned += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"删除缓存文件失败 {cache_file.name}: {e}")
|
||||
|
||||
|
||||
logger.info(f"已清空缩略图缓存: 删除 {cleaned} 个文件")
|
||||
|
||||
|
||||
return ThumbnailCleanupResponse(
|
||||
success=True,
|
||||
message=f"已清空所有缩略图缓存:删除 {cleaned} 个文件",
|
||||
cleaned_count=cleaned,
|
||||
kept_count=0,
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""表达方式管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, NonNegativeFloat
|
||||
from typing import Optional, List, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
@@ -21,7 +21,6 @@ class ExpressionResponse(BaseModel):
|
||||
situation: str
|
||||
style: str
|
||||
context: Optional[str]
|
||||
up_content: Optional[str]
|
||||
last_active_time: float
|
||||
chat_id: str
|
||||
create_date: Optional[float]
|
||||
@@ -49,8 +48,7 @@ class ExpressionCreateRequest(BaseModel):
|
||||
|
||||
situation: str
|
||||
style: str
|
||||
context: Optional[str] = None
|
||||
up_content: Optional[str] = None
|
||||
context: Optional[str] = NonNegativeFloat
|
||||
chat_id: str
|
||||
|
||||
|
||||
@@ -60,7 +58,6 @@ class ExpressionUpdateRequest(BaseModel):
|
||||
situation: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
context: Optional[str] = None
|
||||
up_content: Optional[str] = None
|
||||
chat_id: Optional[str] = None
|
||||
|
||||
|
||||
@@ -102,7 +99,6 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
situation=expression.situation,
|
||||
style=expression.style,
|
||||
context=expression.context,
|
||||
up_content=expression.up_content,
|
||||
last_active_time=expression.last_active_time,
|
||||
chat_id=expression.chat_id,
|
||||
create_date=expression.create_date,
|
||||
@@ -260,7 +256,9 @@ async def get_expression_list(
|
||||
|
||||
|
||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||
async def get_expression_detail(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def get_expression_detail(
|
||||
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取表达方式详细信息
|
||||
|
||||
@@ -289,7 +287,11 @@ async def get_expression_detail(expression_id: int, maibot_session: Optional[str
|
||||
|
||||
|
||||
@router.post("/", response_model=ExpressionCreateResponse)
|
||||
async def create_expression(request: ExpressionCreateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def create_expression(
|
||||
request: ExpressionCreateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
创建新的表达方式
|
||||
|
||||
@@ -310,7 +312,6 @@ async def create_expression(request: ExpressionCreateRequest, maibot_session: Op
|
||||
situation=request.situation,
|
||||
style=request.style,
|
||||
context=request.context,
|
||||
up_content=request.up_content,
|
||||
chat_id=request.chat_id,
|
||||
last_active_time=current_time,
|
||||
create_date=current_time,
|
||||
@@ -331,7 +332,10 @@ async def create_expression(request: ExpressionCreateRequest, maibot_session: Op
|
||||
|
||||
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
||||
async def update_expression(
|
||||
expression_id: int, request: ExpressionUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
expression_id: int,
|
||||
request: ExpressionUpdateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
增量更新表达方式(只更新提供的字段)
|
||||
@@ -381,7 +385,9 @@ async def update_expression(
|
||||
|
||||
|
||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||
async def delete_expression(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def delete_expression(
|
||||
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
删除表达方式
|
||||
|
||||
@@ -424,7 +430,11 @@ class BatchDeleteRequest(BaseModel):
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
|
||||
async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def batch_delete_expressions(
|
||||
request: BatchDeleteRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量删除表达方式
|
||||
|
||||
@@ -465,7 +475,9 @@ async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session:
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_expression_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def get_expression_stats(
|
||||
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取表达方式统计数据
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
|
||||
"""
|
||||
if not chat_id_str:
|
||||
return []
|
||||
|
||||
|
||||
try:
|
||||
# 尝试解析为 JSON
|
||||
parsed = json.loads(chat_id_str)
|
||||
@@ -49,10 +49,10 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str:
|
||||
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
|
||||
"""
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
|
||||
|
||||
|
||||
if not stream_ids:
|
||||
return chat_id_str
|
||||
|
||||
|
||||
# 查询所有 stream_id 对应的名称
|
||||
names = []
|
||||
for stream_id in stream_ids:
|
||||
@@ -62,7 +62,7 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str:
|
||||
else:
|
||||
# 如果没找到,显示截断的 stream_id
|
||||
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
|
||||
|
||||
|
||||
return ", ".join(names) if names else chat_id_str
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ def jargon_to_dict(jargon: Jargon) -> dict:
|
||||
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None
|
||||
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
|
||||
stream_id = stream_ids[0] if stream_ids else None
|
||||
|
||||
|
||||
return {
|
||||
"id": jargon.id,
|
||||
"content": jargon.content,
|
||||
@@ -277,17 +277,13 @@ async def get_chat_list():
|
||||
"""获取所有有黑话记录的聊天列表"""
|
||||
try:
|
||||
# 获取所有不同的 chat_id
|
||||
chat_ids = (
|
||||
Jargon.select(Jargon.chat_id)
|
||||
.distinct()
|
||||
.where(Jargon.chat_id.is_null(False))
|
||||
)
|
||||
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
|
||||
|
||||
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
|
||||
|
||||
# 用于按 stream_id 去重
|
||||
seen_stream_ids: set[str] = set()
|
||||
|
||||
|
||||
for chat_id in chat_id_list:
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id)
|
||||
if stream_ids:
|
||||
@@ -346,12 +342,7 @@ async def get_jargon_stats():
|
||||
complete_count = Jargon.select().where(Jargon.is_complete).count()
|
||||
|
||||
# 关联的聊天数量
|
||||
chat_count = (
|
||||
Jargon.select(Jargon.chat_id)
|
||||
.distinct()
|
||||
.where(Jargon.chat_id.is_null(False))
|
||||
.count()
|
||||
)
|
||||
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
|
||||
|
||||
# 按聊天统计 TOP 5
|
||||
top_chats = (
|
||||
@@ -403,9 +394,7 @@ async def create_jargon(request: JargonCreateRequest):
|
||||
"""创建黑话"""
|
||||
try:
|
||||
# 检查是否已存在相同内容的黑话
|
||||
existing = Jargon.get_or_none(
|
||||
(Jargon.content == request.content) & (Jargon.chat_id == request.chat_id)
|
||||
)
|
||||
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
|
||||
|
||||
@@ -527,11 +516,7 @@ async def batch_set_jargon_status(
|
||||
if not ids:
|
||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||
|
||||
updated_count = (
|
||||
Jargon.update(is_jargon=is_jargon)
|
||||
.where(Jargon.id.in_(ids))
|
||||
.execute()
|
||||
)
|
||||
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
|
||||
|
||||
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录,is_jargon={is_jargon}")
|
||||
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
"""知识库图谱可视化 API 路由"""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Query
|
||||
from fastapi import APIRouter, Query, Depends, Cookie, Header
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
class KnowledgeNode(BaseModel):
|
||||
"""知识节点"""
|
||||
|
||||
@@ -113,6 +122,7 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||||
async def get_knowledge_graph(
|
||||
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
||||
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取知识图谱(限制节点数量)
|
||||
|
||||
@@ -199,7 +209,7 @@ async def get_knowledge_graph(
|
||||
|
||||
|
||||
@router.get("/stats", response_model=KnowledgeStats)
|
||||
async def get_knowledge_stats():
|
||||
async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
|
||||
"""获取知识库统计信息
|
||||
|
||||
Returns:
|
||||
@@ -248,7 +258,7 @@ async def get_knowledge_stats():
|
||||
|
||||
|
||||
@router.get("/search", response_model=List[KnowledgeNode])
|
||||
async def search_knowledge_node(query: str = Query(..., min_length=1)):
|
||||
async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)):
|
||||
"""搜索知识节点
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""WebSocket 日志推送模块"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from typing import Set
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from typing import Set, Optional
|
||||
import json
|
||||
from pathlib import Path
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.token_manager import get_token_manager
|
||||
from src.webui.ws_auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.logs_ws")
|
||||
router = APIRouter()
|
||||
@@ -73,14 +75,48 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
|
||||
|
||||
|
||||
@router.websocket("/ws/logs")
|
||||
async def websocket_logs(websocket: WebSocket):
|
||||
async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)):
|
||||
"""WebSocket 日志推送端点
|
||||
|
||||
客户端连接后会持续接收服务器端的日志消息
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/ws/logs?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
active_connections.add(websocket)
|
||||
logger.info(f"📡 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
|
||||
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
|
||||
|
||||
# 连接建立后,立即发送历史日志
|
||||
try:
|
||||
|
||||
@@ -6,18 +6,27 @@
|
||||
|
||||
import os
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header
|
||||
from typing import Optional
|
||||
import tomlkit
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import CONFIG_DIR
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["models"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# 模型获取器配置
|
||||
MODEL_FETCHER_CONFIG = {
|
||||
# OpenAI 兼容格式的提供商
|
||||
@@ -184,6 +193,7 @@ async def get_provider_models(
|
||||
provider_name: str = Query(..., description="提供商名称"),
|
||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
获取指定提供商的可用模型列表
|
||||
@@ -228,6 +238,7 @@ async def get_models_by_url(
|
||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
通过 URL 直接获取模型列表(用于自定义提供商)
|
||||
@@ -251,6 +262,7 @@ async def get_models_by_url(
|
||||
async def test_provider_connection(
|
||||
base_url: str = Query(..., description="提供商的基础 URL"),
|
||||
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
测试提供商连接状态
|
||||
@@ -337,6 +349,7 @@ async def test_provider_connection(
|
||||
@router.post("/test-connection-by-name")
|
||||
async def test_provider_connection_by_name(
|
||||
provider_name: str = Query(..., description="提供商名称"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
通过提供商名称测试连接(从配置文件读取信息)
|
||||
|
||||
@@ -200,7 +200,9 @@ async def get_person_list(
|
||||
|
||||
|
||||
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
||||
async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def get_person_detail(
|
||||
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取人物详细信息
|
||||
|
||||
@@ -229,7 +231,12 @@ async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cook
|
||||
|
||||
|
||||
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
||||
async def update_person(person_id: str, request: PersonUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def update_person(
|
||||
person_id: str,
|
||||
request: PersonUpdateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
增量更新人物信息(只更新提供的字段)
|
||||
|
||||
@@ -278,7 +285,9 @@ async def update_person(person_id: str, request: PersonUpdateRequest, maibot_ses
|
||||
|
||||
|
||||
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
||||
async def delete_person(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def delete_person(
|
||||
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
删除人物信息
|
||||
|
||||
@@ -348,7 +357,11 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||
async def batch_delete_persons(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
async def batch_delete_persons(
|
||||
request: BatchDeleteRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量删除人物信息
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""WebSocket 插件加载进度推送模块"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from typing import Set, Dict, Any
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from typing import Set, Dict, Any, Optional
|
||||
import json
|
||||
import asyncio
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.token_manager import get_token_manager
|
||||
from src.webui.ws_auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.plugin_progress")
|
||||
|
||||
@@ -89,14 +91,48 @@ async def update_progress(
|
||||
|
||||
|
||||
@router.websocket("/ws/plugin-progress")
|
||||
async def websocket_plugin_progress(websocket: WebSocket):
|
||||
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)):
|
||||
"""WebSocket 插件加载进度推送端点
|
||||
|
||||
客户端连接后会立即收到当前进度状态
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/ws/plugin-progress?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
active_connections.add(websocket)
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
|
||||
|
||||
try:
|
||||
# 发送当前进度状态
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
245
src/webui/rate_limiter.py
Normal file
245
src/webui/rate_limiter.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
WebUI 请求频率限制模块
|
||||
防止暴力破解和 API 滥用
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Tuple, Optional
|
||||
from fastapi import Request, HTTPException
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.rate_limiter")
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
简单的内存请求频率限制器
|
||||
|
||||
使用滑动窗口算法实现
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 存储格式: {key: [(timestamp, count), ...]}
|
||||
self._requests: Dict[str, list] = defaultdict(list)
|
||||
# 被封禁的 IP: {ip: unblock_timestamp}
|
||||
self._blocked: Dict[str, float] = {}
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""获取客户端 IP 地址"""
|
||||
# 检查代理头
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
# 取第一个 IP(最原始的客户端)
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# 直接连接的客户端
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _cleanup_old_requests(self, key: str, window_seconds: int):
|
||||
"""清理过期的请求记录"""
|
||||
now = time.time()
|
||||
cutoff = now - window_seconds
|
||||
self._requests[key] = [(ts, count) for ts, count in self._requests[key] if ts > cutoff]
|
||||
|
||||
def _cleanup_expired_blocks(self):
|
||||
"""清理过期的封禁"""
|
||||
now = time.time()
|
||||
expired = [ip for ip, unblock_time in self._blocked.items() if now > unblock_time]
|
||||
for ip in expired:
|
||||
del self._blocked[ip]
|
||||
logger.info(f"🔓 IP {ip} 封禁已解除")
|
||||
|
||||
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
|
||||
"""
|
||||
检查 IP 是否被封禁
|
||||
|
||||
Returns:
|
||||
(是否被封禁, 剩余封禁秒数)
|
||||
"""
|
||||
self._cleanup_expired_blocks()
|
||||
ip = self._get_client_ip(request)
|
||||
|
||||
if ip in self._blocked:
|
||||
remaining = int(self._blocked[ip] - time.time())
|
||||
return True, max(0, remaining)
|
||||
|
||||
return False, None
|
||||
|
||||
def check_rate_limit(
|
||||
self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = ""
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
检查请求是否超过频率限制
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
max_requests: 窗口期内允许的最大请求数
|
||||
window_seconds: 窗口时间(秒)
|
||||
key_suffix: 键后缀,用于区分不同的限制规则
|
||||
|
||||
Returns:
|
||||
(是否允许, 剩余请求数)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:{key_suffix}" if key_suffix else ip
|
||||
|
||||
# 清理过期记录
|
||||
self._cleanup_old_requests(key, window_seconds)
|
||||
|
||||
# 计算当前窗口内的请求数
|
||||
current_count = sum(count for _, count in self._requests[key])
|
||||
|
||||
if current_count >= max_requests:
|
||||
return False, 0
|
||||
|
||||
# 记录新请求
|
||||
now = time.time()
|
||||
self._requests[key].append((now, 1))
|
||||
|
||||
remaining = max_requests - current_count - 1
|
||||
return True, remaining
|
||||
|
||||
def block_ip(self, request: Request, duration_seconds: int):
|
||||
"""
|
||||
封禁 IP
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
duration_seconds: 封禁时长(秒)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
self._blocked[ip] = time.time() + duration_seconds
|
||||
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds} 秒")
|
||||
|
||||
def record_failed_attempt(
|
||||
self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
记录失败尝试(如登录失败)
|
||||
|
||||
如果在窗口期内失败次数过多,自动封禁 IP
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
max_failures: 允许的最大失败次数
|
||||
window_seconds: 统计窗口(秒)
|
||||
block_duration: 封禁时长(秒)
|
||||
|
||||
Returns:
|
||||
(是否被封禁, 剩余尝试次数)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:auth_failures"
|
||||
|
||||
# 清理过期记录
|
||||
self._cleanup_old_requests(key, window_seconds)
|
||||
|
||||
# 计算当前失败次数
|
||||
current_failures = sum(count for _, count in self._requests[key])
|
||||
|
||||
# 记录本次失败
|
||||
now = time.time()
|
||||
self._requests[key].append((now, 1))
|
||||
current_failures += 1
|
||||
|
||||
remaining = max_failures - current_failures
|
||||
|
||||
# 检查是否需要封禁
|
||||
if current_failures >= max_failures:
|
||||
self.block_ip(request, block_duration)
|
||||
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
|
||||
return True, 0
|
||||
|
||||
if current_failures >= max_failures - 2:
|
||||
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures} 次")
|
||||
|
||||
return False, max(0, remaining)
|
||||
|
||||
def reset_failures(self, request: Request):
|
||||
"""
|
||||
重置失败计数(认证成功后调用)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:auth_failures"
|
||||
if key in self._requests:
|
||||
del self._requests[key]
|
||||
|
||||
|
||||
# 全局单例
|
||||
_rate_limiter: Optional[RateLimiter] = None
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""获取 RateLimiter 单例"""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = RateLimiter()
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
async def check_auth_rate_limit(request: Request):
|
||||
"""
|
||||
认证接口的频率限制依赖
|
||||
|
||||
规则:
|
||||
- 每个 IP 每分钟最多 10 次认证请求
|
||||
- 连续失败 5 次后封禁 10 分钟
|
||||
"""
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
# 检查是否被封禁
|
||||
blocked, remaining_block = limiter.is_blocked(request)
|
||||
if blocked:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||
headers={"Retry-After": str(remaining_block)},
|
||||
)
|
||||
|
||||
# 检查频率限制
|
||||
allowed, remaining = limiter.check_rate_limit(
|
||||
request,
|
||||
max_requests=10, # 每分钟 10 次
|
||||
window_seconds=60,
|
||||
key_suffix="auth",
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
|
||||
|
||||
|
||||
async def check_api_rate_limit(request: Request):
|
||||
"""
|
||||
普通 API 的频率限制依赖
|
||||
|
||||
规则:每个 IP 每分钟最多 100 次请求
|
||||
"""
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
# 检查是否被封禁
|
||||
blocked, remaining_block = limiter.is_blocked(request)
|
||||
if blocked:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||
headers={"Retry-After": str(remaining_block)},
|
||||
)
|
||||
|
||||
# 检查频率限制
|
||||
allowed, _ = limiter.check_rate_limit(
|
||||
request,
|
||||
max_requests=100, # 每分钟 100 次
|
||||
window_seconds=60,
|
||||
key_suffix="api",
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
|
||||
@@ -7,10 +7,12 @@
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
||||
from pydantic import BaseModel
|
||||
from src.config.config import MMC_VERSION
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
router = APIRouter(prefix="/system", tags=["system"])
|
||||
logger = get_logger("webui_system")
|
||||
@@ -19,6 +21,14 @@ logger = get_logger("webui_system")
|
||||
_start_time = time.time()
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
class RestartResponse(BaseModel):
|
||||
"""重启响应"""
|
||||
|
||||
@@ -36,7 +46,7 @@ class StatusResponse(BaseModel):
|
||||
|
||||
|
||||
@router.post("/restart", response_model=RestartResponse)
|
||||
async def restart_maibot():
|
||||
async def restart_maibot(_auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
重启麦麦主程序
|
||||
|
||||
@@ -67,7 +77,7 @@ async def restart_maibot():
|
||||
|
||||
|
||||
@router.get("/status", response_model=StatusResponse)
|
||||
async def get_maibot_status():
|
||||
async def get_maibot_status(_auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取麦麦运行状态
|
||||
|
||||
@@ -90,7 +100,7 @@ async def get_maibot_status():
|
||||
|
||||
|
||||
@router.post("/reload-config")
|
||||
async def reload_config():
|
||||
async def reload_config(_auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
热重载配置(不重启进程)
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""WebUI API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie
|
||||
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from .token_manager import get_token_manager
|
||||
from .auth import set_auth_cookie, clear_auth_cookie
|
||||
from .rate_limiter import get_rate_limiter, check_auth_rate_limit
|
||||
from .config_routes import router as config_router
|
||||
from .statistics_routes import router as statistics_router
|
||||
from .person_routes import router as person_router
|
||||
@@ -16,6 +17,7 @@ from .plugin_routes import router as plugin_router
|
||||
from .plugin_progress_ws import get_progress_router
|
||||
from .routers.system import router as system_router
|
||||
from .model_routes import router as model_router
|
||||
from .ws_auth import router as ws_auth_router
|
||||
|
||||
logger = get_logger("webui.api")
|
||||
|
||||
@@ -42,6 +44,8 @@ router.include_router(get_progress_router())
|
||||
router.include_router(system_router)
|
||||
# 注册模型列表获取路由
|
||||
router.include_router(model_router)
|
||||
# 注册 WebSocket 认证路由
|
||||
router.include_router(ws_auth_router)
|
||||
|
||||
|
||||
class TokenVerifyRequest(BaseModel):
|
||||
@@ -107,12 +111,18 @@ async def health_check():
|
||||
|
||||
|
||||
@router.post("/auth/verify", response_model=TokenVerifyResponse)
|
||||
async def verify_token(request: TokenVerifyRequest, response: Response):
|
||||
async def verify_token(
|
||||
request_body: TokenVerifyRequest,
|
||||
request: Request,
|
||||
response: Response,
|
||||
_rate_limit: None = Depends(check_auth_rate_limit),
|
||||
):
|
||||
"""
|
||||
验证访问令牌,验证成功后设置 HttpOnly Cookie
|
||||
|
||||
Args:
|
||||
request: 包含 token 的验证请求
|
||||
request_body: 包含 token 的验证请求
|
||||
request: FastAPI Request 对象(用于获取客户端 IP)
|
||||
response: FastAPI Response 对象
|
||||
|
||||
Returns:
|
||||
@@ -120,16 +130,37 @@ async def verify_token(request: TokenVerifyRequest, response: Response):
|
||||
"""
|
||||
try:
|
||||
token_manager = get_token_manager()
|
||||
is_valid = token_manager.verify_token(request.token)
|
||||
rate_limiter = get_rate_limiter()
|
||||
|
||||
is_valid = token_manager.verify_token(request_body.token)
|
||||
|
||||
if is_valid:
|
||||
# 设置 HttpOnly Cookie
|
||||
set_auth_cookie(response, request.token)
|
||||
# 认证成功,重置失败计数
|
||||
rate_limiter.reset_failures(request)
|
||||
# 设置 HttpOnly Cookie(传入 request 以检测协议)
|
||||
set_auth_cookie(response, request_body.token, request)
|
||||
# 同时返回首次配置状态,避免额外请求
|
||||
is_first_setup = token_manager.is_first_setup()
|
||||
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
|
||||
else:
|
||||
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
|
||||
# 记录失败尝试
|
||||
blocked, remaining = rate_limiter.record_failed_attempt(
|
||||
request,
|
||||
max_failures=5, # 5 次失败
|
||||
window_seconds=300, # 5 分钟窗口
|
||||
block_duration=600, # 封禁 10 分钟
|
||||
)
|
||||
|
||||
if blocked:
|
||||
raise HTTPException(status_code=429, detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟")
|
||||
|
||||
message = "Token 无效或已过期"
|
||||
if remaining <= 2:
|
||||
message += f"(剩余 {remaining} 次尝试机会)"
|
||||
|
||||
return TokenVerifyResponse(valid=False, message=message)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token 验证失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
||||
@@ -139,10 +170,10 @@ async def verify_token(request: TokenVerifyRequest, response: Response):
|
||||
async def logout(response: Response):
|
||||
"""
|
||||
登出并清除认证 Cookie
|
||||
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
|
||||
|
||||
Returns:
|
||||
登出结果
|
||||
"""
|
||||
@@ -158,29 +189,39 @@ async def check_auth_status(
|
||||
):
|
||||
"""
|
||||
检查当前认证状态(用于前端判断是否已登录)
|
||||
|
||||
|
||||
Returns:
|
||||
认证状态
|
||||
"""
|
||||
try:
|
||||
token = None
|
||||
|
||||
# 记录请求信息用于调试
|
||||
logger.debug(f"检查认证状态 - Cookie: {maibot_session[:20] if maibot_session else 'None'}..., Authorization: {'Present' if authorization else 'None'}")
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
logger.debug("使用 Cookie 中的 token")
|
||||
# 其次从 Header 获取
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
logger.debug("使用 Header 中的 token")
|
||||
|
||||
if not token:
|
||||
logger.debug("未找到 token,返回未认证")
|
||||
return {"authenticated": False}
|
||||
|
||||
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_valid = token_manager.verify_token(token)
|
||||
logger.debug(f"Token 验证结果: {is_valid}")
|
||||
|
||||
if is_valid:
|
||||
return {"authenticated": True}
|
||||
else:
|
||||
return {"authenticated": False}
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(f"认证检查失败: {e}", exc_info=True)
|
||||
return {"authenticated": False}
|
||||
|
||||
|
||||
@@ -211,7 +252,7 @@ async def update_token(
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
@@ -222,10 +263,10 @@ async def update_token(
|
||||
|
||||
# 更新 token
|
||||
success, message = token_manager.update_token(request.new_token)
|
||||
|
||||
# 如果更新成功,更新 Cookie
|
||||
|
||||
# 如果更新成功,清除 Cookie,要求用户重新登录
|
||||
if success:
|
||||
set_auth_cookie(response, request.new_token)
|
||||
clear_auth_cookie(response)
|
||||
|
||||
return TokenUpdateResponse(success=success, message=message)
|
||||
except HTTPException:
|
||||
@@ -263,7 +304,7 @@ async def regenerate_token(
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
@@ -271,9 +312,9 @@ async def regenerate_token(
|
||||
|
||||
# 重新生成 token
|
||||
new_token = token_manager.regenerate_token()
|
||||
|
||||
# 更新 Cookie
|
||||
set_auth_cookie(response, new_token)
|
||||
|
||||
# 清除 Cookie,要求用户重新登录
|
||||
clear_auth_cookie(response)
|
||||
|
||||
return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成")
|
||||
except HTTPException:
|
||||
@@ -306,7 +347,7 @@ async def get_setup_status(
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
@@ -349,7 +390,7 @@ async def complete_setup(
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
@@ -392,7 +433,7 @@ async def reset_setup(
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
@@ -1,19 +1,28 @@
|
||||
"""统计数据 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from peewee import fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
logger = get_logger("webui.statistics")
|
||||
|
||||
router = APIRouter(prefix="/statistics", tags=["statistics"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
class StatisticsSummary(BaseModel):
|
||||
"""统计数据摘要"""
|
||||
|
||||
@@ -58,7 +67,7 @@ class DashboardData(BaseModel):
|
||||
|
||||
|
||||
@router.get("/dashboard", response_model=DashboardData)
|
||||
async def get_dashboard_data(hours: int = 24):
|
||||
async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取仪表盘统计数据
|
||||
|
||||
@@ -275,7 +284,7 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_summary(hours: int = 24):
|
||||
async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取统计摘要
|
||||
|
||||
@@ -293,7 +302,7 @@ async def get_summary(hours: int = 24):
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def get_model_stats(hours: int = 24):
|
||||
async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取模型统计
|
||||
|
||||
|
||||
@@ -160,13 +160,29 @@ class TokenManager:
|
||||
|
||||
def regenerate_token(self) -> str:
|
||||
"""
|
||||
重新生成 token
|
||||
重新生成 token(保留 first_setup_completed 状态)
|
||||
|
||||
Returns:
|
||||
str: 新生成的 token
|
||||
"""
|
||||
logger.info("正在重新生成 WebUI Token...")
|
||||
return self._create_new_token()
|
||||
|
||||
# 生成新的 64 位十六进制字符串
|
||||
new_token = secrets.token_hex(32)
|
||||
|
||||
# 加载现有配置,保留 first_setup_completed 状态
|
||||
config = self._load_config()
|
||||
old_token = config.get("access_token", "")[:8] if config.get("access_token") else "无"
|
||||
first_setup_completed = config.get("first_setup_completed", True) # 默认为 True,表示已完成配置
|
||||
|
||||
config["access_token"] = new_token
|
||||
config["updated_at"] = self._get_current_timestamp()
|
||||
config["first_setup_completed"] = first_setup_completed # 保留原来的状态
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"WebUI Token 已重新生成: {old_token}... -> {new_token[:8]}...")
|
||||
|
||||
return new_token
|
||||
|
||||
def _validate_token_format(self, token: str) -> bool:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""独立的 WebUI 服务器 - 默认运行在 127.0.0.1:8001"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
@@ -22,6 +21,9 @@ class WebUIServer:
|
||||
self.app = FastAPI(title="MaiBot WebUI")
|
||||
self._server = None
|
||||
|
||||
# 配置防爬虫中间件(需要在CORS之前注册)
|
||||
self._setup_anti_crawler()
|
||||
|
||||
# 配置 CORS(支持开发环境跨域请求)
|
||||
self._setup_cors()
|
||||
|
||||
@@ -32,6 +34,9 @@ class WebUIServer:
|
||||
self._register_api_routes()
|
||||
self._setup_static_files()
|
||||
|
||||
# 注册robots.txt路由
|
||||
self._setup_robots_txt()
|
||||
|
||||
def _setup_cors(self):
|
||||
"""配置 CORS 中间件"""
|
||||
# 开发环境需要允许前端开发服务器的跨域请求
|
||||
@@ -40,12 +45,21 @@ class WebUIServer:
|
||||
allow_origins=[
|
||||
"http://localhost:5173", # Vite 开发服务器
|
||||
"http://127.0.0.1:5173",
|
||||
"http://localhost:7999", # 前端开发服务器备用端口
|
||||
"http://127.0.0.1:7999",
|
||||
"http://localhost:8001", # 生产环境
|
||||
"http://127.0.0.1:8001",
|
||||
],
|
||||
allow_credentials=True, # 允许携带 Cookie
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], # 明确指定允许的方法
|
||||
allow_headers=[
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"Accept",
|
||||
"Origin",
|
||||
"X-Requested-With",
|
||||
], # 明确指定允许的头
|
||||
expose_headers=["Content-Length", "Content-Type"], # 允许前端读取的响应头
|
||||
)
|
||||
logger.debug("✅ CORS 中间件已配置")
|
||||
|
||||
@@ -89,43 +103,77 @@ class WebUIServer:
|
||||
"""服务单页应用 - 只处理非 API 请求"""
|
||||
# 如果是根路径,直接返回 index.html
|
||||
if not full_path or full_path == "/":
|
||||
return FileResponse(static_path / "index.html", media_type="text/html")
|
||||
response = FileResponse(static_path / "index.html", media_type="text/html")
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||
return response
|
||||
|
||||
# 检查是否是静态文件
|
||||
file_path = static_path / full_path
|
||||
if file_path.is_file() and file_path.exists():
|
||||
# 自动检测 MIME 类型
|
||||
media_type = mimetypes.guess_type(str(file_path))[0]
|
||||
return FileResponse(file_path, media_type=media_type)
|
||||
response = FileResponse(file_path, media_type=media_type)
|
||||
# HTML 文件添加防索引头
|
||||
if str(file_path).endswith(".html"):
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||
return response
|
||||
|
||||
# 其他路径返回 index.html(SPA 路由)
|
||||
return FileResponse(static_path / "index.html", media_type="text/html")
|
||||
response = FileResponse(static_path / "index.html", media_type="text/html")
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||
return response
|
||||
|
||||
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")
|
||||
|
||||
def _setup_anti_crawler(self):
|
||||
"""配置防爬虫中间件"""
|
||||
try:
|
||||
from src.webui.anti_crawler import AntiCrawlerMiddleware
|
||||
from src.config.config import global_config
|
||||
|
||||
# 从配置读取防爬虫模式
|
||||
anti_crawler_mode = global_config.webui.anti_crawler_mode
|
||||
|
||||
# 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行
|
||||
# 我们需要在CORS之前注册,这样防爬虫检查会在CORS之前执行
|
||||
self.app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
|
||||
|
||||
mode_descriptions = {"false": "已禁用", "strict": "严格模式", "loose": "宽松模式", "basic": "基础模式"}
|
||||
mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式")
|
||||
logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 配置防爬虫中间件失败: {e}", exc_info=True)
|
||||
|
||||
def _setup_robots_txt(self):
|
||||
"""设置robots.txt路由"""
|
||||
try:
|
||||
from src.webui.anti_crawler import create_robots_txt_response
|
||||
|
||||
@self.app.get("/robots.txt", include_in_schema=False)
|
||||
async def robots_txt():
|
||||
"""返回robots.txt,禁止所有爬虫"""
|
||||
return create_robots_txt_response()
|
||||
|
||||
logger.debug("✅ robots.txt 路由已注册")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True)
|
||||
|
||||
def _register_api_routes(self):
|
||||
"""注册所有 WebUI API 路由"""
|
||||
try:
|
||||
# 导入所有 WebUI 路由
|
||||
from src.webui.routes import router as webui_router
|
||||
from src.webui.logs_ws import router as logs_router
|
||||
|
||||
logger.info("开始导入 knowledge_routes...")
|
||||
from src.webui.knowledge_routes import router as knowledge_router
|
||||
|
||||
logger.info("knowledge_routes 导入成功")
|
||||
|
||||
# 导入本地聊天室路由
|
||||
from src.webui.chat_routes import router as chat_router
|
||||
|
||||
logger.info("chat_routes 导入成功")
|
||||
|
||||
# 注册路由
|
||||
self.app.include_router(webui_router)
|
||||
self.app.include_router(logs_router)
|
||||
self.app.include_router(knowledge_router)
|
||||
self.app.include_router(chat_router)
|
||||
logger.info(f"knowledge_router 路由前缀: {knowledge_router.prefix}")
|
||||
|
||||
logger.info("✅ WebUI API 路由已注册")
|
||||
except Exception as e:
|
||||
@@ -133,6 +181,16 @@ class WebUIServer:
|
||||
|
||||
async def start(self):
|
||||
"""启动服务器"""
|
||||
# 预先检查端口是否可用
|
||||
if not self._check_port_available():
|
||||
error_msg = f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用"
|
||||
logger.error(error_msg)
|
||||
logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}")
|
||||
logger.error("💡 可以在 .env 文件中修改 WEBUI_PORT 来更改 WebUI 端口")
|
||||
logger.error(f"💡 Windows 用户可以运行: netstat -ano | findstr :{self.port}")
|
||||
logger.error(f"💡 Linux/Mac 用户可以运行: lsof -i :{self.port}")
|
||||
raise OSError(f"端口 {self.port} 已被占用,无法启动 WebUI 服务器")
|
||||
|
||||
config = Config(
|
||||
app=self.app,
|
||||
host=self.host,
|
||||
@@ -143,15 +201,59 @@ class WebUIServer:
|
||||
self._server = UvicornServer(config=config)
|
||||
|
||||
logger.info("🌐 WebUI 服务器启动中...")
|
||||
logger.info(f"🌐 访问地址: http://{self.host}:{self.port}")
|
||||
if self.host == "0.0.0.0":
|
||||
logger.info(f"本机访问请使用 http://localhost:{self.port}")
|
||||
|
||||
# 根据地址类型显示正确的访问地址
|
||||
if ':' in self.host:
|
||||
# IPv6 地址需要用方括号包裹
|
||||
logger.info(f"🌐 访问地址: http://[{self.host}]:{self.port}")
|
||||
if self.host == "::":
|
||||
logger.info(f"💡 IPv6 本机访问: http://[::1]:{self.port}")
|
||||
logger.info(f"💡 IPv4 本机访问: http://127.0.0.1:{self.port}")
|
||||
elif self.host == "::1":
|
||||
logger.info("💡 仅支持 IPv6 本地访问")
|
||||
else:
|
||||
# IPv4 地址
|
||||
logger.info(f"🌐 访问地址: http://{self.host}:{self.port}")
|
||||
if self.host == "0.0.0.0":
|
||||
logger.info(f"💡 本机访问: http://localhost:{self.port} 或 http://127.0.0.1:{self.port}")
|
||||
|
||||
try:
|
||||
await self._server.serve()
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebUI 服务器运行错误: {e}")
|
||||
except OSError as e:
|
||||
# 处理端口绑定相关的错误
|
||||
if "address already in use" in str(e).lower() or e.errno in (98, 10048): # 98: Linux, 10048: Windows
|
||||
logger.error(f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用")
|
||||
logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}")
|
||||
logger.error("💡 可以在 .env 文件中修改 WEBUI_PORT 来更改 WebUI 端口")
|
||||
else:
|
||||
logger.error(f"❌ WebUI 服务器启动失败 (网络错误): {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebUI 服务器运行错误: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _check_port_available(self) -> bool:
|
||||
"""检查端口是否可用(支持 IPv4 和 IPv6)"""
|
||||
import socket
|
||||
|
||||
# 判断使用 IPv4 还是 IPv6
|
||||
if ':' in self.host:
|
||||
# IPv6 地址
|
||||
family = socket.AF_INET6
|
||||
test_host = self.host if self.host != "::" else "::1"
|
||||
else:
|
||||
# IPv4 地址
|
||||
family = socket.AF_INET
|
||||
test_host = self.host if self.host != "0.0.0.0" else "127.0.0.1"
|
||||
|
||||
try:
|
||||
with socket.socket(family, socket.SOCK_STREAM) as s:
|
||||
s.settimeout(1)
|
||||
# 尝试绑定端口
|
||||
s.bind((test_host, self.port))
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭服务器"""
|
||||
@@ -177,7 +279,8 @@ def get_webui_server() -> WebUIServer:
|
||||
"""获取全局 WebUI 服务器实例"""
|
||||
global _webui_server
|
||||
if _webui_server is None:
|
||||
# 从环境变量读取配置
|
||||
# 从环境变量读取
|
||||
import os
|
||||
host = os.getenv("WEBUI_HOST", "127.0.0.1")
|
||||
port = int(os.getenv("WEBUI_PORT", "8001"))
|
||||
_webui_server = WebUIServer(host=host, port=port)
|
||||
|
||||
114
src/webui/ws_auth.py
Normal file
114
src/webui/ws_auth.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""WebSocket 认证模块
|
||||
|
||||
提供所有 WebSocket 端点统一使用的临时 token 认证机制。
|
||||
临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Cookie, Header
|
||||
from typing import Optional
|
||||
import secrets
|
||||
import time
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.token_manager import get_token_manager
|
||||
|
||||
logger = get_logger("webui.ws_auth")
|
||||
router = APIRouter()
|
||||
|
||||
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
|
||||
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
|
||||
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
|
||||
_WS_TOKEN_EXPIRE_SECONDS = 60
|
||||
|
||||
|
||||
def _cleanup_expired_ws_tokens():
|
||||
"""清理过期的临时 token"""
|
||||
now = time.time()
|
||||
expired = [t for t, (exp, _) in _ws_temp_tokens.items() if now > exp]
|
||||
for t in expired:
|
||||
del _ws_temp_tokens[t]
|
||||
|
||||
|
||||
def generate_ws_token(session_token: str) -> str:
|
||||
"""生成 WebSocket 临时 token
|
||||
|
||||
Args:
|
||||
session_token: 原始的 session token
|
||||
|
||||
Returns:
|
||||
临时 token 字符串
|
||||
"""
|
||||
_cleanup_expired_ws_tokens()
|
||||
temp_token = secrets.token_urlsafe(32)
|
||||
_ws_temp_tokens[temp_token] = (time.time() + _WS_TOKEN_EXPIRE_SECONDS, session_token)
|
||||
logger.debug(f"生成 WS 临时 token: {temp_token[:8]}... 有效期 {_WS_TOKEN_EXPIRE_SECONDS}s")
|
||||
return temp_token
|
||||
|
||||
|
||||
def verify_ws_token(temp_token: str) -> bool:
|
||||
"""验证并消费 WebSocket 临时 token(一次性使用)
|
||||
|
||||
Args:
|
||||
temp_token: 临时 token
|
||||
|
||||
Returns:
|
||||
验证是否通过
|
||||
"""
|
||||
_cleanup_expired_ws_tokens()
|
||||
if temp_token not in _ws_temp_tokens:
|
||||
logger.warning(f"WS token 不存在: {temp_token[:8]}...")
|
||||
return False
|
||||
expire_time, session_token = _ws_temp_tokens[temp_token]
|
||||
if time.time() > expire_time:
|
||||
del _ws_temp_tokens[temp_token]
|
||||
logger.warning(f"WS token 已过期: {temp_token[:8]}...")
|
||||
return False
|
||||
# 验证原始 session token 仍然有效
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(session_token):
|
||||
del _ws_temp_tokens[temp_token]
|
||||
logger.warning(f"WS token 关联的 session 已失效: {temp_token[:8]}...")
|
||||
return False
|
||||
# 消费 token(一次性使用)
|
||||
del _ws_temp_tokens[temp_token]
|
||||
logger.debug(f"WS token 验证成功: {temp_token[:8]}...")
|
||||
return True
|
||||
|
||||
|
||||
@router.get("/ws-token")
|
||||
async def get_ws_token(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取 WebSocket 连接用的临时 token
|
||||
|
||||
此端点验证当前会话的 Cookie 或 Authorization header,
|
||||
然后返回一个临时 token 用于 WebSocket 握手认证。
|
||||
临时 token 有效期 60 秒,且只能使用一次。
|
||||
|
||||
注意:在未认证时返回 200 状态码但 success=False,避免前端因 401 刷新页面。
|
||||
"""
|
||||
# 获取当前 session token
|
||||
session_token = None
|
||||
if maibot_session:
|
||||
session_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
session_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not session_token:
|
||||
# 返回 200 但 success=False,避免前端因 401 刷新页面
|
||||
# 这在登录页面是正常情况,不应该触发错误处理
|
||||
logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)")
|
||||
return {"success": False, "message": "未提供认证信息,请先登录", "token": None, "expires_in": 0}
|
||||
|
||||
# 验证 session token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(session_token):
|
||||
# 同样返回 200 但 success=False,避免前端刷新
|
||||
logger.debug("ws-token 请求:认证已过期")
|
||||
return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0}
|
||||
|
||||
# 生成临时 WebSocket token
|
||||
ws_token = generate_ws_token(session_token)
|
||||
|
||||
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}
|
||||
Reference in New Issue
Block a user