Merge branch 'dev'

This commit is contained in:
SengokuCola
2025-12-21 16:59:10 +08:00
182 changed files with 26236 additions and 3557 deletions

796
src/webui/anti_crawler.py Normal file
View File

@@ -0,0 +1,796 @@
"""
WebUI 防爬虫模块
提供爬虫检测和阻止功能,保护 WebUI 不被搜索引擎和恶意爬虫访问
"""
import os
import time
import ipaddress
import re
from collections import deque
from typing import Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from src.common.logger import get_logger
logger = get_logger("webui.anti_crawler")
# 常见爬虫 User-Agent 列表(使用更精确的关键词,避免误报)
CRAWLER_USER_AGENTS = {
# 搜索引擎爬虫(精确匹配)
"googlebot",
"bingbot",
"baiduspider",
"yandexbot",
"slurp", # Yahoo
"duckduckbot",
"sogou",
"exabot",
"facebot",
"ia_archiver", # Internet Archive
# 通用爬虫(移除过于宽泛的关键词)
"crawler",
"spider",
"scraper",
"wget", # 保留wget因为通常用于自动化脚本
"scrapy", # 保留scrapy因为这是爬虫框架
# 安全扫描工具(这些是明确的扫描工具)
"masscan",
"nmap",
"nikto",
"sqlmap",
# 注意:移除了以下过于宽泛的关键词以避免误报:
# - "bot" (会误匹配GitHub-Robot等)
# - "curl" (正常工具)
# - "python-requests" (正常库)
# - "httpx" (正常库)
# - "aiohttp" (正常库)
}
# 资产测绘工具 User-Agent 标识
ASSET_SCANNER_USER_AGENTS = {
# 知名资产测绘平台
"shodan",
"censys",
"zoomeye",
"fofa",
"quake",
"hunter",
"binaryedge",
"onyphe",
"securitytrails",
"virustotal",
"passivetotal",
# 安全扫描工具
"acunetix",
"appscan",
"burpsuite",
"nessus",
"openvas",
"qualys",
"rapid7",
"tenable",
"veracode",
"zap",
"awvs", # Acunetix Web Vulnerability Scanner
"netsparker",
"skipfish",
"w3af",
"arachni",
# 其他扫描工具
"masscan",
"zmap",
"nmap",
"whatweb",
"wpscan",
"joomscan",
"dnsenum",
"subfinder",
"amass",
"sublist3r",
"theharvester",
}
# 资产测绘工具常用的HTTP头标识
ASSET_SCANNER_HEADERS = {
# 常见的扫描工具自定义头
"x-scan": {"shodan", "censys", "zoomeye", "fofa"},
"x-scanner": {"nmap", "masscan", "zmap"},
"x-probe": {"masscan", "zmap"},
# 其他可疑头(移除反向代理标准头)
"x-originating-ip": set(),
"x-remote-ip": set(),
"x-remote-addr": set(),
# 注意:移除了以下反向代理标准头以避免误报:
# - "x-forwarded-proto" (反向代理标准头)
# - "x-real-ip" (反向代理标准头已在_get_client_ip中使用)
}
# 仅检查特定HTTP头中的可疑模式收紧匹配范围
# 只检查这些特定头,不检查所有头
SCANNER_SPECIFIC_HEADERS = {
"x-scan",
"x-scanner",
"x-probe",
"x-originating-ip",
"x-remote-ip",
"x-remote-addr",
}
# 防爬虫模式配置
# false: 禁用
# strict: 严格模式(更严格的检测,更低的频率限制)
# loose: 宽松模式(较宽松的检测,较高的频率限制)
# basic: 基础模式只记录恶意访问不阻止不限制请求数不跟踪IP
# IP白名单配置从配置文件读取逗号分隔
# 支持格式:
# - 精确IP127.0.0.1, 192.168.1.100
# - CIDR格式192.168.1.0/24, 172.17.0.0/16 (适用于Docker网络)
# - 通配符192.168.*.*, 10.*.*.*, *.*.*.* (匹配所有)
# - IPv6::1, 2001:db8::/32
def _parse_allowed_ips(ip_string: str) -> list:
"""
解析IP白名单字符串支持精确IP、CIDR格式和通配符
Args:
ip_string: 逗号分隔的IP字符串
Returns:
IP白名单列表每个元素可能是
- ipaddress.IPv4Network/IPv6Network对象CIDR格式
- ipaddress.IPv4Address/IPv6Address对象精确IP
- str通配符模式已转换为正则表达式
"""
allowed = []
if not ip_string:
return allowed
for ip_entry in ip_string.split(","):
ip_entry = ip_entry.strip() # 去除空格
if not ip_entry:
continue
# 跳过注释行(以#开头)
if ip_entry.startswith("#"):
continue
# 检查通配符格式(包含*
if "*" in ip_entry:
# 处理通配符
pattern = _convert_wildcard_to_regex(ip_entry)
if pattern:
allowed.append(pattern)
else:
logger.warning(f"无效的通配符IP格式已忽略: {ip_entry}")
continue
try:
# 尝试解析为CIDR格式包含/
if "/" in ip_entry:
allowed.append(ipaddress.ip_network(ip_entry, strict=False))
else:
# 精确IP地址
allowed.append(ipaddress.ip_address(ip_entry))
except (ValueError, AttributeError) as e:
logger.warning(f"无效的IP白名单条目已忽略: {ip_entry} ({e})")
return allowed
def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
"""
将通配符IP模式转换为正则表达式
支持的格式:
- 192.168.*.* 或 192.168.*
- 10.*.*.* 或 10.*
- *.*.*.* 或 *
Args:
wildcard_pattern: 通配符模式字符串
Returns:
正则表达式字符串如果格式无效则返回None
"""
# 去除空格
pattern = wildcard_pattern.strip()
# 处理单个*(匹配所有)
if pattern == "*":
return r".*"
# 处理IPv4通配符格式
# 支持192.168.*.*, 192.168.*, 10.*.*.*, 10.* 等
parts = pattern.split(".")
if len(parts) > 4:
return None # IPv4最多4段
# 构建正则表达式
regex_parts = []
for part in parts:
part = part.strip()
if part == "*":
regex_parts.append(r"\d+") # 匹配任意数字
elif part.isdigit():
# 验证数字范围0-255
num = int(part)
if 0 <= num <= 255:
regex_parts.append(re.escape(part))
else:
return None # 无效的数字
else:
return None # 无效的格式
# 如果部分少于4段补充.*
while len(regex_parts) < 4:
regex_parts.append(r"\d+")
# 组合成正则表达式
regex = r"^" + r"\.".join(regex_parts) + r"$"
return regex
# 从配置读取防爬虫设置(延迟导入避免循环依赖)
def _get_anti_crawler_config():
"""获取防爬虫配置"""
from src.config.config import global_config
return {
'mode': global_config.webui.anti_crawler_mode,
'allowed_ips': _parse_allowed_ips(global_config.webui.allowed_ips),
'trusted_proxies': _parse_allowed_ips(global_config.webui.trusted_proxies),
'trust_xff': global_config.webui.trust_xff
}
# 初始化配置(将在模块加载时执行)
_config = _get_anti_crawler_config()
ANTI_CRAWLER_MODE = _config['mode']
ALLOWED_IPS = _config['allowed_ips']
TRUSTED_PROXIES = _config['trusted_proxies']
TRUST_XFF = _config['trust_xff']
def _get_mode_config(mode: str) -> dict:
"""
根据模式获取配置参数
Args:
mode: 防爬虫模式 (false/strict/loose/basic)
Returns:
配置字典,包含所有相关参数
"""
mode = mode.lower()
if mode == "false":
return {
"enabled": False,
"rate_limit_window": 60,
"rate_limit_max_requests": 1000, # 禁用时设置很高的值
"max_tracked_ips": 0,
"check_user_agent": False,
"check_asset_scanner": False,
"check_rate_limit": False,
"block_on_detect": False, # 不阻止
}
elif mode == "strict":
return {
"enabled": True,
"rate_limit_window": 60,
"rate_limit_max_requests": 15, # 严格模式:更低的请求数
"max_tracked_ips": 20000,
"check_user_agent": True,
"check_asset_scanner": True,
"check_rate_limit": True,
"block_on_detect": True, # 阻止恶意访问
}
elif mode == "loose":
return {
"enabled": True,
"rate_limit_window": 60,
"rate_limit_max_requests": 60, # 宽松模式:更高的请求数
"max_tracked_ips": 5000,
"check_user_agent": True,
"check_asset_scanner": True,
"check_rate_limit": True,
"block_on_detect": True, # 阻止恶意访问
}
else: # basic (默认模式)
return {
"enabled": True,
"rate_limit_window": 60,
"rate_limit_max_requests": 1000, # 不限制请求数
"max_tracked_ips": 0, # 不跟踪IP
"check_user_agent": True, # 检测但不阻止
"check_asset_scanner": True, # 检测但不阻止
"check_rate_limit": False, # 不限制请求频率
"block_on_detect": False, # 只记录,不阻止
}
class AntiCrawlerMiddleware(BaseHTTPMiddleware):
"""防爬虫中间件"""
def __init__(self, app, mode: str = "standard"):
"""
初始化防爬虫中间件
Args:
app: FastAPI 应用实例
mode: 防爬虫模式 (false/strict/loose/standard)
"""
super().__init__(app)
self.mode = mode.lower()
# 根据模式获取配置
config = _get_mode_config(self.mode)
self.enabled = config["enabled"]
self.rate_limit_window = config["rate_limit_window"]
self.rate_limit_max_requests = config["rate_limit_max_requests"]
self.max_tracked_ips = config["max_tracked_ips"]
self.check_user_agent = config["check_user_agent"]
self.check_asset_scanner = config["check_asset_scanner"]
self.check_rate_limit = config["check_rate_limit"]
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
# 用于存储每个IP的请求时间戳使用deque提高性能
self.request_times: dict[str, deque] = {}
# 上次清理时间
self.last_cleanup = time.time()
# 将关键词列表转换为集合以提高查找性能
self.crawler_keywords_set = set(CRAWLER_USER_AGENTS)
self.scanner_keywords_set = set(ASSET_SCANNER_USER_AGENTS)
def _is_crawler_user_agent(self, user_agent: Optional[str]) -> bool:
"""
检测是否为爬虫 User-Agent
Args:
user_agent: User-Agent 字符串
Returns:
如果是爬虫则返回 True
"""
if not user_agent:
# 没有 User-Agent 的请求记录日志但不直接阻止
# 改为只记录,让频率限制来处理
logger.debug("请求缺少User-Agent")
return False # 不再直接阻止无User-Agent的请求
user_agent_lower = user_agent.lower()
# 使用集合查找提高性能(检查是否包含爬虫关键词)
for crawler_keyword in self.crawler_keywords_set:
if crawler_keyword in user_agent_lower:
return True
return False
def _is_asset_scanner_header(self, request: Request) -> bool:
"""
检测是否为资产测绘工具的HTTP头只检查特定头收紧匹配
Args:
request: 请求对象
Returns:
如果检测到资产测绘工具头则返回 True
"""
# 只检查特定的扫描工具头,不检查所有头
for header_name, header_value in request.headers.items():
header_name_lower = header_name.lower()
header_value_lower = header_value.lower() if header_value else ""
# 检查已知的扫描工具头
if header_name_lower in ASSET_SCANNER_HEADERS:
# 如果该头有特定的工具集合,检查值是否匹配
expected_tools = ASSET_SCANNER_HEADERS[header_name_lower]
if expected_tools:
for tool in expected_tools:
if tool in header_value_lower:
return True
else:
# 如果没有特定工具集合,只要存在该头就视为可疑
if header_value_lower:
return True
# 只检查特定头中的可疑模式(收紧匹配)
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
# 检查头值中是否包含已知扫描工具名称
for tool in self.scanner_keywords_set:
if tool in header_value_lower:
return True
return False
def _detect_asset_scanner(self, request: Request) -> tuple[bool, Optional[str]]:
"""
检测资产测绘工具
Args:
request: 请求对象
Returns:
(是否检测到, 检测到的工具名称)
"""
user_agent = request.headers.get("User-Agent")
# 检查 User-Agent使用集合查找提高性能
if user_agent:
user_agent_lower = user_agent.lower()
for scanner_keyword in self.scanner_keywords_set:
if scanner_keyword in user_agent_lower:
return True, scanner_keyword
# 检查HTTP头
if self._is_asset_scanner_header(request):
# 尝试从User-Agent或头中提取工具名称
detected_tool = None
if user_agent:
user_agent_lower = user_agent.lower()
for tool in self.scanner_keywords_set:
if tool in user_agent_lower:
detected_tool = tool
break
# 检查HTTP头中的工具标识只检查特定头
if not detected_tool:
for header_name, header_value in request.headers.items():
header_name_lower = header_name.lower()
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
header_value_lower = (header_value or "").lower()
for tool in self.scanner_keywords_set:
if tool in header_value_lower:
detected_tool = tool
break
if detected_tool:
break
return True, detected_tool or "unknown_scanner"
return False, None
def _check_rate_limit(self, client_ip: str) -> bool:
"""
检查请求频率限制
Args:
client_ip: 客户端IP地址
Returns:
如果超过限制则返回 True需要阻止
"""
# 检查IP白名单
if self._is_ip_allowed(client_ip):
return False
current_time = time.time()
# 定期清理过期的请求记录每5分钟清理一次
if current_time - self.last_cleanup > 300:
self._cleanup_old_requests(current_time)
self.last_cleanup = current_time
# 限制跟踪的IP数量防止内存泄漏
if self.max_tracked_ips > 0 and len(self.request_times) > self.max_tracked_ips:
# 清理最旧的记录删除最久未访问的IP
self._cleanup_oldest_ips()
# 获取或创建该IP的请求时间deque不使用maxlen避免限流变松
if client_ip not in self.request_times:
self.request_times[client_ip] = deque()
request_times = self.request_times[client_ip]
# 移除时间窗口外的请求记录(从左侧弹出过期记录)
while request_times and current_time - request_times[0] >= self.rate_limit_window:
request_times.popleft()
# 检查是否超过限制
if len(request_times) >= self.rate_limit_max_requests:
return True
# 记录当前请求时间
request_times.append(current_time)
return False
def _cleanup_old_requests(self, current_time: float):
"""清理过期的请求记录只清理当前需要检查的IP不全量遍历"""
# 这个方法现在主要用于定期清理实际清理在_check_rate_limit中按需进行
# 清理最久未访问的IP记录
if len(self.request_times) > self.max_tracked_ips * 0.8:
self._cleanup_oldest_ips()
def _cleanup_oldest_ips(self):
"""清理最久未访问的IP记录全量遍历找真正的oldest"""
if not self.request_times:
return
# 先收集空deque的IP优先删除
empty_ips = []
# 找到最久未访问的IP最旧时间戳
oldest_ip = None
oldest_time = float("inf")
# 全量遍历找真正的oldest超限时性能可接受
for ip, times in self.request_times.items():
if not times:
# 空deque记录待删除
empty_ips.append(ip)
else:
# 找到最旧的时间戳
if times[0] < oldest_time:
oldest_time = times[0]
oldest_ip = ip
# 先删除空deque的IP
for ip in empty_ips:
del self.request_times[ip]
# 如果没有空deque可删除且仍需要清理删除最旧的一个IP
if not empty_ips and oldest_ip:
del self.request_times[oldest_ip]
def _is_trusted_proxy(self, ip: str) -> bool:
"""
检查IP是否在信任的代理列表中
Args:
ip: IP地址字符串
Returns:
如果是信任的代理则返回 True
"""
if not TRUSTED_PROXIES or ip == "unknown":
return False
# 检查代理列表中的每个条目
for trusted_entry in TRUSTED_PROXIES:
# 通配符模式(字符串,正则表达式)
if isinstance(trusted_entry, str):
try:
if re.match(trusted_entry, ip):
return True
except re.error:
continue
# CIDR格式网络对象
elif isinstance(trusted_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj in trusted_entry:
return True
except (ValueError, AttributeError):
continue
# 精确IP地址对象
elif isinstance(trusted_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj == trusted_entry:
return True
except (ValueError, AttributeError):
continue
return False
def _get_client_ip(self, request: Request) -> str:
"""
获取客户端真实IP地址带基本验证和代理信任检查
Args:
request: 请求对象
Returns:
客户端IP地址
"""
# 获取直接连接的客户端IP用于验证代理
direct_client_ip = None
if request.client:
direct_client_ip = request.client.host
# 检查是否信任X-Forwarded-For头
# TRUST_XFF 只表示"启用代理解析能力",但仍要求直连 IP 在 TRUSTED_PROXIES 中
use_xff = False
if TRUST_XFF and TRUSTED_PROXIES and direct_client_ip:
# 只有在启用 TRUST_XFF 且直连 IP 在信任列表中时,才信任 XFF
use_xff = self._is_trusted_proxy(direct_client_ip)
# 如果信任代理,优先从 X-Forwarded-For 获取
if use_xff:
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For 可能包含多个IP取第一个
ip = forwarded_for.split(",")[0].strip()
# 基本验证IP格式
if self._validate_ip(ip):
return ip
# 从 X-Real-IP 获取(如果信任代理)
if use_xff:
real_ip = request.headers.get("X-Real-IP")
if real_ip:
ip = real_ip.strip()
if self._validate_ip(ip):
return ip
# 使用直接连接的客户端IP
if direct_client_ip and self._validate_ip(direct_client_ip):
return direct_client_ip
return "unknown"
def _validate_ip(self, ip: str) -> bool:
"""
验证IP地址格式
Args:
ip: IP地址字符串
Returns:
如果格式有效则返回 True
"""
try:
ipaddress.ip_address(ip)
return True
except (ValueError, AttributeError):
return False
def _is_ip_allowed(self, ip: str) -> bool:
"""
检查IP是否在白名单中支持精确IP、CIDR格式和通配符
Args:
ip: 客户端IP地址
Returns:
如果IP在白名单中则返回 True
"""
if not ALLOWED_IPS or ip == "unknown":
return False
# 检查白名单中的每个条目
for allowed_entry in ALLOWED_IPS:
# 通配符模式(字符串,正则表达式)
if isinstance(allowed_entry, str):
try:
if re.match(allowed_entry, ip):
return True
except re.error:
# 正则表达式错误,跳过
continue
# CIDR格式网络对象
elif isinstance(allowed_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj in allowed_entry:
return True
except (ValueError, AttributeError):
# IP格式无效跳过
continue
# 精确IP地址对象
elif isinstance(allowed_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj == allowed_entry:
return True
except (ValueError, AttributeError):
# IP格式无效跳过
continue
return False
async def dispatch(self, request: Request, call_next):
"""
处理请求
Args:
request: 请求对象
call_next: 下一个中间件或路由处理函数
Returns:
响应对象
"""
# 如果未启用,直接通过
if not self.enabled:
return await call_next(request)
# 允许访问 robots.txt由专门的路由处理
if request.url.path == "/robots.txt":
return await call_next(request)
# 允许访问静态资源CSS、JS、图片等
# 注意:.json 已移除,避免 API 路径绕过防护
# 静态资源只在特定前缀下放行(/static/、/assets/、/dist/
static_extensions = {
".css",
".js",
".png",
".jpg",
".jpeg",
".gif",
".svg",
".ico",
".woff",
".woff2",
".ttf",
".eot",
}
static_prefixes = {"/static/", "/assets/", "/dist/"}
# 检查是否是静态资源路径(特定前缀下的静态文件)
path = request.url.path
is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(
path.endswith(ext) for ext in static_extensions
)
# 也允许根路径下的静态文件(如 /favicon.ico
is_root_static = path.count("/") == 1 and any(path.endswith(ext) for ext in static_extensions)
if is_static_path or is_root_static:
return await call_next(request)
# 获取客户端IP只获取一次避免重复调用
client_ip = self._get_client_ip(request)
# 检查IP白名单优先检查白名单IP直接通过
if self._is_ip_allowed(client_ip):
return await call_next(request)
# 获取 User-Agent
user_agent = request.headers.get("User-Agent")
# 检测资产测绘工具(优先检测,因为更危险)
if self.check_asset_scanner:
is_scanner, scanner_name = self._detect_asset_scanner(request)
if is_scanner:
logger.warning(
f"🚫 检测到资产测绘工具请求 - IP: {client_ip}, 工具: {scanner_name}, "
f"User-Agent: {user_agent}, Path: {request.url.path}"
)
# 根据配置决定是否阻止
if self.block_on_detect:
return PlainTextResponse(
"Access Denied: Asset scanning tools are not allowed",
status_code=403,
)
# 检测爬虫 User-Agent
if self.check_user_agent and self._is_crawler_user_agent(user_agent):
logger.warning(f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
# 根据配置决定是否阻止
if self.block_on_detect:
return PlainTextResponse(
"Access Denied: Crawlers are not allowed",
status_code=403,
)
# 检查请求频率限制
if self.check_rate_limit and self._check_rate_limit(client_ip):
logger.warning(f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
return PlainTextResponse(
"Too Many Requests: Rate limit exceeded",
status_code=429,
)
# 正常请求,继续处理
return await call_next(request)
def create_robots_txt_response() -> PlainTextResponse:
"""
创建 robots.txt 响应
Returns:
robots.txt 响应对象
"""
robots_content = """User-agent: *
Disallow: /
# 禁止所有爬虫访问
"""
return PlainTextResponse(
content=robots_content,
media_type="text/plain",
headers={"Cache-Control": "public, max-age=86400"}, # 缓存24小时
)

View File

@@ -6,6 +6,7 @@ WebUI 认证模块
from typing import Optional
from fastapi import HTTPException, Cookie, Header, Response, Request
from src.common.logger import get_logger
from src.config.config import global_config
from .token_manager import get_token_manager
logger = get_logger("webui.auth")
@@ -15,6 +16,28 @@ COOKIE_NAME = "maibot_session"
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
def _is_secure_environment() -> bool:
"""
检测是否应该启用安全 CookieHTTPS
Returns:
bool: 如果应该使用 secure cookie 则返回 True
"""
# 从配置读取
if global_config.webui.secure_cookie:
logger.info("配置中启用了 secure_cookie")
return True
# 检查是否是生产环境
if global_config.webui.mode == "production":
logger.info("WebUI运行在生产模式启用 secure cookie")
return True
# 默认:开发环境不启用(因为通常是 HTTP
logger.debug("WebUI运行在开发模式禁用 secure cookie")
return False
def get_current_token(
request: Request,
maibot_session: Optional[str] = Cookie(None),
@@ -22,69 +45,102 @@ def get_current_token(
) -> str:
"""
获取当前请求的 token优先从 Cookie 获取,其次从 Header 获取
Args:
request: FastAPI Request 对象
maibot_session: Cookie 中的 token
authorization: Authorization Header (Bearer token)
Returns:
验证通过的 token
Raises:
HTTPException: 认证失败时抛出 401 错误
"""
token = None
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
# 其次从 Header 获取(兼容旧版本)
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return token
def set_auth_cookie(response: Response, token: str) -> None:
def set_auth_cookie(response: Response, token: str, request: Optional[Request] = None) -> None:
"""
设置认证 Cookie
Args:
response: FastAPI Response 对象
token: 要设置的 token
request: FastAPI Request 对象(可选,用于检测协议)
"""
# 根据环境和实际请求协议决定安全设置
is_secure = _is_secure_environment()
# 如果提供了 request检测实际使用的协议
if request:
# 检查 X-Forwarded-Proto header代理/负载均衡器)
forwarded_proto = request.headers.get("x-forwarded-proto", "").lower()
if forwarded_proto:
is_https = forwarded_proto == "https"
logger.debug(f"检测到 X-Forwarded-Proto: {forwarded_proto}, is_https={is_https}")
else:
# 检查 request.url.scheme
is_https = request.url.scheme == "https"
logger.debug(f"检测到 scheme: {request.url.scheme}, is_https={is_https}")
# 如果是 HTTP 连接,强制禁用 secure 标志
if not is_https and is_secure:
logger.warning("=" * 80)
logger.warning("检测到 HTTP 连接但环境配置要求 HTTPS (secure cookie)")
logger.warning("已自动禁用 secure 标志以允许登录,但建议修改配置:")
logger.warning("1. 在配置文件中设置: webui.secure_cookie = false")
logger.warning("2. 如果使用反向代理,请确保正确配置 X-Forwarded-Proto 头")
logger.warning("=" * 80)
is_secure = False
# 设置 Cookie
response.set_cookie(
key=COOKIE_NAME,
value=token,
max_age=COOKIE_MAX_AGE,
httponly=True, # 防止 JS 读取
samesite="lax", # 允许同站导航时发送 Cookie兼容开发环境代理
secure=False, # 本地开发不强制 HTTPS生产环境建议设为 True
httponly=True, # 防止 JS 读取,阻止 XSS 窃取
samesite="lax", # 使用 lax 以兼容更多场景(开发和生产
secure=is_secure, # 根据实际协议决定
path="/", # 确保 Cookie 在所有路径下可用
)
logger.debug(f"已设置认证 Cookie: {token[:8]}...")
logger.info(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure}, samesite=lax, httponly=True, path=/, max_age={COOKIE_MAX_AGE})")
logger.debug(f"完整 token 前缀: {token[:20]}...")
def clear_auth_cookie(response: Response) -> None:
"""
清除认证 Cookie
Args:
response: FastAPI Response 对象
"""
# 保持与 set_auth_cookie 相同的安全设置
is_secure = _is_secure_environment()
response.delete_cookie(
key=COOKIE_NAME,
httponly=True,
samesite="lax",
samesite="strict" if is_secure else "lax",
secure=is_secure,
path="/",
)
logger.debug("已清除认证 Cookie")
@@ -96,32 +152,32 @@ def verify_auth_token_from_cookie_or_header(
) -> bool:
"""
验证认证 Token支持从 Cookie 或 Header 获取
Args:
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
验证成功返回 True
Raises:
HTTPException: 认证失败时抛出 401 错误
"""
token = None
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
# 其次从 Header 获取(兼容旧版本)
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True

View File

@@ -8,18 +8,30 @@
import time
import uuid
from typing import Dict, Any, Optional, List
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
from pydantic import BaseModel
from src.common.logger import get_logger
from src.common.database.database_model import Messages, PersonInfo
from src.config.config import global_config
from src.chat.message_receive.bot import chat_bot
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
logger = get_logger("webui.chat")
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# WebUI 聊天的虚拟群组 ID
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
WEBUI_CHAT_PLATFORM = "webui"
@@ -63,14 +75,14 @@ class ChatHistoryManager:
def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]:
"""将数据库消息转换为前端格式
Args:
msg: 数据库消息对象
group_id: 群 ID用于判断是否是虚拟群
"""
# 判断是否是机器人消息
user_id = msg.user_id or ""
# 对于虚拟群,通过比较机器人 QQ 账号来判断
# 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头
if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
@@ -256,6 +268,7 @@ async def get_chat_history(
limit: int = Query(default=50, ge=1, le=200),
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
_auth: bool = Depends(require_auth),
):
"""获取聊天历史记录
@@ -272,7 +285,7 @@ async def get_chat_history(
@router.get("/platforms")
async def get_available_platforms():
async def get_available_platforms(_auth: bool = Depends(require_auth)):
"""获取可用平台列表
从 PersonInfo 表中获取所有已知的平台
@@ -303,6 +316,7 @@ async def get_persons_by_platform(
platform: str = Query(..., description="平台名称"),
search: Optional[str] = Query(default=None, description="搜索关键词"),
limit: int = Query(default=50, ge=1, le=200),
_auth: bool = Depends(require_auth),
):
"""获取指定平台的用户列表
@@ -350,7 +364,7 @@ async def get_persons_by_platform(
@router.delete("/history")
async def clear_chat_history(group_id: Optional[str] = Query(default=None)):
async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)):
"""清空聊天历史记录
Args:
@@ -372,6 +386,7 @@ async def websocket_chat(
person_id: Optional[str] = Query(default=None),
group_name: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
token: Optional[str] = Query(default=None), # 认证 token
):
"""WebSocket 聊天端点
@@ -382,9 +397,45 @@ async def websocket_chat(
person_id: 虚拟身份模式的用户 person_id可选
group_name: 虚拟身份模式的群名(可选)
group_id: 虚拟身份模式的群 ID可选由前端生成并持久化
token: 认证 token可选也可从 Cookie 获取)
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
支持三种认证方式(按优先级):
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/api/chat/ws?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
# 生成会话 ID每次连接都是新的
session_id = str(uuid.uuid4())
@@ -414,7 +465,9 @@ async def websocket_chat(
group_id=virtual_group_id,
group_name=group_name or "WebUI虚拟群聊",
)
logger.info(f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}")
logger.info(
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
)
except Exception as e:
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
@@ -710,7 +763,7 @@ async def websocket_chat(
@router.get("/info")
async def get_chat_info():
async def get_chat_info(_auth: bool = Depends(require_auth)):
"""获取聊天室信息"""
return {
"bot_name": global_config.bot.nickname,

View File

@@ -4,11 +4,12 @@
import os
import tomlkit
from fastapi import APIRouter, HTTPException, Body
from typing import Any, Annotated
from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
from typing import Any, Annotated, Optional
from src.common.logger import get_logger
from src.common.toml_utils import save_toml_with_format
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
from src.config.official_configs import (
BotConfig,
@@ -29,9 +30,7 @@ from src.config.official_configs import (
ToolConfig,
MemoryConfig,
DebugConfig,
MoodConfig,
VoiceConfig,
JargonConfig,
)
from src.config.api_ada_configs import (
ModelTaskConfig,
@@ -51,45 +50,19 @@ PathBody = Annotated[dict[str, str], Body()]
router = APIRouter(prefix="/config", tags=["config"])
# ===== 辅助函数 =====
def _update_dict_preserve_comments(target: Any, source: Any) -> None:
"""
递归合并字典,保留 target 中的注释和格式
将 source 的值更新到 target 中(仅更新已存在的键)
Args:
target: 目标字典tomlkit 对象,包含注释)
source: 源字典(普通 dict 或 list
"""
# 如果 source 是列表,直接替换(数组表没有注释保留的意义)
if isinstance(source, list):
return # 调用者需要直接赋值
# 如果都是字典,递归合并
if isinstance(source, dict) and isinstance(target, dict):
for key, value in source.items():
if key == "version":
continue # 跳过版本号
if key in target:
target_value = target[key]
# 递归处理嵌套字典
if isinstance(value, dict) and isinstance(target_value, dict):
_update_dict_preserve_comments(target_value, value)
else:
# 使用 tomlkit.item 保持类型
try:
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
target[key] = value
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# ===== 架构获取接口 =====
@router.get("/schema/bot")
async def get_bot_config_schema():
async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置架构"""
try:
# Config 类包含所有子配置
@@ -101,7 +74,7 @@ async def get_bot_config_schema():
@router.get("/schema/model")
async def get_model_config_schema():
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
"""获取模型配置架构(包含提供商和模型任务配置)"""
try:
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
@@ -115,7 +88,7 @@ async def get_model_config_schema():
@router.get("/schema/section/{section_name}")
async def get_config_section_schema(section_name: str):
async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)):
"""
获取指定配置节的架构
@@ -138,7 +111,6 @@ async def get_config_section_schema(section_name: str):
- tool: ToolConfig
- memory: MemoryConfig
- debug: DebugConfig
- mood: MoodConfig
- voice: VoiceConfig
- jargon: JargonConfig
- model_task_config: ModelTaskConfig
@@ -164,9 +136,7 @@ async def get_config_section_schema(section_name: str):
"tool": ToolConfig,
"memory": MemoryConfig,
"debug": DebugConfig,
"mood": MoodConfig,
"voice": VoiceConfig,
"jargon": JargonConfig,
"model_task_config": ModelTaskConfig,
"api_provider": APIProvider,
"model_info": ModelInfo,
@@ -188,7 +158,7 @@ async def get_config_section_schema(section_name: str):
@router.get("/bot")
async def get_bot_config():
async def get_bot_config(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
@@ -207,7 +177,7 @@ async def get_bot_config():
@router.get("/model")
async def get_model_config():
async def get_model_config(_auth: bool = Depends(require_auth)):
"""获取模型配置(包含提供商和模型任务配置)"""
try:
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
@@ -229,7 +199,7 @@ async def get_model_config():
@router.post("/bot")
async def update_bot_config(config_data: ConfigBody):
async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置"""
try:
# 验证配置数据
@@ -238,7 +208,7 @@ async def update_bot_config(config_data: ConfigBody):
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件(格式化数组为多行
# 保存配置文件(自动保留注释和格式)
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
save_toml_with_format(config_data, config_path)
@@ -252,7 +222,7 @@ async def update_bot_config(config_data: ConfigBody):
@router.post("/model")
async def update_model_config(config_data: ConfigBody):
async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
"""更新模型配置"""
try:
# 验证配置数据
@@ -261,7 +231,7 @@ async def update_model_config(config_data: ConfigBody):
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件(格式化数组为多行
# 保存配置文件(自动保留注释和格式)
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
save_toml_with_format(config_data, config_path)
@@ -278,7 +248,7 @@ async def update_model_config(config_data: ConfigBody):
@router.post("/bot/section/{section_name}")
async def update_bot_config_section(section_name: str, section_data: SectionBody):
async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
@@ -300,7 +270,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
config_data[section_name] = section_data
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
# 字典递归合并
_update_dict_preserve_comments(config_data[section_name], section_data)
_update_toml_doc(config_data[section_name], section_data)
else:
# 其他类型直接替换
config_data[section_name] = section_data
@@ -327,7 +297,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
@router.get("/bot/raw")
async def get_bot_config_raw():
async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置的原始 TOML 内容"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
@@ -346,7 +316,7 @@ async def get_bot_config_raw():
@router.post("/bot/raw")
async def update_bot_config_raw(raw_content: RawContentBody):
async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
try:
# 验证 TOML 格式
@@ -376,7 +346,9 @@ async def update_bot_config_raw(raw_content: RawContentBody):
@router.post("/model/section/{section_name}")
async def update_model_config_section(section_name: str, section_data: SectionBody):
async def update_model_config_section(
section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)
):
"""更新模型配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
@@ -398,7 +370,7 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
config_data[section_name] = section_data
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
# 字典递归合并
_update_dict_preserve_comments(config_data[section_name], section_data)
_update_toml_doc(config_data[section_name], section_data)
else:
# 其他类型直接替换
config_data[section_name] = section_data
@@ -407,6 +379,17 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
try:
APIAdapterConfig.from_dict(config_data)
except Exception as e:
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
# 特殊处理:如果是更新 api_providers检查是否有模型引用了已删除的provider
if section_name == "api_providers" and "api_provider" in str(e):
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
models = config_data.get("models", [])
orphaned_models = [
m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
]
if orphaned_models:
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
raise HTTPException(status_code=400, detail=error_msg) from e
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置(格式化数组为多行,保留注释)
@@ -457,7 +440,7 @@ def _to_relative_path(path: str) -> str:
@router.get("/adapter-config/path")
async def get_adapter_config_path():
async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
"""获取保存的适配器配置文件路径"""
try:
# 从 data/webui.json 读取路径偏好
@@ -496,7 +479,7 @@ async def get_adapter_config_path():
@router.post("/adapter-config/path")
async def save_adapter_config_path(data: PathBody):
async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)):
"""保存适配器配置文件路径偏好"""
try:
path = data.get("path")
@@ -539,7 +522,7 @@ async def save_adapter_config_path(data: PathBody):
@router.get("/adapter-config")
async def get_adapter_config(path: str):
async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
"""从指定路径读取适配器配置文件"""
try:
if not path:
@@ -571,7 +554,7 @@ async def get_adapter_config(path: str):
@router.post("/adapter-config")
async def save_adapter_config(data: PathBody):
async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)):
"""保存适配器配置到指定路径"""
try:
path = data.get("path")

View File

@@ -296,7 +296,6 @@ class ConfigSchemaGenerator:
"plan_style",
"visual_style",
"private_plan_style",
"emotion_style",
"reaction",
"filtration_prompt",
]:

View File

@@ -1,4 +1,4 @@
""" 表情包管理 API 路由"""
"""表情包管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
from fastapi.responses import FileResponse, JSONResponse
@@ -48,7 +48,7 @@ def _get_thumbnail_lock(file_hash: str) -> threading.Lock:
def _background_generate_thumbnail(source_path: str, file_hash: str) -> None:
"""
后台生成缩略图(在线程池中执行)
生成完成后自动从 generating 集合中移除
"""
try:
@@ -74,14 +74,14 @@ def _get_thumbnail_cache_path(file_hash: str) -> Path:
def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
"""
生成缩略图并保存到缓存目录
Args:
source_path: 原图路径
file_hash: 文件哈希值,用作缓存文件名
Returns:
缩略图路径
Features:
- GIF: 提取第一帧作为缩略图
- 所有格式统一转为 WebP
@@ -89,63 +89,63 @@ def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
"""
_ensure_thumbnail_cache_dir()
cache_path = _get_thumbnail_cache_path(file_hash)
# 使用锁防止并发生成同一缩略图
lock = _get_thumbnail_lock(file_hash)
with lock:
# 双重检查,可能在等待锁时已被其他线程生成
if cache_path.exists():
return cache_path
try:
with Image.open(source_path) as img:
# GIF 处理:提取第一帧
if hasattr(img, 'n_frames') and img.n_frames > 1:
if hasattr(img, "n_frames") and img.n_frames > 1:
img.seek(0) # 确保在第一帧
# 转换为 RGB/RGBAWebP 支持透明度)
if img.mode in ('P', 'PA'):
if img.mode in ("P", "PA"):
# 调色板模式转换为 RGBA 以保留透明度
img = img.convert('RGBA')
elif img.mode == 'LA':
img = img.convert('RGBA')
elif img.mode not in ('RGB', 'RGBA'):
img = img.convert('RGB')
img = img.convert("RGBA")
elif img.mode == "LA":
img = img.convert("RGBA")
elif img.mode not in ("RGB", "RGBA"):
img = img.convert("RGB")
# 创建缩略图(保持宽高比)
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
# 保存为 WebP 格式
img.save(cache_path, 'WEBP', quality=THUMBNAIL_QUALITY, method=6)
img.save(cache_path, "WEBP", quality=THUMBNAIL_QUALITY, method=6)
logger.debug(f"生成缩略图: {file_hash} -> {cache_path}")
except Exception as e:
logger.warning(f"生成缩略图失败 {file_hash}: {e},将返回原图")
# 生成失败时不创建缓存文件,下次会重试
raise
return cache_path
def cleanup_orphaned_thumbnails() -> tuple[int, int]:
"""
清理孤立的缩略图缓存(原图已不存在的缩略图)
Returns:
(清理数量, 保留数量)
"""
if not THUMBNAIL_CACHE_DIR.exists():
return 0, 0
# 获取所有表情包的哈希值
valid_hashes = set()
for emoji in Emoji.select(Emoji.emoji_hash):
valid_hashes.add(emoji.emoji_hash)
cleaned = 0
kept = 0
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
file_hash = cache_file.stem
if file_hash not in valid_hashes:
@@ -157,12 +157,13 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]:
logger.warning(f"清理缩略图失败 {cache_file.name}: {e}")
else:
kept += 1
if cleaned > 0:
logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept}")
return cleaned, kept
# 模块级别的类型别名(解决 B008 ruff 错误)
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
@@ -365,7 +366,9 @@ async def get_emoji_list(
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_emoji_detail(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表情包详细信息
@@ -394,7 +397,12 @@ async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def update_emoji(
emoji_id: int,
request: EmojiUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新表情包(只更新提供的字段)
@@ -446,7 +454,9 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_sessio
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def delete_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除表情包
@@ -538,7 +548,9 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def register_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
注册表情包(快捷操作)
@@ -578,7 +590,9 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def ban_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
禁用表情包(快捷操作)
@@ -633,7 +647,7 @@ async def get_emoji_thumbnail(
Returns:
表情包缩略图WebP 格式)或原图
Features:
- 懒加载:首次请求时生成缩略图
- 缓存:后续请求直接返回缓存
@@ -643,7 +657,7 @@ async def get_emoji_thumbnail(
try:
token_manager = get_token_manager()
is_valid = False
# 1. 优先使用 Cookie
if maibot_session and token_manager.verify_token(maibot_session):
is_valid = True
@@ -655,7 +669,7 @@ async def get_emoji_thumbnail(
auth_token = authorization.replace("Bearer ", "")
if token_manager.verify_token(auth_token):
is_valid = True
if not is_valid:
raise HTTPException(status_code=401, detail="Token 无效或已过期")
@@ -680,35 +694,27 @@ async def get_emoji_thumbnail(
}
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
return FileResponse(
path=emoji.full_path,
media_type=media_type,
filename=f"{emoji.emoji_hash}.{emoji.format}"
path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}"
)
# 尝试获取或生成缩略图
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
# 检查缓存是否存在
if cache_path.exists():
# 缓存命中,直接返回
return FileResponse(
path=str(cache_path),
media_type="image/webp",
filename=f"{emoji.emoji_hash}_thumb.webp"
path=str(cache_path), media_type="image/webp", filename=f"{emoji.emoji_hash}_thumb.webp"
)
# 缓存未命中,触发后台生成并返回 202
with _generating_lock:
if emoji.emoji_hash not in _generating_thumbnails:
# 标记为正在生成
_generating_thumbnails.add(emoji.emoji_hash)
# 提交到线程池后台生成
_thumbnail_executor.submit(
_background_generate_thumbnail,
emoji.full_path,
emoji.emoji_hash
)
_thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.emoji_hash)
# 返回 202 Accepted告诉前端缩略图正在生成中
return JSONResponse(
status_code=202,
@@ -719,7 +725,7 @@ async def get_emoji_thumbnail(
},
headers={
"Retry-After": "1", # 建议 1 秒后重试
}
},
)
except HTTPException:
@@ -730,7 +736,11 @@ async def get_emoji_thumbnail(
@router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_emojis(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def batch_delete_emojis(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除表情包
@@ -1079,7 +1089,7 @@ async def batch_upload_emoji(
class ThumbnailCacheStatsResponse(BaseModel):
"""缩略图缓存统计响应"""
success: bool
cache_dir: str
total_count: int
@@ -1090,7 +1100,7 @@ class ThumbnailCacheStatsResponse(BaseModel):
class ThumbnailCleanupResponse(BaseModel):
"""缩略图清理响应"""
success: bool
message: str
cleaned_count: int
@@ -1099,7 +1109,7 @@ class ThumbnailCleanupResponse(BaseModel):
class ThumbnailPreheatResponse(BaseModel):
"""缩略图预热响应"""
success: bool
message: str
generated_count: int
@@ -1114,27 +1124,27 @@ async def get_thumbnail_cache_stats(
):
"""
获取缩略图缓存统计信息
Returns:
缓存目录、缓存数量、总大小、覆盖率等统计信息
"""
try:
verify_auth_token(maibot_session, authorization)
_ensure_thumbnail_cache_dir()
# 统计缓存文件
cache_files = list(THUMBNAIL_CACHE_DIR.glob("*.webp"))
total_count = len(cache_files)
total_size = sum(f.stat().st_size for f in cache_files)
total_size_mb = round(total_size / (1024 * 1024), 2)
# 统计表情包总数
emoji_count = Emoji.select().count()
# 计算覆盖率
coverage_percent = round((total_count / emoji_count * 100) if emoji_count > 0 else 0, 1)
return ThumbnailCacheStatsResponse(
success=True,
cache_dir=str(THUMBNAIL_CACHE_DIR.absolute()),
@@ -1143,7 +1153,7 @@ async def get_thumbnail_cache_stats(
emoji_count=emoji_count,
coverage_percent=coverage_percent,
)
except HTTPException:
raise
except Exception as e:
@@ -1158,22 +1168,22 @@ async def cleanup_thumbnail_cache(
):
"""
清理孤立的缩略图缓存(原图已删除的表情包对应的缩略图)
Returns:
清理结果
"""
try:
verify_auth_token(maibot_session, authorization)
cleaned, kept = cleanup_orphaned_thumbnails()
return ThumbnailCleanupResponse(
success=True,
message=f"清理完成:删除 {cleaned} 个孤立缓存,保留 {kept} 个有效缓存",
cleaned_count=cleaned,
kept_count=kept,
)
except HTTPException:
raise
except Exception as e:
@@ -1189,20 +1199,20 @@ async def preheat_thumbnail_cache(
):
"""
预热缩略图缓存(提前生成未缓存的缩略图)
优先处理使用次数高的表情包
Args:
limit: 最多预热数量 (1-1000)
Returns:
预热结果
"""
try:
verify_auth_token(maibot_session, authorization)
_ensure_thumbnail_cache_dir()
# 获取使用次数最高的表情包(未缓存的优先)
emojis = (
Emoji.select()
@@ -1210,41 +1220,36 @@ async def preheat_thumbnail_cache(
.order_by(Emoji.usage_count.desc())
.limit(limit * 2) # 多查一些,因为有些可能已缓存
)
generated = 0
skipped = 0
failed = 0
for emoji in emojis:
if generated >= limit:
break
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
# 已缓存,跳过
if cache_path.exists():
skipped += 1
continue
# 原文件不存在,跳过
if not os.path.exists(emoji.full_path):
failed += 1
continue
try:
# 使用线程池异步生成缩略图,避免阻塞事件循环
loop = asyncio.get_event_loop()
await loop.run_in_executor(
_thumbnail_executor,
_generate_thumbnail,
emoji.full_path,
emoji.emoji_hash
)
await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.emoji_hash)
generated += 1
except Exception as e:
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}")
failed += 1
return ThumbnailPreheatResponse(
success=True,
message=f"预热完成:生成 {generated} 个,跳过 {skipped} 个已缓存,失败 {failed}",
@@ -1252,7 +1257,7 @@ async def preheat_thumbnail_cache(
skipped_count=skipped,
failed_count=failed,
)
except HTTPException:
raise
except Exception as e:
@@ -1267,13 +1272,13 @@ async def clear_all_thumbnail_cache(
):
"""
清空所有缩略图缓存(下次访问时会重新生成)
Returns:
清理结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not THUMBNAIL_CACHE_DIR.exists():
return ThumbnailCleanupResponse(
success=True,
@@ -1281,7 +1286,7 @@ async def clear_all_thumbnail_cache(
cleaned_count=0,
kept_count=0,
)
cleaned = 0
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
try:
@@ -1289,16 +1294,16 @@ async def clear_all_thumbnail_cache(
cleaned += 1
except Exception as e:
logger.warning(f"删除缓存文件失败 {cache_file.name}: {e}")
logger.info(f"已清空缩略图缓存: 删除 {cleaned} 个文件")
return ThumbnailCleanupResponse(
success=True,
message=f"已清空所有缩略图缓存:删除 {cleaned} 个文件",
cleaned_count=cleaned,
kept_count=0,
)
except HTTPException:
raise
except Exception as e:

View File

@@ -1,7 +1,7 @@
"""表达方式管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel
from pydantic import BaseModel, NonNegativeFloat
from typing import Optional, List, Dict
from src.common.logger import get_logger
from src.common.database.database_model import Expression, ChatStreams
@@ -21,7 +21,6 @@ class ExpressionResponse(BaseModel):
situation: str
style: str
context: Optional[str]
up_content: Optional[str]
last_active_time: float
chat_id: str
create_date: Optional[float]
@@ -49,8 +48,7 @@ class ExpressionCreateRequest(BaseModel):
situation: str
style: str
context: Optional[str] = None
up_content: Optional[str] = None
context: Optional[str] = NonNegativeFloat
chat_id: str
@@ -60,7 +58,6 @@ class ExpressionUpdateRequest(BaseModel):
situation: Optional[str] = None
style: Optional[str] = None
context: Optional[str] = None
up_content: Optional[str] = None
chat_id: Optional[str] = None
@@ -102,7 +99,6 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
situation=expression.situation,
style=expression.style,
context=expression.context,
up_content=expression.up_content,
last_active_time=expression.last_active_time,
chat_id=expression.chat_id,
create_date=expression.create_date,
@@ -260,7 +256,9 @@ async def get_expression_list(
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_expression_detail(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表达方式详细信息
@@ -289,7 +287,11 @@ async def get_expression_detail(expression_id: int, maibot_session: Optional[str
@router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(request: ExpressionCreateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def create_expression(
request: ExpressionCreateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
创建新的表达方式
@@ -310,7 +312,6 @@ async def create_expression(request: ExpressionCreateRequest, maibot_session: Op
situation=request.situation,
style=request.style,
context=request.context,
up_content=request.up_content,
chat_id=request.chat_id,
last_active_time=current_time,
create_date=current_time,
@@ -331,7 +332,10 @@ async def create_expression(request: ExpressionCreateRequest, maibot_session: Op
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression(
expression_id: int, request: ExpressionUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
expression_id: int,
request: ExpressionUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新表达方式(只更新提供的字段)
@@ -381,7 +385,9 @@ async def update_expression(
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def delete_expression(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除表达方式
@@ -424,7 +430,11 @@ class BatchDeleteRequest(BaseModel):
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def batch_delete_expressions(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除表达方式
@@ -465,7 +475,9 @@ async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session:
@router.get("/stats/summary")
async def get_expression_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_expression_stats(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表达方式统计数据

View File

@@ -24,7 +24,7 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
"""
if not chat_id_str:
return []
try:
# 尝试解析为 JSON
parsed = json.loads(chat_id_str)
@@ -49,10 +49,10 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str:
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
"""
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
if not stream_ids:
return chat_id_str
# 查询所有 stream_id 对应的名称
names = []
for stream_id in stream_ids:
@@ -62,7 +62,7 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str:
else:
# 如果没找到,显示截断的 stream_id
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
return ", ".join(names) if names else chat_id_str
@@ -187,7 +187,7 @@ def jargon_to_dict(jargon: Jargon) -> dict:
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
stream_id = stream_ids[0] if stream_ids else None
return {
"id": jargon.id,
"content": jargon.content,
@@ -277,17 +277,13 @@ async def get_chat_list():
"""获取所有有黑话记录的聊天列表"""
try:
# 获取所有不同的 chat_id
chat_ids = (
Jargon.select(Jargon.chat_id)
.distinct()
.where(Jargon.chat_id.is_null(False))
)
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
# 用于按 stream_id 去重
seen_stream_ids: set[str] = set()
for chat_id in chat_id_list:
stream_ids = parse_chat_id_to_stream_ids(chat_id)
if stream_ids:
@@ -346,12 +342,7 @@ async def get_jargon_stats():
complete_count = Jargon.select().where(Jargon.is_complete).count()
# 关联的聊天数量
chat_count = (
Jargon.select(Jargon.chat_id)
.distinct()
.where(Jargon.chat_id.is_null(False))
.count()
)
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
# 按聊天统计 TOP 5
top_chats = (
@@ -403,9 +394,7 @@ async def create_jargon(request: JargonCreateRequest):
"""创建黑话"""
try:
# 检查是否已存在相同内容的黑话
existing = Jargon.get_or_none(
(Jargon.content == request.content) & (Jargon.chat_id == request.chat_id)
)
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
if existing:
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
@@ -527,11 +516,7 @@ async def batch_set_jargon_status(
if not ids:
raise HTTPException(status_code=400, detail="ID列表不能为空")
updated_count = (
Jargon.update(is_jargon=is_jargon)
.where(Jargon.id.in_(ids))
.execute()
)
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}")

View File

@@ -1,15 +1,24 @@
"""知识库图谱可视化 API 路由"""
from typing import List, Optional
from fastapi import APIRouter, Query
from fastapi import APIRouter, Query, Depends, Cookie, Header
from pydantic import BaseModel
import logging
from src.webui.auth import verify_auth_token_from_cookie_or_header
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class KnowledgeNode(BaseModel):
"""知识节点"""
@@ -113,6 +122,7 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
async def get_knowledge_graph(
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
_auth: bool = Depends(require_auth),
):
"""获取知识图谱(限制节点数量)
@@ -199,7 +209,7 @@ async def get_knowledge_graph(
@router.get("/stats", response_model=KnowledgeStats)
async def get_knowledge_stats():
async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
"""获取知识库统计信息
Returns:
@@ -248,7 +258,7 @@ async def get_knowledge_stats():
@router.get("/search", response_model=List[KnowledgeNode])
async def search_knowledge_node(query: str = Query(..., min_length=1)):
async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)):
"""搜索知识节点
Args:

View File

@@ -1,10 +1,12 @@
"""WebSocket 日志推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Optional
import json
from pathlib import Path
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
logger = get_logger("webui.logs_ws")
router = APIRouter()
@@ -73,14 +75,48 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
@router.websocket("/ws/logs")
async def websocket_logs(websocket: WebSocket):
async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)):
"""WebSocket 日志推送端点
客户端连接后会持续接收服务器端的日志消息
支持三种认证方式(按优先级):
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/logs?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
# 连接建立后,立即发送历史日志
try:

View File

@@ -6,18 +6,27 @@
import os
import httpx
from fastapi import APIRouter, HTTPException, Query
from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header
from typing import Optional
import tomlkit
from src.common.logger import get_logger
from src.config.config import CONFIG_DIR
from src.webui.auth import verify_auth_token_from_cookie_or_header
logger = get_logger("webui")
router = APIRouter(prefix="/models", tags=["models"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# 模型获取器配置
MODEL_FETCHER_CONFIG = {
# OpenAI 兼容格式的提供商
@@ -184,6 +193,7 @@ async def get_provider_models(
provider_name: str = Query(..., description="提供商名称"),
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
_auth: bool = Depends(require_auth),
):
"""
获取指定提供商的可用模型列表
@@ -228,6 +238,7 @@ async def get_models_by_url(
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
_auth: bool = Depends(require_auth),
):
"""
通过 URL 直接获取模型列表(用于自定义提供商)
@@ -251,6 +262,7 @@ async def get_models_by_url(
async def test_provider_connection(
base_url: str = Query(..., description="提供商的基础 URL"),
api_key: Optional[str] = Query(None, description="API Key可选用于验证 Key 有效性)"),
_auth: bool = Depends(require_auth),
):
"""
测试提供商连接状态
@@ -337,6 +349,7 @@ async def test_provider_connection(
@router.post("/test-connection-by-name")
async def test_provider_connection_by_name(
provider_name: str = Query(..., description="提供商名称"),
_auth: bool = Depends(require_auth),
):
"""
通过提供商名称测试连接(从配置文件读取信息)

View File

@@ -200,7 +200,9 @@ async def get_person_list(
@router.get("/{person_id}", response_model=PersonDetailResponse)
async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_person_detail(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取人物详细信息
@@ -229,7 +231,12 @@ async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cook
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
async def update_person(person_id: str, request: PersonUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def update_person(
person_id: str,
request: PersonUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新人物信息(只更新提供的字段)
@@ -278,7 +285,9 @@ async def update_person(person_id: str, request: PersonUpdateRequest, maibot_ses
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
async def delete_person(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def delete_person(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除人物信息
@@ -348,7 +357,11 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
@router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_persons(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def batch_delete_persons(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除人物信息

View File

@@ -1,10 +1,12 @@
"""WebSocket 插件加载进度推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set, Dict, Any
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Dict, Any, Optional
import json
import asyncio
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
logger = get_logger("webui.plugin_progress")
@@ -89,14 +91,48 @@ async def update_progress(
@router.websocket("/ws/plugin-progress")
async def websocket_plugin_progress(websocket: WebSocket):
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)):
"""WebSocket 插件加载进度推送端点
客户端连接后会立即收到当前进度状态
支持三种认证方式(按优先级):
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/plugin-progress?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
try:
# 发送当前进度状态

File diff suppressed because it is too large Load Diff

245
src/webui/rate_limiter.py Normal file
View File

@@ -0,0 +1,245 @@
"""
WebUI 请求频率限制模块
防止暴力破解和 API 滥用
"""
import time
from collections import defaultdict
from typing import Dict, Tuple, Optional
from fastapi import Request, HTTPException
from src.common.logger import get_logger
logger = get_logger("webui.rate_limiter")
class RateLimiter:
"""
简单的内存请求频率限制器
使用滑动窗口算法实现
"""
def __init__(self):
# 存储格式: {key: [(timestamp, count), ...]}
self._requests: Dict[str, list] = defaultdict(list)
# 被封禁的 IP: {ip: unblock_timestamp}
self._blocked: Dict[str, float] = {}
def _get_client_ip(self, request: Request) -> str:
"""获取客户端 IP 地址"""
# 检查代理头
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
# 取第一个 IP最原始的客户端
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# 直接连接的客户端
if request.client:
return request.client.host
return "unknown"
def _cleanup_old_requests(self, key: str, window_seconds: int):
"""清理过期的请求记录"""
now = time.time()
cutoff = now - window_seconds
self._requests[key] = [(ts, count) for ts, count in self._requests[key] if ts > cutoff]
def _cleanup_expired_blocks(self):
"""清理过期的封禁"""
now = time.time()
expired = [ip for ip, unblock_time in self._blocked.items() if now > unblock_time]
for ip in expired:
del self._blocked[ip]
logger.info(f"🔓 IP {ip} 封禁已解除")
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
"""
检查 IP 是否被封禁
Returns:
(是否被封禁, 剩余封禁秒数)
"""
self._cleanup_expired_blocks()
ip = self._get_client_ip(request)
if ip in self._blocked:
remaining = int(self._blocked[ip] - time.time())
return True, max(0, remaining)
return False, None
def check_rate_limit(
self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = ""
) -> Tuple[bool, int]:
"""
检查请求是否超过频率限制
Args:
request: FastAPI Request 对象
max_requests: 窗口期内允许的最大请求数
window_seconds: 窗口时间(秒)
key_suffix: 键后缀,用于区分不同的限制规则
Returns:
(是否允许, 剩余请求数)
"""
ip = self._get_client_ip(request)
key = f"{ip}:{key_suffix}" if key_suffix else ip
# 清理过期记录
self._cleanup_old_requests(key, window_seconds)
# 计算当前窗口内的请求数
current_count = sum(count for _, count in self._requests[key])
if current_count >= max_requests:
return False, 0
# 记录新请求
now = time.time()
self._requests[key].append((now, 1))
remaining = max_requests - current_count - 1
return True, remaining
def block_ip(self, request: Request, duration_seconds: int):
"""
封禁 IP
Args:
request: FastAPI Request 对象
duration_seconds: 封禁时长(秒)
"""
ip = self._get_client_ip(request)
self._blocked[ip] = time.time() + duration_seconds
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds}")
def record_failed_attempt(
self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600
) -> Tuple[bool, int]:
"""
记录失败尝试(如登录失败)
如果在窗口期内失败次数过多,自动封禁 IP
Args:
request: FastAPI Request 对象
max_failures: 允许的最大失败次数
window_seconds: 统计窗口(秒)
block_duration: 封禁时长(秒)
Returns:
(是否被封禁, 剩余尝试次数)
"""
ip = self._get_client_ip(request)
key = f"{ip}:auth_failures"
# 清理过期记录
self._cleanup_old_requests(key, window_seconds)
# 计算当前失败次数
current_failures = sum(count for _, count in self._requests[key])
# 记录本次失败
now = time.time()
self._requests[key].append((now, 1))
current_failures += 1
remaining = max_failures - current_failures
# 检查是否需要封禁
if current_failures >= max_failures:
self.block_ip(request, block_duration)
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
return True, 0
if current_failures >= max_failures - 2:
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures}")
return False, max(0, remaining)
def reset_failures(self, request: Request):
"""
重置失败计数(认证成功后调用)
"""
ip = self._get_client_ip(request)
key = f"{ip}:auth_failures"
if key in self._requests:
del self._requests[key]
# 全局单例
_rate_limiter: Optional[RateLimiter] = None
def get_rate_limiter() -> RateLimiter:
"""获取 RateLimiter 单例"""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = RateLimiter()
return _rate_limiter
async def check_auth_rate_limit(request: Request):
"""
认证接口的频率限制依赖
规则:
- 每个 IP 每分钟最多 10 次认证请求
- 连续失败 5 次后封禁 10 分钟
"""
limiter = get_rate_limiter()
# 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request)
if blocked:
raise HTTPException(
status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)},
)
# 检查频率限制
allowed, remaining = limiter.check_rate_limit(
request,
max_requests=10, # 每分钟 10 次
window_seconds=60,
key_suffix="auth",
)
if not allowed:
raise HTTPException(status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
async def check_api_rate_limit(request: Request):
"""
普通 API 的频率限制依赖
规则:每个 IP 每分钟最多 100 次请求
"""
limiter = get_rate_limiter()
# 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request)
if blocked:
raise HTTPException(
status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)},
)
# 检查频率限制
allowed, _ = limiter.check_rate_limit(
request,
max_requests=100, # 每分钟 100 次
window_seconds=60,
key_suffix="api",
)
if not allowed:
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"})

View File

@@ -7,10 +7,12 @@
import os
import time
from datetime import datetime
from fastapi import APIRouter, HTTPException
from typing import Optional
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel
from src.config.config import MMC_VERSION
from src.common.logger import get_logger
from src.webui.auth import verify_auth_token_from_cookie_or_header
router = APIRouter(prefix="/system", tags=["system"])
logger = get_logger("webui_system")
@@ -19,6 +21,14 @@ logger = get_logger("webui_system")
_start_time = time.time()
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class RestartResponse(BaseModel):
"""重启响应"""
@@ -36,7 +46,7 @@ class StatusResponse(BaseModel):
@router.post("/restart", response_model=RestartResponse)
async def restart_maibot():
async def restart_maibot(_auth: bool = Depends(require_auth)):
"""
重启麦麦主程序
@@ -67,7 +77,7 @@ async def restart_maibot():
@router.get("/status", response_model=StatusResponse)
async def get_maibot_status():
async def get_maibot_status(_auth: bool = Depends(require_auth)):
"""
获取麦麦运行状态
@@ -90,7 +100,7 @@ async def get_maibot_status():
@router.post("/reload-config")
async def reload_config():
async def reload_config(_auth: bool = Depends(require_auth)):
"""
热重载配置(不重启进程)

View File

@@ -1,11 +1,12 @@
"""WebUI API 路由"""
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, Depends
from pydantic import BaseModel, Field
from typing import Optional
from src.common.logger import get_logger
from .token_manager import get_token_manager
from .auth import set_auth_cookie, clear_auth_cookie
from .rate_limiter import get_rate_limiter, check_auth_rate_limit
from .config_routes import router as config_router
from .statistics_routes import router as statistics_router
from .person_routes import router as person_router
@@ -16,6 +17,7 @@ from .plugin_routes import router as plugin_router
from .plugin_progress_ws import get_progress_router
from .routers.system import router as system_router
from .model_routes import router as model_router
from .ws_auth import router as ws_auth_router
logger = get_logger("webui.api")
@@ -42,6 +44,8 @@ router.include_router(get_progress_router())
router.include_router(system_router)
# 注册模型列表获取路由
router.include_router(model_router)
# 注册 WebSocket 认证路由
router.include_router(ws_auth_router)
class TokenVerifyRequest(BaseModel):
@@ -107,12 +111,18 @@ async def health_check():
@router.post("/auth/verify", response_model=TokenVerifyResponse)
async def verify_token(request: TokenVerifyRequest, response: Response):
async def verify_token(
request_body: TokenVerifyRequest,
request: Request,
response: Response,
_rate_limit: None = Depends(check_auth_rate_limit),
):
"""
验证访问令牌,验证成功后设置 HttpOnly Cookie
Args:
request: 包含 token 的验证请求
request_body: 包含 token 的验证请求
request: FastAPI Request 对象(用于获取客户端 IP
response: FastAPI Response 对象
Returns:
@@ -120,16 +130,37 @@ async def verify_token(request: TokenVerifyRequest, response: Response):
"""
try:
token_manager = get_token_manager()
is_valid = token_manager.verify_token(request.token)
rate_limiter = get_rate_limiter()
is_valid = token_manager.verify_token(request_body.token)
if is_valid:
# 设置 HttpOnly Cookie
set_auth_cookie(response, request.token)
# 认证成功,重置失败计数
rate_limiter.reset_failures(request)
# 设置 HttpOnly Cookie传入 request 以检测协议)
set_auth_cookie(response, request_body.token, request)
# 同时返回首次配置状态,避免额外请求
is_first_setup = token_manager.is_first_setup()
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
else:
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
# 记录失败尝试
blocked, remaining = rate_limiter.record_failed_attempt(
request,
max_failures=5, # 5 次失败
window_seconds=300, # 5 分钟窗口
block_duration=600, # 封禁 10 分钟
)
if blocked:
raise HTTPException(status_code=429, detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟")
message = "Token 无效或已过期"
if remaining <= 2:
message += f"(剩余 {remaining} 次尝试机会)"
return TokenVerifyResponse(valid=False, message=message)
except HTTPException:
raise
except Exception as e:
logger.error(f"Token 验证失败: {e}")
raise HTTPException(status_code=500, detail="Token 验证失败") from e
@@ -139,10 +170,10 @@ async def verify_token(request: TokenVerifyRequest, response: Response):
async def logout(response: Response):
"""
登出并清除认证 Cookie
Args:
response: FastAPI Response 对象
Returns:
登出结果
"""
@@ -158,29 +189,39 @@ async def check_auth_status(
):
"""
检查当前认证状态(用于前端判断是否已登录)
Returns:
认证状态
"""
try:
token = None
# 记录请求信息用于调试
logger.debug(f"检查认证状态 - Cookie: {maibot_session[:20] if maibot_session else 'None'}..., Authorization: {'Present' if authorization else 'None'}")
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
logger.debug("使用 Cookie 中的 token")
# 其次从 Header 获取
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
logger.debug("使用 Header 中的 token")
if not token:
logger.debug("未找到 token返回未认证")
return {"authenticated": False}
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_valid = token_manager.verify_token(token)
logger.debug(f"Token 验证结果: {is_valid}")
if is_valid:
return {"authenticated": True}
else:
return {"authenticated": False}
except Exception:
except Exception as e:
logger.error(f"认证检查失败: {e}", exc_info=True)
return {"authenticated": False}
@@ -211,7 +252,7 @@ async def update_token(
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
@@ -222,10 +263,10 @@ async def update_token(
# 更新 token
success, message = token_manager.update_token(request.new_token)
# 如果更新成功,更新 Cookie
# 如果更新成功,清除 Cookie,要求用户重新登录
if success:
set_auth_cookie(response, request.new_token)
clear_auth_cookie(response)
return TokenUpdateResponse(success=success, message=message)
except HTTPException:
@@ -263,7 +304,7 @@ async def regenerate_token(
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
@@ -271,9 +312,9 @@ async def regenerate_token(
# 重新生成 token
new_token = token_manager.regenerate_token()
# 更新 Cookie
set_auth_cookie(response, new_token)
# 清除 Cookie,要求用户重新登录
clear_auth_cookie(response)
return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成")
except HTTPException:
@@ -306,7 +347,7 @@ async def get_setup_status(
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
@@ -349,7 +390,7 @@ async def complete_setup(
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
@@ -392,7 +433,7 @@ async def reset_setup(
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")

View File

@@ -1,19 +1,28 @@
"""统计数据 API 路由"""
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel, Field
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from peewee import fn
from src.common.logger import get_logger
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
from src.webui.auth import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.statistics")
router = APIRouter(prefix="/statistics", tags=["statistics"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class StatisticsSummary(BaseModel):
"""统计数据摘要"""
@@ -58,7 +67,7 @@ class DashboardData(BaseModel):
@router.get("/dashboard", response_model=DashboardData)
async def get_dashboard_data(hours: int = 24):
async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取仪表盘统计数据
@@ -275,7 +284,7 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
@router.get("/summary")
async def get_summary(hours: int = 24):
async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取统计摘要
@@ -293,7 +302,7 @@ async def get_summary(hours: int = 24):
@router.get("/models")
async def get_model_stats(hours: int = 24):
async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取模型统计

View File

@@ -160,13 +160,29 @@ class TokenManager:
def regenerate_token(self) -> str:
"""
重新生成 token
重新生成 token(保留 first_setup_completed 状态)
Returns:
str: 新生成的 token
"""
logger.info("正在重新生成 WebUI Token...")
return self._create_new_token()
# 生成新的 64 位十六进制字符串
new_token = secrets.token_hex(32)
# 加载现有配置,保留 first_setup_completed 状态
config = self._load_config()
old_token = config.get("access_token", "")[:8] if config.get("access_token") else ""
first_setup_completed = config.get("first_setup_completed", True) # 默认为 True表示已完成配置
config["access_token"] = new_token
config["updated_at"] = self._get_current_timestamp()
config["first_setup_completed"] = first_setup_completed # 保留原来的状态
self._save_config(config)
logger.info(f"WebUI Token 已重新生成: {old_token}... -> {new_token[:8]}...")
return new_token
def _validate_token_format(self, token: str) -> bool:
"""

View File

@@ -1,6 +1,5 @@
"""独立的 WebUI 服务器 - 默认运行在 127.0.0.1:8001"""
import os
import asyncio
import mimetypes
from pathlib import Path
@@ -22,6 +21,9 @@ class WebUIServer:
self.app = FastAPI(title="MaiBot WebUI")
self._server = None
# 配置防爬虫中间件需要在CORS之前注册
self._setup_anti_crawler()
# 配置 CORS支持开发环境跨域请求
self._setup_cors()
@@ -32,6 +34,9 @@ class WebUIServer:
self._register_api_routes()
self._setup_static_files()
# 注册robots.txt路由
self._setup_robots_txt()
def _setup_cors(self):
"""配置 CORS 中间件"""
# 开发环境需要允许前端开发服务器的跨域请求
@@ -40,12 +45,21 @@ class WebUIServer:
allow_origins=[
"http://localhost:5173", # Vite 开发服务器
"http://127.0.0.1:5173",
"http://localhost:7999", # 前端开发服务器备用端口
"http://127.0.0.1:7999",
"http://localhost:8001", # 生产环境
"http://127.0.0.1:8001",
],
allow_credentials=True, # 允许携带 Cookie
allow_methods=["*"],
allow_headers=["*"],
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], # 明确指定允许的方法
allow_headers=[
"Content-Type",
"Authorization",
"Accept",
"Origin",
"X-Requested-With",
], # 明确指定允许的头
expose_headers=["Content-Length", "Content-Type"], # 允许前端读取的响应头
)
logger.debug("✅ CORS 中间件已配置")
@@ -89,43 +103,77 @@ class WebUIServer:
"""服务单页应用 - 只处理非 API 请求"""
# 如果是根路径,直接返回 index.html
if not full_path or full_path == "/":
return FileResponse(static_path / "index.html", media_type="text/html")
response = FileResponse(static_path / "index.html", media_type="text/html")
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
# 检查是否是静态文件
file_path = static_path / full_path
if file_path.is_file() and file_path.exists():
# 自动检测 MIME 类型
media_type = mimetypes.guess_type(str(file_path))[0]
return FileResponse(file_path, media_type=media_type)
response = FileResponse(file_path, media_type=media_type)
# HTML 文件添加防索引头
if str(file_path).endswith(".html"):
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
# 其他路径返回 index.htmlSPA 路由)
return FileResponse(static_path / "index.html", media_type="text/html")
response = FileResponse(static_path / "index.html", media_type="text/html")
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")
def _setup_anti_crawler(self):
"""配置防爬虫中间件"""
try:
from src.webui.anti_crawler import AntiCrawlerMiddleware
from src.config.config import global_config
# 从配置读取防爬虫模式
anti_crawler_mode = global_config.webui.anti_crawler_mode
# 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行
# 我们需要在CORS之前注册这样防爬虫检查会在CORS之前执行
self.app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
mode_descriptions = {"false": "已禁用", "strict": "严格模式", "loose": "宽松模式", "basic": "基础模式"}
mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式")
logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}")
except Exception as e:
logger.error(f"❌ 配置防爬虫中间件失败: {e}", exc_info=True)
def _setup_robots_txt(self):
"""设置robots.txt路由"""
try:
from src.webui.anti_crawler import create_robots_txt_response
@self.app.get("/robots.txt", include_in_schema=False)
async def robots_txt():
"""返回robots.txt禁止所有爬虫"""
return create_robots_txt_response()
logger.debug("✅ robots.txt 路由已注册")
except Exception as e:
logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True)
def _register_api_routes(self):
"""注册所有 WebUI API 路由"""
try:
# 导入所有 WebUI 路由
from src.webui.routes import router as webui_router
from src.webui.logs_ws import router as logs_router
logger.info("开始导入 knowledge_routes...")
from src.webui.knowledge_routes import router as knowledge_router
logger.info("knowledge_routes 导入成功")
# 导入本地聊天室路由
from src.webui.chat_routes import router as chat_router
logger.info("chat_routes 导入成功")
# 注册路由
self.app.include_router(webui_router)
self.app.include_router(logs_router)
self.app.include_router(knowledge_router)
self.app.include_router(chat_router)
logger.info(f"knowledge_router 路由前缀: {knowledge_router.prefix}")
logger.info("✅ WebUI API 路由已注册")
except Exception as e:
@@ -133,6 +181,16 @@ class WebUIServer:
async def start(self):
"""启动服务器"""
# 预先检查端口是否可用
if not self._check_port_available():
error_msg = f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用"
logger.error(error_msg)
logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}")
logger.error("💡 可以在 .env 文件中修改 WEBUI_PORT 来更改 WebUI 端口")
logger.error(f"💡 Windows 用户可以运行: netstat -ano | findstr :{self.port}")
logger.error(f"💡 Linux/Mac 用户可以运行: lsof -i :{self.port}")
raise OSError(f"端口 {self.port} 已被占用,无法启动 WebUI 服务器")
config = Config(
app=self.app,
host=self.host,
@@ -143,15 +201,59 @@ class WebUIServer:
self._server = UvicornServer(config=config)
logger.info("🌐 WebUI 服务器启动中...")
logger.info(f"🌐 访问地址: http://{self.host}:{self.port}")
if self.host == "0.0.0.0":
logger.info(f"本机访问请使用 http://localhost:{self.port}")
# 根据地址类型显示正确的访问地址
if ':' in self.host:
# IPv6 地址需要用方括号包裹
logger.info(f"🌐 访问地址: http://[{self.host}]:{self.port}")
if self.host == "::":
logger.info(f"💡 IPv6 本机访问: http://[::1]:{self.port}")
logger.info(f"💡 IPv4 本机访问: http://127.0.0.1:{self.port}")
elif self.host == "::1":
logger.info("💡 仅支持 IPv6 本地访问")
else:
# IPv4 地址
logger.info(f"🌐 访问地址: http://{self.host}:{self.port}")
if self.host == "0.0.0.0":
logger.info(f"💡 本机访问: http://localhost:{self.port} 或 http://127.0.0.1:{self.port}")
try:
await self._server.serve()
except Exception as e:
logger.error(f"❌ WebUI 服务器运行错误: {e}")
except OSError as e:
# 处理端口绑定相关的错误
if "address already in use" in str(e).lower() or e.errno in (98, 10048): # 98: Linux, 10048: Windows
logger.error(f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用")
logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}")
logger.error("💡 可以在 .env 文件中修改 WEBUI_PORT 来更改 WebUI 端口")
else:
logger.error(f"❌ WebUI 服务器启动失败 (网络错误): {e}")
raise
except Exception as e:
logger.error(f"❌ WebUI 服务器运行错误: {e}", exc_info=True)
raise
def _check_port_available(self) -> bool:
"""检查端口是否可用(支持 IPv4 和 IPv6"""
import socket
# 判断使用 IPv4 还是 IPv6
if ':' in self.host:
# IPv6 地址
family = socket.AF_INET6
test_host = self.host if self.host != "::" else "::1"
else:
# IPv4 地址
family = socket.AF_INET
test_host = self.host if self.host != "0.0.0.0" else "127.0.0.1"
try:
with socket.socket(family, socket.SOCK_STREAM) as s:
s.settimeout(1)
# 尝试绑定端口
s.bind((test_host, self.port))
return True
except OSError:
return False
async def shutdown(self):
"""关闭服务器"""
@@ -177,7 +279,8 @@ def get_webui_server() -> WebUIServer:
"""获取全局 WebUI 服务器实例"""
global _webui_server
if _webui_server is None:
# 从环境变量读取配置
# 从环境变量读取
import os
host = os.getenv("WEBUI_HOST", "127.0.0.1")
port = int(os.getenv("WEBUI_PORT", "8001"))
_webui_server = WebUIServer(host=host, port=port)

114
src/webui/ws_auth.py Normal file
View File

@@ -0,0 +1,114 @@
"""WebSocket 认证模块
提供所有 WebSocket 端点统一使用的临时 token 认证机制。
临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。
"""
from fastapi import APIRouter, Cookie, Header
from typing import Optional
import secrets
import time
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
logger = get_logger("webui.ws_auth")
router = APIRouter()
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
_WS_TOKEN_EXPIRE_SECONDS = 60
def _cleanup_expired_ws_tokens():
"""清理过期的临时 token"""
now = time.time()
expired = [t for t, (exp, _) in _ws_temp_tokens.items() if now > exp]
for t in expired:
del _ws_temp_tokens[t]
def generate_ws_token(session_token: str) -> str:
"""生成 WebSocket 临时 token
Args:
session_token: 原始的 session token
Returns:
临时 token 字符串
"""
_cleanup_expired_ws_tokens()
temp_token = secrets.token_urlsafe(32)
_ws_temp_tokens[temp_token] = (time.time() + _WS_TOKEN_EXPIRE_SECONDS, session_token)
logger.debug(f"生成 WS 临时 token: {temp_token[:8]}... 有效期 {_WS_TOKEN_EXPIRE_SECONDS}s")
return temp_token
def verify_ws_token(temp_token: str) -> bool:
"""验证并消费 WebSocket 临时 token一次性使用
Args:
temp_token: 临时 token
Returns:
验证是否通过
"""
_cleanup_expired_ws_tokens()
if temp_token not in _ws_temp_tokens:
logger.warning(f"WS token 不存在: {temp_token[:8]}...")
return False
expire_time, session_token = _ws_temp_tokens[temp_token]
if time.time() > expire_time:
del _ws_temp_tokens[temp_token]
logger.warning(f"WS token 已过期: {temp_token[:8]}...")
return False
# 验证原始 session token 仍然有效
token_manager = get_token_manager()
if not token_manager.verify_token(session_token):
del _ws_temp_tokens[temp_token]
logger.warning(f"WS token 关联的 session 已失效: {temp_token[:8]}...")
return False
# 消费 token一次性使用
del _ws_temp_tokens[temp_token]
logger.debug(f"WS token 验证成功: {temp_token[:8]}...")
return True
@router.get("/ws-token")
async def get_ws_token(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取 WebSocket 连接用的临时 token
此端点验证当前会话的 Cookie 或 Authorization header
然后返回一个临时 token 用于 WebSocket 握手认证。
临时 token 有效期 60 秒,且只能使用一次。
注意:在未认证时返回 200 状态码但 success=False避免前端因 401 刷新页面。
"""
# 获取当前 session token
session_token = None
if maibot_session:
session_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
session_token = authorization.replace("Bearer ", "")
if not session_token:
# 返回 200 但 success=False避免前端因 401 刷新页面
# 这在登录页面是正常情况,不应该触发错误处理
logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)")
return {"success": False, "message": "未提供认证信息,请先登录", "token": None, "expires_in": 0}
# 验证 session token
token_manager = get_token_manager()
if not token_manager.verify_token(session_token):
# 同样返回 200 但 success=False避免前端刷新
logger.debug("ws-token 请求:认证已过期")
return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0}
# 生成临时 WebSocket token
ws_token = generate_ws_token(session_token)
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}