Merge branch 'dev' of https://github.com/Mai-with-u/MaiBot into dev
This commit is contained in:
4
bot.py
4
bot.py
@@ -41,6 +41,7 @@ logger = get_logger("main")
|
|||||||
# 定义重启退出码
|
# 定义重启退出码
|
||||||
RESTART_EXIT_CODE = 42
|
RESTART_EXIT_CODE = 42
|
||||||
|
|
||||||
|
|
||||||
def run_runner_process():
|
def run_runner_process():
|
||||||
"""
|
"""
|
||||||
Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。
|
Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。
|
||||||
@@ -68,7 +69,7 @@ def run_runner_process():
|
|||||||
|
|
||||||
if return_code == RESTART_EXIT_CODE:
|
if return_code == RESTART_EXIT_CODE:
|
||||||
logger.info("检测到重启请求 (退出码 42),正在重启...")
|
logger.info("检测到重启请求 (退出码 42),正在重启...")
|
||||||
time.sleep(1) # 稍作等待
|
time.sleep(1) # 稍作等待
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
logger.info(f"程序已退出 (退出码 {return_code})")
|
logger.info(f"程序已退出 (退出码 {return_code})")
|
||||||
@@ -87,6 +88,7 @@ def run_runner_process():
|
|||||||
process.kill()
|
process.kill()
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
# 检查是否是 Worker 进程
|
# 检查是否是 Worker 进程
|
||||||
# 如果没有设置 MAIBOT_WORKER_PROCESS 环境变量,说明是直接运行的脚本,
|
# 如果没有设置 MAIBOT_WORKER_PROCESS 环境变量,说明是直接运行的脚本,
|
||||||
# 此时应该作为 Runner 运行。
|
# 此时应该作为 Runner 运行。
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ConversionResult:
|
class ConversionResult:
|
||||||
"""转换结果"""
|
"""转换结果"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
servers: List[Dict[str, Any]] = field(default_factory=list)
|
servers: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
errors: List[str] = field(default_factory=list)
|
errors: List[str] = field(default_factory=list)
|
||||||
@@ -271,11 +272,7 @@ class ConfigConverter:
|
|||||||
return name, result
|
return name, result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_claude_format(
|
def from_claude_format(cls, config: Dict[str, Any], existing_names: Optional[set] = None) -> ConversionResult:
|
||||||
cls,
|
|
||||||
config: Dict[str, Any],
|
|
||||||
existing_names: Optional[set] = None
|
|
||||||
) -> ConversionResult:
|
|
||||||
"""从 Claude Desktop 格式转换为 MaiBot 格式
|
"""从 Claude Desktop 格式转换为 MaiBot 格式
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -355,11 +352,7 @@ class ConfigConverter:
|
|||||||
return {"mcpServers": mcp_servers}
|
return {"mcpServers": mcp_servers}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def import_from_string(
|
def import_from_string(cls, json_str: str, existing_names: Optional[set] = None) -> ConversionResult:
|
||||||
cls,
|
|
||||||
json_str: str,
|
|
||||||
existing_names: Optional[set] = None
|
|
||||||
) -> ConversionResult:
|
|
||||||
"""从 JSON 字符串导入配置
|
"""从 JSON 字符串导入配置
|
||||||
|
|
||||||
自动检测格式并转换为 MaiBot 格式
|
自动检测格式并转换为 MaiBot 格式
|
||||||
@@ -422,12 +415,7 @@ class ConfigConverter:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def export_to_string(
|
def export_to_string(cls, servers: List[Dict[str, Any]], format_type: str = "claude", pretty: bool = True) -> str:
|
||||||
cls,
|
|
||||||
servers: List[Dict[str, Any]],
|
|
||||||
format_type: str = "claude",
|
|
||||||
pretty: bool = True
|
|
||||||
) -> str:
|
|
||||||
"""导出配置为 JSON 字符串
|
"""导出配置为 JSON 字符串
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -34,28 +34,31 @@ from enum import Enum
|
|||||||
# 尝试导入 MaiBot 的 logger,如果失败则使用标准 logging
|
# 尝试导入 MaiBot 的 logger,如果失败则使用标准 logging
|
||||||
try:
|
try:
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("mcp_client")
|
logger = get_logger("mcp_client")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Fallback: 使用标准 logging
|
# Fallback: 使用标准 logging
|
||||||
logger = logging.getLogger("mcp_client")
|
logger = logging.getLogger("mcp_client")
|
||||||
if not logger.handlers:
|
if not logger.handlers:
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
handler.setFormatter(logging.Formatter('[%(levelname)s] %(name)s: %(message)s'))
|
handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s"))
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
class TransportType(Enum):
|
class TransportType(Enum):
|
||||||
"""MCP 传输类型"""
|
"""MCP 传输类型"""
|
||||||
STDIO = "stdio" # 本地进程通信
|
|
||||||
SSE = "sse" # Server-Sent Events (旧版 HTTP)
|
STDIO = "stdio" # 本地进程通信
|
||||||
HTTP = "http" # HTTP Streamable (新版,推荐)
|
SSE = "sse" # Server-Sent Events (旧版 HTTP)
|
||||||
|
HTTP = "http" # HTTP Streamable (新版,推荐)
|
||||||
STREAMABLE_HTTP = "streamable_http" # HTTP Streamable 的别名
|
STREAMABLE_HTTP = "streamable_http" # HTTP Streamable 的别名
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MCPToolInfo:
|
class MCPToolInfo:
|
||||||
"""MCP 工具信息"""
|
"""MCP 工具信息"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
input_schema: Dict[str, Any]
|
input_schema: Dict[str, Any]
|
||||||
@@ -65,6 +68,7 @@ class MCPToolInfo:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MCPResourceInfo:
|
class MCPResourceInfo:
|
||||||
"""MCP 资源信息"""
|
"""MCP 资源信息"""
|
||||||
|
|
||||||
uri: str
|
uri: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
@@ -75,6 +79,7 @@ class MCPResourceInfo:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MCPPromptInfo:
|
class MCPPromptInfo:
|
||||||
"""MCP 提示模板信息"""
|
"""MCP 提示模板信息"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
arguments: List[Dict[str, Any]] # [{name, description, required}]
|
arguments: List[Dict[str, Any]] # [{name, description, required}]
|
||||||
@@ -84,6 +89,7 @@ class MCPPromptInfo:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MCPServerConfig:
|
class MCPServerConfig:
|
||||||
"""MCP 服务器配置"""
|
"""MCP 服务器配置"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
transport: TransportType = TransportType.STDIO
|
transport: TransportType = TransportType.STDIO
|
||||||
@@ -99,6 +105,7 @@ class MCPServerConfig:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MCPCallResult:
|
class MCPCallResult:
|
||||||
"""MCP 工具调用结果"""
|
"""MCP 工具调用结果"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
content: Any
|
content: Any
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
@@ -108,8 +115,9 @@ class MCPCallResult:
|
|||||||
|
|
||||||
class CircuitState(Enum):
|
class CircuitState(Enum):
|
||||||
"""断路器状态"""
|
"""断路器状态"""
|
||||||
CLOSED = "closed" # 正常状态,允许请求
|
|
||||||
OPEN = "open" # 熔断状态,拒绝请求
|
CLOSED = "closed" # 正常状态,允许请求
|
||||||
|
OPEN = "open" # 熔断状态,拒绝请求
|
||||||
HALF_OPEN = "half_open" # 半开状态,允许少量试探请求
|
HALF_OPEN = "half_open" # 半开状态,允许少量试探请求
|
||||||
|
|
||||||
|
|
||||||
@@ -125,9 +133,9 @@ class CircuitBreaker:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# 配置
|
# 配置
|
||||||
failure_threshold: int = 5 # 连续失败多少次后熔断
|
failure_threshold: int = 5 # 连续失败多少次后熔断
|
||||||
recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒)
|
recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒)
|
||||||
half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用
|
half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用
|
||||||
|
|
||||||
# 状态
|
# 状态
|
||||||
state: CircuitState = field(default=CircuitState.CLOSED)
|
state: CircuitState = field(default=CircuitState.CLOSED)
|
||||||
@@ -232,6 +240,7 @@ class CircuitBreaker:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ToolCallStats:
|
class ToolCallStats:
|
||||||
"""工具调用统计"""
|
"""工具调用统计"""
|
||||||
|
|
||||||
tool_key: str
|
tool_key: str
|
||||||
total_calls: int = 0
|
total_calls: int = 0
|
||||||
success_calls: int = 0
|
success_calls: int = 0
|
||||||
@@ -282,6 +291,7 @@ class ToolCallStats:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ServerStats:
|
class ServerStats:
|
||||||
"""服务器统计"""
|
"""服务器统计"""
|
||||||
|
|
||||||
server_name: str
|
server_name: str
|
||||||
connect_count: int = 0 # 连接次数
|
connect_count: int = 0 # 连接次数
|
||||||
disconnect_count: int = 0 # 断开次数
|
disconnect_count: int = 0 # 断开次数
|
||||||
@@ -442,9 +452,7 @@ class MCPClientSession:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
server_params = StdioServerParameters(
|
server_params = StdioServerParameters(
|
||||||
command=self.config.command,
|
command=self.config.command, args=self.config.args, env=self.config.env if self.config.env else None
|
||||||
args=self.config.args,
|
|
||||||
env=self.config.env if self.config.env else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._stdio_context = stdio_client(server_params)
|
self._stdio_context = stdio_client(server_params)
|
||||||
@@ -506,6 +514,7 @@ class MCPClientSession:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.server_name}] SSE 连接失败: {e}")
|
logger.error(f"[{self.server_name}] SSE 连接失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
|
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
|
||||||
await self._cleanup()
|
await self._cleanup()
|
||||||
return False
|
return False
|
||||||
@@ -551,6 +560,7 @@ class MCPClientSession:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.server_name}] HTTP 连接失败: {e}")
|
logger.error(f"[{self.server_name}] HTTP 连接失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
|
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
|
||||||
await self._cleanup()
|
await self._cleanup()
|
||||||
return False
|
return False
|
||||||
@@ -568,8 +578,8 @@ class MCPClientSession:
|
|||||||
tool_info = MCPToolInfo(
|
tool_info = MCPToolInfo(
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
description=tool.description or f"MCP tool: {tool.name}",
|
description=tool.description or f"MCP tool: {tool.name}",
|
||||||
input_schema=tool.inputSchema if hasattr(tool, 'inputSchema') else {},
|
input_schema=tool.inputSchema if hasattr(tool, "inputSchema") else {},
|
||||||
server_name=self.server_name
|
server_name=self.server_name,
|
||||||
)
|
)
|
||||||
self._tools.append(tool_info)
|
self._tools.append(tool_info)
|
||||||
# 初始化工具统计
|
# 初始化工具统计
|
||||||
@@ -591,10 +601,7 @@ class MCPClientSession:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(self._session.list_resources(), timeout=self.call_timeout)
|
||||||
self._session.list_resources(),
|
|
||||||
timeout=self.call_timeout
|
|
||||||
)
|
|
||||||
self._resources = []
|
self._resources = []
|
||||||
|
|
||||||
for resource in result.resources:
|
for resource in result.resources:
|
||||||
@@ -602,8 +609,8 @@ class MCPClientSession:
|
|||||||
uri=str(resource.uri),
|
uri=str(resource.uri),
|
||||||
name=resource.name or str(resource.uri),
|
name=resource.name or str(resource.uri),
|
||||||
description=resource.description or "",
|
description=resource.description or "",
|
||||||
mime_type=resource.mimeType if hasattr(resource, 'mimeType') else None,
|
mime_type=resource.mimeType if hasattr(resource, "mimeType") else None,
|
||||||
server_name=self.server_name
|
server_name=self.server_name,
|
||||||
)
|
)
|
||||||
self._resources.append(resource_info)
|
self._resources.append(resource_info)
|
||||||
logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}")
|
logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}")
|
||||||
@@ -633,28 +640,27 @@ class MCPClientSession:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(self._session.list_prompts(), timeout=self.call_timeout)
|
||||||
self._session.list_prompts(),
|
|
||||||
timeout=self.call_timeout
|
|
||||||
)
|
|
||||||
self._prompts = []
|
self._prompts = []
|
||||||
|
|
||||||
for prompt in result.prompts:
|
for prompt in result.prompts:
|
||||||
# 解析参数
|
# 解析参数
|
||||||
arguments = []
|
arguments = []
|
||||||
if hasattr(prompt, 'arguments') and prompt.arguments:
|
if hasattr(prompt, "arguments") and prompt.arguments:
|
||||||
for arg in prompt.arguments:
|
for arg in prompt.arguments:
|
||||||
arguments.append({
|
arguments.append(
|
||||||
"name": arg.name,
|
{
|
||||||
"description": arg.description or "",
|
"name": arg.name,
|
||||||
"required": arg.required if hasattr(arg, 'required') else False,
|
"description": arg.description or "",
|
||||||
})
|
"required": arg.required if hasattr(arg, "required") else False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
prompt_info = MCPPromptInfo(
|
prompt_info = MCPPromptInfo(
|
||||||
name=prompt.name,
|
name=prompt.name,
|
||||||
description=prompt.description or f"MCP prompt: {prompt.name}",
|
description=prompt.description or f"MCP prompt: {prompt.name}",
|
||||||
arguments=arguments,
|
arguments=arguments,
|
||||||
server_name=self.server_name
|
server_name=self.server_name,
|
||||||
)
|
)
|
||||||
self._prompts.append(prompt_info)
|
self._prompts.append(prompt_info)
|
||||||
logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}")
|
logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}")
|
||||||
@@ -686,35 +692,25 @@ class MCPClientSession:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
if not self._connected or not self._session:
|
if not self._connected or not self._session:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {self.server_name} 未连接"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self._supports_resources:
|
if not self._supports_resources:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Resources 功能")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {self.server_name} 不支持 Resources 功能"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(self._session.read_resource(uri), timeout=self.call_timeout)
|
||||||
self._session.read_resource(uri),
|
|
||||||
timeout=self.call_timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
# 处理返回内容
|
# 处理返回内容
|
||||||
content_parts = []
|
content_parts = []
|
||||||
for content in result.contents:
|
for content in result.contents:
|
||||||
if hasattr(content, 'text'):
|
if hasattr(content, "text"):
|
||||||
content_parts.append(content.text)
|
content_parts.append(content.text)
|
||||||
elif hasattr(content, 'blob'):
|
elif hasattr(content, "blob"):
|
||||||
# 二进制数据,返回 base64 或提示
|
# 二进制数据,返回 base64 或提示
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
blob_data = content.blob
|
blob_data = content.blob
|
||||||
if len(blob_data) < 10000: # 小于 10KB 返回 base64
|
if len(blob_data) < 10000: # 小于 10KB 返回 base64
|
||||||
content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}")
|
content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}")
|
||||||
@@ -724,28 +720,18 @@ class MCPClientSession:
|
|||||||
content_parts.append(str(content))
|
content_parts.append(str(content))
|
||||||
|
|
||||||
return MCPCallResult(
|
return MCPCallResult(
|
||||||
success=True,
|
success=True, content="\n".join(content_parts) if content_parts else "", duration_ms=duration_ms
|
||||||
content="\n".join(content_parts) if content_parts else "",
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
return MCPCallResult(
|
return MCPCallResult(
|
||||||
success=False,
|
success=False, content=None, error=f"读取资源超时({self.call_timeout}秒)", duration_ms=duration_ms
|
||||||
content=None,
|
|
||||||
error=f"读取资源超时({self.call_timeout}秒)",
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}")
|
logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}")
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=str(e),
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult:
|
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult:
|
||||||
"""v1.2.0: 获取提示模板的内容
|
"""v1.2.0: 获取提示模板的内容
|
||||||
@@ -760,23 +746,14 @@ class MCPClientSession:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
if not self._connected or not self._session:
|
if not self._connected or not self._session:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {self.server_name} 未连接"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self._supports_prompts:
|
if not self._supports_prompts:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Prompts 功能")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {self.server_name} 不支持 Prompts 功能"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
self._session.get_prompt(name, arguments=arguments or {}),
|
self._session.get_prompt(name, arguments=arguments or {}), timeout=self.call_timeout
|
||||||
timeout=self.call_timeout
|
|
||||||
)
|
)
|
||||||
|
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
@@ -784,10 +761,10 @@ class MCPClientSession:
|
|||||||
# 处理返回的消息
|
# 处理返回的消息
|
||||||
messages = []
|
messages = []
|
||||||
for msg in result.messages:
|
for msg in result.messages:
|
||||||
role = msg.role if hasattr(msg, 'role') else "unknown"
|
role = msg.role if hasattr(msg, "role") else "unknown"
|
||||||
content_text = ""
|
content_text = ""
|
||||||
if hasattr(msg, 'content'):
|
if hasattr(msg, "content"):
|
||||||
if hasattr(msg.content, 'text'):
|
if hasattr(msg.content, "text"):
|
||||||
content_text = msg.content.text
|
content_text = msg.content.text
|
||||||
elif isinstance(msg.content, str):
|
elif isinstance(msg.content, str):
|
||||||
content_text = msg.content
|
content_text = msg.content
|
||||||
@@ -796,28 +773,18 @@ class MCPClientSession:
|
|||||||
messages.append(f"[{role}]: {content_text}")
|
messages.append(f"[{role}]: {content_text}")
|
||||||
|
|
||||||
return MCPCallResult(
|
return MCPCallResult(
|
||||||
success=True,
|
success=True, content="\n\n".join(messages) if messages else "", duration_ms=duration_ms
|
||||||
content="\n\n".join(messages) if messages else "",
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
return MCPCallResult(
|
return MCPCallResult(
|
||||||
success=False,
|
success=False, content=None, error=f"获取提示模板超时({self.call_timeout}秒)", duration_ms=duration_ms
|
||||||
content=None,
|
|
||||||
error=f"获取提示模板超时({self.call_timeout}秒)",
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}")
|
logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}")
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=str(e),
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
|
||||||
|
|
||||||
async def check_health(self) -> bool:
|
async def check_health(self) -> bool:
|
||||||
"""检查连接健康状态(心跳检测)
|
"""检查连接健康状态(心跳检测)
|
||||||
@@ -829,10 +796,7 @@ class MCPClientSession:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用 list_tools 作为心跳检测
|
# 使用 list_tools 作为心跳检测
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(self._session.list_tools(), timeout=10.0)
|
||||||
self._session.list_tools(),
|
|
||||||
timeout=10.0
|
|
||||||
)
|
|
||||||
self.stats.record_heartbeat()
|
self.stats.record_heartbeat()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -849,12 +813,7 @@ class MCPClientSession:
|
|||||||
# v1.7.0: 断路器检查
|
# v1.7.0: 断路器检查
|
||||||
can_execute, reject_reason = self._circuit_breaker.can_execute()
|
can_execute, reject_reason = self._circuit_breaker.can_execute()
|
||||||
if not can_execute:
|
if not can_execute:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"⚡ {reject_reason}", circuit_broken=True)
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"⚡ {reject_reason}",
|
|
||||||
circuit_broken=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 半开状态下增加试探计数
|
# 半开状态下增加试探计数
|
||||||
if self._circuit_breaker.state == CircuitState.HALF_OPEN:
|
if self._circuit_breaker.state == CircuitState.HALF_OPEN:
|
||||||
@@ -870,8 +829,7 @@ class MCPClientSession:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
self._session.call_tool(tool_name, arguments=arguments),
|
self._session.call_tool(tool_name, arguments=arguments), timeout=self.call_timeout
|
||||||
timeout=self.call_timeout
|
|
||||||
)
|
)
|
||||||
|
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
@@ -879,9 +837,9 @@ class MCPClientSession:
|
|||||||
# 处理返回内容
|
# 处理返回内容
|
||||||
content_parts = []
|
content_parts = []
|
||||||
for content in result.content:
|
for content in result.content:
|
||||||
if hasattr(content, 'text'):
|
if hasattr(content, "text"):
|
||||||
content_parts.append(content.text)
|
content_parts.append(content.text)
|
||||||
elif hasattr(content, 'data'):
|
elif hasattr(content, "data"):
|
||||||
content_parts.append(f"[二进制数据: {len(content.data)} bytes]")
|
content_parts.append(f"[二进制数据: {len(content.data)} bytes]")
|
||||||
else:
|
else:
|
||||||
content_parts.append(str(content))
|
content_parts.append(str(content))
|
||||||
@@ -896,7 +854,7 @@ class MCPClientSession:
|
|||||||
return MCPCallResult(
|
return MCPCallResult(
|
||||||
success=True,
|
success=True,
|
||||||
content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)",
|
content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)",
|
||||||
duration_ms=duration_ms
|
duration_ms=duration_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
@@ -939,25 +897,25 @@ class MCPClientSession:
|
|||||||
self._supports_prompts = False # v1.2.0
|
self._supports_prompts = False # v1.2.0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(self, '_session_context') and self._session_context:
|
if hasattr(self, "_session_context") and self._session_context:
|
||||||
await self._session_context.__aexit__(None, None, None)
|
await self._session_context.__aexit__(None, None, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}")
|
logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(self, '_stdio_context') and self._stdio_context:
|
if hasattr(self, "_stdio_context") and self._stdio_context:
|
||||||
await self._stdio_context.__aexit__(None, None, None)
|
await self._stdio_context.__aexit__(None, None, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}")
|
logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(self, '_http_context') and self._http_context:
|
if hasattr(self, "_http_context") and self._http_context:
|
||||||
await self._http_context.__aexit__(None, None, None)
|
await self._http_context.__aexit__(None, None, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}")
|
logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(self, '_sse_context') and self._sse_context:
|
if hasattr(self, "_sse_context") and self._sse_context:
|
||||||
await self._sse_context.__aexit__(None, None, None)
|
await self._sse_context.__aexit__(None, None, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}")
|
logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}")
|
||||||
@@ -1082,7 +1040,9 @@ class MCPClientManager:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
if attempt < retry_attempts:
|
if attempt < retry_attempts:
|
||||||
logger.warning(f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})")
|
logger.warning(
|
||||||
|
f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})"
|
||||||
|
)
|
||||||
await asyncio.sleep(retry_interval)
|
await asyncio.sleep(retry_interval)
|
||||||
|
|
||||||
logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})")
|
logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})")
|
||||||
@@ -1213,11 +1173,7 @@ class MCPClientManager:
|
|||||||
async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult:
|
async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult:
|
||||||
"""调用 MCP 工具"""
|
"""调用 MCP 工具"""
|
||||||
if tool_key not in self._all_tools:
|
if tool_key not in self._all_tools:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"工具 {tool_key} 不存在")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"工具 {tool_key} 不存在"
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_info, client = self._all_tools[tool_key]
|
tool_info, client = self._all_tools[tool_key]
|
||||||
|
|
||||||
@@ -1273,11 +1229,7 @@ class MCPClientManager:
|
|||||||
# 如果指定了服务器
|
# 如果指定了服务器
|
||||||
if server_name:
|
if server_name:
|
||||||
if server_name not in self._clients:
|
if server_name not in self._clients:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {server_name} 不存在"
|
|
||||||
)
|
|
||||||
client = self._clients[server_name]
|
client = self._clients[server_name]
|
||||||
return await client.read_resource(uri)
|
return await client.read_resource(uri)
|
||||||
|
|
||||||
@@ -1293,14 +1245,11 @@ class MCPClientManager:
|
|||||||
if result.success:
|
if result.success:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"未找到资源: {uri}")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"未找到资源: {uri}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None,
|
async def get_prompt(
|
||||||
server_name: Optional[str] = None) -> MCPCallResult:
|
self, name: str, arguments: Optional[Dict[str, str]] = None, server_name: Optional[str] = None
|
||||||
|
) -> MCPCallResult:
|
||||||
"""v1.2.0: 获取提示模板内容
|
"""v1.2.0: 获取提示模板内容
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1311,11 +1260,7 @@ class MCPClientManager:
|
|||||||
# 如果指定了服务器
|
# 如果指定了服务器
|
||||||
if server_name:
|
if server_name:
|
||||||
if server_name not in self._clients:
|
if server_name not in self._clients:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {server_name} 不存在"
|
|
||||||
)
|
|
||||||
client = self._clients[server_name]
|
client = self._clients[server_name]
|
||||||
return await client.get_prompt(name, arguments)
|
return await client.get_prompt(name, arguments)
|
||||||
|
|
||||||
@@ -1324,11 +1269,7 @@ class MCPClientManager:
|
|||||||
if prompt_info.name == name:
|
if prompt_info.name == name:
|
||||||
return await client.get_prompt(name, arguments)
|
return await client.get_prompt(name, arguments)
|
||||||
|
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"未找到提示模板: {name}")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"未找到提示模板: {name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ==================== 心跳检测 ====================
|
# ==================== 心跳检测 ====================
|
||||||
|
|
||||||
@@ -1489,7 +1430,9 @@ class MCPClientManager:
|
|||||||
"global": {
|
"global": {
|
||||||
**self._global_stats,
|
**self._global_stats,
|
||||||
"uptime_seconds": round(uptime, 2),
|
"uptime_seconds": round(uptime, 2),
|
||||||
"calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2) if uptime > 0 else 0,
|
"calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2)
|
||||||
|
if uptime > 0
|
||||||
|
else 0,
|
||||||
},
|
},
|
||||||
"servers": server_stats,
|
"servers": server_stats,
|
||||||
"tools": tool_stats,
|
"tools": tool_stats,
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ from .mcp_client import (
|
|||||||
TransportType,
|
TransportType,
|
||||||
mcp_manager,
|
mcp_manager,
|
||||||
)
|
)
|
||||||
from .config_converter import ConfigConverter, ConversionResult
|
from .config_converter import ConfigConverter
|
||||||
|
|
||||||
logger = get_logger("mcp_bridge_plugin")
|
logger = get_logger("mcp_bridge_plugin")
|
||||||
|
|
||||||
@@ -93,9 +93,11 @@ logger = get_logger("mcp_bridge_plugin")
|
|||||||
# v1.4.0: 调用链路追踪
|
# v1.4.0: 调用链路追踪
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolCallRecord:
|
class ToolCallRecord:
|
||||||
"""工具调用记录"""
|
"""工具调用记录"""
|
||||||
|
|
||||||
call_id: str
|
call_id: str
|
||||||
timestamp: float
|
timestamp: float
|
||||||
tool_name: str
|
tool_name: str
|
||||||
@@ -178,9 +180,11 @@ tool_call_tracer = ToolCallTracer()
|
|||||||
# v1.4.0: 工具调用缓存
|
# v1.4.0: 工具调用缓存
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CacheEntry:
|
class CacheEntry:
|
||||||
"""缓存条目"""
|
"""缓存条目"""
|
||||||
|
|
||||||
tool_name: str
|
tool_name: str
|
||||||
args_hash: str
|
args_hash: str
|
||||||
result: str
|
result: str
|
||||||
@@ -317,6 +321,7 @@ tool_call_cache = ToolCallCache()
|
|||||||
# v1.4.0: 工具权限控制
|
# v1.4.0: 工具权限控制
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class PermissionChecker:
|
class PermissionChecker:
|
||||||
"""工具权限检查器"""
|
"""工具权限检查器"""
|
||||||
|
|
||||||
@@ -449,6 +454,7 @@ permission_checker = PermissionChecker()
|
|||||||
# 工具类型转换
|
# 工具类型转换
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
|
def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
|
||||||
"""将 JSON Schema 类型转换为 MaiBot 的 ToolParamType"""
|
"""将 JSON Schema 类型转换为 MaiBot 的 ToolParamType"""
|
||||||
type_mapping = {
|
type_mapping = {
|
||||||
@@ -462,7 +468,9 @@ def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
|
|||||||
return type_mapping.get(json_type, ToolParamType.STRING)
|
return type_mapping.get(json_type, ToolParamType.STRING)
|
||||||
|
|
||||||
|
|
||||||
def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]:
|
def parse_mcp_parameters(
|
||||||
|
input_schema: Dict[str, Any],
|
||||||
|
) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]:
|
||||||
"""解析 MCP 工具的参数 schema,转换为 MaiBot 的参数格式"""
|
"""解析 MCP 工具的参数 schema,转换为 MaiBot 的参数格式"""
|
||||||
parameters = []
|
parameters = []
|
||||||
|
|
||||||
@@ -497,6 +505,7 @@ def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolPa
|
|||||||
# MCP 工具代理
|
# MCP 工具代理
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MCPToolProxy(BaseTool):
|
class MCPToolProxy(BaseTool):
|
||||||
"""MCP 工具代理基类"""
|
"""MCP 工具代理基类"""
|
||||||
|
|
||||||
@@ -539,10 +548,7 @@ class MCPToolProxy(BaseTool):
|
|||||||
# v1.4.0: 权限检查
|
# v1.4.0: 权限检查
|
||||||
if not permission_checker.check(self.name, chat_id, user_id, is_group):
|
if not permission_checker.check(self.name, chat_id, user_id, is_group):
|
||||||
logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}")
|
logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}")
|
||||||
return {
|
return {"name": self.name, "content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"}
|
||||||
"name": self.name,
|
|
||||||
"content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}")
|
logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}")
|
||||||
|
|
||||||
@@ -726,11 +732,7 @@ class MCPToolProxy(BaseTool):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def _call_post_process_llm(
|
async def _call_post_process_llm(
|
||||||
self,
|
self, prompt: str, max_tokens: int, settings: Dict[str, Any], server_config: Optional[Dict[str, Any]]
|
||||||
prompt: str,
|
|
||||||
max_tokens: int,
|
|
||||||
settings: Dict[str, Any],
|
|
||||||
server_config: Optional[Dict[str, Any]]
|
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""调用 LLM 进行后处理"""
|
"""调用 LLM 进行后处理"""
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
@@ -788,10 +790,7 @@ class MCPToolProxy(BaseTool):
|
|||||||
|
|
||||||
|
|
||||||
def create_mcp_tool_class(
|
def create_mcp_tool_class(
|
||||||
tool_key: str,
|
tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False
|
||||||
tool_info: MCPToolInfo,
|
|
||||||
tool_prefix: str,
|
|
||||||
disabled: bool = False
|
|
||||||
) -> Type[MCPToolProxy]:
|
) -> Type[MCPToolProxy]:
|
||||||
"""根据 MCP 工具信息动态创建 BaseTool 子类"""
|
"""根据 MCP 工具信息动态创建 BaseTool 子类"""
|
||||||
parameters = parse_mcp_parameters(tool_info.input_schema)
|
parameters = parse_mcp_parameters(tool_info.input_schema)
|
||||||
@@ -814,7 +813,7 @@ def create_mcp_tool_class(
|
|||||||
"_mcp_tool_key": tool_key,
|
"_mcp_tool_key": tool_key,
|
||||||
"_mcp_original_name": tool_info.name,
|
"_mcp_original_name": tool_info.name,
|
||||||
"_mcp_server_name": tool_info.server_name,
|
"_mcp_server_name": tool_info.server_name,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return tool_class
|
return tool_class
|
||||||
@@ -828,11 +827,7 @@ class MCPToolRegistry:
|
|||||||
self._tool_infos: Dict[str, ToolInfo] = {}
|
self._tool_infos: Dict[str, ToolInfo] = {}
|
||||||
|
|
||||||
def register_tool(
|
def register_tool(
|
||||||
self,
|
self, tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False
|
||||||
tool_key: str,
|
|
||||||
tool_info: MCPToolInfo,
|
|
||||||
tool_prefix: str,
|
|
||||||
disabled: bool = False
|
|
||||||
) -> Tuple[ToolInfo, Type[MCPToolProxy]]:
|
) -> Tuple[ToolInfo, Type[MCPToolProxy]]:
|
||||||
"""注册 MCP 工具"""
|
"""注册 MCP 工具"""
|
||||||
tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled)
|
tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled)
|
||||||
@@ -879,6 +874,7 @@ _plugin_instance: Optional["MCPBridgePlugin"] = None
|
|||||||
# 内置工具
|
# 内置工具
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MCPReadResourceTool(BaseTool):
|
class MCPReadResourceTool(BaseTool):
|
||||||
"""v1.2.0: MCP 资源读取工具"""
|
"""v1.2.0: MCP 资源读取工具"""
|
||||||
|
|
||||||
@@ -950,9 +946,17 @@ class MCPStatusTool(BaseTool):
|
|||||||
"""MCP 状态查询工具"""
|
"""MCP 状态查询工具"""
|
||||||
|
|
||||||
name = "mcp_status"
|
name = "mcp_status"
|
||||||
description = "查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、资源列表、提示模板列表、调用统计、追踪记录等信息"
|
description = (
|
||||||
|
"查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、资源列表、提示模板列表、调用统计、追踪记录等信息"
|
||||||
|
)
|
||||||
parameters = [
|
parameters = [
|
||||||
("query_type", ToolParamType.STRING, "查询类型", False, ["status", "tools", "resources", "prompts", "stats", "trace", "cache", "all"]),
|
(
|
||||||
|
"query_type",
|
||||||
|
ToolParamType.STRING,
|
||||||
|
"查询类型",
|
||||||
|
False,
|
||||||
|
["status", "tools", "resources", "prompts", "stats", "trace", "cache", "all"],
|
||||||
|
),
|
||||||
("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None),
|
("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None),
|
||||||
]
|
]
|
||||||
available_for_llm = True
|
available_for_llm = True
|
||||||
@@ -986,10 +990,7 @@ class MCPStatusTool(BaseTool):
|
|||||||
if query_type in ("cache",):
|
if query_type in ("cache",):
|
||||||
result_parts.append(self._format_cache())
|
result_parts.append(self._format_cache())
|
||||||
|
|
||||||
return {
|
return {"name": self.name, "content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"}
|
||||||
"name": self.name,
|
|
||||||
"content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"
|
|
||||||
}
|
|
||||||
|
|
||||||
def _format_status(self, server_name: Optional[str] = None) -> str:
|
def _format_status(self, server_name: Optional[str] = None) -> str:
|
||||||
status = mcp_manager.get_status()
|
status = mcp_manager.get_status()
|
||||||
@@ -1001,14 +1002,14 @@ class MCPStatusTool(BaseTool):
|
|||||||
lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}")
|
lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}")
|
||||||
|
|
||||||
lines.append("\n🔌 服务器详情:")
|
lines.append("\n🔌 服务器详情:")
|
||||||
for name, info in status['servers'].items():
|
for name, info in status["servers"].items():
|
||||||
if server_name and name != server_name:
|
if server_name and name != server_name:
|
||||||
continue
|
continue
|
||||||
status_icon = "✅" if info['connected'] else "❌"
|
status_icon = "✅" if info["connected"] else "❌"
|
||||||
enabled_text = "" if info['enabled'] else " (已禁用)"
|
enabled_text = "" if info["enabled"] else " (已禁用)"
|
||||||
lines.append(f" {status_icon} {name}{enabled_text}")
|
lines.append(f" {status_icon} {name}{enabled_text}")
|
||||||
lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}")
|
lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}")
|
||||||
if info['consecutive_failures'] > 0:
|
if info["consecutive_failures"] > 0:
|
||||||
lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']} 次")
|
lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']} 次")
|
||||||
|
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
@@ -1038,11 +1039,11 @@ class MCPStatusTool(BaseTool):
|
|||||||
stats = mcp_manager.get_all_stats()
|
stats = mcp_manager.get_all_stats()
|
||||||
lines = ["📈 调用统计"]
|
lines = ["📈 调用统计"]
|
||||||
|
|
||||||
g = stats['global']
|
g = stats["global"]
|
||||||
lines.append(f" 总调用次数: {g['total_tool_calls']}")
|
lines.append(f" 总调用次数: {g['total_tool_calls']}")
|
||||||
lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}")
|
lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}")
|
||||||
if g['total_tool_calls'] > 0:
|
if g["total_tool_calls"] > 0:
|
||||||
success_rate = (g['successful_calls'] / g['total_tool_calls']) * 100
|
success_rate = (g["successful_calls"] / g["total_tool_calls"]) * 100
|
||||||
lines.append(f" 成功率: {success_rate:.1f}%")
|
lines.append(f" 成功率: {success_rate:.1f}%")
|
||||||
lines.append(f" 运行时间: {g['uptime_seconds']:.0f} 秒")
|
lines.append(f" 运行时间: {g['uptime_seconds']:.0f} 秒")
|
||||||
|
|
||||||
@@ -1126,6 +1127,7 @@ class MCPStatusTool(BaseTool):
|
|||||||
# 命令处理
|
# 命令处理
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MCPStatusCommand(BaseCommand):
|
class MCPStatusCommand(BaseCommand):
|
||||||
"""MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态"""
|
"""MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态"""
|
||||||
|
|
||||||
@@ -1340,7 +1342,9 @@ class MCPStatusCommand(BaseCommand):
|
|||||||
try:
|
try:
|
||||||
exported = ConfigConverter.export_to_string(servers, format_type, pretty=True)
|
exported = ConfigConverter.export_to_string(servers, format_type, pretty=True)
|
||||||
|
|
||||||
format_name = {"claude": "Claude Desktop", "kiro": "Kiro MCP", "maibot": "MaiBot"}.get(format_type, format_type)
|
format_name = {"claude": "Claude Desktop", "kiro": "Kiro MCP", "maibot": "MaiBot"}.get(
|
||||||
|
format_type, format_type
|
||||||
|
)
|
||||||
lines = [f"📤 导出为 {format_name} 格式 ({len(servers)} 个服务器):"]
|
lines = [f"📤 导出为 {format_name} 格式 ({len(servers)} 个服务器):"]
|
||||||
lines.append("")
|
lines.append("")
|
||||||
lines.append(exported)
|
lines.append(exported)
|
||||||
@@ -1446,9 +1450,9 @@ class MCPStatusCommand(BaseCommand):
|
|||||||
cb = info.get("circuit_breaker", {})
|
cb = info.get("circuit_breaker", {})
|
||||||
cb_state = cb.get("state", "closed")
|
cb_state = cb.get("state", "closed")
|
||||||
if cb_state == "open":
|
if cb_state == "open":
|
||||||
lines.append(f" ⚡ 断路器熔断中")
|
lines.append(" ⚡ 断路器熔断中")
|
||||||
elif cb_state == "half_open":
|
elif cb_state == "half_open":
|
||||||
lines.append(f" ⚡ 断路器试探中")
|
lines.append(" ⚡ 断路器试探中")
|
||||||
if info["consecutive_failures"] > 0:
|
if info["consecutive_failures"] > 0:
|
||||||
lines.append(f" ⚠️ 连续失败 {info['consecutive_failures']} 次")
|
lines.append(f" ⚠️ 连续失败 {info['consecutive_failures']} 次")
|
||||||
|
|
||||||
@@ -1634,6 +1638,7 @@ class MCPImportCommand(BaseCommand):
|
|||||||
# 事件处理器
|
# 事件处理器
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MCPStartupHandler(BaseEventHandler):
|
class MCPStartupHandler(BaseEventHandler):
|
||||||
"""MCP 启动事件处理器"""
|
"""MCP 启动事件处理器"""
|
||||||
|
|
||||||
@@ -1692,6 +1697,7 @@ class MCPStopHandler(BaseEventHandler):
|
|||||||
# 主插件类
|
# 主插件类
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@register_plugin
|
@register_plugin
|
||||||
class MCPBridgePlugin(BasePlugin):
|
class MCPBridgePlugin(BasePlugin):
|
||||||
"""MCP 桥接插件 v1.4.0 - 将 MCP 服务器的工具桥接到 MaiBot"""
|
"""MCP 桥接插件 v1.4.0 - 将 MCP 服务器的工具桥接到 MaiBot"""
|
||||||
@@ -2116,9 +2122,9 @@ class MCPBridgePlugin(BasePlugin):
|
|||||||
label="📜 高级权限规则(可选)",
|
label="📜 高级权限规则(可选)",
|
||||||
input_type="textarea",
|
input_type="textarea",
|
||||||
rows=10,
|
rows=10,
|
||||||
placeholder='''[
|
placeholder="""[
|
||||||
{"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}
|
{"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}
|
||||||
]''',
|
]""",
|
||||||
hint="格式: qq:ID:group/private/user,工具名支持通配符 *",
|
hint="格式: qq:ID:group/private/user,工具名支持通配符 *",
|
||||||
order=10,
|
order=10,
|
||||||
),
|
),
|
||||||
@@ -2261,7 +2267,9 @@ class MCPBridgePlugin(BasePlugin):
|
|||||||
value = match1.group(2)
|
value = match1.group(2)
|
||||||
suffix = match1.group(3)
|
suffix = match1.group(3)
|
||||||
# 将转义的换行符还原为实际换行符
|
# 将转义的换行符还原为实际换行符
|
||||||
unescaped = value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\")
|
unescaped = (
|
||||||
|
value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\")
|
||||||
|
)
|
||||||
fixed_line = f'{prefix}"""{unescaped}"""{suffix}'
|
fixed_line = f'{prefix}"""{unescaped}"""{suffix}'
|
||||||
fixed_lines.append(fixed_line)
|
fixed_lines.append(fixed_line)
|
||||||
modified = True
|
modified = True
|
||||||
@@ -2560,7 +2568,7 @@ class MCPBridgePlugin(BasePlugin):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
import_config_str = str(import_config).strip()
|
import_config_str = str(import_config).strip()
|
||||||
logger.info(f"检测到 WebUI 导入请求,开始处理...")
|
logger.info("检测到 WebUI 导入请求,开始处理...")
|
||||||
|
|
||||||
# 执行导入
|
# 执行导入
|
||||||
await self._execute_webui_import(import_config_str, doc, config_path)
|
await self._execute_webui_import(import_config_str, doc, config_path)
|
||||||
@@ -2773,6 +2781,7 @@ class MCPBridgePlugin(BasePlugin):
|
|||||||
async def _async_connect_servers(self) -> None:
|
async def _async_connect_servers(self) -> None:
|
||||||
"""异步连接所有配置的 MCP 服务器(v1.5.0: 并行连接优化)"""
|
"""异步连接所有配置的 MCP 服务器(v1.5.0: 并行连接优化)"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
settings = self.config.get("settings", {})
|
settings = self.config.get("settings", {})
|
||||||
|
|
||||||
servers_section = self.config.get("servers", [])
|
servers_section = self.config.get("servers", [])
|
||||||
@@ -2854,10 +2863,7 @@ class MCPBridgePlugin(BasePlugin):
|
|||||||
|
|
||||||
# 并行执行所有连接
|
# 并行执行所有连接
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(*[connect_single_server(cfg) for cfg in enabled_configs], return_exceptions=True)
|
||||||
*[connect_single_server(cfg) for cfg in enabled_configs],
|
|
||||||
return_exceptions=True
|
|
||||||
)
|
|
||||||
connect_duration = time.time() - start_time
|
connect_duration = time.time() - start_time
|
||||||
|
|
||||||
# 统计连接结果
|
# 统计连接结果
|
||||||
@@ -2878,15 +2884,14 @@ class MCPBridgePlugin(BasePlugin):
|
|||||||
|
|
||||||
# 注册所有工具
|
# 注册所有工具
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
registered_count = 0
|
registered_count = 0
|
||||||
|
|
||||||
for tool_key, (tool_info, _) in mcp_manager.all_tools.items():
|
for tool_key, (tool_info, _) in mcp_manager.all_tools.items():
|
||||||
tool_name = tool_key.replace("-", "_").replace(".", "_")
|
tool_name = tool_key.replace("-", "_").replace(".", "_")
|
||||||
is_disabled = tool_name in disabled_tools
|
is_disabled = tool_name in disabled_tools
|
||||||
|
|
||||||
info, tool_class = mcp_tool_registry.register_tool(
|
info, tool_class = mcp_tool_registry.register_tool(tool_key, tool_info, tool_prefix, disabled=is_disabled)
|
||||||
tool_key, tool_info, tool_prefix, disabled=is_disabled
|
|
||||||
)
|
|
||||||
info.plugin_name = self.plugin_name
|
info.plugin_name = self.plugin_name
|
||||||
|
|
||||||
if component_registry.register_component(info, tool_class):
|
if component_registry.register_component(info, tool_class):
|
||||||
@@ -3004,6 +3009,7 @@ class MCPBridgePlugin(BasePlugin):
|
|||||||
doc["tools"] = tomlkit.table()
|
doc["tools"] = tomlkit.table()
|
||||||
# 使用 tomlkit 多行字符串避免控制字符问题
|
# 使用 tomlkit 多行字符串避免控制字符问题
|
||||||
from tomlkit.items import String, StringType, Trivia
|
from tomlkit.items import String, StringType, Trivia
|
||||||
|
|
||||||
ml_string = String(StringType.MLB, tool_list_text, tool_list_text, Trivia())
|
ml_string = String(StringType.MLB, tool_list_text, tool_list_text, Trivia())
|
||||||
doc["tools"]["tool_list"] = ml_string
|
doc["tools"]["tool_list"] = ml_string
|
||||||
|
|
||||||
@@ -3069,6 +3075,7 @@ class MCPBridgePlugin(BasePlugin):
|
|||||||
doc["status"] = tomlkit.table()
|
doc["status"] = tomlkit.table()
|
||||||
# 使用 tomlkit 多行字符串避免控制字符问题
|
# 使用 tomlkit 多行字符串避免控制字符问题
|
||||||
from tomlkit.items import String, StringType, Trivia
|
from tomlkit.items import String, StringType, Trivia
|
||||||
|
|
||||||
ml_string = String(StringType.MLB, status_text, status_text, Trivia())
|
ml_string = String(StringType.MLB, status_text, status_text, Trivia())
|
||||||
doc["status"]["connection_status"] = ml_string
|
doc["status"]["connection_status"] = ml_string
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ async def test_stats():
|
|||||||
assert stats.total_calls == 3
|
assert stats.total_calls == 3
|
||||||
assert stats.success_calls == 2
|
assert stats.success_calls == 2
|
||||||
assert stats.failed_calls == 1
|
assert stats.failed_calls == 1
|
||||||
assert stats.success_rate == (2/3) * 100
|
assert stats.success_rate == (2 / 3) * 100
|
||||||
assert stats.avg_duration_ms == 150.0
|
assert stats.avg_duration_ms == 150.0
|
||||||
assert stats.last_error == "timeout"
|
assert stats.last_error == "timeout"
|
||||||
|
|
||||||
@@ -66,13 +66,15 @@ async def test_manager_basic():
|
|||||||
manager.__init__()
|
manager.__init__()
|
||||||
|
|
||||||
# 配置
|
# 配置
|
||||||
manager.configure({
|
manager.configure(
|
||||||
"tool_prefix": "mcp",
|
{
|
||||||
"call_timeout": 30.0,
|
"tool_prefix": "mcp",
|
||||||
"retry_attempts": 1,
|
"call_timeout": 30.0,
|
||||||
"retry_interval": 1.0,
|
"retry_attempts": 1,
|
||||||
"heartbeat_enabled": False,
|
"retry_interval": 1.0,
|
||||||
})
|
"heartbeat_enabled": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 测试状态
|
# 测试状态
|
||||||
status = manager.get_status()
|
status = manager.get_status()
|
||||||
@@ -82,10 +84,7 @@ async def test_manager_basic():
|
|||||||
|
|
||||||
# 测试添加禁用的服务器
|
# 测试添加禁用的服务器
|
||||||
config = MCPServerConfig(
|
config = MCPServerConfig(
|
||||||
name="disabled_server",
|
name="disabled_server", enabled=False, transport=TransportType.HTTP, url="https://example.com/mcp"
|
||||||
enabled=False,
|
|
||||||
transport=TransportType.HTTP,
|
|
||||||
url="https://example.com/mcp"
|
|
||||||
)
|
)
|
||||||
result = await manager.add_server(config)
|
result = await manager.add_server(config)
|
||||||
assert result == True
|
assert result == True
|
||||||
@@ -120,27 +119,29 @@ async def test_http_connection():
|
|||||||
manager._initialized = False
|
manager._initialized = False
|
||||||
manager.__init__()
|
manager.__init__()
|
||||||
|
|
||||||
manager.configure({
|
manager.configure(
|
||||||
"tool_prefix": "mcp",
|
{
|
||||||
"call_timeout": 30.0,
|
"tool_prefix": "mcp",
|
||||||
"retry_attempts": 2,
|
"call_timeout": 30.0,
|
||||||
"retry_interval": 2.0,
|
"retry_attempts": 2,
|
||||||
"heartbeat_enabled": False,
|
"retry_interval": 2.0,
|
||||||
})
|
"heartbeat_enabled": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 使用 HowToCook MCP 服务器测试
|
# 使用 HowToCook MCP 服务器测试
|
||||||
config = MCPServerConfig(
|
config = MCPServerConfig(
|
||||||
name="howtocook",
|
name="howtocook",
|
||||||
enabled=True,
|
enabled=True,
|
||||||
transport=TransportType.HTTP,
|
transport=TransportType.HTTP,
|
||||||
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp"
|
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp",
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"正在连接 {config.url} ...")
|
print(f"正在连接 {config.url} ...")
|
||||||
result = await manager.add_server(config)
|
result = await manager.add_server(config)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
print(f"✅ 连接成功!")
|
print("✅ 连接成功!")
|
||||||
|
|
||||||
# 检查工具
|
# 检查工具
|
||||||
tools = manager.all_tools
|
tools = manager.all_tools
|
||||||
@@ -159,19 +160,23 @@ async def test_http_connection():
|
|||||||
call_result = await manager.call_tool("mcp_howtocook_whatToEat", {})
|
call_result = await manager.call_tool("mcp_howtocook_whatToEat", {})
|
||||||
if call_result.success:
|
if call_result.success:
|
||||||
print(f"✅ 工具调用成功 (耗时: {call_result.duration_ms:.0f}ms)")
|
print(f"✅ 工具调用成功 (耗时: {call_result.duration_ms:.0f}ms)")
|
||||||
print(f" 结果: {call_result.content[:200]}..." if len(str(call_result.content)) > 200 else f" 结果: {call_result.content}")
|
print(
|
||||||
|
f" 结果: {call_result.content[:200]}..."
|
||||||
|
if len(str(call_result.content)) > 200
|
||||||
|
else f" 结果: {call_result.content}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(f"❌ 工具调用失败: {call_result.error}")
|
print(f"❌ 工具调用失败: {call_result.error}")
|
||||||
|
|
||||||
# 查看统计
|
# 查看统计
|
||||||
stats = manager.get_all_stats()
|
stats = manager.get_all_stats()
|
||||||
print(f"\n📊 统计信息:")
|
print("\n📊 统计信息:")
|
||||||
print(f" 全局调用: {stats['global']['total_tool_calls']}")
|
print(f" 全局调用: {stats['global']['total_tool_calls']}")
|
||||||
print(f" 成功: {stats['global']['successful_calls']}")
|
print(f" 成功: {stats['global']['successful_calls']}")
|
||||||
print(f" 失败: {stats['global']['failed_calls']}")
|
print(f" 失败: {stats['global']['failed_calls']}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f"❌ 连接失败")
|
print("❌ 连接失败")
|
||||||
|
|
||||||
# 清理
|
# 清理
|
||||||
await manager.shutdown()
|
await manager.shutdown()
|
||||||
@@ -187,23 +192,25 @@ async def test_heartbeat():
|
|||||||
manager._initialized = False
|
manager._initialized = False
|
||||||
manager.__init__()
|
manager.__init__()
|
||||||
|
|
||||||
manager.configure({
|
manager.configure(
|
||||||
"tool_prefix": "mcp",
|
{
|
||||||
"call_timeout": 30.0,
|
"tool_prefix": "mcp",
|
||||||
"retry_attempts": 1,
|
"call_timeout": 30.0,
|
||||||
"retry_interval": 1.0,
|
"retry_attempts": 1,
|
||||||
"heartbeat_enabled": True,
|
"retry_interval": 1.0,
|
||||||
"heartbeat_interval": 5.0, # 5秒间隔用于测试
|
"heartbeat_enabled": True,
|
||||||
"auto_reconnect": True,
|
"heartbeat_interval": 5.0, # 5秒间隔用于测试
|
||||||
"max_reconnect_attempts": 2,
|
"auto_reconnect": True,
|
||||||
})
|
"max_reconnect_attempts": 2,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 添加一个测试服务器
|
# 添加一个测试服务器
|
||||||
config = MCPServerConfig(
|
config = MCPServerConfig(
|
||||||
name="heartbeat_test",
|
name="heartbeat_test",
|
||||||
enabled=True,
|
enabled=True,
|
||||||
transport=TransportType.HTTP,
|
transport=TransportType.HTTP,
|
||||||
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp"
|
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp",
|
||||||
)
|
)
|
||||||
|
|
||||||
print("正在连接服务器...")
|
print("正在连接服务器...")
|
||||||
@@ -260,6 +267,7 @@ async def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n❌ 测试失败: {e}")
|
print(f"\n❌ 测试失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -301,4 +301,3 @@ def main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Optional, Tuple, Any, Dict, Callable
|
from typing import List, Optional, Tuple, Any, Dict
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import Expression
|
from src.common.database.database_model import Expression
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
@@ -13,7 +13,12 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
)
|
)
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.bw_learner.learner_utils import filter_message_content, is_bot_message, build_context_paragraph, contains_bot_self_name
|
from src.bw_learner.learner_utils import (
|
||||||
|
filter_message_content,
|
||||||
|
is_bot_message,
|
||||||
|
build_context_paragraph,
|
||||||
|
contains_bot_self_name,
|
||||||
|
)
|
||||||
from src.bw_learner.jargon_miner import miner_manager
|
from src.bw_learner.jargon_miner import miner_manager
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
@@ -77,8 +82,6 @@ def init_prompt() -> None:
|
|||||||
Prompt(learn_style_prompt, "learn_style_prompt")
|
Prompt(learn_style_prompt, "learn_style_prompt")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ExpressionLearner:
|
class ExpressionLearner:
|
||||||
def __init__(self, chat_id: str) -> None:
|
def __init__(self, chat_id: str) -> None:
|
||||||
self.express_learn_model: LLMRequest = LLMRequest(
|
self.express_learn_model: LLMRequest = LLMRequest(
|
||||||
@@ -97,14 +100,14 @@ class ExpressionLearner:
|
|||||||
async def learn_and_store(
|
async def learn_and_store(
|
||||||
self,
|
self,
|
||||||
messages: List[Any],
|
messages: List[Any],
|
||||||
person_name_filter: Optional[Callable[[str], bool]] = None,
|
|
||||||
) -> List[Tuple[str, str, str]]:
|
) -> List[Tuple[str, str, str]]:
|
||||||
"""
|
"""
|
||||||
学习并存储表达方式
|
学习并存储表达方式
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: 外部传入的消息列表(必需)
|
messages: 外部传入的消息列表(必需)
|
||||||
person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称
|
num: 学习数量
|
||||||
|
timestamp_start: 学习开始的时间戳,如果为None则使用self.last_learning_time
|
||||||
"""
|
"""
|
||||||
if not messages:
|
if not messages:
|
||||||
return None
|
return None
|
||||||
@@ -135,17 +138,6 @@ class ExpressionLearner:
|
|||||||
expressions, jargon_entries = self.parse_expression_response(response)
|
expressions, jargon_entries = self.parse_expression_response(response)
|
||||||
expressions = self._filter_self_reference_styles(expressions)
|
expressions = self._filter_self_reference_styles(expressions)
|
||||||
|
|
||||||
# 过滤掉包含人物名称的表达方式
|
|
||||||
if person_name_filter:
|
|
||||||
filtered_expressions = []
|
|
||||||
for situation, style, source_id in expressions:
|
|
||||||
# 检查 situation 和 style 是否包含人物名称
|
|
||||||
if person_name_filter(situation) or person_name_filter(style):
|
|
||||||
logger.info(f"跳过包含人物名称的表达方式: situation={situation}, style={style}")
|
|
||||||
continue
|
|
||||||
filtered_expressions.append((situation, style, source_id))
|
|
||||||
expressions = filtered_expressions
|
|
||||||
|
|
||||||
# 检查表达方式数量,如果超过10个则放弃本次表达学习
|
# 检查表达方式数量,如果超过10个则放弃本次表达学习
|
||||||
if len(expressions) > 10:
|
if len(expressions) > 10:
|
||||||
logger.info(f"表达方式提取数量超过10个(实际{len(expressions)}个),放弃本次表达学习")
|
logger.info(f"表达方式提取数量超过10个(实际{len(expressions)}个),放弃本次表达学习")
|
||||||
@@ -158,7 +150,7 @@ class ExpressionLearner:
|
|||||||
|
|
||||||
# 处理黑话条目,路由到 jargon_miner(即使没有表达方式也要处理黑话)
|
# 处理黑话条目,路由到 jargon_miner(即使没有表达方式也要处理黑话)
|
||||||
if jargon_entries:
|
if jargon_entries:
|
||||||
await self._process_jargon_entries(jargon_entries, random_msg, person_name_filter)
|
await self._process_jargon_entries(jargon_entries, random_msg)
|
||||||
|
|
||||||
# 如果没有表达方式,直接返回
|
# 如果没有表达方式,直接返回
|
||||||
if not expressions:
|
if not expressions:
|
||||||
@@ -197,7 +189,6 @@ class ExpressionLearner:
|
|||||||
|
|
||||||
filtered_expressions.append((situation, style, context))
|
filtered_expressions.append((situation, style, context))
|
||||||
|
|
||||||
|
|
||||||
learnt_expressions = filtered_expressions
|
learnt_expressions = filtered_expressions
|
||||||
|
|
||||||
if learnt_expressions is None:
|
if learnt_expressions is None:
|
||||||
@@ -281,6 +272,7 @@ class ExpressionLearner:
|
|||||||
# 如果解析失败,尝试修复中文引号问题
|
# 如果解析失败,尝试修复中文引号问题
|
||||||
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
|
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
|
||||||
try:
|
try:
|
||||||
|
|
||||||
def fix_chinese_quotes_in_json(text):
|
def fix_chinese_quotes_in_json(text):
|
||||||
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
||||||
result = []
|
result = []
|
||||||
@@ -298,7 +290,7 @@ class ExpressionLearner:
|
|||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if char == '\\':
|
if char == "\\":
|
||||||
# 转义字符
|
# 转义字符
|
||||||
result.append(char)
|
result.append(char)
|
||||||
escape_next = True
|
escape_next = True
|
||||||
@@ -326,7 +318,7 @@ class ExpressionLearner:
|
|||||||
|
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
return ''.join(result)
|
return "".join(result)
|
||||||
|
|
||||||
fixed_raw = fix_chinese_quotes_in_json(raw)
|
fixed_raw = fix_chinese_quotes_in_json(raw)
|
||||||
|
|
||||||
@@ -511,19 +503,13 @@ class ExpressionLearner:
|
|||||||
logger.error(f"概括表达情境失败: {e}")
|
logger.error(f"概括表达情境失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _process_jargon_entries(
|
async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None:
|
||||||
self,
|
|
||||||
jargon_entries: List[Tuple[str, str]],
|
|
||||||
messages: List[Any],
|
|
||||||
person_name_filter: Optional[Callable[[str], bool]] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
处理从 expression learner 提取的黑话条目,路由到 jargon_miner
|
处理从 expression learner 提取的黑话条目,路由到 jargon_miner
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
jargon_entries: 黑话条目列表,每个元素是 (content, source_id)
|
jargon_entries: 黑话条目列表,每个元素是 (content, source_id)
|
||||||
messages: 消息列表,用于构建上下文
|
messages: 消息列表,用于构建上下文
|
||||||
person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称
|
|
||||||
"""
|
"""
|
||||||
if not jargon_entries or not messages:
|
if not jargon_entries or not messages:
|
||||||
return
|
return
|
||||||
@@ -544,11 +530,6 @@ class ExpressionLearner:
|
|||||||
logger.info(f"跳过包含机器人昵称/别名的黑话: {content}")
|
logger.info(f"跳过包含机器人昵称/别名的黑话: {content}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查是否包含人物名称
|
|
||||||
if person_name_filter and person_name_filter(content):
|
|
||||||
logger.info(f"跳过包含人物名称的黑话: {content}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 解析 source_id
|
# 解析 source_id
|
||||||
source_id_str = (source_id or "").strip()
|
source_id_str = (source_id or "").strip()
|
||||||
if not source_id_str.isdigit():
|
if not source_id_str.isdigit():
|
||||||
@@ -579,7 +560,7 @@ class ExpressionLearner:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 调用 jargon_miner 处理这些条目
|
# 调用 jargon_miner 处理这些条目
|
||||||
await jargon_miner.process_extracted_entries(entries, person_name_filter)
|
await jargon_miner.process_extracted_entries(entries)
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|||||||
@@ -82,9 +82,7 @@ class ExpressionReflector:
|
|||||||
# 获取未检查的表达
|
# 获取未检查的表达
|
||||||
try:
|
try:
|
||||||
logger.info("[Expression Reflection] 查询未检查且未拒绝的表达")
|
logger.info("[Expression Reflection] 查询未检查且未拒绝的表达")
|
||||||
expressions = (
|
expressions = Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
|
||||||
Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
|
|
||||||
)
|
|
||||||
|
|
||||||
expr_list = list(expressions)
|
expr_list = list(expressions)
|
||||||
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
|
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
|
||||||
@@ -147,7 +145,7 @@ expression_reflector_manager = ExpressionReflectorManager()
|
|||||||
|
|
||||||
async def _check_tracker_exists(operator_config: str) -> bool:
|
async def _check_tracker_exists(operator_config: str) -> bool:
|
||||||
"""检查指定 Operator 是否已有活跃的 Tracker"""
|
"""检查指定 Operator 是否已有活跃的 Tracker"""
|
||||||
from src.express.reflect_tracker import reflect_tracker_manager
|
from src.bw_learner.reflect_tracker import reflect_tracker_manager
|
||||||
|
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
chat_stream = None
|
chat_stream = None
|
||||||
@@ -242,7 +240,7 @@ async def _send_to_operator(operator_config: str, text: str, expr: Expression):
|
|||||||
stream_id = chat_stream.stream_id
|
stream_id = chat_stream.stream_id
|
||||||
|
|
||||||
# 注册 Tracker
|
# 注册 Tracker
|
||||||
from src.express.reflect_tracker import ReflectTracker, reflect_tracker_manager
|
from src.bw_learner.reflect_tracker import ReflectTracker, reflect_tracker_manager
|
||||||
|
|
||||||
tracker = ReflectTracker(chat_stream=chat_stream, expression=expr, created_time=time.time())
|
tracker = ReflectTracker(chat_stream=chat_stream, expression=expr, created_time=time.time())
|
||||||
reflect_tracker_manager.add_tracker(stream_id, tracker)
|
reflect_tracker_manager.add_tracker(stream_id, tracker)
|
||||||
|
|||||||
@@ -128,9 +128,7 @@ class ExpressionSelector:
|
|||||||
|
|
||||||
# 查询所有相关chat_id的表达方式,排除 rejected=1 的,且只选择 count > 1 的
|
# 查询所有相关chat_id的表达方式,排除 rejected=1 的,且只选择 count > 1 的
|
||||||
style_query = Expression.select().where(
|
style_query = Expression.select().where(
|
||||||
(Expression.chat_id.in_(related_chat_ids))
|
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
|
||||||
& (~Expression.rejected)
|
|
||||||
& (Expression.count > 1)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
style_exprs = [
|
style_exprs = [
|
||||||
@@ -150,12 +148,15 @@ class ExpressionSelector:
|
|||||||
# 要求至少有10个 count > 1 的表达方式才进行选择
|
# 要求至少有10个 count > 1 的表达方式才进行选择
|
||||||
min_required = 10
|
min_required = 10
|
||||||
if len(style_exprs) < min_required:
|
if len(style_exprs) < min_required:
|
||||||
logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择")
|
logger.info(
|
||||||
|
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择"
|
||||||
|
)
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
# 固定选择5个
|
# 固定选择5个
|
||||||
select_count = 5
|
select_count = 5
|
||||||
import random
|
import random
|
||||||
|
|
||||||
selected_style = random.sample(style_exprs, select_count)
|
selected_style = random.sample(style_exprs, select_count)
|
||||||
|
|
||||||
# 更新last_active_time
|
# 更新last_active_time
|
||||||
@@ -163,7 +164,9 @@ class ExpressionSelector:
|
|||||||
self.update_expressions_last_active_time(selected_style)
|
self.update_expressions_last_active_time(selected_style)
|
||||||
|
|
||||||
selected_ids = [expr["id"] for expr in selected_style]
|
selected_ids = [expr["id"] for expr in selected_style]
|
||||||
logger.debug(f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)} 个")
|
logger.debug(
|
||||||
|
f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)} 个"
|
||||||
|
)
|
||||||
return selected_style, selected_ids
|
return selected_style, selected_ids
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -186,9 +189,7 @@ class ExpressionSelector:
|
|||||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||||
|
|
||||||
# 优化:一次性查询所有相关chat_id的表达方式,排除 rejected=1 的表达
|
# 优化:一次性查询所有相关chat_id的表达方式,排除 rejected=1 的表达
|
||||||
style_query = Expression.select().where(
|
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
|
||||||
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
|
||||||
)
|
|
||||||
|
|
||||||
style_exprs = [
|
style_exprs = [
|
||||||
{
|
{
|
||||||
@@ -246,7 +247,9 @@ class ExpressionSelector:
|
|||||||
|
|
||||||
# 使用classic模式(随机选择+LLM选择)
|
# 使用classic模式(随机选择+LLM选择)
|
||||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式,think_level={think_level}")
|
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式,think_level={think_level}")
|
||||||
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message, reply_reason, think_level)
|
return await self._select_expressions_classic(
|
||||||
|
chat_id, chat_info, max_num, target_message, reply_reason, think_level
|
||||||
|
)
|
||||||
|
|
||||||
async def _select_expressions_classic(
|
async def _select_expressions_classic(
|
||||||
self,
|
self,
|
||||||
@@ -279,9 +282,7 @@ class ExpressionSelector:
|
|||||||
# think_level == 1: 先选高count,再从所有表达方式中随机抽样
|
# think_level == 1: 先选高count,再从所有表达方式中随机抽样
|
||||||
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
|
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
|
||||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||||
style_query = Expression.select().where(
|
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
|
||||||
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
|
||||||
)
|
|
||||||
|
|
||||||
all_style_exprs = [
|
all_style_exprs = [
|
||||||
{
|
{
|
||||||
@@ -308,11 +309,15 @@ class ExpressionSelector:
|
|||||||
|
|
||||||
# 检查数量要求
|
# 检查数量要求
|
||||||
if len(high_count_exprs) < min_high_count:
|
if len(high_count_exprs) < min_high_count:
|
||||||
logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择")
|
logger.info(
|
||||||
|
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择"
|
||||||
|
)
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
if len(all_style_exprs) < min_total_count:
|
if len(all_style_exprs) < min_total_count:
|
||||||
logger.info(f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择")
|
logger.info(
|
||||||
|
f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择"
|
||||||
|
)
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
# 先选取高count的表达方式
|
# 先选取高count的表达方式
|
||||||
@@ -332,6 +337,7 @@ class ExpressionSelector:
|
|||||||
|
|
||||||
# 打乱顺序,避免高count的都在前面
|
# 打乱顺序,避免高count的都在前面
|
||||||
import random
|
import random
|
||||||
|
|
||||||
random.shuffle(candidate_exprs)
|
random.shuffle(candidate_exprs)
|
||||||
|
|
||||||
# 2. 构建所有表达方式的索引和情境列表
|
# 2. 构建所有表达方式的索引和情境列表
|
||||||
@@ -351,7 +357,7 @@ class ExpressionSelector:
|
|||||||
all_situations_str = "\n".join(all_situations)
|
all_situations_str = "\n".join(all_situations)
|
||||||
|
|
||||||
if target_message:
|
if target_message:
|
||||||
target_message_str = f",现在你想要对这条消息进行回复:\"{target_message}\""
|
target_message_str = f',现在你想要对这条消息进行回复:"{target_message}"'
|
||||||
target_message_extra_block = "4.考虑你要回复的目标消息"
|
target_message_extra_block = "4.考虑你要回复的目标消息"
|
||||||
else:
|
else:
|
||||||
target_message_str = ""
|
target_message_str = ""
|
||||||
|
|||||||
@@ -8,7 +8,12 @@ from src.llm_models.utils_model import LLMRequest
|
|||||||
from src.config.config import model_config, global_config
|
from src.config.config import model_config, global_config
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.bw_learner.jargon_miner import search_jargon
|
from src.bw_learner.jargon_miner import search_jargon
|
||||||
from src.bw_learner.learner_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains
|
from src.bw_learner.learner_utils import (
|
||||||
|
is_bot_message,
|
||||||
|
contains_bot_self_name,
|
||||||
|
parse_chat_id_list,
|
||||||
|
chat_id_list_contains,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger("jargon")
|
logger = get_logger("jargon")
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import time
|
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import random
|
import random
|
||||||
@@ -14,7 +13,6 @@ from src.config.config import model_config, global_config
|
|||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
build_readable_messages_with_id,
|
build_readable_messages_with_id,
|
||||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
|
||||||
)
|
)
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.bw_learner.learner_utils import (
|
from src.bw_learner.learner_utils import (
|
||||||
@@ -46,10 +44,10 @@ def _is_single_char_jargon(content: str) -> bool:
|
|||||||
char = content[0]
|
char = content[0]
|
||||||
# 判断是否是单个汉字、单个英文字母或单个数字
|
# 判断是否是单个汉字、单个英文字母或单个数字
|
||||||
return (
|
return (
|
||||||
'\u4e00' <= char <= '\u9fff' or # 汉字
|
"\u4e00" <= char <= "\u9fff" # 汉字
|
||||||
'a' <= char <= 'z' or # 小写字母
|
or "a" <= char <= "z" # 小写字母
|
||||||
'A' <= char <= 'Z' or # 大写字母
|
or "A" <= char <= "Z" # 大写字母
|
||||||
'0' <= char <= '9' # 数字
|
or "0" <= char <= "9" # 数字
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -305,7 +303,9 @@ class JargonMiner:
|
|||||||
# 计算要保留的数量(至少保留1个)
|
# 计算要保留的数量(至少保留1个)
|
||||||
keep_count = max(1, len(raw_content_list) // 2)
|
keep_count = max(1, len(raw_content_list) // 2)
|
||||||
raw_content_list = random.sample(raw_content_list, keep_count)
|
raw_content_list = random.sample(raw_content_list, keep_count)
|
||||||
logger.info(f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目")
|
logger.info(
|
||||||
|
f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目"
|
||||||
|
)
|
||||||
|
|
||||||
# 步骤1: 基于raw_content和content推断
|
# 步骤1: 基于raw_content和content推断
|
||||||
raw_content_text = "\n".join(raw_content_list)
|
raw_content_text = "\n".join(raw_content_list)
|
||||||
@@ -318,7 +318,9 @@ class JargonMiner:
|
|||||||
**上一次推断的含义(仅供参考)**
|
**上一次推断的含义(仅供参考)**
|
||||||
{previous_meaning}
|
{previous_meaning}
|
||||||
"""
|
"""
|
||||||
previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
|
previous_meaning_instruction = (
|
||||||
|
"- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
|
||||||
|
)
|
||||||
|
|
||||||
prompt1 = await global_prompt_manager.format_prompt(
|
prompt1 = await global_prompt_manager.format_prompt(
|
||||||
"jargon_inference_with_context_prompt",
|
"jargon_inference_with_context_prompt",
|
||||||
@@ -660,7 +662,9 @@ class JargonMiner:
|
|||||||
if obj.raw_content:
|
if obj.raw_content:
|
||||||
try:
|
try:
|
||||||
existing_raw_content = (
|
existing_raw_content = (
|
||||||
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
json.loads(obj.raw_content)
|
||||||
|
if isinstance(obj.raw_content, str)
|
||||||
|
else obj.raw_content
|
||||||
)
|
)
|
||||||
if not isinstance(existing_raw_content, list):
|
if not isinstance(existing_raw_content, list):
|
||||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||||
@@ -899,8 +903,6 @@ class JargonMinerManager:
|
|||||||
miner_manager = JargonMinerManager()
|
miner_manager = JargonMinerManager()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def search_jargon(
|
def search_jargon(
|
||||||
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
|
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
|
|||||||
@@ -1,34 +1,16 @@
|
|||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Any, Optional
|
from typing import List, Any
|
||||||
from collections import OrderedDict
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
||||||
from src.bw_learner.expression_learner import expression_learner_manager
|
from src.bw_learner.expression_learner import expression_learner_manager
|
||||||
from src.bw_learner.jargon_miner import miner_manager
|
from src.bw_learner.jargon_miner import miner_manager
|
||||||
from src.person_info.person_info import Person
|
|
||||||
|
|
||||||
logger = get_logger("bw_learner")
|
logger = get_logger("bw_learner")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PersonInfo:
|
|
||||||
"""参与聊天的人物信息"""
|
|
||||||
user_id: str
|
|
||||||
user_platform: str
|
|
||||||
user_nickname: str
|
|
||||||
user_cardname: Optional[str]
|
|
||||||
person_name: str
|
|
||||||
last_seen_time: float # 最后发言时间
|
|
||||||
|
|
||||||
def get_unique_key(self) -> str:
|
|
||||||
"""获取唯一标识(用于去重)"""
|
|
||||||
return f"{self.user_platform}:{self.user_id}"
|
|
||||||
|
|
||||||
|
|
||||||
class MessageRecorder:
|
class MessageRecorder:
|
||||||
"""
|
"""
|
||||||
统一的消息记录器,负责管理时间窗口和消息提取,并将消息分发给 expression_learner 和 jargon_miner
|
统一的消息记录器,负责管理时间窗口和消息提取,并将消息分发给 expression_learner 和 jargon_miner
|
||||||
@@ -45,11 +27,6 @@ class MessageRecorder:
|
|||||||
# 提取锁,防止并发执行
|
# 提取锁,防止并发执行
|
||||||
self._extraction_lock = asyncio.Lock()
|
self._extraction_lock = asyncio.Lock()
|
||||||
|
|
||||||
# 维护参与该chat_id的人物列表(最多30个,使用OrderedDict保持插入顺序)
|
|
||||||
# key: f"{platform}:{user_id}", value: PersonInfo
|
|
||||||
self._person_list: OrderedDict[str, PersonInfo] = OrderedDict()
|
|
||||||
self._max_person_count = 30
|
|
||||||
|
|
||||||
# 获取 expression 和 jargon 的配置参数
|
# 获取 expression 和 jargon 的配置参数
|
||||||
self._init_parameters()
|
self._init_parameters()
|
||||||
|
|
||||||
@@ -134,17 +111,11 @@ class MessageRecorder:
|
|||||||
# 按时间排序,确保顺序一致
|
# 按时间排序,确保顺序一致
|
||||||
messages = sorted(messages, key=lambda msg: msg.time or 0)
|
messages = sorted(messages, key=lambda msg: msg.time or 0)
|
||||||
|
|
||||||
# 更新参与聊天的人物列表
|
|
||||||
self._update_person_list(messages)
|
|
||||||
|
|
||||||
logger.info(f"聊天流 {self.chat_name} 的人物列表: {self._person_list}")
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息,"
|
f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息,"
|
||||||
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
|
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 分别触发 expression_learner 和 jargon_miner 的处理
|
# 分别触发 expression_learner 和 jargon_miner 的处理
|
||||||
# 传递提取的消息,避免它们重复获取
|
# 传递提取的消息,避免它们重复获取
|
||||||
# 触发 expression 学习(如果启用)
|
# 触发 expression 学习(如果启用)
|
||||||
@@ -155,21 +126,19 @@ class MessageRecorder:
|
|||||||
|
|
||||||
# 触发 jargon 提取(如果启用),传递消息
|
# 触发 jargon 提取(如果启用),传递消息
|
||||||
# if self.enable_jargon_learning:
|
# if self.enable_jargon_learning:
|
||||||
# asyncio.create_task(
|
# asyncio.create_task(
|
||||||
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
|
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
|
||||||
# )
|
# )
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
|
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
# 即使失败也保持时间戳更新,避免频繁重试
|
# 即使失败也保持时间戳更新,避免频繁重试
|
||||||
|
|
||||||
async def _trigger_expression_learning(
|
async def _trigger_expression_learning(
|
||||||
self,
|
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
|
||||||
timestamp_start: float,
|
|
||||||
timestamp_end: float,
|
|
||||||
messages: List[Any]
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
触发 expression 学习,使用指定的消息列表
|
触发 expression 学习,使用指定的消息列表
|
||||||
@@ -180,11 +149,8 @@ class MessageRecorder:
|
|||||||
messages: 消息列表
|
messages: 消息列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 传递消息和过滤函数给 ExpressionLearner
|
# 传递消息给 ExpressionLearner(必需参数)
|
||||||
learnt_style = await self.expression_learner.learn_and_store(
|
learnt_style = await self.expression_learner.learn_and_store(messages=messages)
|
||||||
messages=messages,
|
|
||||||
person_name_filter=self.contains_person_name
|
|
||||||
)
|
|
||||||
|
|
||||||
if learnt_style:
|
if learnt_style:
|
||||||
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
||||||
@@ -193,13 +159,11 @@ class MessageRecorder:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}")
|
logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
async def _trigger_jargon_extraction(
|
async def _trigger_jargon_extraction(
|
||||||
self,
|
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
|
||||||
timestamp_start: float,
|
|
||||||
timestamp_end: float,
|
|
||||||
messages: List[Any]
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
触发 jargon 提取,使用指定的消息列表
|
触发 jargon 提取,使用指定的消息列表
|
||||||
@@ -210,124 +174,15 @@ class MessageRecorder:
|
|||||||
messages: 消息列表
|
messages: 消息列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 传递消息和过滤函数给 JargonMiner
|
# 传递消息给 JargonMiner,避免它重复获取
|
||||||
await self.jargon_miner.run_once(
|
await self.jargon_miner.run_once(messages=messages)
|
||||||
messages=messages,
|
|
||||||
person_name_filter=self.contains_person_name
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
|
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
def _update_person_list(self, messages: List[Any]) -> None:
|
|
||||||
"""
|
|
||||||
从消息中提取人物信息并更新人物列表
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: 消息列表
|
|
||||||
"""
|
|
||||||
for msg in messages:
|
|
||||||
# 获取消息发送者信息
|
|
||||||
# 消息对象可能是 DatabaseMessages,它有 user_info 属性
|
|
||||||
if hasattr(msg, 'user_info'):
|
|
||||||
# DatabaseMessages 类型
|
|
||||||
user_info = msg.user_info
|
|
||||||
user_id = getattr(user_info, 'user_id', None) or ''
|
|
||||||
user_platform = getattr(user_info, 'platform', None) or ''
|
|
||||||
user_nickname = getattr(user_info, 'user_nickname', None) or ''
|
|
||||||
user_cardname = getattr(user_info, 'user_cardname', None)
|
|
||||||
else:
|
|
||||||
# 直接属性访问
|
|
||||||
user_id = getattr(msg, 'user_id', None) or ''
|
|
||||||
user_platform = getattr(msg, 'user_platform', None) or ''
|
|
||||||
user_nickname = getattr(msg, 'user_nickname', None) or ''
|
|
||||||
user_cardname = getattr(msg, 'user_cardname', None)
|
|
||||||
|
|
||||||
msg_time = getattr(msg, 'time', time.time())
|
|
||||||
|
|
||||||
# 检查必要信息
|
|
||||||
if not user_id or not user_platform:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 获取 person_name
|
|
||||||
try:
|
|
||||||
person = Person(platform=user_platform, user_id=str(user_id))
|
|
||||||
person_name = person.person_name or user_nickname or (user_cardname if user_cardname else "未知用户")
|
|
||||||
except Exception as e:
|
|
||||||
logger.info(f"获取person_name失败: {e}, 使用nickname")
|
|
||||||
person_name = user_nickname or (user_cardname if user_cardname else "未知用户")
|
|
||||||
|
|
||||||
# 生成唯一key
|
|
||||||
unique_key = f"{user_platform}:{user_id}"
|
|
||||||
|
|
||||||
# 如果已存在,更新最后发言时间
|
|
||||||
if unique_key in self._person_list:
|
|
||||||
self._person_list[unique_key].last_seen_time = msg_time
|
|
||||||
# 移动到末尾(表示最近活跃)
|
|
||||||
self._person_list.move_to_end(unique_key)
|
|
||||||
else:
|
|
||||||
# 如果超过最大数量,移除最早的(最前面的)
|
|
||||||
if len(self._person_list) >= self._max_person_count:
|
|
||||||
oldest_key = next(iter(self._person_list))
|
|
||||||
del self._person_list[oldest_key]
|
|
||||||
logger.info(f"人物列表已满,移除最早的人物: {oldest_key}")
|
|
||||||
|
|
||||||
# 添加新人物
|
|
||||||
person_info = PersonInfo(
|
|
||||||
user_id=str(user_id),
|
|
||||||
user_platform=user_platform,
|
|
||||||
user_nickname=user_nickname or "",
|
|
||||||
user_cardname=user_cardname,
|
|
||||||
person_name=person_name,
|
|
||||||
last_seen_time=msg_time
|
|
||||||
)
|
|
||||||
self._person_list[unique_key] = person_info
|
|
||||||
logger.info(f"添加新人物到列表: {unique_key}, person_name={person_name}")
|
|
||||||
|
|
||||||
def contains_person_name(self, content: str) -> bool:
|
|
||||||
"""
|
|
||||||
检查内容是否包含任何参与聊天的人物的名称或昵称
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: 要检查的内容
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 如果包含任何人物名称或昵称,返回True
|
|
||||||
"""
|
|
||||||
if not content or not self._person_list:
|
|
||||||
return False
|
|
||||||
|
|
||||||
content_lower = content.strip().lower()
|
|
||||||
if not content_lower:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 检查所有人物
|
|
||||||
for person_info in self._person_list.values():
|
|
||||||
# 检查 person_name
|
|
||||||
if person_info.person_name:
|
|
||||||
person_name_lower = person_info.person_name.strip().lower()
|
|
||||||
if person_name_lower and person_name_lower in content_lower:
|
|
||||||
logger.debug(f"内容包含person_name: {person_info.person_name} in {content}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 检查 user_nickname
|
|
||||||
if person_info.user_nickname:
|
|
||||||
nickname_lower = person_info.user_nickname.strip().lower()
|
|
||||||
if nickname_lower and nickname_lower in content_lower:
|
|
||||||
logger.debug(f"内容包含nickname: {person_info.user_nickname} in {content}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 检查 user_cardname(群昵称)
|
|
||||||
if person_info.user_cardname:
|
|
||||||
cardname_lower = person_info.user_cardname.strip().lower()
|
|
||||||
if cardname_lower and cardname_lower in content_lower:
|
|
||||||
logger.debug(f"内容包含cardname: {person_info.user_cardname} in {content}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class MessageRecorderManager:
|
class MessageRecorderManager:
|
||||||
"""MessageRecorder 管理器"""
|
"""MessageRecorder 管理器"""
|
||||||
@@ -355,4 +210,3 @@ async def extract_and_distribute_messages(chat_id: str) -> None:
|
|||||||
"""
|
"""
|
||||||
recorder = recorder_manager.get_recorder(chat_id)
|
recorder = recorder_manager.get_recorder(chat_id)
|
||||||
await recorder.extract_and_distribute()
|
await recorder.extract_and_distribute()
|
||||||
|
|
||||||
|
|||||||
@@ -328,9 +328,7 @@ class BrainChatting:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否有 complete_talk 动作(会停止后续迭代)
|
# 检查是否有 complete_talk 动作(会停止后续迭代)
|
||||||
has_complete_talk = any(
|
has_complete_talk = any(action.action_type == "complete_talk" for action in action_to_use_info)
|
||||||
action.action_type == "complete_talk" for action in action_to_use_info
|
|
||||||
)
|
|
||||||
|
|
||||||
# 并行执行所有动作
|
# 并行执行所有动作
|
||||||
action_tasks = [
|
action_tasks = [
|
||||||
|
|||||||
@@ -204,7 +204,9 @@ class BrainPlanner:
|
|||||||
# 注意:listening 已合并到 wait 中,如果遇到 listening 则转换为 wait
|
# 注意:listening 已合并到 wait 中,如果遇到 listening 则转换为 wait
|
||||||
internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"]
|
internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"]
|
||||||
|
|
||||||
logger.debug(f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}")
|
logger.debug(
|
||||||
|
f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}"
|
||||||
|
)
|
||||||
|
|
||||||
# 将 listening 转换为 wait(向后兼容)
|
# 将 listening 转换为 wait(向后兼容)
|
||||||
if action == "listening":
|
if action == "listening":
|
||||||
@@ -521,7 +523,7 @@ class BrainPlanner:
|
|||||||
if json_objects:
|
if json_objects:
|
||||||
logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||||
for i, json_obj in enumerate(json_objects):
|
for i, json_obj in enumerate(json_objects):
|
||||||
logger.info(f"{self.log_prefix}解析第{i+1}个JSON对象: {json_obj}")
|
logger.info(f"{self.log_prefix}解析第{i + 1}个JSON对象: {json_obj}")
|
||||||
filtered_actions_list = list(filtered_actions.items())
|
filtered_actions_list = list(filtered_actions.items())
|
||||||
for json_obj in json_objects:
|
for json_obj in json_objects:
|
||||||
parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list)
|
parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list)
|
||||||
@@ -553,7 +555,9 @@ class BrainPlanner:
|
|||||||
|
|
||||||
return extracted_reasoning, actions
|
return extracted_reasoning, actions
|
||||||
|
|
||||||
def _create_complete_talk(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
|
def _create_complete_talk(
|
||||||
|
self, reasoning: str, available_actions: Dict[str, ActionInfo]
|
||||||
|
) -> List[ActionPlannerInfo]:
|
||||||
"""创建complete_talk"""
|
"""创建complete_talk"""
|
||||||
return [
|
return [
|
||||||
ActionPlannerInfo(
|
ActionPlannerInfo(
|
||||||
|
|||||||
@@ -271,7 +271,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
|
|||||||
|
|
||||||
emoji.description = emoji_data.description
|
emoji.description = emoji_data.description
|
||||||
# Deserialize emotion string from DB to list
|
# Deserialize emotion string from DB to list
|
||||||
emoji.emotion = emoji_data.emotion.replace(",",",").split(",") if emoji_data.emotion else []
|
emoji.emotion = emoji_data.emotion.replace(",", ",").split(",") if emoji_data.emotion else []
|
||||||
emoji.usage_count = emoji_data.usage_count
|
emoji.usage_count = emoji_data.usage_count
|
||||||
|
|
||||||
db_last_used_time = emoji_data.last_used_time
|
db_last_used_time = emoji_data.last_used_time
|
||||||
@@ -732,7 +732,7 @@ class EmojiManager:
|
|||||||
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||||
if emoji_record and emoji_record.emotion:
|
if emoji_record and emoji_record.emotion:
|
||||||
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
|
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
|
||||||
return emoji_record.emotion.replace(",",",").split(",")
|
return emoji_record.emotion.replace(",", ",").split(",")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
|
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
|
||||||
|
|
||||||
@@ -993,7 +993,7 @@ class EmojiManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 处理情感列表
|
# 处理情感列表
|
||||||
emotions = [e.strip() for e in emotions_text.replace(",",",").split(",") if e.strip()]
|
emotions = [e.strip() for e in emotions_text.replace(",", ",").split(",") if e.strip()]
|
||||||
|
|
||||||
# 根据情感标签数量随机选择 - 超过5个选3个,超过2个选2个
|
# 根据情感标签数量随机选择 - 超过5个选3个,超过2个选2个
|
||||||
if len(emotions) > 5:
|
if len(emotions) > 5:
|
||||||
|
|||||||
@@ -123,7 +123,11 @@ class ChatBot:
|
|||||||
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
|
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
|
||||||
|
|
||||||
# 根据命令的拦截设置决定是否继续处理消息
|
# 根据命令的拦截设置决定是否继续处理消息
|
||||||
return True, response, not bool(intercept_message_level) # 找到命令,根据intercept_message决定是否继续
|
return (
|
||||||
|
True,
|
||||||
|
response,
|
||||||
|
not bool(intercept_message_level),
|
||||||
|
) # 找到命令,根据intercept_message决定是否继续
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
|
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
|
||||||
|
|||||||
@@ -213,6 +213,68 @@ class MessageRecv(Message):
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
return ""
|
return ""
|
||||||
|
elif segment.type == "video_card":
|
||||||
|
# 处理视频卡片消息
|
||||||
|
self.is_picid = False
|
||||||
|
self.is_emoji = False
|
||||||
|
self.is_voice = False
|
||||||
|
if isinstance(segment.data, dict):
|
||||||
|
file_name = segment.data.get("file", "未知视频")
|
||||||
|
file_size = segment.data.get("file_size", "")
|
||||||
|
url = segment.data.get("url", "")
|
||||||
|
text = f"[视频: {file_name}"
|
||||||
|
if file_size:
|
||||||
|
text += f", 大小: {file_size}字节"
|
||||||
|
text += "]"
|
||||||
|
if url:
|
||||||
|
text += f" 链接: {url}"
|
||||||
|
return text
|
||||||
|
return "[视频]"
|
||||||
|
elif segment.type == "music_card":
|
||||||
|
# 处理音乐卡片消息
|
||||||
|
self.is_picid = False
|
||||||
|
self.is_emoji = False
|
||||||
|
self.is_voice = False
|
||||||
|
if isinstance(segment.data, dict):
|
||||||
|
title = segment.data.get("title", "未知歌曲")
|
||||||
|
singer = segment.data.get("singer", "")
|
||||||
|
tag = segment.data.get("tag", "") # 音乐来源,如"网易云音乐"
|
||||||
|
jump_url = segment.data.get("jump_url", "")
|
||||||
|
music_url = segment.data.get("music_url", "")
|
||||||
|
text = f"[音乐: {title}"
|
||||||
|
if singer:
|
||||||
|
text += f" - {singer}"
|
||||||
|
if tag:
|
||||||
|
text += f" ({tag})"
|
||||||
|
text += "]"
|
||||||
|
if jump_url:
|
||||||
|
text += f" 跳转链接: {jump_url}"
|
||||||
|
if music_url:
|
||||||
|
text += f" 音乐链接: {music_url}"
|
||||||
|
return text
|
||||||
|
return "[音乐]"
|
||||||
|
elif segment.type == "miniapp_card":
|
||||||
|
# 处理小程序分享卡片(如B站视频分享)
|
||||||
|
self.is_picid = False
|
||||||
|
self.is_emoji = False
|
||||||
|
self.is_voice = False
|
||||||
|
if isinstance(segment.data, dict):
|
||||||
|
title = segment.data.get("title", "") # 小程序名称
|
||||||
|
desc = segment.data.get("desc", "") # 内容描述
|
||||||
|
source_url = segment.data.get("source_url", "") # 原始链接
|
||||||
|
url = segment.data.get("url", "") # 小程序链接
|
||||||
|
text = "[小程序分享"
|
||||||
|
if title:
|
||||||
|
text += f" - {title}"
|
||||||
|
text += "]"
|
||||||
|
if desc:
|
||||||
|
text += f" {desc}"
|
||||||
|
if source_url:
|
||||||
|
text += f" 链接: {source_url}"
|
||||||
|
elif url:
|
||||||
|
text += f" 链接: {url}"
|
||||||
|
return text
|
||||||
|
return "[小程序分享]"
|
||||||
else:
|
else:
|
||||||
return ""
|
return ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ def parse_message_segments(segment) -> list:
|
|||||||
Returns:
|
Returns:
|
||||||
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
|
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
|
||||||
"""
|
"""
|
||||||
from maim_message import Seg
|
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
@@ -112,9 +111,13 @@ def parse_message_segments(segment) -> list:
|
|||||||
forward_items = []
|
forward_items = []
|
||||||
if segment.data:
|
if segment.data:
|
||||||
for item in segment.data:
|
for item in segment.data:
|
||||||
forward_items.append({
|
forward_items.append(
|
||||||
"content": parse_message_segments(item.get("message_segment", {})) if isinstance(item, dict) else []
|
{
|
||||||
})
|
"content": parse_message_segments(item.get("message_segment", {}))
|
||||||
|
if isinstance(item, dict)
|
||||||
|
else []
|
||||||
|
}
|
||||||
|
)
|
||||||
result.append({"type": "forward", "data": forward_items})
|
result.append({"type": "forward", "data": forward_items})
|
||||||
else:
|
else:
|
||||||
# 未知类型,尝试作为文本处理
|
# 未知类型,尝试作为文本处理
|
||||||
|
|||||||
@@ -78,7 +78,6 @@ target_message_id为必填,表示触发消息的id
|
|||||||
"planner_prompt",
|
"planner_prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
{action_name}
|
{action_name}
|
||||||
|
|||||||
@@ -250,7 +250,12 @@ class DefaultReplyer:
|
|||||||
# 使用从处理器传来的选中表达方式
|
# 使用从处理器传来的选中表达方式
|
||||||
# 使用模型预测选择表达方式
|
# 使用模型预测选择表达方式
|
||||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason, think_level=think_level
|
self.chat_stream.stream_id,
|
||||||
|
chat_history,
|
||||||
|
max_num=8,
|
||||||
|
target_message=target,
|
||||||
|
reply_reason=reply_reason,
|
||||||
|
think_level=think_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
if selected_expressions:
|
if selected_expressions:
|
||||||
@@ -273,7 +278,6 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||||
|
|
||||||
|
|
||||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||||
"""构建工具信息块
|
"""构建工具信息块
|
||||||
|
|
||||||
@@ -788,7 +792,8 @@ class DefaultReplyer:
|
|||||||
# 并行执行八个构建任务(包括黑话解释)
|
# 并行执行八个构建任务(包括黑话解释)
|
||||||
task_results = await asyncio.gather(
|
task_results = await asyncio.gather(
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level), "expression_habits"
|
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level),
|
||||||
|
"expression_habits",
|
||||||
),
|
),
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
||||||
@@ -980,7 +985,6 @@ class DefaultReplyer:
|
|||||||
else:
|
else:
|
||||||
reply_target_block = ""
|
reply_target_block = ""
|
||||||
|
|
||||||
|
|
||||||
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||||
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
|
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
|
||||||
|
|
||||||
|
|||||||
@@ -287,7 +287,6 @@ class PrivateReplyer:
|
|||||||
|
|
||||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||||
|
|
||||||
|
|
||||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||||
"""构建工具信息块
|
"""构建工具信息块
|
||||||
|
|
||||||
@@ -907,16 +906,11 @@ class PrivateReplyer:
|
|||||||
else:
|
else:
|
||||||
reply_target_block = ""
|
reply_target_block = ""
|
||||||
|
|
||||||
|
|
||||||
chat_target_name = "对方"
|
chat_target_name = "对方"
|
||||||
if self.chat_target_info:
|
if self.chat_target_info:
|
||||||
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
|
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
|
||||||
chat_target_1 = await global_prompt_manager.format_prompt(
|
chat_target_1 = await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name)
|
||||||
"chat_target_private1", sender_name=chat_target_name
|
chat_target_2 = await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name)
|
||||||
)
|
|
||||||
chat_target_2 = await global_prompt_manager.format_prompt(
|
|
||||||
"chat_target_private2", sender_name=chat_target_name
|
|
||||||
)
|
|
||||||
|
|
||||||
template_name = "default_expressor_prompt"
|
template_name = "default_expressor_prompt"
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from src.chat.utils.prompt_builder import Prompt
|
from src.chat.utils.prompt_builder import Prompt
|
||||||
|
|
||||||
|
|
||||||
def init_replyer_private_prompt():
|
def init_replyer_private_prompt():
|
||||||
Prompt(
|
Prompt(
|
||||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||||
|
|
||||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||||
@@ -17,8 +18,8 @@ def init_replyer_private_prompt():
|
|||||||
{reply_style}
|
{reply_style}
|
||||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
|
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
|
||||||
"private_replyer_prompt",
|
"private_replyer_prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||||
|
|||||||
@@ -44,4 +44,3 @@ def init_replyer_prompt():
|
|||||||
现在,你说:""",
|
现在,你说:""",
|
||||||
"replyer_prompt",
|
"replyer_prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -311,7 +311,10 @@ def get_raw_msg_before_timestamp_with_chat(
|
|||||||
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
|
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
|
||||||
sort_order = [("time", 1)]
|
sort_order = [("time", 1)]
|
||||||
return find_messages(
|
return find_messages(
|
||||||
message_filter=filter_query, sort=sort_order, limit=limit, filter_intercept_message_level=filter_intercept_message_level
|
message_filter=filter_query,
|
||||||
|
sort=sort_order,
|
||||||
|
limit=limit,
|
||||||
|
filter_intercept_message_level=filter_intercept_message_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -132,9 +132,7 @@ class ImageManager:
|
|||||||
deleted_images = Images.delete().where(Images.type == "emoji").execute()
|
deleted_images = Images.delete().where(Images.type == "emoji").execute()
|
||||||
|
|
||||||
# 清理ImageDescriptions表中type为emoji的记录
|
# 清理ImageDescriptions表中type为emoji的记录
|
||||||
deleted_descriptions = (
|
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
||||||
ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
|
||||||
)
|
|
||||||
|
|
||||||
total_deleted = deleted_images + deleted_descriptions
|
total_deleted = deleted_images + deleted_descriptions
|
||||||
if total_deleted > 0:
|
if total_deleted > 0:
|
||||||
@@ -236,10 +234,14 @@ class ImageManager:
|
|||||||
# 优先使用情感标签,如果没有则使用详细描述
|
# 优先使用情感标签,如果没有则使用详细描述
|
||||||
result_text = ""
|
result_text = ""
|
||||||
if cache_record.emotion_tags:
|
if cache_record.emotion_tags:
|
||||||
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...")
|
logger.info(
|
||||||
|
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
|
||||||
|
)
|
||||||
result_text = f"[表情包:{cache_record.emotion_tags}]"
|
result_text = f"[表情包:{cache_record.emotion_tags}]"
|
||||||
elif cache_record.description:
|
elif cache_record.description:
|
||||||
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...")
|
logger.info(
|
||||||
|
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
|
||||||
|
)
|
||||||
result_text = f"[表情包:{cache_record.description}]"
|
result_text = f"[表情包:{cache_record.description}]"
|
||||||
|
|
||||||
# 即使缓存命中,如果启用了steal_emoji,也检查是否需要保存文件
|
# 即使缓存命中,如果启用了steal_emoji,也检查是否需要保存文件
|
||||||
|
|||||||
@@ -609,10 +609,10 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
|
|||||||
fields = list(model._meta.fields.keys())
|
fields = list(model._meta.fields.keys())
|
||||||
# Peewee 默认使用 'id' 作为主键字段名
|
# Peewee 默认使用 'id' 作为主键字段名
|
||||||
# 尝试获取主键字段名,如果获取失败则默认使用 'id'
|
# 尝试获取主键字段名,如果获取失败则默认使用 'id'
|
||||||
primary_key_name = 'id' # 默认值
|
primary_key_name = "id" # 默认值
|
||||||
try:
|
try:
|
||||||
if hasattr(model._meta, 'primary_key') and model._meta.primary_key:
|
if hasattr(model._meta, "primary_key") and model._meta.primary_key:
|
||||||
if hasattr(model._meta.primary_key, 'name'):
|
if hasattr(model._meta.primary_key, "name"):
|
||||||
primary_key_name = model._meta.primary_key.name
|
primary_key_name = model._meta.primary_key.name
|
||||||
elif isinstance(model._meta.primary_key, str):
|
elif isinstance(model._meta.primary_key, str):
|
||||||
primary_key_name = model._meta.primary_key
|
primary_key_name = model._meta.primary_key
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
# 决定是否多行:仅在顶层且长度超过阈值时
|
# 决定是否多行:仅在顶层且长度超过阈值时
|
||||||
should_multiline = (depth == 0 and len(obj) > threshold)
|
should_multiline = depth == 0 and len(obj) > threshold
|
||||||
|
|
||||||
# 如果已经是 tomlkit Array,原地修改以保留注释
|
# 如果已经是 tomlkit Array,原地修改以保留注释
|
||||||
if isinstance(obj, Array):
|
if isinstance(obj, Array):
|
||||||
@@ -112,7 +112,7 @@ def save_toml_with_format(
|
|||||||
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
||||||
output = tomlkit.dumps(formatted)
|
output = tomlkit.dumps(formatted)
|
||||||
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
|
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
|
||||||
output = re.sub(r'\n{3,}', '\n\n', output)
|
output = re.sub(r"\n{3,}", "\n\n", output)
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
f.write(output)
|
f.write(output)
|
||||||
|
|
||||||
@@ -122,4 +122,4 @@ def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
|
|||||||
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
||||||
output = tomlkit.dumps(formatted)
|
output = tomlkit.dumps(formatted)
|
||||||
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
|
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
|
||||||
return re.sub(r'\n{3,}', '\n\n', output)
|
return re.sub(r"\n{3,}", "\n\n", output)
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import json
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from peewee import fn
|
from peewee import fn
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.common.database.database_model import ChatHistory, Jargon
|
from src.common.database.database_model import ChatHistory
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
||||||
from src.plugin_system.apis import llm_api
|
from src.plugin_system.apis import llm_api
|
||||||
@@ -82,7 +81,6 @@ def init_dream_prompts() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DreamTool:
|
class DreamTool:
|
||||||
"""dream 模块内部使用的简易工具封装"""
|
"""dream 模块内部使用的简易工具封装"""
|
||||||
|
|
||||||
@@ -150,7 +148,13 @@ def init_dream_tools(chat_id: str) -> None:
|
|||||||
"search_chat_history",
|
"search_chat_history",
|
||||||
"根据关键词或参与人查询当前 chat_id 下的 ChatHistory 概览,便于快速定位相关记忆。",
|
"根据关键词或参与人查询当前 chat_id 下的 ChatHistory 概览,便于快速定位相关记忆。",
|
||||||
[
|
[
|
||||||
("keyword", ToolParamType.STRING, "关键词(可选,支持多个关键词,可用空格、逗号等分隔)。", False, None),
|
(
|
||||||
|
"keyword",
|
||||||
|
ToolParamType.STRING,
|
||||||
|
"关键词(可选,支持多个关键词,可用空格、逗号等分隔)。",
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
),
|
||||||
("participant", ToolParamType.STRING, "参与人昵称(可选)。", False, None),
|
("participant", ToolParamType.STRING, "参与人昵称(可选)。", False, None),
|
||||||
],
|
],
|
||||||
search_chat_history,
|
search_chat_history,
|
||||||
@@ -201,8 +205,20 @@ def init_dream_tools(chat_id: str) -> None:
|
|||||||
[
|
[
|
||||||
("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None),
|
("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None),
|
||||||
("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None),
|
("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None),
|
||||||
("keywords", ToolParamType.STRING, "新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。", True, None),
|
(
|
||||||
("key_point", ToolParamType.STRING, "新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。", True, None),
|
"keywords",
|
||||||
|
ToolParamType.STRING,
|
||||||
|
"新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。",
|
||||||
|
True,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"key_point",
|
||||||
|
ToolParamType.STRING,
|
||||||
|
"新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。",
|
||||||
|
True,
|
||||||
|
None,
|
||||||
|
),
|
||||||
("start_time", ToolParamType.STRING, "起始时间戳(秒,Unix 时间,必填)。", True, None),
|
("start_time", ToolParamType.STRING, "起始时间戳(秒,Unix 时间,必填)。", True, None),
|
||||||
("end_time", ToolParamType.STRING, "结束时间戳(秒,Unix 时间,必填)。", True, None),
|
("end_time", ToolParamType.STRING, "结束时间戳(秒,Unix 时间,必填)。", True, None),
|
||||||
],
|
],
|
||||||
@@ -215,7 +231,13 @@ def init_dream_tools(chat_id: str) -> None:
|
|||||||
"finish_maintenance",
|
"finish_maintenance",
|
||||||
"结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理、合并或修改的内容时,调用此工具来主动结束本次运行。",
|
"结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理、合并或修改的内容时,调用此工具来主动结束本次运行。",
|
||||||
[
|
[
|
||||||
("reason", ToolParamType.STRING, "结束维护的原因说明(可选),例如 '已完成所有记录的整理' 或 '当前记录质量良好,无需进一步维护'。", False, None),
|
(
|
||||||
|
"reason",
|
||||||
|
ToolParamType.STRING,
|
||||||
|
"结束维护的原因说明(可选),例如 '已完成所有记录的整理' 或 '当前记录质量良好,无需进一步维护'。",
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
finish_maintenance,
|
finish_maintenance,
|
||||||
)
|
)
|
||||||
@@ -282,9 +304,7 @@ async def run_dream_agent_once(
|
|||||||
else "未知"
|
else "未知"
|
||||||
)
|
)
|
||||||
end_time_str = (
|
end_time_str = (
|
||||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time))
|
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
|
||||||
if record.end_time
|
|
||||||
else "未知"
|
|
||||||
)
|
)
|
||||||
detail_text = (
|
detail_text = (
|
||||||
f"ID={record.id}\n"
|
f"ID={record.id}\n"
|
||||||
@@ -305,8 +325,7 @@ async def run_dream_agent_once(
|
|||||||
start_detail_builder = MessageBuilder()
|
start_detail_builder = MessageBuilder()
|
||||||
start_detail_builder.set_role(RoleType.User)
|
start_detail_builder.set_role(RoleType.User)
|
||||||
start_detail_builder.add_text_content(
|
start_detail_builder.add_text_content(
|
||||||
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n"
|
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n" + detail_text
|
||||||
+ detail_text
|
|
||||||
)
|
)
|
||||||
conversation_messages.append(start_detail_builder.build())
|
conversation_messages.append(start_detail_builder.build())
|
||||||
else:
|
else:
|
||||||
@@ -343,13 +362,17 @@ async def run_dream_agent_once(
|
|||||||
conversation_messages.append(round_info_builder.build())
|
conversation_messages.append(round_info_builder.build())
|
||||||
|
|
||||||
# 调用 LLM 让其决定是否要使用工具
|
# 调用 LLM 让其决定是否要使用工具
|
||||||
success, response, reasoning_content, model_name, tool_calls = (
|
(
|
||||||
await llm_api.generate_with_model_with_tools_by_message_factory(
|
success,
|
||||||
message_factory,
|
response,
|
||||||
model_config=model_config.model_task_config.tool_use,
|
reasoning_content,
|
||||||
tool_options=tool_defs,
|
model_name,
|
||||||
request_type="dream.react",
|
tool_calls,
|
||||||
)
|
) = await llm_api.generate_with_model_with_tools_by_message_factory(
|
||||||
|
message_factory,
|
||||||
|
model_config=model_config.model_task_config.tool_use,
|
||||||
|
tool_options=tool_defs,
|
||||||
|
request_type="dream.react",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
@@ -555,4 +578,3 @@ async def start_dream_scheduler(
|
|||||||
|
|
||||||
# 初始化提示词
|
# 初始化提示词
|
||||||
init_dream_prompts()
|
init_dream_prompts()
|
||||||
|
|
||||||
|
|||||||
@@ -110,13 +110,17 @@ async def generate_dream_summary(
|
|||||||
thought_content = ""
|
thought_content = ""
|
||||||
if msg.content:
|
if msg.content:
|
||||||
if isinstance(msg.content, list) and msg.content:
|
if isinstance(msg.content, list) and msg.content:
|
||||||
thought_content = msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
|
thought_content = (
|
||||||
|
msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
thought_content = str(msg.content)
|
thought_content = str(msg.content)
|
||||||
|
|
||||||
logger.info(f"[dream][工具调用详情] === 第 {tool_call_count} 组工具调用 ===")
|
logger.info(f"[dream][工具调用详情] === 第 {tool_call_count} 组工具调用 ===")
|
||||||
if thought_content:
|
if thought_content:
|
||||||
logger.info(f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}")
|
logger.info(
|
||||||
|
f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}"
|
||||||
|
)
|
||||||
|
|
||||||
# 记录每个工具调用的详细信息
|
# 记录每个工具调用的详细信息
|
||||||
for idx, tool_call in enumerate(msg.tool_calls, 1):
|
for idx, tool_call in enumerate(msg.tool_calls, 1):
|
||||||
@@ -167,7 +171,7 @@ async def generate_dream_summary(
|
|||||||
|
|
||||||
# 随机选择2个梦境风格
|
# 随机选择2个梦境风格
|
||||||
selected_styles = get_random_dream_styles(2)
|
selected_styles = get_random_dream_styles(2)
|
||||||
dream_styles_text = "\n".join([f"{i+1}. {style}" for i, style in enumerate(selected_styles)])
|
dream_styles_text = "\n".join([f"{i + 1}. {style}" for i, style in enumerate(selected_styles)])
|
||||||
|
|
||||||
# 使用 Prompt 管理器格式化梦境生成 prompt
|
# 使用 Prompt 管理器格式化梦境生成 prompt
|
||||||
dream_prompt = await global_prompt_manager.format_prompt(
|
dream_prompt = await global_prompt_manager.format_prompt(
|
||||||
@@ -195,4 +199,5 @@ async def generate_dream_summary(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[dream][梦境总结] 生成梦境总结失败: {e}", exc_info=True)
|
logger.error(f"[dream][梦境总结] 生成梦境总结失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
init_dream_summary_prompt()
|
init_dream_summary_prompt()
|
||||||
@@ -4,8 +4,3 @@ dream agent 工具实现模块。
|
|||||||
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
|
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
|
||||||
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
|
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -60,8 +60,3 @@ def make_create_chat_history(chat_id: str):
|
|||||||
return f"create_chat_history 执行失败: {e}"
|
return f"create_chat_history 执行失败: {e}"
|
||||||
|
|
||||||
return create_chat_history
|
return create_chat_history
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,8 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,
|
|||||||
return f"delete_chat_history 执行失败: {e}"
|
return f"delete_chat_history 执行失败: {e}"
|
||||||
|
|
||||||
return delete_chat_history
|
return delete_chat_history
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,8 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留
|
|||||||
return f"delete_jargon 执行失败: {e}"
|
return f"delete_jargon 执行失败: {e}"
|
||||||
|
|
||||||
return delete_jargon
|
return delete_jargon
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
return finish_maintenance
|
return finish_maintenance
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import ChatHistory
|
from src.common.database.database_model import ChatHistory
|
||||||
@@ -20,14 +19,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
|
|||||||
|
|
||||||
# 将时间戳转换为可读时间格式
|
# 将时间戳转换为可读时间格式
|
||||||
start_time_str = (
|
start_time_str = (
|
||||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time))
|
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time)) if record.start_time else "未知"
|
||||||
if record.start_time
|
|
||||||
else "未知"
|
|
||||||
)
|
)
|
||||||
end_time_str = (
|
end_time_str = (
|
||||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time))
|
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
|
||||||
if record.end_time
|
|
||||||
else "未知"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = (
|
result = (
|
||||||
@@ -40,17 +35,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
|
|||||||
f"概括={record.summary or '无'}\n"
|
f"概括={record.summary or '无'}\n"
|
||||||
f"关键信息={record.key_point or '无'}"
|
f"关键信息={record.key_point or '无'}"
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}")
|
||||||
f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}"
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"get_chat_history_detail 失败: {e}")
|
logger.error(f"get_chat_history_detail 失败: {e}")
|
||||||
return f"get_chat_history_detail 执行失败: {e}"
|
return f"get_chat_history_detail 执行失败: {e}"
|
||||||
|
|
||||||
return get_chat_history_detail
|
return get_chat_history_detail
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -78,9 +78,7 @@ def make_search_chat_history(chat_id: str):
|
|||||||
if record.keywords:
|
if record.keywords:
|
||||||
try:
|
try:
|
||||||
keywords_data = (
|
keywords_data = (
|
||||||
json.loads(record.keywords)
|
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||||
if isinstance(record.keywords, str)
|
|
||||||
else record.keywords
|
|
||||||
)
|
)
|
||||||
if isinstance(keywords_data, list):
|
if isinstance(keywords_data, list):
|
||||||
record_keywords_list = [str(k).lower() for k in keywords_data]
|
record_keywords_list = [str(k).lower() for k in keywords_data]
|
||||||
@@ -125,9 +123,7 @@ def make_search_chat_history(chat_id: str):
|
|||||||
keywords_str = "、".join(keywords_list)
|
keywords_str = "、".join(keywords_list)
|
||||||
if len(keywords_list) > 2:
|
if len(keywords_list) > 2:
|
||||||
required_count = len(keywords_list) - 1
|
required_count = len(keywords_list) - 1
|
||||||
return (
|
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
||||||
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
|
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
|
||||||
elif participant:
|
elif participant:
|
||||||
@@ -142,9 +138,7 @@ def make_search_chat_history(chat_id: str):
|
|||||||
if record.keywords:
|
if record.keywords:
|
||||||
try:
|
try:
|
||||||
keywords_data = (
|
keywords_data = (
|
||||||
json.loads(record.keywords)
|
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||||
if isinstance(record.keywords, str)
|
|
||||||
else record.keywords
|
|
||||||
)
|
)
|
||||||
if isinstance(keywords_data, list):
|
if isinstance(keywords_data, list):
|
||||||
for k in keywords_data:
|
for k in keywords_data:
|
||||||
@@ -160,13 +154,13 @@ def make_search_chat_history(chat_id: str):
|
|||||||
keywords_str = "、".join(sorted(all_keywords_set))
|
keywords_str = "、".join(sorted(all_keywords_set))
|
||||||
response_text = (
|
response_text = (
|
||||||
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
||||||
f"有关\"{search_label}\"的关键词:\n"
|
f'有关"{search_label}"的关键词:\n'
|
||||||
f"{keywords_str}"
|
f"{keywords_str}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response_text = (
|
response_text = (
|
||||||
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
||||||
f"有关\"{search_label}\"的关键词信息为空"
|
f'有关"{search_label}"的关键词信息为空'
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -192,9 +186,7 @@ def make_search_chat_history(chat_id: str):
|
|||||||
if record.keywords:
|
if record.keywords:
|
||||||
try:
|
try:
|
||||||
keywords_data = (
|
keywords_data = (
|
||||||
json.loads(record.keywords)
|
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||||
if isinstance(record.keywords, str)
|
|
||||||
else record.keywords
|
|
||||||
)
|
)
|
||||||
if isinstance(keywords_data, list) and keywords_data:
|
if isinstance(keywords_data, list) and keywords_data:
|
||||||
keywords_str = "、".join([str(k) for k in keywords_data])
|
keywords_str = "、".join([str(k) for k in keywords_data])
|
||||||
@@ -220,8 +212,3 @@ def make_search_chat_history(chat_id: str):
|
|||||||
return f"search_chat_history 执行失败: {e}"
|
return f"search_chat_history 执行失败: {e}"
|
||||||
|
|
||||||
return search_chat_history
|
return search_chat_history
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,9 +16,7 @@ def make_search_jargon(chat_id: str):
|
|||||||
if not keyword or not keyword.strip():
|
if not keyword or not keyword.strip():
|
||||||
return "未指定查询关键词(参数 keyword 为必填,且不能为空)"
|
return "未指定查询关键词(参数 keyword 为必填,且不能为空)"
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})")
|
||||||
f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 基础条件:只查 is_jargon=True 的记录
|
# 基础条件:只查 is_jargon=True 的记录
|
||||||
query = Jargon.select().where(Jargon.is_jargon)
|
query = Jargon.select().where(Jargon.is_jargon)
|
||||||
@@ -102,5 +100,3 @@ def make_search_jargon(chat_id: str):
|
|||||||
return f"search_jargon 执行失败: {e}"
|
return f"search_jargon 执行失败: {e}"
|
||||||
|
|
||||||
return search_jargon
|
return search_jargon
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -49,8 +49,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
|
|||||||
return f"update_chat_history 执行失败: {e}"
|
return f"update_chat_history 执行失败: {e}"
|
||||||
|
|
||||||
return update_chat_history
|
return update_chat_history
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -49,8 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留
|
|||||||
return f"update_jargon 执行失败: {e}"
|
return f"update_jargon 执行失败: {e}"
|
||||||
|
|
||||||
return update_jargon
|
return update_jargon
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -316,7 +316,9 @@ class ChatHistorySummarizer:
|
|||||||
before_count = len(self.current_batch.messages)
|
before_count = len(self.current_batch.messages)
|
||||||
self.current_batch.messages.extend(new_messages)
|
self.current_batch.messages.extend(new_messages)
|
||||||
self.current_batch.end_time = current_time
|
self.current_batch.end_time = current_time
|
||||||
logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息")
|
logger.info(
|
||||||
|
f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息"
|
||||||
|
)
|
||||||
# 更新批次后持久化
|
# 更新批次后持久化
|
||||||
self._persist_topic_cache()
|
self._persist_topic_cache()
|
||||||
else:
|
else:
|
||||||
@@ -362,9 +364,7 @@ class ChatHistorySummarizer:
|
|||||||
else:
|
else:
|
||||||
time_str = f"{time_since_last_check / 3600:.1f}小时"
|
time_str = f"{time_since_last_check / 3600:.1f}小时"
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}")
|
||||||
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查“话题检查”触发条件
|
# 检查“话题检查”触发条件
|
||||||
should_check = False
|
should_check = False
|
||||||
@@ -427,7 +427,9 @@ class ChatHistorySummarizer:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 2. 构造编号后的消息字符串和参与者信息
|
# 2. 构造编号后的消息字符串和参与者信息
|
||||||
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages)
|
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = (
|
||||||
|
self._build_numbered_messages_for_llm(messages)
|
||||||
|
)
|
||||||
|
|
||||||
# 3. 调用 LLM 识别话题,并得到 topic -> indices(失败时最多重试 3 次)
|
# 3. 调用 LLM 识别话题,并得到 topic -> indices(失败时最多重试 3 次)
|
||||||
existing_topics = list(self.topic_cache.keys())
|
existing_topics = list(self.topic_cache.keys())
|
||||||
@@ -456,9 +458,7 @@ class ChatHistorySummarizer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not success or not topic_to_indices:
|
if not success or not topic_to_indices:
|
||||||
logger.error(
|
logger.error(f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃")
|
||||||
f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃"
|
|
||||||
)
|
|
||||||
# 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks(保持原状)
|
# 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks(保持原状)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -610,9 +610,7 @@ class ChatHistorySummarizer:
|
|||||||
if not numbered_lines:
|
if not numbered_lines:
|
||||||
return False, {}
|
return False, {}
|
||||||
|
|
||||||
history_topics_block = (
|
history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
|
||||||
"\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
|
|
||||||
)
|
|
||||||
messages_block = "\n".join(numbered_lines)
|
messages_block = "\n".join(numbered_lines)
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
@@ -642,10 +640,10 @@ class ChatHistorySummarizer:
|
|||||||
else:
|
else:
|
||||||
# 如果没有找到代码块,尝试查找JSON数组的开始和结束位置
|
# 如果没有找到代码块,尝试查找JSON数组的开始和结束位置
|
||||||
# 查找第一个 [ 和最后一个 ]
|
# 查找第一个 [ 和最后一个 ]
|
||||||
start_idx = response.find('[')
|
start_idx = response.find("[")
|
||||||
end_idx = response.rfind(']')
|
end_idx = response.rfind("]")
|
||||||
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
|
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
|
||||||
json_str = response[start_idx:end_idx + 1].strip()
|
json_str = response[start_idx : end_idx + 1].strip()
|
||||||
else:
|
else:
|
||||||
# 如果还是找不到,尝试直接使用整个响应(移除可能的markdown标记)
|
# 如果还是找不到,尝试直接使用整个响应(移除可能的markdown标记)
|
||||||
json_str = response.strip()
|
json_str = response.strip()
|
||||||
@@ -942,4 +940,3 @@ class ChatHistorySummarizer:
|
|||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|
||||||
|
|||||||
@@ -98,7 +98,10 @@ def _convert_messages(
|
|||||||
content: List[Part] = []
|
content: List[Part] = []
|
||||||
for item in message.content:
|
for item in message.content:
|
||||||
if isinstance(item, tuple):
|
if isinstance(item, tuple):
|
||||||
image_format = "jpeg" if item[0].lower() == "jpg" else item[0].lower()
|
image_format = item[0].lower()
|
||||||
|
# 规范 JPEG MIME 类型后缀,统一使用 image/jpeg
|
||||||
|
if image_format in ("jpg", "jpeg"):
|
||||||
|
image_format = "jpeg"
|
||||||
content.append(Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{image_format}"))
|
content.append(Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{image_format}"))
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
content.append(Part.from_text(text=item))
|
content.append(Part.from_text(text=item))
|
||||||
|
|||||||
@@ -61,10 +61,16 @@ def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessagePara
|
|||||||
content = []
|
content = []
|
||||||
for item in message.content:
|
for item in message.content:
|
||||||
if isinstance(item, tuple):
|
if isinstance(item, tuple):
|
||||||
|
image_format = item[0].lower()
|
||||||
|
# 规范 JPEG MIME 类型后缀,统一使用 image/jpeg
|
||||||
|
if image_format in ("jpg", "jpeg"):
|
||||||
|
mime_suffix = "jpeg"
|
||||||
|
else:
|
||||||
|
mime_suffix = image_format
|
||||||
content.append(
|
content.append(
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {"url": f"data:image/{item[0].lower()};base64,{item[1]}"},
|
"image_url": {"url": f"data:image/{mime_suffix};base64,{item[1]}"},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
|
|||||||
@@ -366,7 +366,9 @@ class LLMRequest:
|
|||||||
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}")
|
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}")
|
||||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||||
|
|
||||||
logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}")
|
logger.warning(
|
||||||
|
f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}"
|
||||||
|
)
|
||||||
await asyncio.sleep(api_provider.retry_interval)
|
await asyncio.sleep(api_provider.retry_interval)
|
||||||
|
|
||||||
except NetworkConnectionError as e:
|
except NetworkConnectionError as e:
|
||||||
@@ -394,7 +396,9 @@ class LLMRequest:
|
|||||||
if e.status_code == 429 or e.status_code >= 500:
|
if e.status_code == 429 or e.status_code >= 500:
|
||||||
retry_remain -= 1
|
retry_remain -= 1
|
||||||
if retry_remain <= 0:
|
if retry_remain <= 0:
|
||||||
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}")
|
logger.error(
|
||||||
|
f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}"
|
||||||
|
)
|
||||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -540,7 +544,5 @@ class LLMRequest:
|
|||||||
if e.__cause__:
|
if e.__cause__:
|
||||||
original_error_type = type(e.__cause__).__name__
|
original_error_type = type(e.__cause__).__name__
|
||||||
original_error_msg = str(e.__cause__)
|
original_error_msg = str(e.__cause__)
|
||||||
return (
|
return f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
|
||||||
f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -113,7 +113,6 @@ class MainSystem:
|
|||||||
get_emoji_manager().initialize()
|
get_emoji_manager().initialize()
|
||||||
logger.info("表情包管理器初始化成功")
|
logger.info("表情包管理器初始化成功")
|
||||||
|
|
||||||
|
|
||||||
# 初始化聊天管理器
|
# 初始化聊天管理器
|
||||||
await get_chat_manager()._initialize()
|
await get_chat_manager()._initialize()
|
||||||
asyncio.create_task(get_chat_manager()._auto_save_task())
|
asyncio.create_task(get_chat_manager()._auto_save_task())
|
||||||
|
|||||||
@@ -136,8 +136,6 @@ def init_memory_retrieval_prompt():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _log_conversation_messages(
|
def _log_conversation_messages(
|
||||||
conversation_messages: List[Message],
|
conversation_messages: List[Message],
|
||||||
head_prompt: Optional[str] = None,
|
head_prompt: Optional[str] = None,
|
||||||
@@ -172,7 +170,9 @@ def _log_conversation_messages(
|
|||||||
|
|
||||||
# 构建单条消息的日志信息
|
# 构建单条消息的日志信息
|
||||||
# msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
|
# msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
|
||||||
msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
|
msg_info = (
|
||||||
|
f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
|
||||||
|
)
|
||||||
|
|
||||||
# if full_content:
|
# if full_content:
|
||||||
# msg_info += f"\n{full_content}"
|
# msg_info += f"\n{full_content}"
|
||||||
@@ -185,8 +185,7 @@ def _log_conversation_messages(
|
|||||||
msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}"
|
msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}"
|
||||||
|
|
||||||
# if msg.tool_call_id:
|
# if msg.tool_call_id:
|
||||||
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
|
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
|
||||||
|
|
||||||
|
|
||||||
log_lines.append(msg_info)
|
log_lines.append(msg_info)
|
||||||
|
|
||||||
@@ -365,7 +364,7 @@ async def _react_agent_solve_question(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# logger.info(
|
# logger.info(
|
||||||
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
|
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
|
||||||
# )
|
# )
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
@@ -467,10 +466,17 @@ async def _react_agent_solve_question(
|
|||||||
if parsed_found_answer:
|
if parsed_found_answer:
|
||||||
# 找到了答案
|
# 找到了答案
|
||||||
if parsed_answer:
|
if parsed_answer:
|
||||||
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": parsed_answer}})
|
step["actions"].append(
|
||||||
|
{
|
||||||
|
"action_type": "finish_search",
|
||||||
|
"action_params": {"found_answer": True, "answer": parsed_answer},
|
||||||
|
}
|
||||||
|
)
|
||||||
step["observations"] = ["检测到finish_search文本格式调用,找到答案"]
|
step["observations"] = ["检测到finish_search文本格式调用,找到答案"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}")
|
logger.info(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
|
||||||
|
)
|
||||||
|
|
||||||
_log_conversation_messages(
|
_log_conversation_messages(
|
||||||
conversation_messages,
|
conversation_messages,
|
||||||
@@ -481,10 +487,14 @@ async def _react_agent_solve_question(
|
|||||||
return True, parsed_answer, thinking_steps, False
|
return True, parsed_answer, thinking_steps, False
|
||||||
else:
|
else:
|
||||||
# found_answer为True但没有提供answer,视为错误,继续迭代
|
# found_answer为True但没有提供answer,视为错误,继续迭代
|
||||||
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer")
|
logger.warning(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 未找到答案,直接退出查询
|
# 未找到答案,直接退出查询
|
||||||
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
|
step["actions"].append(
|
||||||
|
{"action_type": "finish_search", "action_params": {"found_answer": False}}
|
||||||
|
)
|
||||||
step["observations"] = ["检测到finish_search文本格式调用,未找到答案"]
|
step["observations"] = ["检测到finish_search文本格式调用,未找到答案"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案")
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案")
|
||||||
@@ -522,10 +532,17 @@ async def _react_agent_solve_question(
|
|||||||
if finish_search_found:
|
if finish_search_found:
|
||||||
# 找到了答案
|
# 找到了答案
|
||||||
if finish_search_answer:
|
if finish_search_answer:
|
||||||
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": finish_search_answer}})
|
step["actions"].append(
|
||||||
|
{
|
||||||
|
"action_type": "finish_search",
|
||||||
|
"action_params": {"found_answer": True, "answer": finish_search_answer},
|
||||||
|
}
|
||||||
|
)
|
||||||
step["observations"] = ["检测到finish_search工具调用,找到答案"]
|
step["observations"] = ["检测到finish_search工具调用,找到答案"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}")
|
logger.info(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
|
||||||
|
)
|
||||||
|
|
||||||
_log_conversation_messages(
|
_log_conversation_messages(
|
||||||
conversation_messages,
|
conversation_messages,
|
||||||
@@ -536,7 +553,9 @@ async def _react_agent_solve_question(
|
|||||||
return True, finish_search_answer, thinking_steps, False
|
return True, finish_search_answer, thinking_steps, False
|
||||||
else:
|
else:
|
||||||
# found_answer为True但没有提供answer,视为错误
|
# found_answer为True但没有提供answer,视为错误
|
||||||
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer")
|
logger.warning(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 未找到答案,直接退出查询
|
# 未找到答案,直接退出查询
|
||||||
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
|
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
|
||||||
@@ -724,7 +743,13 @@ async def _react_agent_solve_question(
|
|||||||
max_iterations=max_iterations,
|
max_iterations=max_iterations,
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_success, eval_response, eval_reasoning_content, eval_model_name, eval_tool_calls = await llm_api.generate_with_model_with_tools(
|
(
|
||||||
|
eval_success,
|
||||||
|
eval_response,
|
||||||
|
eval_reasoning_content,
|
||||||
|
eval_model_name,
|
||||||
|
eval_tool_calls,
|
||||||
|
) = await llm_api.generate_with_model_with_tools(
|
||||||
evaluation_prompt,
|
evaluation_prompt,
|
||||||
model_config=model_config.model_task_config.tool_use,
|
model_config=model_config.model_task_config.tool_use,
|
||||||
tool_options=[], # 最终评估阶段不提供工具
|
tool_options=[], # 最终评估阶段不提供工具
|
||||||
@@ -759,7 +784,7 @@ async def _react_agent_solve_question(
|
|||||||
"iteration": current_iteration,
|
"iteration": current_iteration,
|
||||||
"thought": f"[最终评估] {eval_response}",
|
"thought": f"[最终评估] {eval_response}",
|
||||||
"actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}],
|
"actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}],
|
||||||
"observations": ["最终评估阶段检测到found_answer"]
|
"observations": ["最终评估阶段检测到found_answer"],
|
||||||
}
|
}
|
||||||
thinking_steps.append(eval_step)
|
thinking_steps.append(eval_step)
|
||||||
logger.info(f"ReAct Agent 最终评估阶段找到关于问题{question}的答案: {found_answer_content}")
|
logger.info(f"ReAct Agent 最终评估阶段找到关于问题{question}的答案: {found_answer_content}")
|
||||||
@@ -778,7 +803,7 @@ async def _react_agent_solve_question(
|
|||||||
"iteration": current_iteration,
|
"iteration": current_iteration,
|
||||||
"thought": f"[最终评估] {eval_response}",
|
"thought": f"[最终评估] {eval_response}",
|
||||||
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}],
|
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}],
|
||||||
"observations": ["最终评估阶段检测到not_enough_info"]
|
"observations": ["最终评估阶段检测到not_enough_info"],
|
||||||
}
|
}
|
||||||
thinking_steps.append(eval_step)
|
thinking_steps.append(eval_step)
|
||||||
logger.info(f"ReAct Agent 最终评估阶段判断信息不足: {not_enough_info_reason}")
|
logger.info(f"ReAct Agent 最终评估阶段判断信息不足: {not_enough_info_reason}")
|
||||||
@@ -795,8 +820,10 @@ async def _react_agent_solve_question(
|
|||||||
eval_step = {
|
eval_step = {
|
||||||
"iteration": current_iteration,
|
"iteration": current_iteration,
|
||||||
"thought": f"[最终评估] {eval_response}",
|
"thought": f"[最终评估] {eval_response}",
|
||||||
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}],
|
"actions": [
|
||||||
"observations": ["已到达最大迭代次数,无法找到答案"]
|
{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}
|
||||||
|
],
|
||||||
|
"observations": ["已到达最大迭代次数,无法找到答案"],
|
||||||
}
|
}
|
||||||
thinking_steps.append(eval_step)
|
thinking_steps.append(eval_step)
|
||||||
logger.info("ReAct Agent 已到达最大迭代次数,无法找到答案")
|
logger.info("ReAct Agent 已到达最大迭代次数,无法找到答案")
|
||||||
@@ -1129,7 +1156,9 @@ async def build_memory_retrieval_prompt(
|
|||||||
else:
|
else:
|
||||||
max_iterations = base_max_iterations
|
max_iterations = base_max_iterations
|
||||||
timeout_seconds = global_config.memory.agent_timeout_seconds
|
timeout_seconds = global_config.memory.agent_timeout_seconds
|
||||||
logger.debug(f"问题数量: {len(questions)},think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}秒")
|
logger.debug(
|
||||||
|
f"问题数量: {len(questions)},think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}秒"
|
||||||
|
)
|
||||||
|
|
||||||
# 并行处理所有问题,将概念检索结果作为初始信息传递
|
# 并行处理所有问题,将概念检索结果作为初始信息传递
|
||||||
question_tasks = [
|
question_tasks = [
|
||||||
@@ -1198,4 +1227,3 @@ async def build_memory_retrieval_prompt(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"记忆检索时发生异常: {str(e)}")
|
logger.error(f"记忆检索时发生异常: {str(e)}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from src.common.logger import get_logger
|
|||||||
logger = get_logger("memory_utils")
|
logger = get_logger("memory_utils")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
|
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
|
||||||
"""解析问题JSON,返回概念列表和问题列表
|
"""解析问题JSON,返回概念列表和问题列表
|
||||||
|
|
||||||
@@ -68,6 +67,7 @@ def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
|
|||||||
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
|
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
|
|
||||||
def parse_datetime_to_timestamp(value: str) -> float:
|
def parse_datetime_to_timestamp(value: str) -> float:
|
||||||
"""
|
"""
|
||||||
接受多种常见格式并转换为时间戳(秒)
|
接受多种常见格式并转换为时间戳(秒)
|
||||||
|
|||||||
@@ -47,4 +47,3 @@ def register_tool():
|
|||||||
],
|
],
|
||||||
execute_func=finish_search,
|
execute_func=finish_search,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -16,9 +16,7 @@ from .tool_registry import register_memory_retrieval_tool
|
|||||||
logger = get_logger("memory_retrieval_tools")
|
logger = get_logger("memory_retrieval_tools")
|
||||||
|
|
||||||
|
|
||||||
async def search_chat_history(
|
async def search_chat_history(chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None) -> str:
|
||||||
chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None
|
|
||||||
) -> str:
|
|
||||||
"""根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords
|
"""根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -144,7 +142,9 @@ async def search_chat_history(
|
|||||||
keywords_list = parse_keywords_string(keyword)
|
keywords_list = parse_keywords_string(keyword)
|
||||||
if len(keywords_list) > 2:
|
if len(keywords_list) > 2:
|
||||||
required_count = len(keywords_list) - 1
|
required_count = len(keywords_list) - 1
|
||||||
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
return (
|
||||||
|
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
|
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
|
||||||
elif participant:
|
elif participant:
|
||||||
@@ -160,9 +160,7 @@ async def search_chat_history(
|
|||||||
if record.keywords:
|
if record.keywords:
|
||||||
try:
|
try:
|
||||||
keywords_data = (
|
keywords_data = (
|
||||||
json.loads(record.keywords)
|
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||||
if isinstance(record.keywords, str)
|
|
||||||
else record.keywords
|
|
||||||
)
|
)
|
||||||
if isinstance(keywords_data, list):
|
if isinstance(keywords_data, list):
|
||||||
for k in keywords_data:
|
for k in keywords_data:
|
||||||
@@ -179,13 +177,12 @@ async def search_chat_history(
|
|||||||
keywords_str = "、".join(sorted(all_keywords_set))
|
keywords_str = "、".join(sorted(all_keywords_set))
|
||||||
return (
|
return (
|
||||||
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
||||||
f"有关\"{search_label}\"的关键词:\n"
|
f'有关"{search_label}"的关键词:\n'
|
||||||
f"{keywords_str}"
|
f"{keywords_str}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
f'包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n有关"{search_label}"的关键词信息为空'
|
||||||
f"有关\"{search_label}\"的关键词信息为空"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建结果文本,返回id、theme和keywords(最多20条)
|
# 构建结果文本,返回id、theme和keywords(最多20条)
|
||||||
|
|||||||
@@ -11,6 +11,9 @@ from .base import (
|
|||||||
BaseCommand,
|
BaseCommand,
|
||||||
BaseTool,
|
BaseTool,
|
||||||
ConfigField,
|
ConfigField,
|
||||||
|
ConfigSection,
|
||||||
|
ConfigLayout,
|
||||||
|
ConfigTab,
|
||||||
ComponentType,
|
ComponentType,
|
||||||
ActionActivationType,
|
ActionActivationType,
|
||||||
ChatMode,
|
ChatMode,
|
||||||
@@ -116,6 +119,9 @@ __all__ = [
|
|||||||
# 装饰器
|
# 装饰器
|
||||||
"register_plugin",
|
"register_plugin",
|
||||||
"ConfigField",
|
"ConfigField",
|
||||||
|
"ConfigSection",
|
||||||
|
"ConfigLayout",
|
||||||
|
"ConfigTab",
|
||||||
# 工具函数
|
# 工具函数
|
||||||
"ManifestValidator",
|
"ManifestValidator",
|
||||||
"get_logger",
|
"get_logger",
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from .component_types import (
|
|||||||
ForwardNode,
|
ForwardNode,
|
||||||
ReplySetModel,
|
ReplySetModel,
|
||||||
)
|
)
|
||||||
from .config_types import ConfigField
|
from .config_types import ConfigField, ConfigSection, ConfigLayout, ConfigTab
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BasePlugin",
|
"BasePlugin",
|
||||||
@@ -46,6 +46,9 @@ __all__ = [
|
|||||||
"PluginInfo",
|
"PluginInfo",
|
||||||
"PythonDependency",
|
"PythonDependency",
|
||||||
"ConfigField",
|
"ConfigField",
|
||||||
|
"ConfigSection",
|
||||||
|
"ConfigLayout",
|
||||||
|
"ConfigTab",
|
||||||
"EventHandlerInfo",
|
"EventHandlerInfo",
|
||||||
"EventType",
|
"EventType",
|
||||||
"BaseEventHandler",
|
"BaseEventHandler",
|
||||||
|
|||||||
@@ -70,6 +70,12 @@ class ConfigField:
|
|||||||
depends_on: Optional[str] = None # 依赖的字段路径,如 "section.field"
|
depends_on: Optional[str] = None # 依赖的字段路径,如 "section.field"
|
||||||
depends_value: Any = None # 依赖字段需要的值(当依赖字段等于此值时显示)
|
depends_value: Any = None # 依赖字段需要的值(当依赖字段等于此值时显示)
|
||||||
|
|
||||||
|
# === 列表类型专用 ===
|
||||||
|
item_type: Optional[str] = None # 数组元素类型: "string", "number", "object"
|
||||||
|
item_fields: Optional[Dict[str, Any]] = None # 当 item_type="object" 时,定义对象的字段结构
|
||||||
|
min_items: Optional[int] = None # 数组最小元素数量
|
||||||
|
max_items: Optional[int] = None # 数组最大元素数量
|
||||||
|
|
||||||
def get_ui_type(self) -> str:
|
def get_ui_type(self) -> str:
|
||||||
"""
|
"""
|
||||||
获取 UI 控件类型
|
获取 UI 控件类型
|
||||||
@@ -132,6 +138,10 @@ class ConfigField:
|
|||||||
"group": self.group,
|
"group": self.group,
|
||||||
"depends_on": self.depends_on,
|
"depends_on": self.depends_on,
|
||||||
"depends_value": self.depends_value,
|
"depends_value": self.depends_value,
|
||||||
|
"item_type": self.item_type,
|
||||||
|
"item_fields": self.item_fields,
|
||||||
|
"min_items": self.min_items,
|
||||||
|
"max_items": self.max_items,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
784
src/webui/anti_crawler.py
Normal file
784
src/webui/anti_crawler.py
Normal file
@@ -0,0 +1,784 @@
|
|||||||
|
"""
|
||||||
|
WebUI 防爬虫模块
|
||||||
|
提供爬虫检测和阻止功能,保护 WebUI 不被搜索引擎和恶意爬虫访问
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import ipaddress
|
||||||
|
import re
|
||||||
|
from collections import deque
|
||||||
|
from typing import Optional
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import PlainTextResponse
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("webui.anti_crawler")
|
||||||
|
|
||||||
|
# 常见爬虫 User-Agent 列表(使用更精确的关键词,避免误报)
|
||||||
|
CRAWLER_USER_AGENTS = {
|
||||||
|
# 搜索引擎爬虫(精确匹配)
|
||||||
|
"googlebot",
|
||||||
|
"bingbot",
|
||||||
|
"baiduspider",
|
||||||
|
"yandexbot",
|
||||||
|
"slurp", # Yahoo
|
||||||
|
"duckduckbot",
|
||||||
|
"sogou",
|
||||||
|
"exabot",
|
||||||
|
"facebot",
|
||||||
|
"ia_archiver", # Internet Archive
|
||||||
|
# 通用爬虫(移除过于宽泛的关键词)
|
||||||
|
"crawler",
|
||||||
|
"spider",
|
||||||
|
"scraper",
|
||||||
|
"wget", # 保留wget,因为通常用于自动化脚本
|
||||||
|
"scrapy", # 保留scrapy,因为这是爬虫框架
|
||||||
|
# 安全扫描工具(这些是明确的扫描工具)
|
||||||
|
"masscan",
|
||||||
|
"nmap",
|
||||||
|
"nikto",
|
||||||
|
"sqlmap",
|
||||||
|
# 注意:移除了以下过于宽泛的关键词以避免误报:
|
||||||
|
# - "bot" (会误匹配GitHub-Robot等)
|
||||||
|
# - "curl" (正常工具)
|
||||||
|
# - "python-requests" (正常库)
|
||||||
|
# - "httpx" (正常库)
|
||||||
|
# - "aiohttp" (正常库)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 资产测绘工具 User-Agent 标识
|
||||||
|
ASSET_SCANNER_USER_AGENTS = {
|
||||||
|
# 知名资产测绘平台
|
||||||
|
"shodan",
|
||||||
|
"censys",
|
||||||
|
"zoomeye",
|
||||||
|
"fofa",
|
||||||
|
"quake",
|
||||||
|
"hunter",
|
||||||
|
"binaryedge",
|
||||||
|
"onyphe",
|
||||||
|
"securitytrails",
|
||||||
|
"virustotal",
|
||||||
|
"passivetotal",
|
||||||
|
# 安全扫描工具
|
||||||
|
"acunetix",
|
||||||
|
"appscan",
|
||||||
|
"burpsuite",
|
||||||
|
"nessus",
|
||||||
|
"openvas",
|
||||||
|
"qualys",
|
||||||
|
"rapid7",
|
||||||
|
"tenable",
|
||||||
|
"veracode",
|
||||||
|
"zap",
|
||||||
|
"awvs", # Acunetix Web Vulnerability Scanner
|
||||||
|
"netsparker",
|
||||||
|
"skipfish",
|
||||||
|
"w3af",
|
||||||
|
"arachni",
|
||||||
|
# 其他扫描工具
|
||||||
|
"masscan",
|
||||||
|
"zmap",
|
||||||
|
"nmap",
|
||||||
|
"whatweb",
|
||||||
|
"wpscan",
|
||||||
|
"joomscan",
|
||||||
|
"dnsenum",
|
||||||
|
"subfinder",
|
||||||
|
"amass",
|
||||||
|
"sublist3r",
|
||||||
|
"theharvester",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 资产测绘工具常用的HTTP头标识
|
||||||
|
ASSET_SCANNER_HEADERS = {
|
||||||
|
# 常见的扫描工具自定义头
|
||||||
|
"x-scan": {"shodan", "censys", "zoomeye", "fofa"},
|
||||||
|
"x-scanner": {"nmap", "masscan", "zmap"},
|
||||||
|
"x-probe": {"masscan", "zmap"},
|
||||||
|
# 其他可疑头(移除反向代理标准头)
|
||||||
|
"x-originating-ip": set(),
|
||||||
|
"x-remote-ip": set(),
|
||||||
|
"x-remote-addr": set(),
|
||||||
|
# 注意:移除了以下反向代理标准头以避免误报:
|
||||||
|
# - "x-forwarded-proto" (反向代理标准头)
|
||||||
|
# - "x-real-ip" (反向代理标准头,已在_get_client_ip中使用)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 仅检查特定HTTP头中的可疑模式(收紧匹配范围)
|
||||||
|
# 只检查这些特定头,不检查所有头
|
||||||
|
SCANNER_SPECIFIC_HEADERS = {
|
||||||
|
"x-scan",
|
||||||
|
"x-scanner",
|
||||||
|
"x-probe",
|
||||||
|
"x-originating-ip",
|
||||||
|
"x-remote-ip",
|
||||||
|
"x-remote-addr",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 防爬虫模式配置
|
||||||
|
# false: 禁用
|
||||||
|
# strict: 严格模式(更严格的检测,更低的频率限制)
|
||||||
|
# loose: 宽松模式(较宽松的检测,较高的频率限制)
|
||||||
|
# basic: 基础模式(只记录恶意访问,不阻止,不限制请求数,不跟踪IP)
|
||||||
|
ANTI_CRAWLER_MODE = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower()
|
||||||
|
|
||||||
|
|
||||||
|
# IP白名单配置(从环境变量读取,逗号分隔)
|
||||||
|
# 支持格式:
|
||||||
|
# - 精确IP:127.0.0.1, 192.168.1.100
|
||||||
|
# - CIDR格式:192.168.1.0/24, 172.17.0.0/16 (适用于Docker网络)
|
||||||
|
# - 通配符:192.168.*.*, 10.*.*.*, *.*.*.* (匹配所有)
|
||||||
|
# - IPv6:::1, 2001:db8::/32
|
||||||
|
def _parse_allowed_ips(ip_string: str) -> list:
|
||||||
|
"""
|
||||||
|
解析IP白名单字符串,支持精确IP、CIDR格式和通配符
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip_string: 逗号分隔的IP字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IP白名单列表,每个元素可能是:
|
||||||
|
- ipaddress.IPv4Network/IPv6Network对象(CIDR格式)
|
||||||
|
- ipaddress.IPv4Address/IPv6Address对象(精确IP)
|
||||||
|
- str(通配符模式,已转换为正则表达式)
|
||||||
|
"""
|
||||||
|
allowed = []
|
||||||
|
if not ip_string:
|
||||||
|
return allowed
|
||||||
|
|
||||||
|
for ip_entry in ip_string.split(","):
|
||||||
|
ip_entry = ip_entry.strip() # 去除空格
|
||||||
|
if not ip_entry:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查通配符格式(包含*)
|
||||||
|
if "*" in ip_entry:
|
||||||
|
# 处理通配符
|
||||||
|
pattern = _convert_wildcard_to_regex(ip_entry)
|
||||||
|
if pattern:
|
||||||
|
allowed.append(pattern)
|
||||||
|
else:
|
||||||
|
logger.warning(f"无效的通配符IP格式,已忽略: {ip_entry}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 尝试解析为CIDR格式(包含/)
|
||||||
|
if "/" in ip_entry:
|
||||||
|
allowed.append(ipaddress.ip_network(ip_entry, strict=False))
|
||||||
|
else:
|
||||||
|
# 精确IP地址
|
||||||
|
allowed.append(ipaddress.ip_address(ip_entry))
|
||||||
|
except (ValueError, AttributeError) as e:
|
||||||
|
logger.warning(f"无效的IP白名单条目,已忽略: {ip_entry} ({e})")
|
||||||
|
|
||||||
|
return allowed
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
将通配符IP模式转换为正则表达式
|
||||||
|
|
||||||
|
支持的格式:
|
||||||
|
- 192.168.*.* 或 192.168.*
|
||||||
|
- 10.*.*.* 或 10.*
|
||||||
|
- *.*.*.* 或 *
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wildcard_pattern: 通配符模式字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
正则表达式字符串,如果格式无效则返回None
|
||||||
|
"""
|
||||||
|
# 去除空格
|
||||||
|
pattern = wildcard_pattern.strip()
|
||||||
|
|
||||||
|
# 处理单个*(匹配所有)
|
||||||
|
if pattern == "*":
|
||||||
|
return r".*"
|
||||||
|
|
||||||
|
# 处理IPv4通配符格式
|
||||||
|
# 支持:192.168.*.*, 192.168.*, 10.*.*.*, 10.* 等
|
||||||
|
parts = pattern.split(".")
|
||||||
|
|
||||||
|
if len(parts) > 4:
|
||||||
|
return None # IPv4最多4段
|
||||||
|
|
||||||
|
# 构建正则表达式
|
||||||
|
regex_parts = []
|
||||||
|
for part in parts:
|
||||||
|
part = part.strip()
|
||||||
|
if part == "*":
|
||||||
|
regex_parts.append(r"\d+") # 匹配任意数字
|
||||||
|
elif part.isdigit():
|
||||||
|
# 验证数字范围(0-255)
|
||||||
|
num = int(part)
|
||||||
|
if 0 <= num <= 255:
|
||||||
|
regex_parts.append(re.escape(part))
|
||||||
|
else:
|
||||||
|
return None # 无效的数字
|
||||||
|
else:
|
||||||
|
return None # 无效的格式
|
||||||
|
|
||||||
|
# 如果部分少于4段,补充.*
|
||||||
|
while len(regex_parts) < 4:
|
||||||
|
regex_parts.append(r"\d+")
|
||||||
|
|
||||||
|
# 组合成正则表达式
|
||||||
|
regex = r"^" + r"\.".join(regex_parts) + r"$"
|
||||||
|
return regex
|
||||||
|
|
||||||
|
|
||||||
|
ALLOWED_IPS = _parse_allowed_ips(os.getenv("WEBUI_ALLOWED_IPS", ""))
|
||||||
|
|
||||||
|
# 信任的代理IP配置(从环境变量读取,逗号分隔)
|
||||||
|
# 只有在信任的代理IP下才使用X-Forwarded-For头
|
||||||
|
# 默认关闭(空),不信任任何代理
|
||||||
|
TRUSTED_PROXIES = _parse_allowed_ips(os.getenv("WEBUI_TRUSTED_PROXIES", ""))
|
||||||
|
TRUST_XFF = os.getenv("WEBUI_TRUST_XFF", "false").lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mode_config(mode: str) -> dict:
|
||||||
|
"""
|
||||||
|
根据模式获取配置参数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode: 防爬虫模式 (false/strict/loose/basic)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置字典,包含所有相关参数
|
||||||
|
"""
|
||||||
|
mode = mode.lower()
|
||||||
|
|
||||||
|
if mode == "false":
|
||||||
|
return {
|
||||||
|
"enabled": False,
|
||||||
|
"rate_limit_window": 60,
|
||||||
|
"rate_limit_max_requests": 1000, # 禁用时设置很高的值
|
||||||
|
"max_tracked_ips": 0,
|
||||||
|
"check_user_agent": False,
|
||||||
|
"check_asset_scanner": False,
|
||||||
|
"check_rate_limit": False,
|
||||||
|
"block_on_detect": False, # 不阻止
|
||||||
|
}
|
||||||
|
elif mode == "strict":
|
||||||
|
return {
|
||||||
|
"enabled": True,
|
||||||
|
"rate_limit_window": 60,
|
||||||
|
"rate_limit_max_requests": 15, # 严格模式:更低的请求数
|
||||||
|
"max_tracked_ips": 20000,
|
||||||
|
"check_user_agent": True,
|
||||||
|
"check_asset_scanner": True,
|
||||||
|
"check_rate_limit": True,
|
||||||
|
"block_on_detect": True, # 阻止恶意访问
|
||||||
|
}
|
||||||
|
elif mode == "loose":
|
||||||
|
return {
|
||||||
|
"enabled": True,
|
||||||
|
"rate_limit_window": 60,
|
||||||
|
"rate_limit_max_requests": 60, # 宽松模式:更高的请求数
|
||||||
|
"max_tracked_ips": 5000,
|
||||||
|
"check_user_agent": True,
|
||||||
|
"check_asset_scanner": True,
|
||||||
|
"check_rate_limit": True,
|
||||||
|
"block_on_detect": True, # 阻止恶意访问
|
||||||
|
}
|
||||||
|
else: # basic (默认模式)
|
||||||
|
return {
|
||||||
|
"enabled": True,
|
||||||
|
"rate_limit_window": 60,
|
||||||
|
"rate_limit_max_requests": 1000, # 不限制请求数
|
||||||
|
"max_tracked_ips": 0, # 不跟踪IP
|
||||||
|
"check_user_agent": True, # 检测但不阻止
|
||||||
|
"check_asset_scanner": True, # 检测但不阻止
|
||||||
|
"check_rate_limit": False, # 不限制请求频率
|
||||||
|
"block_on_detect": False, # 只记录,不阻止
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""防爬虫中间件"""
|
||||||
|
|
||||||
|
def __init__(self, app, mode: str = "standard"):
|
||||||
|
"""
|
||||||
|
初始化防爬虫中间件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI 应用实例
|
||||||
|
mode: 防爬虫模式 (false/strict/loose/standard)
|
||||||
|
"""
|
||||||
|
super().__init__(app)
|
||||||
|
self.mode = mode.lower()
|
||||||
|
# 根据模式获取配置
|
||||||
|
config = _get_mode_config(self.mode)
|
||||||
|
self.enabled = config["enabled"]
|
||||||
|
self.rate_limit_window = config["rate_limit_window"]
|
||||||
|
self.rate_limit_max_requests = config["rate_limit_max_requests"]
|
||||||
|
self.max_tracked_ips = config["max_tracked_ips"]
|
||||||
|
self.check_user_agent = config["check_user_agent"]
|
||||||
|
self.check_asset_scanner = config["check_asset_scanner"]
|
||||||
|
self.check_rate_limit = config["check_rate_limit"]
|
||||||
|
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
|
||||||
|
|
||||||
|
# 用于存储每个IP的请求时间戳(使用deque提高性能)
|
||||||
|
self.request_times: dict[str, deque] = {}
|
||||||
|
# 上次清理时间
|
||||||
|
self.last_cleanup = time.time()
|
||||||
|
# 将关键词列表转换为集合以提高查找性能
|
||||||
|
self.crawler_keywords_set = set(CRAWLER_USER_AGENTS)
|
||||||
|
self.scanner_keywords_set = set(ASSET_SCANNER_USER_AGENTS)
|
||||||
|
|
||||||
|
def _is_crawler_user_agent(self, user_agent: Optional[str]) -> bool:
|
||||||
|
"""
|
||||||
|
检测是否为爬虫 User-Agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_agent: User-Agent 字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
如果是爬虫则返回 True
|
||||||
|
"""
|
||||||
|
if not user_agent:
|
||||||
|
# 没有 User-Agent 的请求记录日志但不直接阻止
|
||||||
|
# 改为只记录,让频率限制来处理
|
||||||
|
logger.debug("请求缺少User-Agent")
|
||||||
|
return False # 不再直接阻止无User-Agent的请求
|
||||||
|
|
||||||
|
user_agent_lower = user_agent.lower()
|
||||||
|
|
||||||
|
# 使用集合查找提高性能(检查是否包含爬虫关键词)
|
||||||
|
for crawler_keyword in self.crawler_keywords_set:
|
||||||
|
if crawler_keyword in user_agent_lower:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _is_asset_scanner_header(self, request: Request) -> bool:
|
||||||
|
"""
|
||||||
|
检测是否为资产测绘工具的HTTP头(只检查特定头,收紧匹配)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
如果检测到资产测绘工具头则返回 True
|
||||||
|
"""
|
||||||
|
# 只检查特定的扫描工具头,不检查所有头
|
||||||
|
for header_name, header_value in request.headers.items():
|
||||||
|
header_name_lower = header_name.lower()
|
||||||
|
header_value_lower = header_value.lower() if header_value else ""
|
||||||
|
|
||||||
|
# 检查已知的扫描工具头
|
||||||
|
if header_name_lower in ASSET_SCANNER_HEADERS:
|
||||||
|
# 如果该头有特定的工具集合,检查值是否匹配
|
||||||
|
expected_tools = ASSET_SCANNER_HEADERS[header_name_lower]
|
||||||
|
if expected_tools:
|
||||||
|
for tool in expected_tools:
|
||||||
|
if tool in header_value_lower:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# 如果没有特定工具集合,只要存在该头就视为可疑
|
||||||
|
if header_value_lower:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 只检查特定头中的可疑模式(收紧匹配)
|
||||||
|
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
|
||||||
|
# 检查头值中是否包含已知扫描工具名称
|
||||||
|
for tool in self.scanner_keywords_set:
|
||||||
|
if tool in header_value_lower:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _detect_asset_scanner(self, request: Request) -> tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
检测资产测绘工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否检测到, 检测到的工具名称)
|
||||||
|
"""
|
||||||
|
user_agent = request.headers.get("User-Agent")
|
||||||
|
|
||||||
|
# 检查 User-Agent(使用集合查找提高性能)
|
||||||
|
if user_agent:
|
||||||
|
user_agent_lower = user_agent.lower()
|
||||||
|
for scanner_keyword in self.scanner_keywords_set:
|
||||||
|
if scanner_keyword in user_agent_lower:
|
||||||
|
return True, scanner_keyword
|
||||||
|
|
||||||
|
# 检查HTTP头
|
||||||
|
if self._is_asset_scanner_header(request):
|
||||||
|
# 尝试从User-Agent或头中提取工具名称
|
||||||
|
detected_tool = None
|
||||||
|
if user_agent:
|
||||||
|
user_agent_lower = user_agent.lower()
|
||||||
|
for tool in self.scanner_keywords_set:
|
||||||
|
if tool in user_agent_lower:
|
||||||
|
detected_tool = tool
|
||||||
|
break
|
||||||
|
|
||||||
|
# 检查HTTP头中的工具标识(只检查特定头)
|
||||||
|
if not detected_tool:
|
||||||
|
for header_name, header_value in request.headers.items():
|
||||||
|
header_name_lower = header_name.lower()
|
||||||
|
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
|
||||||
|
header_value_lower = (header_value or "").lower()
|
||||||
|
for tool in self.scanner_keywords_set:
|
||||||
|
if tool in header_value_lower:
|
||||||
|
detected_tool = tool
|
||||||
|
break
|
||||||
|
if detected_tool:
|
||||||
|
break
|
||||||
|
|
||||||
|
return True, detected_tool or "unknown_scanner"
|
||||||
|
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def _check_rate_limit(self, client_ip: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查请求频率限制
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_ip: 客户端IP地址
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
如果超过限制则返回 True(需要阻止)
|
||||||
|
"""
|
||||||
|
# 检查IP白名单
|
||||||
|
if self._is_ip_allowed(client_ip):
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# 定期清理过期的请求记录(每5分钟清理一次)
|
||||||
|
if current_time - self.last_cleanup > 300:
|
||||||
|
self._cleanup_old_requests(current_time)
|
||||||
|
self.last_cleanup = current_time
|
||||||
|
|
||||||
|
# 限制跟踪的IP数量,防止内存泄漏
|
||||||
|
if self.max_tracked_ips > 0 and len(self.request_times) > self.max_tracked_ips:
|
||||||
|
# 清理最旧的记录(删除最久未访问的IP)
|
||||||
|
self._cleanup_oldest_ips()
|
||||||
|
|
||||||
|
# 获取或创建该IP的请求时间deque(不使用maxlen,避免限流变松)
|
||||||
|
if client_ip not in self.request_times:
|
||||||
|
self.request_times[client_ip] = deque()
|
||||||
|
|
||||||
|
request_times = self.request_times[client_ip]
|
||||||
|
|
||||||
|
# 移除时间窗口外的请求记录(从左侧弹出过期记录)
|
||||||
|
while request_times and current_time - request_times[0] >= self.rate_limit_window:
|
||||||
|
request_times.popleft()
|
||||||
|
|
||||||
|
# 检查是否超过限制
|
||||||
|
if len(request_times) >= self.rate_limit_max_requests:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 记录当前请求时间
|
||||||
|
request_times.append(current_time)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _cleanup_old_requests(self, current_time: float):
|
||||||
|
"""清理过期的请求记录(只清理当前需要检查的IP,不全量遍历)"""
|
||||||
|
# 这个方法现在主要用于定期清理,实际清理在_check_rate_limit中按需进行
|
||||||
|
# 清理最久未访问的IP记录
|
||||||
|
if len(self.request_times) > self.max_tracked_ips * 0.8:
|
||||||
|
self._cleanup_oldest_ips()
|
||||||
|
|
||||||
|
def _cleanup_oldest_ips(self):
|
||||||
|
"""清理最久未访问的IP记录(全量遍历找真正的oldest)"""
|
||||||
|
if not self.request_times:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 先收集空deque的IP(优先删除)
|
||||||
|
empty_ips = []
|
||||||
|
# 找到最久未访问的IP(最旧时间戳)
|
||||||
|
oldest_ip = None
|
||||||
|
oldest_time = float("inf")
|
||||||
|
|
||||||
|
# 全量遍历找真正的oldest(超限时性能可接受)
|
||||||
|
for ip, times in self.request_times.items():
|
||||||
|
if not times:
|
||||||
|
# 空deque,记录待删除
|
||||||
|
empty_ips.append(ip)
|
||||||
|
else:
|
||||||
|
# 找到最旧的时间戳
|
||||||
|
if times[0] < oldest_time:
|
||||||
|
oldest_time = times[0]
|
||||||
|
oldest_ip = ip
|
||||||
|
|
||||||
|
# 先删除空deque的IP
|
||||||
|
for ip in empty_ips:
|
||||||
|
del self.request_times[ip]
|
||||||
|
|
||||||
|
# 如果没有空deque可删除,且仍需要清理,删除最旧的一个IP
|
||||||
|
if not empty_ips and oldest_ip:
|
||||||
|
del self.request_times[oldest_ip]
|
||||||
|
|
||||||
|
def _is_trusted_proxy(self, ip: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查IP是否在信任的代理列表中
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip: IP地址字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
如果是信任的代理则返回 True
|
||||||
|
"""
|
||||||
|
if not TRUSTED_PROXIES or ip == "unknown":
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查代理列表中的每个条目
|
||||||
|
for trusted_entry in TRUSTED_PROXIES:
|
||||||
|
# 通配符模式(字符串,正则表达式)
|
||||||
|
if isinstance(trusted_entry, str):
|
||||||
|
try:
|
||||||
|
if re.match(trusted_entry, ip):
|
||||||
|
return True
|
||||||
|
except re.error:
|
||||||
|
continue
|
||||||
|
# CIDR格式(网络对象)
|
||||||
|
elif isinstance(trusted_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
|
||||||
|
try:
|
||||||
|
client_ip_obj = ipaddress.ip_address(ip)
|
||||||
|
if client_ip_obj in trusted_entry:
|
||||||
|
return True
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
continue
|
||||||
|
# 精确IP(地址对象)
|
||||||
|
elif isinstance(trusted_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
|
||||||
|
try:
|
||||||
|
client_ip_obj = ipaddress.ip_address(ip)
|
||||||
|
if client_ip_obj == trusted_entry:
|
||||||
|
return True
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_client_ip(self, request: Request) -> str:
|
||||||
|
"""
|
||||||
|
获取客户端真实IP地址(带基本验证和代理信任检查)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
客户端IP地址
|
||||||
|
"""
|
||||||
|
# 获取直接连接的客户端IP(用于验证代理)
|
||||||
|
direct_client_ip = None
|
||||||
|
if request.client:
|
||||||
|
direct_client_ip = request.client.host
|
||||||
|
|
||||||
|
# 检查是否信任X-Forwarded-For头
|
||||||
|
# TRUST_XFF 只表示"启用代理解析能力",但仍要求直连 IP 在 TRUSTED_PROXIES 中
|
||||||
|
use_xff = False
|
||||||
|
if TRUST_XFF and TRUSTED_PROXIES and direct_client_ip:
|
||||||
|
# 只有在启用 TRUST_XFF 且直连 IP 在信任列表中时,才信任 XFF
|
||||||
|
use_xff = self._is_trusted_proxy(direct_client_ip)
|
||||||
|
|
||||||
|
# 如果信任代理,优先从 X-Forwarded-For 获取
|
||||||
|
if use_xff:
|
||||||
|
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||||
|
if forwarded_for:
|
||||||
|
# X-Forwarded-For 可能包含多个IP,取第一个
|
||||||
|
ip = forwarded_for.split(",")[0].strip()
|
||||||
|
# 基本验证IP格式
|
||||||
|
if self._validate_ip(ip):
|
||||||
|
return ip
|
||||||
|
|
||||||
|
# 从 X-Real-IP 获取(如果信任代理)
|
||||||
|
if use_xff:
|
||||||
|
real_ip = request.headers.get("X-Real-IP")
|
||||||
|
if real_ip:
|
||||||
|
ip = real_ip.strip()
|
||||||
|
if self._validate_ip(ip):
|
||||||
|
return ip
|
||||||
|
|
||||||
|
# 使用直接连接的客户端IP
|
||||||
|
if direct_client_ip and self._validate_ip(direct_client_ip):
|
||||||
|
return direct_client_ip
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
def _validate_ip(self, ip: str) -> bool:
|
||||||
|
"""
|
||||||
|
验证IP地址格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip: IP地址字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
如果格式有效则返回 True
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ipaddress.ip_address(ip)
|
||||||
|
return True
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _is_ip_allowed(self, ip: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查IP是否在白名单中(支持精确IP、CIDR格式和通配符)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip: 客户端IP地址
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
如果IP在白名单中则返回 True
|
||||||
|
"""
|
||||||
|
if not ALLOWED_IPS or ip == "unknown":
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查白名单中的每个条目
|
||||||
|
for allowed_entry in ALLOWED_IPS:
|
||||||
|
# 通配符模式(字符串,正则表达式)
|
||||||
|
if isinstance(allowed_entry, str):
|
||||||
|
try:
|
||||||
|
if re.match(allowed_entry, ip):
|
||||||
|
return True
|
||||||
|
except re.error:
|
||||||
|
# 正则表达式错误,跳过
|
||||||
|
continue
|
||||||
|
# CIDR格式(网络对象)
|
||||||
|
elif isinstance(allowed_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
|
||||||
|
try:
|
||||||
|
client_ip_obj = ipaddress.ip_address(ip)
|
||||||
|
if client_ip_obj in allowed_entry:
|
||||||
|
return True
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
# IP格式无效,跳过
|
||||||
|
continue
|
||||||
|
# 精确IP(地址对象)
|
||||||
|
elif isinstance(allowed_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
|
||||||
|
try:
|
||||||
|
client_ip_obj = ipaddress.ip_address(ip)
|
||||||
|
if client_ip_obj == allowed_entry:
|
||||||
|
return True
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
# IP格式无效,跳过
|
||||||
|
continue
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
"""
|
||||||
|
处理请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 请求对象
|
||||||
|
call_next: 下一个中间件或路由处理函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
响应对象
|
||||||
|
"""
|
||||||
|
# 如果未启用,直接通过
|
||||||
|
if not self.enabled:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# 允许访问 robots.txt(由专门的路由处理)
|
||||||
|
if request.url.path == "/robots.txt":
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# 允许访问静态资源(CSS、JS、图片等)
|
||||||
|
# 注意:.json 已移除,避免 API 路径绕过防护
|
||||||
|
# 静态资源只在特定前缀下放行(/static/、/assets/、/dist/)
|
||||||
|
static_extensions = {
|
||||||
|
".css",
|
||||||
|
".js",
|
||||||
|
".png",
|
||||||
|
".jpg",
|
||||||
|
".jpeg",
|
||||||
|
".gif",
|
||||||
|
".svg",
|
||||||
|
".ico",
|
||||||
|
".woff",
|
||||||
|
".woff2",
|
||||||
|
".ttf",
|
||||||
|
".eot",
|
||||||
|
}
|
||||||
|
static_prefixes = {"/static/", "/assets/", "/dist/"}
|
||||||
|
|
||||||
|
# 检查是否是静态资源路径(特定前缀下的静态文件)
|
||||||
|
path = request.url.path
|
||||||
|
is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(
|
||||||
|
path.endswith(ext) for ext in static_extensions
|
||||||
|
)
|
||||||
|
|
||||||
|
# 也允许根路径下的静态文件(如 /favicon.ico)
|
||||||
|
is_root_static = path.count("/") == 1 and any(path.endswith(ext) for ext in static_extensions)
|
||||||
|
|
||||||
|
if is_static_path or is_root_static:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# 获取客户端IP(只获取一次,避免重复调用)
|
||||||
|
client_ip = self._get_client_ip(request)
|
||||||
|
|
||||||
|
# 检查IP白名单(优先检查,白名单IP直接通过)
|
||||||
|
if self._is_ip_allowed(client_ip):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# 获取 User-Agent
|
||||||
|
user_agent = request.headers.get("User-Agent")
|
||||||
|
|
||||||
|
# 检测资产测绘工具(优先检测,因为更危险)
|
||||||
|
if self.check_asset_scanner:
|
||||||
|
is_scanner, scanner_name = self._detect_asset_scanner(request)
|
||||||
|
if is_scanner:
|
||||||
|
logger.warning(
|
||||||
|
f"🚫 检测到资产测绘工具请求 - IP: {client_ip}, 工具: {scanner_name}, "
|
||||||
|
f"User-Agent: {user_agent}, Path: {request.url.path}"
|
||||||
|
)
|
||||||
|
# 根据配置决定是否阻止
|
||||||
|
if self.block_on_detect:
|
||||||
|
return PlainTextResponse(
|
||||||
|
"Access Denied: Asset scanning tools are not allowed",
|
||||||
|
status_code=403,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检测爬虫 User-Agent
|
||||||
|
if self.check_user_agent and self._is_crawler_user_agent(user_agent):
|
||||||
|
logger.warning(f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
|
||||||
|
# 根据配置决定是否阻止
|
||||||
|
if self.block_on_detect:
|
||||||
|
return PlainTextResponse(
|
||||||
|
"Access Denied: Crawlers are not allowed",
|
||||||
|
status_code=403,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查请求频率限制
|
||||||
|
if self.check_rate_limit and self._check_rate_limit(client_ip):
|
||||||
|
logger.warning(f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
|
||||||
|
return PlainTextResponse(
|
||||||
|
"Too Many Requests: Rate limit exceeded",
|
||||||
|
status_code=429,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 正常请求,继续处理
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
|
def create_robots_txt_response() -> PlainTextResponse:
|
||||||
|
"""
|
||||||
|
创建 robots.txt 响应
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
robots.txt 响应对象
|
||||||
|
"""
|
||||||
|
robots_content = """User-agent: *
|
||||||
|
Disallow: /
|
||||||
|
|
||||||
|
# 禁止所有爬虫访问
|
||||||
|
"""
|
||||||
|
return PlainTextResponse(
|
||||||
|
content=robots_content,
|
||||||
|
media_type="text/plain",
|
||||||
|
headers={"Cache-Control": "public, max-age=86400"}, # 缓存24小时
|
||||||
|
)
|
||||||
@@ -3,6 +3,7 @@ WebUI 认证模块
|
|||||||
提供统一的认证依赖,支持 Cookie 和 Header 两种方式
|
提供统一的认证依赖,支持 Cookie 和 Header 两种方式
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import HTTPException, Cookie, Header, Response, Request
|
from fastapi import HTTPException, Cookie, Header, Response, Request
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -15,6 +16,28 @@ COOKIE_NAME = "maibot_session"
|
|||||||
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
|
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
|
||||||
|
|
||||||
|
|
||||||
|
def _is_secure_environment() -> bool:
|
||||||
|
"""
|
||||||
|
检测是否应该启用安全 Cookie(HTTPS)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 如果应该使用 secure cookie 则返回 True
|
||||||
|
"""
|
||||||
|
# 检查环境变量
|
||||||
|
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("true", "1", "yes"):
|
||||||
|
return True
|
||||||
|
if os.environ.get("WEBUI_SECURE_COOKIE", "").lower() in ("false", "0", "no"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查是否是生产环境
|
||||||
|
env = os.environ.get("WEBUI_MODE", "").lower()
|
||||||
|
if env in ("production", "prod"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 默认:开发环境不启用(因为通常是 HTTP)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_current_token(
|
def get_current_token(
|
||||||
request: Request,
|
request: Request,
|
||||||
maibot_session: Optional[str] = Cookie(None),
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
@@ -62,16 +85,19 @@ def set_auth_cookie(response: Response, token: str) -> None:
|
|||||||
response: FastAPI Response 对象
|
response: FastAPI Response 对象
|
||||||
token: 要设置的 token
|
token: 要设置的 token
|
||||||
"""
|
"""
|
||||||
|
# 根据环境决定安全设置
|
||||||
|
is_secure = _is_secure_environment()
|
||||||
|
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key=COOKIE_NAME,
|
key=COOKIE_NAME,
|
||||||
value=token,
|
value=token,
|
||||||
max_age=COOKIE_MAX_AGE,
|
max_age=COOKIE_MAX_AGE,
|
||||||
httponly=True, # 防止 JS 读取
|
httponly=True, # 防止 JS 读取,阻止 XSS 窃取
|
||||||
samesite="lax", # 允许同站导航时发送 Cookie(兼容开发环境代理)
|
samesite="strict" if is_secure else "lax", # 生产环境使用 strict 防止 CSRF
|
||||||
secure=False, # 本地开发不强制 HTTPS,生产环境建议设为 True
|
secure=is_secure, # 生产环境强制 HTTPS
|
||||||
path="/", # 确保 Cookie 在所有路径下可用
|
path="/", # 确保 Cookie 在所有路径下可用
|
||||||
)
|
)
|
||||||
logger.debug(f"已设置认证 Cookie: {token[:8]}...")
|
logger.debug(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure})")
|
||||||
|
|
||||||
|
|
||||||
def clear_auth_cookie(response: Response) -> None:
|
def clear_auth_cookie(response: Response) -> None:
|
||||||
@@ -81,10 +107,14 @@ def clear_auth_cookie(response: Response) -> None:
|
|||||||
Args:
|
Args:
|
||||||
response: FastAPI Response 对象
|
response: FastAPI Response 对象
|
||||||
"""
|
"""
|
||||||
|
# 保持与 set_auth_cookie 相同的安全设置
|
||||||
|
is_secure = _is_secure_environment()
|
||||||
|
|
||||||
response.delete_cookie(
|
response.delete_cookie(
|
||||||
key=COOKIE_NAME,
|
key=COOKIE_NAME,
|
||||||
httponly=True,
|
httponly=True,
|
||||||
samesite="lax",
|
samesite="strict" if is_secure else "lax",
|
||||||
|
secure=is_secure,
|
||||||
path="/",
|
path="/",
|
||||||
)
|
)
|
||||||
logger.debug("已清除认证 Cookie")
|
logger.debug("已清除认证 Cookie")
|
||||||
|
|||||||
@@ -8,18 +8,30 @@
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, Any, Optional, List
|
from typing import Dict, Any, Optional, List
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import Messages, PersonInfo
|
from src.common.database.database_model import Messages, PersonInfo
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.bot import chat_bot
|
from src.chat.message_receive.bot import chat_bot
|
||||||
|
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||||
|
from src.webui.token_manager import get_token_manager
|
||||||
|
from src.webui.ws_auth import verify_ws_token
|
||||||
|
|
||||||
logger = get_logger("webui.chat")
|
logger = get_logger("webui.chat")
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
|
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
) -> bool:
|
||||||
|
"""认证依赖:验证用户是否已登录"""
|
||||||
|
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||||
|
|
||||||
|
|
||||||
# WebUI 聊天的虚拟群组 ID
|
# WebUI 聊天的虚拟群组 ID
|
||||||
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
|
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
|
||||||
WEBUI_CHAT_PLATFORM = "webui"
|
WEBUI_CHAT_PLATFORM = "webui"
|
||||||
@@ -256,6 +268,7 @@ async def get_chat_history(
|
|||||||
limit: int = Query(default=50, ge=1, le=200),
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
|
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
|
||||||
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
|
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
|
||||||
|
_auth: bool = Depends(require_auth),
|
||||||
):
|
):
|
||||||
"""获取聊天历史记录
|
"""获取聊天历史记录
|
||||||
|
|
||||||
@@ -272,7 +285,7 @@ async def get_chat_history(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/platforms")
|
@router.get("/platforms")
|
||||||
async def get_available_platforms():
|
async def get_available_platforms(_auth: bool = Depends(require_auth)):
|
||||||
"""获取可用平台列表
|
"""获取可用平台列表
|
||||||
|
|
||||||
从 PersonInfo 表中获取所有已知的平台
|
从 PersonInfo 表中获取所有已知的平台
|
||||||
@@ -303,6 +316,7 @@ async def get_persons_by_platform(
|
|||||||
platform: str = Query(..., description="平台名称"),
|
platform: str = Query(..., description="平台名称"),
|
||||||
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
||||||
limit: int = Query(default=50, ge=1, le=200),
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
|
_auth: bool = Depends(require_auth),
|
||||||
):
|
):
|
||||||
"""获取指定平台的用户列表
|
"""获取指定平台的用户列表
|
||||||
|
|
||||||
@@ -350,7 +364,7 @@ async def get_persons_by_platform(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/history")
|
@router.delete("/history")
|
||||||
async def clear_chat_history(group_id: Optional[str] = Query(default=None)):
|
async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)):
|
||||||
"""清空聊天历史记录
|
"""清空聊天历史记录
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -372,6 +386,7 @@ async def websocket_chat(
|
|||||||
person_id: Optional[str] = Query(default=None),
|
person_id: Optional[str] = Query(default=None),
|
||||||
group_name: Optional[str] = Query(default=None),
|
group_name: Optional[str] = Query(default=None),
|
||||||
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
|
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
|
||||||
|
token: Optional[str] = Query(default=None), # 认证 token
|
||||||
):
|
):
|
||||||
"""WebSocket 聊天端点
|
"""WebSocket 聊天端点
|
||||||
|
|
||||||
@@ -382,9 +397,45 @@ async def websocket_chat(
|
|||||||
person_id: 虚拟身份模式的用户 person_id(可选)
|
person_id: 虚拟身份模式的用户 person_id(可选)
|
||||||
group_name: 虚拟身份模式的群名(可选)
|
group_name: 虚拟身份模式的群名(可选)
|
||||||
group_id: 虚拟身份模式的群 ID(可选,由前端生成并持久化)
|
group_id: 虚拟身份模式的群 ID(可选,由前端生成并持久化)
|
||||||
|
token: 认证 token(可选,也可从 Cookie 获取)
|
||||||
|
|
||||||
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
|
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
|
||||||
|
|
||||||
|
支持三种认证方式(按优先级):
|
||||||
|
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||||
|
2. Cookie 中的 maibot_session
|
||||||
|
3. 直接使用 session token(兼容)
|
||||||
|
|
||||||
|
示例:ws://host/api/chat/ws?token=xxx
|
||||||
"""
|
"""
|
||||||
|
is_authenticated = False
|
||||||
|
|
||||||
|
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||||
|
if token and verify_ws_token(token):
|
||||||
|
is_authenticated = True
|
||||||
|
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
|
||||||
|
|
||||||
|
# 方式 2: 尝试从 Cookie 获取 session token
|
||||||
|
if not is_authenticated:
|
||||||
|
cookie_token = websocket.cookies.get("maibot_session")
|
||||||
|
if cookie_token:
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if token_manager.verify_token(cookie_token):
|
||||||
|
is_authenticated = True
|
||||||
|
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
|
||||||
|
|
||||||
|
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||||
|
if not is_authenticated and token:
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if token_manager.verify_token(token):
|
||||||
|
is_authenticated = True
|
||||||
|
logger.debug("聊天 WebSocket 使用 session token 认证成功")
|
||||||
|
|
||||||
|
if not is_authenticated:
|
||||||
|
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
|
||||||
|
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||||
|
return
|
||||||
|
|
||||||
# 生成会话 ID(每次连接都是新的)
|
# 生成会话 ID(每次连接都是新的)
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
@@ -414,7 +465,9 @@ async def websocket_chat(
|
|||||||
group_id=virtual_group_id,
|
group_id=virtual_group_id,
|
||||||
group_name=group_name or "WebUI虚拟群聊",
|
group_name=group_name or "WebUI虚拟群聊",
|
||||||
)
|
)
|
||||||
logger.info(f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}")
|
logger.info(
|
||||||
|
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
||||||
|
|
||||||
@@ -710,7 +763,7 @@ async def websocket_chat(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/info")
|
@router.get("/info")
|
||||||
async def get_chat_info():
|
async def get_chat_info(_auth: bool = Depends(require_auth)):
|
||||||
"""获取聊天室信息"""
|
"""获取聊天室信息"""
|
||||||
return {
|
return {
|
||||||
"bot_name": global_config.bot.nickname,
|
"bot_name": global_config.bot.nickname,
|
||||||
|
|||||||
@@ -4,10 +4,11 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import tomlkit
|
import tomlkit
|
||||||
from fastapi import APIRouter, HTTPException, Body
|
from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
|
||||||
from typing import Any, Annotated
|
from typing import Any, Annotated, Optional
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||||
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
|
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
|
||||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
||||||
from src.config.official_configs import (
|
from src.config.official_configs import (
|
||||||
@@ -49,11 +50,19 @@ PathBody = Annotated[dict[str, str], Body()]
|
|||||||
router = APIRouter(prefix="/config", tags=["config"])
|
router = APIRouter(prefix="/config", tags=["config"])
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
) -> bool:
|
||||||
|
"""认证依赖:验证用户是否已登录"""
|
||||||
|
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||||
|
|
||||||
|
|
||||||
# ===== 架构获取接口 =====
|
# ===== 架构获取接口 =====
|
||||||
|
|
||||||
|
|
||||||
@router.get("/schema/bot")
|
@router.get("/schema/bot")
|
||||||
async def get_bot_config_schema():
|
async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
|
||||||
"""获取麦麦主程序配置架构"""
|
"""获取麦麦主程序配置架构"""
|
||||||
try:
|
try:
|
||||||
# Config 类包含所有子配置
|
# Config 类包含所有子配置
|
||||||
@@ -65,7 +74,7 @@ async def get_bot_config_schema():
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/schema/model")
|
@router.get("/schema/model")
|
||||||
async def get_model_config_schema():
|
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
|
||||||
"""获取模型配置架构(包含提供商和模型任务配置)"""
|
"""获取模型配置架构(包含提供商和模型任务配置)"""
|
||||||
try:
|
try:
|
||||||
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
|
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
|
||||||
@@ -79,7 +88,7 @@ async def get_model_config_schema():
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/schema/section/{section_name}")
|
@router.get("/schema/section/{section_name}")
|
||||||
async def get_config_section_schema(section_name: str):
|
async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)):
|
||||||
"""
|
"""
|
||||||
获取指定配置节的架构
|
获取指定配置节的架构
|
||||||
|
|
||||||
@@ -149,7 +158,7 @@ async def get_config_section_schema(section_name: str):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/bot")
|
@router.get("/bot")
|
||||||
async def get_bot_config():
|
async def get_bot_config(_auth: bool = Depends(require_auth)):
|
||||||
"""获取麦麦主程序配置"""
|
"""获取麦麦主程序配置"""
|
||||||
try:
|
try:
|
||||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||||
@@ -168,7 +177,7 @@ async def get_bot_config():
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/model")
|
@router.get("/model")
|
||||||
async def get_model_config():
|
async def get_model_config(_auth: bool = Depends(require_auth)):
|
||||||
"""获取模型配置(包含提供商和模型任务配置)"""
|
"""获取模型配置(包含提供商和模型任务配置)"""
|
||||||
try:
|
try:
|
||||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||||
@@ -190,7 +199,7 @@ async def get_model_config():
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/bot")
|
@router.post("/bot")
|
||||||
async def update_bot_config(config_data: ConfigBody):
|
async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
|
||||||
"""更新麦麦主程序配置"""
|
"""更新麦麦主程序配置"""
|
||||||
try:
|
try:
|
||||||
# 验证配置数据
|
# 验证配置数据
|
||||||
@@ -213,7 +222,7 @@ async def update_bot_config(config_data: ConfigBody):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/model")
|
@router.post("/model")
|
||||||
async def update_model_config(config_data: ConfigBody):
|
async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
|
||||||
"""更新模型配置"""
|
"""更新模型配置"""
|
||||||
try:
|
try:
|
||||||
# 验证配置数据
|
# 验证配置数据
|
||||||
@@ -239,7 +248,7 @@ async def update_model_config(config_data: ConfigBody):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/bot/section/{section_name}")
|
@router.post("/bot/section/{section_name}")
|
||||||
async def update_bot_config_section(section_name: str, section_data: SectionBody):
|
async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
|
||||||
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
||||||
try:
|
try:
|
||||||
# 读取现有配置
|
# 读取现有配置
|
||||||
@@ -288,7 +297,7 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/bot/raw")
|
@router.get("/bot/raw")
|
||||||
async def get_bot_config_raw():
|
async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
|
||||||
"""获取麦麦主程序配置的原始 TOML 内容"""
|
"""获取麦麦主程序配置的原始 TOML 内容"""
|
||||||
try:
|
try:
|
||||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||||
@@ -307,7 +316,7 @@ async def get_bot_config_raw():
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/bot/raw")
|
@router.post("/bot/raw")
|
||||||
async def update_bot_config_raw(raw_content: RawContentBody):
|
async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)):
|
||||||
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
||||||
try:
|
try:
|
||||||
# 验证 TOML 格式
|
# 验证 TOML 格式
|
||||||
@@ -337,7 +346,9 @@ async def update_bot_config_raw(raw_content: RawContentBody):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/model/section/{section_name}")
|
@router.post("/model/section/{section_name}")
|
||||||
async def update_model_config_section(section_name: str, section_data: SectionBody):
|
async def update_model_config_section(
|
||||||
|
section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)
|
||||||
|
):
|
||||||
"""更新模型配置的指定节(保留注释和格式)"""
|
"""更新模型配置的指定节(保留注释和格式)"""
|
||||||
try:
|
try:
|
||||||
# 读取现有配置
|
# 读取现有配置
|
||||||
@@ -368,6 +379,17 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
|
|||||||
try:
|
try:
|
||||||
APIAdapterConfig.from_dict(config_data)
|
APIAdapterConfig.from_dict(config_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
|
||||||
|
# 特殊处理:如果是更新 api_providers,检查是否有模型引用了已删除的provider
|
||||||
|
if section_name == "api_providers" and "api_provider" in str(e):
|
||||||
|
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
|
||||||
|
models = config_data.get("models", [])
|
||||||
|
orphaned_models = [
|
||||||
|
m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
|
||||||
|
]
|
||||||
|
if orphaned_models:
|
||||||
|
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
|
||||||
|
raise HTTPException(status_code=400, detail=error_msg) from e
|
||||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||||
|
|
||||||
# 保存配置(格式化数组为多行,保留注释)
|
# 保存配置(格式化数组为多行,保留注释)
|
||||||
@@ -418,7 +440,7 @@ def _to_relative_path(path: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/adapter-config/path")
|
@router.get("/adapter-config/path")
|
||||||
async def get_adapter_config_path():
|
async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
|
||||||
"""获取保存的适配器配置文件路径"""
|
"""获取保存的适配器配置文件路径"""
|
||||||
try:
|
try:
|
||||||
# 从 data/webui.json 读取路径偏好
|
# 从 data/webui.json 读取路径偏好
|
||||||
@@ -457,7 +479,7 @@ async def get_adapter_config_path():
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/adapter-config/path")
|
@router.post("/adapter-config/path")
|
||||||
async def save_adapter_config_path(data: PathBody):
|
async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)):
|
||||||
"""保存适配器配置文件路径偏好"""
|
"""保存适配器配置文件路径偏好"""
|
||||||
try:
|
try:
|
||||||
path = data.get("path")
|
path = data.get("path")
|
||||||
@@ -500,7 +522,7 @@ async def save_adapter_config_path(data: PathBody):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/adapter-config")
|
@router.get("/adapter-config")
|
||||||
async def get_adapter_config(path: str):
|
async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
|
||||||
"""从指定路径读取适配器配置文件"""
|
"""从指定路径读取适配器配置文件"""
|
||||||
try:
|
try:
|
||||||
if not path:
|
if not path:
|
||||||
@@ -532,7 +554,7 @@ async def get_adapter_config(path: str):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/adapter-config")
|
@router.post("/adapter-config")
|
||||||
async def save_adapter_config(data: PathBody):
|
async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)):
|
||||||
"""保存适配器配置到指定路径"""
|
"""保存适配器配置到指定路径"""
|
||||||
try:
|
try:
|
||||||
path = data.get("path")
|
path = data.get("path")
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
""" 表情包管理 API 路由"""
|
"""表情包管理 API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
|
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
|
||||||
from fastapi.responses import FileResponse, JSONResponse
|
from fastapi.responses import FileResponse, JSONResponse
|
||||||
@@ -100,23 +100,23 @@ def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
|
|||||||
try:
|
try:
|
||||||
with Image.open(source_path) as img:
|
with Image.open(source_path) as img:
|
||||||
# GIF 处理:提取第一帧
|
# GIF 处理:提取第一帧
|
||||||
if hasattr(img, 'n_frames') and img.n_frames > 1:
|
if hasattr(img, "n_frames") and img.n_frames > 1:
|
||||||
img.seek(0) # 确保在第一帧
|
img.seek(0) # 确保在第一帧
|
||||||
|
|
||||||
# 转换为 RGB/RGBA(WebP 支持透明度)
|
# 转换为 RGB/RGBA(WebP 支持透明度)
|
||||||
if img.mode in ('P', 'PA'):
|
if img.mode in ("P", "PA"):
|
||||||
# 调色板模式转换为 RGBA 以保留透明度
|
# 调色板模式转换为 RGBA 以保留透明度
|
||||||
img = img.convert('RGBA')
|
img = img.convert("RGBA")
|
||||||
elif img.mode == 'LA':
|
elif img.mode == "LA":
|
||||||
img = img.convert('RGBA')
|
img = img.convert("RGBA")
|
||||||
elif img.mode not in ('RGB', 'RGBA'):
|
elif img.mode not in ("RGB", "RGBA"):
|
||||||
img = img.convert('RGB')
|
img = img.convert("RGB")
|
||||||
|
|
||||||
# 创建缩略图(保持宽高比)
|
# 创建缩略图(保持宽高比)
|
||||||
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
|
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
# 保存为 WebP 格式
|
# 保存为 WebP 格式
|
||||||
img.save(cache_path, 'WEBP', quality=THUMBNAIL_QUALITY, method=6)
|
img.save(cache_path, "WEBP", quality=THUMBNAIL_QUALITY, method=6)
|
||||||
|
|
||||||
logger.debug(f"生成缩略图: {file_hash} -> {cache_path}")
|
logger.debug(f"生成缩略图: {file_hash} -> {cache_path}")
|
||||||
|
|
||||||
@@ -163,6 +163,7 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]:
|
|||||||
|
|
||||||
return cleaned, kept
|
return cleaned, kept
|
||||||
|
|
||||||
|
|
||||||
# 模块级别的类型别名(解决 B008 ruff 错误)
|
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||||
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
|
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
|
||||||
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
|
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
|
||||||
@@ -365,7 +366,9 @@ async def get_emoji_list(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
||||||
async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def get_emoji_detail(
|
||||||
|
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
获取表情包详细信息
|
获取表情包详细信息
|
||||||
|
|
||||||
@@ -394,7 +397,12 @@ async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
||||||
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def update_emoji(
|
||||||
|
emoji_id: int,
|
||||||
|
request: EmojiUpdateRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
增量更新表情包(只更新提供的字段)
|
增量更新表情包(只更新提供的字段)
|
||||||
|
|
||||||
@@ -446,7 +454,9 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_sessio
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
||||||
async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def delete_emoji(
|
||||||
|
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
删除表情包
|
删除表情包
|
||||||
|
|
||||||
@@ -538,7 +548,9 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
||||||
async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def register_emoji(
|
||||||
|
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
注册表情包(快捷操作)
|
注册表情包(快捷操作)
|
||||||
|
|
||||||
@@ -578,7 +590,9 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
||||||
async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def ban_emoji(
|
||||||
|
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
禁用表情包(快捷操作)
|
禁用表情包(快捷操作)
|
||||||
|
|
||||||
@@ -680,9 +694,7 @@ async def get_emoji_thumbnail(
|
|||||||
}
|
}
|
||||||
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
||||||
return FileResponse(
|
return FileResponse(
|
||||||
path=emoji.full_path,
|
path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}"
|
||||||
media_type=media_type,
|
|
||||||
filename=f"{emoji.emoji_hash}.{emoji.format}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 尝试获取或生成缩略图
|
# 尝试获取或生成缩略图
|
||||||
@@ -692,9 +704,7 @@ async def get_emoji_thumbnail(
|
|||||||
if cache_path.exists():
|
if cache_path.exists():
|
||||||
# 缓存命中,直接返回
|
# 缓存命中,直接返回
|
||||||
return FileResponse(
|
return FileResponse(
|
||||||
path=str(cache_path),
|
path=str(cache_path), media_type="image/webp", filename=f"{emoji.emoji_hash}_thumb.webp"
|
||||||
media_type="image/webp",
|
|
||||||
filename=f"{emoji.emoji_hash}_thumb.webp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 缓存未命中,触发后台生成并返回 202
|
# 缓存未命中,触发后台生成并返回 202
|
||||||
@@ -703,11 +713,7 @@ async def get_emoji_thumbnail(
|
|||||||
# 标记为正在生成
|
# 标记为正在生成
|
||||||
_generating_thumbnails.add(emoji.emoji_hash)
|
_generating_thumbnails.add(emoji.emoji_hash)
|
||||||
# 提交到线程池后台生成
|
# 提交到线程池后台生成
|
||||||
_thumbnail_executor.submit(
|
_thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.emoji_hash)
|
||||||
_background_generate_thumbnail,
|
|
||||||
emoji.full_path,
|
|
||||||
emoji.emoji_hash
|
|
||||||
)
|
|
||||||
|
|
||||||
# 返回 202 Accepted,告诉前端缩略图正在生成中
|
# 返回 202 Accepted,告诉前端缩略图正在生成中
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@@ -719,7 +725,7 @@ async def get_emoji_thumbnail(
|
|||||||
},
|
},
|
||||||
headers={
|
headers={
|
||||||
"Retry-After": "1", # 建议 1 秒后重试
|
"Retry-After": "1", # 建议 1 秒后重试
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -730,7 +736,11 @@ async def get_emoji_thumbnail(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||||
async def batch_delete_emojis(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def batch_delete_emojis(
|
||||||
|
request: BatchDeleteRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
批量删除表情包
|
批量删除表情包
|
||||||
|
|
||||||
@@ -1234,12 +1244,7 @@ async def preheat_thumbnail_cache(
|
|||||||
try:
|
try:
|
||||||
# 使用线程池异步生成缩略图,避免阻塞事件循环
|
# 使用线程池异步生成缩略图,避免阻塞事件循环
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.emoji_hash)
|
||||||
_thumbnail_executor,
|
|
||||||
_generate_thumbnail,
|
|
||||||
emoji.full_path,
|
|
||||||
emoji.emoji_hash
|
|
||||||
)
|
|
||||||
generated += 1
|
generated += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}")
|
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}")
|
||||||
|
|||||||
@@ -256,7 +256,9 @@ async def get_expression_list(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||||
async def get_expression_detail(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def get_expression_detail(
|
||||||
|
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
获取表达方式详细信息
|
获取表达方式详细信息
|
||||||
|
|
||||||
@@ -285,7 +287,11 @@ async def get_expression_detail(expression_id: int, maibot_session: Optional[str
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/", response_model=ExpressionCreateResponse)
|
@router.post("/", response_model=ExpressionCreateResponse)
|
||||||
async def create_expression(request: ExpressionCreateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def create_expression(
|
||||||
|
request: ExpressionCreateRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
创建新的表达方式
|
创建新的表达方式
|
||||||
|
|
||||||
@@ -326,7 +332,10 @@ async def create_expression(request: ExpressionCreateRequest, maibot_session: Op
|
|||||||
|
|
||||||
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
||||||
async def update_expression(
|
async def update_expression(
|
||||||
expression_id: int, request: ExpressionUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
expression_id: int,
|
||||||
|
request: ExpressionUpdateRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
增量更新表达方式(只更新提供的字段)
|
增量更新表达方式(只更新提供的字段)
|
||||||
@@ -376,7 +385,9 @@ async def update_expression(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||||
async def delete_expression(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def delete_expression(
|
||||||
|
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
删除表达方式
|
删除表达方式
|
||||||
|
|
||||||
@@ -419,7 +430,11 @@ class BatchDeleteRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
|
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
|
||||||
async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def batch_delete_expressions(
|
||||||
|
request: BatchDeleteRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
批量删除表达方式
|
批量删除表达方式
|
||||||
|
|
||||||
@@ -460,7 +475,9 @@ async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session:
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/stats/summary")
|
@router.get("/stats/summary")
|
||||||
async def get_expression_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def get_expression_stats(
|
||||||
|
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
获取表达方式统计数据
|
获取表达方式统计数据
|
||||||
|
|
||||||
|
|||||||
@@ -277,11 +277,7 @@ async def get_chat_list():
|
|||||||
"""获取所有有黑话记录的聊天列表"""
|
"""获取所有有黑话记录的聊天列表"""
|
||||||
try:
|
try:
|
||||||
# 获取所有不同的 chat_id
|
# 获取所有不同的 chat_id
|
||||||
chat_ids = (
|
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
|
||||||
Jargon.select(Jargon.chat_id)
|
|
||||||
.distinct()
|
|
||||||
.where(Jargon.chat_id.is_null(False))
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
|
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
|
||||||
|
|
||||||
@@ -346,12 +342,7 @@ async def get_jargon_stats():
|
|||||||
complete_count = Jargon.select().where(Jargon.is_complete).count()
|
complete_count = Jargon.select().where(Jargon.is_complete).count()
|
||||||
|
|
||||||
# 关联的聊天数量
|
# 关联的聊天数量
|
||||||
chat_count = (
|
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
|
||||||
Jargon.select(Jargon.chat_id)
|
|
||||||
.distinct()
|
|
||||||
.where(Jargon.chat_id.is_null(False))
|
|
||||||
.count()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 按聊天统计 TOP 5
|
# 按聊天统计 TOP 5
|
||||||
top_chats = (
|
top_chats = (
|
||||||
@@ -403,9 +394,7 @@ async def create_jargon(request: JargonCreateRequest):
|
|||||||
"""创建黑话"""
|
"""创建黑话"""
|
||||||
try:
|
try:
|
||||||
# 检查是否已存在相同内容的黑话
|
# 检查是否已存在相同内容的黑话
|
||||||
existing = Jargon.get_or_none(
|
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
|
||||||
(Jargon.content == request.content) & (Jargon.chat_id == request.chat_id)
|
|
||||||
)
|
|
||||||
if existing:
|
if existing:
|
||||||
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
|
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
|
||||||
|
|
||||||
@@ -527,11 +516,7 @@ async def batch_set_jargon_status(
|
|||||||
if not ids:
|
if not ids:
|
||||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||||
|
|
||||||
updated_count = (
|
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
|
||||||
Jargon.update(is_jargon=is_jargon)
|
|
||||||
.where(Jargon.id.in_(ids))
|
|
||||||
.execute()
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录,is_jargon={is_jargon}")
|
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录,is_jargon={is_jargon}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,24 @@
|
|||||||
"""知识库图谱可视化 API 路由"""
|
"""知识库图谱可视化 API 路由"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from fastapi import APIRouter, Query
|
from fastapi import APIRouter, Query, Depends, Cookie, Header
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import logging
|
import logging
|
||||||
|
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
|
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
) -> bool:
|
||||||
|
"""认证依赖:验证用户是否已登录"""
|
||||||
|
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeNode(BaseModel):
|
class KnowledgeNode(BaseModel):
|
||||||
"""知识节点"""
|
"""知识节点"""
|
||||||
|
|
||||||
@@ -113,6 +122,7 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
|||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
||||||
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
|
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
|
||||||
|
_auth: bool = Depends(require_auth),
|
||||||
):
|
):
|
||||||
"""获取知识图谱(限制节点数量)
|
"""获取知识图谱(限制节点数量)
|
||||||
|
|
||||||
@@ -199,7 +209,7 @@ async def get_knowledge_graph(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/stats", response_model=KnowledgeStats)
|
@router.get("/stats", response_model=KnowledgeStats)
|
||||||
async def get_knowledge_stats():
|
async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
|
||||||
"""获取知识库统计信息
|
"""获取知识库统计信息
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -248,7 +258,7 @@ async def get_knowledge_stats():
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/search", response_model=List[KnowledgeNode])
|
@router.get("/search", response_model=List[KnowledgeNode])
|
||||||
async def search_knowledge_node(query: str = Query(..., min_length=1)):
|
async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)):
|
||||||
"""搜索知识节点
|
"""搜索知识节点
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
"""WebSocket 日志推送模块"""
|
"""WebSocket 日志推送模块"""
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||||
from typing import Set
|
from typing import Set, Optional
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.webui.token_manager import get_token_manager
|
||||||
|
from src.webui.ws_auth import verify_ws_token
|
||||||
|
|
||||||
logger = get_logger("webui.logs_ws")
|
logger = get_logger("webui.logs_ws")
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -73,14 +75,48 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
|
|||||||
|
|
||||||
|
|
||||||
@router.websocket("/ws/logs")
|
@router.websocket("/ws/logs")
|
||||||
async def websocket_logs(websocket: WebSocket):
|
async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)):
|
||||||
"""WebSocket 日志推送端点
|
"""WebSocket 日志推送端点
|
||||||
|
|
||||||
客户端连接后会持续接收服务器端的日志消息
|
客户端连接后会持续接收服务器端的日志消息
|
||||||
|
支持三种认证方式(按优先级):
|
||||||
|
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||||
|
2. Cookie 中的 maibot_session
|
||||||
|
3. 直接使用 session token(兼容)
|
||||||
|
|
||||||
|
示例:ws://host/ws/logs?token=xxx
|
||||||
"""
|
"""
|
||||||
|
is_authenticated = False
|
||||||
|
|
||||||
|
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||||
|
if token and verify_ws_token(token):
|
||||||
|
is_authenticated = True
|
||||||
|
logger.debug("WebSocket 使用临时 token 认证成功")
|
||||||
|
|
||||||
|
# 方式 2: 尝试从 Cookie 获取 session token
|
||||||
|
if not is_authenticated:
|
||||||
|
cookie_token = websocket.cookies.get("maibot_session")
|
||||||
|
if cookie_token:
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if token_manager.verify_token(cookie_token):
|
||||||
|
is_authenticated = True
|
||||||
|
logger.debug("WebSocket 使用 Cookie 认证成功")
|
||||||
|
|
||||||
|
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||||
|
if not is_authenticated and token:
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if token_manager.verify_token(token):
|
||||||
|
is_authenticated = True
|
||||||
|
logger.debug("WebSocket 使用 session token 认证成功")
|
||||||
|
|
||||||
|
if not is_authenticated:
|
||||||
|
logger.warning("WebSocket 连接被拒绝:认证失败")
|
||||||
|
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||||
|
return
|
||||||
|
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
active_connections.add(websocket)
|
active_connections.add(websocket)
|
||||||
logger.info(f"📡 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
|
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
|
||||||
|
|
||||||
# 连接建立后,立即发送历史日志
|
# 连接建立后,立即发送历史日志
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -6,18 +6,27 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter, HTTPException, Query
|
from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import tomlkit
|
import tomlkit
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import CONFIG_DIR
|
from src.config.config import CONFIG_DIR
|
||||||
|
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||||
|
|
||||||
logger = get_logger("webui")
|
logger = get_logger("webui")
|
||||||
|
|
||||||
router = APIRouter(prefix="/models", tags=["models"])
|
router = APIRouter(prefix="/models", tags=["models"])
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
) -> bool:
|
||||||
|
"""认证依赖:验证用户是否已登录"""
|
||||||
|
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||||
|
|
||||||
|
|
||||||
# 模型获取器配置
|
# 模型获取器配置
|
||||||
MODEL_FETCHER_CONFIG = {
|
MODEL_FETCHER_CONFIG = {
|
||||||
# OpenAI 兼容格式的提供商
|
# OpenAI 兼容格式的提供商
|
||||||
@@ -184,6 +193,7 @@ async def get_provider_models(
|
|||||||
provider_name: str = Query(..., description="提供商名称"),
|
provider_name: str = Query(..., description="提供商名称"),
|
||||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||||
|
_auth: bool = Depends(require_auth),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取指定提供商的可用模型列表
|
获取指定提供商的可用模型列表
|
||||||
@@ -228,6 +238,7 @@ async def get_models_by_url(
|
|||||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||||
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
|
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
|
||||||
|
_auth: bool = Depends(require_auth),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
通过 URL 直接获取模型列表(用于自定义提供商)
|
通过 URL 直接获取模型列表(用于自定义提供商)
|
||||||
@@ -251,6 +262,7 @@ async def get_models_by_url(
|
|||||||
async def test_provider_connection(
|
async def test_provider_connection(
|
||||||
base_url: str = Query(..., description="提供商的基础 URL"),
|
base_url: str = Query(..., description="提供商的基础 URL"),
|
||||||
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
||||||
|
_auth: bool = Depends(require_auth),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
测试提供商连接状态
|
测试提供商连接状态
|
||||||
@@ -337,6 +349,7 @@ async def test_provider_connection(
|
|||||||
@router.post("/test-connection-by-name")
|
@router.post("/test-connection-by-name")
|
||||||
async def test_provider_connection_by_name(
|
async def test_provider_connection_by_name(
|
||||||
provider_name: str = Query(..., description="提供商名称"),
|
provider_name: str = Query(..., description="提供商名称"),
|
||||||
|
_auth: bool = Depends(require_auth),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
通过提供商名称测试连接(从配置文件读取信息)
|
通过提供商名称测试连接(从配置文件读取信息)
|
||||||
|
|||||||
@@ -200,7 +200,9 @@ async def get_person_list(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
||||||
async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def get_person_detail(
|
||||||
|
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
获取人物详细信息
|
获取人物详细信息
|
||||||
|
|
||||||
@@ -229,7 +231,12 @@ async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cook
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
||||||
async def update_person(person_id: str, request: PersonUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def update_person(
|
||||||
|
person_id: str,
|
||||||
|
request: PersonUpdateRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
增量更新人物信息(只更新提供的字段)
|
增量更新人物信息(只更新提供的字段)
|
||||||
|
|
||||||
@@ -278,7 +285,9 @@ async def update_person(person_id: str, request: PersonUpdateRequest, maibot_ses
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
||||||
async def delete_person(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def delete_person(
|
||||||
|
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
删除人物信息
|
删除人物信息
|
||||||
|
|
||||||
@@ -348,7 +357,11 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||||
async def batch_delete_persons(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def batch_delete_persons(
|
||||||
|
request: BatchDeleteRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
批量删除人物信息
|
批量删除人物信息
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
"""WebSocket 插件加载进度推送模块"""
|
"""WebSocket 插件加载进度推送模块"""
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||||
from typing import Set, Dict, Any
|
from typing import Set, Dict, Any, Optional
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.webui.token_manager import get_token_manager
|
||||||
|
from src.webui.ws_auth import verify_ws_token
|
||||||
|
|
||||||
logger = get_logger("webui.plugin_progress")
|
logger = get_logger("webui.plugin_progress")
|
||||||
|
|
||||||
@@ -89,14 +91,48 @@ async def update_progress(
|
|||||||
|
|
||||||
|
|
||||||
@router.websocket("/ws/plugin-progress")
|
@router.websocket("/ws/plugin-progress")
|
||||||
async def websocket_plugin_progress(websocket: WebSocket):
|
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)):
|
||||||
"""WebSocket 插件加载进度推送端点
|
"""WebSocket 插件加载进度推送端点
|
||||||
|
|
||||||
客户端连接后会立即收到当前进度状态
|
客户端连接后会立即收到当前进度状态
|
||||||
|
支持三种认证方式(按优先级):
|
||||||
|
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||||
|
2. Cookie 中的 maibot_session
|
||||||
|
3. 直接使用 session token(兼容)
|
||||||
|
|
||||||
|
示例:ws://host/ws/plugin-progress?token=xxx
|
||||||
"""
|
"""
|
||||||
|
is_authenticated = False
|
||||||
|
|
||||||
|
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||||
|
if token and verify_ws_token(token):
|
||||||
|
is_authenticated = True
|
||||||
|
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
|
||||||
|
|
||||||
|
# 方式 2: 尝试从 Cookie 获取 session token
|
||||||
|
if not is_authenticated:
|
||||||
|
cookie_token = websocket.cookies.get("maibot_session")
|
||||||
|
if cookie_token:
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if token_manager.verify_token(cookie_token):
|
||||||
|
is_authenticated = True
|
||||||
|
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
|
||||||
|
|
||||||
|
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||||
|
if not is_authenticated and token:
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if token_manager.verify_token(token):
|
||||||
|
is_authenticated = True
|
||||||
|
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
|
||||||
|
|
||||||
|
if not is_authenticated:
|
||||||
|
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
|
||||||
|
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||||
|
return
|
||||||
|
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
active_connections.add(websocket)
|
active_connections.add(websocket)
|
||||||
logger.info(f"📡 插件进度 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
|
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 发送当前进度状态
|
# 发送当前进度状态
|
||||||
|
|||||||
@@ -34,6 +34,85 @@ def get_token_from_cookie_or_header(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_safe_path(user_path: str, base_path: Path) -> Path:
|
||||||
|
"""
|
||||||
|
验证用户提供的路径是否安全,防止路径遍历攻击
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_path: 用户输入的路径(相对路径)
|
||||||
|
base_path: 允许的基础目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
安全的绝对路径
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 如果检测到路径遍历攻击
|
||||||
|
"""
|
||||||
|
# 规范化基础路径
|
||||||
|
base_resolved = base_path.resolve()
|
||||||
|
|
||||||
|
# 检查用户路径是否包含可疑字符
|
||||||
|
# 禁止: .., 绝对路径开头, 空字节等
|
||||||
|
if any(pattern in user_path for pattern in ["..", "\x00"]):
|
||||||
|
logger.warning(f"检测到可疑路径: {user_path}")
|
||||||
|
raise HTTPException(status_code=400, detail="路径包含非法字符")
|
||||||
|
|
||||||
|
# 检查是否为绝对路径(Windows 和 Unix)
|
||||||
|
if user_path.startswith("/") or user_path.startswith("\\") or (len(user_path) > 1 and user_path[1] == ":"):
|
||||||
|
logger.warning(f"检测到绝对路径: {user_path}")
|
||||||
|
raise HTTPException(status_code=400, detail="不允许使用绝对路径")
|
||||||
|
|
||||||
|
# 构建目标路径并解析
|
||||||
|
target_path = (base_path / user_path).resolve()
|
||||||
|
|
||||||
|
# 验证解析后的路径仍在基础目录内
|
||||||
|
try:
|
||||||
|
target_path.relative_to(base_resolved)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(f"路径遍历攻击检测: {user_path} -> {target_path}")
|
||||||
|
raise HTTPException(status_code=400, detail="路径超出允许范围") from e
|
||||||
|
|
||||||
|
return target_path
|
||||||
|
|
||||||
|
|
||||||
|
def validate_plugin_id(plugin_id: str) -> str:
|
||||||
|
"""
|
||||||
|
验证插件 ID 格式是否安全
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: 插件 ID (支持 author.name 格式,允许中文)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
验证通过的插件 ID
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 如果插件 ID 格式不安全
|
||||||
|
"""
|
||||||
|
# 禁止空字符串
|
||||||
|
if not plugin_id or not plugin_id.strip():
|
||||||
|
logger.warning("非法插件 ID: 空字符串")
|
||||||
|
raise HTTPException(status_code=400, detail="插件 ID 不能为空")
|
||||||
|
|
||||||
|
# 禁止危险字符: 路径分隔符、空字节、控制字符等
|
||||||
|
dangerous_patterns = ["/", "\\", "\x00", "..", "\n", "\r", "\t"]
|
||||||
|
for pattern in dangerous_patterns:
|
||||||
|
if pattern in plugin_id:
|
||||||
|
logger.warning(f"非法插件 ID 格式: {plugin_id} (包含危险字符)")
|
||||||
|
raise HTTPException(status_code=400, detail="插件 ID 包含非法字符")
|
||||||
|
|
||||||
|
# 禁止以点开头或结尾(防止隐藏文件和路径问题)
|
||||||
|
if plugin_id.startswith(".") or plugin_id.endswith("."):
|
||||||
|
logger.warning(f"非法插件 ID: {plugin_id}")
|
||||||
|
raise HTTPException(status_code=400, detail="插件 ID 不能以点开头或结尾")
|
||||||
|
|
||||||
|
# 禁止特殊名称
|
||||||
|
if plugin_id in (".", ".."):
|
||||||
|
logger.warning(f"非法插件 ID: {plugin_id}")
|
||||||
|
raise HTTPException(status_code=400, detail="插件 ID 不能为特殊目录名")
|
||||||
|
|
||||||
|
return plugin_id
|
||||||
|
|
||||||
|
|
||||||
def parse_version(version_str: str) -> tuple[int, int, int]:
|
def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||||
"""
|
"""
|
||||||
解析版本号字符串
|
解析版本号字符串
|
||||||
@@ -125,6 +204,7 @@ def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> No
|
|||||||
"""
|
"""
|
||||||
根据 schema 将配置中的类型纠正(目前只纠正 list-from-str)。
|
根据 schema 将配置中的类型纠正(目前只纠正 list-from-str)。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _is_list_type(tp: Any) -> bool:
|
def _is_list_type(tp: Any) -> bool:
|
||||||
origin = get_origin(tp)
|
origin = get_origin(tp)
|
||||||
return tp is list or origin is list
|
return tp is list or origin is list
|
||||||
@@ -313,7 +393,9 @@ async def check_git_status() -> GitStatusResponse:
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
|
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
|
||||||
async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse:
|
async def get_available_mirrors(
|
||||||
|
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
) -> AvailableMirrorsResponse:
|
||||||
"""
|
"""
|
||||||
获取所有可用的镜像源配置
|
获取所有可用的镜像源配置
|
||||||
"""
|
"""
|
||||||
@@ -343,7 +425,9 @@ async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), au
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/mirrors", response_model=MirrorConfigResponse)
|
@router.post("/mirrors", response_model=MirrorConfigResponse)
|
||||||
async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> MirrorConfigResponse:
|
async def add_mirror(
|
||||||
|
request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
) -> MirrorConfigResponse:
|
||||||
"""
|
"""
|
||||||
添加新的镜像源
|
添加新的镜像源
|
||||||
"""
|
"""
|
||||||
@@ -383,7 +467,10 @@ async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] =
|
|||||||
|
|
||||||
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
|
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
|
||||||
async def update_mirror(
|
async def update_mirror(
|
||||||
mirror_id: str, request: UpdateMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
mirror_id: str,
|
||||||
|
request: UpdateMirrorRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
) -> MirrorConfigResponse:
|
) -> MirrorConfigResponse:
|
||||||
"""
|
"""
|
||||||
更新镜像源配置
|
更新镜像源配置
|
||||||
@@ -426,7 +513,9 @@ async def update_mirror(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/mirrors/{mirror_id}")
|
@router.delete("/mirrors/{mirror_id}")
|
||||||
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
async def delete_mirror(
|
||||||
|
mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
删除镜像源
|
删除镜像源
|
||||||
"""
|
"""
|
||||||
@@ -449,26 +538,24 @@ async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(N
|
|||||||
|
|
||||||
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
|
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
|
||||||
async def fetch_raw_file(
|
async def fetch_raw_file(
|
||||||
request: FetchRawFileRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
request: FetchRawFileRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
) -> FetchRawFileResponse:
|
) -> FetchRawFileResponse:
|
||||||
"""
|
"""
|
||||||
获取 GitHub 仓库的 Raw 文件内容
|
获取 GitHub 仓库的 Raw 文件内容
|
||||||
|
|
||||||
支持多镜像源自动切换和错误重试
|
支持多镜像源自动切换和错误重试
|
||||||
|
|
||||||
注意:此接口可公开访问,用于获取插件仓库等公开资源
|
需要认证才能访问,防止被滥用作为 SSRF 跳板
|
||||||
"""
|
"""
|
||||||
# Token 验证(可选,用于日志记录)
|
# Token 验证(强制)
|
||||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||||
token_manager = get_token_manager()
|
token_manager = get_token_manager()
|
||||||
is_authenticated = token and token_manager.verify_token(token)
|
if not token or not token_manager.verify_token(token):
|
||||||
|
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||||
|
|
||||||
# 对于公开仓库的访问,不强制要求认证
|
logger.info(f"收到获取 Raw 文件请求: {request.owner}/{request.repo}/{request.branch}/{request.file_path}")
|
||||||
# 只在日志中记录是否认证
|
|
||||||
logger.info(
|
|
||||||
f"收到获取 Raw 文件请求 (认证: {is_authenticated}): "
|
|
||||||
f"{request.owner}/{request.repo}/{request.branch}/{request.file_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 发送开始加载进度
|
# 发送开始加载进度
|
||||||
await update_progress(
|
await update_progress(
|
||||||
@@ -534,7 +621,9 @@ async def fetch_raw_file(
|
|||||||
|
|
||||||
@router.post("/clone", response_model=CloneRepositoryResponse)
|
@router.post("/clone", response_model=CloneRepositoryResponse)
|
||||||
async def clone_repository(
|
async def clone_repository(
|
||||||
request: CloneRepositoryRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
request: CloneRepositoryRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
) -> CloneRepositoryResponse:
|
) -> CloneRepositoryResponse:
|
||||||
"""
|
"""
|
||||||
克隆 GitHub 仓库到本地
|
克隆 GitHub 仓库到本地
|
||||||
@@ -550,10 +639,10 @@ async def clone_repository(
|
|||||||
logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}")
|
logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: 验证 target_path 的安全性,防止路径遍历攻击
|
# 验证 target_path 的安全性,防止路径遍历攻击
|
||||||
# TODO: 确定实际的插件目录基路径
|
base_plugin_path = Path("./plugins").resolve()
|
||||||
base_plugin_path = Path("./plugins") # 临时路径
|
base_plugin_path.mkdir(exist_ok=True)
|
||||||
target_path = base_plugin_path / request.target_path
|
target_path = validate_safe_path(request.target_path, base_plugin_path)
|
||||||
|
|
||||||
service = get_git_mirror_service()
|
service = get_git_mirror_service()
|
||||||
result = await service.clone_repository(
|
result = await service.clone_repository(
|
||||||
@@ -574,7 +663,11 @@ async def clone_repository(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/install")
|
@router.post("/install")
|
||||||
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
async def install_plugin(
|
||||||
|
request: InstallPluginRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
安装插件
|
安装插件
|
||||||
|
|
||||||
@@ -589,13 +682,16 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
|
|||||||
logger.info(f"收到安装插件请求: {request.plugin_id}")
|
logger.info(f"收到安装插件请求: {request.plugin_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 验证插件 ID 格式安全性
|
||||||
|
plugin_id = validate_plugin_id(request.plugin_id)
|
||||||
|
|
||||||
# 推送进度:开始安装
|
# 推送进度:开始安装
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading",
|
stage="loading",
|
||||||
progress=5,
|
progress=5,
|
||||||
message=f"开始安装插件: {request.plugin_id}",
|
message=f"开始安装插件: {plugin_id}",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 解析仓库 URL
|
# 1. 解析仓库 URL
|
||||||
@@ -616,27 +712,28 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
|
|||||||
progress=10,
|
progress=10,
|
||||||
message=f"解析仓库信息: {owner}/{repo}",
|
message=f"解析仓库信息: {owner}/{repo}",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 确定插件安装路径
|
# 2. 确定插件安装路径
|
||||||
plugins_dir = Path("plugins")
|
plugins_dir = Path("plugins").resolve()
|
||||||
plugins_dir.mkdir(exist_ok=True)
|
plugins_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
# 将插件 ID 中的点替换为下划线作为文件夹名称(避免文件系统问题)
|
# 将插件 ID 中的点替换为下划线作为文件夹名称(避免文件系统问题)
|
||||||
# 例如: SengokuCola.Mute-Plugin -> SengokuCola_Mute-Plugin
|
# 例如: SengokuCola.Mute-Plugin -> SengokuCola_Mute-Plugin
|
||||||
folder_name = request.plugin_id.replace(".", "_")
|
folder_name = plugin_id.replace(".", "_")
|
||||||
target_path = plugins_dir / folder_name
|
# 使用安全路径验证,防止路径遍历
|
||||||
|
target_path = validate_safe_path(folder_name, plugins_dir)
|
||||||
|
|
||||||
# 检查插件是否已安装(需要检查两种格式:新格式下划线和旧格式点)
|
# 检查插件是否已安装(需要检查两种格式:新格式下划线和旧格式点)
|
||||||
old_format_path = plugins_dir / request.plugin_id
|
old_format_path = plugins_dir / plugin_id
|
||||||
if target_path.exists() or old_format_path.exists():
|
if target_path.exists() or old_format_path.exists():
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="error",
|
stage="error",
|
||||||
progress=0,
|
progress=0,
|
||||||
message="插件已存在",
|
message="插件已存在",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error="插件已安装,请先卸载",
|
error="插件已安装,请先卸载",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail="插件已安装")
|
raise HTTPException(status_code=400, detail="插件已安装")
|
||||||
@@ -646,7 +743,7 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
|
|||||||
progress=15,
|
progress=15,
|
||||||
message=f"准备克隆到: {target_path}",
|
message=f"准备克隆到: {target_path}",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 克隆仓库(这里会自动推送 20%-80% 的进度)
|
# 3. 克隆仓库(这里会自动推送 20%-80% 的进度)
|
||||||
@@ -675,14 +772,14 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
|
|||||||
progress=0,
|
progress=0,
|
||||||
message="克隆仓库失败",
|
message="克隆仓库失败",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error=error_msg,
|
error=error_msg,
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=500, detail=error_msg)
|
raise HTTPException(status_code=500, detail=error_msg)
|
||||||
|
|
||||||
# 4. 验证插件完整性
|
# 4. 验证插件完整性
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=request.plugin_id
|
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=plugin_id
|
||||||
)
|
)
|
||||||
|
|
||||||
manifest_path = target_path / "_manifest.json"
|
manifest_path = target_path / "_manifest.json"
|
||||||
@@ -697,14 +794,14 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
|
|||||||
progress=0,
|
progress=0,
|
||||||
message="插件缺少 _manifest.json",
|
message="插件缺少 _manifest.json",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error="无效的插件格式",
|
error="无效的插件格式",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||||
|
|
||||||
# 5. 读取并验证 manifest
|
# 5. 读取并验证 manifest
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=request.plugin_id
|
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=plugin_id
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -721,7 +818,7 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
|
|||||||
|
|
||||||
# 将插件 ID 写入 manifest(用于后续准确识别)
|
# 将插件 ID 写入 manifest(用于后续准确识别)
|
||||||
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
|
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
|
||||||
manifest["id"] = request.plugin_id
|
manifest["id"] = plugin_id
|
||||||
with open(manifest_path, "w", encoding="utf-8") as f:
|
with open(manifest_path, "w", encoding="utf-8") as f:
|
||||||
json_module.dump(manifest, f, ensure_ascii=False, indent=2)
|
json_module.dump(manifest, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
@@ -736,7 +833,7 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
|
|||||||
progress=0,
|
progress=0,
|
||||||
message="_manifest.json 无效",
|
message="_manifest.json 无效",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||||
@@ -747,13 +844,13 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
|
|||||||
progress=100,
|
progress=100,
|
||||||
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
|
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": "插件安装成功",
|
"message": "插件安装成功",
|
||||||
"plugin_id": request.plugin_id,
|
"plugin_id": plugin_id,
|
||||||
"plugin_name": manifest["name"],
|
"plugin_name": manifest["name"],
|
||||||
"version": manifest["version"],
|
"version": manifest["version"],
|
||||||
"path": str(target_path),
|
"path": str(target_path),
|
||||||
@@ -769,7 +866,7 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
|
|||||||
progress=0,
|
progress=0,
|
||||||
message="安装失败",
|
message="安装失败",
|
||||||
operation="install",
|
operation="install",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -778,7 +875,9 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
|
|||||||
|
|
||||||
@router.post("/uninstall")
|
@router.post("/uninstall")
|
||||||
async def uninstall_plugin(
|
async def uninstall_plugin(
|
||||||
request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
request: UninstallPluginRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
卸载插件
|
卸载插件
|
||||||
@@ -794,22 +893,26 @@ async def uninstall_plugin(
|
|||||||
logger.info(f"收到卸载插件请求: {request.plugin_id}")
|
logger.info(f"收到卸载插件请求: {request.plugin_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 验证插件 ID 格式安全性
|
||||||
|
plugin_id = validate_plugin_id(request.plugin_id)
|
||||||
|
|
||||||
# 推送进度:开始卸载
|
# 推送进度:开始卸载
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading",
|
stage="loading",
|
||||||
progress=10,
|
progress=10,
|
||||||
message=f"开始卸载插件: {request.plugin_id}",
|
message=f"开始卸载插件: {plugin_id}",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 检查插件是否存在(支持新旧两种格式)
|
# 1. 检查插件是否存在(支持新旧两种格式)
|
||||||
plugins_dir = Path("plugins")
|
plugins_dir = Path("plugins").resolve()
|
||||||
# 新格式:下划线
|
# 新格式:下划线
|
||||||
folder_name = request.plugin_id.replace(".", "_")
|
folder_name = plugin_id.replace(".", "_")
|
||||||
plugin_path = plugins_dir / folder_name
|
# 使用安全路径验证
|
||||||
|
plugin_path = validate_safe_path(folder_name, plugins_dir)
|
||||||
# 旧格式:点
|
# 旧格式:点
|
||||||
old_format_path = plugins_dir / request.plugin_id
|
old_format_path = validate_safe_path(plugin_id, plugins_dir)
|
||||||
|
|
||||||
# 优先使用新格式,如果不存在则尝试旧格式
|
# 优先使用新格式,如果不存在则尝试旧格式
|
||||||
if not plugin_path.exists():
|
if not plugin_path.exists():
|
||||||
@@ -821,7 +924,7 @@ async def uninstall_plugin(
|
|||||||
progress=0,
|
progress=0,
|
||||||
message="插件不存在",
|
message="插件不存在",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error="插件未安装或已被删除",
|
error="插件未安装或已被删除",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=404, detail="插件未安装")
|
raise HTTPException(status_code=404, detail="插件未安装")
|
||||||
@@ -831,12 +934,12 @@ async def uninstall_plugin(
|
|||||||
progress=30,
|
progress=30,
|
||||||
message=f"正在删除插件文件: {plugin_path}",
|
message=f"正在删除插件文件: {plugin_path}",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 读取插件信息(用于日志)
|
# 2. 读取插件信息(用于日志)
|
||||||
manifest_path = plugin_path / "_manifest.json"
|
manifest_path = plugin_path / "_manifest.json"
|
||||||
plugin_name = request.plugin_id
|
plugin_name = plugin_id
|
||||||
|
|
||||||
if manifest_path.exists():
|
if manifest_path.exists():
|
||||||
try:
|
try:
|
||||||
@@ -844,7 +947,7 @@ async def uninstall_plugin(
|
|||||||
|
|
||||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
manifest = json_module.load(f)
|
manifest = json_module.load(f)
|
||||||
plugin_name = manifest.get("name", request.plugin_id)
|
plugin_name = manifest.get("name", plugin_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # 如果读取失败,使用插件 ID 作为名称
|
pass # 如果读取失败,使用插件 ID 作为名称
|
||||||
|
|
||||||
@@ -853,7 +956,7 @@ async def uninstall_plugin(
|
|||||||
progress=50,
|
progress=50,
|
||||||
message=f"正在删除 {plugin_name}...",
|
message=f"正在删除 {plugin_name}...",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 删除插件目录
|
# 3. 删除插件目录
|
||||||
@@ -869,7 +972,7 @@ async def uninstall_plugin(
|
|||||||
|
|
||||||
shutil.rmtree(plugin_path, onerror=remove_readonly)
|
shutil.rmtree(plugin_path, onerror=remove_readonly)
|
||||||
|
|
||||||
logger.info(f"成功卸载插件: {request.plugin_id} ({plugin_name})")
|
logger.info(f"成功卸载插件: {plugin_id} ({plugin_name})")
|
||||||
|
|
||||||
# 4. 推送成功状态
|
# 4. 推送成功状态
|
||||||
await update_progress(
|
await update_progress(
|
||||||
@@ -877,10 +980,10 @@ async def uninstall_plugin(
|
|||||||
progress=100,
|
progress=100,
|
||||||
message=f"成功卸载插件: {plugin_name}",
|
message=f"成功卸载插件: {plugin_name}",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"success": True, "message": "插件卸载成功", "plugin_id": request.plugin_id, "plugin_name": plugin_name}
|
return {"success": True, "message": "插件卸载成功", "plugin_id": plugin_id, "plugin_name": plugin_name}
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -892,7 +995,7 @@ async def uninstall_plugin(
|
|||||||
progress=0,
|
progress=0,
|
||||||
message="卸载失败",
|
message="卸载失败",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error="权限不足,无法删除插件文件",
|
error="权限不足,无法删除插件文件",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -905,7 +1008,7 @@ async def uninstall_plugin(
|
|||||||
progress=0,
|
progress=0,
|
||||||
message="卸载失败",
|
message="卸载失败",
|
||||||
operation="uninstall",
|
operation="uninstall",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -913,7 +1016,11 @@ async def uninstall_plugin(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/update")
|
@router.post("/update")
|
||||||
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
async def update_plugin(
|
||||||
|
request: UpdatePluginRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
更新插件
|
更新插件
|
||||||
|
|
||||||
@@ -928,22 +1035,26 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
|||||||
logger.info(f"收到更新插件请求: {request.plugin_id}")
|
logger.info(f"收到更新插件请求: {request.plugin_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 验证插件 ID 格式安全性
|
||||||
|
plugin_id = validate_plugin_id(request.plugin_id)
|
||||||
|
|
||||||
# 推送进度:开始更新
|
# 推送进度:开始更新
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading",
|
stage="loading",
|
||||||
progress=5,
|
progress=5,
|
||||||
message=f"开始更新插件: {request.plugin_id}",
|
message=f"开始更新插件: {plugin_id}",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 检查插件是否已安装(支持新旧两种格式)
|
# 1. 检查插件是否已安装(支持新旧两种格式)
|
||||||
plugins_dir = Path("plugins")
|
plugins_dir = Path("plugins").resolve()
|
||||||
# 新格式:下划线
|
# 新格式:下划线
|
||||||
folder_name = request.plugin_id.replace(".", "_")
|
folder_name = plugin_id.replace(".", "_")
|
||||||
plugin_path = plugins_dir / folder_name
|
# 使用安全路径验证
|
||||||
|
plugin_path = validate_safe_path(folder_name, plugins_dir)
|
||||||
# 旧格式:点
|
# 旧格式:点
|
||||||
old_format_path = plugins_dir / request.plugin_id
|
old_format_path = validate_safe_path(plugin_id, plugins_dir)
|
||||||
|
|
||||||
# 优先使用新格式,如果不存在则尝试旧格式
|
# 优先使用新格式,如果不存在则尝试旧格式
|
||||||
if not plugin_path.exists():
|
if not plugin_path.exists():
|
||||||
@@ -955,7 +1066,7 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
|||||||
progress=0,
|
progress=0,
|
||||||
message="插件不存在",
|
message="插件不存在",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error="插件未安装,请先安装",
|
error="插件未安装,请先安装",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=404, detail="插件未安装")
|
raise HTTPException(status_code=404, detail="插件未安装")
|
||||||
@@ -979,12 +1090,12 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
|||||||
progress=10,
|
progress=10,
|
||||||
message=f"当前版本: {old_version},准备更新...",
|
message=f"当前版本: {old_version},准备更新...",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 删除旧版本
|
# 3. 删除旧版本
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=request.plugin_id
|
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=plugin_id
|
||||||
)
|
)
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
@@ -999,7 +1110,7 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
|||||||
|
|
||||||
shutil.rmtree(plugin_path, onerror=remove_readonly)
|
shutil.rmtree(plugin_path, onerror=remove_readonly)
|
||||||
|
|
||||||
logger.info(f"已删除旧版本: {request.plugin_id} v{old_version}")
|
logger.info(f"已删除旧版本: {plugin_id} v{old_version}")
|
||||||
|
|
||||||
# 4. 解析仓库 URL
|
# 4. 解析仓库 URL
|
||||||
await update_progress(
|
await update_progress(
|
||||||
@@ -1007,7 +1118,7 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
|||||||
progress=30,
|
progress=30,
|
||||||
message="正在准备下载新版本...",
|
message="正在准备下载新版本...",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
repo_url = request.repository_url.rstrip("/")
|
repo_url = request.repository_url.rstrip("/")
|
||||||
@@ -1045,14 +1156,14 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
|||||||
progress=0,
|
progress=0,
|
||||||
message="下载新版本失败",
|
message="下载新版本失败",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error=error_msg,
|
error=error_msg,
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=500, detail=error_msg)
|
raise HTTPException(status_code=500, detail=error_msg)
|
||||||
|
|
||||||
# 6. 验证新版本
|
# 6. 验证新版本
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=request.plugin_id
|
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=plugin_id
|
||||||
)
|
)
|
||||||
|
|
||||||
new_manifest_path = plugin_path / "_manifest.json"
|
new_manifest_path = plugin_path / "_manifest.json"
|
||||||
@@ -1072,7 +1183,7 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
|||||||
progress=0,
|
progress=0,
|
||||||
message="新版本缺少 _manifest.json",
|
message="新版本缺少 _manifest.json",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error="无效的插件格式",
|
error="无效的插件格式",
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||||
@@ -1083,9 +1194,9 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
|||||||
new_manifest = json_module.load(f)
|
new_manifest = json_module.load(f)
|
||||||
|
|
||||||
new_version = new_manifest.get("version", "unknown")
|
new_version = new_manifest.get("version", "unknown")
|
||||||
new_name = new_manifest.get("name", request.plugin_id)
|
new_name = new_manifest.get("name", plugin_id)
|
||||||
|
|
||||||
logger.info(f"成功更新插件: {request.plugin_id} {old_version} → {new_version}")
|
logger.info(f"成功更新插件: {plugin_id} {old_version} → {new_version}")
|
||||||
|
|
||||||
# 8. 推送成功状态
|
# 8. 推送成功状态
|
||||||
await update_progress(
|
await update_progress(
|
||||||
@@ -1093,13 +1204,13 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
|||||||
progress=100,
|
progress=100,
|
||||||
message=f"成功更新 {new_name}: {old_version} → {new_version}",
|
message=f"成功更新 {new_name}: {old_version} → {new_version}",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": "插件更新成功",
|
"message": "插件更新成功",
|
||||||
"plugin_id": request.plugin_id,
|
"plugin_id": plugin_id,
|
||||||
"plugin_name": new_name,
|
"plugin_name": new_name,
|
||||||
"old_version": old_version,
|
"old_version": old_version,
|
||||||
"new_version": new_version,
|
"new_version": new_version,
|
||||||
@@ -1114,7 +1225,7 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
|||||||
progress=0,
|
progress=0,
|
||||||
message="_manifest.json 无效",
|
message="_manifest.json 无效",
|
||||||
operation="update",
|
operation="update",
|
||||||
plugin_id=request.plugin_id,
|
plugin_id=plugin_id,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||||
@@ -1125,14 +1236,16 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
|||||||
logger.error(f"更新插件失败: {e}", exc_info=True)
|
logger.error(f"更新插件失败: {e}", exc_info=True)
|
||||||
|
|
||||||
await update_progress(
|
await update_progress(
|
||||||
stage="error", progress=0, message="更新失败", operation="update", plugin_id=request.plugin_id, error=str(e)
|
stage="error", progress=0, message="更新失败", operation="update", plugin_id=plugin_id, error=str(e)
|
||||||
)
|
)
|
||||||
|
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.get("/installed")
|
@router.get("/installed")
|
||||||
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
async def get_installed_plugins(
|
||||||
|
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取已安装的插件列表
|
获取已安装的插件列表
|
||||||
|
|
||||||
@@ -1272,7 +1385,9 @@ class UpdatePluginConfigRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/config/{plugin_id}/schema")
|
@router.get("/config/{plugin_id}/schema")
|
||||||
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
async def get_plugin_config_schema(
|
||||||
|
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取插件配置 Schema
|
获取插件配置 Schema
|
||||||
|
|
||||||
@@ -1373,12 +1488,34 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
|
|||||||
# 推断字段类型
|
# 推断字段类型
|
||||||
field_type = type(field_value).__name__
|
field_type = type(field_value).__name__
|
||||||
ui_type = "text"
|
ui_type = "text"
|
||||||
|
item_type = None
|
||||||
|
item_fields = None
|
||||||
|
|
||||||
if isinstance(field_value, bool):
|
if isinstance(field_value, bool):
|
||||||
ui_type = "switch"
|
ui_type = "switch"
|
||||||
elif isinstance(field_value, (int, float)):
|
elif isinstance(field_value, (int, float)):
|
||||||
ui_type = "number"
|
ui_type = "number"
|
||||||
elif isinstance(field_value, list):
|
elif isinstance(field_value, list):
|
||||||
ui_type = "list"
|
ui_type = "list"
|
||||||
|
# 推断数组元素类型
|
||||||
|
if field_value:
|
||||||
|
first_item = field_value[0]
|
||||||
|
if isinstance(first_item, dict):
|
||||||
|
item_type = "object"
|
||||||
|
# 从第一个元素推断字段结构
|
||||||
|
item_fields = {}
|
||||||
|
for k, v in first_item.items():
|
||||||
|
item_fields[k] = {
|
||||||
|
"type": "number" if isinstance(v, (int, float)) else "string",
|
||||||
|
"label": k,
|
||||||
|
"default": "" if isinstance(v, str) else 0,
|
||||||
|
}
|
||||||
|
elif isinstance(first_item, (int, float)):
|
||||||
|
item_type = "number"
|
||||||
|
else:
|
||||||
|
item_type = "string"
|
||||||
|
else:
|
||||||
|
item_type = "string"
|
||||||
elif isinstance(field_value, dict):
|
elif isinstance(field_value, dict):
|
||||||
ui_type = "json"
|
ui_type = "json"
|
||||||
|
|
||||||
@@ -1393,6 +1530,26 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
|
|||||||
"hidden": False,
|
"hidden": False,
|
||||||
"disabled": False,
|
"disabled": False,
|
||||||
"order": 0,
|
"order": 0,
|
||||||
|
"item_type": item_type,
|
||||||
|
"item_fields": item_fields,
|
||||||
|
"min_items": None,
|
||||||
|
"max_items": None,
|
||||||
|
# 补充缺失的字段
|
||||||
|
"placeholder": None,
|
||||||
|
"hint": None,
|
||||||
|
"icon": None,
|
||||||
|
"example": None,
|
||||||
|
"choices": None,
|
||||||
|
"min": None,
|
||||||
|
"max": None,
|
||||||
|
"step": None,
|
||||||
|
"pattern": None,
|
||||||
|
"max_length": None,
|
||||||
|
"input_type": None,
|
||||||
|
"rows": 3,
|
||||||
|
"group": None,
|
||||||
|
"depends_on": None,
|
||||||
|
"depends_value": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
return {"success": True, "schema": schema}
|
return {"success": True, "schema": schema}
|
||||||
@@ -1405,7 +1562,9 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/config/{plugin_id}")
|
@router.get("/config/{plugin_id}")
|
||||||
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
async def get_plugin_config(
|
||||||
|
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取插件当前配置值
|
获取插件当前配置值
|
||||||
|
|
||||||
@@ -1461,7 +1620,10 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook
|
|||||||
|
|
||||||
@router.put("/config/{plugin_id}")
|
@router.put("/config/{plugin_id}")
|
||||||
async def update_plugin_config(
|
async def update_plugin_config(
|
||||||
plugin_id: str, request: UpdatePluginConfigRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
plugin_id: str,
|
||||||
|
request: UpdatePluginConfigRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
更新插件配置
|
更新插件配置
|
||||||
@@ -1532,7 +1694,9 @@ async def update_plugin_config(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/config/{plugin_id}/reset")
|
@router.post("/config/{plugin_id}/reset")
|
||||||
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
async def reset_plugin_config(
|
||||||
|
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
重置插件配置为默认值
|
重置插件配置为默认值
|
||||||
|
|
||||||
@@ -1592,7 +1756,9 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/config/{plugin_id}/toggle")
|
@router.post("/config/{plugin_id}/toggle")
|
||||||
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
async def toggle_plugin(
|
||||||
|
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
切换插件启用状态
|
切换插件启用状态
|
||||||
|
|
||||||
|
|||||||
245
src/webui/rate_limiter.py
Normal file
245
src/webui/rate_limiter.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
"""
|
||||||
|
WebUI 请求频率限制模块
|
||||||
|
防止暴力破解和 API 滥用
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict, Tuple, Optional
|
||||||
|
from fastapi import Request, HTTPException
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("webui.rate_limiter")
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""
|
||||||
|
简单的内存请求频率限制器
|
||||||
|
|
||||||
|
使用滑动窗口算法实现
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# 存储格式: {key: [(timestamp, count), ...]}
|
||||||
|
self._requests: Dict[str, list] = defaultdict(list)
|
||||||
|
# 被封禁的 IP: {ip: unblock_timestamp}
|
||||||
|
self._blocked: Dict[str, float] = {}
|
||||||
|
|
||||||
|
def _get_client_ip(self, request: Request) -> str:
|
||||||
|
"""获取客户端 IP 地址"""
|
||||||
|
# 检查代理头
|
||||||
|
forwarded = request.headers.get("X-Forwarded-For")
|
||||||
|
if forwarded:
|
||||||
|
# 取第一个 IP(最原始的客户端)
|
||||||
|
return forwarded.split(",")[0].strip()
|
||||||
|
|
||||||
|
real_ip = request.headers.get("X-Real-IP")
|
||||||
|
if real_ip:
|
||||||
|
return real_ip
|
||||||
|
|
||||||
|
# 直接连接的客户端
|
||||||
|
if request.client:
|
||||||
|
return request.client.host
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
def _cleanup_old_requests(self, key: str, window_seconds: int):
|
||||||
|
"""清理过期的请求记录"""
|
||||||
|
now = time.time()
|
||||||
|
cutoff = now - window_seconds
|
||||||
|
self._requests[key] = [(ts, count) for ts, count in self._requests[key] if ts > cutoff]
|
||||||
|
|
||||||
|
def _cleanup_expired_blocks(self):
|
||||||
|
"""清理过期的封禁"""
|
||||||
|
now = time.time()
|
||||||
|
expired = [ip for ip, unblock_time in self._blocked.items() if now > unblock_time]
|
||||||
|
for ip in expired:
|
||||||
|
del self._blocked[ip]
|
||||||
|
logger.info(f"🔓 IP {ip} 封禁已解除")
|
||||||
|
|
||||||
|
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
|
||||||
|
"""
|
||||||
|
检查 IP 是否被封禁
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否被封禁, 剩余封禁秒数)
|
||||||
|
"""
|
||||||
|
self._cleanup_expired_blocks()
|
||||||
|
ip = self._get_client_ip(request)
|
||||||
|
|
||||||
|
if ip in self._blocked:
|
||||||
|
remaining = int(self._blocked[ip] - time.time())
|
||||||
|
return True, max(0, remaining)
|
||||||
|
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def check_rate_limit(
|
||||||
|
self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = ""
|
||||||
|
) -> Tuple[bool, int]:
|
||||||
|
"""
|
||||||
|
检查请求是否超过频率限制
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI Request 对象
|
||||||
|
max_requests: 窗口期内允许的最大请求数
|
||||||
|
window_seconds: 窗口时间(秒)
|
||||||
|
key_suffix: 键后缀,用于区分不同的限制规则
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否允许, 剩余请求数)
|
||||||
|
"""
|
||||||
|
ip = self._get_client_ip(request)
|
||||||
|
key = f"{ip}:{key_suffix}" if key_suffix else ip
|
||||||
|
|
||||||
|
# 清理过期记录
|
||||||
|
self._cleanup_old_requests(key, window_seconds)
|
||||||
|
|
||||||
|
# 计算当前窗口内的请求数
|
||||||
|
current_count = sum(count for _, count in self._requests[key])
|
||||||
|
|
||||||
|
if current_count >= max_requests:
|
||||||
|
return False, 0
|
||||||
|
|
||||||
|
# 记录新请求
|
||||||
|
now = time.time()
|
||||||
|
self._requests[key].append((now, 1))
|
||||||
|
|
||||||
|
remaining = max_requests - current_count - 1
|
||||||
|
return True, remaining
|
||||||
|
|
||||||
|
def block_ip(self, request: Request, duration_seconds: int):
|
||||||
|
"""
|
||||||
|
封禁 IP
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI Request 对象
|
||||||
|
duration_seconds: 封禁时长(秒)
|
||||||
|
"""
|
||||||
|
ip = self._get_client_ip(request)
|
||||||
|
self._blocked[ip] = time.time() + duration_seconds
|
||||||
|
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds} 秒")
|
||||||
|
|
||||||
|
def record_failed_attempt(
|
||||||
|
self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600
|
||||||
|
) -> Tuple[bool, int]:
|
||||||
|
"""
|
||||||
|
记录失败尝试(如登录失败)
|
||||||
|
|
||||||
|
如果在窗口期内失败次数过多,自动封禁 IP
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI Request 对象
|
||||||
|
max_failures: 允许的最大失败次数
|
||||||
|
window_seconds: 统计窗口(秒)
|
||||||
|
block_duration: 封禁时长(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否被封禁, 剩余尝试次数)
|
||||||
|
"""
|
||||||
|
ip = self._get_client_ip(request)
|
||||||
|
key = f"{ip}:auth_failures"
|
||||||
|
|
||||||
|
# 清理过期记录
|
||||||
|
self._cleanup_old_requests(key, window_seconds)
|
||||||
|
|
||||||
|
# 计算当前失败次数
|
||||||
|
current_failures = sum(count for _, count in self._requests[key])
|
||||||
|
|
||||||
|
# 记录本次失败
|
||||||
|
now = time.time()
|
||||||
|
self._requests[key].append((now, 1))
|
||||||
|
current_failures += 1
|
||||||
|
|
||||||
|
remaining = max_failures - current_failures
|
||||||
|
|
||||||
|
# 检查是否需要封禁
|
||||||
|
if current_failures >= max_failures:
|
||||||
|
self.block_ip(request, block_duration)
|
||||||
|
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
|
||||||
|
return True, 0
|
||||||
|
|
||||||
|
if current_failures >= max_failures - 2:
|
||||||
|
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures} 次")
|
||||||
|
|
||||||
|
return False, max(0, remaining)
|
||||||
|
|
||||||
|
def reset_failures(self, request: Request):
|
||||||
|
"""
|
||||||
|
重置失败计数(认证成功后调用)
|
||||||
|
"""
|
||||||
|
ip = self._get_client_ip(request)
|
||||||
|
key = f"{ip}:auth_failures"
|
||||||
|
if key in self._requests:
|
||||||
|
del self._requests[key]
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_rate_limiter: Optional[RateLimiter] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_rate_limiter() -> RateLimiter:
|
||||||
|
"""获取 RateLimiter 单例"""
|
||||||
|
global _rate_limiter
|
||||||
|
if _rate_limiter is None:
|
||||||
|
_rate_limiter = RateLimiter()
|
||||||
|
return _rate_limiter
|
||||||
|
|
||||||
|
|
||||||
|
async def check_auth_rate_limit(request: Request):
|
||||||
|
"""
|
||||||
|
认证接口的频率限制依赖
|
||||||
|
|
||||||
|
规则:
|
||||||
|
- 每个 IP 每分钟最多 10 次认证请求
|
||||||
|
- 连续失败 5 次后封禁 10 分钟
|
||||||
|
"""
|
||||||
|
limiter = get_rate_limiter()
|
||||||
|
|
||||||
|
# 检查是否被封禁
|
||||||
|
blocked, remaining_block = limiter.is_blocked(request)
|
||||||
|
if blocked:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||||
|
headers={"Retry-After": str(remaining_block)},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查频率限制
|
||||||
|
allowed, remaining = limiter.check_rate_limit(
|
||||||
|
request,
|
||||||
|
max_requests=10, # 每分钟 10 次
|
||||||
|
window_seconds=60,
|
||||||
|
key_suffix="auth",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
raise HTTPException(status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
|
||||||
|
|
||||||
|
|
||||||
|
async def check_api_rate_limit(request: Request):
|
||||||
|
"""
|
||||||
|
普通 API 的频率限制依赖
|
||||||
|
|
||||||
|
规则:每个 IP 每分钟最多 100 次请求
|
||||||
|
"""
|
||||||
|
limiter = get_rate_limiter()
|
||||||
|
|
||||||
|
# 检查是否被封禁
|
||||||
|
blocked, remaining_block = limiter.is_blocked(request)
|
||||||
|
if blocked:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||||
|
headers={"Retry-After": str(remaining_block)},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查频率限制
|
||||||
|
allowed, _ = limiter.check_rate_limit(
|
||||||
|
request,
|
||||||
|
max_requests=100, # 每分钟 100 次
|
||||||
|
window_seconds=60,
|
||||||
|
key_suffix="api",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
|
||||||
@@ -7,10 +7,12 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from fastapi import APIRouter, HTTPException
|
from typing import Optional
|
||||||
|
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from src.config.config import MMC_VERSION
|
from src.config.config import MMC_VERSION
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||||
|
|
||||||
router = APIRouter(prefix="/system", tags=["system"])
|
router = APIRouter(prefix="/system", tags=["system"])
|
||||||
logger = get_logger("webui_system")
|
logger = get_logger("webui_system")
|
||||||
@@ -19,6 +21,14 @@ logger = get_logger("webui_system")
|
|||||||
_start_time = time.time()
|
_start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
) -> bool:
|
||||||
|
"""认证依赖:验证用户是否已登录"""
|
||||||
|
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||||
|
|
||||||
|
|
||||||
class RestartResponse(BaseModel):
|
class RestartResponse(BaseModel):
|
||||||
"""重启响应"""
|
"""重启响应"""
|
||||||
|
|
||||||
@@ -36,7 +46,7 @@ class StatusResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/restart", response_model=RestartResponse)
|
@router.post("/restart", response_model=RestartResponse)
|
||||||
async def restart_maibot():
|
async def restart_maibot(_auth: bool = Depends(require_auth)):
|
||||||
"""
|
"""
|
||||||
重启麦麦主程序
|
重启麦麦主程序
|
||||||
|
|
||||||
@@ -67,7 +77,7 @@ async def restart_maibot():
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/status", response_model=StatusResponse)
|
@router.get("/status", response_model=StatusResponse)
|
||||||
async def get_maibot_status():
|
async def get_maibot_status(_auth: bool = Depends(require_auth)):
|
||||||
"""
|
"""
|
||||||
获取麦麦运行状态
|
获取麦麦运行状态
|
||||||
|
|
||||||
@@ -90,7 +100,7 @@ async def get_maibot_status():
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/reload-config")
|
@router.post("/reload-config")
|
||||||
async def reload_config():
|
async def reload_config(_auth: bool = Depends(require_auth)):
|
||||||
"""
|
"""
|
||||||
热重载配置(不重启进程)
|
热重载配置(不重启进程)
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
"""WebUI API 路由"""
|
"""WebUI API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie
|
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, Depends
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from .token_manager import get_token_manager
|
from .token_manager import get_token_manager
|
||||||
from .auth import set_auth_cookie, clear_auth_cookie
|
from .auth import set_auth_cookie, clear_auth_cookie
|
||||||
|
from .rate_limiter import get_rate_limiter, check_auth_rate_limit
|
||||||
from .config_routes import router as config_router
|
from .config_routes import router as config_router
|
||||||
from .statistics_routes import router as statistics_router
|
from .statistics_routes import router as statistics_router
|
||||||
from .person_routes import router as person_router
|
from .person_routes import router as person_router
|
||||||
@@ -16,6 +17,7 @@ from .plugin_routes import router as plugin_router
|
|||||||
from .plugin_progress_ws import get_progress_router
|
from .plugin_progress_ws import get_progress_router
|
||||||
from .routers.system import router as system_router
|
from .routers.system import router as system_router
|
||||||
from .model_routes import router as model_router
|
from .model_routes import router as model_router
|
||||||
|
from .ws_auth import router as ws_auth_router
|
||||||
|
|
||||||
logger = get_logger("webui.api")
|
logger = get_logger("webui.api")
|
||||||
|
|
||||||
@@ -42,6 +44,8 @@ router.include_router(get_progress_router())
|
|||||||
router.include_router(system_router)
|
router.include_router(system_router)
|
||||||
# 注册模型列表获取路由
|
# 注册模型列表获取路由
|
||||||
router.include_router(model_router)
|
router.include_router(model_router)
|
||||||
|
# 注册 WebSocket 认证路由
|
||||||
|
router.include_router(ws_auth_router)
|
||||||
|
|
||||||
|
|
||||||
class TokenVerifyRequest(BaseModel):
|
class TokenVerifyRequest(BaseModel):
|
||||||
@@ -107,12 +111,18 @@ async def health_check():
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/auth/verify", response_model=TokenVerifyResponse)
|
@router.post("/auth/verify", response_model=TokenVerifyResponse)
|
||||||
async def verify_token(request: TokenVerifyRequest, response: Response):
|
async def verify_token(
|
||||||
|
request_body: TokenVerifyRequest,
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
_rate_limit: None = Depends(check_auth_rate_limit),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
验证访问令牌,验证成功后设置 HttpOnly Cookie
|
验证访问令牌,验证成功后设置 HttpOnly Cookie
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: 包含 token 的验证请求
|
request_body: 包含 token 的验证请求
|
||||||
|
request: FastAPI Request 对象(用于获取客户端 IP)
|
||||||
response: FastAPI Response 对象
|
response: FastAPI Response 对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -120,16 +130,37 @@ async def verify_token(request: TokenVerifyRequest, response: Response):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
token_manager = get_token_manager()
|
token_manager = get_token_manager()
|
||||||
is_valid = token_manager.verify_token(request.token)
|
rate_limiter = get_rate_limiter()
|
||||||
|
|
||||||
|
is_valid = token_manager.verify_token(request_body.token)
|
||||||
|
|
||||||
if is_valid:
|
if is_valid:
|
||||||
|
# 认证成功,重置失败计数
|
||||||
|
rate_limiter.reset_failures(request)
|
||||||
# 设置 HttpOnly Cookie
|
# 设置 HttpOnly Cookie
|
||||||
set_auth_cookie(response, request.token)
|
set_auth_cookie(response, request_body.token)
|
||||||
# 同时返回首次配置状态,避免额外请求
|
# 同时返回首次配置状态,避免额外请求
|
||||||
is_first_setup = token_manager.is_first_setup()
|
is_first_setup = token_manager.is_first_setup()
|
||||||
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
|
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
|
||||||
else:
|
else:
|
||||||
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
|
# 记录失败尝试
|
||||||
|
blocked, remaining = rate_limiter.record_failed_attempt(
|
||||||
|
request,
|
||||||
|
max_failures=5, # 5 次失败
|
||||||
|
window_seconds=300, # 5 分钟窗口
|
||||||
|
block_duration=600, # 封禁 10 分钟
|
||||||
|
)
|
||||||
|
|
||||||
|
if blocked:
|
||||||
|
raise HTTPException(status_code=429, detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟")
|
||||||
|
|
||||||
|
message = "Token 无效或已过期"
|
||||||
|
if remaining <= 2:
|
||||||
|
message += f"(剩余 {remaining} 次尝试机会)"
|
||||||
|
|
||||||
|
return TokenVerifyResponse(valid=False, message=message)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Token 验证失败: {e}")
|
logger.error(f"Token 验证失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
||||||
|
|||||||
@@ -1,19 +1,28 @@
|
|||||||
"""统计数据 API 路由"""
|
"""统计数据 API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List, Optional
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from peewee import fn
|
from peewee import fn
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
|
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
|
||||||
|
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||||
|
|
||||||
logger = get_logger("webui.statistics")
|
logger = get_logger("webui.statistics")
|
||||||
|
|
||||||
router = APIRouter(prefix="/statistics", tags=["statistics"])
|
router = APIRouter(prefix="/statistics", tags=["statistics"])
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
) -> bool:
|
||||||
|
"""认证依赖:验证用户是否已登录"""
|
||||||
|
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||||
|
|
||||||
|
|
||||||
class StatisticsSummary(BaseModel):
|
class StatisticsSummary(BaseModel):
|
||||||
"""统计数据摘要"""
|
"""统计数据摘要"""
|
||||||
|
|
||||||
@@ -58,7 +67,7 @@ class DashboardData(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/dashboard", response_model=DashboardData)
|
@router.get("/dashboard", response_model=DashboardData)
|
||||||
async def get_dashboard_data(hours: int = 24):
|
async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||||
"""
|
"""
|
||||||
获取仪表盘统计数据
|
获取仪表盘统计数据
|
||||||
|
|
||||||
@@ -275,7 +284,7 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/summary")
|
@router.get("/summary")
|
||||||
async def get_summary(hours: int = 24):
|
async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||||
"""
|
"""
|
||||||
获取统计摘要
|
获取统计摘要
|
||||||
|
|
||||||
@@ -293,7 +302,7 @@ async def get_summary(hours: int = 24):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/models")
|
@router.get("/models")
|
||||||
async def get_model_stats(hours: int = 24):
|
async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||||
"""
|
"""
|
||||||
获取模型统计
|
获取模型统计
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,9 @@ class WebUIServer:
|
|||||||
self.app = FastAPI(title="MaiBot WebUI")
|
self.app = FastAPI(title="MaiBot WebUI")
|
||||||
self._server = None
|
self._server = None
|
||||||
|
|
||||||
|
# 配置防爬虫中间件(需要在CORS之前注册)
|
||||||
|
self._setup_anti_crawler()
|
||||||
|
|
||||||
# 配置 CORS(支持开发环境跨域请求)
|
# 配置 CORS(支持开发环境跨域请求)
|
||||||
self._setup_cors()
|
self._setup_cors()
|
||||||
|
|
||||||
@@ -32,6 +35,9 @@ class WebUIServer:
|
|||||||
self._register_api_routes()
|
self._register_api_routes()
|
||||||
self._setup_static_files()
|
self._setup_static_files()
|
||||||
|
|
||||||
|
# 注册robots.txt路由
|
||||||
|
self._setup_robots_txt()
|
||||||
|
|
||||||
def _setup_cors(self):
|
def _setup_cors(self):
|
||||||
"""配置 CORS 中间件"""
|
"""配置 CORS 中间件"""
|
||||||
# 开发环境需要允许前端开发服务器的跨域请求
|
# 开发环境需要允许前端开发服务器的跨域请求
|
||||||
@@ -40,12 +46,21 @@ class WebUIServer:
|
|||||||
allow_origins=[
|
allow_origins=[
|
||||||
"http://localhost:5173", # Vite 开发服务器
|
"http://localhost:5173", # Vite 开发服务器
|
||||||
"http://127.0.0.1:5173",
|
"http://127.0.0.1:5173",
|
||||||
|
"http://localhost:7999", # 前端开发服务器备用端口
|
||||||
|
"http://127.0.0.1:7999",
|
||||||
"http://localhost:8001", # 生产环境
|
"http://localhost:8001", # 生产环境
|
||||||
"http://127.0.0.1:8001",
|
"http://127.0.0.1:8001",
|
||||||
],
|
],
|
||||||
allow_credentials=True, # 允许携带 Cookie
|
allow_credentials=True, # 允许携带 Cookie
|
||||||
allow_methods=["*"],
|
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], # 明确指定允许的方法
|
||||||
allow_headers=["*"],
|
allow_headers=[
|
||||||
|
"Content-Type",
|
||||||
|
"Authorization",
|
||||||
|
"Accept",
|
||||||
|
"Origin",
|
||||||
|
"X-Requested-With",
|
||||||
|
], # 明确指定允许的头
|
||||||
|
expose_headers=["Content-Length", "Content-Type"], # 允许前端读取的响应头
|
||||||
)
|
)
|
||||||
logger.debug("✅ CORS 中间件已配置")
|
logger.debug("✅ CORS 中间件已配置")
|
||||||
|
|
||||||
@@ -89,20 +104,60 @@ class WebUIServer:
|
|||||||
"""服务单页应用 - 只处理非 API 请求"""
|
"""服务单页应用 - 只处理非 API 请求"""
|
||||||
# 如果是根路径,直接返回 index.html
|
# 如果是根路径,直接返回 index.html
|
||||||
if not full_path or full_path == "/":
|
if not full_path or full_path == "/":
|
||||||
return FileResponse(static_path / "index.html", media_type="text/html")
|
response = FileResponse(static_path / "index.html", media_type="text/html")
|
||||||
|
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||||
|
return response
|
||||||
|
|
||||||
# 检查是否是静态文件
|
# 检查是否是静态文件
|
||||||
file_path = static_path / full_path
|
file_path = static_path / full_path
|
||||||
if file_path.is_file() and file_path.exists():
|
if file_path.is_file() and file_path.exists():
|
||||||
# 自动检测 MIME 类型
|
# 自动检测 MIME 类型
|
||||||
media_type = mimetypes.guess_type(str(file_path))[0]
|
media_type = mimetypes.guess_type(str(file_path))[0]
|
||||||
return FileResponse(file_path, media_type=media_type)
|
response = FileResponse(file_path, media_type=media_type)
|
||||||
|
# HTML 文件添加防索引头
|
||||||
|
if str(file_path).endswith(".html"):
|
||||||
|
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||||
|
return response
|
||||||
|
|
||||||
# 其他路径返回 index.html(SPA 路由)
|
# 其他路径返回 index.html(SPA 路由)
|
||||||
return FileResponse(static_path / "index.html", media_type="text/html")
|
response = FileResponse(static_path / "index.html", media_type="text/html")
|
||||||
|
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||||
|
return response
|
||||||
|
|
||||||
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")
|
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")
|
||||||
|
|
||||||
|
def _setup_anti_crawler(self):
|
||||||
|
"""配置防爬虫中间件"""
|
||||||
|
try:
|
||||||
|
from src.webui.anti_crawler import AntiCrawlerMiddleware
|
||||||
|
|
||||||
|
# 从环境变量读取防爬虫模式(false/strict/loose/basic)
|
||||||
|
anti_crawler_mode = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "basic").lower()
|
||||||
|
|
||||||
|
# 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行
|
||||||
|
# 我们需要在CORS之前注册,这样防爬虫检查会在CORS之前执行
|
||||||
|
self.app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
|
||||||
|
|
||||||
|
mode_descriptions = {"false": "已禁用", "strict": "严格模式", "loose": "宽松模式", "basic": "基础模式"}
|
||||||
|
mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式")
|
||||||
|
logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 配置防爬虫中间件失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
def _setup_robots_txt(self):
|
||||||
|
"""设置robots.txt路由"""
|
||||||
|
try:
|
||||||
|
from src.webui.anti_crawler import create_robots_txt_response
|
||||||
|
|
||||||
|
@self.app.get("/robots.txt", include_in_schema=False)
|
||||||
|
async def robots_txt():
|
||||||
|
"""返回robots.txt,禁止所有爬虫"""
|
||||||
|
return create_robots_txt_response()
|
||||||
|
|
||||||
|
logger.debug("✅ robots.txt 路由已注册")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True)
|
||||||
|
|
||||||
def _register_api_routes(self):
|
def _register_api_routes(self):
|
||||||
"""注册所有 WebUI API 路由"""
|
"""注册所有 WebUI API 路由"""
|
||||||
try:
|
try:
|
||||||
@@ -110,8 +165,10 @@ class WebUIServer:
|
|||||||
from src.webui.routes import router as webui_router
|
from src.webui.routes import router as webui_router
|
||||||
from src.webui.logs_ws import router as logs_router
|
from src.webui.logs_ws import router as logs_router
|
||||||
from src.webui.knowledge_routes import router as knowledge_router
|
from src.webui.knowledge_routes import router as knowledge_router
|
||||||
|
|
||||||
# 导入本地聊天室路由
|
# 导入本地聊天室路由
|
||||||
from src.webui.chat_routes import router as chat_router
|
from src.webui.chat_routes import router as chat_router
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
self.app.include_router(webui_router)
|
self.app.include_router(webui_router)
|
||||||
self.app.include_router(logs_router)
|
self.app.include_router(logs_router)
|
||||||
@@ -166,6 +223,7 @@ class WebUIServer:
|
|||||||
def _check_port_available(self) -> bool:
|
def _check_port_available(self) -> bool:
|
||||||
"""检查端口是否可用"""
|
"""检查端口是否可用"""
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
s.settimeout(1)
|
s.settimeout(1)
|
||||||
|
|||||||
114
src/webui/ws_auth.py
Normal file
114
src/webui/ws_auth.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
"""WebSocket 认证模块
|
||||||
|
|
||||||
|
提供所有 WebSocket 端点统一使用的临时 token 认证机制。
|
||||||
|
临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Cookie, Header
|
||||||
|
from typing import Optional
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.webui.token_manager import get_token_manager
|
||||||
|
|
||||||
|
logger = get_logger("webui.ws_auth")
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
|
||||||
|
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
|
||||||
|
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
|
||||||
|
_WS_TOKEN_EXPIRE_SECONDS = 60
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_expired_ws_tokens():
|
||||||
|
"""清理过期的临时 token"""
|
||||||
|
now = time.time()
|
||||||
|
expired = [t for t, (exp, _) in _ws_temp_tokens.items() if now > exp]
|
||||||
|
for t in expired:
|
||||||
|
del _ws_temp_tokens[t]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_ws_token(session_token: str) -> str:
|
||||||
|
"""生成 WebSocket 临时 token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_token: 原始的 session token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
临时 token 字符串
|
||||||
|
"""
|
||||||
|
_cleanup_expired_ws_tokens()
|
||||||
|
temp_token = secrets.token_urlsafe(32)
|
||||||
|
_ws_temp_tokens[temp_token] = (time.time() + _WS_TOKEN_EXPIRE_SECONDS, session_token)
|
||||||
|
logger.debug(f"生成 WS 临时 token: {temp_token[:8]}... 有效期 {_WS_TOKEN_EXPIRE_SECONDS}s")
|
||||||
|
return temp_token
|
||||||
|
|
||||||
|
|
||||||
|
def verify_ws_token(temp_token: str) -> bool:
|
||||||
|
"""验证并消费 WebSocket 临时 token(一次性使用)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temp_token: 临时 token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
验证是否通过
|
||||||
|
"""
|
||||||
|
_cleanup_expired_ws_tokens()
|
||||||
|
if temp_token not in _ws_temp_tokens:
|
||||||
|
logger.warning(f"WS token 不存在: {temp_token[:8]}...")
|
||||||
|
return False
|
||||||
|
expire_time, session_token = _ws_temp_tokens[temp_token]
|
||||||
|
if time.time() > expire_time:
|
||||||
|
del _ws_temp_tokens[temp_token]
|
||||||
|
logger.warning(f"WS token 已过期: {temp_token[:8]}...")
|
||||||
|
return False
|
||||||
|
# 验证原始 session token 仍然有效
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if not token_manager.verify_token(session_token):
|
||||||
|
del _ws_temp_tokens[temp_token]
|
||||||
|
logger.warning(f"WS token 关联的 session 已失效: {temp_token[:8]}...")
|
||||||
|
return False
|
||||||
|
# 消费 token(一次性使用)
|
||||||
|
del _ws_temp_tokens[temp_token]
|
||||||
|
logger.debug(f"WS token 验证成功: {temp_token[:8]}...")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/ws-token")
|
||||||
|
async def get_ws_token(
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取 WebSocket 连接用的临时 token
|
||||||
|
|
||||||
|
此端点验证当前会话的 Cookie 或 Authorization header,
|
||||||
|
然后返回一个临时 token 用于 WebSocket 握手认证。
|
||||||
|
临时 token 有效期 60 秒,且只能使用一次。
|
||||||
|
|
||||||
|
注意:在未认证时返回 200 状态码但 success=False,避免前端因 401 刷新页面。
|
||||||
|
"""
|
||||||
|
# 获取当前 session token
|
||||||
|
session_token = None
|
||||||
|
if maibot_session:
|
||||||
|
session_token = maibot_session
|
||||||
|
elif authorization and authorization.startswith("Bearer "):
|
||||||
|
session_token = authorization.replace("Bearer ", "")
|
||||||
|
|
||||||
|
if not session_token:
|
||||||
|
# 返回 200 但 success=False,避免前端因 401 刷新页面
|
||||||
|
# 这在登录页面是正常情况,不应该触发错误处理
|
||||||
|
logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)")
|
||||||
|
return {"success": False, "message": "未提供认证信息,请先登录", "token": None, "expires_in": 0}
|
||||||
|
|
||||||
|
# 验证 session token
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if not token_manager.verify_token(session_token):
|
||||||
|
# 同样返回 200 但 success=False,避免前端刷新
|
||||||
|
logger.debug("ws-token 请求:认证已过期")
|
||||||
|
return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0}
|
||||||
|
|
||||||
|
# 生成临时 WebSocket token
|
||||||
|
ws_token = generate_ws_token(session_token)
|
||||||
|
|
||||||
|
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}
|
||||||
@@ -7,3 +7,12 @@ WEBUI_ENABLED=true
|
|||||||
WEBUI_MODE=production # 模式: development(开发) 或 production(生产)
|
WEBUI_MODE=production # 模式: development(开发) 或 production(生产)
|
||||||
WEBUI_HOST=0.0.0.0 # WebUI 服务器监听地址
|
WEBUI_HOST=0.0.0.0 # WebUI 服务器监听地址
|
||||||
WEBUI_PORT=8001 # WebUI 服务器端口
|
WEBUI_PORT=8001 # WebUI 服务器端口
|
||||||
|
|
||||||
|
# 防爬虫配置
|
||||||
|
WEBUI_ANTI_CRAWLER_MODE=basic # 防爬虫模式: false(禁用) / strict(严格) / loose(宽松) / basic(基础-只记录不阻止)
|
||||||
|
WEBUI_ALLOWED_IPS=127.0.0.1 # IP白名单(逗号分隔,支持精确IP、CIDR格式和通配符)
|
||||||
|
# 示例: 127.0.0.1,192.168.1.0/24,172.17.0.0/16
|
||||||
|
WEBUI_TRUSTED_PROXIES= # 信任的代理IP列表(逗号分隔),只有来自这些IP的X-Forwarded-For才被信任
|
||||||
|
# 示例: 127.0.0.1,192.168.1.1,172.17.0.1
|
||||||
|
WEBUI_TRUST_XFF=false # 是否启用X-Forwarded-For代理解析(默认false)
|
||||||
|
# 启用后,仍要求直连IP在TRUSTED_PROXIES中才会信任XFF头
|
||||||
5
webui/dist/assets/dnd-B_gmzEl7.js
vendored
Normal file
5
webui/dist/assets/dnd-B_gmzEl7.js
vendored
Normal file
File diff suppressed because one or more lines are too long
5
webui/dist/assets/dnd-CHfCzWUK.js
vendored
5
webui/dist/assets/dnd-CHfCzWUK.js
vendored
File diff suppressed because one or more lines are too long
1
webui/dist/assets/icons-B6qV_tuI.js
vendored
Normal file
1
webui/dist/assets/icons-B6qV_tuI.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
webui/dist/assets/icons-CwAZotQh.js
vendored
1
webui/dist/assets/icons-CwAZotQh.js
vendored
File diff suppressed because one or more lines are too long
1
webui/dist/assets/index-CJYRKPEp.css
vendored
Normal file
1
webui/dist/assets/index-CJYRKPEp.css
vendored
Normal file
File diff suppressed because one or more lines are too long
1
webui/dist/assets/index-CWjV9Ftw.css
vendored
1
webui/dist/assets/index-CWjV9Ftw.css
vendored
File diff suppressed because one or more lines are too long
54
webui/dist/assets/index-CgxOYrz-.js
vendored
54
webui/dist/assets/index-CgxOYrz-.js
vendored
File diff suppressed because one or more lines are too long
86
webui/dist/assets/index-D-S1XZ00.js
vendored
Normal file
86
webui/dist/assets/index-D-S1XZ00.js
vendored
Normal file
File diff suppressed because one or more lines are too long
12
webui/dist/index.html
vendored
12
webui/dist/index.html
vendored
@@ -4,24 +4,28 @@
|
|||||||
<meta charset="UTF-8" />
|
<meta charset="UTF-8" />
|
||||||
<meta name="google" content="notranslate" />
|
<meta name="google" content="notranslate" />
|
||||||
<meta http-equiv="content-language" content="zh-CN" />
|
<meta http-equiv="content-language" content="zh-CN" />
|
||||||
|
<!-- 防止搜索引擎索引 -->
|
||||||
|
<meta name="robots" content="noindex, nofollow, noarchive, nosnippet" />
|
||||||
|
<meta name="googlebot" content="noindex, nofollow" />
|
||||||
|
<meta name="bingbot" content="noindex, nofollow" />
|
||||||
<link rel="icon" type="image/x-icon" href="/maimai.ico" />
|
<link rel="icon" type="image/x-icon" href="/maimai.ico" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
<title>MaiBot Dashboard</title>
|
<title>MaiBot Dashboard</title>
|
||||||
<script type="module" crossorigin src="/assets/index-CgxOYrz-.js"></script>
|
<script type="module" crossorigin src="/assets/index-D-S1XZ00.js"></script>
|
||||||
<link rel="modulepreload" crossorigin href="/assets/react-vendor-BmxF9s7Q.js">
|
<link rel="modulepreload" crossorigin href="/assets/react-vendor-BmxF9s7Q.js">
|
||||||
<link rel="modulepreload" crossorigin href="/assets/router-Bz250laD.js">
|
<link rel="modulepreload" crossorigin href="/assets/router-Bz250laD.js">
|
||||||
<link rel="modulepreload" crossorigin href="/assets/utils-BXc2jIuz.js">
|
<link rel="modulepreload" crossorigin href="/assets/utils-BXc2jIuz.js">
|
||||||
<link rel="modulepreload" crossorigin href="/assets/radix-core-9dEfQl-6.js">
|
<link rel="modulepreload" crossorigin href="/assets/radix-core-9dEfQl-6.js">
|
||||||
<link rel="modulepreload" crossorigin href="/assets/radix-extra-DDK-u9dm.js">
|
<link rel="modulepreload" crossorigin href="/assets/radix-extra-DDK-u9dm.js">
|
||||||
<link rel="modulepreload" crossorigin href="/assets/charts-DbiuC1q1.js">
|
<link rel="modulepreload" crossorigin href="/assets/charts-DbiuC1q1.js">
|
||||||
<link rel="modulepreload" crossorigin href="/assets/icons-CwAZotQh.js">
|
<link rel="modulepreload" crossorigin href="/assets/icons-B6qV_tuI.js">
|
||||||
<link rel="modulepreload" crossorigin href="/assets/codemirror-BEE0n9kQ.js">
|
<link rel="modulepreload" crossorigin href="/assets/codemirror-BEE0n9kQ.js">
|
||||||
|
<link rel="modulepreload" crossorigin href="/assets/dnd-B_gmzEl7.js">
|
||||||
<link rel="modulepreload" crossorigin href="/assets/misc-CKjrIrIJ.js">
|
<link rel="modulepreload" crossorigin href="/assets/misc-CKjrIrIJ.js">
|
||||||
<link rel="modulepreload" crossorigin href="/assets/dnd-CHfCzWUK.js">
|
|
||||||
<link rel="modulepreload" crossorigin href="/assets/reactflow-DLoXAt4c.js">
|
<link rel="modulepreload" crossorigin href="/assets/reactflow-DLoXAt4c.js">
|
||||||
<link rel="modulepreload" crossorigin href="/assets/uppy-BMZiFQyG.js">
|
<link rel="modulepreload" crossorigin href="/assets/uppy-BMZiFQyG.js">
|
||||||
<link rel="modulepreload" crossorigin href="/assets/markdown-kUhwkcQP.js">
|
<link rel="modulepreload" crossorigin href="/assets/markdown-kUhwkcQP.js">
|
||||||
<link rel="stylesheet" crossorigin href="/assets/index-CWjV9Ftw.css">
|
<link rel="stylesheet" crossorigin href="/assets/index-CJYRKPEp.css">
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div id="root" class="notranslate"></div>
|
<div id="root" class="notranslate"></div>
|
||||||
|
|||||||
Reference in New Issue
Block a user