diff --git a/src/maisaka/reasoning_engine.py b/src/maisaka/reasoning_engine.py index 16a9132b..dd806e4b 100644 --- a/src/maisaka/reasoning_engine.py +++ b/src/maisaka/reasoning_engine.py @@ -20,7 +20,7 @@ from src.common.database.database import get_db_session from src.common.database.database_model import PersonInfo from src.common.logger import get_logger from src.config.config import global_config -from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation +from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec from src.know_u.knowledge_store import get_knowledge_store from src.learners.jargon_explainer import search_jargon from src.llm_models.exceptions import ReqAbortException @@ -417,6 +417,230 @@ class MaisakaReasoningEngine: metadata={"anchor_message": anchor_message}, ) + @staticmethod + def _normalize_tool_record_value(value: Any) -> Any: + """将工具记录中的任意值规范化为可序列化结构。 + + Args: + value: 原始值。 + + Returns: + Any: 适合写入 JSON 的规范化结果。 + """ + + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, dict): + normalized_dict: dict[str, Any] = {} + for key, item in value.items(): + normalized_dict[str(key)] = MaisakaReasoningEngine._normalize_tool_record_value(item) + return normalized_dict + if isinstance(value, (list, tuple, set)): + return [MaisakaReasoningEngine._normalize_tool_record_value(item) for item in value] + if isinstance(value, bytes): + return f"" + if hasattr(value, "model_dump"): + try: + return MaisakaReasoningEngine._normalize_tool_record_value(value.model_dump()) + except Exception: + return str(value) + if hasattr(value, "__dict__"): + try: + return MaisakaReasoningEngine._normalize_tool_record_value(dict(value.__dict__)) + except Exception: + return str(value) + return str(value) + + @staticmethod + def _truncate_tool_record_text(text: str, max_length: int = 180) -> str: + """截断工具记录中的展示文本。 + + Args: + text: 原始文本。 + max_length: 最长保留字符数。 + + Returns: + str: 截断后的文本。 + """ + + normalized_text = text.strip() + if len(normalized_text) <= max_length: + return normalized_text + return f"{normalized_text[: max_length - 1]}…" + + def _build_tool_record_payload( + self, + invocation: ToolInvocation, + result: ToolExecutionResult, + tool_spec: Optional[ToolSpec], + ) -> dict[str, Any]: + """构造统一工具落库数据。 + + Args: + invocation: 工具调用对象。 + result: 工具执行结果。 + tool_spec: 对应的工具声明。 + + Returns: + dict[str, Any]: 可直接写入数据库的工具记录数据。 + """ + + payload: dict[str, Any] = { + "call_id": invocation.call_id, + "session_id": invocation.session_id, + "stream_id": invocation.stream_id, + "arguments": self._normalize_tool_record_value(invocation.arguments), + "success": result.success, + "content": result.content, + "error_message": result.error_message, + "history_content": result.get_history_content(), + "structured_content": self._normalize_tool_record_value(result.structured_content), + "metadata": self._normalize_tool_record_value(result.metadata), + } + if tool_spec is not None: + payload["provider_name"] = tool_spec.provider_name + payload["provider_type"] = tool_spec.provider_type + payload["brief_description"] = tool_spec.brief_description + payload["detailed_description"] = tool_spec.detailed_description + payload["title"] = tool_spec.title + return payload + + def _build_tool_display_prompt( + self, + invocation: ToolInvocation, + result: ToolExecutionResult, + tool_spec: Optional[ToolSpec], + ) -> str: + """构造展示给历史回放与 UI 的工具摘要。 + + Args: + invocation: 工具调用对象。 + result: 工具执行结果。 + tool_spec: 对应的工具声明。 + + Returns: + str: 用于展示的工具摘要文本。 + """ + + custom_display_prompt = result.metadata.get("record_display_prompt") + if isinstance(custom_display_prompt, str) and custom_display_prompt.strip(): + return custom_display_prompt.strip() + + structured_content = ( + result.structured_content + if isinstance(result.structured_content, dict) + else {} + ) + history_content = self._truncate_tool_record_text(result.get_history_content(), max_length=200) + normalized_args = self._normalize_tool_record_value(invocation.arguments) + + if invocation.tool_name == "reply": + target_user_name = str(structured_content.get("target_user_name") or "对方").strip() or "对方" + reply_text = str(structured_content.get("reply_text") or "").strip() + if result.success and reply_text: + return f"你对{target_user_name}进行了回复:{reply_text}" + target_message_id = str(invocation.arguments.get("msg_id") or "").strip() + error_text = self._truncate_tool_record_text(result.error_message or history_content, max_length=120) + return f"你尝试回复消息 {target_message_id or 'unknown'},但失败了:{error_text}" + + if invocation.tool_name == "send_emoji": + description = str(structured_content.get("description") or "").strip() + emotion_list = structured_content.get("emotion") + if isinstance(emotion_list, list): + emotion_text = "、".join(str(item).strip() for item in emotion_list if str(item).strip()) + else: + emotion_text = "" + if result.success and description: + if emotion_text: + return f"你发送了表情包:{description}(情绪:{emotion_text})" + return f"你发送了表情包:{description}" + return f"你尝试发送表情包,但失败了:{self._truncate_tool_record_text(result.error_message or history_content, 120)}" + + if invocation.tool_name == "wait": + wait_seconds = invocation.arguments.get("seconds", 30) + return f"你让当前对话先等待 {wait_seconds} 秒。" + + if invocation.tool_name == "stop": + return "你暂停了当前对话循环,等待新的外部消息。" + + if invocation.tool_name == "query_jargon": + words = invocation.arguments.get("words", []) + if isinstance(words, list): + words_text = "、".join(str(item).strip() for item in words if str(item).strip()) + else: + words_text = "" + if words_text: + return f"你查询了这些黑话或词条:{words_text}" + return "你查询了一次黑话或词条信息。" + + if invocation.tool_name == "query_person_info": + person_name = str(invocation.arguments.get("person_name") or "").strip() + if person_name: + return f"你查询了人物信息:{person_name}" + return "你查询了一次人物信息。" + + brief_description = "" + if tool_spec is not None: + brief_description = tool_spec.brief_description.strip() + + if normalized_args: + arguments_text = self._truncate_tool_record_text( + json.dumps(normalized_args, ensure_ascii=False), + max_length=160, + ) + else: + arguments_text = "{}" + + if result.success: + if brief_description: + return f"{brief_description} 参数={arguments_text};结果:{history_content or '执行成功'}" + return f"你调用了工具 {invocation.tool_name},参数={arguments_text};结果:{history_content or '执行成功'}" + + error_text = self._truncate_tool_record_text(result.error_message or history_content, max_length=160) + return f"你调用了工具 {invocation.tool_name},参数={arguments_text};执行失败:{error_text}" + + async def _store_tool_execution_record( + self, + invocation: ToolInvocation, + result: ToolExecutionResult, + tool_spec: Optional[ToolSpec], + ) -> None: + """将工具执行结果落库到统一工具记录表。 + + Args: + invocation: 工具调用对象。 + result: 工具执行结果。 + tool_spec: 对应的工具声明。 + """ + + if self._runtime.chat_stream is None: + logger.debug( + f"{self._runtime.log_prefix} 当前没有 chat_stream,跳过工具记录存储: " + f"工具={invocation.tool_name}" + ) + return + + builtin_prompt = "" + if tool_spec is not None: + builtin_prompt = tool_spec.build_llm_description() + + try: + await database_api.store_tool_info( + chat_stream=self._runtime.chat_stream, + builtin_prompt=builtin_prompt, + display_prompt=self._build_tool_display_prompt(invocation, result, tool_spec), + tool_id=invocation.call_id, + tool_data=self._build_tool_record_payload(invocation, result, tool_spec), + tool_name=invocation.tool_name, + tool_reasoning=invocation.reasoning, + ) + except Exception: + logger.exception( + f"{self._runtime.log_prefix} 写入工具记录失败: 工具={invocation.tool_name} 调用编号={invocation.call_id}" + ) + def _append_tool_execution_result(self, tool_call: ToolCall, result: ToolExecutionResult) -> None: """将统一工具执行结果写回 Maisaka 历史。 @@ -435,6 +659,7 @@ class MaisakaReasoningEngine: timestamp=datetime.now(), tool_call_id=tool_call.call_id, tool_name=tool_call.func_name, + success=result.success, ) ) @@ -614,20 +839,29 @@ class MaisakaReasoningEngine: if self._runtime._tool_registry is None: for tool_call in tool_calls: - self._append_tool_execution_result( - tool_call, - ToolExecutionResult( - tool_name=tool_call.func_name, - success=False, - error_message="统一工具注册表尚未初始化。", - ), + invocation = self._build_tool_invocation(tool_call, latest_thought) + result = ToolExecutionResult( + tool_name=tool_call.func_name, + success=False, + error_message="统一工具注册表尚未初始化。", ) + await self._store_tool_execution_record(invocation, result, None) + self._append_tool_execution_result(tool_call, result) return False execution_context = self._build_tool_execution_context(latest_thought, anchor_message) + tool_spec_map = { + tool_spec.name: tool_spec + for tool_spec in await self._runtime._tool_registry.list_tools() + } for tool_call in tool_calls: invocation = self._build_tool_invocation(tool_call, latest_thought) result = await self._runtime._tool_registry.invoke(invocation, execution_context) + await self._store_tool_execution_record( + invocation, + result, + tool_spec_map.get(invocation.tool_name), + ) self._append_tool_execution_result(tool_call, result) if not result.success and tool_call.func_name == "reply": @@ -1015,19 +1249,6 @@ class MaisakaReasoningEngine: or target_user_info.user_nickname or target_user_info.user_id ) - if self._runtime.chat_stream is not None: - await database_api.store_tool_info( - chat_stream=self._runtime.chat_stream, - display_prompt=f"你对{target_user_name}进行了回复:{combined_reply_text}", - tool_data={ - "msg_id": target_message_id, - "quote": quote_reply, - "reply_text": combined_reply_text, - "reply_segments": reply_segments, - }, - tool_name="reply", - tool_reasoning=latest_thought, - ) bot_name = global_config.bot.nickname.strip() or "MaiSaka" reply_timestamp = datetime.now() diff --git a/src/services/database_service.py b/src/services/database_service.py index 7871981d..5e41f2c6 100644 --- a/src/services/database_service.py +++ b/src/services/database_service.py @@ -4,7 +4,7 @@ import json import time import traceback from datetime import datetime -from typing import Any, Optional +from typing import Any, Optional, cast from sqlalchemy import delete, func, select from sqlmodel import SQLModel @@ -65,7 +65,7 @@ async def db_save( record = None if key_field and key_value is not None: key_column = _get_model_field(model_class, key_field) - record = session.exec(select(model_class).where(key_column == key_value)).first() + record = session.exec(cast(Any, select(model_class).where(key_column == key_value))).first() if record is None: record = model_class(**data) @@ -99,7 +99,7 @@ async def db_get( statement = _apply_order_by(statement, model_class, order_by) if limit: statement = statement.limit(limit) - results = session.exec(statement).all() + results = session.exec(cast(Any, statement)).all() data = [_to_dict(item) for item in results] if single_result: return data[0] if data else None @@ -116,7 +116,7 @@ async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters: statement = select(model_class) if conditions := _build_filters(model_class, filters): statement = statement.where(*conditions) - records = session.exec(statement).all() + records = session.exec(cast(Any, statement)).all() for record in records: for field_name, value in data.items(): _get_model_field(model_class, field_name) @@ -149,7 +149,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any] statement = select(func.count()).select_from(model_class) if conditions := _build_filters(model_class, filters): statement = statement.where(*conditions) - result = session.exec(statement).one() + result = session.exec(cast(Any, statement)).one() return int(result or 0) except Exception as e: logger.error(f"[DatabaseService] 统计数据库记录出错: {e}")