Merge branch 'r-dev' of https://github.com/A-Dawn/MaiBot into r-dev

This commit is contained in:
DawnARC
2026-04-13 13:10:02 +08:00
83 changed files with 3789 additions and 3693 deletions

View File

@@ -42,4 +42,7 @@
如果你需要改动配置文件不需要修改实际的bot_config.toml或者model_config.toml只需要修改配置文件模版并新增一个版本号即可也不必要为配置改动创建测试文件。
# 关于webui修改
不要修改dashboard下的内容因为这部分内容由另一个仓库build
不要修改dashboard下的内容因为这部分内容由另一个仓库build
# maibot插件开发文档
https://github.com/Mai-with-u/maibot-plugin-sdk/blob/main/docs/guide.md

23
bot.py
View File

@@ -1,7 +1,6 @@
# raise RuntimeError("System Not Ready")
from pathlib import Path
from dotenv import load_dotenv
from rich.traceback import install
import asyncio
@@ -23,28 +22,6 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(script_dir)
set_locale(os.getenv("MAIBOT_LOCALE", "zh-CN"))
env_path = Path(__file__).parent / ".env"
template_env_path = Path(__file__).parent / "template" / "template.env"
if env_path.exists():
load_dotenv(str(env_path), override=True)
else:
print("[WIP] no .env file found, and templates is not ready yet.")
print("[WIP] continue startup, use environment and existing config values.")
# try:
# if template_env_path.exists():
# shutil.copyfile(template_env_path, env_path)
# print(t("startup.env_created"))
# load_dotenv(str(env_path), override=True)
# else:
# print(t("startup.env_template_missing"))
# raise FileNotFoundError(t("startup.env_file_missing"))
# except Exception as e:
# print(t("startup.env_auto_create_failed", error=e))
# raise
set_locale(os.getenv("MAIBOT_LOCALE", "zh-CN"))
# 检查是否是 Worker 进程,只在 Worker 进程中输出详细的初始化信息
# Runner 进程只需要基本的日志功能,不需要详细的初始化日志
is_worker = os.environ.get("MAIBOT_WORKER_PROCESS") == "1"

View File

@@ -1,44 +0,0 @@
{
"manifest_version": 2,
"version": "2.0.0",
"name": "BetterEmoji",
"description": "更好的表情包管理插件",
"author": {
"name": "SengokuCola",
"url": "https://github.com/SengokuCola"
},
"license": "GPL-v3.0-or-later",
"urls": {
"repository": "https://github.com/SengokuCola/BetterEmoji",
"homepage": "https://github.com/SengokuCola/BetterEmoji",
"documentation": "https://github.com/SengokuCola/BetterEmoji",
"issues": "https://github.com/SengokuCola/BetterEmoji/issues"
},
"host_application": {
"min_version": "1.0.0",
"max_version": "1.0.0"
},
"sdk": {
"min_version": "2.0.0",
"max_version": "2.99.99"
},
"dependencies": [],
"capabilities": [
"emoji.get_random",
"emoji.get_count",
"emoji.get_info",
"emoji.get_all",
"emoji.register_emoji",
"emoji.delete_emoji",
"send.text",
"send.forward"
],
"i18n": {
"default_locale": "zh-CN",
"locales_path": "_locales",
"supported_locales": [
"zh-CN"
]
},
"id": "sengokucola.betteremoji"
}

View File

@@ -1,238 +0,0 @@
"""表情包管理插件 — 新 SDK 版本
通过 /emoji 命令管理表情包的添加、列表和删除。
"""
from maibot_sdk import Command, MaiBotPlugin
import base64
import datetime
import hashlib
import re
class EmojiManagePlugin(MaiBotPlugin):
"""表情包管理插件"""
async def on_load(self) -> None:
"""处理插件加载。"""
async def on_unload(self) -> None:
"""处理插件卸载。"""
# ===== 工具方法 =====
@staticmethod
def _extract_emoji_base64(segments) -> list[str]:
"""从消息 segments 中提取 emoji/image 的 base64 数据。
segments 可以是 dict 列表或 Seg 对象列表(兼容两种格式)。
"""
results: list[str] = []
if not segments:
return results
if isinstance(segments, dict):
seg_type = segments.get("type", "")
if seg_type in ("emoji", "image"):
data = segments.get("data", "")
if data:
results.append(data)
elif seg_type == "seglist":
for child in segments.get("data", []):
results.extend(EmojiManagePlugin._extract_emoji_base64(child))
return results
# 如果有 .type 属性Seg 对象)
if hasattr(segments, "type"):
seg_type = getattr(segments, "type", "")
if seg_type in ("emoji", "image"):
results.append(getattr(segments, "data", ""))
elif seg_type == "seglist":
for child in getattr(segments, "data", []):
results.extend(EmojiManagePlugin._extract_emoji_base64(child))
return results
# 列表
for seg in segments:
results.extend(EmojiManagePlugin._extract_emoji_base64(seg))
return results
# ===== Command 组件 =====
@Command("add_emoji", description="添加表情包", pattern=r".*/emoji add.*")
async def handle_add_emoji(self, stream_id: str = "", message_segments=None, **kwargs):
"""添加表情包"""
emoji_base64_list = self._extract_emoji_base64(message_segments)
if not emoji_base64_list:
await self.ctx.send.text("未在消息中找到表情包或图片", stream_id)
return False, "未在消息中找到表情包或图片", False
success_count = 0
fail_count = 0
results = []
for i, emoji_b64 in enumerate(emoji_base64_list):
result = await self.ctx.emoji.register_emoji(emoji_b64)
if isinstance(result, dict) and result.get("success"):
success_count += 1
desc = result.get("description", "未知描述")
emotions = result.get("emotions", [])
replaced = result.get("replaced", False)
msg = f"表情包 {i + 1} 注册成功{'(替换旧表情包)' if replaced else '(新增表情包)'}"
if desc:
msg += f"\n描述: {desc}"
if emotions:
msg += f"\n情感标签: {', '.join(emotions)}"
results.append(msg)
else:
fail_count += 1
err = result.get("message", "注册失败") if isinstance(result, dict) else "注册失败"
results.append(f"表情包 {i + 1} 注册失败: {err}")
total = success_count + fail_count
summary = f"表情包注册完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total}"
if results:
summary += "\n" + "\n".join(results)
await self.ctx.send.text(summary, stream_id)
return success_count > 0, summary, success_count > 0
@Command("emoji_list", description="列表表情包", pattern=r"^/emoji list(\s+\d+)?$")
async def handle_list_emoji(self, stream_id: str = "", raw_message: str = "", **kwargs):
"""列出表情包"""
max_count = 10
match = re.match(r"^/emoji list(?:\s+(\d+))?$", raw_message)
if match and match.group(1):
max_count = min(int(match.group(1)), 50)
now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
count_result = await self.ctx.emoji.get_count()
emoji_count = count_result if isinstance(count_result, int) else 0
info_result = await self.ctx.emoji.get_info()
max_emoji = info_result.get("max_count", 0) if isinstance(info_result, dict) else 0
available = info_result.get("available_emojis", 0) if isinstance(info_result, dict) else 0
lines = [
f"📊 表情包统计信息 ({now})",
f"• 总数: {emoji_count} / {max_emoji}",
f"• 可用: {available}",
]
if emoji_count == 0:
lines.append("\n❌ 暂无表情包")
await self.ctx.send.text("\n".join(lines), stream_id)
return True, "\n".join(lines), True
all_result = await self.ctx.emoji.get_all()
all_emojis = all_result if isinstance(all_result, list) else []
if not all_emojis:
lines.append("\n❌ 无法获取表情包列表")
await self.ctx.send.text("\n".join(lines), stream_id)
return False, "\n".join(lines), True
display = all_emojis[:max_count]
lines.append(f"\n📋 显示前 {len(display)} 个表情包:")
for i, emoji in enumerate(display, 1):
if isinstance(emoji, (list, tuple)) and len(emoji) >= 3:
_, desc, emotion = emoji[0], emoji[1], emoji[2]
elif isinstance(emoji, dict):
desc = emoji.get("description", "")
emotion = emoji.get("emotion", "")
else:
desc, emotion = str(emoji), ""
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
lines.append(f"{i}. {short_desc} [{emotion}]")
if len(all_emojis) > max_count:
lines.append(f"\n💡 还有 {len(all_emojis) - max_count} 个表情包未显示")
final = "\n".join(lines)
await self.ctx.send.text(final, stream_id)
return True, final, True
@Command("delete_emoji", description="删除表情包", pattern=r".*/emoji delete.*")
async def handle_delete_emoji(self, stream_id: str = "", message_segments=None, **kwargs):
"""删除表情包"""
emoji_base64_list = self._extract_emoji_base64(message_segments)
if not emoji_base64_list:
await self.ctx.send.text("未在消息中找到表情包或图片", stream_id)
return False, "未找到表情包", False
success_count = 0
fail_count = 0
results = []
for i, emoji_b64 in enumerate(emoji_base64_list):
# 计算哈希
if isinstance(emoji_b64, str):
clean = emoji_b64.encode("ascii", errors="ignore").decode("ascii")
else:
clean = str(emoji_b64)
image_bytes = base64.b64decode(clean)
emoji_hash = hashlib.md5(image_bytes).hexdigest() # noqa: S324
result = await self.ctx.emoji.delete_emoji(emoji_hash)
if isinstance(result, dict) and result.get("success"):
success_count += 1
desc = result.get("description", "未知描述")
emotions = result.get("emotions", [])
before = result.get("count_before", 0)
after = result.get("count_after", 0)
msg = f"表情包 {i + 1} 删除成功"
if desc:
msg += f"\n描述: {desc}"
if emotions:
msg += f"\n情感标签: {', '.join(emotions)}"
msg += f"\n表情包数量: {before}{after}"
results.append(msg)
else:
fail_count += 1
err = result.get("message", "删除失败") if isinstance(result, dict) else "删除失败"
results.append(f"表情包 {i + 1} 删除失败: {err}")
total = success_count + fail_count
summary = f"表情包删除完成: 成功 {success_count} 个,失败 {fail_count} 个,共处理 {total}"
if results:
summary += "\n" + "\n".join(results)
await self.ctx.send.text(summary, stream_id)
return success_count > 0, summary, success_count > 0
@Command("random_emojis", description="发送多张随机表情包", pattern=r"^/random_emojis$")
async def handle_random_emojis(self, stream_id: str = "", **kwargs):
"""发送多张随机表情包"""
emojis = await self.ctx.emoji.get_random(5)
if not emojis:
return False, "未找到表情包", False
messages = [
{"user_id": "0", "nickname": "神秘用户", "segments": [{"type": "image", "content": e.get("base64", "")}]}
for e in emojis
]
await self.ctx.send.forward(messages, stream_id)
return True, "已发送随机表情包", True
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
"""处理配置热重载事件。
Args:
scope: 配置变更范围。
config_data: 最新配置数据。
version: 配置版本号。
"""
del scope
del config_data
del version
def create_plugin() -> EmojiManagePlugin:
"""创建表情包管理插件实例。
Returns:
EmojiManagePlugin: 新的表情包管理插件实例。
"""
return EmojiManagePlugin()

View File

@@ -3,6 +3,7 @@ You need to focus on the dialogue between {bot_name} (AI) and different users so
[Reference Information]
{identity}
{time_block}
[End of Reference Information]
You need to analyze based on the provided reference information, the current scenario, and the output rules.

View File

@@ -8,4 +8,5 @@ You are chatting in the group now. Please read the previous chat records, grasp
Try to keep it short. It is best to reply to only one topic at a time, so the reply does not become verbose or messy. Please pay attention to the chat content.
{reply_style}
You may refer to the information in [Reply Reference], but depending on the situation, you do not have to follow it completely.
{group_chat_attention_block}
Please do not output any extra content (including unnecessary prefixes or suffixes, colons, brackets, stickers, at, or @). Only output the message content itself.

View File

@@ -3,6 +3,7 @@
【参考情報】
{identity}
{time_block}
【参考情報ここまで】
提供された参考情報、現在の状況、そして出力ルールに基づいて分析してください。

View File

@@ -8,4 +8,5 @@
できるだけ短くしてください。話題は一度に一つだけに返信したほうが、冗長になったり内容が散らかったりしません。チャット内容をしっかり踏まえてください。
{reply_style}
【返信情報参考】の情報は参考にしてかまいませんが、状況に応じて完全に従う必要はありません。
{group_chat_attention_block}
余計な内容不要な前置きや後置き、コロン、括弧、スタンプ、at や @ など)は出力せず、発言内容だけを出力してください。

View File

@@ -24,4 +24,4 @@
"message_indices": [1, 2, 5]
}},
...
]
]

View File

@@ -19,4 +19,4 @@
聊天记录:
{original_text}
请直接返回JSON不要包含其他内容。
请直接返回JSON不要包含其他内容。

View File

