From e680a4d1f55fe8a238375c6e42b8f9a5cefce017 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Sat, 13 Dec 2025 17:14:09 +0800 Subject: [PATCH] Ruff format --- bot.py | 14 +- .../config_converter.py | 54 +- plugins/MaiBot_MCPBridgePlugin/mcp_client.py | 641 +++++++------- plugins/MaiBot_MCPBridgePlugin/plugin.py | 811 +++++++++--------- .../MaiBot_MCPBridgePlugin/test_mcp_client.py | 162 ++-- scripts/replyer_action_stats.py | 65 +- src/bw_learner/expression_learner.py | 83 +- src/bw_learner/expression_reflector.py | 4 +- src/bw_learner/expression_selector.py | 54 +- src/bw_learner/jargon_explainer.py | 9 +- src/bw_learner/jargon_miner.py | 46 +- src/bw_learner/message_recorder.py | 87 +- src/chat/brain_chat/brain_chat.py | 40 +- src/chat/brain_chat/brain_planner.py | 18 +- src/chat/emoji_system/emoji_manager.py | 6 +- src/chat/heart_flow/heartFC_chat.py | 4 +- src/chat/message_receive/bot.py | 6 +- src/chat/message_receive/message.py | 2 +- .../message_receive/uni_message_sender.py | 29 +- src/chat/planner_actions/planner.py | 3 +- src/chat/replyer/group_generator.py | 12 +- src/chat/replyer/private_generator.py | 10 +- .../replyer/prompt/replyer_private_prompt.py | 11 +- src/chat/replyer/prompt/replyer_prompt.py | 3 +- src/chat/utils/chat_message_builder.py | 5 +- src/chat/utils/statistic.py | 16 +- src/chat/utils/utils.py | 10 +- src/chat/utils/utils_image.py | 22 +- src/common/database/database_model.py | 10 +- src/common/toml_utils.py | 8 +- src/config/official_configs.py | 8 +- src/dream/dream_agent.py | 66 +- src/dream/dream_generator.py | 45 +- src/dream/tools/__init__.py | 5 - src/dream/tools/create_chat_history_tool.py | 5 - src/dream/tools/delete_chat_history_tool.py | 5 - src/dream/tools/delete_jargon_tool.py | 5 - src/dream/tools/finish_maintenance_tool.py | 5 - .../tools/get_chat_history_detail_tool.py | 18 +- src/dream/tools/search_chat_history_tool.py | 25 +- src/dream/tools/search_jargon_tool.py | 6 +- src/dream/tools/update_chat_history_tool.py | 5 - src/dream/tools/update_jargon_tool.py | 5 - .../chat_history_summarizer.py | 31 +- src/llm_models/utils_model.py | 18 +- src/main.py | 1 - src/memory_system/memory_retrieval.py | 140 +-- src/memory_system/memory_utils.py | 2 +- .../retrieval_tools/found_answer.py | 1 - .../retrieval_tools/query_chat_history.py | 21 +- src/webui/auth.py | 32 +- src/webui/chat_routes.py | 8 +- src/webui/emoji_routes.py | 185 ++-- src/webui/expression_routes.py | 29 +- src/webui/jargon_routes.py | 35 +- src/webui/person_routes.py | 21 +- src/webui/plugin_routes.py | 67 +- src/webui/routes.py | 26 +- src/webui/token_manager.py | 10 +- src/webui/webui_server.py | 3 + 60 files changed, 1546 insertions(+), 1532 deletions(-) diff --git a/bot.py b/bot.py index 5e2042ff..c63a1286 100644 --- a/bot.py +++ b/bot.py @@ -41,6 +41,7 @@ logger = get_logger("main") # 定义重启退出码 RESTART_EXIT_CODE = 42 + def run_runner_process(): """ Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。 @@ -55,25 +56,25 @@ def run_runner_process(): while True: logger.info(f"正在启动 {script_file}...") - + # 启动子进程 (Worker) # 使用 sys.executable 确保使用相同的 Python 解释器 cmd = [python_executable, script_file] + sys.argv[1:] - + process = subprocess.Popen(cmd, env=env) - + try: # 等待子进程结束 return_code = process.wait() - + if return_code == RESTART_EXIT_CODE: logger.info("检测到重启请求 (退出码 42),正在重启...") - time.sleep(1) # 稍作等待 + time.sleep(1) # 稍作等待 continue else: logger.info(f"程序已退出 (退出码 {return_code})") sys.exit(return_code) - + except KeyboardInterrupt: # 向子进程发送终止信号 if process.poll() is None: @@ -87,6 +88,7 @@ def run_runner_process(): process.kill() sys.exit(0) + # 检查是否是 Worker 进程 # 如果没有设置 MAIBOT_WORKER_PROCESS 环境变量,说明是直接运行的脚本, # 此时应该作为 Runner 运行。 diff --git a/plugins/MaiBot_MCPBridgePlugin/config_converter.py b/plugins/MaiBot_MCPBridgePlugin/config_converter.py index 16f9e028..eb036991 100644 --- a/plugins/MaiBot_MCPBridgePlugin/config_converter.py +++ b/plugins/MaiBot_MCPBridgePlugin/config_converter.py @@ -19,6 +19,7 @@ from typing import Any, Dict, List, Optional, Tuple @dataclass class ConversionResult: """转换结果""" + success: bool servers: List[Dict[str, Any]] = field(default_factory=list) errors: List[str] = field(default_factory=list) @@ -53,7 +54,7 @@ class ConfigConverter: @classmethod def detect_format(cls, config: Dict[str, Any]) -> Optional[str]: """检测配置格式类型 - + Returns: "claude": Claude Desktop 格式 (mcpServers 对象) "kiro": Kiro MCP 格式 (mcpServers 对象,与 Claude 相同) @@ -82,7 +83,7 @@ class ConfigConverter: @classmethod def parse_json_safe(cls, json_str: str) -> Tuple[Optional[Any], Optional[str]]: """安全解析 JSON 字符串 - + Returns: (解析结果, 错误信息) """ @@ -102,11 +103,11 @@ class ConfigConverter: @classmethod def validate_server_config(cls, name: str, config: Dict[str, Any]) -> Tuple[bool, Optional[str], List[str]]: """验证单个服务器配置 - + Args: name: 服务器名称 config: 服务器配置字典 - + Returns: (是否有效, 错误信息, 警告列表) """ @@ -177,11 +178,11 @@ class ConfigConverter: @classmethod def convert_claude_server(cls, name: str, config: Dict[str, Any]) -> Dict[str, Any]: """将单个 Claude 格式服务器配置转换为 MaiBot 格式 - + Args: name: 服务器名称 config: Claude 格式的服务器配置 - + Returns: MaiBot 格式的服务器配置 """ @@ -231,10 +232,10 @@ class ConfigConverter: @classmethod def convert_maibot_server(cls, config: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: """将单个 MaiBot 格式服务器配置转换为 Claude 格式 - + Args: config: MaiBot 格式的服务器配置 - + Returns: (服务器名称, Claude 格式的服务器配置) """ @@ -271,17 +272,13 @@ class ConfigConverter: return name, result @classmethod - def from_claude_format( - cls, - config: Dict[str, Any], - existing_names: Optional[set] = None - ) -> ConversionResult: + def from_claude_format(cls, config: Dict[str, Any], existing_names: Optional[set] = None) -> ConversionResult: """从 Claude Desktop 格式转换为 MaiBot 格式 - + Args: config: Claude Desktop 配置 (包含 mcpServers 字段) existing_names: 已存在的服务器名称集合,用于跳过重复 - + Returns: ConversionResult """ @@ -336,10 +333,10 @@ class ConfigConverter: @classmethod def to_claude_format(cls, servers: List[Dict[str, Any]]) -> Dict[str, Any]: """将 MaiBot 格式转换为 Claude Desktop 格式 - + Args: servers: MaiBot 格式的服务器列表 - + Returns: Claude Desktop 格式的配置 """ @@ -355,19 +352,15 @@ class ConfigConverter: return {"mcpServers": mcp_servers} @classmethod - def import_from_string( - cls, - json_str: str, - existing_names: Optional[set] = None - ) -> ConversionResult: + def import_from_string(cls, json_str: str, existing_names: Optional[set] = None) -> ConversionResult: """从 JSON 字符串导入配置 - + 自动检测格式并转换为 MaiBot 格式 - + Args: json_str: JSON 字符串 existing_names: 已存在的服务器名称集合 - + Returns: ConversionResult """ @@ -422,19 +415,14 @@ class ConfigConverter: return result @classmethod - def export_to_string( - cls, - servers: List[Dict[str, Any]], - format_type: str = "claude", - pretty: bool = True - ) -> str: + def export_to_string(cls, servers: List[Dict[str, Any]], format_type: str = "claude", pretty: bool = True) -> str: """导出配置为 JSON 字符串 - + Args: servers: MaiBot 格式的服务器列表 format_type: 导出格式 ("claude", "kiro", "maibot") pretty: 是否格式化输出 - + Returns: JSON 字符串 """ diff --git a/plugins/MaiBot_MCPBridgePlugin/mcp_client.py b/plugins/MaiBot_MCPBridgePlugin/mcp_client.py index 0d4eebff..d2eed62b 100644 --- a/plugins/MaiBot_MCPBridgePlugin/mcp_client.py +++ b/plugins/MaiBot_MCPBridgePlugin/mcp_client.py @@ -34,28 +34,31 @@ from enum import Enum # 尝试导入 MaiBot 的 logger,如果失败则使用标准 logging try: from src.common.logger import get_logger + logger = get_logger("mcp_client") except ImportError: # Fallback: 使用标准 logging logger = logging.getLogger("mcp_client") if not logger.handlers: handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter('[%(levelname)s] %(name)s: %(message)s')) + handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s")) logger.addHandler(handler) logger.setLevel(logging.INFO) class TransportType(Enum): """MCP 传输类型""" - STDIO = "stdio" # 本地进程通信 - SSE = "sse" # Server-Sent Events (旧版 HTTP) - HTTP = "http" # HTTP Streamable (新版,推荐) + + STDIO = "stdio" # 本地进程通信 + SSE = "sse" # Server-Sent Events (旧版 HTTP) + HTTP = "http" # HTTP Streamable (新版,推荐) STREAMABLE_HTTP = "streamable_http" # HTTP Streamable 的别名 @dataclass class MCPToolInfo: """MCP 工具信息""" + name: str description: str input_schema: Dict[str, Any] @@ -65,6 +68,7 @@ class MCPToolInfo: @dataclass class MCPResourceInfo: """MCP 资源信息""" + uri: str name: str description: str @@ -75,6 +79,7 @@ class MCPResourceInfo: @dataclass class MCPPromptInfo: """MCP 提示模板信息""" + name: str description: str arguments: List[Dict[str, Any]] # [{name, description, required}] @@ -84,6 +89,7 @@ class MCPPromptInfo: @dataclass class MCPServerConfig: """MCP 服务器配置""" + name: str enabled: bool = True transport: TransportType = TransportType.STDIO @@ -99,6 +105,7 @@ class MCPServerConfig: @dataclass class MCPCallResult: """MCP 工具调用结果""" + success: bool content: Any error: Optional[str] = None @@ -108,27 +115,28 @@ class MCPCallResult: class CircuitState(Enum): """断路器状态""" - CLOSED = "closed" # 正常状态,允许请求 - OPEN = "open" # 熔断状态,拒绝请求 + + CLOSED = "closed" # 正常状态,允许请求 + OPEN = "open" # 熔断状态,拒绝请求 HALF_OPEN = "half_open" # 半开状态,允许少量试探请求 @dataclass class CircuitBreaker: """v1.7.0: 断路器 - 防止对故障服务器持续请求 - + 状态转换: - CLOSED -> OPEN: 连续失败次数达到阈值 - OPEN -> HALF_OPEN: 熔断时间到期 - HALF_OPEN -> CLOSED: 试探请求成功 - HALF_OPEN -> OPEN: 试探请求失败 """ - + # 配置 - failure_threshold: int = 5 # 连续失败多少次后熔断 - recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒) - half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用 - + failure_threshold: int = 5 # 连续失败多少次后熔断 + recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒) + half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用 + # 状态 state: CircuitState = field(default=CircuitState.CLOSED) failure_count: int = 0 @@ -136,18 +144,18 @@ class CircuitBreaker: last_failure_time: float = 0.0 last_state_change: float = field(default_factory=time.time) half_open_calls: int = 0 - + def can_execute(self) -> Tuple[bool, Optional[str]]: """检查是否允许执行请求 - + Returns: (是否允许, 拒绝原因) """ current_time = time.time() - + if self.state == CircuitState.CLOSED: return True, None - + if self.state == CircuitState.OPEN: # 检查是否到了恢复时间 time_since_failure = current_time - self.last_failure_time @@ -158,20 +166,20 @@ class CircuitBreaker: else: remaining = self.recovery_timeout - time_since_failure return False, f"断路器熔断中,{remaining:.0f}秒后重试" - + if self.state == CircuitState.HALF_OPEN: # 半开状态,检查是否还有试探配额 if self.half_open_calls < self.half_open_max_calls: return True, None else: return False, "断路器半开状态,等待试探结果" - + return True, None - + def record_success(self) -> None: """记录成功调用""" self.success_count += 1 - + if self.state == CircuitState.HALF_OPEN: # 半开状态下成功,恢复到关闭状态 self._transition_to(CircuitState.CLOSED) @@ -179,12 +187,12 @@ class CircuitBreaker: elif self.state == CircuitState.CLOSED: # 正常状态下成功,重置失败计数 self.failure_count = 0 - + def record_failure(self) -> None: """记录失败调用""" self.failure_count += 1 self.last_failure_time = time.time() - + if self.state == CircuitState.HALF_OPEN: # 半开状态下失败,重新熔断 self._transition_to(CircuitState.OPEN) @@ -194,21 +202,21 @@ class CircuitBreaker: if self.failure_count >= self.failure_threshold: self._transition_to(CircuitState.OPEN) logger.warning(f"断路器熔断(连续失败 {self.failure_count} 次)") - + def _transition_to(self, new_state: CircuitState) -> None: """状态转换""" old_state = self.state self.state = new_state self.last_state_change = time.time() - + if new_state == CircuitState.CLOSED: self.failure_count = 0 self.half_open_calls = 0 elif new_state == CircuitState.HALF_OPEN: self.half_open_calls = 0 - + logger.debug(f"断路器状态: {old_state.value} -> {new_state.value}") - + def reset(self) -> None: """重置断路器""" self.state = CircuitState.CLOSED @@ -216,7 +224,7 @@ class CircuitBreaker: self.success_count = 0 self.half_open_calls = 0 self.last_state_change = time.time() - + def get_status(self) -> Dict[str, Any]: """获取断路器状态""" return { @@ -232,6 +240,7 @@ class CircuitBreaker: @dataclass class ToolCallStats: """工具调用统计""" + tool_key: str total_calls: int = 0 success_calls: int = 0 @@ -239,21 +248,21 @@ class ToolCallStats: total_duration_ms: float = 0.0 last_call_time: Optional[float] = None last_error: Optional[str] = None - + @property def success_rate(self) -> float: """成功率(0-100)""" if self.total_calls == 0: return 0.0 return (self.success_calls / self.total_calls) * 100 - + @property def avg_duration_ms(self) -> float: """平均耗时(毫秒)""" if self.success_calls == 0: return 0.0 return self.total_duration_ms / self.success_calls - + def record_call(self, success: bool, duration_ms: float, error: Optional[str] = None) -> None: """记录一次调用""" self.total_calls += 1 @@ -264,7 +273,7 @@ class ToolCallStats: else: self.failed_calls += 1 self.last_error = error - + def to_dict(self) -> Dict[str, Any]: """转换为字典""" return { @@ -282,6 +291,7 @@ class ToolCallStats: @dataclass class ServerStats: """服务器统计""" + server_name: str connect_count: int = 0 # 连接次数 disconnect_count: int = 0 # 断开次数 @@ -290,26 +300,26 @@ class ServerStats: last_disconnect_time: Optional[float] = None last_heartbeat_time: Optional[float] = None consecutive_failures: int = 0 # 连续失败次数 - + def record_connect(self) -> None: self.connect_count += 1 self.last_connect_time = time.time() self.consecutive_failures = 0 - + def record_disconnect(self) -> None: self.disconnect_count += 1 self.last_disconnect_time = time.time() - + def record_reconnect(self) -> None: self.reconnect_count += 1 self.consecutive_failures = 0 - + def record_failure(self) -> None: self.consecutive_failures += 1 - + def record_heartbeat(self) -> None: self.last_heartbeat_time = time.time() - + def to_dict(self) -> Dict[str, Any]: return { "server_name": self.server_name, @@ -325,7 +335,7 @@ class ServerStats: class MCPClientSession: """MCP 客户端会话,管理与单个 MCP 服务器的连接""" - + def __init__(self, config: MCPServerConfig, call_timeout: float = 60.0): self.config = config self.call_timeout = call_timeout @@ -338,63 +348,63 @@ class MCPClientSession: self._prompts: List[MCPPromptInfo] = [] # v1.2.0: Prompts 支持 self._connected = False self._lock = asyncio.Lock() - + # 功能支持标记(服务器可能不支持某些功能) self._supports_resources: bool = False self._supports_prompts: bool = False - + # 统计信息 self.stats = ServerStats(server_name=config.name) self._tool_stats: Dict[str, ToolCallStats] = {} - + # v1.7.0: 断路器 self._circuit_breaker = CircuitBreaker() - + @property def is_connected(self) -> bool: return self._connected - + @property def tools(self) -> List[MCPToolInfo]: return self._tools.copy() - + @property def resources(self) -> List[MCPResourceInfo]: """v1.2.0: 获取资源列表""" return self._resources.copy() - + @property def prompts(self) -> List[MCPPromptInfo]: """v1.2.0: 获取提示模板列表""" return self._prompts.copy() - + @property def supports_resources(self) -> bool: """v1.2.0: 服务器是否支持 Resources""" return self._supports_resources - + @property def supports_prompts(self) -> bool: """v1.2.0: 服务器是否支持 Prompts""" return self._supports_prompts - + @property def server_name(self) -> str: return self.config.name - + def get_tool_stats(self, tool_name: str) -> Optional[ToolCallStats]: """获取工具统计""" return self._tool_stats.get(tool_name) - + def get_circuit_breaker_status(self) -> Dict[str, Any]: """v1.7.0: 获取断路器状态""" return self._circuit_breaker.get_status() - + def reset_circuit_breaker(self) -> None: """v1.7.0: 重置断路器""" self._circuit_breaker.reset() logger.info(f"[{self.server_name}] 断路器已重置") - + def get_all_tool_stats(self) -> Dict[str, ToolCallStats]: """获取所有工具统计""" return self._tool_stats.copy() @@ -404,7 +414,7 @@ class MCPClientSession: async with self._lock: if self._connected: return True - + try: success = False if self.config.transport == TransportType.STDIO: @@ -416,7 +426,7 @@ class MCPClientSession: else: logger.error(f"[{self.server_name}] 不支持的传输类型: {self.config.transport}") return False - + if success: self.stats.record_connect() # v1.7.0: 连接成功时重置断路器 @@ -424,13 +434,13 @@ class MCPClientSession: else: self.stats.record_failure() return success - + except Exception as e: logger.error(f"[{self.server_name}] 连接失败: {e}") self._connected = False self.stats.record_failure() return False - + async def _connect_stdio(self) -> bool: """通过 stdio 连接 MCP 服务器""" try: @@ -440,31 +450,29 @@ class MCPClientSession: except ImportError: logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp") return False - + server_params = StdioServerParameters( - command=self.config.command, - args=self.config.args, - env=self.config.env if self.config.env else None + command=self.config.command, args=self.config.args, env=self.config.env if self.config.env else None ) - + self._stdio_context = stdio_client(server_params) self._read_stream, self._write_stream = await self._stdio_context.__aenter__() - + self._session_context = ClientSession(self._read_stream, self._write_stream) self._session = await self._session_context.__aenter__() - + await self._session.initialize() await self._fetch_tools() - + self._connected = True logger.info(f"[{self.server_name}] stdio 连接成功,发现 {len(self._tools)} 个工具") return True - + except Exception as e: logger.error(f"[{self.server_name}] stdio 连接失败: {e}") await self._cleanup() return False - + async def _connect_sse(self) -> bool: """通过 SSE 连接 MCP 服务器""" try: @@ -474,13 +482,13 @@ class MCPClientSession: except ImportError: logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp") return False - + if not self.config.url: logger.error(f"[{self.server_name}] SSE 传输需要配置 url") return False - + logger.debug(f"[{self.server_name}] 正在连接 SSE MCP 服务器: {self.config.url}") - + # v1.4.2: 支持 headers 鉴权 sse_kwargs = { "url": self.config.url, @@ -489,23 +497,24 @@ class MCPClientSession: } if self.config.headers: sse_kwargs["headers"] = self.config.headers - + self._sse_context = sse_client(**sse_kwargs) self._read_stream, self._write_stream = await self._sse_context.__aenter__() - + self._session_context = ClientSession(self._read_stream, self._write_stream) self._session = await self._session_context.__aenter__() - + await self._session.initialize() await self._fetch_tools() - + self._connected = True logger.info(f"[{self.server_name}] SSE 连接成功,发现 {len(self._tools)} 个工具") return True - + except Exception as e: logger.error(f"[{self.server_name}] SSE 连接失败: {e}") import traceback + logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}") await self._cleanup() return False @@ -519,13 +528,13 @@ class MCPClientSession: except ImportError: logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp") return False - + if not self.config.url: logger.error(f"[{self.server_name}] HTTP 传输需要配置 url") return False - + logger.debug(f"[{self.server_name}] 正在连接 HTTP MCP 服务器: {self.config.url}") - + # v1.4.2: 支持 headers 鉴权 http_kwargs = { "url": self.config.url, @@ -534,23 +543,24 @@ class MCPClientSession: } if self.config.headers: http_kwargs["headers"] = self.config.headers - + self._http_context = streamablehttp_client(**http_kwargs) self._read_stream, self._write_stream, self._get_session_id = await self._http_context.__aenter__() - + self._session_context = ClientSession(self._read_stream, self._write_stream) self._session = await self._session_context.__aenter__() - + await self._session.initialize() await self._fetch_tools() - + self._connected = True logger.info(f"[{self.server_name}] HTTP 连接成功,发现 {len(self._tools)} 个工具") return True - + except Exception as e: logger.error(f"[{self.server_name}] HTTP 连接失败: {e}") import traceback + logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}") await self._cleanup() return False @@ -559,59 +569,56 @@ class MCPClientSession: """获取 MCP 服务器的工具列表""" if not self._session: return - + try: result = await self._session.list_tools() self._tools = [] - + for tool in result.tools: tool_info = MCPToolInfo( name=tool.name, description=tool.description or f"MCP tool: {tool.name}", - input_schema=tool.inputSchema if hasattr(tool, 'inputSchema') else {}, - server_name=self.server_name + input_schema=tool.inputSchema if hasattr(tool, "inputSchema") else {}, + server_name=self.server_name, ) self._tools.append(tool_info) # 初始化工具统计 if tool.name not in self._tool_stats: self._tool_stats[tool.name] = ToolCallStats(tool_key=tool.name) logger.debug(f"[{self.server_name}] 发现工具: {tool.name}") - + except Exception as e: logger.error(f"[{self.server_name}] 获取工具列表失败: {e}") self._tools = [] async def fetch_resources(self) -> bool: """v1.2.0: 获取 MCP 服务器的资源列表 - + Returns: bool: 是否成功获取(服务器不支持时返回 False) """ if not self._session: return False - + try: - result = await asyncio.wait_for( - self._session.list_resources(), - timeout=self.call_timeout - ) + result = await asyncio.wait_for(self._session.list_resources(), timeout=self.call_timeout) self._resources = [] - + for resource in result.resources: resource_info = MCPResourceInfo( uri=str(resource.uri), name=resource.name or str(resource.uri), description=resource.description or "", - mime_type=resource.mimeType if hasattr(resource, 'mimeType') else None, - server_name=self.server_name + mime_type=resource.mimeType if hasattr(resource, "mimeType") else None, + server_name=self.server_name, ) self._resources.append(resource_info) logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}") - + self._supports_resources = True logger.info(f"[{self.server_name}] 获取到 {len(self._resources)} 个资源") return True - + except Exception as e: # 服务器可能不支持 resources,这不是错误 error_str = str(e).lower() @@ -625,44 +632,43 @@ class MCPClientSession: async def fetch_prompts(self) -> bool: """v1.2.0: 获取 MCP 服务器的提示模板列表 - + Returns: bool: 是否成功获取(服务器不支持时返回 False) """ if not self._session: return False - + try: - result = await asyncio.wait_for( - self._session.list_prompts(), - timeout=self.call_timeout - ) + result = await asyncio.wait_for(self._session.list_prompts(), timeout=self.call_timeout) self._prompts = [] - + for prompt in result.prompts: # 解析参数 arguments = [] - if hasattr(prompt, 'arguments') and prompt.arguments: + if hasattr(prompt, "arguments") and prompt.arguments: for arg in prompt.arguments: - arguments.append({ - "name": arg.name, - "description": arg.description or "", - "required": arg.required if hasattr(arg, 'required') else False, - }) - + arguments.append( + { + "name": arg.name, + "description": arg.description or "", + "required": arg.required if hasattr(arg, "required") else False, + } + ) + prompt_info = MCPPromptInfo( name=prompt.name, description=prompt.description or f"MCP prompt: {prompt.name}", arguments=arguments, - server_name=self.server_name + server_name=self.server_name, ) self._prompts.append(prompt_info) logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}") - + self._supports_prompts = True logger.info(f"[{self.server_name}] 获取到 {len(self._prompts)} 个提示模板") return True - + except Exception as e: # 服务器可能不支持 prompts,这不是错误 error_str = str(e).lower() @@ -676,45 +682,35 @@ class MCPClientSession: async def read_resource(self, uri: str) -> MCPCallResult: """v1.2.0: 读取指定资源的内容 - + Args: uri: 资源 URI - + Returns: MCPCallResult: 包含资源内容的结果 """ start_time = time.time() - + if not self._connected or not self._session: - return MCPCallResult( - success=False, - content=None, - error=f"服务器 {self.server_name} 未连接" - ) - + return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接") + if not self._supports_resources: - return MCPCallResult( - success=False, - content=None, - error=f"服务器 {self.server_name} 不支持 Resources 功能" - ) - + return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Resources 功能") + try: - result = await asyncio.wait_for( - self._session.read_resource(uri), - timeout=self.call_timeout - ) - + result = await asyncio.wait_for(self._session.read_resource(uri), timeout=self.call_timeout) + duration_ms = (time.time() - start_time) * 1000 - + # 处理返回内容 content_parts = [] for content in result.contents: - if hasattr(content, 'text'): + if hasattr(content, "text"): content_parts.append(content.text) - elif hasattr(content, 'blob'): + elif hasattr(content, "blob"): # 二进制数据,返回 base64 或提示 import base64 + blob_data = content.blob if len(blob_data) < 10000: # 小于 10KB 返回 base64 content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}") @@ -722,117 +718,85 @@ class MCPClientSession: content_parts.append(f"[二进制数据: {len(blob_data)} bytes]") else: content_parts.append(str(content)) - + return MCPCallResult( - success=True, - content="\n".join(content_parts) if content_parts else "", - duration_ms=duration_ms + success=True, content="\n".join(content_parts) if content_parts else "", duration_ms=duration_ms ) - + except asyncio.TimeoutError: duration_ms = (time.time() - start_time) * 1000 return MCPCallResult( - success=False, - content=None, - error=f"读取资源超时({self.call_timeout}秒)", - duration_ms=duration_ms + success=False, content=None, error=f"读取资源超时({self.call_timeout}秒)", duration_ms=duration_ms ) except Exception as e: duration_ms = (time.time() - start_time) * 1000 logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}") - return MCPCallResult( - success=False, - content=None, - error=str(e), - duration_ms=duration_ms - ) + return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms) async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult: """v1.2.0: 获取提示模板的内容 - + Args: name: 提示模板名称 arguments: 模板参数 - + Returns: MCPCallResult: 包含提示内容的结果 """ start_time = time.time() - + if not self._connected or not self._session: - return MCPCallResult( - success=False, - content=None, - error=f"服务器 {self.server_name} 未连接" - ) - + return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接") + if not self._supports_prompts: - return MCPCallResult( - success=False, - content=None, - error=f"服务器 {self.server_name} 不支持 Prompts 功能" - ) - + return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Prompts 功能") + try: result = await asyncio.wait_for( - self._session.get_prompt(name, arguments=arguments or {}), - timeout=self.call_timeout + self._session.get_prompt(name, arguments=arguments or {}), timeout=self.call_timeout ) - + duration_ms = (time.time() - start_time) * 1000 - + # 处理返回的消息 messages = [] for msg in result.messages: - role = msg.role if hasattr(msg, 'role') else "unknown" + role = msg.role if hasattr(msg, "role") else "unknown" content_text = "" - if hasattr(msg, 'content'): - if hasattr(msg.content, 'text'): + if hasattr(msg, "content"): + if hasattr(msg.content, "text"): content_text = msg.content.text elif isinstance(msg.content, str): content_text = msg.content else: content_text = str(msg.content) messages.append(f"[{role}]: {content_text}") - + return MCPCallResult( - success=True, - content="\n\n".join(messages) if messages else "", - duration_ms=duration_ms + success=True, content="\n\n".join(messages) if messages else "", duration_ms=duration_ms ) - + except asyncio.TimeoutError: duration_ms = (time.time() - start_time) * 1000 return MCPCallResult( - success=False, - content=None, - error=f"获取提示模板超时({self.call_timeout}秒)", - duration_ms=duration_ms + success=False, content=None, error=f"获取提示模板超时({self.call_timeout}秒)", duration_ms=duration_ms ) except Exception as e: duration_ms = (time.time() - start_time) * 1000 logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}") - return MCPCallResult( - success=False, - content=None, - error=str(e), - duration_ms=duration_ms - ) - + return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms) + async def check_health(self) -> bool: """检查连接健康状态(心跳检测) - + 通过调用 list_tools 来验证连接是否正常 """ if not self._connected or not self._session: return False - + try: # 使用 list_tools 作为心跳检测 - await asyncio.wait_for( - self._session.list_tools(), - timeout=10.0 - ) + await asyncio.wait_for(self._session.list_tools(), timeout=10.0) self.stats.record_heartbeat() return True except Exception as e: @@ -841,25 +805,20 @@ class MCPClientSession: self._connected = False self.stats.record_disconnect() return False - + async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> MCPCallResult: """调用 MCP 工具""" start_time = time.time() - + # v1.7.0: 断路器检查 can_execute, reject_reason = self._circuit_breaker.can_execute() if not can_execute: - return MCPCallResult( - success=False, - content=None, - error=f"⚡ {reject_reason}", - circuit_broken=True - ) - + return MCPCallResult(success=False, content=None, error=f"⚡ {reject_reason}", circuit_broken=True) + # 半开状态下增加试探计数 if self._circuit_breaker.state == CircuitState.HALF_OPEN: self._circuit_breaker.half_open_calls += 1 - + if not self._connected or not self._session: error_msg = f"服务器 {self.server_name} 未连接" # 记录失败 @@ -867,38 +826,37 @@ class MCPClientSession: self._tool_stats[tool_name].record_call(False, 0, error_msg) self._circuit_breaker.record_failure() return MCPCallResult(success=False, content=None, error=error_msg) - + try: result = await asyncio.wait_for( - self._session.call_tool(tool_name, arguments=arguments), - timeout=self.call_timeout + self._session.call_tool(tool_name, arguments=arguments), timeout=self.call_timeout ) - + duration_ms = (time.time() - start_time) * 1000 - + # 处理返回内容 content_parts = [] for content in result.content: - if hasattr(content, 'text'): + if hasattr(content, "text"): content_parts.append(content.text) - elif hasattr(content, 'data'): + elif hasattr(content, "data"): content_parts.append(f"[二进制数据: {len(content.data)} bytes]") else: content_parts.append(str(content)) - + # 记录成功 if tool_name in self._tool_stats: self._tool_stats[tool_name].record_call(True, duration_ms) - + # v1.7.0: 断路器记录成功 self._circuit_breaker.record_success() - + return MCPCallResult( success=True, content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)", - duration_ms=duration_ms + duration_ms=duration_ms, ) - + except asyncio.TimeoutError: duration_ms = (time.time() - start_time) * 1000 error_msg = f"工具调用超时({self.call_timeout}秒)" @@ -907,7 +865,7 @@ class MCPClientSession: # v1.7.0: 断路器记录失败 self._circuit_breaker.record_failure() return MCPCallResult(success=False, content=None, error=error_msg, duration_ms=duration_ms) - + except Exception as e: duration_ms = (time.time() - start_time) * 1000 error_msg = str(e) @@ -928,7 +886,7 @@ class MCPClientSession: if self._connected: self.stats.record_disconnect() await self._cleanup() - + async def _cleanup(self) -> None: """清理资源""" self._connected = False @@ -937,31 +895,31 @@ class MCPClientSession: self._prompts = [] # v1.2.0 self._supports_resources = False # v1.2.0 self._supports_prompts = False # v1.2.0 - + try: - if hasattr(self, '_session_context') and self._session_context: + if hasattr(self, "_session_context") and self._session_context: await self._session_context.__aexit__(None, None, None) except Exception as e: logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}") - + try: - if hasattr(self, '_stdio_context') and self._stdio_context: + if hasattr(self, "_stdio_context") and self._stdio_context: await self._stdio_context.__aexit__(None, None, None) except Exception as e: logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}") - + try: - if hasattr(self, '_http_context') and self._http_context: + if hasattr(self, "_http_context") and self._http_context: await self._http_context.__aexit__(None, None, None) except Exception as e: logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}") - + try: - if hasattr(self, '_sse_context') and self._sse_context: + if hasattr(self, "_sse_context") and self._sse_context: await self._sse_context.__aexit__(None, None, None) except Exception as e: logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}") - + self._session = None self._session_context = None self._stdio_context = None @@ -969,27 +927,27 @@ class MCPClientSession: self._sse_context = None self._read_stream = None self._write_stream = None - + logger.debug(f"[{self.server_name}] 连接已关闭") class MCPClientManager: """MCP 客户端管理器,管理多个 MCP 服务器连接 - + 功能: - 管理多个 MCP 服务器连接 - 心跳检测和自动重连 - 调用统计 """ - + _instance: Optional["MCPClientManager"] = None - + def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance - + def __init__(self): if self._initialized: return @@ -1000,14 +958,14 @@ class MCPClientManager: self._all_prompts: Dict[str, Tuple[MCPPromptInfo, MCPClientSession]] = {} # v1.2.0 self._settings: Dict[str, Any] = {} self._lock = asyncio.Lock() - + # 心跳检测任务 self._heartbeat_task: Optional[asyncio.Task] = None self._heartbeat_running = False - + # 状态变化回调 self._on_status_change: Optional[callable] = None - + # 全局统计 self._global_stats = { "total_tool_calls": 0, @@ -1015,15 +973,15 @@ class MCPClientManager: "failed_calls": 0, "start_time": time.time(), } - + def configure(self, settings: Dict[str, Any]) -> None: """配置管理器""" self._settings = settings - + def set_status_change_callback(self, callback: callable) -> None: """设置状态变化回调函数""" self._on_status_change = callback - + def _notify_status_change(self) -> None: """通知状态变化""" if self._on_status_change: @@ -1031,27 +989,27 @@ class MCPClientManager: self._on_status_change() except Exception as e: logger.debug(f"状态变化回调出错: {e}") - + @property def all_tools(self) -> Dict[str, Tuple[MCPToolInfo, MCPClientSession]]: """获取所有已注册的工具""" return self._all_tools.copy() - + @property def all_resources(self) -> Dict[str, Tuple[MCPResourceInfo, MCPClientSession]]: """v1.2.0: 获取所有已注册的资源""" return self._all_resources.copy() - + @property def all_prompts(self) -> Dict[str, Tuple[MCPPromptInfo, MCPClientSession]]: """v1.2.0: 获取所有已注册的提示模板""" return self._all_prompts.copy() - + @property def connected_servers(self) -> List[str]: """获取已连接的服务器列表""" return [name for name, client in self._clients.items() if client.is_connected] - + @property def disconnected_servers(self) -> List[str]: """获取已断开的服务器列表""" @@ -1063,36 +1021,38 @@ class MCPClientManager: if config.name in self._clients: logger.warning(f"服务器 {config.name} 已存在") return False - + call_timeout = self._settings.get("call_timeout", 60.0) client = MCPClientSession(config, call_timeout) self._clients[config.name] = client - + if not config.enabled: logger.info(f"服务器 {config.name} 已添加但未启用") return True - + # 尝试连接 retry_attempts = self._settings.get("retry_attempts", 3) retry_interval = self._settings.get("retry_interval", 5.0) - + for attempt in range(1, retry_attempts + 1): if await client.connect(): self._register_tools(client) return True - + if attempt < retry_attempts: - logger.warning(f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})") + logger.warning( + f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})" + ) await asyncio.sleep(retry_interval) - + logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})") # 连接失败,但保留在 _clients 中以便后续重连 return False - + def _register_tools(self, client: MCPClientSession) -> None: """注册客户端的工具""" tool_prefix = self._settings.get("tool_prefix", "mcp") - + for tool in client.tools: if tool.name.startswith(f"{tool_prefix}_{client.server_name}_"): tool_key = tool.name @@ -1100,12 +1060,12 @@ class MCPClientManager: tool_key = f"{tool_prefix}_{client.server_name}_{tool.name}" self._all_tools[tool_key] = (tool, client) logger.debug(f"注册 MCP 工具: {tool_key}") - + def _unregister_tools(self, server_name: str) -> List[str]: """注销服务器的工具,返回被注销的工具键列表""" tool_prefix = self._settings.get("tool_prefix", "mcp") prefix = f"{tool_prefix}_{server_name}_" - + keys_to_remove = [k for k in self._all_tools.keys() if k.startswith(prefix)] for key in keys_to_remove: del self._all_tools[key] @@ -1115,7 +1075,7 @@ class MCPClientManager: def _register_resources(self, client: MCPClientSession) -> None: """v1.2.0: 注册客户端的资源""" tool_prefix = self._settings.get("tool_prefix", "mcp") - + for resource in client.resources: # 资源键格式: mcp_{server}_{uri_safe_name} # 将 URI 转换为安全的键名 @@ -1123,12 +1083,12 @@ class MCPClientManager: resource_key = f"{tool_prefix}_{client.server_name}_res_{safe_uri}" self._all_resources[resource_key] = (resource, client) logger.debug(f"注册 MCP 资源: {resource_key}") - + def _unregister_resources(self, server_name: str) -> List[str]: """v1.2.0: 注销服务器的资源""" tool_prefix = self._settings.get("tool_prefix", "mcp") prefix = f"{tool_prefix}_{server_name}_res_" - + keys_to_remove = [k for k in self._all_resources.keys() if k.startswith(prefix)] for key in keys_to_remove: del self._all_resources[key] @@ -1138,56 +1098,56 @@ class MCPClientManager: def _register_prompts(self, client: MCPClientSession) -> None: """v1.2.0: 注册客户端的提示模板""" tool_prefix = self._settings.get("tool_prefix", "mcp") - + for prompt in client.prompts: prompt_key = f"{tool_prefix}_{client.server_name}_prompt_{prompt.name}" self._all_prompts[prompt_key] = (prompt, client) logger.debug(f"注册 MCP 提示模板: {prompt_key}") - + def _unregister_prompts(self, server_name: str) -> List[str]: """v1.2.0: 注销服务器的提示模板""" tool_prefix = self._settings.get("tool_prefix", "mcp") prefix = f"{tool_prefix}_{server_name}_prompt_" - + keys_to_remove = [k for k in self._all_prompts.keys() if k.startswith(prefix)] for key in keys_to_remove: del self._all_prompts[key] logger.debug(f"注销 MCP 提示模板: {key}") return keys_to_remove - + async def remove_server(self, server_name: str) -> bool: """移除 MCP 服务器""" async with self._lock: if server_name not in self._clients: return False - + client = self._clients[server_name] await client.disconnect() self._unregister_tools(server_name) self._unregister_resources(server_name) # v1.2.0 self._unregister_prompts(server_name) # v1.2.0 del self._clients[server_name] - + logger.info(f"服务器 {server_name} 已移除") return True - + async def reconnect_server(self, server_name: str) -> bool: """重新连接服务器""" if server_name not in self._clients: return False - + client = self._clients[server_name] - + async with self._lock: self._unregister_tools(server_name) self._unregister_resources(server_name) # v1.2.0 self._unregister_prompts(server_name) # v1.2.0 await client.disconnect() - + # 尝试重连 retry_attempts = self._settings.get("retry_attempts", 3) retry_interval = self._settings.get("retry_interval", 5.0) - + for attempt in range(1, retry_attempts + 1): if await client.connect(): async with self._lock: @@ -1202,46 +1162,42 @@ class MCPClientManager: client.stats.record_reconnect() logger.info(f"服务器 {server_name} 重连成功") return True - + if attempt < retry_attempts: logger.warning(f"服务器 {server_name} 重连失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})") await asyncio.sleep(retry_interval) - + logger.error(f"服务器 {server_name} 重连失败") return False async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult: """调用 MCP 工具""" if tool_key not in self._all_tools: - return MCPCallResult( - success=False, - content=None, - error=f"工具 {tool_key} 不存在" - ) - + return MCPCallResult(success=False, content=None, error=f"工具 {tool_key} 不存在") + tool_info, client = self._all_tools[tool_key] - + # 更新全局统计 self._global_stats["total_tool_calls"] += 1 - + result = await client.call_tool(tool_info.name, arguments) - + if result.success: self._global_stats["successful_calls"] += 1 else: self._global_stats["failed_calls"] += 1 - + return result async def fetch_resources_for_server(self, server_name: str) -> bool: """v1.2.0: 获取指定服务器的资源列表""" if server_name not in self._clients: return False - + client = self._clients[server_name] if not client.is_connected: return False - + success = await client.fetch_resources() if success: async with self._lock: @@ -1252,11 +1208,11 @@ class MCPClientManager: """v1.2.0: 获取指定服务器的提示模板列表""" if server_name not in self._clients: return False - + client = self._clients[server_name] if not client.is_connected: return False - + success = await client.fetch_prompts() if success: async with self._lock: @@ -1265,7 +1221,7 @@ class MCPClientManager: async def read_resource(self, uri: str, server_name: Optional[str] = None) -> MCPCallResult: """v1.2.0: 读取资源内容 - + Args: uri: 资源 URI server_name: 指定服务器名称(可选,不指定则自动查找) @@ -1273,36 +1229,29 @@ class MCPClientManager: # 如果指定了服务器 if server_name: if server_name not in self._clients: - return MCPCallResult( - success=False, - content=None, - error=f"服务器 {server_name} 不存在" - ) + return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在") client = self._clients[server_name] return await client.read_resource(uri) - + # 自动查找拥有该资源的服务器 for resource_key, (resource_info, client) in self._all_resources.items(): if resource_info.uri == uri: return await client.read_resource(uri) - + # 尝试在所有支持 resources 的服务器上查找 for client in self._clients.values(): if client.is_connected and client.supports_resources: result = await client.read_resource(uri) if result.success: return result - - return MCPCallResult( - success=False, - content=None, - error=f"未找到资源: {uri}" - ) - async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None, - server_name: Optional[str] = None) -> MCPCallResult: + return MCPCallResult(success=False, content=None, error=f"未找到资源: {uri}") + + async def get_prompt( + self, name: str, arguments: Optional[Dict[str, str]] = None, server_name: Optional[str] = None + ) -> MCPCallResult: """v1.2.0: 获取提示模板内容 - + Args: name: 提示模板名称 arguments: 模板参数 @@ -1311,42 +1260,34 @@ class MCPClientManager: # 如果指定了服务器 if server_name: if server_name not in self._clients: - return MCPCallResult( - success=False, - content=None, - error=f"服务器 {server_name} 不存在" - ) + return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在") client = self._clients[server_name] return await client.get_prompt(name, arguments) - + # 自动查找拥有该提示模板的服务器 for prompt_key, (prompt_info, client) in self._all_prompts.items(): if prompt_info.name == name: return await client.get_prompt(name, arguments) - - return MCPCallResult( - success=False, - content=None, - error=f"未找到提示模板: {name}" - ) - + + return MCPCallResult(success=False, content=None, error=f"未找到提示模板: {name}") + # ==================== 心跳检测 ==================== - + async def start_heartbeat(self) -> None: """启动心跳检测任务""" if self._heartbeat_running: logger.warning("心跳检测任务已在运行") return - + heartbeat_enabled = self._settings.get("heartbeat_enabled", True) if not heartbeat_enabled: logger.info("心跳检测已禁用") return - + self._heartbeat_running = True self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) logger.info("心跳检测任务已启动") - + async def stop_heartbeat(self) -> None: """停止心跳检测任务""" self._heartbeat_running = False @@ -1358,52 +1299,52 @@ class MCPClientManager: pass self._heartbeat_task = None logger.info("心跳检测任务已停止") - + async def _heartbeat_loop(self) -> None: """心跳检测循环(v1.5.2: 智能心跳间隔)""" base_interval = self._settings.get("heartbeat_interval", 60.0) auto_reconnect = self._settings.get("auto_reconnect", True) max_reconnect_attempts = self._settings.get("max_reconnect_attempts", 3) - + # v1.5.2: 智能心跳配置 adaptive_enabled = self._settings.get("heartbeat_adaptive", True) max_multiplier = self._settings.get("heartbeat_max_multiplier", 3.0) - + # 每个服务器独立的心跳间隔(根据稳定性动态调整) server_intervals: Dict[str, float] = {} min_interval = max(base_interval * 0.5, 30.0) # 最小间隔 max_interval = base_interval * max_multiplier # 最大间隔 - + mode_str = "智能" if adaptive_enabled else "固定" logger.info(f"心跳检测循环启动,{mode_str}模式,基准间隔: {base_interval}秒") - + while self._heartbeat_running: try: # 使用最小的服务器间隔作为循环间隔 current_interval = min(server_intervals.values()) if server_intervals else base_interval current_interval = max(current_interval, min_interval) - + await asyncio.sleep(current_interval) - + if not self._heartbeat_running: break - + current_time = time.time() - + # 检查所有已启用的服务器 for server_name, client in list(self._clients.items()): if not client.config.enabled: continue - + # 初始化服务器间隔 if server_name not in server_intervals: server_intervals[server_name] = base_interval - + # 检查是否到达该服务器的心跳时间 last_heartbeat = client.stats.last_heartbeat_time or 0 if current_time - last_heartbeat < server_intervals[server_name] * 0.9: continue # 还没到心跳时间 - + if client.is_connected: # 检查健康状态 healthy = await client.check_health() @@ -1435,71 +1376,73 @@ class MCPClientManager: # 达到最大重连次数,降低检测频率 server_intervals[server_name] = max_interval logger.debug(f"[{server_name}] 已达最大重连次数,降低检测频率") - + except asyncio.CancelledError: break except Exception as e: logger.error(f"心跳检测循环出错: {e}") await asyncio.sleep(5) - + async def _try_reconnect(self, server_name: str, max_attempts: int) -> bool: """尝试重连服务器""" client = self._clients.get(server_name) if not client: return False - + if client.stats.consecutive_failures >= max_attempts: logger.warning(f"[{server_name}] 连续失败次数已达上限 ({max_attempts}),暂停重连") return False - + logger.info(f"[{server_name}] 尝试重连 (失败次数: {client.stats.consecutive_failures}/{max_attempts})") - + success = await self.reconnect_server(server_name) if not success: client.stats.record_failure() - + self._notify_status_change() # 重连后更新状态 return success # ==================== 统计和状态 ==================== - + def get_tool_stats(self, tool_key: str) -> Optional[Dict[str, Any]]: """获取指定工具的统计信息""" if tool_key not in self._all_tools: return None - + tool_info, client = self._all_tools[tool_key] stats = client.get_tool_stats(tool_info.name) return stats.to_dict() if stats else None - + def get_all_stats(self) -> Dict[str, Any]: """获取所有统计信息""" server_stats = {} tool_stats = {} - + for server_name, client in self._clients.items(): server_stats[server_name] = client.stats.to_dict() for tool_name, stats in client.get_all_tool_stats().items(): full_key = f"{self._settings.get('tool_prefix', 'mcp')}_{server_name}_{tool_name}" tool_stats[full_key] = stats.to_dict() - + uptime = time.time() - self._global_stats["start_time"] - + return { "global": { **self._global_stats, "uptime_seconds": round(uptime, 2), - "calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2) if uptime > 0 else 0, + "calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2) + if uptime > 0 + else 0, }, "servers": server_stats, "tools": tool_stats, } - + async def shutdown(self) -> None: """关闭所有连接""" # 停止心跳检测 await self.stop_heartbeat() - + async with self._lock: for client in self._clients.values(): await client.disconnect() @@ -1508,7 +1451,7 @@ class MCPClientManager: self._all_resources.clear() # v1.2.0 self._all_prompts.clear() # v1.2.0 logger.info("MCP 客户端管理器已关闭") - + def get_status(self) -> Dict[str, Any]: """获取状态信息""" return { diff --git a/plugins/MaiBot_MCPBridgePlugin/plugin.py b/plugins/MaiBot_MCPBridgePlugin/plugin.py index aae3f4bd..4ad38d0f 100644 --- a/plugins/MaiBot_MCPBridgePlugin/plugin.py +++ b/plugins/MaiBot_MCPBridgePlugin/plugin.py @@ -84,7 +84,7 @@ from .mcp_client import ( TransportType, mcp_manager, ) -from .config_converter import ConfigConverter, ConversionResult +from .config_converter import ConfigConverter logger = get_logger("mcp_bridge_plugin") @@ -93,9 +93,11 @@ logger = get_logger("mcp_bridge_plugin") # v1.4.0: 调用链路追踪 # ============================================================================ + @dataclass class ToolCallRecord: """工具调用记录""" + call_id: str timestamp: float tool_name: str @@ -115,46 +117,46 @@ class ToolCallRecord: class ToolCallTracer: """工具调用追踪器""" - + def __init__(self, max_records: int = 100): self._records: deque[ToolCallRecord] = deque(maxlen=max_records) self._enabled: bool = True self._log_enabled: bool = False self._log_path: Optional[Path] = None - + def configure(self, enabled: bool, max_records: int, log_enabled: bool, log_path: Optional[Path] = None) -> None: """配置追踪器""" self._enabled = enabled self._records = deque(self._records, maxlen=max_records) self._log_enabled = log_enabled self._log_path = log_path - + def record(self, record: ToolCallRecord) -> None: """添加调用记录""" if not self._enabled: return - + self._records.append(record) - + if self._log_enabled and self._log_path: self._write_to_log(record) - + def get_recent(self, n: int = 10) -> List[ToolCallRecord]: """获取最近 N 条记录""" return list(self._records)[-n:] - + def get_by_tool(self, tool_name: str) -> List[ToolCallRecord]: """按工具名筛选记录""" return [r for r in self._records if r.tool_name == tool_name] - + def get_by_server(self, server_name: str) -> List[ToolCallRecord]: """按服务器名筛选记录""" return [r for r in self._records if r.server_name == server_name] - + def clear(self) -> None: """清空记录""" self._records.clear() - + def _write_to_log(self, record: ToolCallRecord) -> None: """写入 JSONL 日志文件""" try: @@ -164,7 +166,7 @@ class ToolCallTracer: f.write(json.dumps(asdict(record), ensure_ascii=False) + "\n") except Exception as e: logger.warning(f"写入追踪日志失败: {e}") - + @property def total_records(self) -> int: return len(self._records) @@ -178,9 +180,11 @@ tool_call_tracer = ToolCallTracer() # v1.4.0: 工具调用缓存 # ============================================================================ + @dataclass class CacheEntry: """缓存条目""" + tool_name: str args_hash: str result: str @@ -191,7 +195,7 @@ class CacheEntry: class ToolCallCache: """工具调用缓存(LRU)""" - + def __init__(self, max_entries: int = 200, ttl: int = 300): self._cache: OrderedDict[str, CacheEntry] = OrderedDict() self._max_entries = max_entries @@ -199,54 +203,54 @@ class ToolCallCache: self._enabled = False self._exclude_patterns: List[str] = [] self._stats = {"hits": 0, "misses": 0} - + def configure(self, enabled: bool, ttl: int, max_entries: int, exclude_tools: str) -> None: """配置缓存""" self._enabled = enabled self._ttl = ttl self._max_entries = max_entries self._exclude_patterns = [p.strip() for p in exclude_tools.strip().split("\n") if p.strip()] - + def get(self, tool_name: str, args: Dict) -> Optional[str]: """获取缓存""" if not self._enabled: return None - + if self._is_excluded(tool_name): return None - + key = self._generate_key(tool_name, args) - + if key not in self._cache: self._stats["misses"] += 1 return None - + entry = self._cache[key] - + # 检查是否过期 if time.time() > entry.expires_at: del self._cache[key] self._stats["misses"] += 1 return None - + # LRU: 移到末尾 self._cache.move_to_end(key) entry.hit_count += 1 self._stats["hits"] += 1 - + return entry.result - + def set(self, tool_name: str, args: Dict, result: str) -> None: """设置缓存""" if not self._enabled: return - + if self._is_excluded(tool_name): return - + key = self._generate_key(tool_name, args) now = time.time() - + entry = CacheEntry( tool_name=tool_name, args_hash=key, @@ -254,7 +258,7 @@ class ToolCallCache: created_at=now, expires_at=now + self._ttl, ) - + # 如果已存在,更新 if key in self._cache: self._cache[key] = entry @@ -263,25 +267,25 @@ class ToolCallCache: # 检查容量 self._evict_if_needed() self._cache[key] = entry - + def clear(self) -> None: """清空缓存""" self._cache.clear() self._stats = {"hits": 0, "misses": 0} - + def _generate_key(self, tool_name: str, args: Dict) -> str: """生成缓存键""" args_str = json.dumps(args, sort_keys=True, ensure_ascii=False) content = f"{tool_name}:{args_str}" return hashlib.md5(content.encode()).hexdigest() - + def _is_excluded(self, tool_name: str) -> bool: """检查是否在排除列表中""" for pattern in self._exclude_patterns: if fnmatch.fnmatch(tool_name, pattern): return True return False - + def _evict_if_needed(self) -> None: """必要时淘汰条目""" # 先清理过期的 @@ -289,11 +293,11 @@ class ToolCallCache: expired_keys = [k for k, v in self._cache.items() if now > v.expires_at] for k in expired_keys: del self._cache[k] - + # LRU 淘汰 while len(self._cache) >= self._max_entries: self._cache.popitem(last=False) - + def get_stats(self) -> Dict[str, Any]: """获取缓存统计""" total = self._stats["hits"] + self._stats["misses"] @@ -317,16 +321,17 @@ tool_call_cache = ToolCallCache() # v1.4.0: 工具权限控制 # ============================================================================ + class PermissionChecker: """工具权限检查器""" - + def __init__(self): self._enabled = False self._default_mode = "allow_all" # allow_all 或 deny_all self._rules: List[Dict] = [] self._quick_deny_groups: set = set() self._quick_allow_users: set = set() - + def configure( self, enabled: bool, @@ -338,61 +343,61 @@ class PermissionChecker: """配置权限检查器""" self._enabled = enabled self._default_mode = default_mode if default_mode in ("allow_all", "deny_all") else "allow_all" - + # 解析快捷配置 self._quick_deny_groups = {g.strip() for g in quick_deny_groups.strip().split("\n") if g.strip()} self._quick_allow_users = {u.strip() for u in quick_allow_users.strip().split("\n") if u.strip()} - + try: self._rules = json.loads(rules_json) if rules_json.strip() else [] except json.JSONDecodeError as e: logger.warning(f"权限规则 JSON 解析失败: {e}") self._rules = [] - + def check(self, tool_name: str, chat_id: str, user_id: str, is_group: bool) -> bool: """检查权限 - + Args: tool_name: 工具名称 chat_id: 聊天 ID(群号或私聊 ID) user_id: 用户 ID is_group: 是否为群聊 - + Returns: True 表示允许,False 表示拒绝 """ if not self._enabled: return True - + # 快捷配置优先级最高 # 1. 管理员白名单(始终允许) if user_id and user_id in self._quick_allow_users: return True - + # 2. 禁用群列表(始终拒绝) if is_group and chat_id and chat_id in self._quick_deny_groups: return False - + # 查找匹配的规则 for rule in self._rules: tool_pattern = rule.get("tool", "") if not self._match_tool(tool_pattern, tool_name): continue - + # 找到匹配的规则 mode = rule.get("mode", "") allowed = rule.get("allowed", []) denied = rule.get("denied", []) - + # 构建当前上下文的 ID 列表 context_ids = self._build_context_ids(chat_id, user_id, is_group) - + # 检查 denied 列表(优先级最高) if denied: for ctx_id in context_ids: if self._match_id_list(denied, ctx_id): return False - + # 检查 allowed 列表 if allowed: for ctx_id in context_ids: @@ -401,41 +406,41 @@ class PermissionChecker: # 如果是 whitelist 模式且不在 allowed 中,拒绝 if mode == "whitelist": return False - + # 规则匹配但没有明确允许/拒绝,继续检查下一条规则 - + # 没有匹配的规则,使用默认模式 return self._default_mode == "allow_all" - + def _match_tool(self, pattern: str, tool_name: str) -> bool: """工具名通配符匹配""" if not pattern: return False return fnmatch.fnmatch(tool_name, pattern) - + def _build_context_ids(self, chat_id: str, user_id: str, is_group: bool) -> List[str]: """构建上下文 ID 列表""" ids = [] - + # 用户级别(任何场景生效) if user_id: ids.append(f"qq:{user_id}:user") - + # 场景级别 if is_group and chat_id: ids.append(f"qq:{chat_id}:group") elif chat_id: ids.append(f"qq:{chat_id}:private") - + return ids - + def _match_id_list(self, id_list: List[str], context_id: str) -> bool: """检查 ID 是否在列表中""" for rule_id in id_list: if fnmatch.fnmatch(context_id, rule_id): return True return False - + def get_rules_for_tool(self, tool_name: str) -> List[Dict]: """获取特定工具的权限规则""" return [r for r in self._rules if self._match_tool(r.get("tool", ""), tool_name)] @@ -449,6 +454,7 @@ permission_checker = PermissionChecker() # 工具类型转换 # ============================================================================ + def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType: """将 JSON Schema 类型转换为 MaiBot 的 ToolParamType""" type_mapping = { @@ -462,34 +468,36 @@ def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType: return type_mapping.get(json_type, ToolParamType.STRING) -def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]: +def parse_mcp_parameters( + input_schema: Dict[str, Any], +) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]: """解析 MCP 工具的参数 schema,转换为 MaiBot 的参数格式""" parameters = [] - + if not input_schema: return parameters - + properties = input_schema.get("properties", {}) required = input_schema.get("required", []) - + for param_name, param_info in properties.items(): json_type = param_info.get("type", "string") param_type = convert_json_type_to_tool_param_type(json_type) description = param_info.get("description", f"参数 {param_name}") - + if json_type == "array": description = f"{description} (JSON 数组格式)" elif json_type == "object": description = f"{description} (JSON 对象格式)" - + is_required = param_name in required enum_values = param_info.get("enum") - + if enum_values is not None: enum_values = [str(v) for v in enum_values] - + parameters.append((param_name, param_type, description, is_required, enum_values)) - + return parameters @@ -497,28 +505,29 @@ def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolPa # MCP 工具代理 # ============================================================================ + class MCPToolProxy(BaseTool): """MCP 工具代理基类""" - + name: str = "" description: str = "" parameters: List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]] = [] available_for_llm: bool = True - + _mcp_tool_key: str = "" _mcp_original_name: str = "" _mcp_server_name: str = "" - + async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: """执行 MCP 工具调用""" global _plugin_instance - + call_id = str(uuid.uuid4())[:8] start_time = time.time() - + # 移除 MaiBot 内部标记 args = {k: v for k, v in function_args.items() if k != "llm_called"} - + # 解析 JSON 字符串参数 parsed_args = {} for key, value in args.items(): @@ -532,24 +541,21 @@ class MCPToolProxy(BaseTool): parsed_args[key] = value else: parsed_args[key] = value - + # 获取上下文信息 chat_id, user_id, is_group, user_query = self._get_context_info() - + # v1.4.0: 权限检查 if not permission_checker.check(self.name, chat_id, user_id, is_group): logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}") - return { - "name": self.name, - "content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用" - } - + return {"name": self.name, "content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"} + logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}") - + # v1.4.0: 检查缓存 cache_hit = False cached_result = tool_call_cache.get(self.name, parsed_args) - + if cached_result is not None: cache_hit = True content = cached_result @@ -560,13 +566,13 @@ class MCPToolProxy(BaseTool): else: # 调用 MCP result = await mcp_manager.call_tool(self._mcp_tool_key, parsed_args) - + if result.success: content = result.content raw_result = content success = True error = "" - + # 存入缓存 tool_call_cache.set(self.name, parsed_args, content) else: @@ -575,7 +581,7 @@ class MCPToolProxy(BaseTool): success = False error = result.error logger.warning(f"MCP 工具 {self.name} 调用失败: {result.error}") - + # v1.3.0: 后处理 post_processed = False processed_result = content @@ -585,9 +591,9 @@ class MCPToolProxy(BaseTool): post_processed = True processed_result = processed_content content = processed_content - + duration_ms = (time.time() - start_time) * 1000 - + # v1.4.0: 记录调用追踪 record = ToolCallRecord( call_id=call_id, @@ -607,16 +613,16 @@ class MCPToolProxy(BaseTool): cache_hit=cache_hit, ) tool_call_tracer.record(record) - + return {"name": self.name, "content": content} - + def _get_context_info(self) -> Tuple[str, str, bool, str]: """获取上下文信息""" chat_id = "" user_id = "" is_group = False user_query = "" - + if self.chat_stream and hasattr(self.chat_stream, "context") and self.chat_stream.context: try: ctx = self.chat_stream.context @@ -626,53 +632,53 @@ class MCPToolProxy(BaseTool): user_id = str(ctx.user_id) if ctx.user_id else "" if hasattr(ctx, "is_group"): is_group = bool(ctx.is_group) - + last_message = ctx.get_last_message() if last_message and hasattr(last_message, "processed_plain_text"): user_query = last_message.processed_plain_text or "" except Exception as e: logger.debug(f"获取上下文信息失败: {e}") - + return chat_id, user_id, is_group, user_query async def _post_process_result(self, content: str) -> str: """v1.3.0: 对工具返回结果进行后处理(摘要提炼)""" global _plugin_instance - + if _plugin_instance is None: return content - + settings = _plugin_instance.config.get("settings", {}) - + if not settings.get("post_process_enabled", False): return content - + server_post_config = self._get_server_post_process_config() - + if server_post_config is not None: if not server_post_config.get("enabled", True): return content - + threshold = settings.get("post_process_threshold", 500) if server_post_config and "threshold" in server_post_config: threshold = server_post_config["threshold"] - + content_length = len(content) if content else 0 if content_length <= threshold: return content - + user_query = self._get_context_info()[3] if not user_query: return content - + max_tokens = settings.get("post_process_max_tokens", 500) if server_post_config and "max_tokens" in server_post_config: max_tokens = server_post_config["max_tokens"] - + prompt_template = settings.get("post_process_prompt", "") if server_post_config and "prompt" in server_post_config: prompt_template = server_post_config["prompt"] - + if not prompt_template: prompt_template = """用户问题:{query} @@ -680,13 +686,13 @@ class MCPToolProxy(BaseTool): {result} 请从上述内容中提取与用户问题最相关的关键信息,简洁准确地输出:""" - + try: prompt = prompt_template.format(query=user_query, result=content) except KeyError as e: logger.warning(f"后处理 prompt 模板格式错误: {e}") return content - + try: processed_content = await self._call_post_process_llm(prompt, max_tokens, settings, server_post_config) if processed_content: @@ -696,14 +702,14 @@ class MCPToolProxy(BaseTool): except Exception as e: logger.error(f"MCP 工具 {self.name} 后处理失败: {e}") return content - + def _get_server_post_process_config(self) -> Optional[Dict[str, Any]]: """获取当前服务器的后处理配置""" global _plugin_instance - + if _plugin_instance is None: return None - + servers_section = _plugin_instance.config.get("servers", {}) if isinstance(servers_section, dict): servers_list = servers_section.get("list", "[]") @@ -718,29 +724,25 @@ class MCPToolProxy(BaseTool): return None else: servers = servers_section if isinstance(servers_section, list) else [] - + for server_conf in servers: if server_conf.get("name") == self._mcp_server_name: return server_conf.get("post_process") - + return None - + async def _call_post_process_llm( - self, - prompt: str, - max_tokens: int, - settings: Dict[str, Any], - server_config: Optional[Dict[str, Any]] + self, prompt: str, max_tokens: int, settings: Dict[str, Any], server_config: Optional[Dict[str, Any]] ) -> Optional[str]: """调用 LLM 进行后处理""" from src.config.config import model_config from src.config.api_ada_configs import TaskConfig from src.llm_models.utils_model import LLMRequest - + model_name = settings.get("post_process_model", "") if server_config and "model" in server_config: model_name = server_config["model"] - + if model_name: task_config = TaskConfig( model_list=[model_name], @@ -750,59 +752,56 @@ class MCPToolProxy(BaseTool): ) else: task_config = model_config.model_task_config.utils - + llm_request = LLMRequest(model_set=task_config, request_type="mcp_post_process") - + response, (reasoning, model_used, _) = await llm_request.generate_response_async( prompt=prompt, max_tokens=max_tokens, temperature=0.3, ) - + return response.strip() if response else None - + def _format_error_message(self, error: str, duration_ms: float) -> str: """格式化友好的错误消息""" if not error: return "工具调用失败(未知错误)" - + error_lower = error.lower() - + if "未连接" in error or "not connected" in error_lower: return f"⚠️ MCP 服务器 [{self._mcp_server_name}] 未连接,请检查服务器状态或等待自动重连" - + if "超时" in error or "timeout" in error_lower: return f"⏱️ 工具调用超时(耗时 {duration_ms:.0f}ms),服务器响应过慢,请稍后重试" - + if "connection" in error_lower and ("closed" in error_lower or "reset" in error_lower): return f"🔌 与 MCP 服务器 [{self._mcp_server_name}] 的连接已断开,正在尝试重连..." - + if "invalid" in error_lower and "argument" in error_lower: return f"❌ 参数错误: {error}" - + return f"❌ 工具调用失败: {error}" - + async def direct_execute(self, **function_args) -> Dict[str, Any]: """直接执行(供其他插件调用)""" return await self.execute(function_args) def create_mcp_tool_class( - tool_key: str, - tool_info: MCPToolInfo, - tool_prefix: str, - disabled: bool = False + tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False ) -> Type[MCPToolProxy]: """根据 MCP 工具信息动态创建 BaseTool 子类""" parameters = parse_mcp_parameters(tool_info.input_schema) - + class_name = f"MCPTool_{tool_info.server_name}_{tool_info.name}".replace("-", "_").replace(".", "_") tool_name = tool_key.replace("-", "_").replace(".", "_") - + description = tool_info.description if not description.endswith(f"[来自 MCP 服务器: {tool_info.server_name}]"): description = f"{description} [来自 MCP 服务器: {tool_info.server_name}]" - + tool_class = type( class_name, (MCPToolProxy,), @@ -814,31 +813,27 @@ def create_mcp_tool_class( "_mcp_tool_key": tool_key, "_mcp_original_name": tool_info.name, "_mcp_server_name": tool_info.server_name, - } + }, ) - + return tool_class class MCPToolRegistry: """MCP 工具注册表""" - + def __init__(self): self._tool_classes: Dict[str, Type[MCPToolProxy]] = {} self._tool_infos: Dict[str, ToolInfo] = {} - + def register_tool( - self, - tool_key: str, - tool_info: MCPToolInfo, - tool_prefix: str, - disabled: bool = False + self, tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False ) -> Tuple[ToolInfo, Type[MCPToolProxy]]: """注册 MCP 工具""" tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled) - + self._tool_classes[tool_key] = tool_class - + info = ToolInfo( name=tool_class.name, tool_description=tool_class.description, @@ -847,9 +842,9 @@ class MCPToolRegistry: component_type=ComponentType.TOOL, ) self._tool_infos[tool_key] = info - + return info, tool_class - + def unregister_tool(self, tool_key: str) -> bool: """注销工具""" if tool_key in self._tool_classes: @@ -857,11 +852,11 @@ class MCPToolRegistry: del self._tool_infos[tool_key] return True return False - + def get_all_components(self) -> List[Tuple[ComponentInfo, Type]]: """获取所有工具组件""" return [(self._tool_infos[key], self._tool_classes[key]) for key in self._tool_classes.keys()] - + def clear(self) -> None: """清空所有注册""" self._tool_classes.clear() @@ -879,9 +874,10 @@ _plugin_instance: Optional["MCPBridgePlugin"] = None # 内置工具 # ============================================================================ + class MCPReadResourceTool(BaseTool): """v1.2.0: MCP 资源读取工具""" - + name = "mcp_read_resource" description = "读取 MCP 服务器提供的资源内容(如文件、数据库记录等)。使用前请先用 mcp_status 查看可用资源。" parameters = [ @@ -889,28 +885,28 @@ class MCPReadResourceTool(BaseTool): ("server_name", ToolParamType.STRING, "指定服务器名称(可选,不指定则自动查找)", False, None), ] available_for_llm = True - + async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: uri = function_args.get("uri", "") server_name = function_args.get("server_name") - + if not uri: return {"name": self.name, "content": "❌ 请提供资源 URI"} - + result = await mcp_manager.read_resource(uri, server_name) - + if result.success: return {"name": self.name, "content": result.content} else: return {"name": self.name, "content": f"❌ 读取资源失败: {result.error}"} - + async def direct_execute(self, **function_args) -> Dict[str, Any]: return await self.execute(function_args) class MCPGetPromptTool(BaseTool): """v1.2.0: MCP 提示模板工具""" - + name = "mcp_get_prompt" description = "获取 MCP 服务器提供的提示模板内容。使用前请先用 mcp_status 查看可用模板。" parameters = [ @@ -919,78 +915,83 @@ class MCPGetPromptTool(BaseTool): ("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None), ] available_for_llm = True - + async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: prompt_name = function_args.get("name", "") arguments_str = function_args.get("arguments", "") server_name = function_args.get("server_name") - + if not prompt_name: return {"name": self.name, "content": "❌ 请提供提示模板名称"} - + arguments = None if arguments_str: try: arguments = json.loads(arguments_str) except json.JSONDecodeError: return {"name": self.name, "content": "❌ 参数格式错误,请使用 JSON 对象格式"} - + result = await mcp_manager.get_prompt(prompt_name, arguments, server_name) - + if result.success: return {"name": self.name, "content": result.content} else: return {"name": self.name, "content": f"❌ 获取提示模板失败: {result.error}"} - + async def direct_execute(self, **function_args) -> Dict[str, Any]: return await self.execute(function_args) class MCPStatusTool(BaseTool): """MCP 状态查询工具""" - + name = "mcp_status" - description = "查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、资源列表、提示模板列表、调用统计、追踪记录等信息" + description = ( + "查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、资源列表、提示模板列表、调用统计、追踪记录等信息" + ) parameters = [ - ("query_type", ToolParamType.STRING, "查询类型", False, ["status", "tools", "resources", "prompts", "stats", "trace", "cache", "all"]), + ( + "query_type", + ToolParamType.STRING, + "查询类型", + False, + ["status", "tools", "resources", "prompts", "stats", "trace", "cache", "all"], + ), ("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None), ] available_for_llm = True - + async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: query_type = function_args.get("query_type", "status") server_name = function_args.get("server_name") - + result_parts = [] - + if query_type in ("status", "all"): result_parts.append(self._format_status(server_name)) - + if query_type in ("tools", "all"): result_parts.append(self._format_tools(server_name)) - + if query_type in ("resources", "all"): result_parts.append(self._format_resources(server_name)) - + if query_type in ("prompts", "all"): result_parts.append(self._format_prompts(server_name)) - + if query_type in ("stats", "all"): result_parts.append(self._format_stats(server_name)) - + # v1.4.0: 追踪记录 if query_type in ("trace",): result_parts.append(self._format_trace()) - + # v1.4.0: 缓存状态 if query_type in ("cache",): result_parts.append(self._format_cache()) - - return { - "name": self.name, - "content": "\n\n".join(result_parts) if result_parts else "未知的查询类型" - } - + + return {"name": self.name, "content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"} + def _format_status(self, server_name: Optional[str] = None) -> str: status = mcp_manager.get_status() lines = ["📊 MCP 桥接插件状态"] @@ -999,24 +1000,24 @@ class MCPStatusTool(BaseTool): lines.append(f" 已断开: {status['disconnected_servers']}") lines.append(f" 可用工具数: {status['total_tools']}") lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}") - + lines.append("\n🔌 服务器详情:") - for name, info in status['servers'].items(): + for name, info in status["servers"].items(): if server_name and name != server_name: continue - status_icon = "✅" if info['connected'] else "❌" - enabled_text = "" if info['enabled'] else " (已禁用)" + status_icon = "✅" if info["connected"] else "❌" + enabled_text = "" if info["enabled"] else " (已禁用)" lines.append(f" {status_icon} {name}{enabled_text}") lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}") - if info['consecutive_failures'] > 0: + if info["consecutive_failures"] > 0: lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']} 次") - + return "\n".join(lines) - + def _format_tools(self, server_name: Optional[str] = None) -> str: tools = mcp_manager.all_tools lines = ["🔧 可用 MCP 工具"] - + by_server: Dict[str, List[str]] = {} for tool_key, (tool_info, _) in tools.items(): if server_name and tool_info.server_name != server_name: @@ -1024,35 +1025,35 @@ class MCPStatusTool(BaseTool): if tool_info.server_name not in by_server: by_server[tool_info.server_name] = [] by_server[tool_info.server_name].append(f" • {tool_key}: {tool_info.description[:50]}...") - + for srv_name, tool_list in by_server.items(): lines.append(f"\n📦 {srv_name} ({len(tool_list)} 个工具):") lines.extend(tool_list) - + if not by_server: lines.append(" (无可用工具)") - + return "\n".join(lines) - + def _format_stats(self, server_name: Optional[str] = None) -> str: stats = mcp_manager.get_all_stats() lines = ["📈 调用统计"] - - g = stats['global'] + + g = stats["global"] lines.append(f" 总调用次数: {g['total_tool_calls']}") lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}") - if g['total_tool_calls'] > 0: - success_rate = (g['successful_calls'] / g['total_tool_calls']) * 100 + if g["total_tool_calls"] > 0: + success_rate = (g["successful_calls"] / g["total_tool_calls"]) * 100 lines.append(f" 成功率: {success_rate:.1f}%") lines.append(f" 运行时间: {g['uptime_seconds']:.0f} 秒") - + return "\n".join(lines) - + def _format_resources(self, server_name: Optional[str] = None) -> str: resources = mcp_manager.all_resources if not resources: return "📦 当前没有可用的 MCP 资源" - + lines = ["📦 可用 MCP 资源"] by_server: Dict[str, List[MCPResourceInfo]] = {} for key, (resource_info, _) in resources.items(): @@ -1061,19 +1062,19 @@ class MCPStatusTool(BaseTool): if resource_info.server_name not in by_server: by_server[resource_info.server_name] = [] by_server[resource_info.server_name].append(resource_info) - + for srv_name, resource_list in by_server.items(): lines.append(f"\n🔌 {srv_name} ({len(resource_list)} 个资源):") for res in resource_list: lines.append(f" • {res.name}: {res.uri}") - + return "\n".join(lines) - + def _format_prompts(self, server_name: Optional[str] = None) -> str: prompts = mcp_manager.all_prompts if not prompts: return "📝 当前没有可用的 MCP 提示模板" - + lines = ["📝 可用 MCP 提示模板"] by_server: Dict[str, List[MCPPromptInfo]] = {} for key, (prompt_info, _) in prompts.items(): @@ -1082,20 +1083,20 @@ class MCPStatusTool(BaseTool): if prompt_info.server_name not in by_server: by_server[prompt_info.server_name] = [] by_server[prompt_info.server_name].append(prompt_info) - + for srv_name, prompt_list in by_server.items(): lines.append(f"\n🔌 {srv_name} ({len(prompt_list)} 个模板):") for prompt in prompt_list: lines.append(f" • {prompt.name}") - + return "\n".join(lines) - + def _format_trace(self) -> str: """v1.4.0: 格式化追踪记录""" records = tool_call_tracer.get_recent(10) if not records: return "🔍 暂无调用追踪记录" - + lines = ["🔍 最近调用追踪记录"] for r in reversed(records): status = "✅" if r.success else "❌" @@ -1104,9 +1105,9 @@ class MCPStatusTool(BaseTool): lines.append(f" {status}{cache}{post} {r.tool_name} ({r.duration_ms:.0f}ms)") if r.error: lines.append(f" 错误: {r.error[:50]}") - + return "\n".join(lines) - + def _format_cache(self) -> str: """v1.4.0: 格式化缓存状态""" stats = tool_call_cache.get_stats() @@ -1117,7 +1118,7 @@ class MCPStatusTool(BaseTool): lines.append(f" 命中: {stats['hits']}, 未命中: {stats['misses']}") lines.append(f" 命中率: {stats['hit_rate']}") return "\n".join(lines) - + async def direct_execute(self, **function_args) -> Dict[str, Any]: return await self.execute(function_args) @@ -1126,6 +1127,7 @@ class MCPStatusTool(BaseTool): # 命令处理 # ============================================================================ + class MCPStatusCommand(BaseCommand): """MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态""" @@ -1140,23 +1142,23 @@ class MCPStatusCommand(BaseCommand): if subcommand == "reconnect": return await self._handle_reconnect(arg) - + # v1.4.0: 追踪命令 if subcommand == "trace": return await self._handle_trace(arg) - + # v1.4.0: 缓存命令 if subcommand == "cache": return await self._handle_cache(arg) - + # v1.4.0: 权限命令 if subcommand == "perm": return await self._handle_perm(arg) - + # v1.6.0: 导出命令 if subcommand == "export": return await self._handle_export(arg) - + # v1.7.0: 工具搜索命令 if subcommand == "search": return await self._handle_search(arg) @@ -1169,7 +1171,7 @@ class MCPStatusCommand(BaseCommand): """查找相似的服务器名称""" name_lower = name.lower() all_servers = list(mcp_manager._clients.keys()) - + # 简单的相似度匹配:包含关系或前缀匹配 similar = [] for srv in all_servers: @@ -1178,7 +1180,7 @@ class MCPStatusCommand(BaseCommand): similar.append(srv) elif srv_lower.startswith(name_lower[:3]) if len(name_lower) >= 3 else False: similar.append(srv) - + return similar[:max_results] async def _handle_reconnect(self, server_name: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: @@ -1212,7 +1214,7 @@ class MCPStatusCommand(BaseCommand): await self.send_text(f"{status} {srv}") return (True, None, True) - + async def _handle_trace(self, arg: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: """v1.4.0: 处理追踪命令""" if arg and arg.isdigit(): @@ -1225,11 +1227,11 @@ class MCPStatusCommand(BaseCommand): else: # /mcp trace - 最近 10 条 records = tool_call_tracer.get_recent(10) - + if not records: await self.send_text("🔍 暂无调用追踪记录\n\n用法: /mcp trace [数量|工具名]") return (True, None, True) - + lines = [f"🔍 调用追踪记录 ({len(records)} 条)"] lines.append("-" * 30) for i, r in enumerate(reversed(records)): @@ -1243,17 +1245,17 @@ class MCPStatusCommand(BaseCommand): lines.append(f" 错误: {r.error[:50]}") if i < len(records) - 1: lines.append("") - + await self.send_text("\n".join(lines)) return (True, None, True) - + async def _handle_cache(self, arg: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: """v1.4.0: 处理缓存命令""" if arg == "clear": tool_call_cache.clear() await self.send_text("✅ 缓存已清空") return (True, None, True) - + stats = tool_call_cache.get_stats() lines = ["🗄️ 缓存状态"] lines.append(f"├ 启用: {'是' if stats['enabled'] else '否'}") @@ -1262,22 +1264,22 @@ class MCPStatusCommand(BaseCommand): lines.append(f"├ 命中: {stats['hits']}") lines.append(f"├ 未命中: {stats['misses']}") lines.append(f"└ 命中率: {stats['hit_rate']}") - + await self.send_text("\n".join(lines)) return (True, None, True) - + async def _handle_perm(self, arg: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: """v1.4.0: 处理权限命令""" global _plugin_instance - + if _plugin_instance is None: await self.send_text("❌ 插件未初始化") return (True, None, True) - + perm_config = _plugin_instance.config.get("permissions", {}) enabled = perm_config.get("perm_enabled", False) default_mode = perm_config.get("perm_default_mode", "allow_all") - + if arg: # 查看特定工具的权限 rules = permission_checker.get_rules_for_tool(arg) @@ -1306,50 +1308,52 @@ class MCPStatusCommand(BaseCommand): lines.append(f"├ 管理员白名单: {allow_count} 人") lines.append(f"└ 高级规则: {len(permission_checker._rules)} 条") await self.send_text("\n".join(lines)) - + return (True, None, True) - + async def _handle_export(self, format_type: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: """v1.6.0: 处理导出命令""" global _plugin_instance - + if _plugin_instance is None: await self.send_text("❌ 插件未初始化") return (True, None, True) - + # 获取当前服务器列表 servers_section = _plugin_instance.config.get("servers", {}) servers_list_str = servers_section.get("list", "[]") if isinstance(servers_section, dict) else "[]" - + try: servers = json.loads(servers_list_str) if servers_list_str.strip() else [] except json.JSONDecodeError: await self.send_text("❌ 当前服务器配置格式错误,无法导出") return (True, None, True) - + if not servers: await self.send_text("📤 当前没有配置任何服务器") return (True, None, True) - + # 确定导出格式 format_type = (format_type or "claude").lower() if format_type not in ("claude", "kiro", "maibot"): format_type = "claude" - + # 导出 try: exported = ConfigConverter.export_to_string(servers, format_type, pretty=True) - - format_name = {"claude": "Claude Desktop", "kiro": "Kiro MCP", "maibot": "MaiBot"}.get(format_type, format_type) + + format_name = {"claude": "Claude Desktop", "kiro": "Kiro MCP", "maibot": "MaiBot"}.get( + format_type, format_type + ) lines = [f"📤 导出为 {format_name} 格式 ({len(servers)} 个服务器):"] lines.append("") lines.append(exported) - + await self.send_text("\n".join(lines)) except Exception as e: logger.error(f"导出配置失败: {e}") await self.send_text(f"❌ 导出失败: {str(e)}") - + return (True, None, True) async def _handle_search(self, query: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: @@ -1406,11 +1410,11 @@ class MCPStatusCommand(BaseCommand): for srv_name, tool_list in by_server.items(): lines.append(f"\n📦 {srv_name} ({len(tool_list)} 个):") - + # 单服务器或结果少于 15 个时显示全部 show_all = single_server or len(matched) <= 15 display_limit = len(tool_list) if show_all else 5 - + for tool_key, tool_info in tool_list[:display_limit]: desc = tool_info.description[:40] + "..." if len(tool_info.description) > 40 else tool_info.description lines.append(f" • {tool_key}") @@ -1446,9 +1450,9 @@ class MCPStatusCommand(BaseCommand): cb = info.get("circuit_breaker", {}) cb_state = cb.get("state", "closed") if cb_state == "open": - lines.append(f" ⚡ 断路器熔断中") + lines.append(" ⚡ 断路器熔断中") elif cb_state == "half_open": - lines.append(f" ⚡ 断路器试探中") + lines.append(" ⚡ 断路器试探中") if info["consecutive_failures"] > 0: lines.append(f" ⚠️ 连续失败 {info['consecutive_failures']} 次") @@ -1464,7 +1468,7 @@ class MCPStatusCommand(BaseCommand): # 如果指定了服务器名,显示全部工具;否则折叠显示 show_all = server_name is not None - + for srv, tool_list in by_server.items(): lines.append(f" 📦 {srv} ({len(tool_list)})") if show_all: @@ -1527,13 +1531,13 @@ class MCPImportCommand(BaseCommand): async def execute(self) -> Tuple[bool, Optional[str], bool]: """执行导入命令""" global _plugin_instance - + if _plugin_instance is None: await self.send_text("❌ 插件未初始化") return (True, None, True) - + content = self.matched_groups.get("content", "") - + if not content or not content.strip(): # 显示使用帮助 help_text = """📥 MCP 配置导入 @@ -1551,31 +1555,31 @@ class MCPImportCommand(BaseCommand): /mcp import {"mcpServers":{"api":{"url":"https://example.com/mcp","transport":"sse"}}}""" await self.send_text(help_text) return (True, None, True) - + # 获取现有服务器名称 servers_section = _plugin_instance.config.get("servers", {}) servers_list_str = servers_section.get("list", "[]") if isinstance(servers_section, dict) else "[]" - + try: existing_servers = json.loads(servers_list_str) if servers_list_str.strip() else [] except json.JSONDecodeError: existing_servers = [] - + existing_names = {srv.get("name", "") for srv in existing_servers if isinstance(srv, dict)} - + # 执行导入 result = ConfigConverter.import_from_string(content.strip(), existing_names) - + # 构建响应 lines = [] - + if not result.success: lines.append("❌ 导入失败:") for err in result.errors: lines.append(f" • {err}") await self.send_text("\n".join(lines)) return (True, None, True) - + if not result.servers: lines.append("⚠️ 没有新服务器可导入") if result.skipped: @@ -1588,44 +1592,44 @@ class MCPImportCommand(BaseCommand): lines.append(f" • {w}") await self.send_text("\n".join(lines)) return (True, None, True) - + # 合并到现有列表 new_servers = existing_servers + result.servers new_list_str = json.dumps(new_servers, ensure_ascii=False, indent=2) - + # 更新配置 if "servers" not in _plugin_instance.config: _plugin_instance.config["servers"] = {} _plugin_instance.config["servers"]["list"] = new_list_str - + # 保存到配置文件 _plugin_instance._save_servers_list(new_list_str) - + # 构建成功响应 lines.append(f"✅ 成功导入 {len(result.servers)} 个服务器:") for srv in result.servers: transport = srv.get("transport", "stdio") lines.append(f" • {srv.get('name')} ({transport})") - + if result.skipped: lines.append(f"\n⏭️ 跳过 {len(result.skipped)} 个:") for s in result.skipped[:5]: lines.append(f" • {s}") if len(result.skipped) > 5: lines.append(f" ... 还有 {len(result.skipped) - 5} 个") - + if result.warnings: lines.append("\n⚠️ 警告:") for w in result.warnings[:3]: lines.append(f" • {w}") - + if result.errors: lines.append("\n❌ 部分失败:") for e in result.errors[:3]: lines.append(f" • {e}") - + lines.append("\n💡 发送 /mcp reconnect 使配置生效") - + await self.send_text("\n".join(lines)) return (True, None, True) @@ -1634,56 +1638,57 @@ class MCPImportCommand(BaseCommand): # 事件处理器 # ============================================================================ + class MCPStartupHandler(BaseEventHandler): """MCP 启动事件处理器""" - + event_type = EventType.ON_START handler_name = "mcp_startup_handler" handler_description = "MCP 桥接插件启动处理器" weight = 0 intercept_message = False - + async def execute(self, message: Optional[Any]) -> Tuple[bool, bool, Optional[str], None, None]: """处理启动事件""" global _plugin_instance - + if _plugin_instance is None: logger.warning("MCP 桥接插件实例未初始化") return (False, True, None, None, None) - + logger.info("MCP 桥接插件收到 ON_START 事件,开始连接 MCP 服务器...") await _plugin_instance._async_connect_servers() - + await mcp_manager.start_heartbeat() - + # v1.6.0: 启动配置文件监控(用于 WebUI 导入) await _plugin_instance._start_config_watcher() - + return (True, True, None, None, None) class MCPStopHandler(BaseEventHandler): """MCP 停止事件处理器""" - + event_type = EventType.ON_STOP handler_name = "mcp_stop_handler" handler_description = "MCP 桥接插件停止处理器" weight = 0 intercept_message = False - + async def execute(self, message: Optional[Any]) -> Tuple[bool, bool, Optional[str], None, None]: """处理停止事件""" global _plugin_instance - + logger.info("MCP 桥接插件收到 ON_STOP 事件,正在关闭...") - + # v1.6.0: 停止配置文件监控 if _plugin_instance: await _plugin_instance._stop_config_watcher() - + await mcp_manager.shutdown() mcp_tool_registry.clear() - + logger.info("MCP 桥接插件已关闭所有连接") return (True, True, None, None, None) @@ -1692,16 +1697,17 @@ class MCPStopHandler(BaseEventHandler): # 主插件类 # ============================================================================ + @register_plugin class MCPBridgePlugin(BasePlugin): """MCP 桥接插件 v1.4.0 - 将 MCP 服务器的工具桥接到 MaiBot""" - + plugin_name: str = "mcp_bridge_plugin" enable_plugin: bool = False # 默认禁用,用户需在 WebUI 手动启用 dependencies: List[str] = [] python_dependencies: List[str] = ["mcp"] config_file_name: str = "config.toml" - + config_section_descriptions = { "guide": "📖 快速入门", "plugin": "🔘 插件开关", @@ -1713,7 +1719,7 @@ class MCPBridgePlugin(BasePlugin): "tools": "🔧 工具管理", "permissions": "🔐 权限控制", } - + config_schema: dict = { # 新手引导区(只读) "guide": { @@ -2116,9 +2122,9 @@ class MCPBridgePlugin(BasePlugin): label="📜 高级权限规则(可选)", input_type="textarea", rows=10, - placeholder='''[ + placeholder="""[ {"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]} -]''', +]""", hint="格式: qq:ID:group/private/user,工具名支持通配符 *", order=10, ), @@ -2217,43 +2223,43 @@ class MCPBridgePlugin(BasePlugin): ), }, } - + @staticmethod def _fix_config_multiline_strings(config_path: Path) -> bool: """修复配置文件中的多行字符串格式问题 - + 处理两种情况: 1. 带转义 \\n 的单行字符串(json.dumps 生成) 2. 跨越多行但使用普通双引号的字符串(控制字符错误) - + Returns: bool: 是否进行了修复 """ if not config_path.exists(): return False - + try: content = config_path.read_text(encoding="utf-8") - + # 情况1: 修复带转义 \n 的单行字符串 # 匹配: key = "内容包含\n的字符串" pattern1 = r'^(\s*\w+\s*=\s*)"((?:[^"\\]|\\.)*\\n(?:[^"\\]|\\.)*)"(\s*)$' - + # 情况2: 修复跨越多行的普通双引号字符串 # 匹配: key = "第一行 # 第二行 # 第三行" pattern2_start = r'^(\s*\w+\s*=\s*)"([^"]*?)$' # 开始行 pattern2_end = r'^([^"]*)"(\s*)$' # 结束行 - + lines = content.split("\n") fixed_lines = [] modified = False - + i = 0 while i < len(lines): line = lines[i] - + # 情况1: 单行带转义换行符 match1 = re.match(pattern1, line) if match1: @@ -2261,24 +2267,26 @@ class MCPBridgePlugin(BasePlugin): value = match1.group(2) suffix = match1.group(3) # 将转义的换行符还原为实际换行符 - unescaped = value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\") + unescaped = ( + value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\") + ) fixed_line = f'{prefix}"""{unescaped}"""{suffix}' fixed_lines.append(fixed_line) modified = True i += 1 continue - + # 情况2: 跨越多行的字符串 match2_start = re.match(pattern2_start, line) if match2_start: prefix = match2_start.group(1) first_part = match2_start.group(2) - + # 收集后续行直到找到结束引号 multiline_parts = [first_part] j = i + 1 found_end = False - + while j < len(lines): next_line = lines[j] match2_end = re.match(pattern2_end, next_line) @@ -2291,7 +2299,7 @@ class MCPBridgePlugin(BasePlugin): else: multiline_parts.append(next_line) j += 1 - + if found_end and len(multiline_parts) > 1: # 合并为三引号字符串 full_value = "\n".join(multiline_parts) @@ -2300,35 +2308,35 @@ class MCPBridgePlugin(BasePlugin): modified = True i = j continue - + fixed_lines.append(line) i += 1 - + if modified: config_path.write_text("\n".join(fixed_lines), encoding="utf-8") logger.info("已自动修复配置文件中的多行字符串格式") return True - + return False except Exception as e: logger.warning(f"修复配置文件格式失败: {e}") return False - + def __init__(self, *args, **kwargs): global _plugin_instance - + # 在父类初始化前尝试修复配置文件格式 config_path = Path(__file__).parent / "config.toml" self._fix_config_multiline_strings(config_path) - + super().__init__(*args, **kwargs) self._initialized = False _plugin_instance = self - + # 配置 MCP 管理器 settings = self.config.get("settings", {}) mcp_manager.configure(settings) - + # v1.4.0: 配置追踪器 trace_log_path = Path(__file__).parent / "logs" / "trace.jsonl" tool_call_tracer.configure( @@ -2337,7 +2345,7 @@ class MCPBridgePlugin(BasePlugin): log_enabled=settings.get("trace_log_enabled", False), log_path=trace_log_path, ) - + # v1.4.0: 配置缓存 tool_call_cache.configure( enabled=settings.get("cache_enabled", False), @@ -2345,7 +2353,7 @@ class MCPBridgePlugin(BasePlugin): max_entries=settings.get("cache_max_entries", 200), exclude_tools=settings.get("cache_exclude_tools", ""), ) - + # v1.4.0: 配置权限检查器 perm_config = self.config.get("permissions", {}) permission_checker.configure( @@ -2355,13 +2363,13 @@ class MCPBridgePlugin(BasePlugin): quick_deny_groups=perm_config.get("quick_deny_groups", ""), quick_allow_users=perm_config.get("quick_allow_users", ""), ) - + # 注册状态变化回调 mcp_manager.set_status_change_callback(self._update_status_display) - + # v1.6.0: 处理 WebUI 导入导出 self._process_webui_import_export() - + # v1.5.1: 处理快速添加服务器 self._process_quick_add_server() @@ -2560,7 +2568,7 @@ class MCPBridgePlugin(BasePlugin): continue import_config_str = str(import_config).strip() - logger.info(f"检测到 WebUI 导入请求,开始处理...") + logger.info("检测到 WebUI 导入请求,开始处理...") # 执行导入 await self._execute_webui_import(import_config_str, doc, config_path) @@ -2650,23 +2658,23 @@ class MCPBridgePlugin(BasePlugin): """v1.5.1: 处理快速添加服务器表单,将新服务器合并到列表""" quick_add = self.config.get("quick_add", {}) server_name = quick_add.get("server_name", "").strip() - + if not server_name: return # 没有填写名称,跳过 - + server_type = quick_add.get("server_type", "streamable_http") server_url = quick_add.get("server_url", "").strip() server_command = quick_add.get("server_command", "").strip() server_args_str = quick_add.get("server_args", "").strip() server_headers_str = quick_add.get("server_headers", "").strip() - + # 构建新服务器配置 new_server = { "name": server_name, "enabled": True, "transport": server_type, } - + if server_type == "stdio": if not server_command: logger.warning(f"快速添加: stdio 类型需要填写命令,跳过 {server_name}") @@ -2679,7 +2687,7 @@ class MCPBridgePlugin(BasePlugin): logger.warning(f"快速添加: {server_type} 类型需要填写 URL,跳过 {server_name}") return new_server["url"] = server_url - + # 解析鉴权头 if server_headers_str: try: @@ -2688,39 +2696,39 @@ class MCPBridgePlugin(BasePlugin): new_server["headers"] = headers except json.JSONDecodeError: logger.warning("快速添加: 鉴权头 JSON 格式错误,已忽略") - + # 获取现有服务器列表 servers_section = self.config.get("servers", {}) servers_list_str = servers_section.get("list", "[]") if isinstance(servers_section, dict) else "[]" - + try: servers_list = json.loads(servers_list_str) if servers_list_str.strip() else [] except json.JSONDecodeError: servers_list = [] - + # 检查是否已存在同名服务器 for existing in servers_list: if existing.get("name") == server_name: logger.info(f"快速添加: 服务器 {server_name} 已存在,跳过") self._clear_quick_add_fields() return - + # 添加新服务器 servers_list.append(new_server) logger.info(f"快速添加: 已添加服务器 {server_name} ({server_type})") - + # 更新配置 new_list_str = json.dumps(servers_list, ensure_ascii=False, indent=2) if "servers" not in self.config: self.config["servers"] = {} self.config["servers"]["list"] = new_list_str - + # 清空快速添加字段 self._clear_quick_add_fields() - + # 保存到配置文件 self._save_servers_list(new_list_str) - + def _clear_quick_add_fields(self) -> None: """清空快速添加表单字段""" if "quick_add" not in self.config: @@ -2730,25 +2738,25 @@ class MCPBridgePlugin(BasePlugin): self.config["quick_add"]["server_command"] = "" self.config["quick_add"]["server_args"] = "" self.config["quick_add"]["server_headers"] = "" - + def _save_servers_list(self, servers_json: str) -> None: """保存服务器列表到配置文件""" import tomlkit from tomlkit.items import String, StringType, Trivia - + try: config_path = Path(__file__).parent / "config.toml" if config_path.exists(): with open(config_path, "r", encoding="utf-8") as f: doc = tomlkit.load(f) - + if "servers" not in doc: doc["servers"] = tomlkit.table() - + # 使用多行字符串 ml_string = String(StringType.MLB, servers_json, servers_json, Trivia()) doc["servers"]["list"] = ml_string - + # 清空快速添加字段 if "quick_add" in doc: doc["quick_add"]["server_name"] = "" @@ -2756,27 +2764,28 @@ class MCPBridgePlugin(BasePlugin): doc["quick_add"]["server_command"] = "" doc["quick_add"]["server_args"] = "" doc["quick_add"]["server_headers"] = "" - + with open(config_path, "w", encoding="utf-8") as f: tomlkit.dump(doc, f) - + logger.info("服务器列表已保存到配置文件") except Exception as e: logger.warning(f"保存服务器列表失败: {e}") - + def _get_disabled_tools(self) -> set: """v1.4.0: 获取禁用的工具列表""" tools_config = self.config.get("tools", {}) disabled_str = tools_config.get("disabled_tools", "") return {t.strip() for t in disabled_str.strip().split("\n") if t.strip()} - + async def _async_connect_servers(self) -> None: """异步连接所有配置的 MCP 服务器(v1.5.0: 并行连接优化)""" import asyncio + settings = self.config.get("settings", {}) - + servers_section = self.config.get("servers", []) - + if isinstance(servers_section, dict): servers_list = servers_section.get("list", []) if isinstance(servers_list, str): @@ -2787,45 +2796,45 @@ class MCPBridgePlugin(BasePlugin): servers_config = [] else: servers_config = servers_section - + if not servers_config: logger.warning("未配置任何 MCP 服务器") self._initialized = True return - + auto_connect = settings.get("auto_connect", True) if not auto_connect: logger.info("auto_connect 已禁用,跳过自动连接") self._initialized = True return - + tool_prefix = settings.get("tool_prefix", "mcp") disabled_tools = self._get_disabled_tools() enable_resources = settings.get("enable_resources", False) enable_prompts = settings.get("enable_prompts", False) - + # 解析所有服务器配置 enabled_configs: List[MCPServerConfig] = [] for idx, server_conf in enumerate(servers_config): server_name = server_conf.get("name", f"unknown_{idx}") - + if not server_conf.get("enabled", True): logger.info(f"服务器 {server_name} 已禁用,跳过") continue - + try: config = self._parse_server_config(server_conf) enabled_configs.append(config) except Exception as e: logger.error(f"解析服务器 {server_name} 配置失败: {e}") - + if not enabled_configs: logger.warning("没有已启用的 MCP 服务器") self._initialized = True return - + logger.info(f"准备并行连接 {len(enabled_configs)} 个 MCP 服务器") - + # v1.5.0: 并行连接所有服务器 async def connect_single_server(config: MCPServerConfig) -> Tuple[MCPServerConfig, bool]: """连接单个服务器""" @@ -2851,15 +2860,12 @@ class MCPBridgePlugin(BasePlugin): except Exception as e: logger.error(f"❌ 服务器 {config.name} 连接异常: {e}") return config, False - + # 并行执行所有连接 start_time = time.time() - results = await asyncio.gather( - *[connect_single_server(cfg) for cfg in enabled_configs], - return_exceptions=True - ) + results = await asyncio.gather(*[connect_single_server(cfg) for cfg in enabled_configs], return_exceptions=True) connect_duration = time.time() - start_time - + # 统计连接结果 success_count = 0 failed_count = 0 @@ -2873,43 +2879,42 @@ class MCPBridgePlugin(BasePlugin): success_count += 1 else: failed_count += 1 - + logger.info(f"并行连接完成: {success_count} 成功, {failed_count} 失败, 耗时 {connect_duration:.2f}s") - + # 注册所有工具 from src.plugin_system.core.component_registry import component_registry + registered_count = 0 - + for tool_key, (tool_info, _) in mcp_manager.all_tools.items(): tool_name = tool_key.replace("-", "_").replace(".", "_") is_disabled = tool_name in disabled_tools - - info, tool_class = mcp_tool_registry.register_tool( - tool_key, tool_info, tool_prefix, disabled=is_disabled - ) + + info, tool_class = mcp_tool_registry.register_tool(tool_key, tool_info, tool_prefix, disabled=is_disabled) info.plugin_name = self.plugin_name - + if component_registry.register_component(info, tool_class): registered_count += 1 status = "🚫" if is_disabled else "✅" logger.info(f"{status} 注册 MCP 工具: {tool_class.name}") else: logger.warning(f"❌ 注册 MCP 工具失败: {tool_class.name}") - + self._initialized = True logger.info(f"MCP 桥接插件初始化完成,已注册 {registered_count} 个工具") - + # 更新状态显示 self._update_status_display() self._update_tool_list_display() - + def _parse_servers_json(self, servers_list: str) -> List[Dict]: """解析服务器列表 JSON 字符串""" if not servers_list.strip(): return [] - + content = servers_list.strip() - + try: parsed = json.loads(content) if isinstance(parsed, list): @@ -2922,7 +2927,7 @@ class MCPBridgePlugin(BasePlugin): return [] except json.JSONDecodeError as e: logger.warning(f"JSON 解析失败: {e}") - + if content.startswith("{") and not content.startswith("["): try: fixed_content = f"[{content}]" @@ -2932,14 +2937,14 @@ class MCPBridgePlugin(BasePlugin): return parsed except json.JSONDecodeError: pass - + logger.error("❌ 服务器配置 JSON 格式错误") return [] - + def _parse_server_config(self, conf: Dict) -> MCPServerConfig: """解析服务器配置字典""" transport_str = conf.get("transport", "stdio").lower() - + transport_map = { "stdio": TransportType.STDIO, "sse": TransportType.SSE, @@ -2947,7 +2952,7 @@ class MCPBridgePlugin(BasePlugin): "streamable_http": TransportType.STREAMABLE_HTTP, } transport = transport_map.get(transport_str, TransportType.STDIO) - + return MCPServerConfig( name=conf.get("name", "unnamed"), enabled=conf.get("enabled", True), @@ -2958,68 +2963,69 @@ class MCPBridgePlugin(BasePlugin): url=conf.get("url", ""), headers=conf.get("headers", {}), # v1.4.2: 鉴权头支持 ) - + def _update_tool_list_display(self) -> None: """v1.4.0: 更新工具列表显示""" import tomlkit - + tools = mcp_manager.all_tools disabled_tools = self._get_disabled_tools() - + lines = [] by_server: Dict[str, List[str]] = {} - + for tool_key, (tool_info, _) in tools.items(): tool_name = tool_key.replace("-", "_").replace(".", "_") if tool_info.server_name not in by_server: by_server[tool_info.server_name] = [] - + is_disabled = tool_name in disabled_tools status = " ❌" if is_disabled else "" by_server[tool_info.server_name].append(f" • {tool_name}{status}") - + for srv_name, tool_list in by_server.items(): lines.append(f"📦 {srv_name} ({len(tool_list)}个工具):") lines.extend(tool_list) lines.append("") - + if not by_server: lines.append("(无已注册工具)") - + tool_list_text = "\n".join(lines) - + # 更新内存配置 if "tools" not in self.config: self.config["tools"] = {} self.config["tools"]["tool_list"] = tool_list_text - + # 写入配置文件 try: config_path = Path(__file__).parent / "config.toml" if config_path.exists(): with open(config_path, "r", encoding="utf-8") as f: doc = tomlkit.load(f) - + if "tools" not in doc: doc["tools"] = tomlkit.table() # 使用 tomlkit 多行字符串避免控制字符问题 from tomlkit.items import String, StringType, Trivia + ml_string = String(StringType.MLB, tool_list_text, tool_list_text, Trivia()) doc["tools"]["tool_list"] = ml_string - + with open(config_path, "w", encoding="utf-8") as f: tomlkit.dump(doc, f) except Exception as e: logger.warning(f"更新工具列表显示失败: {e}") - + def _update_status_display(self) -> None: """更新配置文件中的状态显示字段""" import tomlkit - + status = mcp_manager.get_status() settings = self.config.get("settings", {}) lines = [] - + lines.append(f"服务器: {status['connected_servers']}/{status['total_servers']} 已连接") lines.append(f"工具数: {status['total_tools']}") if settings.get("enable_resources", False): @@ -3028,13 +3034,13 @@ class MCPBridgePlugin(BasePlugin): lines.append(f"模板数: {status.get('total_prompts', 0)}") lines.append(f"心跳: {'运行中' if status['heartbeat_running'] else '已停止'}") lines.append("") - + tools = mcp_manager.all_tools - + for name, info in status.get("servers", {}).items(): icon = "✅" if info["connected"] else "❌" lines.append(f"{icon} {name} ({info['transport']})") - + # v1.7.0: 显示断路器状态 cb_status = info.get("circuit_breaker", {}) cb_state = cb_status.get("state", "closed") @@ -3042,53 +3048,54 @@ class MCPBridgePlugin(BasePlugin): lines.append(" ⚡ 断路器: 熔断中") elif cb_state == "half_open": lines.append(" ⚡ 断路器: 试探中") - + server_tools = [t.name for key, (t, _) in tools.items() if t.server_name == name] if server_tools: for tool_name in server_tools: lines.append(f" • {tool_name}") else: lines.append(" (无工具)") - + if not status.get("servers"): lines.append("(无服务器)") - + status_text = "\n".join(lines) - + if "status" not in self.config: self.config["status"] = {} self.config["status"]["connection_status"] = status_text - + try: config_path = Path(__file__).parent / "config.toml" if config_path.exists(): with open(config_path, "r", encoding="utf-8") as f: doc = tomlkit.load(f) - + if "status" not in doc: doc["status"] = tomlkit.table() # 使用 tomlkit 多行字符串避免控制字符问题 from tomlkit.items import String, StringType, Trivia + ml_string = String(StringType.MLB, status_text, status_text, Trivia()) doc["status"]["connection_status"] = ml_string - + with open(config_path, "w", encoding="utf-8") as f: tomlkit.dump(doc, f) except Exception as e: logger.warning(f"更新配置文件状态失败: {e}") - + def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: """返回插件的所有组件""" components: List[Tuple[ComponentInfo, Type]] = [] - + # 事件处理器 components.append((MCPStartupHandler.get_handler_info(), MCPStartupHandler)) components.append((MCPStopHandler.get_handler_info(), MCPStopHandler)) - + # 命令 components.append((MCPStatusCommand.get_command_info(), MCPStatusCommand)) components.append((MCPImportCommand.get_command_info(), MCPImportCommand)) - + # 内置工具 status_tool_info = ToolInfo( name=MCPStatusTool.name, @@ -3098,9 +3105,9 @@ class MCPBridgePlugin(BasePlugin): component_type=ComponentType.TOOL, ) components.append((status_tool_info, MCPStatusTool)) - + settings = self.config.get("settings", {}) - + if settings.get("enable_resources", False): read_resource_info = ToolInfo( name=MCPReadResourceTool.name, @@ -3110,7 +3117,7 @@ class MCPBridgePlugin(BasePlugin): component_type=ComponentType.TOOL, ) components.append((read_resource_info, MCPReadResourceTool)) - + if settings.get("enable_prompts", False): get_prompt_info = ToolInfo( name=MCPGetPromptTool.name, @@ -3120,9 +3127,9 @@ class MCPBridgePlugin(BasePlugin): component_type=ComponentType.TOOL, ) components.append((get_prompt_info, MCPGetPromptTool)) - + return components - + def get_status(self) -> Dict[str, Any]: """获取插件状态""" return { @@ -3132,7 +3139,7 @@ class MCPBridgePlugin(BasePlugin): "trace_records": tool_call_tracer.total_records, "cache_stats": tool_call_cache.get_stats(), } - + def get_stats(self) -> Dict[str, Any]: """获取详细统计信息""" return mcp_manager.get_all_stats() diff --git a/plugins/MaiBot_MCPBridgePlugin/test_mcp_client.py b/plugins/MaiBot_MCPBridgePlugin/test_mcp_client.py index d2264314..fc968ae3 100644 --- a/plugins/MaiBot_MCPBridgePlugin/test_mcp_client.py +++ b/plugins/MaiBot_MCPBridgePlugin/test_mcp_client.py @@ -23,22 +23,22 @@ from mcp_client import ( async def test_stats(): """测试统计类""" print("\n=== 测试统计类 ===") - + # 测试 ToolCallStats stats = ToolCallStats(tool_key="test_tool") stats.record_call(True, 100.0) stats.record_call(True, 200.0) stats.record_call(False, 50.0, "timeout") - + assert stats.total_calls == 3 assert stats.success_calls == 2 assert stats.failed_calls == 1 - assert stats.success_rate == (2/3) * 100 + assert stats.success_rate == (2 / 3) * 100 assert stats.avg_duration_ms == 150.0 assert stats.last_error == "timeout" - + print(f"✅ ToolCallStats: {stats.to_dict()}") - + # 测试 ServerStats server_stats = ServerStats(server_name="test_server") server_stats.record_connect() @@ -46,133 +46,138 @@ async def test_stats(): server_stats.record_disconnect() server_stats.record_failure() server_stats.record_failure() - + assert server_stats.connect_count == 1 assert server_stats.disconnect_count == 1 assert server_stats.consecutive_failures == 2 - + print(f"✅ ServerStats: {server_stats.to_dict()}") - + return True async def test_manager_basic(): """测试管理器基本功能""" print("\n=== 测试管理器基本功能 ===") - + # 创建新的管理器实例(绕过单例) manager = MCPClientManager.__new__(MCPClientManager) manager._initialized = False manager.__init__() - + # 配置 - manager.configure({ - "tool_prefix": "mcp", - "call_timeout": 30.0, - "retry_attempts": 1, - "retry_interval": 1.0, - "heartbeat_enabled": False, - }) - + manager.configure( + { + "tool_prefix": "mcp", + "call_timeout": 30.0, + "retry_attempts": 1, + "retry_interval": 1.0, + "heartbeat_enabled": False, + } + ) + # 测试状态 status = manager.get_status() assert status["total_servers"] == 0 assert status["connected_servers"] == 0 print(f"✅ 初始状态: {status}") - + # 测试添加禁用的服务器 config = MCPServerConfig( - name="disabled_server", - enabled=False, - transport=TransportType.HTTP, - url="https://example.com/mcp" + name="disabled_server", enabled=False, transport=TransportType.HTTP, url="https://example.com/mcp" ) result = await manager.add_server(config) assert result == True assert "disabled_server" in manager._clients assert manager._clients["disabled_server"].is_connected == False print("✅ 添加禁用服务器成功") - + # 测试重复添加 result = await manager.add_server(config) assert result == False print("✅ 重复添加被拒绝") - + # 测试移除 result = await manager.remove_server("disabled_server") assert result == True assert "disabled_server" not in manager._clients print("✅ 移除服务器成功") - + # 清理 await manager.shutdown() print("✅ 管理器关闭成功") - + return True async def test_http_connection(): """测试 HTTP 连接(使用真实的 MCP 服务器)""" print("\n=== 测试 HTTP 连接 ===") - + # 创建新的管理器实例 manager = MCPClientManager.__new__(MCPClientManager) manager._initialized = False manager.__init__() - - manager.configure({ - "tool_prefix": "mcp", - "call_timeout": 30.0, - "retry_attempts": 2, - "retry_interval": 2.0, - "heartbeat_enabled": False, - }) - + + manager.configure( + { + "tool_prefix": "mcp", + "call_timeout": 30.0, + "retry_attempts": 2, + "retry_interval": 2.0, + "heartbeat_enabled": False, + } + ) + # 使用 HowToCook MCP 服务器测试 config = MCPServerConfig( name="howtocook", enabled=True, transport=TransportType.HTTP, - url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp" + url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp", ) - + print(f"正在连接 {config.url} ...") result = await manager.add_server(config) - + if result: - print(f"✅ 连接成功!") - + print("✅ 连接成功!") + # 检查工具 tools = manager.all_tools print(f"✅ 发现 {len(tools)} 个工具:") for tool_key in tools: print(f" - {tool_key}") - + # 测试心跳 client = manager._clients["howtocook"] healthy = await client.check_health() print(f"✅ 心跳检测: {'健康' if healthy else '异常'}") - + # 测试工具调用 if "mcp_howtocook_whatToEat" in tools: print("\n正在调用 whatToEat 工具...") call_result = await manager.call_tool("mcp_howtocook_whatToEat", {}) if call_result.success: print(f"✅ 工具调用成功 (耗时: {call_result.duration_ms:.0f}ms)") - print(f" 结果: {call_result.content[:200]}..." if len(str(call_result.content)) > 200 else f" 结果: {call_result.content}") + print( + f" 结果: {call_result.content[:200]}..." + if len(str(call_result.content)) > 200 + else f" 结果: {call_result.content}" + ) else: print(f"❌ 工具调用失败: {call_result.error}") - + # 查看统计 stats = manager.get_all_stats() - print(f"\n📊 统计信息:") + print("\n📊 统计信息:") print(f" 全局调用: {stats['global']['total_tool_calls']}") print(f" 成功: {stats['global']['successful_calls']}") print(f" 失败: {stats['global']['failed_calls']}") - + else: - print(f"❌ 连接失败") - + print("❌ 连接失败") + # 清理 await manager.shutdown() return result @@ -181,55 +186,57 @@ async def test_http_connection(): async def test_heartbeat(): """测试心跳检测功能""" print("\n=== 测试心跳检测 ===") - + # 创建新的管理器实例 manager = MCPClientManager.__new__(MCPClientManager) manager._initialized = False manager.__init__() - - manager.configure({ - "tool_prefix": "mcp", - "call_timeout": 30.0, - "retry_attempts": 1, - "retry_interval": 1.0, - "heartbeat_enabled": True, - "heartbeat_interval": 5.0, # 5秒间隔用于测试 - "auto_reconnect": True, - "max_reconnect_attempts": 2, - }) - + + manager.configure( + { + "tool_prefix": "mcp", + "call_timeout": 30.0, + "retry_attempts": 1, + "retry_interval": 1.0, + "heartbeat_enabled": True, + "heartbeat_interval": 5.0, # 5秒间隔用于测试 + "auto_reconnect": True, + "max_reconnect_attempts": 2, + } + ) + # 添加一个测试服务器 config = MCPServerConfig( name="heartbeat_test", enabled=True, transport=TransportType.HTTP, - url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp" + url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp", ) - + print("正在连接服务器...") result = await manager.add_server(config) - + if result: print("✅ 服务器连接成功") - + # 启动心跳检测 await manager.start_heartbeat() print("✅ 心跳检测已启动") - + # 等待一个心跳周期 print("等待心跳检测...") await asyncio.sleep(2) - + # 检查状态 status = manager.get_status() print(f"✅ 心跳运行状态: {status['heartbeat_running']}") - + # 停止心跳 await manager.stop_heartbeat() print("✅ 心跳检测已停止") else: print("❌ 服务器连接失败,跳过心跳测试") - + await manager.shutdown() return True @@ -239,30 +246,31 @@ async def main(): print("=" * 50) print("MCP 客户端测试") print("=" * 50) - + try: # 基础测试 await test_stats() await test_manager_basic() - + # 网络测试 print("\n是否进行网络连接测试? (需要网络) [y/N]: ", end="") # 自动进行网络测试 await test_http_connection() - + # 心跳测试 await test_heartbeat() - + print("\n" + "=" * 50) print("✅ 所有测试通过!") print("=" * 50) - + except Exception as e: print(f"\n❌ 测试失败: {e}") import traceback + traceback.print_exc() return False - + return True diff --git a/scripts/replyer_action_stats.py b/scripts/replyer_action_stats.py index f155d98e..8d8904bf 100644 --- a/scripts/replyer_action_stats.py +++ b/scripts/replyer_action_stats.py @@ -35,13 +35,13 @@ def get_chat_name(chat_id: str) -> str: return f"{chat_stream.group_name}" elif chat_stream.user_nickname: return f"{chat_stream.user_nickname}的私聊" - + if get_chat_manager: chat_manager = get_chat_manager() stream_name = chat_manager.get_stream_name(chat_id) if stream_name: return stream_name - + return f"未知聊天 ({chat_id[:8]}...)" except Exception: return f"查询失败 ({chat_id[:8]}...)" @@ -51,11 +51,11 @@ def load_records(temp_dir: str = "data/temp") -> List[Dict[str, Any]]: """加载所有 replyer 动作记录""" records = [] temp_path = Path(temp_dir) - + if not temp_path.exists(): print(f"目录不存在: {temp_dir}") return records - + # 查找所有 replyer_action_*.json 文件 pattern = "replyer_action_*.json" for file_path in temp_path.glob(pattern): @@ -65,7 +65,7 @@ def load_records(temp_dir: str = "data/temp") -> List[Dict[str, Any]]: records.append(data) except Exception as e: print(f"读取文件失败 {file_path}: {e}") - + # 按时间戳排序 records.sort(key=lambda x: x.get("timestamp", "")) return records @@ -91,7 +91,7 @@ def calculate_time_distribution(records: List[Dict[str, Any]]) -> Dict[str, int] "30天内": 0, "更早": 0, } - + for record in records: try: ts = record.get("timestamp", "") @@ -99,7 +99,7 @@ def calculate_time_distribution(records: List[Dict[str, Any]]) -> Dict[str, int] continue dt = datetime.fromisoformat(ts) diff = (now - dt).days - + if diff == 0: distribution["今天"] += 1 elif diff == 1: @@ -114,7 +114,7 @@ def calculate_time_distribution(records: List[Dict[str, Any]]) -> Dict[str, int] distribution["更早"] += 1 except Exception: pass - + return distribution @@ -123,17 +123,17 @@ def print_statistics(records: List[Dict[str, Any]]): if not records: print("没有找到任何记录") return - + print("=" * 80) print("Replyer 动作选择记录统计") print("=" * 80) print() - + # 总记录数 total_count = len(records) print(f"📊 总记录数: {total_count}") print() - + # 时间范围 timestamps = [r.get("timestamp", "") for r in records if r.get("timestamp")] if timestamps: @@ -141,7 +141,7 @@ def print_statistics(records: List[Dict[str, Any]]): last_time = format_timestamp(max(timestamps)) print(f"📅 时间范围: {first_time} ~ {last_time}") print() - + # 按 think_level 统计 think_levels = [r.get("think_level", 0) for r in records] think_level_counter = Counter(think_levels) @@ -152,7 +152,7 @@ def print_statistics(records: List[Dict[str, Any]]): level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})") print(f" Level {level} ({level_name}): {count} 次 ({percentage:.1f}%)") print() - + # 按 chat_id 统计(总体) chat_counter = Counter([r.get("chat_id", "未知") for r in records]) print(f"💬 聊天分布 (共 {len(chat_counter)} 个聊天):") @@ -164,30 +164,30 @@ def print_statistics(records: List[Dict[str, Any]]): if len(chat_counter) > 10: print(f" ... 还有 {len(chat_counter) - 10} 个聊天") print() - + # 每个 chat_id 的详细统计 print("=" * 80) print("每个聊天的详细统计") print("=" * 80) print() - + # 按 chat_id 分组记录 records_by_chat = defaultdict(list) for record in records: chat_id = record.get("chat_id", "未知") records_by_chat[chat_id].append(record) - + # 按记录数排序 sorted_chats = sorted(records_by_chat.items(), key=lambda x: len(x[1]), reverse=True) - + for chat_id, chat_records in sorted_chats: chat_name = get_chat_name(chat_id) chat_count = len(chat_records) chat_percentage = (chat_count / total_count) * 100 - + print(f"📱 {chat_name} ({chat_id[:8]}...)") print(f" 总记录数: {chat_count} ({chat_percentage:.1f}%)") - + # 该聊天的 think_level 分布 chat_think_levels = [r.get("think_level", 0) for r in chat_records] chat_think_counter = Counter(chat_think_levels) @@ -197,14 +197,14 @@ def print_statistics(records: List[Dict[str, Any]]): level_percentage = (level_count / chat_count) * 100 level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})") print(f" Level {level} ({level_name}): {level_count} 次 ({level_percentage:.1f}%)") - + # 该聊天的时间范围 chat_timestamps = [r.get("timestamp", "") for r in chat_records if r.get("timestamp")] if chat_timestamps: first_time = format_timestamp(min(chat_timestamps)) last_time = format_timestamp(max(chat_timestamps)) print(f" 时间范围: {first_time} ~ {last_time}") - + # 该聊天的时间分布 chat_time_dist = calculate_time_distribution(chat_records) print(" 时间分布:") @@ -212,7 +212,7 @@ def print_statistics(records: List[Dict[str, Any]]): if count > 0: period_percentage = (count / chat_count) * 100 print(f" {period}: {count} 次 ({period_percentage:.1f}%)") - + # 显示该聊天最近的一条理由示例 if chat_records: latest_record = chat_records[-1] @@ -222,9 +222,9 @@ def print_statistics(records: List[Dict[str, Any]]): timestamp = format_timestamp(latest_record.get("timestamp", "")) think_level = latest_record.get("think_level", 0) print(f" 最新记录 [{timestamp}] (Level {think_level}): {reason}") - + print() - + # 时间分布 time_dist = calculate_time_distribution(records) print("⏰ 时间分布:") @@ -233,7 +233,7 @@ def print_statistics(records: List[Dict[str, Any]]): percentage = (count / total_count) * 100 print(f" {period}: {count} 次 ({percentage:.1f}%)") print() - + # 显示一些示例理由 print("📝 示例理由 (最近5条):") recent_records = records[-5:] @@ -243,29 +243,29 @@ def print_statistics(records: List[Dict[str, Any]]): timestamp = format_timestamp(record.get("timestamp", "")) chat_id = record.get("chat_id", "未知") chat_name = get_chat_name(chat_id) - + # 截断过长的理由 if len(reason) > 100: reason = reason[:100] + "..." - + print(f" {i}. [{timestamp}] {chat_name} (Level {think_level})") print(f" {reason}") print() - + # 按 think_level 分组显示理由示例 print("=" * 80) print("按思考深度分类的示例理由") print("=" * 80) print() - + for level in [0, 1, 2]: level_records = [r for r in records if r.get("think_level") == level] if not level_records: continue - + level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})") print(f"Level {level} ({level_name}) - 共 {len(level_records)} 条:") - + # 显示3个示例(选择最近的) examples = level_records[-3:] if len(level_records) >= 3 else level_records for i, record in enumerate(examples, 1): @@ -278,7 +278,7 @@ def print_statistics(records: List[Dict[str, Any]]): print(f" {i}. [{timestamp}] {chat_name}") print(f" {reason}") print() - + # 统计信息汇总 print("=" * 80) print("统计汇总") @@ -301,4 +301,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/src/bw_learner/expression_learner.py b/src/bw_learner/expression_learner.py index 71866dea..4b306fe2 100644 --- a/src/bw_learner/expression_learner.py +++ b/src/bw_learner/expression_learner.py @@ -13,7 +13,12 @@ from src.chat.utils.chat_message_builder import ( ) from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.message_receive.chat_stream import get_chat_manager -from src.bw_learner.learner_utils import filter_message_content, is_bot_message, build_context_paragraph, contains_bot_self_name +from src.bw_learner.learner_utils import ( + filter_message_content, + is_bot_message, + build_context_paragraph, + contains_bot_self_name, +) from src.bw_learner.jargon_miner import miner_manager from json_repair import repair_json @@ -77,8 +82,6 @@ def init_prompt() -> None: Prompt(learn_style_prompt, "learn_style_prompt") - - class ExpressionLearner: def __init__(self, chat_id: str) -> None: self.express_learn_model: LLMRequest = LLMRequest( @@ -95,12 +98,12 @@ class ExpressionLearner: self._learning_lock = asyncio.Lock() async def learn_and_store( - self, + self, messages: List[Any], ) -> List[Tuple[str, str, str]]: """ 学习并存储表达方式 - + Args: messages: 外部传入的消息列表(必需) num: 学习数量 @@ -108,7 +111,7 @@ class ExpressionLearner: """ if not messages: return None - + random_msg = messages # 学习用(开启行编号,便于溯源) @@ -134,26 +137,26 @@ class ExpressionLearner: jargon_entries: List[Tuple[str, str]] # (content, source_id) expressions, jargon_entries = self.parse_expression_response(response) expressions = self._filter_self_reference_styles(expressions) - + # 检查表达方式数量,如果超过10个则放弃本次表达学习 if len(expressions) > 10: logger.info(f"表达方式提取数量超过10个(实际{len(expressions)}个),放弃本次表达学习") expressions = [] - + # 检查黑话数量,如果超过30个则放弃本次黑话学习 if len(jargon_entries) > 30: logger.info(f"黑话提取数量超过30个(实际{len(jargon_entries)}个),放弃本次黑话学习") jargon_entries = [] - + # 处理黑话条目,路由到 jargon_miner(即使没有表达方式也要处理黑话) if jargon_entries: await self._process_jargon_entries(jargon_entries, random_msg) - + # 如果没有表达方式,直接返回 if not expressions: logger.info("过滤后没有可用的表达方式(style 与机器人名称重复)") return [] - + logger.info(f"学习的prompt: {prompt}") logger.info(f"学习的expressions: {expressions}") logger.info(f"学习的jargon_entries: {jargon_entries}") @@ -175,18 +178,17 @@ class ExpressionLearner: # 当前行的原始内容 current_msg = random_msg[line_index] - + # 过滤掉从bot自己发言中提取到的表达方式 if is_bot_message(current_msg): continue - + context = filter_message_content(current_msg.processed_plain_text or "") if not context: continue filtered_expressions.append((situation, style, context)) - - + learnt_expressions = filtered_expressions if learnt_expressions is None: @@ -270,37 +272,38 @@ class ExpressionLearner: # 如果解析失败,尝试修复中文引号问题 # 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号 try: + def fix_chinese_quotes_in_json(text): """使用状态机修复 JSON 字符串值中的中文引号""" result = [] i = 0 in_string = False escape_next = False - + while i < len(text): char = text[i] - + if escape_next: # 当前字符是转义字符后的字符,直接添加 result.append(char) escape_next = False i += 1 continue - - if char == '\\': + + if char == "\\": # 转义字符 result.append(char) escape_next = True i += 1 continue - + if char == '"' and not escape_next: # 遇到英文引号,切换字符串状态 in_string = not in_string result.append(char) i += 1 continue - + if in_string: # 在字符串值内部,将中文引号替换为转义的英文引号 if char == '"': # 中文左引号 U+201C @@ -312,13 +315,13 @@ class ExpressionLearner: else: # 不在字符串内,直接添加 result.append(char) - + i += 1 - - return ''.join(result) - + + return "".join(result) + fixed_raw = fix_chinese_quotes_in_json(raw) - + # 再次尝试解析 if fixed_raw.startswith("[") and fixed_raw.endswith("]"): parsed = json.loads(fixed_raw) @@ -346,12 +349,12 @@ class ExpressionLearner: for item in parsed_list: if not isinstance(item, dict): continue - + # 检查是否是表达方式条目(有 situation 和 style) situation = str(item.get("situation", "")).strip() style = str(item.get("style", "")).strip() source_id = str(item.get("source_id", "")).strip() - + if situation and style and source_id: # 表达方式条目 expressions.append((situation, style, source_id)) @@ -503,59 +506,59 @@ class ExpressionLearner: async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None: """ 处理从 expression learner 提取的黑话条目,路由到 jargon_miner - + Args: jargon_entries: 黑话条目列表,每个元素是 (content, source_id) messages: 消息列表,用于构建上下文 """ if not jargon_entries or not messages: return - + # 获取 jargon_miner 实例 jargon_miner = miner_manager.get_miner(self.chat_id) - + # 构建黑话条目格式,与 jargon_miner.run_once 中的格式一致 entries: List[Dict[str, List[str]]] = [] - + for content, source_id in jargon_entries: content = content.strip() if not content: continue - + # 检查是否包含机器人名称 if contains_bot_self_name(content): logger.info(f"跳过包含机器人昵称/别名的黑话: {content}") continue - + # 解析 source_id source_id_str = (source_id or "").strip() if not source_id_str.isdigit(): logger.warning(f"黑话条目 source_id 无效: content={content}, source_id={source_id_str}") continue - + # build_anonymous_messages 的编号从 1 开始 line_index = int(source_id_str) - 1 if line_index < 0 or line_index >= len(messages): logger.warning(f"黑话条目 source_id 超出范围: content={content}, source_id={source_id_str}") continue - + # 检查是否是机器人自己的消息 target_msg = messages[line_index] if is_bot_message(target_msg): logger.info(f"跳过引用机器人自身消息的黑话: content={content}, source_id={source_id_str}") continue - + # 构建上下文段落 context_paragraph = build_context_paragraph(messages, line_index) if not context_paragraph: logger.warning(f"黑话条目上下文为空: content={content}, source_id={source_id_str}") continue - + entries.append({"content": content, "raw_content": [context_paragraph]}) - + if not entries: return - + # 调用 jargon_miner 处理这些条目 await jargon_miner.process_extracted_entries(entries) diff --git a/src/bw_learner/expression_reflector.py b/src/bw_learner/expression_reflector.py index 5c165b9f..c627b5b7 100644 --- a/src/bw_learner/expression_reflector.py +++ b/src/bw_learner/expression_reflector.py @@ -82,9 +82,7 @@ class ExpressionReflector: # 获取未检查的表达 try: logger.info("[Expression Reflection] 查询未检查且未拒绝的表达") - expressions = ( - Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50) - ) + expressions = Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50) expr_list = list(expressions) logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达") diff --git a/src/bw_learner/expression_selector.py b/src/bw_learner/expression_selector.py index 931c5eb5..386d4fdf 100644 --- a/src/bw_learner/expression_selector.py +++ b/src/bw_learner/expression_selector.py @@ -128,9 +128,7 @@ class ExpressionSelector: # 查询所有相关chat_id的表达方式,排除 rejected=1 的,且只选择 count > 1 的 style_query = Expression.select().where( - (Expression.chat_id.in_(related_chat_ids)) - & (~Expression.rejected) - & (Expression.count > 1) + (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1) ) style_exprs = [ @@ -150,12 +148,15 @@ class ExpressionSelector: # 要求至少有10个 count > 1 的表达方式才进行选择 min_required = 10 if len(style_exprs) < min_required: - logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择") + logger.info( + f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择" + ) return [], [] # 固定选择5个 select_count = 5 import random + selected_style = random.sample(style_exprs, select_count) # 更新last_active_time @@ -163,7 +164,9 @@ class ExpressionSelector: self.update_expressions_last_active_time(selected_style) selected_ids = [expr["id"] for expr in selected_style] - logger.debug(f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)} 个") + logger.debug( + f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)} 个" + ) return selected_style, selected_ids except Exception as e: @@ -186,9 +189,7 @@ class ExpressionSelector: related_chat_ids = self.get_related_chat_ids(chat_id) # 优化:一次性查询所有相关chat_id的表达方式,排除 rejected=1 的表达 - style_query = Expression.select().where( - (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) - ) + style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)) style_exprs = [ { @@ -246,7 +247,9 @@ class ExpressionSelector: # 使用classic模式(随机选择+LLM选择) logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式,think_level={think_level}") - return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message, reply_reason, think_level) + return await self._select_expressions_classic( + chat_id, chat_info, max_num, target_message, reply_reason, think_level + ) async def _select_expressions_classic( self, @@ -275,14 +278,12 @@ class ExpressionSelector: # think_level == 0: 只选择 count > 1 的项目,随机选10个,不进行LLM选择 if think_level == 0: return self._select_expressions_simple(chat_id, max_num) - + # think_level == 1: 先选高count,再从所有表达方式中随机抽样 # 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的 related_chat_ids = self.get_related_chat_ids(chat_id) - style_query = Expression.select().where( - (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) - ) - + style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)) + all_style_exprs = [ { "id": expr.id, @@ -299,29 +300,33 @@ class ExpressionSelector: # 分离 count > 1 和 count <= 1 的表达方式 high_count_exprs = [expr for expr in all_style_exprs if (expr.get("count", 1) or 1) > 1] - + # 根据 think_level 设置要求(仅支持 0/1,0 已在上方返回) min_high_count = 10 min_total_count = 10 select_high_count = 5 select_random_count = 5 - + # 检查数量要求 if len(high_count_exprs) < min_high_count: - logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择") + logger.info( + f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择" + ) return [], [] - + if len(all_style_exprs) < min_total_count: - logger.info(f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择") + logger.info( + f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择" + ) return [], [] - + # 先选取高count的表达方式 selected_high = weighted_sample(high_count_exprs, min(len(high_count_exprs), select_high_count)) - + # 然后从所有表达方式中随机抽样(使用加权抽样) remaining_num = select_random_count selected_random = weighted_sample(all_style_exprs, min(len(all_style_exprs), remaining_num)) - + # 合并候选池(去重,避免重复) candidate_exprs = selected_high.copy() candidate_ids = {expr["id"] for expr in candidate_exprs} @@ -329,9 +334,10 @@ class ExpressionSelector: if expr["id"] not in candidate_ids: candidate_exprs.append(expr) candidate_ids.add(expr["id"]) - + # 打乱顺序,避免高count的都在前面 import random + random.shuffle(candidate_exprs) # 2. 构建所有表达方式的索引和情境列表 @@ -351,7 +357,7 @@ class ExpressionSelector: all_situations_str = "\n".join(all_situations) if target_message: - target_message_str = f",现在你想要对这条消息进行回复:\"{target_message}\"" + target_message_str = f',现在你想要对这条消息进行回复:"{target_message}"' target_message_extra_block = "4.考虑你要回复的目标消息" else: target_message_str = "" diff --git a/src/bw_learner/jargon_explainer.py b/src/bw_learner/jargon_explainer.py index ac62fa5f..207a080a 100644 --- a/src/bw_learner/jargon_explainer.py +++ b/src/bw_learner/jargon_explainer.py @@ -8,7 +8,12 @@ from src.llm_models.utils_model import LLMRequest from src.config.config import model_config, global_config from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.bw_learner.jargon_miner import search_jargon -from src.bw_learner.learner_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains +from src.bw_learner.learner_utils import ( + is_bot_message, + contains_bot_self_name, + parse_chat_id_list, + chat_id_list_contains, +) logger = get_logger("jargon") @@ -357,4 +362,4 @@ async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> st if results: return "【概念检索结果】\n" + "\n".join(results) + "\n" - return "" \ No newline at end of file + return "" diff --git a/src/bw_learner/jargon_miner.py b/src/bw_learner/jargon_miner.py index 275d9cbd..fb3ce6b4 100644 --- a/src/bw_learner/jargon_miner.py +++ b/src/bw_learner/jargon_miner.py @@ -1,4 +1,3 @@ -import time import json import asyncio import random @@ -14,7 +13,6 @@ from src.config.config import model_config, global_config from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import ( build_readable_messages_with_id, - get_raw_msg_by_timestamp_with_chat_inclusive, ) from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.bw_learner.learner_utils import ( @@ -33,23 +31,23 @@ logger = get_logger("jargon") def _is_single_char_jargon(content: str) -> bool: """ 判断是否是单字黑话(单个汉字、英文或数字) - + Args: content: 词条内容 - + Returns: bool: 如果是单字黑话返回True,否则返回False """ if not content or len(content) != 1: return False - + char = content[0] # 判断是否是单个汉字、单个英文字母或单个数字 return ( - '\u4e00' <= char <= '\u9fff' or # 汉字 - 'a' <= char <= 'z' or # 小写字母 - 'A' <= char <= 'Z' or # 大写字母 - '0' <= char <= '9' # 数字 + "\u4e00" <= char <= "\u9fff" # 汉字 + or "a" <= char <= "z" # 小写字母 + or "A" <= char <= "Z" # 大写字母 + or "0" <= char <= "9" # 数字 ) @@ -195,7 +193,7 @@ class JargonMiner: model_set=model_config.model_task_config.utils, request_type="jargon.extract", ) - + self.llm_inference = LLMRequest( model_set=model_config.model_task_config.utils, request_type="jargon.inference", @@ -207,7 +205,7 @@ class JargonMiner: self.stream_name = stream_name if stream_name else self.chat_id self.cache_limit = 50 self.cache: OrderedDict[str, None] = OrderedDict() - + # 黑话提取锁,防止并发执行 self._extraction_lock = asyncio.Lock() @@ -299,17 +297,19 @@ class JargonMiner: # 获取当前count和上一次的meaning current_count = jargon_obj.count or 0 previous_meaning = jargon_obj.meaning or "" - + # 当count为24, 60时,随机移除一半的raw_content项目 if current_count in [24, 60] and len(raw_content_list) > 1: # 计算要保留的数量(至少保留1个) keep_count = max(1, len(raw_content_list) // 2) raw_content_list = random.sample(raw_content_list, keep_count) - logger.info(f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目") + logger.info( + f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目" + ) # 步骤1: 基于raw_content和content推断 raw_content_text = "\n".join(raw_content_list) - + # 当count为24, 60, 100时,在prompt中放入上一次推断出的meaning作为参考 previous_meaning_section = "" previous_meaning_instruction = "" @@ -318,8 +318,10 @@ class JargonMiner: **上一次推断的含义(仅供参考)** {previous_meaning} """ - previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果" - + previous_meaning_instruction = ( + "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果" + ) + prompt1 = await global_prompt_manager.format_prompt( "jargon_inference_with_context_prompt", content=content, @@ -481,7 +483,7 @@ class JargonMiner: async def run_once(self, messages: List[Any]) -> None: """ 运行一次黑话提取 - + Args: messages: 外部传入的消息列表(必需) """ @@ -650,7 +652,9 @@ class JargonMiner: if obj.raw_content: try: existing_raw_content = ( - json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content + json.loads(obj.raw_content) + if isinstance(obj.raw_content, str) + else obj.raw_content ) if not isinstance(existing_raw_content, list): existing_raw_content = [existing_raw_content] if existing_raw_content else [] @@ -726,13 +730,13 @@ class JargonMiner: async def process_extracted_entries(self, entries: List[Dict[str, List[str]]]) -> None: """ 处理已提取的黑话条目(从 expression_learner 路由过来的) - + Args: entries: 黑话条目列表,每个元素格式为 {"content": "...", "raw_content": [...]} """ if not entries: return - + try: # 去重并合并raw_content(按 content 聚合) merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict() @@ -876,8 +880,6 @@ class JargonMinerManager: miner_manager = JargonMinerManager() - - def search_jargon( keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True ) -> List[Dict[str, str]]: diff --git a/src/bw_learner/message_recorder.py b/src/bw_learner/message_recorder.py index c49da951..b31ab153 100644 --- a/src/bw_learner/message_recorder.py +++ b/src/bw_learner/message_recorder.py @@ -15,25 +15,25 @@ class MessageRecorder: """ 统一的消息记录器,负责管理时间窗口和消息提取,并将消息分发给 expression_learner 和 jargon_miner """ - + def __init__(self, chat_id: str) -> None: self.chat_id = chat_id self.chat_stream = get_chat_manager().get_stream(chat_id) self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id - + # 维护每个chat的上次提取时间 self.last_extraction_time: float = time.time() - + # 提取锁,防止并发执行 self._extraction_lock = asyncio.Lock() - + # 获取 expression 和 jargon 的配置参数 self._init_parameters() - + # 获取 expression_learner 和 jargon_miner 实例 self.expression_learner = expression_learner_manager.get_expression_learner(chat_id) self.jargon_miner = miner_manager.get_miner(chat_id) - + def _init_parameters(self) -> None: """初始化提取参数""" # 获取 expression 配置 @@ -42,17 +42,17 @@ class MessageRecorder: ) self.min_messages_for_extraction = 30 self.min_extraction_interval = 60 - + logger.debug( f"MessageRecorder 初始化: chat_id={self.chat_id}, " f"min_messages={self.min_messages_for_extraction}, " f"min_interval={self.min_extraction_interval}" ) - + def should_trigger_extraction(self) -> bool: """ 检查是否应该触发消息提取 - + Returns: bool: 是否应该触发提取 """ @@ -60,19 +60,19 @@ class MessageRecorder: time_diff = time.time() - self.last_extraction_time if time_diff < self.min_extraction_interval: return False - + # 检查消息数量 recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_extraction_time, timestamp_end=time.time(), ) - + if not recent_messages or len(recent_messages) < self.min_messages_for_extraction: return False - + return True - + async def extract_and_distribute(self) -> None: """ 提取消息并分发给 expression_learner 和 jargon_miner @@ -82,41 +82,40 @@ class MessageRecorder: # 在锁内检查,避免并发触发 if not self.should_trigger_extraction(): return - + # 检查 chat_stream 是否存在 if not self.chat_stream: return - + # 记录本次提取的时间窗口,避免重复提取 extraction_start_time = self.last_extraction_time extraction_end_time = time.time() - + # 立即更新提取时间,防止并发触发 self.last_extraction_time = extraction_end_time - + try: logger.info(f"在聊天流 {self.chat_name} 开始统一消息提取和分发") - + # 拉取提取窗口内的消息 messages = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=extraction_start_time, timestamp_end=extraction_end_time, ) - + if not messages: logger.debug(f"聊天流 {self.chat_name} 没有新消息,跳过提取") return - + # 按时间排序,确保顺序一致 messages = sorted(messages, key=lambda msg: msg.time or 0) - + logger.info( f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息," f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}" ) - - + # 分别触发 expression_learner 和 jargon_miner 的处理 # 传递提取的消息,避免它们重复获取 # 触发 expression 学习(如果启用) @@ -124,28 +123,26 @@ class MessageRecorder: asyncio.create_task( self._trigger_expression_learning(extraction_start_time, extraction_end_time, messages) ) - + # 触发 jargon 提取(如果启用),传递消息 # if self.enable_jargon_learning: - # asyncio.create_task( - # self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages) - # ) - + # asyncio.create_task( + # self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages) + # ) + except Exception as e: logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}") import traceback + traceback.print_exc() # 即使失败也保持时间戳更新,避免频繁重试 - + async def _trigger_expression_learning( - self, - timestamp_start: float, - timestamp_end: float, - messages: List[Any] + self, timestamp_start: float, timestamp_end: float, messages: List[Any] ) -> None: """ 触发 expression 学习,使用指定的消息列表 - + Args: timestamp_start: 开始时间戳 timestamp_end: 结束时间戳 @@ -154,7 +151,7 @@ class MessageRecorder: try: # 传递消息给 ExpressionLearner(必需参数) learnt_style = await self.expression_learner.learn_and_store(messages=messages) - + if learnt_style: logger.info(f"聊天流 {self.chat_name} 表达学习完成") else: @@ -162,17 +159,15 @@ class MessageRecorder: except Exception as e: logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}") import traceback + traceback.print_exc() - + async def _trigger_jargon_extraction( - self, - timestamp_start: float, - timestamp_end: float, - messages: List[Any] + self, timestamp_start: float, timestamp_end: float, messages: List[Any] ) -> None: """ 触发 jargon 提取,使用指定的消息列表 - + Args: timestamp_start: 开始时间戳 timestamp_end: 结束时间戳 @@ -181,19 +176,20 @@ class MessageRecorder: try: # 传递消息给 JargonMiner,避免它重复获取 await self.jargon_miner.run_once(messages=messages) - + except Exception as e: logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}") import traceback + traceback.print_exc() class MessageRecorderManager: """MessageRecorder 管理器""" - + def __init__(self) -> None: self._recorders: dict[str, MessageRecorder] = {} - + def get_recorder(self, chat_id: str) -> MessageRecorder: """获取或创建指定 chat_id 的 MessageRecorder""" if chat_id not in self._recorders: @@ -208,10 +204,9 @@ recorder_manager = MessageRecorderManager() async def extract_and_distribute_messages(chat_id: str) -> None: """ 统一的消息提取和分发入口函数 - + Args: chat_id: 聊天流ID """ recorder = recorder_manager.get_recorder(chat_id) await recorder.extract_and_distribute() - diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py index ae05e5dd..4a22628e 100644 --- a/src/chat/brain_chat/brain_chat.py +++ b/src/chat/brain_chat/brain_chat.py @@ -176,19 +176,19 @@ class BrainChatting: # 如果有新消息,更新 last_read_time if len(recent_messages_list) >= 1: self.last_read_time = time.time() - + # 总是执行一次思考迭代(不管有没有新消息) # wait 动作会在其内部等待,不需要在这里处理 should_continue = await self._observe(recent_messages_list=recent_messages_list) - + if not should_continue: # 选择了 complete_talk,返回 False 表示需要等待新消息 return False - + # 继续下一次迭代(除非选择了 complete_talk) # 短暂等待后再继续,避免过于频繁的循环 await asyncio.sleep(0.1) - + return True async def _send_and_store_reply( @@ -328,9 +328,7 @@ class BrainChatting: ) # 检查是否有 complete_talk 动作(会停止后续迭代) - has_complete_talk = any( - action.action_type == "complete_talk" for action in action_to_use_info - ) + has_complete_talk = any(action.action_type == "complete_talk" for action in action_to_use_info) # 并行执行所有动作 action_tasks = [ @@ -430,12 +428,12 @@ class BrainChatting: await asyncio.sleep(3) self._loop_task = asyncio.create_task(self._main_chat_loop()) logger.error(f"{self.log_prefix} 结束了当前聊天循环") - + async def _wait_for_new_message(self): """等待新消息到达""" last_check_time = self.last_read_time check_interval = 1.0 # 每秒检查一次 - + while self.running: # 检查是否有新消息 recent_messages_list = message_api.get_messages_by_time_in_chat( @@ -448,13 +446,13 @@ class BrainChatting: filter_command=False, filter_intercept_message_level=1, ) - + # 如果有新消息,更新 last_read_time 并返回 if len(recent_messages_list) >= 1: self.last_read_time = time.time() logger.info(f"{self.log_prefix} 检测到新消息,恢复循环") return - + # 等待一段时间后再次检查 await asyncio.sleep(check_interval) @@ -660,9 +658,9 @@ class BrainChatting: except (ValueError, TypeError): logger.warning(f"{self.log_prefix} wait_seconds 参数格式错误,使用默认值 5 秒") wait_seconds = 5 - + logger.info(f"{self.log_prefix} 执行 wait 动作,等待 {wait_seconds} 秒") - + # 记录动作信息 await database_api.store_action_info( chat_stream=self.chat_stream, @@ -673,12 +671,12 @@ class BrainChatting: action_data={"reason": reason, "wait_seconds": wait_seconds}, action_name="wait", ) - + # 等待指定时间 await asyncio.sleep(wait_seconds) - + logger.info(f"{self.log_prefix} wait 动作完成,继续下一次思考") - + # 这些动作本身不产生文本回复 self._last_successful_reply = False return { @@ -693,9 +691,9 @@ class BrainChatting: logger.debug(f"{self.log_prefix} 检测到 listening 动作,已合并到 wait,自动转换") # 使用默认等待时间 wait_seconds = 3 - + logger.info(f"{self.log_prefix} 执行 listening(转换为 wait)动作,等待 {wait_seconds} 秒") - + # 记录动作信息 await database_api.store_action_info( chat_stream=self.chat_stream, @@ -706,12 +704,12 @@ class BrainChatting: action_data={"reason": reason, "wait_seconds": wait_seconds}, action_name="listening", ) - + # 等待指定时间 await asyncio.sleep(wait_seconds) - + logger.info(f"{self.log_prefix} listening 动作完成,继续下一次思考") - + # 这些动作本身不产生文本回复 self._last_successful_reply = False return { diff --git a/src/chat/brain_chat/brain_planner.py b/src/chat/brain_chat/brain_planner.py index 7da5dc29..14d28575 100644 --- a/src/chat/brain_chat/brain_planner.py +++ b/src/chat/brain_chat/brain_planner.py @@ -147,7 +147,7 @@ class BrainPlanner: ) # 用于动作规划 self.last_obs_time_mark = 0.0 - + # 计划日志记录 self.plan_log: List[Tuple[str, float, List[ActionPlannerInfo]]] = [] @@ -203,9 +203,11 @@ class BrainPlanner: # 内部保留动作(不依赖插件系统) # 注意:listening 已合并到 wait 中,如果遇到 listening 则转换为 wait internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"] - - logger.debug(f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}") - + + logger.debug( + f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}" + ) + # 将 listening 转换为 wait(向后兼容) if action == "listening": logger.debug(f"{self.log_prefix}检测到 listening 动作,已合并到 wait,自动转换") @@ -521,7 +523,7 @@ class BrainPlanner: if json_objects: logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象") for i, json_obj in enumerate(json_objects): - logger.info(f"{self.log_prefix}解析第{i+1}个JSON对象: {json_obj}") + logger.info(f"{self.log_prefix}解析第{i + 1}个JSON对象: {json_obj}") filtered_actions_list = list(filtered_actions.items()) for json_obj in json_objects: parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list) @@ -553,7 +555,9 @@ class BrainPlanner: return extracted_reasoning, actions - def _create_complete_talk(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]: + def _create_complete_talk( + self, reasoning: str, available_actions: Dict[str, ActionInfo] + ) -> List[ActionPlannerInfo]: """创建complete_talk""" return [ ActionPlannerInfo( @@ -564,7 +568,7 @@ class BrainPlanner: available_actions=available_actions, ) ] - + def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]): """添加计划日志""" self.plan_log.append((reasoning, time.time(), actions)) diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index f7d132a3..af12bb1a 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -271,7 +271,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]: emoji.description = emoji_data.description # Deserialize emotion string from DB to list - emoji.emotion = emoji_data.emotion.replace(",",",").split(",") if emoji_data.emotion else [] + emoji.emotion = emoji_data.emotion.replace(",", ",").split(",") if emoji_data.emotion else [] emoji.usage_count = emoji_data.usage_count db_last_used_time = emoji_data.last_used_time @@ -732,7 +732,7 @@ class EmojiManager: emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash) if emoji_record and emoji_record.emotion: logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...") - return emoji_record.emotion.replace(",",",").split(",") + return emoji_record.emotion.replace(",", ",").split(",") except Exception as e: logger.error(f"从数据库查询表情包情感标签时出错: {e}") @@ -993,7 +993,7 @@ class EmojiManager: ) # 处理情感列表 - emotions = [e.strip() for e in emotions_text.replace(",",",").split(",") if e.strip()] + emotions = [e.strip() for e in emotions_text.replace(",", ",").split(",") if e.strip()] # 根据情感标签数量随机选择 - 超过5个选3个,超过2个选2个 if len(emotions) > 5: diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index f700ff47..5fbddb7c 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -619,13 +619,13 @@ class HeartFChatting: think_level = 0 # 使用 action_reasoning(planner 的整体思考理由)作为 reply_reason planner_reasoning = action_planner_info.action_reasoning or reason - + record_replyer_action_temp( chat_id=self.stream_id, reason=reason, think_level=think_level, ) - + await database_api.store_action_info( chat_stream=self.chat_stream, action_build_into_prompt=False, diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 14c49d04..c45ec105 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -123,7 +123,11 @@ class ChatBot: logger.warning(f"命令执行失败: {command_class.__name__} - {response}") # 根据命令的拦截设置决定是否继续处理消息 - return True, response, not bool(intercept_message_level) # 找到命令,根据intercept_message决定是否继续 + return ( + True, + response, + not bool(intercept_message_level), + ) # 找到命令,根据intercept_message决定是否继续 except Exception as e: logger.error(f"执行命令时出错: {command_class.__name__} - {e}") diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index f06e3d1b..d093e07e 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -263,7 +263,7 @@ class MessageRecv(Message): desc = segment.data.get("desc", "") # 内容描述 source_url = segment.data.get("source_url", "") # 原始链接 url = segment.data.get("url", "") # 小程序链接 - text = f"[小程序分享" + text = "[小程序分享" if title: text += f" - {title}" text += "]" diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 3e33511f..a45ac515 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -42,22 +42,21 @@ def is_webui_virtual_group(group_id: str) -> bool: def parse_message_segments(segment) -> list: """解析消息段,转换为 WebUI 可用的格式 - + 参考 NapCat 适配器的消息解析逻辑 - + Args: segment: Seg 消息段对象 - + Returns: list: 消息段列表,每个元素为 {"type": "...", "data": ...} """ - from maim_message import Seg - + result = [] - + if segment is None: return result - + if segment.type == "seglist": # 处理消息段列表 if segment.data: @@ -112,15 +111,19 @@ def parse_message_segments(segment) -> list: forward_items = [] if segment.data: for item in segment.data: - forward_items.append({ - "content": parse_message_segments(item.get("message_segment", {})) if isinstance(item, dict) else [] - }) + forward_items.append( + { + "content": parse_message_segments(item.get("message_segment", {})) + if isinstance(item, dict) + else [] + } + ) result.append({"type": "forward", "data": forward_items}) else: # 未知类型,尝试作为文本处理 if segment.data: result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)}) - + return result @@ -134,7 +137,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: # 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息 chat_manager, webui_platform = get_webui_chat_broadcaster() is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id) - + if is_webui_message and chat_manager is not None: # WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播 import time @@ -142,7 +145,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: # 解析消息段,获取富文本内容 message_segments = parse_message_segments(message.message_segment) - + # 判断消息类型 # 如果只有一个文本段,使用简单的 text 类型 # 否则使用 rich 类型,包含完整的消息段 diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 2c58de7d..bc20c552 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -77,8 +77,7 @@ target_message_id为必填,表示触发消息的id ```""", "planner_prompt", ) - - + Prompt( """ {action_name} diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 57f64dba..3c4c5440 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -250,7 +250,12 @@ class DefaultReplyer: # 使用从处理器传来的选中表达方式 # 使用模型预测选择表达方式 selected_expressions, selected_ids = await expression_selector.select_suitable_expressions( - self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason, think_level=think_level + self.chat_stream.stream_id, + chat_history, + max_num=8, + target_message=target, + reply_reason=reply_reason, + think_level=think_level, ) if selected_expressions: @@ -273,7 +278,6 @@ class DefaultReplyer: return f"{expression_habits_title}\n{expression_habits_block}", selected_ids - async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str: """构建工具信息块 @@ -788,7 +792,8 @@ class DefaultReplyer: # 并行执行八个构建任务(包括黑话解释) task_results = await asyncio.gather( self._time_and_run_task( - self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level), "expression_habits" + self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level), + "expression_habits", ), self._time_and_run_task( self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info" @@ -980,7 +985,6 @@ class DefaultReplyer: else: reply_target_block = "" - chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1") chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2") diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index 9ff340d9..a9b78c28 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -287,7 +287,6 @@ class PrivateReplyer: return f"{expression_habits_title}\n{expression_habits_block}", selected_ids - async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str: """构建工具信息块 @@ -907,16 +906,11 @@ class PrivateReplyer: else: reply_target_block = "" - chat_target_name = "对方" if self.chat_target_info: chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方" - chat_target_1 = await global_prompt_manager.format_prompt( - "chat_target_private1", sender_name=chat_target_name - ) - chat_target_2 = await global_prompt_manager.format_prompt( - "chat_target_private2", sender_name=chat_target_name - ) + chat_target_1 = await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name) + chat_target_2 = await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name) template_name = "default_expressor_prompt" diff --git a/src/chat/replyer/prompt/replyer_private_prompt.py b/src/chat/replyer/prompt/replyer_private_prompt.py index c251d1a1..7dfd54e7 100644 --- a/src/chat/replyer/prompt/replyer_private_prompt.py +++ b/src/chat/replyer/prompt/replyer_private_prompt.py @@ -1,8 +1,9 @@ from src.chat.utils.prompt_builder import Prompt + def init_replyer_private_prompt(): Prompt( - """{knowledge_prompt}{tool_info_block}{extra_info_block} + """{knowledge_prompt}{tool_info_block}{extra_info_block} {expression_habits_block}{memory_retrieval}{jargon_explanation} 你正在和{sender_name}聊天,这是你们之前聊的内容: @@ -17,9 +18,9 @@ def init_replyer_private_prompt(): {reply_style} 请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。 {moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""", - "private_replyer_prompt", - ) - + "private_replyer_prompt", + ) + Prompt( """{knowledge_prompt}{tool_info_block}{extra_info_block} {expression_habits_block}{memory_retrieval}{jargon_explanation} @@ -37,4 +38,4 @@ def init_replyer_private_prompt(): {moderation_prompt}不要输出多余内容(包括冒号和引号,括号,表情包,at或 @等 )。 """, "private_replyer_self_prompt", - ) \ No newline at end of file + ) diff --git a/src/chat/replyer/prompt/replyer_prompt.py b/src/chat/replyer/prompt/replyer_prompt.py index e17e07f5..b8649692 100644 --- a/src/chat/replyer/prompt/replyer_prompt.py +++ b/src/chat/replyer/prompt/replyer_prompt.py @@ -23,7 +23,7 @@ def init_replyer_prompt(): 现在,你说:""", "replyer_prompt_0", ) - + Prompt( """{knowledge_prompt}{tool_info_block}{extra_info_block} {expression_habits_block}{memory_retrieval}{jargon_explanation} @@ -44,4 +44,3 @@ def init_replyer_prompt(): 现在,你说:""", "replyer_prompt", ) - diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 6a634fb4..156322ae 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -311,7 +311,10 @@ def get_raw_msg_before_timestamp_with_chat( filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}} sort_order = [("time", 1)] return find_messages( - message_filter=filter_query, sort=sort_order, limit=limit, filter_intercept_message_level=filter_intercept_message_level + message_filter=filter_query, + sort=sort_order, + limit=limit, + filter_intercept_message_level=filter_intercept_message_level, ) diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 4a115e06..20c0843b 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -746,7 +746,7 @@ class StatisticOutputTask(AsyncTask): data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}" total_replies = stats.get(TOTAL_REPLY_CNT, 0) - + output = [ "按模型分类统计:", " 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数", @@ -759,11 +759,11 @@ class StatisticOutputTask(AsyncTask): cost = stats[COST_BY_MODEL][model_name] avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name] std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name] - + # 计算每次回复平均值 avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0 avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0 - + # 格式化大数字 formatted_count = _format_large_number(count) formatted_in_tokens = _format_large_number(in_tokens) @@ -771,7 +771,7 @@ class StatisticOutputTask(AsyncTask): formatted_tokens = _format_large_number(tokens) formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A" formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A" - + output.append( data_fmt.format( name, @@ -800,7 +800,7 @@ class StatisticOutputTask(AsyncTask): data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}" total_replies = stats.get(TOTAL_REPLY_CNT, 0) - + output = [ "按模块分类统计:", " 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数", @@ -813,11 +813,11 @@ class StatisticOutputTask(AsyncTask): cost = stats[COST_BY_MODULE][module_name] avg_time_cost = stats[AVG_TIME_COST_BY_MODULE][module_name] std_time_cost = stats[STD_TIME_COST_BY_MODULE][module_name] - + # 计算每次回复平均值 avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0 avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0 - + # 格式化大数字 formatted_count = _format_large_number(count) formatted_in_tokens = _format_large_number(in_tokens) @@ -825,7 +825,7 @@ class StatisticOutputTask(AsyncTask): formatted_tokens = _format_large_number(tokens) formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A" formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A" - + output.append( data_fmt.format( name, diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 67089fe6..bb0c01f7 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -646,7 +646,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> None: """ 临时记录replyer动作被选择的信息(仅群聊) - + Args: chat_id: 聊天ID reason: 选择理由 @@ -656,7 +656,7 @@ def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> N # 确保data/temp目录存在 temp_dir = "data/temp" os.makedirs(temp_dir, exist_ok=True) - + # 创建记录数据 record_data = { "chat_id": chat_id, @@ -664,16 +664,16 @@ def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> N "think_level": think_level, "timestamp": datetime.now().isoformat(), } - + # 生成文件名(使用时间戳避免冲突) timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S_%f") filename = f"replyer_action_{timestamp_str}.json" filepath = os.path.join(temp_dir, filename) - + # 写入文件 with open(filepath, "w", encoding="utf-8") as f: json.dump(record_data, f, ensure_ascii=False, indent=2) - + logger.debug(f"已记录replyer动作选择: chat_id={chat_id}, think_level={think_level}") except Exception as e: logger.warning(f"记录replyer动作选择失败: {e}") diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 56a3ae0f..1145cc83 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -130,12 +130,10 @@ class ImageManager: try: # 清理Images表中type为emoji的记录 deleted_images = Images.delete().where(Images.type == "emoji").execute() - + # 清理ImageDescriptions表中type为emoji的记录 - deleted_descriptions = ( - ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute() - ) - + deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute() + total_deleted = deleted_images + deleted_descriptions if total_deleted > 0: logger.info( @@ -166,7 +164,7 @@ class ImageManager: async def _save_emoji_file_if_needed(self, image_base64: str, image_hash: str, image_format: str) -> None: """如果启用了steal_emoji且表情包未注册,保存文件到data/emoji目录 - + Args: image_base64: 图片的base64编码 image_hash: 图片的MD5哈希值 @@ -174,7 +172,7 @@ class ImageManager: """ if not global_config.emoji.steal_emoji: return - + try: from src.chat.emoji_system.emoji_manager import EMOJI_DIR from src.chat.emoji_system.emoji_manager import get_emoji_manager @@ -236,12 +234,16 @@ class ImageManager: # 优先使用情感标签,如果没有则使用详细描述 result_text = "" if cache_record.emotion_tags: - logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...") + logger.info( + f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..." + ) result_text = f"[表情包:{cache_record.emotion_tags}]" elif cache_record.description: - logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...") + logger.info( + f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..." + ) result_text = f"[表情包:{cache_record.description}]" - + # 即使缓存命中,如果启用了steal_emoji,也检查是否需要保存文件 if result_text: await self._save_emoji_file_if_needed(image_base64, image_hash, image_format) diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 02a581cd..4d930f60 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -609,23 +609,23 @@ def _fix_table_constraints(table_name, model, constraints_to_fix): fields = list(model._meta.fields.keys()) # Peewee 默认使用 'id' 作为主键字段名 # 尝试获取主键字段名,如果获取失败则默认使用 'id' - primary_key_name = 'id' # 默认值 + primary_key_name = "id" # 默认值 try: - if hasattr(model._meta, 'primary_key') and model._meta.primary_key: - if hasattr(model._meta.primary_key, 'name'): + if hasattr(model._meta, "primary_key") and model._meta.primary_key: + if hasattr(model._meta.primary_key, "name"): primary_key_name = model._meta.primary_key.name elif isinstance(model._meta.primary_key, str): primary_key_name = model._meta.primary_key except Exception: pass # 如果获取失败,使用默认值 'id' - + # 如果字段列表包含主键,则排除它 if primary_key_name in fields: fields_without_pk = [f for f in fields if f != primary_key_name] logger.info(f"排除主键字段 '{primary_key_name}',让数据库自动生成新的主键") else: fields_without_pk = fields - + fields_str = ", ".join(fields_without_pk) # 检查是否有字段需要从 NULL 改为 NOT NULL diff --git a/src/common/toml_utils.py b/src/common/toml_utils.py index 0a88c458..6a7b8bb9 100644 --- a/src/common/toml_utils.py +++ b/src/common/toml_utils.py @@ -34,7 +34,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any: return obj # 决定是否多行:仅在顶层且长度超过阈值时 - should_multiline = (depth == 0 and len(obj) > threshold) + should_multiline = depth == 0 and len(obj) > threshold # 如果已经是 tomlkit Array,原地修改以保留注释 if isinstance(obj, Array): @@ -46,7 +46,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any: # 普通 list:转换为 tomlkit 数组 arr = tomlkit.array() arr.multiline(should_multiline) - + for item in obj: arr.append(_format_toml_value(item, threshold, depth + 1)) return arr @@ -112,7 +112,7 @@ def save_toml_with_format( formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data output = tomlkit.dumps(formatted) # 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积 - output = re.sub(r'\n{3,}', '\n\n', output) + output = re.sub(r"\n{3,}", "\n\n", output) with open(file_path, "w", encoding="utf-8") as f: f.write(output) @@ -122,4 +122,4 @@ def format_toml_string(data: Any, multiline_threshold: int = 1) -> str: formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data output = tomlkit.dumps(formatted) # 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积 - return re.sub(r'\n{3,}', '\n\n', output) \ No newline at end of file + return re.sub(r"\n{3,}", "\n\n", output) diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 55528d51..bfd6ea5c 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -778,9 +778,9 @@ class DreamConfig(ConfigBase): """ if not self.dream_time_ranges: return True - + now_min = self._now_minutes() - + for time_range in self.dream_time_ranges: if not isinstance(time_range, str): continue @@ -790,7 +790,7 @@ class DreamConfig(ConfigBase): start_min, end_min = parsed if self._in_range(now_min, start_min, end_min): return True - + return False def __post_init__(self): @@ -800,4 +800,4 @@ class DreamConfig(ConfigBase): if self.max_iterations < 1: raise ValueError(f"max_iterations 必须至少为1,当前值: {self.max_iterations}") if self.first_delay_seconds < 0: - raise ValueError(f"first_delay_seconds 不能为负数,当前值: {self.first_delay_seconds}") \ No newline at end of file + raise ValueError(f"first_delay_seconds 不能为负数,当前值: {self.first_delay_seconds}") diff --git a/src/dream/dream_agent.py b/src/dream/dream_agent.py index c14f8061..b516a88e 100644 --- a/src/dream/dream_agent.py +++ b/src/dream/dream_agent.py @@ -1,14 +1,13 @@ import asyncio import random import time -import json from typing import Any, Dict, List, Optional, Tuple from peewee import fn from src.common.logger import get_logger from src.config.config import global_config, model_config -from src.common.database.database_model import ChatHistory, Jargon +from src.common.database.database_model import ChatHistory from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message from src.plugin_system.apis import llm_api @@ -82,7 +81,6 @@ def init_dream_prompts() -> None: ) - class DreamTool: """dream 模块内部使用的简易工具封装""" @@ -150,7 +148,13 @@ def init_dream_tools(chat_id: str) -> None: "search_chat_history", "根据关键词或参与人查询当前 chat_id 下的 ChatHistory 概览,便于快速定位相关记忆。", [ - ("keyword", ToolParamType.STRING, "关键词(可选,支持多个关键词,可用空格、逗号等分隔)。", False, None), + ( + "keyword", + ToolParamType.STRING, + "关键词(可选,支持多个关键词,可用空格、逗号等分隔)。", + False, + None, + ), ("participant", ToolParamType.STRING, "参与人昵称(可选)。", False, None), ], search_chat_history, @@ -201,8 +205,20 @@ def init_dream_tools(chat_id: str) -> None: [ ("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None), ("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None), - ("keywords", ToolParamType.STRING, "新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。", True, None), - ("key_point", ToolParamType.STRING, "新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。", True, None), + ( + "keywords", + ToolParamType.STRING, + "新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。", + True, + None, + ), + ( + "key_point", + ToolParamType.STRING, + "新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。", + True, + None, + ), ("start_time", ToolParamType.STRING, "起始时间戳(秒,Unix 时间,必填)。", True, None), ("end_time", ToolParamType.STRING, "结束时间戳(秒,Unix 时间,必填)。", True, None), ], @@ -215,7 +231,13 @@ def init_dream_tools(chat_id: str) -> None: "finish_maintenance", "结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理、合并或修改的内容时,调用此工具来主动结束本次运行。", [ - ("reason", ToolParamType.STRING, "结束维护的原因说明(可选),例如 '已完成所有记录的整理' 或 '当前记录质量良好,无需进一步维护'。", False, None), + ( + "reason", + ToolParamType.STRING, + "结束维护的原因说明(可选),例如 '已完成所有记录的整理' 或 '当前记录质量良好,无需进一步维护'。", + False, + None, + ), ], finish_maintenance, ) @@ -246,7 +268,7 @@ async def run_dream_agent_once( """ if max_iterations is None: max_iterations = global_config.dream.max_iterations - + start_ts = time.time() logger.info(f"[dream] 开始对 chat_id={chat_id} 进行 dream 维护,最多迭代 {max_iterations} 轮") @@ -282,9 +304,7 @@ async def run_dream_agent_once( else "未知" ) end_time_str = ( - time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) - if record.end_time - else "未知" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知" ) detail_text = ( f"ID={record.id}\n" @@ -305,8 +325,7 @@ async def run_dream_agent_once( start_detail_builder = MessageBuilder() start_detail_builder.set_role(RoleType.User) start_detail_builder.add_text_content( - "【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n" - + detail_text + "【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n" + detail_text ) conversation_messages.append(start_detail_builder.build()) else: @@ -343,13 +362,17 @@ async def run_dream_agent_once( conversation_messages.append(round_info_builder.build()) # 调用 LLM 让其决定是否要使用工具 - success, response, reasoning_content, model_name, tool_calls = ( - await llm_api.generate_with_model_with_tools_by_message_factory( - message_factory, - model_config=model_config.model_task_config.tool_use, - tool_options=tool_defs, - request_type="dream.react", - ) + ( + success, + response, + reasoning_content, + model_name, + tool_calls, + ) = await llm_api.generate_with_model_with_tools_by_message_factory( + message_factory, + model_config=model_config.model_task_config.tool_use, + tool_options=tool_defs, + request_type="dream.react", ) if not success: @@ -522,7 +545,7 @@ async def start_dream_scheduler( if interval_seconds is None: interval_seconds = global_config.dream.interval_minutes * 60 - + logger.info( f"[dream] dream 调度器启动:首次延迟 {first_delay_seconds}s,之后每隔 {interval_seconds}s ({interval_seconds // 60} 分钟) 运行一次 dream agent" ) @@ -555,4 +578,3 @@ async def start_dream_scheduler( # 初始化提示词 init_dream_prompts() - diff --git a/src/dream/dream_generator.py b/src/dream/dream_generator.py index 174d1b69..945aebf9 100644 --- a/src/dream/dream_generator.py +++ b/src/dream/dream_generator.py @@ -86,7 +86,7 @@ async def generate_dream_summary( try: import json from src.chat.utils.prompt_builder import global_prompt_manager - + # 第一步:建立工具调用结果映射 (call_id -> result) tool_results_map: dict[str, str] = {} for msg in conversation_messages: @@ -98,11 +98,11 @@ async def generate_dream_summary( else: content = str(msg.content) tool_results_map[msg.tool_call_id] = content - + # 第二步:详细记录所有工具调用操作和结果到日志 tool_call_count = 0 logger.info(f"[dream][工具调用详情] 开始记录 chat_id={chat_id} 的所有工具调用操作:") - + for msg in conversation_messages: if msg.role == RoleType.Assistant and msg.tool_calls: tool_call_count += 1 @@ -110,34 +110,38 @@ async def generate_dream_summary( thought_content = "" if msg.content: if isinstance(msg.content, list) and msg.content: - thought_content = msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0]) + thought_content = ( + msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0]) + ) else: thought_content = str(msg.content) - + logger.info(f"[dream][工具调用详情] === 第 {tool_call_count} 组工具调用 ===") if thought_content: - logger.info(f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}") - + logger.info( + f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}" + ) + # 记录每个工具调用的详细信息 for idx, tool_call in enumerate(msg.tool_calls, 1): tool_name = tool_call.func_name tool_args = tool_call.args or {} tool_call_id = tool_call.call_id tool_result = tool_results_map.get(tool_call_id, "未找到执行结果") - + # 格式化参数 try: args_str = json.dumps(tool_args, ensure_ascii=False, indent=2) if tool_args else "无参数" except Exception: args_str = str(tool_args) - + logger.info(f"[dream][工具调用详情] --- 工具 {idx}: {tool_name} ---") logger.info(f"[dream][工具调用详情] 调用参数:\n{args_str}") logger.info(f"[dream][工具调用详情] 执行结果:\n{tool_result}") logger.info(f"[dream][工具调用详情] {'-' * 60}") - + logger.info(f"[dream][工具调用详情] 共记录了 {tool_call_count} 组工具调用操作") - + # 第三步:构建对话历史摘要(用于生成梦境) conversation_summary = [] for msg in conversation_messages: @@ -145,11 +149,11 @@ async def generate_dream_summary( content = "" if msg.content: content = msg.content[0].text if isinstance(msg.content, list) and msg.content else str(msg.content) - + if role == "user" and "轮次信息" in content: # 跳过轮次信息消息 continue - + if role == "assistant": # 只保留思考内容,简化工具调用信息 if content: @@ -162,13 +166,13 @@ async def generate_dream_summary( # 截取前300字符 content_preview = content[:300] + ("..." if len(content) > 300 else "") conversation_summary.append(f"[工具执行] {content_preview}") - + conversation_text = "\n".join(conversation_summary[-20:]) # 只保留最后20条消息 - + # 随机选择2个梦境风格 selected_styles = get_random_dream_styles(2) - dream_styles_text = "\n".join([f"{i+1}. {style}" for i, style in enumerate(selected_styles)]) - + dream_styles_text = "\n".join([f"{i + 1}. {style}" for i, style in enumerate(selected_styles)]) + # 使用 Prompt 管理器格式化梦境生成 prompt dream_prompt = await global_prompt_manager.format_prompt( "dream_summary_prompt", @@ -186,13 +190,14 @@ async def generate_dream_summary( max_tokens=512, temperature=0.8, ) - + if dream_content: logger.info(f"[dream][梦境总结] 对 chat_id={chat_id} 的整理过程梦境:\n{dream_content}") else: logger.warning("[dream][梦境总结] 未能生成梦境总结") - + except Exception as e: logger.error(f"[dream][梦境总结] 生成梦境总结失败: {e}", exc_info=True) -init_dream_summary_prompt() \ No newline at end of file + +init_dream_summary_prompt() diff --git a/src/dream/tools/__init__.py b/src/dream/tools/__init__.py index ef3b8be3..cd784b02 100644 --- a/src/dream/tools/__init__.py +++ b/src/dream/tools/__init__.py @@ -4,8 +4,3 @@ dream agent 工具实现模块。 每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数 生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。 """ - - - - - diff --git a/src/dream/tools/create_chat_history_tool.py b/src/dream/tools/create_chat_history_tool.py index dde4ffaf..551b7ec6 100644 --- a/src/dream/tools/create_chat_history_tool.py +++ b/src/dream/tools/create_chat_history_tool.py @@ -60,8 +60,3 @@ def make_create_chat_history(chat_id: str): return f"create_chat_history 执行失败: {e}" return create_chat_history - - - - - diff --git a/src/dream/tools/delete_chat_history_tool.py b/src/dream/tools/delete_chat_history_tool.py index e7b9755d..18c32f27 100644 --- a/src/dream/tools/delete_chat_history_tool.py +++ b/src/dream/tools/delete_chat_history_tool.py @@ -23,8 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用, return f"delete_chat_history 执行失败: {e}" return delete_chat_history - - - - - diff --git a/src/dream/tools/delete_jargon_tool.py b/src/dream/tools/delete_jargon_tool.py index 0ff1c218..8edd3245 100644 --- a/src/dream/tools/delete_jargon_tool.py +++ b/src/dream/tools/delete_jargon_tool.py @@ -23,8 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留 return f"delete_jargon 执行失败: {e}" return delete_jargon - - - - - diff --git a/src/dream/tools/finish_maintenance_tool.py b/src/dream/tools/finish_maintenance_tool.py index 66f8c99e..403b6c6e 100644 --- a/src/dream/tools/finish_maintenance_tool.py +++ b/src/dream/tools/finish_maintenance_tool.py @@ -14,8 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用, return msg return finish_maintenance - - - - - diff --git a/src/dream/tools/get_chat_history_detail_tool.py b/src/dream/tools/get_chat_history_detail_tool.py index 4f9e16b7..92f1d4d9 100644 --- a/src/dream/tools/get_chat_history_detail_tool.py +++ b/src/dream/tools/get_chat_history_detail_tool.py @@ -1,5 +1,4 @@ import time -from typing import Optional from src.common.logger import get_logger from src.common.database.database_model import ChatHistory @@ -20,14 +19,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用 # 将时间戳转换为可读时间格式 start_time_str = ( - time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time)) - if record.start_time - else "未知" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time)) if record.start_time else "未知" ) end_time_str = ( - time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) - if record.end_time - else "未知" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知" ) result = ( @@ -40,17 +35,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用 f"概括={record.summary or '无'}\n" f"关键信息={record.key_point or '无'}" ) - logger.debug( - f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}" - ) + logger.debug(f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}") return result except Exception as e: logger.error(f"get_chat_history_detail 失败: {e}") return f"get_chat_history_detail 执行失败: {e}" return get_chat_history_detail - - - - - diff --git a/src/dream/tools/search_chat_history_tool.py b/src/dream/tools/search_chat_history_tool.py index 105f6676..5d216f00 100644 --- a/src/dream/tools/search_chat_history_tool.py +++ b/src/dream/tools/search_chat_history_tool.py @@ -78,9 +78,7 @@ def make_search_chat_history(chat_id: str): if record.keywords: try: keywords_data = ( - json.loads(record.keywords) - if isinstance(record.keywords, str) - else record.keywords + json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords ) if isinstance(keywords_data, list): record_keywords_list = [str(k).lower() for k in keywords_data] @@ -125,9 +123,7 @@ def make_search_chat_history(chat_id: str): keywords_str = "、".join(keywords_list) if len(keywords_list) > 2: required_count = len(keywords_list) - 1 - return ( - f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录" - ) + return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录" else: return f"未找到包含所有关键词'{keywords_str}'的聊天记录" elif participant: @@ -142,9 +138,7 @@ def make_search_chat_history(chat_id: str): if record.keywords: try: keywords_data = ( - json.loads(record.keywords) - if isinstance(record.keywords, str) - else record.keywords + json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords ) if isinstance(keywords_data, list): for k in keywords_data: @@ -160,13 +154,13 @@ def make_search_chat_history(chat_id: str): keywords_str = "、".join(sorted(all_keywords_set)) response_text = ( f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n" - f"有关\"{search_label}\"的关键词:\n" + f'有关"{search_label}"的关键词:\n' f"{keywords_str}" ) else: response_text = ( f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n" - f"有关\"{search_label}\"的关键词信息为空" + f'有关"{search_label}"的关键词信息为空' ) logger.info( @@ -192,9 +186,7 @@ def make_search_chat_history(chat_id: str): if record.keywords: try: keywords_data = ( - json.loads(record.keywords) - if isinstance(record.keywords, str) - else record.keywords + json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords ) if isinstance(keywords_data, list) and keywords_data: keywords_str = "、".join([str(k) for k in keywords_data]) @@ -220,8 +212,3 @@ def make_search_chat_history(chat_id: str): return f"search_chat_history 执行失败: {e}" return search_chat_history - - - - - diff --git a/src/dream/tools/search_jargon_tool.py b/src/dream/tools/search_jargon_tool.py index 0429a6a7..139536ac 100644 --- a/src/dream/tools/search_jargon_tool.py +++ b/src/dream/tools/search_jargon_tool.py @@ -16,9 +16,7 @@ def make_search_jargon(chat_id: str): if not keyword or not keyword.strip(): return "未指定查询关键词(参数 keyword 为必填,且不能为空)" - logger.info( - f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})" - ) + logger.info(f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})") # 基础条件:只查 is_jargon=True 的记录 query = Jargon.select().where(Jargon.is_jargon) @@ -102,5 +100,3 @@ def make_search_jargon(chat_id: str): return f"search_jargon 执行失败: {e}" return search_jargon - - diff --git a/src/dream/tools/update_chat_history_tool.py b/src/dream/tools/update_chat_history_tool.py index 8c0caf63..a65e78a7 100644 --- a/src/dream/tools/update_chat_history_tool.py +++ b/src/dream/tools/update_chat_history_tool.py @@ -49,8 +49,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用, return f"update_chat_history 执行失败: {e}" return update_chat_history - - - - - diff --git a/src/dream/tools/update_jargon_tool.py b/src/dream/tools/update_jargon_tool.py index b50504ae..1d559cf6 100644 --- a/src/dream/tools/update_jargon_tool.py +++ b/src/dream/tools/update_jargon_tool.py @@ -49,8 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留 return f"update_jargon 执行失败: {e}" return update_jargon - - - - - diff --git a/src/hippo_memorizer/chat_history_summarizer.py b/src/hippo_memorizer/chat_history_summarizer.py index 456a0f2e..241a2af8 100644 --- a/src/hippo_memorizer/chat_history_summarizer.py +++ b/src/hippo_memorizer/chat_history_summarizer.py @@ -316,7 +316,9 @@ class ChatHistorySummarizer: before_count = len(self.current_batch.messages) self.current_batch.messages.extend(new_messages) self.current_batch.end_time = current_time - logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息") + logger.info( + f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息" + ) # 更新批次后持久化 self._persist_topic_cache() else: @@ -362,9 +364,7 @@ class ChatHistorySummarizer: else: time_str = f"{time_since_last_check / 3600:.1f}小时" - logger.debug( - f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}" - ) + logger.debug(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}") # 检查“话题检查”触发条件 should_check = False @@ -414,7 +414,7 @@ class ChatHistorySummarizer: # 说明 bot 没有参与这段对话,不应该记录 bot_user_id = str(global_config.bot.qq_account) has_bot_message = False - + for msg in messages: if msg.user_info.user_id == bot_user_id: has_bot_message = True @@ -427,7 +427,9 @@ class ChatHistorySummarizer: return # 2. 构造编号后的消息字符串和参与者信息 - numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages) + numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = ( + self._build_numbered_messages_for_llm(messages) + ) # 3. 调用 LLM 识别话题,并得到 topic -> indices(失败时最多重试 3 次) existing_topics = list(self.topic_cache.keys()) @@ -456,9 +458,7 @@ class ChatHistorySummarizer: ) if not success or not topic_to_indices: - logger.error( - f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃" - ) + logger.error(f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃") # 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks(保持原状) return @@ -610,9 +610,7 @@ class ChatHistorySummarizer: if not numbered_lines: return False, {} - history_topics_block = ( - "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)" - ) + history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)" messages_block = "\n".join(numbered_lines) prompt = await global_prompt_manager.format_prompt( @@ -635,17 +633,17 @@ class ChatHistorySummarizer: json_str = None json_pattern = r"```json\s*(.*?)\s*```" matches = re.findall(json_pattern, response, re.DOTALL) - + if matches: # 找到JSON代码块,使用第一个匹配 json_str = matches[0].strip() else: # 如果没有找到代码块,尝试查找JSON数组的开始和结束位置 # 查找第一个 [ 和最后一个 ] - start_idx = response.find('[') - end_idx = response.rfind(']') + start_idx = response.find("[") + end_idx = response.rfind("]") if start_idx != -1 and end_idx != -1 and end_idx > start_idx: - json_str = response[start_idx:end_idx + 1].strip() + json_str = response[start_idx : end_idx + 1].strip() else: # 如果还是找不到,尝试直接使用整个响应(移除可能的markdown标记) json_str = response.strip() @@ -942,4 +940,3 @@ class ChatHistorySummarizer: init_prompt() - diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 7fb4cfd0..df22c9cd 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -49,7 +49,7 @@ class LLMRequest: def _check_slow_request(self, time_cost: float, model_name: str) -> None: """检查请求是否过慢并输出警告日志 - + Args: time_cost: 请求耗时(秒) model_name: 使用的模型名称 @@ -323,7 +323,7 @@ class LLMRequest: effective_temperature = (model_info.extra_params or {}).get("temperature") if effective_temperature is None: effective_temperature = self.model_for_task.temperature - + # max_tokens 优先级:参数传入 > 模型级别配置 > extra_params > 任务配置 effective_max_tokens = max_tokens if effective_max_tokens is None: @@ -332,7 +332,7 @@ class LLMRequest: effective_max_tokens = (model_info.extra_params or {}).get("max_tokens") if effective_max_tokens is None: effective_max_tokens = self.model_for_task.max_tokens - + return await client.get_response( model_info=model_info, message_list=(compressed_messages or message_list), @@ -366,7 +366,9 @@ class LLMRequest: logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}") raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e - logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}") + logger.warning( + f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}" + ) await asyncio.sleep(api_provider.retry_interval) except NetworkConnectionError as e: @@ -394,7 +396,9 @@ class LLMRequest: if e.status_code == 429 or e.status_code >= 500: retry_remain -= 1 if retry_remain <= 0: - logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}") + logger.error( + f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}" + ) raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e logger.warning( @@ -540,7 +544,5 @@ class LLMRequest: if e.__cause__: original_error_type = type(e.__cause__).__name__ original_error_msg = str(e.__cause__) - return ( - f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}" - ) + return f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}" return "" diff --git a/src/main.py b/src/main.py index d5426f5c..b6141ad4 100644 --- a/src/main.py +++ b/src/main.py @@ -113,7 +113,6 @@ class MainSystem: get_emoji_manager().initialize() logger.info("表情包管理器初始化成功") - # 初始化聊天管理器 await get_chat_manager()._initialize() asyncio.create_task(get_chat_manager()._auto_save_task()) diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index 248e3062..9aafb9ef 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -136,8 +136,6 @@ def init_memory_retrieval_prompt(): ) - - def _log_conversation_messages( conversation_messages: List[Message], head_prompt: Optional[str] = None, @@ -172,7 +170,9 @@ def _log_conversation_messages( # 构建单条消息的日志信息 # msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------" - msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------" + msg_info = ( + f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------" + ) # if full_content: # msg_info += f"\n{full_content}" @@ -185,8 +185,7 @@ def _log_conversation_messages( msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}" # if msg.tool_call_id: - # msg_info += f"\n 工具调用ID: {msg.tool_call_id}" - + # msg_info += f"\n 工具调用ID: {msg.tool_call_id}" log_lines.append(msg_info) @@ -330,7 +329,7 @@ async def _react_agent_solve_question( remaining_iterations=remaining_iterations, max_iterations=max_iterations, ) - + # 后续迭代都复用第一次构建的head_prompt head_prompt = first_head_prompt @@ -365,7 +364,7 @@ async def _react_agent_solve_question( ) # logger.info( - # f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}" + # f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}" # ) if not success: @@ -409,20 +408,20 @@ async def _react_agent_solve_question( """从文本中解析finish_search函数调用,返回(found_answer, answer)元组,如果未找到则返回(None, None)""" if not text: return None, None - + # 查找finish_search函数调用位置(不区分大小写) func_pattern = "finish_search" text_lower = text.lower() func_pos = text_lower.find(func_pattern) if func_pos == -1: return None, None - + # 查找函数调用的开始和结束位置 # 从func_pos开始向后查找左括号 start_pos = text.find("(", func_pos) if start_pos == -1: return None, None - + # 查找匹配的右括号(考虑嵌套) paren_count = 0 end_pos = start_pos @@ -437,10 +436,10 @@ async def _react_agent_solve_question( else: # 没有找到匹配的右括号 return None, None - + # 提取函数参数部分 params_text = text[start_pos + 1 : end_pos] - + # 解析found_answer参数(布尔值,可能是true/false/True/False) found_answer = None found_answer_patterns = [ @@ -454,49 +453,60 @@ async def _react_agent_solve_question( if match: found_answer = "true" in match.group(0).lower() break - + # 解析answer参数(字符串,使用extract_quoted_content) answer = extract_quoted_content(text, "finish_search", "answer") - + return found_answer, answer - + parsed_found_answer, parsed_answer = parse_finish_search_from_text(response) - + if parsed_found_answer is not None: # 检测到finish_search函数调用格式 if parsed_found_answer: # 找到了答案 if parsed_answer: - step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": parsed_answer}}) + step["actions"].append( + { + "action_type": "finish_search", + "action_params": {"found_answer": True, "answer": parsed_answer}, + } + ) step["observations"] = ["检测到finish_search文本格式调用,找到答案"] thinking_steps.append(step) - logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}") - + logger.info( + f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}" + ) + _log_conversation_messages( conversation_messages, head_prompt=first_head_prompt, final_status=f"找到答案:{parsed_answer}", ) - + return True, parsed_answer, thinking_steps, False else: # found_answer为True但没有提供answer,视为错误,继续迭代 - logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer") + logger.warning( + f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer" + ) else: # 未找到答案,直接退出查询 - step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}}) + step["actions"].append( + {"action_type": "finish_search", "action_params": {"found_answer": False}} + ) step["observations"] = ["检测到finish_search文本格式调用,未找到答案"] thinking_steps.append(step) logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案") - + _log_conversation_messages( conversation_messages, head_prompt=first_head_prompt, final_status="未找到答案:通过finish_search文本格式判断未找到答案", ) - + return False, "", thinking_steps, False - + # 如果没有检测到finish_search格式,记录思考过程,继续下一轮迭代 step["observations"] = [f"思考完成,但未调用工具。响应: {response}"] logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考完成但未调用工具: {response}") @@ -514,44 +524,53 @@ async def _react_agent_solve_question( for tool_call in tool_calls: tool_name = tool_call.func_name tool_args = tool_call.args or {} - + if tool_name == "finish_search": finish_search_found = tool_args.get("found_answer", False) finish_search_answer = tool_args.get("answer", "") - + if finish_search_found: # 找到了答案 if finish_search_answer: - step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": finish_search_answer}}) + step["actions"].append( + { + "action_type": "finish_search", + "action_params": {"found_answer": True, "answer": finish_search_answer}, + } + ) step["observations"] = ["检测到finish_search工具调用,找到答案"] thinking_steps.append(step) - logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}") - + logger.info( + f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}" + ) + _log_conversation_messages( conversation_messages, head_prompt=first_head_prompt, final_status=f"找到答案:{finish_search_answer}", ) - + return True, finish_search_answer, thinking_steps, False else: # found_answer为True但没有提供answer,视为错误 - logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer") + logger.warning( + f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer" + ) else: # 未找到答案,直接退出查询 step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}}) step["observations"] = ["检测到finish_search工具调用,未找到答案"] thinking_steps.append(step) logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具判断未找到答案") - + _log_conversation_messages( conversation_messages, head_prompt=first_head_prompt, final_status="未找到答案:通过finish_search工具判断未找到答案", ) - + return False, "", thinking_steps, False - + # 如果没有finish_search工具调用,继续处理其他工具 tool_tasks = [] for i, tool_call in enumerate(tool_calls): @@ -627,7 +646,7 @@ async def _react_agent_solve_question( observation_text += f"\n\n{jargon_info}" collected_info += f"\n{jargon_info}\n" logger.info(f"工具输出触发黑话解析: {new_concepts}") - + tool_builder = MessageBuilder() tool_builder.set_role(RoleType.Tool) tool_builder.add_text_content(observation_text) @@ -645,7 +664,7 @@ async def _react_agent_solve_question( elif iteration + 1 >= max_iterations: should_do_final_evaluation = True logger.info(f"ReAct Agent达到最大迭代次数(已迭代{iteration + 1}次),进入最终评估") - + if should_do_final_evaluation: # 获取必要变量用于最终评估 tool_registry = get_tool_registry() @@ -653,7 +672,7 @@ async def _react_agent_solve_question( time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) current_iteration = iteration + 1 remaining_iterations = 0 - + # 提取函数调用中参数的值,支持单引号和双引号 def extract_quoted_content(text, func_name, param_name): """从文本中提取函数调用中参数的值,支持单引号和双引号 @@ -724,7 +743,13 @@ async def _react_agent_solve_question( max_iterations=max_iterations, ) - eval_success, eval_response, eval_reasoning_content, eval_model_name, eval_tool_calls = await llm_api.generate_with_model_with_tools( + ( + eval_success, + eval_response, + eval_reasoning_content, + eval_model_name, + eval_tool_calls, + ) = await llm_api.generate_with_model_with_tools( evaluation_prompt, model_config=model_config.model_task_config.tool_use, tool_options=[], # 最终评估阶段不提供工具 @@ -739,7 +764,7 @@ async def _react_agent_solve_question( final_status="未找到答案:最终评估阶段LLM调用失败", ) return False, "最终评估阶段LLM调用失败", thinking_steps, is_timeout - + if global_config.debug.show_memory_prompt: logger.info(f"ReAct Agent 最终评估Prompt: {evaluation_prompt}") logger.info(f"ReAct Agent 最终评估响应: {eval_response}") @@ -759,17 +784,17 @@ async def _react_agent_solve_question( "iteration": current_iteration, "thought": f"[最终评估] {eval_response}", "actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}], - "observations": ["最终评估阶段检测到found_answer"] + "observations": ["最终评估阶段检测到found_answer"], } thinking_steps.append(eval_step) logger.info(f"ReAct Agent 最终评估阶段找到关于问题{question}的答案: {found_answer_content}") - + _log_conversation_messages( conversation_messages, head_prompt=first_head_prompt, final_status=f"找到答案:{found_answer_content}", ) - + return True, found_answer_content, thinking_steps, False # 如果评估为not_enough_info,返回空字符串(不返回任何信息) @@ -778,35 +803,37 @@ async def _react_agent_solve_question( "iteration": current_iteration, "thought": f"[最终评估] {eval_response}", "actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}], - "observations": ["最终评估阶段检测到not_enough_info"] + "observations": ["最终评估阶段检测到not_enough_info"], } thinking_steps.append(eval_step) logger.info(f"ReAct Agent 最终评估阶段判断信息不足: {not_enough_info_reason}") - + _log_conversation_messages( conversation_messages, head_prompt=first_head_prompt, final_status=f"未找到答案:{not_enough_info_reason}", ) - + return False, "", thinking_steps, is_timeout # 如果没有明确判断,视为not_enough_info,返回空字符串(不返回任何信息) eval_step = { "iteration": current_iteration, "thought": f"[最终评估] {eval_response}", - "actions": [{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}], - "observations": ["已到达最大迭代次数,无法找到答案"] + "actions": [ + {"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}} + ], + "observations": ["已到达最大迭代次数,无法找到答案"], } thinking_steps.append(eval_step) logger.info("ReAct Agent 已到达最大迭代次数,无法找到答案") - + _log_conversation_messages( conversation_messages, head_prompt=first_head_prompt, final_status="未找到答案:已到达最大迭代次数,无法找到答案", ) - + return False, "", thinking_steps, is_timeout # 如果正常迭代过程中提前找到答案返回,不会到达这里 @@ -817,7 +844,7 @@ async def _react_agent_solve_question( head_prompt=first_head_prompt, final_status="未找到答案:正常迭代结束", ) - + return False, "", thinking_steps, is_timeout @@ -1129,7 +1156,9 @@ async def build_memory_retrieval_prompt( else: max_iterations = base_max_iterations timeout_seconds = global_config.memory.agent_timeout_seconds - logger.debug(f"问题数量: {len(questions)},think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}秒") + logger.debug( + f"问题数量: {len(questions)},think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}秒" + ) # 并行处理所有问题,将概念检索结果作为初始信息传递 question_tasks = [ @@ -1157,10 +1186,10 @@ async def build_memory_retrieval_prompt( # 获取最近10分钟内已找到答案的缓存记录 cached_answers = _get_recent_found_answers(chat_id, time_window_seconds=600.0) - + # 合并当前查询结果和缓存答案(去重:如果当前查询的问题在缓存中已存在,优先使用当前结果) all_results = [] - + # 先添加当前查询的结果 current_questions = set() for result in question_results: @@ -1170,7 +1199,7 @@ async def build_memory_retrieval_prompt( if question_end != -1: current_questions.add(result[4:question_end]) all_results.append(result) - + # 添加缓存答案(排除当前查询中已存在的问题) for cached_answer in cached_answers: if cached_answer.startswith("问题:"): @@ -1198,4 +1227,3 @@ async def build_memory_retrieval_prompt( except Exception as e: logger.error(f"记忆检索时发生异常: {str(e)}") return "" - diff --git a/src/memory_system/memory_utils.py b/src/memory_system/memory_utils.py index 7aa33a52..9886142c 100644 --- a/src/memory_system/memory_utils.py +++ b/src/memory_system/memory_utils.py @@ -17,7 +17,6 @@ from src.common.logger import get_logger logger = get_logger("memory_utils") - def parse_questions_json(response: str) -> Tuple[List[str], List[str]]: """解析问题JSON,返回概念列表和问题列表 @@ -68,6 +67,7 @@ def parse_questions_json(response: str) -> Tuple[List[str], List[str]]: logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...") return [], [] + def parse_datetime_to_timestamp(value: str) -> float: """ 接受多种常见格式并转换为时间戳(秒) diff --git a/src/memory_system/retrieval_tools/found_answer.py b/src/memory_system/retrieval_tools/found_answer.py index 148424b9..bbed96b3 100644 --- a/src/memory_system/retrieval_tools/found_answer.py +++ b/src/memory_system/retrieval_tools/found_answer.py @@ -47,4 +47,3 @@ def register_tool(): ], execute_func=finish_search, ) - diff --git a/src/memory_system/retrieval_tools/query_chat_history.py b/src/memory_system/retrieval_tools/query_chat_history.py index 5cf9a8ed..351d0606 100644 --- a/src/memory_system/retrieval_tools/query_chat_history.py +++ b/src/memory_system/retrieval_tools/query_chat_history.py @@ -16,9 +16,7 @@ from .tool_registry import register_memory_retrieval_tool logger = get_logger("memory_retrieval_tools") -async def search_chat_history( - chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None -) -> str: +async def search_chat_history(chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None) -> str: """根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords Args: @@ -117,7 +115,7 @@ async def search_chat_history( ) if kw_matched: matched_count += 1 - + # 计算需要匹配的关键词数量 total_keywords = len(keywords_lower) if total_keywords > 2: @@ -126,7 +124,7 @@ async def search_chat_history( else: # 关键词数量<=2,必须全部匹配 required_matches = total_keywords - + keyword_matched = matched_count >= required_matches # 两者都匹配(如果同时有participant和keyword,需要两者都匹配;如果只有一个条件,只需要该条件匹配) @@ -144,7 +142,9 @@ async def search_chat_history( keywords_list = parse_keywords_string(keyword) if len(keywords_list) > 2: required_count = len(keywords_list) - 1 - return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录" + return ( + f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录" + ) else: return f"未找到包含所有关键词'{keywords_str}'的聊天记录" elif participant: @@ -160,9 +160,7 @@ async def search_chat_history( if record.keywords: try: keywords_data = ( - json.loads(record.keywords) - if isinstance(record.keywords, str) - else record.keywords + json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords ) if isinstance(keywords_data, list): for k in keywords_data: @@ -179,13 +177,12 @@ async def search_chat_history( keywords_str = "、".join(sorted(all_keywords_set)) return ( f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n" - f"有关\"{search_label}\"的关键词:\n" + f'有关"{search_label}"的关键词:\n' f"{keywords_str}" ) else: return ( - f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n" - f"有关\"{search_label}\"的关键词信息为空" + f'包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n有关"{search_label}"的关键词信息为空' ) # 构建结果文本,返回id、theme和keywords(最多20条) diff --git a/src/webui/auth.py b/src/webui/auth.py index 8d52a5e3..804cef55 100644 --- a/src/webui/auth.py +++ b/src/webui/auth.py @@ -22,42 +22,42 @@ def get_current_token( ) -> str: """ 获取当前请求的 token,优先从 Cookie 获取,其次从 Header 获取 - + Args: request: FastAPI Request 对象 maibot_session: Cookie 中的 token authorization: Authorization Header (Bearer token) - + Returns: 验证通过的 token - + Raises: HTTPException: 认证失败时抛出 401 错误 """ token = None - + # 优先从 Cookie 获取 if maibot_session: token = maibot_session # 其次从 Header 获取(兼容旧版本) elif authorization and authorization.startswith("Bearer "): token = authorization.replace("Bearer ", "") - + if not token: raise HTTPException(status_code=401, detail="未提供有效的认证信息") - + # 验证 token token_manager = get_token_manager() if not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="Token 无效或已过期") - + return token def set_auth_cookie(response: Response, token: str) -> None: """ 设置认证 Cookie - + Args: response: FastAPI Response 对象 token: 要设置的 token @@ -77,7 +77,7 @@ def set_auth_cookie(response: Response, token: str) -> None: def clear_auth_cookie(response: Response) -> None: """ 清除认证 Cookie - + Args: response: FastAPI Response 对象 """ @@ -96,32 +96,32 @@ def verify_auth_token_from_cookie_or_header( ) -> bool: """ 验证认证 Token,支持从 Cookie 或 Header 获取 - + Args: maibot_session: Cookie 中的 token authorization: Authorization header (Bearer token) - + Returns: 验证成功返回 True - + Raises: HTTPException: 认证失败时抛出 401 错误 """ token = None - + # 优先从 Cookie 获取 if maibot_session: token = maibot_session # 其次从 Header 获取(兼容旧版本) elif authorization and authorization.startswith("Bearer "): token = authorization.replace("Bearer ", "") - + if not token: raise HTTPException(status_code=401, detail="未提供有效的认证信息") - + # 验证 token token_manager = get_token_manager() if not token_manager.verify_token(token): raise HTTPException(status_code=401, detail="Token 无效或已过期") - + return True diff --git a/src/webui/chat_routes.py b/src/webui/chat_routes.py index 14d8d9d2..5e492cb2 100644 --- a/src/webui/chat_routes.py +++ b/src/webui/chat_routes.py @@ -63,14 +63,14 @@ class ChatHistoryManager: def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]: """将数据库消息转换为前端格式 - + Args: msg: 数据库消息对象 group_id: 群 ID,用于判断是否是虚拟群 """ # 判断是否是机器人消息 user_id = msg.user_id or "" - + # 对于虚拟群,通过比较机器人 QQ 账号来判断 # 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头 if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX): @@ -414,7 +414,9 @@ async def websocket_chat( group_id=virtual_group_id, group_name=group_name or "WebUI虚拟群聊", ) - logger.info(f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}") + logger.info( + f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}" + ) except Exception as e: logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}") diff --git a/src/webui/emoji_routes.py b/src/webui/emoji_routes.py index 0784a26a..90b2d60b 100644 --- a/src/webui/emoji_routes.py +++ b/src/webui/emoji_routes.py @@ -1,4 +1,4 @@ -""" 表情包管理 API 路由""" +"""表情包管理 API 路由""" from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie from fastapi.responses import FileResponse, JSONResponse @@ -48,7 +48,7 @@ def _get_thumbnail_lock(file_hash: str) -> threading.Lock: def _background_generate_thumbnail(source_path: str, file_hash: str) -> None: """ 后台生成缩略图(在线程池中执行) - + 生成完成后自动从 generating 集合中移除 """ try: @@ -74,14 +74,14 @@ def _get_thumbnail_cache_path(file_hash: str) -> Path: def _generate_thumbnail(source_path: str, file_hash: str) -> Path: """ 生成缩略图并保存到缓存目录 - + Args: source_path: 原图路径 file_hash: 文件哈希值,用作缓存文件名 - + Returns: 缩略图路径 - + Features: - GIF: 提取第一帧作为缩略图 - 所有格式统一转为 WebP @@ -89,63 +89,63 @@ def _generate_thumbnail(source_path: str, file_hash: str) -> Path: """ _ensure_thumbnail_cache_dir() cache_path = _get_thumbnail_cache_path(file_hash) - + # 使用锁防止并发生成同一缩略图 lock = _get_thumbnail_lock(file_hash) with lock: # 双重检查,可能在等待锁时已被其他线程生成 if cache_path.exists(): return cache_path - + try: with Image.open(source_path) as img: # GIF 处理:提取第一帧 - if hasattr(img, 'n_frames') and img.n_frames > 1: + if hasattr(img, "n_frames") and img.n_frames > 1: img.seek(0) # 确保在第一帧 - + # 转换为 RGB/RGBA(WebP 支持透明度) - if img.mode in ('P', 'PA'): + if img.mode in ("P", "PA"): # 调色板模式转换为 RGBA 以保留透明度 - img = img.convert('RGBA') - elif img.mode == 'LA': - img = img.convert('RGBA') - elif img.mode not in ('RGB', 'RGBA'): - img = img.convert('RGB') - + img = img.convert("RGBA") + elif img.mode == "LA": + img = img.convert("RGBA") + elif img.mode not in ("RGB", "RGBA"): + img = img.convert("RGB") + # 创建缩略图(保持宽高比) img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS) - + # 保存为 WebP 格式 - img.save(cache_path, 'WEBP', quality=THUMBNAIL_QUALITY, method=6) - + img.save(cache_path, "WEBP", quality=THUMBNAIL_QUALITY, method=6) + logger.debug(f"生成缩略图: {file_hash} -> {cache_path}") - + except Exception as e: logger.warning(f"生成缩略图失败 {file_hash}: {e},将返回原图") # 生成失败时不创建缓存文件,下次会重试 raise - + return cache_path def cleanup_orphaned_thumbnails() -> tuple[int, int]: """ 清理孤立的缩略图缓存(原图已不存在的缩略图) - + Returns: (清理数量, 保留数量) """ if not THUMBNAIL_CACHE_DIR.exists(): return 0, 0 - + # 获取所有表情包的哈希值 valid_hashes = set() for emoji in Emoji.select(Emoji.emoji_hash): valid_hashes.add(emoji.emoji_hash) - + cleaned = 0 kept = 0 - + for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"): file_hash = cache_file.stem if file_hash not in valid_hashes: @@ -157,12 +157,13 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]: logger.warning(f"清理缩略图失败 {cache_file.name}: {e}") else: kept += 1 - + if cleaned > 0: logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept} 个") - + return cleaned, kept + # 模块级别的类型别名(解决 B008 ruff 错误) EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")] EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")] @@ -365,7 +366,9 @@ async def get_emoji_list( @router.get("/{emoji_id}", response_model=EmojiDetailResponse) -async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def get_emoji_detail( + emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +): """ 获取表情包详细信息 @@ -394,7 +397,12 @@ async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie @router.patch("/{emoji_id}", response_model=EmojiUpdateResponse) -async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def update_emoji( + emoji_id: int, + request: EmojiUpdateRequest, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +): """ 增量更新表情包(只更新提供的字段) @@ -446,7 +454,9 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_sessio @router.delete("/{emoji_id}", response_model=EmojiDeleteResponse) -async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def delete_emoji( + emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +): """ 删除表情包 @@ -538,7 +548,9 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz @router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse) -async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def register_emoji( + emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +): """ 注册表情包(快捷操作) @@ -578,7 +590,9 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N @router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse) -async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def ban_emoji( + emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +): """ 禁用表情包(快捷操作) @@ -633,7 +647,7 @@ async def get_emoji_thumbnail( Returns: 表情包缩略图(WebP 格式)或原图 - + Features: - 懒加载:首次请求时生成缩略图 - 缓存:后续请求直接返回缓存 @@ -643,7 +657,7 @@ async def get_emoji_thumbnail( try: token_manager = get_token_manager() is_valid = False - + # 1. 优先使用 Cookie if maibot_session and token_manager.verify_token(maibot_session): is_valid = True @@ -655,7 +669,7 @@ async def get_emoji_thumbnail( auth_token = authorization.replace("Bearer ", "") if token_manager.verify_token(auth_token): is_valid = True - + if not is_valid: raise HTTPException(status_code=401, detail="Token 无效或已过期") @@ -680,35 +694,27 @@ async def get_emoji_thumbnail( } media_type = mime_types.get(emoji.format.lower(), "application/octet-stream") return FileResponse( - path=emoji.full_path, - media_type=media_type, - filename=f"{emoji.emoji_hash}.{emoji.format}" + path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}" ) # 尝试获取或生成缩略图 cache_path = _get_thumbnail_cache_path(emoji.emoji_hash) - + # 检查缓存是否存在 if cache_path.exists(): # 缓存命中,直接返回 return FileResponse( - path=str(cache_path), - media_type="image/webp", - filename=f"{emoji.emoji_hash}_thumb.webp" + path=str(cache_path), media_type="image/webp", filename=f"{emoji.emoji_hash}_thumb.webp" ) - + # 缓存未命中,触发后台生成并返回 202 with _generating_lock: if emoji.emoji_hash not in _generating_thumbnails: # 标记为正在生成 _generating_thumbnails.add(emoji.emoji_hash) # 提交到线程池后台生成 - _thumbnail_executor.submit( - _background_generate_thumbnail, - emoji.full_path, - emoji.emoji_hash - ) - + _thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.emoji_hash) + # 返回 202 Accepted,告诉前端缩略图正在生成中 return JSONResponse( status_code=202, @@ -719,7 +725,7 @@ async def get_emoji_thumbnail( }, headers={ "Retry-After": "1", # 建议 1 秒后重试 - } + }, ) except HTTPException: @@ -730,7 +736,11 @@ async def get_emoji_thumbnail( @router.post("/batch/delete", response_model=BatchDeleteResponse) -async def batch_delete_emojis(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def batch_delete_emojis( + request: BatchDeleteRequest, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +): """ 批量删除表情包 @@ -1079,7 +1089,7 @@ async def batch_upload_emoji( class ThumbnailCacheStatsResponse(BaseModel): """缩略图缓存统计响应""" - + success: bool cache_dir: str total_count: int @@ -1090,7 +1100,7 @@ class ThumbnailCacheStatsResponse(BaseModel): class ThumbnailCleanupResponse(BaseModel): """缩略图清理响应""" - + success: bool message: str cleaned_count: int @@ -1099,7 +1109,7 @@ class ThumbnailCleanupResponse(BaseModel): class ThumbnailPreheatResponse(BaseModel): """缩略图预热响应""" - + success: bool message: str generated_count: int @@ -1114,27 +1124,27 @@ async def get_thumbnail_cache_stats( ): """ 获取缩略图缓存统计信息 - + Returns: 缓存目录、缓存数量、总大小、覆盖率等统计信息 """ try: verify_auth_token(maibot_session, authorization) - + _ensure_thumbnail_cache_dir() - + # 统计缓存文件 cache_files = list(THUMBNAIL_CACHE_DIR.glob("*.webp")) total_count = len(cache_files) total_size = sum(f.stat().st_size for f in cache_files) total_size_mb = round(total_size / (1024 * 1024), 2) - + # 统计表情包总数 emoji_count = Emoji.select().count() - + # 计算覆盖率 coverage_percent = round((total_count / emoji_count * 100) if emoji_count > 0 else 0, 1) - + return ThumbnailCacheStatsResponse( success=True, cache_dir=str(THUMBNAIL_CACHE_DIR.absolute()), @@ -1143,7 +1153,7 @@ async def get_thumbnail_cache_stats( emoji_count=emoji_count, coverage_percent=coverage_percent, ) - + except HTTPException: raise except Exception as e: @@ -1158,22 +1168,22 @@ async def cleanup_thumbnail_cache( ): """ 清理孤立的缩略图缓存(原图已删除的表情包对应的缩略图) - + Returns: 清理结果 """ try: verify_auth_token(maibot_session, authorization) - + cleaned, kept = cleanup_orphaned_thumbnails() - + return ThumbnailCleanupResponse( success=True, message=f"清理完成:删除 {cleaned} 个孤立缓存,保留 {kept} 个有效缓存", cleaned_count=cleaned, kept_count=kept, ) - + except HTTPException: raise except Exception as e: @@ -1189,20 +1199,20 @@ async def preheat_thumbnail_cache( ): """ 预热缩略图缓存(提前生成未缓存的缩略图) - + 优先处理使用次数高的表情包 - + Args: limit: 最多预热数量 (1-1000) - + Returns: 预热结果 """ try: verify_auth_token(maibot_session, authorization) - + _ensure_thumbnail_cache_dir() - + # 获取使用次数最高的表情包(未缓存的优先) emojis = ( Emoji.select() @@ -1210,41 +1220,36 @@ async def preheat_thumbnail_cache( .order_by(Emoji.usage_count.desc()) .limit(limit * 2) # 多查一些,因为有些可能已缓存 ) - + generated = 0 skipped = 0 failed = 0 - + for emoji in emojis: if generated >= limit: break - + cache_path = _get_thumbnail_cache_path(emoji.emoji_hash) - + # 已缓存,跳过 if cache_path.exists(): skipped += 1 continue - + # 原文件不存在,跳过 if not os.path.exists(emoji.full_path): failed += 1 continue - + try: # 使用线程池异步生成缩略图,避免阻塞事件循环 loop = asyncio.get_event_loop() - await loop.run_in_executor( - _thumbnail_executor, - _generate_thumbnail, - emoji.full_path, - emoji.emoji_hash - ) + await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.emoji_hash) generated += 1 except Exception as e: logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}") failed += 1 - + return ThumbnailPreheatResponse( success=True, message=f"预热完成:生成 {generated} 个,跳过 {skipped} 个已缓存,失败 {failed} 个", @@ -1252,7 +1257,7 @@ async def preheat_thumbnail_cache( skipped_count=skipped, failed_count=failed, ) - + except HTTPException: raise except Exception as e: @@ -1267,13 +1272,13 @@ async def clear_all_thumbnail_cache( ): """ 清空所有缩略图缓存(下次访问时会重新生成) - + Returns: 清理结果 """ try: verify_auth_token(maibot_session, authorization) - + if not THUMBNAIL_CACHE_DIR.exists(): return ThumbnailCleanupResponse( success=True, @@ -1281,7 +1286,7 @@ async def clear_all_thumbnail_cache( cleaned_count=0, kept_count=0, ) - + cleaned = 0 for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"): try: @@ -1289,16 +1294,16 @@ async def clear_all_thumbnail_cache( cleaned += 1 except Exception as e: logger.warning(f"删除缓存文件失败 {cache_file.name}: {e}") - + logger.info(f"已清空缩略图缓存: 删除 {cleaned} 个文件") - + return ThumbnailCleanupResponse( success=True, message=f"已清空所有缩略图缓存:删除 {cleaned} 个文件", cleaned_count=cleaned, kept_count=0, ) - + except HTTPException: raise except Exception as e: diff --git a/src/webui/expression_routes.py b/src/webui/expression_routes.py index 1fa0ce7d..b6ec76b3 100644 --- a/src/webui/expression_routes.py +++ b/src/webui/expression_routes.py @@ -256,7 +256,9 @@ async def get_expression_list( @router.get("/{expression_id}", response_model=ExpressionDetailResponse) -async def get_expression_detail(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def get_expression_detail( + expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +): """ 获取表达方式详细信息 @@ -285,7 +287,11 @@ async def get_expression_detail(expression_id: int, maibot_session: Optional[str @router.post("/", response_model=ExpressionCreateResponse) -async def create_expression(request: ExpressionCreateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def create_expression( + request: ExpressionCreateRequest, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +): """ 创建新的表达方式 @@ -326,7 +332,10 @@ async def create_expression(request: ExpressionCreateRequest, maibot_session: Op @router.patch("/{expression_id}", response_model=ExpressionUpdateResponse) async def update_expression( - expression_id: int, request: ExpressionUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) + expression_id: int, + request: ExpressionUpdateRequest, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), ): """ 增量更新表达方式(只更新提供的字段) @@ -376,7 +385,9 @@ async def update_expression( @router.delete("/{expression_id}", response_model=ExpressionDeleteResponse) -async def delete_expression(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def delete_expression( + expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +): """ 删除表达方式 @@ -419,7 +430,11 @@ class BatchDeleteRequest(BaseModel): @router.post("/batch/delete", response_model=ExpressionDeleteResponse) -async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def batch_delete_expressions( + request: BatchDeleteRequest, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +): """ 批量删除表达方式 @@ -460,7 +475,9 @@ async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session: @router.get("/stats/summary") -async def get_expression_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def get_expression_stats( + maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +): """ 获取表达方式统计数据 diff --git a/src/webui/jargon_routes.py b/src/webui/jargon_routes.py index 318912e8..8d372688 100644 --- a/src/webui/jargon_routes.py +++ b/src/webui/jargon_routes.py @@ -24,7 +24,7 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]: """ if not chat_id_str: return [] - + try: # 尝试解析为 JSON parsed = json.loads(chat_id_str) @@ -49,10 +49,10 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str: 尝试解析 JSON 并查询 ChatStreams 表获取群聊名称 """ stream_ids = parse_chat_id_to_stream_ids(chat_id_str) - + if not stream_ids: return chat_id_str - + # 查询所有 stream_id 对应的名称 names = [] for stream_id in stream_ids: @@ -62,7 +62,7 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str: else: # 如果没找到,显示截断的 stream_id names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id) - + return ", ".join(names) if names else chat_id_str @@ -187,7 +187,7 @@ def jargon_to_dict(jargon: Jargon) -> dict: chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else [] stream_id = stream_ids[0] if stream_ids else None - + return { "id": jargon.id, "content": jargon.content, @@ -277,17 +277,13 @@ async def get_chat_list(): """获取所有有黑话记录的聊天列表""" try: # 获取所有不同的 chat_id - chat_ids = ( - Jargon.select(Jargon.chat_id) - .distinct() - .where(Jargon.chat_id.is_null(False)) - ) + chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)) chat_id_list = [j.chat_id for j in chat_ids if j.chat_id] # 用于按 stream_id 去重 seen_stream_ids: set[str] = set() - + for chat_id in chat_id_list: stream_ids = parse_chat_id_to_stream_ids(chat_id) if stream_ids: @@ -346,12 +342,7 @@ async def get_jargon_stats(): complete_count = Jargon.select().where(Jargon.is_complete).count() # 关联的聊天数量 - chat_count = ( - Jargon.select(Jargon.chat_id) - .distinct() - .where(Jargon.chat_id.is_null(False)) - .count() - ) + chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count() # 按聊天统计 TOP 5 top_chats = ( @@ -403,9 +394,7 @@ async def create_jargon(request: JargonCreateRequest): """创建黑话""" try: # 检查是否已存在相同内容的黑话 - existing = Jargon.get_or_none( - (Jargon.content == request.content) & (Jargon.chat_id == request.chat_id) - ) + existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id)) if existing: raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话") @@ -527,11 +516,7 @@ async def batch_set_jargon_status( if not ids: raise HTTPException(status_code=400, detail="ID列表不能为空") - updated_count = ( - Jargon.update(is_jargon=is_jargon) - .where(Jargon.id.in_(ids)) - .execute() - ) + updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute() logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录,is_jargon={is_jargon}") diff --git a/src/webui/person_routes.py b/src/webui/person_routes.py index 5c039371..9881d44e 100644 --- a/src/webui/person_routes.py +++ b/src/webui/person_routes.py @@ -200,7 +200,9 @@ async def get_person_list( @router.get("/{person_id}", response_model=PersonDetailResponse) -async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def get_person_detail( + person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +): """ 获取人物详细信息 @@ -229,7 +231,12 @@ async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cook @router.patch("/{person_id}", response_model=PersonUpdateResponse) -async def update_person(person_id: str, request: PersonUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def update_person( + person_id: str, + request: PersonUpdateRequest, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +): """ 增量更新人物信息(只更新提供的字段) @@ -278,7 +285,9 @@ async def update_person(person_id: str, request: PersonUpdateRequest, maibot_ses @router.delete("/{person_id}", response_model=PersonDeleteResponse) -async def delete_person(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def delete_person( + person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +): """ 删除人物信息 @@ -348,7 +357,11 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori @router.post("/batch/delete", response_model=BatchDeleteResponse) -async def batch_delete_persons(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): +async def batch_delete_persons( + request: BatchDeleteRequest, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +): """ 批量删除人物信息 diff --git a/src/webui/plugin_routes.py b/src/webui/plugin_routes.py index a99d36f5..00c960d3 100644 --- a/src/webui/plugin_routes.py +++ b/src/webui/plugin_routes.py @@ -125,6 +125,7 @@ def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> No """ 根据 schema 将配置中的类型纠正(目前只纠正 list-from-str)。 """ + def _is_list_type(tp: Any) -> bool: origin = get_origin(tp) return tp is list or origin is list @@ -313,7 +314,9 @@ async def check_git_status() -> GitStatusResponse: @router.get("/mirrors", response_model=AvailableMirrorsResponse) -async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse: +async def get_available_mirrors( + maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +) -> AvailableMirrorsResponse: """ 获取所有可用的镜像源配置 """ @@ -343,7 +346,9 @@ async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), au @router.post("/mirrors", response_model=MirrorConfigResponse) -async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> MirrorConfigResponse: +async def add_mirror( + request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +) -> MirrorConfigResponse: """ 添加新的镜像源 """ @@ -383,7 +388,10 @@ async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] = @router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse) async def update_mirror( - mirror_id: str, request: UpdateMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) + mirror_id: str, + request: UpdateMirrorRequest, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), ) -> MirrorConfigResponse: """ 更新镜像源配置 @@ -426,7 +434,9 @@ async def update_mirror( @router.delete("/mirrors/{mirror_id}") -async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: +async def delete_mirror( + mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +) -> Dict[str, Any]: """ 删除镜像源 """ @@ -449,7 +459,9 @@ async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(N @router.post("/fetch-raw", response_model=FetchRawFileResponse) async def fetch_raw_file( - request: FetchRawFileRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) + request: FetchRawFileRequest, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), ) -> FetchRawFileResponse: """ 获取 GitHub 仓库的 Raw 文件内容 @@ -534,7 +546,9 @@ async def fetch_raw_file( @router.post("/clone", response_model=CloneRepositoryResponse) async def clone_repository( - request: CloneRepositoryRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) + request: CloneRepositoryRequest, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), ) -> CloneRepositoryResponse: """ 克隆 GitHub 仓库到本地 @@ -574,7 +588,11 @@ async def clone_repository( @router.post("/install") -async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: +async def install_plugin( + request: InstallPluginRequest, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +) -> Dict[str, Any]: """ 安装插件 @@ -778,7 +796,9 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional @router.post("/uninstall") async def uninstall_plugin( - request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) + request: UninstallPluginRequest, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), ) -> Dict[str, Any]: """ 卸载插件 @@ -913,7 +933,11 @@ async def uninstall_plugin( @router.post("/update") -async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: +async def update_plugin( + request: UpdatePluginRequest, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +) -> Dict[str, Any]: """ 更新插件 @@ -1132,7 +1156,9 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s @router.get("/installed") -async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: +async def get_installed_plugins( + maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +) -> Dict[str, Any]: """ 获取已安装的插件列表 @@ -1272,7 +1298,9 @@ class UpdatePluginConfigRequest(BaseModel): @router.get("/config/{plugin_id}/schema") -async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: +async def get_plugin_config_schema( + plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +) -> Dict[str, Any]: """ 获取插件配置 Schema @@ -1405,7 +1433,9 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] @router.get("/config/{plugin_id}") -async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: +async def get_plugin_config( + plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +) -> Dict[str, Any]: """ 获取插件当前配置值 @@ -1461,7 +1491,10 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook @router.put("/config/{plugin_id}") async def update_plugin_config( - plugin_id: str, request: UpdatePluginConfigRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) + plugin_id: str, + request: UpdatePluginConfigRequest, + maibot_session: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), ) -> Dict[str, Any]: """ 更新插件配置 @@ -1532,7 +1565,9 @@ async def update_plugin_config( @router.post("/config/{plugin_id}/reset") -async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: +async def reset_plugin_config( + plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +) -> Dict[str, Any]: """ 重置插件配置为默认值 @@ -1592,7 +1627,9 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co @router.post("/config/{plugin_id}/toggle") -async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: +async def toggle_plugin( + plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) +) -> Dict[str, Any]: """ 切换插件启用状态 diff --git a/src/webui/routes.py b/src/webui/routes.py index 96942f1d..36ee8b1f 100644 --- a/src/webui/routes.py +++ b/src/webui/routes.py @@ -139,10 +139,10 @@ async def verify_token(request: TokenVerifyRequest, response: Response): async def logout(response: Response): """ 登出并清除认证 Cookie - + Args: response: FastAPI Response 对象 - + Returns: 登出结果 """ @@ -158,23 +158,23 @@ async def check_auth_status( ): """ 检查当前认证状态(用于前端判断是否已登录) - + Returns: 认证状态 """ try: token = None - + # 优先从 Cookie 获取 if maibot_session: token = maibot_session # 其次从 Header 获取 elif authorization and authorization.startswith("Bearer "): token = authorization.replace("Bearer ", "") - + if not token: return {"authenticated": False} - + token_manager = get_token_manager() if token_manager.verify_token(token): return {"authenticated": True} @@ -211,7 +211,7 @@ async def update_token( current_token = maibot_session elif authorization and authorization.startswith("Bearer "): current_token = authorization.replace("Bearer ", "") - + if not current_token: raise HTTPException(status_code=401, detail="未提供有效的认证信息") @@ -222,7 +222,7 @@ async def update_token( # 更新 token success, message = token_manager.update_token(request.new_token) - + # 如果更新成功,清除 Cookie,要求用户重新登录 if success: clear_auth_cookie(response) @@ -263,7 +263,7 @@ async def regenerate_token( if not current_token: raise HTTPException(status_code=401, detail="未提供有效的认证信息") - + token_manager = get_token_manager() if not token_manager.verify_token(current_token): @@ -271,7 +271,7 @@ async def regenerate_token( # 重新生成 token new_token = token_manager.regenerate_token() - + # 清除 Cookie,要求用户重新登录 clear_auth_cookie(response) @@ -306,7 +306,7 @@ async def get_setup_status( current_token = maibot_session elif authorization and authorization.startswith("Bearer "): current_token = authorization.replace("Bearer ", "") - + if not current_token: raise HTTPException(status_code=401, detail="未提供有效的认证信息") @@ -349,7 +349,7 @@ async def complete_setup( current_token = maibot_session elif authorization and authorization.startswith("Bearer "): current_token = authorization.replace("Bearer ", "") - + if not current_token: raise HTTPException(status_code=401, detail="未提供有效的认证信息") @@ -392,7 +392,7 @@ async def reset_setup( current_token = maibot_session elif authorization and authorization.startswith("Bearer "): current_token = authorization.replace("Bearer ", "") - + if not current_token: raise HTTPException(status_code=401, detail="未提供有效的认证信息") diff --git a/src/webui/token_manager.py b/src/webui/token_manager.py index 17e3f068..bd1e5fbb 100644 --- a/src/webui/token_manager.py +++ b/src/webui/token_manager.py @@ -166,22 +166,22 @@ class TokenManager: str: 新生成的 token """ logger.info("正在重新生成 WebUI Token...") - + # 生成新的 64 位十六进制字符串 new_token = secrets.token_hex(32) - + # 加载现有配置,保留 first_setup_completed 状态 config = self._load_config() old_token = config.get("access_token", "")[:8] if config.get("access_token") else "无" first_setup_completed = config.get("first_setup_completed", True) # 默认为 True,表示已完成配置 - + config["access_token"] = new_token config["updated_at"] = self._get_current_timestamp() config["first_setup_completed"] = first_setup_completed # 保留原来的状态 - + self._save_config(config) logger.info(f"WebUI Token 已重新生成: {old_token}... -> {new_token[:8]}...") - + return new_token def _validate_token_format(self, token: str) -> bool: diff --git a/src/webui/webui_server.py b/src/webui/webui_server.py index 21e79565..87b47192 100644 --- a/src/webui/webui_server.py +++ b/src/webui/webui_server.py @@ -110,8 +110,10 @@ class WebUIServer: from src.webui.routes import router as webui_router from src.webui.logs_ws import router as logs_router from src.webui.knowledge_routes import router as knowledge_router + # 导入本地聊天室路由 from src.webui.chat_routes import router as chat_router + # 注册路由 self.app.include_router(webui_router) self.app.include_router(logs_router) @@ -166,6 +168,7 @@ class WebUIServer: def _check_port_available(self) -> bool: """检查端口是否可用""" import socket + try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.settimeout(1)