Merge branch 'r-dev' of https://github.com/A-Dawn/MaiBot into r-dev
This commit is contained in:
@@ -42,4 +42,7 @@
|
||||
如果你需要改动配置文件,不需要修改实际的bot_config.toml或者model_config.toml,只需要修改配置文件模版,并新增一个版本号即可,也不必要为配置改动创建测试文件。
|
||||
|
||||
# 关于webui修改
|
||||
不要修改dashboard下的内容,因为这部分内容由另一个仓库build
|
||||
不要修改dashboard下的内容,因为这部分内容由另一个仓库build
|
||||
|
||||
# maibot插件开发文档
|
||||
https://github.com/Mai-with-u/maibot-plugin-sdk/blob/main/docs/guide.md
|
||||
23
bot.py
23
bot.py
@@ -1,7 +1,6 @@
|
||||
# raise RuntimeError("System Not Ready")
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from rich.traceback import install
|
||||
|
||||
import asyncio
|
||||
@@ -23,28 +22,6 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
os.chdir(script_dir)
|
||||
set_locale(os.getenv("MAIBOT_LOCALE", "zh-CN"))
|
||||
|
||||
env_path = Path(__file__).parent / ".env"
|
||||
template_env_path = Path(__file__).parent / "template" / "template.env"
|
||||
|
||||
if env_path.exists():
|
||||
load_dotenv(str(env_path), override=True)
|
||||
else:
|
||||
print("[WIP] no .env file found, and templates is not ready yet.")
|
||||
print("[WIP] continue startup, use environment and existing config values.")
|
||||
# try:
|
||||
# if template_env_path.exists():
|
||||
# shutil.copyfile(template_env_path, env_path)
|
||||
# print(t("startup.env_created"))
|
||||
# load_dotenv(str(env_path), override=True)
|
||||
# else:
|
||||
# print(t("startup.env_template_missing"))
|
||||
# raise FileNotFoundError(t("startup.env_file_missing"))
|
||||
# except Exception as e:
|
||||
# print(t("startup.env_auto_create_failed", error=e))
|
||||
# raise
|
||||
|
||||
set_locale(os.getenv("MAIBOT_LOCALE", "zh-CN"))
|
||||
|
||||
# 检查是否是 Worker 进程,只在 Worker 进程中输出详细的初始化信息
|
||||
# Runner 进程只需要基本的日志功能,不需要详细的初始化日志
|
||||
is_worker = os.environ.get("MAIBOT_WORKER_PROCESS") == "1"
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
{
|
||||
"manifest_version": 2,
|
||||
"version": "2.0.0",
|
||||
"name": "BetterEmoji",
|
||||
"description": "更好的表情包管理插件",
|
||||
"author": {
|
||||
"name": "SengokuCola",
|
||||
"url": "https://github.com/SengokuCola"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
"urls": {
|
||||
"repository": "https://github.com/SengokuCola/BetterEmoji",
|
||||
"homepage": "https://github.com/SengokuCola/BetterEmoji",
|
||||
"documentation": "https://github.com/SengokuCola/BetterEmoji",
|
||||
"issues": "https://github.com/SengokuCola/BetterEmoji/issues"
|
||||
},
|
||||
"host_application": {
|
||||
"min_version": "1.0.0",
|
||||
"max_version": "1.0.0"
|
||||
},
|
||||
"sdk": {
|
||||
"min_version": "2.0.0",
|
||||
"max_version": "2.99.99"
|
||||
},
|
||||
"dependencies": [],
|
||||
"capabilities": [
|
||||
"emoji.get_random",
|
||||
"emoji.get_count",
|
||||
"emoji.get_info",
|
||||
"emoji.get_all",
|
||||
"emoji.register_emoji",
|
||||
"emoji.delete_emoji",
|
||||
"send.text",
|
||||
"send.forward"
|
||||
],
|
||||
"i18n": {
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
"supported_locales": [
|
||||
"zh-CN"
|
||||
]
|
||||
},
|
||||
"id": "sengokucola.betteremoji"
|
||||
}
|
||||
@@ -1,238 +0,0 @@
|
||||
"""表情包管理插件 — 新 SDK 版本
|
||||
|
||||
通过 /emoji 命令管理表情包的添加、列表和删除。
|
||||
"""
|
||||
|
||||
from maibot_sdk import Command, MaiBotPlugin
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
|
||||
class EmojiManagePlugin(MaiBotPlugin):
|
||||
"""表情包管理插件"""
|
||||
|
||||
async def on_load(self) -> None:
|
||||
"""处理插件加载。"""
|
||||
|
||||
async def on_unload(self) -> None:
|
||||
"""处理插件卸载。"""
|
||||
|
||||
# ===== 工具方法 =====
|
||||
|
||||
@staticmethod
|
||||
def _extract_emoji_base64(segments) -> list[str]:
|
||||
"""从消息 segments 中提取 emoji/image 的 base64 数据。
|
||||
|
||||
segments 可以是 dict 列表或 Seg 对象列表(兼容两种格式)。
|
||||
"""
|
||||
results: list[str] = []
|
||||
if not segments:
|
||||
return results
|
||||
|
||||
if isinstance(segments, dict):
|
||||
seg_type = segments.get("type", "")
|
||||
if seg_type in ("emoji", "image"):
|
||||
data = segments.get("data", "")
|
||||
if data:
|
||||
results.append(data)
|
||||
elif seg_type == "seglist":
|
||||
for child in segments.get("data", []):
|
||||
results.extend(EmojiManagePlugin._extract_emoji_base64(child))
|
||||
return results
|
||||
|
||||
# 如果有 .type 属性(Seg 对象)
|
||||
if hasattr(segments, "type"):
|
||||
seg_type = getattr(segments, "type", "")
|
||||
if seg_type in ("emoji", "image"):
|
||||
results.append(getattr(segments, "data", ""))
|
||||
elif seg_type == "seglist":
|
||||
for child in getattr(segments, "data", []):
|
||||
results.extend(EmojiManagePlugin._extract_emoji_base64(child))
|
||||
return results
|
||||
|
||||
# 列表
|
||||
for seg in segments:
|
||||
results.extend(EmojiManagePlugin._extract_emoji_base64(seg))
|
||||
return results
|
||||
|
||||
# ===== Command 组件 =====
|
||||
|
||||
@Command("add_emoji", description="添加表情包", pattern=r".*/emoji add.*")
|
||||
async def handle_add_emoji(self, stream_id: str = "", message_segments=None, **kwargs):
|
||||
"""添加表情包"""
|
||||
emoji_base64_list = self._extract_emoji_base64(message_segments)
|
||||
if not emoji_base64_list:
|
||||
await self.ctx.send.text("未在消息中找到表情包或图片", stream_id)
|
||||
return False, "未在消息中找到表情包或图片", False
|
||||
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
results = []
|
||||
|
||||
for i, emoji_b64 in enumerate(emoji_base64_list):
|
||||
result = await self.ctx.emoji.register_emoji(emoji_b64)
|
||||
if isinstance(result, dict) and result.get("success"):
|
||||
success_count += 1
|
||||
desc = result.get("description", "未知描述")
|
||||
emotions = result.get("emotions", [])
|
||||
replaced = result.get("replaced", False)
|
||||
msg = f"表情包 {i + 1} 注册成功{'(替换旧表情包)' if replaced else '(新增表情包)'}"
|
||||
if desc:
|
||||
msg += f"\n描述: {desc}"
|
||||
if emotions:
|
||||
msg += f"\n情感标签: {', '.join(emotions)}"
|
||||
results.append(msg)
|
||||
else:
|
||||
fail_count += 1
|
||||
err = result.get("message", "注册失败") if isinstance(result, dict) else "注册失败"
|
||||
results.append(f"表情包 {i + 1} 注册失败: {err}")
|
||||
|
||||
total = success_count + fail_count
|
||||
summary = f"表情包注册完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total} 个"
|
||||
if results:
|
||||
summary += "\n" + "\n".join(results)
|
||||
|
||||
await self.ctx.send.text(summary, stream_id)
|
||||
return success_count > 0, summary, success_count > 0
|
||||
|
||||
@Command("emoji_list", description="列表表情包", pattern=r"^/emoji list(\s+\d+)?$")
|
||||
async def handle_list_emoji(self, stream_id: str = "", raw_message: str = "", **kwargs):
|
||||
"""列出表情包"""
|
||||
max_count = 10
|
||||
match = re.match(r"^/emoji list(?:\s+(\d+))?$", raw_message)
|
||||
if match and match.group(1):
|
||||
max_count = min(int(match.group(1)), 50)
|
||||
|
||||
now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
count_result = await self.ctx.emoji.get_count()
|
||||
emoji_count = count_result if isinstance(count_result, int) else 0
|
||||
|
||||
info_result = await self.ctx.emoji.get_info()
|
||||
max_emoji = info_result.get("max_count", 0) if isinstance(info_result, dict) else 0
|
||||
available = info_result.get("available_emojis", 0) if isinstance(info_result, dict) else 0
|
||||
|
||||
lines = [
|
||||
f"📊 表情包统计信息 ({now})",
|
||||
f"• 总数: {emoji_count} / {max_emoji}",
|
||||
f"• 可用: {available}",
|
||||
]
|
||||
|
||||
if emoji_count == 0:
|
||||
lines.append("\n❌ 暂无表情包")
|
||||
await self.ctx.send.text("\n".join(lines), stream_id)
|
||||
return True, "\n".join(lines), True
|
||||
|
||||
all_result = await self.ctx.emoji.get_all()
|
||||
all_emojis = all_result if isinstance(all_result, list) else []
|
||||
if not all_emojis:
|
||||
lines.append("\n❌ 无法获取表情包列表")
|
||||
await self.ctx.send.text("\n".join(lines), stream_id)
|
||||
return False, "\n".join(lines), True
|
||||
|
||||
display = all_emojis[:max_count]
|
||||
lines.append(f"\n📋 显示前 {len(display)} 个表情包:")
|
||||
for i, emoji in enumerate(display, 1):
|
||||
if isinstance(emoji, (list, tuple)) and len(emoji) >= 3:
|
||||
_, desc, emotion = emoji[0], emoji[1], emoji[2]
|
||||
elif isinstance(emoji, dict):
|
||||
desc = emoji.get("description", "")
|
||||
emotion = emoji.get("emotion", "")
|
||||
else:
|
||||
desc, emotion = str(emoji), ""
|
||||
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
|
||||
lines.append(f"{i}. {short_desc} [{emotion}]")
|
||||
|
||||
if len(all_emojis) > max_count:
|
||||
lines.append(f"\n💡 还有 {len(all_emojis) - max_count} 个表情包未显示")
|
||||
|
||||
final = "\n".join(lines)
|
||||
await self.ctx.send.text(final, stream_id)
|
||||
return True, final, True
|
||||
|
||||
@Command("delete_emoji", description="删除表情包", pattern=r".*/emoji delete.*")
|
||||
async def handle_delete_emoji(self, stream_id: str = "", message_segments=None, **kwargs):
|
||||
"""删除表情包"""
|
||||
emoji_base64_list = self._extract_emoji_base64(message_segments)
|
||||
if not emoji_base64_list:
|
||||
await self.ctx.send.text("未在消息中找到表情包或图片", stream_id)
|
||||
return False, "未找到表情包", False
|
||||
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
results = []
|
||||
|
||||
for i, emoji_b64 in enumerate(emoji_base64_list):
|
||||
# 计算哈希
|
||||
if isinstance(emoji_b64, str):
|
||||
clean = emoji_b64.encode("ascii", errors="ignore").decode("ascii")
|
||||
else:
|
||||
clean = str(emoji_b64)
|
||||
image_bytes = base64.b64decode(clean)
|
||||
emoji_hash = hashlib.md5(image_bytes).hexdigest() # noqa: S324
|
||||
|
||||
result = await self.ctx.emoji.delete_emoji(emoji_hash)
|
||||
if isinstance(result, dict) and result.get("success"):
|
||||
success_count += 1
|
||||
desc = result.get("description", "未知描述")
|
||||
emotions = result.get("emotions", [])
|
||||
before = result.get("count_before", 0)
|
||||
after = result.get("count_after", 0)
|
||||
msg = f"表情包 {i + 1} 删除成功"
|
||||
if desc:
|
||||
msg += f"\n描述: {desc}"
|
||||
if emotions:
|
||||
msg += f"\n情感标签: {', '.join(emotions)}"
|
||||
msg += f"\n表情包数量: {before} → {after}"
|
||||
results.append(msg)
|
||||
else:
|
||||
fail_count += 1
|
||||
err = result.get("message", "删除失败") if isinstance(result, dict) else "删除失败"
|
||||
results.append(f"表情包 {i + 1} 删除失败: {err}")
|
||||
|
||||
total = success_count + fail_count
|
||||
summary = f"表情包删除完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total} 个"
|
||||
if results:
|
||||
summary += "\n" + "\n".join(results)
|
||||
|
||||
await self.ctx.send.text(summary, stream_id)
|
||||
return success_count > 0, summary, success_count > 0
|
||||
|
||||
@Command("random_emojis", description="发送多张随机表情包", pattern=r"^/random_emojis$")
|
||||
async def handle_random_emojis(self, stream_id: str = "", **kwargs):
|
||||
"""发送多张随机表情包"""
|
||||
emojis = await self.ctx.emoji.get_random(5)
|
||||
if not emojis:
|
||||
return False, "未找到表情包", False
|
||||
messages = [
|
||||
{"user_id": "0", "nickname": "神秘用户", "segments": [{"type": "image", "content": e.get("base64", "")}]}
|
||||
for e in emojis
|
||||
]
|
||||
await self.ctx.send.forward(messages, stream_id)
|
||||
return True, "已发送随机表情包", True
|
||||
|
||||
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
|
||||
"""处理配置热重载事件。
|
||||
|
||||
Args:
|
||||
scope: 配置变更范围。
|
||||
config_data: 最新配置数据。
|
||||
version: 配置版本号。
|
||||
"""
|
||||
|
||||
del scope
|
||||
del config_data
|
||||
del version
|
||||
|
||||
|
||||
def create_plugin() -> EmojiManagePlugin:
|
||||
"""创建表情包管理插件实例。
|
||||
|
||||
Returns:
|
||||
EmojiManagePlugin: 新的表情包管理插件实例。
|
||||
"""
|
||||
|
||||
return EmojiManagePlugin()
|
||||
@@ -3,6 +3,7 @@ You need to focus on the dialogue between {bot_name} (AI) and different users so
|
||||
|
||||
[Reference Information]
|
||||
{identity}
|
||||
{time_block}
|
||||
[End of Reference Information]
|
||||
|
||||
You need to analyze based on the provided reference information, the current scenario, and the output rules.
|
||||
|
||||
@@ -8,4 +8,5 @@ You are chatting in the group now. Please read the previous chat records, grasp
|
||||
Try to keep it short. It is best to reply to only one topic at a time, so the reply does not become verbose or messy. Please pay attention to the chat content.
|
||||
{reply_style}
|
||||
You may refer to the information in [Reply Reference], but depending on the situation, you do not have to follow it completely.
|
||||
{group_chat_attention_block}
|
||||
Please do not output any extra content (including unnecessary prefixes or suffixes, colons, brackets, stickers, at, or @). Only output the message content itself.
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
【参考情報】
|
||||
{identity}
|
||||
{time_block}
|
||||
【参考情報ここまで】
|
||||
|
||||
提供された参考情報、現在の状況、そして出力ルールに基づいて分析してください。
|
||||
|
||||
@@ -8,4 +8,5 @@
|
||||
できるだけ短くしてください。話題は一度に一つだけに返信したほうが、冗長になったり内容が散らかったりしません。チャット内容をしっかり踏まえてください。
|
||||
{reply_style}
|
||||
【返信情報参考】の情報は参考にしてかまいませんが、状況に応じて完全に従う必要はありません。
|
||||
{group_chat_attention_block}
|
||||
余計な内容(不要な前置きや後置き、コロン、括弧、スタンプ、at や @ など)は出力せず、発言内容だけを出力してください。
|
||||
|
||||
@@ -24,4 +24,4 @@
|
||||
"message_indices": [1, 2, 5]
|
||||
}},
|
||||
...
|
||||
]
|
||||
]
|
||||
|
||||
@@ -19,4 +19,4 @@
|
||||
聊天记录:
|
||||
{original_text}
|
||||
|
||||
请直接返回JSON,不要包含其他内容。
|
||||
请直接返回JSON,不要包含其他内容。
|
||||
|
||||
@@ -1,27 +1,30 @@
|
||||
你的任务是分析聊天和聊天中的互动情况。
|
||||
你的任务是分析聊天和聊天中的互动情况,然后做出下一步动作。
|
||||
你需要关注 {bot_name}(AI) 与不同用户的对话来为选择正确的动作和行为以及搜集信息提供建议
|
||||
|
||||
【参考信息】
|
||||
{bot_name}的人设:{identity}
|
||||
{time_block}
|
||||
【参考信息结束】
|
||||
|
||||
你需要根据提供的参考信息,当前场景和输出规则来进行分析
|
||||
在当前场景中,不同的人正在互动({bot_name}也是一位参与的用户),用户也可能与进行聊天互动,你的任务不是生成对用户可见的发言,而是进行分析来指导AI进行回复。
|
||||
“分析”应该体现你对当前局面的判断、你的建议、你的下一步计划,以及你为什么这样想。
|
||||
你需要先搜集能够帮助{bot_name}进行下一步行动的信息,然后再给出回复意见
|
||||
请你对当前场景和输出规则来进行分析,你可以参考参考信息中的内容,但不用过分遵守,仅供参考。
|
||||
在当前场景中,不同的人正在互动({bot_name}也是一位参与的用户),用户也可能与进行聊天互动,你的任务不是生成对用户可见的发言,而是进行分析来指导AI进行动作。
|
||||
“分析”应该体现你对当前局面的判断、你的建议、你的下一步计划,以及你为什么这样想。默认直接输出你当前的最新分析,不要重复之前的分析内容。最新分析应尽量具体,贴近上下文。
|
||||
你需要先搜集能够帮助{bot_name}进行下一步行动的信息,然后再给出思考
|
||||
|
||||
{group_chat_attention_block}
|
||||
|
||||
你可以使用这些工具:
|
||||
工具说明:
|
||||
- reply():当你判断{bot_name}现在应该正式对用户发出一条可见回复时调用。调用后系统会基于你当前这轮的想法生成一条真正展示给用户的回复。你可以针对某个用户回复,也可以对所有用户回复。
|
||||
- query_jargon():当你认为某些词的含义不明确,或用户询问某些词的含义,需要进行查询
|
||||
- query_memory():如果当前可用工具中存在它,当回复明显依赖历史对话、长期偏好、共同经历、人物长期信息或之前约定时使用
|
||||
- tool_search():当你在deferred tools列表中需要其中某个工具时,先调用它来搜索并发现对应工具;它只负责让工具在后续轮次变为可用,不直接执行业务
|
||||
- finish():当没有更多操作需要做,使用finish结束这次思考
|
||||
- 其他定义的工具,你可以视情况合适使用
|
||||
|
||||
工具使用规则:
|
||||
1. 你当前处于 Action Loop 阶段,节奏控制由独立的 timing gate 负责;如果系统让你继续,就专注于分析、搜集信息和执行真正需要的工具。
|
||||
2. 如果存在用户的疑问,或者对某些概念的不确定,你可以使用工具来搜集信息或者查询含义,你可以使用多个工具。
|
||||
3. 当你判断 {bot_name} 现在应该正式发出可见回复时,调用 reply()。
|
||||
4. 如果需要补充上下文、查看消息、查询黑话、检索记忆或使用其他可用工具,可以按需调用。
|
||||
1. 你可以使用多个工具。
|
||||
2. 如果存在工具可以帮助你执行某些动作,完成某些目标,直接使用该工具来完成任务
|
||||
3. 如果看到 `<system-reminder>` 中列出了 deferred tools,而你需要其中某个工具,先调用 tool_search() 搜索该工具,等它在后续轮次变为可用后再正常调用。
|
||||
|
||||
长期记忆使用建议:
|
||||
1. 仅当历史信息会明显影响当前回复时,才考虑调用 `query_memory()`。
|
||||
@@ -30,11 +33,5 @@
|
||||
4. 模式上:`search` 查事实或偏好,`time` 查某段时间,`episode` 查某次经历,`aggregate` 查整体情况;拿不准时用 `hybrid`。
|
||||
5. 如果无命中、被过滤、或证据不足,就不要编造。
|
||||
|
||||
你的分析规则:
|
||||
1. 默认直接输出你当前的最新分析,不要重复之前的分析内容。最新分析应尽量具体,贴近上下文。
|
||||
2. 你需要先评估是用户之间在互动还是和{bot_name}在互动,不要盲目插话,弄错回复对象
|
||||
3. 你需要评估哪些话是对{bot_name}的发言,哪些是用户之间的交流或者自言自语,不要频繁插入无关的话题。
|
||||
|
||||
{group_chat_attention_block}
|
||||
|
||||
现在,请你输出你对{bot_name}发言的分析,你必须先输出文本内容的分析,然后再进行工具调用,输出json形式的function call:
|
||||
现在,请你输出你对{bot_name}发言的分析,你必须先输出文本内容的分析,然后再进行工具调用,:
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片
|
||||
其中标注 {bot_name}(你) 的发言是你自己的发言,请注意区分:
|
||||
|
||||
在下面的内容中,标注 {bot_name}(你) 的发言是你自己的发言,请注意区分:
|
||||
{identity}
|
||||
{time_block}
|
||||
|
||||
{identity}
|
||||
你正在群里聊天,现在请你读读之前的聊天记录,把握当前的话题,然后给出日常且口语化的回复,
|
||||
尽量简短一些。最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。请注意把握聊天内容。
|
||||
现在请你读读之前的聊天记录,把握当前的话题,然后给出日常且口语化的回复,
|
||||
{reply_style}
|
||||
你可以参考【回复信息参考】中的信息,但是视情况而定,不用完全遵守。
|
||||
请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。
|
||||
{group_chat_attention_block}
|
||||
请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。
|
||||
|
||||
@@ -8,16 +8,16 @@
|
||||
在当前场景中,不同的人正在互动({bot_name} 也是一位参与的用户),用户也可能正在连续发送消息或彼此互动。
|
||||
你的任务不是生成对别人可见的发言,也不是直接使用查询类工具,而是判断当前是否应该:
|
||||
- continue:立刻进入下一轮完整思考、搜集信息、回复与其他工具执行
|
||||
- wait:再等待一段时间,然后重新判断,可选几秒的等待,也可等待数分钟
|
||||
- no_reply:本轮不继续,直接等待新的外部消息
|
||||
- wait:固定再等待一段时间,时间到后再重新判断;
|
||||
- no_reply:本轮不继续,直接等待新的消息
|
||||
|
||||
节奏控制规则:
|
||||
1. 如果 {bot_name} 已经回复,但用户暂时没有新的回复,且没有新信息需要搜集,使用 wait 或者 no_reply 进行等待。
|
||||
2. 如果用户有新发言,但是你评估用户还有后续发言尚未发送,可以适当等待让用户说完。
|
||||
3. 在特定情况下也可以连续回复,例如想要追问,或者补充自己先前的发言,这时应调用 continue,让主流程继续执行。
|
||||
4. 不要每条消息都回复,不要直接因为别的用户发送了表情包就发言。
|
||||
5. 如果你判断现在需要真正回复、查询信息、查看上下文或做进一步分析,不要在这里完成,直接调用 continue,把工作交给主流程。
|
||||
6. 你必须且只能调用一个工具,不要连续调用多个工具,也不要只输出文本不调用工具。
|
||||
3. 你需要先评估是用户之间在互动还是和{bot_name}在互动,不要盲目插话,弄错回复对象
|
||||
4. 你需要评估哪些话是对{bot_name}的发言,哪些是用户之间的交流或者自言自语,不要频繁插入无关的话题。
|
||||
5. 在特定情况下也可以连续回复,例如想要追问,或者补充自己先前的发言,这时应调用 continue,让主流程继续执行。
|
||||
6. 如果你判断现在需要真正回复、查询信息、查看上下文或做进一步分析,不要在这里完成,直接调用 continue,把工作交给主流程。
|
||||
|
||||
{group_chat_attention_block}
|
||||
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
你是一个专门获取长期记忆的助手。你的名字是{bot_name}。现在是{time_now}。
|
||||
群里正在进行的聊天内容:
|
||||
{chat_history}
|
||||
|
||||
现在,{sender}发送了内容:{target_message},你想要回复ta。
|
||||
请仔细分析聊天内容,考虑以下几点:
|
||||
1. 内容中是否包含需要查询历史知识或长期记忆的问题
|
||||
2. 是否有明确的知识获取指令
|
||||
|
||||
如果需要使用长期记忆工具,请直接调用函数 `search_long_term_memory`;如果不需要任何工具,直接输出 `No tool needed`。
|
||||
|
||||
工具模式说明:
|
||||
- `mode="search"`:普通长期记忆检索,适合查具体事实、偏好、历史对话内容
|
||||
- `mode="time"`:按时间范围检索,必须同时提供 `time_expression`
|
||||
- `mode="episode"`:按事件/情节检索,适合查“那次经历”“那件事的经过”
|
||||
- `mode="aggregate"`:综合检索,适合“整体回忆一下”“把相关线索综合找出来”
|
||||
|
||||
优先规则:
|
||||
- 问“某段时间发生了什么”:优先 `time`
|
||||
- 问“某次事件/某段经历”:优先 `episode`
|
||||
- 问“整体情况/最近发生过什么”:优先 `aggregate`
|
||||
- 问单点事实:优先 `search`
|
||||
|
||||
`time_expression` 可用表达:
|
||||
- `今天`、`昨天`、`前天`、`本周`、`上周`、`本月`、`上月`、`最近7天`
|
||||
- 或绝对时间:`2026/03/18`、`2026/03/18 09:30`
|
||||
@@ -1,19 +0,0 @@
|
||||
你的名字是{bot_name}。现在是{time_now}。
|
||||
你正在参与聊天,你需要根据搜集到的信息总结信息。
|
||||
如果搜集到的信息对于参与聊天,回答问题有帮助,请加入总结,如果无关,请不要加入到总结。
|
||||
|
||||
当前聊天记录:
|
||||
{chat_history}
|
||||
|
||||
已收集的信息:
|
||||
{collected_info}
|
||||
|
||||
|
||||
分析:
|
||||
- 基于已收集的信息,总结出对当前聊天有帮助的相关信息
|
||||
- **如果收集的信息对当前聊天有帮助**,在思考中直接给出总结信息,格式为:return_information(information="你的总结信息")
|
||||
- **如果信息无关或没有帮助**,在思考中给出:return_information(information="")
|
||||
|
||||
**重要规则:**
|
||||
- 必须严格使用检索到的信息回答问题,不要编造信息
|
||||
- 答案必须精简,不要过多解释
|
||||
@@ -1,34 +0,0 @@
|
||||
你的名字是{bot_name}。现在是{time_now}。
|
||||
你正在参与聊天,你需要搜集信息来帮助你进行回复。
|
||||
重要,这是当前聊天记录:
|
||||
{chat_history}
|
||||
聊天记录结束
|
||||
|
||||
已收集的信息:
|
||||
{collected_info}
|
||||
|
||||
- 你可以对查询思路给出简短的思考:思考要简短,直接切入要点
|
||||
- 思考完毕后,使用工具
|
||||
|
||||
**工具说明:**
|
||||
- 如果涉及过往事件、历史对话、用户长期偏好或某段时间发生的事件,可以使用长期记忆查询工具
|
||||
- 如果遇到不熟悉的词语、缩写、黑话或网络用语,可以使用query_words工具查询其含义
|
||||
- 你必须使用tool,如果需要查询你必须给出使用什么工具进行查询
|
||||
- 当你决定结束查询时,必须调用return_information工具返回总结信息并结束查询
|
||||
|
||||
长期记忆工具 `search_long_term_memory` 支持以下模式:
|
||||
- `mode="search"`:普通事实/偏好/历史内容检索。适合问“她喜欢什么”“我们之前讨论过什么”。
|
||||
- `mode="time"`:按时间范围检索。适合问“昨天发生了什么”“最近7天有哪些相关记忆”。
|
||||
- `mode="episode"`:按事件/情节检索。适合问“那次灯塔停电的经过是什么”“关于某次经历还有什么”。
|
||||
- `mode="aggregate"`:综合检索。适合问“帮我整体回忆一下这个人最近的情况”“把相关线索综合找出来”。
|
||||
|
||||
模式选择建议:
|
||||
- 问单点事实、偏好、人设、具体信息:优先 `search`
|
||||
- 问某段时间发生了什么:优先 `time`
|
||||
- 问某次事件、某段经历、某个剧情片段:优先 `episode`
|
||||
- 问整体回忆、综合找线索、总结最近发生的事:优先 `aggregate`
|
||||
|
||||
时间模式要求:
|
||||
- 使用 `mode="time"` 时,必须填写 `time_expression`
|
||||
- 可用时间表达包括:`今天`、`昨天`、`前天`、`本周`、`上周`、`本月`、`上月`、`最近7天`
|
||||
- 也可以使用绝对时间:`2026/03/18`、`2026/03/18 09:30`
|
||||
@@ -21,6 +21,7 @@ dependencies = [
|
||||
"maim-message>=0.6.2",
|
||||
"maibot-dashboard==1.0.0.dev2026040439",
|
||||
"maibot-plugin-sdk>=2.3.0",
|
||||
"matplotlib>=3.10.5",
|
||||
"mcp",
|
||||
"msgpack>=1.1.2",
|
||||
"numpy>=2.2.6",
|
||||
|
||||
@@ -33,3 +33,34 @@ def test_legacy_learning_list_with_numeric_fourth_column_is_migrated():
|
||||
"enable_jargon_learning": False,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_visual_multimodal_replyer_is_migrated_to_replyer_mode() -> None:
|
||||
payload = {
|
||||
"visual": {
|
||||
"multimodal_replyer": True,
|
||||
}
|
||||
}
|
||||
|
||||
result = try_migrate_legacy_bot_config_dict(payload)
|
||||
|
||||
assert result.migrated is True
|
||||
assert "visual.multimodal_replyer_moved_to_visual.replyer_mode" in result.reason
|
||||
assert result.data["visual"]["replyer_mode"] == "multimodal"
|
||||
assert "multimodal_replyer" not in result.data["visual"]
|
||||
|
||||
|
||||
def test_chat_replyer_generator_type_is_migrated_to_replyer_mode() -> None:
|
||||
payload = {
|
||||
"chat": {
|
||||
"replyer_generator_type": "legacy",
|
||||
},
|
||||
"visual": {},
|
||||
}
|
||||
|
||||
result = try_migrate_legacy_bot_config_dict(payload)
|
||||
|
||||
assert result.migrated is True
|
||||
assert "chat.replyer_generator_type_moved_to_visual.replyer_mode" in result.reason
|
||||
assert result.data["visual"]["replyer_mode"] == "text"
|
||||
assert "replyer_generator_type" not in result.data["chat"]
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
"""测试表达方式自动检查任务的数据库读取行为。"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
|
||||
@pytest.fixture(name="expression_auto_check_engine")
|
||||
def expression_auto_check_engine_fixture() -> Generator:
|
||||
"""创建用于表达方式自动检查任务测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_expressions_uses_read_only_session(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
expression_auto_check_engine,
|
||||
) -> None:
|
||||
"""选择表达方式时应使用只读会话,并在离开会话后安全读取 ORM 字段。"""
|
||||
|
||||
import src.bw_learner.expression_auto_check_task as expression_auto_check_task_module
|
||||
|
||||
with Session(expression_auto_check_engine) as session:
|
||||
session.add(
|
||||
Expression(
|
||||
situation="表达情绪高涨或生理反应",
|
||||
style="发送💦表情符号",
|
||||
content_list='["表达情绪高涨或生理反应"]',
|
||||
count=1,
|
||||
session_id="session-a",
|
||||
checked=False,
|
||||
rejected=False,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
auto_commit_calls: list[bool] = []
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||
"""构造带自动提交语义的测试会话工厂。
|
||||
|
||||
Args:
|
||||
auto_commit: 退出上下文时是否自动提交。
|
||||
|
||||
Yields:
|
||||
Generator[Session, None, None]: SQLModel 会话对象。
|
||||
"""
|
||||
|
||||
auto_commit_calls.append(auto_commit)
|
||||
session = Session(expression_auto_check_engine)
|
||||
try:
|
||||
yield session
|
||||
if auto_commit:
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
monkeypatch.setattr(expression_auto_check_task_module, "get_db_session", fake_get_db_session)
|
||||
monkeypatch.setattr(expression_auto_check_task_module.random, "sample", lambda entries, _count: list(entries))
|
||||
|
||||
task = ExpressionAutoCheckTask()
|
||||
expressions = await task._select_expressions(1)
|
||||
|
||||
assert auto_commit_calls == [False]
|
||||
assert len(expressions) == 1
|
||||
assert expressions[0].id is not None
|
||||
assert expressions[0].situation == "表达情绪高涨或生理反应"
|
||||
assert expressions[0].style == "发送💦表情符号"
|
||||
91
pytests/common_test/test_maisaka_expression_selector.py
Normal file
91
pytests/common_test/test_maisaka_expression_selector.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
import src.chat.replyer.maisaka_expression_selector as selector_module
|
||||
from src.chat.replyer.maisaka_expression_selector import MaisakaExpressionSelector
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
|
||||
def _build_target(platform: str, item_id: str, rule_type: str = "group") -> SimpleNamespace:
|
||||
return SimpleNamespace(platform=platform, item_id=item_id, rule_type=rule_type)
|
||||
|
||||
|
||||
def test_resolve_expression_group_scope_returns_related_sessions(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
|
||||
related_session_id = SessionUtils.calculate_session_id("qq", group_id="10002")
|
||||
|
||||
monkeypatch.setattr(
|
||||
selector_module,
|
||||
"global_config",
|
||||
SimpleNamespace(
|
||||
expression=SimpleNamespace(
|
||||
expression_groups=[
|
||||
SimpleNamespace(
|
||||
expression_groups=[
|
||||
_build_target("qq", "10001"),
|
||||
_build_target("qq", "10002"),
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
selector = MaisakaExpressionSelector()
|
||||
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
|
||||
|
||||
assert related_session_ids == {current_session_id, related_session_id}
|
||||
assert has_global_share is False
|
||||
|
||||
|
||||
def test_resolve_expression_group_scope_uses_star_as_global_share(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
|
||||
|
||||
monkeypatch.setattr(
|
||||
selector_module,
|
||||
"global_config",
|
||||
SimpleNamespace(
|
||||
expression=SimpleNamespace(
|
||||
expression_groups=[
|
||||
SimpleNamespace(
|
||||
expression_groups=[
|
||||
_build_target("*", "*"),
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
selector = MaisakaExpressionSelector()
|
||||
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
|
||||
|
||||
assert related_session_ids == {current_session_id}
|
||||
assert has_global_share is True
|
||||
|
||||
|
||||
def test_resolve_expression_group_scope_does_not_treat_empty_target_as_global(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
|
||||
|
||||
monkeypatch.setattr(
|
||||
selector_module,
|
||||
"global_config",
|
||||
SimpleNamespace(
|
||||
expression=SimpleNamespace(
|
||||
expression_groups=[
|
||||
SimpleNamespace(
|
||||
expression_groups=[
|
||||
_build_target("", ""),
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
selector = MaisakaExpressionSelector()
|
||||
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
|
||||
|
||||
assert related_session_ids == {current_session_id}
|
||||
assert has_global_share is False
|
||||
@@ -31,6 +31,23 @@ async def test_calculate_hash_format_updates_runtime_path_metadata(tmp_path: Pat
|
||||
assert emoji.dir_path == tmp_path.resolve()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_format_reuses_existing_target_file(tmp_path: Path) -> None:
|
||||
image_bytes = _build_test_image_bytes("JPEG")
|
||||
tmp_file_path = tmp_path / "emoji.tmp"
|
||||
target_file_path = tmp_path / "emoji.jpeg"
|
||||
tmp_file_path.write_bytes(image_bytes)
|
||||
target_file_path.write_bytes(image_bytes)
|
||||
|
||||
emoji = MaiEmoji(full_path=tmp_file_path, image_bytes=image_bytes)
|
||||
|
||||
assert await emoji.calculate_hash_format() is True
|
||||
assert emoji.full_path == target_file_path.resolve()
|
||||
assert emoji.file_name == target_file_path.name
|
||||
assert not tmp_file_path.exists()
|
||||
assert target_file_path.exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_cls", "extra_fields"),
|
||||
[
|
||||
|
||||
22
pytests/test_context_message_fallback.py
Normal file
22
pytests/test_context_message_fallback.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from src.common.data_models.message_component_data_model import ImageComponent, MessageSequence, TextComponent
|
||||
from src.llm_models.payload_content.message import RoleType
|
||||
from src.maisaka.context_messages import _build_message_from_sequence
|
||||
|
||||
|
||||
def test_image_only_message_keeps_placeholder_in_text_fallback() -> None:
|
||||
message_sequence = MessageSequence(
|
||||
[
|
||||
TextComponent("[时间]19:21:20\n[用户名]William730\n[用户群昵称]\n[msg_id]1385025976\n[发言内容]"),
|
||||
ImageComponent(binary_hash="hash", content=None, binary_data=None),
|
||||
]
|
||||
)
|
||||
|
||||
message = _build_message_from_sequence(
|
||||
RoleType.User,
|
||||
message_sequence,
|
||||
"[时间]19:21:20\n[用户名]William730\n[用户群昵称]\n[msg_id]1385025976\n[发言内容][图片]",
|
||||
)
|
||||
|
||||
assert message is not None
|
||||
assert "[发言内容]" in message.get_text_content()
|
||||
assert "[图片]" in message.get_text_content()
|
||||
@@ -5,8 +5,7 @@ import pytest
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from src.chat.replyer import maisaka_generator as legacy_replyer_module
|
||||
from src.chat.replyer import maisaka_generator_multi as multimodal_replyer_module
|
||||
from src.chat.replyer import maisaka_generator as replyer_module
|
||||
from src.common.data_models.reply_generation_data_models import (
|
||||
GenerationMetrics,
|
||||
LLMCompletionResult,
|
||||
@@ -37,8 +36,8 @@ class _FakeLegacyLLMServiceClient:
|
||||
del args
|
||||
del kwargs
|
||||
|
||||
async def generate_response(self, prompt: str) -> _FakeLLMResult:
|
||||
assert prompt
|
||||
async def generate_response_with_messages(self, *, message_factory: Callable[[object], list[Any]]) -> _FakeLLMResult:
|
||||
assert message_factory(object())
|
||||
return _FakeLLMResult()
|
||||
|
||||
|
||||
@@ -54,13 +53,21 @@ class _FakeMultimodalLLMServiceClient:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_and_multimodal_replyer_monitor_detail_have_same_shape(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(legacy_replyer_module, "LLMServiceClient", _FakeLegacyLLMServiceClient)
|
||||
monkeypatch.setattr(multimodal_replyer_module, "LLMServiceClient", _FakeMultimodalLLMServiceClient)
|
||||
monkeypatch.setattr(legacy_replyer_module, "load_prompt", lambda *args, **kwargs: "legacy prompt")
|
||||
monkeypatch.setattr(multimodal_replyer_module, "load_prompt", lambda *args, **kwargs: "multi prompt")
|
||||
monkeypatch.setattr(replyer_module, "LLMServiceClient", _FakeLegacyLLMServiceClient)
|
||||
monkeypatch.setattr(replyer_module, "load_prompt", lambda *args, **kwargs: "legacy prompt")
|
||||
|
||||
legacy_generator = legacy_replyer_module.MaisakaReplyGenerator(chat_stream=None, request_type="test_legacy")
|
||||
multimodal_generator = multimodal_replyer_module.MaisakaReplyGenerator(chat_stream=None, request_type="test_multi")
|
||||
legacy_generator = replyer_module.MaisakaReplyGenerator(
|
||||
chat_stream=None,
|
||||
request_type="test_legacy",
|
||||
enable_visual_message=False,
|
||||
)
|
||||
multimodal_generator = replyer_module.MaisakaReplyGenerator(
|
||||
chat_stream=None,
|
||||
request_type="test_multi",
|
||||
llm_client_cls=_FakeMultimodalLLMServiceClient,
|
||||
load_prompt_func=lambda *args, **kwargs: "multi prompt",
|
||||
enable_visual_message=True,
|
||||
)
|
||||
|
||||
legacy_success, legacy_result = await legacy_generator.generate_reply_with_context(
|
||||
stream_id="session-legacy",
|
||||
@@ -84,6 +91,40 @@ async def test_legacy_and_multimodal_replyer_monitor_detail_have_same_shape(monk
|
||||
assert legacy_result.monitor_detail["metrics"]["total_tokens"] == 19
|
||||
|
||||
|
||||
def test_legacy_replyer_builds_message_sequence_like_multimodal() -> None:
|
||||
legacy_generator = replyer_module.MaisakaReplyGenerator(
|
||||
chat_stream=None,
|
||||
request_type="test_legacy",
|
||||
enable_visual_message=False,
|
||||
)
|
||||
legacy_prompt_loader = replyer_module.load_prompt
|
||||
replyer_module.load_prompt = lambda *args, **kwargs: "legacy prompt"
|
||||
|
||||
try:
|
||||
session_message = replyer_module.SessionBackedMessage(
|
||||
raw_message=SimpleNamespace(),
|
||||
visible_text="[Alice]你好\n[Bob]在吗",
|
||||
timestamp=replyer_module.datetime.now(),
|
||||
source_kind="user",
|
||||
)
|
||||
request_messages = legacy_generator._build_request_messages(
|
||||
chat_history=[session_message],
|
||||
reply_message=None,
|
||||
reply_reason="测试原因",
|
||||
stream_id="session-legacy",
|
||||
)
|
||||
finally:
|
||||
replyer_module.load_prompt = legacy_prompt_loader
|
||||
|
||||
assert len(request_messages) == 4
|
||||
assert request_messages[0].role.value == "system"
|
||||
assert request_messages[1].role.value == "user"
|
||||
assert request_messages[1].get_text_content() == "[Alice]你好"
|
||||
assert request_messages[2].role.value == "user"
|
||||
assert request_messages[2].get_text_content() == "[Bob]在吗"
|
||||
assert request_messages[3].role.value == "user"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_tool_puts_monitor_detail_into_metadata(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fake_monitor_detail = {
|
||||
@@ -324,7 +365,7 @@ def test_reasoning_engine_build_tool_monitor_result_keeps_non_reply_tool_without
|
||||
def test_runtime_build_tool_detail_panels_renders_reply_monitor_detail() -> None:
|
||||
runtime = object.__new__(MaisakaHeartFlowChatting)
|
||||
runtime.session_id = "session-1"
|
||||
panels = runtime._build_tool_detail_panels(
|
||||
panels = runtime._build_tool_detail_cards(
|
||||
[
|
||||
{
|
||||
"tool_call_id": "call-reply-1",
|
||||
@@ -348,7 +389,8 @@ def test_runtime_build_tool_detail_panels_renders_reply_monitor_detail() -> None
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
],
|
||||
stage_title="工具调用",
|
||||
)
|
||||
|
||||
assert len(panels) == 1
|
||||
@@ -387,7 +429,7 @@ def test_runtime_build_tool_detail_panels_uses_prompt_access_panel(monkeypatch:
|
||||
_fake_build_text_access_panel,
|
||||
)
|
||||
|
||||
panels = runtime._build_tool_detail_panels(
|
||||
panels = runtime._build_tool_detail_cards(
|
||||
[
|
||||
{
|
||||
"tool_call_id": "call-reply-2",
|
||||
@@ -401,7 +443,8 @@ def test_runtime_build_tool_detail_panels_uses_prompt_access_panel(monkeypatch:
|
||||
"output_text": "reply output",
|
||||
},
|
||||
}
|
||||
]
|
||||
],
|
||||
stage_title="工具调用",
|
||||
)
|
||||
|
||||
assert len(panels) == 1
|
||||
@@ -425,7 +468,7 @@ def test_runtime_build_tool_detail_panels_uses_emotion_prompt_access_panel(monke
|
||||
_fake_build_text_access_panel,
|
||||
)
|
||||
|
||||
panels = runtime._build_tool_detail_panels(
|
||||
panels = runtime._build_tool_detail_cards(
|
||||
[
|
||||
{
|
||||
"tool_call_id": "call-emoji-1",
|
||||
@@ -439,7 +482,8 @@ def test_runtime_build_tool_detail_panels_uses_emotion_prompt_access_panel(monke
|
||||
"output_text": '{"emoji_index": 1}',
|
||||
},
|
||||
}
|
||||
]
|
||||
],
|
||||
stage_title="工具调用",
|
||||
)
|
||||
|
||||
assert len(panels) == 1
|
||||
@@ -448,6 +492,63 @@ def test_runtime_build_tool_detail_panels_uses_emotion_prompt_access_panel(monke
|
||||
assert captured["kwargs"]["request_kind"] == "emotion"
|
||||
|
||||
|
||||
def test_runtime_build_tool_detail_cards_uses_structured_prompt_messages_with_images(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
runtime = object.__new__(MaisakaHeartFlowChatting)
|
||||
runtime.session_id = "session-image"
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def _fake_build_prompt_access_panel(messages: list[Any], **kwargs: Any) -> str:
|
||||
captured["messages"] = messages
|
||||
captured["kwargs"] = kwargs
|
||||
return "IMAGE_PROMPT_LINK"
|
||||
|
||||
def _fake_build_text_access_panel(content: str, **kwargs: Any) -> str:
|
||||
captured["text_content"] = content
|
||||
captured["text_kwargs"] = kwargs
|
||||
return "TEXT_PROMPT_LINK"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"src.maisaka.runtime.PromptCLIVisualizer.build_prompt_access_panel",
|
||||
_fake_build_prompt_access_panel,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"src.maisaka.runtime.PromptCLIVisualizer.build_text_access_panel",
|
||||
_fake_build_text_access_panel,
|
||||
)
|
||||
|
||||
panels = runtime._build_tool_detail_cards(
|
||||
[
|
||||
{
|
||||
"tool_call_id": "call-reply-image-1",
|
||||
"tool_name": "reply",
|
||||
"tool_args": {"msg_id": "m3"},
|
||||
"success": True,
|
||||
"duration_ms": 22.0,
|
||||
"summary": "- reply [成功]: 已回复",
|
||||
"detail": {
|
||||
"prompt_text": "reply prompt image",
|
||||
"request_messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": ["前缀文本", ["png", "ZmFrZQ=="]],
|
||||
}
|
||||
],
|
||||
"output_text": "reply output",
|
||||
},
|
||||
}
|
||||
],
|
||||
stage_title="工具调用",
|
||||
)
|
||||
|
||||
assert len(panels) == 1
|
||||
assert "messages" in captured
|
||||
assert "text_content" not in captured
|
||||
assert captured["kwargs"]["chat_id"] == "session-image"
|
||||
assert captured["kwargs"]["request_kind"] == "replyer"
|
||||
|
||||
|
||||
def test_runtime_render_context_usage_panel_merges_timing_and_planner(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
runtime = object.__new__(MaisakaHeartFlowChatting)
|
||||
runtime.session_id = "session-merged"
|
||||
|
||||
23
pytests/test_maisaka_tool_logging.py
Normal file
23
pytests/test_maisaka_tool_logging.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from src.maisaka.chat_loop_service import MaisakaChatLoopService
|
||||
|
||||
|
||||
def test_build_tool_names_log_text_supports_openai_function_schema() -> None:
|
||||
tool_definitions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "mute_user",
|
||||
"description": "禁言指定用户",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "reply",
|
||||
"description": "发送回复",
|
||||
},
|
||||
]
|
||||
|
||||
assert MaisakaChatLoopService._build_tool_names_log_text(tool_definitions) == "mute_user、reply"
|
||||
339
pytests/test_mute_plugin_sdk.py
Normal file
339
pytests/test_mute_plugin_sdk.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""MutePlugin SDK 回归测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from maibot_sdk.context import PluginContext
|
||||
from maibot_sdk.plugin import MaiBotPlugin
|
||||
|
||||
from plugins.MutePlugin.plugin import create_plugin
|
||||
from src.core.tooling import ToolExecutionContext, ToolInvocation
|
||||
from src.plugin_runtime.component_query import ComponentQueryService
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
|
||||
def _build_plugin() -> MaiBotPlugin:
|
||||
"""构造已注入默认配置的插件实例。"""
|
||||
|
||||
plugin = create_plugin()
|
||||
plugin.set_plugin_config(plugin.get_default_config())
|
||||
return plugin
|
||||
|
||||
|
||||
def test_mute_plugin_manifest_is_valid_v2() -> None:
|
||||
"""MutePlugin 的 manifest 应符合当前运行时要求。"""
|
||||
|
||||
validator = ManifestValidator(host_version="1.0.0", sdk_version="2.3.0")
|
||||
manifest = validator.load_from_plugin_path(Path("plugins/MutePlugin"))
|
||||
|
||||
assert manifest is not None
|
||||
assert manifest.id == "sengokucola.mute-plugin"
|
||||
assert manifest.manifest_version == 2
|
||||
|
||||
|
||||
def test_create_plugin_returns_sdk_plugin() -> None:
|
||||
"""插件入口应返回 SDK 插件实例。"""
|
||||
|
||||
plugin = create_plugin()
|
||||
|
||||
assert isinstance(plugin, MaiBotPlugin)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_command_calls_napcat_group_ban_api() -> None:
|
||||
"""手动禁言命令应通过 NapCat Adapter 新 API 执行。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
plugin.set_plugin_config(
|
||||
{
|
||||
**plugin.get_default_config(),
|
||||
"components": {
|
||||
"enable_smart_mute": True,
|
||||
"enable_mute_command": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
capability_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
capability_calls.append(payload)
|
||||
|
||||
capability = payload["capability"]
|
||||
if capability == "person.get_id_by_name":
|
||||
return {"success": True, "person_id": "person-1"}
|
||||
if capability == "person.get_value":
|
||||
return {"success": True, "value": "123456"}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
|
||||
return {"success": True, "result": {"role": "member"}}
|
||||
if capability == "api.call":
|
||||
return {"success": True, "result": {"status": "ok", "retcode": 0}}
|
||||
if capability == "send.text":
|
||||
return {"success": True}
|
||||
raise AssertionError(f"unexpected capability: {capability}")
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message, intercept = await plugin.handle_mute_command(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
user_id="42",
|
||||
matched_groups={
|
||||
"target": "张三",
|
||||
"duration": "120",
|
||||
"reason": "刷屏",
|
||||
},
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert message == "成功禁言 张三"
|
||||
assert intercept is True
|
||||
|
||||
api_call = next(
|
||||
call
|
||||
for call in capability_calls
|
||||
if call["capability"] == "api.call"
|
||||
and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
|
||||
)
|
||||
assert api_call["args"]["version"] == "1"
|
||||
assert api_call["args"]["args"] == {
|
||||
"group_id": "10001",
|
||||
"user_id": "123456",
|
||||
"duration": 120,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_tool_requires_target_person_name() -> None:
|
||||
"""禁言工具在缺少目标时应直接失败并提示。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
capability_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
capability_calls.append(payload)
|
||||
return {"success": True}
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message = await plugin.handle_mute_tool(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
target="",
|
||||
duration="60",
|
||||
reason="测试",
|
||||
)
|
||||
|
||||
assert success is False
|
||||
assert message == "禁言目标不能为空"
|
||||
assert capability_calls[-1]["capability"] == "send.text"
|
||||
assert capability_calls[-1]["args"]["text"] == "没有指定禁言对象哦"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_tool_can_unwrap_nested_person_user_id_response() -> None:
|
||||
"""禁言工具应能兼容解包多层 capability 返回结果。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
capability_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
capability_calls.append(payload)
|
||||
|
||||
capability = payload["capability"]
|
||||
if capability == "person.get_id_by_name":
|
||||
return {"success": True, "result": {"success": True, "person_id": "person-1"}}
|
||||
if capability == "person.get_value":
|
||||
return {"success": True, "result": {"success": True, "value": "123456"}}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
|
||||
return {"success": True, "result": {"role": "member"}}
|
||||
if capability == "api.call":
|
||||
return {"success": True, "result": {"status": "ok"}}
|
||||
if capability == "send.text":
|
||||
return {"success": True}
|
||||
raise AssertionError(f"unexpected capability: {capability}")
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message = await plugin.handle_mute_tool(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
target="张三",
|
||||
duration=60,
|
||||
reason="测试",
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert message == "成功禁言 张三"
|
||||
|
||||
api_call = next(
|
||||
call
|
||||
for call in capability_calls
|
||||
if call["capability"] == "api.call"
|
||||
and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
|
||||
)
|
||||
assert api_call["args"]["args"]["user_id"] == "123456"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_tool_rejects_owner_before_group_ban_call() -> None:
|
||||
"""禁言工具应在检测到群主时提前返回明确提示。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
capability_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
capability_calls.append(payload)
|
||||
|
||||
capability = payload["capability"]
|
||||
if capability == "person.get_id_by_name":
|
||||
return {"success": True, "person_id": "person-1"}
|
||||
if capability == "person.get_value":
|
||||
return {"success": True, "value": "123456"}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
|
||||
return {"success": True, "result": {"role": "owner"}}
|
||||
if capability == "send.text":
|
||||
return {"success": True}
|
||||
raise AssertionError(f"unexpected capability: {capability}")
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message = await plugin.handle_mute_tool(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
target="张三",
|
||||
duration=60,
|
||||
reason="测试",
|
||||
)
|
||||
|
||||
assert success is False
|
||||
assert message == "张三 是群主,不能被禁言"
|
||||
assert not any(
|
||||
call["capability"] == "api.call" and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
|
||||
for call in capability_calls
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_tool_maps_cannot_ban_owner_error_message() -> None:
|
||||
"""NapCat 返回 cannot ban owner 时应转成明确中文提示。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
capability_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
capability_calls.append(payload)
|
||||
|
||||
capability = payload["capability"]
|
||||
if capability == "person.get_id_by_name":
|
||||
return {"success": True, "person_id": "person-1"}
|
||||
if capability == "person.get_value":
|
||||
return {"success": True, "value": "123456"}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
|
||||
return {"success": True, "result": {"role": "member"}}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.set_group_ban":
|
||||
return {"success": False, "error": "NapCat 动作返回失败: action=set_group_ban message=cannot ban owner"}
|
||||
if capability == "send.text":
|
||||
return {"success": True}
|
||||
raise AssertionError(f"unexpected capability: {capability}")
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message = await plugin.handle_mute_tool(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
target="张三",
|
||||
duration=60,
|
||||
reason="测试",
|
||||
)
|
||||
|
||||
assert success is False
|
||||
assert message == "张三 是群主,不能被禁言"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_tool_accepts_nested_ok_api_result() -> None:
|
||||
"""嵌套的 success/result/status=ok 返回值也应判定为成功。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
|
||||
capability = payload["capability"]
|
||||
if capability == "person.get_id_by_name":
|
||||
return {"success": True, "person_id": "person-1"}
|
||||
if capability == "person.get_value":
|
||||
return {"success": True, "value": "123456"}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
|
||||
return {"success": True, "result": {"role": "member"}}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.set_group_ban":
|
||||
return {
|
||||
"success": True,
|
||||
"result": {
|
||||
"status": "ok",
|
||||
"retcode": 0,
|
||||
"data": None,
|
||||
"message": "",
|
||||
"wording": "",
|
||||
},
|
||||
}
|
||||
if capability == "send.text":
|
||||
return {"success": True}
|
||||
raise AssertionError(f"unexpected capability: {capability}")
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message = await plugin.handle_mute_tool(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
target="张三",
|
||||
duration=60,
|
||||
reason="测试",
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert message == "成功禁言 张三"
|
||||
|
||||
|
||||
def test_tool_invocation_payload_injects_group_and_user_context() -> None:
|
||||
"""插件工具执行时应自动补齐群聊上下文字段。"""
|
||||
|
||||
entry = SimpleNamespace(invoke_method="plugin.invoke_tool")
|
||||
anchor_message = SimpleNamespace(
|
||||
message_info=SimpleNamespace(
|
||||
group_info=SimpleNamespace(group_id="10001"),
|
||||
user_info=SimpleNamespace(user_id="20002"),
|
||||
)
|
||||
)
|
||||
invocation = ToolInvocation(tool_name="mute", arguments={"target": "张三"}, stream_id="session-1")
|
||||
context = ToolExecutionContext(
|
||||
session_id="session-1",
|
||||
stream_id="session-1",
|
||||
reasoning="test",
|
||||
metadata={"anchor_message": anchor_message},
|
||||
)
|
||||
|
||||
payload = ComponentQueryService._build_tool_invocation_payload(entry, invocation, context)
|
||||
|
||||
assert payload["target"] == "张三"
|
||||
assert payload["stream_id"] == "session-1"
|
||||
assert payload["chat_id"] == "session-1"
|
||||
assert payload["group_id"] == "10001"
|
||||
assert payload["user_id"] == "20002"
|
||||
27
pytests/test_openai_client_toolless_request.py
Normal file
27
pytests/test_openai_client_toolless_request.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from src.llm_models.model_client.openai_client import _sanitize_messages_for_toolless_request
|
||||
from src.llm_models.payload_content.message import Message, RoleType, TextMessagePart
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
|
||||
|
||||
def test_sanitize_messages_for_toolless_request_drops_assistant_tool_call_without_parts() -> None:
|
||||
messages = [
|
||||
Message(
|
||||
role=RoleType.Assistant,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="call_1",
|
||||
func_name="mute_user",
|
||||
args={"target": "alice"},
|
||||
)
|
||||
],
|
||||
),
|
||||
Message(
|
||||
role=RoleType.User,
|
||||
parts=[TextMessagePart(text="继续")],
|
||||
),
|
||||
]
|
||||
|
||||
sanitized_messages = _sanitize_messages_for_toolless_request(messages)
|
||||
|
||||
assert len(sanitized_messages) == 1
|
||||
assert sanitized_messages[0].role == RoleType.User
|
||||
18
pytests/test_prompt_message_roundtrip.py
Normal file
18
pytests/test_prompt_message_roundtrip.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||
from src.plugin_runtime.hook_payloads import deserialize_prompt_messages, serialize_prompt_messages
|
||||
|
||||
|
||||
def test_prompt_messages_roundtrip_preserves_image_parts() -> None:
|
||||
messages = [
|
||||
MessageBuilder().set_role(RoleType.User).add_text_content("你好").add_image_content("png", "ZmFrZQ==").build(),
|
||||
]
|
||||
|
||||
serialized_messages = serialize_prompt_messages(messages)
|
||||
restored_messages = deserialize_prompt_messages(serialized_messages)
|
||||
|
||||
assert len(restored_messages) == 1
|
||||
assert restored_messages[0].role == RoleType.User
|
||||
assert restored_messages[0].get_text_content() == "你好"
|
||||
assert len(restored_messages[0].parts) == 2
|
||||
assert restored_messages[0].parts[1].image_format == "png"
|
||||
assert restored_messages[0].parts[1].image_base64 == "ZmFrZQ=="
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,9 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
import json
|
||||
from typing import Any, Awaitable, Callable, List, Optional
|
||||
|
||||
import json
|
||||
|
||||
from json_repair import repair_json
|
||||
from sqlmodel import select
|
||||
|
||||
@@ -30,7 +31,7 @@ class MaisakaExpressionSelectionResult:
|
||||
|
||||
|
||||
class MaisakaExpressionSelector:
|
||||
"""负责在 replyer 侧完成表达方式筛选与子代理选择。"""
|
||||
"""负责在 replyer 侧完成表达方式筛选与子代理二次选择。"""
|
||||
|
||||
def _can_use_expressions(self, session_id: str) -> bool:
|
||||
try:
|
||||
@@ -40,18 +41,34 @@ class MaisakaExpressionSelector:
|
||||
logger.error(f"检查表达方式使用开关失败: {exc}")
|
||||
return False
|
||||
|
||||
def _get_related_session_ids(self, session_id: str) -> List[str]:
|
||||
def _can_use_advanced_chosen(self, session_id: str) -> bool:
|
||||
try:
|
||||
return ExpressionConfigUtils.get_expression_advanced_chosen_for_chat(session_id)
|
||||
except Exception as exc:
|
||||
logger.error(f"检查表达方式二次选择开关失败: {exc}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_global_expression_group_marker(platform: str, item_id: str) -> bool:
|
||||
return platform == "*" and item_id == "*"
|
||||
|
||||
def _resolve_expression_group_scope(self, session_id: str) -> tuple[set[str], bool]:
|
||||
related_session_ids = {session_id}
|
||||
has_global_share = False
|
||||
expression_groups = global_config.expression.expression_groups
|
||||
|
||||
for expression_group in expression_groups:
|
||||
target_items = expression_group.expression_groups
|
||||
group_session_ids: set[str] = set()
|
||||
contains_current_session = False
|
||||
contains_global_share_marker = False
|
||||
|
||||
for target_item in target_items:
|
||||
platform = target_item.platform.strip()
|
||||
item_id = target_item.item_id.strip()
|
||||
if self._is_global_expression_group_marker(platform, item_id):
|
||||
contains_global_share_marker = True
|
||||
continue
|
||||
if not platform or not item_id:
|
||||
continue
|
||||
|
||||
@@ -65,19 +82,24 @@ class MaisakaExpressionSelector:
|
||||
if target_session_id == session_id:
|
||||
contains_current_session = True
|
||||
|
||||
if contains_global_share_marker:
|
||||
has_global_share = True
|
||||
if contains_current_session:
|
||||
related_session_ids.update(group_session_ids)
|
||||
|
||||
return list(related_session_ids)
|
||||
return related_session_ids, has_global_share
|
||||
|
||||
def _load_expression_candidates(self, session_id: str) -> List[dict[str, Any]]:
|
||||
related_session_ids = self._get_related_session_ids(session_id)
|
||||
related_session_ids, has_global_share = self._resolve_expression_group_scope(session_id)
|
||||
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
base_query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
||||
scoped_query = base_query.where(
|
||||
(Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
)
|
||||
if has_global_share:
|
||||
scoped_query = base_query
|
||||
else:
|
||||
scoped_query = base_query.where(
|
||||
(Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
)
|
||||
if global_config.expression.expression_checked_only:
|
||||
scoped_query = scoped_query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
||||
expressions = session.exec(scoped_query).all()
|
||||
@@ -87,7 +109,7 @@ class MaisakaExpressionSelector:
|
||||
"id": expression.id,
|
||||
"situation": expression.situation,
|
||||
"style": expression.style,
|
||||
"count": expression.count if getattr(expression, "count", None) is not None else 1,
|
||||
"count": expression.count if expression.count is not None else 1,
|
||||
}
|
||||
for expression in expressions
|
||||
if expression.id is not None and expression.situation and expression.style
|
||||
@@ -171,7 +193,7 @@ class MaisakaExpressionSelector:
|
||||
"你只负责根据最近聊天上下文,为这一次可见回复挑选最合适的表达方式。\n"
|
||||
"请只从下面候选中选择 0 到 3 条最适合当前语境的表达方式。\n"
|
||||
"优先考虑自然、贴合上下文、不生硬、不模板化。\n"
|
||||
"如果没有明显合适的,就返回空列表。\n"
|
||||
"如果没有明显合适的,就返回空数组。\n"
|
||||
'严格只输出 JSON,对象格式为 {"selected_ids":[123,456]}。\n\n'
|
||||
f"最近上下文:\n{history_block}\n\n"
|
||||
f"目标消息:{target_text or '无'}\n"
|
||||
@@ -208,6 +230,32 @@ class MaisakaExpressionSelector:
|
||||
break
|
||||
return selected_ids
|
||||
|
||||
def _build_direct_selection_result(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
candidates: List[dict[str, Any]],
|
||||
) -> MaisakaExpressionSelectionResult:
|
||||
selected_ids = [
|
||||
candidate["id"]
|
||||
for candidate in candidates
|
||||
if isinstance(candidate.get("id"), int)
|
||||
]
|
||||
selected_expressions = [
|
||||
candidate
|
||||
for candidate in candidates
|
||||
if candidate.get("id") in selected_ids
|
||||
]
|
||||
self._update_last_active_time(selected_ids)
|
||||
logger.info(
|
||||
f"表达方式直接注入:session_id={session_id} 已选数={len(selected_ids)} "
|
||||
f"selected_ids={selected_ids!r} 已选预览={self._format_candidate_preview(selected_expressions)}"
|
||||
)
|
||||
return MaisakaExpressionSelectionResult(
|
||||
expression_habits=self._build_expression_habits_block(selected_expressions),
|
||||
selected_expression_ids=selected_ids,
|
||||
)
|
||||
|
||||
def _update_last_active_time(self, selected_ids: List[int]) -> None:
|
||||
if not selected_ids:
|
||||
return
|
||||
@@ -233,15 +281,22 @@ class MaisakaExpressionSelector:
|
||||
if not self._can_use_expressions(session_id):
|
||||
logger.info(f"表达方式选择已跳过:当前会话未启用表达方式,session_id={session_id}")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
if sub_agent_runner is None:
|
||||
logger.info(f"表达方式选择已跳过:缺少 sub_agent_runner,session_id={session_id}")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
|
||||
candidates = self._load_expression_candidates(session_id)
|
||||
if not candidates:
|
||||
logger.info(f"表达方式选择已跳过:本地候选不足,session_id={session_id}")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
|
||||
if not self._can_use_advanced_chosen(session_id):
|
||||
return self._build_direct_selection_result(
|
||||
session_id=session_id,
|
||||
candidates=candidates,
|
||||
)
|
||||
|
||||
if sub_agent_runner is None:
|
||||
logger.info(f"表达方式选择已跳过:缺少 sub_agent_runner,session_id={session_id}")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
|
||||
logger.info(
|
||||
f"表达方式选择开始:session_id={session_id} 候选数={len(candidates)} "
|
||||
f"候选预览={self._format_candidate_preview(candidates)}"
|
||||
@@ -259,10 +314,9 @@ class MaisakaExpressionSelector:
|
||||
logger.exception("表达方式选择子代理执行失败")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
|
||||
logger.info(f"表达方式子代理原始结果:session_id={session_id} response={raw_response!r}")
|
||||
selected_ids = self._parse_selected_ids(raw_response, candidates)
|
||||
if not selected_ids:
|
||||
logger.info(f"表达方式选择完成但未命中:session_id={session_id}")
|
||||
logger.info(f"表达方式选择完成但未命中,session_id={session_id}")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
|
||||
selected_expressions = [candidate for candidate in candidates if candidate.get("id") in selected_ids]
|
||||
|
||||
@@ -1,440 +1,29 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import random
|
||||
import time
|
||||
|
||||
from rich.panel import Panel
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.cli.console import console
|
||||
from src.common.data_models.reply_generation_data_models import (
|
||||
GenerationMetrics,
|
||||
LLMCompletionResult,
|
||||
ReplyGenerationResult,
|
||||
build_reply_monitor_detail,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.config.config import global_config
|
||||
from src.core.types import ActionInfo
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from src.maisaka.context_messages import (
|
||||
AssistantMessage,
|
||||
LLMContextMessage,
|
||||
ReferenceMessage,
|
||||
SessionBackedMessage,
|
||||
ToolResultMessage,
|
||||
)
|
||||
from src.maisaka.message_adapter import parse_speaker_content
|
||||
from src.maisaka.prompt_cli_renderer import PromptCLIVisualizer
|
||||
|
||||
from .maisaka_expression_selector import maisaka_expression_selector
|
||||
|
||||
logger = get_logger("replyer")
|
||||
from .maisaka_generator_base import BaseMaisakaReplyGenerator
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaisakaReplyContext:
|
||||
"""Maisaka replyer 使用的回复上下文。"""
|
||||
|
||||
expression_habits: str = ""
|
||||
selected_expression_ids: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
class MaisakaReplyGenerator:
|
||||
"""生成 Maisaka 的最终可见回复。"""
|
||||
class MaisakaReplyGenerator(BaseMaisakaReplyGenerator):
|
||||
"""Maisaka replyer。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
request_type: str = "maisaka_replyer",
|
||||
llm_client_cls: Optional[Any] = None,
|
||||
load_prompt_func: Optional[Callable[..., str]] = None,
|
||||
enable_visual_message: Optional[bool] = None,
|
||||
) -> None:
|
||||
self.chat_stream = chat_stream
|
||||
self.request_type = request_type
|
||||
self.express_model = LLMServiceClient(
|
||||
task_name="replyer",
|
||||
super().__init__(
|
||||
chat_stream=chat_stream,
|
||||
request_type=request_type,
|
||||
llm_client_cls=llm_client_cls or LLMServiceClient,
|
||||
load_prompt_func=load_prompt_func or load_prompt,
|
||||
enable_visual_message=enable_visual_message,
|
||||
replyer_mode=global_config.visual.replyer_mode,
|
||||
)
|
||||
self._personality_prompt = self._build_personality_prompt()
|
||||
|
||||
def _build_personality_prompt(self) -> str:
|
||||
"""构建 replyer 使用的人设提示。"""
|
||||
try:
|
||||
bot_name = global_config.bot.nickname
|
||||
alias_names = global_config.bot.alias_names
|
||||
bot_aliases = f",也有人叫你{','.join(alias_names)}" if alias_names else ""
|
||||
|
||||
prompt_personality = global_config.personality.personality
|
||||
if (
|
||||
hasattr(global_config.personality, "states")
|
||||
and global_config.personality.states
|
||||
and hasattr(global_config.personality, "state_probability")
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
return f"你的名字是{bot_name}{bot_aliases},你{prompt_personality};"
|
||||
except Exception as exc:
|
||||
logger.warning(f"构建 Maisaka 人设提示词失败: {exc}")
|
||||
return "你的名字是麦麦,你是一个活泼可爱的 AI 助手。"
|
||||
|
||||
@staticmethod
|
||||
def _normalize_content(content: str, limit: int = 500) -> str:
|
||||
normalized = " ".join((content or "").split())
|
||||
if len(normalized) > limit:
|
||||
return normalized[:limit] + "..."
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _format_message_time(message: LLMContextMessage) -> str:
|
||||
return message.timestamp.strftime("%H:%M:%S")
|
||||
|
||||
@staticmethod
|
||||
def _extract_visible_assistant_reply(message: AssistantMessage) -> str:
|
||||
del message
|
||||
return ""
|
||||
|
||||
def _extract_guided_bot_reply(self, message: SessionBackedMessage) -> str:
|
||||
speaker_name, body = parse_speaker_content(message.processed_plain_text.strip())
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
if speaker_name == bot_nickname:
|
||||
return self._normalize_content(body.strip())
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _split_user_message_segments(raw_content: str) -> List[tuple[Optional[str], str]]:
|
||||
"""按说话人拆分用户消息。"""
|
||||
segments: List[tuple[Optional[str], str]] = []
|
||||
current_speaker: Optional[str] = None
|
||||
current_lines: List[str] = []
|
||||
|
||||
for raw_line in raw_content.splitlines():
|
||||
speaker_name, content_body = parse_speaker_content(raw_line)
|
||||
if speaker_name is not None:
|
||||
if current_lines:
|
||||
segments.append((current_speaker, "\n".join(current_lines)))
|
||||
current_speaker = speaker_name
|
||||
current_lines = [content_body]
|
||||
continue
|
||||
|
||||
current_lines.append(raw_line)
|
||||
|
||||
if current_lines:
|
||||
segments.append((current_speaker, "\n".join(current_lines)))
|
||||
|
||||
return segments
|
||||
|
||||
def _format_chat_history(self, messages: List[LLMContextMessage]) -> str:
|
||||
"""格式化 replyer 使用的可见聊天记录。"""
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
parts: List[str] = []
|
||||
|
||||
for message in messages:
|
||||
timestamp = self._format_message_time(message)
|
||||
|
||||
if isinstance(message, (ReferenceMessage, ToolResultMessage)):
|
||||
continue
|
||||
|
||||
if isinstance(message, SessionBackedMessage):
|
||||
guided_reply = self._extract_guided_bot_reply(message)
|
||||
if guided_reply:
|
||||
parts.append(f"{timestamp} {bot_nickname}(you): {guided_reply}")
|
||||
continue
|
||||
|
||||
raw_content = message.processed_plain_text
|
||||
for speaker_name, content_body in self._split_user_message_segments(raw_content):
|
||||
content = self._normalize_content(content_body)
|
||||
if not content:
|
||||
continue
|
||||
visible_speaker = speaker_name or global_config.maisaka.cli_user_name.strip() or "User"
|
||||
parts.append(f"{timestamp} {visible_speaker}: {content}")
|
||||
continue
|
||||
|
||||
if isinstance(message, AssistantMessage):
|
||||
visible_reply = self._extract_visible_assistant_reply(message)
|
||||
if visible_reply:
|
||||
parts.append(f"{timestamp} {bot_nickname}(you): {visible_reply}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _build_target_message_block(self, reply_message: Optional[SessionMessage]) -> str:
|
||||
"""构建当前需要回复的目标消息摘要。"""
|
||||
if reply_message is None:
|
||||
return ""
|
||||
|
||||
user_info = reply_message.message_info.user_info
|
||||
sender_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id
|
||||
target_message_id = reply_message.message_id.strip() if reply_message.message_id else "未知"
|
||||
target_content = self._normalize_content((reply_message.processed_plain_text or "").strip(), limit=300)
|
||||
if not target_content:
|
||||
target_content = "[无可见文本内容]"
|
||||
|
||||
return (
|
||||
"【本次回复目标】\n"
|
||||
f"- 目标消息ID:{target_message_id}\n"
|
||||
f"- 发送者:{sender_name}\n"
|
||||
f"- 消息内容:{target_content}\n"
|
||||
"- 你这次要回复的就是这条目标消息,请结合整段上下文理解,但不要误把其他历史消息当成当前回复对象。"
|
||||
)
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
) -> str:
|
||||
"""构建 Maisaka replyer 提示词。"""
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
formatted_history = self._format_chat_history(chat_history)
|
||||
target_message_block = self._build_target_message_block(reply_message)
|
||||
|
||||
try:
|
||||
system_prompt = load_prompt(
|
||||
"maisaka_replyer",
|
||||
bot_name=global_config.bot.nickname,
|
||||
time_block=f"当前时间:{current_time}",
|
||||
identity=self._personality_prompt,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
)
|
||||
except Exception:
|
||||
system_prompt = "你是一个友好的 AI 助手,请根据聊天记录自然回复。"
|
||||
|
||||
extra_sections: List[str] = []
|
||||
if expression_habits.strip():
|
||||
extra_sections.append(expression_habits.strip())
|
||||
|
||||
user_sections = [
|
||||
f"当前时间:{current_time}",
|
||||
f"【聊天记录】\n{formatted_history}",
|
||||
]
|
||||
if target_message_block:
|
||||
user_sections.append(target_message_block)
|
||||
if extra_sections:
|
||||
user_sections.append("\n\n".join(extra_sections))
|
||||
user_sections.append(f"【回复信息参考】\n{reply_reason}")
|
||||
user_sections.append("现在,你说:")
|
||||
|
||||
user_prompt = "\n\n".join(user_sections)
|
||||
return f"System: {system_prompt}\n\nUser: {user_prompt}"
|
||||
|
||||
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
|
||||
"""解析当前回复使用的会话 ID。"""
|
||||
if stream_id:
|
||||
return stream_id
|
||||
if self.chat_stream is not None:
|
||||
return self.chat_stream.session_id
|
||||
return ""
|
||||
|
||||
async def _build_reply_context(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
stream_id: Optional[str],
|
||||
sub_agent_runner: Optional[Callable[[str], Awaitable[str]]],
|
||||
) -> MaisakaReplyContext:
|
||||
"""构建回复上下文:表达习惯和已选表达 ID。"""
|
||||
session_id = self._resolve_session_id(stream_id)
|
||||
if not session_id:
|
||||
logger.warning("构建 Maisaka 回复上下文失败:缺少会话标识")
|
||||
return MaisakaReplyContext()
|
||||
|
||||
if sub_agent_runner is None:
|
||||
logger.info("表达方式选择跳过:缺少子代理执行器")
|
||||
return MaisakaReplyContext()
|
||||
|
||||
selection_result = await maisaka_expression_selector.select_for_reply(
|
||||
session_id=session_id,
|
||||
chat_history=chat_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
sub_agent_runner=sub_agent_runner,
|
||||
)
|
||||
return MaisakaReplyContext(
|
||||
expression_habits=selection_result.expression_habits,
|
||||
selected_expression_ids=selection_result.selected_expression_ids,
|
||||
)
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
chosen_actions: Optional[List[object]] = None,
|
||||
from_plugin: bool = True,
|
||||
stream_id: Optional[str] = None,
|
||||
reply_message: Optional[SessionMessage] = None,
|
||||
reply_time_point: Optional[float] = None,
|
||||
think_level: int = 1,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
log_reply: bool = True,
|
||||
chat_history: Optional[List[LLMContextMessage]] = None,
|
||||
expression_habits: str = "",
|
||||
selected_expression_ids: Optional[List[int]] = None,
|
||||
sub_agent_runner: Optional[Callable[[str], Awaitable[str]]] = None,
|
||||
) -> Tuple[bool, ReplyGenerationResult]:
|
||||
"""结合上下文生成 Maisaka 的最终可见回复。"""
|
||||
|
||||
def finalize(success_value: bool) -> Tuple[bool, ReplyGenerationResult]:
|
||||
result.monitor_detail = build_reply_monitor_detail(result)
|
||||
return success_value, result
|
||||
|
||||
del available_actions
|
||||
del chosen_actions
|
||||
del extra_info
|
||||
del from_plugin
|
||||
del log_reply
|
||||
del reply_time_point
|
||||
del think_level
|
||||
del unknown_words
|
||||
|
||||
result = ReplyGenerationResult()
|
||||
overall_started_at = time.perf_counter()
|
||||
if chat_history is None:
|
||||
result.error_message = "聊天历史为空"
|
||||
return finalize(False)
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复器开始生成: 会话流标识={stream_id} 回复原因={reply_reason!r} "
|
||||
f"历史消息数={len(chat_history)} 目标消息编号={reply_message.message_id if reply_message else None}"
|
||||
)
|
||||
|
||||
filtered_history = [
|
||||
message
|
||||
for message in chat_history
|
||||
if not isinstance(message, (ReferenceMessage, ToolResultMessage))
|
||||
]
|
||||
logger.debug(f"Maisaka 回复器过滤后历史消息数={len(filtered_history)}")
|
||||
|
||||
if self.express_model is None:
|
||||
logger.error("Maisaka 回复器的回复模型未初始化")
|
||||
result.error_message = "回复模型尚未初始化"
|
||||
return finalize(False)
|
||||
|
||||
try:
|
||||
reply_context = await self._build_reply_context(
|
||||
chat_history=filtered_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
stream_id=stream_id,
|
||||
sub_agent_runner=sub_agent_runner,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
|
||||
logger.error(f"Maisaka 回复器构建回复上下文失败: {exc}\n{traceback.format_exc()}")
|
||||
result.error_message = f"构建回复上下文失败: {exc}"
|
||||
result.metrics = GenerationMetrics(
|
||||
overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2),
|
||||
)
|
||||
return finalize(False)
|
||||
|
||||
merged_expression_habits = expression_habits.strip() or reply_context.expression_habits
|
||||
result.selected_expression_ids = (
|
||||
list(selected_expression_ids)
|
||||
if selected_expression_ids is not None
|
||||
else list(reply_context.selected_expression_ids)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复上下文构建完成: 会话流标识={stream_id} "
|
||||
f"已选表达编号={result.selected_expression_ids!r}"
|
||||
)
|
||||
|
||||
prompt_started_at = time.perf_counter()
|
||||
try:
|
||||
prompt = self._build_prompt(
|
||||
chat_history=filtered_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
expression_habits=merged_expression_habits,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
|
||||
logger.error(f"Maisaka 回复器构建提示词失败: {exc}\n{traceback.format_exc()}")
|
||||
result.error_message = f"构建提示词失败: {exc}"
|
||||
result.metrics = GenerationMetrics(
|
||||
overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2),
|
||||
)
|
||||
return finalize(False)
|
||||
|
||||
prompt_ms = round((time.perf_counter() - prompt_started_at) * 1000, 2)
|
||||
result.completion.request_prompt = prompt
|
||||
show_replyer_prompt = bool(getattr(global_config.debug, "show_replyer_prompt", False))
|
||||
show_replyer_reasoning = bool(getattr(global_config.debug, "show_replyer_reasoning", False))
|
||||
preview_chat_id = self._resolve_session_id(stream_id) or "unknown"
|
||||
|
||||
if show_replyer_prompt:
|
||||
console.print(
|
||||
Panel(
|
||||
PromptCLIVisualizer.build_text_access_panel(
|
||||
prompt,
|
||||
category="replyer",
|
||||
chat_id=preview_chat_id,
|
||||
request_kind="replyer",
|
||||
subtitle=f"流ID: {preview_chat_id}",
|
||||
),
|
||||
title="Maisaka 回复器 Prompt",
|
||||
border_style="bright_yellow",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
llm_started_at = time.perf_counter()
|
||||
try:
|
||||
generation_result = await self.express_model.generate_response(prompt)
|
||||
except Exception as exc:
|
||||
logger.exception("Maisaka 回复器调用失败")
|
||||
result.error_message = str(exc)
|
||||
result.metrics = GenerationMetrics(
|
||||
prompt_ms=prompt_ms,
|
||||
llm_ms=round((time.perf_counter() - llm_started_at) * 1000, 2),
|
||||
overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2),
|
||||
)
|
||||
return finalize(False)
|
||||
|
||||
llm_ms = round((time.perf_counter() - llm_started_at) * 1000, 2)
|
||||
response_text = (generation_result.response or "").strip()
|
||||
result.success = bool(response_text)
|
||||
result.completion = LLMCompletionResult(
|
||||
request_prompt=prompt,
|
||||
response_text=response_text,
|
||||
reasoning_text=generation_result.reasoning or "",
|
||||
model_name=generation_result.model_name or "",
|
||||
tool_calls=generation_result.tool_calls or [],
|
||||
prompt_tokens=generation_result.prompt_tokens,
|
||||
completion_tokens=generation_result.completion_tokens,
|
||||
total_tokens=generation_result.total_tokens,
|
||||
)
|
||||
result.metrics = GenerationMetrics(
|
||||
prompt_ms=prompt_ms,
|
||||
llm_ms=llm_ms,
|
||||
overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2),
|
||||
stage_logs=[
|
||||
f"prompt: {prompt_ms} ms",
|
||||
f"llm: {llm_ms} ms",
|
||||
],
|
||||
)
|
||||
|
||||
if show_replyer_reasoning and result.completion.reasoning_text:
|
||||
logger.info(f"Maisaka 回复器思考内容:\n{result.completion.reasoning_text}")
|
||||
|
||||
if not result.success:
|
||||
result.error_message = "回复器返回了空内容"
|
||||
logger.warning("Maisaka 回复器返回了空内容")
|
||||
return finalize(False)
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复器生成成功: 回复文本={response_text!r} "
|
||||
f"总耗时毫秒={result.metrics.overall_ms} "
|
||||
f"已选表达编号={result.selected_expression_ids!r}"
|
||||
)
|
||||
result.text_fragments = [response_text]
|
||||
return finalize(True)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import random
|
||||
|
||||
from rich.console import Group, RenderableType
|
||||
from rich.panel import Panel
|
||||
@@ -10,6 +11,7 @@ from rich.text import Text
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.cli.console import console
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.common.data_models.reply_generation_data_models import (
|
||||
@@ -19,18 +21,11 @@ from src.common.data_models.reply_generation_data_models import (
|
||||
build_reply_monitor_detail,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.config.model_configs import ModelInfo
|
||||
from src.core.types import ActionInfo
|
||||
from src.llm_models.payload_content.message import (
|
||||
ImageMessagePart,
|
||||
Message,
|
||||
MessageBuilder,
|
||||
RoleType,
|
||||
TextMessagePart,
|
||||
)
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.maisaka.context_messages import (
|
||||
AssistantMessage,
|
||||
LLMContextMessage,
|
||||
@@ -38,8 +33,9 @@ from src.maisaka.context_messages import (
|
||||
SessionBackedMessage,
|
||||
ToolResultMessage,
|
||||
)
|
||||
from src.maisaka.display.prompt_cli_renderer import PromptCLIVisualizer
|
||||
from src.maisaka.message_adapter import clone_message_sequence, parse_speaker_content
|
||||
from src.maisaka.prompt_cli_renderer import PromptCLIVisualizer
|
||||
from src.plugin_runtime.hook_payloads import serialize_prompt_messages
|
||||
|
||||
from .maisaka_expression_selector import maisaka_expression_selector
|
||||
|
||||
@@ -54,17 +50,26 @@ class MaisakaReplyContext:
|
||||
selected_expression_ids: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
class MaisakaReplyGenerator:
|
||||
"""生成 Maisaka 的最终可见回复(多模态管线)。"""
|
||||
class BaseMaisakaReplyGenerator:
|
||||
"""Maisaka replyer 的共享实现。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
request_type: str = "maisaka_replyer",
|
||||
llm_client_cls: Any,
|
||||
load_prompt_func: Callable[..., str],
|
||||
enable_visual_message: Optional[bool],
|
||||
replyer_mode: Literal["text", "multimodal", "auto"],
|
||||
) -> None:
|
||||
self.chat_stream = chat_stream
|
||||
self.request_type = request_type
|
||||
self.express_model = LLMServiceClient(
|
||||
self._llm_client_cls = llm_client_cls
|
||||
self._load_prompt = load_prompt_func
|
||||
self._enable_visual_message = enable_visual_message
|
||||
self._replyer_mode = replyer_mode
|
||||
self.express_model = llm_client_cls(
|
||||
task_name="replyer",
|
||||
request_type=request_type,
|
||||
)
|
||||
@@ -111,28 +116,6 @@ class MaisakaReplyGenerator:
|
||||
return self._normalize_content(body.strip())
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _split_user_message_segments(raw_content: str) -> List[tuple[Optional[str], str]]:
|
||||
segments: List[tuple[Optional[str], str]] = []
|
||||
current_speaker: Optional[str] = None
|
||||
current_lines: List[str] = []
|
||||
|
||||
for raw_line in raw_content.splitlines():
|
||||
speaker_name, content_body = parse_speaker_content(raw_line)
|
||||
if speaker_name is not None:
|
||||
if current_lines:
|
||||
segments.append((current_speaker, "\n".join(current_lines)))
|
||||
current_speaker = speaker_name
|
||||
current_lines = [content_body]
|
||||
continue
|
||||
|
||||
current_lines.append(raw_line)
|
||||
|
||||
if current_lines:
|
||||
segments.append((current_speaker, "\n".join(current_lines)))
|
||||
|
||||
return segments
|
||||
|
||||
def _build_target_message_block(self, reply_message: Optional[SessionMessage]) -> str:
|
||||
if reply_message is None:
|
||||
return ""
|
||||
@@ -152,19 +135,93 @@ class MaisakaReplyGenerator:
|
||||
"- 你这次要回复的就是这条目标消息,请结合整段上下文理解,但不要把其他历史消息当成当前回复对象。"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_chat_prompt_for_chat(chat_id: str, is_group_chat: Optional[bool]) -> str:
|
||||
"""根据聊天流 ID 获取匹配的额外 prompt。"""
|
||||
if not global_config.chat.chat_prompts:
|
||||
return ""
|
||||
|
||||
for chat_prompt_item in global_config.chat.chat_prompts:
|
||||
if hasattr(chat_prompt_item, "platform"):
|
||||
platform = str(chat_prompt_item.platform or "").strip()
|
||||
item_id = str(chat_prompt_item.item_id or "").strip()
|
||||
rule_type = str(chat_prompt_item.rule_type or "").strip()
|
||||
prompt_content = str(chat_prompt_item.prompt or "").strip()
|
||||
elif isinstance(chat_prompt_item, str):
|
||||
parts = chat_prompt_item.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
continue
|
||||
|
||||
platform, item_id, rule_type, prompt_content = parts
|
||||
platform = platform.strip()
|
||||
item_id = item_id.strip()
|
||||
rule_type = rule_type.strip()
|
||||
prompt_content = prompt_content.strip()
|
||||
else:
|
||||
continue
|
||||
|
||||
if not platform or not item_id or not prompt_content:
|
||||
continue
|
||||
|
||||
if rule_type == "group":
|
||||
config_is_group = True
|
||||
config_chat_id = SessionUtils.calculate_session_id(platform, group_id=item_id)
|
||||
elif rule_type == "private":
|
||||
config_is_group = False
|
||||
config_chat_id = SessionUtils.calculate_session_id(platform, user_id=item_id)
|
||||
else:
|
||||
continue
|
||||
|
||||
if config_is_group != is_group_chat:
|
||||
continue
|
||||
if config_chat_id == chat_id:
|
||||
return prompt_content
|
||||
|
||||
return ""
|
||||
|
||||
def _build_group_chat_attention_block(self, session_id: str) -> str:
|
||||
"""构建当前聊天场景下的额外注意事项块。"""
|
||||
if not session_id:
|
||||
return ""
|
||||
|
||||
try:
|
||||
is_group_chat, _ = get_chat_type_and_target_info(session_id)
|
||||
except Exception:
|
||||
is_group_chat = None
|
||||
|
||||
prompt_lines: List[str] = []
|
||||
|
||||
if is_group_chat is True:
|
||||
if group_chat_prompt := global_config.chat.group_chat_prompt.strip():
|
||||
prompt_lines.append(f"通用注意事项:\n{group_chat_prompt}")
|
||||
elif is_group_chat is False:
|
||||
if private_chat_prompt := global_config.chat.private_chat_prompts.strip():
|
||||
prompt_lines.append(f"通用注意事项:\n{private_chat_prompt}")
|
||||
|
||||
if chat_prompt := self._get_chat_prompt_for_chat(session_id, is_group_chat).strip():
|
||||
prompt_lines.append(f"当前聊天额外注意事项:\n{chat_prompt}")
|
||||
|
||||
if not prompt_lines:
|
||||
return ""
|
||||
|
||||
return "在该聊天中的注意事项:\n" + "\n\n".join(prompt_lines) + "\n"
|
||||
|
||||
def _build_system_prompt(
|
||||
self,
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
stream_id: Optional[str] = None,
|
||||
) -> str:
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
target_message_block = self._build_target_message_block(reply_message)
|
||||
session_id = self._resolve_session_id(stream_id)
|
||||
|
||||
try:
|
||||
system_prompt = load_prompt(
|
||||
system_prompt = self._load_prompt(
|
||||
"maisaka_replyer",
|
||||
bot_name=global_config.bot.nickname,
|
||||
group_chat_attention_block=self._build_group_chat_attention_block(session_id),
|
||||
time_block=f"当前时间:{current_time}",
|
||||
identity=self._personality_prompt,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
@@ -184,38 +241,35 @@ class MaisakaReplyGenerator:
|
||||
return f"{system_prompt}\n\n" + "\n\n".join(sections)
|
||||
|
||||
def _build_reply_instruction(self) -> str:
|
||||
return "请自然地回复。不要输出多余说明、括号、at 或额外标记,只输出实际要发送的内容。"
|
||||
return "请自然地回复。不要输出多余说明、括号、@ 或额外标记,只输出实际要发送的内容。"
|
||||
|
||||
def _build_multimodal_user_message(
|
||||
def _build_visual_user_message(
|
||||
self,
|
||||
message: SessionBackedMessage,
|
||||
default_user_name: str,
|
||||
enable_visual_message: bool,
|
||||
) -> Optional[Message]:
|
||||
speaker_name, _ = parse_speaker_content(message.processed_plain_text.strip())
|
||||
visible_speaker = speaker_name or default_user_name
|
||||
if not enable_visual_message:
|
||||
return None
|
||||
|
||||
raw_message = clone_message_sequence(message.raw_message)
|
||||
if not raw_message.components:
|
||||
raw_message = MessageSequence([TextComponent(f"[{visible_speaker}]")])
|
||||
elif isinstance(raw_message.components[0], TextComponent):
|
||||
first_text = raw_message.components[0].text or ""
|
||||
raw_message.components[0] = TextComponent(f"[{visible_speaker}]{first_text}")
|
||||
else:
|
||||
raw_message.components.insert(0, TextComponent(f"[{visible_speaker}]"))
|
||||
raw_message = MessageSequence([TextComponent(message.processed_plain_text)])
|
||||
|
||||
multimodal_message = SessionBackedMessage(
|
||||
visual_message = SessionBackedMessage(
|
||||
raw_message=raw_message,
|
||||
visible_text=f"[{visible_speaker}]{message.processed_plain_text}",
|
||||
visible_text=message.processed_plain_text,
|
||||
timestamp=message.timestamp,
|
||||
message_id=message.message_id,
|
||||
original_message=message.original_message,
|
||||
source_kind=message.source_kind,
|
||||
)
|
||||
return multimodal_message.to_llm_message()
|
||||
return visual_message.to_llm_message()
|
||||
|
||||
def _build_history_messages(self, chat_history: List[LLMContextMessage]) -> List[Message]:
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
default_user_name = global_config.maisaka.cli_user_name.strip() or "User"
|
||||
def _build_history_messages(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
enable_visual_message: bool,
|
||||
) -> List[Message]:
|
||||
messages: List[Message] = []
|
||||
|
||||
for message in chat_history:
|
||||
@@ -230,25 +284,14 @@ class MaisakaReplyGenerator:
|
||||
)
|
||||
continue
|
||||
|
||||
multimodal_message = self._build_multimodal_user_message(message, default_user_name)
|
||||
if multimodal_message is not None:
|
||||
messages.append(multimodal_message)
|
||||
visual_message = self._build_visual_user_message(message, enable_visual_message)
|
||||
if visual_message is not None:
|
||||
messages.append(visual_message)
|
||||
continue
|
||||
|
||||
for speaker_name, content_body in self._split_user_message_segments(message.processed_plain_text):
|
||||
content = self._normalize_content(content_body)
|
||||
if not content:
|
||||
continue
|
||||
|
||||
visible_speaker = speaker_name or default_user_name
|
||||
if visible_speaker == bot_nickname:
|
||||
messages.append(
|
||||
MessageBuilder().set_role(RoleType.Assistant).add_text_content(content).build()
|
||||
)
|
||||
continue
|
||||
|
||||
user_content = f"[{visible_speaker}]{content}"
|
||||
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(user_content).build())
|
||||
llm_message = message.to_llm_message()
|
||||
if llm_message is not None:
|
||||
messages.append(llm_message)
|
||||
continue
|
||||
|
||||
if isinstance(message, AssistantMessage):
|
||||
@@ -266,34 +309,33 @@ class MaisakaReplyGenerator:
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
stream_id: Optional[str] = None,
|
||||
enable_visual_message: bool = False,
|
||||
) -> List[Message]:
|
||||
messages: List[Message] = []
|
||||
system_prompt = self._build_system_prompt(
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
expression_habits=expression_habits,
|
||||
stream_id=stream_id,
|
||||
)
|
||||
instruction = self._build_reply_instruction()
|
||||
|
||||
messages.append(MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build())
|
||||
messages.extend(self._build_history_messages(chat_history))
|
||||
messages.extend(self._build_history_messages(chat_history, enable_visual_message))
|
||||
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(instruction).build())
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _build_request_prompt_preview(messages: List[Message]) -> str:
|
||||
preview_lines: List[str] = []
|
||||
for message in messages:
|
||||
role_name = message.role.value.capitalize()
|
||||
part_previews: List[str] = []
|
||||
for part in message.parts:
|
||||
if isinstance(part, TextMessagePart):
|
||||
part_previews.append(part.text)
|
||||
continue
|
||||
if isinstance(part, ImageMessagePart):
|
||||
part_previews.append(f"[图片:{part.normalized_image_format}]")
|
||||
preview_lines.append(f"{role_name}: {''.join(part_previews)}")
|
||||
return "\n\n".join(preview_lines)
|
||||
def _resolve_enable_visual_message(self, model_info: Optional[ModelInfo] = None) -> bool:
|
||||
if self._enable_visual_message is not None:
|
||||
return self._enable_visual_message
|
||||
if self._replyer_mode == "multimodal":
|
||||
if model_info is not None and not model_info.visual:
|
||||
raise ValueError(f"replyer_mode=multimodal,但模型 '{model_info.name}' 未开启 visual,无法使用多模态 replyer")
|
||||
return True
|
||||
if self._replyer_mode == "text":
|
||||
return False
|
||||
return bool(model_info.visual) if model_info is not None else False
|
||||
|
||||
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
|
||||
if stream_id:
|
||||
@@ -349,7 +391,6 @@ class MaisakaReplyGenerator:
|
||||
selected_expression_ids: Optional[List[int]] = None,
|
||||
sub_agent_runner: Optional[Callable[[str], Awaitable[str]]] = None,
|
||||
) -> Tuple[bool, ReplyGenerationResult]:
|
||||
|
||||
def finalize(success_value: bool) -> Tuple[bool, ReplyGenerationResult]:
|
||||
result.monitor_detail = build_reply_monitor_detail(result)
|
||||
return success_value, result
|
||||
@@ -411,7 +452,7 @@ class MaisakaReplyGenerator:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"回复上下文完成: 流={stream_id} 已选表达={result.selected_expression_ids!r}"
|
||||
f"回复上下文完成 流={stream_id} 已选表达={result.selected_expression_ids!r}"
|
||||
)
|
||||
|
||||
prompt_started_at = time.perf_counter()
|
||||
@@ -421,6 +462,7 @@ class MaisakaReplyGenerator:
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
expression_habits=merged_expression_habits,
|
||||
stream_id=stream_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
@@ -433,24 +475,36 @@ class MaisakaReplyGenerator:
|
||||
return finalize(False)
|
||||
|
||||
prompt_ms = round((time.perf_counter() - prompt_started_at) * 1000, 2)
|
||||
prompt_preview = self._build_request_prompt_preview(request_messages)
|
||||
prompt_preview = PromptCLIVisualizer._build_prompt_dump_text(request_messages)
|
||||
show_replyer_prompt = bool(getattr(global_config.debug, "show_replyer_prompt", False))
|
||||
show_replyer_reasoning = bool(getattr(global_config.debug, "show_replyer_reasoning", False))
|
||||
|
||||
def message_factory(_client: object) -> List[Message]:
|
||||
def message_factory(_client: object, model_info: Optional[ModelInfo] = None) -> List[Message]:
|
||||
nonlocal prompt_ms, prompt_preview, request_messages
|
||||
prompt_started_at = time.perf_counter()
|
||||
request_messages = self._build_request_messages(
|
||||
chat_history=filtered_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
expression_habits=merged_expression_habits,
|
||||
stream_id=stream_id,
|
||||
enable_visual_message=self._resolve_enable_visual_message(model_info),
|
||||
)
|
||||
prompt_ms = round((time.perf_counter() - prompt_started_at) * 1000, 2)
|
||||
prompt_preview = PromptCLIVisualizer._build_prompt_dump_text(request_messages)
|
||||
return request_messages
|
||||
|
||||
result.completion.request_prompt = prompt_preview
|
||||
preview_chat_id = self._resolve_session_id(stream_id)
|
||||
replyer_prompt_section: RenderableType | None = None
|
||||
if show_replyer_prompt:
|
||||
replyer_prompt_section = Panel(
|
||||
PromptCLIVisualizer.build_text_access_panel(
|
||||
prompt_preview,
|
||||
PromptCLIVisualizer.build_prompt_access_panel(
|
||||
request_messages,
|
||||
category="replyer",
|
||||
chat_id=preview_chat_id,
|
||||
request_kind="replyer",
|
||||
subtitle=f"流ID: {preview_chat_id}",
|
||||
selection_reason=f"ID: {preview_chat_id}",
|
||||
image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy",
|
||||
),
|
||||
title="Reply Prompt",
|
||||
border_style="bright_yellow",
|
||||
@@ -472,6 +526,8 @@ class MaisakaReplyGenerator:
|
||||
)
|
||||
return finalize(False)
|
||||
|
||||
result.completion.request_prompt = prompt_preview
|
||||
result.request_messages = serialize_prompt_messages(request_messages)
|
||||
llm_ms = round((time.perf_counter() - llm_started_at) * 1000, 2)
|
||||
response_text = (generation_result.response or "").strip()
|
||||
result.success = bool(response_text)
|
||||
@@ -504,7 +560,7 @@ class MaisakaReplyGenerator:
|
||||
return finalize(False)
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复器生成成功: 文本={response_text!r} "
|
||||
f"Maisaka 回复器生成成功 文本={response_text!r} "
|
||||
f"总耗时ms={result.metrics.overall_ms} 已选表达={result.selected_expression_ids!r}"
|
||||
)
|
||||
if show_replyer_prompt or show_replyer_reasoning:
|
||||
@@ -1,21 +0,0 @@
|
||||
from typing import Type
|
||||
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
def get_maisaka_replyer_class() -> Type[object]:
|
||||
"""根据配置返回 Maisaka replyer 类。"""
|
||||
generator_type = get_maisaka_replyer_generator_type()
|
||||
if generator_type == "multimodal":
|
||||
from .maisaka_generator_multi import MaisakaReplyGenerator
|
||||
|
||||
return MaisakaReplyGenerator
|
||||
|
||||
from .maisaka_generator import MaisakaReplyGenerator
|
||||
|
||||
return MaisakaReplyGenerator
|
||||
|
||||
|
||||
def get_maisaka_replyer_generator_type() -> str:
|
||||
"""返回当前配置的 Maisaka replyer 生成器类型。"""
|
||||
return "multimodal" if global_config.visual.multimodal_replyer else "legacy"
|
||||
@@ -1,15 +1,10 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
from src.chat.replyer.maisaka_replyer_factory import (
|
||||
get_maisaka_replyer_class,
|
||||
get_maisaka_replyer_generator_type,
|
||||
)
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
from .maisaka_generator import MaisakaReplyGenerator
|
||||
|
||||
logger = get_logger("ReplyerManager")
|
||||
|
||||
@@ -20,20 +15,25 @@ class ReplyerManager:
|
||||
def __init__(self) -> None:
|
||||
self._repliers: Dict[str, Any] = {}
|
||||
|
||||
@staticmethod
|
||||
def _get_maisaka_generator_type() -> str:
|
||||
"""返回当前配置下 Maisaka replyer 的消息模式。"""
|
||||
return global_config.visual.replyer_mode
|
||||
|
||||
def get_replyer(
|
||||
self,
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
request_type: str = "replyer",
|
||||
replyer_type: str = "default",
|
||||
) -> Optional["DefaultReplyer | PrivateReplyer | Any"]:
|
||||
) -> Optional[MaisakaReplyGenerator]:
|
||||
"""按会话和 replyer 类型获取实例。"""
|
||||
stream_id = chat_stream.session_id if chat_stream else chat_id
|
||||
if not stream_id:
|
||||
logger.warning("[ReplyerManager] 缺少 stream_id,无法获取 replyer")
|
||||
return None
|
||||
|
||||
generator_type = get_maisaka_replyer_generator_type() if replyer_type == "maisaka" else ""
|
||||
generator_type = self._get_maisaka_generator_type() if replyer_type == "maisaka" else ""
|
||||
cache_key = f"{replyer_type}:{generator_type}:{stream_id}"
|
||||
if cache_key in self._repliers:
|
||||
logger.info(f"[ReplyerManager] 命中缓存 replyer: cache_key={cache_key}")
|
||||
@@ -51,29 +51,13 @@ class ReplyerManager:
|
||||
|
||||
try:
|
||||
if replyer_type == "maisaka":
|
||||
logger.info(f"[ReplyerManager] 选择 MaisakaReplyGenerator: generator_type={generator_type}")
|
||||
maisaka_replyer_class = get_maisaka_replyer_class()
|
||||
|
||||
replyer = maisaka_replyer_class(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
)
|
||||
elif target_stream.is_group_session:
|
||||
logger.info("[ReplyerManager] importing DefaultReplyer")
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
|
||||
replyer = DefaultReplyer(
|
||||
replyer = MaisakaReplyGenerator(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
)
|
||||
else:
|
||||
logger.info("[ReplyerManager] importing PrivateReplyer")
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
|
||||
replyer = PrivateReplyer(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
)
|
||||
logger.warning(f"[ReplyerManager] 不支持的 replyer_type={replyer_type}")
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception(f"[ReplyerManager] 创建 replyer 失败: cache_key={cache_key}")
|
||||
raise
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("common_utils")
|
||||
|
||||
@@ -10,23 +10,14 @@ class TempMethodsExpression:
|
||||
"""用于临时存放一些方法的类"""
|
||||
|
||||
@staticmethod
|
||||
def get_expression_config_for_chat(chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
|
||||
"""
|
||||
根据聊天流ID获取表达配置
|
||||
|
||||
Args:
|
||||
chat_stream_id: 聊天流ID,格式为哈希值
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
|
||||
"""
|
||||
def _find_expression_config_item(chat_stream_id: Optional[str] = None):
|
||||
if not global_config.expression.learning_list:
|
||||
return True, True, True
|
||||
return None
|
||||
|
||||
if chat_stream_id:
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
continue # 这是全局的
|
||||
continue
|
||||
stream_id = TempMethodsExpression._get_stream_id(
|
||||
config_item.platform,
|
||||
str(config_item.item_id),
|
||||
@@ -34,14 +25,44 @@ class TempMethodsExpression:
|
||||
)
|
||||
if stream_id is None:
|
||||
continue
|
||||
if stream_id == chat_stream_id:
|
||||
if stream_id != chat_stream_id:
|
||||
continue
|
||||
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
|
||||
return config_item
|
||||
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
|
||||
return config_item
|
||||
|
||||
return True, True, True
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_expression_advanced_chosen_for_chat(chat_stream_id: Optional[str] = None) -> bool:
|
||||
"""根据聊天流 ID 获取表达方式是否启用二次选择。"""
|
||||
config_item = TempMethodsExpression._find_expression_config_item(chat_stream_id)
|
||||
if config_item is None:
|
||||
return False
|
||||
return config_item.advanced_chosen
|
||||
|
||||
@staticmethod
|
||||
def get_expression_config_for_chat(chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
|
||||
"""
|
||||
根据聊天流 ID 获取表达配置。
|
||||
|
||||
Args:
|
||||
chat_stream_id: 聊天流 ID,格式为哈希值
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用 jargon 学习)
|
||||
"""
|
||||
config_item = TempMethodsExpression._find_expression_config_item(chat_stream_id)
|
||||
if config_item is None:
|
||||
return True, True, True
|
||||
|
||||
return (
|
||||
config_item.use_expression,
|
||||
config_item.enable_learning,
|
||||
config_item.enable_jargon_learning,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_stream_id(
|
||||
@@ -50,15 +71,15 @@ class TempMethodsExpression:
|
||||
is_group: bool = False,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
根据平台、ID字符串和是否为群聊生成聊天流ID
|
||||
根据平台、ID 字符串和是否为群聊生成聊天流 ID。
|
||||
|
||||
Args:
|
||||
platform: 平台名称
|
||||
id_str: 用户或群组的原始ID字符串
|
||||
id_str: 用户或群组的原始 ID 字符串
|
||||
is_group: 是否为群聊
|
||||
|
||||
Returns:
|
||||
str: 生成的聊天流ID(哈希值)
|
||||
str: 生成的聊天流 ID(哈希值)
|
||||
"""
|
||||
try:
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
@@ -68,5 +89,5 @@ class TempMethodsExpression:
|
||||
else:
|
||||
return SessionUtils.calculate_session_id(platform, user_id=str(id_str))
|
||||
except Exception as e:
|
||||
logger.error(f"生成聊天流ID失败: {e}")
|
||||
logger.error(f"生成聊天流 ID 失败: {e}")
|
||||
return None
|
||||
|
||||
@@ -20,7 +20,7 @@ from src.services.embedding_service import EmbeddingServiceClient
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||
from src.common.data_models.chat_target_info_data_model import ChatTargetInfo
|
||||
|
||||
logger = get_logger("chat_utils")
|
||||
_warned_unconfigured_platforms: set[str] = set()
|
||||
@@ -699,7 +699,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
|
||||
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
||||
|
||||
|
||||
def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetPersonInfo"]]:
|
||||
def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["ChatTargetInfo"]]:
|
||||
"""
|
||||
获取聊天类型(是否群聊)和私聊对象信息。
|
||||
|
||||
@@ -734,13 +734,13 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
|
||||
):
|
||||
user_nickname = chat_stream.context.message.message_info.user_info.user_nickname
|
||||
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo # 解决循环导入问题
|
||||
from src.common.data_models.chat_target_info_data_model import ChatTargetInfo # 解决循环导入问题
|
||||
|
||||
# Initialize target_info with basic info
|
||||
target_info = TargetPersonInfo(
|
||||
target_info = ChatTargetInfo(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname, # type: ignore
|
||||
session_nickname=user_nickname or "",
|
||||
person_id=None,
|
||||
person_name=None,
|
||||
)
|
||||
@@ -752,6 +752,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
|
||||
logger.warning(f"用户 {user_nickname} 尚未认识")
|
||||
# 如果用户尚未认识,则返回False和None
|
||||
return False, None
|
||||
target_info.is_known = True
|
||||
if person.person_id:
|
||||
target_info.person_id = person.person_id
|
||||
target_info.person_name = person.person_name
|
||||
|
||||
@@ -24,125 +24,148 @@ logger = get_logger("emoji")
|
||||
class BaseImageDataModel(BaseDatabaseDataModel[Images]):
|
||||
def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None):
|
||||
if not full_path:
|
||||
# 创建时候即检测文件路径合法性
|
||||
raise ValueError("表情包路径不能为空")
|
||||
raise ValueError("图片路径不能为空")
|
||||
if Path(full_path).is_dir() or not Path(full_path).exists():
|
||||
raise FileNotFoundError(f"表情包路径无效: {full_path}")
|
||||
raise FileNotFoundError(f"图片路径无效: {full_path}")
|
||||
|
||||
resolved_path = Path(full_path).absolute().resolve()
|
||||
self.full_path: Path
|
||||
self.dir_path: Path
|
||||
self.file_name: str
|
||||
self._set_full_path(resolved_path)
|
||||
|
||||
self.file_hash: str = None # type: ignore
|
||||
|
||||
self.image_bytes: Optional[bytes] = image_bytes
|
||||
|
||||
self.image_format: str = "" # 图片格式
|
||||
self.image_format: str = ""
|
||||
|
||||
def _set_full_path(self, full_path: Path) -> None:
|
||||
"""同步更新文件路径相关的运行时元数据。"""
|
||||
"""同步刷新路径、目录和文件名等运行时元数据。"""
|
||||
resolved_path = full_path.absolute().resolve()
|
||||
self.full_path = resolved_path
|
||||
self.dir_path = resolved_path.parent.resolve()
|
||||
self.file_name = resolved_path.name
|
||||
|
||||
def _restore_image_format_from_path(self) -> None:
|
||||
"""根据文件扩展名恢复基础图片格式信息。"""
|
||||
"""根据文件扩展名恢复图片格式信息。"""
|
||||
self.image_format = self.full_path.suffix.removeprefix(".").lower()
|
||||
|
||||
def _build_non_conflicting_path(self, target_path: Path) -> Path:
|
||||
"""在目标路径被占用时,生成一个可用的新路径。"""
|
||||
candidate_path = target_path
|
||||
index = 1
|
||||
while candidate_path.exists():
|
||||
candidate_path = target_path.with_name(
|
||||
f"{target_path.stem}_{self.file_hash[:8]}_{index}{target_path.suffix}"
|
||||
)
|
||||
index += 1
|
||||
return candidate_path
|
||||
|
||||
def _rename_file_to_match_format(self) -> None:
|
||||
"""修正文件扩展名,并处理目标文件已存在的冲突。"""
|
||||
new_file_name = ".".join(self.file_name.split(".")[:-1] + [self.image_format])
|
||||
new_full_path = self.dir_path / new_file_name
|
||||
if new_full_path == self.full_path:
|
||||
return
|
||||
|
||||
if new_full_path.exists():
|
||||
existing_file_hash = hashlib.sha256(self.read_image_bytes(new_full_path)).hexdigest()
|
||||
if existing_file_hash == self.file_hash:
|
||||
logger.info(f"[初始化] {new_full_path.name} 已存在且内容一致,复用已有文件")
|
||||
self.full_path.unlink()
|
||||
self._set_full_path(new_full_path)
|
||||
return
|
||||
|
||||
conflict_free_path = self._build_non_conflicting_path(new_full_path)
|
||||
logger.warning(
|
||||
f"[初始化] {new_full_path.name} 已存在且内容不同,改为保存到 {conflict_free_path.name}"
|
||||
)
|
||||
self.full_path.rename(conflict_free_path)
|
||||
self._set_full_path(conflict_free_path)
|
||||
return
|
||||
|
||||
self.full_path.rename(new_full_path)
|
||||
self._set_full_path(new_full_path)
|
||||
|
||||
def read_image_bytes(self, path: Path) -> bytes:
|
||||
"""
|
||||
同步读取图片文件的字节内容
|
||||
同步读取图片文件的字节内容。
|
||||
|
||||
Args:
|
||||
path (Path): 图片文件的完整路径
|
||||
path: 图片文件的完整路径。
|
||||
|
||||
Returns:
|
||||
return (bytes): 图片文件的字节内容
|
||||
Raises:
|
||||
FileNotFoundError: 如果文件不存在则抛出该异常
|
||||
Exception: 其他读取文件时发生的异常
|
||||
图片文件的字节内容。
|
||||
"""
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError as e:
|
||||
with open(path, "rb") as file:
|
||||
return file.read()
|
||||
except FileNotFoundError as exc:
|
||||
logger.error(f"[读取图片文件] 文件未找到: {path}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"[读取图片文件] 读取文件时发生错误: {e}")
|
||||
raise e
|
||||
raise exc
|
||||
except Exception as exc:
|
||||
logger.error(f"[读取图片文件] 读取文件时发生错误: {exc}")
|
||||
raise exc
|
||||
|
||||
def get_image_format(self, image_bytes: bytes) -> str:
|
||||
"""
|
||||
获取图片的格式
|
||||
获取图片的实际格式。
|
||||
|
||||
Args:
|
||||
image_bytes (bytes): 图片的字节内容
|
||||
image_bytes: 图片的字节内容。
|
||||
|
||||
Returns:
|
||||
return (str): 图片的格式(小写)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果无法识别图片格式
|
||||
Exception: 其他读取图片格式时发生的异常
|
||||
小写格式名,例如 `png`、`jpeg`。
|
||||
"""
|
||||
try:
|
||||
with PILImage.open(io.BytesIO(image_bytes)) as img:
|
||||
if not img.format:
|
||||
raise ValueError("无法识别图片格式")
|
||||
return img.format.lower()
|
||||
except Exception as e:
|
||||
logger.error(f"[获取图片格式] 读取图片格式时发生错误: {e}")
|
||||
raise e
|
||||
except Exception as exc:
|
||||
logger.error(f"[获取图片格式] 读取图片格式时发生错误: {exc}")
|
||||
raise exc
|
||||
|
||||
async def calculate_hash_format(self) -> bool:
|
||||
"""
|
||||
异步计算表情包的哈希值和格式,初始化后应该执行此方法来确保对象的哈希值和格式正确
|
||||
计算图片哈希和实际格式,并在需要时修正扩展名。
|
||||
|
||||
Returns:
|
||||
return (bool): 如果成功计算哈希值和格式则返回True,否则返回False
|
||||
成功返回 `True`,失败返回 `False`。
|
||||
"""
|
||||
try:
|
||||
# 计算哈希值
|
||||
logger.debug(f"[初始化] 计算 {self.file_name} 的哈希值...")
|
||||
if not self.image_bytes:
|
||||
if self.image_bytes is None:
|
||||
logger.debug(f"[初始化] 正在读取文件: {self.full_path}")
|
||||
image_bytes = await asyncio.to_thread(self.read_image_bytes, self.full_path)
|
||||
else:
|
||||
image_bytes = self.image_bytes
|
||||
|
||||
self.image_bytes = image_bytes
|
||||
self.file_hash = hashlib.sha256(image_bytes).hexdigest()
|
||||
logger.debug(f"[初始化] {self.file_name} 计算哈希值成功: {self.file_hash}")
|
||||
|
||||
# 用PIL读取图片格式
|
||||
logger.debug(f"[初始化] 读取 {self.file_name} 的图片格式...")
|
||||
self.image_format = await asyncio.to_thread(self.get_image_format, image_bytes)
|
||||
logger.debug(f"[初始化] {self.file_name} 读取图片格式成功: {self.image_format}")
|
||||
|
||||
# 比对文件扩展名和实际格式
|
||||
file_ext = self.file_name.split(".")[-1].lower()
|
||||
if file_ext != self.image_format:
|
||||
logger.warning(
|
||||
f"[初始化] {self.file_name} 文件扩展名与实际格式不符: ext`{file_ext}`!=`{self.image_format}`"
|
||||
)
|
||||
# 重命名文件以匹配实际格式
|
||||
new_file_name = ".".join(self.file_name.split(".")[:-1] + [self.image_format])
|
||||
new_full_path = self.dir_path / new_file_name
|
||||
self.full_path.rename(new_full_path)
|
||||
self._set_full_path(new_full_path)
|
||||
self._rename_file_to_match_format()
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[初始化] 初始化图片时发生错误: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"[初始化] 初始化图片时发生错误: {exc}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
|
||||
class MaiEmoji(BaseImageDataModel):
|
||||
"""麦麦的表情包对象,仅当**图片文件存在**时才应该创建此对象,数据库记录如果标记为文件不存在`(no_file_flag = True)`则不应该调用 `from_db_instance` 方法来创建此对象"""
|
||||
"""表情包数据模型。"""
|
||||
|
||||
def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None):
|
||||
# self.embedding = []
|
||||
self.description: str = ""
|
||||
self.emotion: List[str] = []
|
||||
self.query_count = 0
|
||||
@@ -152,33 +175,26 @@ class MaiEmoji(BaseImageDataModel):
|
||||
|
||||
@classmethod
|
||||
def from_db_instance(cls, db_record: Images):
|
||||
"""从数据库记录创建 MaiEmoji 对象,如果记录标记为文件不存在则**抛出异常**
|
||||
|
||||
调用者应该对数据库记录进行检查,如果 `no_file_flag` 为 True 则不应该调用此方法
|
||||
|
||||
Args:
|
||||
db_record (Images): 数据库中的图片记录
|
||||
Returns:
|
||||
return (MaiEmoji): 包含图片信息的 MaiEmoji 对象
|
||||
Raises:
|
||||
ValueError: 如果数据库记录标记为文件不存在则抛出该异常
|
||||
"""
|
||||
"""从数据库记录构建 `MaiEmoji` 对象。"""
|
||||
if db_record.no_file_flag:
|
||||
raise ValueError(f"数据库记录 {db_record.image_hash} 标记为文件不存在,无法创建 MaiEmoji 对象")
|
||||
|
||||
obj = cls(db_record.full_path)
|
||||
obj.file_hash = db_record.image_hash
|
||||
obj._restore_image_format_from_path()
|
||||
|
||||
description = db_record.description or ""
|
||||
obj.description = description
|
||||
normalized_tags = [
|
||||
str(item).strip()
|
||||
for item in str(description).replace(",", ",").replace("、", ",").replace(";", ",").split(",")
|
||||
for item in str(description).replace(",", ",").replace("。", ",").replace("、", ",").split(",")
|
||||
if str(item).strip()
|
||||
]
|
||||
deduped_tags: List[str] = []
|
||||
for item in normalized_tags:
|
||||
if item not in deduped_tags:
|
||||
deduped_tags.append(item)
|
||||
|
||||
obj.emotion = deduped_tags
|
||||
obj.query_count = db_record.query_count
|
||||
obj.last_used_time = db_record.last_used_time
|
||||
@@ -198,7 +214,7 @@ class MaiEmoji(BaseImageDataModel):
|
||||
|
||||
|
||||
class MaiImage(BaseImageDataModel):
|
||||
"""麦麦图片数据模型,仅当**图片文件存在**时才应该创建此对象,数据库记录如果标记为文件不存在`(no_file_flag = True)`则不应该调用 `from_db_instance` 方法来创建此对象"""
|
||||
"""普通图片数据模型。"""
|
||||
|
||||
def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None):
|
||||
self.description: str = ""
|
||||
@@ -207,19 +223,10 @@ class MaiImage(BaseImageDataModel):
|
||||
|
||||
@classmethod
|
||||
def from_db_instance(cls, db_record: Images):
|
||||
"""从数据库记录创建 MaiImage 对象,如果记录标记为文件不存在则**抛出异常**
|
||||
|
||||
调用者应该对数据库记录进行检查,如果 `no_file_flag` 为 True 则不应该调用此方法
|
||||
|
||||
Args:
|
||||
db_record (Images): 数据库中的图片记录
|
||||
Returns:
|
||||
return (MaiImage): 包含图片信息的 MaiImage 对象
|
||||
Raises:
|
||||
ValueError: 如果数据库记录标记为文件不存在则抛出该异常
|
||||
"""
|
||||
"""从数据库记录构建 `MaiImage` 对象。"""
|
||||
if db_record.no_file_flag:
|
||||
raise ValueError(f"数据库记录 {db_record.image_hash} 标记为文件不存在,无法创建 MaiImage 对象")
|
||||
|
||||
obj = cls(db_record.full_path)
|
||||
obj.file_hash = db_record.image_hash
|
||||
obj._set_full_path(Path(db_record.full_path))
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
# from dataclasses import dataclass, field
|
||||
# from typing import Optional, Dict, TYPE_CHECKING
|
||||
# from . import BaseDataModel
|
||||
|
||||
# if TYPE_CHECKING:
|
||||
# from .database_data_model import DatabaseMessages
|
||||
# from src.core.types import ActionInfo
|
||||
|
||||
|
||||
# # @dataclass
|
||||
# # class TargetPersonInfo(BaseDataModel):
|
||||
# # platform: str = field(default_factory=str)
|
||||
# # user_id: str = field(default_factory=str)
|
||||
# # user_nickname: str = field(default_factory=str)
|
||||
# # person_id: Optional[str] = None
|
||||
# # person_name: Optional[str] = None
|
||||
# 已重构,见src/common/data_models/chat_target_info_data_model.py
|
||||
|
||||
# @dataclass
|
||||
# class ActionPlannerInfo(BaseDataModel):
|
||||
# action_type: str = field(default_factory=str)
|
||||
# reasoning: Optional[str] = None
|
||||
# action_data: Optional[Dict] = None
|
||||
# action_message: Optional["DatabaseMessages"] = None
|
||||
# available_actions: Optional[Dict[str, "ActionInfo"]] = None
|
||||
# loop_start_time: Optional[float] = None
|
||||
# action_reasoning: Optional[str] = None
|
||||
# 已重构,见src/common/data_models/planned_action_data_models.py
|
||||
@@ -14,7 +14,6 @@ from src.llm_models.payload_content.resp_format import RespFormat
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message
|
||||
|
||||
|
||||
@@ -24,7 +23,7 @@ PromptMessage: TypeAlias = Dict[str, Any]
|
||||
PromptInput: TypeAlias = str | List[PromptMessage]
|
||||
"""统一的提示输入类型。"""
|
||||
|
||||
MessageFactory: TypeAlias = Callable[["BaseClient"], List["Message"]]
|
||||
MessageFactory: TypeAlias = Callable[..., List["Message"]]
|
||||
"""统一的消息工厂类型。"""
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from . import BaseDataModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.message_component_data_model import MessageSequence
|
||||
from src.common.data_models.llm_service_data_models import PromptMessage
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
|
||||
|
||||
@@ -121,6 +122,10 @@ class ReplyGenerationResult(BaseDataModel):
|
||||
default=None,
|
||||
metadata={"description": "供监控层直接消费的通用 tool 展示详情。"},
|
||||
)
|
||||
request_messages: List["PromptMessage"] = field(
|
||||
default_factory=list,
|
||||
metadata={"description": "本次 replyer 实际发送给模型的消息列表。"},
|
||||
)
|
||||
|
||||
|
||||
def build_reply_monitor_detail(result: ReplyGenerationResult) -> Dict[str, Any]:
|
||||
@@ -133,6 +138,8 @@ def build_reply_monitor_detail(result: ReplyGenerationResult) -> Dict[str, Any]:
|
||||
|
||||
if prompt_text:
|
||||
detail["prompt_text"] = prompt_text
|
||||
if result.request_messages:
|
||||
detail["request_messages"] = result.request_messages
|
||||
if reasoning_text:
|
||||
detail["reasoning_text"] = reasoning_text
|
||||
if output_text:
|
||||
|
||||
@@ -84,9 +84,9 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: TimestampMode
|
||||
|
||||
def calculate_typing_time(
|
||||
input_string: str,
|
||||
chinese_time: float = 0.3,
|
||||
english_time: float = 0.15,
|
||||
line_break_time: float = 0.1,
|
||||
chinese_time: float = 0.2,
|
||||
english_time: float = 0.1,
|
||||
line_break_time: float = 0.05,
|
||||
is_emoji: bool = False,
|
||||
) -> float:
|
||||
"""
|
||||
|
||||
@@ -10,24 +10,14 @@ logger = get_logger("config_utils")
|
||||
|
||||
class ExpressionConfigUtils:
|
||||
@staticmethod
|
||||
def get_expression_config_for_chat(session_id: Optional[str] = None) -> tuple[bool, bool, bool]:
|
||||
# sourcery skip: use-next
|
||||
"""
|
||||
根据聊天会话ID获取表达配置
|
||||
|
||||
Args:
|
||||
session_id: 聊天会话ID,格式为哈希值
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
|
||||
"""
|
||||
def _find_expression_config_item(session_id: Optional[str] = None):
|
||||
if not global_config.expression.learning_list:
|
||||
return True, True, True
|
||||
return None
|
||||
|
||||
if session_id:
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
continue # 这是全局的
|
||||
continue
|
||||
stream_id = ExpressionConfigUtils._get_stream_id(
|
||||
config_item.platform,
|
||||
str(config_item.item_id),
|
||||
@@ -35,28 +25,59 @@ class ExpressionConfigUtils:
|
||||
)
|
||||
if stream_id is None:
|
||||
continue
|
||||
if stream_id == session_id:
|
||||
if stream_id != session_id:
|
||||
continue
|
||||
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
|
||||
return config_item
|
||||
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
|
||||
return config_item
|
||||
|
||||
return True, True, True
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_expression_advanced_chosen_for_chat(session_id: Optional[str] = None) -> bool:
|
||||
"""根据聊天会话 ID 获取表达方式是否启用二次选择。"""
|
||||
config_item = ExpressionConfigUtils._find_expression_config_item(session_id)
|
||||
if config_item is None:
|
||||
return False
|
||||
return config_item.advanced_chosen
|
||||
|
||||
@staticmethod
|
||||
def get_expression_config_for_chat(session_id: Optional[str] = None) -> tuple[bool, bool, bool]:
|
||||
# sourcery skip: use-next
|
||||
"""
|
||||
根据聊天会话 ID 获取表达配置。
|
||||
|
||||
Args:
|
||||
session_id: 聊天会话 ID,格式为哈希值
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用 jargon 学习)
|
||||
"""
|
||||
config_item = ExpressionConfigUtils._find_expression_config_item(session_id)
|
||||
if config_item is None:
|
||||
return True, True, True
|
||||
|
||||
return (
|
||||
config_item.use_expression,
|
||||
config_item.enable_learning,
|
||||
config_item.enable_jargon_learning,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_stream_id(platform: str, id_str: str, is_group: bool = False) -> Optional[str]:
|
||||
# sourcery skip: remove-unnecessary-cast
|
||||
"""
|
||||
根据平台、ID字符串和是否为群聊生成聊天流ID
|
||||
根据平台、ID 字符串和是否为群聊生成聊天流 ID。
|
||||
|
||||
Args:
|
||||
platform: 平台名称
|
||||
id_str: 用户或群组的原始ID字符串
|
||||
id_str: 用户或群组的原始 ID 字符串
|
||||
is_group: 是否为群聊
|
||||
|
||||
Returns:
|
||||
str: 生成的聊天流ID(哈希值)
|
||||
str: 生成的聊天流 ID(哈希值)
|
||||
"""
|
||||
try:
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
@@ -66,7 +87,7 @@ class ExpressionConfigUtils:
|
||||
else:
|
||||
return SessionUtils.calculate_session_id(platform, user_id=str(id_str))
|
||||
except Exception as e:
|
||||
logger.error(f"生成聊天流ID失败: {e}")
|
||||
logger.error(f"生成聊天流 ID 失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -91,7 +112,7 @@ class ChatConfigUtils:
|
||||
else:
|
||||
rule_session_id = SessionUtils.calculate_session_id(rule.platform, user_id=str(rule.item_id))
|
||||
if rule_session_id != session_id:
|
||||
continue # 不匹配的会话ID,跳过
|
||||
continue # 不匹配的会话 ID,跳过
|
||||
parsed_range = ChatConfigUtils.parse_range(rule.time)
|
||||
if not parsed_range:
|
||||
continue # 无法解析的时间范围,跳过
|
||||
@@ -102,7 +123,7 @@ class ChatConfigUtils:
|
||||
else: # 跨天的时间范围
|
||||
in_range = now_min >= start_min or now_min <= end_min
|
||||
if in_range:
|
||||
return rule.value or 0.0 # 如果规则生效但没有设置值,返回0.0
|
||||
return rule.value or 0.0 # 如果规则生效但没有设置值,返回 0.0
|
||||
|
||||
# 没有匹配到会话相关的规则,继续匹配全局规则
|
||||
for rule in global_config.chat.talk_value_rules:
|
||||
@@ -118,7 +139,7 @@ class ChatConfigUtils:
|
||||
else: # 跨天的时间范围
|
||||
in_range = now_min >= start_min or now_min <= end_min
|
||||
if in_range:
|
||||
return rule.value or 0.0 # 如果规则生效但没有设置值,返回0.0
|
||||
return rule.value or 0.0 # 如果规则生效但没有设置值,返回 0.0
|
||||
return result # 如果没有任何规则生效,返回默认值
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -23,7 +23,6 @@ from .official_configs import (
|
||||
EmojiConfig,
|
||||
ExpressionConfig,
|
||||
KeywordReactionConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
MaiSakaConfig,
|
||||
MaimMessageConfig,
|
||||
MCPConfig,
|
||||
@@ -54,9 +53,10 @@ PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
|
||||
CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0"
|
||||
CONFIG_VERSION: str = "8.5.2"
|
||||
MODEL_CONFIG_VERSION: str = "1.13.1"
|
||||
CONFIG_VERSION: str = "8.7.1"
|
||||
MODEL_CONFIG_VERSION: str = "1.14.0"
|
||||
|
||||
logger = get_logger("config")
|
||||
|
||||
@@ -454,6 +454,20 @@ def generate_new_config_file(config_class: type[T], config_path: Path, inner_con
|
||||
write_config_to_file(config, config_path, inner_config_version)
|
||||
|
||||
|
||||
def remove_legacy_env_file(env_path: Path) -> None:
|
||||
"""删除已完成迁移的旧版 `.env` 文件。"""
|
||||
|
||||
if not env_path.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
env_path.unlink()
|
||||
except OSError as exc:
|
||||
logger.warning(f"旧版 .env 配置文件删除失败,请手动删除: {env_path},原因: {exc}")
|
||||
else:
|
||||
logger.warning(f"检测到旧版环境变量绑定配置迁移成功,已删除旧版 .env 文件: {env_path}")
|
||||
|
||||
|
||||
def load_config_from_file(
|
||||
config_class: type[T], config_path: Path, new_ver: str, override_repr: bool = False
|
||||
) -> tuple[T, bool]:
|
||||
@@ -467,10 +481,12 @@ def load_config_from_file(
|
||||
if not isinstance(inner_version, str):
|
||||
raise TypeError(t("config.invalid_inner_version"))
|
||||
old_ver: str = inner_version
|
||||
env_migration_applied: bool = False
|
||||
config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理
|
||||
config_data = config_data.unwrap() # 转换为普通字典,方便后续处理
|
||||
if config_path.name == "bot_config.toml" and config_class.__name__ == "Config":
|
||||
env_migration = migrate_legacy_bind_env_to_bot_config_dict(config_data)
|
||||
env_migration_applied = env_migration.migrated
|
||||
if env_migration.migrated:
|
||||
logger.warning(f"检测到旧版环境变量绑定配置,已迁移到主配置: {env_migration.reason}")
|
||||
config_data = env_migration.data
|
||||
@@ -497,9 +513,11 @@ def load_config_from_file(
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
if compare_versions(old_ver, new_ver):
|
||||
if compare_versions(old_ver, new_ver) or env_migration_applied:
|
||||
output_config_changes(attribute_data, logger, old_ver, new_ver, config_path.name)
|
||||
write_config_to_file(target_config, config_path, new_ver, override_repr)
|
||||
if env_migration_applied:
|
||||
remove_legacy_env_file(LEGACY_ENV_PATH)
|
||||
updated = True
|
||||
return target_config, updated
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
"""
|
||||
legacy_migration.py
|
||||
|
||||
一个“可随时拔掉”的旧配置兼容层:
|
||||
- 仅在配置解析失败时尝试修复旧格式数据(7.x -> 8.x 这一类结构性变更)
|
||||
- 不依赖 Pydantic / ConfigBase,仅对 dict 做最小转换
|
||||
- 成功则返回(修复后的 dict, True),失败则返回(原 dict, False)
|
||||
|
||||
设计目标:与现有 config 加载逻辑的接触点尽可能小,未来不需要时可一键移除。
|
||||
旧配置兼容层。
|
||||
仅保留当前仍需要的“解析前结构修复”,避免老配置在 `from_dict` 前直接失败。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -16,12 +12,7 @@ from typing import Any, Optional
|
||||
|
||||
import os
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("legacy_migration")
|
||||
|
||||
|
||||
# 方便未来快速关闭/移除
|
||||
ENABLE_LEGACY_MIGRATION: bool = True
|
||||
|
||||
|
||||
@@ -43,6 +34,7 @@ def _as_list(x: Any) -> Optional[list[Any]]:
|
||||
def _parse_host_env(value: Any) -> Optional[str]:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
|
||||
normalized_value = value.strip()
|
||||
return normalized_value or None
|
||||
|
||||
@@ -75,116 +67,73 @@ def _migrate_env_value(section: dict[str, Any], key: str, parsed_env_value: Any,
|
||||
return True
|
||||
|
||||
|
||||
def _move_section_key(source: dict[str, Any], target: dict[str, Any], key: str) -> bool:
|
||||
"""将配置项从旧分组移动到新分组,若新分组已有值则保留新值。"""
|
||||
|
||||
if key not in source:
|
||||
return False
|
||||
|
||||
if key not in target:
|
||||
target[key] = source[key]
|
||||
source.pop(key, None)
|
||||
return True
|
||||
|
||||
|
||||
def _parse_triplet_target(s: str) -> Optional[dict[str, str]]:
|
||||
"""
|
||||
解析 "platform:id:type" -> {platform,item_id,rule_type}
|
||||
返回 None 表示无法解析。
|
||||
解析 "platform:id:type" -> {platform, item_id, rule_type}
|
||||
"""
|
||||
if not isinstance(s, str):
|
||||
return None
|
||||
|
||||
parts = s.split(":", 2)
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
platform, item_id, rule_type = parts
|
||||
if rule_type not in ("group", "private"):
|
||||
return None
|
||||
return {"platform": platform, "item_id": item_id, "rule_type": rule_type}
|
||||
|
||||
|
||||
def _parse_quad_prompt(s: str) -> Optional[dict[str, str]]:
|
||||
"""
|
||||
解析 "platform:id:type:prompt" -> {platform,item_id,rule_type,prompt}
|
||||
prompt 允许包含冒号,因此只切前三个冒号。
|
||||
"""
|
||||
if not isinstance(s, str):
|
||||
return None
|
||||
parts = s.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
return None
|
||||
platform, item_id, rule_type, prompt = parts
|
||||
if rule_type not in ("group", "private"):
|
||||
return None
|
||||
if not prompt:
|
||||
return None
|
||||
return {"platform": platform, "item_id": item_id, "rule_type": rule_type, "prompt": prompt}
|
||||
|
||||
|
||||
def _parse_enable_disable(v: Any) -> Optional[bool]:
|
||||
"""
|
||||
兼容旧值 "enable"/"disable" 以及 bool。
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
|
||||
if isinstance(v, str):
|
||||
vv = v.strip().lower()
|
||||
if vv == "enable":
|
||||
normalized_value = v.strip().lower()
|
||||
if normalized_value == "enable":
|
||||
return True
|
||||
if vv == "disable":
|
||||
if normalized_value == "disable":
|
||||
return False
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool:
|
||||
"""
|
||||
旧:
|
||||
learning_list = [
|
||||
["", "enable", "enable", "enable"],
|
||||
["qq:1919810:group", "enable", "enable", "enable"],
|
||||
]
|
||||
兼容旧旧格式:
|
||||
learning_list = [
|
||||
["qq:1919810:group", "enable", "enable", "0.5"],
|
||||
["", "disable", "disable", "0.1"],
|
||||
]
|
||||
新:
|
||||
[[expression.learning_list]]
|
||||
platform="", item_id="", rule_type="group", use_expression=true, enable_learning=true, enable_jargon_learning=true
|
||||
将旧版 expression.learning_list 转成当前结构。
|
||||
"""
|
||||
ll = _as_list(expr.get("learning_list"))
|
||||
if ll is None:
|
||||
learning_list = _as_list(expr.get("learning_list"))
|
||||
if learning_list is None:
|
||||
return False
|
||||
|
||||
# 如果已经是新格式(列表里是 dict),跳过
|
||||
if ll and all(isinstance(i, dict) for i in ll):
|
||||
if learning_list and all(isinstance(item, dict) for item in learning_list):
|
||||
return False
|
||||
|
||||
migrated_items: list[dict[str, Any]] = []
|
||||
for row in ll:
|
||||
r = _as_list(row)
|
||||
if r is None or len(r) < 4:
|
||||
# 行结构不对,无法安全迁移
|
||||
for row in learning_list:
|
||||
row_items = _as_list(row)
|
||||
if row_items is None or len(row_items) < 4:
|
||||
return False
|
||||
|
||||
target_raw = r[0]
|
||||
use_expression = _parse_enable_disable(r[1])
|
||||
enable_learning = _parse_enable_disable(r[2])
|
||||
enable_jargon_learning = _parse_enable_disable(r[3])
|
||||
target_raw = row_items[0]
|
||||
use_expression = _parse_enable_disable(row_items[1])
|
||||
enable_learning = _parse_enable_disable(row_items[2])
|
||||
enable_jargon_learning = _parse_enable_disable(row_items[3])
|
||||
|
||||
if enable_jargon_learning is None:
|
||||
# 更早期的配置在第 4 列记录的是一个已废弃的数值权重/阈值,
|
||||
# 当前 schema 已没有对应字段。这里按保守策略兼容迁移:
|
||||
# 丢弃旧数值,并将 enable_jargon_learning 置为 False。
|
||||
# 更早期版本第 4 列是已废弃的数值阈值,这里仅做保守兼容。
|
||||
try:
|
||||
float(str(r[3]))
|
||||
float(str(row_items[3]))
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
enable_jargon_learning = False
|
||||
|
||||
if use_expression is None or enable_learning is None or enable_jargon_learning is None:
|
||||
return False
|
||||
|
||||
# 旧格式中 target 允许为空字符串:表示全局;新结构必须有三元组字段
|
||||
if target_raw == "" or target_raw is None:
|
||||
target = {"platform": "", "item_id": "", "rule_type": "group"}
|
||||
else:
|
||||
@@ -209,99 +158,56 @@ def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool:
|
||||
|
||||
def _migrate_expression_groups(expr: dict[str, Any]) -> bool:
|
||||
"""
|
||||
旧:
|
||||
expression_groups = [
|
||||
["qq:1:group","qq:2:group"],
|
||||
["qq:3:group"],
|
||||
]
|
||||
新:
|
||||
expression_groups = [
|
||||
{ expression_groups = [ {platform="qq", item_id="1", rule_type="group"}, ... ] },
|
||||
{ expression_groups = [ ... ] },
|
||||
]
|
||||
将旧版 expression.expression_groups 转成当前结构。
|
||||
"""
|
||||
eg = _as_list(expr.get("expression_groups"))
|
||||
if eg is None:
|
||||
expression_groups = _as_list(expr.get("expression_groups"))
|
||||
if expression_groups is None:
|
||||
return False
|
||||
if expression_groups and all(isinstance(item, dict) for item in expression_groups):
|
||||
return False
|
||||
|
||||
# 已经是新格式(列表里是 dict 且包含 expression_groups),跳过
|
||||
if eg and all(isinstance(i, dict) for i in eg):
|
||||
return False
|
||||
|
||||
migrated: list[dict[str, Any]] = []
|
||||
for group in eg:
|
||||
g = _as_list(group)
|
||||
if g is None:
|
||||
migrated_groups: list[dict[str, Any]] = []
|
||||
for group in expression_groups:
|
||||
group_items = _as_list(group)
|
||||
if group_items is None:
|
||||
return False
|
||||
|
||||
targets: list[dict[str, str]] = []
|
||||
for item in g:
|
||||
for item in group_items:
|
||||
parsed = _parse_triplet_target(str(item))
|
||||
if parsed is None:
|
||||
return False
|
||||
targets.append(parsed)
|
||||
migrated.append({"expression_groups": targets})
|
||||
|
||||
expr["expression_groups"] = migrated
|
||||
migrated_groups.append({"expression_groups": targets})
|
||||
|
||||
expr["expression_groups"] = migrated_groups
|
||||
return True
|
||||
|
||||
|
||||
def _migrate_target_item_list(parent: dict[str, Any], key: str) -> bool:
|
||||
"""
|
||||
将 list[str] 的 "platform:id:type" 迁移为 list[{platform,item_id,rule_type}]
|
||||
用于:memory.global_memory_blacklist 等。
|
||||
将 list[str] 的 "platform:id:type" 迁移为 list[TargetItem]。
|
||||
"""
|
||||
raw = _as_list(parent.get(key))
|
||||
if raw is None:
|
||||
if raw is None or not raw:
|
||||
return False
|
||||
if raw and all(isinstance(i, dict) for i in raw):
|
||||
if all(isinstance(item, dict) for item in raw):
|
||||
return False
|
||||
|
||||
targets: list[dict[str, str]] = []
|
||||
for item in raw:
|
||||
parsed = _parse_triplet_target(str(item))
|
||||
if parsed is None:
|
||||
return False
|
||||
targets.append(parsed)
|
||||
|
||||
parent[key] = targets
|
||||
return True
|
||||
|
||||
|
||||
def _migrate_extra_prompt_list(exp: dict[str, Any], key: str) -> bool:
|
||||
"""
|
||||
将 list[str] 的 "platform:id:type:prompt" 迁移为 list[{platform,item_id,rule_type,prompt}]
|
||||
用于:experimental.chat_prompts
|
||||
"""
|
||||
raw = _as_list(exp.get(key))
|
||||
if raw is None:
|
||||
return False
|
||||
if raw and all(isinstance(i, dict) for i in raw):
|
||||
return False
|
||||
items: list[dict[str, str]] = []
|
||||
for item in raw:
|
||||
parsed = _parse_quad_prompt(str(item))
|
||||
if parsed is None:
|
||||
return False
|
||||
items.append(parsed)
|
||||
exp[key] = items
|
||||
return True
|
||||
|
||||
|
||||
def _parse_multimodal_replyer(v: Any) -> Optional[bool]:
|
||||
"""兼容旧 replyer_generator_type 到布尔开关的迁移。"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if not isinstance(v, str):
|
||||
return None
|
||||
|
||||
normalized_value = v.strip().lower()
|
||||
if normalized_value == "multimodal":
|
||||
return True
|
||||
if normalized_value == "legacy":
|
||||
return False
|
||||
return None
|
||||
|
||||
|
||||
def migrate_legacy_bind_env_to_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
"""将旧版环境变量中的绑定地址迁移到主配置结构。"""
|
||||
"""将旧版 `.env` 中的绑定地址迁移到主配置结构。"""
|
||||
|
||||
migrated_any = False
|
||||
reasons: list[str] = []
|
||||
@@ -339,8 +245,7 @@ def migrate_legacy_bind_env_to_bot_config_dict(data: dict[str, Any]) -> Migratio
|
||||
|
||||
def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
"""
|
||||
尝试对“总配置 bot_config.toml”的 dict(已 unwrap)进行旧格式修复。
|
||||
仅做我们明确知道的结构性变更;其它字段不动。
|
||||
尝试修复 `bot_config.toml` 的少量旧结构,仅保留当前仍需要的兼容逻辑。
|
||||
"""
|
||||
if not ENABLE_LEGACY_MIGRATION:
|
||||
return MigrationResult(data=data, migrated=False, reason="disabled")
|
||||
@@ -353,41 +258,30 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
if _migrate_expression_learning_list(expr):
|
||||
migrated_any = True
|
||||
reasons.append("expression.learning_list")
|
||||
|
||||
if _migrate_expression_groups(expr):
|
||||
migrated_any = True
|
||||
reasons.append("expression.expression_groups")
|
||||
# allow_reflect: 旧 list[str] -> 新 list[TargetItem]
|
||||
|
||||
if _migrate_target_item_list(expr, "allow_reflect"):
|
||||
migrated_any = True
|
||||
reasons.append("expression.allow_reflect")
|
||||
# manual_reflect_operator_id: 旧 str -> 新 Optional[TargetItem]
|
||||
mroi = expr.get("manual_reflect_operator_id")
|
||||
if isinstance(mroi, str) and mroi.strip():
|
||||
parsed = _parse_triplet_target(mroi.strip())
|
||||
|
||||
manual_reflect_operator_id = expr.get("manual_reflect_operator_id")
|
||||
if isinstance(manual_reflect_operator_id, str) and manual_reflect_operator_id.strip():
|
||||
parsed = _parse_triplet_target(manual_reflect_operator_id.strip())
|
||||
if parsed is not None:
|
||||
expr["manual_reflect_operator_id"] = parsed
|
||||
migrated_any = True
|
||||
reasons.append("expression.manual_reflect_operator_id")
|
||||
|
||||
chat = _as_dict(data.get("chat"))
|
||||
if chat is None:
|
||||
chat = {}
|
||||
data["chat"] = chat
|
||||
elif "private_plan_style" in chat:
|
||||
chat.pop("private_plan_style", None)
|
||||
migrated_any = True
|
||||
reasons.append("chat.private_plan_style_removed")
|
||||
if isinstance(manual_reflect_operator_id, str) and not manual_reflect_operator_id.strip():
|
||||
expr.pop("manual_reflect_operator_id", None)
|
||||
migrated_any = True
|
||||
reasons.append("expression.manual_reflect_operator_id_empty")
|
||||
|
||||
personality = _as_dict(data.get("personality"))
|
||||
visual = _as_dict(data.get("visual"))
|
||||
if visual is None and (
|
||||
(personality is not None and "visual_style" in personality)
|
||||
or "multimodal_planner" in chat
|
||||
or "replyer_generator_type" in chat
|
||||
):
|
||||
visual = {}
|
||||
data["visual"] = visual
|
||||
|
||||
if visual is not None and personality is not None and "visual_style" in personality:
|
||||
if "visual_style" not in visual:
|
||||
visual["visual_style"] = personality["visual_style"]
|
||||
@@ -395,108 +289,19 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
migrated_any = True
|
||||
reasons.append("personality.visual_style_moved_to_visual.visual_style")
|
||||
|
||||
if visual is not None and "multimodal_planner" in chat:
|
||||
if "multimodal_planner" not in visual and isinstance(chat["multimodal_planner"], bool):
|
||||
visual["multimodal_planner"] = chat["multimodal_planner"]
|
||||
if "multimodal_planner" in visual:
|
||||
chat.pop("multimodal_planner", None)
|
||||
if visual is not None and "multimodal_planner" in visual and "planner_mode" not in visual:
|
||||
multimodal_planner = visual.pop("multimodal_planner")
|
||||
if isinstance(multimodal_planner, bool):
|
||||
visual["planner_mode"] = "multimodal" if multimodal_planner else "text"
|
||||
migrated_any = True
|
||||
reasons.append("chat.multimodal_planner_moved_to_visual.multimodal_planner")
|
||||
reasons.append("visual.multimodal_planner_moved_to_visual.planner_mode")
|
||||
else:
|
||||
visual["multimodal_planner"] = multimodal_planner
|
||||
|
||||
if visual is not None and "replyer_generator_type" in chat:
|
||||
multimodal_replyer = _parse_multimodal_replyer(chat["replyer_generator_type"])
|
||||
if "multimodal_replyer" not in visual and multimodal_replyer is not None:
|
||||
visual["multimodal_replyer"] = multimodal_replyer
|
||||
if "multimodal_replyer" in visual:
|
||||
chat.pop("replyer_generator_type", None)
|
||||
migrated_any = True
|
||||
reasons.append("chat.replyer_generator_type_moved_to_visual.multimodal_replyer")
|
||||
|
||||
maisaka = _as_dict(data.get("maisaka"))
|
||||
mem = _as_dict(data.get("memory"))
|
||||
debug = _as_dict(data.get("debug"))
|
||||
if maisaka is not None:
|
||||
moved_memory_keys = ("enable_memory_query_tool", "memory_query_default_limit")
|
||||
if any(key in maisaka for key in moved_memory_keys) and mem is None:
|
||||
mem = {}
|
||||
data["memory"] = mem
|
||||
|
||||
if mem is not None:
|
||||
for moved_key in moved_memory_keys:
|
||||
if _move_section_key(maisaka, mem, moved_key):
|
||||
migrated_any = True
|
||||
reasons.append(f"maisaka.{moved_key}_moved_to_memory")
|
||||
|
||||
if mem is not None and "show_memory_prompt" in mem and debug is None:
|
||||
debug = {}
|
||||
data["debug"] = debug
|
||||
|
||||
if mem is not None:
|
||||
if _migrate_target_item_list(mem, "global_memory_blacklist"):
|
||||
migrated_any = True
|
||||
reasons.append("memory.global_memory_blacklist")
|
||||
|
||||
if debug is not None and _move_section_key(mem, debug, "show_memory_prompt"):
|
||||
migrated_any = True
|
||||
reasons.append("memory.show_memory_prompt_moved_to_debug")
|
||||
|
||||
for removed_key in (
|
||||
"agent_timeout_seconds",
|
||||
"max_agent_iterations",
|
||||
):
|
||||
if removed_key in mem:
|
||||
mem.pop(removed_key, None)
|
||||
migrated_any = True
|
||||
reasons.append(f"memory.{removed_key}_removed")
|
||||
|
||||
relationship = _as_dict(data.get("relationship"))
|
||||
if relationship is not None:
|
||||
data.pop("relationship", None)
|
||||
memory = _as_dict(data.get("memory"))
|
||||
if memory is not None and _migrate_target_item_list(memory, "global_memory_blacklist"):
|
||||
migrated_any = True
|
||||
reasons.append("relationship_removed")
|
||||
|
||||
exp = _as_dict(data.get("experimental"))
|
||||
if exp is not None:
|
||||
if _migrate_extra_prompt_list(exp, "chat_prompts"):
|
||||
migrated_any = True
|
||||
reasons.append("experimental.chat_prompts")
|
||||
|
||||
if "private_plan_style" in exp:
|
||||
exp.pop("private_plan_style", None)
|
||||
migrated_any = True
|
||||
reasons.append("experimental.private_plan_style_removed")
|
||||
|
||||
for key in ("group_chat_prompt", "private_chat_prompts", "chat_prompts"):
|
||||
if key in exp and key not in chat:
|
||||
chat[key] = exp[key]
|
||||
migrated_any = True
|
||||
reasons.append(f"experimental.{key}_moved_to_chat")
|
||||
|
||||
data.pop("experimental", None)
|
||||
migrated_any = True
|
||||
reasons.append("experimental_removed")
|
||||
|
||||
if chat is not None and "think_mode" in chat:
|
||||
chat.pop("think_mode", None)
|
||||
migrated_any = True
|
||||
reasons.append("chat.think_mode_removed")
|
||||
|
||||
tool = _as_dict(data.get("tool"))
|
||||
if tool is not None:
|
||||
data.pop("tool", None)
|
||||
migrated_any = True
|
||||
reasons.append("tool_section_removed")
|
||||
|
||||
# ExpressionConfig 中的 manual_reflect_operator_id:
|
||||
# 旧版本可能是 ""(字符串),新版本期望 Optional[TargetItem]。
|
||||
# 空字符串视为未配置,转换为 None/删除键以避免校验错误。
|
||||
expr = _as_dict(data.get("expression"))
|
||||
if expr is not None:
|
||||
mroi = expr.get("manual_reflect_operator_id")
|
||||
if isinstance(mroi, str) and not mroi.strip():
|
||||
expr.pop("manual_reflect_operator_id", None)
|
||||
migrated_any = True
|
||||
reasons.append("expression.manual_reflect_operator_id_empty")
|
||||
reasons.append("memory.global_memory_blacklist")
|
||||
|
||||
reason = ",".join(reasons)
|
||||
return MigrationResult(data=data, migrated=migrated_any, reason=reason)
|
||||
|
||||
@@ -307,6 +307,15 @@ class ModelInfo(ConfigBase):
|
||||
)
|
||||
"""强制流式输出模式 (若模型不支持非流式输出, 请设置为true启用强制流式输出, 默认值为false)"""
|
||||
|
||||
visual: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""是否为多模态模型。开启后表示该模型支持视觉输入。"""
|
||||
|
||||
extra_params: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
json_schema_extra={
|
||||
@@ -437,4 +446,4 @@ class ModelTaskConfig(ConfigBase):
|
||||
"x-icon": "database",
|
||||
},
|
||||
)
|
||||
"""嵌入模型配置"""
|
||||
"""嵌入模型配置"""
|
||||
|
||||
@@ -145,23 +145,23 @@ class VisualConfig(ConfigBase):
|
||||
__ui_label__ = "视觉"
|
||||
__ui_icon__ = "image"
|
||||
|
||||
multimodal_planner: bool = Field(
|
||||
default=True,
|
||||
planner_mode: Literal["text", "multimodal", "auto"] = Field(
|
||||
default="auto",
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""是否直接输入图片"""
|
||||
|
||||
multimodal_replyer: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-widget": "select",
|
||||
"x-icon": "git-branch",
|
||||
},
|
||||
)
|
||||
"""是否启用 Maisaka 多模态 replyer 生成器"""
|
||||
"""规划器模式,auto根据模型信息自动选择,text为纯文本模式,multimodal为多模态模式"""
|
||||
|
||||
replyer_mode: Literal["text", "multimodal", "auto"] = Field(
|
||||
default="auto",
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "git-branch",
|
||||
},
|
||||
)
|
||||
"""回复器模式,auto根据模型信息自动选择,text为纯文本模式,multimodal为多模态模式"""
|
||||
|
||||
visual_style: str = Field(
|
||||
default="请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本",
|
||||
@@ -239,16 +239,12 @@ class ChatConfig(ConfigBase):
|
||||
)
|
||||
"""Planner 连续被新消息打断的最大次数,0 表示不启用打断"""
|
||||
|
||||
plan_reply_log_max_per_chat: int = Field(
|
||||
default=1024,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "file-text",
|
||||
},
|
||||
)
|
||||
"""每个聊天流最大保存的Plan/Reply日志数量,超过此数量时会自动删除最老的日志"""
|
||||
group_chat_prompt: str = Field(
|
||||
default="你需要控制自己发言的频率,如果是一对一聊天,可以以较均匀的频率发言;如果用户较多,不要每句都回复,控制回复频率,不要回复的太频繁!控制回复的频率,不要每个人的消息都回复。",
|
||||
default="""
|
||||
你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片。
|
||||
回复尽量简短一些。最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。请注意把握聊天内容。
|
||||
不要回复的太频繁!控制回复的频率,不要每个人的消息都回复,只回复你感兴趣的或者主动提及你的。
|
||||
""",
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "users",
|
||||
@@ -257,7 +253,11 @@ class ChatConfig(ConfigBase):
|
||||
"""_wrap_群聊通用注意事项"""
|
||||
|
||||
private_chat_prompts: str = Field(
|
||||
default="你需要控制自己发言的频率,可以以较均匀的频率发言。",
|
||||
default="""
|
||||
你正在聊天,下面是正在聊的内容,其中包含聊天记录和聊天中的图片。
|
||||
回复尽量简短一些。请注意把握聊天内容。
|
||||
请考虑对方的发言频率,想法,思考自己何时回复以及回复内容。
|
||||
""",
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "user",
|
||||
@@ -740,6 +740,15 @@ class LearningItem(ConfigBase):
|
||||
)
|
||||
"""是否启用jargon学习"""
|
||||
|
||||
advanced_chosen: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "sparkles",
|
||||
},
|
||||
)
|
||||
"""是否启用基于子代理的二次表达方式选择"""
|
||||
|
||||
|
||||
class ExpressionGroup(ConfigBase):
|
||||
"""表达互通组配置类,若列表为空代表全局共享"""
|
||||
@@ -769,6 +778,7 @@ class ExpressionConfig(ConfigBase):
|
||||
use_expression=True,
|
||||
enable_learning=True,
|
||||
enable_jargon_learning=True,
|
||||
advanced_chosen=False,
|
||||
)
|
||||
],
|
||||
json_schema_extra={
|
||||
@@ -1640,35 +1650,6 @@ class MaiSakaConfig(ConfigBase):
|
||||
)
|
||||
"""MaiSaka 使用的用户名称"""
|
||||
|
||||
tool_filter_task_name: str = Field(
|
||||
default="utils",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "sparkles",
|
||||
},
|
||||
)
|
||||
"""工具筛选预判使用的模型任务名"""
|
||||
|
||||
tool_filter_threshold: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "filter",
|
||||
},
|
||||
)
|
||||
"""当可用工具总数超过该阈值时,先进行一轮工具筛选"""
|
||||
|
||||
tool_filter_max_keep: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "list-filter",
|
||||
},
|
||||
)
|
||||
"""工具筛选阶段最多保留的非内置工具数量"""
|
||||
|
||||
show_image_path: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -181,10 +181,7 @@ class ToolSpec:
|
||||
str: 合并后的单段工具描述。
|
||||
"""
|
||||
|
||||
parts = [self.brief_description.strip()]
|
||||
if self.detailed_description.strip():
|
||||
parts.append(self.detailed_description.strip())
|
||||
return "\n\n".join(part for part in parts if part).strip()
|
||||
return self.brief_description.strip()
|
||||
|
||||
def to_llm_definition(self) -> ToolDefinitionInput:
|
||||
"""转换为统一的 LLM 工具定义。
|
||||
@@ -389,7 +386,24 @@ class ToolRegistry:
|
||||
for provider in self._providers:
|
||||
provider_specs = await provider.list_tools()
|
||||
if any(spec.name == invocation.tool_name and spec.enabled for spec in provider_specs):
|
||||
return await provider.invoke(invocation, context)
|
||||
try:
|
||||
return await provider.invoke(invocation, context)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"工具调用异常: tool=%s provider=%s",
|
||||
invocation.tool_name,
|
||||
getattr(provider, "provider_name", ""),
|
||||
)
|
||||
error_message = str(exc).strip()
|
||||
if error_message:
|
||||
error_message = f"工具 {invocation.tool_name} 调用失败:{exc.__class__.__name__}: {error_message}"
|
||||
else:
|
||||
error_message = f"工具 {invocation.tool_name} 调用失败:{exc.__class__.__name__}"
|
||||
return ToolExecutionResult(
|
||||
tool_name=invocation.tool_name,
|
||||
success=False,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
return ToolExecutionResult(
|
||||
tool_name=invocation.tool_name,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
表达方式自动检查定时任务
|
||||
表达方式自动检查定时任务。
|
||||
|
||||
功能:
|
||||
1. 定期随机选取指定数量的表达方式
|
||||
@@ -9,52 +9,48 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from src.learners.expression_review_store import get_review_state, set_review_state
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.learners.expression_review_store import get_review_state, set_review_state
|
||||
from src.learners.expression_utils import parse_evaluation_response
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("expression_auto_check_task")
|
||||
logger = get_logger("expressor")
|
||||
|
||||
|
||||
def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
"""
|
||||
创建评估提示词
|
||||
创建评估提示词。
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
situation: 情景
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
评估提示词
|
||||
"""
|
||||
# 基础评估标准
|
||||
base_criteria = [
|
||||
"表达方式或言语风格 是否与使用条件或使用情景 匹配",
|
||||
"允许部分语法错误或口头化或缺省出现",
|
||||
"表达方式或言语风格是否与使用条件或使用情景匹配",
|
||||
"允许部分语法错误或口语化或缺省出现",
|
||||
"表达方式不能太过特指,需要具有泛用性",
|
||||
"一般不涉及具体的人名或名称",
|
||||
]
|
||||
|
||||
# 从配置中获取额外的自定义标准
|
||||
custom_criteria = global_config.expression.expression_auto_check_custom_criteria
|
||||
|
||||
# 合并所有评估标准
|
||||
all_criteria = base_criteria.copy()
|
||||
if custom_criteria:
|
||||
all_criteria.extend(custom_criteria)
|
||||
|
||||
# 构建评估标准列表字符串
|
||||
criteria_list = "\n".join([f"{i + 1}. {criterion}" for i, criterion in enumerate(all_criteria)])
|
||||
|
||||
prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
|
||||
@@ -64,14 +60,13 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
请从以下方面进行评估:
|
||||
{criteria_list}
|
||||
|
||||
请以JSON格式输出评估结果:
|
||||
请以 JSON 格式输出评估结果:
|
||||
{{
|
||||
"suitable": true/false,
|
||||
"reason": "评估理由(如果不合适,请说明原因)"
|
||||
|
||||
}}
|
||||
如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。
|
||||
请严格按照JSON格式输出,不要包含其他内容。"""
|
||||
如果合适,suitable 设为 true;如果不合适,suitable 设为 false,并在 reason 中说明原因。
|
||||
请严格按照 JSON 格式输出,不要包含其他内容。"""
|
||||
|
||||
return prompt
|
||||
|
||||
@@ -81,10 +76,10 @@ judge_llm = LLMServiceClient(task_name="utils", request_type="expression_check")
|
||||
|
||||
async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str | None]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
执行单次 LLM 评估。
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
situation: 情景
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
@@ -101,20 +96,10 @@ async def single_expression_check(situation: str, style: str) -> tuple[bool, str
|
||||
response = generation_result.response
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError as e:
|
||||
import re
|
||||
evaluation = parse_evaluation_response(response)
|
||||
|
||||
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
||||
if json_match:
|
||||
evaluation = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
|
||||
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
suitable = bool(evaluation.get("suitable", False))
|
||||
reason = str(evaluation.get("reason", "未提供理由"))
|
||||
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
@@ -125,20 +110,19 @@ async def single_expression_check(situation: str, style: str) -> tuple[bool, str
|
||||
|
||||
|
||||
class ExpressionAutoCheckTask(AsyncTask):
|
||||
"""表达方式自动检查定时任务"""
|
||||
"""表达方式自动检查定时任务。"""
|
||||
|
||||
def __init__(self):
|
||||
# 从配置中获取检查间隔和一次检查数量
|
||||
check_interval = global_config.expression.expression_auto_check_interval
|
||||
super().__init__(
|
||||
task_name="Expression Auto Check Task",
|
||||
wait_before_start=60, # 启动后等待60秒再开始第一次检查
|
||||
wait_before_start=60,
|
||||
run_interval=check_interval,
|
||||
)
|
||||
|
||||
async def _select_expressions(self, count: int) -> List[Expression]:
|
||||
"""
|
||||
随机选择指定数量的未检查表达方式
|
||||
随机选择指定数量的未检查表达方式。
|
||||
|
||||
Args:
|
||||
count: 需要选择的数量
|
||||
@@ -158,11 +142,12 @@ class ExpressionAutoCheckTask(AsyncTask):
|
||||
logger.info("没有未检查的表达方式")
|
||||
return []
|
||||
|
||||
# 随机选择指定数量
|
||||
selected_count = min(count, len(unevaluated_expressions))
|
||||
selected = random.sample(unevaluated_expressions, selected_count)
|
||||
|
||||
logger.info(f"从 {len(unevaluated_expressions)} 条未检查表达方式中随机选择了 {selected_count} 条")
|
||||
logger.info(
|
||||
f"从 {len(unevaluated_expressions)} 条未检查表达方式中随机选择了 {selected_count} 条"
|
||||
)
|
||||
return selected
|
||||
|
||||
except Exception as e:
|
||||
@@ -171,35 +156,35 @@ class ExpressionAutoCheckTask(AsyncTask):
|
||||
|
||||
async def _evaluate_expression(self, expression: Expression) -> bool:
|
||||
"""
|
||||
评估单个表达方式
|
||||
评估单个表达方式。
|
||||
|
||||
Args:
|
||||
expression: 要评估的表达方式
|
||||
|
||||
Returns:
|
||||
True表示通过,False表示不通过
|
||||
True 表示通过,False 表示不通过
|
||||
"""
|
||||
|
||||
suitable, reason, error = await single_expression_check(
|
||||
expression.situation,
|
||||
expression.style,
|
||||
)
|
||||
|
||||
# 更新数据库
|
||||
try:
|
||||
set_review_state(expression.id, True, not suitable, "ai")
|
||||
|
||||
status = "通过" if suitable else "不通过"
|
||||
# 保留这段注释,方便后续需要时恢复更详细的审核日志。
|
||||
# logger.info(
|
||||
# f"表达方式评估完成 [ID: {expression.id}] - {status} | "
|
||||
# f"Situation: {expression.situation}... | "
|
||||
# f"Style: {expression.style}... | "
|
||||
# f"Reason: {reason[:50]}..."
|
||||
# f"表达方式评估完成 [ID: {expression.id}] - {status} | "
|
||||
# f"Situation: {expression.situation}... | "
|
||||
# f"Style: {expression.style}... | "
|
||||
# f"Reason: {reason[:50]}..."
|
||||
# )
|
||||
|
||||
if error:
|
||||
logger.warning(f"表达方式评估时出现错误 [ID: {expression.id}]: {error}")
|
||||
|
||||
logger.debug(f"表达方式 [ID: {expression.id}] 评估完成: {status}, reason={reason}")
|
||||
return suitable
|
||||
|
||||
except Exception as e:
|
||||
@@ -207,9 +192,8 @@ class ExpressionAutoCheckTask(AsyncTask):
|
||||
return False
|
||||
|
||||
async def run(self):
|
||||
"""执行检查任务"""
|
||||
"""执行检查任务。"""
|
||||
try:
|
||||
# 检查是否启用自动检查
|
||||
if not global_config.expression.expression_self_reflect:
|
||||
logger.debug("表达方式自动检查未启用,跳过本次执行")
|
||||
return
|
||||
@@ -221,26 +205,22 @@ class ExpressionAutoCheckTask(AsyncTask):
|
||||
|
||||
logger.info(f"开始执行表达方式自动检查,本次将检查 {check_count} 条")
|
||||
|
||||
# 选择要检查的表达方式
|
||||
expressions = await self._select_expressions(check_count)
|
||||
|
||||
if not expressions:
|
||||
logger.info("没有需要检查的表达方式")
|
||||
return
|
||||
|
||||
# 逐个评估
|
||||
passed_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
logger.info(f"正在评估 [{i}/{len(expressions)}]: ID={expression.id}")
|
||||
for index, expression in enumerate(expressions, 1):
|
||||
logger.debug(f"正在评估 [{index}/{len(expressions)}]: ID={expression.id}")
|
||||
|
||||
if await self._evaluate_expression(expression):
|
||||
passed_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
# 避免请求过快
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -267,7 +267,7 @@ class ExpressionLearner:
|
||||
return normalized_entries
|
||||
|
||||
def get_pending_count(self, message_cache: List["SessionMessage"]) -> int:
|
||||
"""??????????????"""
|
||||
"""获取待处理消息数量"""
|
||||
return max(0, len(message_cache) - self._last_processed_index)
|
||||
|
||||
async def learn(
|
||||
@@ -275,10 +275,10 @@ class ExpressionLearner:
|
||||
message_cache: List["SessionMessage"],
|
||||
jargon_miner: Optional["JargonMiner"] = None,
|
||||
) -> bool:
|
||||
"""?????????????????????"""
|
||||
"""学习表达方式"""
|
||||
pending_messages = message_cache[self._last_processed_index :]
|
||||
if not pending_messages:
|
||||
logger.debug("??????????????????")
|
||||
logger.debug("没有待处理消息")
|
||||
return False
|
||||
if len(pending_messages) < self.min_messages_for_extraction:
|
||||
return False
|
||||
@@ -304,7 +304,7 @@ class ExpressionLearner:
|
||||
)
|
||||
response = generation_result.response
|
||||
except Exception as e:
|
||||
logger.error(f"????????????????{e}")
|
||||
logger.error(f"学习表达方式失败: {e}")
|
||||
return False
|
||||
|
||||
expressions: List[Tuple[str, str, str]]
|
||||
@@ -319,14 +319,14 @@ class ExpressionLearner:
|
||||
continue
|
||||
jargon_entries.append((content, source_id))
|
||||
existing_contents.add(content)
|
||||
logger.info(f"??????????{content}")
|
||||
logger.info(f"从缓存中找到黑话: {content}")
|
||||
|
||||
if len(expressions) > 20:
|
||||
logger.info(f"?????????? 20 ???????????{len(expressions)}")
|
||||
logger.info(f"表达方式数量超过20: {len(expressions)}")
|
||||
expressions = []
|
||||
|
||||
if len(jargon_entries) > 30:
|
||||
logger.info(f"???????? 30 ???????????{len(jargon_entries)}")
|
||||
logger.info(f"黑话数量超过30: {len(jargon_entries)}")
|
||||
jargon_entries = []
|
||||
|
||||
after_extract_result = await self._get_runtime_manager().invoke_hook(
|
||||
@@ -337,7 +337,7 @@ class ExpressionLearner:
|
||||
jargon_entries=self._serialize_jargon_entries(jargon_entries),
|
||||
)
|
||||
if after_extract_result.aborted:
|
||||
logger.info(f"{self.session_id} ?????????? Hook ??")
|
||||
logger.info(f"{self.session_id} 表达方式选择 Hook 中止")
|
||||
self._last_processed_index = len(message_cache)
|
||||
return False
|
||||
|
||||
@@ -353,21 +353,21 @@ class ExpressionLearner:
|
||||
await self._process_jargon_entries(jargon_entries, pending_messages, jargon_miner)
|
||||
|
||||
if not expressions:
|
||||
logger.info("????????????")
|
||||
logger.info("没有可学习的表达方式")
|
||||
self._last_processed_index = len(message_cache)
|
||||
return False
|
||||
|
||||
logger.info(f"???? expressions: {expressions}")
|
||||
logger.info(f"???? jargon_entries: {jargon_entries}")
|
||||
logger.info(f"可学习的表达方式: {expressions}")
|
||||
logger.info(f"可学习的黑话: {jargon_entries}")
|
||||
|
||||
learnt_expressions = self._filter_expressions(expressions, pending_messages)
|
||||
if not learnt_expressions:
|
||||
logger.info("????????????")
|
||||
logger.info("没有可学习的表达方式通过过滤")
|
||||
self._last_processed_index = len(message_cache)
|
||||
return False
|
||||
|
||||
learnt_expressions_str = "\n".join(f"{situation}->{style}" for situation, style in learnt_expressions)
|
||||
logger.info(f"? {self.session_id} ????????\n{learnt_expressions_str}")
|
||||
logger.info(f"{self.session_id} 可学习的表达方式: \n{learnt_expressions_str}")
|
||||
|
||||
for situation, style in learnt_expressions:
|
||||
before_upsert_result = await self._get_runtime_manager().invoke_hook(
|
||||
@@ -377,14 +377,14 @@ class ExpressionLearner:
|
||||
style=style,
|
||||
)
|
||||
if before_upsert_result.aborted:
|
||||
logger.info(f"{self.session_id} ???????? Hook ??: situation={situation!r}")
|
||||
logger.info(f"{self.session_id} 表达方式写入 Hook 中止: situation={situation!r}")
|
||||
continue
|
||||
|
||||
upsert_kwargs = before_upsert_result.kwargs
|
||||
situation = str(upsert_kwargs.get("situation", situation) or "").strip()
|
||||
style = str(upsert_kwargs.get("style", style) or "").strip()
|
||||
if not situation or not style:
|
||||
logger.info(f"{self.session_id} ???????? Hook ??????")
|
||||
logger.info(f"{self.session_id} 表达方式写入 Hook 中止: situation={situation!r}")
|
||||
continue
|
||||
await self._upsert_expression_to_db(situation, style)
|
||||
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from json_repair import repair_json
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("expression_utils")
|
||||
|
||||
@@ -16,17 +16,7 @@ judge_llm = LLMServiceClient(task_name="utils", request_type="expression_check")
|
||||
|
||||
|
||||
def _normalize_repair_json_result(repaired_result: Any) -> str:
|
||||
"""将 repair_json 的返回值规范化为 JSON 字符串。
|
||||
|
||||
Args:
|
||||
repaired_result: `repair_json` 的返回值,可能是字符串或带附加信息的元组。
|
||||
|
||||
Returns:
|
||||
str: 可供 `json.loads` 继续解析的 JSON 字符串。
|
||||
|
||||
Raises:
|
||||
TypeError: 当返回值无法规范化为字符串时抛出。
|
||||
"""
|
||||
"""将 `repair_json` 的返回结果统一转换为字符串。"""
|
||||
if isinstance(repaired_result, str):
|
||||
return repaired_result
|
||||
if isinstance(repaired_result, tuple) and repaired_result:
|
||||
@@ -37,22 +27,121 @@ def _normalize_repair_json_result(repaired_result: Any) -> str:
|
||||
raise TypeError(f"repair_json 返回了无法处理的结果类型: {type(repaired_result)}")
|
||||
|
||||
|
||||
def _strip_markdown_code_fence(text: str) -> str:
|
||||
"""移除 LLM 可能附带的 Markdown 代码块包裹。"""
|
||||
raw = text.strip()
|
||||
if match := re.search(r"```json\s*(.*?)\s*```", raw, re.DOTALL):
|
||||
return match[1].strip()
|
||||
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
|
||||
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
|
||||
return raw.strip()
|
||||
|
||||
|
||||
def _extract_json_object_candidate(text: str) -> str:
|
||||
"""尽量从文本中提取首个 JSON 对象片段。"""
|
||||
start_index = text.find("{")
|
||||
end_index = text.rfind("}")
|
||||
if start_index != -1 and end_index != -1 and start_index < end_index:
|
||||
return text[start_index : end_index + 1].strip()
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _extract_reason_from_text(text: str) -> Optional[str]:
|
||||
"""从格式不完整的 JSON 文本中兜底提取 reason 字段。"""
|
||||
reason_key_match = re.search(r'["“”]?reason["“”]?\s*:\s*', text, re.IGNORECASE)
|
||||
if reason_key_match is None:
|
||||
return None
|
||||
|
||||
value_text = text[reason_key_match.end() :].strip()
|
||||
if not value_text:
|
||||
return None
|
||||
|
||||
if value_text.endswith("}"):
|
||||
value_text = value_text[:-1].rstrip()
|
||||
if value_text.endswith(","):
|
||||
value_text = value_text[:-1].rstrip()
|
||||
if not value_text:
|
||||
return None
|
||||
|
||||
if value_text[0] in {'"', "'", "“", "”", "‘", "’"}:
|
||||
value_text = value_text[1:]
|
||||
while value_text and value_text[-1] in {'"', "'", "“", "”", "‘", "’"}:
|
||||
value_text = value_text[:-1].rstrip()
|
||||
|
||||
return value_text.strip() or None
|
||||
|
||||
|
||||
def _normalize_reason_text(reason: Any) -> str:
|
||||
"""清理解析后 reason 中残留的包裹引号。"""
|
||||
normalized_reason = str(reason).strip()
|
||||
|
||||
if len(normalized_reason) >= 2 and normalized_reason[0] == normalized_reason[-1]:
|
||||
if normalized_reason[0] in {'"', "'", "“", "”", "‘", "’"}:
|
||||
normalized_reason = normalized_reason[1:-1].strip()
|
||||
|
||||
if normalized_reason.endswith('"') and normalized_reason.count('"') % 2 == 1:
|
||||
normalized_reason = normalized_reason[:-1].rstrip()
|
||||
if normalized_reason.endswith("'") and normalized_reason.count("'") % 2 == 1:
|
||||
normalized_reason = normalized_reason[:-1].rstrip()
|
||||
if normalized_reason.endswith('"') and not normalized_reason.startswith('"'):
|
||||
normalized_reason = normalized_reason[:-1].rstrip()
|
||||
if normalized_reason.endswith("'") and not normalized_reason.startswith("'"):
|
||||
normalized_reason = normalized_reason[:-1].rstrip()
|
||||
|
||||
return normalized_reason
|
||||
|
||||
|
||||
def parse_evaluation_response(response: str) -> Dict[str, Any]:
|
||||
"""解析表达方式评估结果,兼容不完全合法的 JSON。"""
|
||||
raw = _strip_markdown_code_fence(response)
|
||||
if not raw:
|
||||
raise ValueError("LLM 响应为空")
|
||||
|
||||
parse_candidates = [raw]
|
||||
json_candidate = _extract_json_object_candidate(raw)
|
||||
if json_candidate and json_candidate not in parse_candidates:
|
||||
parse_candidates.append(json_candidate)
|
||||
|
||||
for candidate in parse_candidates:
|
||||
parsed = _try_parse(candidate)
|
||||
if isinstance(parsed, dict):
|
||||
if "reason" in parsed:
|
||||
parsed["reason"] = _normalize_reason_text(parsed["reason"])
|
||||
return parsed
|
||||
|
||||
fixed_candidate = fix_chinese_quotes_in_json(candidate)
|
||||
if fixed_candidate != candidate:
|
||||
parsed = _try_parse(fixed_candidate)
|
||||
if isinstance(parsed, dict):
|
||||
if "reason" in parsed:
|
||||
parsed["reason"] = _normalize_reason_text(parsed["reason"])
|
||||
return parsed
|
||||
|
||||
suitable_match = re.search(r'["“”]?suitable["“”]?\s*:\s*(true|false)', raw, re.IGNORECASE)
|
||||
reason = _extract_reason_from_text(json_candidate or raw)
|
||||
if suitable_match is None or reason is None:
|
||||
raise ValueError(f"无法解析 LLM 响应为评估结果 JSON: {response}")
|
||||
|
||||
return {
|
||||
"suitable": suitable_match.group(1).lower() == "true",
|
||||
"reason": _normalize_reason_text(reason),
|
||||
}
|
||||
|
||||
|
||||
async def check_expression_suitability(situation: str, style: str) -> Tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
执行单次 LLM 评估。
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
situation: 情景
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
(suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息
|
||||
"""
|
||||
# 构建评估提示词
|
||||
# 基础评估标准
|
||||
base_criteria = [
|
||||
"表达方式或言语风格是否与使用条件或使用情景匹配",
|
||||
"允许部分语法错误或口头化或缺省出现",
|
||||
"允许部分语法错误或口语化或缺省出现",
|
||||
"表达方式不能太过特指,需要具有泛用性",
|
||||
"一般不涉及具体的人名或名称",
|
||||
]
|
||||
@@ -60,7 +149,6 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool
|
||||
if custom_criteria := global_config.expression.expression_auto_check_custom_criteria:
|
||||
base_criteria.extend(custom_criteria)
|
||||
|
||||
# 构建评估标准列表字符串
|
||||
criteria_list = "\n".join([f"{i + 1}. {criterion}" for i, criterion in enumerate(base_criteria)])
|
||||
|
||||
prompt_template = prompt_manager.get_prompt("expression_evaluation")
|
||||
@@ -81,18 +169,13 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool
|
||||
logger.debug(f"评估结果: {response}")
|
||||
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
response_repaired = _normalize_repair_json_result(repair_json(response))
|
||||
evaluation = json.loads(response_repaired)
|
||||
except Exception as e:
|
||||
raise ValueError(f"无法解析LLM响应为JSON: {response}") from e
|
||||
evaluation = parse_evaluation_response(response)
|
||||
except Exception as e:
|
||||
return False, f"评估表达方式时发生错误: {e}", str(e)
|
||||
|
||||
try:
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
suitable = bool(evaluation.get("suitable", False))
|
||||
reason = _normalize_reason_text(evaluation.get("reason", "未提供理由"))
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
except Exception as e:
|
||||
@@ -100,69 +183,48 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool
|
||||
|
||||
|
||||
def fix_chinese_quotes_in_json(text: str) -> str:
|
||||
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
||||
result = []
|
||||
i = 0
|
||||
"""使用状态机修复 JSON 字符串值中的中文引号。"""
|
||||
result: List[str] = []
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
for char in text:
|
||||
if escape_next:
|
||||
# 当前字符是转义字符后的字符,直接添加
|
||||
result.append(char)
|
||||
escape_next = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if char == "\\":
|
||||
# 转义字符
|
||||
result.append(char)
|
||||
escape_next = True
|
||||
i += 1
|
||||
continue
|
||||
if char == '"' and not escape_next:
|
||||
# 遇到英文引号,切换字符串状态
|
||||
|
||||
if char == '"':
|
||||
in_string = not in_string
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_string and char in ["“", "”"]:
|
||||
result.append('\\"')
|
||||
else:
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
result.append(char)
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
"""
|
||||
解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。
|
||||
|
||||
期望的 JSON 结构:
|
||||
[
|
||||
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
|
||||
{"content": "词条", "source_id": "12"}, // 黑话
|
||||
...
|
||||
]
|
||||
解析 LLM 返回的表达方式总结和黑话 JSON,提取两个列表。
|
||||
|
||||
Returns:
|
||||
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
第一个列表是表达方式 (situation, style, source_id)
|
||||
第二个列表是黑话 (content, source_id)
|
||||
第一个列表是表达方式 (situation, style, source_id)
|
||||
第二个列表是黑话 (content, source_id)
|
||||
"""
|
||||
if not response:
|
||||
return [], []
|
||||
|
||||
raw = response.strip()
|
||||
|
||||
if match := re.search(r"```json\s*(.*?)\s*```", raw, re.DOTALL):
|
||||
raw = match[1].strip()
|
||||
else:
|
||||
# 去掉可能存在的通用 ``` 包裹
|
||||
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
|
||||
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
|
||||
raw = raw.strip()
|
||||
raw = _strip_markdown_code_fence(response)
|
||||
|
||||
parsed = _try_parse(raw)
|
||||
if parsed is None:
|
||||
@@ -180,22 +242,21 @@ def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]]
|
||||
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
|
||||
return [], []
|
||||
|
||||
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
|
||||
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
|
||||
expressions: List[Tuple[str, str, str]] = []
|
||||
jargon_entries: List[Tuple[str, str]] = []
|
||||
|
||||
for item in parsed_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# 检查是否是表达方式条目(有 situation 和 style)
|
||||
situation = str(item.get("situation", "")).strip()
|
||||
style = str(item.get("style", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
|
||||
if situation and style and source_id:
|
||||
# 表达方式条目
|
||||
expressions.append((situation, style, source_id))
|
||||
continue
|
||||
|
||||
content = str(item.get("content", "")).strip()
|
||||
if content and source_id:
|
||||
jargon_entries.append((content, source_id))
|
||||
@@ -204,25 +265,16 @@ def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]]
|
||||
|
||||
|
||||
def is_single_char_jargon(content: str) -> bool:
|
||||
"""
|
||||
判断是否是单字黑话(单个汉字、英文或数字)
|
||||
|
||||
Args:
|
||||
content: 词条内容
|
||||
|
||||
Returns:
|
||||
bool: 如果是单字黑话返回True,否则返回False
|
||||
"""
|
||||
"""判断是否是单字黑话(单个汉字、英文或数字)。"""
|
||||
if not content or len(content) != 1:
|
||||
return False
|
||||
|
||||
char = content[0]
|
||||
# 判断是否是单个汉字、单个英文字母或单个数字
|
||||
return (
|
||||
"\u4e00" <= char <= "\u9fff" # 汉字
|
||||
or "a" <= char <= "z" # 小写字母
|
||||
or "A" <= char <= "Z" # 大写字母
|
||||
or "0" <= char <= "9" # 数字
|
||||
"\u4e00" <= char <= "\u9fff"
|
||||
or "a" <= char <= "z"
|
||||
or "A" <= char <= "Z"
|
||||
or "0" <= char <= "9"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -54,27 +54,30 @@ class RespNotOkException(Exception):
|
||||
return f"未知的异常响应代码:{self.status_code}"
|
||||
|
||||
|
||||
class RespParseException(Exception):
|
||||
"""响应解析错误,常见于响应格式不正确或解析方法不匹配"""
|
||||
class ResponseContextException(Exception):
|
||||
"""携带原始响应上下文的异常基类。"""
|
||||
|
||||
def __init__(self, ext_info: Any, message: str | None = None):
|
||||
default_message: str = "请求失败"
|
||||
|
||||
def __init__(self, ext_info: Any = None, message: str | None = None):
|
||||
super().__init__(message)
|
||||
self.ext_info = ext_info
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return self.message or "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法"
|
||||
return self.message or self.default_message
|
||||
|
||||
|
||||
class EmptyResponseException(Exception):
|
||||
class RespParseException(ResponseContextException):
|
||||
"""响应解析错误,常见于响应格式不正确或解析方法不匹配"""
|
||||
|
||||
default_message = "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法"
|
||||
|
||||
|
||||
class EmptyResponseException(ResponseContextException):
|
||||
"""响应内容为空"""
|
||||
|
||||
def __init__(self, message: str = "响应内容为空,这可能是一个临时性问题"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
default_message = "响应内容为空,这可能是一个临时性问题"
|
||||
|
||||
|
||||
class ModelAttemptFailed(Exception):
|
||||
|
||||
@@ -552,7 +552,7 @@ def _build_stream_api_response(
|
||||
|
||||
_warn_if_max_tokens_truncated(last_response, response.content, response.tool_calls)
|
||||
if not response.content and not response.tool_calls and not response.reasoning_content:
|
||||
raise EmptyResponseException()
|
||||
raise EmptyResponseException(last_response)
|
||||
return response
|
||||
|
||||
|
||||
@@ -627,7 +627,7 @@ def _default_normal_response_parser(
|
||||
usage_record = _extract_usage_record(response)
|
||||
_warn_if_max_tokens_truncated(response, api_response.content, api_response.tool_calls)
|
||||
if not api_response.content and not api_response.tool_calls and not api_response.reasoning_content:
|
||||
raise EmptyResponseException("响应中既无文本内容也无工具调用")
|
||||
raise EmptyResponseException(response, "响应中既无文本内容也无工具调用")
|
||||
return api_response, usage_record
|
||||
|
||||
|
||||
|
||||
@@ -79,6 +79,25 @@ THINK_CONTENT_PATTERN = re.compile(
|
||||
)
|
||||
"""用于解析 `<think>` 推理块的正则表达式。"""
|
||||
|
||||
XML_TOOL_CALL_PATTERN = re.compile(r"<tool_call>\s*(?P<body>.*?)\s*</tool_call>", re.DOTALL | re.IGNORECASE)
|
||||
"""用于兜底解析模型以 XML 文本返回的工具调用。
|
||||
|
||||
这是一个暂时性兼容方案,专门处理“思维链内容里夹带工具调用”的情况;
|
||||
后续如果上游稳定返回标准 tool_calls 字段,这里可能会调整或移除。
|
||||
"""
|
||||
|
||||
XML_FUNCTION_CALL_PATTERN = re.compile(
|
||||
r"<function=(?P<name>[A-Za-z0-9_.-]+)>\s*(?P<arguments>.*?)\s*</function>",
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
"""用于从 XML 风格工具调用块中提取函数名与参数。"""
|
||||
|
||||
XML_PARAMETER_PATTERN = re.compile(
|
||||
r"<parameter=(?P<name>[A-Za-z0-9_.-]+)>\s*(?P<value>.*?)\s*</parameter>",
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
"""用于从 XML 风格工具调用块中提取参数列表。"""
|
||||
|
||||
CHAT_COMPLETIONS_RESERVED_EXTRA_BODY_KEYS = {
|
||||
"max_tokens",
|
||||
"messages",
|
||||
@@ -346,6 +365,32 @@ def _convert_assistant_tool_calls(tool_calls: List[ToolCall]) -> List[ChatComple
|
||||
return converted_tool_calls
|
||||
|
||||
|
||||
def _sanitize_messages_for_toolless_request(messages: List[Message]) -> List[Message]:
|
||||
"""在无工具请求时清洗历史工具调用链,避免兼容接口拒收消息。"""
|
||||
sanitized_messages: List[Message] = []
|
||||
|
||||
for message in messages:
|
||||
if message.role == RoleType.Tool:
|
||||
continue
|
||||
|
||||
if message.role == RoleType.Assistant and message.tool_calls:
|
||||
if not message.parts:
|
||||
continue
|
||||
assistant_message = Message(
|
||||
role=message.role,
|
||||
parts=list(message.parts),
|
||||
tool_call_id=message.tool_call_id,
|
||||
tool_name=message.tool_name,
|
||||
tool_calls=None,
|
||||
)
|
||||
sanitized_messages.append(assistant_message)
|
||||
continue
|
||||
|
||||
sanitized_messages.append(message)
|
||||
|
||||
return sanitized_messages
|
||||
|
||||
|
||||
def _convert_messages(messages: List[Message]) -> List[ChatCompletionMessageParam]:
|
||||
"""将内部消息列表转换为 OpenAI 兼容消息列表。
|
||||
|
||||
@@ -515,6 +560,66 @@ def _extract_reasoning_and_content(
|
||||
return None, match.group("content_only").strip() or None
|
||||
|
||||
|
||||
def _extract_xml_tool_calls(
|
||||
raw_text: str | None,
|
||||
parse_mode: ToolArgumentParseMode,
|
||||
response: Any,
|
||||
) -> Tuple[str | None, List[ToolCall] | None]:
|
||||
"""从 XML 风格文本中兜底提取工具调用。"""
|
||||
if not isinstance(raw_text, str) or not raw_text.strip():
|
||||
return raw_text, None
|
||||
|
||||
tool_calls: List[ToolCall] = []
|
||||
|
||||
def _coerce_xml_parameter_value(raw_value: str) -> Any:
|
||||
normalized_value = raw_value.strip()
|
||||
if not normalized_value:
|
||||
return ""
|
||||
lowered_value = normalized_value.lower()
|
||||
if lowered_value == "true":
|
||||
return True
|
||||
if lowered_value == "false":
|
||||
return False
|
||||
if lowered_value in {"null", "none"}:
|
||||
return None
|
||||
if normalized_value.startswith(("{", "[")):
|
||||
try:
|
||||
return repair_json(normalized_value, return_objects=True, logging=False)
|
||||
except Exception:
|
||||
return normalized_value
|
||||
return normalized_value
|
||||
|
||||
def _parse_xml_parameters(raw_arguments: str) -> Dict[str, Any] | None:
|
||||
parameters = {
|
||||
match.group("name").strip(): _coerce_xml_parameter_value(match.group("value"))
|
||||
for match in XML_PARAMETER_PATTERN.finditer(raw_arguments)
|
||||
}
|
||||
return parameters or None
|
||||
|
||||
def _replace_tool_call(match: re.Match[str]) -> str:
|
||||
body = match.group("body")
|
||||
function_match = XML_FUNCTION_CALL_PATTERN.search(body)
|
||||
if function_match is None:
|
||||
return match.group(0)
|
||||
|
||||
function_name = function_match.group("name").strip()
|
||||
raw_arguments = function_match.group("arguments").strip()
|
||||
arguments = _parse_xml_parameters(raw_arguments)
|
||||
if arguments is None:
|
||||
arguments = _parse_tool_arguments(raw_arguments, parse_mode, response) if raw_arguments else {}
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=f"xml_tool_call_{len(tool_calls) + 1}",
|
||||
func_name=function_name,
|
||||
args=arguments,
|
||||
)
|
||||
)
|
||||
return ""
|
||||
|
||||
cleaned_text = XML_TOOL_CALL_PATTERN.sub(_replace_tool_call, raw_text).strip() or None
|
||||
return cleaned_text, tool_calls or None
|
||||
|
||||
|
||||
def _log_length_truncation(finish_reason: str | None, model_name: str | None) -> None:
|
||||
"""记录因长度截断导致的告警日志。
|
||||
|
||||
@@ -526,6 +631,38 @@ def _log_length_truncation(finish_reason: str | None, model_name: str | None) ->
|
||||
logger.info(f"模型{model_name or ''}因为超过最大 max_token 限制,可能仅输出部分内容,可视情况调整")
|
||||
|
||||
|
||||
def _apply_xml_tool_call_fallback(
|
||||
response: APIResponse,
|
||||
parse_mode: ToolArgumentParseMode,
|
||||
raw_response: Any,
|
||||
) -> None:
|
||||
"""当上游未返回标准 tool_calls 时,尝试从 XML 文本兜底解析。
|
||||
|
||||
这是一个暂时性处理方法,用来兼容思维链中混入工具调用的返回格式,
|
||||
后续可能随着模型或上游接口的规范化而变更。
|
||||
"""
|
||||
if response.tool_calls:
|
||||
return
|
||||
|
||||
reasoning_content, tool_calls = _extract_xml_tool_calls(response.reasoning_content, parse_mode, raw_response)
|
||||
if reasoning_content != response.reasoning_content:
|
||||
response.reasoning_content = reasoning_content
|
||||
if tool_calls:
|
||||
response.tool_calls = tool_calls
|
||||
if not response.content and reasoning_content:
|
||||
response.content = reasoning_content
|
||||
response.reasoning_content = None
|
||||
logger.warning("OpenAI 兼容响应未返回标准 tool_calls,已从 XML 文本兜底解析工具调用")
|
||||
return
|
||||
|
||||
cleaned_content, tool_calls = _extract_xml_tool_calls(response.content, parse_mode, raw_response)
|
||||
if cleaned_content != response.content:
|
||||
response.content = cleaned_content
|
||||
if tool_calls:
|
||||
response.tool_calls = tool_calls
|
||||
logger.warning("OpenAI 兼容响应未返回标准 tool_calls,已从 XML 文本兜底解析工具调用")
|
||||
|
||||
|
||||
def _coerce_openai_argument(value: Any) -> Any | Omit:
|
||||
"""将可选参数转换为 OpenAI SDK 期望的值。
|
||||
|
||||
@@ -561,7 +698,7 @@ def _build_api_status_message(error: APIStatusError) -> str:
|
||||
message_parts.append(str(error.message))
|
||||
response_text = getattr(getattr(error, "response", None), "text", None)
|
||||
if response_text:
|
||||
message_parts.append(str(response_text)[:300])
|
||||
message_parts.append(str(response_text))
|
||||
if message_parts:
|
||||
return " | ".join(message_parts)
|
||||
return f"上游接口返回状态码 {error.status_code}"
|
||||
@@ -722,9 +859,10 @@ class _OpenAIStreamAccumulator:
|
||||
response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments))
|
||||
|
||||
response.raw_data = {"model": self.model_name} if self.model_name else None
|
||||
_apply_xml_tool_call_fallback(response, self.tool_argument_parse_mode, response.raw_data)
|
||||
|
||||
if not response.content and not response.tool_calls:
|
||||
raise EmptyResponseException()
|
||||
raise EmptyResponseException(response.raw_data)
|
||||
|
||||
return response
|
||||
|
||||
@@ -808,7 +946,7 @@ def _default_normal_response_parser(
|
||||
"""
|
||||
choices = getattr(resp, "choices", None)
|
||||
if not choices:
|
||||
raise EmptyResponseException("响应解析失败,choices 为空或缺失")
|
||||
raise EmptyResponseException(resp, "响应解析失败,choices 为空或缺失")
|
||||
|
||||
api_response = APIResponse()
|
||||
message_part = choices[0].message
|
||||
@@ -847,9 +985,10 @@ def _default_normal_response_parser(
|
||||
|
||||
finish_reason = getattr(resp.choices[0], "finish_reason", None)
|
||||
_log_length_truncation(finish_reason, getattr(resp, "model", None))
|
||||
_apply_xml_tool_call_fallback(api_response, tool_argument_parse_mode, resp)
|
||||
|
||||
if not api_response.content and not api_response.tool_calls:
|
||||
raise EmptyResponseException()
|
||||
raise EmptyResponseException(resp)
|
||||
|
||||
return api_response, usage_record
|
||||
|
||||
@@ -965,7 +1104,12 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
||||
model_info = request.model_info
|
||||
|
||||
try:
|
||||
messages_payload: List[ChatCompletionMessageParam] = _convert_messages(request.message_list)
|
||||
request_messages = (
|
||||
list(request.message_list)
|
||||
if request.tool_options
|
||||
else _sanitize_messages_for_toolless_request(request.message_list)
|
||||
)
|
||||
messages_payload: List[ChatCompletionMessageParam] = _convert_messages(request_messages)
|
||||
tools_payload: List[ChatCompletionToolParam] | None = (
|
||||
_convert_tool_options(request.tool_options) if request.tool_options else None
|
||||
)
|
||||
|
||||
@@ -58,6 +58,42 @@ def _json_friendly(value: Any) -> Any:
|
||||
return str(value)
|
||||
|
||||
|
||||
def extract_error_response_body(error: Exception) -> Any | None:
|
||||
"""尽量从异常对象中提取上游返回体,便于排查模型请求失败。"""
|
||||
candidate_errors = [error, getattr(error, "__cause__", None)]
|
||||
|
||||
for candidate in candidate_errors:
|
||||
if candidate is None:
|
||||
continue
|
||||
|
||||
response = getattr(candidate, "response", None)
|
||||
if response is not None:
|
||||
response_json = getattr(response, "json", None)
|
||||
if callable(response_json):
|
||||
try:
|
||||
return _json_friendly(response_json())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
response_text = getattr(response, "text", None)
|
||||
if response_text not in (None, ""):
|
||||
return str(response_text)
|
||||
|
||||
response_content = getattr(response, "content", None)
|
||||
if response_content not in (None, b"", ""):
|
||||
return _json_friendly(response_content)
|
||||
|
||||
response_body = getattr(candidate, "body", None)
|
||||
if response_body not in (None, "", b""):
|
||||
return _json_friendly(response_body)
|
||||
|
||||
ext_info = getattr(candidate, "ext_info", None)
|
||||
if ext_info is not None:
|
||||
return _json_friendly(ext_info)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_filename_component(value: str) -> str:
|
||||
"""将任意字符串转换为适合文件名使用的片段。"""
|
||||
normalized_value = FILENAME_SAFE_PATTERN.sub("-", value.strip())
|
||||
@@ -228,6 +264,7 @@ def serialize_model_info_snapshot(model_info: ModelInfo) -> dict[str, Any]:
|
||||
"model_identifier": model_info.model_identifier,
|
||||
"name": model_info.name,
|
||||
"temperature": model_info.temperature,
|
||||
"visual": model_info.visual,
|
||||
}
|
||||
|
||||
|
||||
@@ -244,6 +281,7 @@ def deserialize_model_info_snapshot(raw_model_info: Any) -> ModelInfo:
|
||||
model_identifier=str(raw_model_info.get("model_identifier") or ""),
|
||||
name=str(raw_model_info.get("name") or ""),
|
||||
temperature=raw_model_info.get("temperature"),
|
||||
visual=bool(raw_model_info.get("visual", False)),
|
||||
)
|
||||
|
||||
|
||||
@@ -386,6 +424,10 @@ def save_failed_request_snapshot(
|
||||
"snapshot_version": SNAPSHOT_VERSION,
|
||||
}
|
||||
|
||||
response_body = extract_error_response_body(error)
|
||||
if response_body is not None:
|
||||
snapshot_payload["error"]["response_body"] = response_body
|
||||
|
||||
snapshot_payload["replay"] = {
|
||||
"command": build_replay_command(snapshot_path),
|
||||
"file_uri": snapshot_path.as_uri(),
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
@@ -397,8 +398,6 @@ class LLMOrchestrator:
|
||||
start_time = time.time()
|
||||
|
||||
tool_built = self._build_tool_options(tools)
|
||||
if self.request_type.startswith("maisaka_"):
|
||||
logger.info(f"LLMOrchestrator[{self.request_type}] 已构建 {len(tool_built or [])} 个内部工具选项")
|
||||
|
||||
execution_result = await self._execute_request(
|
||||
request_type=RequestType.RESPONSE,
|
||||
@@ -912,7 +911,11 @@ class LLMOrchestrator:
|
||||
model_info, api_provider, client = self._select_model(exclude_models=failed_models_this_request)
|
||||
message_list = []
|
||||
if message_factory:
|
||||
message_list = message_factory(client)
|
||||
parameter_count = len(inspect.signature(message_factory).parameters)
|
||||
if parameter_count >= 2:
|
||||
message_list = message_factory(client, model_info)
|
||||
else:
|
||||
message_list = message_factory(client)
|
||||
try:
|
||||
request = self._build_client_request(
|
||||
request_type=request_type,
|
||||
|
||||
@@ -18,6 +18,7 @@ from src.common.message_server.server import Server, get_global_server
|
||||
from src.common.remote import TelemetryHeartBeatTask
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
from src.maisaka.display.stage_status_board import disable_stage_status_board, enable_stage_status_board
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.memory_flow_service import memory_automation_service
|
||||
@@ -65,6 +66,7 @@ class MainSystem:
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化系统组件"""
|
||||
enable_stage_status_board()
|
||||
logger.info(t("startup.waking_up", nickname=global_config.bot.nickname))
|
||||
|
||||
# 其他初始化任务
|
||||
@@ -169,6 +171,7 @@ async def main() -> None:
|
||||
system.schedule_tasks(),
|
||||
)
|
||||
finally:
|
||||
disable_stage_status_board()
|
||||
emoji_manager.shutdown()
|
||||
await memory_automation_service.shutdown()
|
||||
await a_memorix_host_service.stop()
|
||||
|
||||
@@ -10,6 +10,8 @@ from src.llm_models.payload_content.tool_option import ToolDefinitionInput
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
from .continue_tool import get_tool_spec as get_continue_tool_spec
|
||||
from .continue_tool import handle_tool as handle_continue_tool
|
||||
from .finish import get_tool_spec as get_finish_tool_spec
|
||||
from .finish import handle_tool as handle_finish_tool
|
||||
from .no_reply import get_tool_spec as get_no_reply_tool_spec
|
||||
from .no_reply import handle_tool as handle_no_reply_tool
|
||||
from .query_jargon import get_tool_spec as get_query_jargon_tool_spec
|
||||
@@ -22,6 +24,8 @@ from .reply import get_tool_spec as get_reply_tool_spec
|
||||
from .reply import handle_tool as handle_reply_tool
|
||||
from .send_emoji import get_tool_spec as get_send_emoji_tool_spec
|
||||
from .send_emoji import handle_tool as handle_send_emoji_tool
|
||||
from .tool_search import get_tool_spec as get_tool_search_tool_spec
|
||||
from .tool_search import handle_tool as handle_tool_search_tool
|
||||
from .view_complex_message import get_tool_spec as get_view_complex_message_tool_spec
|
||||
from .view_complex_message import handle_tool as handle_view_complex_message_tool
|
||||
from .wait import get_tool_spec as get_wait_tool_spec
|
||||
@@ -44,11 +48,13 @@ def get_action_tool_specs() -> List[ToolSpec]:
|
||||
"""获取 Action Loop 阶段可用的内置工具声明。"""
|
||||
|
||||
return [
|
||||
get_finish_tool_spec(),
|
||||
get_reply_tool_spec(),
|
||||
get_view_complex_message_tool_spec(),
|
||||
get_query_jargon_tool_spec(),
|
||||
get_query_memory_tool_spec(enabled=bool(global_config.memory.enable_memory_query_tool)),
|
||||
get_send_emoji_tool_spec(),
|
||||
get_tool_search_tool_spec(),
|
||||
]
|
||||
|
||||
|
||||
@@ -63,12 +69,14 @@ def get_all_builtin_tool_specs() -> List[ToolSpec]:
|
||||
|
||||
return [
|
||||
*get_timing_tool_specs(),
|
||||
get_finish_tool_spec(),
|
||||
get_reply_tool_spec(),
|
||||
get_view_complex_message_tool_spec(),
|
||||
get_query_jargon_tool_spec(),
|
||||
get_query_memory_tool_spec(enabled=True),
|
||||
get_query_person_info_tool_spec(),
|
||||
get_send_emoji_tool_spec(),
|
||||
get_tool_search_tool_spec(),
|
||||
]
|
||||
|
||||
|
||||
@@ -95,6 +103,7 @@ def build_builtin_tool_handlers(tool_ctx: BuiltinToolRuntimeContext) -> Dict[str
|
||||
|
||||
return {
|
||||
"continue": lambda invocation, context=None: handle_continue_tool(tool_ctx, invocation, context),
|
||||
"finish": lambda invocation, context=None: handle_finish_tool(tool_ctx, invocation, context),
|
||||
"reply": lambda invocation, context=None: handle_reply_tool(tool_ctx, invocation, context),
|
||||
"no_reply": lambda invocation, context=None: handle_no_reply_tool(tool_ctx, invocation, context),
|
||||
"query_jargon": lambda invocation, context=None: handle_query_jargon_tool(tool_ctx, invocation, context),
|
||||
@@ -106,6 +115,7 @@ def build_builtin_tool_handlers(tool_ctx: BuiltinToolRuntimeContext) -> Dict[str
|
||||
),
|
||||
"wait": lambda invocation, context=None: handle_wait_tool(tool_ctx, invocation, context),
|
||||
"send_emoji": lambda invocation, context=None: handle_send_emoji_tool(tool_ctx, invocation, context),
|
||||
"tool_search": lambda invocation, context=None: handle_tool_search_tool(tool_ctx, invocation, context),
|
||||
"view_complex_message": lambda invocation, context=None: handle_view_complex_message_tool(
|
||||
tool_ctx,
|
||||
invocation,
|
||||
|
||||
34
src/maisaka/builtin_tool/finish.py
Normal file
34
src/maisaka/builtin_tool/finish.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""finish 内置工具。"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
|
||||
|
||||
def get_tool_spec() -> ToolSpec:
|
||||
"""获取 finish 工具声明。"""
|
||||
|
||||
return ToolSpec(
|
||||
name="finish",
|
||||
brief_description="结束本轮思考,等待后续新的外部消息再继续。",
|
||||
provider_name="maisaka_builtin",
|
||||
provider_type="builtin",
|
||||
)
|
||||
|
||||
|
||||
async def handle_tool(
|
||||
tool_ctx: BuiltinToolRuntimeContext,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 finish 内置工具。"""
|
||||
|
||||
del context
|
||||
tool_ctx.runtime._enter_stop_state()
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
"当前对话循环已结束本轮思考,等待新的消息到来。",
|
||||
metadata={"pause_execution": True},
|
||||
)
|
||||
@@ -29,6 +29,6 @@ async def handle_tool(
|
||||
tool_ctx.runtime._enter_stop_state()
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
"当前对话循环已暂停,等待新消息到来。",
|
||||
"当前暂时停止思考,等待新消息到来。",
|
||||
metadata={"pause_execution": True},
|
||||
)
|
||||
|
||||
@@ -91,10 +91,6 @@ async def handle_tool(
|
||||
f"未找到要回复的目标消息,msg_id={target_message_id}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{tool_ctx.runtime.log_prefix} 已触发回复工具,"
|
||||
f"目标消息编号={target_message_id} 引用回复={set_quote} 最新思考={latest_thought!r}"
|
||||
)
|
||||
try:
|
||||
replyer = replyer_manager.get_replyer(
|
||||
chat_stream=tool_ctx.runtime.chat_stream,
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
import math
|
||||
from random import sample
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from PIL import ImageDraw, ImageFont
|
||||
@@ -20,12 +20,14 @@ from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||
from src.maisaka.context_messages import (
|
||||
LLMContextMessage,
|
||||
ReferenceMessage,
|
||||
ReferenceMessageType,
|
||||
SessionBackedMessage,
|
||||
)
|
||||
from src.plugin_runtime.hook_payloads import serialize_prompt_messages
|
||||
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
|
||||
@@ -242,34 +244,9 @@ def _build_emoji_candidate_summary(emojis: list[MaiEmoji]) -> str:
|
||||
return "\n".join(summary_lines).strip()
|
||||
|
||||
|
||||
def _build_send_emoji_prompt_preview(
|
||||
*,
|
||||
system_prompt: str,
|
||||
requested_emotion: str,
|
||||
grid_rows: int,
|
||||
grid_columns: int,
|
||||
sampled_emojis: list[MaiEmoji],
|
||||
) -> str:
|
||||
"""构建表情选择子代理的文本预览。"""
|
||||
|
||||
task_text = (
|
||||
"[选择任务]\n"
|
||||
f"requested_emotion: {requested_emotion or '未指定'}\n"
|
||||
f"候选总数: {len(sampled_emojis)}\n"
|
||||
f"拼图布局: {grid_rows}x{grid_columns}\n"
|
||||
"请只输出 JSON。"
|
||||
)
|
||||
candidate_summary = _build_emoji_candidate_summary(sampled_emojis)
|
||||
return (
|
||||
f"[System Prompt]\n{system_prompt}\n\n"
|
||||
f"{task_text}\n\n"
|
||||
f"[候选表情摘要]\n{candidate_summary or '无候选表情'}"
|
||||
).strip()
|
||||
|
||||
|
||||
def _build_send_emoji_monitor_detail(
|
||||
*,
|
||||
prompt_text: str = "",
|
||||
request_messages: Optional[list[dict[str, Any]]] = None,
|
||||
reasoning_text: str = "",
|
||||
output_text: str = "",
|
||||
metrics: Optional[Dict[str, Any]] = None,
|
||||
@@ -278,8 +255,8 @@ def _build_send_emoji_monitor_detail(
|
||||
"""构建 emotion tool 统一监控详情。"""
|
||||
|
||||
detail: Dict[str, Any] = {}
|
||||
if prompt_text.strip():
|
||||
detail["prompt_text"] = prompt_text.strip()
|
||||
if isinstance(request_messages, list) and request_messages:
|
||||
detail["request_messages"] = request_messages
|
||||
if reasoning_text.strip():
|
||||
detail["reasoning_text"] = reasoning_text.strip()
|
||||
if output_text.strip():
|
||||
@@ -387,13 +364,16 @@ async def _select_emoji_with_sub_agent(
|
||||
remaining_uses_value=1,
|
||||
display_prefix="[表情包选择任务]",
|
||||
)
|
||||
prompt_preview = _build_send_emoji_prompt_preview(
|
||||
system_prompt=system_prompt,
|
||||
requested_emotion=requested_emotion,
|
||||
grid_rows=grid_rows,
|
||||
grid_columns=grid_columns,
|
||||
sampled_emojis=sampled_emojis,
|
||||
)
|
||||
request_messages = [
|
||||
MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build(),
|
||||
]
|
||||
prompt_llm_message = prompt_message.to_llm_message()
|
||||
if prompt_llm_message is not None:
|
||||
request_messages.append(prompt_llm_message)
|
||||
candidate_llm_message = candidate_message.to_llm_message()
|
||||
if candidate_llm_message is not None:
|
||||
request_messages.append(candidate_llm_message)
|
||||
serialized_request_messages = serialize_prompt_messages(request_messages)
|
||||
|
||||
selection_started_at = datetime.now()
|
||||
response = await tool_ctx.runtime.run_sub_agent(
|
||||
@@ -421,7 +401,7 @@ async def _select_emoji_with_sub_agent(
|
||||
logger.warning(f"{tool_ctx.runtime.log_prefix} 表情包子代理结果解析失败,将回退到候选首项: {exc}")
|
||||
if selection_metadata is not None:
|
||||
selection_metadata["monitor_detail"] = _build_send_emoji_monitor_detail(
|
||||
prompt_text=prompt_preview,
|
||||
request_messages=serialized_request_messages,
|
||||
output_text=response.content or "",
|
||||
metrics=selection_metrics,
|
||||
extra_sections=[{
|
||||
@@ -435,7 +415,7 @@ async def _select_emoji_with_sub_agent(
|
||||
if selection_metadata is not None:
|
||||
selection_metadata["reason"] = selection.reason.strip()
|
||||
selection_metadata["monitor_detail"] = _build_send_emoji_monitor_detail(
|
||||
prompt_text=prompt_preview,
|
||||
request_messages=serialized_request_messages,
|
||||
reasoning_text=selection.reason,
|
||||
output_text=response.content or "",
|
||||
metrics=selection_metrics,
|
||||
|
||||
106
src/maisaka/builtin_tool/tool_search.py
Normal file
106
src/maisaka/builtin_tool/tool_search.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""tool_search 内置工具。"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import json
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
|
||||
|
||||
def get_tool_spec() -> ToolSpec:
|
||||
"""获取 tool_search 工具声明。"""
|
||||
|
||||
return ToolSpec(
|
||||
name="tool_search",
|
||||
brief_description="在 deferred tools 列表中按名称或关键词搜索工具,并将命中的工具加入后续轮次的可用工具列表。",
|
||||
detailed_description=(
|
||||
"参数说明:\n"
|
||||
"- query:String,必填。工具名、前缀或关键词。\n"
|
||||
"- limit:Integer,可选。最多返回多少个匹配工具,默认为 5。"
|
||||
),
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "要搜索的工具名、前缀或关键词。",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "最多返回多少个匹配工具。",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
provider_name="maisaka_builtin",
|
||||
provider_type="builtin",
|
||||
)
|
||||
|
||||
|
||||
async def handle_tool(
|
||||
tool_ctx: BuiltinToolRuntimeContext,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 tool_search 内置工具。"""
|
||||
|
||||
del context
|
||||
raw_query = invocation.arguments.get("query")
|
||||
if not isinstance(raw_query, str) or not raw_query.strip():
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"tool_search 需要提供非空的 `query` 字符串参数。",
|
||||
)
|
||||
|
||||
raw_limit = invocation.arguments.get("limit", 5)
|
||||
try:
|
||||
limit = max(1, int(raw_limit))
|
||||
except (TypeError, ValueError):
|
||||
limit = 5
|
||||
|
||||
matched_tool_specs = tool_ctx.runtime.search_deferred_tool_specs(raw_query, limit=limit)
|
||||
matched_tool_names = [tool_spec.name for tool_spec in matched_tool_specs]
|
||||
newly_discovered_tool_names = tool_ctx.runtime.discover_deferred_tools(matched_tool_names)
|
||||
|
||||
structured_content: Dict[str, Any] = {
|
||||
"query": raw_query.strip(),
|
||||
"matched_tool_names": matched_tool_names,
|
||||
"newly_discovered_tool_names": newly_discovered_tool_names,
|
||||
}
|
||||
|
||||
if not matched_tool_names:
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
"未找到匹配的 deferred tools,请尝试更完整的工具名、前缀或其他关键词。",
|
||||
structured_content=structured_content,
|
||||
metadata={"record_display_prompt": "tool_search 未找到匹配工具。"},
|
||||
)
|
||||
|
||||
content_lines: List[str] = [
|
||||
f"已找到 {len(matched_tool_names)} 个 deferred tools,它们会在后续轮次中加入可用工具列表:",
|
||||
*[f"- {tool_name}" for tool_name in matched_tool_names],
|
||||
]
|
||||
if newly_discovered_tool_names:
|
||||
content_lines.extend(
|
||||
[
|
||||
"",
|
||||
"本次新发现的工具:",
|
||||
*[f"- {tool_name}" for tool_name in newly_discovered_tool_names],
|
||||
]
|
||||
)
|
||||
else:
|
||||
content_lines.extend(["", "这些工具此前已经发现过,无需重复展开。"])
|
||||
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
"\n".join(content_lines),
|
||||
structured_content=structured_content,
|
||||
metadata={
|
||||
"matched_tool_names": matched_tool_names,
|
||||
"newly_discovered_tool_names": newly_discovered_tool_names,
|
||||
"record_display_prompt": json.dumps(structured_content, ensure_ascii=False),
|
||||
},
|
||||
)
|
||||
@@ -12,8 +12,8 @@ def get_tool_spec() -> ToolSpec:
|
||||
|
||||
return ToolSpec(
|
||||
name="wait",
|
||||
brief_description="暂停当前对话并等待用户新的输入。",
|
||||
detailed_description="参数说明:\n- seconds:integer,必填。等待的秒数。",
|
||||
brief_description="暂停当前对话并固定等待一段时间,期间不因新消息提前恢复。",
|
||||
detailed_description="参数说明:\n- seconds:integer,必填。等待的秒数。等待期间收到的新消息只会暂存,直到超时后再继续处理。",
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -46,6 +46,6 @@ async def handle_tool(
|
||||
tool_ctx.runtime._enter_wait_state(seconds=wait_seconds, tool_call_id=invocation.call_id)
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
f"当前对话循环进入等待状态,最长等待 {wait_seconds} 秒。",
|
||||
f"当前对话循环进入等待状态,将固定等待 {wait_seconds} 秒;期间收到的新消息不会提前打断本次等待。",
|
||||
metadata={"pause_execution": True},
|
||||
)
|
||||
|
||||
@@ -5,20 +5,18 @@ from datetime import datetime
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
from rich.console import RenderableType
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolRegistry, ToolSpec
|
||||
from src.core.tooling import ToolRegistry
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||
from src.llm_models.payload_content.resp_format import RespFormat
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, ToolOption, normalize_tool_options
|
||||
from src.plugin_runtime.hook_payloads import (
|
||||
deserialize_prompt_messages,
|
||||
@@ -32,9 +30,11 @@ from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistr
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from .builtin_tool import get_builtin_tools
|
||||
from .context_messages import AssistantMessage, LLMContextMessage
|
||||
from .context_messages import AssistantMessage, LLMContextMessage, ToolResultMessage
|
||||
from .history_utils import drop_orphan_tool_results
|
||||
from .prompt_cli_renderer import PromptCLIVisualizer
|
||||
from .display.prompt_cli_renderer import PromptCLIVisualizer
|
||||
|
||||
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -54,13 +54,6 @@ class ChatResponse:
|
||||
prompt_section: Optional[RenderableType] = None
|
||||
|
||||
|
||||
class ToolFilterSelection(BaseModel):
|
||||
"""工具筛选响应。"""
|
||||
|
||||
selected_tool_names: list[str] = PydanticField(default_factory=list)
|
||||
"""经过预筛后保留的候选工具名称列表。"""
|
||||
|
||||
|
||||
logger = get_logger("maisaka_chat_loop")
|
||||
|
||||
|
||||
@@ -217,10 +210,6 @@ class MaisakaChatLoopService:
|
||||
else:
|
||||
self._chat_system_prompt = chat_system_prompt
|
||||
self._llm_chat = LLMServiceClient(task_name="planner", request_type="maisaka_planner")
|
||||
self._tool_filter_llm = LLMServiceClient(
|
||||
task_name=global_config.maisaka.tool_filter_task_name,
|
||||
request_type="maisaka_tool_filter",
|
||||
)
|
||||
|
||||
@property
|
||||
def personality_prompt(self) -> str:
|
||||
@@ -303,8 +292,15 @@ class MaisakaChatLoopService:
|
||||
"file_tools_section": tools_section,
|
||||
"group_chat_attention_block": self._build_group_chat_attention_block(),
|
||||
"identity": self._personality_prompt,
|
||||
"time_block": self._build_time_block(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_time_block() -> str:
|
||||
"""构建当前时间提示块。"""
|
||||
|
||||
return f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
def _build_group_chat_attention_block(self) -> str:
|
||||
"""构建当前聊天场景下的额外注意事项块。"""
|
||||
|
||||
@@ -399,6 +395,7 @@ class MaisakaChatLoopService:
|
||||
self,
|
||||
selected_history: List[LLMContextMessage],
|
||||
*,
|
||||
injected_user_messages: Sequence[str] | None = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> List[Message]:
|
||||
"""构造发给大模型的消息列表。
|
||||
@@ -420,254 +417,49 @@ class MaisakaChatLoopService:
|
||||
if llm_message is not None:
|
||||
messages.append(llm_message)
|
||||
|
||||
normalized_injected_messages: List[Message] = []
|
||||
for injected_message in injected_user_messages or []:
|
||||
normalized_message = str(injected_message or "").strip()
|
||||
if not normalized_message:
|
||||
continue
|
||||
normalized_injected_messages.append(
|
||||
MessageBuilder()
|
||||
.set_role(RoleType.User)
|
||||
.add_text_content(normalized_message)
|
||||
.build()
|
||||
)
|
||||
|
||||
if normalized_injected_messages:
|
||||
insertion_index = self._resolve_injected_user_messages_insertion_index(messages)
|
||||
messages[insertion_index:insertion_index] = normalized_injected_messages
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _is_builtin_tool_spec(tool_spec: ToolSpec) -> bool:
|
||||
"""判断一个工具是否属于默认内置工具。
|
||||
def _resolve_injected_user_messages_insertion_index(messages: Sequence[Message]) -> int:
|
||||
"""计算 injected meta user messages 在请求中的插入位置。
|
||||
|
||||
Args:
|
||||
tool_spec: 待判断的工具声明。
|
||||
|
||||
Returns:
|
||||
bool: 是否为默认内置工具。
|
||||
规则与 deferred attachment 更接近:
|
||||
- 从尾部向前寻找最近的 stopping point;
|
||||
- stopping point 为 assistant 消息或 tool 结果消息;
|
||||
- 找到后插入到其后面;
|
||||
- 若不存在 stopping point,则退回到 system 消息之后。
|
||||
"""
|
||||
|
||||
return tool_spec.provider_type == "builtin" or tool_spec.provider_name == "maisaka_builtin"
|
||||
for index in range(len(messages) - 1, -1, -1):
|
||||
message = messages[index]
|
||||
if message.role in {RoleType.Assistant, RoleType.Tool}:
|
||||
return index + 1
|
||||
|
||||
@classmethod
|
||||
def _split_builtin_and_candidate_tools(
|
||||
cls,
|
||||
tool_specs: List[ToolSpec],
|
||||
) -> tuple[List[ToolSpec], List[ToolSpec]]:
|
||||
"""拆分内置工具与可筛选工具列表。
|
||||
|
||||
Args:
|
||||
tool_specs: 当前全部工具声明。
|
||||
|
||||
Returns:
|
||||
tuple[List[ToolSpec], List[ToolSpec]]: `(内置工具, 可筛选工具)`。
|
||||
"""
|
||||
|
||||
builtin_tool_specs: List[ToolSpec] = []
|
||||
candidate_tool_specs: List[ToolSpec] = []
|
||||
for tool_spec in tool_specs:
|
||||
if cls._is_builtin_tool_spec(tool_spec):
|
||||
builtin_tool_specs.append(tool_spec)
|
||||
else:
|
||||
candidate_tool_specs.append(tool_spec)
|
||||
return builtin_tool_specs, candidate_tool_specs
|
||||
|
||||
@staticmethod
|
||||
def _truncate_tool_filter_text(text: str, max_length: int = 180) -> str:
|
||||
"""截断工具筛选阶段展示的文本。
|
||||
|
||||
Args:
|
||||
text: 原始文本。
|
||||
max_length: 最长保留字符数。
|
||||
|
||||
Returns:
|
||||
str: 截断后的文本。
|
||||
"""
|
||||
|
||||
normalized_text = text.strip()
|
||||
if len(normalized_text) <= max_length:
|
||||
return normalized_text
|
||||
return f"{normalized_text[: max_length - 1]}…"
|
||||
|
||||
def _build_tool_filter_prompt(
|
||||
self,
|
||||
selected_history: List[LLMContextMessage],
|
||||
candidate_tool_specs: List[ToolSpec],
|
||||
max_keep: int,
|
||||
) -> str:
|
||||
"""构造小模型工具预筛选提示词。
|
||||
|
||||
Args:
|
||||
selected_history: 已选中的对话上下文。
|
||||
candidate_tool_specs: 非内置候选工具列表。
|
||||
max_keep: 最多保留的候选工具数量。
|
||||
|
||||
Returns:
|
||||
str: 用于工具预筛的小模型提示词。
|
||||
"""
|
||||
|
||||
history_lines: List[str] = []
|
||||
for message in selected_history[-10:]:
|
||||
plain_text = message.processed_plain_text.strip()
|
||||
if not plain_text:
|
||||
continue
|
||||
history_lines.append(
|
||||
f"- {message.role}: {self._truncate_tool_filter_text(plain_text, max_length=200)}"
|
||||
)
|
||||
|
||||
if history_lines:
|
||||
history_section = "\n".join(history_lines)
|
||||
else:
|
||||
history_section = "- 当前没有可用的对话上下文。"
|
||||
|
||||
tool_lines = [
|
||||
f"- {tool_spec.name}: {tool_spec.brief_description.strip() or '无简要描述'}"
|
||||
for tool_spec in candidate_tool_specs
|
||||
]
|
||||
tool_section = "\n".join(tool_lines) if tool_lines else "- 当前没有候选工具。"
|
||||
|
||||
return (
|
||||
"你是 Maisaka 的工具预筛选器。\n"
|
||||
"你的任务是在正式进入 planner 前,根据当前情景从候选工具中挑出最可能马上会用到的工具。\n"
|
||||
"默认内置工具已经自动保留,不在候选列表中,你不需要再次选择它们。\n"
|
||||
"你只能参考工具的简要描述,不要假设未描述的隐藏能力。\n"
|
||||
f"最多保留 {max_keep} 个候选工具;如果都不合适,可以返回空数组。\n"
|
||||
"请严格返回 JSON 对象,格式为:"
|
||||
'{"selected_tool_names":["工具名1","工具名2"]}\n\n'
|
||||
f"【最近对话】\n{history_section}\n\n"
|
||||
f"【候选工具(仅简要描述)】\n{tool_section}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_tool_filter_response(
|
||||
response_text: str,
|
||||
candidate_tool_specs: List[ToolSpec],
|
||||
max_keep: int,
|
||||
) -> List[ToolSpec] | None:
|
||||
"""解析工具预筛选响应。
|
||||
|
||||
Args:
|
||||
response_text: 小模型返回的原始文本。
|
||||
candidate_tool_specs: 非内置候选工具列表。
|
||||
max_keep: 最多保留的候选工具数量。
|
||||
|
||||
Returns:
|
||||
List[ToolSpec] | None: 成功解析时返回筛选后的工具列表;解析失败时返回 ``None``。
|
||||
"""
|
||||
|
||||
normalized_response = response_text.strip()
|
||||
if not normalized_response:
|
||||
return None
|
||||
|
||||
selected_tool_names: List[str]
|
||||
try:
|
||||
selected_tool_names = ToolFilterSelection.model_validate_json(normalized_response).selected_tool_names
|
||||
except Exception:
|
||||
try:
|
||||
parsed_payload = json.loads(normalized_response)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
if isinstance(parsed_payload, dict):
|
||||
raw_tool_names = parsed_payload.get("selected_tool_names", [])
|
||||
elif isinstance(parsed_payload, list):
|
||||
raw_tool_names = parsed_payload
|
||||
else:
|
||||
return None
|
||||
|
||||
if not isinstance(raw_tool_names, list):
|
||||
return None
|
||||
|
||||
selected_tool_names = []
|
||||
for item in raw_tool_names:
|
||||
normalized_name = str(item).strip()
|
||||
if normalized_name:
|
||||
selected_tool_names.append(normalized_name)
|
||||
|
||||
candidate_map = {tool_spec.name: tool_spec for tool_spec in candidate_tool_specs}
|
||||
filtered_tool_specs: List[ToolSpec] = []
|
||||
seen_names: set[str] = set()
|
||||
for tool_name in selected_tool_names:
|
||||
normalized_name = tool_name.strip()
|
||||
if not normalized_name or normalized_name in seen_names:
|
||||
continue
|
||||
tool_spec = candidate_map.get(normalized_name)
|
||||
if tool_spec is None:
|
||||
continue
|
||||
|
||||
seen_names.add(normalized_name)
|
||||
filtered_tool_specs.append(tool_spec)
|
||||
if len(filtered_tool_specs) >= max_keep:
|
||||
break
|
||||
|
||||
return filtered_tool_specs
|
||||
|
||||
async def _filter_tool_specs_for_planner(
|
||||
self,
|
||||
selected_history: List[LLMContextMessage],
|
||||
tool_specs: List[ToolSpec],
|
||||
) -> List[ToolSpec]:
|
||||
"""在将工具交给 planner 前进行快速预筛选。
|
||||
|
||||
Args:
|
||||
selected_history: 已选中的对话上下文。
|
||||
tool_specs: 当前全部可用工具声明。
|
||||
|
||||
Returns:
|
||||
List[ToolSpec]: 最终交给 planner 的工具声明列表。
|
||||
"""
|
||||
|
||||
threshold = max(1, int(global_config.maisaka.tool_filter_threshold))
|
||||
max_keep = max(1, int(global_config.maisaka.tool_filter_max_keep))
|
||||
if len(tool_specs) <= threshold:
|
||||
return tool_specs
|
||||
|
||||
builtin_tool_specs, candidate_tool_specs = self._split_builtin_and_candidate_tools(tool_specs)
|
||||
if not candidate_tool_specs:
|
||||
return tool_specs
|
||||
if len(candidate_tool_specs) <= max_keep:
|
||||
return [*builtin_tool_specs, *candidate_tool_specs]
|
||||
|
||||
filter_prompt = self._build_tool_filter_prompt(selected_history, candidate_tool_specs, max_keep)
|
||||
logger.info(
|
||||
"工具预筛选开始: "
|
||||
f"总工具数={len(tool_specs)} "
|
||||
f"内置工具数={len(builtin_tool_specs)} "
|
||||
f"候选工具数={len(candidate_tool_specs)} "
|
||||
f"最多保留候选数={max_keep}"
|
||||
)
|
||||
|
||||
try:
|
||||
generation_result = await self._tool_filter_llm.generate_response(
|
||||
prompt=filter_prompt,
|
||||
options=LLMGenerationOptions(
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
response_format=RespFormat(
|
||||
format_type=RespFormatType.JSON_SCHEMA,
|
||||
schema=ToolFilterSelection,
|
||||
),
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"工具预筛选失败,保留全部工具。错误={exc}")
|
||||
return tool_specs
|
||||
|
||||
filtered_candidate_tool_specs = self._parse_tool_filter_response(
|
||||
generation_result.response or "",
|
||||
candidate_tool_specs,
|
||||
max_keep,
|
||||
)
|
||||
if filtered_candidate_tool_specs is None:
|
||||
logger.warning(
|
||||
"工具预筛选返回结果无法解析,保留全部工具。"
|
||||
f" 原始返回={generation_result.response or ''!r}"
|
||||
)
|
||||
return tool_specs
|
||||
|
||||
filtered_tool_specs = [*builtin_tool_specs, *filtered_candidate_tool_specs]
|
||||
if not filtered_tool_specs:
|
||||
logger.warning("工具预筛选得到空结果,保留全部工具以避免主流程失去工具能力。")
|
||||
return tool_specs
|
||||
|
||||
logger.info(
|
||||
"工具预筛选完成: "
|
||||
f"筛选前总数={len(tool_specs)} "
|
||||
f"筛选后总数={len(filtered_tool_specs)} "
|
||||
f"保留候选工具={[tool_spec.name for tool_spec in filtered_candidate_tool_specs]}"
|
||||
)
|
||||
return filtered_tool_specs
|
||||
if messages and messages[0].role == RoleType.System:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
async def chat_loop_step(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
*,
|
||||
injected_user_messages: Sequence[str] | None = None,
|
||||
request_kind: str = "planner",
|
||||
response_format: RespFormat | None = None,
|
||||
tool_definitions: Sequence[ToolDefinitionInput] | None = None,
|
||||
@@ -683,8 +475,14 @@ class MaisakaChatLoopService:
|
||||
|
||||
if not self._prompts_loaded:
|
||||
await self.ensure_chat_prompt_loaded()
|
||||
selected_history, selection_reason = self.select_llm_context_messages(chat_history)
|
||||
built_messages = self._build_request_messages(selected_history)
|
||||
selected_history, selection_reason = self.select_llm_context_messages(
|
||||
chat_history,
|
||||
request_kind=request_kind,
|
||||
)
|
||||
built_messages = self._build_request_messages(
|
||||
selected_history,
|
||||
injected_user_messages=injected_user_messages,
|
||||
)
|
||||
|
||||
def message_factory(_client: BaseClient) -> List[Message]:
|
||||
"""返回当前轮次已经构建好的请求消息。
|
||||
@@ -704,8 +502,7 @@ class MaisakaChatLoopService:
|
||||
all_tools = list(tool_definitions)
|
||||
elif self._tool_registry is not None:
|
||||
tool_specs = await self._tool_registry.list_tools()
|
||||
filtered_tool_specs = await self._filter_tool_specs_for_planner(selected_history, tool_specs)
|
||||
all_tools = [tool_spec.to_llm_definition() for tool_spec in filtered_tool_specs]
|
||||
all_tools = [tool_spec.to_llm_definition() for tool_spec in tool_specs]
|
||||
else:
|
||||
all_tools = [*get_builtin_tools(), *self._extra_tools]
|
||||
|
||||
@@ -740,15 +537,9 @@ class MaisakaChatLoopService:
|
||||
selection_reason=selection_reason,
|
||||
image_display_mode=image_display_mode,
|
||||
folded=global_config.debug.fold_maisaka_thinking,
|
||||
tool_definitions=list(all_tools),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"规划器请求开始: "
|
||||
f"已选上下文消息数={len(selected_history)} "
|
||||
f"大模型消息数={len(built_messages)} "
|
||||
f"工具数={len(all_tools)} "
|
||||
f"启用打断={self._interrupt_flag is not None}"
|
||||
)
|
||||
generation_result = await self._llm_chat.generate_response_with_messages(
|
||||
message_factory=message_factory,
|
||||
options=LLMGenerationOptions(
|
||||
@@ -760,15 +551,6 @@ class MaisakaChatLoopService:
|
||||
),
|
||||
)
|
||||
|
||||
prompt_stats_text = PromptCLIVisualizer.build_prompt_stats_text(
|
||||
selected_history_count=len(selected_history),
|
||||
built_message_count=len(built_messages),
|
||||
prompt_tokens=generation_result.prompt_tokens,
|
||||
completion_tokens=generation_result.completion_tokens,
|
||||
total_tokens=generation_result.total_tokens,
|
||||
)
|
||||
logger.info(f"本轮Prompt统计: {prompt_stats_text}")
|
||||
|
||||
final_response = generation_result.response or ""
|
||||
final_tool_calls = list(generation_result.tool_calls or [])
|
||||
after_response_result = await self._get_runtime_manager().invoke_hook(
|
||||
@@ -822,16 +604,21 @@ class MaisakaChatLoopService:
|
||||
def select_llm_context_messages(
|
||||
chat_history: List[LLMContextMessage],
|
||||
*,
|
||||
request_kind: str = "planner",
|
||||
max_context_size: Optional[int] = None,
|
||||
) -> tuple[List[LLMContextMessage], str]:
|
||||
"""??????? LLM ???????"""
|
||||
"""选择LLM上下文消息"""
|
||||
|
||||
filtered_history = MaisakaChatLoopService._filter_history_for_request_kind(
|
||||
chat_history,
|
||||
request_kind=request_kind,
|
||||
)
|
||||
effective_context_size = max(1, int(max_context_size or global_config.chat.max_context_size))
|
||||
selected_indices: List[int] = []
|
||||
counted_message_count = 0
|
||||
|
||||
for index in range(len(chat_history) - 1, -1, -1):
|
||||
message = chat_history[index]
|
||||
for index in range(len(filtered_history) - 1, -1, -1):
|
||||
message = filtered_history[index]
|
||||
if message.to_llm_message() is None:
|
||||
continue
|
||||
|
||||
@@ -842,10 +629,10 @@ class MaisakaChatLoopService:
|
||||
break
|
||||
|
||||
if not selected_indices:
|
||||
return [], f"???????? {effective_context_size} ? user/assistant??? 0 ??"
|
||||
return [], f"没有选择到上下文消息,实际发送 {effective_context_size} 条 user/assistant 消息"
|
||||
|
||||
selected_indices.reverse()
|
||||
selected_history = [chat_history[index] for index in selected_indices]
|
||||
selected_history = [filtered_history[index] for index in selected_indices]
|
||||
selected_history, hidden_assistant_count = MaisakaChatLoopService._hide_early_assistant_messages(selected_history)
|
||||
selected_history, _ = drop_orphan_tool_results(selected_history)
|
||||
selection_reason = (
|
||||
@@ -860,45 +647,43 @@ class MaisakaChatLoopService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _select_llm_context_messages(chat_history: List[LLMContextMessage]) -> tuple[List[LLMContextMessage], str]:
|
||||
"""选择真正发送给 LLM 的上下文消息。
|
||||
def _filter_history_for_request_kind(
|
||||
selected_history: List[LLMContextMessage],
|
||||
*,
|
||||
request_kind: str,
|
||||
) -> List[LLMContextMessage]:
|
||||
"""按请求类型过滤不应暴露的历史工具链。"""
|
||||
|
||||
Args:
|
||||
chat_history: 当前全部对话历史。
|
||||
if request_kind != "planner":
|
||||
return selected_history
|
||||
|
||||
Returns:
|
||||
tuple[List[LLMContextMessage], str]: `(已选上下文, 选择说明)`。
|
||||
"""
|
||||
|
||||
max_context_size = max(1, int(global_config.chat.max_context_size))
|
||||
selected_indices: List[int] = []
|
||||
counted_message_count = 0
|
||||
|
||||
for index in range(len(chat_history) - 1, -1, -1):
|
||||
message = chat_history[index]
|
||||
if message.to_llm_message() is None:
|
||||
filtered_history: List[LLMContextMessage] = []
|
||||
for message in selected_history:
|
||||
if isinstance(message, ToolResultMessage) and message.tool_name in TIMING_GATE_TOOL_NAMES:
|
||||
continue
|
||||
|
||||
selected_indices.append(index)
|
||||
if message.count_in_context:
|
||||
counted_message_count += 1
|
||||
if counted_message_count >= max_context_size:
|
||||
break
|
||||
if isinstance(message, AssistantMessage) and message.tool_calls:
|
||||
kept_tool_calls = [
|
||||
tool_call
|
||||
for tool_call in message.tool_calls
|
||||
if tool_call.func_name not in TIMING_GATE_TOOL_NAMES
|
||||
]
|
||||
if not kept_tool_calls:
|
||||
continue
|
||||
if len(kept_tool_calls) != len(message.tool_calls):
|
||||
filtered_history.append(
|
||||
AssistantMessage(
|
||||
content=message.content,
|
||||
timestamp=message.timestamp,
|
||||
tool_calls=kept_tool_calls,
|
||||
source_kind=message.source_kind,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if not selected_indices:
|
||||
return [], f"上下文判定:最近 {max_context_size} 条 user/assistant(当前 0 条)"
|
||||
filtered_history.append(message)
|
||||
|
||||
selected_indices.reverse()
|
||||
selected_history = [chat_history[index] for index in selected_indices]
|
||||
selected_history, hidden_assistant_count = MaisakaChatLoopService._hide_early_assistant_messages(selected_history)
|
||||
selected_history, _ = drop_orphan_tool_results(selected_history)
|
||||
return (
|
||||
selected_history,
|
||||
(
|
||||
f"上下文判定:最近 {max_context_size} 条 user/assistant;"
|
||||
f"展示并发送窗口内消息 {len(selected_history)} 条"
|
||||
),
|
||||
)
|
||||
return filtered_history
|
||||
|
||||
@staticmethod
|
||||
def _hide_early_assistant_messages(
|
||||
|
||||
@@ -51,7 +51,9 @@ def _append_emoji_component(builder: MessageBuilder, component: EmojiComponent)
|
||||
if component.content:
|
||||
builder.add_text_content(component.content)
|
||||
return True
|
||||
return False
|
||||
|
||||
builder.add_text_content("[表情包]")
|
||||
return True
|
||||
|
||||
|
||||
def _append_image_component(builder: MessageBuilder, component: ImageComponent) -> bool:
|
||||
@@ -65,7 +67,9 @@ def _append_image_component(builder: MessageBuilder, component: ImageComponent)
|
||||
if component.content:
|
||||
builder.add_text_content(component.content)
|
||||
return True
|
||||
return False
|
||||
|
||||
builder.add_text_content("[图片]")
|
||||
return True
|
||||
|
||||
|
||||
def _append_reply_component(builder: MessageBuilder, component: ReplyComponent) -> bool:
|
||||
|
||||
33
src/maisaka/display/__init__.py
Normal file
33
src/maisaka/display/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Maisaka 展示模块。"""
|
||||
|
||||
from .display_utils import (
|
||||
build_tool_call_summary_lines,
|
||||
format_token_count,
|
||||
format_tool_call_for_display,
|
||||
get_request_panel_style,
|
||||
get_role_badge_label,
|
||||
get_role_badge_style,
|
||||
)
|
||||
from .prompt_cli_renderer import PromptCLIVisualizer
|
||||
from .prompt_preview_logger import PromptPreviewLogger
|
||||
from .stage_status_board import (
|
||||
disable_stage_status_board,
|
||||
enable_stage_status_board,
|
||||
remove_stage_status,
|
||||
update_stage_status,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PromptCLIVisualizer",
|
||||
"PromptPreviewLogger",
|
||||
"build_tool_call_summary_lines",
|
||||
"disable_stage_status_board",
|
||||
"enable_stage_status_board",
|
||||
"format_token_count",
|
||||
"format_tool_call_for_display",
|
||||
"get_request_panel_style",
|
||||
"get_role_badge_label",
|
||||
"get_role_badge_style",
|
||||
"remove_stage_status",
|
||||
"update_stage_status",
|
||||
]
|
||||
@@ -4,14 +4,15 @@ from typing import Any
|
||||
|
||||
|
||||
_REQUEST_PANEL_STYLE_MAP: dict[str, tuple[str, str]] = {
|
||||
"timing_gate": ("\u004d\u0061\u0069\u0053\u0061\u006b\u0061 \u5927\u6a21\u578b\u8bf7\u6c42 - Timing Gate \u5b50\u4ee3\u7406", "bright_magenta"),
|
||||
"replyer": ("\u004d\u0061\u0069\u0053\u0061\u006b\u0061 \u56de\u590d\u5668 Prompt", "bright_yellow"),
|
||||
"planner": ("MaiSaka 大模型请求 - 对话单步", "green"),
|
||||
"timing_gate": ("MaiSaka 大模型请求 - Timing Gate 子代理", "bright_magenta"),
|
||||
"replyer": ("MaiSaka 回复器 Prompt", "bright_yellow"),
|
||||
"emotion": ("MaiSaka Emotion Tool Prompt", "bright_cyan"),
|
||||
"sub_agent": ("\u004d\u0061\u0069\u0053\u0061\u006b\u0061 \u5927\u6a21\u578b\u8bf7\u6c42 - \u5b50\u4ee3\u7406", "bright_blue"),
|
||||
"sub_agent": ("MaiSaka 大模型请求 - 子代理", "bright_blue"),
|
||||
}
|
||||
|
||||
_DEFAULT_REQUEST_PANEL_STYLE: tuple[str, str] = (
|
||||
"\u004d\u0061\u0069\u0053\u0061\u006b\u0061 \u5927\u6a21\u578b\u8bf7\u6c42 - \u5bf9\u8bdd\u5355\u6b65",
|
||||
"MaiSaka 大模型请求 - 对话单步",
|
||||
"cyan",
|
||||
)
|
||||
|
||||
@@ -23,10 +24,10 @@ _ROLE_BADGE_STYLE_MAP: dict[str, str] = {
|
||||
}
|
||||
|
||||
_ROLE_BADGE_LABEL_MAP: dict[str, str] = {
|
||||
"system": "\u7cfb\u7edf",
|
||||
"user": "\u7528\u6237",
|
||||
"assistant": "\u52a9\u624b",
|
||||
"tool": "\u5de5\u5177",
|
||||
"system": "系统",
|
||||
"user": "用户",
|
||||
"assistant": "助手",
|
||||
"tool": "工具",
|
||||
}
|
||||
|
||||
|
||||
@@ -54,7 +55,7 @@ def get_role_badge_style(role: str) -> str:
|
||||
def get_role_badge_label(role: str) -> str:
|
||||
"""返回角色标签对应的展示文案。"""
|
||||
|
||||
return _ROLE_BADGE_LABEL_MAP.get(role, "\u672a\u77e5")
|
||||
return _ROLE_BADGE_LABEL_MAP.get(role, "未知")
|
||||
|
||||
|
||||
def format_tool_call_for_display(tool_call: Any) -> dict[str, Any]:
|
||||
58
src/maisaka/display/preview_path_utils.py
Normal file
58
src/maisaka/display/preview_path_utils.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Maisaka Prompt 预览路径工具。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote
|
||||
|
||||
import re
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve()
|
||||
SAFE_NAME_PATTERN = re.compile(r"[^A-Za-z0-9._-]+")
|
||||
|
||||
|
||||
def normalize_preview_name(value: str) -> str:
|
||||
normalized_value = SAFE_NAME_PATTERN.sub("_", str(value or "").strip()).strip("._")
|
||||
if normalized_value:
|
||||
return normalized_value
|
||||
return "unknown"
|
||||
|
||||
|
||||
def normalize_platform_name(platform: str) -> str:
|
||||
normalized_platform = str(platform or "").strip().lower()
|
||||
platform_aliases = {
|
||||
"telegram": "tg",
|
||||
}
|
||||
return normalize_preview_name(platform_aliases.get(normalized_platform, normalized_platform))
|
||||
|
||||
|
||||
def build_preview_chat_dir_name(chat_id: str) -> str:
|
||||
session = chat_manager.get_session_by_session_id(chat_id)
|
||||
if session is not None:
|
||||
platform = normalize_platform_name(session.platform)
|
||||
if session.is_group_session and session.group_id:
|
||||
return f"{platform}_group_{normalize_preview_name(session.group_id)}"
|
||||
if session.user_id:
|
||||
return f"{platform}_private_{normalize_preview_name(session.user_id)}"
|
||||
|
||||
normalized_chat_id = normalize_preview_name(chat_id)
|
||||
if normalized_chat_id != "unknown":
|
||||
return normalized_chat_id
|
||||
return "unknown_chat"
|
||||
|
||||
|
||||
def build_display_path(file_path: Path) -> str:
|
||||
"""构造用于展示的路径,项目内文件优先显示相对路径。"""
|
||||
resolved_path = file_path.resolve()
|
||||
try:
|
||||
return resolved_path.relative_to(REPO_ROOT).as_posix()
|
||||
except ValueError:
|
||||
return resolved_path.as_posix()
|
||||
|
||||
|
||||
def build_file_uri(file_path: Path) -> str:
|
||||
normalized = file_path.resolve().as_posix()
|
||||
return f"file:///{quote(normalized, safe='/:')}"
|
||||
@@ -7,7 +7,6 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal
|
||||
from urllib.parse import quote
|
||||
|
||||
import hashlib
|
||||
import html
|
||||
@@ -27,10 +26,10 @@ from .display_utils import (
|
||||
get_role_badge_label as get_shared_role_badge_label,
|
||||
get_role_badge_style as get_shared_role_badge_style,
|
||||
)
|
||||
from .preview_path_utils import build_display_path, build_file_uri, REPO_ROOT
|
||||
from .prompt_preview_logger import PromptPreviewLogger
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve()
|
||||
DATA_IMAGE_DIR = PROJECT_ROOT / "data" / "images"
|
||||
DATA_IMAGE_DIR = REPO_ROOT / "data" / "images"
|
||||
|
||||
|
||||
class PromptImageDisplayMode(str, Enum):
|
||||
@@ -115,11 +114,6 @@ class PromptCLIVisualizer:
|
||||
digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest()
|
||||
return root / f"{digest}.{image_format}"
|
||||
|
||||
@staticmethod
|
||||
def _build_file_uri(file_path: Path) -> str:
|
||||
normalized = file_path.resolve().as_posix()
|
||||
return f"file:///{quote(normalized, safe='/:')}"
|
||||
|
||||
@staticmethod
|
||||
def _build_official_image_path(image_format: str, image_base64: str) -> Path | None:
|
||||
normalized_format = PromptCLIVisualizer._normalize_image_format(image_format)
|
||||
@@ -140,7 +134,7 @@ class PromptCLIVisualizer:
|
||||
normalized_format = PromptCLIVisualizer._normalize_image_format(image_format) or "bin"
|
||||
official_path = PromptCLIVisualizer._build_official_image_path(image_format, image_base64)
|
||||
if official_path is not None:
|
||||
return PromptCLIVisualizer._build_file_uri(official_path), official_path
|
||||
return build_file_uri(official_path), official_path
|
||||
|
||||
try:
|
||||
image_bytes = b64decode(image_base64)
|
||||
@@ -153,7 +147,7 @@ class PromptCLIVisualizer:
|
||||
path.write_bytes(image_bytes)
|
||||
except Exception:
|
||||
return None
|
||||
return PromptCLIVisualizer._build_file_uri(path), path
|
||||
return build_file_uri(path), path
|
||||
|
||||
@classmethod
|
||||
def _render_image_item(cls, image_format: str, image_base64: str, settings: PromptImageDisplaySettings) -> Panel:
|
||||
@@ -169,8 +163,9 @@ class PromptCLIVisualizer:
|
||||
path_result = cls._build_image_file_link(image_format, image_base64)
|
||||
if path_result is not None:
|
||||
file_uri, file_path = path_result
|
||||
display_path = build_display_path(file_path)
|
||||
preview_parts: List[RenderableType] = [
|
||||
Text(f"图片格式 image/{normalized_format} {size_text} 路径:{file_path}", style="magenta")
|
||||
Text(f"图片格式 image/{normalized_format} {size_text} 路径:{display_path}", style="magenta")
|
||||
]
|
||||
|
||||
preview_parts.append(Text.from_markup(f"[link={file_uri}]点击打开图片[/link]", style="cyan"))
|
||||
@@ -181,6 +176,16 @@ class PromptCLIVisualizer:
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_image_pair(item: Any) -> tuple[str, str] | None:
|
||||
"""兼容图片片段被序列化为 tuple 或 list 的两种形式。"""
|
||||
|
||||
if isinstance(item, (tuple, list)) and len(item) == 2:
|
||||
image_format, image_base64 = item
|
||||
if isinstance(image_format, str) and isinstance(image_base64, str):
|
||||
return image_format, image_base64
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _render_message_content(cls, content: Any, settings: PromptImageDisplaySettings) -> RenderableType:
|
||||
if isinstance(content, str):
|
||||
@@ -192,11 +197,11 @@ class PromptCLIVisualizer:
|
||||
if isinstance(item, str):
|
||||
parts.append(Text(item))
|
||||
continue
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
image_format, image_base64 = item
|
||||
if isinstance(image_format, str) and isinstance(image_base64, str):
|
||||
parts.append(cls._render_image_item(image_format, image_base64, settings))
|
||||
continue
|
||||
image_pair = cls._extract_image_pair(item)
|
||||
if image_pair is not None:
|
||||
image_format, image_base64 = image_pair
|
||||
parts.append(cls._render_image_item(image_format, image_base64, settings))
|
||||
continue
|
||||
if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str):
|
||||
parts.append(Text(item["text"]))
|
||||
else:
|
||||
@@ -218,8 +223,9 @@ class PromptCLIVisualizer:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
continue
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
image_format, image_base64 = item
|
||||
image_pair = cls._extract_image_pair(item)
|
||||
if image_pair is not None:
|
||||
image_format, image_base64 = image_pair
|
||||
approx_size = max(0, len(str(image_base64)) * 3 // 4)
|
||||
parts.append(f"[图片 image/{image_format} {approx_size} B]")
|
||||
continue
|
||||
@@ -242,6 +248,85 @@ class PromptCLIVisualizer:
|
||||
def format_tool_call_for_display(cls, tool_call: Any) -> Dict[str, Any]:
|
||||
return normalize_tool_call_for_display(tool_call)
|
||||
|
||||
@classmethod
|
||||
def _build_tool_card_title(cls, tool_call: Any) -> str:
|
||||
"""构建 HTML 中工具卡片的折叠标题。"""
|
||||
|
||||
normalized_tool_call = cls.format_tool_call_for_display(tool_call)
|
||||
tool_name = str(normalized_tool_call.get("name") or "").strip()
|
||||
return tool_name or "unknown"
|
||||
|
||||
@classmethod
|
||||
def _build_tool_call_html(cls, tool_call: Any) -> str:
|
||||
"""将单个工具调用渲染为默认折叠的 HTML 卡片。"""
|
||||
|
||||
normalized_tool_call = cls.format_tool_call_for_display(tool_call)
|
||||
tool_name = cls._build_tool_card_title(tool_call)
|
||||
tool_call_id = str(normalized_tool_call.get("id") or "").strip()
|
||||
tool_arguments = normalized_tool_call.get("arguments")
|
||||
|
||||
tool_meta_html = ""
|
||||
if tool_call_id:
|
||||
tool_meta_html = (
|
||||
"<div class='tool-card-meta'>"
|
||||
"<span class='tool-card-meta-label'>调用 ID</span>"
|
||||
f"<code>{html.escape(tool_call_id)}</code>"
|
||||
"</div>"
|
||||
)
|
||||
|
||||
return (
|
||||
"<details class='tool-card tool-call-card'>"
|
||||
"<summary class='tool-card-summary'>"
|
||||
f"<span class='tool-card-name'>{html.escape(tool_name)}</span>"
|
||||
"</summary>"
|
||||
"<div class='tool-card-body'>"
|
||||
f"{tool_meta_html}"
|
||||
f"<pre>{html.escape(json.dumps(tool_arguments, ensure_ascii=False, indent=2, default=str))}</pre>"
|
||||
"</div>"
|
||||
"</details>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_tool_definition_fields(cls, tool_definition: dict[str, Any]) -> tuple[str, str, Any]:
|
||||
"""提取工具定义中的名称、描述和详情内容。"""
|
||||
|
||||
function_info = tool_definition.get("function")
|
||||
if isinstance(function_info, dict):
|
||||
tool_name = str(function_info.get("name") or "").strip() or "unknown"
|
||||
description = str(function_info.get("description") or "").strip()
|
||||
detail_payload = function_info
|
||||
else:
|
||||
tool_name = str(tool_definition.get("name") or "").strip() or "unknown"
|
||||
description = str(tool_definition.get("description") or "").strip()
|
||||
detail_payload = tool_definition
|
||||
return tool_name, description, detail_payload
|
||||
|
||||
@classmethod
|
||||
def _build_tool_definition_html(cls, tool_definition: dict[str, Any]) -> str:
|
||||
"""将单个传入工具定义渲染为默认折叠的 HTML 卡片。"""
|
||||
|
||||
tool_name, description, detail_payload = cls._extract_tool_definition_fields(tool_definition)
|
||||
description_html = ""
|
||||
if description:
|
||||
description_html = (
|
||||
"<div class='tool-card-meta'>"
|
||||
"<span class='tool-card-meta-label'>说明</span>"
|
||||
f"<span>{html.escape(description)}</span>"
|
||||
"</div>"
|
||||
)
|
||||
|
||||
return (
|
||||
"<details class='tool-card tool-definition-card'>"
|
||||
"<summary class='tool-card-summary'>"
|
||||
f"<span class='tool-card-name'>{html.escape(tool_name)}</span>"
|
||||
"</summary>"
|
||||
"<div class='tool-card-body'>"
|
||||
f"{description_html}"
|
||||
f"<pre>{html.escape(json.dumps(detail_payload, ensure_ascii=False, indent=2, default=str))}</pre>"
|
||||
"</div>"
|
||||
"</details>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _render_tool_call_panel(cls, tool_call: Any, index: int, parent_index: int) -> Panel:
|
||||
title = Text.assemble(
|
||||
@@ -291,6 +376,20 @@ class PromptCLIVisualizer:
|
||||
|
||||
return "\n\n" + ("\n\n" + ("=" * 80) + "\n\n").join(sections) if sections else "[空 Prompt]"
|
||||
|
||||
@classmethod
|
||||
def _build_tool_definition_dump_text(cls, tool_definitions: list[dict[str, Any]] | None) -> str:
|
||||
"""构建传入工具定义的文本备份内容。"""
|
||||
|
||||
if not tool_definitions:
|
||||
return ""
|
||||
|
||||
sections: List[str] = ["[tool_definitions]"]
|
||||
for index, tool_definition in enumerate(tool_definitions, start=1):
|
||||
tool_name, _, detail_payload = cls._extract_tool_definition_fields(tool_definition)
|
||||
sections.append(f"[{index}] name={tool_name}")
|
||||
sections.append(json.dumps(detail_payload, ensure_ascii=False, indent=2, default=str))
|
||||
return "\n\n".join(sections).strip()
|
||||
|
||||
@classmethod
|
||||
def _render_message_content_html(cls, content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
@@ -302,8 +401,9 @@ class PromptCLIVisualizer:
|
||||
if isinstance(item, str):
|
||||
parts.append(f"<pre>{html.escape(item)}</pre>")
|
||||
continue
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
image_format, image_base64 = item
|
||||
image_pair = cls._extract_image_pair(item)
|
||||
if image_pair is not None:
|
||||
image_format, image_base64 = image_pair
|
||||
image_html = cls._render_image_item_html(str(image_format), str(image_base64))
|
||||
parts.append(image_html)
|
||||
continue
|
||||
@@ -332,14 +432,44 @@ class PromptCLIVisualizer:
|
||||
)
|
||||
|
||||
file_uri, file_path = path_result
|
||||
display_path = build_display_path(file_path)
|
||||
return (
|
||||
"<div class='image-card'>"
|
||||
f"<div class='image-meta'>图片 image/{html.escape(normalized_format)} {html.escape(size_text)}</div>"
|
||||
f"<div class='image-path'>{html.escape(str(file_path))}</div>"
|
||||
f"<a class='image-preview-link' href='{html.escape(file_uri, quote=True)}'>"
|
||||
f"<img class='image-preview' src='{html.escape(file_uri, quote=True)}' alt='图片预览' />"
|
||||
"</a>"
|
||||
f"<div class='image-path'>{html.escape(display_path)}</div>"
|
||||
f"<a class='image-link' href='{html.escape(file_uri, quote=True)}'>打开图片</a>"
|
||||
"</div>"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_preview_access_body(
|
||||
*,
|
||||
viewer_label: str,
|
||||
viewer_path: Path,
|
||||
viewer_link_text: str,
|
||||
dump_label: str,
|
||||
dump_path: Path,
|
||||
dump_link_text: str,
|
||||
) -> RenderableType:
|
||||
viewer_uri = build_file_uri(viewer_path)
|
||||
dump_uri = build_file_uri(dump_path)
|
||||
viewer_display_path = build_display_path(viewer_path)
|
||||
dump_display_path = build_display_path(dump_path)
|
||||
|
||||
return Group(
|
||||
Text.from_markup(
|
||||
f"[bold green]{viewer_label}:{viewer_display_path}[/bold green] "
|
||||
f"[link={viewer_uri}]{viewer_link_text}[/link]"
|
||||
),
|
||||
Text.from_markup(
|
||||
f"[magenta]{dump_label}:{dump_display_path}[/magenta] "
|
||||
f"[cyan][link={dump_uri}]{dump_link_text}[/link][/cyan]"
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _build_html_role_class(cls, role: str) -> str:
|
||||
return {
|
||||
@@ -356,6 +486,7 @@ class PromptCLIVisualizer:
|
||||
*,
|
||||
request_kind: str,
|
||||
selection_reason: str,
|
||||
tool_definitions: list[dict[str, Any]] | None = None,
|
||||
) -> str:
|
||||
panel_title, _ = cls.get_request_panel_style(request_kind)
|
||||
message_cards: List[str] = []
|
||||
@@ -378,16 +509,12 @@ class PromptCLIVisualizer:
|
||||
tool_panels = ""
|
||||
raw_tool_calls = message.get("tool_calls") or []
|
||||
if isinstance(raw_tool_calls, list) and raw_tool_calls:
|
||||
tool_items = []
|
||||
for tool_call_index, tool_call in enumerate(raw_tool_calls, start=1):
|
||||
normalized_tool_call = cls.format_tool_call_for_display(tool_call)
|
||||
tool_items.append(
|
||||
"<div class='tool-panel'>"
|
||||
f"<div class='tool-panel-title'>工具调用 #{index}.{tool_call_index}</div>"
|
||||
f"<pre>{html.escape(json.dumps(normalized_tool_call, ensure_ascii=False, indent=2, default=str))}</pre>"
|
||||
"</div>"
|
||||
)
|
||||
tool_panels = "".join(tool_items)
|
||||
tool_panels = (
|
||||
"<div class='tool-list'>"
|
||||
"<div class='tool-list-title'>工具调用</div>"
|
||||
f"{''.join(cls._build_tool_call_html(tool_call) for tool_call in raw_tool_calls)}"
|
||||
"</div>"
|
||||
)
|
||||
|
||||
message_cards.append(
|
||||
"<section class='message-card'>"
|
||||
@@ -405,6 +532,21 @@ class PromptCLIVisualizer:
|
||||
if selection_reason.strip():
|
||||
subtitle_html = f"<div class='subtitle'>{html.escape(selection_reason)}</div>"
|
||||
|
||||
tool_definition_section_html = ""
|
||||
if tool_definitions:
|
||||
tool_definition_section_html = (
|
||||
"<section class='message-card tool-definition-section'>"
|
||||
"<div class='message-head'>"
|
||||
"<span class='role-badge tool'>全部工具</span>"
|
||||
f"<span class='message-index'>{len(tool_definitions)} 个</span>"
|
||||
"</div>"
|
||||
"<div class='tool-list'>"
|
||||
"<div class='tool-list-title'>本次送入模型的工具定义</div>"
|
||||
f"{''.join(cls._build_tool_definition_html(tool_definition) for tool_definition in tool_definitions)}"
|
||||
"</div>"
|
||||
"</section>"
|
||||
)
|
||||
|
||||
return f"""<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
@@ -491,7 +633,7 @@ class PromptCLIVisualizer:
|
||||
font-weight: 600;
|
||||
}}
|
||||
.message-content pre,
|
||||
.tool-panel pre {{
|
||||
.tool-card pre {{
|
||||
margin: 0;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
@@ -517,18 +659,81 @@ class PromptCLIVisualizer:
|
||||
border-radius: 8px;
|
||||
padding: 3px 8px;
|
||||
}}
|
||||
.tool-panel {{
|
||||
.tool-list {{
|
||||
margin-top: 14px;
|
||||
}}
|
||||
.tool-list-title {{
|
||||
color: #86198f;
|
||||
font-size: 13px;
|
||||
font-weight: 800;
|
||||
margin-bottom: 10px;
|
||||
}}
|
||||
.tool-card {{
|
||||
margin-top: 12px;
|
||||
background: #fcf4ff;
|
||||
border: 1px solid #f0d7fb;
|
||||
border-radius: 14px;
|
||||
padding: 12px 14px;
|
||||
overflow: hidden;
|
||||
}}
|
||||
.tool-panel-title {{
|
||||
color: #a21caf;
|
||||
.tool-call-card {{
|
||||
border-color: #ff8700;
|
||||
}}
|
||||
.tool-card:first-of-type {{
|
||||
margin-top: 0;
|
||||
}}
|
||||
.tool-card-summary {{
|
||||
list-style: none;
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 12px 14px;
|
||||
color: #86198f;
|
||||
font-size: 13px;
|
||||
font-weight: 800;
|
||||
}}
|
||||
.tool-card-summary::-webkit-details-marker {{
|
||||
display: none;
|
||||
}}
|
||||
.tool-card-summary::after {{
|
||||
content: "展开";
|
||||
color: #a21caf;
|
||||
font-size: 12px;
|
||||
font-weight: 700;
|
||||
margin-bottom: 8px;
|
||||
}}
|
||||
.tool-card[open] .tool-card-summary::after {{
|
||||
content: "收起";
|
||||
}}
|
||||
.tool-card-name {{
|
||||
word-break: break-word;
|
||||
}}
|
||||
.tool-card-body {{
|
||||
border-top: 1px solid #f0d7fb;
|
||||
padding: 12px 14px;
|
||||
background: rgba(255, 255, 255, 0.52);
|
||||
}}
|
||||
.tool-call-card .tool-card-body {{
|
||||
border-top-color: #ff8700;
|
||||
}}
|
||||
.tool-card-meta {{
|
||||
margin-bottom: 10px;
|
||||
color: #a21caf;
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
align-items: center;
|
||||
flex-wrap: wrap;
|
||||
}}
|
||||
.tool-card-meta-label {{
|
||||
font-weight: 700;
|
||||
}}
|
||||
.tool-card-meta code {{
|
||||
background: #faf5ff;
|
||||
border: 1px solid #e9d5ff;
|
||||
border-radius: 8px;
|
||||
padding: 3px 8px;
|
||||
}}
|
||||
.tool-card pre {{
|
||||
color: #3b0764;
|
||||
}}
|
||||
.image-card {{
|
||||
background: #f8fafc;
|
||||
@@ -547,6 +752,22 @@ class PromptCLIVisualizer:
|
||||
font-family: "Cascadia Mono", "JetBrains Mono", "Consolas", monospace;
|
||||
word-break: break-all;
|
||||
}}
|
||||
.image-preview-link {{
|
||||
display: block;
|
||||
margin-top: 10px;
|
||||
}}
|
||||
.image-preview {{
|
||||
display: block;
|
||||
max-width: min(100%, 560px);
|
||||
max-height: 420px;
|
||||
width: auto;
|
||||
height: auto;
|
||||
border-radius: 12px;
|
||||
border: 1px solid #dbe4f0;
|
||||
background: #fff;
|
||||
box-shadow: 0 8px 20px rgba(15, 23, 42, 0.08);
|
||||
object-fit: contain;
|
||||
}}
|
||||
.image-link {{
|
||||
display: inline-block;
|
||||
margin-top: 8px;
|
||||
@@ -564,6 +785,7 @@ class PromptCLIVisualizer:
|
||||
{subtitle_html}
|
||||
</header>
|
||||
{''.join(message_cards)}
|
||||
{tool_definition_section_html}
|
||||
</main>
|
||||
</body>
|
||||
</html>"""
|
||||
@@ -578,6 +800,7 @@ class PromptCLIVisualizer:
|
||||
request_kind: str,
|
||||
selection_reason: str,
|
||||
image_display_mode: Literal["legacy", "path_link"],
|
||||
tool_definitions: list[dict[str, Any]] | None = None,
|
||||
) -> RenderableType:
|
||||
"""构建用于查看完整 prompt 的折叠入口内容。"""
|
||||
|
||||
@@ -603,10 +826,14 @@ class PromptCLIVisualizer:
|
||||
viewer_messages.append(normalized_message)
|
||||
|
||||
prompt_dump_text = cls._build_prompt_dump_text(messages)
|
||||
tool_definition_dump_text = cls._build_tool_definition_dump_text(tool_definitions)
|
||||
if tool_definition_dump_text:
|
||||
prompt_dump_text = f"{prompt_dump_text}\n\n{'=' * 80}\n\n{tool_definition_dump_text}"
|
||||
viewer_html_text = cls._build_prompt_viewer_html(
|
||||
viewer_messages,
|
||||
request_kind=request_kind,
|
||||
selection_reason=selection_reason,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
saved_paths = PromptPreviewLogger.save_preview_files(
|
||||
chat_id,
|
||||
@@ -618,18 +845,13 @@ class PromptCLIVisualizer:
|
||||
)
|
||||
viewer_html_path = saved_paths[".html"]
|
||||
prompt_dump_path = saved_paths[".txt"]
|
||||
viewer_uri = cls._build_file_uri(viewer_html_path)
|
||||
dump_uri = cls._build_file_uri(prompt_dump_path)
|
||||
|
||||
body = Group(
|
||||
Text.from_markup(
|
||||
f"[bold green]富文本预览:{viewer_html_path}[/bold green] "
|
||||
f"[link={viewer_uri}]点击在浏览器打开富文本 Prompt 视图[/link]"
|
||||
),
|
||||
Text.from_markup(
|
||||
f"[magenta]原始文本备份:{prompt_dump_path}[/magenta] "
|
||||
f"[cyan][link={dump_uri}]点击直接打开 Prompt 文本[/link][/cyan]"
|
||||
),
|
||||
body = cls._build_preview_access_body(
|
||||
viewer_label="html预览",
|
||||
viewer_path=viewer_html_path,
|
||||
viewer_link_text="在浏览器打开 Prompt",
|
||||
dump_label="原始文本",
|
||||
dump_path=prompt_dump_path,
|
||||
dump_link_text="点击打开 Prompt 文本",
|
||||
)
|
||||
return body
|
||||
|
||||
@@ -644,6 +866,7 @@ class PromptCLIVisualizer:
|
||||
selection_reason: str,
|
||||
image_display_mode: Literal["legacy", "path_link"],
|
||||
folded: bool,
|
||||
tool_definitions: list[dict[str, Any]] | None = None,
|
||||
) -> Panel:
|
||||
"""构建用于嵌入结果面板中的 Prompt 区块。"""
|
||||
|
||||
@@ -656,6 +879,7 @@ class PromptCLIVisualizer:
|
||||
request_kind=request_kind,
|
||||
selection_reason=selection_reason,
|
||||
image_display_mode=image_display_mode,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
else:
|
||||
ordered_panels = cls.build_prompt_panels(
|
||||
@@ -782,18 +1006,13 @@ class PromptCLIVisualizer:
|
||||
)
|
||||
viewer_html_path = saved_paths[".html"]
|
||||
text_dump_path = saved_paths[".txt"]
|
||||
viewer_uri = cls._build_file_uri(viewer_html_path)
|
||||
dump_uri = cls._build_file_uri(text_dump_path)
|
||||
|
||||
body = Group(
|
||||
Text.from_markup(
|
||||
f"[bold green]富文本预览:{viewer_html_path}[/bold green] "
|
||||
f"[link={viewer_uri}]点击在浏览器打开富文本 Prompt 视图[/link]"
|
||||
),
|
||||
Text.from_markup(
|
||||
f"[magenta]原始文本备份:{text_dump_path}[/magenta] "
|
||||
f"[cyan][link={dump_uri}]点击直接打开 Prompt 文本[/link][/cyan]"
|
||||
),
|
||||
body = cls._build_preview_access_body(
|
||||
viewer_label="富文本预览",
|
||||
viewer_path=viewer_html_path,
|
||||
viewer_link_text="点击在浏览器打开富文本 Prompt 视图",
|
||||
dump_label="原始文本备份",
|
||||
dump_path=text_dump_path,
|
||||
dump_link_text="点击直接打开 Prompt 文本",
|
||||
)
|
||||
return body
|
||||
|
||||
@@ -2,34 +2,29 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from uuid import uuid4
|
||||
|
||||
from src.config.config import global_config
|
||||
from .preview_path_utils import build_preview_chat_dir_name, normalize_preview_name
|
||||
|
||||
|
||||
class PromptPreviewLogger:
|
||||
"""负责保存 Maisaka Prompt 预览文件并控制目录容量。"""
|
||||
|
||||
_BASE_DIR = Path("logs") / "maisaka_prompt"
|
||||
_MAX_PREVIEW_GROUPS_PER_CHAT = 1024
|
||||
_TRIM_COUNT = 100
|
||||
_SAFE_NAME_PATTERN = re.compile(r"[^A-Za-z0-9._-]+")
|
||||
|
||||
@classmethod
|
||||
def _get_max_per_chat(cls) -> int:
|
||||
"""从配置中获取每个聊天流最大保存的预览数量。"""
|
||||
|
||||
return getattr(global_config.chat, "plan_reply_log_max_per_chat", 1000)
|
||||
|
||||
@classmethod
|
||||
def _normalize_chat_id(cls, chat_id: str) -> str:
|
||||
normalized_chat_id = cls._SAFE_NAME_PATTERN.sub("_", str(chat_id or "").strip()).strip("._")
|
||||
if normalized_chat_id:
|
||||
return normalized_chat_id
|
||||
return "unknown_chat"
|
||||
def _build_file_stem(cls, chat_dir: Path) -> str:
|
||||
base_stem = str(int(time.time() * 1000))
|
||||
candidate_stem = base_stem
|
||||
suffix_index = 1
|
||||
while any((chat_dir / f"{candidate_stem}{suffix}").exists() for suffix in (".html", ".txt")):
|
||||
candidate_stem = f"{base_stem}_{suffix_index}"
|
||||
suffix_index += 1
|
||||
return candidate_stem
|
||||
|
||||
@classmethod
|
||||
def save_preview_files(
|
||||
@@ -40,10 +35,10 @@ class PromptPreviewLogger:
|
||||
) -> Dict[str, Path]:
|
||||
"""保存同一份 Prompt 预览的多个文件并执行超量清理。"""
|
||||
|
||||
normalized_category = cls._normalize_chat_id(category)
|
||||
chat_dir = (cls._BASE_DIR / normalized_category / cls._normalize_chat_id(chat_id)).resolve()
|
||||
normalized_category = normalize_preview_name(category)
|
||||
chat_dir = (cls._BASE_DIR / normalized_category / build_preview_chat_dir_name(chat_id)).resolve()
|
||||
chat_dir.mkdir(parents=True, exist_ok=True)
|
||||
stem = f"{int(time.time() * 1000)}_{uuid4().hex[:8]}"
|
||||
stem = cls._build_file_stem(chat_dir)
|
||||
saved_paths: Dict[str, Path] = {}
|
||||
try:
|
||||
for suffix, content in files.items():
|
||||
@@ -65,15 +60,14 @@ class PromptPreviewLogger:
|
||||
continue
|
||||
grouped_files.setdefault(file_path.stem, []).append(file_path)
|
||||
|
||||
max_per_chat = cls._get_max_per_chat()
|
||||
if len(grouped_files) <= max_per_chat:
|
||||
if len(grouped_files) <= cls._MAX_PREVIEW_GROUPS_PER_CHAT:
|
||||
return
|
||||
|
||||
sorted_groups = sorted(
|
||||
grouped_files.items(),
|
||||
key=lambda item: min(path.stat().st_mtime for path in item[1]),
|
||||
)
|
||||
overflow_count = len(grouped_files) - max_per_chat
|
||||
overflow_count = len(grouped_files) - cls._MAX_PREVIEW_GROUPS_PER_CHAT
|
||||
trim_count = min(len(sorted_groups), max(cls._TRIM_COUNT, overflow_count))
|
||||
for _, file_group in sorted_groups[:trim_count]:
|
||||
for old_file in file_group:
|
||||
163
src/maisaka/display/stage_status_board.py
Normal file
163
src/maisaka/display/stage_status_board.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Maisaka 阶段状态看板。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
class MaisakaStageStatusBoard:
|
||||
"""维护 Maisaka 阶段状态,并在独立终端中展示。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._enabled = False
|
||||
self._entries: dict[str, dict[str, Any]] = {}
|
||||
self._viewer_process: Optional[subprocess.Popen[Any]] = None
|
||||
self._state_file = Path("temp") / "maisaka_stage_status.json"
|
||||
self._state_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def enable(self) -> None:
|
||||
"""启用阶段状态看板。"""
|
||||
|
||||
with self._lock:
|
||||
if self._enabled:
|
||||
return
|
||||
self._enabled = True
|
||||
self._write_state_locked()
|
||||
self._ensure_viewer_process_locked()
|
||||
|
||||
def disable(self) -> None:
|
||||
"""禁用阶段状态看板。"""
|
||||
|
||||
with self._lock:
|
||||
self._enabled = False
|
||||
self._entries.clear()
|
||||
self._write_state_locked()
|
||||
process = self._viewer_process
|
||||
self._viewer_process = None
|
||||
|
||||
if process is not None and process.poll() is None:
|
||||
try:
|
||||
process.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def update(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
session_name: str,
|
||||
stage: str,
|
||||
detail: str = "",
|
||||
round_text: str = "",
|
||||
agent_state: str = "",
|
||||
) -> None:
|
||||
"""更新一个会话的阶段状态。"""
|
||||
|
||||
with self._lock:
|
||||
if not self._enabled:
|
||||
return
|
||||
now = time.time()
|
||||
current = self._entries.get(session_id, {})
|
||||
previous_stage = str(current.get("stage") or "").strip()
|
||||
stage_started_at = float(current.get("stage_started_at") or now)
|
||||
if previous_stage != stage:
|
||||
stage_started_at = now
|
||||
self._entries[session_id] = {
|
||||
"session_id": session_id,
|
||||
"session_name": session_name,
|
||||
"stage": stage,
|
||||
"detail": detail,
|
||||
"round_text": round_text,
|
||||
"agent_state": agent_state,
|
||||
"stage_started_at": stage_started_at,
|
||||
"updated_at": now,
|
||||
}
|
||||
self._write_state_locked()
|
||||
|
||||
def remove(self, session_id: str) -> None:
|
||||
"""移除一个会话的阶段状态。"""
|
||||
|
||||
with self._lock:
|
||||
if not self._enabled:
|
||||
return
|
||||
self._entries.pop(session_id, None)
|
||||
self._write_state_locked()
|
||||
|
||||
def _write_state_locked(self) -> None:
|
||||
payload = {
|
||||
"enabled": self._enabled,
|
||||
"host_pid": os.getpid(),
|
||||
"updated_at": time.time(),
|
||||
"entries": list(self._entries.values()),
|
||||
}
|
||||
tmp_file = self._state_file.with_suffix(".tmp")
|
||||
tmp_file.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
tmp_file.replace(self._state_file)
|
||||
|
||||
def _ensure_viewer_process_locked(self) -> None:
|
||||
if not sys.platform.startswith("win"):
|
||||
return
|
||||
if self._viewer_process is not None and self._viewer_process.poll() is None:
|
||||
return
|
||||
creationflags = getattr(subprocess, "CREATE_NEW_CONSOLE", 0)
|
||||
viewer_script = Path(__file__).resolve().with_name("stage_status_viewer.py")
|
||||
self._viewer_process = subprocess.Popen(
|
||||
[
|
||||
sys.executable,
|
||||
str(viewer_script),
|
||||
str(self._state_file.resolve()),
|
||||
],
|
||||
creationflags=creationflags,
|
||||
cwd=str(Path.cwd()),
|
||||
)
|
||||
|
||||
|
||||
_stage_board = MaisakaStageStatusBoard()
|
||||
|
||||
|
||||
def enable_stage_status_board() -> None:
|
||||
"""启用控制台阶段状态看板。"""
|
||||
|
||||
_stage_board.enable()
|
||||
|
||||
|
||||
def disable_stage_status_board() -> None:
|
||||
"""禁用控制台阶段状态看板。"""
|
||||
|
||||
_stage_board.disable()
|
||||
|
||||
|
||||
def update_stage_status(
|
||||
*,
|
||||
session_id: str,
|
||||
session_name: str,
|
||||
stage: str,
|
||||
detail: str = "",
|
||||
round_text: str = "",
|
||||
agent_state: str = "",
|
||||
) -> None:
|
||||
"""更新控制台阶段状态。"""
|
||||
|
||||
_stage_board.update(
|
||||
session_id=session_id,
|
||||
session_name=session_name,
|
||||
stage=stage,
|
||||
detail=detail,
|
||||
round_text=round_text,
|
||||
agent_state=agent_state,
|
||||
)
|
||||
|
||||
|
||||
def remove_stage_status(session_id: str) -> None:
|
||||
"""移除控制台阶段状态。"""
|
||||
|
||||
_stage_board.remove(session_id)
|
||||
93
src/maisaka/display/stage_status_viewer.py
Normal file
93
src/maisaka/display/stage_status_viewer.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Maisaka 阶段状态看板查看器。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
|
||||
def _clear_screen() -> None:
|
||||
os.system("cls" if sys.platform.startswith("win") else "clear")
|
||||
|
||||
|
||||
def _load_state(state_file: Path) -> dict[str, Any]:
|
||||
if not state_file.exists():
|
||||
return {}
|
||||
try:
|
||||
return json.loads(state_file.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _render(state: dict[str, Any]) -> str:
|
||||
entries = state.get("entries")
|
||||
if not isinstance(entries, list):
|
||||
entries = []
|
||||
|
||||
lines = ["Maisaka 阶段看板", "=" * 72, ""]
|
||||
if not entries:
|
||||
lines.append("当前没有活跃会话。")
|
||||
return "\n".join(lines)
|
||||
|
||||
entries = sorted(
|
||||
[entry for entry in entries if isinstance(entry, dict)],
|
||||
key=lambda item: str(item.get("session_name") or item.get("session_id") or ""),
|
||||
)
|
||||
now = time.time()
|
||||
for entry in entries:
|
||||
session_name = str(entry.get("session_name") or entry.get("session_id") or "").strip() or "unknown"
|
||||
session_id = str(entry.get("session_id") or "").strip()
|
||||
stage = str(entry.get("stage") or "").strip() or "未知"
|
||||
detail = str(entry.get("detail") or "").strip() or "-"
|
||||
round_text = str(entry.get("round_text") or "").strip()
|
||||
agent_state = str(entry.get("agent_state") or "").strip() or "-"
|
||||
stage_started_at = float(entry.get("stage_started_at") or now)
|
||||
elapsed = max(0.0, now - stage_started_at)
|
||||
|
||||
lines.append(f"Chat: {session_name}")
|
||||
if session_id and session_id != session_name:
|
||||
lines.append(f"ID: {session_id}")
|
||||
lines.append(f"阶段: {stage}")
|
||||
if round_text:
|
||||
lines.append(f"轮次: {round_text}")
|
||||
lines.append(f"详情: {detail}")
|
||||
lines.append(f"状态: {agent_state}")
|
||||
lines.append(f"阶段耗时: {elapsed:.1f}s")
|
||||
lines.append("-" * 72)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
if len(sys.argv) < 2:
|
||||
return 1
|
||||
|
||||
state_file = Path(sys.argv[1]).resolve()
|
||||
log_file = state_file.with_name("maisaka_stage_status_viewer.log")
|
||||
last_render = ""
|
||||
while True:
|
||||
try:
|
||||
state = _load_state(state_file)
|
||||
if not state.get("enabled", False):
|
||||
return 0
|
||||
|
||||
rendered = _render(state)
|
||||
if rendered != last_render:
|
||||
_clear_screen()
|
||||
print(rendered, flush=True)
|
||||
last_render = rendered
|
||||
time.sleep(0.5)
|
||||
except Exception:
|
||||
log_file.write_text(traceback.format_exc(), encoding="utf-8")
|
||||
time.sleep(3)
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
125
src/maisaka/history_post_processor.py
Normal file
125
src/maisaka/history_post_processor.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Maisaka 历史消息轮次结束后处理。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .context_messages import AssistantMessage, LLMContextMessage, ToolResultMessage
|
||||
from .history_utils import drop_leading_orphan_tool_results, drop_orphan_tool_results
|
||||
|
||||
TIMING_HISTORY_TOOL_NAMES = {"continue", "finish", "no_reply", "wait"}
|
||||
EARLY_TRIM_RATIO = 0.2
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HistoryPostProcessResult:
|
||||
"""历史后处理结果。"""
|
||||
|
||||
history: list[LLMContextMessage]
|
||||
removed_count: int
|
||||
remaining_context_count: int
|
||||
|
||||
|
||||
def process_chat_history_after_cycle(
|
||||
chat_history: list[LLMContextMessage],
|
||||
*,
|
||||
max_context_size: int,
|
||||
) -> HistoryPostProcessResult:
|
||||
"""在每轮结束后统一执行历史裁切与清理。"""
|
||||
|
||||
processed_history = list(chat_history)
|
||||
removed_timing_tool_count = _remove_early_timing_tool_records(processed_history)
|
||||
removed_assistant_thought_count = _remove_early_assistant_thoughts(processed_history)
|
||||
|
||||
processed_history, orphan_removed_count = drop_orphan_tool_results(processed_history)
|
||||
remaining_context_count = sum(1 for message in processed_history if message.count_in_context)
|
||||
removed_overflow_count = 0
|
||||
|
||||
while remaining_context_count > max_context_size and processed_history:
|
||||
removed_message = processed_history.pop(0)
|
||||
removed_overflow_count += 1
|
||||
if removed_message.count_in_context:
|
||||
remaining_context_count -= 1
|
||||
|
||||
processed_history, leading_orphan_removed_count = drop_leading_orphan_tool_results(processed_history)
|
||||
removed_overflow_count += leading_orphan_removed_count
|
||||
remaining_context_count = sum(1 for message in processed_history if message.count_in_context)
|
||||
removed_count = (
|
||||
removed_timing_tool_count
|
||||
+ removed_assistant_thought_count
|
||||
+ orphan_removed_count
|
||||
+ removed_overflow_count
|
||||
)
|
||||
return HistoryPostProcessResult(
|
||||
history=processed_history,
|
||||
removed_count=removed_count,
|
||||
remaining_context_count=remaining_context_count,
|
||||
)
|
||||
|
||||
|
||||
def _remove_early_timing_tool_records(chat_history: list[LLMContextMessage]) -> int:
|
||||
"""移除最早 20% 的门控/结束类工具链记录。"""
|
||||
|
||||
candidate_assistant_indexes = [
|
||||
index
|
||||
for index, message in enumerate(chat_history)
|
||||
if _is_timing_tool_assistant_message(message)
|
||||
]
|
||||
remove_count = int(len(candidate_assistant_indexes) * EARLY_TRIM_RATIO)
|
||||
if remove_count <= 0:
|
||||
return 0
|
||||
|
||||
removed_indexes = set(candidate_assistant_indexes[:remove_count])
|
||||
removed_tool_call_ids = {
|
||||
tool_call.call_id
|
||||
for index in removed_indexes
|
||||
for tool_call in chat_history[index].tool_calls
|
||||
if tool_call.call_id
|
||||
}
|
||||
|
||||
filtered_history: list[LLMContextMessage] = []
|
||||
removed_total = 0
|
||||
for index, message in enumerate(chat_history):
|
||||
if index in removed_indexes:
|
||||
removed_total += 1
|
||||
continue
|
||||
if isinstance(message, ToolResultMessage) and message.tool_call_id in removed_tool_call_ids:
|
||||
removed_total += 1
|
||||
continue
|
||||
filtered_history.append(message)
|
||||
|
||||
chat_history[:] = filtered_history
|
||||
return removed_total
|
||||
|
||||
|
||||
def _remove_early_assistant_thoughts(chat_history: list[LLMContextMessage]) -> int:
|
||||
"""移除最早 20% 的非工具 assistant 思考内容。"""
|
||||
|
||||
candidate_indexes = [
|
||||
index
|
||||
for index, message in enumerate(chat_history)
|
||||
if isinstance(message, AssistantMessage)
|
||||
and not message.tool_calls
|
||||
and message.source_kind != "perception"
|
||||
and bool(message.content.strip())
|
||||
]
|
||||
remove_count = int(len(candidate_indexes) * EARLY_TRIM_RATIO)
|
||||
if remove_count <= 0:
|
||||
return 0
|
||||
|
||||
removed_indexes = set(candidate_indexes[:remove_count])
|
||||
filtered_history: list[LLMContextMessage] = []
|
||||
removed_total = 0
|
||||
for index, message in enumerate(chat_history):
|
||||
if index in removed_indexes:
|
||||
removed_total += 1
|
||||
continue
|
||||
filtered_history.append(message)
|
||||
|
||||
chat_history[:] = filtered_history
|
||||
return removed_total
|
||||
|
||||
|
||||
def _is_timing_tool_assistant_message(message: LLMContextMessage) -> bool:
|
||||
if not isinstance(message, AssistantMessage) or not message.tool_calls:
|
||||
return False
|
||||
|
||||
return all(tool_call.func_name in TIMING_HISTORY_TOOL_NAMES for tool_call in message.tool_calls)
|
||||
@@ -14,7 +14,7 @@ from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.config.config import global_config
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.llm_models.exceptions import ReqAbortException
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
@@ -35,7 +35,8 @@ from .context_messages import (
|
||||
ToolResultMessage,
|
||||
contains_complex_message,
|
||||
)
|
||||
from .history_utils import build_prefixed_message_sequence, build_session_message_visible_text, drop_leading_orphan_tool_results
|
||||
from .history_post_processor import process_chat_history_after_cycle
|
||||
from .history_utils import build_prefixed_message_sequence, build_session_message_visible_text
|
||||
from .monitor_events import (
|
||||
emit_cycle_start,
|
||||
emit_message_ingested,
|
||||
@@ -53,7 +54,7 @@ logger = get_logger("maisaka_reasoning_engine")
|
||||
TIMING_GATE_CONTEXT_LIMIT = 24
|
||||
TIMING_GATE_MAX_TOKENS = 384
|
||||
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
|
||||
ACTION_HIDDEN_TOOL_NAMES = {"continue", "no_reply", "wait"}
|
||||
ACTION_HIDDEN_TOOL_NAMES = {"continue", "no_reply"}
|
||||
ACTION_BUILTIN_TOOL_NAMES = {tool_spec.name for tool_spec in get_action_tool_specs()}
|
||||
|
||||
|
||||
@@ -94,6 +95,7 @@ class MaisakaReasoningEngine:
|
||||
async def _run_interruptible_planner(
|
||||
self,
|
||||
*,
|
||||
injected_user_messages: Optional[list[str]] = None,
|
||||
tool_definitions: Optional[list[dict[str, Any]]] = None,
|
||||
) -> Any:
|
||||
"""运行一轮可被新消息打断的主 planner 请求。"""
|
||||
@@ -105,6 +107,7 @@ class MaisakaReasoningEngine:
|
||||
try:
|
||||
return await self._runtime._chat_loop_service.chat_loop_step(
|
||||
self._runtime._chat_history,
|
||||
injected_user_messages=injected_user_messages,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
except ReqAbortException:
|
||||
@@ -117,36 +120,27 @@ class MaisakaReasoningEngine:
|
||||
)
|
||||
self._runtime._chat_loop_service.set_interrupt_flag(None)
|
||||
|
||||
async def _run_interruptible_sub_agent(
|
||||
async def _run_timing_gate_sub_agent(
|
||||
self,
|
||||
*,
|
||||
context_message_limit: int,
|
||||
system_prompt: str,
|
||||
tool_definitions: list[dict[str, Any]],
|
||||
) -> Any:
|
||||
"""运行一轮可被新消息打断的临时子代理请求。"""
|
||||
"""运行一轮 Timing Gate 子代理请求。
|
||||
|
||||
interrupt_flag = asyncio.Event()
|
||||
interrupted = False
|
||||
self._runtime._bind_planner_interrupt_flag(interrupt_flag)
|
||||
try:
|
||||
return await self._runtime.run_sub_agent(
|
||||
context_message_limit=context_message_limit,
|
||||
system_prompt=system_prompt,
|
||||
request_kind="timing_gate",
|
||||
interrupt_flag=interrupt_flag,
|
||||
max_tokens=TIMING_GATE_MAX_TOKENS,
|
||||
temperature=0.1,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
except ReqAbortException:
|
||||
interrupted = True
|
||||
raise
|
||||
finally:
|
||||
self._runtime._unbind_planner_interrupt_flag(
|
||||
interrupt_flag,
|
||||
interrupted=interrupted,
|
||||
)
|
||||
Timing Gate 阶段不再响应新的 planner 打断,只有主 planner 阶段允许被打断。
|
||||
"""
|
||||
|
||||
return await self._runtime.run_sub_agent(
|
||||
context_message_limit=context_message_limit,
|
||||
system_prompt=system_prompt,
|
||||
request_kind="timing_gate",
|
||||
interrupt_flag=None,
|
||||
max_tokens=TIMING_GATE_MAX_TOKENS,
|
||||
temperature=0.1,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_timing_gate_fallback_prompt() -> str:
|
||||
@@ -174,22 +168,34 @@ class MaisakaReasoningEngine:
|
||||
except Exception:
|
||||
return self._build_timing_gate_fallback_prompt()
|
||||
|
||||
async def _build_action_tool_definitions(self) -> list[dict[str, Any]]:
|
||||
"""构造 Action Loop 阶段可见的工具定义。"""
|
||||
async def _build_action_tool_definitions(self) -> tuple[list[dict[str, Any]], str]:
|
||||
"""构造 Action Loop 阶段可见的工具定义与 deferred tools 提示。"""
|
||||
|
||||
if self._runtime._tool_registry is None:
|
||||
return []
|
||||
self._runtime.update_deferred_tool_specs([])
|
||||
self._runtime.set_current_action_tool_names([])
|
||||
return [], ""
|
||||
|
||||
tool_specs = await self._runtime._tool_registry.list_tools()
|
||||
return [
|
||||
tool_spec.to_llm_definition()
|
||||
for tool_spec in tool_specs
|
||||
if tool_spec.name not in ACTION_HIDDEN_TOOL_NAMES
|
||||
and (
|
||||
tool_spec.provider_name != "maisaka_builtin"
|
||||
or tool_spec.name in ACTION_BUILTIN_TOOL_NAMES
|
||||
)
|
||||
]
|
||||
visible_builtin_tool_specs: list[ToolSpec] = []
|
||||
deferred_tool_specs: list[ToolSpec] = []
|
||||
for tool_spec in tool_specs:
|
||||
if tool_spec.name in ACTION_HIDDEN_TOOL_NAMES:
|
||||
continue
|
||||
if tool_spec.provider_name == "maisaka_builtin":
|
||||
if tool_spec.name in ACTION_BUILTIN_TOOL_NAMES:
|
||||
visible_builtin_tool_specs.append(tool_spec)
|
||||
continue
|
||||
deferred_tool_specs.append(tool_spec)
|
||||
|
||||
self._runtime.update_deferred_tool_specs(deferred_tool_specs)
|
||||
discovered_deferred_tool_specs = self._runtime.get_discovered_deferred_tool_specs()
|
||||
visible_tool_specs = [*visible_builtin_tool_specs, *discovered_deferred_tool_specs]
|
||||
self._runtime.set_current_action_tool_names([tool_spec.name for tool_spec in visible_tool_specs])
|
||||
return (
|
||||
[tool_spec.to_llm_definition() for tool_spec in visible_tool_specs],
|
||||
self._runtime.build_deferred_tools_reminder(),
|
||||
)
|
||||
|
||||
async def _invoke_tool_call(
|
||||
self,
|
||||
@@ -227,18 +233,19 @@ class MaisakaReasoningEngine:
|
||||
async def _run_timing_gate(
|
||||
self,
|
||||
anchor_message: SessionMessage,
|
||||
) -> tuple[Literal["continue", "no_reply", "wait"], Any, list[str]]:
|
||||
) -> tuple[Literal["continue", "no_reply", "wait"], Any, list[str], list[dict[str, Any]]]:
|
||||
"""运行 Timing Gate 子代理并返回控制决策。"""
|
||||
|
||||
if self._runtime._force_continue_until_reply:
|
||||
if self._runtime._force_next_timing_continue:
|
||||
return self._build_forced_continue_timing_result()
|
||||
|
||||
response = await self._run_interruptible_sub_agent(
|
||||
response = await self._run_timing_gate_sub_agent(
|
||||
context_message_limit=TIMING_GATE_CONTEXT_LIMIT,
|
||||
system_prompt=self._build_timing_gate_system_prompt(),
|
||||
tool_definitions=get_timing_tools(),
|
||||
)
|
||||
tool_result_summaries: list[str] = []
|
||||
tool_monitor_results: list[dict[str, Any]] = []
|
||||
selected_tool_call: Optional[ToolCall] = None
|
||||
for tool_call in response.tool_calls:
|
||||
if tool_call.func_name in TIMING_GATE_TOOL_NAMES:
|
||||
@@ -247,11 +254,11 @@ class MaisakaReasoningEngine:
|
||||
|
||||
if selected_tool_call is None:
|
||||
logger.warning(f"{self._runtime.log_prefix} Timing Gate 未返回有效控制工具,默认继续执行 Action Loop")
|
||||
return "continue", response, tool_result_summaries
|
||||
return "continue", response, tool_result_summaries, tool_monitor_results
|
||||
|
||||
append_history = selected_tool_call.func_name != "continue"
|
||||
append_history = False
|
||||
store_record = selected_tool_call.func_name != "continue"
|
||||
_, result, _ = await self._invoke_tool_call(
|
||||
invocation, result, tool_spec = await self._invoke_tool_call(
|
||||
selected_tool_call,
|
||||
response.content or "",
|
||||
anchor_message,
|
||||
@@ -259,19 +266,31 @@ class MaisakaReasoningEngine:
|
||||
store_record=store_record,
|
||||
)
|
||||
tool_result_summaries.append(self._build_tool_result_summary(selected_tool_call, result))
|
||||
tool_monitor_results.append(
|
||||
self._build_tool_monitor_result(
|
||||
selected_tool_call,
|
||||
invocation,
|
||||
result,
|
||||
duration_ms=0.0,
|
||||
tool_spec=tool_spec,
|
||||
)
|
||||
)
|
||||
self._append_timing_gate_execution_result(response, selected_tool_call, result)
|
||||
|
||||
timing_action = str(result.metadata.get("timing_action") or selected_tool_call.func_name).strip()
|
||||
if timing_action not in TIMING_GATE_TOOL_NAMES:
|
||||
logger.warning(
|
||||
f"{self._runtime.log_prefix} Timing Gate 返回未知动作 {timing_action!r},将按 continue 处理"
|
||||
)
|
||||
return "continue", response, tool_result_summaries
|
||||
return timing_action, response, tool_result_summaries
|
||||
return "continue", response, tool_result_summaries, tool_monitor_results
|
||||
return timing_action, response, tool_result_summaries, tool_monitor_results
|
||||
|
||||
def _build_forced_continue_timing_result(self) -> tuple[Literal["continue"], ChatResponse, list[str]]:
|
||||
def _build_forced_continue_timing_result(
|
||||
self,
|
||||
) -> tuple[Literal["continue"], ChatResponse, list[str], list[dict[str, Any]]]:
|
||||
"""构造跳过 Timing Gate 时使用的伪 continue 结果。"""
|
||||
|
||||
reason = self._runtime._build_force_continue_timing_reason()
|
||||
reason = self._runtime._consume_force_next_timing_continue_reason() or "本轮直接跳过 Timing Gate 并视作 continue。"
|
||||
logger.info(f"{self._runtime.log_prefix} {reason}")
|
||||
return (
|
||||
"continue",
|
||||
@@ -296,8 +315,24 @@ class MaisakaReasoningEngine:
|
||||
prompt_section=None,
|
||||
),
|
||||
[f"- continue [强制跳过]: {reason}"],
|
||||
[],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _mark_timing_gate_completed(timing_action: str) -> bool:
|
||||
"""根据门控动作决定下一轮是否还需要重新执行 timing。"""
|
||||
|
||||
return timing_action != "continue"
|
||||
|
||||
@staticmethod
|
||||
def _should_retry_planner_after_interrupt(
|
||||
*,
|
||||
round_index: int,
|
||||
max_internal_rounds: int,
|
||||
has_pending_messages: bool,
|
||||
) -> bool:
|
||||
return has_pending_messages and round_index + 1 < max_internal_rounds
|
||||
|
||||
async def run_loop(self) -> None:
|
||||
"""独立消费消息批次,并执行对应的内部思考轮次。"""
|
||||
try:
|
||||
@@ -314,13 +349,20 @@ class MaisakaReasoningEngine:
|
||||
if self._runtime._has_pending_messages()
|
||||
else []
|
||||
)
|
||||
if not timeout_triggered and not cached_messages and not message_triggered:
|
||||
if not timeout_triggered and not cached_messages:
|
||||
continue
|
||||
|
||||
self._runtime._agent_state = self._runtime._STATE_RUNNING
|
||||
self._runtime._update_stage_status(
|
||||
"消息整理",
|
||||
f"待处理消息 {len(cached_messages)} 条" if cached_messages else "准备复用超时锚点",
|
||||
)
|
||||
if cached_messages:
|
||||
asyncio.create_task(self._runtime._trigger_batch_learning(cached_messages))
|
||||
self._append_wait_interrupted_message_if_needed()
|
||||
if timeout_triggered:
|
||||
self._runtime._chat_history.append(
|
||||
self._build_wait_completed_message(has_new_messages=True)
|
||||
)
|
||||
await self._ingest_messages(cached_messages)
|
||||
anchor_message = cached_messages[-1]
|
||||
else:
|
||||
@@ -332,13 +374,16 @@ class MaisakaReasoningEngine:
|
||||
continue
|
||||
logger.info(f"{self._runtime.log_prefix} 等待超时后开始新一轮思考")
|
||||
if self._runtime._pending_wait_tool_call_id:
|
||||
self._runtime._chat_history.append(self._build_wait_timeout_message())
|
||||
self._trim_chat_history()
|
||||
|
||||
self._runtime._chat_history.append(
|
||||
self._build_wait_completed_message(has_new_messages=False)
|
||||
)
|
||||
try:
|
||||
timing_gate_required = True
|
||||
for round_index in range(self._runtime._max_internal_rounds):
|
||||
cycle_detail = self._start_cycle()
|
||||
round_text = f"第 {round_index + 1}/{self._runtime._max_internal_rounds} 轮"
|
||||
self._runtime._log_cycle_started(cycle_detail, round_index)
|
||||
self._runtime._update_stage_status("启动循环", f"循环 {cycle_detail.cycle_id}", round_text=round_text)
|
||||
await emit_cycle_start(
|
||||
session_id=self._runtime.session_id,
|
||||
cycle_id=cycle_detail.cycle_id,
|
||||
@@ -349,10 +394,14 @@ class MaisakaReasoningEngine:
|
||||
planner_started_at = 0.0
|
||||
planner_duration_ms = 0.0
|
||||
timing_duration_ms = 0.0
|
||||
current_stage_started_at = 0.0
|
||||
timing_action: Optional[str] = None
|
||||
timing_response: Optional[ChatResponse] = None
|
||||
timing_tool_results: Optional[list[str]] = None
|
||||
timing_tool_monitor_results: Optional[list[dict[str, Any]]] = None
|
||||
response: Optional[ChatResponse] = None
|
||||
action_tool_definitions: list[dict[str, Any]] = []
|
||||
planner_extra_lines: list[str] = []
|
||||
tool_result_summaries: list[str] = []
|
||||
tool_monitor_results: list[dict[str, Any]] = []
|
||||
try:
|
||||
@@ -364,30 +413,46 @@ class MaisakaReasoningEngine:
|
||||
f"{self._runtime.log_prefix} 本轮思考前已刷新 {refreshed_message_count} 条视觉占位历史消息"
|
||||
)
|
||||
|
||||
timing_started_at = time.time()
|
||||
timing_action, timing_response, timing_tool_results = await self._run_timing_gate(anchor_message)
|
||||
timing_duration_ms = (time.time() - timing_started_at) * 1000
|
||||
cycle_detail.time_records["timing_gate"] = timing_duration_ms / 1000
|
||||
await emit_timing_gate_result(
|
||||
session_id=self._runtime.session_id,
|
||||
cycle_id=cycle_detail.cycle_id,
|
||||
action=timing_action,
|
||||
content=timing_response.content,
|
||||
tool_calls=timing_response.tool_calls,
|
||||
messages=[],
|
||||
prompt_tokens=timing_response.prompt_tokens,
|
||||
selected_history_count=timing_response.selected_history_count,
|
||||
duration_ms=timing_duration_ms,
|
||||
)
|
||||
if timing_action != "continue":
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} Timing Gate 结束当前回合: "
|
||||
f"回合={round_index + 1} 动作={timing_action}"
|
||||
if timing_gate_required:
|
||||
self._runtime._update_stage_status("Timing Gate", "等待门控决策", round_text=round_text)
|
||||
current_stage_started_at = time.time()
|
||||
timing_started_at = time.time()
|
||||
(
|
||||
timing_action,
|
||||
timing_response,
|
||||
timing_tool_results,
|
||||
timing_tool_monitor_results,
|
||||
) = await self._run_timing_gate(anchor_message)
|
||||
timing_duration_ms = (time.time() - timing_started_at) * 1000
|
||||
cycle_detail.time_records["timing_gate"] = timing_duration_ms / 1000
|
||||
await emit_timing_gate_result(
|
||||
session_id=self._runtime.session_id,
|
||||
cycle_id=cycle_detail.cycle_id,
|
||||
action=timing_action,
|
||||
content=timing_response.content,
|
||||
tool_calls=timing_response.tool_calls,
|
||||
messages=[],
|
||||
prompt_tokens=timing_response.prompt_tokens,
|
||||
selected_history_count=timing_response.selected_history_count,
|
||||
duration_ms=timing_duration_ms,
|
||||
)
|
||||
timing_gate_required = self._mark_timing_gate_completed(timing_action)
|
||||
if timing_action != "continue":
|
||||
logger.debug(
|
||||
f"{self._runtime.log_prefix} Timing Gate 结束当前回合: "
|
||||
f"回合={round_index + 1} 动作={timing_action}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} 跳过 Timing Gate,继续执行 Planner: "
|
||||
f"回合={round_index + 1}"
|
||||
)
|
||||
break
|
||||
|
||||
planner_started_at = time.time()
|
||||
action_tool_definitions = await self._build_action_tool_definitions()
|
||||
current_stage_started_at = planner_started_at
|
||||
self._runtime._update_stage_status("Planner", "组织上下文并请求模型", round_text=round_text)
|
||||
action_tool_definitions, deferred_tools_reminder = await self._build_action_tool_definitions()
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} 规划器开始执行: "
|
||||
f"回合={round_index + 1} "
|
||||
@@ -395,6 +460,7 @@ class MaisakaReasoningEngine:
|
||||
f"开始时间={planner_started_at:.3f}"
|
||||
)
|
||||
response = await self._run_interruptible_planner(
|
||||
injected_user_messages=[deferred_tools_reminder] if deferred_tools_reminder else None,
|
||||
tool_definitions=action_tool_definitions,
|
||||
)
|
||||
planner_duration_ms = (time.time() - planner_started_at) * 1000
|
||||
@@ -406,8 +472,8 @@ class MaisakaReasoningEngine:
|
||||
)
|
||||
reasoning_content = response.content or ""
|
||||
if self._should_replace_reasoning(reasoning_content):
|
||||
response.content = "我应该根据我上面思考的内容进行反思,重新思考我下一步的行动,我需要分析当前场景,对话,以及我可以使用的工具,然后先输出想法再使用工具"
|
||||
response.raw_message.content = "我应该根据我上面思考的内容进行反思,重新思考我下一步的行动,我需要分析当前场景,对话,以及我可以使用的工具,然后先输出想法再使用工具"
|
||||
response.content = "我应该根据我上面思考的内容进行反思,重新思考我下一步的行动,我需要分析当前场景,对话,以及我可以使用的工具,然后直接输出我的想法"
|
||||
response.raw_message.content = "我应该根据我上面思考的内容进行反思,重新思考我下一步的行动,我需要分析当前场景,对话,以及我可以使用的工具,然后直接输出我的想法"
|
||||
logger.info(f"{self._runtime.log_prefix} 当前思考与上一轮过于相似,已替换为重新思考提示")
|
||||
|
||||
self._last_reasoning_content = reasoning_content
|
||||
@@ -428,20 +494,73 @@ class MaisakaReasoningEngine:
|
||||
|
||||
if not response.content:
|
||||
break
|
||||
except ReqAbortException:
|
||||
interrupted_at = time.time()
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} 规划器打断成功: "
|
||||
f"回合={round_index + 1} "
|
||||
f"开始时间={planner_started_at:.3f} "
|
||||
f"打断时间={interrupted_at:.3f} "
|
||||
f"耗时={interrupted_at - planner_started_at:.3f} 秒"
|
||||
except ReqAbortException as exc:
|
||||
self._runtime._update_stage_status(
|
||||
"Planner 已打断",
|
||||
str(exc) or "收到外部中断信号",
|
||||
round_text=round_text,
|
||||
)
|
||||
break
|
||||
interrupted_at = time.time()
|
||||
interrupted_stage_label = "Planner"
|
||||
interrupted_text = "Planner 收到新消息,开始重新决策"
|
||||
interrupted_response = ChatResponse(
|
||||
content=interrupted_text or None,
|
||||
tool_calls=[],
|
||||
request_messages=[],
|
||||
raw_message=AssistantMessage(
|
||||
content=interrupted_text,
|
||||
timestamp=datetime.now(),
|
||||
tool_calls=[],
|
||||
source_kind="perception",
|
||||
),
|
||||
selected_history_count=len(self._runtime._chat_history),
|
||||
tool_count=len(action_tool_definitions),
|
||||
prompt_tokens=0,
|
||||
built_message_count=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
prompt_section=None,
|
||||
)
|
||||
interrupted_extra_lines = [
|
||||
"状态:已被新消息打断",
|
||||
f"打断位置:{interrupted_stage_label} 请求流式响应阶段",
|
||||
f"打断耗时:{interrupted_at - current_stage_started_at:.3f} 秒",
|
||||
]
|
||||
response = interrupted_response
|
||||
planner_extra_lines = interrupted_extra_lines
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} {interrupted_stage_label} 打断成功: "
|
||||
f"回合={round_index + 1} "
|
||||
f"开始时间={current_stage_started_at:.3f} "
|
||||
f"打断时间={interrupted_at:.3f} "
|
||||
f"耗时={interrupted_at - current_stage_started_at:.3f} 秒"
|
||||
)
|
||||
if not self._should_retry_planner_after_interrupt(
|
||||
round_index=round_index,
|
||||
max_internal_rounds=self._runtime._max_internal_rounds,
|
||||
has_pending_messages=self._runtime._has_pending_messages(),
|
||||
):
|
||||
break
|
||||
|
||||
await self._runtime._wait_for_message_quiet_period()
|
||||
self._runtime._message_turn_scheduled = False
|
||||
interrupted_messages = self._runtime._collect_pending_messages()
|
||||
if not interrupted_messages:
|
||||
break
|
||||
|
||||
asyncio.create_task(self._runtime._trigger_batch_learning(interrupted_messages))
|
||||
await self._ingest_messages(interrupted_messages)
|
||||
anchor_message = interrupted_messages[-1]
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} 淇濇寔娲昏穬鐘舵€侊紝璺宠繃 Timing Gate 鐩存帴閲嶈瘯 Planner: "
|
||||
f"鍥炲悎={round_index + 2}"
|
||||
)
|
||||
continue
|
||||
finally:
|
||||
completed_cycle = self._end_cycle(cycle_detail)
|
||||
self._runtime._render_context_usage_panel(
|
||||
cycle_id=cycle_detail.cycle_id,
|
||||
time_records=dict(completed_cycle.time_records),
|
||||
timing_selected_history_count=(
|
||||
timing_response.selected_history_count if timing_response is not None else None
|
||||
),
|
||||
@@ -452,6 +571,7 @@ class MaisakaReasoningEngine:
|
||||
timing_response=timing_response.content or "" if timing_response is not None else "",
|
||||
timing_tool_calls=timing_response.tool_calls if timing_response is not None else None,
|
||||
timing_tool_results=timing_tool_results,
|
||||
timing_tool_detail_results=timing_tool_monitor_results,
|
||||
timing_prompt_section=(
|
||||
timing_response.prompt_section if timing_response is not None else None
|
||||
),
|
||||
@@ -464,6 +584,7 @@ class MaisakaReasoningEngine:
|
||||
planner_tool_results=tool_result_summaries,
|
||||
planner_tool_detail_results=tool_monitor_results,
|
||||
planner_prompt_section=response.prompt_section if response is not None else None,
|
||||
planner_extra_lines=planner_extra_lines,
|
||||
)
|
||||
await emit_planner_finalized(
|
||||
session_id=self._runtime.session_id,
|
||||
@@ -505,6 +626,8 @@ class MaisakaReasoningEngine:
|
||||
finally:
|
||||
if self._runtime._agent_state == self._runtime._STATE_RUNNING:
|
||||
self._runtime._agent_state = self._runtime._STATE_STOP
|
||||
if self._runtime._running:
|
||||
self._runtime._update_stage_status("等待消息", "本轮处理结束")
|
||||
except asyncio.CancelledError:
|
||||
self._runtime._log_internal_loop_cancelled()
|
||||
raise
|
||||
@@ -543,33 +666,22 @@ class MaisakaReasoningEngine:
|
||||
return self._runtime.message_cache[-1]
|
||||
return None
|
||||
|
||||
def _build_wait_timeout_message(self) -> ToolResultMessage:
|
||||
"""构造 wait 超时后的工具结果消息。"""
|
||||
def _build_wait_completed_message(self, *, has_new_messages: bool) -> ToolResultMessage:
|
||||
"""构造 wait 完成后的工具结果消息。"""
|
||||
tool_call_id = self._runtime._pending_wait_tool_call_id or "wait_timeout"
|
||||
self._runtime._pending_wait_tool_call_id = None
|
||||
content = (
|
||||
"等待已结束,期间收到了新的用户输入。请结合这些新消息继续下一轮思考。"
|
||||
if has_new_messages
|
||||
else "等待已超时,期间没有收到新的用户输入。请基于现有上下文继续下一轮思考。"
|
||||
)
|
||||
return ToolResultMessage(
|
||||
content="等待已超时,期间没有收到新的用户输入。请基于现有上下文继续下一轮思考。",
|
||||
content=content,
|
||||
timestamp=datetime.now(),
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name="wait",
|
||||
)
|
||||
|
||||
def _append_wait_interrupted_message_if_needed(self) -> None:
|
||||
"""如果 wait 被新消息打断,则补一条对应的工具结果消息。"""
|
||||
tool_call_id = self._runtime._pending_wait_tool_call_id
|
||||
if not tool_call_id:
|
||||
return
|
||||
|
||||
self._runtime._pending_wait_tool_call_id = None
|
||||
self._runtime._chat_history.append(
|
||||
ToolResultMessage(
|
||||
content="等待过程被新的用户输入打断,已继续处理最新消息。",
|
||||
timestamp=datetime.now(),
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name="wait",
|
||||
)
|
||||
)
|
||||
|
||||
async def _ingest_messages(self, messages: list[SessionMessage]) -> None:
|
||||
"""处理传入消息列表,将其转换为历史消息并加入聊天历史缓存。"""
|
||||
for message in messages:
|
||||
@@ -578,7 +690,6 @@ class MaisakaReasoningEngine:
|
||||
continue
|
||||
|
||||
self._insert_chat_history_message(history_message)
|
||||
self._trim_chat_history()
|
||||
|
||||
# 向监控前端广播新消息注入事件
|
||||
user_info = message.message_info.user_info
|
||||
@@ -628,10 +739,47 @@ class MaisakaReasoningEngine:
|
||||
planner_prefix: str,
|
||||
) -> MessageSequence:
|
||||
message_sequence = build_prefixed_message_sequence(message.raw_message, planner_prefix)
|
||||
if global_config.visual.multimodal_planner:
|
||||
if self._resolve_enable_visual_planner():
|
||||
await self._hydrate_visual_components(message_sequence.components)
|
||||
return message_sequence
|
||||
|
||||
@staticmethod
|
||||
def _resolve_enable_visual_planner() -> bool:
|
||||
planner_mode = global_config.visual.planner_mode
|
||||
planner_task_config = config_manager.get_model_config().model_task_config.planner
|
||||
models_by_name = {model.name: model for model in config_manager.get_model_config().models}
|
||||
|
||||
if planner_mode == "text":
|
||||
return False
|
||||
|
||||
planner_models: list[str] = list(planner_task_config.model_list)
|
||||
missing_models = [model_name for model_name in planner_models if model_name not in models_by_name]
|
||||
non_visual_models = [
|
||||
model_name for model_name in planner_models if model_name in models_by_name and not models_by_name[model_name].visual
|
||||
]
|
||||
|
||||
if planner_mode == "multimodal":
|
||||
if missing_models:
|
||||
raise ValueError(
|
||||
"planner_mode=multimodal,但 planner 任务存在未定义的模型:"
|
||||
f"{', '.join(missing_models)}"
|
||||
)
|
||||
if non_visual_models:
|
||||
raise ValueError(
|
||||
"planner_mode=multimodal,但 planner 任务存在未开启 visual 的模型:"
|
||||
f"{', '.join(non_visual_models)}"
|
||||
)
|
||||
return True
|
||||
|
||||
if missing_models:
|
||||
logger.warning(
|
||||
"planner_mode=auto 时发现 planner 任务存在未定义模型:"
|
||||
f"{', '.join(missing_models)},将退化为纯文本 planner"
|
||||
)
|
||||
return False
|
||||
|
||||
return bool(planner_models) and not non_visual_models
|
||||
|
||||
async def _hydrate_visual_components(self, planner_components: list[object]) -> None:
|
||||
"""在 Maisaka 真正需要图片或表情时,按需回填二进制数据。"""
|
||||
load_tasks: list[asyncio.Task[None]] = []
|
||||
@@ -681,6 +829,7 @@ class MaisakaReasoningEngine:
|
||||
"""结束并记录一轮 Maisaka 思考循环。"""
|
||||
cycle_detail.end_time = time.time()
|
||||
self._runtime.history_loop.append(cycle_detail)
|
||||
self._post_process_chat_history_after_cycle()
|
||||
|
||||
timer_strings = [
|
||||
f"{name}: {duration:.2f}s"
|
||||
@@ -690,26 +839,20 @@ class MaisakaReasoningEngine:
|
||||
self._runtime._log_cycle_completed(cycle_detail, timer_strings)
|
||||
return cycle_detail
|
||||
|
||||
def _trim_chat_history(self) -> None:
|
||||
def _post_process_chat_history_after_cycle(self) -> None:
|
||||
"""裁剪聊天历史,保证用户消息数量不超过配置限制。"""
|
||||
conversation_message_count = sum(1 for message in self._runtime._chat_history if message.count_in_context)
|
||||
if conversation_message_count <= self._runtime._max_context_size:
|
||||
process_result = process_chat_history_after_cycle(
|
||||
self._runtime._chat_history,
|
||||
max_context_size=self._runtime._max_context_size,
|
||||
)
|
||||
if process_result.removed_count <= 0:
|
||||
return
|
||||
|
||||
trimmed_history = list(self._runtime._chat_history)
|
||||
removed_count = 0
|
||||
|
||||
while conversation_message_count > self._runtime._max_context_size and trimmed_history:
|
||||
removed_message = trimmed_history.pop(0)
|
||||
removed_count += 1
|
||||
if removed_message.count_in_context:
|
||||
conversation_message_count -= 1
|
||||
|
||||
trimmed_history, pruned_orphan_count = drop_leading_orphan_tool_results(trimmed_history)
|
||||
removed_count += pruned_orphan_count
|
||||
|
||||
self._runtime._chat_history = trimmed_history
|
||||
self._runtime._log_history_trimmed(removed_count, conversation_message_count)
|
||||
self._runtime._chat_history = process_result.history
|
||||
self._runtime._log_history_trimmed(
|
||||
process_result.removed_count,
|
||||
process_result.remaining_context_count,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _calculate_similarity(text1: str, text2: str) -> float:
|
||||
@@ -934,6 +1077,9 @@ class MaisakaReasoningEngine:
|
||||
if invocation.tool_name == "no_reply":
|
||||
return "你暂停了当前对话循环,等待新的外部消息。"
|
||||
|
||||
if invocation.tool_name == "finish":
|
||||
return "你结束了本轮思考,等待新的外部消息后再继续。"
|
||||
|
||||
if invocation.tool_name == "continue":
|
||||
return "你允许当前对话继续进入下一轮完整思考与工具执行。"
|
||||
|
||||
@@ -1065,6 +1211,24 @@ class MaisakaReasoningEngine:
|
||||
)
|
||||
)
|
||||
|
||||
def _append_timing_gate_execution_result(
|
||||
self,
|
||||
response: ChatResponse,
|
||||
tool_call: ToolCall,
|
||||
result: ToolExecutionResult,
|
||||
) -> None:
|
||||
"""将 Timing Gate 的决策链写入历史,供后续门控复用。"""
|
||||
|
||||
self._runtime._chat_history.append(
|
||||
AssistantMessage(
|
||||
content=response.content or "",
|
||||
timestamp=response.raw_message.timestamp,
|
||||
tool_calls=[tool_call],
|
||||
source_kind="timing_gate",
|
||||
)
|
||||
)
|
||||
self._append_tool_execution_result(tool_call, result)
|
||||
|
||||
def _build_tool_result_summary(self, tool_call: ToolCall, result: ToolExecutionResult) -> str:
|
||||
"""构建用于终端展示的工具结果摘要。"""
|
||||
|
||||
@@ -1084,6 +1248,7 @@ class MaisakaReasoningEngine:
|
||||
invocation: ToolInvocation,
|
||||
result: ToolExecutionResult,
|
||||
duration_ms: float,
|
||||
tool_spec: Optional[ToolSpec] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""构建 planner.finalized 中单个工具的监控结果。"""
|
||||
|
||||
@@ -1092,9 +1257,20 @@ class MaisakaReasoningEngine:
|
||||
if monitor_detail is not None:
|
||||
normalized_detail = self._normalize_tool_record_value(monitor_detail)
|
||||
|
||||
monitor_card = result.metadata.get("monitor_card")
|
||||
normalized_card = None
|
||||
if monitor_card is not None:
|
||||
normalized_card = self._normalize_tool_record_value(monitor_card)
|
||||
|
||||
monitor_sub_cards = result.metadata.get("monitor_sub_cards")
|
||||
normalized_sub_cards = None
|
||||
if monitor_sub_cards is not None:
|
||||
normalized_sub_cards = self._normalize_tool_record_value(monitor_sub_cards)
|
||||
|
||||
return {
|
||||
"tool_call_id": tool_call.call_id,
|
||||
"tool_name": tool_call.func_name,
|
||||
"tool_title": tool_spec.title.strip() if tool_spec is not None and tool_spec.title.strip() else "",
|
||||
"tool_args": self._normalize_tool_record_value(
|
||||
invocation.arguments if isinstance(invocation.arguments, dict) else {}
|
||||
),
|
||||
@@ -1102,6 +1278,8 @@ class MaisakaReasoningEngine:
|
||||
"duration_ms": round(duration_ms, 2),
|
||||
"summary": self._build_tool_result_summary(tool_call, result),
|
||||
"detail": normalized_detail,
|
||||
"card": normalized_card,
|
||||
"sub_cards": normalized_sub_cards,
|
||||
}
|
||||
|
||||
async def _handle_tool_calls(
|
||||
@@ -1137,7 +1315,7 @@ class MaisakaReasoningEngine:
|
||||
self._append_tool_execution_result(tool_call, result)
|
||||
tool_result_summaries.append(self._build_tool_result_summary(tool_call, result))
|
||||
tool_monitor_results.append(
|
||||
self._build_tool_monitor_result(tool_call, invocation, result, duration_ms=0.0)
|
||||
self._build_tool_monitor_result(tool_call, invocation, result, duration_ms=0.0, tool_spec=None)
|
||||
)
|
||||
return False, tool_result_summaries, tool_monitor_results
|
||||
|
||||
@@ -1146,10 +1324,25 @@ class MaisakaReasoningEngine:
|
||||
tool_spec.name: tool_spec
|
||||
for tool_spec in await self._runtime._tool_registry.list_tools()
|
||||
}
|
||||
for tool_call in tool_calls:
|
||||
total_tool_count = len(tool_calls)
|
||||
for tool_index, tool_call in enumerate(tool_calls, start=1):
|
||||
invocation = self._build_tool_invocation(tool_call, latest_thought)
|
||||
self._runtime._update_stage_status(
|
||||
f"工具执行 · {invocation.tool_name}",
|
||||
f"第 {tool_index}/{total_tool_count} 个工具",
|
||||
)
|
||||
tool_started_at = time.time()
|
||||
result = await self._runtime._tool_registry.invoke(invocation, execution_context)
|
||||
if not self._runtime.is_action_tool_currently_available(invocation.tool_name):
|
||||
result = ToolExecutionResult(
|
||||
tool_name=invocation.tool_name,
|
||||
success=False,
|
||||
error_message=(
|
||||
f"工具 {invocation.tool_name} 当前未直接暴露给 planner。"
|
||||
"如果它在 deferred tools 提示中,请先调用 tool_search。"
|
||||
),
|
||||
)
|
||||
else:
|
||||
result = await self._runtime._tool_registry.invoke(invocation, execution_context)
|
||||
tool_duration_ms = (time.time() - tool_started_at) * 1000
|
||||
await self._store_tool_execution_record(
|
||||
invocation,
|
||||
@@ -1159,7 +1352,13 @@ class MaisakaReasoningEngine:
|
||||
self._append_tool_execution_result(tool_call, result)
|
||||
tool_result_summaries.append(self._build_tool_result_summary(tool_call, result))
|
||||
tool_monitor_results.append(
|
||||
self._build_tool_monitor_result(tool_call, invocation, result, tool_duration_ms)
|
||||
self._build_tool_monitor_result(
|
||||
tool_call,
|
||||
invocation,
|
||||
result,
|
||||
tool_duration_ms,
|
||||
tool_spec=tool_spec_map.get(invocation.tool_name),
|
||||
)
|
||||
)
|
||||
|
||||
if not result.success and tool_call.func_name == "reply":
|
||||
|
||||
@@ -21,7 +21,7 @@ from src.common.data_models.mai_message_data_model import GroupInfo, UserInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolRegistry
|
||||
from src.core.tooling import ToolRegistry, ToolSpec
|
||||
from src.learners.expression_learner import ExpressionLearner
|
||||
from src.learners.jargon_miner import JargonMiner
|
||||
from src.llm_models.payload_content.resp_format import RespFormat
|
||||
@@ -30,11 +30,13 @@ from src.mcp_module import MCPManager
|
||||
from src.mcp_module.host_llm_bridge import MCPHostLLMBridge
|
||||
from src.mcp_module.provider import MCPToolProvider
|
||||
from src.plugin_runtime.tool_provider import PluginToolProvider
|
||||
from src.plugin_runtime.hook_payloads import deserialize_prompt_messages
|
||||
|
||||
from .chat_loop_service import ChatResponse, MaisakaChatLoopService
|
||||
from .context_messages import LLMContextMessage
|
||||
from .display_utils import build_tool_call_summary_lines, format_token_count
|
||||
from .prompt_cli_renderer import PromptCLIVisualizer
|
||||
from .display.display_utils import build_tool_call_summary_lines, format_token_count
|
||||
from .display.prompt_cli_renderer import PromptCLIVisualizer
|
||||
from .display.stage_status_board import remove_stage_status, update_stage_status
|
||||
from .reasoning_engine import MaisakaReasoningEngine
|
||||
from .tool_provider import MaisakaBuiltinToolProvider
|
||||
|
||||
@@ -92,14 +94,16 @@ class MaisakaHeartFlowChatting:
|
||||
self._max_internal_rounds = MAX_INTERNAL_ROUNDS
|
||||
self._max_context_size = max(1, int(global_config.chat.max_context_size))
|
||||
self._agent_state: Literal["running", "wait", "stop"] = self._STATE_STOP
|
||||
self._wait_until: Optional[float] = None
|
||||
self._pending_wait_tool_call_id: Optional[str] = None
|
||||
self._force_continue_until_reply = False
|
||||
self._force_continue_trigger_message_id = ""
|
||||
self._force_continue_trigger_reason = ""
|
||||
self._force_next_timing_continue = False
|
||||
self._force_next_timing_message_id = ""
|
||||
self._force_next_timing_reason = ""
|
||||
self._planner_interrupt_flag: Optional[asyncio.Event] = None
|
||||
self._planner_interrupt_requested = False
|
||||
self._planner_interrupt_consecutive_count = 0
|
||||
self._current_action_tool_names: set[str] = set()
|
||||
self.discovered_tool_names: set[str] = set()
|
||||
self.deferred_tool_specs_by_name: dict[str, ToolSpec] = {}
|
||||
self._planner_interrupt_max_consecutive_count = max(
|
||||
0,
|
||||
int(global_config.chat.planner_interrupt_max_consecutive_count),
|
||||
@@ -118,6 +122,18 @@ class MaisakaHeartFlowChatting:
|
||||
self._tool_registry = ToolRegistry()
|
||||
self._register_tool_providers()
|
||||
|
||||
def _update_stage_status(self, stage: str, detail: str = "", *, round_text: str = "") -> None:
|
||||
"""更新当前会话的阶段状态。"""
|
||||
|
||||
update_stage_status(
|
||||
session_id=self.session_id,
|
||||
session_name=self.session_name,
|
||||
stage=stage,
|
||||
detail=detail,
|
||||
round_text=round_text,
|
||||
agent_state=self._agent_state,
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动运行时主循环。"""
|
||||
if self._running:
|
||||
@@ -130,6 +146,7 @@ class MaisakaHeartFlowChatting:
|
||||
self._running = True
|
||||
self._ensure_background_tasks_running()
|
||||
self._schedule_message_turn()
|
||||
self._update_stage_status("空闲", "等待消息触发")
|
||||
logger.info(f"{self.log_prefix} Maisaka 运行时已启动")
|
||||
|
||||
async def stop(self) -> None:
|
||||
@@ -157,6 +174,7 @@ class MaisakaHeartFlowChatting:
|
||||
await self._tool_registry.close()
|
||||
self._mcp_manager = None
|
||||
self._mcp_host_bridge = None
|
||||
remove_stage_status(self.session_id)
|
||||
|
||||
logger.info(f"{self.log_prefix} Maisaka 运行时已停止")
|
||||
|
||||
@@ -175,9 +193,6 @@ class MaisakaHeartFlowChatting:
|
||||
self.message_cache.append(message)
|
||||
self._message_received_at_by_id[message.message_id] = received_at
|
||||
self._source_messages_by_id[message.message_id] = message
|
||||
if self._agent_state == self._STATE_WAIT:
|
||||
self._cancel_wait_timeout_task()
|
||||
self._wait_until = None
|
||||
if self._agent_state == self._STATE_RUNNING:
|
||||
self._message_debounce_required = True
|
||||
if self._agent_state == self._STATE_RUNNING and self._planner_interrupt_flag is not None:
|
||||
@@ -248,7 +263,6 @@ class MaisakaHeartFlowChatting:
|
||||
|
||||
def _record_reply_sent(self) -> None:
|
||||
"""在成功发送 reply 后记录本轮消息回复时长。"""
|
||||
self._clear_force_continue_until_reply()
|
||||
if self._reply_latency_measurement_started_at is None:
|
||||
return
|
||||
|
||||
@@ -308,26 +322,26 @@ class MaisakaHeartFlowChatting:
|
||||
if not message.is_at and not message.is_mentioned:
|
||||
return
|
||||
|
||||
self._arm_force_continue_until_reply(
|
||||
self._arm_force_next_timing_continue(
|
||||
message,
|
||||
is_at=message.is_at,
|
||||
is_mentioned=message.is_mentioned,
|
||||
)
|
||||
|
||||
def _arm_force_continue_until_reply(
|
||||
def _arm_force_next_timing_continue(
|
||||
self,
|
||||
message: SessionMessage,
|
||||
*,
|
||||
is_at: bool,
|
||||
is_mentioned: bool,
|
||||
) -> None:
|
||||
"""在检测到 @ 或提及时,要求后续轮次跳过 Timing Gate 直到成功 reply。"""
|
||||
"""在检测到 @ 或提及时,要求下一次 Timing Gate 直接 continue。"""
|
||||
|
||||
trigger_reason = "@消息" if is_at else "提及消息" if is_mentioned else "触发消息"
|
||||
was_armed = self._force_continue_until_reply
|
||||
self._force_continue_until_reply = True
|
||||
self._force_continue_trigger_message_id = message.message_id
|
||||
self._force_continue_trigger_reason = trigger_reason
|
||||
was_armed = self._force_next_timing_continue
|
||||
self._force_next_timing_continue = True
|
||||
self._force_next_timing_message_id = message.message_id
|
||||
self._force_next_timing_reason = trigger_reason
|
||||
|
||||
if was_armed:
|
||||
logger.info(
|
||||
@@ -337,34 +351,31 @@ class MaisakaHeartFlowChatting:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 检测到{trigger_reason},将跳过 Timing Gate 直到成功发送一条 reply;"
|
||||
f"{self.log_prefix} 检测到{trigger_reason},下一次 Timing Gate 将直接视作 continue;"
|
||||
f"消息编号={message.message_id}"
|
||||
)
|
||||
|
||||
def _clear_force_continue_until_reply(self) -> None:
|
||||
"""在成功发送 reply 后清理强制 continue 状态。"""
|
||||
def _consume_force_next_timing_continue_reason(self) -> str | None:
|
||||
"""消费一次性 Timing Gate continue 状态,并返回原因描述。"""
|
||||
|
||||
if not self._force_continue_until_reply:
|
||||
return
|
||||
if not self._force_next_timing_continue:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 已成功发送 reply,恢复 Timing Gate;"
|
||||
f"触发原因={self._force_continue_trigger_reason or '未知'} "
|
||||
f"触发消息编号={self._force_continue_trigger_message_id or 'unknown'}"
|
||||
)
|
||||
self._force_continue_until_reply = False
|
||||
self._force_continue_trigger_message_id = ""
|
||||
self._force_continue_trigger_reason = ""
|
||||
|
||||
def _build_force_continue_timing_reason(self) -> str:
|
||||
"""返回当前强制跳过 Timing Gate 的原因描述。"""
|
||||
|
||||
trigger_reason = self._force_continue_trigger_reason or "@/提及消息"
|
||||
trigger_message_id = self._force_continue_trigger_message_id or "unknown"
|
||||
return (
|
||||
trigger_reason = self._force_next_timing_reason or "@/提及消息"
|
||||
trigger_message_id = self._force_next_timing_message_id or "unknown"
|
||||
reason = (
|
||||
f"检测到新的{trigger_reason}(消息编号={trigger_message_id}),"
|
||||
"本轮直接跳过 Timing Gate 并视作 continue,直到成功发送一条 reply。"
|
||||
"本轮直接跳过 Timing Gate 并视作 continue。"
|
||||
)
|
||||
logger.info(
|
||||
f"{self.log_prefix} 已结束本次强制 continue,恢复 Timing Gate;"
|
||||
f"触发原因={trigger_reason} "
|
||||
f"触发消息编号={trigger_message_id}"
|
||||
)
|
||||
self._force_next_timing_continue = False
|
||||
self._force_next_timing_message_id = ""
|
||||
self._force_next_timing_reason = ""
|
||||
return reason
|
||||
|
||||
def _bind_planner_interrupt_flag(self, interrupt_flag: asyncio.Event) -> None:
|
||||
"""绑定当前可打断请求使用的中断标记。"""
|
||||
@@ -426,6 +437,7 @@ class MaisakaHeartFlowChatting:
|
||||
|
||||
selected_history, _ = MaisakaChatLoopService.select_llm_context_messages(
|
||||
self._chat_history,
|
||||
request_kind=request_kind,
|
||||
max_context_size=context_message_limit,
|
||||
)
|
||||
sub_agent_history = list(selected_history)
|
||||
@@ -447,11 +459,133 @@ class MaisakaHeartFlowChatting:
|
||||
tool_definitions=[] if tool_definitions is None else tool_definitions,
|
||||
)
|
||||
|
||||
def set_current_action_tool_names(self, tool_names: Sequence[str]) -> None:
|
||||
"""记录当前 Action Loop 已实际暴露给 planner 的工具名集合。"""
|
||||
|
||||
self._current_action_tool_names = {tool_name for tool_name in tool_names if str(tool_name).strip()}
|
||||
|
||||
def is_action_tool_currently_available(self, tool_name: str) -> bool:
|
||||
"""判断指定工具在当前 Action Loop 轮次中是否真实可用。"""
|
||||
|
||||
normalized_name = str(tool_name).strip()
|
||||
return bool(normalized_name) and normalized_name in self._current_action_tool_names
|
||||
|
||||
def update_deferred_tool_specs(self, deferred_tool_specs: Sequence[ToolSpec]) -> None:
|
||||
"""刷新当前会话的 deferred tools 池,并清理失效的已发现工具。"""
|
||||
|
||||
next_specs_by_name: dict[str, ToolSpec] = {}
|
||||
for tool_spec in deferred_tool_specs:
|
||||
normalized_name = tool_spec.name.strip()
|
||||
if not normalized_name:
|
||||
continue
|
||||
next_specs_by_name[normalized_name] = tool_spec
|
||||
|
||||
self.deferred_tool_specs_by_name = next_specs_by_name
|
||||
self.discovered_tool_names.intersection_update(next_specs_by_name.keys())
|
||||
|
||||
def get_discovered_deferred_tool_specs(self) -> list[ToolSpec]:
|
||||
"""返回当前会话中已发现、且仍然有效的 deferred tools。"""
|
||||
|
||||
return [
|
||||
tool_spec
|
||||
for tool_name, tool_spec in self.deferred_tool_specs_by_name.items()
|
||||
if tool_name in self.discovered_tool_names
|
||||
]
|
||||
|
||||
def build_deferred_tools_reminder(self) -> str:
|
||||
"""构造供 planner 使用的 deferred tools 提示消息。"""
|
||||
|
||||
undiscovered_tool_specs = [
|
||||
tool_spec
|
||||
for tool_name, tool_spec in self.deferred_tool_specs_by_name.items()
|
||||
if tool_name not in self.discovered_tool_names
|
||||
]
|
||||
if not undiscovered_tool_specs:
|
||||
return ""
|
||||
|
||||
tool_lines: list[str] = []
|
||||
for index, tool_spec in enumerate(undiscovered_tool_specs, start=1):
|
||||
tool_name = tool_spec.name.strip()
|
||||
tool_description = tool_spec.brief_description.strip()
|
||||
if tool_description:
|
||||
tool_lines.append(f"{index}. {tool_name}: {tool_description}")
|
||||
else:
|
||||
tool_lines.append(f"{index}. {tool_name}")
|
||||
|
||||
reminder_lines = [
|
||||
"<system-reminder>",
|
||||
"以下工具当前未直接暴露给你,但可以通过 tool_search 工具发现并在后续轮次中使用:",
|
||||
*tool_lines,
|
||||
"",
|
||||
"如需其中某个工具,请先调用 tool_search。tool_search 只负责发现工具,不直接执行业务。",
|
||||
"</system-reminder>",
|
||||
]
|
||||
return "\n".join(reminder_lines)
|
||||
|
||||
def search_deferred_tool_specs(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
) -> list[ToolSpec]:
|
||||
"""按名称或简要描述搜索 deferred tools。"""
|
||||
|
||||
normalized_query = " ".join(query.lower().split()).strip()
|
||||
if not normalized_query:
|
||||
return []
|
||||
|
||||
scored_matches: list[tuple[int, str, ToolSpec]] = []
|
||||
query_terms = [term for term in normalized_query.replace("_", " ").replace("-", " ").split() if term]
|
||||
for tool_name, tool_spec in self.deferred_tool_specs_by_name.items():
|
||||
lower_name = tool_name.lower()
|
||||
lower_description = tool_spec.brief_description.lower()
|
||||
score = 0
|
||||
|
||||
if normalized_query == lower_name:
|
||||
score += 1000
|
||||
if lower_name.startswith(normalized_query):
|
||||
score += 300
|
||||
if normalized_query in lower_name:
|
||||
score += 200
|
||||
if normalized_query in lower_description:
|
||||
score += 100
|
||||
|
||||
for query_term in query_terms:
|
||||
if query_term in lower_name:
|
||||
score += 25
|
||||
if query_term in lower_description:
|
||||
score += 10
|
||||
|
||||
if score <= 0:
|
||||
continue
|
||||
|
||||
scored_matches.append((score, tool_name, tool_spec))
|
||||
|
||||
scored_matches.sort(key=lambda item: (-item[0], item[1]))
|
||||
return [tool_spec for _, _, tool_spec in scored_matches[: max(1, limit)]]
|
||||
|
||||
def discover_deferred_tools(self, tool_names: Sequence[str]) -> list[str]:
|
||||
"""将指定 deferred tools 标记为已发现,并返回本次新发现的工具名。"""
|
||||
|
||||
newly_discovered_tool_names: list[str] = []
|
||||
for raw_tool_name in tool_names:
|
||||
normalized_name = str(raw_tool_name).strip()
|
||||
if not normalized_name or normalized_name not in self.deferred_tool_specs_by_name:
|
||||
continue
|
||||
if normalized_name in self.discovered_tool_names:
|
||||
continue
|
||||
self.discovered_tool_names.add(normalized_name)
|
||||
newly_discovered_tool_names.append(normalized_name)
|
||||
return newly_discovered_tool_names
|
||||
|
||||
def _has_pending_messages(self) -> bool:
|
||||
return self._last_processed_index < len(self.message_cache)
|
||||
|
||||
def _schedule_message_turn(self) -> None:
|
||||
"""为当前待处理消息安排一次内部 turn。"""
|
||||
if self._agent_state == self._STATE_WAIT:
|
||||
return
|
||||
|
||||
if not self._has_pending_messages() or self._message_turn_scheduled:
|
||||
return
|
||||
|
||||
@@ -531,8 +665,9 @@ class MaisakaHeartFlowChatting:
|
||||
def _enter_wait_state(self, seconds: Optional[float] = None, tool_call_id: Optional[str] = None) -> None:
|
||||
"""切换到等待状态。"""
|
||||
self._agent_state = self._STATE_WAIT
|
||||
self._wait_until = None if seconds is None else time.time() + seconds
|
||||
self._pending_wait_tool_call_id = tool_call_id
|
||||
self._message_turn_scheduled = False
|
||||
self._cancel_deferred_message_turn_task()
|
||||
self._cancel_wait_timeout_task()
|
||||
if seconds is not None:
|
||||
self._wait_timeout_task = asyncio.create_task(
|
||||
@@ -542,7 +677,6 @@ class MaisakaHeartFlowChatting:
|
||||
def _enter_stop_state(self) -> None:
|
||||
"""切换到停止状态。"""
|
||||
self._agent_state = self._STATE_STOP
|
||||
self._wait_until = None
|
||||
self._pending_wait_tool_call_id = None
|
||||
self._cancel_wait_timeout_task()
|
||||
|
||||
@@ -567,7 +701,6 @@ class MaisakaHeartFlowChatting:
|
||||
|
||||
logger.info(f"{self.log_prefix} Maisaka 等待已超时")
|
||||
self._agent_state = self._STATE_RUNNING
|
||||
self._wait_until = None
|
||||
await self._internal_turn_queue.put("timeout")
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
@@ -616,7 +749,7 @@ class MaisakaHeartFlowChatting:
|
||||
return True
|
||||
|
||||
async def _trigger_expression_learning(self, messages: list[SessionMessage]) -> None:
|
||||
"""?????????????????"""
|
||||
"""触发表达方式学习"""
|
||||
pending_count = self._expression_learner.get_pending_count(self.message_cache)
|
||||
if not self._should_trigger_learning(
|
||||
enabled=self._enable_expression_learning,
|
||||
@@ -629,21 +762,21 @@ class MaisakaHeartFlowChatting:
|
||||
|
||||
self._last_expression_extraction_time = time.time()
|
||||
logger.info(
|
||||
f"{self.log_prefix} ??????: "
|
||||
f"??????={len(messages)} ??????={pending_count} "
|
||||
f"?????={len(self.message_cache)} "
|
||||
f"??????={self._enable_jargon_learning}"
|
||||
f"{self.log_prefix} 触发表达方式学习: "
|
||||
f"消息数量={len(messages)} 待处理消息数量={pending_count} "
|
||||
f"缓存总量={len(self.message_cache)} "
|
||||
f"是否启用黑话学习={self._enable_jargon_learning}"
|
||||
)
|
||||
|
||||
try:
|
||||
jargon_miner = self._jargon_miner if self._enable_jargon_learning else None
|
||||
learnt_style = await self._expression_learner.learn(self.message_cache, jargon_miner)
|
||||
if learnt_style:
|
||||
logger.info(f"{self.log_prefix} ???????")
|
||||
logger.info(f"{self.log_prefix} 表达方式学习成功")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} ???????????????")
|
||||
logger.debug(f"{self.log_prefix} 表达方式学习失败")
|
||||
except Exception:
|
||||
logger.exception(f"{self.log_prefix} ??????")
|
||||
logger.exception(f"{self.log_prefix} 表达方式学习异常")
|
||||
|
||||
async def _init_mcp(self) -> None:
|
||||
"""初始化 MCP 工具并注册到统一工具层。"""
|
||||
@@ -655,12 +788,12 @@ class MaisakaHeartFlowChatting:
|
||||
host_callbacks=self._mcp_host_bridge.build_callbacks(),
|
||||
)
|
||||
if self._mcp_manager is None:
|
||||
logger.info(f"{self.log_prefix} MCP 管理器不可用")
|
||||
logger.info(f"{self.log_prefix} Maisaka MCP 管理器不可用")
|
||||
return
|
||||
|
||||
mcp_tool_specs = self._mcp_manager.get_tool_specs()
|
||||
if not mcp_tool_specs:
|
||||
logger.info(f"{self.log_prefix} 没有可供 Maisaka 使用的 MCP 工具")
|
||||
logger.info(f"{self.log_prefix} Maisaka 没有可供使用的 MCP 工具")
|
||||
return
|
||||
|
||||
self._tool_registry.register_provider(MCPToolProvider(self._mcp_manager))
|
||||
@@ -694,6 +827,7 @@ class MaisakaHeartFlowChatting:
|
||||
self,
|
||||
*,
|
||||
cycle_id: Optional[int] = None,
|
||||
time_records: Optional[dict[str, float]] = None,
|
||||
timing_selected_history_count: Optional[int] = None,
|
||||
timing_prompt_tokens: Optional[int] = None,
|
||||
timing_action: str = "",
|
||||
@@ -709,6 +843,7 @@ class MaisakaHeartFlowChatting:
|
||||
planner_tool_results: Optional[list[str]] = None,
|
||||
planner_tool_detail_results: Optional[list[dict[str, Any]]] = None,
|
||||
planner_prompt_section: Optional[RenderableType] = None,
|
||||
planner_extra_lines: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
"""在终端展示当前聊天流本轮 cycle 的最终结果。"""
|
||||
if not global_config.debug.show_maisaka_thinking:
|
||||
@@ -721,6 +856,7 @@ class MaisakaHeartFlowChatting:
|
||||
if cycle_id is not None:
|
||||
body_lines.append(f"循环编号:{cycle_id}")
|
||||
|
||||
panel_subtitle = self._build_cycle_time_records_text(time_records or {})
|
||||
renderables: list[RenderableType] = [Text("\n".join(body_lines))]
|
||||
timing_panel = self._build_cycle_stage_panel(
|
||||
title="Timing Gate",
|
||||
@@ -728,33 +864,49 @@ class MaisakaHeartFlowChatting:
|
||||
selected_history_count=timing_selected_history_count,
|
||||
prompt_tokens=timing_prompt_tokens,
|
||||
response_text=timing_response,
|
||||
tool_calls=timing_tool_calls,
|
||||
tool_results=timing_tool_results,
|
||||
tool_detail_results=timing_tool_detail_results,
|
||||
prompt_section=timing_prompt_section,
|
||||
extra_lines=[f"门控动作:{timing_action}"] if timing_action.strip() else None,
|
||||
)
|
||||
if timing_panel is not None:
|
||||
renderables.append(timing_panel)
|
||||
|
||||
timing_tool_cards = self._build_tool_activity_cards(
|
||||
stage_title="Timing Tool",
|
||||
tool_calls=timing_tool_calls,
|
||||
tool_results=timing_tool_results,
|
||||
tool_detail_results=timing_tool_detail_results,
|
||||
planner_style=False,
|
||||
)
|
||||
if timing_tool_cards:
|
||||
renderables.extend(timing_tool_cards)
|
||||
|
||||
planner_panel = self._build_cycle_stage_panel(
|
||||
title="Planner",
|
||||
border_style="green",
|
||||
selected_history_count=planner_selected_history_count,
|
||||
prompt_tokens=planner_prompt_tokens,
|
||||
response_text=planner_response,
|
||||
tool_calls=planner_tool_calls,
|
||||
tool_results=planner_tool_results,
|
||||
tool_detail_results=planner_tool_detail_results,
|
||||
prompt_section=planner_prompt_section,
|
||||
extra_lines=planner_extra_lines,
|
||||
)
|
||||
if planner_panel is not None:
|
||||
renderables.append(planner_panel)
|
||||
|
||||
planner_tool_cards = self._build_tool_activity_cards(
|
||||
stage_title="Planner Tool",
|
||||
tool_calls=planner_tool_calls,
|
||||
tool_results=planner_tool_results,
|
||||
tool_detail_results=planner_tool_detail_results,
|
||||
planner_style=True,
|
||||
)
|
||||
if planner_tool_cards:
|
||||
renderables.extend(planner_tool_cards)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Group(*renderables),
|
||||
title="MaiSaka 循环",
|
||||
subtitle=panel_subtitle,
|
||||
border_style="bright_blue",
|
||||
padding=(0, 1),
|
||||
)
|
||||
@@ -768,9 +920,6 @@ class MaisakaHeartFlowChatting:
|
||||
selected_history_count: Optional[int],
|
||||
prompt_tokens: Optional[int],
|
||||
response_text: str = "",
|
||||
tool_calls: Optional[list[Any]] = None,
|
||||
tool_results: Optional[list[str]] = None,
|
||||
tool_detail_results: Optional[list[dict[str, Any]]] = None,
|
||||
prompt_section: Optional[RenderableType] = None,
|
||||
extra_lines: Optional[list[str]] = None,
|
||||
) -> Optional[Panel]:
|
||||
@@ -780,9 +929,6 @@ class MaisakaHeartFlowChatting:
|
||||
selected_history_count is not None,
|
||||
prompt_tokens is not None,
|
||||
bool(response_text.strip()),
|
||||
bool(tool_calls),
|
||||
bool(tool_results),
|
||||
bool(tool_detail_results),
|
||||
prompt_section is not None,
|
||||
bool(extra_lines),
|
||||
])
|
||||
@@ -809,40 +955,11 @@ class MaisakaHeartFlowChatting:
|
||||
Panel(
|
||||
Text(normalized_response),
|
||||
title="Maisaka 返回",
|
||||
border_style="green",
|
||||
border_style=border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
normalized_tool_calls = build_tool_call_summary_lines(tool_calls or [])
|
||||
if normalized_tool_calls:
|
||||
renderables.append(
|
||||
Panel(
|
||||
Text("\n".join(normalized_tool_calls)),
|
||||
title="工具调用",
|
||||
border_style="magenta",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
normalized_tool_results = self._filter_redundant_tool_results(
|
||||
tool_results=tool_results or [],
|
||||
tool_detail_results=tool_detail_results or [],
|
||||
)
|
||||
if normalized_tool_results:
|
||||
renderables.append(
|
||||
Panel(
|
||||
Text("\n".join(normalized_tool_results)),
|
||||
title="工具结果",
|
||||
border_style="yellow",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
detail_panels = self._build_tool_detail_panels(tool_detail_results or [])
|
||||
if detail_panels:
|
||||
renderables.extend(detail_panels)
|
||||
|
||||
return Panel(
|
||||
Group(*renderables),
|
||||
title=title,
|
||||
@@ -850,6 +967,75 @@ class MaisakaHeartFlowChatting:
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
def _build_tool_activity_cards(
|
||||
self,
|
||||
*,
|
||||
stage_title: str,
|
||||
tool_calls: Optional[list[Any]] = None,
|
||||
tool_results: Optional[list[str]] = None,
|
||||
tool_detail_results: Optional[list[dict[str, Any]]] = None,
|
||||
planner_style: bool = False,
|
||||
) -> list[RenderableType]:
|
||||
"""构建与阶段同级的工具执行卡片列表。"""
|
||||
|
||||
detail_results = tool_detail_results or []
|
||||
cards = self._build_tool_detail_cards(
|
||||
detail_results,
|
||||
stage_title=stage_title,
|
||||
planner_style=planner_style,
|
||||
)
|
||||
if cards:
|
||||
return cards
|
||||
|
||||
# 兼容旧数据结构:若尚无 detail,则降级为简单文本卡片。
|
||||
fallback_lines = self._filter_redundant_tool_results(
|
||||
tool_results=tool_results or [],
|
||||
tool_detail_results=detail_results,
|
||||
)
|
||||
if not fallback_lines and tool_calls:
|
||||
fallback_lines = build_tool_call_summary_lines(tool_calls)
|
||||
if not fallback_lines:
|
||||
return []
|
||||
|
||||
fallback_border_style = "yellow"
|
||||
return [
|
||||
Panel(
|
||||
Text("\n".join(fallback_lines)),
|
||||
title=stage_title,
|
||||
border_style=fallback_border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _build_cycle_time_records_text(time_records: dict[str, float]) -> str:
|
||||
"""构建循环最外层面板展示的阶段耗时文本。"""
|
||||
|
||||
if not time_records:
|
||||
return "流程耗时:无"
|
||||
|
||||
label_map = {
|
||||
"timing_gate": "Timing Gate",
|
||||
"planner": "Planner",
|
||||
"tool_calls": "工具执行",
|
||||
}
|
||||
ordered_keys = ["timing_gate", "planner", "tool_calls"]
|
||||
|
||||
parts: list[str] = []
|
||||
for key in ordered_keys:
|
||||
duration = time_records.get(key)
|
||||
if isinstance(duration, (int, float)):
|
||||
parts.append(f"{label_map.get(key, key)} {float(duration):.2f} s")
|
||||
|
||||
for key, duration in time_records.items():
|
||||
if key in ordered_keys or not isinstance(duration, (int, float)):
|
||||
continue
|
||||
parts.append(f"{label_map.get(key, key)} {float(duration):.2f} s")
|
||||
|
||||
if not parts:
|
||||
return "流程耗时:无"
|
||||
return "流程耗时:" + " | ".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _filter_redundant_tool_results(
|
||||
*,
|
||||
@@ -941,7 +1127,9 @@ class MaisakaHeartFlowChatting:
|
||||
*,
|
||||
tool_name: str,
|
||||
prompt_text: str,
|
||||
request_messages: Optional[list[Any]] = None,
|
||||
tool_call_id: str,
|
||||
border_style: str = "bright_yellow",
|
||||
) -> Panel:
|
||||
"""将工具 prompt 渲染为可点击查看的预览入口。"""
|
||||
|
||||
@@ -950,6 +1138,26 @@ class MaisakaHeartFlowChatting:
|
||||
if tool_call_id:
|
||||
subtitle += f"\n调用ID: {tool_call_id}"
|
||||
|
||||
if isinstance(request_messages, list) and request_messages:
|
||||
try:
|
||||
normalized_messages = deserialize_prompt_messages(request_messages)
|
||||
except Exception as exc:
|
||||
logger.warning(f"工具 {tool_name} 的 request_messages 无法反序列化,已回退为文本预览: {exc}")
|
||||
else:
|
||||
return Panel(
|
||||
PromptCLIVisualizer.build_prompt_access_panel(
|
||||
normalized_messages,
|
||||
category=labels["prompt_category"],
|
||||
chat_id=self.session_id,
|
||||
request_kind=labels["request_kind"],
|
||||
selection_reason=subtitle,
|
||||
image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy",
|
||||
),
|
||||
title=labels["prompt_title"],
|
||||
border_style=border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
return Panel(
|
||||
PromptCLIVisualizer.build_text_access_panel(
|
||||
prompt_text,
|
||||
@@ -959,116 +1167,235 @@ class MaisakaHeartFlowChatting:
|
||||
subtitle=subtitle,
|
||||
),
|
||||
title=labels["prompt_title"],
|
||||
border_style="bright_yellow",
|
||||
border_style=border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
def _build_tool_detail_panels(self, tool_detail_results: list[dict[str, Any]]) -> list[RenderableType]:
|
||||
"""将 tool monitor detail 渲染为 CLI 详情卡片。"""
|
||||
def _normalize_tool_card_body_lines(self, body: Any) -> list[str]:
|
||||
"""将工具卡片正文规范化为行列表。"""
|
||||
|
||||
if isinstance(body, str):
|
||||
return [line for line in body.splitlines() if line.strip()]
|
||||
if isinstance(body, list):
|
||||
return [
|
||||
str(item).strip()
|
||||
for item in body
|
||||
if str(item).strip()
|
||||
]
|
||||
return []
|
||||
|
||||
def _build_custom_tool_sub_cards(
|
||||
self,
|
||||
sub_cards: Any,
|
||||
*,
|
||||
default_border_style: str,
|
||||
) -> list[RenderableType]:
|
||||
"""构建工具自定义子卡片。"""
|
||||
|
||||
if not isinstance(sub_cards, list):
|
||||
return []
|
||||
|
||||
renderables: list[RenderableType] = []
|
||||
for sub_card in sub_cards:
|
||||
if not isinstance(sub_card, dict):
|
||||
continue
|
||||
title = str(sub_card.get("title") or "").strip() or "附加信息"
|
||||
border_style = str(sub_card.get("border_style") or "").strip() or default_border_style
|
||||
body_lines = self._normalize_tool_card_body_lines(
|
||||
sub_card.get("body_lines", sub_card.get("content", ""))
|
||||
)
|
||||
if not body_lines:
|
||||
continue
|
||||
renderables.append(
|
||||
Panel(
|
||||
Text("\n".join(body_lines)),
|
||||
title=title,
|
||||
border_style=border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
return renderables
|
||||
|
||||
def _build_default_tool_detail_parts(
|
||||
self,
|
||||
*,
|
||||
tool_name: str,
|
||||
tool_call_id: str,
|
||||
tool_args: Any,
|
||||
summary: str,
|
||||
duration_ms: Any,
|
||||
detail: dict[str, Any],
|
||||
planner_style: bool,
|
||||
) -> list[RenderableType]:
|
||||
"""构建工具卡片默认内容块。"""
|
||||
|
||||
argument_border_style = "yellow"
|
||||
metrics_border_style = "bright_yellow"
|
||||
prompt_border_style = "bright_yellow"
|
||||
reasoning_border_style = "yellow"
|
||||
output_border_style = "bright_yellow"
|
||||
extra_info_border_style = "yellow"
|
||||
detail_labels = self._get_tool_detail_labels(tool_name)
|
||||
|
||||
parts: list[RenderableType] = []
|
||||
header_lines: list[str] = []
|
||||
if summary:
|
||||
header_lines.append(summary)
|
||||
if tool_call_id:
|
||||
header_lines.append(f"调用ID:{tool_call_id}")
|
||||
if isinstance(duration_ms, (int, float)):
|
||||
header_lines.append(f"执行耗时:{round(float(duration_ms), 2)} ms")
|
||||
if header_lines:
|
||||
parts.append(Text("\n".join(header_lines)))
|
||||
|
||||
if isinstance(tool_args, dict) and tool_args:
|
||||
parts.append(
|
||||
Panel(
|
||||
Pretty(tool_args, expand_all=True),
|
||||
title="工具参数",
|
||||
border_style=argument_border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
metrics = detail.get("metrics")
|
||||
if isinstance(metrics, dict):
|
||||
metrics_text = self._build_tool_metrics_text(metrics)
|
||||
if metrics_text:
|
||||
parts.append(
|
||||
Panel(
|
||||
Text(metrics_text),
|
||||
title="执行指标",
|
||||
border_style=metrics_border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
prompt_text = str(detail.get("prompt_text") or "").strip()
|
||||
if prompt_text:
|
||||
parts.append(
|
||||
self._build_tool_prompt_access_panel(
|
||||
tool_name=tool_name,
|
||||
prompt_text=prompt_text,
|
||||
request_messages=detail.get("request_messages") if isinstance(detail.get("request_messages"), list) else None,
|
||||
tool_call_id=tool_call_id,
|
||||
border_style=prompt_border_style,
|
||||
)
|
||||
)
|
||||
|
||||
reasoning_text = str(detail.get("reasoning_text") or "").strip()
|
||||
if reasoning_text:
|
||||
parts.append(
|
||||
Panel(
|
||||
Text(reasoning_text),
|
||||
title=detail_labels["reasoning_title"],
|
||||
border_style=reasoning_border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
output_text = str(detail.get("output_text") or "").strip()
|
||||
if output_text:
|
||||
parts.append(
|
||||
Panel(
|
||||
Text(output_text),
|
||||
title=detail_labels["output_title"],
|
||||
border_style=output_border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
extra_sections = detail.get("extra_sections")
|
||||
if isinstance(extra_sections, list):
|
||||
for section in extra_sections:
|
||||
if not isinstance(section, dict):
|
||||
continue
|
||||
section_title = str(section.get("title") or "").strip() or "附加信息"
|
||||
section_content = str(section.get("content") or "").strip()
|
||||
if not section_content:
|
||||
continue
|
||||
parts.append(
|
||||
Panel(
|
||||
Text(section_content),
|
||||
title=section_title,
|
||||
border_style=extra_info_border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
return parts
|
||||
|
||||
def _build_tool_detail_cards(
|
||||
self,
|
||||
tool_detail_results: list[dict[str, Any]],
|
||||
*,
|
||||
stage_title: str,
|
||||
planner_style: bool = False,
|
||||
) -> list[RenderableType]:
|
||||
"""将 tool monitor detail 渲染为与 Planner/Timing 平级的工具卡片。"""
|
||||
|
||||
detail_panel_border_style = "yellow"
|
||||
sub_card_border_style = "bright_yellow"
|
||||
|
||||
panels: list[RenderableType] = []
|
||||
for tool_result in tool_detail_results:
|
||||
detail = tool_result.get("detail")
|
||||
if not isinstance(detail, dict) or not detail:
|
||||
continue
|
||||
|
||||
detail_dict = detail if isinstance(detail, dict) else {}
|
||||
tool_name = str(tool_result.get("tool_name") or "unknown").strip() or "unknown"
|
||||
detail_labels = self._get_tool_detail_labels(tool_name)
|
||||
tool_title = str(tool_result.get("tool_title") or "").strip() or tool_name
|
||||
tool_call_id = str(tool_result.get("tool_call_id") or "").strip()
|
||||
tool_args = tool_result.get("tool_args")
|
||||
summary = str(tool_result.get("summary") or "").strip()
|
||||
duration_ms = tool_result.get("duration_ms")
|
||||
custom_card = tool_result.get("card")
|
||||
|
||||
parts: list[RenderableType] = []
|
||||
header_lines: list[str] = []
|
||||
if summary:
|
||||
header_lines.append(summary)
|
||||
if tool_call_id:
|
||||
header_lines.append(f"调用ID:{tool_call_id}")
|
||||
if isinstance(duration_ms, (int, float)):
|
||||
header_lines.append(f"执行耗时:{round(float(duration_ms), 2)} ms")
|
||||
if header_lines:
|
||||
parts.append(Text("\n".join(header_lines)))
|
||||
|
||||
if isinstance(tool_args, dict) and tool_args:
|
||||
parts.append(
|
||||
Panel(
|
||||
Pretty(tool_args, expand_all=True),
|
||||
title="工具参数",
|
||||
border_style="cyan",
|
||||
padding=(0, 1),
|
||||
)
|
||||
custom_title = ""
|
||||
card_border_style = detail_panel_border_style
|
||||
replace_default_children = False
|
||||
if isinstance(custom_card, dict):
|
||||
custom_title = str(custom_card.get("title") or "").strip()
|
||||
card_border_style = str(custom_card.get("border_style") or "").strip() or detail_panel_border_style
|
||||
replace_default_children = bool(custom_card.get("replace_default_children", False))
|
||||
custom_body_lines = self._normalize_tool_card_body_lines(
|
||||
custom_card.get("body_lines", custom_card.get("content", ""))
|
||||
)
|
||||
if custom_body_lines:
|
||||
parts.append(Text("\n".join(custom_body_lines)))
|
||||
|
||||
metrics = detail.get("metrics")
|
||||
if isinstance(metrics, dict):
|
||||
metrics_text = self._build_tool_metrics_text(metrics)
|
||||
if metrics_text:
|
||||
parts.append(
|
||||
Panel(
|
||||
Text(metrics_text),
|
||||
title="执行指标",
|
||||
border_style="bright_cyan",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
prompt_text = str(detail.get("prompt_text") or "").strip()
|
||||
if prompt_text:
|
||||
parts.append(
|
||||
self._build_tool_prompt_access_panel(
|
||||
if not replace_default_children:
|
||||
parts.extend(
|
||||
self._build_default_tool_detail_parts(
|
||||
tool_name=tool_name,
|
||||
prompt_text=prompt_text,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_args=tool_args,
|
||||
summary=summary,
|
||||
duration_ms=duration_ms,
|
||||
detail=detail_dict,
|
||||
planner_style=planner_style,
|
||||
)
|
||||
)
|
||||
|
||||
reasoning_text = str(detail.get("reasoning_text") or "").strip()
|
||||
if reasoning_text:
|
||||
parts.append(
|
||||
Panel(
|
||||
Text(reasoning_text),
|
||||
title=detail_labels["reasoning_title"],
|
||||
border_style="magenta",
|
||||
padding=(0, 1),
|
||||
if isinstance(custom_card, dict):
|
||||
parts.extend(
|
||||
self._build_custom_tool_sub_cards(
|
||||
custom_card.get("sub_cards"),
|
||||
default_border_style=sub_card_border_style,
|
||||
)
|
||||
)
|
||||
|
||||
output_text = str(detail.get("output_text") or "").strip()
|
||||
if output_text:
|
||||
parts.append(
|
||||
Panel(
|
||||
Text(output_text),
|
||||
title=detail_labels["output_title"],
|
||||
border_style="green",
|
||||
padding=(0, 1),
|
||||
)
|
||||
parts.extend(
|
||||
self._build_custom_tool_sub_cards(
|
||||
tool_result.get("sub_cards"),
|
||||
default_border_style=sub_card_border_style,
|
||||
)
|
||||
|
||||
extra_sections = detail.get("extra_sections")
|
||||
if isinstance(extra_sections, list):
|
||||
for section in extra_sections:
|
||||
if not isinstance(section, dict):
|
||||
continue
|
||||
section_title = str(section.get("title") or "").strip() or "附加信息"
|
||||
section_content = str(section.get("content") or "").strip()
|
||||
if not section_content:
|
||||
continue
|
||||
parts.append(
|
||||
Panel(
|
||||
Text(section_content),
|
||||
title=section_title,
|
||||
border_style="white",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if parts:
|
||||
panels.append(
|
||||
Panel(
|
||||
Group(*parts),
|
||||
title=f"{tool_name} 工具详情",
|
||||
border_style="yellow",
|
||||
title=custom_title or f"{stage_title} · {tool_title}",
|
||||
border_style=card_border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -521,9 +521,7 @@ class MCPHostLLMBridge:
|
||||
tool_definitions.append(
|
||||
{
|
||||
"name": tool_name,
|
||||
"description": "\n\n".join(
|
||||
part for part in [brief_description, detailed_description] if part.strip()
|
||||
).strip(),
|
||||
"description": brief_description,
|
||||
"parameters_schema": parameters_schema or {"type": "object", "properties": {}},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -672,6 +672,32 @@ class ComponentQueryService:
|
||||
collected_specs[entry.name] = self._build_tool_spec(entry) # type: ignore[arg-type]
|
||||
return collected_specs
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_context_payload(context: Optional[ToolExecutionContext]) -> Dict[str, Any]:
|
||||
"""提取插件工具可复用的会话上下文字段。"""
|
||||
|
||||
if context is None:
|
||||
return {}
|
||||
|
||||
payload: Dict[str, Any] = {}
|
||||
stream_id = str(context.stream_id or context.session_id or "").strip()
|
||||
if stream_id:
|
||||
payload["stream_id"] = stream_id
|
||||
payload["chat_id"] = stream_id
|
||||
|
||||
anchor_message = context.metadata.get("anchor_message")
|
||||
message_info = getattr(anchor_message, "message_info", None)
|
||||
group_info = getattr(message_info, "group_info", None)
|
||||
user_info = getattr(message_info, "user_info", None)
|
||||
|
||||
group_id = str(getattr(group_info, "group_id", "") or "").strip()
|
||||
user_id = str(getattr(user_info, "user_id", "") or "").strip()
|
||||
if group_id:
|
||||
payload["group_id"] = group_id
|
||||
if user_id:
|
||||
payload["user_id"] = user_id
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_invocation_payload(
|
||||
entry: "ToolEntry",
|
||||
@@ -690,16 +716,27 @@ class ComponentQueryService:
|
||||
"""
|
||||
|
||||
payload = dict(invocation.arguments)
|
||||
context_payload = ComponentQueryService._build_tool_context_payload(context)
|
||||
if entry.invoke_method == "plugin.invoke_action":
|
||||
stream_id = context.stream_id if context is not None else invocation.stream_id
|
||||
stream_id = str(
|
||||
context_payload.get("stream_id")
|
||||
or (context.stream_id if context is not None else invocation.stream_id)
|
||||
or invocation.stream_id
|
||||
).strip()
|
||||
reasoning = context.reasoning if context is not None else invocation.reasoning
|
||||
payload = {
|
||||
**payload,
|
||||
**{key: value for key, value in context_payload.items() if key not in payload or not payload.get(key)},
|
||||
"stream_id": stream_id,
|
||||
"chat_id": stream_id,
|
||||
"reasoning": reasoning,
|
||||
"action_data": dict(invocation.arguments),
|
||||
}
|
||||
return payload
|
||||
|
||||
for key, value in context_payload.items():
|
||||
if key not in payload or not payload.get(key):
|
||||
payload[key] = value
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -11,8 +11,7 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
|
||||
from src.chat.replyer.replyer_manager import replyer_manager
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
@@ -20,8 +19,8 @@ from src.common.logger import get_logger
|
||||
from src.core.types import ActionInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
from src.common.data_models.planned_action_data_models import PlannedAction
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -38,7 +37,7 @@ def _get_replyer(
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer | PrivateReplyer]:
|
||||
) -> Optional[MaisakaReplyGenerator]:
|
||||
"""获取回复器对象"""
|
||||
if not chat_id and not chat_stream:
|
||||
raise ValueError("chat_stream 和 chat_id 不可均为空")
|
||||
@@ -100,7 +99,7 @@ async def generate_reply(
|
||||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
chosen_actions: Optional[List["ActionPlannerInfo"]] = None,
|
||||
chosen_actions: Optional[List["PlannedAction"]] = None,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
|
||||
@@ -267,6 +267,46 @@ def _parse_data_url_image(image_url: str) -> Tuple[str, str]:
|
||||
return image_format, image_base64
|
||||
|
||||
|
||||
def _append_image_content(message_builder: MessageBuilder, content_item: Any) -> bool:
|
||||
"""向消息构建器追加图片片段。
|
||||
|
||||
兼容两种输入格式:
|
||||
1. 旧序列化格式中的 `(image_format, image_base64)` 元组。
|
||||
2. 标准字典片段中的 Data URL 或 `image_format`/`image_base64` 字段。
|
||||
"""
|
||||
|
||||
if isinstance(content_item, (tuple, list)) and len(content_item) == 2:
|
||||
image_format, image_base64 = content_item
|
||||
if not isinstance(image_format, str) or not isinstance(image_base64, str):
|
||||
raise ValueError("图片元组片段必须包含字符串类型的 image_format 和 image_base64")
|
||||
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
return True
|
||||
|
||||
if not isinstance(content_item, dict):
|
||||
return False
|
||||
|
||||
part_type = str(content_item.get("type", "text")).strip().lower()
|
||||
if part_type not in {"image", "image_url", "input_image"}:
|
||||
return False
|
||||
|
||||
image_url = content_item.get("image_url")
|
||||
if isinstance(image_url, dict):
|
||||
image_url = image_url.get("url")
|
||||
if isinstance(image_url, str):
|
||||
image_format, image_base64 = _parse_data_url_image(image_url)
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
return True
|
||||
|
||||
image_format = content_item.get("image_format")
|
||||
image_base64 = content_item.get("image_base64")
|
||||
if isinstance(image_format, str) and isinstance(image_base64, str):
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
return True
|
||||
|
||||
raise ValueError("图片片段缺少可识别的图片数据")
|
||||
|
||||
|
||||
def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None:
|
||||
"""将原始消息内容追加到内部消息构建器。
|
||||
|
||||
@@ -293,8 +333,10 @@ def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None
|
||||
if isinstance(content_item, str):
|
||||
message_builder.add_text_content(content_item)
|
||||
continue
|
||||
if _append_image_content(message_builder, content_item):
|
||||
continue
|
||||
if not isinstance(content_item, dict):
|
||||
raise ValueError("消息内容列表中仅支持字符串或字典片段")
|
||||
raise ValueError("消息内容列表中仅支持字符串、图片元组或字典片段")
|
||||
|
||||
part_type = str(content_item.get("type", "text")).strip().lower()
|
||||
if part_type == "text":
|
||||
@@ -304,22 +346,6 @@ def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None
|
||||
message_builder.add_text_content(text_content)
|
||||
continue
|
||||
|
||||
if part_type in {"image", "image_url", "input_image"}:
|
||||
image_url = content_item.get("image_url")
|
||||
if isinstance(image_url, dict):
|
||||
image_url = image_url.get("url")
|
||||
if isinstance(image_url, str):
|
||||
image_format, image_base64 = _parse_data_url_image(image_url)
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
continue
|
||||
|
||||
image_format = content_item.get("image_format")
|
||||
image_base64 = content_item.get("image_base64")
|
||||
if isinstance(image_format, str) and isinstance(image_base64, str):
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
continue
|
||||
raise ValueError("图片片段缺少可识别的图片数据")
|
||||
|
||||
raise ValueError(f"不支持的消息片段类型: {part_type}")
|
||||
|
||||
|
||||
|
||||
@@ -326,7 +326,7 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
if emoji.is_registered:
|
||||
return EmojiUpdateResponse(success=True, message="??????????", data=emoji_to_response(emoji))
|
||||
return EmojiUpdateResponse(success=True, message="表情包已注册", data=emoji_to_response(emoji))
|
||||
|
||||
emoji.is_registered = True
|
||||
emoji.is_banned = False
|
||||
|
||||
25
tests/test_maisaka_active_state_logic.py
Normal file
25
tests/test_maisaka_active_state_logic.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from src.maisaka.reasoning_engine import MaisakaReasoningEngine
|
||||
|
||||
|
||||
def test_retry_planner_after_interrupt_only_when_has_new_messages_and_more_rounds() -> None:
|
||||
assert MaisakaReasoningEngine._should_retry_planner_after_interrupt(
|
||||
round_index=0,
|
||||
max_internal_rounds=6,
|
||||
has_pending_messages=True,
|
||||
)
|
||||
|
||||
|
||||
def test_do_not_retry_planner_after_interrupt_without_pending_messages() -> None:
|
||||
assert not MaisakaReasoningEngine._should_retry_planner_after_interrupt(
|
||||
round_index=0,
|
||||
max_internal_rounds=6,
|
||||
has_pending_messages=False,
|
||||
)
|
||||
|
||||
|
||||
def test_do_not_retry_planner_after_interrupt_on_last_round() -> None:
|
||||
assert not MaisakaReasoningEngine._should_retry_planner_after_interrupt(
|
||||
round_index=5,
|
||||
max_internal_rounds=6,
|
||||
has_pending_messages=True,
|
||||
)
|
||||
63
tests/test_maisaka_deferred_tools.py
Normal file
63
tests/test_maisaka_deferred_tools.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from src.core.tooling import ToolSpec
|
||||
from src.llm_models.payload_content.message import RoleType
|
||||
from src.maisaka.chat_loop_service import MaisakaChatLoopService
|
||||
from src.maisaka.runtime import MaisakaHeartFlowChatting
|
||||
|
||||
|
||||
def _build_runtime_stub() -> MaisakaHeartFlowChatting:
|
||||
runtime = object.__new__(MaisakaHeartFlowChatting)
|
||||
runtime._current_action_tool_names = set()
|
||||
runtime.deferred_tool_specs_by_name = {}
|
||||
runtime.discovered_tool_names = set()
|
||||
return runtime
|
||||
|
||||
|
||||
def test_deferred_tools_reminder_only_lists_undiscovered_tools() -> None:
|
||||
runtime = _build_runtime_stub()
|
||||
runtime.update_deferred_tool_specs(
|
||||
[
|
||||
ToolSpec(name="plugin_alpha", brief_description="alpha"),
|
||||
ToolSpec(name="plugin_beta", brief_description="beta"),
|
||||
]
|
||||
)
|
||||
runtime.discover_deferred_tools(["plugin_alpha"])
|
||||
|
||||
reminder = runtime.build_deferred_tools_reminder()
|
||||
|
||||
assert "plugin_alpha" not in reminder
|
||||
assert "1. plugin_beta" in reminder
|
||||
assert "<system-reminder>" in reminder
|
||||
assert "tool_search" in reminder
|
||||
|
||||
|
||||
def test_search_and_discover_deferred_tools() -> None:
|
||||
runtime = _build_runtime_stub()
|
||||
runtime.update_deferred_tool_specs(
|
||||
[
|
||||
ToolSpec(name="mcp__slack__send_message", brief_description="向 Slack 发送消息"),
|
||||
ToolSpec(name="mcp__github__create_issue", brief_description="在 GitHub 创建 Issue"),
|
||||
]
|
||||
)
|
||||
|
||||
matched_tool_specs = runtime.search_deferred_tool_specs("slack send", limit=5)
|
||||
newly_discovered_tool_names = runtime.discover_deferred_tools([tool_spec.name for tool_spec in matched_tool_specs])
|
||||
|
||||
assert [tool_spec.name for tool_spec in matched_tool_specs] == ["mcp__slack__send_message"]
|
||||
assert newly_discovered_tool_names == ["mcp__slack__send_message"]
|
||||
assert [tool_spec.name for tool_spec in runtime.get_discovered_deferred_tool_specs()] == [
|
||||
"mcp__slack__send_message"
|
||||
]
|
||||
|
||||
|
||||
def test_build_request_messages_appends_injected_user_message() -> None:
|
||||
chat_loop_service = MaisakaChatLoopService(chat_system_prompt="system prompt")
|
||||
|
||||
messages = chat_loop_service._build_request_messages(
|
||||
[],
|
||||
injected_user_messages=["<system-reminder>\n1. plugin_beta\n</system-reminder>"],
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert messages[0].role == RoleType.System
|
||||
assert messages[1].role == RoleType.User
|
||||
assert messages[1].content == "<system-reminder>\n1. plugin_beta\n</system-reminder>"
|
||||
10
tests/test_maisaka_timing_gate_logic.py
Normal file
10
tests/test_maisaka_timing_gate_logic.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from src.maisaka.reasoning_engine import MaisakaReasoningEngine
|
||||
|
||||
|
||||
def test_continue_action_closes_timing_gate_for_following_rounds() -> None:
|
||||
assert MaisakaReasoningEngine._mark_timing_gate_completed("continue") is False
|
||||
|
||||
|
||||
def test_non_continue_actions_require_next_timing_gate() -> None:
|
||||
assert MaisakaReasoningEngine._mark_timing_gate_completed("wait") is True
|
||||
assert MaisakaReasoningEngine._mark_timing_gate_completed("no_reply") is True
|
||||
21
tests/test_maisaka_tool_visibility.py
Normal file
21
tests/test_maisaka_tool_visibility.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from src.maisaka.builtin_tool import get_action_tool_specs, get_timing_tool_specs
|
||||
|
||||
|
||||
def test_wait_tool_available_in_timing_stage() -> None:
|
||||
tool_names = {tool_spec.name for tool_spec in get_timing_tool_specs()}
|
||||
|
||||
assert "wait" in tool_names
|
||||
|
||||
|
||||
def test_wait_tool_not_available_in_action_stage() -> None:
|
||||
tool_names = {tool_spec.name for tool_spec in get_action_tool_specs()}
|
||||
|
||||
assert "wait" not in tool_names
|
||||
assert "finish" in tool_names
|
||||
assert "tool_search" in tool_names
|
||||
|
||||
|
||||
def test_tool_search_not_available_in_timing_stage() -> None:
|
||||
tool_names = {tool_spec.name for tool_spec in get_timing_tool_specs()}
|
||||
|
||||
assert "tool_search" not in tool_names
|
||||
Reference in New Issue
Block a user