@@ -1,27 +1,30 @@
你的任务是分析聊天和聊天中的互动情况。
你的任务是分析聊天和聊天中的互动情况,然后做出下一步动作
你需要关注 {bot_name}AI 与不同用户的对话来为选择正确的动作和行为以及搜集信息提供建议
【参考信息】
{bot_name}的人设:{identity}
{time_block}
【参考信息结束】
你需要根据提供的参考信息,当前场景和输出规则来进行分析
在当前场景中,不同的人正在互动({bot_name}也是一位参与的用户用户也可能与进行聊天互动你的任务不是生成对用户可见的发言而是进行分析来指导AI进行回复
“分析”应该体现你对当前局面的判断、你的建议、你的下一步计划,以及你为什么这样想。
你需要先搜集能够帮助{bot_name}进行下一步行动的信息,然后再给出回复意见
请你对当前场景和输出规则来进行分析,你可以参考参考信息中的内容,但不用过分遵守,仅供参考。
在当前场景中,不同的人正在互动({bot_name}也是一位参与的用户用户也可能与进行聊天互动你的任务不是生成对用户可见的发言而是进行分析来指导AI进行动作
“分析”应该体现你对当前局面的判断、你的建议、你的下一步计划,以及你为什么这样想。默认直接输出你当前的最新分析,不要重复之前的分析内容。最新分析应尽量具体,贴近上下文。
你需要先搜集能够帮助{bot_name}进行下一步行动的信息,然后再给出思考
{group_chat_attention_block}
你可以使用这些工具:
工具说明
- reply():当你判断{bot_name}现在应该正式对用户发出一条可见回复时调用。调用后系统会基于你当前这轮的想法生成一条真正展示给用户的回复。你可以针对某个用户回复,也可以对所有用户回复。
- query_jargon():当你认为某些词的含义不明确,或用户询问某些词的含义,需要进行查询
- query_memory():如果当前可用工具中存在它,当回复明显依赖历史对话、长期偏好、共同经历、人物长期信息或之前约定时使用
- tool_search()当你在deferred tools列表中需要其中某个工具时先调用它来搜索并发现对应工具它只负责让工具在后续轮次变为可用不直接执行业务
- finish()当没有更多操作需要做使用finish结束这次思考
- 其他定义的工具,你可以视情况合适使用
工具使用规则:
1. 你当前处于 Action Loop 阶段,节奏控制由独立的 timing gate 负责;如果系统让你继续,就专注于分析、搜集信息和执行真正需要的工具。
2. 如果存在用户的疑问,或者对某些概念的不确定,你可以使用工具来搜集信息或者查询含义,你可以使用多个工具。
3. 当你判断 {bot_name} 现在应该正式发出可见回复时,调用 reply()
4. 如果需要补充上下文、查看消息、查询黑话、检索记忆或使用其他可用工具,可以按需调用。
1. 你可以使用多个工具。
2. 如果存在工具可以帮助你执行某些动作,完成某些目标,直接使用工具来完成任务
3. 如果看到 `<system-reminder>` 中列出了 deferred tools而你需要其中某个工具先调用 tool_search() 搜索该工具,等它在后续轮次变为可用后再正常调用
长期记忆使用建议:
1. 仅当历史信息会明显影响当前回复时,才考虑调用 `query_memory()`。
@@ -30,11 +33,5 @@
4. 模式上:`search` 查事实或偏好,`time` 查某段时间,`episode` 查某次经历,`aggregate` 查整体情况;拿不准时用 `hybrid`。
5. 如果无命中、被过滤、或证据不足,就不要编造。
你的分析规则:
1. 默认直接输出你当前的最新分析,不要重复之前的分析内容。最新分析应尽量具体,贴近上下文。
2. 你需要先评估是用户之间在互动还是和{bot_name}在互动,不要盲目插话,弄错回复对象
3. 你需要评估哪些话是对{bot_name}的发言,哪些是用户之间的交流或者自言自语,不要频繁插入无关的话题。
{group_chat_attention_block}
现在,请你输出你对{bot_name}发言的分析你必须先输出文本内容的分析然后再进行工具调用输出json形式的function call
现在,请你输出你对{bot_name}发言的分析,你必须先输出文本内容的分析,然后再进行工具调用,:

View File

@@ -1,11 +1,9 @@
你正在qq群里聊天下面是群里正在聊的内容其中包含聊天记录和聊天中的图片
其中标注 {bot_name}(你) 的发言是你自己的发言,请注意区分:
在下面的内容中,标注 {bot_name}(你) 的发言是你自己的发言,请注意区分:
{identity}
{time_block}
{identity}
你正在群里聊天,现在请你读读之前的聊天记录,把握当前的话题,然后给出日常且口语化的回复,
尽量简短一些。最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。请注意把握聊天内容。
现在请你读读之前的聊天记录,把握当前的话题,然后给出日常且口语化的回复,
{reply_style}
你可以参考【回复信息参考】中的信息,但是视情况而定,不用完全遵守。
请注意不要输出多余内容(包括不必要的前后缀冒号括号表情包at或 @等 ),只输出发言内容就好。
{group_chat_attention_block}
请注意不要输出多余内容(包括不必要的前后缀冒号括号表情包at或 @等 ),只输出发言内容就好。

View File

@@ -8,16 +8,16 @@
在当前场景中,不同的人正在互动({bot_name} 也是一位参与的用户),用户也可能正在连续发送消息或彼此互动。
你的任务不是生成对别人可见的发言,也不是直接使用查询类工具,而是判断当前是否应该:
- continue立刻进入下一轮完整思考、搜集信息、回复与其他工具执行
- wait再等待一段时间然后重新判断,可选几秒的等待,也可等待数分钟
- no_reply本轮不继续直接等待新的外部消息
- wait固定再等待一段时间,时间到后再重新判断;
- no_reply本轮不继续直接等待新的消息
节奏控制规则:
1. 如果 {bot_name} 已经回复,但用户暂时没有新的回复,且没有新信息需要搜集,使用 wait 或者 no_reply 进行等待。
2. 如果用户有新发言,但是你评估用户还有后续发言尚未发送,可以适当等待让用户说完。
3. 在特定情况下也可以连续回复,例如想要追问,或者补充自己先前的发言,这时应调用 continue让主流程继续执行。
4. 不要每条消息都回复,不要直接因为别的用户发送了表情包就发言
5. 如果你判断现在需要真正回复、查询信息、查看上下文或做进一步分析,不要在这里完成,直接调用 continue把工作交给主流程
6. 你必须且只能调用一个工具,不要连续调用多个工具,也不要只输出文本不调用工具
3. 你需要先评估是用户之间在互动还是和{bot_name}在互动,不要盲目插话,弄错回复对象
4. 你需要评估哪些话是对{bot_name}的发言,哪些是用户之间的交流或者自言自语,不要频繁插入无关的话题
5. 在特定情况下也可以连续回复,例如想要追问,或者补充自己先前的发言,这时应调用 continue让主流程继续执行
6. 如果你判断现在需要真正回复、查询信息、查看上下文或做进一步分析,不要在这里完成,直接调用 continue把工作交给主流程
{group_chat_attention_block}

View File

@@ -1,26 +0,0 @@
你是一个专门获取长期记忆的助手。你的名字是{bot_name}。现在是{time_now}。
群里正在进行的聊天内容:
{chat_history}
现在,{sender}发送了内容:{target_message},你想要回复ta。
请仔细分析聊天内容,考虑以下几点:
1. 内容中是否包含需要查询历史知识或长期记忆的问题
2. 是否有明确的知识获取指令
如果需要使用长期记忆工具,请直接调用函数 `search_long_term_memory`;如果不需要任何工具,直接输出 `No tool needed`。
工具模式说明:
- `mode="search"`:普通长期记忆检索,适合查具体事实、偏好、历史对话内容
- `mode="time"`:按时间范围检索,必须同时提供 `time_expression`
- `mode="episode"`:按事件/情节检索,适合查“那次经历”“那件事的经过”
- `mode="aggregate"`:综合检索,适合“整体回忆一下”“把相关线索综合找出来”
优先规则:
- 问“某段时间发生了什么”:优先 `time`
- 问“某次事件/某段经历”:优先 `episode`
- 问“整体情况/最近发生过什么”:优先 `aggregate`
- 问单点事实:优先 `search`
`time_expression` 可用表达:
- `今天`、`昨天`、`前天`、`本周`、`上周`、`本月`、`上月`、`最近7天`
- 或绝对时间:`2026/03/18`、`2026/03/18 09:30`

View File

@@ -1,19 +0,0 @@
你的名字是{bot_name}。现在是{time_now}。
你正在参与聊天,你需要根据搜集到的信息总结信息。
如果搜集到的信息对于参与聊天,回答问题有帮助,请加入总结,如果无关,请不要加入到总结。
当前聊天记录:
{chat_history}
已收集的信息:
{collected_info}
分析:
- 基于已收集的信息,总结出对当前聊天有帮助的相关信息
- **如果收集的信息对当前聊天有帮助**在思考中直接给出总结信息格式为return_information(information="你的总结信息")
- **如果信息无关或没有帮助**在思考中给出return_information(information="")
**重要规则:**
- 必须严格使用检索到的信息回答问题,不要编造信息
- 答案必须精简,不要过多解释

View File

@@ -1,34 +0,0 @@
你的名字是{bot_name}。现在是{time_now}。
你正在参与聊天,你需要搜集信息来帮助你进行回复。
重要,这是当前聊天记录:
{chat_history}
聊天记录结束
已收集的信息:
{collected_info}
- 你可以对查询思路给出简短的思考:思考要简短,直接切入要点
- 思考完毕后,使用工具
**工具说明:**
- 如果涉及过往事件、历史对话、用户长期偏好或某段时间发生的事件,可以使用长期记忆查询工具
- 如果遇到不熟悉的词语、缩写、黑话或网络用语可以使用query_words工具查询其含义
- 你必须使用tool如果需要查询你必须给出使用什么工具进行查询
- 当你决定结束查询时必须调用return_information工具返回总结信息并结束查询
长期记忆工具 `search_long_term_memory` 支持以下模式:
- `mode="search"`:普通事实/偏好/历史内容检索。适合问“她喜欢什么”“我们之前讨论过什么”。
- `mode="time"`按时间范围检索。适合问“昨天发生了什么”“最近7天有哪些相关记忆”。
- `mode="episode"`:按事件/情节检索。适合问“那次灯塔停电的经过是什么”“关于某次经历还有什么”。
- `mode="aggregate"`:综合检索。适合问“帮我整体回忆一下这个人最近的情况”“把相关线索综合找出来”。
模式选择建议:
- 问单点事实、偏好、人设、具体信息:优先 `search`
- 问某段时间发生了什么:优先 `time`
- 问某次事件、某段经历、某个剧情片段:优先 `episode`
- 问整体回忆、综合找线索、总结最近发生的事:优先 `aggregate`
时间模式要求:
- 使用 `mode="time"` 时,必须填写 `time_expression`
- 可用时间表达包括:`今天`、`昨天`、`前天`、`本周`、`上周`、`本月`、`上月`、`最近7天`
- 也可以使用绝对时间:`2026/03/18`、`2026/03/18 09:30`

View File

@@ -21,6 +21,7 @@ dependencies = [
"maim-message>=0.6.2",
"maibot-dashboard==1.0.0.dev2026040439",
"maibot-plugin-sdk>=2.3.0",
"matplotlib>=3.10.5",
"mcp",
"msgpack>=1.1.2",
"numpy>=2.2.6",

View File

@@ -33,3 +33,34 @@ def test_legacy_learning_list_with_numeric_fourth_column_is_migrated():
"enable_jargon_learning": False,
},
]
def test_visual_multimodal_replyer_is_migrated_to_replyer_mode() -> None:
payload = {
"visual": {
"multimodal_replyer": True,
}
}
result = try_migrate_legacy_bot_config_dict(payload)
assert result.migrated is True
assert "visual.multimodal_replyer_moved_to_visual.replyer_mode" in result.reason
assert result.data["visual"]["replyer_mode"] == "multimodal"
assert "multimodal_replyer" not in result.data["visual"]
def test_chat_replyer_generator_type_is_migrated_to_replyer_mode() -> None:
payload = {
"chat": {
"replyer_generator_type": "legacy",
},
"visual": {},
}
result = try_migrate_legacy_bot_config_dict(payload)
assert result.migrated is True
assert "chat.replyer_generator_type_moved_to_visual.replyer_mode" in result.reason
assert result.data["visual"]["replyer_mode"] == "text"
assert "replyer_generator_type" not in result.data["chat"]

View File

@@ -1,89 +0,0 @@
"""测试表达方式自动检查任务的数据库读取行为。"""
from contextlib import contextmanager
from typing import Generator
import pytest
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
from src.common.database.database_model import Expression
@pytest.fixture(name="expression_auto_check_engine")
def expression_auto_check_engine_fixture() -> Generator:
"""创建用于表达方式自动检查任务测试的内存数据库引擎。
Yields:
Generator: 供测试使用的 SQLite 内存引擎。
"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
@pytest.mark.asyncio
async def test_select_expressions_uses_read_only_session(
monkeypatch: pytest.MonkeyPatch,
expression_auto_check_engine,
) -> None:
"""选择表达方式时应使用只读会话,并在离开会话后安全读取 ORM 字段。"""
import src.bw_learner.expression_auto_check_task as expression_auto_check_task_module
with Session(expression_auto_check_engine) as session:
session.add(
Expression(
situation="表达情绪高涨或生理反应",
style="发送💦表情符号",
content_list='["表达情绪高涨或生理反应"]',
count=1,
session_id="session-a",
checked=False,
rejected=False,
)
)
session.commit()
auto_commit_calls: list[bool] = []
@contextmanager
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
"""构造带自动提交语义的测试会话工厂。
Args:
auto_commit: 退出上下文时是否自动提交。
Yields:
Generator[Session, None, None]: SQLModel 会话对象。
"""
auto_commit_calls.append(auto_commit)
session = Session(expression_auto_check_engine)
try:
yield session
if auto_commit:
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
monkeypatch.setattr(expression_auto_check_task_module, "get_db_session", fake_get_db_session)
monkeypatch.setattr(expression_auto_check_task_module.random, "sample", lambda entries, _count: list(entries))
task = ExpressionAutoCheckTask()
expressions = await task._select_expressions(1)
assert auto_commit_calls == [False]
assert len(expressions) == 1
assert expressions[0].id is not None
assert expressions[0].situation == "表达情绪高涨或生理反应"
assert expressions[0].style == "发送💦表情符号"

View File

@@ -0,0 +1,91 @@
from types import SimpleNamespace
import pytest
import src.chat.replyer.maisaka_expression_selector as selector_module
from src.chat.replyer.maisaka_expression_selector import MaisakaExpressionSelector
from src.common.utils.utils_session import SessionUtils
def _build_target(platform: str, item_id: str, rule_type: str = "group") -> SimpleNamespace:
return SimpleNamespace(platform=platform, item_id=item_id, rule_type=rule_type)
def test_resolve_expression_group_scope_returns_related_sessions(monkeypatch: pytest.MonkeyPatch) -> None:
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
related_session_id = SessionUtils.calculate_session_id("qq", group_id="10002")
monkeypatch.setattr(
selector_module,
"global_config",
SimpleNamespace(
expression=SimpleNamespace(
expression_groups=[
SimpleNamespace(
expression_groups=[
_build_target("qq", "10001"),
_build_target("qq", "10002"),
]
)
]
)
),
)
selector = MaisakaExpressionSelector()
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
assert related_session_ids == {current_session_id, related_session_id}
assert has_global_share is False
def test_resolve_expression_group_scope_uses_star_as_global_share(monkeypatch: pytest.MonkeyPatch) -> None:
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
monkeypatch.setattr(
selector_module,
"global_config",
SimpleNamespace(
expression=SimpleNamespace(
expression_groups=[
SimpleNamespace(
expression_groups=[
_build_target("*", "*"),
]
)
]
)
),
)
selector = MaisakaExpressionSelector()
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
assert related_session_ids == {current_session_id}
assert has_global_share is True
def test_resolve_expression_group_scope_does_not_treat_empty_target_as_global(monkeypatch: pytest.MonkeyPatch) -> None:
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
monkeypatch.setattr(
selector_module,
"global_config",
SimpleNamespace(
expression=SimpleNamespace(
expression_groups=[
SimpleNamespace(
expression_groups=[
_build_target("", ""),
]
)
]
)
),
)
selector = MaisakaExpressionSelector()
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
assert related_session_ids == {current_session_id}
assert has_global_share is False

View File

@@ -31,6 +31,23 @@ async def test_calculate_hash_format_updates_runtime_path_metadata(tmp_path: Pat
assert emoji.dir_path == tmp_path.resolve()
@pytest.mark.asyncio
async def test_calculate_hash_format_reuses_existing_target_file(tmp_path: Path) -> None:
image_bytes = _build_test_image_bytes("JPEG")
tmp_file_path = tmp_path / "emoji.tmp"
target_file_path = tmp_path / "emoji.jpeg"
tmp_file_path.write_bytes(image_bytes)
target_file_path.write_bytes(image_bytes)
emoji = MaiEmoji(full_path=tmp_file_path, image_bytes=image_bytes)
assert await emoji.calculate_hash_format() is True
assert emoji.full_path == target_file_path.resolve()
assert emoji.file_name == target_file_path.name
assert not tmp_file_path.exists()
assert target_file_path.exists()
@pytest.mark.parametrize(
("model_cls", "extra_fields"),
[

View File

@@ -0,0 +1,22 @@
from src.common.data_models.message_component_data_model import ImageComponent, MessageSequence, TextComponent
from src.llm_models.payload_content.message import RoleType
from src.maisaka.context_messages import _build_message_from_sequence
def test_image_only_message_keeps_placeholder_in_text_fallback() -> None:
message_sequence = MessageSequence(
[
TextComponent("[时间]19:21:20\n[用户名]William730\n[用户群昵称]\n[msg_id]1385025976\n[发言内容]"),
ImageComponent(binary_hash="hash", content=None, binary_data=None),
]
)
message = _build_message_from_sequence(
RoleType.User,
message_sequence,
"[时间]19:21:20\n[用户名]William730\n[用户群昵称]\n[msg_id]1385025976\n[发言内容][图片]",
)
assert message is not None
assert "[发言内容]" in message.get_text_content()
assert "[图片]" in message.get_text_content()

View File

@@ -5,8 +5,7 @@ import pytest
from rich.panel import Panel
from rich.text import Text
from src.chat.replyer import maisaka_generator as legacy_replyer_module
from src.chat.replyer import maisaka_generator_multi as multimodal_replyer_module
from src.chat.replyer import maisaka_generator as replyer_module
from src.common.data_models.reply_generation_data_models import (
GenerationMetrics,
LLMCompletionResult,
@@ -37,8 +36,8 @@ class _FakeLegacyLLMServiceClient:
del args
del kwargs
async def generate_response(self, prompt: str) -> _FakeLLMResult:
assert prompt
async def generate_response_with_messages(self, *, message_factory: Callable[[object], list[Any]]) -> _FakeLLMResult:
assert message_factory(object())
return _FakeLLMResult()
@@ -54,13 +53,21 @@ class _FakeMultimodalLLMServiceClient:
@pytest.mark.asyncio
async def test_legacy_and_multimodal_replyer_monitor_detail_have_same_shape(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(legacy_replyer_module, "LLMServiceClient", _FakeLegacyLLMServiceClient)
monkeypatch.setattr(multimodal_replyer_module, "LLMServiceClient", _FakeMultimodalLLMServiceClient)
monkeypatch.setattr(legacy_replyer_module, "load_prompt", lambda *args, **kwargs: "legacy prompt")
monkeypatch.setattr(multimodal_replyer_module, "load_prompt", lambda *args, **kwargs: "multi prompt")
monkeypatch.setattr(replyer_module, "LLMServiceClient", _FakeLegacyLLMServiceClient)
monkeypatch.setattr(replyer_module, "load_prompt", lambda *args, **kwargs: "legacy prompt")
legacy_generator = legacy_replyer_module.MaisakaReplyGenerator(chat_stream=None, request_type="test_legacy")
multimodal_generator = multimodal_replyer_module.MaisakaReplyGenerator(chat_stream=None, request_type="test_multi")
legacy_generator = replyer_module.MaisakaReplyGenerator(
chat_stream=None,
request_type="test_legacy",
enable_visual_message=False,
)
multimodal_generator = replyer_module.MaisakaReplyGenerator(
chat_stream=None,
request_type="test_multi",
llm_client_cls=_FakeMultimodalLLMServiceClient,
load_prompt_func=lambda *args, **kwargs: "multi prompt",
enable_visual_message=True,
)
legacy_success, legacy_result = await legacy_generator.generate_reply_with_context(
stream_id="session-legacy",
@@ -84,6 +91,40 @@ async def test_legacy_and_multimodal_replyer_monitor_detail_have_same_shape(monk
assert legacy_result.monitor_detail["metrics"]["total_tokens"] == 19
def test_legacy_replyer_builds_message_sequence_like_multimodal() -> None:
legacy_generator = replyer_module.MaisakaReplyGenerator(
chat_stream=None,
request_type="test_legacy",
enable_visual_message=False,
)
legacy_prompt_loader = replyer_module.load_prompt
replyer_module.load_prompt = lambda *args, **kwargs: "legacy prompt"
try:
session_message = replyer_module.SessionBackedMessage(
raw_message=SimpleNamespace(),
visible_text="[Alice]你好\n[Bob]在吗",
timestamp=replyer_module.datetime.now(),
source_kind="user",
)
request_messages = legacy_generator._build_request_messages(
chat_history=[session_message],
reply_message=None,
reply_reason="测试原因",
stream_id="session-legacy",
)
finally:
replyer_module.load_prompt = legacy_prompt_loader
assert len(request_messages) == 4
assert request_messages[0].role.value == "system"
assert request_messages[1].role.value == "user"
assert request_messages[1].get_text_content() == "[Alice]你好"
assert request_messages[2].role.value == "user"
assert request_messages[2].get_text_content() == "[Bob]在吗"
assert request_messages[3].role.value == "user"
@pytest.mark.asyncio
async def test_reply_tool_puts_monitor_detail_into_metadata(monkeypatch: pytest.MonkeyPatch) -> None:
fake_monitor_detail = {
@@ -324,7 +365,7 @@ def test_reasoning_engine_build_tool_monitor_result_keeps_non_reply_tool_without
def test_runtime_build_tool_detail_panels_renders_reply_monitor_detail() -> None:
runtime = object.__new__(MaisakaHeartFlowChatting)
runtime.session_id = "session-1"
panels = runtime._build_tool_detail_panels(
panels = runtime._build_tool_detail_cards(
[
{
"tool_call_id": "call-reply-1",
@@ -348,7 +389,8 @@ def test_runtime_build_tool_detail_panels_renders_reply_monitor_detail() -> None
},
},
}
]
],
stage_title="工具调用",
)
assert len(panels) == 1
@@ -387,7 +429,7 @@ def test_runtime_build_tool_detail_panels_uses_prompt_access_panel(monkeypatch:
_fake_build_text_access_panel,
)
panels = runtime._build_tool_detail_panels(
panels = runtime._build_tool_detail_cards(
[
{
"tool_call_id": "call-reply-2",
@@ -401,7 +443,8 @@ def test_runtime_build_tool_detail_panels_uses_prompt_access_panel(monkeypatch:
"output_text": "reply output",
},
}
]
],
stage_title="工具调用",
)
assert len(panels) == 1
@@ -425,7 +468,7 @@ def test_runtime_build_tool_detail_panels_uses_emotion_prompt_access_panel(monke
_fake_build_text_access_panel,
)
panels = runtime._build_tool_detail_panels(
panels = runtime._build_tool_detail_cards(
[
{
"tool_call_id": "call-emoji-1",
@@ -439,7 +482,8 @@ def test_runtime_build_tool_detail_panels_uses_emotion_prompt_access_panel(monke
"output_text": '{"emoji_index": 1}',
},
}
]
],
stage_title="工具调用",
)
assert len(panels) == 1
@@ -448,6 +492,63 @@ def test_runtime_build_tool_detail_panels_uses_emotion_prompt_access_panel(monke
assert captured["kwargs"]["request_kind"] == "emotion"
def test_runtime_build_tool_detail_cards_uses_structured_prompt_messages_with_images(
monkeypatch: pytest.MonkeyPatch,
) -> None:
runtime = object.__new__(MaisakaHeartFlowChatting)
runtime.session_id = "session-image"
captured: dict[str, Any] = {}
def _fake_build_prompt_access_panel(messages: list[Any], **kwargs: Any) -> str:
captured["messages"] = messages
captured["kwargs"] = kwargs
return "IMAGE_PROMPT_LINK"
def _fake_build_text_access_panel(content: str, **kwargs: Any) -> str:
captured["text_content"] = content
captured["text_kwargs"] = kwargs
return "TEXT_PROMPT_LINK"
monkeypatch.setattr(
"src.maisaka.runtime.PromptCLIVisualizer.build_prompt_access_panel",
_fake_build_prompt_access_panel,
)
monkeypatch.setattr(
"src.maisaka.runtime.PromptCLIVisualizer.build_text_access_panel",
_fake_build_text_access_panel,
)
panels = runtime._build_tool_detail_cards(
[
{
"tool_call_id": "call-reply-image-1",
"tool_name": "reply",
"tool_args": {"msg_id": "m3"},
"success": True,
"duration_ms": 22.0,
"summary": "- reply [成功]: 已回复",
"detail": {
"prompt_text": "reply prompt image",
"request_messages": [
{
"role": "user",
"content": ["前缀文本", ["png", "ZmFrZQ=="]],
}
],
"output_text": "reply output",
},
}
],
stage_title="工具调用",
)
assert len(panels) == 1
assert "messages" in captured
assert "text_content" not in captured
assert captured["kwargs"]["chat_id"] == "session-image"
assert captured["kwargs"]["request_kind"] == "replyer"
def test_runtime_render_context_usage_panel_merges_timing_and_planner(monkeypatch: pytest.MonkeyPatch) -> None:
runtime = object.__new__(MaisakaHeartFlowChatting)
runtime.session_id = "session-merged"

View File

@@ -0,0 +1,23 @@
from src.maisaka.chat_loop_service import MaisakaChatLoopService
def test_build_tool_names_log_text_supports_openai_function_schema() -> None:
tool_definitions = [
{
"type": "function",
"function": {
"name": "mute_user",
"description": "禁言指定用户",
"parameters": {
"type": "object",
"properties": {},
},
},
},
{
"name": "reply",
"description": "发送回复",
},
]
assert MaisakaChatLoopService._build_tool_names_log_text(tool_definitions) == "mute_user、reply"

View File

@@ -0,0 +1,339 @@
"""MutePlugin SDK 回归测试。"""
from __future__ import annotations
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List
import pytest
from maibot_sdk.context import PluginContext
from maibot_sdk.plugin import MaiBotPlugin
from plugins.MutePlugin.plugin import create_plugin
from src.core.tooling import ToolExecutionContext, ToolInvocation
from src.plugin_runtime.component_query import ComponentQueryService
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
def _build_plugin() -> MaiBotPlugin:
"""构造已注入默认配置的插件实例。"""
plugin = create_plugin()
plugin.set_plugin_config(plugin.get_default_config())
return plugin
def test_mute_plugin_manifest_is_valid_v2() -> None:
"""MutePlugin 的 manifest 应符合当前运行时要求。"""
validator = ManifestValidator(host_version="1.0.0", sdk_version="2.3.0")
manifest = validator.load_from_plugin_path(Path("plugins/MutePlugin"))
assert manifest is not None
assert manifest.id == "sengokucola.mute-plugin"
assert manifest.manifest_version == 2
def test_create_plugin_returns_sdk_plugin() -> None:
"""插件入口应返回 SDK 插件实例。"""
plugin = create_plugin()
assert isinstance(plugin, MaiBotPlugin)
@pytest.mark.asyncio
async def test_mute_command_calls_napcat_group_ban_api() -> None:
"""手动禁言命令应通过 NapCat Adapter 新 API 执行。"""
plugin = _build_plugin()
plugin.set_plugin_config(
{
**plugin.get_default_config(),
"components": {
"enable_smart_mute": True,
"enable_mute_command": True,
},
}
)
capability_calls: List[Dict[str, Any]] = []
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability_calls.append(payload)
capability = payload["capability"]
if capability == "person.get_id_by_name":
return {"success": True, "person_id": "person-1"}
if capability == "person.get_value":
return {"success": True, "value": "123456"}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
return {"success": True, "result": {"role": "member"}}
if capability == "api.call":
return {"success": True, "result": {"status": "ok", "retcode": 0}}
if capability == "send.text":
return {"success": True}
raise AssertionError(f"unexpected capability: {capability}")
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message, intercept = await plugin.handle_mute_command(
stream_id="group-10001",
group_id="10001",
user_id="42",
matched_groups={
"target": "张三",
"duration": "120",
"reason": "刷屏",
},
)
assert success is True
assert message == "成功禁言 张三"
assert intercept is True
api_call = next(
call
for call in capability_calls
if call["capability"] == "api.call"
and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
)
assert api_call["args"]["version"] == "1"
assert api_call["args"]["args"] == {
"group_id": "10001",
"user_id": "123456",
"duration": 120,
}
@pytest.mark.asyncio
async def test_mute_tool_requires_target_person_name() -> None:
"""禁言工具在缺少目标时应直接失败并提示。"""
plugin = _build_plugin()
capability_calls: List[Dict[str, Any]] = []
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability_calls.append(payload)
return {"success": True}
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message = await plugin.handle_mute_tool(
stream_id="group-10001",
group_id="10001",
target="",
duration="60",
reason="测试",
)
assert success is False
assert message == "禁言目标不能为空"
assert capability_calls[-1]["capability"] == "send.text"
assert capability_calls[-1]["args"]["text"] == "没有指定禁言对象哦"
@pytest.mark.asyncio
async def test_mute_tool_can_unwrap_nested_person_user_id_response() -> None:
"""禁言工具应能兼容解包多层 capability 返回结果。"""
plugin = _build_plugin()
capability_calls: List[Dict[str, Any]] = []
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability_calls.append(payload)
capability = payload["capability"]
if capability == "person.get_id_by_name":
return {"success": True, "result": {"success": True, "person_id": "person-1"}}
if capability == "person.get_value":
return {"success": True, "result": {"success": True, "value": "123456"}}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
return {"success": True, "result": {"role": "member"}}
if capability == "api.call":
return {"success": True, "result": {"status": "ok"}}
if capability == "send.text":
return {"success": True}
raise AssertionError(f"unexpected capability: {capability}")
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message = await plugin.handle_mute_tool(
stream_id="group-10001",
group_id="10001",
target="张三",
duration=60,
reason="测试",
)
assert success is True
assert message == "成功禁言 张三"
api_call = next(
call
for call in capability_calls
if call["capability"] == "api.call"
and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
)
assert api_call["args"]["args"]["user_id"] == "123456"
@pytest.mark.asyncio
async def test_mute_tool_rejects_owner_before_group_ban_call() -> None:
"""禁言工具应在检测到群主时提前返回明确提示。"""
plugin = _build_plugin()
capability_calls: List[Dict[str, Any]] = []
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability_calls.append(payload)
capability = payload["capability"]
if capability == "person.get_id_by_name":
return {"success": True, "person_id": "person-1"}
if capability == "person.get_value":
return {"success": True, "value": "123456"}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
return {"success": True, "result": {"role": "owner"}}
if capability == "send.text":
return {"success": True}
raise AssertionError(f"unexpected capability: {capability}")
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message = await plugin.handle_mute_tool(
stream_id="group-10001",
group_id="10001",
target="张三",
duration=60,
reason="测试",
)
assert success is False
assert message == "张三 是群主,不能被禁言"
assert not any(
call["capability"] == "api.call" and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
for call in capability_calls
)
@pytest.mark.asyncio
async def test_mute_tool_maps_cannot_ban_owner_error_message() -> None:
"""NapCat 返回 cannot ban owner 时应转成明确中文提示。"""
plugin = _build_plugin()
capability_calls: List[Dict[str, Any]] = []
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability_calls.append(payload)
capability = payload["capability"]
if capability == "person.get_id_by_name":
return {"success": True, "person_id": "person-1"}
if capability == "person.get_value":
return {"success": True, "value": "123456"}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
return {"success": True, "result": {"role": "member"}}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.set_group_ban":
return {"success": False, "error": "NapCat 动作返回失败: action=set_group_ban message=cannot ban owner"}
if capability == "send.text":
return {"success": True}
raise AssertionError(f"unexpected capability: {capability}")
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message = await plugin.handle_mute_tool(
stream_id="group-10001",
group_id="10001",
target="张三",
duration=60,
reason="测试",
)
assert success is False
assert message == "张三 是群主,不能被禁言"
@pytest.mark.asyncio
async def test_mute_tool_accepts_nested_ok_api_result() -> None:
"""嵌套的 success/result/status=ok 返回值也应判定为成功。"""
plugin = _build_plugin()
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability = payload["capability"]
if capability == "person.get_id_by_name":
return {"success": True, "person_id": "person-1"}
if capability == "person.get_value":
return {"success": True, "value": "123456"}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
return {"success": True, "result": {"role": "member"}}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.set_group_ban":
return {
"success": True,
"result": {
"status": "ok",
"retcode": 0,
"data": None,
"message": "",
"wording": "",
},
}
if capability == "send.text":
return {"success": True}
raise AssertionError(f"unexpected capability: {capability}")
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message = await plugin.handle_mute_tool(
stream_id="group-10001",
group_id="10001",
target="张三",
duration=60,
reason="测试",
)
assert success is True
assert message == "成功禁言 张三"
def test_tool_invocation_payload_injects_group_and_user_context() -> None:
"""插件工具执行时应自动补齐群聊上下文字段。"""
entry = SimpleNamespace(invoke_method="plugin.invoke_tool")
anchor_message = SimpleNamespace(
message_info=SimpleNamespace(
group_info=SimpleNamespace(group_id="10001"),
user_info=SimpleNamespace(user_id="20002"),
)
)
invocation = ToolInvocation(tool_name="mute", arguments={"target": "张三"}, stream_id="session-1")
context = ToolExecutionContext(
session_id="session-1",
stream_id="session-1",
reasoning="test",
metadata={"anchor_message": anchor_message},
)
payload = ComponentQueryService._build_tool_invocation_payload(entry, invocation, context)
assert payload["target"] == "张三"
assert payload["stream_id"] == "session-1"
assert payload["chat_id"] == "session-1"
assert payload["group_id"] == "10001"
assert payload["user_id"] == "20002"

View File

@@ -0,0 +1,27 @@
from src.llm_models.model_client.openai_client import _sanitize_messages_for_toolless_request
from src.llm_models.payload_content.message import Message, RoleType, TextMessagePart
from src.llm_models.payload_content.tool_option import ToolCall
def test_sanitize_messages_for_toolless_request_drops_assistant_tool_call_without_parts() -> None:
messages = [
Message(
role=RoleType.Assistant,
tool_calls=[
ToolCall(
call_id="call_1",
func_name="mute_user",
args={"target": "alice"},
)
],
),
Message(
role=RoleType.User,
parts=[TextMessagePart(text="继续")],
),
]
sanitized_messages = _sanitize_messages_for_toolless_request(messages)
assert len(sanitized_messages) == 1
assert sanitized_messages[0].role == RoleType.User

View File

@@ -0,0 +1,18 @@
from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.plugin_runtime.hook_payloads import deserialize_prompt_messages, serialize_prompt_messages
def test_prompt_messages_roundtrip_preserves_image_parts() -> None:
messages = [
MessageBuilder().set_role(RoleType.User).add_text_content("你好").add_image_content("png", "ZmFrZQ==").build(),
]
serialized_messages = serialize_prompt_messages(messages)
restored_messages = deserialize_prompt_messages(serialized_messages)
assert len(restored_messages) == 1
assert restored_messages[0].role == RoleType.User
assert restored_messages[0].get_text_content() == "你好"
assert len(restored_messages[0].parts) == 2
assert restored_messages[0].parts[1].image_format == "png"
assert restored_messages[0].parts[1].image_base64 == "ZmFrZQ=="

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,9 @@
from dataclasses import dataclass, field
from datetime import datetime
import json
from typing import Any, Awaitable, Callable, List, Optional
import json
from json_repair import repair_json
from sqlmodel import select
@@ -30,7 +31,7 @@ class MaisakaExpressionSelectionResult:
class MaisakaExpressionSelector:
"""负责在 replyer 侧完成表达方式筛选与子代理选择。"""
"""负责在 replyer 侧完成表达方式筛选与子代理二次选择。"""
def _can_use_expressions(self, session_id: str) -> bool:
try:
@@ -40,18 +41,34 @@ class MaisakaExpressionSelector:
logger.error(f"检查表达方式使用开关失败: {exc}")
return False
def _get_related_session_ids(self, session_id: str) -> List[str]:
def _can_use_advanced_chosen(self, session_id: str) -> bool:
try:
return ExpressionConfigUtils.get_expression_advanced_chosen_for_chat(session_id)
except Exception as exc:
logger.error(f"检查表达方式二次选择开关失败: {exc}")
return False
@staticmethod
def _is_global_expression_group_marker(platform: str, item_id: str) -> bool:
return platform == "*" and item_id == "*"
def _resolve_expression_group_scope(self, session_id: str) -> tuple[set[str], bool]:
related_session_ids = {session_id}
has_global_share = False
expression_groups = global_config.expression.expression_groups
for expression_group in expression_groups:
target_items = expression_group.expression_groups
group_session_ids: set[str] = set()
contains_current_session = False
contains_global_share_marker = False
for target_item in target_items:
platform = target_item.platform.strip()
item_id = target_item.item_id.strip()
if self._is_global_expression_group_marker(platform, item_id):
contains_global_share_marker = True
continue
if not platform or not item_id:
continue
@@ -65,19 +82,24 @@ class MaisakaExpressionSelector:
if target_session_id == session_id:
contains_current_session = True
if contains_global_share_marker:
has_global_share = True
if contains_current_session:
related_session_ids.update(group_session_ids)
return list(related_session_ids)
return related_session_ids, has_global_share
def _load_expression_candidates(self, session_id: str) -> List[dict[str, Any]]:
related_session_ids = self._get_related_session_ids(session_id)
related_session_ids, has_global_share = self._resolve_expression_group_scope(session_id)
with get_db_session(auto_commit=False) as session:
base_query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
scoped_query = base_query.where(
(Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
)
if has_global_share:
scoped_query = base_query
else:
scoped_query = base_query.where(
(Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
)
if global_config.expression.expression_checked_only:
scoped_query = scoped_query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
expressions = session.exec(scoped_query).all()
@@ -87,7 +109,7 @@ class MaisakaExpressionSelector:
"id": expression.id,
"situation": expression.situation,
"style": expression.style,
"count": expression.count if getattr(expression, "count", None) is not None else 1,
"count": expression.count if expression.count is not None else 1,
}
for expression in expressions
if expression.id is not None and expression.situation and expression.style
@@ -171,7 +193,7 @@ class MaisakaExpressionSelector:
"你只负责根据最近聊天上下文,为这一次可见回复挑选最合适的表达方式。\n"
"请只从下面候选中选择 0 到 3 条最适合当前语境的表达方式。\n"
"优先考虑自然、贴合上下文、不生硬、不模板化。\n"
"如果没有明显合适的,就返回空列表\n"
"如果没有明显合适的,就返回空数组\n"
'严格只输出 JSON对象格式为 {"selected_ids":[123,456]}。\n\n'
f"最近上下文:\n{history_block}\n\n"
f"目标消息:{target_text or ''}\n"
@@ -208,6 +230,32 @@ class MaisakaExpressionSelector:
break
return selected_ids
def _build_direct_selection_result(
self,
*,
session_id: str,
candidates: List[dict[str, Any]],
) -> MaisakaExpressionSelectionResult:
selected_ids = [
candidate["id"]
for candidate in candidates
if isinstance(candidate.get("id"), int)
]
selected_expressions = [
candidate
for candidate in candidates
if candidate.get("id") in selected_ids
]
self._update_last_active_time(selected_ids)
logger.info(
f"表达方式直接注入session_id={session_id} 已选数={len(selected_ids)} "
f"selected_ids={selected_ids!r} 已选预览={self._format_candidate_preview(selected_expressions)}"
)
return MaisakaExpressionSelectionResult(
expression_habits=self._build_expression_habits_block(selected_expressions),
selected_expression_ids=selected_ids,
)
def _update_last_active_time(self, selected_ids: List[int]) -> None:
if not selected_ids:
return
@@ -233,15 +281,22 @@ class MaisakaExpressionSelector:
if not self._can_use_expressions(session_id):
logger.info(f"表达方式选择已跳过当前会话未启用表达方式session_id={session_id}")
return MaisakaExpressionSelectionResult()
if sub_agent_runner is None:
logger.info(f"表达方式选择已跳过:缺少 sub_agent_runnersession_id={session_id}")
return MaisakaExpressionSelectionResult()
candidates = self._load_expression_candidates(session_id)
if not candidates:
logger.info(f"表达方式选择已跳过本地候选不足session_id={session_id}")
return MaisakaExpressionSelectionResult()
if not self._can_use_advanced_chosen(session_id):
return self._build_direct_selection_result(
session_id=session_id,
candidates=candidates,
)
if sub_agent_runner is None:
logger.info(f"表达方式选择已跳过:缺少 sub_agent_runnersession_id={session_id}")
return MaisakaExpressionSelectionResult()
logger.info(
f"表达方式选择开始session_id={session_id} 候选数={len(candidates)} "
f"候选预览={self._format_candidate_preview(candidates)}"
@@ -259,10 +314,9 @@ class MaisakaExpressionSelector:
logger.exception("表达方式选择子代理执行失败")
return MaisakaExpressionSelectionResult()
logger.info(f"表达方式子代理原始结果session_id={session_id} response={raw_response!r}")
selected_ids = self._parse_selected_ids(raw_response, candidates)
if not selected_ids:
logger.info(f"表达方式选择完成但未命中session_id={session_id}")
logger.info(f"表达方式选择完成但未命中session_id={session_id}")
return MaisakaExpressionSelectionResult()
selected_expressions = [candidate for candidate in candidates if candidate.get("id") in selected_ids]

View File

@@ -1,440 +1,29 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
import random
import time
from rich.panel import Panel
from typing import Any, Callable, Optional
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import SessionMessage
from src.cli.console import console
from src.common.data_models.reply_generation_data_models import (
GenerationMetrics,
LLMCompletionResult,
ReplyGenerationResult,
build_reply_monitor_detail,
)
from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt
from src.config.config import global_config
from src.core.types import ActionInfo
from src.services.llm_service import LLMServiceClient
from src.maisaka.context_messages import (
AssistantMessage,
LLMContextMessage,
ReferenceMessage,
SessionBackedMessage,
ToolResultMessage,
)
from src.maisaka.message_adapter import parse_speaker_content
from src.maisaka.prompt_cli_renderer import PromptCLIVisualizer
from .maisaka_expression_selector import maisaka_expression_selector
logger = get_logger("replyer")
from .maisaka_generator_base import BaseMaisakaReplyGenerator
@dataclass
class MaisakaReplyContext:
"""Maisaka replyer 使用的回复上下文。"""
expression_habits: str = ""
selected_expression_ids: List[int] = field(default_factory=list)
class MaisakaReplyGenerator:
"""生成 Maisaka 的最终可见回复。"""
class MaisakaReplyGenerator(BaseMaisakaReplyGenerator):
"""Maisaka replyer。"""
def __init__(
self,
chat_stream: Optional[BotChatSession] = None,
request_type: str = "maisaka_replyer",
llm_client_cls: Optional[Any] = None,
load_prompt_func: Optional[Callable[..., str]] = None,
enable_visual_message: Optional[bool] = None,
) -> None:
self.chat_stream = chat_stream
self.request_type = request_type
self.express_model = LLMServiceClient(
task_name="replyer",
super().__init__(
chat_stream=chat_stream,
request_type=request_type,
llm_client_cls=llm_client_cls or LLMServiceClient,
load_prompt_func=load_prompt_func or load_prompt,
enable_visual_message=enable_visual_message,
replyer_mode=global_config.visual.replyer_mode,
)
self._personality_prompt = self._build_personality_prompt()
def _build_personality_prompt(self) -> str:
"""构建 replyer 使用的人设提示。"""
try:
bot_name = global_config.bot.nickname
alias_names = global_config.bot.alias_names
bot_aliases = f",也有人叫你{','.join(alias_names)}" if alias_names else ""
prompt_personality = global_config.personality.personality
if (
hasattr(global_config.personality, "states")
and global_config.personality.states
and hasattr(global_config.personality, "state_probability")
and global_config.personality.state_probability > 0
and random.random() < global_config.personality.state_probability
):
prompt_personality = random.choice(global_config.personality.states)
return f"你的名字是{bot_name}{bot_aliases},你{prompt_personality};"
except Exception as exc:
logger.warning(f"构建 Maisaka 人设提示词失败: {exc}")
return "你的名字是麦麦,你是一个活泼可爱的 AI 助手。"
@staticmethod
def _normalize_content(content: str, limit: int = 500) -> str:
normalized = " ".join((content or "").split())
if len(normalized) > limit:
return normalized[:limit] + "..."
return normalized
@staticmethod
def _format_message_time(message: LLMContextMessage) -> str:
return message.timestamp.strftime("%H:%M:%S")
@staticmethod
def _extract_visible_assistant_reply(message: AssistantMessage) -> str:
del message
return ""
def _extract_guided_bot_reply(self, message: SessionBackedMessage) -> str:
speaker_name, body = parse_speaker_content(message.processed_plain_text.strip())
bot_nickname = global_config.bot.nickname.strip() or "Bot"
if speaker_name == bot_nickname:
return self._normalize_content(body.strip())
return ""
@staticmethod
def _split_user_message_segments(raw_content: str) -> List[tuple[Optional[str], str]]:
"""按说话人拆分用户消息。"""
segments: List[tuple[Optional[str], str]] = []
current_speaker: Optional[str] = None
current_lines: List[str] = []
for raw_line in raw_content.splitlines():
speaker_name, content_body = parse_speaker_content(raw_line)
if speaker_name is not None:
if current_lines:
segments.append((current_speaker, "\n".join(current_lines)))
current_speaker = speaker_name
current_lines = [content_body]
continue
current_lines.append(raw_line)
if current_lines:
segments.append((current_speaker, "\n".join(current_lines)))
return segments
def _format_chat_history(self, messages: List[LLMContextMessage]) -> str:
"""格式化 replyer 使用的可见聊天记录。"""
bot_nickname = global_config.bot.nickname.strip() or "Bot"
parts: List[str] = []
for message in messages:
timestamp = self._format_message_time(message)
if isinstance(message, (ReferenceMessage, ToolResultMessage)):
continue
if isinstance(message, SessionBackedMessage):
guided_reply = self._extract_guided_bot_reply(message)
if guided_reply:
parts.append(f"{timestamp} {bot_nickname}(you): {guided_reply}")
continue
raw_content = message.processed_plain_text
for speaker_name, content_body in self._split_user_message_segments(raw_content):
content = self._normalize_content(content_body)
if not content:
continue
visible_speaker = speaker_name or global_config.maisaka.cli_user_name.strip() or "User"
parts.append(f"{timestamp} {visible_speaker}: {content}")
continue
if isinstance(message, AssistantMessage):
visible_reply = self._extract_visible_assistant_reply(message)
if visible_reply:
parts.append(f"{timestamp} {bot_nickname}(you): {visible_reply}")
return "\n".join(parts)
def _build_target_message_block(self, reply_message: Optional[SessionMessage]) -> str:
"""构建当前需要回复的目标消息摘要。"""
if reply_message is None:
return ""
user_info = reply_message.message_info.user_info
sender_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id
target_message_id = reply_message.message_id.strip() if reply_message.message_id else "未知"
target_content = self._normalize_content((reply_message.processed_plain_text or "").strip(), limit=300)
if not target_content:
target_content = "[无可见文本内容]"
return (
"【本次回复目标】\n"
f"- 目标消息ID{target_message_id}\n"
f"- 发送者:{sender_name}\n"
f"- 消息内容:{target_content}\n"
"- 你这次要回复的就是这条目标消息,请结合整段上下文理解,但不要误把其他历史消息当成当前回复对象。"
)
def _build_prompt(
self,
chat_history: List[LLMContextMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
expression_habits: str = "",
) -> str:
"""构建 Maisaka replyer 提示词。"""
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
formatted_history = self._format_chat_history(chat_history)
target_message_block = self._build_target_message_block(reply_message)
try:
system_prompt = load_prompt(
"maisaka_replyer",
bot_name=global_config.bot.nickname,
time_block=f"当前时间:{current_time}",
identity=self._personality_prompt,
reply_style=global_config.personality.reply_style,
)
except Exception:
system_prompt = "你是一个友好的 AI 助手,请根据聊天记录自然回复。"
extra_sections: List[str] = []
if expression_habits.strip():
extra_sections.append(expression_habits.strip())
user_sections = [
f"当前时间:{current_time}",
f"【聊天记录】\n{formatted_history}",
]
if target_message_block:
user_sections.append(target_message_block)
if extra_sections:
user_sections.append("\n\n".join(extra_sections))
user_sections.append(f"【回复信息参考】\n{reply_reason}")
user_sections.append("现在,你说:")
user_prompt = "\n\n".join(user_sections)
return f"System: {system_prompt}\n\nUser: {user_prompt}"
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
"""解析当前回复使用的会话 ID。"""
if stream_id:
return stream_id
if self.chat_stream is not None:
return self.chat_stream.session_id
return ""
async def _build_reply_context(
self,
chat_history: List[LLMContextMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
stream_id: Optional[str],
sub_agent_runner: Optional[Callable[[str], Awaitable[str]]],
) -> MaisakaReplyContext:
"""构建回复上下文:表达习惯和已选表达 ID。"""
session_id = self._resolve_session_id(stream_id)
if not session_id:
logger.warning("构建 Maisaka 回复上下文失败:缺少会话标识")
return MaisakaReplyContext()
if sub_agent_runner is None:
logger.info("表达方式选择跳过:缺少子代理执行器")
return MaisakaReplyContext()
selection_result = await maisaka_expression_selector.select_for_reply(
session_id=session_id,
chat_history=chat_history,
reply_message=reply_message,
reply_reason=reply_reason,
sub_agent_runner=sub_agent_runner,
)
return MaisakaReplyContext(
expression_habits=selection_result.expression_habits,
selected_expression_ids=selection_result.selected_expression_ids,
)
async def generate_reply_with_context(
self,
extra_info: str = "",
reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
chosen_actions: Optional[List[object]] = None,
from_plugin: bool = True,
stream_id: Optional[str] = None,
reply_message: Optional[SessionMessage] = None,
reply_time_point: Optional[float] = None,
think_level: int = 1,
unknown_words: Optional[List[str]] = None,
log_reply: bool = True,
chat_history: Optional[List[LLMContextMessage]] = None,
expression_habits: str = "",
selected_expression_ids: Optional[List[int]] = None,
sub_agent_runner: Optional[Callable[[str], Awaitable[str]]] = None,
) -> Tuple[bool, ReplyGenerationResult]:
"""结合上下文生成 Maisaka 的最终可见回复。"""
def finalize(success_value: bool) -> Tuple[bool, ReplyGenerationResult]:
result.monitor_detail = build_reply_monitor_detail(result)
return success_value, result
del available_actions
del chosen_actions
del extra_info
del from_plugin
del log_reply
del reply_time_point
del think_level
del unknown_words
result = ReplyGenerationResult()
overall_started_at = time.perf_counter()
if chat_history is None:
result.error_message = "聊天历史为空"
return finalize(False)
logger.info(
f"Maisaka 回复器开始生成: 会话流标识={stream_id} 回复原因={reply_reason!r} "
f"历史消息数={len(chat_history)} 目标消息编号={reply_message.message_id if reply_message else None}"
)
filtered_history = [
message
for message in chat_history
if not isinstance(message, (ReferenceMessage, ToolResultMessage))
]
logger.debug(f"Maisaka 回复器过滤后历史消息数={len(filtered_history)}")
if self.express_model is None:
logger.error("Maisaka 回复器的回复模型未初始化")
result.error_message = "回复模型尚未初始化"
return finalize(False)
try:
reply_context = await self._build_reply_context(
chat_history=filtered_history,
reply_message=reply_message,
reply_reason=reply_reason or "",
stream_id=stream_id,
sub_agent_runner=sub_agent_runner,
)
except Exception as exc:
import traceback
logger.error(f"Maisaka 回复器构建回复上下文失败: {exc}\n{traceback.format_exc()}")
result.error_message = f"构建回复上下文失败: {exc}"
result.metrics = GenerationMetrics(
overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2),
)
return finalize(False)
merged_expression_habits = expression_habits.strip() or reply_context.expression_habits
result.selected_expression_ids = (
list(selected_expression_ids)
if selected_expression_ids is not None
else list(reply_context.selected_expression_ids)
)
logger.info(
f"Maisaka 回复上下文构建完成: 会话流标识={stream_id} "
f"已选表达编号={result.selected_expression_ids!r}"
)
prompt_started_at = time.perf_counter()
try:
prompt = self._build_prompt(
chat_history=filtered_history,
reply_message=reply_message,
reply_reason=reply_reason or "",
expression_habits=merged_expression_habits,
)
except Exception as exc:
import traceback
logger.error(f"Maisaka 回复器构建提示词失败: {exc}\n{traceback.format_exc()}")
result.error_message = f"构建提示词失败: {exc}"
result.metrics = GenerationMetrics(
overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2),
)
return finalize(False)
prompt_ms = round((time.perf_counter() - prompt_started_at) * 1000, 2)
result.completion.request_prompt = prompt
show_replyer_prompt = bool(getattr(global_config.debug, "show_replyer_prompt", False))
show_replyer_reasoning = bool(getattr(global_config.debug, "show_replyer_reasoning", False))
preview_chat_id = self._resolve_session_id(stream_id) or "unknown"
if show_replyer_prompt:
console.print(
Panel(
PromptCLIVisualizer.build_text_access_panel(
prompt,
category="replyer",
chat_id=preview_chat_id,
request_kind="replyer",
subtitle=f"流ID: {preview_chat_id}",
),
title="Maisaka 回复器 Prompt",
border_style="bright_yellow",
padding=(0, 1),
)
)
llm_started_at = time.perf_counter()
try:
generation_result = await self.express_model.generate_response(prompt)
except Exception as exc:
logger.exception("Maisaka 回复器调用失败")
result.error_message = str(exc)
result.metrics = GenerationMetrics(
prompt_ms=prompt_ms,
llm_ms=round((time.perf_counter() - llm_started_at) * 1000, 2),
overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2),
)
return finalize(False)
llm_ms = round((time.perf_counter() - llm_started_at) * 1000, 2)
response_text = (generation_result.response or "").strip()
result.success = bool(response_text)
result.completion = LLMCompletionResult(
request_prompt=prompt,
response_text=response_text,
reasoning_text=generation_result.reasoning or "",
model_name=generation_result.model_name or "",
tool_calls=generation_result.tool_calls or [],
prompt_tokens=generation_result.prompt_tokens,
completion_tokens=generation_result.completion_tokens,
total_tokens=generation_result.total_tokens,
)
result.metrics = GenerationMetrics(
prompt_ms=prompt_ms,
llm_ms=llm_ms,
overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2),
stage_logs=[
f"prompt: {prompt_ms} ms",
f"llm: {llm_ms} ms",
],
)
if show_replyer_reasoning and result.completion.reasoning_text:
logger.info(f"Maisaka 回复器思考内容:\n{result.completion.reasoning_text}")
if not result.success:
result.error_message = "回复器返回了空内容"
logger.warning("Maisaka 回复器返回了空内容")
return finalize(False)
logger.info(
f"Maisaka 回复器生成成功: 回复文本={response_text!r} "
f"总耗时毫秒={result.metrics.overall_ms} "
f"已选表达编号={result.selected_expression_ids!r}"
)
result.text_fragments = [response_text]
return finalize(True)

View File

@@ -1,8 +1,9 @@
import random
import time
from dataclasses import dataclass, field
from datetime import datetime
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple
import random
from rich.console import Group, RenderableType
from rich.panel import Panel
@@ -10,6 +11,7 @@ from rich.text import Text
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import SessionMessage
from src.chat.utils.utils import get_chat_type_and_target_info
from src.cli.console import console
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.common.data_models.reply_generation_data_models import (
@@ -19,18 +21,11 @@ from src.common.data_models.reply_generation_data_models import (
build_reply_monitor_detail,
)
from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.config.model_configs import ModelInfo
from src.core.types import ActionInfo
from src.llm_models.payload_content.message import (
ImageMessagePart,
Message,
MessageBuilder,
RoleType,
TextMessagePart,
)
from src.services.llm_service import LLMServiceClient
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.maisaka.context_messages import (
AssistantMessage,
LLMContextMessage,
@@ -38,8 +33,9 @@ from src.maisaka.context_messages import (
SessionBackedMessage,
ToolResultMessage,
)
from src.maisaka.display.prompt_cli_renderer import PromptCLIVisualizer
from src.maisaka.message_adapter import clone_message_sequence, parse_speaker_content
from src.maisaka.prompt_cli_renderer import PromptCLIVisualizer
from src.plugin_runtime.hook_payloads import serialize_prompt_messages
from .maisaka_expression_selector import maisaka_expression_selector
@@ -54,17 +50,26 @@ class MaisakaReplyContext:
selected_expression_ids: List[int] = field(default_factory=list)
class MaisakaReplyGenerator:
"""生成 Maisaka 的最终可见回复(多模态管线)"""
class BaseMaisakaReplyGenerator:
"""Maisaka replyer 的共享实现"""
def __init__(
self,
*,
chat_stream: Optional[BotChatSession] = None,
request_type: str = "maisaka_replyer",
llm_client_cls: Any,
load_prompt_func: Callable[..., str],
enable_visual_message: Optional[bool],
replyer_mode: Literal["text", "multimodal", "auto"],
) -> None:
self.chat_stream = chat_stream
self.request_type = request_type
self.express_model = LLMServiceClient(
self._llm_client_cls = llm_client_cls
self._load_prompt = load_prompt_func
self._enable_visual_message = enable_visual_message
self._replyer_mode = replyer_mode
self.express_model = llm_client_cls(
task_name="replyer",
request_type=request_type,
)
@@ -111,28 +116,6 @@ class MaisakaReplyGenerator:
return self._normalize_content(body.strip())
return ""
@staticmethod
def _split_user_message_segments(raw_content: str) -> List[tuple[Optional[str], str]]:
segments: List[tuple[Optional[str], str]] = []
current_speaker: Optional[str] = None
current_lines: List[str] = []
for raw_line in raw_content.splitlines():
speaker_name, content_body = parse_speaker_content(raw_line)
if speaker_name is not None:
if current_lines:
segments.append((current_speaker, "\n".join(current_lines)))
current_speaker = speaker_name
current_lines = [content_body]
continue
current_lines.append(raw_line)
if current_lines:
segments.append((current_speaker, "\n".join(current_lines)))
return segments
def _build_target_message_block(self, reply_message: Optional[SessionMessage]) -> str:
if reply_message is None:
return ""
@@ -152,19 +135,93 @@ class MaisakaReplyGenerator:
"- 你这次要回复的就是这条目标消息,请结合整段上下文理解,但不要把其他历史消息当成当前回复对象。"
)
@staticmethod
def _get_chat_prompt_for_chat(chat_id: str, is_group_chat: Optional[bool]) -> str:
"""根据聊天流 ID 获取匹配的额外 prompt。"""
if not global_config.chat.chat_prompts:
return ""
for chat_prompt_item in global_config.chat.chat_prompts:
if hasattr(chat_prompt_item, "platform"):
platform = str(chat_prompt_item.platform or "").strip()
item_id = str(chat_prompt_item.item_id or "").strip()
rule_type = str(chat_prompt_item.rule_type or "").strip()
prompt_content = str(chat_prompt_item.prompt or "").strip()
elif isinstance(chat_prompt_item, str):
parts = chat_prompt_item.split(":", 3)
if len(parts) != 4:
continue
platform, item_id, rule_type, prompt_content = parts
platform = platform.strip()
item_id = item_id.strip()
rule_type = rule_type.strip()
prompt_content = prompt_content.strip()
else:
continue
if not platform or not item_id or not prompt_content:
continue
if rule_type == "group":
config_is_group = True
config_chat_id = SessionUtils.calculate_session_id(platform, group_id=item_id)
elif rule_type == "private":
config_is_group = False
config_chat_id = SessionUtils.calculate_session_id(platform, user_id=item_id)
else:
continue
if config_is_group != is_group_chat:
continue
if config_chat_id == chat_id:
return prompt_content
return ""
def _build_group_chat_attention_block(self, session_id: str) -> str:
"""构建当前聊天场景下的额外注意事项块。"""
if not session_id:
return ""
try:
is_group_chat, _ = get_chat_type_and_target_info(session_id)
except Exception:
is_group_chat = None
prompt_lines: List[str] = []
if is_group_chat is True:
if group_chat_prompt := global_config.chat.group_chat_prompt.strip():
prompt_lines.append(f"通用注意事项:\n{group_chat_prompt}")
elif is_group_chat is False:
if private_chat_prompt := global_config.chat.private_chat_prompts.strip():
prompt_lines.append(f"通用注意事项:\n{private_chat_prompt}")
if chat_prompt := self._get_chat_prompt_for_chat(session_id, is_group_chat).strip():
prompt_lines.append(f"当前聊天额外注意事项:\n{chat_prompt}")
if not prompt_lines:
return ""
return "在该聊天中的注意事项:\n" + "\n\n".join(prompt_lines) + "\n"
def _build_system_prompt(
self,
reply_message: Optional[SessionMessage],
reply_reason: str,
expression_habits: str = "",
stream_id: Optional[str] = None,
) -> str:
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
target_message_block = self._build_target_message_block(reply_message)
session_id = self._resolve_session_id(stream_id)
try:
system_prompt = load_prompt(
system_prompt = self._load_prompt(
"maisaka_replyer",
bot_name=global_config.bot.nickname,
group_chat_attention_block=self._build_group_chat_attention_block(session_id),
time_block=f"当前时间:{current_time}",
identity=self._personality_prompt,
reply_style=global_config.personality.reply_style,
@@ -184,38 +241,35 @@ class MaisakaReplyGenerator:
return f"{system_prompt}\n\n" + "\n\n".join(sections)
def _build_reply_instruction(self) -> str:
return "请自然地回复。不要输出多余说明、括号、at 或额外标记,只输出实际要发送的内容。"
return "请自然地回复。不要输出多余说明、括号、@ 或额外标记,只输出实际要发送的内容。"
def _build_multimodal_user_message(
def _build_visual_user_message(
self,
message: SessionBackedMessage,
default_user_name: str,
enable_visual_message: bool,
) -> Optional[Message]:
speaker_name, _ = parse_speaker_content(message.processed_plain_text.strip())
visible_speaker = speaker_name or default_user_name
if not enable_visual_message:
return None
raw_message = clone_message_sequence(message.raw_message)
if not raw_message.components:
raw_message = MessageSequence([TextComponent(f"[{visible_speaker}]")])
elif isinstance(raw_message.components[0], TextComponent):
first_text = raw_message.components[0].text or ""
raw_message.components[0] = TextComponent(f"[{visible_speaker}]{first_text}")
else:
raw_message.components.insert(0, TextComponent(f"[{visible_speaker}]"))
raw_message = MessageSequence([TextComponent(message.processed_plain_text)])
multimodal_message = SessionBackedMessage(
visual_message = SessionBackedMessage(
raw_message=raw_message,
visible_text=f"[{visible_speaker}]{message.processed_plain_text}",
visible_text=message.processed_plain_text,
timestamp=message.timestamp,
message_id=message.message_id,
original_message=message.original_message,
source_kind=message.source_kind,
)
return multimodal_message.to_llm_message()
return visual_message.to_llm_message()
def _build_history_messages(self, chat_history: List[LLMContextMessage]) -> List[Message]:
bot_nickname = global_config.bot.nickname.strip() or "Bot"
default_user_name = global_config.maisaka.cli_user_name.strip() or "User"
def _build_history_messages(
self,
chat_history: List[LLMContextMessage],
enable_visual_message: bool,
) -> List[Message]:
messages: List[Message] = []
for message in chat_history:
@@ -230,25 +284,14 @@ class MaisakaReplyGenerator:
)
continue
multimodal_message = self._build_multimodal_user_message(message, default_user_name)
if multimodal_message is not None:
messages.append(multimodal_message)
visual_message = self._build_visual_user_message(message, enable_visual_message)
if visual_message is not None:
messages.append(visual_message)
continue
for speaker_name, content_body in self._split_user_message_segments(message.processed_plain_text):
content = self._normalize_content(content_body)
if not content:
continue
visible_speaker = speaker_name or default_user_name
if visible_speaker == bot_nickname:
messages.append(
MessageBuilder().set_role(RoleType.Assistant).add_text_content(content).build()
)
continue
user_content = f"[{visible_speaker}]{content}"
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(user_content).build())
llm_message = message.to_llm_message()
if llm_message is not None:
messages.append(llm_message)
continue
if isinstance(message, AssistantMessage):
@@ -266,34 +309,33 @@ class MaisakaReplyGenerator:
reply_message: Optional[SessionMessage],
reply_reason: str,
expression_habits: str = "",
stream_id: Optional[str] = None,
enable_visual_message: bool = False,
) -> List[Message]:
messages: List[Message] = []
system_prompt = self._build_system_prompt(
reply_message=reply_message,
reply_reason=reply_reason,
expression_habits=expression_habits,
stream_id=stream_id,
)
instruction = self._build_reply_instruction()
messages.append(MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build())
messages.extend(self._build_history_messages(chat_history))
messages.extend(self._build_history_messages(chat_history, enable_visual_message))
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(instruction).build())
return messages
@staticmethod
def _build_request_prompt_preview(messages: List[Message]) -> str:
preview_lines: List[str] = []
for message in messages:
role_name = message.role.value.capitalize()
part_previews: List[str] = []
for part in message.parts:
if isinstance(part, TextMessagePart):
part_previews.append(part.text)
continue
if isinstance(part, ImageMessagePart):
part_previews.append(f"[图片:{part.normalized_image_format}]")
preview_lines.append(f"{role_name}: {''.join(part_previews)}")
return "\n\n".join(preview_lines)
def _resolve_enable_visual_message(self, model_info: Optional[ModelInfo] = None) -> bool:
if self._enable_visual_message is not None:
return self._enable_visual_message
if self._replyer_mode == "multimodal":
if model_info is not None and not model_info.visual:
raise ValueError(f"replyer_mode=multimodal但模型 '{model_info.name}' 未开启 visual无法使用多模态 replyer")
return True
if self._replyer_mode == "text":
return False
return bool(model_info.visual) if model_info is not None else False
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
if stream_id:
@@ -349,7 +391,6 @@ class MaisakaReplyGenerator:
selected_expression_ids: Optional[List[int]] = None,
sub_agent_runner: Optional[Callable[[str], Awaitable[str]]] = None,
) -> Tuple[bool, ReplyGenerationResult]:
def finalize(success_value: bool) -> Tuple[bool, ReplyGenerationResult]:
result.monitor_detail = build_reply_monitor_detail(result)
return success_value, result
@@ -411,7 +452,7 @@ class MaisakaReplyGenerator:
)
logger.info(
f"回复上下文完成: 流={stream_id} 已选表达={result.selected_expression_ids!r}"
f"回复上下文完成 流={stream_id} 已选表达={result.selected_expression_ids!r}"
)
prompt_started_at = time.perf_counter()
@@ -421,6 +462,7 @@ class MaisakaReplyGenerator:
reply_message=reply_message,
reply_reason=reply_reason or "",
expression_habits=merged_expression_habits,
stream_id=stream_id,
)
except Exception as exc:
import traceback
@@ -433,24 +475,36 @@ class MaisakaReplyGenerator:
return finalize(False)
prompt_ms = round((time.perf_counter() - prompt_started_at) * 1000, 2)
prompt_preview = self._build_request_prompt_preview(request_messages)
prompt_preview = PromptCLIVisualizer._build_prompt_dump_text(request_messages)
show_replyer_prompt = bool(getattr(global_config.debug, "show_replyer_prompt", False))
show_replyer_reasoning = bool(getattr(global_config.debug, "show_replyer_reasoning", False))
def message_factory(_client: object) -> List[Message]:
def message_factory(_client: object, model_info: Optional[ModelInfo] = None) -> List[Message]:
nonlocal prompt_ms, prompt_preview, request_messages
prompt_started_at = time.perf_counter()
request_messages = self._build_request_messages(
chat_history=filtered_history,
reply_message=reply_message,
reply_reason=reply_reason or "",
expression_habits=merged_expression_habits,
stream_id=stream_id,
enable_visual_message=self._resolve_enable_visual_message(model_info),
)
prompt_ms = round((time.perf_counter() - prompt_started_at) * 1000, 2)
prompt_preview = PromptCLIVisualizer._build_prompt_dump_text(request_messages)
return request_messages
result.completion.request_prompt = prompt_preview
preview_chat_id = self._resolve_session_id(stream_id)
replyer_prompt_section: RenderableType | None = None
if show_replyer_prompt:
replyer_prompt_section = Panel(
PromptCLIVisualizer.build_text_access_panel(
prompt_preview,
PromptCLIVisualizer.build_prompt_access_panel(
request_messages,
category="replyer",
chat_id=preview_chat_id,
request_kind="replyer",
subtitle=f"ID: {preview_chat_id}",
selection_reason=f"ID: {preview_chat_id}",
image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy",
),
title="Reply Prompt",
border_style="bright_yellow",
@@ -472,6 +526,8 @@ class MaisakaReplyGenerator:
)
return finalize(False)
result.completion.request_prompt = prompt_preview
result.request_messages = serialize_prompt_messages(request_messages)
llm_ms = round((time.perf_counter() - llm_started_at) * 1000, 2)
response_text = (generation_result.response or "").strip()
result.success = bool(response_text)
@@ -504,7 +560,7 @@ class MaisakaReplyGenerator:
return finalize(False)
logger.info(
f"Maisaka 回复器生成成功: 文本={response_text!r} "
f"Maisaka 回复器生成成功 文本={response_text!r} "
f"总耗时ms={result.metrics.overall_ms} 已选表达={result.selected_expression_ids!r}"
)
if show_replyer_prompt or show_replyer_reasoning:

View File

@@ -1,21 +0,0 @@
from typing import Type
from src.config.config import global_config
def get_maisaka_replyer_class() -> Type[object]:
"""根据配置返回 Maisaka replyer 类。"""
generator_type = get_maisaka_replyer_generator_type()
if generator_type == "multimodal":
from .maisaka_generator_multi import MaisakaReplyGenerator
return MaisakaReplyGenerator
from .maisaka_generator import MaisakaReplyGenerator
return MaisakaReplyGenerator
def get_maisaka_replyer_generator_type() -> str:
"""返回当前配置的 Maisaka replyer 生成器类型。"""
return "multimodal" if global_config.visual.multimodal_replyer else "legacy"

View File

@@ -1,15 +1,10 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import Any, Dict, Optional
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from src.chat.replyer.maisaka_replyer_factory import (
get_maisaka_replyer_class,
get_maisaka_replyer_generator_type,
)
from src.config.config import global_config
from src.common.logger import get_logger
if TYPE_CHECKING:
from src.chat.replyer.group_generator import DefaultReplyer
from src.chat.replyer.private_generator import PrivateReplyer
from .maisaka_generator import MaisakaReplyGenerator
logger = get_logger("ReplyerManager")
@@ -20,20 +15,25 @@ class ReplyerManager:
def __init__(self) -> None:
self._repliers: Dict[str, Any] = {}
@staticmethod
def _get_maisaka_generator_type() -> str:
"""返回当前配置下 Maisaka replyer 的消息模式。"""
return global_config.visual.replyer_mode
def get_replyer(
self,
chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None,
request_type: str = "replyer",
replyer_type: str = "default",
) -> Optional["DefaultReplyer | PrivateReplyer | Any"]:
) -> Optional[MaisakaReplyGenerator]:
"""按会话和 replyer 类型获取实例。"""
stream_id = chat_stream.session_id if chat_stream else chat_id
if not stream_id:
logger.warning("[ReplyerManager] 缺少 stream_id无法获取 replyer")
return None
generator_type = get_maisaka_replyer_generator_type() if replyer_type == "maisaka" else ""
generator_type = self._get_maisaka_generator_type() if replyer_type == "maisaka" else ""
cache_key = f"{replyer_type}:{generator_type}:{stream_id}"
if cache_key in self._repliers:
logger.info(f"[ReplyerManager] 命中缓存 replyer: cache_key={cache_key}")
@@ -51,29 +51,13 @@ class ReplyerManager:
try:
if replyer_type == "maisaka":
logger.info(f"[ReplyerManager] 选择 MaisakaReplyGenerator: generator_type={generator_type}")
maisaka_replyer_class = get_maisaka_replyer_class()
replyer = maisaka_replyer_class(
chat_stream=target_stream,
request_type=request_type,
)
elif target_stream.is_group_session:
logger.info("[ReplyerManager] importing DefaultReplyer")
from src.chat.replyer.group_generator import DefaultReplyer
replyer = DefaultReplyer(
replyer = MaisakaReplyGenerator(
chat_stream=target_stream,
request_type=request_type,
)
else:
logger.info("[ReplyerManager] importing PrivateReplyer")
from src.chat.replyer.private_generator import PrivateReplyer
replyer = PrivateReplyer(
chat_stream=target_stream,
request_type=request_type,
)
logger.warning(f"[ReplyerManager] 不支持的 replyer_type={replyer_type}")
return None
except Exception:
logger.exception(f"[ReplyerManager] 创建 replyer 失败: cache_key={cache_key}")
raise

View File

@@ -1,7 +1,7 @@
from typing import Optional
from src.config.config import global_config
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("common_utils")
@@ -10,23 +10,14 @@ class TempMethodsExpression:
"""用于临时存放一些方法的类"""
@staticmethod
def get_expression_config_for_chat(chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
"""
根据聊天流ID获取表达配置
Args:
chat_stream_id: 聊天流ID格式为哈希值
Returns:
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
"""
def _find_expression_config_item(chat_stream_id: Optional[str] = None):
if not global_config.expression.learning_list:
return True, True, True
return None
if chat_stream_id:
for config_item in global_config.expression.learning_list:
if not config_item.platform and not config_item.item_id:
continue # 这是全局的
continue
stream_id = TempMethodsExpression._get_stream_id(
config_item.platform,
str(config_item.item_id),
@@ -34,14 +25,44 @@ class TempMethodsExpression:
)
if stream_id is None:
continue
if stream_id == chat_stream_id:
if stream_id != chat_stream_id:
continue
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
return config_item
for config_item in global_config.expression.learning_list:
if not config_item.platform and not config_item.item_id:
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
return config_item
return True, True, True
return None
@staticmethod
def get_expression_advanced_chosen_for_chat(chat_stream_id: Optional[str] = None) -> bool:
"""根据聊天流 ID 获取表达方式是否启用二次选择。"""
config_item = TempMethodsExpression._find_expression_config_item(chat_stream_id)
if config_item is None:
return False
return config_item.advanced_chosen
@staticmethod
def get_expression_config_for_chat(chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
"""
根据聊天流 ID 获取表达配置。
Args:
chat_stream_id: 聊天流 ID格式为哈希值
Returns:
tuple: (是否使用表达, 是否学习表达, 是否启用 jargon 学习)
"""
config_item = TempMethodsExpression._find_expression_config_item(chat_stream_id)
if config_item is None:
return True, True, True
return (
config_item.use_expression,
config_item.enable_learning,
config_item.enable_jargon_learning,
)
@staticmethod
def _get_stream_id(
@@ -50,15 +71,15 @@ class TempMethodsExpression:
is_group: bool = False,
) -> Optional[str]:
"""
根据平台、ID字符串和是否为群聊生成聊天流ID
根据平台、ID 字符串和是否为群聊生成聊天流 ID
Args:
platform: 平台名称
id_str: 用户或群组的原始ID字符串
id_str: 用户或群组的原始 ID 字符串
is_group: 是否为群聊
Returns:
str: 生成的聊天流ID哈希值
str: 生成的聊天流 ID哈希值
"""
try:
from src.common.utils.utils_session import SessionUtils
@@ -68,5 +89,5 @@ class TempMethodsExpression:
else:
return SessionUtils.calculate_session_id(platform, user_id=str(id_str))
except Exception as e:
logger.error(f"生成聊天流ID失败: {e}")
logger.error(f"生成聊天流 ID 失败: {e}")
return None

View File

@@ -20,7 +20,7 @@ from src.services.embedding_service import EmbeddingServiceClient
from .typo_generator import ChineseTypoGenerator
if TYPE_CHECKING:
from src.common.data_models.info_data_model import TargetPersonInfo
from src.common.data_models.chat_target_info_data_model import ChatTargetInfo
logger = get_logger("chat_utils")
_warned_unconfigured_platforms: set[str] = set()
@@ -699,7 +699,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
return time.strftime("%H:%M:%S", time.localtime(timestamp))
def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetPersonInfo"]]:
def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["ChatTargetInfo"]]:
"""
获取聊天类型(是否群聊)和私聊对象信息。
@@ -734,13 +734,13 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
):
user_nickname = chat_stream.context.message.message_info.user_info.user_nickname
from src.common.data_models.info_data_model import TargetPersonInfo # 解决循环导入问题
from src.common.data_models.chat_target_info_data_model import ChatTargetInfo # 解决循环导入问题
# Initialize target_info with basic info
target_info = TargetPersonInfo(
target_info = ChatTargetInfo(
platform=platform,
user_id=user_id,
user_nickname=user_nickname, # type: ignore
session_nickname=user_nickname or "",
person_id=None,
person_name=None,
)
@@ -752,6 +752,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
logger.warning(f"用户 {user_nickname} 尚未认识")
# 如果用户尚未认识则返回False和None
return False, None
target_info.is_known = True
if person.person_id:
target_info.person_id = person.person_id
target_info.person_name = person.person_name

View File

@@ -24,125 +24,148 @@ logger = get_logger("emoji")
class BaseImageDataModel(BaseDatabaseDataModel[Images]):
def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None):
if not full_path:
# 创建时候即检测文件路径合法性
raise ValueError("表情包路径不能为空")
raise ValueError("图片路径不能为空")
if Path(full_path).is_dir() or not Path(full_path).exists():
raise FileNotFoundError(f"表情包路径无效: {full_path}")
raise FileNotFoundError(f"图片路径无效: {full_path}")
resolved_path = Path(full_path).absolute().resolve()
self.full_path: Path
self.dir_path: Path
self.file_name: str
self._set_full_path(resolved_path)
self.file_hash: str = None # type: ignore
self.image_bytes: Optional[bytes] = image_bytes
self.image_format: str = "" # 图片格式
self.image_format: str = ""
def _set_full_path(self, full_path: Path) -> None:
"""同步更新文件路径相关的运行时元数据。"""
"""同步刷新路径、目录和文件名等运行时元数据。"""
resolved_path = full_path.absolute().resolve()
self.full_path = resolved_path
self.dir_path = resolved_path.parent.resolve()
self.file_name = resolved_path.name
def _restore_image_format_from_path(self) -> None:
"""根据文件扩展名恢复基础图片格式信息。"""
"""根据文件扩展名恢复图片格式信息。"""
self.image_format = self.full_path.suffix.removeprefix(".").lower()
def _build_non_conflicting_path(self, target_path: Path) -> Path:
"""在目标路径被占用时,生成一个可用的新路径。"""
candidate_path = target_path
index = 1
while candidate_path.exists():
candidate_path = target_path.with_name(
f"{target_path.stem}_{self.file_hash[:8]}_{index}{target_path.suffix}"
)
index += 1
return candidate_path
def _rename_file_to_match_format(self) -> None:
"""修正文件扩展名,并处理目标文件已存在的冲突。"""
new_file_name = ".".join(self.file_name.split(".")[:-1] + [self.image_format])
new_full_path = self.dir_path / new_file_name
if new_full_path == self.full_path:
return
if new_full_path.exists():
existing_file_hash = hashlib.sha256(self.read_image_bytes(new_full_path)).hexdigest()
if existing_file_hash == self.file_hash:
logger.info(f"[初始化] {new_full_path.name} 已存在且内容一致,复用已有文件")
self.full_path.unlink()
self._set_full_path(new_full_path)
return
conflict_free_path = self._build_non_conflicting_path(new_full_path)
logger.warning(
f"[初始化] {new_full_path.name} 已存在且内容不同,改为保存到 {conflict_free_path.name}"
)
self.full_path.rename(conflict_free_path)
self._set_full_path(conflict_free_path)
return
self.full_path.rename(new_full_path)
self._set_full_path(new_full_path)
def read_image_bytes(self, path: Path) -> bytes:
"""
同步读取图片文件的字节内容
同步读取图片文件的字节内容
Args:
path (Path): 图片文件的完整路径
path: 图片文件的完整路径
Returns:
return (bytes): 图片文件的字节内容
Raises:
FileNotFoundError: 如果文件不存在则抛出该异常
Exception: 其他读取文件时发生的异常
图片文件的字节内容
"""
try:
with open(path, "rb") as f:
return f.read()
except FileNotFoundError as e:
with open(path, "rb") as file:
return file.read()
except FileNotFoundError as exc:
logger.error(f"[读取图片文件] 文件未找到: {path}")
raise e
except Exception as e:
logger.error(f"[读取图片文件] 读取文件时发生错误: {e}")
raise e
raise exc
except Exception as exc:
logger.error(f"[读取图片文件] 读取文件时发生错误: {exc}")
raise exc
def get_image_format(self, image_bytes: bytes) -> str:
"""
获取图片的格式
获取图片的实际格式
Args:
image_bytes (bytes): 图片的字节内容
image_bytes: 图片的字节内容
Returns:
return (str): 图片的格式(小写)
Raises:
ValueError: 如果无法识别图片格式
Exception: 其他读取图片格式时发生的异常
小写格式名,例如 `png`、`jpeg`。
"""
try:
with PILImage.open(io.BytesIO(image_bytes)) as img:
if not img.format:
raise ValueError("无法识别图片格式")
return img.format.lower()
except Exception as e:
logger.error(f"[获取图片格式] 读取图片格式时发生错误: {e}")
raise e
except Exception as exc:
logger.error(f"[获取图片格式] 读取图片格式时发生错误: {exc}")
raise exc
async def calculate_hash_format(self) -> bool:
"""
异步计算表情包的哈希值和格式,初始化后应该执行此方法来确保对象的哈希值和格式正确
计算图片哈希和实际格式,并在需要时修正扩展名。
Returns:
return (bool): 如果成功计算哈希值和格式则返回True否则返回False
成功返回 `True`,失败返回 `False`。
"""
try:
# 计算哈希值
logger.debug(f"[初始化] 计算 {self.file_name} 的哈希值...")
if not self.image_bytes:
if self.image_bytes is None:
logger.debug(f"[初始化] 正在读取文件: {self.full_path}")
image_bytes = await asyncio.to_thread(self.read_image_bytes, self.full_path)
else:
image_bytes = self.image_bytes
self.image_bytes = image_bytes
self.file_hash = hashlib.sha256(image_bytes).hexdigest()
logger.debug(f"[初始化] {self.file_name} 计算哈希值成功: {self.file_hash}")
# 用PIL读取图片格式
logger.debug(f"[初始化] 读取 {self.file_name} 的图片格式...")
self.image_format = await asyncio.to_thread(self.get_image_format, image_bytes)
logger.debug(f"[初始化] {self.file_name} 读取图片格式成功: {self.image_format}")
# 比对文件扩展名和实际格式
file_ext = self.file_name.split(".")[-1].lower()
if file_ext != self.image_format:
logger.warning(
f"[初始化] {self.file_name} 文件扩展名与实际格式不符: ext`{file_ext}`!=`{self.image_format}`"
)
# 重命名文件以匹配实际格式
new_file_name = ".".join(self.file_name.split(".")[:-1] + [self.image_format])
new_full_path = self.dir_path / new_file_name
self.full_path.rename(new_full_path)
self._set_full_path(new_full_path)
self._rename_file_to_match_format()
return True
except Exception as e:
logger.error(f"[初始化] 初始化图片时发生错误: {e}")
except Exception as exc:
logger.error(f"[初始化] 初始化图片时发生错误: {exc}")
logger.error(traceback.format_exc())
return False
class MaiEmoji(BaseImageDataModel):
"""麦麦的表情包对象,仅当**图片文件存在**时才应该创建此对象,数据库记录如果标记为文件不存在`(no_file_flag = True)`则不应该调用 `from_db_instance` 方法来创建此对象"""
"""表情包数据模型。"""
def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None):
# self.embedding = []
self.description: str = ""
self.emotion: List[str] = []
self.query_count = 0
@@ -152,33 +175,26 @@ class MaiEmoji(BaseImageDataModel):
@classmethod
def from_db_instance(cls, db_record: Images):
"""从数据库记录建 MaiEmoji 对象,如果记录标记为文件不存在则**抛出异常**
调用者应该对数据库记录进行检查,如果 `no_file_flag` 为 True 则不应该调用此方法
Args:
db_record (Images): 数据库中的图片记录
Returns:
return (MaiEmoji): 包含图片信息的 MaiEmoji 对象
Raises:
ValueError: 如果数据库记录标记为文件不存在则抛出该异常
"""
"""从数据库记录`MaiEmoji` 对象"""
if db_record.no_file_flag:
raise ValueError(f"数据库记录 {db_record.image_hash} 标记为文件不存在,无法创建 MaiEmoji 对象")
obj = cls(db_record.full_path)
obj.file_hash = db_record.image_hash
obj._restore_image_format_from_path()
description = db_record.description or ""
obj.description = description
normalized_tags = [
str(item).strip()
for item in str(description).replace("", ",").replace("", ",").replace("", ",").split(",")
for item in str(description).replace("", ",").replace("", ",").replace("", ",").split(",")
if str(item).strip()
]
deduped_tags: List[str] = []
for item in normalized_tags:
if item not in deduped_tags:
deduped_tags.append(item)
obj.emotion = deduped_tags
obj.query_count = db_record.query_count
obj.last_used_time = db_record.last_used_time
@@ -198,7 +214,7 @@ class MaiEmoji(BaseImageDataModel):
class MaiImage(BaseImageDataModel):
"""麦麦图片数据模型,仅当**图片文件存在**时才应该创建此对象,数据库记录如果标记为文件不存在`(no_file_flag = True)`则不应该调用 `from_db_instance` 方法来创建此对象"""
"""普通图片数据模型"""
def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None):
self.description: str = ""
@@ -207,19 +223,10 @@ class MaiImage(BaseImageDataModel):
@classmethod
def from_db_instance(cls, db_record: Images):
"""从数据库记录建 MaiImage 对象,如果记录标记为文件不存在则**抛出异常**
调用者应该对数据库记录进行检查,如果 `no_file_flag` 为 True 则不应该调用此方法
Args:
db_record (Images): 数据库中的图片记录
Returns:
return (MaiImage): 包含图片信息的 MaiImage 对象
Raises:
ValueError: 如果数据库记录标记为文件不存在则抛出该异常
"""
"""从数据库记录`MaiImage` 对象"""
if db_record.no_file_flag:
raise ValueError(f"数据库记录 {db_record.image_hash} 标记为文件不存在,无法创建 MaiImage 对象")
obj = cls(db_record.full_path)
obj.file_hash = db_record.image_hash
obj._set_full_path(Path(db_record.full_path))

