remove:无用配置,整理配置文件
This commit is contained in:
24
AGENTS.md
24
AGENTS.md
@@ -1,23 +1,8 @@
|
||||
# import 规范
|
||||
在从外部库进行导入时候,请遵循以下顺序:
|
||||
1. 对于标准库和第三方库的导入,请按照如下顺序:
|
||||
- 需要使用`from ... import ...`语法的导入放在前面。
|
||||
- 直接使用`import ...`语法的导入放在后面。
|
||||
- 对于使用`from ... import ...`导入的多个项,请**在保证不会引起import错误的前提下**,按照**字母顺序**排列。
|
||||
- 对于使用`import ...`导入的多个项,请**在保证不会引起import错误的前提下**,按照**字母顺序**排列。
|
||||
2. 对于本地模块的导入,请按照如下顺序:
|
||||
- 对于同一个文件夹下的模块导入,使用相对导入,排列顺序按照**不发生import错误的前提下**,随便排列。
|
||||
- 对于不同文件夹下的模块导入,使用绝对导入。这些导入应该以`from src`开头,并且按照**不发生import错误的前提下**,尽量使得第二层的文件夹名称相同的导入放在一起;第二层文件夹名称排列随机。
|
||||
3. 标准库和第三方库的导入应该放在本地模块导入的前面。
|
||||
4. 各个导入块之间应该使用一个空行进行分隔。
|
||||
5. 对于现有的代码,如果导入顺序不符合上述规范,在重构代码时应该调整导入顺序以符合规范。
|
||||
|
||||
# 代码规范
|
||||
## 注释规范
|
||||
1. 尽量保持良好的注释
|
||||
2. 如果原来的代码中有注释,则重构的时候,除非这部分代码被删除,否则相同功能的代码应该保留注释(可以对注释进行修改以保持准确性,但不应该删除注释)。
|
||||
3. 如果原来的代码中没有注释,则重构的时候,如果某个功能块的代码较长或者逻辑较为复杂,则应该添加注释来解释这部分代码的功能和逻辑。
|
||||
4. 对于类,方法以及模块的注释,首选使用的注释格式为 Google DocStr 格式,但保证语言为简体中文
|
||||
## 类型注解规范
|
||||
1. 重构代码时,如果原来的代码中有类型注解,则相同功能的代码应该保留类型注解(可以对类型注解进行修改以保持准确性,但不应该删除类型注解)。
|
||||
2. 重构代码时,如果原来的代码中没有类型注解,则重构的时候,如果某个函数的功能较为复杂或者参数较多,则应该添加类型注解来提高代码的可读性和可维护性。(对于简单的变量,可以不添加类型注解)
|
||||
@@ -34,8 +19,13 @@
|
||||
# 运行/调试/构建/测试/依赖
|
||||
优先使用uv
|
||||
依赖项以 pyproject.toml 为准
|
||||
不要修改dashboard下的内容,因为这部分内容由另一个仓库build
|
||||
|
||||
|
||||
# 语言规范
|
||||
项目的首选语言为简体中文,无论是注释语言,日志展示语言,还是 WebUI 展示语言都首要以简体中文为首要实现目标
|
||||
|
||||
项目的首选语言为简体中文,无论是注释语言,日志展示语言,还是 WebUI 展示语言都应该首要以简体中文为首要实现目标
|
||||
# 配置文件修改
|
||||
如果你需要改动配置文件,不需要修改实际的bot_config.toml或者model_config.toml,只需要修改配置文件模版,并新增一个版本号即可,也不必要为配置改动创建测试文件。
|
||||
|
||||
# 关于webui修改
|
||||
不要修改dashboard下的内容,因为这部分内容由另一个仓库build
|
||||
@@ -501,7 +501,7 @@ class EmojiManager:
|
||||
existing_record = session.exec(statement).first()
|
||||
if existing_record:
|
||||
if existing_record.is_registered and _is_available_emoji_record(existing_record):
|
||||
logger.info(f"[register_emoji] Emoji already registered, skipping: {emoji.file_hash}")
|
||||
# logger.info(f"[register_emoji] Emoji already registered, skipping: {emoji.file_hash}")
|
||||
return "skipped"
|
||||
|
||||
normalized_description = str(emoji.description or existing_record.description or "").strip()
|
||||
|
||||
@@ -369,7 +369,7 @@ class ImageManager:
|
||||
logger.info(f"Cleaned mistaken registration state on {fixed_counter} image records")
|
||||
|
||||
async def _generate_image_description(self, image_bytes: bytes, image_format: str) -> str:
|
||||
prompt = global_config.personality.visual_style
|
||||
prompt = global_config.visual.visual_style
|
||||
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
generation_result = await vlm.generate_response_for_image(
|
||||
|
||||
@@ -18,4 +18,4 @@ def get_maisaka_replyer_class() -> Type[object]:
|
||||
|
||||
def get_maisaka_replyer_generator_type() -> str:
|
||||
"""返回当前配置的 Maisaka replyer 生成器类型。"""
|
||||
return global_config.chat.replyer_generator_type
|
||||
return "multimodal" if global_config.visual.multimodal_replyer else "legacy"
|
||||
|
||||
@@ -27,14 +27,15 @@ from .official_configs import (
|
||||
MaiSakaConfig,
|
||||
MaimMessageConfig,
|
||||
MCPConfig,
|
||||
PluginRuntimeConfig,
|
||||
MemoryConfig,
|
||||
MessageReceiveConfig,
|
||||
PersonalityConfig,
|
||||
PluginRuntimeConfig,
|
||||
RelationshipConfig,
|
||||
ResponsePostProcessConfig,
|
||||
ResponseSplitterConfig,
|
||||
TelemetryConfig,
|
||||
VisualConfig,
|
||||
VoiceConfig,
|
||||
WebUIConfig,
|
||||
)
|
||||
@@ -55,7 +56,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0"
|
||||
CONFIG_VERSION: str = "8.3.4"
|
||||
CONFIG_VERSION: str = "8.4.1"
|
||||
MODEL_CONFIG_VERSION: str = "1.13.1"
|
||||
|
||||
logger = get_logger("config")
|
||||
@@ -73,6 +74,9 @@ class Config(ConfigBase):
|
||||
personality: PersonalityConfig = Field(default_factory=PersonalityConfig)
|
||||
"""人格配置类"""
|
||||
|
||||
visual: VisualConfig = Field(default_factory=VisualConfig)
|
||||
"""视觉配置类"""
|
||||
|
||||
expression: ExpressionConfig = Field(default_factory=ExpressionConfig)
|
||||
"""表达配置类"""
|
||||
|
||||
@@ -474,6 +478,10 @@ def load_config_from_file(
|
||||
if env_migration.migrated:
|
||||
logger.warning(f"检测到旧版环境变量绑定配置,已迁移到主配置: {env_migration.reason}")
|
||||
config_data = env_migration.data
|
||||
legacy_migration = try_migrate_legacy_bot_config_dict(config_data)
|
||||
if legacy_migration.migrated:
|
||||
logger.warning(t("config.legacy_migrated", reason=legacy_migration.reason))
|
||||
config_data = legacy_migration.data
|
||||
# 保留一份“干净”的原始数据副本,避免第一次 from_dict 过程中对 dict 的就地修改
|
||||
original_data: dict[str, Any] = copy.deepcopy(config_data)
|
||||
try:
|
||||
|
||||
@@ -75,6 +75,18 @@ def _migrate_env_value(section: dict[str, Any], key: str, parsed_env_value: Any,
|
||||
return True
|
||||
|
||||
|
||||
def _move_section_key(source: dict[str, Any], target: dict[str, Any], key: str) -> bool:
|
||||
"""将配置项从旧分组移动到新分组,若新分组已有值则保留新值。"""
|
||||
|
||||
if key not in source:
|
||||
return False
|
||||
|
||||
if key not in target:
|
||||
target[key] = source[key]
|
||||
source.pop(key, None)
|
||||
return True
|
||||
|
||||
|
||||
def _parse_triplet_target(s: str) -> Optional[dict[str, str]]:
|
||||
"""
|
||||
解析 "platform:id:type" -> {platform,item_id,rule_type}
|
||||
@@ -273,6 +285,21 @@ def _migrate_extra_prompt_list(exp: dict[str, Any], key: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _parse_multimodal_replyer(v: Any) -> Optional[bool]:
|
||||
"""兼容旧 replyer_generator_type 到布尔开关的迁移。"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if not isinstance(v, str):
|
||||
return None
|
||||
|
||||
normalized_value = v.strip().lower()
|
||||
if normalized_value == "multimodal":
|
||||
return True
|
||||
if normalized_value == "legacy":
|
||||
return False
|
||||
return None
|
||||
|
||||
|
||||
def migrate_legacy_bind_env_to_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
"""将旧版环境变量中的绑定地址迁移到主配置结构。"""
|
||||
|
||||
@@ -351,12 +378,61 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
migrated_any = True
|
||||
reasons.append("chat.private_plan_style_removed")
|
||||
|
||||
personality = _as_dict(data.get("personality"))
|
||||
visual = _as_dict(data.get("visual"))
|
||||
if visual is None and (
|
||||
(personality is not None and "visual_style" in personality)
|
||||
or "multimodal_planner" in chat
|
||||
or "replyer_generator_type" in chat
|
||||
):
|
||||
visual = {}
|
||||
data["visual"] = visual
|
||||
|
||||
if visual is not None and personality is not None and "visual_style" in personality:
|
||||
if "visual_style" not in visual:
|
||||
visual["visual_style"] = personality["visual_style"]
|
||||
personality.pop("visual_style", None)
|
||||
migrated_any = True
|
||||
reasons.append("personality.visual_style_moved_to_visual.visual_style")
|
||||
|
||||
if visual is not None and "multimodal_planner" in chat:
|
||||
if "multimodal_planner" not in visual and isinstance(chat["multimodal_planner"], bool):
|
||||
visual["multimodal_planner"] = chat["multimodal_planner"]
|
||||
if "multimodal_planner" in visual:
|
||||
chat.pop("multimodal_planner", None)
|
||||
migrated_any = True
|
||||
reasons.append("chat.multimodal_planner_moved_to_visual.multimodal_planner")
|
||||
|
||||
if visual is not None and "replyer_generator_type" in chat:
|
||||
multimodal_replyer = _parse_multimodal_replyer(chat["replyer_generator_type"])
|
||||
if "multimodal_replyer" not in visual and multimodal_replyer is not None:
|
||||
visual["multimodal_replyer"] = multimodal_replyer
|
||||
if "multimodal_replyer" in visual:
|
||||
chat.pop("replyer_generator_type", None)
|
||||
migrated_any = True
|
||||
reasons.append("chat.replyer_generator_type_moved_to_visual.multimodal_replyer")
|
||||
|
||||
maisaka = _as_dict(data.get("maisaka"))
|
||||
mem = _as_dict(data.get("memory"))
|
||||
if maisaka is not None:
|
||||
moved_memory_keys = ("enable_memory_query_tool", "memory_query_default_limit")
|
||||
if any(key in maisaka for key in moved_memory_keys) and mem is None:
|
||||
mem = {}
|
||||
data["memory"] = mem
|
||||
|
||||
if mem is not None:
|
||||
for moved_key in moved_memory_keys:
|
||||
if _move_section_key(maisaka, mem, moved_key):
|
||||
migrated_any = True
|
||||
reasons.append(f"maisaka.{moved_key}_moved_to_memory")
|
||||
|
||||
if mem is not None:
|
||||
if _migrate_target_item_list(mem, "global_memory_blacklist"):
|
||||
migrated_any = True
|
||||
reasons.append("memory.global_memory_blacklist")
|
||||
|
||||
for removed_key in (
|
||||
"agent_timeout_seconds",
|
||||
"global_memory",
|
||||
"global_memory_blacklist",
|
||||
"max_agent_iterations",
|
||||
):
|
||||
if removed_key in mem:
|
||||
|
||||
@@ -113,15 +113,6 @@ class PersonalityConfig(ConfigBase):
|
||||
)
|
||||
"""每次构建回复时,从 multiple_reply_style 中随机替换 reply_style 的概率(0.0-1.0)"""
|
||||
|
||||
visual_style: str = Field(
|
||||
default="请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本",
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""_wrap_识图提示词,不建议修改"""
|
||||
|
||||
states: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"是一个女大学生,喜欢上网聊天,会刷小红书。",
|
||||
@@ -148,6 +139,40 @@ class PersonalityConfig(ConfigBase):
|
||||
"""状态概率,每次构建人格时替换personality的概率"""
|
||||
|
||||
|
||||
class VisualConfig(ConfigBase):
|
||||
"""视觉配置类"""
|
||||
|
||||
__ui_label__ = "视觉"
|
||||
__ui_icon__ = "image"
|
||||
|
||||
multimodal_planner: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""是否直接输入图片"""
|
||||
|
||||
multimodal_replyer: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "git-branch",
|
||||
},
|
||||
)
|
||||
"""是否启用 Maisaka 多模态 replyer 生成器"""
|
||||
|
||||
visual_style: str = Field(
|
||||
default="请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本",
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""_wrap_识图提示词,不建议修改"""
|
||||
|
||||
|
||||
class RelationshipConfig(ConfigBase):
|
||||
"""关系配置类"""
|
||||
|
||||
@@ -253,24 +278,6 @@ class ChatConfig(ConfigBase):
|
||||
},
|
||||
)
|
||||
|
||||
multimodal_planner: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""是否直接输入图片"""
|
||||
|
||||
replyer_generator_type: Literal["legacy", "multimodal"] = Field(
|
||||
default="legacy",
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "git-branch",
|
||||
},
|
||||
)
|
||||
"""Maisaka replyer 生成器类型:legacy(旧版单 prompt)/ multimodal(多模态版,适合主循环直接展示图片)"""
|
||||
|
||||
enable_talk_value_rules: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
@@ -373,24 +380,6 @@ class MemoryConfig(ConfigBase):
|
||||
|
||||
__ui_parent__ = "emoji"
|
||||
|
||||
max_agent_iterations: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "layers",
|
||||
},
|
||||
)
|
||||
"""记忆思考深度(最低为1)"""
|
||||
|
||||
agent_timeout_seconds: float = Field(
|
||||
default=120.0,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "clock",
|
||||
},
|
||||
)
|
||||
"""最长回忆时间(秒)"""
|
||||
|
||||
global_memory: bool = Field(
|
||||
default=False,
|
||||
@@ -410,6 +399,26 @@ class MemoryConfig(ConfigBase):
|
||||
)
|
||||
"""_wrap_全局记忆黑名单,当启用全局记忆时,不将特定聊天流纳入检索"""
|
||||
|
||||
enable_memory_query_tool: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "database",
|
||||
},
|
||||
)
|
||||
"""是否启用 Maisaka 内置长期记忆检索工具 query_memory"""
|
||||
|
||||
memory_query_default_limit: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=20,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "hash",
|
||||
},
|
||||
)
|
||||
"""Maisaka 内置长期记忆检索工具 query_memory 的默认返回条数"""
|
||||
|
||||
long_term_auto_summary_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
@@ -657,16 +666,6 @@ class ExpressionConfig(ConfigBase):
|
||||
)
|
||||
"""是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除"""
|
||||
|
||||
enable_jargon_explanation: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "info",
|
||||
},
|
||||
)
|
||||
"""是否在回复前尝试对上下文中的黑话进行解释(关闭可减少一次LLM调用,仅影响回复前的黑话匹配与解释,不影响黑话学习)"""
|
||||
|
||||
|
||||
class VoiceConfig(ConfigBase):
|
||||
"""语音识别配置类"""
|
||||
|
||||
@@ -700,7 +699,7 @@ class EmojiConfig(ConfigBase):
|
||||
"""一次从多少个表情包中选择发送,最大为 64"""
|
||||
|
||||
max_reg_num: int = Field(
|
||||
default=100,
|
||||
default=64,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "hash",
|
||||
@@ -988,23 +987,6 @@ class DebugConfig(ConfigBase):
|
||||
)
|
||||
"""是否显示prompt"""
|
||||
|
||||
show_replyer_prompt: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "message-square",
|
||||
},
|
||||
)
|
||||
"""是否显示回复器prompt"""
|
||||
|
||||
show_replyer_reasoning: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "brain",
|
||||
},
|
||||
)
|
||||
|
||||
show_maisaka_thinking: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
@@ -1041,25 +1023,6 @@ class DebugConfig(ConfigBase):
|
||||
)
|
||||
"""是否显示记忆检索相关prompt"""
|
||||
|
||||
show_planner_prompt: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "map",
|
||||
},
|
||||
)
|
||||
"""是否显示planner的prompt和原始返回结果"""
|
||||
|
||||
show_lpmm_paragraph: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "file-text",
|
||||
},
|
||||
)
|
||||
"""是否显示lpmm找到的相关文段日志"""
|
||||
|
||||
|
||||
class ExtraPromptItem(ConfigBase):
|
||||
platform: str = Field(
|
||||
default="",
|
||||
@@ -1511,16 +1474,6 @@ class MaiSakaConfig(ConfigBase):
|
||||
)
|
||||
"""MaiSaka 使用的用户名称"""
|
||||
|
||||
max_internal_rounds: int = Field(
|
||||
default=6,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "repeat",
|
||||
},
|
||||
)
|
||||
"""每个入站消息的最大内部规划轮数"""
|
||||
|
||||
planner_interrupt_max_consecutive_count: int = Field(
|
||||
default=2,
|
||||
ge=0,
|
||||
@@ -1531,26 +1484,6 @@ class MaiSakaConfig(ConfigBase):
|
||||
)
|
||||
"""Planner 连续被新消息打断的最大次数,0 表示不启用打断"""
|
||||
|
||||
enable_memory_query_tool: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "database",
|
||||
},
|
||||
)
|
||||
"""是否启用 Maisaka 内置长期记忆检索工具 query_memory"""
|
||||
|
||||
memory_query_default_limit: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=20,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "hash",
|
||||
},
|
||||
)
|
||||
"""Maisaka 内置长期记忆检索工具 query_memory 的默认返回条数"""
|
||||
|
||||
tool_filter_task_name: str = Field(
|
||||
default="utils",
|
||||
json_schema_extra={
|
||||
|
||||
@@ -47,7 +47,7 @@ def get_action_tool_specs() -> List[ToolSpec]:
|
||||
get_reply_tool_spec(),
|
||||
get_view_complex_message_tool_spec(),
|
||||
get_query_jargon_tool_spec(),
|
||||
get_query_memory_tool_spec(enabled=bool(global_config.maisaka.enable_memory_query_tool)),
|
||||
get_query_memory_tool_spec(enabled=bool(global_config.memory.enable_memory_query_tool)),
|
||||
get_send_emoji_tool_spec(),
|
||||
]
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ async def handle_tool(
|
||||
f"不支持的检索模式:{mode}。可选值:search/time/hybrid/episode/aggregate。",
|
||||
)
|
||||
|
||||
default_limit = max(1, int(getattr(global_config.maisaka, "memory_query_default_limit", 5) or 5))
|
||||
default_limit = max(1, global_config.memory.memory_query_default_limit)
|
||||
try:
|
||||
limit = int(invocation.arguments.get("limit", default_limit) or default_limit)
|
||||
except (TypeError, ValueError):
|
||||
|
||||
@@ -596,7 +596,7 @@ class MaisakaReasoningEngine:
|
||||
planner_prefix: str,
|
||||
) -> MessageSequence:
|
||||
message_sequence = build_prefixed_message_sequence(message.raw_message, planner_prefix)
|
||||
if global_config.chat.multimodal_planner:
|
||||
if global_config.visual.multimodal_planner:
|
||||
await self._hydrate_visual_components(message_sequence.components)
|
||||
return message_sequence
|
||||
|
||||
|
||||
@@ -37,6 +37,8 @@ from .tool_provider import MaisakaBuiltinToolProvider
|
||||
|
||||
logger = get_logger("maisaka_runtime")
|
||||
|
||||
MAX_INTERNAL_ROUNDS = 6
|
||||
|
||||
|
||||
class MaisakaHeartFlowChatting:
|
||||
"""会话级别的 Maisaka 运行时。"""
|
||||
@@ -78,7 +80,7 @@ class MaisakaHeartFlowChatting:
|
||||
self._message_debounce_required = False
|
||||
self._last_message_received_at = 0.0
|
||||
self._wait_timeout_task: Optional[asyncio.Task[None]] = None
|
||||
self._max_internal_rounds = global_config.maisaka.max_internal_rounds
|
||||
self._max_internal_rounds = MAX_INTERNAL_ROUNDS
|
||||
self._max_context_size = max(1, int(global_config.chat.max_context_size))
|
||||
self._agent_state: Literal["running", "wait", "stop"] = self._STATE_STOP
|
||||
self._wait_until: Optional[float] = None
|
||||
|
||||
@@ -1,878 +0,0 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services import llm_service as llm_api
|
||||
|
||||
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
|
||||
|
||||
logger = get_logger("memory_retrieval")
|
||||
|
||||
|
||||
def init_memory_retrieval_sys():
|
||||
"""初始化记忆检索相关工具"""
|
||||
# 注册所有工具
|
||||
init_all_tools()
|
||||
|
||||
|
||||
def _log_conversation_messages(
|
||||
conversation_messages: List[Message],
|
||||
head_prompt: Optional[str] = None,
|
||||
final_status: Optional[str] = None,
|
||||
) -> None:
|
||||
"""输出对话消息列表的日志
|
||||
|
||||
Args:
|
||||
conversation_messages: 对话消息列表
|
||||
head_prompt: 第一条系统消息(head_prompt)的内容,可选
|
||||
final_status: 最终结果状态描述(例如:找到答案/未找到答案),可选
|
||||
"""
|
||||
if not global_config.debug.show_memory_prompt:
|
||||
return
|
||||
|
||||
log_lines: List[str] = []
|
||||
|
||||
# 如果有head_prompt,先添加为第一条消息
|
||||
if head_prompt:
|
||||
msg_info = "========================================\n[消息 1] 角色: System\n-----------------------------"
|
||||
msg_info += f"\n{head_prompt}"
|
||||
log_lines.append(msg_info)
|
||||
start_idx = 2
|
||||
else:
|
||||
start_idx = 1
|
||||
|
||||
if not conversation_messages and not head_prompt:
|
||||
return
|
||||
|
||||
for idx, msg in enumerate(conversation_messages, start_idx):
|
||||
role_name = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
|
||||
|
||||
# 构建单条消息的日志信息
|
||||
# msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
|
||||
msg_info = (
|
||||
f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
|
||||
)
|
||||
|
||||
# if full_content:
|
||||
# msg_info += f"\n{full_content}"
|
||||
if msg.content:
|
||||
msg_info += f"\n{msg.content}"
|
||||
|
||||
if msg.tool_calls:
|
||||
msg_info += f"\n 工具调用: {len(msg.tool_calls)}个"
|
||||
for tool_call in msg.tool_calls:
|
||||
msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}"
|
||||
|
||||
# if msg.tool_call_id:
|
||||
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
|
||||
|
||||
log_lines.append(msg_info)
|
||||
|
||||
total_count = len(conversation_messages) + (1 if head_prompt else 0)
|
||||
log_text = f"消息列表 (共{total_count}条):{''.join(log_lines)}"
|
||||
if final_status:
|
||||
log_text += f"\n\n[最终结果] {final_status}"
|
||||
logger.info(log_text)
|
||||
|
||||
|
||||
async def _react_agent_solve_question(
|
||||
chat_id: str,
|
||||
max_iterations: int = 5,
|
||||
timeout: float = 30.0,
|
||||
initial_info: str = "",
|
||||
chat_history: str = "",
|
||||
) -> Tuple[bool, str, List[Dict[str, Any]], bool]:
|
||||
"""使用ReAct架构的Agent来解决问题
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
max_iterations: 最大迭代次数
|
||||
timeout: 超时时间(秒)
|
||||
initial_info: 初始信息,将作为collected_info的初始值
|
||||
chat_history: 聊天记录,将传递给 ReAct Agent prompt
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时)
|
||||
"""
|
||||
start_time = time.time()
|
||||
collected_info = initial_info or ""
|
||||
# 构造日志前缀:[聊天流名称],用于在日志中标识聊天流
|
||||
try:
|
||||
chat_name = _chat_manager.get_session_name(chat_id) or chat_id
|
||||
except Exception:
|
||||
chat_name = chat_id
|
||||
react_log_prefix = f"[{chat_name}] "
|
||||
thinking_steps = []
|
||||
is_timeout = False
|
||||
conversation_messages: List[Message] = []
|
||||
first_head_prompt: Optional[str] = None # 保存第一次使用的head_prompt(用于日志显示)
|
||||
last_tool_name: Optional[str] = None # 记录最后一次使用的工具名称
|
||||
|
||||
# 使用 while 循环,支持额外迭代
|
||||
iteration = 0
|
||||
max_iterations_with_extra = max_iterations
|
||||
while iteration < max_iterations_with_extra:
|
||||
# 检查超时
|
||||
if time.time() - start_time > timeout:
|
||||
logger.warning(f"ReAct Agent超时,已迭代{iteration}次")
|
||||
is_timeout = True
|
||||
break
|
||||
|
||||
# 获取工具注册器
|
||||
tool_registry = get_tool_registry()
|
||||
|
||||
# 获取bot_name
|
||||
bot_name = global_config.bot.nickname
|
||||
|
||||
# 获取当前时间
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
# 计算剩余迭代次数
|
||||
current_iteration = iteration + 1
|
||||
remaining_iterations = max_iterations - current_iteration
|
||||
|
||||
# 提取函数调用中参数的值,支持单引号和双引号
|
||||
def extract_quoted_content(text, func_name, param_name):
|
||||
"""从文本中提取函数调用中参数的值,支持单引号和双引号
|
||||
|
||||
Args:
|
||||
text: 要搜索的文本
|
||||
func_name: 函数名,如 'return_information'
|
||||
param_name: 参数名,如 'information'
|
||||
|
||||
Returns:
|
||||
提取的参数值,如果未找到则返回None
|
||||
"""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# 查找函数调用位置(不区分大小写)
|
||||
func_pattern = func_name.lower()
|
||||
text_lower = text.lower()
|
||||
func_pos = text_lower.find(func_pattern)
|
||||
if func_pos == -1:
|
||||
return None
|
||||
|
||||
# 查找参数名和等号
|
||||
param_pattern = f"{param_name}="
|
||||
param_pos = text_lower.find(param_pattern, func_pos)
|
||||
if param_pos == -1:
|
||||
return None
|
||||
|
||||
# 跳过参数名、等号和空白
|
||||
start_pos = param_pos + len(param_pattern)
|
||||
while start_pos < len(text) and text[start_pos] in " \t\n":
|
||||
start_pos += 1
|
||||
|
||||
if start_pos >= len(text):
|
||||
return None
|
||||
|
||||
# 确定引号类型
|
||||
quote_char = text[start_pos]
|
||||
if quote_char not in ['"', "'"]:
|
||||
return None
|
||||
|
||||
# 查找匹配的结束引号(考虑转义)
|
||||
end_pos = start_pos + 1
|
||||
while end_pos < len(text):
|
||||
if text[end_pos] == quote_char:
|
||||
# 检查是否是转义的引号
|
||||
if end_pos > start_pos + 1 and text[end_pos - 1] == "\\":
|
||||
end_pos += 1
|
||||
continue
|
||||
# 找到匹配的引号
|
||||
content = text[start_pos + 1 : end_pos]
|
||||
# 处理转义字符
|
||||
content = content.replace('\\"', '"').replace("\\'", "'").replace("\\\\", "\\")
|
||||
return content
|
||||
end_pos += 1
|
||||
|
||||
return None
|
||||
|
||||
# 正常迭代:使用head_prompt决定调用哪些工具(包含return_information工具)
|
||||
tool_definitions = tool_registry.get_tool_definitions()
|
||||
# tool_names = [tool_def["name"] for tool_def in tool_definitions]
|
||||
# logger.debug(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具: {', '.join(tool_names)} (共{len(tool_definitions)}个)")
|
||||
|
||||
# head_prompt应该只构建一次,使用初始的collected_info,后续迭代都复用同一个
|
||||
if first_head_prompt is None:
|
||||
# 第一次构建,使用初始的collected_info(即initial_info)
|
||||
initial_collected_info = initial_info or ""
|
||||
# 使用统一长期记忆检索 prompt
|
||||
first_head_prompt_template = prompt_manager.get_prompt("memory_retrieval_react_prompt_head_memory")
|
||||
first_head_prompt_template.add_context("bot_name", bot_name)
|
||||
first_head_prompt_template.add_context("time_now", time_now)
|
||||
first_head_prompt_template.add_context("chat_history", chat_history)
|
||||
first_head_prompt_template.add_context("collected_info", initial_collected_info)
|
||||
first_head_prompt_template.add_context("current_iteration", str(current_iteration))
|
||||
first_head_prompt_template.add_context("remaining_iterations", str(remaining_iterations))
|
||||
first_head_prompt_template.add_context("max_iterations", str(max_iterations))
|
||||
first_head_prompt = await prompt_manager.render_prompt(first_head_prompt_template)
|
||||
|
||||
# 后续迭代都复用第一次构建的head_prompt
|
||||
head_prompt = first_head_prompt
|
||||
|
||||
def _build_messages(
|
||||
_client,
|
||||
*,
|
||||
_head_prompt: str = head_prompt,
|
||||
_conversation_messages: List[Message] = conversation_messages,
|
||||
):
|
||||
messages: List[Message] = []
|
||||
|
||||
system_builder = MessageBuilder()
|
||||
system_builder.set_role(RoleType.System)
|
||||
system_builder.add_text_content(_head_prompt)
|
||||
messages.append(system_builder.build())
|
||||
|
||||
messages.extend(_conversation_messages)
|
||||
|
||||
return messages
|
||||
|
||||
message_factory_fn: Callable[..., List[Message]] = _build_messages # pyright: ignore[reportGeneralTypeIssues]
|
||||
generation_result = await llm_api.generate(
|
||||
llm_api.LLMServiceRequest(
|
||||
task_name="utils",
|
||||
request_type="memory.react",
|
||||
message_factory=message_factory_fn, # type: ignore[arg-type]
|
||||
tool_options=tool_definitions,
|
||||
)
|
||||
)
|
||||
success = generation_result.success
|
||||
response = generation_result.completion.response
|
||||
reasoning_content = generation_result.completion.reasoning
|
||||
tool_calls = generation_result.completion.tool_calls
|
||||
|
||||
# logger.info(
|
||||
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
|
||||
# )
|
||||
|
||||
if not success:
|
||||
logger.error(f"ReAct Agent LLM调用失败: {response}")
|
||||
break
|
||||
|
||||
# 注意:这里会检查return_information工具调用,如果检测到return_information工具,会根据information参数决定返回信息或退出查询
|
||||
|
||||
assistant_message: Optional[Message] = None
|
||||
if tool_calls:
|
||||
assistant_builder = MessageBuilder()
|
||||
assistant_builder.set_role(RoleType.Assistant)
|
||||
if response and response.strip():
|
||||
assistant_builder.add_text_content(response)
|
||||
assistant_builder.set_tool_calls(tool_calls)
|
||||
assistant_message = assistant_builder.build()
|
||||
elif response and response.strip():
|
||||
assistant_builder = MessageBuilder()
|
||||
assistant_builder.set_role(RoleType.Assistant)
|
||||
assistant_builder.add_text_content(response)
|
||||
assistant_message = assistant_builder.build()
|
||||
|
||||
# 记录思考步骤
|
||||
step: Dict[str, Any] = {
|
||||
"iteration": iteration + 1,
|
||||
"thought": response,
|
||||
"actions": [],
|
||||
"observations": [],
|
||||
}
|
||||
|
||||
if assistant_message:
|
||||
conversation_messages.append(assistant_message)
|
||||
|
||||
# 记录思考过程到collected_info中
|
||||
if reasoning_content or response:
|
||||
thought_summary = reasoning_content or (response[:200] if response else "")
|
||||
if thought_summary:
|
||||
collected_info += f"\n[思考] {thought_summary}\n"
|
||||
|
||||
# 处理工具调用
|
||||
if not tool_calls:
|
||||
# 如果没有工具调用,检查响应文本中是否包含return_information函数调用格式或JSON格式
|
||||
if response and response.strip():
|
||||
# 首先尝试解析JSON格式的return_information
|
||||
def parse_json_return_information(text: str):
|
||||
"""从文本中解析JSON格式的return_information,返回information字符串,如果未找到则返回None"""
|
||||
if not text:
|
||||
return None, None
|
||||
|
||||
try:
|
||||
# 尝试提取JSON对象(可能包含在代码块中或直接是JSON)
|
||||
json_text = text.strip()
|
||||
|
||||
# 如果包含代码块标记,提取JSON部分
|
||||
if "```json" in json_text:
|
||||
start = json_text.find("```json") + 7
|
||||
end = json_text.find("```", start)
|
||||
if end != -1:
|
||||
json_text = json_text[start:end].strip()
|
||||
elif "```" in json_text:
|
||||
start = json_text.find("```") + 3
|
||||
end = json_text.find("```", start)
|
||||
if end != -1:
|
||||
json_text = json_text[start:end].strip()
|
||||
|
||||
# 尝试解析JSON
|
||||
data = json.loads(json_text)
|
||||
|
||||
# 检查是否包含return_information字段
|
||||
if isinstance(data, dict) and "return_information" in data:
|
||||
information = data.get("information", "")
|
||||
return information
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
# 如果JSON解析失败,尝试在文本中查找JSON对象
|
||||
with contextlib.suppress(json.JSONDecodeError, ValueError, TypeError):
|
||||
# 查找第一个 { 和最后一个 } 之间的内容(更健壮的JSON提取)
|
||||
first_brace = text.find("{")
|
||||
if first_brace != -1:
|
||||
# 从第一个 { 开始,找到匹配的 }
|
||||
brace_count = 0
|
||||
json_end = -1
|
||||
for i in range(first_brace, len(text)):
|
||||
if text[i] == "{":
|
||||
brace_count += 1
|
||||
elif text[i] == "}":
|
||||
brace_count -= 1
|
||||
if brace_count == 0:
|
||||
json_end = i + 1
|
||||
break
|
||||
|
||||
if json_end != -1:
|
||||
json_text = text[first_brace:json_end]
|
||||
data = json.loads(json_text)
|
||||
if isinstance(data, dict) and "return_information" in data:
|
||||
information = data.get("information", "")
|
||||
return information
|
||||
|
||||
return None
|
||||
|
||||
# 尝试从文本中解析return_information函数调用
|
||||
def parse_return_information_from_text(text: str):
|
||||
"""从文本中解析return_information函数调用,返回information字符串,如果未找到则返回None"""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# 查找return_information函数调用位置(不区分大小写)
|
||||
func_pattern = "return_information"
|
||||
text_lower = text.lower()
|
||||
func_pos = text_lower.find(func_pattern)
|
||||
if func_pos == -1:
|
||||
return None
|
||||
|
||||
# 解析information参数(字符串,使用extract_quoted_content)
|
||||
information = extract_quoted_content(text, "return_information", "information")
|
||||
|
||||
# 如果information存在(即使是空字符串),也返回它
|
||||
return information
|
||||
|
||||
# 首先尝试解析JSON格式
|
||||
parsed_information_json = parse_json_return_information(response)
|
||||
is_json_format = parsed_information_json is not None
|
||||
|
||||
# 如果JSON解析成功,使用JSON结果
|
||||
if is_json_format:
|
||||
parsed_information = parsed_information_json
|
||||
else:
|
||||
# 如果JSON解析失败,尝试解析函数调用格式
|
||||
parsed_information = parse_return_information_from_text(response)
|
||||
|
||||
if parsed_information is not None or is_json_format:
|
||||
# 检测到return_information格式(可能是JSON格式或函数调用格式)
|
||||
format_type = "JSON格式" if is_json_format else "函数调用格式"
|
||||
# 返回信息(即使为空字符串也返回)
|
||||
step["actions"].append(
|
||||
{
|
||||
"action_type": "return_information",
|
||||
"action_params": {"information": parsed_information or ""},
|
||||
}
|
||||
)
|
||||
parsed_info_text = parsed_information if isinstance(parsed_information, str) else ""
|
||||
if parsed_info_text.strip():
|
||||
step["observations"] = [f"检测到return_information{format_type}调用,返回信息"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过return_information{format_type}返回信息: {parsed_info_text[:100]}..."
|
||||
)
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status=f"返回信息:{parsed_info_text}",
|
||||
)
|
||||
|
||||
return True, parsed_info_text, thinking_steps, False
|
||||
else:
|
||||
# 信息为空,直接退出查询
|
||||
step["observations"] = [f"检测到return_information{format_type}调用,信息为空"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过return_information{format_type}判断信息为空"
|
||||
)
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="信息为空:通过return_information文本格式判断信息为空",
|
||||
)
|
||||
|
||||
return False, "", thinking_steps, False
|
||||
|
||||
# 如果没有检测到return_information格式,记录思考过程,继续下一轮迭代
|
||||
step["observations"] = [f"思考完成,但未调用工具。响应: {response}"]
|
||||
logger.info(f"{react_log_prefix}第 {iteration + 1} 次迭代 思考完成但未调用工具: {response}")
|
||||
collected_info += f"思考: {response}"
|
||||
else:
|
||||
logger.warning(f"{react_log_prefix}第 {iteration + 1} 次迭代 无工具调用且无响应")
|
||||
step["observations"] = ["无响应且无工具调用"]
|
||||
thinking_steps.append(step)
|
||||
iteration += 1 # 在continue之前增加迭代计数,避免跳过iteration += 1
|
||||
continue
|
||||
|
||||
# 处理工具调用
|
||||
# 首先检查是否有return_information工具调用,如果有则立即返回,不再处理其他工具
|
||||
return_information_info = None
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.func_name
|
||||
tool_args = tool_call.args or {}
|
||||
|
||||
if tool_name == "return_information":
|
||||
return_information_info = tool_args.get("information", "")
|
||||
|
||||
# 返回信息(即使为空也返回)
|
||||
step["actions"].append(
|
||||
{
|
||||
"action_type": "return_information",
|
||||
"action_params": {"information": return_information_info},
|
||||
}
|
||||
)
|
||||
if return_information_info and return_information_info.strip():
|
||||
# 有信息,返回
|
||||
step["observations"] = ["检测到return_information工具调用,返回信息"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过return_information工具返回信息: {return_information_info}"
|
||||
)
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status=f"返回信息:{return_information_info}",
|
||||
)
|
||||
|
||||
return True, return_information_info, thinking_steps, False
|
||||
else:
|
||||
# 信息为空,直接退出查询
|
||||
step["observations"] = ["检测到return_information工具调用,信息为空"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(f"{react_log_prefix}第 {iteration + 1} 次迭代 通过return_information工具判断信息为空")
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="信息为空:通过return_information工具判断信息为空",
|
||||
)
|
||||
|
||||
return False, "", thinking_steps, False
|
||||
|
||||
# 如果没有return_information工具调用,继续处理其他工具
|
||||
tool_tasks = []
|
||||
for i, tool_call in enumerate(tool_calls):
|
||||
tool_name = tool_call.func_name
|
||||
tool_args = tool_call.args or {}
|
||||
|
||||
logger.debug(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
|
||||
)
|
||||
|
||||
# 跳过return_information工具调用(已经在上面处理过了)
|
||||
if tool_name == "return_information":
|
||||
continue
|
||||
|
||||
# 记录最后一次使用的工具名称(用于判断是否需要额外迭代)
|
||||
last_tool_name = tool_name
|
||||
|
||||
# 普通工具调用
|
||||
tool = tool_registry.get_tool(tool_name)
|
||||
if tool:
|
||||
# 准备工具参数(需要添加chat_id如果工具需要)
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(tool.execute_func)
|
||||
tool_params = tool_args.copy()
|
||||
if "chat_id" in sig.parameters:
|
||||
tool_params["chat_id"] = chat_id
|
||||
|
||||
# 创建异步任务
|
||||
async def execute_single_tool(tool_instance, params, tool_name_str, iter_num):
|
||||
try:
|
||||
observation = await tool_instance.execute(**params)
|
||||
param_str = ", ".join([f"{k}={v}" for k, v in params.items() if k != "chat_id"])
|
||||
return f"查询{tool_name_str}({param_str})的结果:{observation}"
|
||||
except Exception as e:
|
||||
error_msg = f"工具执行失败: {str(e)}"
|
||||
logger.error(f"{react_log_prefix}第 {iter_num + 1} 次迭代 工具 {tool_name_str} {error_msg}")
|
||||
return f"查询{tool_name_str}失败: {error_msg}"
|
||||
|
||||
tool_tasks.append(execute_single_tool(tool, tool_params, tool_name, iteration))
|
||||
step["actions"].append({"action_type": tool_name, "action_params": tool_args})
|
||||
else:
|
||||
error_msg = f"未知的工具类型: {tool_name}"
|
||||
logger.warning(
|
||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 工具 {i + 1}/{len(tool_calls)} {error_msg}"
|
||||
)
|
||||
tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{tool_name}失败: {error_msg}")))
|
||||
|
||||
# 并行执行所有工具
|
||||
if tool_tasks:
|
||||
observations = await asyncio.gather(*tool_tasks, return_exceptions=True)
|
||||
|
||||
# 处理执行结果
|
||||
for i, (tool_call_item, observation) in enumerate(zip(tool_calls, observations, strict=False)):
|
||||
if isinstance(observation, Exception):
|
||||
observation = f"工具执行异常: {str(observation)}"
|
||||
logger.error(f"{react_log_prefix}第 {iteration + 1} 次迭代 工具 {i + 1} 执行异常: {observation}")
|
||||
|
||||
observation_text = observation if isinstance(observation, str) else str(observation)
|
||||
stripped_observation = observation_text.strip()
|
||||
step["observations"].append(observation_text)
|
||||
collected_info += f"\n{observation_text}\n"
|
||||
if stripped_observation:
|
||||
# 不再自动检测工具输出中的jargon,改为通过 query_words 工具主动查询
|
||||
tool_builder = MessageBuilder()
|
||||
tool_builder.set_role(RoleType.Tool)
|
||||
tool_builder.add_text_content(observation_text)
|
||||
tool_builder.add_tool_call(tool_call_item.call_id)
|
||||
conversation_messages.append(tool_builder.build())
|
||||
|
||||
thinking_steps.append(step)
|
||||
|
||||
# 检查是否需要额外迭代:如果最后一次使用的工具是 search_chat_history 且达到最大迭代次数,额外增加一回合
|
||||
if iteration + 1 >= max_iterations and last_tool_name == "search_chat_history" and not is_timeout:
|
||||
max_iterations_with_extra = max_iterations + 1
|
||||
logger.info(
|
||||
f"{react_log_prefix}达到最大迭代次数(已迭代{iteration + 1}次),最后一次使用工具为 search_chat_history,额外增加一回合尝试"
|
||||
)
|
||||
|
||||
iteration += 1
|
||||
|
||||
# 正常迭代结束后,如果达到最大迭代次数或超时,执行最终评估
|
||||
# 最终评估单独处理,不算在迭代中
|
||||
should_do_final_evaluation = False
|
||||
if is_timeout:
|
||||
should_do_final_evaluation = True
|
||||
logger.warning(f"{react_log_prefix}超时,已迭代{iteration}次,进入最终评估")
|
||||
elif iteration >= max_iterations:
|
||||
should_do_final_evaluation = True
|
||||
logger.info(f"{react_log_prefix}达到最大迭代次数(已迭代{iteration}次),进入最终评估")
|
||||
|
||||
if should_do_final_evaluation:
|
||||
# 获取必要变量用于最终评估
|
||||
tool_registry = get_tool_registry()
|
||||
bot_name = global_config.bot.nickname
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
current_iteration = iteration + 1
|
||||
remaining_iterations = 0
|
||||
|
||||
# 提取函数调用中参数的值,支持单引号和双引号
|
||||
def extract_quoted_content(text, func_name, param_name):
|
||||
"""从文本中提取函数调用中参数的值,支持单引号和双引号
|
||||
|
||||
Args:
|
||||
text: 要搜索的文本
|
||||
func_name: 函数名,如 'return_information'
|
||||
param_name: 参数名,如 'information'
|
||||
|
||||
Returns:
|
||||
提取的参数值,如果未找到则返回None
|
||||
"""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# 查找函数调用位置(不区分大小写)
|
||||
func_pattern = func_name.lower()
|
||||
text_lower = text.lower()
|
||||
func_pos = text_lower.find(func_pattern)
|
||||
if func_pos == -1:
|
||||
return None
|
||||
|
||||
# 查找参数名和等号
|
||||
param_pattern = f"{param_name}="
|
||||
param_pos = text_lower.find(param_pattern, func_pos)
|
||||
if param_pos == -1:
|
||||
return None
|
||||
|
||||
# 跳过参数名、等号和空白
|
||||
start_pos = param_pos + len(param_pattern)
|
||||
while start_pos < len(text) and text[start_pos] in " \t\n":
|
||||
start_pos += 1
|
||||
|
||||
if start_pos >= len(text):
|
||||
return None
|
||||
|
||||
# 确定引号类型
|
||||
quote_char = text[start_pos]
|
||||
if quote_char not in ['"', "'"]:
|
||||
return None
|
||||
|
||||
# 查找匹配的结束引号(考虑转义)
|
||||
end_pos = start_pos + 1
|
||||
while end_pos < len(text):
|
||||
if text[end_pos] == quote_char:
|
||||
# 检查是否是转义的引号
|
||||
if end_pos > start_pos + 1 and text[end_pos - 1] == "\\":
|
||||
end_pos += 1
|
||||
continue
|
||||
# 找到匹配的引号
|
||||
content = text[start_pos + 1 : end_pos]
|
||||
# 处理转义字符
|
||||
content = content.replace('\\"', '"').replace("\\'", "'").replace("\\\\", "\\")
|
||||
return content
|
||||
end_pos += 1
|
||||
|
||||
return None
|
||||
|
||||
# 执行最终评估
|
||||
evaluation_prompt_template = prompt_manager.get_prompt("memory_retrieval_react_final")
|
||||
evaluation_prompt_template.add_context("bot_name", bot_name)
|
||||
evaluation_prompt_template.add_context("time_now", time_now)
|
||||
evaluation_prompt_template.add_context("chat_history", chat_history)
|
||||
evaluation_prompt_template.add_context("collected_info", collected_info or "暂无信息")
|
||||
evaluation_prompt_template.add_context("current_iteration", str(current_iteration))
|
||||
evaluation_prompt_template.add_context("remaining_iterations", str(remaining_iterations))
|
||||
evaluation_prompt_template.add_context("max_iterations", str(max_iterations))
|
||||
evaluation_prompt = await prompt_manager.render_prompt(evaluation_prompt_template)
|
||||
|
||||
evaluation_result = await llm_api.generate(
|
||||
llm_api.LLMServiceRequest(
|
||||
task_name="utils",
|
||||
request_type="memory.react.final",
|
||||
prompt=evaluation_prompt,
|
||||
tool_options=[],
|
||||
)
|
||||
)
|
||||
eval_success = evaluation_result.success
|
||||
eval_response = evaluation_result.completion.response
|
||||
|
||||
if not eval_success:
|
||||
logger.error(f"ReAct Agent 最终评估阶段 LLM调用失败: {eval_response}")
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="未找到答案:最终评估阶段LLM调用失败",
|
||||
)
|
||||
return False, "最终评估阶段LLM调用失败", thinking_steps, is_timeout
|
||||
|
||||
if global_config.debug.show_memory_prompt:
|
||||
logger.info(f"{react_log_prefix}最终评估Prompt: {evaluation_prompt}")
|
||||
logger.info(f"{react_log_prefix}最终评估响应: {eval_response}")
|
||||
|
||||
# 从最终评估响应中提取return_information
|
||||
return_information_content = None
|
||||
|
||||
if eval_response:
|
||||
return_information_content = extract_quoted_content(eval_response, "return_information", "information")
|
||||
|
||||
# 如果提取到信息,返回(无论是否超时,都视为成功完成)
|
||||
if return_information_content is not None:
|
||||
eval_step = {
|
||||
"iteration": current_iteration,
|
||||
"thought": f"[最终评估] {eval_response}",
|
||||
"actions": [
|
||||
{"action_type": "return_information", "action_params": {"information": return_information_content}}
|
||||
],
|
||||
"observations": ["最终评估阶段检测到return_information"],
|
||||
}
|
||||
thinking_steps.append(eval_step)
|
||||
if return_information_content and return_information_content.strip():
|
||||
logger.info(f"ReAct Agent 最终评估阶段返回信息: {return_information_content}")
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status=f"返回信息:{return_information_content}",
|
||||
)
|
||||
return True, return_information_content, thinking_steps, False
|
||||
else:
|
||||
logger.info("ReAct Agent 最终评估阶段判断信息为空")
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="信息为空:最终评估阶段判断信息为空",
|
||||
)
|
||||
return False, "", thinking_steps, False
|
||||
|
||||
# 如果没有明确判断,视为not_enough_info,返回空字符串(不返回任何信息)
|
||||
eval_step = {
|
||||
"iteration": current_iteration,
|
||||
"thought": f"[最终评估] {eval_response}",
|
||||
"actions": [{"action_type": "return_information", "action_params": {"information": ""}}],
|
||||
"observations": ["已到达最大迭代次数,信息为空"],
|
||||
}
|
||||
thinking_steps.append(eval_step)
|
||||
logger.info("ReAct Agent 已到达最大迭代次数,信息为空")
|
||||
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="未找到答案:已到达最大迭代次数,无法找到答案",
|
||||
)
|
||||
|
||||
return False, "", thinking_steps, is_timeout
|
||||
|
||||
# 如果正常迭代过程中提前找到答案返回,不会到达这里
|
||||
# 如果正常迭代结束但没有触发最终评估(理论上不应该发生),直接返回
|
||||
logger.warning("ReAct Agent正常迭代结束,但未触发最终评估")
|
||||
_log_conversation_messages(
|
||||
conversation_messages,
|
||||
head_prompt=first_head_prompt,
|
||||
final_status="未找到答案:正常迭代结束",
|
||||
)
|
||||
|
||||
return False, "", thinking_steps, is_timeout
|
||||
|
||||
async def _process_memory_retrieval(
|
||||
chat_id: str,
|
||||
context: str,
|
||||
initial_info: str = "",
|
||||
max_iterations: Optional[int] = None,
|
||||
chat_history: str = "",
|
||||
) -> Optional[str]:
|
||||
"""处理记忆检索
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
context: 上下文信息
|
||||
initial_info: 初始信息,将传递给ReAct Agent
|
||||
max_iterations: 最大迭代次数
|
||||
chat_history: 聊天记录,将传递给 ReAct Agent
|
||||
|
||||
Returns:
|
||||
Optional[str]: 如果找到答案,返回答案内容,否则返回None
|
||||
"""
|
||||
question_initial_info = initial_info or ""
|
||||
|
||||
# 直接使用ReAct Agent进行记忆检索
|
||||
# 如果未指定max_iterations,使用配置的默认值
|
||||
if max_iterations is None:
|
||||
max_iterations = global_config.memory.max_agent_iterations
|
||||
|
||||
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
|
||||
chat_id=chat_id,
|
||||
max_iterations=max_iterations,
|
||||
timeout=global_config.memory.agent_timeout_seconds,
|
||||
initial_info=question_initial_info,
|
||||
chat_history=chat_history,
|
||||
)
|
||||
|
||||
# 不再存储到数据库,直接返回答案
|
||||
if is_timeout:
|
||||
logger.info("ReAct Agent超时,不返回结果")
|
||||
|
||||
return answer if found_answer and answer else None
|
||||
|
||||
|
||||
async def build_memory_retrieval_prompt(
|
||||
message: str,
|
||||
sender: str,
|
||||
target: str,
|
||||
chat_stream,
|
||||
think_level: int = 1,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""构建记忆检索提示
|
||||
Args:
|
||||
message: 聊天历史记录
|
||||
sender: 发送者名称
|
||||
target: 目标消息内容
|
||||
chat_stream: 聊天流对象
|
||||
think_level: 思考深度等级
|
||||
unknown_words: Planner 提供的未知词语列表,优先使用此列表而不是从聊天记录匹配
|
||||
|
||||
Returns:
|
||||
str: 记忆检索结果字符串
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 构造日志前缀:[聊天流名称],用于在日志中标识聊天流(优先群名称/用户昵称)
|
||||
try:
|
||||
group_info = chat_stream.group_info
|
||||
user_info = chat_stream.user_info
|
||||
# 群聊优先使用群名称
|
||||
if group_info is not None and getattr(group_info, "group_name", None):
|
||||
stream_name = group_info.group_name.strip() or str(group_info.group_id)
|
||||
# 私聊使用用户昵称
|
||||
elif user_info is not None and getattr(user_info, "user_nickname", None):
|
||||
stream_name = user_info.user_nickname.strip() or str(user_info.user_id)
|
||||
# 兜底使用 stream_id
|
||||
else:
|
||||
stream_name = chat_stream.stream_id
|
||||
except Exception:
|
||||
stream_name = chat_stream.stream_id
|
||||
log_prefix = f"[{stream_name}] " if stream_name else ""
|
||||
|
||||
logger.info(f"{log_prefix}检测是否需要回忆,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||
try:
|
||||
chat_id = chat_stream.stream_id
|
||||
|
||||
# 初始阶段:使用 Planner 提供的 unknown_words 进行检索(如果提供)
|
||||
initial_info = ""
|
||||
if unknown_words and len(unknown_words) > 0:
|
||||
# 清理和去重 unknown_words
|
||||
cleaned_concepts = []
|
||||
for word in unknown_words:
|
||||
if isinstance(word, str):
|
||||
if cleaned := word.strip():
|
||||
cleaned_concepts.append(cleaned)
|
||||
if cleaned_concepts:
|
||||
# 对匹配到的概念进行jargon检索,作为初始信息
|
||||
concept_info = await retrieve_concepts_with_jargon(cleaned_concepts, chat_id)
|
||||
if concept_info:
|
||||
initial_info += concept_info
|
||||
logger.info(
|
||||
f"{log_prefix}使用 Planner 提供的 unknown_words,共 {len(cleaned_concepts)} 个概念,检索结果: {concept_info[:100]}..."
|
||||
)
|
||||
else:
|
||||
logger.debug(f"{log_prefix}unknown_words 检索未找到任何结果")
|
||||
|
||||
# 直接使用 ReAct Agent 进行记忆检索(跳过问题生成步骤)
|
||||
base_max_iterations = global_config.memory.max_agent_iterations
|
||||
# 根据think_level调整迭代次数:think_level=1时不变,think_level=0时减半
|
||||
if think_level == 0:
|
||||
max_iterations = max(1, base_max_iterations // 2) # 至少为1
|
||||
else:
|
||||
max_iterations = base_max_iterations
|
||||
timeout_seconds = global_config.memory.agent_timeout_seconds
|
||||
logger.debug(
|
||||
f"{log_prefix}直接使用 ReAct Agent 进行记忆检索,think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}秒"
|
||||
)
|
||||
|
||||
# 直接调用 ReAct Agent 处理记忆检索
|
||||
try:
|
||||
result = await _process_memory_retrieval(
|
||||
chat_id=chat_id,
|
||||
context=message,
|
||||
initial_info=initial_info,
|
||||
max_iterations=max_iterations,
|
||||
chat_history=message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{log_prefix}处理记忆检索时发生异常: {e}")
|
||||
result = None
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
if result:
|
||||
logger.info(f"{log_prefix}记忆检索成功,耗时: {(end_time - start_time):.3f}秒")
|
||||
return f"你回忆起了以下信息:\n{result}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
||||
else:
|
||||
logger.debug(f"{log_prefix}记忆检索未找到相关信息")
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{log_prefix}记忆检索时发生异常: {str(e)}")
|
||||
return ""
|
||||
@@ -1,32 +0,0 @@
|
||||
"""
|
||||
记忆检索工具模块
|
||||
提供统一的工具注册和管理系统
|
||||
"""
|
||||
|
||||
from .tool_registry import (
|
||||
MemoryRetrievalTool,
|
||||
MemoryRetrievalToolRegistry,
|
||||
register_memory_retrieval_tool,
|
||||
get_tool_registry,
|
||||
)
|
||||
|
||||
|
||||
def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
# 延迟导入,避免在仅使用部分工具或单元测试阶段触发不必要的依赖链。
|
||||
from .query_long_term_memory import register_tool as register_long_term_memory
|
||||
from .query_words import register_tool as register_query_words
|
||||
from .return_information import register_tool as register_return_information
|
||||
|
||||
register_query_words()
|
||||
register_return_information()
|
||||
register_long_term_memory()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MemoryRetrievalTool",
|
||||
"MemoryRetrievalToolRegistry",
|
||||
"register_memory_retrieval_tool",
|
||||
"get_tool_registry",
|
||||
"init_all_tools",
|
||||
]
|
||||
@@ -1,307 +0,0 @@
|
||||
"""通过统一长期记忆服务查询信息。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from calendar import monthrange
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Iterable, Literal, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.services.memory_service import MemoryHit, MemorySearchResult, memory_service
|
||||
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
_SUPPORTED_MODES = {"search", "time", "episode", "aggregate"}
|
||||
_RELATIVE_DAYS_RE = re.compile(r"^最近\s*(\d+)\s*天$")
|
||||
_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$")
|
||||
_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}\s+\d{2}:\d{2}$")
|
||||
_TIME_EXPRESSION_HELP = (
|
||||
"请改用更具体的时间表达,例如:今天、昨天、前天、本周、上周、本月、上月、最近7天、"
|
||||
"2026/03/18、2026/03/18 09:30。"
|
||||
)
|
||||
|
||||
|
||||
def _format_query_datetime(dt: datetime) -> str:
|
||||
return dt.strftime("%Y/%m/%d %H:%M")
|
||||
|
||||
|
||||
def _resolve_time_expression(
|
||||
expression: str,
|
||||
*,
|
||||
now: datetime | None = None,
|
||||
) -> Tuple[float, float, str, str]:
|
||||
clean = str(expression or "").strip()
|
||||
if not clean:
|
||||
raise ValueError(f"time 模式需要提供 time_expression。{_TIME_EXPRESSION_HELP}")
|
||||
|
||||
current = now or datetime.now()
|
||||
day_start = current.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
if clean == "今天":
|
||||
start = day_start
|
||||
end = day_start.replace(hour=23, minute=59)
|
||||
elif clean == "昨天":
|
||||
start = day_start - timedelta(days=1)
|
||||
end = start.replace(hour=23, minute=59)
|
||||
elif clean == "前天":
|
||||
start = day_start - timedelta(days=2)
|
||||
end = start.replace(hour=23, minute=59)
|
||||
elif clean == "本周":
|
||||
start = day_start - timedelta(days=day_start.weekday())
|
||||
end = start + timedelta(days=6, hours=23, minutes=59)
|
||||
elif clean == "上周":
|
||||
this_week_start = day_start - timedelta(days=day_start.weekday())
|
||||
start = this_week_start - timedelta(days=7)
|
||||
end = start + timedelta(days=6, hours=23, minutes=59)
|
||||
elif clean == "本月":
|
||||
start = day_start.replace(day=1)
|
||||
last_day = monthrange(start.year, start.month)[1]
|
||||
end = start.replace(day=last_day, hour=23, minute=59)
|
||||
elif clean == "上月":
|
||||
year = day_start.year
|
||||
month = day_start.month - 1
|
||||
if month == 0:
|
||||
year -= 1
|
||||
month = 12
|
||||
start = day_start.replace(year=year, month=month, day=1)
|
||||
last_day = monthrange(year, month)[1]
|
||||
end = start.replace(day=last_day, hour=23, minute=59)
|
||||
else:
|
||||
relative_match = _RELATIVE_DAYS_RE.fullmatch(clean)
|
||||
if relative_match:
|
||||
days = max(1, int(relative_match.group(1)))
|
||||
start = day_start - timedelta(days=max(0, days - 1))
|
||||
end = day_start.replace(hour=23, minute=59)
|
||||
elif _DATE_RE.fullmatch(clean):
|
||||
start = datetime.strptime(clean, "%Y/%m/%d")
|
||||
end = start.replace(hour=23, minute=59)
|
||||
elif _MINUTE_RE.fullmatch(clean):
|
||||
start = datetime.strptime(clean, "%Y/%m/%d %H:%M")
|
||||
end = start
|
||||
else:
|
||||
raise ValueError(f"时间表达“{clean}”无法解析。{_TIME_EXPRESSION_HELP}")
|
||||
|
||||
return start.timestamp(), end.timestamp(), _format_query_datetime(start), _format_query_datetime(end)
|
||||
|
||||
|
||||
def _extract_time_label(metadata: dict) -> str:
|
||||
if not isinstance(metadata, dict):
|
||||
return ""
|
||||
start = metadata.get("event_time_start")
|
||||
end = metadata.get("event_time_end")
|
||||
event_time = metadata.get("event_time")
|
||||
|
||||
def _fmt(value: object) -> str:
|
||||
if value in {None, ""}:
|
||||
return ""
|
||||
try:
|
||||
return datetime.fromtimestamp(float(value)).strftime("%Y/%m/%d %H:%M")
|
||||
except Exception:
|
||||
return str(value)
|
||||
|
||||
start_text = _fmt(start or event_time)
|
||||
end_text = _fmt(end)
|
||||
if start_text and end_text:
|
||||
return f"{start_text} - {end_text}"
|
||||
return start_text or end_text
|
||||
|
||||
|
||||
def _truncate(text: str, limit: int = 160) -> str:
|
||||
compact = str(text or "").strip().replace("\n", " ")
|
||||
if len(compact) <= limit:
|
||||
return compact
|
||||
return compact[:limit] + "..."
|
||||
|
||||
|
||||
def _format_search_lines(hits: Iterable[MemoryHit], *, limit: int, include_time: bool = False) -> str:
|
||||
lines = []
|
||||
for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1):
|
||||
time_label = _extract_time_label(item.metadata) if include_time else ""
|
||||
prefix = f"[{time_label}] " if time_label else ""
|
||||
lines.append(f"{index}. {prefix}{_truncate(item.content)}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_episode_lines(hits: Iterable[MemoryHit], *, limit: int) -> str:
|
||||
lines = []
|
||||
for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1):
|
||||
metadata = item.metadata if isinstance(item.metadata, dict) else {}
|
||||
title = str(item.title or "").strip() or "未命名事件"
|
||||
summary = _truncate(item.content, limit=180)
|
||||
participants = [str(x).strip() for x in (metadata.get("participants") or []) if str(x).strip()]
|
||||
keywords = [str(x).strip() for x in (metadata.get("keywords") or []) if str(x).strip()]
|
||||
extras = []
|
||||
if participants:
|
||||
extras.append(f"参与者:{'、'.join(participants[:4])}")
|
||||
if keywords:
|
||||
extras.append(f"关键词:{'、'.join(keywords[:6])}")
|
||||
time_label = _extract_time_label(metadata)
|
||||
if time_label:
|
||||
extras.append(f"时间:{time_label}")
|
||||
suffix = f"({';'.join(extras)})" if extras else ""
|
||||
lines.append(f"{index}. 事件《{title}》:{summary}{suffix}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_aggregate_lines(hits: Iterable[MemoryHit], *, limit: int) -> str:
|
||||
lines = []
|
||||
for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1):
|
||||
metadata = item.metadata if isinstance(item.metadata, dict) else {}
|
||||
source_branches = [str(x).strip() for x in (metadata.get("source_branches") or []) if str(x).strip()]
|
||||
branch_text = f"[{','.join(source_branches)}]" if source_branches else ""
|
||||
item_type = str(item.hit_type or "").strip().lower() or "memory"
|
||||
if item_type == "episode":
|
||||
title = str(item.title or "").strip() or "未命名事件"
|
||||
lines.append(f"{index}. {branch_text}[episode] 《{title}》:{_truncate(item.content, 160)}")
|
||||
else:
|
||||
lines.append(f"{index}. {branch_text}[{item_type}] {_truncate(item.content, 160)}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_tool_result(
|
||||
*,
|
||||
result: MemorySearchResult,
|
||||
mode: Literal["search", "time", "episode", "aggregate"],
|
||||
limit: int,
|
||||
query: str,
|
||||
time_range_text: str = "",
|
||||
) -> str:
|
||||
if not result.success:
|
||||
return f"长期记忆查询失败:{result.error or '未知错误'}"
|
||||
|
||||
if not result.hits:
|
||||
if mode == "time":
|
||||
return f"在指定时间范围内未找到相关的长期记忆{time_range_text}"
|
||||
if mode == "episode":
|
||||
return f"未找到与“{query}”相关的事件或情节记忆"
|
||||
if mode == "aggregate":
|
||||
return f"未找到可用于综合回忆的长期记忆线索{f'(query:{query})' if query else ''}"
|
||||
return f"在长期记忆中未找到与“{query}”相关的信息"
|
||||
|
||||
if mode == "episode":
|
||||
text = _format_episode_lines(result.hits, limit=limit)
|
||||
return f"你从长期记忆的事件/情节中找到以下信息:\n{text}"
|
||||
|
||||
if mode == "aggregate":
|
||||
text = _format_aggregate_lines(result.hits, limit=limit)
|
||||
return f"你从长期记忆中综合找到了以下线索:\n{text}"
|
||||
|
||||
if mode == "time":
|
||||
text = _format_search_lines(result.hits, limit=limit, include_time=True)
|
||||
return f"你从指定时间范围内的长期记忆中找到以下信息{time_range_text}:\n{text}"
|
||||
|
||||
text = _format_search_lines(result.hits, limit=limit)
|
||||
return f"你从长期记忆中找到以下信息:\n{text}"
|
||||
|
||||
|
||||
async def query_long_term_memory(
|
||||
query: str = "",
|
||||
limit: int = 5,
|
||||
chat_id: str = "",
|
||||
person_id: str = "",
|
||||
mode: str = "search",
|
||||
time_expression: str = "",
|
||||
) -> str:
|
||||
content = str(query or "").strip()
|
||||
safe_limit = max(1, int(limit or 5))
|
||||
normalized_mode = str(mode or "search").strip().lower() or "search"
|
||||
if normalized_mode not in _SUPPORTED_MODES:
|
||||
return f"不支持的长期记忆检索模式:{normalized_mode}。可用模式:search、time、episode、aggregate。"
|
||||
|
||||
if normalized_mode == "search" and not content:
|
||||
return "查询关键词为空,请提供你想查找的长期记忆内容。"
|
||||
if normalized_mode == "time" and not str(time_expression or "").strip():
|
||||
return f"time 模式需要提供 time_expression。{_TIME_EXPRESSION_HELP}"
|
||||
if normalized_mode in {"episode", "aggregate"} and not content and not str(time_expression or "").strip():
|
||||
return f"{normalized_mode} 模式至少需要提供 query 或 time_expression。"
|
||||
|
||||
time_start = None
|
||||
time_end = None
|
||||
time_range_text = ""
|
||||
if str(time_expression or "").strip():
|
||||
try:
|
||||
time_start, time_end, time_start_text, time_end_text = _resolve_time_expression(time_expression)
|
||||
except ValueError as exc:
|
||||
return str(exc)
|
||||
time_range_text = f"(时间范围:{time_start_text} 至 {time_end_text})"
|
||||
|
||||
backend_mode = normalized_mode
|
||||
|
||||
try:
|
||||
result = await memory_service.search(
|
||||
content,
|
||||
limit=safe_limit,
|
||||
mode=backend_mode,
|
||||
chat_id=str(chat_id or "").strip(),
|
||||
person_id=str(person_id or "").strip(),
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
)
|
||||
text = _format_tool_result(
|
||||
result=result,
|
||||
mode=normalized_mode, # type: ignore[arg-type]
|
||||
limit=safe_limit,
|
||||
query=content,
|
||||
time_range_text=time_range_text,
|
||||
)
|
||||
logger.debug(f"长期记忆查询结果({normalized_mode}): {text}")
|
||||
return text
|
||||
except Exception as exc:
|
||||
logger.error(f"长期记忆查询失败: {exc}")
|
||||
return f"长期记忆查询失败:{exc}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
register_memory_retrieval_tool(
|
||||
name="search_long_term_memory",
|
||||
description=(
|
||||
"从长期记忆中检索信息。支持 search(普通事实检索)、time(按时间范围检索)、"
|
||||
"episode(按事件/情节检索)、aggregate(综合检索)四种模式。"
|
||||
),
|
||||
parameters=[
|
||||
{
|
||||
"name": "query",
|
||||
"type": "string",
|
||||
"description": "需要查询的问题。search 模式建议用自然语言问句;time/episode/aggregate 模式也可用关键词短语。",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "mode",
|
||||
"type": "string",
|
||||
"description": "检索模式:search(普通长期记忆)、time(按时间窗口)、episode(事件/情节)、aggregate(综合检索)。",
|
||||
"required": False,
|
||||
"enum": ["search", "time", "episode", "aggregate"],
|
||||
},
|
||||
{
|
||||
"name": "limit",
|
||||
"type": "integer",
|
||||
"description": "希望返回的相关知识条数,默认为5",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "chat_id",
|
||||
"type": "string",
|
||||
"description": "当前聊天流ID,可选。提供后优先检索当前聊天上下文相关的长期记忆。",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "person_id",
|
||||
"type": "string",
|
||||
"description": "相关人物ID,可选。提供后优先检索该人物相关的长期记忆。",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "time_expression",
|
||||
"type": "string",
|
||||
"description": (
|
||||
"时间表达,可选。time 模式必填;episode/aggregate 模式可选。支持:今天、昨天、前天、本周、上周、本月、上月、"
|
||||
"最近N天,以及 YYYY/MM/DD、YYYY/MM/DD HH:mm。"
|
||||
),
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
execute_func=query_long_term_memory,
|
||||
)
|
||||
@@ -1,78 +0,0 @@
|
||||
"""
|
||||
查询黑话/概念含义 - 工具实现
|
||||
用于在记忆检索过程中主动查询未知词语或黑话的含义
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_words(chat_id: str, words: str) -> str:
|
||||
"""查询词语或黑话的含义
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
words: 要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔)
|
||||
|
||||
Returns:
|
||||
str: 查询结果,包含词语的含义解释
|
||||
"""
|
||||
try:
|
||||
if not words or not words.strip():
|
||||
return "未提供要查询的词语"
|
||||
|
||||
# 解析词语列表(支持逗号、空格等分隔符)
|
||||
words_list = []
|
||||
for separator in [",", ",", " ", "\n", "\t"]:
|
||||
if separator in words:
|
||||
words_list = [w.strip() for w in words.split(separator) if w.strip()]
|
||||
break
|
||||
|
||||
# 如果没有找到分隔符,整个字符串作为一个词语
|
||||
if not words_list:
|
||||
words_list = [words.strip()]
|
||||
|
||||
# 去重
|
||||
unique_words = []
|
||||
seen = set()
|
||||
for word in words_list:
|
||||
if word and word not in seen:
|
||||
unique_words.append(word)
|
||||
seen.add(word)
|
||||
|
||||
if not unique_words:
|
||||
return "未提供有效的词语"
|
||||
|
||||
logger.info(f"查询词语含义: {unique_words}")
|
||||
|
||||
# 调用检索函数
|
||||
result = await retrieve_concepts_with_jargon(unique_words, chat_id)
|
||||
|
||||
if result:
|
||||
return result
|
||||
else:
|
||||
return f"未找到词语 '{', '.join(unique_words)}' 的含义或黑话解释"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询词语含义失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_words",
|
||||
description="查询词语或黑话的含义。当遇到不熟悉的词语、缩写、黑话或网络用语时,可以使用此工具查询其含义。支持查询单个或多个词语(用逗号、空格等分隔)。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "words",
|
||||
"type": "string",
|
||||
"description": "要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔,如:'YYDS' 或 'YYDS,内卷,996')",
|
||||
"required": True,
|
||||
},
|
||||
],
|
||||
execute_func=query_words,
|
||||
)
|
||||
@@ -1,42 +0,0 @@
|
||||
"""
|
||||
return_information工具 - 用于在记忆检索过程中返回总结信息并结束查询
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def return_information(information: str) -> str:
|
||||
"""返回总结信息并结束查询
|
||||
|
||||
Args:
|
||||
information: 基于已收集信息总结出的相关信息,用于帮助回复。如果收集的信息对当前聊天没有帮助,可以返回空字符串。
|
||||
|
||||
Returns:
|
||||
str: 确认信息
|
||||
"""
|
||||
if information and information.strip():
|
||||
logger.info(f"返回总结信息: {information}")
|
||||
return f"已确认返回信息: {information}"
|
||||
else:
|
||||
logger.info("未收集到相关信息,结束查询")
|
||||
return "未收集到相关信息,查询结束"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册return_information工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="return_information",
|
||||
description="当你决定结束查询时,调用此工具。基于已收集的信息,总结出一段相关信息用于帮助回复。如果收集的信息对当前聊天有帮助,在information参数中提供总结信息;如果信息无关或没有帮助,可以提供空字符串。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "information",
|
||||
"type": "string",
|
||||
"description": "基于已收集信息总结出的相关信息,用于帮助回复。必须基于已收集的信息,不要编造。如果信息对当前聊天没有帮助,可以返回空字符串。",
|
||||
"required": True,
|
||||
},
|
||||
],
|
||||
execute_func=return_information,
|
||||
)
|
||||
@@ -1,167 +0,0 @@
|
||||
"""工具注册系统。
|
||||
|
||||
提供统一的工具注册和管理接口。
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType, normalize_tool_option
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
class MemoryRetrievalTool:
|
||||
"""记忆检索工具基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: List[Dict[str, Any]],
|
||||
execute_func: Callable[..., Awaitable[str]],
|
||||
) -> None:
|
||||
"""初始化工具。
|
||||
|
||||
Args:
|
||||
name: 工具名称。
|
||||
description: 工具描述。
|
||||
parameters: 参数定义列表。
|
||||
execute_func: 执行函数,必须是异步函数。
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.parameters = parameters
|
||||
self.execute_func = execute_func
|
||||
|
||||
def get_tool_description(self) -> str:
|
||||
"""获取工具的文本描述,用于prompt"""
|
||||
param_descriptions = []
|
||||
for param in self.parameters:
|
||||
param_name = param.get("name", "")
|
||||
param_type = param.get("type", "string")
|
||||
param_desc = param.get("description", "")
|
||||
required = param.get("required", True)
|
||||
required_str = "必填" if required else "可选"
|
||||
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
|
||||
|
||||
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
|
||||
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
"""执行工具。"""
|
||||
return await self.execute_func(**kwargs)
|
||||
|
||||
def get_tool_definition(self) -> Dict[str, Any]:
|
||||
"""获取规范化的工具定义。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 统一工具定义字典。
|
||||
"""
|
||||
legacy_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
|
||||
|
||||
for param in self.parameters:
|
||||
param_name = param.get("name", "")
|
||||
param_type_str = param.get("type", "string").lower()
|
||||
param_desc = param.get("description", "")
|
||||
is_required = param.get("required", False)
|
||||
enum_values = param.get("enum", None)
|
||||
|
||||
# 转换类型字符串到ToolParamType
|
||||
type_mapping = {
|
||||
"string": ToolParamType.STRING,
|
||||
"integer": ToolParamType.INTEGER,
|
||||
"int": ToolParamType.INTEGER,
|
||||
"float": ToolParamType.FLOAT,
|
||||
"boolean": ToolParamType.BOOLEAN,
|
||||
"bool": ToolParamType.BOOLEAN,
|
||||
}
|
||||
param_type = type_mapping.get(param_type_str, ToolParamType.STRING)
|
||||
|
||||
legacy_parameters.append((param_name, param_type, param_desc, is_required, enum_values))
|
||||
|
||||
normalized_option = normalize_tool_option(
|
||||
{
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": legacy_parameters,
|
||||
}
|
||||
)
|
||||
return {
|
||||
"name": normalized_option.name,
|
||||
"description": normalized_option.description,
|
||||
"parameters_schema": normalized_option.parameters_schema,
|
||||
}
|
||||
|
||||
|
||||
class MemoryRetrievalToolRegistry:
|
||||
"""工具注册器"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化工具注册器。"""
|
||||
self.tools: Dict[str, MemoryRetrievalTool] = {}
|
||||
|
||||
def register_tool(self, tool: MemoryRetrievalTool) -> None:
|
||||
"""注册工具"""
|
||||
if tool.name in self.tools:
|
||||
logger.debug(f"记忆检索工具 {tool.name} 已存在,跳过重复注册")
|
||||
return
|
||||
self.tools[tool.name] = tool
|
||||
logger.info(f"注册记忆检索工具: {tool.name}")
|
||||
|
||||
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
|
||||
"""获取工具"""
|
||||
return self.tools.get(name)
|
||||
|
||||
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
|
||||
"""获取所有工具"""
|
||||
return self.tools.copy()
|
||||
|
||||
def get_tools_description(self) -> str:
|
||||
"""获取所有工具的描述,用于prompt"""
|
||||
descriptions = []
|
||||
for i, tool in enumerate(self.tools.values(), 1):
|
||||
descriptions.append(f"{i}. {tool.get_tool_description()}")
|
||||
return "\n".join(descriptions)
|
||||
|
||||
def get_action_types_list(self) -> str:
|
||||
"""获取所有动作类型的列表,用于prompt(已废弃,保留用于兼容)"""
|
||||
action_types = [tool.name for tool in self.tools.values()]
|
||||
action_types.append("final_answer")
|
||||
action_types.append("no_answer")
|
||||
return " 或 ".join([f'"{at}"' for at in action_types])
|
||||
|
||||
def get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有工具的定义列表,用于LLM function calling
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 工具定义列表,每个元素是一个工具定义字典
|
||||
"""
|
||||
return [tool.get_tool_definition() for tool in self.tools.values()]
|
||||
|
||||
|
||||
# 全局工具注册器实例
|
||||
_tool_registry = MemoryRetrievalToolRegistry()
|
||||
|
||||
|
||||
def register_memory_retrieval_tool(
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: List[Dict[str, Any]],
|
||||
execute_func: Callable[..., Awaitable[str]],
|
||||
) -> None:
|
||||
"""注册记忆检索工具的便捷函数。
|
||||
|
||||
Args:
|
||||
name: 工具名称。
|
||||
description: 工具描述。
|
||||
parameters: 参数定义列表。
|
||||
execute_func: 执行函数。
|
||||
"""
|
||||
tool = MemoryRetrievalTool(name, description, parameters, execute_func)
|
||||
_tool_registry.register_tool(tool)
|
||||
|
||||
|
||||
def get_tool_registry() -> MemoryRetrievalToolRegistry:
|
||||
"""获取工具注册器实例"""
|
||||
return _tool_registry
|
||||
Reference in New Issue
Block a user