From f2b64cc58c049aa44175b1ab195c6763de44c69b Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Tue, 7 Apr 2026 14:30:23 +0800 Subject: [PATCH] =?UTF-8?q?remove=EF=BC=9A=E6=97=A0=E7=94=A8=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=EF=BC=8C=E6=95=B4=E7=90=86=E9=85=8D=E7=BD=AE=E6=96=87?= =?UTF-8?q?=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AGENTS.md | 24 +- src/chat/emoji_system/emoji_manager.py | 2 +- src/chat/image_system/image_manager.py | 2 +- src/chat/replyer/maisaka_replyer_factory.py | 2 +- src/config/config.py | 12 +- src/config/legacy_migration.py | 80 +- src/config/official_configs.py | 177 ++-- src/maisaka/builtin_tool/__init__.py | 2 +- src/maisaka/builtin_tool/query_memory.py | 2 +- src/maisaka/reasoning_engine.py | 2 +- src/maisaka/runtime.py | 4 +- src/memory_system/memory_retrieval.py | 878 ------------------ src/memory_system/retrieval_tools/__init__.py | 32 - .../retrieval_tools/query_long_term_memory.py | 307 ------ .../retrieval_tools/query_words.py | 78 -- .../retrieval_tools/return_information.py | 42 - .../retrieval_tools/tool_registry.py | 167 ---- 17 files changed, 159 insertions(+), 1654 deletions(-) delete mode 100644 src/memory_system/memory_retrieval.py delete mode 100644 src/memory_system/retrieval_tools/__init__.py delete mode 100644 src/memory_system/retrieval_tools/query_long_term_memory.py delete mode 100644 src/memory_system/retrieval_tools/query_words.py delete mode 100644 src/memory_system/retrieval_tools/return_information.py delete mode 100644 src/memory_system/retrieval_tools/tool_registry.py diff --git a/AGENTS.md b/AGENTS.md index c7e66fa5..7efc8f88 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,23 +1,8 @@ -# import 规范 -在从外部库进行导入时候,请遵循以下顺序: -1. 对于标准库和第三方库的导入,请按照如下顺序: - - 需要使用`from ... import ...`语法的导入放在前面。 - - 直接使用`import ...`语法的导入放在后面。 - - 对于使用`from ... import ...`导入的多个项,请**在保证不会引起import错误的前提下**,按照**字母顺序**排列。 - - 对于使用`import ...`导入的多个项,请**在保证不会引起import错误的前提下**,按照**字母顺序**排列。 -2. 对于本地模块的导入,请按照如下顺序: - - 对于同一个文件夹下的模块导入,使用相对导入,排列顺序按照**不发生import错误的前提下**,随便排列。 - - 对于不同文件夹下的模块导入,使用绝对导入。这些导入应该以`from src`开头,并且按照**不发生import错误的前提下**,尽量使得第二层的文件夹名称相同的导入放在一起;第二层文件夹名称排列随机。 -3. 标准库和第三方库的导入应该放在本地模块导入的前面。 -4. 各个导入块之间应该使用一个空行进行分隔。 -5. 对于现有的代码,如果导入顺序不符合上述规范,在重构代码时应该调整导入顺序以符合规范。 - # 代码规范 ## 注释规范 1. 尽量保持良好的注释 2. 如果原来的代码中有注释,则重构的时候,除非这部分代码被删除,否则相同功能的代码应该保留注释(可以对注释进行修改以保持准确性,但不应该删除注释)。 3. 如果原来的代码中没有注释,则重构的时候,如果某个功能块的代码较长或者逻辑较为复杂,则应该添加注释来解释这部分代码的功能和逻辑。 -4. 对于类,方法以及模块的注释,首选使用的注释格式为 Google DocStr 格式,但保证语言为简体中文 ## 类型注解规范 1. 重构代码时,如果原来的代码中有类型注解,则相同功能的代码应该保留类型注解(可以对类型注解进行修改以保持准确性,但不应该删除类型注解)。 2. 重构代码时,如果原来的代码中没有类型注解,则重构的时候,如果某个函数的功能较为复杂或者参数较多,则应该添加类型注解来提高代码的可读性和可维护性。(对于简单的变量,可以不添加类型注解) @@ -34,8 +19,13 @@ # 运行/调试/构建/测试/依赖 优先使用uv 依赖项以 pyproject.toml 为准 -不要修改dashboard下的内容,因为这部分内容由另一个仓库build + # 语言规范 +项目的首选语言为简体中文,无论是注释语言,日志展示语言,还是 WebUI 展示语言都首要以简体中文为首要实现目标 -项目的首选语言为简体中文,无论是注释语言,日志展示语言,还是 WebUI 展示语言都应该首要以简体中文为首要实现目标 +# 配置文件修改 +如果你需要改动配置文件,不需要修改实际的bot_config.toml或者model_config.toml,只需要修改配置文件模版,并新增一个版本号即可,也不必要为配置改动创建测试文件。 + +# 关于webui修改 +不要修改dashboard下的内容,因为这部分内容由另一个仓库build \ No newline at end of file diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index b03b3505..0dccea06 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -501,7 +501,7 @@ class EmojiManager: existing_record = session.exec(statement).first() if existing_record: if existing_record.is_registered and _is_available_emoji_record(existing_record): - logger.info(f"[register_emoji] Emoji already registered, skipping: {emoji.file_hash}") + # logger.info(f"[register_emoji] Emoji already registered, skipping: {emoji.file_hash}") return "skipped" normalized_description = str(emoji.description or existing_record.description or "").strip() diff --git a/src/chat/image_system/image_manager.py b/src/chat/image_system/image_manager.py index 81b95b38..a45a3c97 100644 --- a/src/chat/image_system/image_manager.py +++ b/src/chat/image_system/image_manager.py @@ -369,7 +369,7 @@ class ImageManager: logger.info(f"Cleaned mistaken registration state on {fixed_counter} image records") async def _generate_image_description(self, image_bytes: bytes, image_format: str) -> str: - prompt = global_config.personality.visual_style + prompt = global_config.visual.visual_style image_base64 = base64.b64encode(image_bytes).decode("utf-8") generation_result = await vlm.generate_response_for_image( diff --git a/src/chat/replyer/maisaka_replyer_factory.py b/src/chat/replyer/maisaka_replyer_factory.py index e511143c..e8194cfa 100644 --- a/src/chat/replyer/maisaka_replyer_factory.py +++ b/src/chat/replyer/maisaka_replyer_factory.py @@ -18,4 +18,4 @@ def get_maisaka_replyer_class() -> Type[object]: def get_maisaka_replyer_generator_type() -> str: """返回当前配置的 Maisaka replyer 生成器类型。""" - return global_config.chat.replyer_generator_type + return "multimodal" if global_config.visual.multimodal_replyer else "legacy" diff --git a/src/config/config.py b/src/config/config.py index f8b5f91b..e4412f2d 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -27,14 +27,15 @@ from .official_configs import ( MaiSakaConfig, MaimMessageConfig, MCPConfig, - PluginRuntimeConfig, MemoryConfig, MessageReceiveConfig, PersonalityConfig, + PluginRuntimeConfig, RelationshipConfig, ResponsePostProcessConfig, ResponseSplitterConfig, TelemetryConfig, + VisualConfig, VoiceConfig, WebUIConfig, ) @@ -55,7 +56,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config" BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute() MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute() MMC_VERSION: str = "1.0.0" -CONFIG_VERSION: str = "8.3.4" +CONFIG_VERSION: str = "8.4.1" MODEL_CONFIG_VERSION: str = "1.13.1" logger = get_logger("config") @@ -73,6 +74,9 @@ class Config(ConfigBase): personality: PersonalityConfig = Field(default_factory=PersonalityConfig) """人格配置类""" + visual: VisualConfig = Field(default_factory=VisualConfig) + """视觉配置类""" + expression: ExpressionConfig = Field(default_factory=ExpressionConfig) """表达配置类""" @@ -474,6 +478,10 @@ def load_config_from_file( if env_migration.migrated: logger.warning(f"检测到旧版环境变量绑定配置,已迁移到主配置: {env_migration.reason}") config_data = env_migration.data + legacy_migration = try_migrate_legacy_bot_config_dict(config_data) + if legacy_migration.migrated: + logger.warning(t("config.legacy_migrated", reason=legacy_migration.reason)) + config_data = legacy_migration.data # 保留一份“干净”的原始数据副本,避免第一次 from_dict 过程中对 dict 的就地修改 original_data: dict[str, Any] = copy.deepcopy(config_data) try: diff --git a/src/config/legacy_migration.py b/src/config/legacy_migration.py index bdae3254..6ffb75b5 100644 --- a/src/config/legacy_migration.py +++ b/src/config/legacy_migration.py @@ -75,6 +75,18 @@ def _migrate_env_value(section: dict[str, Any], key: str, parsed_env_value: Any, return True +def _move_section_key(source: dict[str, Any], target: dict[str, Any], key: str) -> bool: + """将配置项从旧分组移动到新分组,若新分组已有值则保留新值。""" + + if key not in source: + return False + + if key not in target: + target[key] = source[key] + source.pop(key, None) + return True + + def _parse_triplet_target(s: str) -> Optional[dict[str, str]]: """ 解析 "platform:id:type" -> {platform,item_id,rule_type} @@ -273,6 +285,21 @@ def _migrate_extra_prompt_list(exp: dict[str, Any], key: str) -> bool: return True +def _parse_multimodal_replyer(v: Any) -> Optional[bool]: + """兼容旧 replyer_generator_type 到布尔开关的迁移。""" + if isinstance(v, bool): + return v + if not isinstance(v, str): + return None + + normalized_value = v.strip().lower() + if normalized_value == "multimodal": + return True + if normalized_value == "legacy": + return False + return None + + def migrate_legacy_bind_env_to_bot_config_dict(data: dict[str, Any]) -> MigrationResult: """将旧版环境变量中的绑定地址迁移到主配置结构。""" @@ -351,12 +378,61 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult: migrated_any = True reasons.append("chat.private_plan_style_removed") + personality = _as_dict(data.get("personality")) + visual = _as_dict(data.get("visual")) + if visual is None and ( + (personality is not None and "visual_style" in personality) + or "multimodal_planner" in chat + or "replyer_generator_type" in chat + ): + visual = {} + data["visual"] = visual + + if visual is not None and personality is not None and "visual_style" in personality: + if "visual_style" not in visual: + visual["visual_style"] = personality["visual_style"] + personality.pop("visual_style", None) + migrated_any = True + reasons.append("personality.visual_style_moved_to_visual.visual_style") + + if visual is not None and "multimodal_planner" in chat: + if "multimodal_planner" not in visual and isinstance(chat["multimodal_planner"], bool): + visual["multimodal_planner"] = chat["multimodal_planner"] + if "multimodal_planner" in visual: + chat.pop("multimodal_planner", None) + migrated_any = True + reasons.append("chat.multimodal_planner_moved_to_visual.multimodal_planner") + + if visual is not None and "replyer_generator_type" in chat: + multimodal_replyer = _parse_multimodal_replyer(chat["replyer_generator_type"]) + if "multimodal_replyer" not in visual and multimodal_replyer is not None: + visual["multimodal_replyer"] = multimodal_replyer + if "multimodal_replyer" in visual: + chat.pop("replyer_generator_type", None) + migrated_any = True + reasons.append("chat.replyer_generator_type_moved_to_visual.multimodal_replyer") + + maisaka = _as_dict(data.get("maisaka")) mem = _as_dict(data.get("memory")) + if maisaka is not None: + moved_memory_keys = ("enable_memory_query_tool", "memory_query_default_limit") + if any(key in maisaka for key in moved_memory_keys) and mem is None: + mem = {} + data["memory"] = mem + + if mem is not None: + for moved_key in moved_memory_keys: + if _move_section_key(maisaka, mem, moved_key): + migrated_any = True + reasons.append(f"maisaka.{moved_key}_moved_to_memory") + if mem is not None: + if _migrate_target_item_list(mem, "global_memory_blacklist"): + migrated_any = True + reasons.append("memory.global_memory_blacklist") + for removed_key in ( "agent_timeout_seconds", - "global_memory", - "global_memory_blacklist", "max_agent_iterations", ): if removed_key in mem: diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 03399994..6a6bc627 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -113,15 +113,6 @@ class PersonalityConfig(ConfigBase): ) """每次构建回复时,从 multiple_reply_style 中随机替换 reply_style 的概率(0.0-1.0)""" - visual_style: str = Field( - default="请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本", - json_schema_extra={ - "x-widget": "textarea", - "x-icon": "image", - }, - ) - """_wrap_识图提示词,不建议修改""" - states: list[str] = Field( default_factory=lambda: [ "是一个女大学生,喜欢上网聊天,会刷小红书。", @@ -148,6 +139,40 @@ class PersonalityConfig(ConfigBase): """状态概率,每次构建人格时替换personality的概率""" +class VisualConfig(ConfigBase): + """视觉配置类""" + + __ui_label__ = "视觉" + __ui_icon__ = "image" + + multimodal_planner: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "image", + }, + ) + """是否直接输入图片""" + + multimodal_replyer: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "git-branch", + }, + ) + """是否启用 Maisaka 多模态 replyer 生成器""" + + visual_style: str = Field( + default="请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本", + json_schema_extra={ + "x-widget": "textarea", + "x-icon": "image", + }, + ) + """_wrap_识图提示词,不建议修改""" + + class RelationshipConfig(ConfigBase): """关系配置类""" @@ -253,24 +278,6 @@ class ChatConfig(ConfigBase): }, ) - multimodal_planner: bool = Field( - default=True, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "image", - }, - ) - """是否直接输入图片""" - - replyer_generator_type: Literal["legacy", "multimodal"] = Field( - default="legacy", - json_schema_extra={ - "x-widget": "select", - "x-icon": "git-branch", - }, - ) - """Maisaka replyer 生成器类型:legacy(旧版单 prompt)/ multimodal(多模态版,适合主循环直接展示图片)""" - enable_talk_value_rules: bool = Field( default=True, json_schema_extra={ @@ -373,24 +380,6 @@ class MemoryConfig(ConfigBase): __ui_parent__ = "emoji" - max_agent_iterations: int = Field( - default=5, - ge=1, - json_schema_extra={ - "x-widget": "input", - "x-icon": "layers", - }, - ) - """记忆思考深度(最低为1)""" - - agent_timeout_seconds: float = Field( - default=120.0, - json_schema_extra={ - "x-widget": "input", - "x-icon": "clock", - }, - ) - """最长回忆时间(秒)""" global_memory: bool = Field( default=False, @@ -410,6 +399,26 @@ class MemoryConfig(ConfigBase): ) """_wrap_全局记忆黑名单,当启用全局记忆时,不将特定聊天流纳入检索""" + enable_memory_query_tool: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "database", + }, + ) + """是否启用 Maisaka 内置长期记忆检索工具 query_memory""" + + memory_query_default_limit: int = Field( + default=5, + ge=1, + le=20, + json_schema_extra={ + "x-widget": "input", + "x-icon": "hash", + }, + ) + """Maisaka 内置长期记忆检索工具 query_memory 的默认返回条数""" + long_term_auto_summary_enabled: bool = Field( default=True, json_schema_extra={ @@ -657,16 +666,6 @@ class ExpressionConfig(ConfigBase): ) """是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除""" - enable_jargon_explanation: bool = Field( - default=True, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "info", - }, - ) - """是否在回复前尝试对上下文中的黑话进行解释(关闭可减少一次LLM调用,仅影响回复前的黑话匹配与解释,不影响黑话学习)""" - - class VoiceConfig(ConfigBase): """语音识别配置类""" @@ -700,7 +699,7 @@ class EmojiConfig(ConfigBase): """一次从多少个表情包中选择发送,最大为 64""" max_reg_num: int = Field( - default=100, + default=64, json_schema_extra={ "x-widget": "input", "x-icon": "hash", @@ -988,23 +987,6 @@ class DebugConfig(ConfigBase): ) """是否显示prompt""" - show_replyer_prompt: bool = Field( - default=True, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "message-square", - }, - ) - """是否显示回复器prompt""" - - show_replyer_reasoning: bool = Field( - default=True, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "brain", - }, - ) - show_maisaka_thinking: bool = Field( default=True, json_schema_extra={ @@ -1041,25 +1023,6 @@ class DebugConfig(ConfigBase): ) """是否显示记忆检索相关prompt""" - show_planner_prompt: bool = Field( - default=False, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "map", - }, - ) - """是否显示planner的prompt和原始返回结果""" - - show_lpmm_paragraph: bool = Field( - default=False, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "file-text", - }, - ) - """是否显示lpmm找到的相关文段日志""" - - class ExtraPromptItem(ConfigBase): platform: str = Field( default="", @@ -1511,16 +1474,6 @@ class MaiSakaConfig(ConfigBase): ) """MaiSaka 使用的用户名称""" - max_internal_rounds: int = Field( - default=6, - ge=1, - json_schema_extra={ - "x-widget": "input", - "x-icon": "repeat", - }, - ) - """每个入站消息的最大内部规划轮数""" - planner_interrupt_max_consecutive_count: int = Field( default=2, ge=0, @@ -1531,26 +1484,6 @@ class MaiSakaConfig(ConfigBase): ) """Planner 连续被新消息打断的最大次数,0 表示不启用打断""" - enable_memory_query_tool: bool = Field( - default=True, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "database", - }, - ) - """是否启用 Maisaka 内置长期记忆检索工具 query_memory""" - - memory_query_default_limit: int = Field( - default=5, - ge=1, - le=20, - json_schema_extra={ - "x-widget": "input", - "x-icon": "hash", - }, - ) - """Maisaka 内置长期记忆检索工具 query_memory 的默认返回条数""" - tool_filter_task_name: str = Field( default="utils", json_schema_extra={ diff --git a/src/maisaka/builtin_tool/__init__.py b/src/maisaka/builtin_tool/__init__.py index 8af1af20..8d6454ba 100644 --- a/src/maisaka/builtin_tool/__init__.py +++ b/src/maisaka/builtin_tool/__init__.py @@ -47,7 +47,7 @@ def get_action_tool_specs() -> List[ToolSpec]: get_reply_tool_spec(), get_view_complex_message_tool_spec(), get_query_jargon_tool_spec(), - get_query_memory_tool_spec(enabled=bool(global_config.maisaka.enable_memory_query_tool)), + get_query_memory_tool_spec(enabled=bool(global_config.memory.enable_memory_query_tool)), get_send_emoji_tool_spec(), ] diff --git a/src/maisaka/builtin_tool/query_memory.py b/src/maisaka/builtin_tool/query_memory.py index 33bd1ace..3bd4b587 100644 --- a/src/maisaka/builtin_tool/query_memory.py +++ b/src/maisaka/builtin_tool/query_memory.py @@ -161,7 +161,7 @@ async def handle_tool( f"不支持的检索模式:{mode}。可选值:search/time/hybrid/episode/aggregate。", ) - default_limit = max(1, int(getattr(global_config.maisaka, "memory_query_default_limit", 5) or 5)) + default_limit = max(1, global_config.memory.memory_query_default_limit) try: limit = int(invocation.arguments.get("limit", default_limit) or default_limit) except (TypeError, ValueError): diff --git a/src/maisaka/reasoning_engine.py b/src/maisaka/reasoning_engine.py index 583865f6..6269b8ec 100644 --- a/src/maisaka/reasoning_engine.py +++ b/src/maisaka/reasoning_engine.py @@ -596,7 +596,7 @@ class MaisakaReasoningEngine: planner_prefix: str, ) -> MessageSequence: message_sequence = build_prefixed_message_sequence(message.raw_message, planner_prefix) - if global_config.chat.multimodal_planner: + if global_config.visual.multimodal_planner: await self._hydrate_visual_components(message_sequence.components) return message_sequence diff --git a/src/maisaka/runtime.py b/src/maisaka/runtime.py index 47bfb9e6..1c2935d2 100644 --- a/src/maisaka/runtime.py +++ b/src/maisaka/runtime.py @@ -37,6 +37,8 @@ from .tool_provider import MaisakaBuiltinToolProvider logger = get_logger("maisaka_runtime") +MAX_INTERNAL_ROUNDS = 6 + class MaisakaHeartFlowChatting: """会话级别的 Maisaka 运行时。""" @@ -78,7 +80,7 @@ class MaisakaHeartFlowChatting: self._message_debounce_required = False self._last_message_received_at = 0.0 self._wait_timeout_task: Optional[asyncio.Task[None]] = None - self._max_internal_rounds = global_config.maisaka.max_internal_rounds + self._max_internal_rounds = MAX_INTERNAL_ROUNDS self._max_context_size = max(1, int(global_config.chat.max_context_size)) self._agent_state: Literal["running", "wait", "stop"] = self._STATE_STOP self._wait_until: Optional[float] = None diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py deleted file mode 100644 index 9bbc0267..00000000 --- a/src/memory_system/memory_retrieval.py +++ /dev/null @@ -1,878 +0,0 @@ -import asyncio -import contextlib -import json -import time -from typing import Any, Callable, Dict, List, Optional, Tuple - -from src.common.logger import get_logger -from src.config.config import global_config -from src.prompt.prompt_manager import prompt_manager -from src.services import llm_service as llm_api - -from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools -from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon - -logger = get_logger("memory_retrieval") - - -def init_memory_retrieval_sys(): - """初始化记忆检索相关工具""" - # 注册所有工具 - init_all_tools() - - -def _log_conversation_messages( - conversation_messages: List[Message], - head_prompt: Optional[str] = None, - final_status: Optional[str] = None, -) -> None: - """输出对话消息列表的日志 - - Args: - conversation_messages: 对话消息列表 - head_prompt: 第一条系统消息(head_prompt)的内容,可选 - final_status: 最终结果状态描述(例如:找到答案/未找到答案),可选 - """ - if not global_config.debug.show_memory_prompt: - return - - log_lines: List[str] = [] - - # 如果有head_prompt,先添加为第一条消息 - if head_prompt: - msg_info = "========================================\n[消息 1] 角色: System\n-----------------------------" - msg_info += f"\n{head_prompt}" - log_lines.append(msg_info) - start_idx = 2 - else: - start_idx = 1 - - if not conversation_messages and not head_prompt: - return - - for idx, msg in enumerate(conversation_messages, start_idx): - role_name = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - - # 构建单条消息的日志信息 - # msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------" - msg_info = ( - f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------" - ) - - # if full_content: - # msg_info += f"\n{full_content}" - if msg.content: - msg_info += f"\n{msg.content}" - - if msg.tool_calls: - msg_info += f"\n 工具调用: {len(msg.tool_calls)}个" - for tool_call in msg.tool_calls: - msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}" - - # if msg.tool_call_id: - # msg_info += f"\n 工具调用ID: {msg.tool_call_id}" - - log_lines.append(msg_info) - - total_count = len(conversation_messages) + (1 if head_prompt else 0) - log_text = f"消息列表 (共{total_count}条):{''.join(log_lines)}" - if final_status: - log_text += f"\n\n[最终结果] {final_status}" - logger.info(log_text) - - -async def _react_agent_solve_question( - chat_id: str, - max_iterations: int = 5, - timeout: float = 30.0, - initial_info: str = "", - chat_history: str = "", -) -> Tuple[bool, str, List[Dict[str, Any]], bool]: - """使用ReAct架构的Agent来解决问题 - - Args: - chat_id: 聊天ID - max_iterations: 最大迭代次数 - timeout: 超时时间(秒) - initial_info: 初始信息,将作为collected_info的初始值 - chat_history: 聊天记录,将传递给 ReAct Agent prompt - - Returns: - Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时) - """ - start_time = time.time() - collected_info = initial_info or "" - # 构造日志前缀:[聊天流名称],用于在日志中标识聊天流 - try: - chat_name = _chat_manager.get_session_name(chat_id) or chat_id - except Exception: - chat_name = chat_id - react_log_prefix = f"[{chat_name}] " - thinking_steps = [] - is_timeout = False - conversation_messages: List[Message] = [] - first_head_prompt: Optional[str] = None # 保存第一次使用的head_prompt(用于日志显示) - last_tool_name: Optional[str] = None # 记录最后一次使用的工具名称 - - # 使用 while 循环,支持额外迭代 - iteration = 0 - max_iterations_with_extra = max_iterations - while iteration < max_iterations_with_extra: - # 检查超时 - if time.time() - start_time > timeout: - logger.warning(f"ReAct Agent超时,已迭代{iteration}次") - is_timeout = True - break - - # 获取工具注册器 - tool_registry = get_tool_registry() - - # 获取bot_name - bot_name = global_config.bot.nickname - - # 获取当前时间 - time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - - # 计算剩余迭代次数 - current_iteration = iteration + 1 - remaining_iterations = max_iterations - current_iteration - - # 提取函数调用中参数的值,支持单引号和双引号 - def extract_quoted_content(text, func_name, param_name): - """从文本中提取函数调用中参数的值,支持单引号和双引号 - - Args: - text: 要搜索的文本 - func_name: 函数名,如 'return_information' - param_name: 参数名,如 'information' - - Returns: - 提取的参数值,如果未找到则返回None - """ - if not text: - return None - - # 查找函数调用位置(不区分大小写) - func_pattern = func_name.lower() - text_lower = text.lower() - func_pos = text_lower.find(func_pattern) - if func_pos == -1: - return None - - # 查找参数名和等号 - param_pattern = f"{param_name}=" - param_pos = text_lower.find(param_pattern, func_pos) - if param_pos == -1: - return None - - # 跳过参数名、等号和空白 - start_pos = param_pos + len(param_pattern) - while start_pos < len(text) and text[start_pos] in " \t\n": - start_pos += 1 - - if start_pos >= len(text): - return None - - # 确定引号类型 - quote_char = text[start_pos] - if quote_char not in ['"', "'"]: - return None - - # 查找匹配的结束引号(考虑转义) - end_pos = start_pos + 1 - while end_pos < len(text): - if text[end_pos] == quote_char: - # 检查是否是转义的引号 - if end_pos > start_pos + 1 and text[end_pos - 1] == "\\": - end_pos += 1 - continue - # 找到匹配的引号 - content = text[start_pos + 1 : end_pos] - # 处理转义字符 - content = content.replace('\\"', '"').replace("\\'", "'").replace("\\\\", "\\") - return content - end_pos += 1 - - return None - - # 正常迭代:使用head_prompt决定调用哪些工具(包含return_information工具) - tool_definitions = tool_registry.get_tool_definitions() - # tool_names = [tool_def["name"] for tool_def in tool_definitions] - # logger.debug(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具: {', '.join(tool_names)} (共{len(tool_definitions)}个)") - - # head_prompt应该只构建一次,使用初始的collected_info,后续迭代都复用同一个 - if first_head_prompt is None: - # 第一次构建,使用初始的collected_info(即initial_info) - initial_collected_info = initial_info or "" - # 使用统一长期记忆检索 prompt - first_head_prompt_template = prompt_manager.get_prompt("memory_retrieval_react_prompt_head_memory") - first_head_prompt_template.add_context("bot_name", bot_name) - first_head_prompt_template.add_context("time_now", time_now) - first_head_prompt_template.add_context("chat_history", chat_history) - first_head_prompt_template.add_context("collected_info", initial_collected_info) - first_head_prompt_template.add_context("current_iteration", str(current_iteration)) - first_head_prompt_template.add_context("remaining_iterations", str(remaining_iterations)) - first_head_prompt_template.add_context("max_iterations", str(max_iterations)) - first_head_prompt = await prompt_manager.render_prompt(first_head_prompt_template) - - # 后续迭代都复用第一次构建的head_prompt - head_prompt = first_head_prompt - - def _build_messages( - _client, - *, - _head_prompt: str = head_prompt, - _conversation_messages: List[Message] = conversation_messages, - ): - messages: List[Message] = [] - - system_builder = MessageBuilder() - system_builder.set_role(RoleType.System) - system_builder.add_text_content(_head_prompt) - messages.append(system_builder.build()) - - messages.extend(_conversation_messages) - - return messages - - message_factory_fn: Callable[..., List[Message]] = _build_messages # pyright: ignore[reportGeneralTypeIssues] - generation_result = await llm_api.generate( - llm_api.LLMServiceRequest( - task_name="utils", - request_type="memory.react", - message_factory=message_factory_fn, # type: ignore[arg-type] - tool_options=tool_definitions, - ) - ) - success = generation_result.success - response = generation_result.completion.response - reasoning_content = generation_result.completion.reasoning - tool_calls = generation_result.completion.tool_calls - - # logger.info( - # f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}" - # ) - - if not success: - logger.error(f"ReAct Agent LLM调用失败: {response}") - break - - # 注意:这里会检查return_information工具调用,如果检测到return_information工具,会根据information参数决定返回信息或退出查询 - - assistant_message: Optional[Message] = None - if tool_calls: - assistant_builder = MessageBuilder() - assistant_builder.set_role(RoleType.Assistant) - if response and response.strip(): - assistant_builder.add_text_content(response) - assistant_builder.set_tool_calls(tool_calls) - assistant_message = assistant_builder.build() - elif response and response.strip(): - assistant_builder = MessageBuilder() - assistant_builder.set_role(RoleType.Assistant) - assistant_builder.add_text_content(response) - assistant_message = assistant_builder.build() - - # 记录思考步骤 - step: Dict[str, Any] = { - "iteration": iteration + 1, - "thought": response, - "actions": [], - "observations": [], - } - - if assistant_message: - conversation_messages.append(assistant_message) - - # 记录思考过程到collected_info中 - if reasoning_content or response: - thought_summary = reasoning_content or (response[:200] if response else "") - if thought_summary: - collected_info += f"\n[思考] {thought_summary}\n" - - # 处理工具调用 - if not tool_calls: - # 如果没有工具调用,检查响应文本中是否包含return_information函数调用格式或JSON格式 - if response and response.strip(): - # 首先尝试解析JSON格式的return_information - def parse_json_return_information(text: str): - """从文本中解析JSON格式的return_information,返回information字符串,如果未找到则返回None""" - if not text: - return None, None - - try: - # 尝试提取JSON对象(可能包含在代码块中或直接是JSON) - json_text = text.strip() - - # 如果包含代码块标记,提取JSON部分 - if "```json" in json_text: - start = json_text.find("```json") + 7 - end = json_text.find("```", start) - if end != -1: - json_text = json_text[start:end].strip() - elif "```" in json_text: - start = json_text.find("```") + 3 - end = json_text.find("```", start) - if end != -1: - json_text = json_text[start:end].strip() - - # 尝试解析JSON - data = json.loads(json_text) - - # 检查是否包含return_information字段 - if isinstance(data, dict) and "return_information" in data: - information = data.get("information", "") - return information - except (json.JSONDecodeError, ValueError, TypeError): - # 如果JSON解析失败,尝试在文本中查找JSON对象 - with contextlib.suppress(json.JSONDecodeError, ValueError, TypeError): - # 查找第一个 { 和最后一个 } 之间的内容(更健壮的JSON提取) - first_brace = text.find("{") - if first_brace != -1: - # 从第一个 { 开始,找到匹配的 } - brace_count = 0 - json_end = -1 - for i in range(first_brace, len(text)): - if text[i] == "{": - brace_count += 1 - elif text[i] == "}": - brace_count -= 1 - if brace_count == 0: - json_end = i + 1 - break - - if json_end != -1: - json_text = text[first_brace:json_end] - data = json.loads(json_text) - if isinstance(data, dict) and "return_information" in data: - information = data.get("information", "") - return information - - return None - - # 尝试从文本中解析return_information函数调用 - def parse_return_information_from_text(text: str): - """从文本中解析return_information函数调用,返回information字符串,如果未找到则返回None""" - if not text: - return None - - # 查找return_information函数调用位置(不区分大小写) - func_pattern = "return_information" - text_lower = text.lower() - func_pos = text_lower.find(func_pattern) - if func_pos == -1: - return None - - # 解析information参数(字符串,使用extract_quoted_content) - information = extract_quoted_content(text, "return_information", "information") - - # 如果information存在(即使是空字符串),也返回它 - return information - - # 首先尝试解析JSON格式 - parsed_information_json = parse_json_return_information(response) - is_json_format = parsed_information_json is not None - - # 如果JSON解析成功,使用JSON结果 - if is_json_format: - parsed_information = parsed_information_json - else: - # 如果JSON解析失败,尝试解析函数调用格式 - parsed_information = parse_return_information_from_text(response) - - if parsed_information is not None or is_json_format: - # 检测到return_information格式(可能是JSON格式或函数调用格式) - format_type = "JSON格式" if is_json_format else "函数调用格式" - # 返回信息(即使为空字符串也返回) - step["actions"].append( - { - "action_type": "return_information", - "action_params": {"information": parsed_information or ""}, - } - ) - parsed_info_text = parsed_information if isinstance(parsed_information, str) else "" - if parsed_info_text.strip(): - step["observations"] = [f"检测到return_information{format_type}调用,返回信息"] - thinking_steps.append(step) - logger.info( - f"{react_log_prefix}第 {iteration + 1} 次迭代 通过return_information{format_type}返回信息: {parsed_info_text[:100]}..." - ) - - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status=f"返回信息:{parsed_info_text}", - ) - - return True, parsed_info_text, thinking_steps, False - else: - # 信息为空,直接退出查询 - step["observations"] = [f"检测到return_information{format_type}调用,信息为空"] - thinking_steps.append(step) - logger.info( - f"{react_log_prefix}第 {iteration + 1} 次迭代 通过return_information{format_type}判断信息为空" - ) - - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status="信息为空:通过return_information文本格式判断信息为空", - ) - - return False, "", thinking_steps, False - - # 如果没有检测到return_information格式,记录思考过程,继续下一轮迭代 - step["observations"] = [f"思考完成,但未调用工具。响应: {response}"] - logger.info(f"{react_log_prefix}第 {iteration + 1} 次迭代 思考完成但未调用工具: {response}") - collected_info += f"思考: {response}" - else: - logger.warning(f"{react_log_prefix}第 {iteration + 1} 次迭代 无工具调用且无响应") - step["observations"] = ["无响应且无工具调用"] - thinking_steps.append(step) - iteration += 1 # 在continue之前增加迭代计数,避免跳过iteration += 1 - continue - - # 处理工具调用 - # 首先检查是否有return_information工具调用,如果有则立即返回,不再处理其他工具 - return_information_info = None - for tool_call in tool_calls: - tool_name = tool_call.func_name - tool_args = tool_call.args or {} - - if tool_name == "return_information": - return_information_info = tool_args.get("information", "") - - # 返回信息(即使为空也返回) - step["actions"].append( - { - "action_type": "return_information", - "action_params": {"information": return_information_info}, - } - ) - if return_information_info and return_information_info.strip(): - # 有信息,返回 - step["observations"] = ["检测到return_information工具调用,返回信息"] - thinking_steps.append(step) - logger.info( - f"{react_log_prefix}第 {iteration + 1} 次迭代 通过return_information工具返回信息: {return_information_info}" - ) - - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status=f"返回信息:{return_information_info}", - ) - - return True, return_information_info, thinking_steps, False - else: - # 信息为空,直接退出查询 - step["observations"] = ["检测到return_information工具调用,信息为空"] - thinking_steps.append(step) - logger.info(f"{react_log_prefix}第 {iteration + 1} 次迭代 通过return_information工具判断信息为空") - - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status="信息为空:通过return_information工具判断信息为空", - ) - - return False, "", thinking_steps, False - - # 如果没有return_information工具调用,继续处理其他工具 - tool_tasks = [] - for i, tool_call in enumerate(tool_calls): - tool_name = tool_call.func_name - tool_args = tool_call.args or {} - - logger.debug( - f"{react_log_prefix}第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})" - ) - - # 跳过return_information工具调用(已经在上面处理过了) - if tool_name == "return_information": - continue - - # 记录最后一次使用的工具名称(用于判断是否需要额外迭代) - last_tool_name = tool_name - - # 普通工具调用 - tool = tool_registry.get_tool(tool_name) - if tool: - # 准备工具参数(需要添加chat_id如果工具需要) - import inspect - - sig = inspect.signature(tool.execute_func) - tool_params = tool_args.copy() - if "chat_id" in sig.parameters: - tool_params["chat_id"] = chat_id - - # 创建异步任务 - async def execute_single_tool(tool_instance, params, tool_name_str, iter_num): - try: - observation = await tool_instance.execute(**params) - param_str = ", ".join([f"{k}={v}" for k, v in params.items() if k != "chat_id"]) - return f"查询{tool_name_str}({param_str})的结果:{observation}" - except Exception as e: - error_msg = f"工具执行失败: {str(e)}" - logger.error(f"{react_log_prefix}第 {iter_num + 1} 次迭代 工具 {tool_name_str} {error_msg}") - return f"查询{tool_name_str}失败: {error_msg}" - - tool_tasks.append(execute_single_tool(tool, tool_params, tool_name, iteration)) - step["actions"].append({"action_type": tool_name, "action_params": tool_args}) - else: - error_msg = f"未知的工具类型: {tool_name}" - logger.warning( - f"{react_log_prefix}第 {iteration + 1} 次迭代 工具 {i + 1}/{len(tool_calls)} {error_msg}" - ) - tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{tool_name}失败: {error_msg}"))) - - # 并行执行所有工具 - if tool_tasks: - observations = await asyncio.gather(*tool_tasks, return_exceptions=True) - - # 处理执行结果 - for i, (tool_call_item, observation) in enumerate(zip(tool_calls, observations, strict=False)): - if isinstance(observation, Exception): - observation = f"工具执行异常: {str(observation)}" - logger.error(f"{react_log_prefix}第 {iteration + 1} 次迭代 工具 {i + 1} 执行异常: {observation}") - - observation_text = observation if isinstance(observation, str) else str(observation) - stripped_observation = observation_text.strip() - step["observations"].append(observation_text) - collected_info += f"\n{observation_text}\n" - if stripped_observation: - # 不再自动检测工具输出中的jargon,改为通过 query_words 工具主动查询 - tool_builder = MessageBuilder() - tool_builder.set_role(RoleType.Tool) - tool_builder.add_text_content(observation_text) - tool_builder.add_tool_call(tool_call_item.call_id) - conversation_messages.append(tool_builder.build()) - - thinking_steps.append(step) - - # 检查是否需要额外迭代:如果最后一次使用的工具是 search_chat_history 且达到最大迭代次数,额外增加一回合 - if iteration + 1 >= max_iterations and last_tool_name == "search_chat_history" and not is_timeout: - max_iterations_with_extra = max_iterations + 1 - logger.info( - f"{react_log_prefix}达到最大迭代次数(已迭代{iteration + 1}次),最后一次使用工具为 search_chat_history,额外增加一回合尝试" - ) - - iteration += 1 - - # 正常迭代结束后,如果达到最大迭代次数或超时,执行最终评估 - # 最终评估单独处理,不算在迭代中 - should_do_final_evaluation = False - if is_timeout: - should_do_final_evaluation = True - logger.warning(f"{react_log_prefix}超时,已迭代{iteration}次,进入最终评估") - elif iteration >= max_iterations: - should_do_final_evaluation = True - logger.info(f"{react_log_prefix}达到最大迭代次数(已迭代{iteration}次),进入最终评估") - - if should_do_final_evaluation: - # 获取必要变量用于最终评估 - tool_registry = get_tool_registry() - bot_name = global_config.bot.nickname - time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - current_iteration = iteration + 1 - remaining_iterations = 0 - - # 提取函数调用中参数的值,支持单引号和双引号 - def extract_quoted_content(text, func_name, param_name): - """从文本中提取函数调用中参数的值,支持单引号和双引号 - - Args: - text: 要搜索的文本 - func_name: 函数名,如 'return_information' - param_name: 参数名,如 'information' - - Returns: - 提取的参数值,如果未找到则返回None - """ - if not text: - return None - - # 查找函数调用位置(不区分大小写) - func_pattern = func_name.lower() - text_lower = text.lower() - func_pos = text_lower.find(func_pattern) - if func_pos == -1: - return None - - # 查找参数名和等号 - param_pattern = f"{param_name}=" - param_pos = text_lower.find(param_pattern, func_pos) - if param_pos == -1: - return None - - # 跳过参数名、等号和空白 - start_pos = param_pos + len(param_pattern) - while start_pos < len(text) and text[start_pos] in " \t\n": - start_pos += 1 - - if start_pos >= len(text): - return None - - # 确定引号类型 - quote_char = text[start_pos] - if quote_char not in ['"', "'"]: - return None - - # 查找匹配的结束引号(考虑转义) - end_pos = start_pos + 1 - while end_pos < len(text): - if text[end_pos] == quote_char: - # 检查是否是转义的引号 - if end_pos > start_pos + 1 and text[end_pos - 1] == "\\": - end_pos += 1 - continue - # 找到匹配的引号 - content = text[start_pos + 1 : end_pos] - # 处理转义字符 - content = content.replace('\\"', '"').replace("\\'", "'").replace("\\\\", "\\") - return content - end_pos += 1 - - return None - - # 执行最终评估 - evaluation_prompt_template = prompt_manager.get_prompt("memory_retrieval_react_final") - evaluation_prompt_template.add_context("bot_name", bot_name) - evaluation_prompt_template.add_context("time_now", time_now) - evaluation_prompt_template.add_context("chat_history", chat_history) - evaluation_prompt_template.add_context("collected_info", collected_info or "暂无信息") - evaluation_prompt_template.add_context("current_iteration", str(current_iteration)) - evaluation_prompt_template.add_context("remaining_iterations", str(remaining_iterations)) - evaluation_prompt_template.add_context("max_iterations", str(max_iterations)) - evaluation_prompt = await prompt_manager.render_prompt(evaluation_prompt_template) - - evaluation_result = await llm_api.generate( - llm_api.LLMServiceRequest( - task_name="utils", - request_type="memory.react.final", - prompt=evaluation_prompt, - tool_options=[], - ) - ) - eval_success = evaluation_result.success - eval_response = evaluation_result.completion.response - - if not eval_success: - logger.error(f"ReAct Agent 最终评估阶段 LLM调用失败: {eval_response}") - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status="未找到答案:最终评估阶段LLM调用失败", - ) - return False, "最终评估阶段LLM调用失败", thinking_steps, is_timeout - - if global_config.debug.show_memory_prompt: - logger.info(f"{react_log_prefix}最终评估Prompt: {evaluation_prompt}") - logger.info(f"{react_log_prefix}最终评估响应: {eval_response}") - - # 从最终评估响应中提取return_information - return_information_content = None - - if eval_response: - return_information_content = extract_quoted_content(eval_response, "return_information", "information") - - # 如果提取到信息,返回(无论是否超时,都视为成功完成) - if return_information_content is not None: - eval_step = { - "iteration": current_iteration, - "thought": f"[最终评估] {eval_response}", - "actions": [ - {"action_type": "return_information", "action_params": {"information": return_information_content}} - ], - "observations": ["最终评估阶段检测到return_information"], - } - thinking_steps.append(eval_step) - if return_information_content and return_information_content.strip(): - logger.info(f"ReAct Agent 最终评估阶段返回信息: {return_information_content}") - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status=f"返回信息:{return_information_content}", - ) - return True, return_information_content, thinking_steps, False - else: - logger.info("ReAct Agent 最终评估阶段判断信息为空") - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status="信息为空:最终评估阶段判断信息为空", - ) - return False, "", thinking_steps, False - - # 如果没有明确判断,视为not_enough_info,返回空字符串(不返回任何信息) - eval_step = { - "iteration": current_iteration, - "thought": f"[最终评估] {eval_response}", - "actions": [{"action_type": "return_information", "action_params": {"information": ""}}], - "observations": ["已到达最大迭代次数,信息为空"], - } - thinking_steps.append(eval_step) - logger.info("ReAct Agent 已到达最大迭代次数,信息为空") - - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status="未找到答案:已到达最大迭代次数,无法找到答案", - ) - - return False, "", thinking_steps, is_timeout - - # 如果正常迭代过程中提前找到答案返回,不会到达这里 - # 如果正常迭代结束但没有触发最终评估(理论上不应该发生),直接返回 - logger.warning("ReAct Agent正常迭代结束,但未触发最终评估") - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status="未找到答案:正常迭代结束", - ) - - return False, "", thinking_steps, is_timeout - -async def _process_memory_retrieval( - chat_id: str, - context: str, - initial_info: str = "", - max_iterations: Optional[int] = None, - chat_history: str = "", -) -> Optional[str]: - """处理记忆检索 - - Args: - chat_id: 聊天ID - context: 上下文信息 - initial_info: 初始信息,将传递给ReAct Agent - max_iterations: 最大迭代次数 - chat_history: 聊天记录,将传递给 ReAct Agent - - Returns: - Optional[str]: 如果找到答案,返回答案内容,否则返回None - """ - question_initial_info = initial_info or "" - - # 直接使用ReAct Agent进行记忆检索 - # 如果未指定max_iterations,使用配置的默认值 - if max_iterations is None: - max_iterations = global_config.memory.max_agent_iterations - - found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question( - chat_id=chat_id, - max_iterations=max_iterations, - timeout=global_config.memory.agent_timeout_seconds, - initial_info=question_initial_info, - chat_history=chat_history, - ) - - # 不再存储到数据库,直接返回答案 - if is_timeout: - logger.info("ReAct Agent超时,不返回结果") - - return answer if found_answer and answer else None - - -async def build_memory_retrieval_prompt( - message: str, - sender: str, - target: str, - chat_stream, - think_level: int = 1, - unknown_words: Optional[List[str]] = None, -) -> str: - """构建记忆检索提示 - Args: - message: 聊天历史记录 - sender: 发送者名称 - target: 目标消息内容 - chat_stream: 聊天流对象 - think_level: 思考深度等级 - unknown_words: Planner 提供的未知词语列表,优先使用此列表而不是从聊天记录匹配 - - Returns: - str: 记忆检索结果字符串 - """ - start_time = time.time() - - # 构造日志前缀:[聊天流名称],用于在日志中标识聊天流(优先群名称/用户昵称) - try: - group_info = chat_stream.group_info - user_info = chat_stream.user_info - # 群聊优先使用群名称 - if group_info is not None and getattr(group_info, "group_name", None): - stream_name = group_info.group_name.strip() or str(group_info.group_id) - # 私聊使用用户昵称 - elif user_info is not None and getattr(user_info, "user_nickname", None): - stream_name = user_info.user_nickname.strip() or str(user_info.user_id) - # 兜底使用 stream_id - else: - stream_name = chat_stream.stream_id - except Exception: - stream_name = chat_stream.stream_id - log_prefix = f"[{stream_name}] " if stream_name else "" - - logger.info(f"{log_prefix}检测是否需要回忆,元消息:{message[:30]}...,消息长度: {len(message)}") - try: - chat_id = chat_stream.stream_id - - # 初始阶段:使用 Planner 提供的 unknown_words 进行检索(如果提供) - initial_info = "" - if unknown_words and len(unknown_words) > 0: - # 清理和去重 unknown_words - cleaned_concepts = [] - for word in unknown_words: - if isinstance(word, str): - if cleaned := word.strip(): - cleaned_concepts.append(cleaned) - if cleaned_concepts: - # 对匹配到的概念进行jargon检索,作为初始信息 - concept_info = await retrieve_concepts_with_jargon(cleaned_concepts, chat_id) - if concept_info: - initial_info += concept_info - logger.info( - f"{log_prefix}使用 Planner 提供的 unknown_words,共 {len(cleaned_concepts)} 个概念,检索结果: {concept_info[:100]}..." - ) - else: - logger.debug(f"{log_prefix}unknown_words 检索未找到任何结果") - - # 直接使用 ReAct Agent 进行记忆检索(跳过问题生成步骤) - base_max_iterations = global_config.memory.max_agent_iterations - # 根据think_level调整迭代次数:think_level=1时不变,think_level=0时减半 - if think_level == 0: - max_iterations = max(1, base_max_iterations // 2) # 至少为1 - else: - max_iterations = base_max_iterations - timeout_seconds = global_config.memory.agent_timeout_seconds - logger.debug( - f"{log_prefix}直接使用 ReAct Agent 进行记忆检索,think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}秒" - ) - - # 直接调用 ReAct Agent 处理记忆检索 - try: - result = await _process_memory_retrieval( - chat_id=chat_id, - context=message, - initial_info=initial_info, - max_iterations=max_iterations, - chat_history=message, - ) - except Exception as e: - logger.error(f"{log_prefix}处理记忆检索时发生异常: {e}") - result = None - - end_time = time.time() - - if result: - logger.info(f"{log_prefix}记忆检索成功,耗时: {(end_time - start_time):.3f}秒") - return f"你回忆起了以下信息:\n{result}\n如果与回复内容相关,可以参考这些回忆的信息。\n" - else: - logger.debug(f"{log_prefix}记忆检索未找到相关信息") - return "" - - except Exception as e: - logger.error(f"{log_prefix}记忆检索时发生异常: {str(e)}") - return "" diff --git a/src/memory_system/retrieval_tools/__init__.py b/src/memory_system/retrieval_tools/__init__.py deleted file mode 100644 index ba5f731f..00000000 --- a/src/memory_system/retrieval_tools/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -记忆检索工具模块 -提供统一的工具注册和管理系统 -""" - -from .tool_registry import ( - MemoryRetrievalTool, - MemoryRetrievalToolRegistry, - register_memory_retrieval_tool, - get_tool_registry, -) - - -def init_all_tools(): - """初始化并注册所有记忆检索工具""" - # 延迟导入,避免在仅使用部分工具或单元测试阶段触发不必要的依赖链。 - from .query_long_term_memory import register_tool as register_long_term_memory - from .query_words import register_tool as register_query_words - from .return_information import register_tool as register_return_information - - register_query_words() - register_return_information() - register_long_term_memory() - - -__all__ = [ - "MemoryRetrievalTool", - "MemoryRetrievalToolRegistry", - "register_memory_retrieval_tool", - "get_tool_registry", - "init_all_tools", -] diff --git a/src/memory_system/retrieval_tools/query_long_term_memory.py b/src/memory_system/retrieval_tools/query_long_term_memory.py deleted file mode 100644 index bf39c0cd..00000000 --- a/src/memory_system/retrieval_tools/query_long_term_memory.py +++ /dev/null @@ -1,307 +0,0 @@ -"""通过统一长期记忆服务查询信息。""" - -from __future__ import annotations - -import re -from calendar import monthrange -from datetime import datetime, timedelta -from typing import Iterable, Literal, Tuple - -from src.common.logger import get_logger -from src.services.memory_service import MemoryHit, MemorySearchResult, memory_service - -from .tool_registry import register_memory_retrieval_tool - -logger = get_logger("memory_retrieval_tools") - -_SUPPORTED_MODES = {"search", "time", "episode", "aggregate"} -_RELATIVE_DAYS_RE = re.compile(r"^最近\s*(\d+)\s*天$") -_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$") -_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}\s+\d{2}:\d{2}$") -_TIME_EXPRESSION_HELP = ( - "请改用更具体的时间表达,例如:今天、昨天、前天、本周、上周、本月、上月、最近7天、" - "2026/03/18、2026/03/18 09:30。" -) - - -def _format_query_datetime(dt: datetime) -> str: - return dt.strftime("%Y/%m/%d %H:%M") - - -def _resolve_time_expression( - expression: str, - *, - now: datetime | None = None, -) -> Tuple[float, float, str, str]: - clean = str(expression or "").strip() - if not clean: - raise ValueError(f"time 模式需要提供 time_expression。{_TIME_EXPRESSION_HELP}") - - current = now or datetime.now() - day_start = current.replace(hour=0, minute=0, second=0, microsecond=0) - - if clean == "今天": - start = day_start - end = day_start.replace(hour=23, minute=59) - elif clean == "昨天": - start = day_start - timedelta(days=1) - end = start.replace(hour=23, minute=59) - elif clean == "前天": - start = day_start - timedelta(days=2) - end = start.replace(hour=23, minute=59) - elif clean == "本周": - start = day_start - timedelta(days=day_start.weekday()) - end = start + timedelta(days=6, hours=23, minutes=59) - elif clean == "上周": - this_week_start = day_start - timedelta(days=day_start.weekday()) - start = this_week_start - timedelta(days=7) - end = start + timedelta(days=6, hours=23, minutes=59) - elif clean == "本月": - start = day_start.replace(day=1) - last_day = monthrange(start.year, start.month)[1] - end = start.replace(day=last_day, hour=23, minute=59) - elif clean == "上月": - year = day_start.year - month = day_start.month - 1 - if month == 0: - year -= 1 - month = 12 - start = day_start.replace(year=year, month=month, day=1) - last_day = monthrange(year, month)[1] - end = start.replace(day=last_day, hour=23, minute=59) - else: - relative_match = _RELATIVE_DAYS_RE.fullmatch(clean) - if relative_match: - days = max(1, int(relative_match.group(1))) - start = day_start - timedelta(days=max(0, days - 1)) - end = day_start.replace(hour=23, minute=59) - elif _DATE_RE.fullmatch(clean): - start = datetime.strptime(clean, "%Y/%m/%d") - end = start.replace(hour=23, minute=59) - elif _MINUTE_RE.fullmatch(clean): - start = datetime.strptime(clean, "%Y/%m/%d %H:%M") - end = start - else: - raise ValueError(f"时间表达“{clean}”无法解析。{_TIME_EXPRESSION_HELP}") - - return start.timestamp(), end.timestamp(), _format_query_datetime(start), _format_query_datetime(end) - - -def _extract_time_label(metadata: dict) -> str: - if not isinstance(metadata, dict): - return "" - start = metadata.get("event_time_start") - end = metadata.get("event_time_end") - event_time = metadata.get("event_time") - - def _fmt(value: object) -> str: - if value in {None, ""}: - return "" - try: - return datetime.fromtimestamp(float(value)).strftime("%Y/%m/%d %H:%M") - except Exception: - return str(value) - - start_text = _fmt(start or event_time) - end_text = _fmt(end) - if start_text and end_text: - return f"{start_text} - {end_text}" - return start_text or end_text - - -def _truncate(text: str, limit: int = 160) -> str: - compact = str(text or "").strip().replace("\n", " ") - if len(compact) <= limit: - return compact - return compact[:limit] + "..." - - -def _format_search_lines(hits: Iterable[MemoryHit], *, limit: int, include_time: bool = False) -> str: - lines = [] - for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1): - time_label = _extract_time_label(item.metadata) if include_time else "" - prefix = f"[{time_label}] " if time_label else "" - lines.append(f"{index}. {prefix}{_truncate(item.content)}") - return "\n".join(lines) - - -def _format_episode_lines(hits: Iterable[MemoryHit], *, limit: int) -> str: - lines = [] - for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1): - metadata = item.metadata if isinstance(item.metadata, dict) else {} - title = str(item.title or "").strip() or "未命名事件" - summary = _truncate(item.content, limit=180) - participants = [str(x).strip() for x in (metadata.get("participants") or []) if str(x).strip()] - keywords = [str(x).strip() for x in (metadata.get("keywords") or []) if str(x).strip()] - extras = [] - if participants: - extras.append(f"参与者:{'、'.join(participants[:4])}") - if keywords: - extras.append(f"关键词:{'、'.join(keywords[:6])}") - time_label = _extract_time_label(metadata) - if time_label: - extras.append(f"时间:{time_label}") - suffix = f"({';'.join(extras)})" if extras else "" - lines.append(f"{index}. 事件《{title}》:{summary}{suffix}") - return "\n".join(lines) - - -def _format_aggregate_lines(hits: Iterable[MemoryHit], *, limit: int) -> str: - lines = [] - for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1): - metadata = item.metadata if isinstance(item.metadata, dict) else {} - source_branches = [str(x).strip() for x in (metadata.get("source_branches") or []) if str(x).strip()] - branch_text = f"[{','.join(source_branches)}]" if source_branches else "" - item_type = str(item.hit_type or "").strip().lower() or "memory" - if item_type == "episode": - title = str(item.title or "").strip() or "未命名事件" - lines.append(f"{index}. {branch_text}[episode] 《{title}》:{_truncate(item.content, 160)}") - else: - lines.append(f"{index}. {branch_text}[{item_type}] {_truncate(item.content, 160)}") - return "\n".join(lines) - - -def _format_tool_result( - *, - result: MemorySearchResult, - mode: Literal["search", "time", "episode", "aggregate"], - limit: int, - query: str, - time_range_text: str = "", -) -> str: - if not result.success: - return f"长期记忆查询失败:{result.error or '未知错误'}" - - if not result.hits: - if mode == "time": - return f"在指定时间范围内未找到相关的长期记忆{time_range_text}" - if mode == "episode": - return f"未找到与“{query}”相关的事件或情节记忆" - if mode == "aggregate": - return f"未找到可用于综合回忆的长期记忆线索{f'(query:{query})' if query else ''}" - return f"在长期记忆中未找到与“{query}”相关的信息" - - if mode == "episode": - text = _format_episode_lines(result.hits, limit=limit) - return f"你从长期记忆的事件/情节中找到以下信息:\n{text}" - - if mode == "aggregate": - text = _format_aggregate_lines(result.hits, limit=limit) - return f"你从长期记忆中综合找到了以下线索:\n{text}" - - if mode == "time": - text = _format_search_lines(result.hits, limit=limit, include_time=True) - return f"你从指定时间范围内的长期记忆中找到以下信息{time_range_text}:\n{text}" - - text = _format_search_lines(result.hits, limit=limit) - return f"你从长期记忆中找到以下信息:\n{text}" - - -async def query_long_term_memory( - query: str = "", - limit: int = 5, - chat_id: str = "", - person_id: str = "", - mode: str = "search", - time_expression: str = "", -) -> str: - content = str(query or "").strip() - safe_limit = max(1, int(limit or 5)) - normalized_mode = str(mode or "search").strip().lower() or "search" - if normalized_mode not in _SUPPORTED_MODES: - return f"不支持的长期记忆检索模式:{normalized_mode}。可用模式:search、time、episode、aggregate。" - - if normalized_mode == "search" and not content: - return "查询关键词为空,请提供你想查找的长期记忆内容。" - if normalized_mode == "time" and not str(time_expression or "").strip(): - return f"time 模式需要提供 time_expression。{_TIME_EXPRESSION_HELP}" - if normalized_mode in {"episode", "aggregate"} and not content and not str(time_expression or "").strip(): - return f"{normalized_mode} 模式至少需要提供 query 或 time_expression。" - - time_start = None - time_end = None - time_range_text = "" - if str(time_expression or "").strip(): - try: - time_start, time_end, time_start_text, time_end_text = _resolve_time_expression(time_expression) - except ValueError as exc: - return str(exc) - time_range_text = f"(时间范围:{time_start_text} 至 {time_end_text})" - - backend_mode = normalized_mode - - try: - result = await memory_service.search( - content, - limit=safe_limit, - mode=backend_mode, - chat_id=str(chat_id or "").strip(), - person_id=str(person_id or "").strip(), - time_start=time_start, - time_end=time_end, - ) - text = _format_tool_result( - result=result, - mode=normalized_mode, # type: ignore[arg-type] - limit=safe_limit, - query=content, - time_range_text=time_range_text, - ) - logger.debug(f"长期记忆查询结果({normalized_mode}): {text}") - return text - except Exception as exc: - logger.error(f"长期记忆查询失败: {exc}") - return f"长期记忆查询失败:{exc}" - - -def register_tool(): - register_memory_retrieval_tool( - name="search_long_term_memory", - description=( - "从长期记忆中检索信息。支持 search(普通事实检索)、time(按时间范围检索)、" - "episode(按事件/情节检索)、aggregate(综合检索)四种模式。" - ), - parameters=[ - { - "name": "query", - "type": "string", - "description": "需要查询的问题。search 模式建议用自然语言问句;time/episode/aggregate 模式也可用关键词短语。", - "required": False, - }, - { - "name": "mode", - "type": "string", - "description": "检索模式:search(普通长期记忆)、time(按时间窗口)、episode(事件/情节)、aggregate(综合检索)。", - "required": False, - "enum": ["search", "time", "episode", "aggregate"], - }, - { - "name": "limit", - "type": "integer", - "description": "希望返回的相关知识条数,默认为5", - "required": False, - }, - { - "name": "chat_id", - "type": "string", - "description": "当前聊天流ID,可选。提供后优先检索当前聊天上下文相关的长期记忆。", - "required": False, - }, - { - "name": "person_id", - "type": "string", - "description": "相关人物ID,可选。提供后优先检索该人物相关的长期记忆。", - "required": False, - }, - { - "name": "time_expression", - "type": "string", - "description": ( - "时间表达,可选。time 模式必填;episode/aggregate 模式可选。支持:今天、昨天、前天、本周、上周、本月、上月、" - "最近N天,以及 YYYY/MM/DD、YYYY/MM/DD HH:mm。" - ), - "required": False, - }, - ], - execute_func=query_long_term_memory, - ) diff --git a/src/memory_system/retrieval_tools/query_words.py b/src/memory_system/retrieval_tools/query_words.py deleted file mode 100644 index ee28b934..00000000 --- a/src/memory_system/retrieval_tools/query_words.py +++ /dev/null @@ -1,78 +0,0 @@ -""" -查询黑话/概念含义 - 工具实现 -用于在记忆检索过程中主动查询未知词语或黑话的含义 -""" - -from src.common.logger import get_logger -from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon -from .tool_registry import register_memory_retrieval_tool - -logger = get_logger("memory_retrieval_tools") - - -async def query_words(chat_id: str, words: str) -> str: - """查询词语或黑话的含义 - - Args: - chat_id: 聊天ID - words: 要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔) - - Returns: - str: 查询结果,包含词语的含义解释 - """ - try: - if not words or not words.strip(): - return "未提供要查询的词语" - - # 解析词语列表(支持逗号、空格等分隔符) - words_list = [] - for separator in [",", ",", " ", "\n", "\t"]: - if separator in words: - words_list = [w.strip() for w in words.split(separator) if w.strip()] - break - - # 如果没有找到分隔符,整个字符串作为一个词语 - if not words_list: - words_list = [words.strip()] - - # 去重 - unique_words = [] - seen = set() - for word in words_list: - if word and word not in seen: - unique_words.append(word) - seen.add(word) - - if not unique_words: - return "未提供有效的词语" - - logger.info(f"查询词语含义: {unique_words}") - - # 调用检索函数 - result = await retrieve_concepts_with_jargon(unique_words, chat_id) - - if result: - return result - else: - return f"未找到词语 '{', '.join(unique_words)}' 的含义或黑话解释" - - except Exception as e: - logger.error(f"查询词语含义失败: {e}") - return f"查询失败: {str(e)}" - - -def register_tool(): - """注册工具""" - register_memory_retrieval_tool( - name="query_words", - description="查询词语或黑话的含义。当遇到不熟悉的词语、缩写、黑话或网络用语时,可以使用此工具查询其含义。支持查询单个或多个词语(用逗号、空格等分隔)。", - parameters=[ - { - "name": "words", - "type": "string", - "description": "要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔,如:'YYDS' 或 'YYDS,内卷,996')", - "required": True, - }, - ], - execute_func=query_words, - ) diff --git a/src/memory_system/retrieval_tools/return_information.py b/src/memory_system/retrieval_tools/return_information.py deleted file mode 100644 index bf368083..00000000 --- a/src/memory_system/retrieval_tools/return_information.py +++ /dev/null @@ -1,42 +0,0 @@ -""" -return_information工具 - 用于在记忆检索过程中返回总结信息并结束查询 -""" - -from src.common.logger import get_logger -from .tool_registry import register_memory_retrieval_tool - -logger = get_logger("memory_retrieval_tools") - - -async def return_information(information: str) -> str: - """返回总结信息并结束查询 - - Args: - information: 基于已收集信息总结出的相关信息,用于帮助回复。如果收集的信息对当前聊天没有帮助,可以返回空字符串。 - - Returns: - str: 确认信息 - """ - if information and information.strip(): - logger.info(f"返回总结信息: {information}") - return f"已确认返回信息: {information}" - else: - logger.info("未收集到相关信息,结束查询") - return "未收集到相关信息,查询结束" - - -def register_tool(): - """注册return_information工具""" - register_memory_retrieval_tool( - name="return_information", - description="当你决定结束查询时,调用此工具。基于已收集的信息,总结出一段相关信息用于帮助回复。如果收集的信息对当前聊天有帮助,在information参数中提供总结信息;如果信息无关或没有帮助,可以提供空字符串。", - parameters=[ - { - "name": "information", - "type": "string", - "description": "基于已收集信息总结出的相关信息,用于帮助回复。必须基于已收集的信息,不要编造。如果信息对当前聊天没有帮助,可以返回空字符串。", - "required": True, - }, - ], - execute_func=return_information, - ) diff --git a/src/memory_system/retrieval_tools/tool_registry.py b/src/memory_system/retrieval_tools/tool_registry.py deleted file mode 100644 index f2dd1f0d..00000000 --- a/src/memory_system/retrieval_tools/tool_registry.py +++ /dev/null @@ -1,167 +0,0 @@ -"""工具注册系统。 - -提供统一的工具注册和管理接口。 -""" - -from typing import Any, Awaitable, Callable, Dict, List, Optional - -from src.common.logger import get_logger -from src.llm_models.payload_content.tool_option import ToolParamType, normalize_tool_option - -logger = get_logger("memory_retrieval_tools") - - -class MemoryRetrievalTool: - """记忆检索工具基类""" - - def __init__( - self, - name: str, - description: str, - parameters: List[Dict[str, Any]], - execute_func: Callable[..., Awaitable[str]], - ) -> None: - """初始化工具。 - - Args: - name: 工具名称。 - description: 工具描述。 - parameters: 参数定义列表。 - execute_func: 执行函数,必须是异步函数。 - """ - self.name = name - self.description = description - self.parameters = parameters - self.execute_func = execute_func - - def get_tool_description(self) -> str: - """获取工具的文本描述,用于prompt""" - param_descriptions = [] - for param in self.parameters: - param_name = param.get("name", "") - param_type = param.get("type", "string") - param_desc = param.get("description", "") - required = param.get("required", True) - required_str = "必填" if required else "可选" - param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}") - - params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数" - return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}" - - async def execute(self, **kwargs: Any) -> str: - """执行工具。""" - return await self.execute_func(**kwargs) - - def get_tool_definition(self) -> Dict[str, Any]: - """获取规范化的工具定义。 - - Returns: - Dict[str, Any]: 统一工具定义字典。 - """ - legacy_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = [] - - for param in self.parameters: - param_name = param.get("name", "") - param_type_str = param.get("type", "string").lower() - param_desc = param.get("description", "") - is_required = param.get("required", False) - enum_values = param.get("enum", None) - - # 转换类型字符串到ToolParamType - type_mapping = { - "string": ToolParamType.STRING, - "integer": ToolParamType.INTEGER, - "int": ToolParamType.INTEGER, - "float": ToolParamType.FLOAT, - "boolean": ToolParamType.BOOLEAN, - "bool": ToolParamType.BOOLEAN, - } - param_type = type_mapping.get(param_type_str, ToolParamType.STRING) - - legacy_parameters.append((param_name, param_type, param_desc, is_required, enum_values)) - - normalized_option = normalize_tool_option( - { - "name": self.name, - "description": self.description, - "parameters": legacy_parameters, - } - ) - return { - "name": normalized_option.name, - "description": normalized_option.description, - "parameters_schema": normalized_option.parameters_schema, - } - - -class MemoryRetrievalToolRegistry: - """工具注册器""" - - def __init__(self) -> None: - """初始化工具注册器。""" - self.tools: Dict[str, MemoryRetrievalTool] = {} - - def register_tool(self, tool: MemoryRetrievalTool) -> None: - """注册工具""" - if tool.name in self.tools: - logger.debug(f"记忆检索工具 {tool.name} 已存在,跳过重复注册") - return - self.tools[tool.name] = tool - logger.info(f"注册记忆检索工具: {tool.name}") - - def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]: - """获取工具""" - return self.tools.get(name) - - def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]: - """获取所有工具""" - return self.tools.copy() - - def get_tools_description(self) -> str: - """获取所有工具的描述,用于prompt""" - descriptions = [] - for i, tool in enumerate(self.tools.values(), 1): - descriptions.append(f"{i}. {tool.get_tool_description()}") - return "\n".join(descriptions) - - def get_action_types_list(self) -> str: - """获取所有动作类型的列表,用于prompt(已废弃,保留用于兼容)""" - action_types = [tool.name for tool in self.tools.values()] - action_types.append("final_answer") - action_types.append("no_answer") - return " 或 ".join([f'"{at}"' for at in action_types]) - - def get_tool_definitions(self) -> List[Dict[str, Any]]: - """获取所有工具的定义列表,用于LLM function calling - - Returns: - List[Dict[str, Any]]: 工具定义列表,每个元素是一个工具定义字典 - """ - return [tool.get_tool_definition() for tool in self.tools.values()] - - -# 全局工具注册器实例 -_tool_registry = MemoryRetrievalToolRegistry() - - -def register_memory_retrieval_tool( - name: str, - description: str, - parameters: List[Dict[str, Any]], - execute_func: Callable[..., Awaitable[str]], -) -> None: - """注册记忆检索工具的便捷函数。 - - Args: - name: 工具名称。 - description: 工具描述。 - parameters: 参数定义列表。 - execute_func: 执行函数。 - """ - tool = MemoryRetrievalTool(name, description, parameters, execute_func) - _tool_registry.register_tool(tool) - - -def get_tool_registry() -> MemoryRetrievalToolRegistry: - """获取工具注册器实例""" - return _tool_registry