View File

@@ -1,28 +0,0 @@
# from dataclasses import dataclass, field
# from typing import Optional, Dict, TYPE_CHECKING
# from . import BaseDataModel
# if TYPE_CHECKING:
# from .database_data_model import DatabaseMessages
# from src.core.types import ActionInfo
# # @dataclass
# # class TargetPersonInfo(BaseDataModel):
# # platform: str = field(default_factory=str)
# # user_id: str = field(default_factory=str)
# # user_nickname: str = field(default_factory=str)
# # person_id: Optional[str] = None
# # person_name: Optional[str] = None
# 已重构见src/common/data_models/chat_target_info_data_model.py
# @dataclass
# class ActionPlannerInfo(BaseDataModel):
# action_type: str = field(default_factory=str)
# reasoning: Optional[str] = None
# action_data: Optional[Dict] = None
# action_message: Optional["DatabaseMessages"] = None
# available_actions: Optional[Dict[str, "ActionInfo"]] = None
# loop_start_time: Optional[float] = None
# action_reasoning: Optional[str] = None
# 已重构见src/common/data_models/planned_action_data_models.py

View File

@@ -14,7 +14,6 @@ from src.llm_models.payload_content.resp_format import RespFormat
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput
if TYPE_CHECKING:
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.payload_content.message import Message
@@ -24,7 +23,7 @@ PromptMessage: TypeAlias = Dict[str, Any]
PromptInput: TypeAlias = str | List[PromptMessage]
"""统一的提示输入类型。"""
MessageFactory: TypeAlias = Callable[["BaseClient"], List["Message"]]
MessageFactory: TypeAlias = Callable[..., List["Message"]]
"""统一的消息工厂类型。"""

View File

@@ -14,6 +14,7 @@ from . import BaseDataModel
if TYPE_CHECKING:
from src.common.data_models.message_component_data_model import MessageSequence
from src.common.data_models.llm_service_data_models import PromptMessage
from src.llm_models.payload_content.tool_option import ToolCall
@@ -121,6 +122,10 @@ class ReplyGenerationResult(BaseDataModel):
default=None,
metadata={"description": "供监控层直接消费的通用 tool 展示详情。"},
)
request_messages: List["PromptMessage"] = field(
default_factory=list,
metadata={"description": "本次 replyer 实际发送给模型的消息列表。"},
)
def build_reply_monitor_detail(result: ReplyGenerationResult) -> Dict[str, Any]:
@@ -133,6 +138,8 @@ def build_reply_monitor_detail(result: ReplyGenerationResult) -> Dict[str, Any]:
if prompt_text:
detail["prompt_text"] = prompt_text
if result.request_messages:
detail["request_messages"] = result.request_messages
if reasoning_text:
detail["reasoning_text"] = reasoning_text
if output_text:

View File

@@ -84,9 +84,9 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: TimestampMode
def calculate_typing_time(
input_string: str,
chinese_time: float = 0.3,
english_time: float = 0.15,
line_break_time: float = 0.1,
chinese_time: float = 0.2,
english_time: float = 0.1,
line_break_time: float = 0.05,
is_emoji: bool = False,
) -> float:
"""

View File

@@ -10,24 +10,14 @@ logger = get_logger("config_utils")
class ExpressionConfigUtils:
@staticmethod
def get_expression_config_for_chat(session_id: Optional[str] = None) -> tuple[bool, bool, bool]:
# sourcery skip: use-next
"""
根据聊天会话ID获取表达配置
Args:
session_id: 聊天会话ID格式为哈希值
Returns:
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
"""
def _find_expression_config_item(session_id: Optional[str] = None):
if not global_config.expression.learning_list:
return True, True, True
return None
if session_id:
for config_item in global_config.expression.learning_list:
if not config_item.platform and not config_item.item_id:
continue # 这是全局的
continue
stream_id = ExpressionConfigUtils._get_stream_id(
config_item.platform,
str(config_item.item_id),
@@ -35,28 +25,59 @@ class ExpressionConfigUtils:
)
if stream_id is None:
continue
if stream_id == session_id:
if stream_id != session_id:
continue
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
return config_item
for config_item in global_config.expression.learning_list:
if not config_item.platform and not config_item.item_id:
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
return config_item
return True, True, True
return None
@staticmethod
def get_expression_advanced_chosen_for_chat(session_id: Optional[str] = None) -> bool:
"""根据聊天会话 ID 获取表达方式是否启用二次选择。"""
config_item = ExpressionConfigUtils._find_expression_config_item(session_id)
if config_item is None:
return False
return config_item.advanced_chosen
@staticmethod
def get_expression_config_for_chat(session_id: Optional[str] = None) -> tuple[bool, bool, bool]:
# sourcery skip: use-next
"""
根据聊天会话 ID 获取表达配置。
Args:
session_id: 聊天会话 ID格式为哈希值
Returns:
tuple: (是否使用表达, 是否学习表达, 是否启用 jargon 学习)
"""
config_item = ExpressionConfigUtils._find_expression_config_item(session_id)
if config_item is None:
return True, True, True
return (
config_item.use_expression,
config_item.enable_learning,
config_item.enable_jargon_learning,
)
@staticmethod
def _get_stream_id(platform: str, id_str: str, is_group: bool = False) -> Optional[str]:
# sourcery skip: remove-unnecessary-cast
"""
根据平台、ID字符串和是否为群聊生成聊天流ID
根据平台、ID 字符串和是否为群聊生成聊天流 ID
Args:
platform: 平台名称
id_str: 用户或群组的原始ID字符串
id_str: 用户或群组的原始 ID 字符串
is_group: 是否为群聊
Returns:
str: 生成的聊天流ID哈希值
str: 生成的聊天流 ID哈希值
"""
try:
from src.common.utils.utils_session import SessionUtils
@@ -66,7 +87,7 @@ class ExpressionConfigUtils:
else:
return SessionUtils.calculate_session_id(platform, user_id=str(id_str))
except Exception as e:
logger.error(f"生成聊天流ID失败: {e}")
logger.error(f"生成聊天流 ID 失败: {e}")
return None
@@ -91,7 +112,7 @@ class ChatConfigUtils:
else:
rule_session_id = SessionUtils.calculate_session_id(rule.platform, user_id=str(rule.item_id))
if rule_session_id != session_id:
continue # 不匹配的会话ID跳过
continue # 不匹配的会话 ID跳过
parsed_range = ChatConfigUtils.parse_range(rule.time)
if not parsed_range:
continue # 无法解析的时间范围,跳过
@@ -102,7 +123,7 @@ class ChatConfigUtils:
else: # 跨天的时间范围
in_range = now_min >= start_min or now_min <= end_min
if in_range:
return rule.value or 0.0 # 如果规则生效但没有设置值返回0.0
return rule.value or 0.0 # 如果规则生效但没有设置值,返回 0.0
# 没有匹配到会话相关的规则,继续匹配全局规则
for rule in global_config.chat.talk_value_rules:
@@ -118,7 +139,7 @@ class ChatConfigUtils:
else: # 跨天的时间范围
in_range = now_min >= start_min or now_min <= end_min
if in_range:
return rule.value or 0.0 # 如果规则生效但没有设置值返回0.0
return rule.value or 0.0 # 如果规则生效但没有设置值,返回 0.0
return result # 如果没有任何规则生效,返回默认值
@staticmethod

View File

@@ -23,7 +23,6 @@ from .official_configs import (
EmojiConfig,
ExpressionConfig,
KeywordReactionConfig,
LPMMKnowledgeConfig,
MaiSakaConfig,
MaimMessageConfig,
MCPConfig,
@@ -54,9 +53,10 @@ PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
CONFIG_DIR: Path = PROJECT_ROOT / "config"
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
MMC_VERSION: str = "1.0.0"
CONFIG_VERSION: str = "8.5.2"
MODEL_CONFIG_VERSION: str = "1.13.1"
CONFIG_VERSION: str = "8.7.1"
MODEL_CONFIG_VERSION: str = "1.14.0"
logger = get_logger("config")
@@ -454,6 +454,20 @@ def generate_new_config_file(config_class: type[T], config_path: Path, inner_con
write_config_to_file(config, config_path, inner_config_version)
def remove_legacy_env_file(env_path: Path) -> None:
"""删除已完成迁移的旧版 `.env` 文件。"""
if not env_path.exists():
return
try:
env_path.unlink()
except OSError as exc:
logger.warning(f"旧版 .env 配置文件删除失败,请手动删除: {env_path},原因: {exc}")
else:
logger.warning(f"检测到旧版环境变量绑定配置迁移成功,已删除旧版 .env 文件: {env_path}")
def load_config_from_file(
config_class: type[T], config_path: Path, new_ver: str, override_repr: bool = False
) -> tuple[T, bool]:
@@ -467,10 +481,12 @@ def load_config_from_file(
if not isinstance(inner_version, str):
raise TypeError(t("config.invalid_inner_version"))
old_ver: str = inner_version
env_migration_applied: bool = False
config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理
config_data = config_data.unwrap() # 转换为普通字典,方便后续处理
if config_path.name == "bot_config.toml" and config_class.__name__ == "Config":
env_migration = migrate_legacy_bind_env_to_bot_config_dict(config_data)
env_migration_applied = env_migration.migrated
if env_migration.migrated:
logger.warning(f"检测到旧版环境变量绑定配置,已迁移到主配置: {env_migration.reason}")
config_data = env_migration.data
@@ -497,9 +513,11 @@ def load_config_from_file(
raise e
else:
raise e
if compare_versions(old_ver, new_ver):
if compare_versions(old_ver, new_ver) or env_migration_applied:
output_config_changes(attribute_data, logger, old_ver, new_ver, config_path.name)
write_config_to_file(target_config, config_path, new_ver, override_repr)
if env_migration_applied:
remove_legacy_env_file(LEGACY_ENV_PATH)
updated = True
return target_config, updated
except Exception as e:

View File

@@ -1,12 +1,8 @@
"""
legacy_migration.py
一个“可随时拔掉”的旧配置兼容层
- 仅在配置解析失败时尝试修复旧格式数据7.x -> 8.x 这一类结构性变更)
- 不依赖 Pydantic / ConfigBase仅对 dict 做最小转换
- 成功则返回(修复后的 dict, True),失败则返回(原 dict, False)
设计目标:与现有 config 加载逻辑的接触点尽可能小,未来不需要时可一键移除。
旧配置兼容层
仅保留当前仍需要的“解析前结构修复”,避免老配置在 `from_dict` 前直接失败。
"""
from __future__ import annotations
@@ -16,12 +12,7 @@ from typing import Any, Optional
import os
from src.common.logger import get_logger
logger = get_logger("legacy_migration")
# 方便未来快速关闭/移除
ENABLE_LEGACY_MIGRATION: bool = True
@@ -43,6 +34,7 @@ def _as_list(x: Any) -> Optional[list[Any]]:
def _parse_host_env(value: Any) -> Optional[str]:
if not isinstance(value, str):
return None
normalized_value = value.strip()
return normalized_value or None
@@ -75,116 +67,73 @@ def _migrate_env_value(section: dict[str, Any], key: str, parsed_env_value: Any,
return True
def _move_section_key(source: dict[str, Any], target: dict[str, Any], key: str) -> bool:
"""将配置项从旧分组移动到新分组,若新分组已有值则保留新值。"""
if key not in source:
return False
if key not in target:
target[key] = source[key]
source.pop(key, None)
return True
def _parse_triplet_target(s: str) -> Optional[dict[str, str]]:
"""
解析 "platform:id:type" -> {platform,item_id,rule_type}
返回 None 表示无法解析。
解析 "platform:id:type" -> {platform, item_id, rule_type}
"""
if not isinstance(s, str):
return None
parts = s.split(":", 2)
if len(parts) != 3:
return None
platform, item_id, rule_type = parts
if rule_type not in ("group", "private"):
return None
return {"platform": platform, "item_id": item_id, "rule_type": rule_type}
def _parse_quad_prompt(s: str) -> Optional[dict[str, str]]:
"""
解析 "platform:id:type:prompt" -> {platform,item_id,rule_type,prompt}
prompt 允许包含冒号,因此只切前三个冒号。
"""
if not isinstance(s, str):
return None
parts = s.split(":", 3)
if len(parts) != 4:
return None
platform, item_id, rule_type, prompt = parts
if rule_type not in ("group", "private"):
return None
if not prompt:
return None
return {"platform": platform, "item_id": item_id, "rule_type": rule_type, "prompt": prompt}
def _parse_enable_disable(v: Any) -> Optional[bool]:
"""
兼容旧值 "enable"/"disable" 以及 bool。
"""
if isinstance(v, bool):
return v
if isinstance(v, str):
vv = v.strip().lower()
if vv == "enable":
normalized_value = v.strip().lower()
if normalized_value == "enable":
return True
if vv == "disable":
if normalized_value == "disable":
return False
return None
def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool:
"""
旧:
learning_list = [
["", "enable", "enable", "enable"],
["qq:1919810:group", "enable", "enable", "enable"],
]
兼容旧旧格式:
learning_list = [
["qq:1919810:group", "enable", "enable", "0.5"],
["", "disable", "disable", "0.1"],
]
新:
[[expression.learning_list]]
platform="", item_id="", rule_type="group", use_expression=true, enable_learning=true, enable_jargon_learning=true
将旧版 expression.learning_list 转成当前结构。
"""
ll = _as_list(expr.get("learning_list"))
if ll is None:
learning_list = _as_list(expr.get("learning_list"))
if learning_list is None:
return False
# 如果已经是新格式(列表里是 dict跳过
if ll and all(isinstance(i, dict) for i in ll):
if learning_list and all(isinstance(item, dict) for item in learning_list):
return False
migrated_items: list[dict[str, Any]] = []
for row in ll:
r = _as_list(row)
if r is None or len(r) < 4:
# 行结构不对,无法安全迁移
for row in learning_list:
row_items = _as_list(row)
if row_items is None or len(row_items) < 4:
return False
target_raw = r[0]
use_expression = _parse_enable_disable(r[1])
enable_learning = _parse_enable_disable(r[2])
enable_jargon_learning = _parse_enable_disable(r[3])
target_raw = row_items[0]
use_expression = _parse_enable_disable(row_items[1])
enable_learning = _parse_enable_disable(row_items[2])
enable_jargon_learning = _parse_enable_disable(row_items[3])
if enable_jargon_learning is None:
# 更早期的配置在第 4 列记录的是一个已废弃的数值权重/阈值,
# 当前 schema 已没有对应字段。这里按保守策略兼容迁移:
# 丢弃旧数值,并将 enable_jargon_learning 置为 False。
# 更早期版本第 4 列是已废弃的数值阈值,这里仅做保守兼容。
try:
float(str(r[3]))
float(str(row_items[3]))
except (TypeError, ValueError):
pass
else:
enable_jargon_learning = False
if use_expression is None or enable_learning is None or enable_jargon_learning is None:
return False
# 旧格式中 target 允许为空字符串:表示全局;新结构必须有三元组字段
if target_raw == "" or target_raw is None:
target = {"platform": "", "item_id": "", "rule_type": "group"}
else:
@@ -209,99 +158,56 @@ def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool:
def _migrate_expression_groups(expr: dict[str, Any]) -> bool:
"""
旧:
expression_groups = [
["qq:1:group","qq:2:group"],
["qq:3:group"],
]
新:
expression_groups = [
{ expression_groups = [ {platform="qq", item_id="1", rule_type="group"}, ... ] },
{ expression_groups = [ ... ] },
]
将旧版 expression.expression_groups 转成当前结构。
"""
eg = _as_list(expr.get("expression_groups"))
if eg is None:
expression_groups = _as_list(expr.get("expression_groups"))
if expression_groups is None:
return False
if expression_groups and all(isinstance(item, dict) for item in expression_groups):
return False
# 已经是新格式(列表里是 dict 且包含 expression_groups跳过
if eg and all(isinstance(i, dict) for i in eg):
return False
migrated: list[dict[str, Any]] = []
for group in eg:
g = _as_list(group)
if g is None:
migrated_groups: list[dict[str, Any]] = []
for group in expression_groups:
group_items = _as_list(group)
if group_items is None:
return False
targets: list[dict[str, str]] = []
for item in g:
for item in group_items:
parsed = _parse_triplet_target(str(item))
if parsed is None:
return False
targets.append(parsed)
migrated.append({"expression_groups": targets})
expr["expression_groups"] = migrated
migrated_groups.append({"expression_groups": targets})
expr["expression_groups"] = migrated_groups
return True
def _migrate_target_item_list(parent: dict[str, Any], key: str) -> bool:
"""
将 list[str] 的 "platform:id:type" 迁移为 list[{platform,item_id,rule_type}]
用于memory.global_memory_blacklist 等。
将 list[str] 的 "platform:id:type" 迁移为 list[TargetItem]。
"""
raw = _as_list(parent.get(key))
if raw is None:
if raw is None or not raw:
return False
if raw and all(isinstance(i, dict) for i in raw):
if all(isinstance(item, dict) for item in raw):
return False
targets: list[dict[str, str]] = []
for item in raw:
parsed = _parse_triplet_target(str(item))
if parsed is None:
return False
targets.append(parsed)
parent[key] = targets
return True
def _migrate_extra_prompt_list(exp: dict[str, Any], key: str) -> bool:
"""
将 list[str] 的 "platform:id:type:prompt" 迁移为 list[{platform,item_id,rule_type,prompt}]
用于experimental.chat_prompts
"""
raw = _as_list(exp.get(key))
if raw is None:
return False
if raw and all(isinstance(i, dict) for i in raw):
return False
items: list[dict[str, str]] = []
for item in raw:
parsed = _parse_quad_prompt(str(item))
if parsed is None:
return False
items.append(parsed)
exp[key] = items
return True
def _parse_multimodal_replyer(v: Any) -> Optional[bool]:
"""兼容旧 replyer_generator_type 到布尔开关的迁移。"""
if isinstance(v, bool):
return v
if not isinstance(v, str):
return None
normalized_value = v.strip().lower()
if normalized_value == "multimodal":
return True
if normalized_value == "legacy":
return False
return None
def migrate_legacy_bind_env_to_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
"""将旧版环境变量中的绑定地址迁移到主配置结构。"""
"""将旧版 `.env` 中的绑定地址迁移到主配置结构。"""
migrated_any = False
reasons: list[str] = []
@@ -339,8 +245,7 @@ def migrate_legacy_bind_env_to_bot_config_dict(data: dict[str, Any]) -> Migratio
def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
"""
尝试对“总配置 bot_config.toml”的 dict已 unwrap进行旧格式修复
仅做我们明确知道的结构性变更;其它字段不动。
尝试修复 `bot_config.toml` 的少量旧结构,仅保留当前仍需要的兼容逻辑
"""
if not ENABLE_LEGACY_MIGRATION:
return MigrationResult(data=data, migrated=False, reason="disabled")
@@ -353,41 +258,30 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
if _migrate_expression_learning_list(expr):
migrated_any = True
reasons.append("expression.learning_list")
if _migrate_expression_groups(expr):
migrated_any = True
reasons.append("expression.expression_groups")
# allow_reflect: 旧 list[str] -> 新 list[TargetItem]
if _migrate_target_item_list(expr, "allow_reflect"):
migrated_any = True
reasons.append("expression.allow_reflect")
# manual_reflect_operator_id: 旧 str -> 新 Optional[TargetItem]
mroi = expr.get("manual_reflect_operator_id")
if isinstance(mroi, str) and mroi.strip():
parsed = _parse_triplet_target(mroi.strip())
manual_reflect_operator_id = expr.get("manual_reflect_operator_id")
if isinstance(manual_reflect_operator_id, str) and manual_reflect_operator_id.strip():
parsed = _parse_triplet_target(manual_reflect_operator_id.strip())
if parsed is not None:
expr["manual_reflect_operator_id"] = parsed
migrated_any = True
reasons.append("expression.manual_reflect_operator_id")
chat = _as_dict(data.get("chat"))
if chat is None:
chat = {}
data["chat"] = chat
elif "private_plan_style" in chat:
chat.pop("private_plan_style", None)
migrated_any = True
reasons.append("chat.private_plan_style_removed")
if isinstance(manual_reflect_operator_id, str) and not manual_reflect_operator_id.strip():
expr.pop("manual_reflect_operator_id", None)
migrated_any = True
reasons.append("expression.manual_reflect_operator_id_empty")
personality = _as_dict(data.get("personality"))
visual = _as_dict(data.get("visual"))
if visual is None and (
(personality is not None and "visual_style" in personality)
or "multimodal_planner" in chat
or "replyer_generator_type" in chat
):
visual = {}
data["visual"] = visual
if visual is not None and personality is not None and "visual_style" in personality:
if "visual_style" not in visual:
visual["visual_style"] = personality["visual_style"]
@@ -395,108 +289,19 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
migrated_any = True
reasons.append("personality.visual_style_moved_to_visual.visual_style")
if visual is not None and "multimodal_planner" in chat:
if "multimodal_planner" not in visual and isinstance(chat["multimodal_planner"], bool):
visual["multimodal_planner"] = chat["multimodal_planner"]
if "multimodal_planner" in visual:
chat.pop("multimodal_planner", None)
if visual is not None and "multimodal_planner" in visual and "planner_mode" not in visual:
multimodal_planner = visual.pop("multimodal_planner")
if isinstance(multimodal_planner, bool):
visual["planner_mode"] = "multimodal" if multimodal_planner else "text"
migrated_any = True
reasons.append("chat.multimodal_planner_moved_to_visual.multimodal_planner")
reasons.append("visual.multimodal_planner_moved_to_visual.planner_mode")
else:
visual["multimodal_planner"] = multimodal_planner
if visual is not None and "replyer_generator_type" in chat:
multimodal_replyer = _parse_multimodal_replyer(chat["replyer_generator_type"])
if "multimodal_replyer" not in visual and multimodal_replyer is not None:
visual["multimodal_replyer"] = multimodal_replyer
if "multimodal_replyer" in visual:
chat.pop("replyer_generator_type", None)
migrated_any = True
reasons.append("chat.replyer_generator_type_moved_to_visual.multimodal_replyer")
maisaka = _as_dict(data.get("maisaka"))
mem = _as_dict(data.get("memory"))
debug = _as_dict(data.get("debug"))
if maisaka is not None:
moved_memory_keys = ("enable_memory_query_tool", "memory_query_default_limit")
if any(key in maisaka for key in moved_memory_keys) and mem is None:
mem = {}
data["memory"] = mem
if mem is not None:
for moved_key in moved_memory_keys:
if _move_section_key(maisaka, mem, moved_key):
migrated_any = True
reasons.append(f"maisaka.{moved_key}_moved_to_memory")
if mem is not None and "show_memory_prompt" in mem and debug is None:
debug = {}
data["debug"] = debug
if mem is not None:
if _migrate_target_item_list(mem, "global_memory_blacklist"):
migrated_any = True
reasons.append("memory.global_memory_blacklist")
if debug is not None and _move_section_key(mem, debug, "show_memory_prompt"):
migrated_any = True
reasons.append("memory.show_memory_prompt_moved_to_debug")
for removed_key in (
"agent_timeout_seconds",
"max_agent_iterations",
):
if removed_key in mem:
mem.pop(removed_key, None)
migrated_any = True
reasons.append(f"memory.{removed_key}_removed")
relationship = _as_dict(data.get("relationship"))
if relationship is not None:
data.pop("relationship", None)
memory = _as_dict(data.get("memory"))
if memory is not None and _migrate_target_item_list(memory, "global_memory_blacklist"):
migrated_any = True
reasons.append("relationship_removed")
exp = _as_dict(data.get("experimental"))
if exp is not None:
if _migrate_extra_prompt_list(exp, "chat_prompts"):
migrated_any = True
reasons.append("experimental.chat_prompts")
if "private_plan_style" in exp:
exp.pop("private_plan_style", None)
migrated_any = True
reasons.append("experimental.private_plan_style_removed")
for key in ("group_chat_prompt", "private_chat_prompts", "chat_prompts"):
if key in exp and key not in chat:
chat[key] = exp[key]
migrated_any = True
reasons.append(f"experimental.{key}_moved_to_chat")
data.pop("experimental", None)
migrated_any = True
reasons.append("experimental_removed")
if chat is not None and "think_mode" in chat:
chat.pop("think_mode", None)
migrated_any = True
reasons.append("chat.think_mode_removed")
tool = _as_dict(data.get("tool"))
if tool is not None:
data.pop("tool", None)
migrated_any = True
reasons.append("tool_section_removed")
# ExpressionConfig 中的 manual_reflect_operator_id:
# 旧版本可能是 ""(字符串),新版本期望 Optional[TargetItem]。
# 空字符串视为未配置,转换为 None/删除键以避免校验错误。
expr = _as_dict(data.get("expression"))
if expr is not None:
mroi = expr.get("manual_reflect_operator_id")
if isinstance(mroi, str) and not mroi.strip():
expr.pop("manual_reflect_operator_id", None)
migrated_any = True
reasons.append("expression.manual_reflect_operator_id_empty")
reasons.append("memory.global_memory_blacklist")
reason = ",".join(reasons)
return MigrationResult(data=data, migrated=migrated_any, reason=reason)

View File

@@ -307,6 +307,15 @@ class ModelInfo(ConfigBase):
)
"""强制流式输出模式 (若模型不支持非流式输出, 请设置为true启用强制流式输出, 默认值为false)"""
visual: bool = Field(
default=False,
json_schema_extra={
"x-widget": "switch",
"x-icon": "image",
},
)
"""是否为多模态模型。开启后表示该模型支持视觉输入。"""
extra_params: dict[str, Any] = Field(
default_factory=dict,
json_schema_extra={
@@ -437,4 +446,4 @@ class ModelTaskConfig(ConfigBase):
"x-icon": "database",
},
)
"""嵌入模型配置"""
"""嵌入模型配置"""

View File

@@ -145,23 +145,23 @@ class VisualConfig(ConfigBase):
__ui_label__ = "视觉"
__ui_icon__ = "image"
multimodal_planner: bool = Field(
default=True,
planner_mode: Literal["text", "multimodal", "auto"] = Field(
default="auto",
json_schema_extra={
"x-widget": "switch",
"x-icon": "image",
},
)
"""是否直接输入图片"""
multimodal_replyer: bool = Field(
default=False,
json_schema_extra={
"x-widget": "switch",
"x-widget": "select",
"x-icon": "git-branch",
},
)
"""是否启用 Maisaka 多模态 replyer 生成器"""
"""规划器模式auto根据模型信息自动选择text为纯文本模式multimodal为多模态模式"""
replyer_mode: Literal["text", "multimodal", "auto"] = Field(
default="auto",
json_schema_extra={
"x-widget": "select",
"x-icon": "git-branch",
},
)
"""回复器模式auto根据模型信息自动选择text为纯文本模式multimodal为多模态模式"""
visual_style: str = Field(
default="请用中文描述这张图片的内容。如果有文字请把文字描述概括出来请留意其主题直观感受输出为一段平文本最多30字请注意不要分点就输出一段文本",
@@ -239,16 +239,12 @@ class ChatConfig(ConfigBase):
)
"""Planner 连续被新消息打断的最大次数0 表示不启用打断"""
plan_reply_log_max_per_chat: int = Field(
default=1024,
json_schema_extra={
"x-widget": "input",
"x-icon": "file-text",
},
)
"""每个聊天流最大保存的Plan/Reply日志数量超过此数量时会自动删除最老的日志"""
group_chat_prompt: str = Field(
default="你需要控制自己发言的频率,如果是一对一聊天,可以以较均匀的频率发言;如果用户较多,不要每句都回复,控制回复频率,不要回复的太频繁!控制回复的频率,不要每个人的消息都回复。",
default="""
你正在qq群里聊天下面是群里正在聊的内容其中包含聊天记录和聊天中的图片。
回复尽量简短一些。最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。请注意把握聊天内容。
不要回复的太频繁!控制回复的频率,不要每个人的消息都回复,只回复你感兴趣的或者主动提及你的。
""",
json_schema_extra={
"x-widget": "textarea",
"x-icon": "users",
@@ -257,7 +253,11 @@ class ChatConfig(ConfigBase):
"""_wrap_群聊通用注意事项"""
private_chat_prompts: str = Field(
default="你需要控制自己发言的频率,可以以较均匀的频率发言。",
default="""
你正在聊天,下面是正在聊的内容,其中包含聊天记录和聊天中的图片。
回复尽量简短一些。请注意把握聊天内容。
请考虑对方的发言频率,想法,思考自己何时回复以及回复内容。
""",
json_schema_extra={
"x-widget": "textarea",
"x-icon": "user",
@@ -740,6 +740,15 @@ class LearningItem(ConfigBase):
)
"""是否启用jargon学习"""
advanced_chosen: bool = Field(
default=False,
json_schema_extra={
"x-widget": "switch",
"x-icon": "sparkles",
},
)
"""是否启用基于子代理的二次表达方式选择"""
class ExpressionGroup(ConfigBase):
"""表达互通组配置类,若列表为空代表全局共享"""
@@ -769,6 +778,7 @@ class ExpressionConfig(ConfigBase):
use_expression=True,
enable_learning=True,
enable_jargon_learning=True,
advanced_chosen=False,
)
],
json_schema_extra={
@@ -1640,35 +1650,6 @@ class MaiSakaConfig(ConfigBase):
)
"""MaiSaka 使用的用户名称"""
tool_filter_task_name: str = Field(
default="utils",
json_schema_extra={
"x-widget": "input",
"x-icon": "sparkles",
},
)
"""工具筛选预判使用的模型任务名"""
tool_filter_threshold: int = Field(
default=20,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "filter",
},
)
"""当可用工具总数超过该阈值时,先进行一轮工具筛选"""
tool_filter_max_keep: int = Field(
default=5,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "list-filter",
},
)
"""工具筛选阶段最多保留的非内置工具数量"""
show_image_path: bool = Field(
default=True,
json_schema_extra={

View File

@@ -181,10 +181,7 @@ class ToolSpec:
str: 合并后的单段工具描述。
"""
parts = [self.brief_description.strip()]
if self.detailed_description.strip():
parts.append(self.detailed_description.strip())
return "\n\n".join(part for part in parts if part).strip()
return self.brief_description.strip()
def to_llm_definition(self) -> ToolDefinitionInput:
"""转换为统一的 LLM 工具定义。
@@ -389,7 +386,24 @@ class ToolRegistry:
for provider in self._providers:
provider_specs = await provider.list_tools()
if any(spec.name == invocation.tool_name and spec.enabled for spec in provider_specs):
return await provider.invoke(invocation, context)
try:
return await provider.invoke(invocation, context)
except Exception as exc:
logger.exception(
"工具调用异常: tool=%s provider=%s",
invocation.tool_name,
getattr(provider, "provider_name", ""),
)
error_message = str(exc).strip()
if error_message:
error_message = f"工具 {invocation.tool_name} 调用失败:{exc.__class__.__name__}: {error_message}"
else:
error_message = f"工具 {invocation.tool_name} 调用失败:{exc.__class__.__name__}"
return ToolExecutionResult(
tool_name=invocation.tool_name,
success=False,
error_message=error_message,
)
return ToolExecutionResult(
tool_name=invocation.tool_name,

View File

@@ -1,5 +1,5 @@
"""
表达方式自动检查定时任务
表达方式自动检查定时任务
功能:
1. 定期随机选取指定数量的表达方式
@@ -9,52 +9,48 @@
"""
import asyncio
import json
import random
from typing import List
from sqlmodel import select
from src.learners.expression_review_store import get_review_state, set_review_state
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.config.config import global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.learners.expression_review_store import get_review_state, set_review_state
from src.learners.expression_utils import parse_evaluation_response
from src.manager.async_task_manager import AsyncTask
from src.services.llm_service import LLMServiceClient
logger = get_logger("expression_auto_check_task")
logger = get_logger("expressor")
def create_evaluation_prompt(situation: str, style: str) -> str:
"""
创建评估提示词
创建评估提示词
Args:
situation: 情
situation: 情
style: 风格
Returns:
评估提示词
"""
# 基础评估标准
base_criteria = [
"表达方式或言语风格 是否与使用条件或使用情景 匹配",
"允许部分语法错误或口化或缺省出现",
"表达方式或言语风格是否与使用条件或使用情景匹配",
"允许部分语法错误或口化或缺省出现",
"表达方式不能太过特指,需要具有泛用性",
"一般不涉及具体的人名或名称",
]
# 从配置中获取额外的自定义标准
custom_criteria = global_config.expression.expression_auto_check_custom_criteria
# 合并所有评估标准
all_criteria = base_criteria.copy()
if custom_criteria:
all_criteria.extend(custom_criteria)
# 构建评估标准列表字符串
criteria_list = "\n".join([f"{i + 1}. {criterion}" for i, criterion in enumerate(all_criteria)])
prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
@@ -64,14 +60,13 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
请从以下方面进行评估:
{criteria_list}
请以JSON格式输出评估结果
请以 JSON 格式输出评估结果:
{{
"suitable": true/false,
"reason": "评估理由(如果不合适,请说明原因)"
}}
如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因。
请严格按照JSON格式输出不要包含其他内容。"""
如果合适suitable 设为 true如果不合适suitable 设为 false并在 reason 中说明原因。
请严格按照 JSON 格式输出,不要包含其他内容。"""
return prompt
@@ -81,10 +76,10 @@ judge_llm = LLMServiceClient(task_name="utils", request_type="expression_check")
async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str | None]:
"""
执行单次LLM评估
执行单次 LLM 评估
Args:
situation: 情
situation: 情
style: 风格
Returns:
@@ -101,20 +96,10 @@ async def single_expression_check(situation: str, style: str) -> tuple[bool, str
response = generation_result.response
logger.debug(f"LLM响应: {response}")
# 解析JSON响应
try:
evaluation = json.loads(response)
except json.JSONDecodeError as e:
import re
evaluation = parse_evaluation_response(response)
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
if json_match:
evaluation = json.loads(json_match.group())
else:
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由")
suitable = bool(evaluation.get("suitable", False))
reason = str(evaluation.get("reason", "未提供理由"))
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None
@@ -125,20 +110,19 @@ async def single_expression_check(situation: str, style: str) -> tuple[bool, str
class ExpressionAutoCheckTask(AsyncTask):
"""表达方式自动检查定时任务"""
"""表达方式自动检查定时任务"""
def __init__(self):
# 从配置中获取检查间隔和一次检查数量
check_interval = global_config.expression.expression_auto_check_interval
super().__init__(
task_name="Expression Auto Check Task",
wait_before_start=60, # 启动后等待60秒再开始第一次检查
wait_before_start=60,
run_interval=check_interval,
)
async def _select_expressions(self, count: int) -> List[Expression]:
"""
随机选择指定数量的未检查表达方式
随机选择指定数量的未检查表达方式
Args:
count: 需要选择的数量
@@ -158,11 +142,12 @@ class ExpressionAutoCheckTask(AsyncTask):
logger.info("没有未检查的表达方式")
return []
# 随机选择指定数量
selected_count = min(count, len(unevaluated_expressions))
selected = random.sample(unevaluated_expressions, selected_count)
logger.info(f"{len(unevaluated_expressions)} 条未检查表达方式中随机选择了 {selected_count}")
logger.info(
f"{len(unevaluated_expressions)} 条未检查表达方式中随机选择了 {selected_count}"
)
return selected
except Exception as e:
@@ -171,35 +156,35 @@ class ExpressionAutoCheckTask(AsyncTask):
async def _evaluate_expression(self, expression: Expression) -> bool:
"""
评估单个表达方式
评估单个表达方式
Args:
expression: 要评估的表达方式
Returns:
True表示通过False表示不通过
True 表示通过False 表示不通过
"""
suitable, reason, error = await single_expression_check(
expression.situation,
expression.style,
)
# 更新数据库
try:
set_review_state(expression.id, True, not suitable, "ai")
status = "通过" if suitable else "不通过"
# 保留这段注释,方便后续需要时恢复更详细的审核日志。
# logger.info(
# f"表达方式评估完成 [ID: {expression.id}] - {status} | "
# f"Situation: {expression.situation}... | "
# f"Style: {expression.style}... | "
# f"Reason: {reason[:50]}..."
# f"表达方式评估完成 [ID: {expression.id}] - {status} | "
# f"Situation: {expression.situation}... | "
# f"Style: {expression.style}... | "
# f"Reason: {reason[:50]}..."
# )
if error:
logger.warning(f"表达方式评估时出现错误 [ID: {expression.id}]: {error}")
logger.debug(f"表达方式 [ID: {expression.id}] 评估完成: {status}, reason={reason}")
return suitable
except Exception as e:
@@ -207,9 +192,8 @@ class ExpressionAutoCheckTask(AsyncTask):
return False
async def run(self):
"""执行检查任务"""
"""执行检查任务"""
try:
# 检查是否启用自动检查
if not global_config.expression.expression_self_reflect:
logger.debug("表达方式自动检查未启用,跳过本次执行")
return
@@ -221,26 +205,22 @@ class ExpressionAutoCheckTask(AsyncTask):
logger.info(f"开始执行表达方式自动检查,本次将检查 {check_count}")
# 选择要检查的表达方式
expressions = await self._select_expressions(check_count)
if not expressions:
logger.info("没有需要检查的表达方式")
return
# 逐个评估
passed_count = 0
failed_count = 0
for i, expression in enumerate(expressions, 1):
logger.info(f"正在评估 [{i}/{len(expressions)}]: ID={expression.id}")
for index, expression in enumerate(expressions, 1):
logger.debug(f"正在评估 [{index}/{len(expressions)}]: ID={expression.id}")
if await self._evaluate_expression(expression):
passed_count += 1
else:
failed_count += 1
# 避免请求过快
await asyncio.sleep(0.3)
logger.info(

View File

@@ -267,7 +267,7 @@ class ExpressionLearner:
return normalized_entries
def get_pending_count(self, message_cache: List["SessionMessage"]) -> int:
"""??????????????"""
"""获取待处理消息数量"""
return max(0, len(message_cache) - self._last_processed_index)
async def learn(
@@ -275,10 +275,10 @@ class ExpressionLearner:
message_cache: List["SessionMessage"],
jargon_miner: Optional["JargonMiner"] = None,
) -> bool:
"""?????????????????????"""
"""学习表达方式"""
pending_messages = message_cache[self._last_processed_index :]
if not pending_messages:
logger.debug("??????????????????")
logger.debug("没有待处理消息")
return False
if len(pending_messages) < self.min_messages_for_extraction:
return False
@@ -304,7 +304,7 @@ class ExpressionLearner:
)
response = generation_result.response
except Exception as e:
logger.error(f"????????????????{e}")
logger.error(f"学习表达方式失败: {e}")
return False
expressions: List[Tuple[str, str, str]]
@@ -319,14 +319,14 @@ class ExpressionLearner:
continue
jargon_entries.append((content, source_id))
existing_contents.add(content)
logger.info(f"??????????{content}")
logger.info(f"从缓存中找到黑话: {content}")
if len(expressions) > 20:
logger.info(f"?????????? 20 ???????????{len(expressions)}")
logger.info(f"表达方式数量超过20: {len(expressions)}")
expressions = []
if len(jargon_entries) > 30:
logger.info(f"???????? 30 ???????????{len(jargon_entries)}")
logger.info(f"黑话数量超过30: {len(jargon_entries)}")
jargon_entries = []
after_extract_result = await self._get_runtime_manager().invoke_hook(
@@ -337,7 +337,7 @@ class ExpressionLearner:
jargon_entries=self._serialize_jargon_entries(jargon_entries),
)
if after_extract_result.aborted:
logger.info(f"{self.session_id} ?????????? Hook ??")
logger.info(f"{self.session_id} 表达方式选择 Hook 中止")
self._last_processed_index = len(message_cache)
return False
@@ -353,21 +353,21 @@ class ExpressionLearner:
await self._process_jargon_entries(jargon_entries, pending_messages, jargon_miner)
if not expressions:
logger.info("????????????")
logger.info("没有可学习的表达方式")
self._last_processed_index = len(message_cache)
return False
logger.info(f"???? expressions: {expressions}")
logger.info(f"???? jargon_entries: {jargon_entries}")
logger.info(f"可学习的表达方式: {expressions}")
logger.info(f"可学习的黑话: {jargon_entries}")
learnt_expressions = self._filter_expressions(expressions, pending_messages)
if not learnt_expressions:
logger.info("????????????")
logger.info("没有可学习的表达方式通过过滤")
self._last_processed_index = len(message_cache)
return False
learnt_expressions_str = "\n".join(f"{situation}->{style}" for situation, style in learnt_expressions)
logger.info(f"? {self.session_id} ????????\n{learnt_expressions_str}")
logger.info(f"{self.session_id} 可学习的表达方式: \n{learnt_expressions_str}")
for situation, style in learnt_expressions:
before_upsert_result = await self._get_runtime_manager().invoke_hook(
@@ -377,14 +377,14 @@ class ExpressionLearner:
style=style,
)
if before_upsert_result.aborted:
logger.info(f"{self.session_id} ???????? Hook ??: situation={situation!r}")
logger.info(f"{self.session_id} 表达方式写入 Hook 中止: situation={situation!r}")
continue
upsert_kwargs = before_upsert_result.kwargs
situation = str(upsert_kwargs.get("situation", situation) or "").strip()
style = str(upsert_kwargs.get("style", style) or "").strip()
if not situation or not style:
logger.info(f"{self.session_id} ???????? Hook ??????")
logger.info(f"{self.session_id} 表达方式写入 Hook 中止: situation={situation!r}")
continue
await self._upsert_expression_to_db(situation, style)

View File

@@ -1,14 +1,14 @@
from json_repair import repair_json
from typing import Any, List, Optional, Tuple
import json
import re
from typing import Any, Dict, List, Optional, Tuple
from json_repair import repair_json
from src.config.config import global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.prompt.prompt_manager import prompt_manager
from src.common.logger import get_logger
from src.config.config import global_config
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
logger = get_logger("expression_utils")
@@ -16,17 +16,7 @@ judge_llm = LLMServiceClient(task_name="utils", request_type="expression_check")
def _normalize_repair_json_result(repaired_result: Any) -> str:
"""将 repair_json 的返回值规范化为 JSON 字符串。
Args:
repaired_result: `repair_json` 的返回值,可能是字符串或带附加信息的元组。
Returns:
str: 可供 `json.loads` 继续解析的 JSON 字符串。
Raises:
TypeError: 当返回值无法规范化为字符串时抛出。
"""
"""`repair_json` 的返回结果统一转换为字符串。"""
if isinstance(repaired_result, str):
return repaired_result
if isinstance(repaired_result, tuple) and repaired_result:
@@ -37,22 +27,121 @@ def _normalize_repair_json_result(repaired_result: Any) -> str:
raise TypeError(f"repair_json 返回了无法处理的结果类型: {type(repaired_result)}")
def _strip_markdown_code_fence(text: str) -> str:
"""移除 LLM 可能附带的 Markdown 代码块包裹。"""
raw = text.strip()
if match := re.search(r"```json\s*(.*?)\s*```", raw, re.DOTALL):
return match[1].strip()
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
return raw.strip()
def _extract_json_object_candidate(text: str) -> str:
"""尽量从文本中提取首个 JSON 对象片段。"""
start_index = text.find("{")
end_index = text.rfind("}")
if start_index != -1 and end_index != -1 and start_index < end_index:
return text[start_index : end_index + 1].strip()
return text.strip()
def _extract_reason_from_text(text: str) -> Optional[str]:
"""从格式不完整的 JSON 文本中兜底提取 reason 字段。"""
reason_key_match = re.search(r'["“”]?reason["“”]?\s*:\s*', text, re.IGNORECASE)
if reason_key_match is None:
return None
value_text = text[reason_key_match.end() :].strip()
if not value_text:
return None
if value_text.endswith("}"):
value_text = value_text[:-1].rstrip()
if value_text.endswith(","):
value_text = value_text[:-1].rstrip()
if not value_text:
return None
if value_text[0] in {'"', "'", "", "", "", ""}:
value_text = value_text[1:]
while value_text and value_text[-1] in {'"', "'", "", "", "", ""}:
value_text = value_text[:-1].rstrip()
return value_text.strip() or None
def _normalize_reason_text(reason: Any) -> str:
"""清理解析后 reason 中残留的包裹引号。"""
normalized_reason = str(reason).strip()
if len(normalized_reason) >= 2 and normalized_reason[0] == normalized_reason[-1]:
if normalized_reason[0] in {'"', "'", "", "", "", ""}:
normalized_reason = normalized_reason[1:-1].strip()
if normalized_reason.endswith('"') and normalized_reason.count('"') % 2 == 1:
normalized_reason = normalized_reason[:-1].rstrip()
if normalized_reason.endswith("'") and normalized_reason.count("'") % 2 == 1:
normalized_reason = normalized_reason[:-1].rstrip()
if normalized_reason.endswith('"') and not normalized_reason.startswith('"'):
normalized_reason = normalized_reason[:-1].rstrip()
if normalized_reason.endswith("'") and not normalized_reason.startswith("'"):
normalized_reason = normalized_reason[:-1].rstrip()
return normalized_reason
def parse_evaluation_response(response: str) -> Dict[str, Any]:
"""解析表达方式评估结果,兼容不完全合法的 JSON。"""
raw = _strip_markdown_code_fence(response)
if not raw:
raise ValueError("LLM 响应为空")
parse_candidates = [raw]
json_candidate = _extract_json_object_candidate(raw)
if json_candidate and json_candidate not in parse_candidates:
parse_candidates.append(json_candidate)
for candidate in parse_candidates:
parsed = _try_parse(candidate)
if isinstance(parsed, dict):
if "reason" in parsed:
parsed["reason"] = _normalize_reason_text(parsed["reason"])
return parsed
fixed_candidate = fix_chinese_quotes_in_json(candidate)
if fixed_candidate != candidate:
parsed = _try_parse(fixed_candidate)
if isinstance(parsed, dict):
if "reason" in parsed:
parsed["reason"] = _normalize_reason_text(parsed["reason"])
return parsed
suitable_match = re.search(r'["“”]?suitable["“”]?\s*:\s*(true|false)', raw, re.IGNORECASE)
reason = _extract_reason_from_text(json_candidate or raw)
if suitable_match is None or reason is None:
raise ValueError(f"无法解析 LLM 响应为评估结果 JSON: {response}")
return {
"suitable": suitable_match.group(1).lower() == "true",
"reason": _normalize_reason_text(reason),
}
async def check_expression_suitability(situation: str, style: str) -> Tuple[bool, str, Optional[str]]:
"""
执行单次LLM评估
执行单次 LLM 评估
Args:
situation: 情
situation: 情
style: 风格
Returns:
(suitable, reason, error) 元组,如果出错则 suitable 为 Falseerror 包含错误信息
"""
# 构建评估提示词
# 基础评估标准
base_criteria = [
"表达方式或言语风格是否与使用条件或使用情景匹配",
"允许部分语法错误或口化或缺省出现",
"允许部分语法错误或口化或缺省出现",
"表达方式不能太过特指,需要具有泛用性",
"一般不涉及具体的人名或名称",
]
@@ -60,7 +149,6 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool
if custom_criteria := global_config.expression.expression_auto_check_custom_criteria:
base_criteria.extend(custom_criteria)
# 构建评估标准列表字符串
criteria_list = "\n".join([f"{i + 1}. {criterion}" for i, criterion in enumerate(base_criteria)])
prompt_template = prompt_manager.get_prompt("expression_evaluation")
@@ -81,18 +169,13 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool
logger.debug(f"评估结果: {response}")
try:
evaluation = json.loads(response)
except json.JSONDecodeError:
try:
response_repaired = _normalize_repair_json_result(repair_json(response))
evaluation = json.loads(response_repaired)
except Exception as e:
raise ValueError(f"无法解析LLM响应为JSON: {response}") from e
evaluation = parse_evaluation_response(response)
except Exception as e:
return False, f"评估表达方式时发生错误: {e}", str(e)
try:
suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由")
suitable = bool(evaluation.get("suitable", False))
reason = _normalize_reason_text(evaluation.get("reason", "未提供理由"))
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None
except Exception as e:
@@ -100,69 +183,48 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool
def fix_chinese_quotes_in_json(text: str) -> str:
"""使用状态机修复 JSON 字符串值中的中文引号"""
result = []
i = 0
"""使用状态机修复 JSON 字符串值中的中文引号"""
result: List[str] = []
in_string = False
escape_next = False
while i < len(text):
char = text[i]
for char in text:
if escape_next:
# 当前字符是转义字符后的字符,直接添加
result.append(char)
escape_next = False
i += 1
continue
if char == "\\":
# 转义字符
result.append(char)
escape_next = True
i += 1
continue
if char == '"' and not escape_next:
# 遇到英文引号,切换字符串状态
if char == '"':
in_string = not in_string
result.append(char)
i += 1
continue
if in_string and char in ["", ""]:
result.append('\\"')
else:
result.append(char)
i += 1
continue
result.append(char)
return "".join(result)
def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
"""
解析 LLM 返回的表达风格总结和黑话 JSON提取两个列表。
期望的 JSON 结构:
[
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
{"content": "词条", "source_id": "12"}, // 黑话
...
]
解析 LLM 返回的表达方式总结和黑话 JSON提取两个列表。
Returns:
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
第一个列表是表达方式 (situation, style, source_id)
第二个列表是黑话 (content, source_id)
第一个列表是表达方式 (situation, style, source_id)
第二个列表是黑话 (content, source_id)
"""
if not response:
return [], []
raw = response.strip()
if match := re.search(r"```json\s*(.*?)\s*```", raw, re.DOTALL):
raw = match[1].strip()
else:
# 去掉可能存在的通用 ``` 包裹
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
raw = raw.strip()
raw = _strip_markdown_code_fence(response)
parsed = _try_parse(raw)
if parsed is None:
@@ -180,22 +242,21 @@ def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]]
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
return [], []
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
expressions: List[Tuple[str, str, str]] = []
jargon_entries: List[Tuple[str, str]] = []
for item in parsed_list:
if not isinstance(item, dict):
continue
# 检查是否是表达方式条目(有 situation 和 style
situation = str(item.get("situation", "")).strip()
style = str(item.get("style", "")).strip()
source_id = str(item.get("source_id", "")).strip()
if situation and style and source_id:
# 表达方式条目
expressions.append((situation, style, source_id))
continue
content = str(item.get("content", "")).strip()
if content and source_id:
jargon_entries.append((content, source_id))
@@ -204,25 +265,16 @@ def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]]
def is_single_char_jargon(content: str) -> bool:
"""
判断是否是单字黑话(单个汉字、英文或数字)
Args:
content: 词条内容
Returns:
bool: 如果是单字黑话返回True否则返回False
"""
"""判断是否是单字黑话(单个汉字、英文或数字)。"""
if not content or len(content) != 1:
return False
char = content[0]
# 判断是否是单个汉字、单个英文字母或单个数字
return (
"\u4e00" <= char <= "\u9fff" # 汉字
or "a" <= char <= "z" # 小写字母
or "A" <= char <= "Z" # 大写字母
or "0" <= char <= "9" # 数字
"\u4e00" <= char <= "\u9fff"
or "a" <= char <= "z"
or "A" <= char <= "Z"
or "0" <= char <= "9"
)

View File

@@ -54,27 +54,30 @@ class RespNotOkException(Exception):
return f"未知的异常响应代码:{self.status_code}"
class RespParseException(Exception):
"""响应解析错误,常见于响应格式不正确或解析方法不匹配"""
class ResponseContextException(Exception):
"""携带原始响应上下文的异常基类。"""
def __init__(self, ext_info: Any, message: str | None = None):
default_message: str = "请求失败"
def __init__(self, ext_info: Any = None, message: str | None = None):
super().__init__(message)
self.ext_info = ext_info
self.message = message
def __str__(self):
return self.message or "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法"
return self.message or self.default_message
class EmptyResponseException(Exception):
class RespParseException(ResponseContextException):
"""响应解析错误,常见于响应格式不正确或解析方法不匹配"""
default_message = "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法"
class EmptyResponseException(ResponseContextException):
"""响应内容为空"""
def __init__(self, message: str = "响应内容为空,这可能是一个临时性问题"):
super().__init__(message)
self.message = message
def __str__(self):
return self.message
default_message = "响应内容为空,这可能是一个临时性问题"
class ModelAttemptFailed(Exception):

View File

@@ -552,7 +552,7 @@ def _build_stream_api_response(
_warn_if_max_tokens_truncated(last_response, response.content, response.tool_calls)
if not response.content and not response.tool_calls and not response.reasoning_content:
raise EmptyResponseException()
raise EmptyResponseException(last_response)
return response
@@ -627,7 +627,7 @@ def _default_normal_response_parser(
usage_record = _extract_usage_record(response)
_warn_if_max_tokens_truncated(response, api_response.content, api_response.tool_calls)
if not api_response.content and not api_response.tool_calls and not api_response.reasoning_content:
raise EmptyResponseException("响应中既无文本内容也无工具调用")
raise EmptyResponseException(response, "响应中既无文本内容也无工具调用")
return api_response, usage_record

View File

@@ -79,6 +79,25 @@ THINK_CONTENT_PATTERN = re.compile(
)
"""用于解析 `<think>` 推理块的正则表达式。"""
XML_TOOL_CALL_PATTERN = re.compile(r"<tool_call>\s*(?P<body>.*?)\s*</tool_call>", re.DOTALL | re.IGNORECASE)
"""用于兜底解析模型以 XML 文本返回的工具调用。
这是一个暂时性兼容方案,专门处理“思维链内容里夹带工具调用”的情况;
后续如果上游稳定返回标准 tool_calls 字段,这里可能会调整或移除。
"""
XML_FUNCTION_CALL_PATTERN = re.compile(
r"<function=(?P<name>[A-Za-z0-9_.-]+)>\s*(?P<arguments>.*?)\s*</function>",
re.DOTALL | re.IGNORECASE,
)
"""用于从 XML 风格工具调用块中提取函数名与参数。"""
XML_PARAMETER_PATTERN = re.compile(
r"<parameter=(?P<name>[A-Za-z0-9_.-]+)>\s*(?P<value>.*?)\s*</parameter>",
re.DOTALL | re.IGNORECASE,
)
"""用于从 XML 风格工具调用块中提取参数列表。"""
CHAT_COMPLETIONS_RESERVED_EXTRA_BODY_KEYS = {
"max_tokens",
"messages",
@@ -346,6 +365,32 @@ def _convert_assistant_tool_calls(tool_calls: List[ToolCall]) -> List[ChatComple
return converted_tool_calls
def _sanitize_messages_for_toolless_request(messages: List[Message]) -> List[Message]:
"""在无工具请求时清洗历史工具调用链,避免兼容接口拒收消息。"""
sanitized_messages: List[Message] = []
for message in messages:
if message.role == RoleType.Tool:
continue
if message.role == RoleType.Assistant and message.tool_calls:
if not message.parts:
continue
assistant_message = Message(
role=message.role,
parts=list(message.parts),
tool_call_id=message.tool_call_id,
tool_name=message.tool_name,
tool_calls=None,
)
sanitized_messages.append(assistant_message)
continue
sanitized_messages.append(message)
return sanitized_messages
def _convert_messages(messages: List[Message]) -> List[ChatCompletionMessageParam]:
"""将内部消息列表转换为 OpenAI 兼容消息列表。
@@ -515,6 +560,66 @@ def _extract_reasoning_and_content(
return None, match.group("content_only").strip() or None
def _extract_xml_tool_calls(
raw_text: str | None,
parse_mode: ToolArgumentParseMode,
response: Any,
) -> Tuple[str | None, List[ToolCall] | None]:
"""从 XML 风格文本中兜底提取工具调用。"""
if not isinstance(raw_text, str) or not raw_text.strip():
return raw_text, None
tool_calls: List[ToolCall] = []
def _coerce_xml_parameter_value(raw_value: str) -> Any:
normalized_value = raw_value.strip()
if not normalized_value:
return ""
lowered_value = normalized_value.lower()
if lowered_value == "true":
return True
if lowered_value == "false":
return False
if lowered_value in {"null", "none"}:
return None
if normalized_value.startswith(("{", "[")):
try:
return repair_json(normalized_value, return_objects=True, logging=False)
except Exception:
return normalized_value
return normalized_value
def _parse_xml_parameters(raw_arguments: str) -> Dict[str, Any] | None:
parameters = {
match.group("name").strip(): _coerce_xml_parameter_value(match.group("value"))
for match in XML_PARAMETER_PATTERN.finditer(raw_arguments)
}
return parameters or None
def _replace_tool_call(match: re.Match[str]) -> str:
body = match.group("body")
function_match = XML_FUNCTION_CALL_PATTERN.search(body)
if function_match is None:
return match.group(0)
function_name = function_match.group("name").strip()
raw_arguments = function_match.group("arguments").strip()
arguments = _parse_xml_parameters(raw_arguments)
if arguments is None:
arguments = _parse_tool_arguments(raw_arguments, parse_mode, response) if raw_arguments else {}
tool_calls.append(
ToolCall(
call_id=f"xml_tool_call_{len(tool_calls) + 1}",
func_name=function_name,
args=arguments,
)
)
return ""
cleaned_text = XML_TOOL_CALL_PATTERN.sub(_replace_tool_call, raw_text).strip() or None
return cleaned_text, tool_calls or None
def _log_length_truncation(finish_reason: str | None, model_name: str | None) -> None:
"""记录因长度截断导致的告警日志。
@@ -526,6 +631,38 @@ def _log_length_truncation(finish_reason: str | None, model_name: str | None) ->
logger.info(f"模型{model_name or ''}因为超过最大 max_token 限制,可能仅输出部分内容,可视情况调整")
def _apply_xml_tool_call_fallback(
response: APIResponse,
parse_mode: ToolArgumentParseMode,
raw_response: Any,
) -> None:
"""当上游未返回标准 tool_calls 时,尝试从 XML 文本兜底解析。
这是一个暂时性处理方法,用来兼容思维链中混入工具调用的返回格式,
后续可能随着模型或上游接口的规范化而变更。
"""
if response.tool_calls:
return
reasoning_content, tool_calls = _extract_xml_tool_calls(response.reasoning_content, parse_mode, raw_response)
if reasoning_content != response.reasoning_content:
response.reasoning_content = reasoning_content
if tool_calls:
response.tool_calls = tool_calls
if not response.content and reasoning_content:
response.content = reasoning_content
response.reasoning_content = None
logger.warning("OpenAI 兼容响应未返回标准 tool_calls已从 XML 文本兜底解析工具调用")
return
cleaned_content, tool_calls = _extract_xml_tool_calls(response.content, parse_mode, raw_response)
if cleaned_content != response.content:
response.content = cleaned_content
if tool_calls:
response.tool_calls = tool_calls
logger.warning("OpenAI 兼容响应未返回标准 tool_calls已从 XML 文本兜底解析工具调用")
def _coerce_openai_argument(value: Any) -> Any | Omit:
"""将可选参数转换为 OpenAI SDK 期望的值。
@@ -561,7 +698,7 @@ def _build_api_status_message(error: APIStatusError) -> str:
message_parts.append(str(error.message))
response_text = getattr(getattr(error, "response", None), "text", None)
if response_text:
message_parts.append(str(response_text)[:300])
message_parts.append(str(response_text))
if message_parts:
return " | ".join(message_parts)
return f"上游接口返回状态码 {error.status_code}"
@@ -722,9 +859,10 @@ class _OpenAIStreamAccumulator:
response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments))
response.raw_data = {"model": self.model_name} if self.model_name else None
_apply_xml_tool_call_fallback(response, self.tool_argument_parse_mode, response.raw_data)
if not response.content and not response.tool_calls:
raise EmptyResponseException()
raise EmptyResponseException(response.raw_data)
return response
@@ -808,7 +946,7 @@ def _default_normal_response_parser(
"""
choices = getattr(resp, "choices", None)
if not choices:
raise EmptyResponseException("响应解析失败choices 为空或缺失")
raise EmptyResponseException(resp, "响应解析失败choices 为空或缺失")
api_response = APIResponse()
message_part = choices[0].message
@@ -847,9 +985,10 @@ def _default_normal_response_parser(
finish_reason = getattr(resp.choices[0], "finish_reason", None)
_log_length_truncation(finish_reason, getattr(resp, "model", None))
_apply_xml_tool_call_fallback(api_response, tool_argument_parse_mode, resp)
if not api_response.content and not api_response.tool_calls:
raise EmptyResponseException()
raise EmptyResponseException(resp)
return api_response, usage_record
@@ -965,7 +1104,12 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
model_info = request.model_info
try:
messages_payload: List[ChatCompletionMessageParam] = _convert_messages(request.message_list)
request_messages = (
list(request.message_list)
if request.tool_options
else _sanitize_messages_for_toolless_request(request.message_list)
)
messages_payload: List[ChatCompletionMessageParam] = _convert_messages(request_messages)
tools_payload: List[ChatCompletionToolParam] | None = (
_convert_tool_options(request.tool_options) if request.tool_options else None
)

View File

@@ -58,6 +58,42 @@ def _json_friendly(value: Any) -> Any:
return str(value)
def extract_error_response_body(error: Exception) -> Any | None:
"""尽量从异常对象中提取上游返回体,便于排查模型请求失败。"""
candidate_errors = [error, getattr(error, "__cause__", None)]
for candidate in candidate_errors:
if candidate is None:
continue
response = getattr(candidate, "response", None)
if response is not None:
response_json = getattr(response, "json", None)
if callable(response_json):
try:
return _json_friendly(response_json())
except Exception:
pass
response_text = getattr(response, "text", None)
if response_text not in (None, ""):
return str(response_text)
response_content = getattr(response, "content", None)
if response_content not in (None, b"", ""):
return _json_friendly(response_content)
response_body = getattr(candidate, "body", None)
if response_body not in (None, "", b""):
return _json_friendly(response_body)
ext_info = getattr(candidate, "ext_info", None)
if ext_info is not None:
return _json_friendly(ext_info)
return None
def _sanitize_filename_component(value: str) -> str:
"""将任意字符串转换为适合文件名使用的片段。"""
normalized_value = FILENAME_SAFE_PATTERN.sub("-", value.strip())
@@ -228,6 +264,7 @@ def serialize_model_info_snapshot(model_info: ModelInfo) -> dict[str, Any]:
"model_identifier": model_info.model_identifier,
"name": model_info.name,
"temperature": model_info.temperature,
"visual": model_info.visual,
}
@@ -244,6 +281,7 @@ def deserialize_model_info_snapshot(raw_model_info: Any) -> ModelInfo:
model_identifier=str(raw_model_info.get("model_identifier") or ""),
name=str(raw_model_info.get("name") or ""),
temperature=raw_model_info.get("temperature"),
visual=bool(raw_model_info.get("visual", False)),
)
@@ -386,6 +424,10 @@ def save_failed_request_snapshot(
"snapshot_version": SNAPSHOT_VERSION,
}
response_body = extract_error_response_body(error)
if response_body is not None:
snapshot_payload["error"]["response_body"] = response_body
snapshot_payload["replay"] = {
"command": build_replay_command(snapshot_path),
"file_uri": snapshot_path.as_uri(),

View File

@@ -3,6 +3,7 @@ from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
import asyncio
import inspect
import random
import re
import time
@@ -397,8 +398,6 @@ class LLMOrchestrator:
start_time = time.time()
tool_built = self._build_tool_options(tools)
if self.request_type.startswith("maisaka_"):
logger.info(f"LLMOrchestrator[{self.request_type}] 已构建 {len(tool_built or [])} 个内部工具选项")
execution_result = await self._execute_request(
request_type=RequestType.RESPONSE,
@@ -912,7 +911,11 @@ class LLMOrchestrator:
model_info, api_provider, client = self._select_model(exclude_models=failed_models_this_request)
message_list = []
if message_factory:
message_list = message_factory(client)
parameter_count = len(inspect.signature(message_factory).parameters)
if parameter_count >= 2:
message_list = message_factory(client, model_info)
else:
message_list = message_factory(client)
try:
request = self._build_client_request(
request_type=request_type,

View File

@@ -18,6 +18,7 @@ from src.common.message_server.server import Server, get_global_server
from src.common.remote import TelemetryHeartBeatTask
from src.config.config import config_manager, global_config
from src.manager.async_task_manager import async_task_manager
from src.maisaka.display.stage_status_board import disable_stage_status_board, enable_stage_status_board
from src.plugin_runtime.integration import get_plugin_runtime_manager
from src.prompt.prompt_manager import prompt_manager
from src.services.memory_flow_service import memory_automation_service
@@ -65,6 +66,7 @@ class MainSystem:
async def initialize(self) -> None:
"""初始化系统组件"""
enable_stage_status_board()
logger.info(t("startup.waking_up", nickname=global_config.bot.nickname))
# 其他初始化任务
@@ -169,6 +171,7 @@ async def main() -> None:
system.schedule_tasks(),
)
finally:
disable_stage_status_board()
emoji_manager.shutdown()
await memory_automation_service.shutdown()
await a_memorix_host_service.stop()

View File

@@ -10,6 +10,8 @@ from src.llm_models.payload_content.tool_option import ToolDefinitionInput
from .context import BuiltinToolRuntimeContext
from .continue_tool import get_tool_spec as get_continue_tool_spec
from .continue_tool import handle_tool as handle_continue_tool
from .finish import get_tool_spec as get_finish_tool_spec
from .finish import handle_tool as handle_finish_tool
from .no_reply import get_tool_spec as get_no_reply_tool_spec
from .no_reply import handle_tool as handle_no_reply_tool
from .query_jargon import get_tool_spec as get_query_jargon_tool_spec
@@ -22,6 +24,8 @@ from .reply import get_tool_spec as get_reply_tool_spec
from .reply import handle_tool as handle_reply_tool
from .send_emoji import get_tool_spec as get_send_emoji_tool_spec
from .send_emoji import handle_tool as handle_send_emoji_tool
from .tool_search import get_tool_spec as get_tool_search_tool_spec
from .tool_search import handle_tool as handle_tool_search_tool
from .view_complex_message import get_tool_spec as get_view_complex_message_tool_spec
from .view_complex_message import handle_tool as handle_view_complex_message_tool
from .wait import get_tool_spec as get_wait_tool_spec
@@ -44,11 +48,13 @@ def get_action_tool_specs() -> List[ToolSpec]:
"""获取 Action Loop 阶段可用的内置工具声明。"""
return [
get_finish_tool_spec(),
get_reply_tool_spec(),
get_view_complex_message_tool_spec(),
get_query_jargon_tool_spec(),
get_query_memory_tool_spec(enabled=bool(global_config.memory.enable_memory_query_tool)),
get_send_emoji_tool_spec(),
get_tool_search_tool_spec(),
]
@@ -63,12 +69,14 @@ def get_all_builtin_tool_specs() -> List[ToolSpec]:
return [
*get_timing_tool_specs(),
get_finish_tool_spec(),
get_reply_tool_spec(),
get_view_complex_message_tool_spec(),
get_query_jargon_tool_spec(),
get_query_memory_tool_spec(enabled=True),
get_query_person_info_tool_spec(),
get_send_emoji_tool_spec(),
get_tool_search_tool_spec(),
]
@@ -95,6 +103,7 @@ def build_builtin_tool_handlers(tool_ctx: BuiltinToolRuntimeContext) -> Dict[str
return {
"continue": lambda invocation, context=None: handle_continue_tool(tool_ctx, invocation, context),
"finish": lambda invocation, context=None: handle_finish_tool(tool_ctx, invocation, context),
"reply": lambda invocation, context=None: handle_reply_tool(tool_ctx, invocation, context),
"no_reply": lambda invocation, context=None: handle_no_reply_tool(tool_ctx, invocation, context),
"query_jargon": lambda invocation, context=None: handle_query_jargon_tool(tool_ctx, invocation, context),
@@ -106,6 +115,7 @@ def build_builtin_tool_handlers(tool_ctx: BuiltinToolRuntimeContext) -> Dict[str
),
"wait": lambda invocation, context=None: handle_wait_tool(tool_ctx, invocation, context),
"send_emoji": lambda invocation, context=None: handle_send_emoji_tool(tool_ctx, invocation, context),
"tool_search": lambda invocation, context=None: handle_tool_search_tool(tool_ctx, invocation, context),
"view_complex_message": lambda invocation, context=None: handle_view_complex_message_tool(
tool_ctx,
invocation,

View File

@@ -0,0 +1,34 @@
"""finish 内置工具。"""
from typing import Optional
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from .context import BuiltinToolRuntimeContext
def get_tool_spec() -> ToolSpec:
"""获取 finish 工具声明。"""
return ToolSpec(
name="finish",
brief_description="结束本轮思考,等待后续新的外部消息再继续。",
provider_name="maisaka_builtin",
provider_type="builtin",
)
async def handle_tool(
tool_ctx: BuiltinToolRuntimeContext,
invocation: ToolInvocation,
context: Optional[ToolExecutionContext] = None,
) -> ToolExecutionResult:
"""执行 finish 内置工具。"""
del context
tool_ctx.runtime._enter_stop_state()
return tool_ctx.build_success_result(
invocation.tool_name,
"当前对话循环已结束本轮思考,等待新的消息到来。",
metadata={"pause_execution": True},
)

View File

@@ -29,6 +29,6 @@ async def handle_tool(
tool_ctx.runtime._enter_stop_state()
return tool_ctx.build_success_result(
invocation.tool_name,
"当前对话循环已暂停,等待新消息到来。",
"当前暂时停止思考,等待新消息到来。",
metadata={"pause_execution": True},
)

View File

@@ -91,10 +91,6 @@ async def handle_tool(
f"未找到要回复的目标消息msg_id={target_message_id}",
)
logger.info(
f"{tool_ctx.runtime.log_prefix} 已触发回复工具,"
f"目标消息编号={target_message_id} 引用回复={set_quote} 最新思考={latest_thought!r}"
)
try:
replyer = replyer_manager.get_replyer(
chat_stream=tool_ctx.runtime.chat_stream,

View File

@@ -2,11 +2,11 @@
from datetime import datetime
from io import BytesIO
import math
from random import sample
from typing import Any, Dict, Optional
import asyncio
import math
from PIL import Image as PILImage
from PIL import ImageDraw, ImageFont
@@ -20,12 +20,14 @@ from src.common.logger import get_logger
from src.config.config import global_config
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.maisaka.context_messages import (
LLMContextMessage,
ReferenceMessage,
ReferenceMessageType,
SessionBackedMessage,
)
from src.plugin_runtime.hook_payloads import serialize_prompt_messages
from .context import BuiltinToolRuntimeContext
@@ -242,34 +244,9 @@ def _build_emoji_candidate_summary(emojis: list[MaiEmoji]) -> str:
return "\n".join(summary_lines).strip()
def _build_send_emoji_prompt_preview(
*,
system_prompt: str,
requested_emotion: str,
grid_rows: int,
grid_columns: int,
sampled_emojis: list[MaiEmoji],
) -> str:
"""构建表情选择子代理的文本预览。"""
task_text = (
"[选择任务]\n"
f"requested_emotion: {requested_emotion or '未指定'}\n"
f"候选总数: {len(sampled_emojis)}\n"
f"拼图布局: {grid_rows}x{grid_columns}\n"
"请只输出 JSON。"
)
candidate_summary = _build_emoji_candidate_summary(sampled_emojis)
return (
f"[System Prompt]\n{system_prompt}\n\n"
f"{task_text}\n\n"
f"[候选表情摘要]\n{candidate_summary or '无候选表情'}"
).strip()
def _build_send_emoji_monitor_detail(
*,
prompt_text: str = "",
request_messages: Optional[list[dict[str, Any]]] = None,
reasoning_text: str = "",
output_text: str = "",
metrics: Optional[Dict[str, Any]] = None,
@@ -278,8 +255,8 @@ def _build_send_emoji_monitor_detail(
"""构建 emotion tool 统一监控详情。"""
detail: Dict[str, Any] = {}
if prompt_text.strip():
detail["prompt_text"] = prompt_text.strip()
if isinstance(request_messages, list) and request_messages:
detail["request_messages"] = request_messages
if reasoning_text.strip():
detail["reasoning_text"] = reasoning_text.strip()
if output_text.strip():
@@ -387,13 +364,16 @@ async def _select_emoji_with_sub_agent(
remaining_uses_value=1,
display_prefix="[表情包选择任务]",
)
prompt_preview = _build_send_emoji_prompt_preview(
system_prompt=system_prompt,
requested_emotion=requested_emotion,
grid_rows=grid_rows,
grid_columns=grid_columns,
sampled_emojis=sampled_emojis,
)
request_messages = [
MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build(),
]
prompt_llm_message = prompt_message.to_llm_message()
if prompt_llm_message is not None:
request_messages.append(prompt_llm_message)
candidate_llm_message = candidate_message.to_llm_message()
if candidate_llm_message is not None:
request_messages.append(candidate_llm_message)
serialized_request_messages = serialize_prompt_messages(request_messages)
selection_started_at = datetime.now()
response = await tool_ctx.runtime.run_sub_agent(
@@ -421,7 +401,7 @@ async def _select_emoji_with_sub_agent(
logger.warning(f"{tool_ctx.runtime.log_prefix} 表情包子代理结果解析失败,将回退到候选首项: {exc}")
if selection_metadata is not None:
selection_metadata["monitor_detail"] = _build_send_emoji_monitor_detail(
prompt_text=prompt_preview,
request_messages=serialized_request_messages,
output_text=response.content or "",
metrics=selection_metrics,
extra_sections=[{
@@ -435,7 +415,7 @@ async def _select_emoji_with_sub_agent(
if selection_metadata is not None:
selection_metadata["reason"] = selection.reason.strip()
selection_metadata["monitor_detail"] = _build_send_emoji_monitor_detail(
prompt_text=prompt_preview,
request_messages=serialized_request_messages,
reasoning_text=selection.reason,
output_text=response.content or "",
metrics=selection_metrics,

View File

@@ -0,0 +1,106 @@
"""tool_search 内置工具。"""
from typing import Any, Dict, List, Optional
import json
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from .context import BuiltinToolRuntimeContext
def get_tool_spec() -> ToolSpec:
"""获取 tool_search 工具声明。"""
return ToolSpec(
name="tool_search",
brief_description="在 deferred tools 列表中按名称或关键词搜索工具,并将命中的工具加入后续轮次的可用工具列表。",
detailed_description=(
"参数说明:\n"
"- queryString必填。工具名、前缀或关键词。\n"
"- limitInteger可选。最多返回多少个匹配工具默认为 5。"
),
parameters_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "要搜索的工具名、前缀或关键词。",
},
"limit": {
"type": "integer",
"description": "最多返回多少个匹配工具。",
"minimum": 1,
},
},
"required": ["query"],
},
provider_name="maisaka_builtin",
provider_type="builtin",
)
async def handle_tool(
tool_ctx: BuiltinToolRuntimeContext,
invocation: ToolInvocation,
context: Optional[ToolExecutionContext] = None,
) -> ToolExecutionResult:
"""执行 tool_search 内置工具。"""
del context
raw_query = invocation.arguments.get("query")
if not isinstance(raw_query, str) or not raw_query.strip():
return tool_ctx.build_failure_result(
invocation.tool_name,
"tool_search 需要提供非空的 `query` 字符串参数。",
)
raw_limit = invocation.arguments.get("limit", 5)
try:
limit = max(1, int(raw_limit))
except (TypeError, ValueError):
limit = 5
matched_tool_specs = tool_ctx.runtime.search_deferred_tool_specs(raw_query, limit=limit)
matched_tool_names = [tool_spec.name for tool_spec in matched_tool_specs]
newly_discovered_tool_names = tool_ctx.runtime.discover_deferred_tools(matched_tool_names)
structured_content: Dict[str, Any] = {
"query": raw_query.strip(),
"matched_tool_names": matched_tool_names,
"newly_discovered_tool_names": newly_discovered_tool_names,
}
if not matched_tool_names:
return tool_ctx.build_success_result(
invocation.tool_name,
"未找到匹配的 deferred tools请尝试更完整的工具名、前缀或其他关键词。",
structured_content=structured_content,
metadata={"record_display_prompt": "tool_search 未找到匹配工具。"},
)
content_lines: List[str] = [
f"已找到 {len(matched_tool_names)} 个 deferred tools它们会在后续轮次中加入可用工具列表",
*[f"- {tool_name}" for tool_name in matched_tool_names],
]
if newly_discovered_tool_names:
content_lines.extend(
[
"",
"本次新发现的工具:",
*[f"- {tool_name}" for tool_name in newly_discovered_tool_names],
]
)
else:
content_lines.extend(["", "这些工具此前已经发现过,无需重复展开。"])
return tool_ctx.build_success_result(
invocation.tool_name,
"\n".join(content_lines),
structured_content=structured_content,
metadata={
"matched_tool_names": matched_tool_names,
"newly_discovered_tool_names": newly_discovered_tool_names,
"record_display_prompt": json.dumps(structured_content, ensure_ascii=False),
},
)

View File

@@ -12,8 +12,8 @@ def get_tool_spec() -> ToolSpec:
return ToolSpec(
name="wait",
brief_description="暂停当前对话并等待用户新的输入",
detailed_description="参数说明:\n- secondsinteger必填。等待的秒数。",
brief_description="暂停当前对话并固定等待一段时间,期间不因新消息提前恢复",
detailed_description="参数说明:\n- secondsinteger必填。等待的秒数。等待期间收到的新消息只会暂存,直到超时后再继续处理。",
parameters_schema={
"type": "object",
"properties": {
@@ -46,6 +46,6 @@ async def handle_tool(
tool_ctx.runtime._enter_wait_state(seconds=wait_seconds, tool_call_id=invocation.call_id)
return tool_ctx.build_success_result(
invocation.tool_name,
f"当前对话循环进入等待状态,最长等待 {wait_seconds} 秒。",
f"当前对话循环进入等待状态,将固定等待 {wait_seconds};期间收到的新消息不会提前打断本次等待",
metadata={"pause_execution": True},
)

View File

@@ -5,20 +5,18 @@ from datetime import datetime
from typing import Any, List, Optional, Sequence
import asyncio
import json
import random
from pydantic import BaseModel, Field as PydanticField
from rich.console import RenderableType
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.core.tooling import ToolRegistry, ToolSpec
from src.core.tooling import ToolRegistry
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
from src.llm_models.payload_content.resp_format import RespFormat
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, ToolOption, normalize_tool_options
from src.plugin_runtime.hook_payloads import (
deserialize_prompt_messages,
@@ -32,9 +30,11 @@ from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistr
from src.services.llm_service import LLMServiceClient
from .builtin_tool import get_builtin_tools
from .context_messages import AssistantMessage, LLMContextMessage
from .context_messages import AssistantMessage, LLMContextMessage, ToolResultMessage
from .history_utils import drop_orphan_tool_results
from .prompt_cli_renderer import PromptCLIVisualizer
from .display.prompt_cli_renderer import PromptCLIVisualizer
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
@dataclass(slots=True)
@@ -54,13 +54,6 @@ class ChatResponse:
prompt_section: Optional[RenderableType] = None
class ToolFilterSelection(BaseModel):
"""工具筛选响应。"""
selected_tool_names: list[str] = PydanticField(default_factory=list)
"""经过预筛后保留的候选工具名称列表。"""
logger = get_logger("maisaka_chat_loop")
@@ -217,10 +210,6 @@ class MaisakaChatLoopService:
else:
self._chat_system_prompt = chat_system_prompt
self._llm_chat = LLMServiceClient(task_name="planner", request_type="maisaka_planner")
self._tool_filter_llm = LLMServiceClient(
task_name=global_config.maisaka.tool_filter_task_name,
request_type="maisaka_tool_filter",
)
@property
def personality_prompt(self) -> str:
@@ -303,8 +292,15 @@ class MaisakaChatLoopService:
"file_tools_section": tools_section,
"group_chat_attention_block": self._build_group_chat_attention_block(),
"identity": self._personality_prompt,
"time_block": self._build_time_block(),
}
@staticmethod
def _build_time_block() -> str:
"""构建当前时间提示块。"""
return f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
def _build_group_chat_attention_block(self) -> str:
"""构建当前聊天场景下的额外注意事项块。"""
@@ -399,6 +395,7 @@ class MaisakaChatLoopService:
self,
selected_history: List[LLMContextMessage],
*,
injected_user_messages: Sequence[str] | None = None,
system_prompt: Optional[str] = None,
) -> List[Message]:
"""构造发给大模型的消息列表。
@@ -420,254 +417,49 @@ class MaisakaChatLoopService:
if llm_message is not None:
messages.append(llm_message)
normalized_injected_messages: List[Message] = []
for injected_message in injected_user_messages or []:
normalized_message = str(injected_message or "").strip()
if not normalized_message:
continue
normalized_injected_messages.append(
MessageBuilder()
.set_role(RoleType.User)
.add_text_content(normalized_message)
.build()
)
if normalized_injected_messages:
insertion_index = self._resolve_injected_user_messages_insertion_index(messages)
messages[insertion_index:insertion_index] = normalized_injected_messages
return messages
@staticmethod
def _is_builtin_tool_spec(tool_spec: ToolSpec) -> bool:
"""判断一个工具是否属于默认内置工具
def _resolve_injected_user_messages_insertion_index(messages: Sequence[Message]) -> int:
"""计算 injected meta user messages 在请求中的插入位置
Args:
tool_spec: 待判断的工具声明。
Returns:
bool: 是否为默认内置工具
规则与 deferred attachment 更接近:
- 从尾部向前寻找最近的 stopping point
- stopping point 为 assistant 消息或 tool 结果消息;
- 找到后插入到其后面;
- 若不存在 stopping point则退回到 system 消息之后
"""
return tool_spec.provider_type == "builtin" or tool_spec.provider_name == "maisaka_builtin"
for index in range(len(messages) - 1, -1, -1):
message = messages[index]
if message.role in {RoleType.Assistant, RoleType.Tool}:
return index + 1
@classmethod
def _split_builtin_and_candidate_tools(
cls,
tool_specs: List[ToolSpec],
) -> tuple[List[ToolSpec], List[ToolSpec]]:
"""拆分内置工具与可筛选工具列表。
Args:
tool_specs: 当前全部工具声明。
Returns:
tuple[List[ToolSpec], List[ToolSpec]]: `(内置工具, 可筛选工具)`。
"""
builtin_tool_specs: List[ToolSpec] = []
candidate_tool_specs: List[ToolSpec] = []
for tool_spec in tool_specs:
if cls._is_builtin_tool_spec(tool_spec):
builtin_tool_specs.append(tool_spec)
else:
candidate_tool_specs.append(tool_spec)
return builtin_tool_specs, candidate_tool_specs
@staticmethod
def _truncate_tool_filter_text(text: str, max_length: int = 180) -> str:
"""截断工具筛选阶段展示的文本。
Args:
text: 原始文本。
max_length: 最长保留字符数。
Returns:
str: 截断后的文本。
"""
normalized_text = text.strip()
if len(normalized_text) <= max_length:
return normalized_text
return f"{normalized_text[: max_length - 1]}"
def _build_tool_filter_prompt(
self,
selected_history: List[LLMContextMessage],
candidate_tool_specs: List[ToolSpec],
max_keep: int,
) -> str:
"""构造小模型工具预筛选提示词。
Args:
selected_history: 已选中的对话上下文。
candidate_tool_specs: 非内置候选工具列表。
max_keep: 最多保留的候选工具数量。
Returns:
str: 用于工具预筛的小模型提示词。
"""
history_lines: List[str] = []
for message in selected_history[-10:]:
plain_text = message.processed_plain_text.strip()
if not plain_text:
continue
history_lines.append(
f"- {message.role}: {self._truncate_tool_filter_text(plain_text, max_length=200)}"
)
if history_lines:
history_section = "\n".join(history_lines)
else:
history_section = "- 当前没有可用的对话上下文。"
tool_lines = [
f"- {tool_spec.name}: {tool_spec.brief_description.strip() or '无简要描述'}"
for tool_spec in candidate_tool_specs
]
tool_section = "\n".join(tool_lines) if tool_lines else "- 当前没有候选工具。"
return (
"你是 Maisaka 的工具预筛选器。\n"
"你的任务是在正式进入 planner 前,根据当前情景从候选工具中挑出最可能马上会用到的工具。\n"
"默认内置工具已经自动保留,不在候选列表中,你不需要再次选择它们。\n"
"你只能参考工具的简要描述,不要假设未描述的隐藏能力。\n"
f"最多保留 {max_keep} 个候选工具;如果都不合适,可以返回空数组。\n"
"请严格返回 JSON 对象,格式为:"
'{"selected_tool_names":["工具名1","工具名2"]}\n\n'
f"【最近对话】\n{history_section}\n\n"
f"【候选工具(仅简要描述)】\n{tool_section}"
)
@staticmethod
def _parse_tool_filter_response(
response_text: str,
candidate_tool_specs: List[ToolSpec],
max_keep: int,
) -> List[ToolSpec] | None:
"""解析工具预筛选响应。
Args:
response_text: 小模型返回的原始文本。
candidate_tool_specs: 非内置候选工具列表。
max_keep: 最多保留的候选工具数量。
Returns:
List[ToolSpec] | None: 成功解析时返回筛选后的工具列表;解析失败时返回 ``None``。
"""
normalized_response = response_text.strip()
if not normalized_response:
return None
selected_tool_names: List[str]
try:
selected_tool_names = ToolFilterSelection.model_validate_json(normalized_response).selected_tool_names
except Exception:
try:
parsed_payload = json.loads(normalized_response)
except json.JSONDecodeError:
return None
if isinstance(parsed_payload, dict):
raw_tool_names = parsed_payload.get("selected_tool_names", [])
elif isinstance(parsed_payload, list):
raw_tool_names = parsed_payload
else:
return None
if not isinstance(raw_tool_names, list):
return None
selected_tool_names = []
for item in raw_tool_names:
normalized_name = str(item).strip()
if normalized_name:
selected_tool_names.append(normalized_name)
candidate_map = {tool_spec.name: tool_spec for tool_spec in candidate_tool_specs}
filtered_tool_specs: List[ToolSpec] = []
seen_names: set[str] = set()
for tool_name in selected_tool_names:
normalized_name = tool_name.strip()
if not normalized_name or normalized_name in seen_names:
continue
tool_spec = candidate_map.get(normalized_name)
if tool_spec is None:
continue
seen_names.add(normalized_name)
filtered_tool_specs.append(tool_spec)
if len(filtered_tool_specs) >= max_keep:
break
return filtered_tool_specs
async def _filter_tool_specs_for_planner(
self,
selected_history: List[LLMContextMessage],
tool_specs: List[ToolSpec],
) -> List[ToolSpec]:
"""在将工具交给 planner 前进行快速预筛选。
Args:
selected_history: 已选中的对话上下文。
tool_specs: 当前全部可用工具声明。
Returns:
List[ToolSpec]: 最终交给 planner 的工具声明列表。
"""
threshold = max(1, int(global_config.maisaka.tool_filter_threshold))
max_keep = max(1, int(global_config.maisaka.tool_filter_max_keep))
if len(tool_specs) <= threshold:
return tool_specs
builtin_tool_specs, candidate_tool_specs = self._split_builtin_and_candidate_tools(tool_specs)
if not candidate_tool_specs:
return tool_specs
if len(candidate_tool_specs) <= max_keep:
return [*builtin_tool_specs, *candidate_tool_specs]
filter_prompt = self._build_tool_filter_prompt(selected_history, candidate_tool_specs, max_keep)
logger.info(
"工具预筛选开始: "
f"总工具数={len(tool_specs)} "
f"内置工具数={len(builtin_tool_specs)} "
f"候选工具数={len(candidate_tool_specs)} "
f"最多保留候选数={max_keep}"
)
try:
generation_result = await self._tool_filter_llm.generate_response(
prompt=filter_prompt,
options=LLMGenerationOptions(
temperature=0.0,
max_tokens=256,
response_format=RespFormat(
format_type=RespFormatType.JSON_SCHEMA,
schema=ToolFilterSelection,
),
),
)
except Exception as exc:
logger.warning(f"工具预筛选失败,保留全部工具。错误={exc}")
return tool_specs
filtered_candidate_tool_specs = self._parse_tool_filter_response(
generation_result.response or "",
candidate_tool_specs,
max_keep,
)
if filtered_candidate_tool_specs is None:
logger.warning(
"工具预筛选返回结果无法解析,保留全部工具。"
f" 原始返回={generation_result.response or ''!r}"
)
return tool_specs
filtered_tool_specs = [*builtin_tool_specs, *filtered_candidate_tool_specs]
if not filtered_tool_specs:
logger.warning("工具预筛选得到空结果,保留全部工具以避免主流程失去工具能力。")
return tool_specs
logger.info(
"工具预筛选完成: "
f"筛选前总数={len(tool_specs)} "
f"筛选后总数={len(filtered_tool_specs)} "
f"保留候选工具={[tool_spec.name for tool_spec in filtered_candidate_tool_specs]}"
)
return filtered_tool_specs
if messages and messages[0].role == RoleType.System:
return 1
return 0
async def chat_loop_step(
self,
chat_history: List[LLMContextMessage],
*,
injected_user_messages: Sequence[str] | None = None,
request_kind: str = "planner",
response_format: RespFormat | None = None,
tool_definitions: Sequence[ToolDefinitionInput] | None = None,
@@ -683,8 +475,14 @@ class MaisakaChatLoopService:
if not self._prompts_loaded:
await self.ensure_chat_prompt_loaded()
selected_history, selection_reason = self.select_llm_context_messages(chat_history)
built_messages = self._build_request_messages(selected_history)
selected_history, selection_reason = self.select_llm_context_messages(
chat_history,
request_kind=request_kind,
)
built_messages = self._build_request_messages(
selected_history,
injected_user_messages=injected_user_messages,
)
def message_factory(_client: BaseClient) -> List[Message]:
"""返回当前轮次已经构建好的请求消息。
@@ -704,8 +502,7 @@ class MaisakaChatLoopService:
all_tools = list(tool_definitions)
elif self._tool_registry is not None:
tool_specs = await self._tool_registry.list_tools()
filtered_tool_specs = await self._filter_tool_specs_for_planner(selected_history, tool_specs)
all_tools = [tool_spec.to_llm_definition() for tool_spec in filtered_tool_specs]
all_tools = [tool_spec.to_llm_definition() for tool_spec in tool_specs]
else:
all_tools = [*get_builtin_tools(), *self._extra_tools]
@@ -740,15 +537,9 @@ class MaisakaChatLoopService:
selection_reason=selection_reason,
image_display_mode=image_display_mode,
folded=global_config.debug.fold_maisaka_thinking,
tool_definitions=list(all_tools),
)
logger.info(
"规划器请求开始: "
f"已选上下文消息数={len(selected_history)} "
f"大模型消息数={len(built_messages)} "
f"工具数={len(all_tools)} "
f"启用打断={self._interrupt_flag is not None}"
)
generation_result = await self._llm_chat.generate_response_with_messages(
message_factory=message_factory,
options=LLMGenerationOptions(
@@ -760,15 +551,6 @@ class MaisakaChatLoopService:
),
)
prompt_stats_text = PromptCLIVisualizer.build_prompt_stats_text(
selected_history_count=len(selected_history),
built_message_count=len(built_messages),
prompt_tokens=generation_result.prompt_tokens,
completion_tokens=generation_result.completion_tokens,
total_tokens=generation_result.total_tokens,
)
logger.info(f"本轮Prompt统计: {prompt_stats_text}")
final_response = generation_result.response or ""
final_tool_calls = list(generation_result.tool_calls or [])
after_response_result = await self._get_runtime_manager().invoke_hook(
@@ -822,16 +604,21 @@ class MaisakaChatLoopService:
def select_llm_context_messages(
chat_history: List[LLMContextMessage],
*,
request_kind: str = "planner",
max_context_size: Optional[int] = None,
) -> tuple[List[LLMContextMessage], str]:
"""??????? LLM ???????"""
"""选择LLM上下文消息"""
filtered_history = MaisakaChatLoopService._filter_history_for_request_kind(
chat_history,
request_kind=request_kind,
)
effective_context_size = max(1, int(max_context_size or global_config.chat.max_context_size))
selected_indices: List[int] = []
counted_message_count = 0
for index in range(len(chat_history) - 1, -1, -1):
message = chat_history[index]
for index in range(len(filtered_history) - 1, -1, -1):
message = filtered_history[index]
if message.to_llm_message() is None:
continue
@@ -842,10 +629,10 @@ class MaisakaChatLoopService:
break
if not selected_indices:
return [], f"???????? {effective_context_size} ? user/assistant??? 0 ??"
return [], f"没有选择到上下文消息,实际发送 {effective_context_size} user/assistant 消息"
selected_indices.reverse()
selected_history = [chat_history[index] for index in selected_indices]
selected_history = [filtered_history[index] for index in selected_indices]
selected_history, hidden_assistant_count = MaisakaChatLoopService._hide_early_assistant_messages(selected_history)
selected_history, _ = drop_orphan_tool_results(selected_history)
selection_reason = (
@@ -860,45 +647,43 @@ class MaisakaChatLoopService:
)
@staticmethod
def _select_llm_context_messages(chat_history: List[LLMContextMessage]) -> tuple[List[LLMContextMessage], str]:
"""选择真正发送给 LLM 的上下文消息。
def _filter_history_for_request_kind(
selected_history: List[LLMContextMessage],
*,
request_kind: str,
) -> List[LLMContextMessage]:
"""按请求类型过滤不应暴露的历史工具链。"""
Args:
chat_history: 当前全部对话历史。
if request_kind != "planner":
return selected_history
Returns:
tuple[List[LLMContextMessage], str]: `(已选上下文, 选择说明)`。
"""
max_context_size = max(1, int(global_config.chat.max_context_size))
selected_indices: List[int] = []
counted_message_count = 0
for index in range(len(chat_history) - 1, -1, -1):
message = chat_history[index]
if message.to_llm_message() is None:
filtered_history: List[LLMContextMessage] = []
for message in selected_history:
if isinstance(message, ToolResultMessage) and message.tool_name in TIMING_GATE_TOOL_NAMES:
continue
selected_indices.append(index)
if message.count_in_context:
counted_message_count += 1
if counted_message_count >= max_context_size:
break
if isinstance(message, AssistantMessage) and message.tool_calls:
kept_tool_calls = [
tool_call
for tool_call in message.tool_calls
if tool_call.func_name not in TIMING_GATE_TOOL_NAMES
]
if not kept_tool_calls:
continue
if len(kept_tool_calls) != len(message.tool_calls):
filtered_history.append(
AssistantMessage(
content=message.content,
timestamp=message.timestamp,
tool_calls=kept_tool_calls,
source_kind=message.source_kind,
)
)
continue
if not selected_indices:
return [], f"上下文判定:最近 {max_context_size} 条 user/assistant当前 0 条)"
filtered_history.append(message)
selected_indices.reverse()
selected_history = [chat_history[index] for index in selected_indices]
selected_history, hidden_assistant_count = MaisakaChatLoopService._hide_early_assistant_messages(selected_history)
selected_history, _ = drop_orphan_tool_results(selected_history)
return (
selected_history,
(
f"上下文判定:最近 {max_context_size} 条 user/assistant"
f"展示并发送窗口内消息 {len(selected_history)}"
),
)
return filtered_history
@staticmethod
def _hide_early_assistant_messages(

View File

@@ -51,7 +51,9 @@ def _append_emoji_component(builder: MessageBuilder, component: EmojiComponent)
if component.content:
builder.add_text_content(component.content)
return True
return False
builder.add_text_content("[表情包]")
return True
def _append_image_component(builder: MessageBuilder, component: ImageComponent) -> bool:
@@ -65,7 +67,9 @@ def _append_image_component(builder: MessageBuilder, component: ImageComponent)
if component.content:
builder.add_text_content(component.content)
return True
return False
builder.add_text_content("[图片]")
return True
def _append_reply_component(builder: MessageBuilder, component: ReplyComponent) -> bool:

View File

@@ -0,0 +1,33 @@
"""Maisaka 展示模块。"""
from .display_utils import (
build_tool_call_summary_lines,
format_token_count,
format_tool_call_for_display,
get_request_panel_style,
get_role_badge_label,
get_role_badge_style,
)
from .prompt_cli_renderer import PromptCLIVisualizer
from .prompt_preview_logger import PromptPreviewLogger
from .stage_status_board import (
disable_stage_status_board,
enable_stage_status_board,
remove_stage_status,
update_stage_status,
)
__all__ = [
"PromptCLIVisualizer",
"PromptPreviewLogger",
"build_tool_call_summary_lines",
"disable_stage_status_board",
"enable_stage_status_board",
"format_token_count",
"format_tool_call_for_display",
"get_request_panel_style",
"get_role_badge_label",
"get_role_badge_style",
"remove_stage_status",
"update_stage_status",
]

View File

@@ -4,14 +4,15 @@ from typing import Any
_REQUEST_PANEL_STYLE_MAP: dict[str, tuple[str, str]] = {
"timing_gate": ("\u004d\u0061\u0069\u0053\u0061\u006b\u0061 \u5927\u6a21\u578b\u8bf7\u6c42 - Timing Gate \u5b50\u4ee3\u7406", "bright_magenta"),
"replyer": ("\u004d\u0061\u0069\u0053\u0061\u006b\u0061 \u56de\u590d\u5668 Prompt", "bright_yellow"),
"planner": ("MaiSaka 大模型请求 - 对话单步", "green"),
"timing_gate": ("MaiSaka 大模型请求 - Timing Gate 子代理", "bright_magenta"),
"replyer": ("MaiSaka 回复器 Prompt", "bright_yellow"),
"emotion": ("MaiSaka Emotion Tool Prompt", "bright_cyan"),
"sub_agent": ("\u004d\u0061\u0069\u0053\u0061\u006b\u0061 \u5927\u6a21\u578b\u8bf7\u6c42 - \u5b50\u4ee3\u7406", "bright_blue"),
"sub_agent": ("MaiSaka 大模型请求 - 子代理", "bright_blue"),
}
_DEFAULT_REQUEST_PANEL_STYLE: tuple[str, str] = (
"\u004d\u0061\u0069\u0053\u0061\u006b\u0061 \u5927\u6a21\u578b\u8bf7\u6c42 - \u5bf9\u8bdd\u5355\u6b65",
"MaiSaka 大模型请求 - 对话单步",
"cyan",
)
@@ -23,10 +24,10 @@ _ROLE_BADGE_STYLE_MAP: dict[str, str] = {
}
_ROLE_BADGE_LABEL_MAP: dict[str, str] = {
"system": "\u7cfb\u7edf",
"user": "\u7528\u6237",
"assistant": "\u52a9\u624b",
"tool": "\u5de5\u5177",
"system": "系统",
"user": "用户",
"assistant": "助手",
"tool": "工具",
}
@@ -54,7 +55,7 @@ def get_role_badge_style(role: str) -> str:
def get_role_badge_label(role: str) -> str:
"""返回角色标签对应的展示文案。"""
return _ROLE_BADGE_LABEL_MAP.get(role, "\u672a\u77e5")
return _ROLE_BADGE_LABEL_MAP.get(role, "未知")
def format_tool_call_for_display(tool_call: Any) -> dict[str, Any]:

View File

@@ -0,0 +1,58 @@
"""Maisaka Prompt 预览路径工具。"""
from __future__ import annotations
from pathlib import Path
from urllib.parse import quote
import re
from src.chat.message_receive.chat_manager import chat_manager
REPO_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve()
SAFE_NAME_PATTERN = re.compile(r"[^A-Za-z0-9._-]+")
def normalize_preview_name(value: str) -> str:
normalized_value = SAFE_NAME_PATTERN.sub("_", str(value or "").strip()).strip("._")
if normalized_value:
return normalized_value
return "unknown"
def normalize_platform_name(platform: str) -> str:
normalized_platform = str(platform or "").strip().lower()
platform_aliases = {
"telegram": "tg",
}
return normalize_preview_name(platform_aliases.get(normalized_platform, normalized_platform))
def build_preview_chat_dir_name(chat_id: str) -> str:
session = chat_manager.get_session_by_session_id(chat_id)
if session is not None:
platform = normalize_platform_name(session.platform)
if session.is_group_session and session.group_id:
return f"{platform}_group_{normalize_preview_name(session.group_id)}"
if session.user_id:
return f"{platform}_private_{normalize_preview_name(session.user_id)}"
normalized_chat_id = normalize_preview_name(chat_id)
if normalized_chat_id != "unknown":
return normalized_chat_id
return "unknown_chat"
def build_display_path(file_path: Path) -> str:
"""构造用于展示的路径,项目内文件优先显示相对路径。"""
resolved_path = file_path.resolve()
try:
return resolved_path.relative_to(REPO_ROOT).as_posix()
except ValueError:
return resolved_path.as_posix()
def build_file_uri(file_path: Path) -> str:
normalized = file_path.resolve().as_posix()
return f"file:///{quote(normalized, safe='/:')}"

View File

@@ -7,7 +7,6 @@ from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal
from urllib.parse import quote
import hashlib
import html
@@ -27,10 +26,10 @@ from .display_utils import (
get_role_badge_label as get_shared_role_badge_label,
get_role_badge_style as get_shared_role_badge_style,
)
from .preview_path_utils import build_display_path, build_file_uri, REPO_ROOT
from .prompt_preview_logger import PromptPreviewLogger
PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve()
DATA_IMAGE_DIR = PROJECT_ROOT / "data" / "images"
DATA_IMAGE_DIR = REPO_ROOT / "data" / "images"
class PromptImageDisplayMode(str, Enum):
@@ -115,11 +114,6 @@ class PromptCLIVisualizer:
digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest()
return root / f"{digest}.{image_format}"
@staticmethod
def _build_file_uri(file_path: Path) -> str:
normalized = file_path.resolve().as_posix()
return f"file:///{quote(normalized, safe='/:')}"
@staticmethod
def _build_official_image_path(image_format: str, image_base64: str) -> Path | None:
normalized_format = PromptCLIVisualizer._normalize_image_format(image_format)
@@ -140,7 +134,7 @@ class PromptCLIVisualizer:
normalized_format = PromptCLIVisualizer._normalize_image_format(image_format) or "bin"
official_path = PromptCLIVisualizer._build_official_image_path(image_format, image_base64)
if official_path is not None:
return PromptCLIVisualizer._build_file_uri(official_path), official_path
return build_file_uri(official_path), official_path
try:
image_bytes = b64decode(image_base64)
@@ -153,7 +147,7 @@ class PromptCLIVisualizer:
path.write_bytes(image_bytes)
except Exception:
return None
return PromptCLIVisualizer._build_file_uri(path), path
return build_file_uri(path), path
@classmethod
def _render_image_item(cls, image_format: str, image_base64: str, settings: PromptImageDisplaySettings) -> Panel:
@@ -169,8 +163,9 @@ class PromptCLIVisualizer:
path_result = cls._build_image_file_link(image_format, image_base64)
if path_result is not None:
file_uri, file_path = path_result
display_path = build_display_path(file_path)
preview_parts: List[RenderableType] = [
Text(f"图片格式 image/{normalized_format} {size_text} 路径:{file_path}", style="magenta")
Text(f"图片格式 image/{normalized_format} {size_text} 路径:{display_path}", style="magenta")
]
preview_parts.append(Text.from_markup(f"[link={file_uri}]点击打开图片[/link]", style="cyan"))
@@ -181,6 +176,16 @@ class PromptCLIVisualizer:
padding=(0, 1),
)
@staticmethod
def _extract_image_pair(item: Any) -> tuple[str, str] | None:
"""兼容图片片段被序列化为 tuple 或 list 的两种形式。"""
if isinstance(item, (tuple, list)) and len(item) == 2:
image_format, image_base64 = item
if isinstance(image_format, str) and isinstance(image_base64, str):
return image_format, image_base64
return None
@classmethod
def _render_message_content(cls, content: Any, settings: PromptImageDisplaySettings) -> RenderableType:
if isinstance(content, str):
@@ -192,11 +197,11 @@ class PromptCLIVisualizer:
if isinstance(item, str):
parts.append(Text(item))
continue
if isinstance(item, tuple) and len(item) == 2:
image_format, image_base64 = item
if isinstance(image_format, str) and isinstance(image_base64, str):
parts.append(cls._render_image_item(image_format, image_base64, settings))
continue
image_pair = cls._extract_image_pair(item)
if image_pair is not None:
image_format, image_base64 = image_pair
parts.append(cls._render_image_item(image_format, image_base64, settings))
continue
if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str):
parts.append(Text(item["text"]))
else:
@@ -218,8 +223,9 @@ class PromptCLIVisualizer:
if isinstance(item, str):
parts.append(item)
continue
if isinstance(item, tuple) and len(item) == 2:
image_format, image_base64 = item
image_pair = cls._extract_image_pair(item)
if image_pair is not None:
image_format, image_base64 = image_pair
approx_size = max(0, len(str(image_base64)) * 3 // 4)
parts.append(f"[图片 image/{image_format} {approx_size} B]")
continue
@@ -242,6 +248,85 @@ class PromptCLIVisualizer:
def format_tool_call_for_display(cls, tool_call: Any) -> Dict[str, Any]:
return normalize_tool_call_for_display(tool_call)
@classmethod
def _build_tool_card_title(cls, tool_call: Any) -> str:
"""构建 HTML 中工具卡片的折叠标题。"""
normalized_tool_call = cls.format_tool_call_for_display(tool_call)
tool_name = str(normalized_tool_call.get("name") or "").strip()
return tool_name or "unknown"
@classmethod
def _build_tool_call_html(cls, tool_call: Any) -> str:
"""将单个工具调用渲染为默认折叠的 HTML 卡片。"""
normalized_tool_call = cls.format_tool_call_for_display(tool_call)
tool_name = cls._build_tool_card_title(tool_call)
tool_call_id = str(normalized_tool_call.get("id") or "").strip()
tool_arguments = normalized_tool_call.get("arguments")
tool_meta_html = ""
if tool_call_id:
tool_meta_html = (
"<div class='tool-card-meta'>"
"<span class='tool-card-meta-label'>调用 ID</span>"
f"<code>{html.escape(tool_call_id)}</code>"
"</div>"
)
return (
"<details class='tool-card tool-call-card'>"
"<summary class='tool-card-summary'>"
f"<span class='tool-card-name'>{html.escape(tool_name)}</span>"
"</summary>"
"<div class='tool-card-body'>"
f"{tool_meta_html}"
f"<pre>{html.escape(json.dumps(tool_arguments, ensure_ascii=False, indent=2, default=str))}</pre>"
"</div>"
"</details>"
)
@classmethod
def _extract_tool_definition_fields(cls, tool_definition: dict[str, Any]) -> tuple[str, str, Any]:
"""提取工具定义中的名称、描述和详情内容。"""
function_info = tool_definition.get("function")
if isinstance(function_info, dict):
tool_name = str(function_info.get("name") or "").strip() or "unknown"
description = str(function_info.get("description") or "").strip()
detail_payload = function_info
else:
tool_name = str(tool_definition.get("name") or "").strip() or "unknown"
description = str(tool_definition.get("description") or "").strip()
detail_payload = tool_definition
return tool_name, description, detail_payload
@classmethod
def _build_tool_definition_html(cls, tool_definition: dict[str, Any]) -> str:
"""将单个传入工具定义渲染为默认折叠的 HTML 卡片。"""
tool_name, description, detail_payload = cls._extract_tool_definition_fields(tool_definition)
description_html = ""
if description:
description_html = (
"<div class='tool-card-meta'>"
"<span class='tool-card-meta-label'>说明</span>"
f"<span>{html.escape(description)}</span>"
"</div>"
)
return (
"<details class='tool-card tool-definition-card'>"
"<summary class='tool-card-summary'>"
f"<span class='tool-card-name'>{html.escape(tool_name)}</span>"
"</summary>"
"<div class='tool-card-body'>"
f"{description_html}"
f"<pre>{html.escape(json.dumps(detail_payload, ensure_ascii=False, indent=2, default=str))}</pre>"
"</div>"
"</details>"
)
@classmethod
def _render_tool_call_panel(cls, tool_call: Any, index: int, parent_index: int) -> Panel:
title = Text.assemble(
@@ -291,6 +376,20 @@ class PromptCLIVisualizer:
return "\n\n" + ("\n\n" + ("=" * 80) + "\n\n").join(sections) if sections else "[空 Prompt]"
@classmethod
def _build_tool_definition_dump_text(cls, tool_definitions: list[dict[str, Any]] | None) -> str:
"""构建传入工具定义的文本备份内容。"""
if not tool_definitions:
return ""
sections: List[str] = ["[tool_definitions]"]
for index, tool_definition in enumerate(tool_definitions, start=1):
tool_name, _, detail_payload = cls._extract_tool_definition_fields(tool_definition)
sections.append(f"[{index}] name={tool_name}")
sections.append(json.dumps(detail_payload, ensure_ascii=False, indent=2, default=str))
return "\n\n".join(sections).strip()
@classmethod
def _render_message_content_html(cls, content: Any) -> str:
if isinstance(content, str):
@@ -302,8 +401,9 @@ class PromptCLIVisualizer:
if isinstance(item, str):
parts.append(f"<pre>{html.escape(item)}</pre>")
continue
if isinstance(item, tuple) and len(item) == 2:
image_format, image_base64 = item
image_pair = cls._extract_image_pair(item)
if image_pair is not None:
image_format, image_base64 = image_pair
image_html = cls._render_image_item_html(str(image_format), str(image_base64))
parts.append(image_html)
continue
@@ -332,14 +432,44 @@ class PromptCLIVisualizer:
)
file_uri, file_path = path_result
display_path = build_display_path(file_path)
return (
"<div class='image-card'>"
f"<div class='image-meta'>图片 image/{html.escape(normalized_format)} {html.escape(size_text)}</div>"
f"<div class='image-path'>{html.escape(str(file_path))}</div>"
f"<a class='image-preview-link' href='{html.escape(file_uri, quote=True)}'>"
f"<img class='image-preview' src='{html.escape(file_uri, quote=True)}' alt='图片预览' />"
"</a>"
f"<div class='image-path'>{html.escape(display_path)}</div>"
f"<a class='image-link' href='{html.escape(file_uri, quote=True)}'>打开图片</a>"
"</div>"
)
@staticmethod
def _build_preview_access_body(
*,
viewer_label: str,
viewer_path: Path,
viewer_link_text: str,
dump_label: str,
dump_path: Path,
dump_link_text: str,
) -> RenderableType:
viewer_uri = build_file_uri(viewer_path)
dump_uri = build_file_uri(dump_path)
viewer_display_path = build_display_path(viewer_path)
dump_display_path = build_display_path(dump_path)
return Group(
Text.from_markup(
f"[bold green]{viewer_label}{viewer_display_path}[/bold green] "
f"[link={viewer_uri}]{viewer_link_text}[/link]"
),
Text.from_markup(
f"[magenta]{dump_label}{dump_display_path}[/magenta] "
f"[cyan][link={dump_uri}]{dump_link_text}[/link][/cyan]"
),
)
@classmethod
def _build_html_role_class(cls, role: str) -> str:
return {
@@ -356,6 +486,7 @@ class PromptCLIVisualizer:
*,
request_kind: str,
selection_reason: str,
tool_definitions: list[dict[str, Any]] | None = None,
) -> str:
panel_title, _ = cls.get_request_panel_style(request_kind)
message_cards: List[str] = []
@@ -378,16 +509,12 @@ class PromptCLIVisualizer:
tool_panels = ""
raw_tool_calls = message.get("tool_calls") or []
if isinstance(raw_tool_calls, list) and raw_tool_calls:
tool_items = []
for tool_call_index, tool_call in enumerate(raw_tool_calls, start=1):
normalized_tool_call = cls.format_tool_call_for_display(tool_call)
tool_items.append(
"<div class='tool-panel'>"
f"<div class='tool-panel-title'>工具调用 #{index}.{tool_call_index}</div>"
f"<pre>{html.escape(json.dumps(normalized_tool_call, ensure_ascii=False, indent=2, default=str))}</pre>"
"</div>"
)
tool_panels = "".join(tool_items)
tool_panels = (
"<div class='tool-list'>"
"<div class='tool-list-title'>工具调用</div>"
f"{''.join(cls._build_tool_call_html(tool_call) for tool_call in raw_tool_calls)}"
"</div>"
)
message_cards.append(
"<section class='message-card'>"
@@ -405,6 +532,21 @@ class PromptCLIVisualizer:
if selection_reason.strip():
subtitle_html = f"<div class='subtitle'>{html.escape(selection_reason)}</div>"
tool_definition_section_html = ""
if tool_definitions:
tool_definition_section_html = (
"<section class='message-card tool-definition-section'>"
"<div class='message-head'>"
"<span class='role-badge tool'>全部工具</span>"
f"<span class='message-index'>{len(tool_definitions)} 个</span>"
"</div>"
"<div class='tool-list'>"
"<div class='tool-list-title'>本次送入模型的工具定义</div>"
f"{''.join(cls._build_tool_definition_html(tool_definition) for tool_definition in tool_definitions)}"
"</div>"
"</section>"
)
return f"""<!DOCTYPE html>
<html lang="zh-CN">
<head>
@@ -491,7 +633,7 @@ class PromptCLIVisualizer:
font-weight: 600;
}}
.message-content pre,
.tool-panel pre {{
.tool-card pre {{
margin: 0;
white-space: pre-wrap;
word-break: break-word;
@@ -517,18 +659,81 @@ class PromptCLIVisualizer:
border-radius: 8px;
padding: 3px 8px;
}}
.tool-panel {{
.tool-list {{
margin-top: 14px;
}}
.tool-list-title {{
color: #86198f;
font-size: 13px;
font-weight: 800;
margin-bottom: 10px;
}}
.tool-card {{
margin-top: 12px;
background: #fcf4ff;
border: 1px solid #f0d7fb;
border-radius: 14px;
padding: 12px 14px;
overflow: hidden;
}}
.tool-panel-title {{
color: #a21caf;
.tool-call-card {{
border-color: #ff8700;
}}
.tool-card:first-of-type {{
margin-top: 0;
}}
.tool-card-summary {{
list-style: none;
cursor: pointer;
display: flex;
align-items: center;
justify-content: space-between;
padding: 12px 14px;
color: #86198f;
font-size: 13px;
font-weight: 800;
}}
.tool-card-summary::-webkit-details-marker {{
display: none;
}}
.tool-card-summary::after {{
content: "展开";
color: #a21caf;
font-size: 12px;
font-weight: 700;
margin-bottom: 8px;
}}
.tool-card[open] .tool-card-summary::after {{
content: "收起";
}}
.tool-card-name {{
word-break: break-word;
}}
.tool-card-body {{
border-top: 1px solid #f0d7fb;
padding: 12px 14px;
background: rgba(255, 255, 255, 0.52);
}}
.tool-call-card .tool-card-body {{
border-top-color: #ff8700;
}}
.tool-card-meta {{
margin-bottom: 10px;
color: #a21caf;
display: flex;
gap: 10px;
align-items: center;
flex-wrap: wrap;
}}
.tool-card-meta-label {{
font-weight: 700;
}}
.tool-card-meta code {{
background: #faf5ff;
border: 1px solid #e9d5ff;
border-radius: 8px;
padding: 3px 8px;
}}
.tool-card pre {{
color: #3b0764;
}}
.image-card {{
background: #f8fafc;
@@ -547,6 +752,22 @@ class PromptCLIVisualizer:
font-family: "Cascadia Mono", "JetBrains Mono", "Consolas", monospace;
word-break: break-all;
}}
.image-preview-link {{
display: block;
margin-top: 10px;
}}
.image-preview {{
display: block;
max-width: min(100%, 560px);
max-height: 420px;
width: auto;
height: auto;
border-radius: 12px;
border: 1px solid #dbe4f0;
background: #fff;
box-shadow: 0 8px 20px rgba(15, 23, 42, 0.08);
object-fit: contain;
}}
.image-link {{
display: inline-block;
margin-top: 8px;
@@ -564,6 +785,7 @@ class PromptCLIVisualizer:
{subtitle_html}
</header>
{''.join(message_cards)}
{tool_definition_section_html}
</main>
</body>
</html>"""
@@ -578,6 +800,7 @@ class PromptCLIVisualizer:
request_kind: str,
selection_reason: str,
image_display_mode: Literal["legacy", "path_link"],
tool_definitions: list[dict[str, Any]] | None = None,
) -> RenderableType:
"""构建用于查看完整 prompt 的折叠入口内容。"""
@@ -603,10 +826,14 @@ class PromptCLIVisualizer:
viewer_messages.append(normalized_message)
prompt_dump_text = cls._build_prompt_dump_text(messages)
tool_definition_dump_text = cls._build_tool_definition_dump_text(tool_definitions)
if tool_definition_dump_text:
prompt_dump_text = f"{prompt_dump_text}\n\n{'=' * 80}\n\n{tool_definition_dump_text}"
viewer_html_text = cls._build_prompt_viewer_html(
viewer_messages,
request_kind=request_kind,
selection_reason=selection_reason,
tool_definitions=tool_definitions,
)
saved_paths = PromptPreviewLogger.save_preview_files(
chat_id,
@@ -618,18 +845,13 @@ class PromptCLIVisualizer:
)
viewer_html_path = saved_paths[".html"]
prompt_dump_path = saved_paths[".txt"]
viewer_uri = cls._build_file_uri(viewer_html_path)
dump_uri = cls._build_file_uri(prompt_dump_path)
body = Group(
Text.from_markup(
f"[bold green]富文本预览:{viewer_html_path}[/bold green] "
f"[link={viewer_uri}]点击在浏览器打开富文本 Prompt 视图[/link]"
),
Text.from_markup(
f"[magenta]原始文本备份:{prompt_dump_path}[/magenta] "
f"[cyan][link={dump_uri}]点击直接打开 Prompt 文本[/link][/cyan]"
),
body = cls._build_preview_access_body(
viewer_label="html预览",
viewer_path=viewer_html_path,
viewer_link_text="在浏览器打开 Prompt",
dump_label="原始文本",
dump_path=prompt_dump_path,
dump_link_text="点击打开 Prompt 文本",
)
return body
@@ -644,6 +866,7 @@ class PromptCLIVisualizer:
selection_reason: str,
image_display_mode: Literal["legacy", "path_link"],
folded: bool,
tool_definitions: list[dict[str, Any]] | None = None,
) -> Panel:
"""构建用于嵌入结果面板中的 Prompt 区块。"""
@@ -656,6 +879,7 @@ class PromptCLIVisualizer:
request_kind=request_kind,
selection_reason=selection_reason,
image_display_mode=image_display_mode,
tool_definitions=tool_definitions,
)
else:
ordered_panels = cls.build_prompt_panels(
@@ -782,18 +1006,13 @@ class PromptCLIVisualizer:
)
viewer_html_path = saved_paths[".html"]
text_dump_path = saved_paths[".txt"]
viewer_uri = cls._build_file_uri(viewer_html_path)
dump_uri = cls._build_file_uri(text_dump_path)
body = Group(
Text.from_markup(
f"[bold green]富文本预览:{viewer_html_path}[/bold green] "
f"[link={viewer_uri}]点击在浏览器打开富文本 Prompt 视图[/link]"
),
Text.from_markup(
f"[magenta]原始文本备份:{text_dump_path}[/magenta] "
f"[cyan][link={dump_uri}]点击直接打开 Prompt 文本[/link][/cyan]"
),
body = cls._build_preview_access_body(
viewer_label="富文本预览",
viewer_path=viewer_html_path,
viewer_link_text="点击在浏览器打开富文本 Prompt 视图",
dump_label="原始文本备份",
dump_path=text_dump_path,
dump_link_text="点击直接打开 Prompt 文本",
)
return body

View File

@@ -2,34 +2,29 @@
from __future__ import annotations
import re
import time
from pathlib import Path
from typing import Dict
from uuid import uuid4
from src.config.config import global_config
from .preview_path_utils import build_preview_chat_dir_name, normalize_preview_name
class PromptPreviewLogger:
"""负责保存 Maisaka Prompt 预览文件并控制目录容量。"""
_BASE_DIR = Path("logs") / "maisaka_prompt"
_MAX_PREVIEW_GROUPS_PER_CHAT = 1024
_TRIM_COUNT = 100
_SAFE_NAME_PATTERN = re.compile(r"[^A-Za-z0-9._-]+")
@classmethod
def _get_max_per_chat(cls) -> int:
"""从配置中获取每个聊天流最大保存的预览数量。"""
return getattr(global_config.chat, "plan_reply_log_max_per_chat", 1000)
@classmethod
def _normalize_chat_id(cls, chat_id: str) -> str:
normalized_chat_id = cls._SAFE_NAME_PATTERN.sub("_", str(chat_id or "").strip()).strip("._")
if normalized_chat_id:
return normalized_chat_id
return "unknown_chat"
def _build_file_stem(cls, chat_dir: Path) -> str:
base_stem = str(int(time.time() * 1000))
candidate_stem = base_stem
suffix_index = 1
while any((chat_dir / f"{candidate_stem}{suffix}").exists() for suffix in (".html", ".txt")):
candidate_stem = f"{base_stem}_{suffix_index}"
suffix_index += 1
return candidate_stem
@classmethod
def save_preview_files(
@@ -40,10 +35,10 @@ class PromptPreviewLogger:
) -> Dict[str, Path]:
"""保存同一份 Prompt 预览的多个文件并执行超量清理。"""
normalized_category = cls._normalize_chat_id(category)
chat_dir = (cls._BASE_DIR / normalized_category / cls._normalize_chat_id(chat_id)).resolve()
normalized_category = normalize_preview_name(category)
chat_dir = (cls._BASE_DIR / normalized_category / build_preview_chat_dir_name(chat_id)).resolve()
chat_dir.mkdir(parents=True, exist_ok=True)
stem = f"{int(time.time() * 1000)}_{uuid4().hex[:8]}"
stem = cls._build_file_stem(chat_dir)
saved_paths: Dict[str, Path] = {}
try:
for suffix, content in files.items():
@@ -65,15 +60,14 @@ class PromptPreviewLogger:
continue
grouped_files.setdefault(file_path.stem, []).append(file_path)
max_per_chat = cls._get_max_per_chat()
if len(grouped_files) <= max_per_chat:
if len(grouped_files) <= cls._MAX_PREVIEW_GROUPS_PER_CHAT:
return
sorted_groups = sorted(
grouped_files.items(),
key=lambda item: min(path.stat().st_mtime for path in item[1]),
)
overflow_count = len(grouped_files) - max_per_chat
overflow_count = len(grouped_files) - cls._MAX_PREVIEW_GROUPS_PER_CHAT
trim_count = min(len(sorted_groups), max(cls._TRIM_COUNT, overflow_count))
for _, file_group in sorted_groups[:trim_count]:
for old_file in file_group:

View File

@@ -0,0 +1,163 @@
"""Maisaka 阶段状态看板。"""
from __future__ import annotations
from pathlib import Path
from typing import Any, Optional
import json
import os
import subprocess
import sys
import threading
import time
class MaisakaStageStatusBoard:
"""维护 Maisaka 阶段状态,并在独立终端中展示。"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._enabled = False
self._entries: dict[str, dict[str, Any]] = {}
self._viewer_process: Optional[subprocess.Popen[Any]] = None
self._state_file = Path("temp") / "maisaka_stage_status.json"
self._state_file.parent.mkdir(parents=True, exist_ok=True)
def enable(self) -> None:
"""启用阶段状态看板。"""
with self._lock:
if self._enabled:
return
self._enabled = True
self._write_state_locked()
self._ensure_viewer_process_locked()
def disable(self) -> None:
"""禁用阶段状态看板。"""
with self._lock:
self._enabled = False
self._entries.clear()
self._write_state_locked()
process = self._viewer_process
self._viewer_process = None
if process is not None and process.poll() is None:
try:
process.terminate()
except Exception:
pass
def update(
self,
*,
session_id: str,
session_name: str,
stage: str,
detail: str = "",
round_text: str = "",
agent_state: str = "",
) -> None:
"""更新一个会话的阶段状态。"""
with self._lock:
if not self._enabled:
return
now = time.time()
current = self._entries.get(session_id, {})
previous_stage = str(current.get("stage") or "").strip()
stage_started_at = float(current.get("stage_started_at") or now)
if previous_stage != stage:
stage_started_at = now
self._entries[session_id] = {
"session_id": session_id,
"session_name": session_name,
"stage": stage,
"detail": detail,
"round_text": round_text,
"agent_state": agent_state,
"stage_started_at": stage_started_at,
"updated_at": now,
}
self._write_state_locked()
def remove(self, session_id: str) -> None:
"""移除一个会话的阶段状态。"""
with self._lock:
if not self._enabled:
return
self._entries.pop(session_id, None)
self._write_state_locked()
def _write_state_locked(self) -> None:
payload = {
"enabled": self._enabled,
"host_pid": os.getpid(),
"updated_at": time.time(),
"entries": list(self._entries.values()),
}
tmp_file = self._state_file.with_suffix(".tmp")
tmp_file.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
tmp_file.replace(self._state_file)
def _ensure_viewer_process_locked(self) -> None:
if not sys.platform.startswith("win"):
return
if self._viewer_process is not None and self._viewer_process.poll() is None:
return
creationflags = getattr(subprocess, "CREATE_NEW_CONSOLE", 0)
viewer_script = Path(__file__).resolve().with_name("stage_status_viewer.py")
self._viewer_process = subprocess.Popen(
[
sys.executable,
str(viewer_script),
str(self._state_file.resolve()),
],
creationflags=creationflags,
cwd=str(Path.cwd()),
)
_stage_board = MaisakaStageStatusBoard()
def enable_stage_status_board() -> None:
"""启用控制台阶段状态看板。"""
_stage_board.enable()
def disable_stage_status_board() -> None:
"""禁用控制台阶段状态看板。"""
_stage_board.disable()
def update_stage_status(
*,
session_id: str,
session_name: str,
stage: str,
detail: str = "",
round_text: str = "",
agent_state: str = "",
) -> None:
"""更新控制台阶段状态。"""
_stage_board.update(
session_id=session_id,
session_name=session_name,
stage=stage,
detail=detail,
round_text=round_text,
agent_state=agent_state,
)
def remove_stage_status(session_id: str) -> None:
"""移除控制台阶段状态。"""
_stage_board.remove(session_id)

View File

@@ -0,0 +1,93 @@
"""Maisaka 阶段状态看板查看器。"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import json
import os
import sys
import time
import traceback
def _clear_screen() -> None:
os.system("cls" if sys.platform.startswith("win") else "clear")
def _load_state(state_file: Path) -> dict[str, Any]:
if not state_file.exists():
return {}
try:
return json.loads(state_file.read_text(encoding="utf-8"))
except Exception:
return {}
def _render(state: dict[str, Any]) -> str:
entries = state.get("entries")
if not isinstance(entries, list):
entries = []
lines = ["Maisaka 阶段看板", "=" * 72, ""]
if not entries:
lines.append("当前没有活跃会话。")
return "\n".join(lines)
entries = sorted(
[entry for entry in entries if isinstance(entry, dict)],
key=lambda item: str(item.get("session_name") or item.get("session_id") or ""),
)
now = time.time()
for entry in entries:
session_name = str(entry.get("session_name") or entry.get("session_id") or "").strip() or "unknown"
session_id = str(entry.get("session_id") or "").strip()
stage = str(entry.get("stage") or "").strip() or "未知"
detail = str(entry.get("detail") or "").strip() or "-"
round_text = str(entry.get("round_text") or "").strip()
agent_state = str(entry.get("agent_state") or "").strip() or "-"
stage_started_at = float(entry.get("stage_started_at") or now)
elapsed = max(0.0, now - stage_started_at)
lines.append(f"Chat: {session_name}")
if session_id and session_id != session_name:
lines.append(f"ID: {session_id}")
lines.append(f"阶段: {stage}")
if round_text:
lines.append(f"轮次: {round_text}")
lines.append(f"详情: {detail}")
lines.append(f"状态: {agent_state}")
lines.append(f"阶段耗时: {elapsed:.1f}s")
lines.append("-" * 72)
return "\n".join(lines)
def main() -> int:
if len(sys.argv) < 2:
return 1
state_file = Path(sys.argv[1]).resolve()
log_file = state_file.with_name("maisaka_stage_status_viewer.log")
last_render = ""
while True:
try:
state = _load_state(state_file)
if not state.get("enabled", False):
return 0
rendered = _render(state)
if rendered != last_render:
_clear_screen()
print(rendered, flush=True)
last_render = rendered
time.sleep(0.5)
except Exception:
log_file.write_text(traceback.format_exc(), encoding="utf-8")
time.sleep(3)
return 1
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,125 @@
"""Maisaka 历史消息轮次结束后处理。"""
from dataclasses import dataclass
from .context_messages import AssistantMessage, LLMContextMessage, ToolResultMessage
from .history_utils import drop_leading_orphan_tool_results, drop_orphan_tool_results
TIMING_HISTORY_TOOL_NAMES = {"continue", "finish", "no_reply", "wait"}
EARLY_TRIM_RATIO = 0.2
@dataclass(slots=True)
class HistoryPostProcessResult:
"""历史后处理结果。"""
history: list[LLMContextMessage]
removed_count: int
remaining_context_count: int
def process_chat_history_after_cycle(
chat_history: list[LLMContextMessage],
*,
max_context_size: int,
) -> HistoryPostProcessResult:
"""在每轮结束后统一执行历史裁切与清理。"""
processed_history = list(chat_history)
removed_timing_tool_count = _remove_early_timing_tool_records(processed_history)
removed_assistant_thought_count = _remove_early_assistant_thoughts(processed_history)
processed_history, orphan_removed_count = drop_orphan_tool_results(processed_history)
remaining_context_count = sum(1 for message in processed_history if message.count_in_context)
removed_overflow_count = 0
while remaining_context_count > max_context_size and processed_history:
removed_message = processed_history.pop(0)
removed_overflow_count += 1
if removed_message.count_in_context:
remaining_context_count -= 1
processed_history, leading_orphan_removed_count = drop_leading_orphan_tool_results(processed_history)
removed_overflow_count += leading_orphan_removed_count
remaining_context_count = sum(1 for message in processed_history if message.count_in_context)
removed_count = (
removed_timing_tool_count
+ removed_assistant_thought_count
+ orphan_removed_count
+ removed_overflow_count
)
return HistoryPostProcessResult(
history=processed_history,
removed_count=removed_count,
remaining_context_count=remaining_context_count,
)
def _remove_early_timing_tool_records(chat_history: list[LLMContextMessage]) -> int:
"""移除最早 20% 的门控/结束类工具链记录。"""
candidate_assistant_indexes = [
index
for index, message in enumerate(chat_history)
if _is_timing_tool_assistant_message(message)
]
remove_count = int(len(candidate_assistant_indexes) * EARLY_TRIM_RATIO)
if remove_count <= 0:
return 0
removed_indexes = set(candidate_assistant_indexes[:remove_count])
removed_tool_call_ids = {
tool_call.call_id
for index in removed_indexes
for tool_call in chat_history[index].tool_calls
if tool_call.call_id
}
filtered_history: list[LLMContextMessage] = []
removed_total = 0
for index, message in enumerate(chat_history):
if index in removed_indexes:
removed_total += 1
continue
if isinstance(message, ToolResultMessage) and message.tool_call_id in removed_tool_call_ids:
removed_total += 1
continue
filtered_history.append(message)
chat_history[:] = filtered_history
return removed_total
def _remove_early_assistant_thoughts(chat_history: list[LLMContextMessage]) -> int:
"""移除最早 20% 的非工具 assistant 思考内容。"""
candidate_indexes = [
index
for index, message in enumerate(chat_history)
if isinstance(message, AssistantMessage)
and not message.tool_calls
and message.source_kind != "perception"
and bool(message.content.strip())
]
remove_count = int(len(candidate_indexes) * EARLY_TRIM_RATIO)
if remove_count <= 0:
return 0
removed_indexes = set(candidate_indexes[:remove_count])
filtered_history: list[LLMContextMessage] = []
removed_total = 0
for index, message in enumerate(chat_history):
if index in removed_indexes:
removed_total += 1
continue
filtered_history.append(message)
chat_history[:] = filtered_history
return removed_total
def _is_timing_tool_assistant_message(message: LLMContextMessage) -> bool:
if not isinstance(message, AssistantMessage) or not message.tool_calls:
return False
return all(tool_call.func_name in TIMING_HISTORY_TOOL_NAMES for tool_call in message.tool_calls)

View File

@@ -14,7 +14,7 @@ from src.chat.message_receive.message import SessionMessage
from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence
from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt
from src.config.config import global_config
from src.config.config import config_manager, global_config
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.llm_models.exceptions import ReqAbortException
from src.llm_models.payload_content.tool_option import ToolCall
@@ -35,7 +35,8 @@ from .context_messages import (
ToolResultMessage,
contains_complex_message,
)
from .history_utils import build_prefixed_message_sequence, build_session_message_visible_text, drop_leading_orphan_tool_results
from .history_post_processor import process_chat_history_after_cycle
from .history_utils import build_prefixed_message_sequence, build_session_message_visible_text
from .monitor_events import (
emit_cycle_start,
emit_message_ingested,
@@ -53,7 +54,7 @@ logger = get_logger("maisaka_reasoning_engine")
TIMING_GATE_CONTEXT_LIMIT = 24
TIMING_GATE_MAX_TOKENS = 384
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
ACTION_HIDDEN_TOOL_NAMES = {"continue", "no_reply", "wait"}
ACTION_HIDDEN_TOOL_NAMES = {"continue", "no_reply"}
ACTION_BUILTIN_TOOL_NAMES = {tool_spec.name for tool_spec in get_action_tool_specs()}
@@ -94,6 +95,7 @@ class MaisakaReasoningEngine:
async def _run_interruptible_planner(
self,
*,
injected_user_messages: Optional[list[str]] = None,
tool_definitions: Optional[list[dict[str, Any]]] = None,
) -> Any:
"""运行一轮可被新消息打断的主 planner 请求。"""
@@ -105,6 +107,7 @@ class MaisakaReasoningEngine:
try:
return await self._runtime._chat_loop_service.chat_loop_step(
self._runtime._chat_history,
injected_user_messages=injected_user_messages,
tool_definitions=tool_definitions,
)
except ReqAbortException:
@@ -117,36 +120,27 @@ class MaisakaReasoningEngine:
)
self._runtime._chat_loop_service.set_interrupt_flag(None)
async def _run_interruptible_sub_agent(
async def _run_timing_gate_sub_agent(
self,
*,
context_message_limit: int,
system_prompt: str,
tool_definitions: list[dict[str, Any]],
) -> Any:
"""运行一轮可被新消息打断的临时子代理请求。"""
"""运行一轮 Timing Gate 子代理请求。
interrupt_flag = asyncio.Event()
interrupted = False
self._runtime._bind_planner_interrupt_flag(interrupt_flag)
try:
return await self._runtime.run_sub_agent(
context_message_limit=context_message_limit,
system_prompt=system_prompt,
request_kind="timing_gate",
interrupt_flag=interrupt_flag,
max_tokens=TIMING_GATE_MAX_TOKENS,
temperature=0.1,
tool_definitions=tool_definitions,
)
except ReqAbortException:
interrupted = True
raise
finally:
self._runtime._unbind_planner_interrupt_flag(
interrupt_flag,
interrupted=interrupted,
)
Timing Gate 阶段不再响应新的 planner 打断,只有主 planner 阶段允许被打断。
"""
return await self._runtime.run_sub_agent(
context_message_limit=context_message_limit,
system_prompt=system_prompt,
request_kind="timing_gate",
interrupt_flag=None,
max_tokens=TIMING_GATE_MAX_TOKENS,
temperature=0.1,
tool_definitions=tool_definitions,
)
@staticmethod
def _build_timing_gate_fallback_prompt() -> str:
@@ -174,22 +168,34 @@ class MaisakaReasoningEngine:
except Exception:
return self._build_timing_gate_fallback_prompt()
async def _build_action_tool_definitions(self) -> list[dict[str, Any]]:
"""构造 Action Loop 阶段可见的工具定义。"""
async def _build_action_tool_definitions(self) -> tuple[list[dict[str, Any]], str]:
"""构造 Action Loop 阶段可见的工具定义与 deferred tools 提示"""
if self._runtime._tool_registry is None:
return []
self._runtime.update_deferred_tool_specs([])
self._runtime.set_current_action_tool_names([])
return [], ""
tool_specs = await self._runtime._tool_registry.list_tools()
return [
tool_spec.to_llm_definition()
for tool_spec in tool_specs
if tool_spec.name not in ACTION_HIDDEN_TOOL_NAMES
and (
tool_spec.provider_name != "maisaka_builtin"
or tool_spec.name in ACTION_BUILTIN_TOOL_NAMES
)
]
visible_builtin_tool_specs: list[ToolSpec] = []
deferred_tool_specs: list[ToolSpec] = []
for tool_spec in tool_specs:
if tool_spec.name in ACTION_HIDDEN_TOOL_NAMES:
continue
if tool_spec.provider_name == "maisaka_builtin":
if tool_spec.name in ACTION_BUILTIN_TOOL_NAMES:
visible_builtin_tool_specs.append(tool_spec)
continue
deferred_tool_specs.append(tool_spec)
self._runtime.update_deferred_tool_specs(deferred_tool_specs)
discovered_deferred_tool_specs = self._runtime.get_discovered_deferred_tool_specs()
visible_tool_specs = [*visible_builtin_tool_specs, *discovered_deferred_tool_specs]
self._runtime.set_current_action_tool_names([tool_spec.name for tool_spec in visible_tool_specs])
return (
[tool_spec.to_llm_definition() for tool_spec in visible_tool_specs],
self._runtime.build_deferred_tools_reminder(),
)
async def _invoke_tool_call(
self,
@@ -227,18 +233,19 @@ class MaisakaReasoningEngine:
async def _run_timing_gate(
self,
anchor_message: SessionMessage,
) -> tuple[Literal["continue", "no_reply", "wait"], Any, list[str]]:
) -> tuple[Literal["continue", "no_reply", "wait"], Any, list[str], list[dict[str, Any]]]:
"""运行 Timing Gate 子代理并返回控制决策。"""
if self._runtime._force_continue_until_reply:
if self._runtime._force_next_timing_continue:
return self._build_forced_continue_timing_result()
response = await self._run_interruptible_sub_agent(
response = await self._run_timing_gate_sub_agent(
context_message_limit=TIMING_GATE_CONTEXT_LIMIT,
system_prompt=self._build_timing_gate_system_prompt(),
tool_definitions=get_timing_tools(),
)
tool_result_summaries: list[str] = []
tool_monitor_results: list[dict[str, Any]] = []
selected_tool_call: Optional[ToolCall] = None
for tool_call in response.tool_calls:
if tool_call.func_name in TIMING_GATE_TOOL_NAMES:
@@ -247,11 +254,11 @@ class MaisakaReasoningEngine:
if selected_tool_call is None:
logger.warning(f"{self._runtime.log_prefix} Timing Gate 未返回有效控制工具,默认继续执行 Action Loop")
return "continue", response, tool_result_summaries
return "continue", response, tool_result_summaries, tool_monitor_results
append_history = selected_tool_call.func_name != "continue"
append_history = False
store_record = selected_tool_call.func_name != "continue"
_, result, _ = await self._invoke_tool_call(
invocation, result, tool_spec = await self._invoke_tool_call(
selected_tool_call,
response.content or "",
anchor_message,
@@ -259,19 +266,31 @@ class MaisakaReasoningEngine:
store_record=store_record,
)
tool_result_summaries.append(self._build_tool_result_summary(selected_tool_call, result))
tool_monitor_results.append(
self._build_tool_monitor_result(
selected_tool_call,
invocation,
result,
duration_ms=0.0,
tool_spec=tool_spec,
)
)
self._append_timing_gate_execution_result(response, selected_tool_call, result)
timing_action = str(result.metadata.get("timing_action") or selected_tool_call.func_name).strip()
if timing_action not in TIMING_GATE_TOOL_NAMES:
logger.warning(
f"{self._runtime.log_prefix} Timing Gate 返回未知动作 {timing_action!r},将按 continue 处理"
)
return "continue", response, tool_result_summaries
return timing_action, response, tool_result_summaries
return "continue", response, tool_result_summaries, tool_monitor_results
return timing_action, response, tool_result_summaries, tool_monitor_results
def _build_forced_continue_timing_result(self) -> tuple[Literal["continue"], ChatResponse, list[str]]:
def _build_forced_continue_timing_result(
self,
) -> tuple[Literal["continue"], ChatResponse, list[str], list[dict[str, Any]]]:
"""构造跳过 Timing Gate 时使用的伪 continue 结果。"""
reason = self._runtime._build_force_continue_timing_reason()
reason = self._runtime._consume_force_next_timing_continue_reason() or "本轮直接跳过 Timing Gate 并视作 continue。"
logger.info(f"{self._runtime.log_prefix} {reason}")
return (
"continue",
@@ -296,8 +315,24 @@ class MaisakaReasoningEngine:
prompt_section=None,
),
[f"- continue [强制跳过]: {reason}"],
[],
)
@staticmethod
def _mark_timing_gate_completed(timing_action: str) -> bool:
"""根据门控动作决定下一轮是否还需要重新执行 timing。"""
return timing_action != "continue"
@staticmethod
def _should_retry_planner_after_interrupt(
*,
round_index: int,
max_internal_rounds: int,
has_pending_messages: bool,
) -> bool:
return has_pending_messages and round_index + 1 < max_internal_rounds
async def run_loop(self) -> None:
"""独立消费消息批次,并执行对应的内部思考轮次。"""
try:
@@ -314,13 +349,20 @@ class MaisakaReasoningEngine:
if self._runtime._has_pending_messages()
else []
)
if not timeout_triggered and not cached_messages and not message_triggered:
if not timeout_triggered and not cached_messages:
continue
self._runtime._agent_state = self._runtime._STATE_RUNNING
self._runtime._update_stage_status(
"消息整理",
f"待处理消息 {len(cached_messages)}" if cached_messages else "准备复用超时锚点",
)
if cached_messages:
asyncio.create_task(self._runtime._trigger_batch_learning(cached_messages))
self._append_wait_interrupted_message_if_needed()
if timeout_triggered:
self._runtime._chat_history.append(
self._build_wait_completed_message(has_new_messages=True)
)
await self._ingest_messages(cached_messages)
anchor_message = cached_messages[-1]
else:
@@ -332,13 +374,16 @@ class MaisakaReasoningEngine:
continue
logger.info(f"{self._runtime.log_prefix} 等待超时后开始新一轮思考")
if self._runtime._pending_wait_tool_call_id:
self._runtime._chat_history.append(self._build_wait_timeout_message())
self._trim_chat_history()
self._runtime._chat_history.append(
self._build_wait_completed_message(has_new_messages=False)
)
try:
timing_gate_required = True
for round_index in range(self._runtime._max_internal_rounds):
cycle_detail = self._start_cycle()
round_text = f"{round_index + 1}/{self._runtime._max_internal_rounds}"
self._runtime._log_cycle_started(cycle_detail, round_index)
self._runtime._update_stage_status("启动循环", f"循环 {cycle_detail.cycle_id}", round_text=round_text)
await emit_cycle_start(
session_id=self._runtime.session_id,
cycle_id=cycle_detail.cycle_id,
@@ -349,10 +394,14 @@ class MaisakaReasoningEngine:
planner_started_at = 0.0
planner_duration_ms = 0.0
timing_duration_ms = 0.0
current_stage_started_at = 0.0
timing_action: Optional[str] = None
timing_response: Optional[ChatResponse] = None
timing_tool_results: Optional[list[str]] = None
timing_tool_monitor_results: Optional[list[dict[str, Any]]] = None
response: Optional[ChatResponse] = None
action_tool_definitions: list[dict[str, Any]] = []
planner_extra_lines: list[str] = []
tool_result_summaries: list[str] = []
tool_monitor_results: list[dict[str, Any]] = []
try:
@@ -364,30 +413,46 @@ class MaisakaReasoningEngine:
f"{self._runtime.log_prefix} 本轮思考前已刷新 {refreshed_message_count} 条视觉占位历史消息"
)
timing_started_at = time.time()
timing_action, timing_response, timing_tool_results = await self._run_timing_gate(anchor_message)
timing_duration_ms = (time.time() - timing_started_at) * 1000
cycle_detail.time_records["timing_gate"] = timing_duration_ms / 1000
await emit_timing_gate_result(
session_id=self._runtime.session_id,
cycle_id=cycle_detail.cycle_id,
action=timing_action,
content=timing_response.content,
tool_calls=timing_response.tool_calls,
messages=[],
prompt_tokens=timing_response.prompt_tokens,
selected_history_count=timing_response.selected_history_count,
duration_ms=timing_duration_ms,
)
if timing_action != "continue":
logger.info(
f"{self._runtime.log_prefix} Timing Gate 结束当前回合: "
f"回合={round_index + 1} 动作={timing_action}"
if timing_gate_required:
self._runtime._update_stage_status("Timing Gate", "等待门控决策", round_text=round_text)
current_stage_started_at = time.time()
timing_started_at = time.time()
(
timing_action,
timing_response,
timing_tool_results,
timing_tool_monitor_results,
) = await self._run_timing_gate(anchor_message)
timing_duration_ms = (time.time() - timing_started_at) * 1000
cycle_detail.time_records["timing_gate"] = timing_duration_ms / 1000
await emit_timing_gate_result(
session_id=self._runtime.session_id,
cycle_id=cycle_detail.cycle_id,
action=timing_action,
content=timing_response.content,
tool_calls=timing_response.tool_calls,
messages=[],
prompt_tokens=timing_response.prompt_tokens,
selected_history_count=timing_response.selected_history_count,
duration_ms=timing_duration_ms,
)
timing_gate_required = self._mark_timing_gate_completed(timing_action)
if timing_action != "continue":
logger.debug(
f"{self._runtime.log_prefix} Timing Gate 结束当前回合: "
f"回合={round_index + 1} 动作={timing_action}"
)
break
else:
logger.info(
f"{self._runtime.log_prefix} 跳过 Timing Gate继续执行 Planner: "
f"回合={round_index + 1}"
)
break
planner_started_at = time.time()
action_tool_definitions = await self._build_action_tool_definitions()
current_stage_started_at = planner_started_at
self._runtime._update_stage_status("Planner", "组织上下文并请求模型", round_text=round_text)
action_tool_definitions, deferred_tools_reminder = await self._build_action_tool_definitions()
logger.info(
f"{self._runtime.log_prefix} 规划器开始执行: "
f"回合={round_index + 1} "
@@ -395,6 +460,7 @@ class MaisakaReasoningEngine:
f"开始时间={planner_started_at:.3f}"
)
response = await self._run_interruptible_planner(
injected_user_messages=[deferred_tools_reminder] if deferred_tools_reminder else None,
tool_definitions=action_tool_definitions,
)
planner_duration_ms = (time.time() - planner_started_at) * 1000
@@ -406,8 +472,8 @@ class MaisakaReasoningEngine:
)
reasoning_content = response.content or ""
if self._should_replace_reasoning(reasoning_content):
response.content = "我应该根据我上面思考的内容进行反思,重新思考我下一步的行动,我需要分析当前场景,对话,以及我可以使用的工具,然后先输出想法再使用工具"
response.raw_message.content = "我应该根据我上面思考的内容进行反思,重新思考我下一步的行动,我需要分析当前场景,对话,以及我可以使用的工具,然后先输出想法再使用工具"
response.content = "我应该根据我上面思考的内容进行反思,重新思考我下一步的行动,我需要分析当前场景,对话,以及我可以使用的工具,然后直接输出我的想法"
response.raw_message.content = "我应该根据我上面思考的内容进行反思,重新思考我下一步的行动,我需要分析当前场景,对话,以及我可以使用的工具,然后直接输出我的想法"
logger.info(f"{self._runtime.log_prefix} 当前思考与上一轮过于相似,已替换为重新思考提示")
self._last_reasoning_content = reasoning_content
@@ -428,20 +494,73 @@ class MaisakaReasoningEngine:
if not response.content:
break
except ReqAbortException:
interrupted_at = time.time()
logger.info(
f"{self._runtime.log_prefix} 规划器打断成功: "
f"回合={round_index + 1} "
f"开始时间={planner_started_at:.3f} "
f"打断时间={interrupted_at:.3f} "
f"耗时={interrupted_at - planner_started_at:.3f}"
except ReqAbortException as exc:
self._runtime._update_stage_status(
"Planner 已打断",
str(exc) or "收到外部中断信号",
round_text=round_text,
)
break
interrupted_at = time.time()
interrupted_stage_label = "Planner"
interrupted_text = "Planner 收到新消息,开始重新决策"
interrupted_response = ChatResponse(
content=interrupted_text or None,
tool_calls=[],
request_messages=[],
raw_message=AssistantMessage(
content=interrupted_text,
timestamp=datetime.now(),
tool_calls=[],
source_kind="perception",
),
selected_history_count=len(self._runtime._chat_history),
tool_count=len(action_tool_definitions),
prompt_tokens=0,
built_message_count=0,
completion_tokens=0,
total_tokens=0,
prompt_section=None,
)
interrupted_extra_lines = [
"状态:已被新消息打断",
f"打断位置:{interrupted_stage_label} 请求流式响应阶段",
f"打断耗时:{interrupted_at - current_stage_started_at:.3f}",
]
response = interrupted_response
planner_extra_lines = interrupted_extra_lines
logger.info(
f"{self._runtime.log_prefix} {interrupted_stage_label} 打断成功: "
f"回合={round_index + 1} "
f"开始时间={current_stage_started_at:.3f} "
f"打断时间={interrupted_at:.3f} "
f"耗时={interrupted_at - current_stage_started_at:.3f}"
)
if not self._should_retry_planner_after_interrupt(
round_index=round_index,
max_internal_rounds=self._runtime._max_internal_rounds,
has_pending_messages=self._runtime._has_pending_messages(),
):
break
await self._runtime._wait_for_message_quiet_period()
self._runtime._message_turn_scheduled = False
interrupted_messages = self._runtime._collect_pending_messages()
if not interrupted_messages:
break
asyncio.create_task(self._runtime._trigger_batch_learning(interrupted_messages))
await self._ingest_messages(interrupted_messages)
anchor_message = interrupted_messages[-1]
logger.info(
f"{self._runtime.log_prefix} 淇濇寔娲昏穬鐘舵€侊紝璺宠繃 Timing Gate 鐩存帴閲嶈瘯 Planner: "
f"鍥炲悎={round_index + 2}"
)
continue
finally:
completed_cycle = self._end_cycle(cycle_detail)
self._runtime._render_context_usage_panel(
cycle_id=cycle_detail.cycle_id,
time_records=dict(completed_cycle.time_records),
timing_selected_history_count=(
timing_response.selected_history_count if timing_response is not None else None
),
@@ -452,6 +571,7 @@ class MaisakaReasoningEngine:
timing_response=timing_response.content or "" if timing_response is not None else "",
timing_tool_calls=timing_response.tool_calls if timing_response is not None else None,
timing_tool_results=timing_tool_results,
timing_tool_detail_results=timing_tool_monitor_results,
timing_prompt_section=(
timing_response.prompt_section if timing_response is not None else None
),
@@ -464,6 +584,7 @@ class MaisakaReasoningEngine:
planner_tool_results=tool_result_summaries,
planner_tool_detail_results=tool_monitor_results,
planner_prompt_section=response.prompt_section if response is not None else None,
planner_extra_lines=planner_extra_lines,
)
await emit_planner_finalized(
session_id=self._runtime.session_id,
@@ -505,6 +626,8 @@ class MaisakaReasoningEngine:
finally:
if self._runtime._agent_state == self._runtime._STATE_RUNNING:
self._runtime._agent_state = self._runtime._STATE_STOP
if self._runtime._running:
self._runtime._update_stage_status("等待消息", "本轮处理结束")
except asyncio.CancelledError:
self._runtime._log_internal_loop_cancelled()
raise
@@ -543,33 +666,22 @@ class MaisakaReasoningEngine:
return self._runtime.message_cache[-1]
return None
def _build_wait_timeout_message(self) -> ToolResultMessage:
"""构造 wait 超时后的工具结果消息。"""
def _build_wait_completed_message(self, *, has_new_messages: bool) -> ToolResultMessage:
"""构造 wait 完成后的工具结果消息。"""
tool_call_id = self._runtime._pending_wait_tool_call_id or "wait_timeout"
self._runtime._pending_wait_tool_call_id = None
content = (
"等待已结束,期间收到了新的用户输入。请结合这些新消息继续下一轮思考。"
if has_new_messages
else "等待已超时,期间没有收到新的用户输入。请基于现有上下文继续下一轮思考。"
)
return ToolResultMessage(
content="等待已超时,期间没有收到新的用户输入。请基于现有上下文继续下一轮思考。",
content=content,
timestamp=datetime.now(),
tool_call_id=tool_call_id,
tool_name="wait",
)
def _append_wait_interrupted_message_if_needed(self) -> None:
"""如果 wait 被新消息打断,则补一条对应的工具结果消息。"""
tool_call_id = self._runtime._pending_wait_tool_call_id
if not tool_call_id:
return
self._runtime._pending_wait_tool_call_id = None
self._runtime._chat_history.append(
ToolResultMessage(
content="等待过程被新的用户输入打断,已继续处理最新消息。",
timestamp=datetime.now(),
tool_call_id=tool_call_id,
tool_name="wait",
)
)
async def _ingest_messages(self, messages: list[SessionMessage]) -> None:
"""处理传入消息列表,将其转换为历史消息并加入聊天历史缓存。"""
for message in messages:
@@ -578,7 +690,6 @@ class MaisakaReasoningEngine:
continue
self._insert_chat_history_message(history_message)
self._trim_chat_history()
# 向监控前端广播新消息注入事件
user_info = message.message_info.user_info
@@ -628,10 +739,47 @@ class MaisakaReasoningEngine:
planner_prefix: str,
) -> MessageSequence:
message_sequence = build_prefixed_message_sequence(message.raw_message, planner_prefix)
if global_config.visual.multimodal_planner:
if self._resolve_enable_visual_planner():
await self._hydrate_visual_components(message_sequence.components)
return message_sequence
@staticmethod
def _resolve_enable_visual_planner() -> bool:
planner_mode = global_config.visual.planner_mode
planner_task_config = config_manager.get_model_config().model_task_config.planner
models_by_name = {model.name: model for model in config_manager.get_model_config().models}
if planner_mode == "text":
return False
planner_models: list[str] = list(planner_task_config.model_list)
missing_models = [model_name for model_name in planner_models if model_name not in models_by_name]
non_visual_models = [
model_name for model_name in planner_models if model_name in models_by_name and not models_by_name[model_name].visual
]
if planner_mode == "multimodal":
if missing_models:
raise ValueError(
"planner_mode=multimodal但 planner 任务存在未定义的模型:"
f"{', '.join(missing_models)}"
)
if non_visual_models:
raise ValueError(
"planner_mode=multimodal但 planner 任务存在未开启 visual 的模型:"
f"{', '.join(non_visual_models)}"
)
return True
if missing_models:
logger.warning(
"planner_mode=auto 时发现 planner 任务存在未定义模型:"
f"{', '.join(missing_models)},将退化为纯文本 planner"
)
return False
return bool(planner_models) and not non_visual_models
async def _hydrate_visual_components(self, planner_components: list[object]) -> None:
"""在 Maisaka 真正需要图片或表情时,按需回填二进制数据。"""
load_tasks: list[asyncio.Task[None]] = []
@@ -681,6 +829,7 @@ class MaisakaReasoningEngine:
"""结束并记录一轮 Maisaka 思考循环。"""
cycle_detail.end_time = time.time()
self._runtime.history_loop.append(cycle_detail)
self._post_process_chat_history_after_cycle()
timer_strings = [
f"{name}: {duration:.2f}s"
@@ -690,26 +839,20 @@ class MaisakaReasoningEngine:
self._runtime._log_cycle_completed(cycle_detail, timer_strings)
return cycle_detail
def _trim_chat_history(self) -> None:
def _post_process_chat_history_after_cycle(self) -> None:
"""裁剪聊天历史,保证用户消息数量不超过配置限制。"""
conversation_message_count = sum(1 for message in self._runtime._chat_history if message.count_in_context)
if conversation_message_count <= self._runtime._max_context_size:
process_result = process_chat_history_after_cycle(
self._runtime._chat_history,
max_context_size=self._runtime._max_context_size,
)
if process_result.removed_count <= 0:
return
trimmed_history = list(self._runtime._chat_history)
removed_count = 0
while conversation_message_count > self._runtime._max_context_size and trimmed_history:
removed_message = trimmed_history.pop(0)
removed_count += 1
if removed_message.count_in_context:
conversation_message_count -= 1
trimmed_history, pruned_orphan_count = drop_leading_orphan_tool_results(trimmed_history)
removed_count += pruned_orphan_count
self._runtime._chat_history = trimmed_history
self._runtime._log_history_trimmed(removed_count, conversation_message_count)
self._runtime._chat_history = process_result.history
self._runtime._log_history_trimmed(
process_result.removed_count,
process_result.remaining_context_count,
)
@staticmethod
def _calculate_similarity(text1: str, text2: str) -> float:
@@ -934,6 +1077,9 @@ class MaisakaReasoningEngine:
if invocation.tool_name == "no_reply":
return "你暂停了当前对话循环,等待新的外部消息。"
if invocation.tool_name == "finish":
return "你结束了本轮思考,等待新的外部消息后再继续。"
if invocation.tool_name == "continue":
return "你允许当前对话继续进入下一轮完整思考与工具执行。"
@@ -1065,6 +1211,24 @@ class MaisakaReasoningEngine:
)
)
def _append_timing_gate_execution_result(
self,
response: ChatResponse,
tool_call: ToolCall,
result: ToolExecutionResult,
) -> None:
"""将 Timing Gate 的决策链写入历史,供后续门控复用。"""
self._runtime._chat_history.append(
AssistantMessage(
content=response.content or "",
timestamp=response.raw_message.timestamp,
tool_calls=[tool_call],
source_kind="timing_gate",
)
)
self._append_tool_execution_result(tool_call, result)
def _build_tool_result_summary(self, tool_call: ToolCall, result: ToolExecutionResult) -> str:
"""构建用于终端展示的工具结果摘要。"""
@@ -1084,6 +1248,7 @@ class MaisakaReasoningEngine:
invocation: ToolInvocation,
result: ToolExecutionResult,
duration_ms: float,
tool_spec: Optional[ToolSpec] = None,
) -> dict[str, Any]:
"""构建 planner.finalized 中单个工具的监控结果。"""
@@ -1092,9 +1257,20 @@ class MaisakaReasoningEngine:
if monitor_detail is not None:
normalized_detail = self._normalize_tool_record_value(monitor_detail)
monitor_card = result.metadata.get("monitor_card")
normalized_card = None
if monitor_card is not None:
normalized_card = self._normalize_tool_record_value(monitor_card)
monitor_sub_cards = result.metadata.get("monitor_sub_cards")
normalized_sub_cards = None
if monitor_sub_cards is not None:
normalized_sub_cards = self._normalize_tool_record_value(monitor_sub_cards)
return {
"tool_call_id": tool_call.call_id,
"tool_name": tool_call.func_name,
"tool_title": tool_spec.title.strip() if tool_spec is not None and tool_spec.title.strip() else "",
"tool_args": self._normalize_tool_record_value(
invocation.arguments if isinstance(invocation.arguments, dict) else {}
),
@@ -1102,6 +1278,8 @@ class MaisakaReasoningEngine:
"duration_ms": round(duration_ms, 2),
"summary": self._build_tool_result_summary(tool_call, result),
"detail": normalized_detail,
"card": normalized_card,
"sub_cards": normalized_sub_cards,
}
async def _handle_tool_calls(
@@ -1137,7 +1315,7 @@ class MaisakaReasoningEngine:
self._append_tool_execution_result(tool_call, result)
tool_result_summaries.append(self._build_tool_result_summary(tool_call, result))
tool_monitor_results.append(
self._build_tool_monitor_result(tool_call, invocation, result, duration_ms=0.0)
self._build_tool_monitor_result(tool_call, invocation, result, duration_ms=0.0, tool_spec=None)
)
return False, tool_result_summaries, tool_monitor_results
@@ -1146,10 +1324,25 @@ class MaisakaReasoningEngine:
tool_spec.name: tool_spec
for tool_spec in await self._runtime._tool_registry.list_tools()
}
for tool_call in tool_calls:
total_tool_count = len(tool_calls)
for tool_index, tool_call in enumerate(tool_calls, start=1):
invocation = self._build_tool_invocation(tool_call, latest_thought)
self._runtime._update_stage_status(
f"工具执行 · {invocation.tool_name}",
f"{tool_index}/{total_tool_count} 个工具",
)
tool_started_at = time.time()
result = await self._runtime._tool_registry.invoke(invocation, execution_context)
if not self._runtime.is_action_tool_currently_available(invocation.tool_name):
result = ToolExecutionResult(
tool_name=invocation.tool_name,
success=False,
error_message=(
f"工具 {invocation.tool_name} 当前未直接暴露给 planner。"
"如果它在 deferred tools 提示中,请先调用 tool_search。"
),
)
else:
result = await self._runtime._tool_registry.invoke(invocation, execution_context)
tool_duration_ms = (time.time() - tool_started_at) * 1000
await self._store_tool_execution_record(
invocation,
@@ -1159,7 +1352,13 @@ class MaisakaReasoningEngine:
self._append_tool_execution_result(tool_call, result)
tool_result_summaries.append(self._build_tool_result_summary(tool_call, result))
tool_monitor_results.append(
self._build_tool_monitor_result(tool_call, invocation, result, tool_duration_ms)
self._build_tool_monitor_result(
tool_call,
invocation,
result,
tool_duration_ms,
tool_spec=tool_spec_map.get(invocation.tool_name),
)
)
if not result.success and tool_call.func_name == "reply":

View File

@@ -21,7 +21,7 @@ from src.common.data_models.mai_message_data_model import GroupInfo, UserInfo
from src.common.logger import get_logger
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
from src.config.config import global_config
from src.core.tooling import ToolRegistry
from src.core.tooling import ToolRegistry, ToolSpec
from src.learners.expression_learner import ExpressionLearner
from src.learners.jargon_miner import JargonMiner
from src.llm_models.payload_content.resp_format import RespFormat
@@ -30,11 +30,13 @@ from src.mcp_module import MCPManager
from src.mcp_module.host_llm_bridge import MCPHostLLMBridge
from src.mcp_module.provider import MCPToolProvider
from src.plugin_runtime.tool_provider import PluginToolProvider
from src.plugin_runtime.hook_payloads import deserialize_prompt_messages
from .chat_loop_service import ChatResponse, MaisakaChatLoopService
from .context_messages import LLMContextMessage
from .display_utils import build_tool_call_summary_lines, format_token_count
from .prompt_cli_renderer import PromptCLIVisualizer
from .display.display_utils import build_tool_call_summary_lines, format_token_count
from .display.prompt_cli_renderer import PromptCLIVisualizer
from .display.stage_status_board import remove_stage_status, update_stage_status
from .reasoning_engine import MaisakaReasoningEngine
from .tool_provider import MaisakaBuiltinToolProvider
@@ -92,14 +94,16 @@ class MaisakaHeartFlowChatting:
self._max_internal_rounds = MAX_INTERNAL_ROUNDS
self._max_context_size = max(1, int(global_config.chat.max_context_size))
self._agent_state: Literal["running", "wait", "stop"] = self._STATE_STOP
self._wait_until: Optional[float] = None
self._pending_wait_tool_call_id: Optional[str] = None
self._force_continue_until_reply = False
self._force_continue_trigger_message_id = ""
self._force_continue_trigger_reason = ""
self._force_next_timing_continue = False
self._force_next_timing_message_id = ""
self._force_next_timing_reason = ""
self._planner_interrupt_flag: Optional[asyncio.Event] = None
self._planner_interrupt_requested = False
self._planner_interrupt_consecutive_count = 0
self._current_action_tool_names: set[str] = set()
self.discovered_tool_names: set[str] = set()
self.deferred_tool_specs_by_name: dict[str, ToolSpec] = {}
self._planner_interrupt_max_consecutive_count = max(
0,
int(global_config.chat.planner_interrupt_max_consecutive_count),
@@ -118,6 +122,18 @@ class MaisakaHeartFlowChatting:
self._tool_registry = ToolRegistry()
self._register_tool_providers()
def _update_stage_status(self, stage: str, detail: str = "", *, round_text: str = "") -> None:
"""更新当前会话的阶段状态。"""
update_stage_status(
session_id=self.session_id,
session_name=self.session_name,
stage=stage,
detail=detail,
round_text=round_text,
agent_state=self._agent_state,
)
async def start(self) -> None:
"""启动运行时主循环。"""
if self._running:
@@ -130,6 +146,7 @@ class MaisakaHeartFlowChatting:
self._running = True
self._ensure_background_tasks_running()
self._schedule_message_turn()
self._update_stage_status("空闲", "等待消息触发")
logger.info(f"{self.log_prefix} Maisaka 运行时已启动")
async def stop(self) -> None:
@@ -157,6 +174,7 @@ class MaisakaHeartFlowChatting:
await self._tool_registry.close()
self._mcp_manager = None
self._mcp_host_bridge = None
remove_stage_status(self.session_id)
logger.info(f"{self.log_prefix} Maisaka 运行时已停止")
@@ -175,9 +193,6 @@ class MaisakaHeartFlowChatting:
self.message_cache.append(message)
self._message_received_at_by_id[message.message_id] = received_at
self._source_messages_by_id[message.message_id] = message
if self._agent_state == self._STATE_WAIT:
self._cancel_wait_timeout_task()
self._wait_until = None
if self._agent_state == self._STATE_RUNNING:
self._message_debounce_required = True
if self._agent_state == self._STATE_RUNNING and self._planner_interrupt_flag is not None:
@@ -248,7 +263,6 @@ class MaisakaHeartFlowChatting:
def _record_reply_sent(self) -> None:
"""在成功发送 reply 后记录本轮消息回复时长。"""
self._clear_force_continue_until_reply()
if self._reply_latency_measurement_started_at is None:
return
@@ -308,26 +322,26 @@ class MaisakaHeartFlowChatting:
if not message.is_at and not message.is_mentioned:
return
self._arm_force_continue_until_reply(
self._arm_force_next_timing_continue(
message,
is_at=message.is_at,
is_mentioned=message.is_mentioned,
)
def _arm_force_continue_until_reply(
def _arm_force_next_timing_continue(
self,
message: SessionMessage,
*,
is_at: bool,
is_mentioned: bool,
) -> None:
"""在检测到 @ 或提及时,要求后续轮次跳过 Timing Gate 直到成功 reply"""
"""在检测到 @ 或提及时,要求下一次 Timing Gate 直接 continue"""
trigger_reason = "@消息" if is_at else "提及消息" if is_mentioned else "触发消息"
was_armed = self._force_continue_until_reply
self._force_continue_until_reply = True
self._force_continue_trigger_message_id = message.message_id
self._force_continue_trigger_reason = trigger_reason
was_armed = self._force_next_timing_continue
self._force_next_timing_continue = True
self._force_next_timing_message_id = message.message_id
self._force_next_timing_reason = trigger_reason
if was_armed:
logger.info(
@@ -337,34 +351,31 @@ class MaisakaHeartFlowChatting:
return
logger.info(
f"{self.log_prefix} 检测到{trigger_reason}将跳过 Timing Gate 直到成功发送一条 reply"
f"{self.log_prefix} 检测到{trigger_reason}下一次 Timing Gate 将直接视作 continue"
f"消息编号={message.message_id}"
)
def _clear_force_continue_until_reply(self) -> None:
"""在成功发送 reply 后清理强制 continue 状态。"""
def _consume_force_next_timing_continue_reason(self) -> str | None:
"""消费一次性 Timing Gate continue 状态,并返回原因描述"""
if not self._force_continue_until_reply:
return
if not self._force_next_timing_continue:
return None
logger.info(
f"{self.log_prefix} 已成功发送 reply恢复 Timing Gate"
f"触发原因={self._force_continue_trigger_reason or '未知'} "
f"触发消息编号={self._force_continue_trigger_message_id or 'unknown'}"
)
self._force_continue_until_reply = False
self._force_continue_trigger_message_id = ""
self._force_continue_trigger_reason = ""
def _build_force_continue_timing_reason(self) -> str:
"""返回当前强制跳过 Timing Gate 的原因描述。"""
trigger_reason = self._force_continue_trigger_reason or "@/提及消息"
trigger_message_id = self._force_continue_trigger_message_id or "unknown"
return (
trigger_reason = self._force_next_timing_reason or "@/提及消息"
trigger_message_id = self._force_next_timing_message_id or "unknown"
reason = (
f"检测到新的{trigger_reason}(消息编号={trigger_message_id}"
"本轮直接跳过 Timing Gate 并视作 continue,直到成功发送一条 reply"
"本轮直接跳过 Timing Gate 并视作 continue。"
)
logger.info(
f"{self.log_prefix} 已结束本次强制 continue恢复 Timing Gate"
f"触发原因={trigger_reason} "
f"触发消息编号={trigger_message_id}"
)
self._force_next_timing_continue = False
self._force_next_timing_message_id = ""
self._force_next_timing_reason = ""
return reason
def _bind_planner_interrupt_flag(self, interrupt_flag: asyncio.Event) -> None:
"""绑定当前可打断请求使用的中断标记。"""
@@ -426,6 +437,7 @@ class MaisakaHeartFlowChatting:
selected_history, _ = MaisakaChatLoopService.select_llm_context_messages(
self._chat_history,
request_kind=request_kind,
max_context_size=context_message_limit,
)
sub_agent_history = list(selected_history)
@@ -447,11 +459,133 @@ class MaisakaHeartFlowChatting:
tool_definitions=[] if tool_definitions is None else tool_definitions,
)
def set_current_action_tool_names(self, tool_names: Sequence[str]) -> None:
"""记录当前 Action Loop 已实际暴露给 planner 的工具名集合。"""
self._current_action_tool_names = {tool_name for tool_name in tool_names if str(tool_name).strip()}
def is_action_tool_currently_available(self, tool_name: str) -> bool:
"""判断指定工具在当前 Action Loop 轮次中是否真实可用。"""
normalized_name = str(tool_name).strip()
return bool(normalized_name) and normalized_name in self._current_action_tool_names
def update_deferred_tool_specs(self, deferred_tool_specs: Sequence[ToolSpec]) -> None:
"""刷新当前会话的 deferred tools 池,并清理失效的已发现工具。"""
next_specs_by_name: dict[str, ToolSpec] = {}
for tool_spec in deferred_tool_specs:
normalized_name = tool_spec.name.strip()
if not normalized_name:
continue
next_specs_by_name[normalized_name] = tool_spec
self.deferred_tool_specs_by_name = next_specs_by_name
self.discovered_tool_names.intersection_update(next_specs_by_name.keys())
def get_discovered_deferred_tool_specs(self) -> list[ToolSpec]:
"""返回当前会话中已发现、且仍然有效的 deferred tools。"""
return [
tool_spec
for tool_name, tool_spec in self.deferred_tool_specs_by_name.items()
if tool_name in self.discovered_tool_names
]
def build_deferred_tools_reminder(self) -> str:
"""构造供 planner 使用的 deferred tools 提示消息。"""
undiscovered_tool_specs = [
tool_spec
for tool_name, tool_spec in self.deferred_tool_specs_by_name.items()
if tool_name not in self.discovered_tool_names
]
if not undiscovered_tool_specs:
return ""
tool_lines: list[str] = []
for index, tool_spec in enumerate(undiscovered_tool_specs, start=1):
tool_name = tool_spec.name.strip()
tool_description = tool_spec.brief_description.strip()
if tool_description:
tool_lines.append(f"{index}. {tool_name}: {tool_description}")
else:
tool_lines.append(f"{index}. {tool_name}")
reminder_lines = [
"<system-reminder>",
"以下工具当前未直接暴露给你,但可以通过 tool_search 工具发现并在后续轮次中使用:",
*tool_lines,
"",
"如需其中某个工具,请先调用 tool_search。tool_search 只负责发现工具,不直接执行业务。",
"</system-reminder>",
]
return "\n".join(reminder_lines)
def search_deferred_tool_specs(
self,
query: str,
*,
limit: int,
) -> list[ToolSpec]:
"""按名称或简要描述搜索 deferred tools。"""
normalized_query = " ".join(query.lower().split()).strip()
if not normalized_query:
return []
scored_matches: list[tuple[int, str, ToolSpec]] = []
query_terms = [term for term in normalized_query.replace("_", " ").replace("-", " ").split() if term]
for tool_name, tool_spec in self.deferred_tool_specs_by_name.items():
lower_name = tool_name.lower()
lower_description = tool_spec.brief_description.lower()
score = 0
if normalized_query == lower_name:
score += 1000
if lower_name.startswith(normalized_query):
score += 300
if normalized_query in lower_name:
score += 200
if normalized_query in lower_description:
score += 100
for query_term in query_terms:
if query_term in lower_name:
score += 25
if query_term in lower_description:
score += 10
if score <= 0:
continue
scored_matches.append((score, tool_name, tool_spec))
scored_matches.sort(key=lambda item: (-item[0], item[1]))
return [tool_spec for _, _, tool_spec in scored_matches[: max(1, limit)]]
def discover_deferred_tools(self, tool_names: Sequence[str]) -> list[str]:
"""将指定 deferred tools 标记为已发现,并返回本次新发现的工具名。"""
newly_discovered_tool_names: list[str] = []
for raw_tool_name in tool_names:
normalized_name = str(raw_tool_name).strip()
if not normalized_name or normalized_name not in self.deferred_tool_specs_by_name:
continue
if normalized_name in self.discovered_tool_names:
continue
self.discovered_tool_names.add(normalized_name)
newly_discovered_tool_names.append(normalized_name)
return newly_discovered_tool_names
def _has_pending_messages(self) -> bool:
return self._last_processed_index < len(self.message_cache)
def _schedule_message_turn(self) -> None:
"""为当前待处理消息安排一次内部 turn。"""
if self._agent_state == self._STATE_WAIT:
return
if not self._has_pending_messages() or self._message_turn_scheduled:
return
@@ -531,8 +665,9 @@ class MaisakaHeartFlowChatting:
def _enter_wait_state(self, seconds: Optional[float] = None, tool_call_id: Optional[str] = None) -> None:
"""切换到等待状态。"""
self._agent_state = self._STATE_WAIT
self._wait_until = None if seconds is None else time.time() + seconds
self._pending_wait_tool_call_id = tool_call_id
self._message_turn_scheduled = False
self._cancel_deferred_message_turn_task()
self._cancel_wait_timeout_task()
if seconds is not None:
self._wait_timeout_task = asyncio.create_task(
@@ -542,7 +677,6 @@ class MaisakaHeartFlowChatting:
def _enter_stop_state(self) -> None:
"""切换到停止状态。"""
self._agent_state = self._STATE_STOP
self._wait_until = None
self._pending_wait_tool_call_id = None
self._cancel_wait_timeout_task()
@@ -567,7 +701,6 @@ class MaisakaHeartFlowChatting:
logger.info(f"{self.log_prefix} Maisaka 等待已超时")
self._agent_state = self._STATE_RUNNING
self._wait_until = None
await self._internal_turn_queue.put("timeout")
except asyncio.CancelledError:
return
@@ -616,7 +749,7 @@ class MaisakaHeartFlowChatting:
return True
async def _trigger_expression_learning(self, messages: list[SessionMessage]) -> None:
"""?????????????????"""
"""触发表达方式学习"""
pending_count = self._expression_learner.get_pending_count(self.message_cache)
if not self._should_trigger_learning(
enabled=self._enable_expression_learning,
@@ -629,21 +762,21 @@ class MaisakaHeartFlowChatting:
self._last_expression_extraction_time = time.time()
logger.info(
f"{self.log_prefix} ??????: "
f"??????={len(messages)} ??????={pending_count} "
f"?????={len(self.message_cache)} "
f"??????={self._enable_jargon_learning}"
f"{self.log_prefix} 触发表达方式学习: "
f"消息数量={len(messages)} 待处理消息数量={pending_count} "
f"缓存总量={len(self.message_cache)} "
f"是否启用黑话学习={self._enable_jargon_learning}"
)
try:
jargon_miner = self._jargon_miner if self._enable_jargon_learning else None
learnt_style = await self._expression_learner.learn(self.message_cache, jargon_miner)
if learnt_style:
logger.info(f"{self.log_prefix} ???????")
logger.info(f"{self.log_prefix} 表达方式学习成功")
else:
logger.debug(f"{self.log_prefix} ???????????????")
logger.debug(f"{self.log_prefix} 表达方式学习失败")
except Exception:
logger.exception(f"{self.log_prefix} ??????")
logger.exception(f"{self.log_prefix} 表达方式学习异常")
async def _init_mcp(self) -> None:
"""初始化 MCP 工具并注册到统一工具层。"""
@@ -655,12 +788,12 @@ class MaisakaHeartFlowChatting:
host_callbacks=self._mcp_host_bridge.build_callbacks(),
)
if self._mcp_manager is None:
logger.info(f"{self.log_prefix} MCP 管理器不可用")
logger.info(f"{self.log_prefix} Maisaka MCP 管理器不可用")
return
mcp_tool_specs = self._mcp_manager.get_tool_specs()
if not mcp_tool_specs:
logger.info(f"{self.log_prefix} 没有可供 Maisaka 使用的 MCP 工具")
logger.info(f"{self.log_prefix} Maisaka 没有可供使用的 MCP 工具")
return
self._tool_registry.register_provider(MCPToolProvider(self._mcp_manager))
@@ -694,6 +827,7 @@ class MaisakaHeartFlowChatting:
self,
*,
cycle_id: Optional[int] = None,
time_records: Optional[dict[str, float]] = None,
timing_selected_history_count: Optional[int] = None,
timing_prompt_tokens: Optional[int] = None,
timing_action: str = "",
@@ -709,6 +843,7 @@ class MaisakaHeartFlowChatting:
planner_tool_results: Optional[list[str]] = None,
planner_tool_detail_results: Optional[list[dict[str, Any]]] = None,
planner_prompt_section: Optional[RenderableType] = None,
planner_extra_lines: Optional[list[str]] = None,
) -> None:
"""在终端展示当前聊天流本轮 cycle 的最终结果。"""
if not global_config.debug.show_maisaka_thinking:
@@ -721,6 +856,7 @@ class MaisakaHeartFlowChatting:
if cycle_id is not None:
body_lines.append(f"循环编号:{cycle_id}")
panel_subtitle = self._build_cycle_time_records_text(time_records or {})
renderables: list[RenderableType] = [Text("\n".join(body_lines))]
timing_panel = self._build_cycle_stage_panel(
title="Timing Gate",
@@ -728,33 +864,49 @@ class MaisakaHeartFlowChatting:
selected_history_count=timing_selected_history_count,
prompt_tokens=timing_prompt_tokens,
response_text=timing_response,
tool_calls=timing_tool_calls,
tool_results=timing_tool_results,
tool_detail_results=timing_tool_detail_results,
prompt_section=timing_prompt_section,
extra_lines=[f"门控动作:{timing_action}"] if timing_action.strip() else None,
)
if timing_panel is not None:
renderables.append(timing_panel)
timing_tool_cards = self._build_tool_activity_cards(
stage_title="Timing Tool",
tool_calls=timing_tool_calls,
tool_results=timing_tool_results,
tool_detail_results=timing_tool_detail_results,
planner_style=False,
)
if timing_tool_cards:
renderables.extend(timing_tool_cards)
planner_panel = self._build_cycle_stage_panel(
title="Planner",
border_style="green",
selected_history_count=planner_selected_history_count,
prompt_tokens=planner_prompt_tokens,
response_text=planner_response,
tool_calls=planner_tool_calls,
tool_results=planner_tool_results,
tool_detail_results=planner_tool_detail_results,
prompt_section=planner_prompt_section,
extra_lines=planner_extra_lines,
)
if planner_panel is not None:
renderables.append(planner_panel)
planner_tool_cards = self._build_tool_activity_cards(
stage_title="Planner Tool",
tool_calls=planner_tool_calls,
tool_results=planner_tool_results,
tool_detail_results=planner_tool_detail_results,
planner_style=True,
)
if planner_tool_cards:
renderables.extend(planner_tool_cards)
console.print(
Panel(
Group(*renderables),
title="MaiSaka 循环",
subtitle=panel_subtitle,
border_style="bright_blue",
padding=(0, 1),
)
@@ -768,9 +920,6 @@ class MaisakaHeartFlowChatting:
selected_history_count: Optional[int],
prompt_tokens: Optional[int],
response_text: str = "",
tool_calls: Optional[list[Any]] = None,
tool_results: Optional[list[str]] = None,
tool_detail_results: Optional[list[dict[str, Any]]] = None,
prompt_section: Optional[RenderableType] = None,
extra_lines: Optional[list[str]] = None,
) -> Optional[Panel]:
@@ -780,9 +929,6 @@ class MaisakaHeartFlowChatting:
selected_history_count is not None,
prompt_tokens is not None,
bool(response_text.strip()),
bool(tool_calls),
bool(tool_results),
bool(tool_detail_results),
prompt_section is not None,
bool(extra_lines),
])
@@ -809,40 +955,11 @@ class MaisakaHeartFlowChatting:
Panel(
Text(normalized_response),
title="Maisaka 返回",
border_style="green",
border_style=border_style,
padding=(0, 1),
)
)
normalized_tool_calls = build_tool_call_summary_lines(tool_calls or [])
if normalized_tool_calls:
renderables.append(
Panel(
Text("\n".join(normalized_tool_calls)),
title="工具调用",
border_style="magenta",
padding=(0, 1),
)
)
normalized_tool_results = self._filter_redundant_tool_results(
tool_results=tool_results or [],
tool_detail_results=tool_detail_results or [],
)
if normalized_tool_results:
renderables.append(
Panel(
Text("\n".join(normalized_tool_results)),
title="工具结果",
border_style="yellow",
padding=(0, 1),
)
)
detail_panels = self._build_tool_detail_panels(tool_detail_results or [])
if detail_panels:
renderables.extend(detail_panels)
return Panel(
Group(*renderables),
title=title,
@@ -850,6 +967,75 @@ class MaisakaHeartFlowChatting:
padding=(0, 1),
)
def _build_tool_activity_cards(
self,
*,
stage_title: str,
tool_calls: Optional[list[Any]] = None,
tool_results: Optional[list[str]] = None,
tool_detail_results: Optional[list[dict[str, Any]]] = None,
planner_style: bool = False,
) -> list[RenderableType]:
"""构建与阶段同级的工具执行卡片列表。"""
detail_results = tool_detail_results or []
cards = self._build_tool_detail_cards(
detail_results,
stage_title=stage_title,
planner_style=planner_style,
)
if cards:
return cards
# 兼容旧数据结构:若尚无 detail则降级为简单文本卡片。
fallback_lines = self._filter_redundant_tool_results(
tool_results=tool_results or [],
tool_detail_results=detail_results,
)
if not fallback_lines and tool_calls:
fallback_lines = build_tool_call_summary_lines(tool_calls)
if not fallback_lines:
return []
fallback_border_style = "yellow"
return [
Panel(
Text("\n".join(fallback_lines)),
title=stage_title,
border_style=fallback_border_style,
padding=(0, 1),
)
]
@staticmethod
def _build_cycle_time_records_text(time_records: dict[str, float]) -> str:
"""构建循环最外层面板展示的阶段耗时文本。"""
if not time_records:
return "流程耗时:无"
label_map = {
"timing_gate": "Timing Gate",
"planner": "Planner",
"tool_calls": "工具执行",
}
ordered_keys = ["timing_gate", "planner", "tool_calls"]
parts: list[str] = []
for key in ordered_keys:
duration = time_records.get(key)
if isinstance(duration, (int, float)):
parts.append(f"{label_map.get(key, key)} {float(duration):.2f} s")
for key, duration in time_records.items():
if key in ordered_keys or not isinstance(duration, (int, float)):
continue
parts.append(f"{label_map.get(key, key)} {float(duration):.2f} s")
if not parts:
return "流程耗时:无"
return "流程耗时:" + " | ".join(parts)
@staticmethod
def _filter_redundant_tool_results(
*,
@@ -941,7 +1127,9 @@ class MaisakaHeartFlowChatting:
*,
tool_name: str,
prompt_text: str,
request_messages: Optional[list[Any]] = None,
tool_call_id: str,
border_style: str = "bright_yellow",
) -> Panel:
"""将工具 prompt 渲染为可点击查看的预览入口。"""
@@ -950,6 +1138,26 @@ class MaisakaHeartFlowChatting:
if tool_call_id:
subtitle += f"\n调用ID: {tool_call_id}"
if isinstance(request_messages, list) and request_messages:
try:
normalized_messages = deserialize_prompt_messages(request_messages)
except Exception as exc:
logger.warning(f"工具 {tool_name} 的 request_messages 无法反序列化,已回退为文本预览: {exc}")
else:
return Panel(
PromptCLIVisualizer.build_prompt_access_panel(
normalized_messages,
category=labels["prompt_category"],
chat_id=self.session_id,
request_kind=labels["request_kind"],
selection_reason=subtitle,
image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy",
),
title=labels["prompt_title"],
border_style=border_style,
padding=(0, 1),
)
return Panel(
PromptCLIVisualizer.build_text_access_panel(
prompt_text,
@@ -959,116 +1167,235 @@ class MaisakaHeartFlowChatting:
subtitle=subtitle,
),
title=labels["prompt_title"],
border_style="bright_yellow",
border_style=border_style,
padding=(0, 1),
)
def _build_tool_detail_panels(self, tool_detail_results: list[dict[str, Any]]) -> list[RenderableType]:
""" tool monitor detail 渲染为 CLI 详情卡片"""
def _normalize_tool_card_body_lines(self, body: Any) -> list[str]:
"""工具卡片正文规范化为行列表"""
if isinstance(body, str):
return [line for line in body.splitlines() if line.strip()]
if isinstance(body, list):
return [
str(item).strip()
for item in body
if str(item).strip()
]
return []
def _build_custom_tool_sub_cards(
self,
sub_cards: Any,
*,
default_border_style: str,
) -> list[RenderableType]:
"""构建工具自定义子卡片。"""
if not isinstance(sub_cards, list):
return []
renderables: list[RenderableType] = []
for sub_card in sub_cards:
if not isinstance(sub_card, dict):
continue
title = str(sub_card.get("title") or "").strip() or "附加信息"
border_style = str(sub_card.get("border_style") or "").strip() or default_border_style
body_lines = self._normalize_tool_card_body_lines(
sub_card.get("body_lines", sub_card.get("content", ""))
)
if not body_lines:
continue
renderables.append(
Panel(
Text("\n".join(body_lines)),
title=title,
border_style=border_style,
padding=(0, 1),
)
)
return renderables
def _build_default_tool_detail_parts(
self,
*,
tool_name: str,
tool_call_id: str,
tool_args: Any,
summary: str,
duration_ms: Any,
detail: dict[str, Any],
planner_style: bool,
) -> list[RenderableType]:
"""构建工具卡片默认内容块。"""
argument_border_style = "yellow"
metrics_border_style = "bright_yellow"
prompt_border_style = "bright_yellow"
reasoning_border_style = "yellow"
output_border_style = "bright_yellow"
extra_info_border_style = "yellow"
detail_labels = self._get_tool_detail_labels(tool_name)
parts: list[RenderableType] = []
header_lines: list[str] = []
if summary:
header_lines.append(summary)
if tool_call_id:
header_lines.append(f"调用ID{tool_call_id}")
if isinstance(duration_ms, (int, float)):
header_lines.append(f"执行耗时:{round(float(duration_ms), 2)} ms")
if header_lines:
parts.append(Text("\n".join(header_lines)))
if isinstance(tool_args, dict) and tool_args:
parts.append(
Panel(
Pretty(tool_args, expand_all=True),
title="工具参数",
border_style=argument_border_style,
padding=(0, 1),
)
)
metrics = detail.get("metrics")
if isinstance(metrics, dict):
metrics_text = self._build_tool_metrics_text(metrics)
if metrics_text:
parts.append(
Panel(
Text(metrics_text),
title="执行指标",
border_style=metrics_border_style,
padding=(0, 1),
)
)
prompt_text = str(detail.get("prompt_text") or "").strip()
if prompt_text:
parts.append(
self._build_tool_prompt_access_panel(
tool_name=tool_name,
prompt_text=prompt_text,
request_messages=detail.get("request_messages") if isinstance(detail.get("request_messages"), list) else None,
tool_call_id=tool_call_id,
border_style=prompt_border_style,
)
)
reasoning_text = str(detail.get("reasoning_text") or "").strip()
if reasoning_text:
parts.append(
Panel(
Text(reasoning_text),
title=detail_labels["reasoning_title"],
border_style=reasoning_border_style,
padding=(0, 1),
)
)
output_text = str(detail.get("output_text") or "").strip()
if output_text:
parts.append(
Panel(
Text(output_text),
title=detail_labels["output_title"],
border_style=output_border_style,
padding=(0, 1),
)
)
extra_sections = detail.get("extra_sections")
if isinstance(extra_sections, list):
for section in extra_sections:
if not isinstance(section, dict):
continue
section_title = str(section.get("title") or "").strip() or "附加信息"
section_content = str(section.get("content") or "").strip()
if not section_content:
continue
parts.append(
Panel(
Text(section_content),
title=section_title,
border_style=extra_info_border_style,
padding=(0, 1),
)
)
return parts
def _build_tool_detail_cards(
self,
tool_detail_results: list[dict[str, Any]],
*,
stage_title: str,
planner_style: bool = False,
) -> list[RenderableType]:
"""将 tool monitor detail 渲染为与 Planner/Timing 平级的工具卡片。"""
detail_panel_border_style = "yellow"
sub_card_border_style = "bright_yellow"
panels: list[RenderableType] = []
for tool_result in tool_detail_results:
detail = tool_result.get("detail")
if not isinstance(detail, dict) or not detail:
continue
detail_dict = detail if isinstance(detail, dict) else {}
tool_name = str(tool_result.get("tool_name") or "unknown").strip() or "unknown"
detail_labels = self._get_tool_detail_labels(tool_name)
tool_title = str(tool_result.get("tool_title") or "").strip() or tool_name
tool_call_id = str(tool_result.get("tool_call_id") or "").strip()
tool_args = tool_result.get("tool_args")
summary = str(tool_result.get("summary") or "").strip()
duration_ms = tool_result.get("duration_ms")
custom_card = tool_result.get("card")
parts: list[RenderableType] = []
header_lines: list[str] = []
if summary:
header_lines.append(summary)
if tool_call_id:
header_lines.append(f"调用ID{tool_call_id}")
if isinstance(duration_ms, (int, float)):
header_lines.append(f"执行耗时:{round(float(duration_ms), 2)} ms")
if header_lines:
parts.append(Text("\n".join(header_lines)))
if isinstance(tool_args, dict) and tool_args:
parts.append(
Panel(
Pretty(tool_args, expand_all=True),
title="工具参数",
border_style="cyan",
padding=(0, 1),
)
custom_title = ""
card_border_style = detail_panel_border_style
replace_default_children = False
if isinstance(custom_card, dict):
custom_title = str(custom_card.get("title") or "").strip()
card_border_style = str(custom_card.get("border_style") or "").strip() or detail_panel_border_style
replace_default_children = bool(custom_card.get("replace_default_children", False))
custom_body_lines = self._normalize_tool_card_body_lines(
custom_card.get("body_lines", custom_card.get("content", ""))
)
if custom_body_lines:
parts.append(Text("\n".join(custom_body_lines)))
metrics = detail.get("metrics")
if isinstance(metrics, dict):
metrics_text = self._build_tool_metrics_text(metrics)
if metrics_text:
parts.append(
Panel(
Text(metrics_text),
title="执行指标",
border_style="bright_cyan",
padding=(0, 1),
)
)
prompt_text = str(detail.get("prompt_text") or "").strip()
if prompt_text:
parts.append(
self._build_tool_prompt_access_panel(
if not replace_default_children:
parts.extend(
self._build_default_tool_detail_parts(
tool_name=tool_name,
prompt_text=prompt_text,
tool_call_id=tool_call_id,
tool_args=tool_args,
summary=summary,
duration_ms=duration_ms,
detail=detail_dict,
planner_style=planner_style,
)
)
reasoning_text = str(detail.get("reasoning_text") or "").strip()
if reasoning_text:
parts.append(
Panel(
Text(reasoning_text),
title=detail_labels["reasoning_title"],
border_style="magenta",
padding=(0, 1),
if isinstance(custom_card, dict):
parts.extend(
self._build_custom_tool_sub_cards(
custom_card.get("sub_cards"),
default_border_style=sub_card_border_style,
)
)
output_text = str(detail.get("output_text") or "").strip()
if output_text:
parts.append(
Panel(
Text(output_text),
title=detail_labels["output_title"],
border_style="green",
padding=(0, 1),
)
parts.extend(
self._build_custom_tool_sub_cards(
tool_result.get("sub_cards"),
default_border_style=sub_card_border_style,
)
extra_sections = detail.get("extra_sections")
if isinstance(extra_sections, list):
for section in extra_sections:
if not isinstance(section, dict):
continue
section_title = str(section.get("title") or "").strip() or "附加信息"
section_content = str(section.get("content") or "").strip()
if not section_content:
continue
parts.append(
Panel(
Text(section_content),
title=section_title,
border_style="white",
padding=(0, 1),
)
)
)
if parts:
panels.append(
Panel(
Group(*parts),
title=f"{tool_name} 工具详情",
border_style="yellow",
title=custom_title or f"{stage_title} · {tool_title}",
border_style=card_border_style,
padding=(0, 1),
)
)

View File

@@ -521,9 +521,7 @@ class MCPHostLLMBridge:
tool_definitions.append(
{
"name": tool_name,
"description": "\n\n".join(
part for part in [brief_description, detailed_description] if part.strip()
).strip(),
"description": brief_description,
"parameters_schema": parameters_schema or {"type": "object", "properties": {}},
}
)

View File

@@ -672,6 +672,32 @@ class ComponentQueryService:
collected_specs[entry.name] = self._build_tool_spec(entry) # type: ignore[arg-type]
return collected_specs
@staticmethod
def _build_tool_context_payload(context: Optional[ToolExecutionContext]) -> Dict[str, Any]:
"""提取插件工具可复用的会话上下文字段。"""
if context is None:
return {}
payload: Dict[str, Any] = {}
stream_id = str(context.stream_id or context.session_id or "").strip()
if stream_id:
payload["stream_id"] = stream_id
payload["chat_id"] = stream_id
anchor_message = context.metadata.get("anchor_message")
message_info = getattr(anchor_message, "message_info", None)
group_info = getattr(message_info, "group_info", None)
user_info = getattr(message_info, "user_info", None)
group_id = str(getattr(group_info, "group_id", "") or "").strip()
user_id = str(getattr(user_info, "user_id", "") or "").strip()
if group_id:
payload["group_id"] = group_id
if user_id:
payload["user_id"] = user_id
return payload
@staticmethod
def _build_tool_invocation_payload(
entry: "ToolEntry",
@@ -690,16 +716,27 @@ class ComponentQueryService:
"""
payload = dict(invocation.arguments)
context_payload = ComponentQueryService._build_tool_context_payload(context)
if entry.invoke_method == "plugin.invoke_action":
stream_id = context.stream_id if context is not None else invocation.stream_id
stream_id = str(
context_payload.get("stream_id")
or (context.stream_id if context is not None else invocation.stream_id)
or invocation.stream_id
).strip()
reasoning = context.reasoning if context is not None else invocation.reasoning
payload = {
**payload,
**{key: value for key, value in context_payload.items() if key not in payload or not payload.get(key)},
"stream_id": stream_id,
"chat_id": stream_id,
"reasoning": reasoning,
"action_data": dict(invocation.arguments),
}
return payload
for key, value in context_payload.items():
if key not in payload or not payload.get(key):
payload[key] = value
return payload
@staticmethod

View File

@@ -11,8 +11,7 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from rich.traceback import install
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.replyer.group_generator import DefaultReplyer
from src.chat.replyer.private_generator import PrivateReplyer
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
from src.chat.replyer.replyer_manager import replyer_manager
from src.chat.utils.utils import process_llm_response
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
@@ -20,8 +19,8 @@ from src.common.logger import get_logger
from src.core.types import ActionInfo
if TYPE_CHECKING:
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel
from src.common.data_models.planned_action_data_models import PlannedAction
from src.chat.message_receive.message import SessionMessage
install(extra_lines=3)
@@ -38,7 +37,7 @@ def _get_replyer(
chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer | PrivateReplyer]:
) -> Optional[MaisakaReplyGenerator]:
"""获取回复器对象"""
if not chat_id and not chat_stream:
raise ValueError("chat_stream 和 chat_id 不可均为空")
@@ -100,7 +99,7 @@ async def generate_reply(
extra_info: str = "",
reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
chosen_actions: Optional[List["ActionPlannerInfo"]] = None,
chosen_actions: Optional[List["PlannedAction"]] = None,
unknown_words: Optional[List[str]] = None,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,

View File

@@ -267,6 +267,46 @@ def _parse_data_url_image(image_url: str) -> Tuple[str, str]:
return image_format, image_base64
def _append_image_content(message_builder: MessageBuilder, content_item: Any) -> bool:
"""向消息构建器追加图片片段。
兼容两种输入格式:
1. 旧序列化格式中的 `(image_format, image_base64)` 元组。
2. 标准字典片段中的 Data URL 或 `image_format`/`image_base64` 字段。
"""
if isinstance(content_item, (tuple, list)) and len(content_item) == 2:
image_format, image_base64 = content_item
if not isinstance(image_format, str) or not isinstance(image_base64, str):
raise ValueError("图片元组片段必须包含字符串类型的 image_format 和 image_base64")
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
return True
if not isinstance(content_item, dict):
return False
part_type = str(content_item.get("type", "text")).strip().lower()
if part_type not in {"image", "image_url", "input_image"}:
return False
image_url = content_item.get("image_url")
if isinstance(image_url, dict):
image_url = image_url.get("url")
if isinstance(image_url, str):
image_format, image_base64 = _parse_data_url_image(image_url)
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
return True
image_format = content_item.get("image_format")
image_base64 = content_item.get("image_base64")
if isinstance(image_format, str) and isinstance(image_base64, str):
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
return True
raise ValueError("图片片段缺少可识别的图片数据")
def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None:
"""将原始消息内容追加到内部消息构建器。
@@ -293,8 +333,10 @@ def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None
if isinstance(content_item, str):
message_builder.add_text_content(content_item)
continue
if _append_image_content(message_builder, content_item):
continue
if not isinstance(content_item, dict):
raise ValueError("消息内容列表中仅支持字符串或字典片段")
raise ValueError("消息内容列表中仅支持字符串、图片元组或字典片段")
part_type = str(content_item.get("type", "text")).strip().lower()
if part_type == "text":
@@ -304,22 +346,6 @@ def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None
message_builder.add_text_content(text_content)
continue
if part_type in {"image", "image_url", "input_image"}:
image_url = content_item.get("image_url")
if isinstance(image_url, dict):
image_url = image_url.get("url")
if isinstance(image_url, str):
image_format, image_base64 = _parse_data_url_image(image_url)
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
continue
image_format = content_item.get("image_format")
image_base64 = content_item.get("image_base64")
if isinstance(image_format, str) and isinstance(image_base64, str):
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
continue
raise ValueError("图片片段缺少可识别的图片数据")
raise ValueError(f"不支持的消息片段类型: {part_type}")

View File

@@ -326,7 +326,7 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
if emoji.is_registered:
return EmojiUpdateResponse(success=True, message="??????????", data=emoji_to_response(emoji))
return EmojiUpdateResponse(success=True, message="表情包已注册", data=emoji_to_response(emoji))
emoji.is_registered = True
emoji.is_banned = False

View File

@@ -0,0 +1,25 @@
from src.maisaka.reasoning_engine import MaisakaReasoningEngine
def test_retry_planner_after_interrupt_only_when_has_new_messages_and_more_rounds() -> None:
assert MaisakaReasoningEngine._should_retry_planner_after_interrupt(
round_index=0,
max_internal_rounds=6,
has_pending_messages=True,
)
def test_do_not_retry_planner_after_interrupt_without_pending_messages() -> None:
assert not MaisakaReasoningEngine._should_retry_planner_after_interrupt(
round_index=0,
max_internal_rounds=6,
has_pending_messages=False,
)
def test_do_not_retry_planner_after_interrupt_on_last_round() -> None:
assert not MaisakaReasoningEngine._should_retry_planner_after_interrupt(
round_index=5,
max_internal_rounds=6,
has_pending_messages=True,
)

View File

@@ -0,0 +1,63 @@
from src.core.tooling import ToolSpec
from src.llm_models.payload_content.message import RoleType
from src.maisaka.chat_loop_service import MaisakaChatLoopService
from src.maisaka.runtime import MaisakaHeartFlowChatting
def _build_runtime_stub() -> MaisakaHeartFlowChatting:
runtime = object.__new__(MaisakaHeartFlowChatting)
runtime._current_action_tool_names = set()
runtime.deferred_tool_specs_by_name = {}
runtime.discovered_tool_names = set()
return runtime
def test_deferred_tools_reminder_only_lists_undiscovered_tools() -> None:
runtime = _build_runtime_stub()
runtime.update_deferred_tool_specs(
[
ToolSpec(name="plugin_alpha", brief_description="alpha"),
ToolSpec(name="plugin_beta", brief_description="beta"),
]
)
runtime.discover_deferred_tools(["plugin_alpha"])
reminder = runtime.build_deferred_tools_reminder()
assert "plugin_alpha" not in reminder
assert "1. plugin_beta" in reminder
assert "<system-reminder>" in reminder
assert "tool_search" in reminder
def test_search_and_discover_deferred_tools() -> None:
runtime = _build_runtime_stub()
runtime.update_deferred_tool_specs(
[
ToolSpec(name="mcp__slack__send_message", brief_description="向 Slack 发送消息"),
ToolSpec(name="mcp__github__create_issue", brief_description="在 GitHub 创建 Issue"),
]
)
matched_tool_specs = runtime.search_deferred_tool_specs("slack send", limit=5)
newly_discovered_tool_names = runtime.discover_deferred_tools([tool_spec.name for tool_spec in matched_tool_specs])
assert [tool_spec.name for tool_spec in matched_tool_specs] == ["mcp__slack__send_message"]
assert newly_discovered_tool_names == ["mcp__slack__send_message"]
assert [tool_spec.name for tool_spec in runtime.get_discovered_deferred_tool_specs()] == [
"mcp__slack__send_message"
]
def test_build_request_messages_appends_injected_user_message() -> None:
chat_loop_service = MaisakaChatLoopService(chat_system_prompt="system prompt")
messages = chat_loop_service._build_request_messages(
[],
injected_user_messages=["<system-reminder>\n1. plugin_beta\n</system-reminder>"],
)
assert len(messages) == 2
assert messages[0].role == RoleType.System
assert messages[1].role == RoleType.User
assert messages[1].content == "<system-reminder>\n1. plugin_beta\n</system-reminder>"

View File

@@ -0,0 +1,10 @@
from src.maisaka.reasoning_engine import MaisakaReasoningEngine
def test_continue_action_closes_timing_gate_for_following_rounds() -> None:
assert MaisakaReasoningEngine._mark_timing_gate_completed("continue") is False
def test_non_continue_actions_require_next_timing_gate() -> None:
assert MaisakaReasoningEngine._mark_timing_gate_completed("wait") is True
assert MaisakaReasoningEngine._mark_timing_gate_completed("no_reply") is True

View File

@@ -0,0 +1,21 @@
from src.maisaka.builtin_tool import get_action_tool_specs, get_timing_tool_specs
def test_wait_tool_available_in_timing_stage() -> None:
tool_names = {tool_spec.name for tool_spec in get_timing_tool_specs()}
assert "wait" in tool_names
def test_wait_tool_not_available_in_action_stage() -> None:
tool_names = {tool_spec.name for tool_spec in get_action_tool_specs()}
assert "wait" not in tool_names
assert "finish" in tool_names
assert "tool_search" in tool_names
def test_tool_search_not_available_in_timing_stage() -> None:
tool_names = {tool_spec.name for tool_spec in get_timing_tool_specs()}
assert "tool_search" not in tool_names