Merge branch 'r-dev' of https://github.com/A-Dawn/MaiBot into r-dev
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,9 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
import json
|
||||
from typing import Any, Awaitable, Callable, List, Optional
|
||||
|
||||
import json
|
||||
|
||||
from json_repair import repair_json
|
||||
from sqlmodel import select
|
||||
|
||||
@@ -30,7 +31,7 @@ class MaisakaExpressionSelectionResult:
|
||||
|
||||
|
||||
class MaisakaExpressionSelector:
|
||||
"""负责在 replyer 侧完成表达方式筛选与子代理选择。"""
|
||||
"""负责在 replyer 侧完成表达方式筛选与子代理二次选择。"""
|
||||
|
||||
def _can_use_expressions(self, session_id: str) -> bool:
|
||||
try:
|
||||
@@ -40,18 +41,34 @@ class MaisakaExpressionSelector:
|
||||
logger.error(f"检查表达方式使用开关失败: {exc}")
|
||||
return False
|
||||
|
||||
def _get_related_session_ids(self, session_id: str) -> List[str]:
|
||||
def _can_use_advanced_chosen(self, session_id: str) -> bool:
|
||||
try:
|
||||
return ExpressionConfigUtils.get_expression_advanced_chosen_for_chat(session_id)
|
||||
except Exception as exc:
|
||||
logger.error(f"检查表达方式二次选择开关失败: {exc}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_global_expression_group_marker(platform: str, item_id: str) -> bool:
|
||||
return platform == "*" and item_id == "*"
|
||||
|
||||
def _resolve_expression_group_scope(self, session_id: str) -> tuple[set[str], bool]:
|
||||
related_session_ids = {session_id}
|
||||
has_global_share = False
|
||||
expression_groups = global_config.expression.expression_groups
|
||||
|
||||
for expression_group in expression_groups:
|
||||
target_items = expression_group.expression_groups
|
||||
group_session_ids: set[str] = set()
|
||||
contains_current_session = False
|
||||
contains_global_share_marker = False
|
||||
|
||||
for target_item in target_items:
|
||||
platform = target_item.platform.strip()
|
||||
item_id = target_item.item_id.strip()
|
||||
if self._is_global_expression_group_marker(platform, item_id):
|
||||
contains_global_share_marker = True
|
||||
continue
|
||||
if not platform or not item_id:
|
||||
continue
|
||||
|
||||
@@ -65,19 +82,24 @@ class MaisakaExpressionSelector:
|
||||
if target_session_id == session_id:
|
||||
contains_current_session = True
|
||||
|
||||
if contains_global_share_marker:
|
||||
has_global_share = True
|
||||
if contains_current_session:
|
||||
related_session_ids.update(group_session_ids)
|
||||
|
||||
return list(related_session_ids)
|
||||
return related_session_ids, has_global_share
|
||||
|
||||
def _load_expression_candidates(self, session_id: str) -> List[dict[str, Any]]:
|
||||
related_session_ids = self._get_related_session_ids(session_id)
|
||||
related_session_ids, has_global_share = self._resolve_expression_group_scope(session_id)
|
||||
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
base_query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
||||
scoped_query = base_query.where(
|
||||
(Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
)
|
||||
if has_global_share:
|
||||
scoped_query = base_query
|
||||
else:
|
||||
scoped_query = base_query.where(
|
||||
(Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
)
|
||||
if global_config.expression.expression_checked_only:
|
||||
scoped_query = scoped_query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
||||
expressions = session.exec(scoped_query).all()
|
||||
@@ -87,7 +109,7 @@ class MaisakaExpressionSelector:
|
||||
"id": expression.id,
|
||||
"situation": expression.situation,
|
||||
"style": expression.style,
|
||||
"count": expression.count if getattr(expression, "count", None) is not None else 1,
|
||||
"count": expression.count if expression.count is not None else 1,
|
||||
}
|
||||
for expression in expressions
|
||||
if expression.id is not None and expression.situation and expression.style
|
||||
@@ -171,7 +193,7 @@ class MaisakaExpressionSelector:
|
||||
"你只负责根据最近聊天上下文,为这一次可见回复挑选最合适的表达方式。\n"
|
||||
"请只从下面候选中选择 0 到 3 条最适合当前语境的表达方式。\n"
|
||||
"优先考虑自然、贴合上下文、不生硬、不模板化。\n"
|
||||
"如果没有明显合适的,就返回空列表。\n"
|
||||
"如果没有明显合适的,就返回空数组。\n"
|
||||
'严格只输出 JSON,对象格式为 {"selected_ids":[123,456]}。\n\n'
|
||||
f"最近上下文:\n{history_block}\n\n"
|
||||
f"目标消息:{target_text or '无'}\n"
|
||||
@@ -208,6 +230,32 @@ class MaisakaExpressionSelector:
|
||||
break
|
||||
return selected_ids
|
||||
|
||||
def _build_direct_selection_result(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
candidates: List[dict[str, Any]],
|
||||
) -> MaisakaExpressionSelectionResult:
|
||||
selected_ids = [
|
||||
candidate["id"]
|
||||
for candidate in candidates
|
||||
if isinstance(candidate.get("id"), int)
|
||||
]
|
||||
selected_expressions = [
|
||||
candidate
|
||||
for candidate in candidates
|
||||
if candidate.get("id") in selected_ids
|
||||
]
|
||||
self._update_last_active_time(selected_ids)
|
||||
logger.info(
|
||||
f"表达方式直接注入:session_id={session_id} 已选数={len(selected_ids)} "
|
||||
f"selected_ids={selected_ids!r} 已选预览={self._format_candidate_preview(selected_expressions)}"
|
||||
)
|
||||
return MaisakaExpressionSelectionResult(
|
||||
expression_habits=self._build_expression_habits_block(selected_expressions),
|
||||
selected_expression_ids=selected_ids,
|
||||
)
|
||||
|
||||
def _update_last_active_time(self, selected_ids: List[int]) -> None:
|
||||
if not selected_ids:
|
||||
return
|
||||
@@ -233,15 +281,22 @@ class MaisakaExpressionSelector:
|
||||
if not self._can_use_expressions(session_id):
|
||||
logger.info(f"表达方式选择已跳过:当前会话未启用表达方式,session_id={session_id}")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
if sub_agent_runner is None:
|
||||
logger.info(f"表达方式选择已跳过:缺少 sub_agent_runner,session_id={session_id}")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
|
||||
candidates = self._load_expression_candidates(session_id)
|
||||
if not candidates:
|
||||
logger.info(f"表达方式选择已跳过:本地候选不足,session_id={session_id}")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
|
||||
if not self._can_use_advanced_chosen(session_id):
|
||||
return self._build_direct_selection_result(
|
||||
session_id=session_id,
|
||||
candidates=candidates,
|
||||
)
|
||||
|
||||
if sub_agent_runner is None:
|
||||
logger.info(f"表达方式选择已跳过:缺少 sub_agent_runner,session_id={session_id}")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
|
||||
logger.info(
|
||||
f"表达方式选择开始:session_id={session_id} 候选数={len(candidates)} "
|
||||
f"候选预览={self._format_candidate_preview(candidates)}"
|
||||
@@ -259,10 +314,9 @@ class MaisakaExpressionSelector:
|
||||
logger.exception("表达方式选择子代理执行失败")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
|
||||
logger.info(f"表达方式子代理原始结果:session_id={session_id} response={raw_response!r}")
|
||||
selected_ids = self._parse_selected_ids(raw_response, candidates)
|
||||
if not selected_ids:
|
||||
logger.info(f"表达方式选择完成但未命中:session_id={session_id}")
|
||||
logger.info(f"表达方式选择完成但未命中,session_id={session_id}")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
|
||||
selected_expressions = [candidate for candidate in candidates if candidate.get("id") in selected_ids]
|
||||
|
||||
@@ -1,440 +1,29 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import random
|
||||
import time
|
||||
|
||||
from rich.panel import Panel
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.cli.console import console
|
||||
from src.common.data_models.reply_generation_data_models import (
|
||||
GenerationMetrics,
|
||||
LLMCompletionResult,
|
||||
ReplyGenerationResult,
|
||||
build_reply_monitor_detail,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.config.config import global_config
|
||||
from src.core.types import ActionInfo
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from src.maisaka.context_messages import (
|
||||
AssistantMessage,
|
||||
LLMContextMessage,
|
||||
ReferenceMessage,
|
||||
SessionBackedMessage,
|
||||
ToolResultMessage,
|
||||
)
|
||||
from src.maisaka.message_adapter import parse_speaker_content
|
||||
from src.maisaka.prompt_cli_renderer import PromptCLIVisualizer
|
||||
|
||||
from .maisaka_expression_selector import maisaka_expression_selector
|
||||
|
||||
logger = get_logger("replyer")
|
||||
from .maisaka_generator_base import BaseMaisakaReplyGenerator
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaisakaReplyContext:
|
||||
"""Maisaka replyer 使用的回复上下文。"""
|
||||
|
||||
expression_habits: str = ""
|
||||
selected_expression_ids: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
class MaisakaReplyGenerator:
|
||||
"""生成 Maisaka 的最终可见回复。"""
|
||||
class MaisakaReplyGenerator(BaseMaisakaReplyGenerator):
|
||||
"""Maisaka replyer。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
request_type: str = "maisaka_replyer",
|
||||
llm_client_cls: Optional[Any] = None,
|
||||
load_prompt_func: Optional[Callable[..., str]] = None,
|
||||
enable_visual_message: Optional[bool] = None,
|
||||
) -> None:
|
||||
self.chat_stream = chat_stream
|
||||
self.request_type = request_type
|
||||
self.express_model = LLMServiceClient(
|
||||
task_name="replyer",
|
||||
super().__init__(
|
||||
chat_stream=chat_stream,
|
||||
request_type=request_type,
|
||||
llm_client_cls=llm_client_cls or LLMServiceClient,
|
||||
load_prompt_func=load_prompt_func or load_prompt,
|
||||
enable_visual_message=enable_visual_message,
|
||||
replyer_mode=global_config.visual.replyer_mode,
|
||||
)
|
||||
self._personality_prompt = self._build_personality_prompt()
|
||||
|
||||
def _build_personality_prompt(self) -> str:
|
||||
"""构建 replyer 使用的人设提示。"""
|
||||
try:
|
||||
bot_name = global_config.bot.nickname
|
||||
alias_names = global_config.bot.alias_names
|
||||
bot_aliases = f",也有人叫你{','.join(alias_names)}" if alias_names else ""
|
||||
|
||||
prompt_personality = global_config.personality.personality
|
||||
if (
|
||||
hasattr(global_config.personality, "states")
|
||||
and global_config.personality.states
|
||||
and hasattr(global_config.personality, "state_probability")
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
return f"你的名字是{bot_name}{bot_aliases},你{prompt_personality};"
|
||||
except Exception as exc:
|
||||
logger.warning(f"构建 Maisaka 人设提示词失败: {exc}")
|
||||
return "你的名字是麦麦,你是一个活泼可爱的 AI 助手。"
|
||||
|
||||
@staticmethod
|
||||
def _normalize_content(content: str, limit: int = 500) -> str:
|
||||
normalized = " ".join((content or "").split())
|
||||
if len(normalized) > limit:
|
||||
return normalized[:limit] + "..."
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _format_message_time(message: LLMContextMessage) -> str:
|
||||
return message.timestamp.strftime("%H:%M:%S")
|
||||
|
||||
@staticmethod
|
||||
def _extract_visible_assistant_reply(message: AssistantMessage) -> str:
|
||||
del message
|
||||
return ""
|
||||
|
||||
def _extract_guided_bot_reply(self, message: SessionBackedMessage) -> str:
|
||||
speaker_name, body = parse_speaker_content(message.processed_plain_text.strip())
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
if speaker_name == bot_nickname:
|
||||
return self._normalize_content(body.strip())
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _split_user_message_segments(raw_content: str) -> List[tuple[Optional[str], str]]:
|
||||
"""按说话人拆分用户消息。"""
|
||||
segments: List[tuple[Optional[str], str]] = []
|
||||
current_speaker: Optional[str] = None
|
||||
current_lines: List[str] = []
|
||||
|
||||
for raw_line in raw_content.splitlines():
|
||||
speaker_name, content_body = parse_speaker_content(raw_line)
|
||||
if speaker_name is not None:
|
||||
if current_lines:
|
||||
segments.append((current_speaker, "\n".join(current_lines)))
|
||||
current_speaker = speaker_name
|
||||
current_lines = [content_body]
|
||||
continue
|
||||
|
||||
current_lines.append(raw_line)
|
||||
|
||||
if current_lines:
|
||||
segments.append((current_speaker, "\n".join(current_lines)))
|
||||
|
||||
return segments
|
||||
|
||||
def _format_chat_history(self, messages: List[LLMContextMessage]) -> str:
|
||||
"""格式化 replyer 使用的可见聊天记录。"""
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
parts: List[str] = []
|
||||
|
||||
for message in messages:
|
||||
timestamp = self._format_message_time(message)
|
||||
|
||||
if isinstance(message, (ReferenceMessage, ToolResultMessage)):
|
||||
continue
|
||||
|
||||
if isinstance(message, SessionBackedMessage):
|
||||
guided_reply = self._extract_guided_bot_reply(message)
|
||||
if guided_reply:
|
||||
parts.append(f"{timestamp} {bot_nickname}(you): {guided_reply}")
|
||||
continue
|
||||
|
||||
raw_content = message.processed_plain_text
|
||||
for speaker_name, content_body in self._split_user_message_segments(raw_content):
|
||||
content = self._normalize_content(content_body)
|
||||
if not content:
|
||||
continue
|
||||
visible_speaker = speaker_name or global_config.maisaka.cli_user_name.strip() or "User"
|
||||
parts.append(f"{timestamp} {visible_speaker}: {content}")
|
||||
continue
|
||||
|
||||
if isinstance(message, AssistantMessage):
|
||||
visible_reply = self._extract_visible_assistant_reply(message)
|
||||
if visible_reply:
|
||||
parts.append(f"{timestamp} {bot_nickname}(you): {visible_reply}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _build_target_message_block(self, reply_message: Optional[SessionMessage]) -> str:
|
||||
"""构建当前需要回复的目标消息摘要。"""
|
||||
if reply_message is None:
|
||||
return ""
|
||||
|
||||
user_info = reply_message.message_info.user_info
|
||||
sender_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id
|
||||
target_message_id = reply_message.message_id.strip() if reply_message.message_id else "未知"
|
||||
target_content = self._normalize_content((reply_message.processed_plain_text or "").strip(), limit=300)
|
||||
if not target_content:
|
||||
target_content = "[无可见文本内容]"
|
||||
|
||||
return (
|
||||
"【本次回复目标】\n"
|
||||
f"- 目标消息ID:{target_message_id}\n"
|
||||
f"- 发送者:{sender_name}\n"
|
||||
f"- 消息内容:{target_content}\n"
|
||||
"- 你这次要回复的就是这条目标消息,请结合整段上下文理解,但不要误把其他历史消息当成当前回复对象。"
|
||||
)
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
) -> str:
|
||||
"""构建 Maisaka replyer 提示词。"""
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
formatted_history = self._format_chat_history(chat_history)
|
||||
target_message_block = self._build_target_message_block(reply_message)
|
||||
|
||||
try:
|
||||
system_prompt = load_prompt(
|
||||
"maisaka_replyer",
|
||||
bot_name=global_config.bot.nickname,
|
||||
time_block=f"当前时间:{current_time}",
|
||||
identity=self._personality_prompt,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
)
|
||||
except Exception:
|
||||
system_prompt = "你是一个友好的 AI 助手,请根据聊天记录自然回复。"
|
||||
|
||||
extra_sections: List[str] = []
|
||||
if expression_habits.strip():
|
||||
extra_sections.append(expression_habits.strip())
|
||||
|
||||
user_sections = [
|
||||
f"当前时间:{current_time}",
|
||||
f"【聊天记录】\n{formatted_history}",
|
||||
]
|
||||
if target_message_block:
|
||||
user_sections.append(target_message_block)
|
||||
if extra_sections:
|
||||
user_sections.append("\n\n".join(extra_sections))
|
||||
user_sections.append(f"【回复信息参考】\n{reply_reason}")
|
||||
user_sections.append("现在,你说:")
|
||||
|
||||
user_prompt = "\n\n".join(user_sections)
|
||||
return f"System: {system_prompt}\n\nUser: {user_prompt}"
|
||||
|
||||
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
|
||||
"""解析当前回复使用的会话 ID。"""
|
||||
if stream_id:
|
||||
return stream_id
|
||||
if self.chat_stream is not None:
|
||||
return self.chat_stream.session_id
|
||||
return ""
|
||||
|
||||
async def _build_reply_context(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
stream_id: Optional[str],
|
||||
sub_agent_runner: Optional[Callable[[str], Awaitable[str]]],
|
||||
) -> MaisakaReplyContext:
|
||||
"""构建回复上下文:表达习惯和已选表达 ID。"""
|
||||
session_id = self._resolve_session_id(stream_id)
|
||||
if not session_id:
|
||||
logger.warning("构建 Maisaka 回复上下文失败:缺少会话标识")
|
||||
return MaisakaReplyContext()
|
||||
|
||||
if sub_agent_runner is None:
|
||||
logger.info("表达方式选择跳过:缺少子代理执行器")
|
||||
return MaisakaReplyContext()
|
||||
|
||||
selection_result = await maisaka_expression_selector.select_for_reply(
|
||||
session_id=session_id,
|
||||
chat_history=chat_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
sub_agent_runner=sub_agent_runner,
|
||||
)
|
||||
return MaisakaReplyContext(
|
||||
expression_habits=selection_result.expression_habits,
|
||||
selected_expression_ids=selection_result.selected_expression_ids,
|
||||
)
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
chosen_actions: Optional[List[object]] = None,
|
||||
from_plugin: bool = True,
|
||||
stream_id: Optional[str] = None,
|
||||
reply_message: Optional[SessionMessage] = None,
|
||||
reply_time_point: Optional[float] = None,
|
||||
think_level: int = 1,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
log_reply: bool = True,
|
||||
chat_history: Optional[List[LLMContextMessage]] = None,
|
||||
expression_habits: str = "",
|
||||
selected_expression_ids: Optional[List[int]] = None,
|
||||
sub_agent_runner: Optional[Callable[[str], Awaitable[str]]] = None,
|
||||
) -> Tuple[bool, ReplyGenerationResult]:
|
||||
"""结合上下文生成 Maisaka 的最终可见回复。"""
|
||||
|
||||
def finalize(success_value: bool) -> Tuple[bool, ReplyGenerationResult]:
|
||||
result.monitor_detail = build_reply_monitor_detail(result)
|
||||
return success_value, result
|
||||
|
||||
del available_actions
|
||||
del chosen_actions
|
||||
del extra_info
|
||||
del from_plugin
|
||||
del log_reply
|
||||
del reply_time_point
|
||||
del think_level
|
||||
del unknown_words
|
||||
|
||||
result = ReplyGenerationResult()
|
||||
overall_started_at = time.perf_counter()
|
||||
if chat_history is None:
|
||||
result.error_message = "聊天历史为空"
|
||||
return finalize(False)
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复器开始生成: 会话流标识={stream_id} 回复原因={reply_reason!r} "
|
||||
f"历史消息数={len(chat_history)} 目标消息编号={reply_message.message_id if reply_message else None}"
|
||||
)
|
||||
|
||||
filtered_history = [
|
||||
message
|
||||
for message in chat_history
|
||||
if not isinstance(message, (ReferenceMessage, ToolResultMessage))
|
||||
]
|
||||
logger.debug(f"Maisaka 回复器过滤后历史消息数={len(filtered_history)}")
|
||||
|
||||
if self.express_model is None:
|
||||
logger.error("Maisaka 回复器的回复模型未初始化")
|
||||
result.error_message = "回复模型尚未初始化"
|
||||
return finalize(False)
|
||||
|
||||
try:
|
||||
reply_context = await self._build_reply_context(
|
||||
chat_history=filtered_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
stream_id=stream_id,
|
||||
sub_agent_runner=sub_agent_runner,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
|
||||
logger.error(f"Maisaka 回复器构建回复上下文失败: {exc}\n{traceback.format_exc()}")
|
||||
result.error_message = f"构建回复上下文失败: {exc}"
|
||||
result.metrics = GenerationMetrics(
|
||||
overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2),
|
||||
)
|
||||
return finalize(False)
|
||||
|
||||
merged_expression_habits = expression_habits.strip() or reply_context.expression_habits
|
||||
result.selected_expression_ids = (
|
||||
list(selected_expression_ids)
|
||||
if selected_expression_ids is not None
|
||||
else list(reply_context.selected_expression_ids)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复上下文构建完成: 会话流标识={stream_id} "
|
||||
f"已选表达编号={result.selected_expression_ids!r}"
|
||||
)
|
||||
|
||||
prompt_started_at = time.perf_counter()
|
||||
try:
|
||||
prompt = self._build_prompt(
|
||||
chat_history=filtered_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
expression_habits=merged_expression_habits,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
|
||||
logger.error(f"Maisaka 回复器构建提示词失败: {exc}\n{traceback.format_exc()}")
|
||||
result.error_message = f"构建提示词失败: {exc}"
|
||||
result.metrics = GenerationMetrics(
|
||||
overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2),
|
||||
)
|
||||
return finalize(False)
|
||||
|
||||
prompt_ms = round((time.perf_counter() - prompt_started_at) * 1000, 2)
|
||||
result.completion.request_prompt = prompt
|
||||
show_replyer_prompt = bool(getattr(global_config.debug, "show_replyer_prompt", False))
|
||||
show_replyer_reasoning = bool(getattr(global_config.debug, "show_replyer_reasoning", False))
|
||||
preview_chat_id = self._resolve_session_id(stream_id) or "unknown"
|
||||
|
||||
if show_replyer_prompt:
|
||||
console.print(
|
||||
Panel(
|
||||
PromptCLIVisualizer.build_text_access_panel(
|
||||
prompt,
|
||||
category="replyer",
|
||||
chat_id=preview_chat_id,
|
||||
request_kind="replyer",
|
||||
subtitle=f"流ID: {preview_chat_id}",
|
||||
),
|
||||
title="Maisaka 回复器 Prompt",
|
||||
border_style="bright_yellow",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
llm_started_at = time.perf_counter()
|
||||
try:
|
||||
generation_result = await self.express_model.generate_response(prompt)
|
||||
except Exception as exc:
|
||||
logger.exception("Maisaka 回复器调用失败")
|
||||
result.error_message = str(exc)
|
||||
result.metrics = GenerationMetrics(
|
||||
prompt_ms=prompt_ms,
|
||||
llm_ms=round((time.perf_counter() - llm_started_at) * 1000, 2),
|
||||
overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2),
|
||||
)
|
||||
return finalize(False)
|
||||
|
||||
llm_ms = round((time.perf_counter() - llm_started_at) * 1000, 2)
|
||||
response_text = (generation_result.response or "").strip()
|
||||
result.success = bool(response_text)
|
||||
result.completion = LLMCompletionResult(
|
||||
request_prompt=prompt,
|
||||
response_text=response_text,
|
||||
reasoning_text=generation_result.reasoning or "",
|
||||
model_name=generation_result.model_name or "",
|
||||
tool_calls=generation_result.tool_calls or [],
|
||||
prompt_tokens=generation_result.prompt_tokens,
|
||||
completion_tokens=generation_result.completion_tokens,
|
||||
total_tokens=generation_result.total_tokens,
|
||||
)
|
||||
result.metrics = GenerationMetrics(
|
||||
prompt_ms=prompt_ms,
|
||||
llm_ms=llm_ms,
|
||||
overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2),
|
||||
stage_logs=[
|
||||
f"prompt: {prompt_ms} ms",
|
||||
f"llm: {llm_ms} ms",
|
||||
],
|
||||
)
|
||||
|
||||
if show_replyer_reasoning and result.completion.reasoning_text:
|
||||
logger.info(f"Maisaka 回复器思考内容:\n{result.completion.reasoning_text}")
|
||||
|
||||
if not result.success:
|
||||
result.error_message = "回复器返回了空内容"
|
||||
logger.warning("Maisaka 回复器返回了空内容")
|
||||
return finalize(False)
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复器生成成功: 回复文本={response_text!r} "
|
||||
f"总耗时毫秒={result.metrics.overall_ms} "
|
||||
f"已选表达编号={result.selected_expression_ids!r}"
|
||||
)
|
||||
result.text_fragments = [response_text]
|
||||
return finalize(True)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import random
|
||||
|
||||
from rich.console import Group, RenderableType
|
||||
from rich.panel import Panel
|
||||
@@ -10,6 +11,7 @@ from rich.text import Text
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.cli.console import console
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.common.data_models.reply_generation_data_models import (
|
||||
@@ -19,18 +21,11 @@ from src.common.data_models.reply_generation_data_models import (
|
||||
build_reply_monitor_detail,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.config.model_configs import ModelInfo
|
||||
from src.core.types import ActionInfo
|
||||
from src.llm_models.payload_content.message import (
|
||||
ImageMessagePart,
|
||||
Message,
|
||||
MessageBuilder,
|
||||
RoleType,
|
||||
TextMessagePart,
|
||||
)
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.maisaka.context_messages import (
|
||||
AssistantMessage,
|
||||
LLMContextMessage,
|
||||
@@ -38,8 +33,9 @@ from src.maisaka.context_messages import (
|
||||
SessionBackedMessage,
|
||||
ToolResultMessage,
|
||||
)
|
||||
from src.maisaka.display.prompt_cli_renderer import PromptCLIVisualizer
|
||||
from src.maisaka.message_adapter import clone_message_sequence, parse_speaker_content
|
||||
from src.maisaka.prompt_cli_renderer import PromptCLIVisualizer
|
||||
from src.plugin_runtime.hook_payloads import serialize_prompt_messages
|
||||
|
||||
from .maisaka_expression_selector import maisaka_expression_selector
|
||||
|
||||
@@ -54,17 +50,26 @@ class MaisakaReplyContext:
|
||||
selected_expression_ids: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
class MaisakaReplyGenerator:
|
||||
"""生成 Maisaka 的最终可见回复(多模态管线)。"""
|
||||
class BaseMaisakaReplyGenerator:
|
||||
"""Maisaka replyer 的共享实现。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
request_type: str = "maisaka_replyer",
|
||||
llm_client_cls: Any,
|
||||
load_prompt_func: Callable[..., str],
|
||||
enable_visual_message: Optional[bool],
|
||||
replyer_mode: Literal["text", "multimodal", "auto"],
|
||||
) -> None:
|
||||
self.chat_stream = chat_stream
|
||||
self.request_type = request_type
|
||||
self.express_model = LLMServiceClient(
|
||||
self._llm_client_cls = llm_client_cls
|
||||
self._load_prompt = load_prompt_func
|
||||
self._enable_visual_message = enable_visual_message
|
||||
self._replyer_mode = replyer_mode
|
||||
self.express_model = llm_client_cls(
|
||||
task_name="replyer",
|
||||
request_type=request_type,
|
||||
)
|
||||
@@ -111,28 +116,6 @@ class MaisakaReplyGenerator:
|
||||
return self._normalize_content(body.strip())
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _split_user_message_segments(raw_content: str) -> List[tuple[Optional[str], str]]:
|
||||
segments: List[tuple[Optional[str], str]] = []
|
||||
current_speaker: Optional[str] = None
|
||||
current_lines: List[str] = []
|
||||
|
||||
for raw_line in raw_content.splitlines():
|
||||
speaker_name, content_body = parse_speaker_content(raw_line)
|
||||
if speaker_name is not None:
|
||||
if current_lines:
|
||||
segments.append((current_speaker, "\n".join(current_lines)))
|
||||
current_speaker = speaker_name
|
||||
current_lines = [content_body]
|
||||
continue
|
||||
|
||||
current_lines.append(raw_line)
|
||||
|
||||
if current_lines:
|
||||
segments.append((current_speaker, "\n".join(current_lines)))
|
||||
|
||||
return segments
|
||||
|
||||
def _build_target_message_block(self, reply_message: Optional[SessionMessage]) -> str:
|
||||
if reply_message is None:
|
||||
return ""
|
||||
@@ -152,19 +135,93 @@ class MaisakaReplyGenerator:
|
||||
"- 你这次要回复的就是这条目标消息,请结合整段上下文理解,但不要把其他历史消息当成当前回复对象。"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_chat_prompt_for_chat(chat_id: str, is_group_chat: Optional[bool]) -> str:
|
||||
"""根据聊天流 ID 获取匹配的额外 prompt。"""
|
||||
if not global_config.chat.chat_prompts:
|
||||
return ""
|
||||
|
||||
for chat_prompt_item in global_config.chat.chat_prompts:
|
||||
if hasattr(chat_prompt_item, "platform"):
|
||||
platform = str(chat_prompt_item.platform or "").strip()
|
||||
item_id = str(chat_prompt_item.item_id or "").strip()
|
||||
rule_type = str(chat_prompt_item.rule_type or "").strip()
|
||||
prompt_content = str(chat_prompt_item.prompt or "").strip()
|
||||
elif isinstance(chat_prompt_item, str):
|
||||
parts = chat_prompt_item.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
continue
|
||||
|
||||
platform, item_id, rule_type, prompt_content = parts
|
||||
platform = platform.strip()
|
||||
item_id = item_id.strip()
|
||||
rule_type = rule_type.strip()
|
||||
prompt_content = prompt_content.strip()
|
||||
else:
|
||||
continue
|
||||
|
||||
if not platform or not item_id or not prompt_content:
|
||||
continue
|
||||
|
||||
if rule_type == "group":
|
||||
config_is_group = True
|
||||
config_chat_id = SessionUtils.calculate_session_id(platform, group_id=item_id)
|
||||
elif rule_type == "private":
|
||||
config_is_group = False
|
||||
config_chat_id = SessionUtils.calculate_session_id(platform, user_id=item_id)
|
||||
else:
|
||||
continue
|
||||
|
||||
if config_is_group != is_group_chat:
|
||||
continue
|
||||
if config_chat_id == chat_id:
|
||||
return prompt_content
|
||||
|
||||
return ""
|
||||
|
||||
def _build_group_chat_attention_block(self, session_id: str) -> str:
|
||||
"""构建当前聊天场景下的额外注意事项块。"""
|
||||
if not session_id:
|
||||
return ""
|
||||
|
||||
try:
|
||||
is_group_chat, _ = get_chat_type_and_target_info(session_id)
|
||||
except Exception:
|
||||
is_group_chat = None
|
||||
|
||||
prompt_lines: List[str] = []
|
||||
|
||||
if is_group_chat is True:
|
||||
if group_chat_prompt := global_config.chat.group_chat_prompt.strip():
|
||||
prompt_lines.append(f"通用注意事项:\n{group_chat_prompt}")
|
||||
elif is_group_chat is False:
|
||||
if private_chat_prompt := global_config.chat.private_chat_prompts.strip():
|
||||
prompt_lines.append(f"通用注意事项:\n{private_chat_prompt}")
|
||||
|
||||
if chat_prompt := self._get_chat_prompt_for_chat(session_id, is_group_chat).strip():
|
||||
prompt_lines.append(f"当前聊天额外注意事项:\n{chat_prompt}")
|
||||
|
||||
if not prompt_lines:
|
||||
return ""
|
||||
|
||||
return "在该聊天中的注意事项:\n" + "\n\n".join(prompt_lines) + "\n"
|
||||
|
||||
def _build_system_prompt(
|
||||
self,
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
stream_id: Optional[str] = None,
|
||||
) -> str:
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
target_message_block = self._build_target_message_block(reply_message)
|
||||
session_id = self._resolve_session_id(stream_id)
|
||||
|
||||
try:
|
||||
system_prompt = load_prompt(
|
||||
system_prompt = self._load_prompt(
|
||||
"maisaka_replyer",
|
||||
bot_name=global_config.bot.nickname,
|
||||
group_chat_attention_block=self._build_group_chat_attention_block(session_id),
|
||||
time_block=f"当前时间:{current_time}",
|
||||
identity=self._personality_prompt,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
@@ -184,38 +241,35 @@ class MaisakaReplyGenerator:
|
||||
return f"{system_prompt}\n\n" + "\n\n".join(sections)
|
||||
|
||||
def _build_reply_instruction(self) -> str:
|
||||
return "请自然地回复。不要输出多余说明、括号、at 或额外标记,只输出实际要发送的内容。"
|
||||
return "请自然地回复。不要输出多余说明、括号、@ 或额外标记,只输出实际要发送的内容。"
|
||||
|
||||
def _build_multimodal_user_message(
|
||||
def _build_visual_user_message(
|
||||
self,
|
||||
message: SessionBackedMessage,
|
||||
default_user_name: str,
|
||||
enable_visual_message: bool,
|
||||
) -> Optional[Message]:
|
||||
speaker_name, _ = parse_speaker_content(message.processed_plain_text.strip())
|
||||
visible_speaker = speaker_name or default_user_name
|
||||
if not enable_visual_message:
|
||||
return None
|
||||
|
||||
raw_message = clone_message_sequence(message.raw_message)
|
||||
if not raw_message.components:
|
||||
raw_message = MessageSequence([TextComponent(f"[{visible_speaker}]")])
|
||||
elif isinstance(raw_message.components[0], TextComponent):
|
||||
first_text = raw_message.components[0].text or ""
|
||||
raw_message.components[0] = TextComponent(f"[{visible_speaker}]{first_text}")
|
||||
else:
|
||||
raw_message.components.insert(0, TextComponent(f"[{visible_speaker}]"))
|
||||
raw_message = MessageSequence([TextComponent(message.processed_plain_text)])
|
||||
|
||||
multimodal_message = SessionBackedMessage(
|
||||
visual_message = SessionBackedMessage(
|
||||
raw_message=raw_message,
|
||||
visible_text=f"[{visible_speaker}]{message.processed_plain_text}",
|
||||
visible_text=message.processed_plain_text,
|
||||
timestamp=message.timestamp,
|
||||
message_id=message.message_id,
|
||||
original_message=message.original_message,
|
||||
source_kind=message.source_kind,
|
||||
)
|
||||
return multimodal_message.to_llm_message()
|
||||
return visual_message.to_llm_message()
|
||||
|
||||
def _build_history_messages(self, chat_history: List[LLMContextMessage]) -> List[Message]:
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
default_user_name = global_config.maisaka.cli_user_name.strip() or "User"
|
||||
def _build_history_messages(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
enable_visual_message: bool,
|
||||
) -> List[Message]:
|
||||
messages: List[Message] = []
|
||||
|
||||
for message in chat_history:
|
||||
@@ -230,25 +284,14 @@ class MaisakaReplyGenerator:
|
||||
)
|
||||
continue
|
||||
|
||||
multimodal_message = self._build_multimodal_user_message(message, default_user_name)
|
||||
if multimodal_message is not None:
|
||||
messages.append(multimodal_message)
|
||||
visual_message = self._build_visual_user_message(message, enable_visual_message)
|
||||
if visual_message is not None:
|
||||
messages.append(visual_message)
|
||||
continue
|
||||
|
||||
for speaker_name, content_body in self._split_user_message_segments(message.processed_plain_text):
|
||||
content = self._normalize_content(content_body)
|
||||
if not content:
|
||||
continue
|
||||
|
||||
visible_speaker = speaker_name or default_user_name
|
||||
if visible_speaker == bot_nickname:
|
||||
messages.append(
|
||||
MessageBuilder().set_role(RoleType.Assistant).add_text_content(content).build()
|
||||
)
|
||||
continue
|
||||
|
||||
user_content = f"[{visible_speaker}]{content}"
|
||||
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(user_content).build())
|
||||
llm_message = message.to_llm_message()
|
||||
if llm_message is not None:
|
||||
messages.append(llm_message)
|
||||
continue
|
||||
|
||||
if isinstance(message, AssistantMessage):
|
||||
@@ -266,34 +309,33 @@ class MaisakaReplyGenerator:
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
stream_id: Optional[str] = None,
|
||||
enable_visual_message: bool = False,
|
||||
) -> List[Message]:
|
||||
messages: List[Message] = []
|
||||
system_prompt = self._build_system_prompt(
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
expression_habits=expression_habits,
|
||||
stream_id=stream_id,
|
||||
)
|
||||
instruction = self._build_reply_instruction()
|
||||
|
||||
messages.append(MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build())
|
||||
messages.extend(self._build_history_messages(chat_history))
|
||||
messages.extend(self._build_history_messages(chat_history, enable_visual_message))
|
||||
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(instruction).build())
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _build_request_prompt_preview(messages: List[Message]) -> str:
|
||||
preview_lines: List[str] = []
|
||||
for message in messages:
|
||||
role_name = message.role.value.capitalize()
|
||||
part_previews: List[str] = []
|
||||
for part in message.parts:
|
||||
if isinstance(part, TextMessagePart):
|
||||
part_previews.append(part.text)
|
||||
continue
|
||||
if isinstance(part, ImageMessagePart):
|
||||
part_previews.append(f"[图片:{part.normalized_image_format}]")
|
||||
preview_lines.append(f"{role_name}: {''.join(part_previews)}")
|
||||
return "\n\n".join(preview_lines)
|
||||
def _resolve_enable_visual_message(self, model_info: Optional[ModelInfo] = None) -> bool:
|
||||
if self._enable_visual_message is not None:
|
||||
return self._enable_visual_message
|
||||
if self._replyer_mode == "multimodal":
|
||||
if model_info is not None and not model_info.visual:
|
||||
raise ValueError(f"replyer_mode=multimodal,但模型 '{model_info.name}' 未开启 visual,无法使用多模态 replyer")
|
||||
return True
|
||||
if self._replyer_mode == "text":
|
||||
return False
|
||||
return bool(model_info.visual) if model_info is not None else False
|
||||
|
||||
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
|
||||
if stream_id:
|
||||
@@ -349,7 +391,6 @@ class MaisakaReplyGenerator:
|
||||
selected_expression_ids: Optional[List[int]] = None,
|
||||
sub_agent_runner: Optional[Callable[[str], Awaitable[str]]] = None,
|
||||
) -> Tuple[bool, ReplyGenerationResult]:
|
||||
|
||||
def finalize(success_value: bool) -> Tuple[bool, ReplyGenerationResult]:
|
||||
result.monitor_detail = build_reply_monitor_detail(result)
|
||||
return success_value, result
|
||||
@@ -411,7 +452,7 @@ class MaisakaReplyGenerator:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"回复上下文完成: 流={stream_id} 已选表达={result.selected_expression_ids!r}"
|
||||
f"回复上下文完成 流={stream_id} 已选表达={result.selected_expression_ids!r}"
|
||||
)
|
||||
|
||||
prompt_started_at = time.perf_counter()
|
||||
@@ -421,6 +462,7 @@ class MaisakaReplyGenerator:
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
expression_habits=merged_expression_habits,
|
||||
stream_id=stream_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
@@ -433,24 +475,36 @@ class MaisakaReplyGenerator:
|
||||
return finalize(False)
|
||||
|
||||
prompt_ms = round((time.perf_counter() - prompt_started_at) * 1000, 2)
|
||||
prompt_preview = self._build_request_prompt_preview(request_messages)
|
||||
prompt_preview = PromptCLIVisualizer._build_prompt_dump_text(request_messages)
|
||||
show_replyer_prompt = bool(getattr(global_config.debug, "show_replyer_prompt", False))
|
||||
show_replyer_reasoning = bool(getattr(global_config.debug, "show_replyer_reasoning", False))
|
||||
|
||||
def message_factory(_client: object) -> List[Message]:
|
||||
def message_factory(_client: object, model_info: Optional[ModelInfo] = None) -> List[Message]:
|
||||
nonlocal prompt_ms, prompt_preview, request_messages
|
||||
prompt_started_at = time.perf_counter()
|
||||
request_messages = self._build_request_messages(
|
||||
chat_history=filtered_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
expression_habits=merged_expression_habits,
|
||||
stream_id=stream_id,
|
||||
enable_visual_message=self._resolve_enable_visual_message(model_info),
|
||||
)
|
||||
prompt_ms = round((time.perf_counter() - prompt_started_at) * 1000, 2)
|
||||
prompt_preview = PromptCLIVisualizer._build_prompt_dump_text(request_messages)
|
||||
return request_messages
|
||||
|
||||
result.completion.request_prompt = prompt_preview
|
||||
preview_chat_id = self._resolve_session_id(stream_id)
|
||||
replyer_prompt_section: RenderableType | None = None
|
||||
if show_replyer_prompt:
|
||||
replyer_prompt_section = Panel(
|
||||
PromptCLIVisualizer.build_text_access_panel(
|
||||
prompt_preview,
|
||||
PromptCLIVisualizer.build_prompt_access_panel(
|
||||
request_messages,
|
||||
category="replyer",
|
||||
chat_id=preview_chat_id,
|
||||
request_kind="replyer",
|
||||
subtitle=f"流ID: {preview_chat_id}",
|
||||
selection_reason=f"ID: {preview_chat_id}",
|
||||
image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy",
|
||||
),
|
||||
title="Reply Prompt",
|
||||
border_style="bright_yellow",
|
||||
@@ -472,6 +526,8 @@ class MaisakaReplyGenerator:
|
||||
)
|
||||
return finalize(False)
|
||||
|
||||
result.completion.request_prompt = prompt_preview
|
||||
result.request_messages = serialize_prompt_messages(request_messages)
|
||||
llm_ms = round((time.perf_counter() - llm_started_at) * 1000, 2)
|
||||
response_text = (generation_result.response or "").strip()
|
||||
result.success = bool(response_text)
|
||||
@@ -504,7 +560,7 @@ class MaisakaReplyGenerator:
|
||||
return finalize(False)
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复器生成成功: 文本={response_text!r} "
|
||||
f"Maisaka 回复器生成成功 文本={response_text!r} "
|
||||
f"总耗时ms={result.metrics.overall_ms} 已选表达={result.selected_expression_ids!r}"
|
||||
)
|
||||
if show_replyer_prompt or show_replyer_reasoning:
|
||||
@@ -1,21 +0,0 @@
|
||||
from typing import Type
|
||||
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
def get_maisaka_replyer_class() -> Type[object]:
|
||||
"""根据配置返回 Maisaka replyer 类。"""
|
||||
generator_type = get_maisaka_replyer_generator_type()
|
||||
if generator_type == "multimodal":
|
||||
from .maisaka_generator_multi import MaisakaReplyGenerator
|
||||
|
||||
return MaisakaReplyGenerator
|
||||
|
||||
from .maisaka_generator import MaisakaReplyGenerator
|
||||
|
||||
return MaisakaReplyGenerator
|
||||
|
||||
|
||||
def get_maisaka_replyer_generator_type() -> str:
|
||||
"""返回当前配置的 Maisaka replyer 生成器类型。"""
|
||||
return "multimodal" if global_config.visual.multimodal_replyer else "legacy"
|
||||
@@ -1,15 +1,10 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
from src.chat.replyer.maisaka_replyer_factory import (
|
||||
get_maisaka_replyer_class,
|
||||
get_maisaka_replyer_generator_type,
|
||||
)
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
from .maisaka_generator import MaisakaReplyGenerator
|
||||
|
||||
logger = get_logger("ReplyerManager")
|
||||
|
||||
@@ -20,20 +15,25 @@ class ReplyerManager:
|
||||
def __init__(self) -> None:
|
||||
self._repliers: Dict[str, Any] = {}
|
||||
|
||||
@staticmethod
|
||||
def _get_maisaka_generator_type() -> str:
|
||||
"""返回当前配置下 Maisaka replyer 的消息模式。"""
|
||||
return global_config.visual.replyer_mode
|
||||
|
||||
def get_replyer(
|
||||
self,
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
request_type: str = "replyer",
|
||||
replyer_type: str = "default",
|
||||
) -> Optional["DefaultReplyer | PrivateReplyer | Any"]:
|
||||
) -> Optional[MaisakaReplyGenerator]:
|
||||
"""按会话和 replyer 类型获取实例。"""
|
||||
stream_id = chat_stream.session_id if chat_stream else chat_id
|
||||
if not stream_id:
|
||||
logger.warning("[ReplyerManager] 缺少 stream_id,无法获取 replyer")
|
||||
return None
|
||||
|
||||
generator_type = get_maisaka_replyer_generator_type() if replyer_type == "maisaka" else ""
|
||||
generator_type = self._get_maisaka_generator_type() if replyer_type == "maisaka" else ""
|
||||
cache_key = f"{replyer_type}:{generator_type}:{stream_id}"
|
||||
if cache_key in self._repliers:
|
||||
logger.info(f"[ReplyerManager] 命中缓存 replyer: cache_key={cache_key}")
|
||||
@@ -51,29 +51,13 @@ class ReplyerManager:
|
||||
|
||||
try:
|
||||
if replyer_type == "maisaka":
|
||||
logger.info(f"[ReplyerManager] 选择 MaisakaReplyGenerator: generator_type={generator_type}")
|
||||
maisaka_replyer_class = get_maisaka_replyer_class()
|
||||
|
||||
replyer = maisaka_replyer_class(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
)
|
||||
elif target_stream.is_group_session:
|
||||
logger.info("[ReplyerManager] importing DefaultReplyer")
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
|
||||
replyer = DefaultReplyer(
|
||||
replyer = MaisakaReplyGenerator(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
)
|
||||
else:
|
||||
logger.info("[ReplyerManager] importing PrivateReplyer")
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
|
||||
replyer = PrivateReplyer(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
)
|
||||
logger.warning(f"[ReplyerManager] 不支持的 replyer_type={replyer_type}")
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception(f"[ReplyerManager] 创建 replyer 失败: cache_key={cache_key}")
|
||||
raise
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("common_utils")
|
||||
|
||||
@@ -10,23 +10,14 @@ class TempMethodsExpression:
|
||||
"""用于临时存放一些方法的类"""
|
||||
|
||||
@staticmethod
|
||||
def get_expression_config_for_chat(chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
|
||||
"""
|
||||
根据聊天流ID获取表达配置
|
||||
|
||||
Args:
|
||||
chat_stream_id: 聊天流ID,格式为哈希值
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
|
||||
"""
|
||||
def _find_expression_config_item(chat_stream_id: Optional[str] = None):
|
||||
if not global_config.expression.learning_list:
|
||||
return True, True, True
|
||||
return None
|
||||
|
||||
if chat_stream_id:
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
continue # 这是全局的
|
||||
continue
|
||||
stream_id = TempMethodsExpression._get_stream_id(
|
||||
config_item.platform,
|
||||
str(config_item.item_id),
|
||||
@@ -34,14 +25,44 @@ class TempMethodsExpression:
|
||||
)
|
||||
if stream_id is None:
|
||||
continue
|
||||
if stream_id == chat_stream_id:
|
||||
if stream_id != chat_stream_id:
|
||||
continue
|
||||
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
|
||||
return config_item
|
||||
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
|
||||
return config_item
|
||||
|
||||
return True, True, True
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_expression_advanced_chosen_for_chat(chat_stream_id: Optional[str] = None) -> bool:
|
||||
"""根据聊天流 ID 获取表达方式是否启用二次选择。"""
|
||||
config_item = TempMethodsExpression._find_expression_config_item(chat_stream_id)
|
||||
if config_item is None:
|
||||
return False
|
||||
return config_item.advanced_chosen
|
||||
|
||||
@staticmethod
|
||||
def get_expression_config_for_chat(chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
|
||||
"""
|
||||
根据聊天流 ID 获取表达配置。
|
||||
|
||||
Args:
|
||||
chat_stream_id: 聊天流 ID,格式为哈希值
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用 jargon 学习)
|
||||
"""
|
||||
config_item = TempMethodsExpression._find_expression_config_item(chat_stream_id)
|
||||
if config_item is None:
|
||||
return True, True, True
|
||||
|
||||
return (
|
||||
config_item.use_expression,
|
||||
config_item.enable_learning,
|
||||
config_item.enable_jargon_learning,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_stream_id(
|
||||
@@ -50,15 +71,15 @@ class TempMethodsExpression:
|
||||
is_group: bool = False,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
根据平台、ID字符串和是否为群聊生成聊天流ID
|
||||
根据平台、ID 字符串和是否为群聊生成聊天流 ID。
|
||||
|
||||
Args:
|
||||
platform: 平台名称
|
||||
id_str: 用户或群组的原始ID字符串
|
||||
id_str: 用户或群组的原始 ID 字符串
|
||||
is_group: 是否为群聊
|
||||
|
||||
Returns:
|
||||
str: 生成的聊天流ID(哈希值)
|
||||
str: 生成的聊天流 ID(哈希值)
|
||||
"""
|
||||
try:
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
@@ -68,5 +89,5 @@ class TempMethodsExpression:
|
||||
else:
|
||||
return SessionUtils.calculate_session_id(platform, user_id=str(id_str))
|
||||
except Exception as e:
|
||||
logger.error(f"生成聊天流ID失败: {e}")
|
||||
logger.error(f"生成聊天流 ID 失败: {e}")
|
||||
return None
|
||||
|
||||
@@ -20,7 +20,7 @@ from src.services.embedding_service import EmbeddingServiceClient
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||
from src.common.data_models.chat_target_info_data_model import ChatTargetInfo
|
||||
|
||||
logger = get_logger("chat_utils")
|
||||
_warned_unconfigured_platforms: set[str] = set()
|
||||
@@ -699,7 +699,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
|
||||
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
||||
|
||||
|
||||
def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetPersonInfo"]]:
|
||||
def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["ChatTargetInfo"]]:
|
||||
"""
|
||||
获取聊天类型(是否群聊)和私聊对象信息。
|
||||
|
||||
@@ -734,13 +734,13 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
|
||||
):
|
||||
user_nickname = chat_stream.context.message.message_info.user_info.user_nickname
|
||||
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo # 解决循环导入问题
|
||||
from src.common.data_models.chat_target_info_data_model import ChatTargetInfo # 解决循环导入问题
|
||||
|
||||
# Initialize target_info with basic info
|
||||
target_info = TargetPersonInfo(
|
||||
target_info = ChatTargetInfo(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname, # type: ignore
|
||||
session_nickname=user_nickname or "",
|
||||
person_id=None,
|
||||
person_name=None,
|
||||
)
|
||||
@@ -752,6 +752,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
|
||||
logger.warning(f"用户 {user_nickname} 尚未认识")
|
||||
# 如果用户尚未认识,则返回False和None
|
||||
return False, None
|
||||
target_info.is_known = True
|
||||
if person.person_id:
|
||||
target_info.person_id = person.person_id
|
||||
target_info.person_name = person.person_name
|
||||
|
||||
@@ -24,125 +24,148 @@ logger = get_logger("emoji")
|
||||
class BaseImageDataModel(BaseDatabaseDataModel[Images]):
|
||||
def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None):
|
||||
if not full_path:
|
||||
# 创建时候即检测文件路径合法性
|
||||
raise ValueError("表情包路径不能为空")
|
||||
raise ValueError("图片路径不能为空")
|
||||
if Path(full_path).is_dir() or not Path(full_path).exists():
|
||||
raise FileNotFoundError(f"表情包路径无效: {full_path}")
|
||||
raise FileNotFoundError(f"图片路径无效: {full_path}")
|
||||
|
||||
resolved_path = Path(full_path).absolute().resolve()
|
||||
self.full_path: Path
|
||||
self.dir_path: Path
|
||||
self.file_name: str
|
||||
self._set_full_path(resolved_path)
|
||||
|
||||
self.file_hash: str = None # type: ignore
|
||||
|
||||
self.image_bytes: Optional[bytes] = image_bytes
|
||||
|
||||
self.image_format: str = "" # 图片格式
|
||||
self.image_format: str = ""
|
||||
|
||||
def _set_full_path(self, full_path: Path) -> None:
|
||||
"""同步更新文件路径相关的运行时元数据。"""
|
||||
"""同步刷新路径、目录和文件名等运行时元数据。"""
|
||||
resolved_path = full_path.absolute().resolve()
|
||||
self.full_path = resolved_path
|
||||
self.dir_path = resolved_path.parent.resolve()
|
||||
self.file_name = resolved_path.name
|
||||
|
||||
def _restore_image_format_from_path(self) -> None:
|
||||
"""根据文件扩展名恢复基础图片格式信息。"""
|
||||
"""根据文件扩展名恢复图片格式信息。"""
|
||||
self.image_format = self.full_path.suffix.removeprefix(".").lower()
|
||||
|
||||
def _build_non_conflicting_path(self, target_path: Path) -> Path:
|
||||
"""在目标路径被占用时,生成一个可用的新路径。"""
|
||||
candidate_path = target_path
|
||||
index = 1
|
||||
while candidate_path.exists():
|
||||
candidate_path = target_path.with_name(
|
||||
f"{target_path.stem}_{self.file_hash[:8]}_{index}{target_path.suffix}"
|
||||
)
|
||||
index += 1
|
||||
return candidate_path
|
||||
|
||||
def _rename_file_to_match_format(self) -> None:
|
||||
"""修正文件扩展名,并处理目标文件已存在的冲突。"""
|
||||
new_file_name = ".".join(self.file_name.split(".")[:-1] + [self.image_format])
|
||||
new_full_path = self.dir_path / new_file_name
|
||||
if new_full_path == self.full_path:
|
||||
return
|
||||
|
||||
if new_full_path.exists():
|
||||
existing_file_hash = hashlib.sha256(self.read_image_bytes(new_full_path)).hexdigest()
|
||||
if existing_file_hash == self.file_hash:
|
||||
logger.info(f"[初始化] {new_full_path.name} 已存在且内容一致,复用已有文件")
|
||||
self.full_path.unlink()
|
||||
self._set_full_path(new_full_path)
|
||||
return
|
||||
|
||||
conflict_free_path = self._build_non_conflicting_path(new_full_path)
|
||||
logger.warning(
|
||||
f"[初始化] {new_full_path.name} 已存在且内容不同,改为保存到 {conflict_free_path.name}"
|
||||
)
|
||||
self.full_path.rename(conflict_free_path)
|
||||
self._set_full_path(conflict_free_path)
|
||||
return
|
||||
|
||||
self.full_path.rename(new_full_path)
|
||||
self._set_full_path(new_full_path)
|
||||
|
||||
def read_image_bytes(self, path: Path) -> bytes:
|
||||
"""
|
||||
同步读取图片文件的字节内容
|
||||
同步读取图片文件的字节内容。
|
||||
|
||||
Args:
|
||||
path (Path): 图片文件的完整路径
|
||||
path: 图片文件的完整路径。
|
||||
|
||||
Returns:
|
||||
return (bytes): 图片文件的字节内容
|
||||
Raises:
|
||||
FileNotFoundError: 如果文件不存在则抛出该异常
|
||||
Exception: 其他读取文件时发生的异常
|
||||
图片文件的字节内容。
|
||||
"""
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError as e:
|
||||
with open(path, "rb") as file:
|
||||
return file.read()
|
||||
except FileNotFoundError as exc:
|
||||
logger.error(f"[读取图片文件] 文件未找到: {path}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"[读取图片文件] 读取文件时发生错误: {e}")
|
||||
raise e
|
||||
raise exc
|
||||
except Exception as exc:
|
||||
logger.error(f"[读取图片文件] 读取文件时发生错误: {exc}")
|
||||
raise exc
|
||||
|
||||
def get_image_format(self, image_bytes: bytes) -> str:
|
||||
"""
|
||||
获取图片的格式
|
||||
获取图片的实际格式。
|
||||
|
||||
Args:
|
||||
image_bytes (bytes): 图片的字节内容
|
||||
image_bytes: 图片的字节内容。
|
||||
|
||||
Returns:
|
||||
return (str): 图片的格式(小写)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果无法识别图片格式
|
||||
Exception: 其他读取图片格式时发生的异常
|
||||
小写格式名,例如 `png`、`jpeg`。
|
||||
"""
|
||||
try:
|
||||
with PILImage.open(io.BytesIO(image_bytes)) as img:
|
||||
if not img.format:
|
||||
raise ValueError("无法识别图片格式")
|
||||
return img.format.lower()
|
||||
except Exception as e:
|
||||
logger.error(f"[获取图片格式] 读取图片格式时发生错误: {e}")
|
||||
raise e
|
||||
except Exception as exc:
|
||||
logger.error(f"[获取图片格式] 读取图片格式时发生错误: {exc}")
|
||||
raise exc
|
||||
|
||||
async def calculate_hash_format(self) -> bool:
|
||||
"""
|
||||
异步计算表情包的哈希值和格式,初始化后应该执行此方法来确保对象的哈希值和格式正确
|
||||
计算图片哈希和实际格式,并在需要时修正扩展名。
|
||||
|
||||
Returns:
|
||||
return (bool): 如果成功计算哈希值和格式则返回True,否则返回False
|
||||
成功返回 `True`,失败返回 `False`。
|
||||
"""
|
||||
try:
|
||||
# 计算哈希值
|
||||
logger.debug(f"[初始化] 计算 {self.file_name} 的哈希值...")
|
||||
if not self.image_bytes:
|
||||
if self.image_bytes is None:
|
||||
logger.debug(f"[初始化] 正在读取文件: {self.full_path}")
|
||||
image_bytes = await asyncio.to_thread(self.read_image_bytes, self.full_path)
|
||||
else:
|
||||
image_bytes = self.image_bytes
|
||||
|
||||
self.image_bytes = image_bytes
|
||||
self.file_hash = hashlib.sha256(image_bytes).hexdigest()
|
||||
logger.debug(f"[初始化] {self.file_name} 计算哈希值成功: {self.file_hash}")
|
||||
|
||||
# 用PIL读取图片格式
|
||||
logger.debug(f"[初始化] 读取 {self.file_name} 的图片格式...")
|
||||
self.image_format = await asyncio.to_thread(self.get_image_format, image_bytes)
|
||||
logger.debug(f"[初始化] {self.file_name} 读取图片格式成功: {self.image_format}")
|
||||
|
||||
# 比对文件扩展名和实际格式
|
||||
file_ext = self.file_name.split(".")[-1].lower()
|
||||
if file_ext != self.image_format:
|
||||
logger.warning(
|
||||
f"[初始化] {self.file_name} 文件扩展名与实际格式不符: ext`{file_ext}`!=`{self.image_format}`"
|
||||
)
|
||||
# 重命名文件以匹配实际格式
|
||||
new_file_name = ".".join(self.file_name.split(".")[:-1] + [self.image_format])
|
||||
new_full_path = self.dir_path / new_file_name
|
||||
self.full_path.rename(new_full_path)
|
||||
self._set_full_path(new_full_path)
|
||||
self._rename_file_to_match_format()
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[初始化] 初始化图片时发生错误: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"[初始化] 初始化图片时发生错误: {exc}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
|
||||
class MaiEmoji(BaseImageDataModel):
|
||||
"""麦麦的表情包对象,仅当**图片文件存在**时才应该创建此对象,数据库记录如果标记为文件不存在`(no_file_flag = True)`则不应该调用 `from_db_instance` 方法来创建此对象"""
|
||||
"""表情包数据模型。"""
|
||||
|
||||
def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None):
|
||||
# self.embedding = []
|
||||
self.description: str = ""
|
||||
self.emotion: List[str] = []
|
||||
self.query_count = 0
|
||||
@@ -152,33 +175,26 @@ class MaiEmoji(BaseImageDataModel):
|
||||
|
||||
@classmethod
|
||||
def from_db_instance(cls, db_record: Images):
|
||||
"""从数据库记录创建 MaiEmoji 对象,如果记录标记为文件不存在则**抛出异常**
|
||||
|
||||
调用者应该对数据库记录进行检查,如果 `no_file_flag` 为 True 则不应该调用此方法
|
||||
|
||||
Args:
|
||||
db_record (Images): 数据库中的图片记录
|
||||
Returns:
|
||||
return (MaiEmoji): 包含图片信息的 MaiEmoji 对象
|
||||
Raises:
|
||||
ValueError: 如果数据库记录标记为文件不存在则抛出该异常
|
||||
"""
|
||||
"""从数据库记录构建 `MaiEmoji` 对象。"""
|
||||
if db_record.no_file_flag:
|
||||
raise ValueError(f"数据库记录 {db_record.image_hash} 标记为文件不存在,无法创建 MaiEmoji 对象")
|
||||
|
||||
obj = cls(db_record.full_path)
|
||||
obj.file_hash = db_record.image_hash
|
||||
obj._restore_image_format_from_path()
|
||||
|
||||
description = db_record.description or ""
|
||||
obj.description = description
|
||||
normalized_tags = [
|
||||
str(item).strip()
|
||||
for item in str(description).replace(",", ",").replace("、", ",").replace(";", ",").split(",")
|
||||
for item in str(description).replace(",", ",").replace("。", ",").replace("、", ",").split(",")
|
||||
if str(item).strip()
|
||||
]
|
||||
deduped_tags: List[str] = []
|
||||
for item in normalized_tags:
|
||||
if item not in deduped_tags:
|
||||
deduped_tags.append(item)
|
||||
|
||||
obj.emotion = deduped_tags
|
||||
obj.query_count = db_record.query_count
|
||||
obj.last_used_time = db_record.last_used_time
|
||||
@@ -198,7 +214,7 @@ class MaiEmoji(BaseImageDataModel):
|
||||
|
||||
|
||||
class MaiImage(BaseImageDataModel):
|
||||
"""麦麦图片数据模型,仅当**图片文件存在**时才应该创建此对象,数据库记录如果标记为文件不存在`(no_file_flag = True)`则不应该调用 `from_db_instance` 方法来创建此对象"""
|
||||
"""普通图片数据模型。"""
|
||||
|
||||
def __init__(self, full_path: str | Path, image_bytes: Optional[bytes] = None):
|
||||
self.description: str = ""
|
||||
@@ -207,19 +223,10 @@ class MaiImage(BaseImageDataModel):
|
||||
|
||||
@classmethod
|
||||
def from_db_instance(cls, db_record: Images):
|
||||
"""从数据库记录创建 MaiImage 对象,如果记录标记为文件不存在则**抛出异常**
|
||||
|
||||
调用者应该对数据库记录进行检查,如果 `no_file_flag` 为 True 则不应该调用此方法
|
||||
|
||||
Args:
|
||||
db_record (Images): 数据库中的图片记录
|
||||
Returns:
|
||||
return (MaiImage): 包含图片信息的 MaiImage 对象
|
||||
Raises:
|
||||
ValueError: 如果数据库记录标记为文件不存在则抛出该异常
|
||||
"""
|
||||
"""从数据库记录构建 `MaiImage` 对象。"""
|
||||
if db_record.no_file_flag:
|
||||
raise ValueError(f"数据库记录 {db_record.image_hash} 标记为文件不存在,无法创建 MaiImage 对象")
|
||||
|
||||
obj = cls(db_record.full_path)
|
||||
obj.file_hash = db_record.image_hash
|
||||
obj._set_full_path(Path(db_record.full_path))
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
# from dataclasses import dataclass, field
|
||||
# from typing import Optional, Dict, TYPE_CHECKING
|
||||
# from . import BaseDataModel
|
||||
|
||||
# if TYPE_CHECKING:
|
||||
# from .database_data_model import DatabaseMessages
|
||||
# from src.core.types import ActionInfo
|
||||
|
||||
|
||||
# # @dataclass
|
||||
# # class TargetPersonInfo(BaseDataModel):
|
||||
# # platform: str = field(default_factory=str)
|
||||
# # user_id: str = field(default_factory=str)
|
||||
# # user_nickname: str = field(default_factory=str)
|
||||
# # person_id: Optional[str] = None
|
||||
# # person_name: Optional[str] = None
|
||||
# 已重构,见src/common/data_models/chat_target_info_data_model.py
|
||||
|
||||
# @dataclass
|
||||
# class ActionPlannerInfo(BaseDataModel):
|
||||
# action_type: str = field(default_factory=str)
|
||||
# reasoning: Optional[str] = None
|
||||
# action_data: Optional[Dict] = None
|
||||
# action_message: Optional["DatabaseMessages"] = None
|
||||
# available_actions: Optional[Dict[str, "ActionInfo"]] = None
|
||||
# loop_start_time: Optional[float] = None
|
||||
# action_reasoning: Optional[str] = None
|
||||
# 已重构,见src/common/data_models/planned_action_data_models.py
|
||||
@@ -14,7 +14,6 @@ from src.llm_models.payload_content.resp_format import RespFormat
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message
|
||||
|
||||
|
||||
@@ -24,7 +23,7 @@ PromptMessage: TypeAlias = Dict[str, Any]
|
||||
PromptInput: TypeAlias = str | List[PromptMessage]
|
||||
"""统一的提示输入类型。"""
|
||||
|
||||
MessageFactory: TypeAlias = Callable[["BaseClient"], List["Message"]]
|
||||
MessageFactory: TypeAlias = Callable[..., List["Message"]]
|
||||
"""统一的消息工厂类型。"""
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from . import BaseDataModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.message_component_data_model import MessageSequence
|
||||
from src.common.data_models.llm_service_data_models import PromptMessage
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
|
||||
|
||||
@@ -121,6 +122,10 @@ class ReplyGenerationResult(BaseDataModel):
|
||||
default=None,
|
||||
metadata={"description": "供监控层直接消费的通用 tool 展示详情。"},
|
||||
)
|
||||
request_messages: List["PromptMessage"] = field(
|
||||
default_factory=list,
|
||||
metadata={"description": "本次 replyer 实际发送给模型的消息列表。"},
|
||||
)
|
||||
|
||||
|
||||
def build_reply_monitor_detail(result: ReplyGenerationResult) -> Dict[str, Any]:
|
||||
@@ -133,6 +138,8 @@ def build_reply_monitor_detail(result: ReplyGenerationResult) -> Dict[str, Any]:
|
||||
|
||||
if prompt_text:
|
||||
detail["prompt_text"] = prompt_text
|
||||
if result.request_messages:
|
||||
detail["request_messages"] = result.request_messages
|
||||
if reasoning_text:
|
||||
detail["reasoning_text"] = reasoning_text
|
||||
if output_text:
|
||||
|
||||
@@ -84,9 +84,9 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: TimestampMode
|
||||
|
||||
def calculate_typing_time(
|
||||
input_string: str,
|
||||
chinese_time: float = 0.3,
|
||||
english_time: float = 0.15,
|
||||
line_break_time: float = 0.1,
|
||||
chinese_time: float = 0.2,
|
||||
english_time: float = 0.1,
|
||||
line_break_time: float = 0.05,
|
||||
is_emoji: bool = False,
|
||||
) -> float:
|
||||
"""
|
||||
|
||||
@@ -10,24 +10,14 @@ logger = get_logger("config_utils")
|
||||
|
||||
class ExpressionConfigUtils:
|
||||
@staticmethod
|
||||
def get_expression_config_for_chat(session_id: Optional[str] = None) -> tuple[bool, bool, bool]:
|
||||
# sourcery skip: use-next
|
||||
"""
|
||||
根据聊天会话ID获取表达配置
|
||||
|
||||
Args:
|
||||
session_id: 聊天会话ID,格式为哈希值
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
|
||||
"""
|
||||
def _find_expression_config_item(session_id: Optional[str] = None):
|
||||
if not global_config.expression.learning_list:
|
||||
return True, True, True
|
||||
return None
|
||||
|
||||
if session_id:
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
continue # 这是全局的
|
||||
continue
|
||||
stream_id = ExpressionConfigUtils._get_stream_id(
|
||||
config_item.platform,
|
||||
str(config_item.item_id),
|
||||
@@ -35,28 +25,59 @@ class ExpressionConfigUtils:
|
||||
)
|
||||
if stream_id is None:
|
||||
continue
|
||||
if stream_id == session_id:
|
||||
if stream_id != session_id:
|
||||
continue
|
||||
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
|
||||
return config_item
|
||||
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
|
||||
return config_item
|
||||
|
||||
return True, True, True
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_expression_advanced_chosen_for_chat(session_id: Optional[str] = None) -> bool:
|
||||
"""根据聊天会话 ID 获取表达方式是否启用二次选择。"""
|
||||
config_item = ExpressionConfigUtils._find_expression_config_item(session_id)
|
||||
if config_item is None:
|
||||
return False
|
||||
return config_item.advanced_chosen
|
||||
|
||||
@staticmethod
|
||||
def get_expression_config_for_chat(session_id: Optional[str] = None) -> tuple[bool, bool, bool]:
|
||||
# sourcery skip: use-next
|
||||
"""
|
||||
根据聊天会话 ID 获取表达配置。
|
||||
|
||||
Args:
|
||||
session_id: 聊天会话 ID,格式为哈希值
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用 jargon 学习)
|
||||
"""
|
||||
config_item = ExpressionConfigUtils._find_expression_config_item(session_id)
|
||||
if config_item is None:
|
||||
return True, True, True
|
||||
|
||||
return (
|
||||
config_item.use_expression,
|
||||
config_item.enable_learning,
|
||||
config_item.enable_jargon_learning,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_stream_id(platform: str, id_str: str, is_group: bool = False) -> Optional[str]:
|
||||
# sourcery skip: remove-unnecessary-cast
|
||||
"""
|
||||
根据平台、ID字符串和是否为群聊生成聊天流ID
|
||||
根据平台、ID 字符串和是否为群聊生成聊天流 ID。
|
||||
|
||||
Args:
|
||||
platform: 平台名称
|
||||
id_str: 用户或群组的原始ID字符串
|
||||
id_str: 用户或群组的原始 ID 字符串
|
||||
is_group: 是否为群聊
|
||||
|
||||
Returns:
|
||||
str: 生成的聊天流ID(哈希值)
|
||||
str: 生成的聊天流 ID(哈希值)
|
||||
"""
|
||||
try:
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
@@ -66,7 +87,7 @@ class ExpressionConfigUtils:
|
||||
else:
|
||||
return SessionUtils.calculate_session_id(platform, user_id=str(id_str))
|
||||
except Exception as e:
|
||||
logger.error(f"生成聊天流ID失败: {e}")
|
||||
logger.error(f"生成聊天流 ID 失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -91,7 +112,7 @@ class ChatConfigUtils:
|
||||
else:
|
||||
rule_session_id = SessionUtils.calculate_session_id(rule.platform, user_id=str(rule.item_id))
|
||||
if rule_session_id != session_id:
|
||||
continue # 不匹配的会话ID,跳过
|
||||
continue # 不匹配的会话 ID,跳过
|
||||
parsed_range = ChatConfigUtils.parse_range(rule.time)
|
||||
if not parsed_range:
|
||||
continue # 无法解析的时间范围,跳过
|
||||
@@ -102,7 +123,7 @@ class ChatConfigUtils:
|
||||
else: # 跨天的时间范围
|
||||
in_range = now_min >= start_min or now_min <= end_min
|
||||
if in_range:
|
||||
return rule.value or 0.0 # 如果规则生效但没有设置值,返回0.0
|
||||
return rule.value or 0.0 # 如果规则生效但没有设置值,返回 0.0
|
||||
|
||||
# 没有匹配到会话相关的规则,继续匹配全局规则
|
||||
for rule in global_config.chat.talk_value_rules:
|
||||
@@ -118,7 +139,7 @@ class ChatConfigUtils:
|
||||
else: # 跨天的时间范围
|
||||
in_range = now_min >= start_min or now_min <= end_min
|
||||
if in_range:
|
||||
return rule.value or 0.0 # 如果规则生效但没有设置值,返回0.0
|
||||
return rule.value or 0.0 # 如果规则生效但没有设置值,返回 0.0
|
||||
return result # 如果没有任何规则生效,返回默认值
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -23,7 +23,6 @@ from .official_configs import (
|
||||
EmojiConfig,
|
||||
ExpressionConfig,
|
||||
KeywordReactionConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
MaiSakaConfig,
|
||||
MaimMessageConfig,
|
||||
MCPConfig,
|
||||
@@ -54,9 +53,10 @@ PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
|
||||
CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0"
|
||||
CONFIG_VERSION: str = "8.5.2"
|
||||
MODEL_CONFIG_VERSION: str = "1.13.1"
|
||||
CONFIG_VERSION: str = "8.7.1"
|
||||
MODEL_CONFIG_VERSION: str = "1.14.0"
|
||||
|
||||
logger = get_logger("config")
|
||||
|
||||
@@ -454,6 +454,20 @@ def generate_new_config_file(config_class: type[T], config_path: Path, inner_con
|
||||
write_config_to_file(config, config_path, inner_config_version)
|
||||
|
||||
|
||||
def remove_legacy_env_file(env_path: Path) -> None:
|
||||
"""删除已完成迁移的旧版 `.env` 文件。"""
|
||||
|
||||
if not env_path.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
env_path.unlink()
|
||||
except OSError as exc:
|
||||
logger.warning(f"旧版 .env 配置文件删除失败,请手动删除: {env_path},原因: {exc}")
|
||||
else:
|
||||
logger.warning(f"检测到旧版环境变量绑定配置迁移成功,已删除旧版 .env 文件: {env_path}")
|
||||
|
||||
|
||||
def load_config_from_file(
|
||||
config_class: type[T], config_path: Path, new_ver: str, override_repr: bool = False
|
||||
) -> tuple[T, bool]:
|
||||
@@ -467,10 +481,12 @@ def load_config_from_file(
|
||||
if not isinstance(inner_version, str):
|
||||
raise TypeError(t("config.invalid_inner_version"))
|
||||
old_ver: str = inner_version
|
||||
env_migration_applied: bool = False
|
||||
config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理
|
||||
config_data = config_data.unwrap() # 转换为普通字典,方便后续处理
|
||||
if config_path.name == "bot_config.toml" and config_class.__name__ == "Config":
|
||||
env_migration = migrate_legacy_bind_env_to_bot_config_dict(config_data)
|
||||
env_migration_applied = env_migration.migrated
|
||||
if env_migration.migrated:
|
||||
logger.warning(f"检测到旧版环境变量绑定配置,已迁移到主配置: {env_migration.reason}")
|
||||
config_data = env_migration.data
|
||||
@@ -497,9 +513,11 @@ def load_config_from_file(
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
if compare_versions(old_ver, new_ver):
|
||||
if compare_versions(old_ver, new_ver) or env_migration_applied:
|
||||
output_config_changes(attribute_data, logger, old_ver, new_ver, config_path.name)
|
||||
write_config_to_file(target_config, config_path, new_ver, override_repr)
|
||||
if env_migration_applied:
|
||||
remove_legacy_env_file(LEGACY_ENV_PATH)
|
||||
updated = True
|
||||
return target_config, updated
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
"""
|
||||
legacy_migration.py
|
||||
|
||||
一个“可随时拔掉”的旧配置兼容层:
|
||||
- 仅在配置解析失败时尝试修复旧格式数据(7.x -> 8.x 这一类结构性变更)
|
||||
- 不依赖 Pydantic / ConfigBase,仅对 dict 做最小转换
|
||||
- 成功则返回(修复后的 dict, True),失败则返回(原 dict, False)
|
||||
|
||||
设计目标:与现有 config 加载逻辑的接触点尽可能小,未来不需要时可一键移除。
|
||||
旧配置兼容层。
|
||||
仅保留当前仍需要的“解析前结构修复”,避免老配置在 `from_dict` 前直接失败。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -16,12 +12,7 @@ from typing import Any, Optional
|
||||
|
||||
import os
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("legacy_migration")
|
||||
|
||||
|
||||
# 方便未来快速关闭/移除
|
||||
ENABLE_LEGACY_MIGRATION: bool = True
|
||||
|
||||
|
||||
@@ -43,6 +34,7 @@ def _as_list(x: Any) -> Optional[list[Any]]:
|
||||
def _parse_host_env(value: Any) -> Optional[str]:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
|
||||
normalized_value = value.strip()
|
||||
return normalized_value or None
|
||||
|
||||
@@ -75,116 +67,73 @@ def _migrate_env_value(section: dict[str, Any], key: str, parsed_env_value: Any,
|
||||
return True
|
||||
|
||||
|
||||
def _move_section_key(source: dict[str, Any], target: dict[str, Any], key: str) -> bool:
|
||||
"""将配置项从旧分组移动到新分组,若新分组已有值则保留新值。"""
|
||||
|
||||
if key not in source:
|
||||
return False
|
||||
|
||||
if key not in target:
|
||||
target[key] = source[key]
|
||||
source.pop(key, None)
|
||||
return True
|
||||
|
||||
|
||||
def _parse_triplet_target(s: str) -> Optional[dict[str, str]]:
|
||||
"""
|
||||
解析 "platform:id:type" -> {platform,item_id,rule_type}
|
||||
返回 None 表示无法解析。
|
||||
解析 "platform:id:type" -> {platform, item_id, rule_type}
|
||||
"""
|
||||
if not isinstance(s, str):
|
||||
return None
|
||||
|
||||
parts = s.split(":", 2)
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
platform, item_id, rule_type = parts
|
||||
if rule_type not in ("group", "private"):
|
||||
return None
|
||||
return {"platform": platform, "item_id": item_id, "rule_type": rule_type}
|
||||
|
||||
|
||||
def _parse_quad_prompt(s: str) -> Optional[dict[str, str]]:
|
||||
"""
|
||||
解析 "platform:id:type:prompt" -> {platform,item_id,rule_type,prompt}
|
||||
prompt 允许包含冒号,因此只切前三个冒号。
|
||||
"""
|
||||
if not isinstance(s, str):
|
||||
return None
|
||||
parts = s.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
return None
|
||||
platform, item_id, rule_type, prompt = parts
|
||||
if rule_type not in ("group", "private"):
|
||||
return None
|
||||
if not prompt:
|
||||
return None
|
||||
return {"platform": platform, "item_id": item_id, "rule_type": rule_type, "prompt": prompt}
|
||||
|
||||
|
||||
def _parse_enable_disable(v: Any) -> Optional[bool]:
|
||||
"""
|
||||
兼容旧值 "enable"/"disable" 以及 bool。
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
|
||||
if isinstance(v, str):
|
||||
vv = v.strip().lower()
|
||||
if vv == "enable":
|
||||
normalized_value = v.strip().lower()
|
||||
if normalized_value == "enable":
|
||||
return True
|
||||
if vv == "disable":
|
||||
if normalized_value == "disable":
|
||||
return False
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool:
|
||||
"""
|
||||
旧:
|
||||
learning_list = [
|
||||
["", "enable", "enable", "enable"],
|
||||
["qq:1919810:group", "enable", "enable", "enable"],
|
||||
]
|
||||
兼容旧旧格式:
|
||||
learning_list = [
|
||||
["qq:1919810:group", "enable", "enable", "0.5"],
|
||||
["", "disable", "disable", "0.1"],
|
||||
]
|
||||
新:
|
||||
[[expression.learning_list]]
|
||||
platform="", item_id="", rule_type="group", use_expression=true, enable_learning=true, enable_jargon_learning=true
|
||||
将旧版 expression.learning_list 转成当前结构。
|
||||
"""
|
||||
ll = _as_list(expr.get("learning_list"))
|
||||
if ll is None:
|
||||
learning_list = _as_list(expr.get("learning_list"))
|
||||
if learning_list is None:
|
||||
return False
|
||||
|
||||
# 如果已经是新格式(列表里是 dict),跳过
|
||||
if ll and all(isinstance(i, dict) for i in ll):
|
||||
if learning_list and all(isinstance(item, dict) for item in learning_list):
|
||||
return False
|
||||
|
||||
migrated_items: list[dict[str, Any]] = []
|
||||
for row in ll:
|
||||
r = _as_list(row)
|
||||
if r is None or len(r) < 4:
|
||||
# 行结构不对,无法安全迁移
|
||||
for row in learning_list:
|
||||
row_items = _as_list(row)
|
||||
if row_items is None or len(row_items) < 4:
|
||||
return False
|
||||
|
||||
target_raw = r[0]
|
||||
use_expression = _parse_enable_disable(r[1])
|
||||
enable_learning = _parse_enable_disable(r[2])
|
||||
enable_jargon_learning = _parse_enable_disable(r[3])
|
||||
target_raw = row_items[0]
|
||||
use_expression = _parse_enable_disable(row_items[1])
|
||||
enable_learning = _parse_enable_disable(row_items[2])
|
||||
enable_jargon_learning = _parse_enable_disable(row_items[3])
|
||||
|
||||
if enable_jargon_learning is None:
|
||||
# 更早期的配置在第 4 列记录的是一个已废弃的数值权重/阈值,
|
||||
# 当前 schema 已没有对应字段。这里按保守策略兼容迁移:
|
||||
# 丢弃旧数值,并将 enable_jargon_learning 置为 False。
|
||||
# 更早期版本第 4 列是已废弃的数值阈值,这里仅做保守兼容。
|
||||
try:
|
||||
float(str(r[3]))
|
||||
float(str(row_items[3]))
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
enable_jargon_learning = False
|
||||
|
||||
if use_expression is None or enable_learning is None or enable_jargon_learning is None:
|
||||
return False
|
||||
|
||||
# 旧格式中 target 允许为空字符串:表示全局;新结构必须有三元组字段
|
||||
if target_raw == "" or target_raw is None:
|
||||
target = {"platform": "", "item_id": "", "rule_type": "group"}
|
||||
else:
|
||||
@@ -209,99 +158,56 @@ def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool:
|
||||
|
||||
def _migrate_expression_groups(expr: dict[str, Any]) -> bool:
|
||||
"""
|
||||
旧:
|
||||
expression_groups = [
|
||||
["qq:1:group","qq:2:group"],
|
||||
["qq:3:group"],
|
||||
]
|
||||
新:
|
||||
expression_groups = [
|
||||
{ expression_groups = [ {platform="qq", item_id="1", rule_type="group"}, ... ] },
|
||||
{ expression_groups = [ ... ] },
|
||||
]
|
||||
将旧版 expression.expression_groups 转成当前结构。
|
||||
"""
|
||||
eg = _as_list(expr.get("expression_groups"))
|
||||
if eg is None:
|
||||
expression_groups = _as_list(expr.get("expression_groups"))
|
||||
if expression_groups is None:
|
||||
return False
|
||||
if expression_groups and all(isinstance(item, dict) for item in expression_groups):
|
||||
return False
|
||||
|
||||
# 已经是新格式(列表里是 dict 且包含 expression_groups),跳过
|
||||
if eg and all(isinstance(i, dict) for i in eg):
|
||||
return False
|
||||
|
||||
migrated: list[dict[str, Any]] = []
|
||||
for group in eg:
|
||||
g = _as_list(group)
|
||||
if g is None:
|
||||
migrated_groups: list[dict[str, Any]] = []
|
||||
for group in expression_groups:
|
||||
group_items = _as_list(group)
|
||||
if group_items is None:
|
||||
return False
|
||||
|
||||
targets: list[dict[str, str]] = []
|
||||
for item in g:
|
||||
for item in group_items:
|
||||
parsed = _parse_triplet_target(str(item))
|
||||
if parsed is None:
|
||||
return False
|
||||
targets.append(parsed)
|
||||
migrated.append({"expression_groups": targets})
|
||||
|
||||
expr["expression_groups"] = migrated
|
||||
migrated_groups.append({"expression_groups": targets})
|
||||
|
||||
expr["expression_groups"] = migrated_groups
|
||||
return True
|
||||
|
||||
|
||||
def _migrate_target_item_list(parent: dict[str, Any], key: str) -> bool:
|
||||
"""
|
||||
将 list[str] 的 "platform:id:type" 迁移为 list[{platform,item_id,rule_type}]
|
||||
用于:memory.global_memory_blacklist 等。
|
||||
将 list[str] 的 "platform:id:type" 迁移为 list[TargetItem]。
|
||||
"""
|
||||
raw = _as_list(parent.get(key))
|
||||
if raw is None:
|
||||
if raw is None or not raw:
|
||||
return False
|
||||
if raw and all(isinstance(i, dict) for i in raw):
|
||||
if all(isinstance(item, dict) for item in raw):
|
||||
return False
|
||||
|
||||
targets: list[dict[str, str]] = []
|
||||
for item in raw:
|
||||
parsed = _parse_triplet_target(str(item))
|
||||
if parsed is None:
|
||||
return False
|
||||
targets.append(parsed)
|
||||
|
||||
parent[key] = targets
|
||||
return True
|
||||
|
||||
|
||||
def _migrate_extra_prompt_list(exp: dict[str, Any], key: str) -> bool:
|
||||
"""
|
||||
将 list[str] 的 "platform:id:type:prompt" 迁移为 list[{platform,item_id,rule_type,prompt}]
|
||||
用于:experimental.chat_prompts
|
||||
"""
|
||||
raw = _as_list(exp.get(key))
|
||||
if raw is None:
|
||||
return False
|
||||
if raw and all(isinstance(i, dict) for i in raw):
|
||||
return False
|
||||
items: list[dict[str, str]] = []
|
||||
for item in raw:
|
||||
parsed = _parse_quad_prompt(str(item))
|
||||
if parsed is None:
|
||||
return False
|
||||
items.append(parsed)
|
||||
exp[key] = items
|
||||
return True
|
||||
|
||||
|
||||
def _parse_multimodal_replyer(v: Any) -> Optional[bool]:
|
||||
"""兼容旧 replyer_generator_type 到布尔开关的迁移。"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if not isinstance(v, str):
|
||||
return None
|
||||
|
||||
normalized_value = v.strip().lower()
|
||||
if normalized_value == "multimodal":
|
||||
return True
|
||||
if normalized_value == "legacy":
|
||||
return False
|
||||
return None
|
||||
|
||||
|
||||
def migrate_legacy_bind_env_to_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
"""将旧版环境变量中的绑定地址迁移到主配置结构。"""
|
||||
"""将旧版 `.env` 中的绑定地址迁移到主配置结构。"""
|
||||
|
||||
migrated_any = False
|
||||
reasons: list[str] = []
|
||||
@@ -339,8 +245,7 @@ def migrate_legacy_bind_env_to_bot_config_dict(data: dict[str, Any]) -> Migratio
|
||||
|
||||
def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
"""
|
||||
尝试对“总配置 bot_config.toml”的 dict(已 unwrap)进行旧格式修复。
|
||||
仅做我们明确知道的结构性变更;其它字段不动。
|
||||
尝试修复 `bot_config.toml` 的少量旧结构,仅保留当前仍需要的兼容逻辑。
|
||||
"""
|
||||
if not ENABLE_LEGACY_MIGRATION:
|
||||
return MigrationResult(data=data, migrated=False, reason="disabled")
|
||||
@@ -353,41 +258,30 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
if _migrate_expression_learning_list(expr):
|
||||
migrated_any = True
|
||||
reasons.append("expression.learning_list")
|
||||
|
||||
if _migrate_expression_groups(expr):
|
||||
migrated_any = True
|
||||
reasons.append("expression.expression_groups")
|
||||
# allow_reflect: 旧 list[str] -> 新 list[TargetItem]
|
||||
|
||||
if _migrate_target_item_list(expr, "allow_reflect"):
|
||||
migrated_any = True
|
||||
reasons.append("expression.allow_reflect")
|
||||
# manual_reflect_operator_id: 旧 str -> 新 Optional[TargetItem]
|
||||
mroi = expr.get("manual_reflect_operator_id")
|
||||
if isinstance(mroi, str) and mroi.strip():
|
||||
parsed = _parse_triplet_target(mroi.strip())
|
||||
|
||||
manual_reflect_operator_id = expr.get("manual_reflect_operator_id")
|
||||
if isinstance(manual_reflect_operator_id, str) and manual_reflect_operator_id.strip():
|
||||
parsed = _parse_triplet_target(manual_reflect_operator_id.strip())
|
||||
if parsed is not None:
|
||||
expr["manual_reflect_operator_id"] = parsed
|
||||
migrated_any = True
|
||||
reasons.append("expression.manual_reflect_operator_id")
|
||||
|
||||
chat = _as_dict(data.get("chat"))
|
||||
if chat is None:
|
||||
chat = {}
|
||||
data["chat"] = chat
|
||||
elif "private_plan_style" in chat:
|
||||
chat.pop("private_plan_style", None)
|
||||
migrated_any = True
|
||||
reasons.append("chat.private_plan_style_removed")
|
||||
if isinstance(manual_reflect_operator_id, str) and not manual_reflect_operator_id.strip():
|
||||
expr.pop("manual_reflect_operator_id", None)
|
||||
migrated_any = True
|
||||
reasons.append("expression.manual_reflect_operator_id_empty")
|
||||
|
||||
personality = _as_dict(data.get("personality"))
|
||||
visual = _as_dict(data.get("visual"))
|
||||
if visual is None and (
|
||||
(personality is not None and "visual_style" in personality)
|
||||
or "multimodal_planner" in chat
|
||||
or "replyer_generator_type" in chat
|
||||
):
|
||||
visual = {}
|
||||
data["visual"] = visual
|
||||
|
||||
if visual is not None and personality is not None and "visual_style" in personality:
|
||||
if "visual_style" not in visual:
|
||||
visual["visual_style"] = personality["visual_style"]
|
||||
@@ -395,108 +289,19 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
migrated_any = True
|
||||
reasons.append("personality.visual_style_moved_to_visual.visual_style")
|
||||
|
||||
if visual is not None and "multimodal_planner" in chat:
|
||||
if "multimodal_planner" not in visual and isinstance(chat["multimodal_planner"], bool):
|
||||
visual["multimodal_planner"] = chat["multimodal_planner"]
|
||||
if "multimodal_planner" in visual:
|
||||
chat.pop("multimodal_planner", None)
|
||||
if visual is not None and "multimodal_planner" in visual and "planner_mode" not in visual:
|
||||
multimodal_planner = visual.pop("multimodal_planner")
|
||||
if isinstance(multimodal_planner, bool):
|
||||
visual["planner_mode"] = "multimodal" if multimodal_planner else "text"
|
||||
migrated_any = True
|
||||
reasons.append("chat.multimodal_planner_moved_to_visual.multimodal_planner")
|
||||
reasons.append("visual.multimodal_planner_moved_to_visual.planner_mode")
|
||||
else:
|
||||
visual["multimodal_planner"] = multimodal_planner
|
||||
|
||||
if visual is not None and "replyer_generator_type" in chat:
|
||||
multimodal_replyer = _parse_multimodal_replyer(chat["replyer_generator_type"])
|
||||
if "multimodal_replyer" not in visual and multimodal_replyer is not None:
|
||||
visual["multimodal_replyer"] = multimodal_replyer
|
||||
if "multimodal_replyer" in visual:
|
||||
chat.pop("replyer_generator_type", None)
|
||||
migrated_any = True
|
||||
reasons.append("chat.replyer_generator_type_moved_to_visual.multimodal_replyer")
|
||||
|
||||
maisaka = _as_dict(data.get("maisaka"))
|
||||
mem = _as_dict(data.get("memory"))
|
||||
debug = _as_dict(data.get("debug"))
|
||||
if maisaka is not None:
|
||||
moved_memory_keys = ("enable_memory_query_tool", "memory_query_default_limit")
|
||||
if any(key in maisaka for key in moved_memory_keys) and mem is None:
|
||||
mem = {}
|
||||
data["memory"] = mem
|
||||
|
||||
if mem is not None:
|
||||
for moved_key in moved_memory_keys:
|
||||
if _move_section_key(maisaka, mem, moved_key):
|
||||
migrated_any = True
|
||||
reasons.append(f"maisaka.{moved_key}_moved_to_memory")
|
||||
|
||||
if mem is not None and "show_memory_prompt" in mem and debug is None:
|
||||
debug = {}
|
||||
data["debug"] = debug
|
||||
|
||||
if mem is not None:
|
||||
if _migrate_target_item_list(mem, "global_memory_blacklist"):
|
||||
migrated_any = True
|
||||
reasons.append("memory.global_memory_blacklist")
|
||||
|
||||
if debug is not None and _move_section_key(mem, debug, "show_memory_prompt"):
|
||||
migrated_any = True
|
||||
reasons.append("memory.show_memory_prompt_moved_to_debug")
|
||||
|
||||
for removed_key in (
|
||||
"agent_timeout_seconds",
|
||||
"max_agent_iterations",
|
||||
):
|
||||
if removed_key in mem:
|
||||
mem.pop(removed_key, None)
|
||||
migrated_any = True
|
||||
reasons.append(f"memory.{removed_key}_removed")
|
||||
|
||||
relationship = _as_dict(data.get("relationship"))
|
||||
if relationship is not None:
|
||||
data.pop("relationship", None)
|
||||
memory = _as_dict(data.get("memory"))
|
||||
if memory is not None and _migrate_target_item_list(memory, "global_memory_blacklist"):
|
||||
migrated_any = True
|
||||
reasons.append("relationship_removed")
|
||||
|
||||
exp = _as_dict(data.get("experimental"))
|
||||
if exp is not None:
|
||||
if _migrate_extra_prompt_list(exp, "chat_prompts"):
|
||||
migrated_any = True
|
||||
reasons.append("experimental.chat_prompts")
|
||||
|
||||
if "private_plan_style" in exp:
|
||||
exp.pop("private_plan_style", None)
|
||||
migrated_any = True
|
||||
reasons.append("experimental.private_plan_style_removed")
|
||||
|
||||
for key in ("group_chat_prompt", "private_chat_prompts", "chat_prompts"):
|
||||
if key in exp and key not in chat:
|
||||
chat[key] = exp[key]
|
||||
migrated_any = True
|
||||
reasons.append(f"experimental.{key}_moved_to_chat")
|
||||
|
||||
data.pop("experimental", None)
|
||||
migrated_any = True
|
||||
reasons.append("experimental_removed")
|
||||
|
||||
if chat is not None and "think_mode" in chat:
|
||||
chat.pop("think_mode", None)
|
||||
migrated_any = True
|
||||
reasons.append("chat.think_mode_removed")
|
||||
|
||||
tool = _as_dict(data.get("tool"))
|
||||
if tool is not None:
|
||||
data.pop("tool", None)
|
||||
migrated_any = True
|
||||
reasons.append("tool_section_removed")
|
||||
|
||||
# ExpressionConfig 中的 manual_reflect_operator_id:
|
||||
# 旧版本可能是 ""(字符串),新版本期望 Optional[TargetItem]。
|
||||
# 空字符串视为未配置,转换为 None/删除键以避免校验错误。
|
||||
expr = _as_dict(data.get("expression"))
|
||||
if expr is not None:
|
||||
mroi = expr.get("manual_reflect_operator_id")
|
||||
if isinstance(mroi, str) and not mroi.strip():
|
||||
expr.pop("manual_reflect_operator_id", None)
|
||||
migrated_any = True
|
||||
reasons.append("expression.manual_reflect_operator_id_empty")
|
||||
reasons.append("memory.global_memory_blacklist")
|
||||
|
||||
reason = ",".join(reasons)
|
||||
return MigrationResult(data=data, migrated=migrated_any, reason=reason)
|
||||
|
||||
@@ -307,6 +307,15 @@ class ModelInfo(ConfigBase):
|
||||
)
|
||||
"""强制流式输出模式 (若模型不支持非流式输出, 请设置为true启用强制流式输出, 默认值为false)"""
|
||||
|
||||
visual: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""是否为多模态模型。开启后表示该模型支持视觉输入。"""
|
||||
|
||||
extra_params: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
json_schema_extra={
|
||||
@@ -437,4 +446,4 @@ class ModelTaskConfig(ConfigBase):
|
||||
"x-icon": "database",
|
||||
},
|
||||
)
|
||||
"""嵌入模型配置"""
|
||||
"""嵌入模型配置"""
|
||||
|
||||
@@ -145,23 +145,23 @@ class VisualConfig(ConfigBase):
|
||||
__ui_label__ = "视觉"
|
||||
__ui_icon__ = "image"
|
||||
|
||||
multimodal_planner: bool = Field(
|
||||
default=True,
|
||||
planner_mode: Literal["text", "multimodal", "auto"] = Field(
|
||||
default="auto",
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""是否直接输入图片"""
|
||||
|
||||
multimodal_replyer: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-widget": "select",
|
||||
"x-icon": "git-branch",
|
||||
},
|
||||
)
|
||||
"""是否启用 Maisaka 多模态 replyer 生成器"""
|
||||
"""规划器模式,auto根据模型信息自动选择,text为纯文本模式,multimodal为多模态模式"""
|
||||
|
||||
replyer_mode: Literal["text", "multimodal", "auto"] = Field(
|
||||
default="auto",
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "git-branch",
|
||||
},
|
||||
)
|
||||
"""回复器模式,auto根据模型信息自动选择,text为纯文本模式,multimodal为多模态模式"""
|
||||
|
||||
visual_style: str = Field(
|
||||
default="请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本",
|
||||
@@ -239,16 +239,12 @@ class ChatConfig(ConfigBase):
|
||||
)
|
||||
"""Planner 连续被新消息打断的最大次数,0 表示不启用打断"""
|
||||
|
||||
plan_reply_log_max_per_chat: int = Field(
|
||||
default=1024,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "file-text",
|
||||
},
|
||||
)
|
||||
"""每个聊天流最大保存的Plan/Reply日志数量,超过此数量时会自动删除最老的日志"""
|
||||
group_chat_prompt: str = Field(
|
||||
default="你需要控制自己发言的频率,如果是一对一聊天,可以以较均匀的频率发言;如果用户较多,不要每句都回复,控制回复频率,不要回复的太频繁!控制回复的频率,不要每个人的消息都回复。",
|
||||
default="""
|
||||
你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片。
|
||||
回复尽量简短一些。最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。请注意把握聊天内容。
|
||||
不要回复的太频繁!控制回复的频率,不要每个人的消息都回复,只回复你感兴趣的或者主动提及你的。
|
||||
""",
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "users",
|
||||
@@ -257,7 +253,11 @@ class ChatConfig(ConfigBase):
|
||||
"""_wrap_群聊通用注意事项"""
|
||||
|
||||
private_chat_prompts: str = Field(
|
||||
default="你需要控制自己发言的频率,可以以较均匀的频率发言。",
|
||||
default="""
|
||||
你正在聊天,下面是正在聊的内容,其中包含聊天记录和聊天中的图片。
|
||||
回复尽量简短一些。请注意把握聊天内容。
|
||||
请考虑对方的发言频率,想法,思考自己何时回复以及回复内容。
|
||||
""",
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "user",
|
||||
@@ -740,6 +740,15 @@ class LearningItem(ConfigBase):
|
||||
)
|
||||
"""是否启用jargon学习"""
|
||||
|
||||
advanced_chosen: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "sparkles",
|
||||
},
|
||||
)
|
||||
"""是否启用基于子代理的二次表达方式选择"""
|
||||
|
||||
|
||||
class ExpressionGroup(ConfigBase):
|
||||
"""表达互通组配置类,若列表为空代表全局共享"""
|
||||
@@ -769,6 +778,7 @@ class ExpressionConfig(ConfigBase):
|
||||
use_expression=True,
|
||||
enable_learning=True,
|
||||
enable_jargon_learning=True,
|
||||
advanced_chosen=False,
|
||||
)
|
||||
],
|
||||
json_schema_extra={
|
||||
@@ -1640,35 +1650,6 @@ class MaiSakaConfig(ConfigBase):
|
||||
)
|
||||
"""MaiSaka 使用的用户名称"""
|
||||
|
||||
tool_filter_task_name: str = Field(
|
||||
default="utils",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "sparkles",
|
||||
},
|
||||
)
|
||||
"""工具筛选预判使用的模型任务名"""
|
||||
|
||||
tool_filter_threshold: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "filter",
|
||||
},
|
||||
)
|
||||
"""当可用工具总数超过该阈值时,先进行一轮工具筛选"""
|
||||
|
||||
tool_filter_max_keep: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "list-filter",
|
||||
},
|
||||
)
|
||||
"""工具筛选阶段最多保留的非内置工具数量"""
|
||||
|
||||
show_image_path: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -181,10 +181,7 @@ class ToolSpec:
|
||||
str: 合并后的单段工具描述。
|
||||
"""
|
||||
|
||||
parts = [self.brief_description.strip()]
|
||||
if self.detailed_description.strip():
|
||||
parts.append(self.detailed_description.strip())
|
||||
return "\n\n".join(part for part in parts if part).strip()
|
||||
return self.brief_description.strip()
|
||||
|
||||
def to_llm_definition(self) -> ToolDefinitionInput:
|
||||
"""转换为统一的 LLM 工具定义。
|
||||
@@ -389,7 +386,24 @@ class ToolRegistry:
|
||||
for provider in self._providers:
|
||||
provider_specs = await provider.list_tools()
|
||||
if any(spec.name == invocation.tool_name and spec.enabled for spec in provider_specs):
|
||||
return await provider.invoke(invocation, context)
|
||||
try:
|
||||
return await provider.invoke(invocation, context)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"工具调用异常: tool=%s provider=%s",
|
||||
invocation.tool_name,
|
||||
getattr(provider, "provider_name", ""),
|
||||
)
|
||||
error_message = str(exc).strip()
|
||||
if error_message:
|
||||
error_message = f"工具 {invocation.tool_name} 调用失败:{exc.__class__.__name__}: {error_message}"
|
||||
else:
|
||||
error_message = f"工具 {invocation.tool_name} 调用失败:{exc.__class__.__name__}"
|
||||
return ToolExecutionResult(
|
||||
tool_name=invocation.tool_name,
|
||||
success=False,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
return ToolExecutionResult(
|
||||
tool_name=invocation.tool_name,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
表达方式自动检查定时任务
|
||||
表达方式自动检查定时任务。
|
||||
|
||||
功能:
|
||||
1. 定期随机选取指定数量的表达方式
|
||||
@@ -9,52 +9,48 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from src.learners.expression_review_store import get_review_state, set_review_state
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.learners.expression_review_store import get_review_state, set_review_state
|
||||
from src.learners.expression_utils import parse_evaluation_response
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("expression_auto_check_task")
|
||||
logger = get_logger("expressor")
|
||||
|
||||
|
||||
def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
"""
|
||||
创建评估提示词
|
||||
创建评估提示词。
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
situation: 情景
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
评估提示词
|
||||
"""
|
||||
# 基础评估标准
|
||||
base_criteria = [
|
||||
"表达方式或言语风格 是否与使用条件或使用情景 匹配",
|
||||
"允许部分语法错误或口头化或缺省出现",
|
||||
"表达方式或言语风格是否与使用条件或使用情景匹配",
|
||||
"允许部分语法错误或口语化或缺省出现",
|
||||
"表达方式不能太过特指,需要具有泛用性",
|
||||
"一般不涉及具体的人名或名称",
|
||||
]
|
||||
|
||||
# 从配置中获取额外的自定义标准
|
||||
custom_criteria = global_config.expression.expression_auto_check_custom_criteria
|
||||
|
||||
# 合并所有评估标准
|
||||
all_criteria = base_criteria.copy()
|
||||
if custom_criteria:
|
||||
all_criteria.extend(custom_criteria)
|
||||
|
||||
# 构建评估标准列表字符串
|
||||
criteria_list = "\n".join([f"{i + 1}. {criterion}" for i, criterion in enumerate(all_criteria)])
|
||||
|
||||
prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
|
||||
@@ -64,14 +60,13 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
请从以下方面进行评估:
|
||||
{criteria_list}
|
||||
|
||||
请以JSON格式输出评估结果:
|
||||
请以 JSON 格式输出评估结果:
|
||||
{{
|
||||
"suitable": true/false,
|
||||
"reason": "评估理由(如果不合适,请说明原因)"
|
||||
|
||||
}}
|
||||
如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。
|
||||
请严格按照JSON格式输出,不要包含其他内容。"""
|
||||
如果合适,suitable 设为 true;如果不合适,suitable 设为 false,并在 reason 中说明原因。
|
||||
请严格按照 JSON 格式输出,不要包含其他内容。"""
|
||||
|
||||
return prompt
|
||||
|
||||
@@ -81,10 +76,10 @@ judge_llm = LLMServiceClient(task_name="utils", request_type="expression_check")
|
||||
|
||||
async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str | None]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
执行单次 LLM 评估。
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
situation: 情景
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
@@ -101,20 +96,10 @@ async def single_expression_check(situation: str, style: str) -> tuple[bool, str
|
||||
response = generation_result.response
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError as e:
|
||||
import re
|
||||
evaluation = parse_evaluation_response(response)
|
||||
|
||||
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
||||
if json_match:
|
||||
evaluation = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
|
||||
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
suitable = bool(evaluation.get("suitable", False))
|
||||
reason = str(evaluation.get("reason", "未提供理由"))
|
||||
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
@@ -125,20 +110,19 @@ async def single_expression_check(situation: str, style: str) -> tuple[bool, str
|
||||
|
||||
|
||||
class ExpressionAutoCheckTask(AsyncTask):
|
||||
"""表达方式自动检查定时任务"""
|
||||
"""表达方式自动检查定时任务。"""
|
||||
|
||||
def __init__(self):
|
||||
# 从配置中获取检查间隔和一次检查数量
|
||||
check_interval = global_config.expression.expression_auto_check_interval
|
||||
super().__init__(
|
||||
task_name="Expression Auto Check Task",
|
||||
wait_before_start=60, # 启动后等待60秒再开始第一次检查
|
||||
wait_before_start=60,
|
||||
run_interval=check_interval,
|
||||
)
|
||||
|
||||
async def _select_expressions(self, count: int) -> List[Expression]:
|
||||
"""
|
||||
随机选择指定数量的未检查表达方式
|
||||
随机选择指定数量的未检查表达方式。
|
||||
|
||||
Args:
|
||||
count: 需要选择的数量
|
||||
@@ -158,11 +142,12 @@ class ExpressionAutoCheckTask(AsyncTask):
|
||||
logger.info("没有未检查的表达方式")
|
||||
return []
|
||||
|
||||
# 随机选择指定数量
|
||||
selected_count = min(count, len(unevaluated_expressions))
|
||||
selected = random.sample(unevaluated_expressions, selected_count)
|
||||
|
||||
logger.info(f"从 {len(unevaluated_expressions)} 条未检查表达方式中随机选择了 {selected_count} 条")
|
||||
logger.info(
|
||||
f"从 {len(unevaluated_expressions)} 条未检查表达方式中随机选择了 {selected_count} 条"
|
||||
)
|
||||
return selected
|
||||
|
||||
except Exception as e:
|
||||
@@ -171,35 +156,35 @@ class ExpressionAutoCheckTask(AsyncTask):
|
||||
|
||||
async def _evaluate_expression(self, expression: Expression) -> bool:
|
||||
"""
|
||||
评估单个表达方式
|
||||
评估单个表达方式。
|
||||
|
||||
Args:
|
||||
expression: 要评估的表达方式
|
||||
|
||||
Returns:
|
||||
True表示通过,False表示不通过
|
||||
True 表示通过,False 表示不通过
|
||||
"""
|
||||
|
||||
suitable, reason, error = await single_expression_check(
|
||||
expression.situation,
|
||||
expression.style,
|
||||
)
|
||||
|
||||
# 更新数据库
|
||||
try:
|
||||
set_review_state(expression.id, True, not suitable, "ai")
|
||||
|
||||
status = "通过" if suitable else "不通过"
|
||||
# 保留这段注释,方便后续需要时恢复更详细的审核日志。
|
||||
# logger.info(
|
||||
# f"表达方式评估完成 [ID: {expression.id}] - {status} | "
|
||||
# f"Situation: {expression.situation}... | "
|
||||
# f"Style: {expression.style}... | "
|
||||
# f"Reason: {reason[:50]}..."
|
||||
# f"表达方式评估完成 [ID: {expression.id}] - {status} | "
|
||||
# f"Situation: {expression.situation}... | "
|
||||
# f"Style: {expression.style}... | "
|
||||
# f"Reason: {reason[:50]}..."
|
||||
# )
|
||||
|
||||
if error:
|
||||
logger.warning(f"表达方式评估时出现错误 [ID: {expression.id}]: {error}")
|
||||
|
||||
logger.debug(f"表达方式 [ID: {expression.id}] 评估完成: {status}, reason={reason}")
|
||||
return suitable
|
||||
|
||||
except Exception as e:
|
||||
@@ -207,9 +192,8 @@ class ExpressionAutoCheckTask(AsyncTask):
|
||||
return False
|
||||
|
||||
async def run(self):
|
||||
"""执行检查任务"""
|
||||
"""执行检查任务。"""
|
||||
try:
|
||||
# 检查是否启用自动检查
|
||||
if not global_config.expression.expression_self_reflect:
|
||||
logger.debug("表达方式自动检查未启用,跳过本次执行")
|
||||
return
|
||||
@@ -221,26 +205,22 @@ class ExpressionAutoCheckTask(AsyncTask):
|
||||
|
||||
logger.info(f"开始执行表达方式自动检查,本次将检查 {check_count} 条")
|
||||
|
||||
# 选择要检查的表达方式
|
||||
expressions = await self._select_expressions(check_count)
|
||||
|
||||
if not expressions:
|
||||
logger.info("没有需要检查的表达方式")
|
||||
return
|
||||
|
||||
# 逐个评估
|
||||
passed_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
logger.info(f"正在评估 [{i}/{len(expressions)}]: ID={expression.id}")
|
||||
for index, expression in enumerate(expressions, 1):
|
||||
logger.debug(f"正在评估 [{index}/{len(expressions)}]: ID={expression.id}")
|
||||
|
||||
if await self._evaluate_expression(expression):
|
||||
passed_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
# 避免请求过快
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -267,7 +267,7 @@ class ExpressionLearner:
|
||||
return normalized_entries
|
||||
|
||||
def get_pending_count(self, message_cache: List["SessionMessage"]) -> int:
|
||||
"""??????????????"""
|
||||
"""获取待处理消息数量"""
|
||||
return max(0, len(message_cache) - self._last_processed_index)
|
||||
|
||||
async def learn(
|
||||
@@ -275,10 +275,10 @@ class ExpressionLearner:
|
||||
message_cache: List["SessionMessage"],
|
||||
jargon_miner: Optional["JargonMiner"] = None,
|
||||
) -> bool:
|
||||
"""?????????????????????"""
|
||||
"""学习表达方式"""
|
||||
pending_messages = message_cache[self._last_processed_index :]
|
||||
if not pending_messages:
|
||||
logger.debug("??????????????????")
|
||||
logger.debug("没有待处理消息")
|
||||
return False
|
||||
if len(pending_messages) < self.min_messages_for_extraction:
|
||||
return False
|
||||
@@ -304,7 +304,7 @@ class ExpressionLearner:
|
||||
)
|
||||
response = generation_result.response
|
||||
except Exception as e:
|
||||
logger.error(f"????????????????{e}")
|
||||
logger.error(f"学习表达方式失败: {e}")
|
||||
return False
|
||||
|
||||
expressions: List[Tuple[str, str, str]]
|
||||
@@ -319,14 +319,14 @@ class ExpressionLearner:
|
||||
continue
|
||||
jargon_entries.append((content, source_id))
|
||||
existing_contents.add(content)
|
||||
logger.info(f"??????????{content}")
|
||||
logger.info(f"从缓存中找到黑话: {content}")
|
||||
|
||||
if len(expressions) > 20:
|
||||
logger.info(f"?????????? 20 ???????????{len(expressions)}")
|
||||
logger.info(f"表达方式数量超过20: {len(expressions)}")
|
||||
expressions = []
|
||||
|
||||
if len(jargon_entries) > 30:
|
||||
logger.info(f"???????? 30 ???????????{len(jargon_entries)}")
|
||||
logger.info(f"黑话数量超过30: {len(jargon_entries)}")
|
||||
jargon_entries = []
|
||||
|
||||
after_extract_result = await self._get_runtime_manager().invoke_hook(
|
||||
@@ -337,7 +337,7 @@ class ExpressionLearner:
|
||||
jargon_entries=self._serialize_jargon_entries(jargon_entries),
|
||||
)
|
||||
if after_extract_result.aborted:
|
||||
logger.info(f"{self.session_id} ?????????? Hook ??")
|
||||
logger.info(f"{self.session_id} 表达方式选择 Hook 中止")
|
||||
self._last_processed_index = len(message_cache)
|
||||
return False
|
||||
|
||||
@@ -353,21 +353,21 @@ class ExpressionLearner:
|
||||
await self._process_jargon_entries(jargon_entries, pending_messages, jargon_miner)
|
||||
|
||||
if not expressions:
|
||||
logger.info("????????????")
|
||||
logger.info("没有可学习的表达方式")
|
||||
self._last_processed_index = len(message_cache)
|
||||
return False
|
||||
|
||||
logger.info(f"???? expressions: {expressions}")
|
||||
logger.info(f"???? jargon_entries: {jargon_entries}")
|
||||
logger.info(f"可学习的表达方式: {expressions}")
|
||||
logger.info(f"可学习的黑话: {jargon_entries}")
|
||||
|
||||
learnt_expressions = self._filter_expressions(expressions, pending_messages)
|
||||
if not learnt_expressions:
|
||||
logger.info("????????????")
|
||||
logger.info("没有可学习的表达方式通过过滤")
|
||||
self._last_processed_index = len(message_cache)
|
||||
return False
|
||||
|
||||
learnt_expressions_str = "\n".join(f"{situation}->{style}" for situation, style in learnt_expressions)
|
||||
logger.info(f"? {self.session_id} ????????\n{learnt_expressions_str}")
|
||||
logger.info(f"{self.session_id} 可学习的表达方式: \n{learnt_expressions_str}")
|
||||
|
||||
for situation, style in learnt_expressions:
|
||||
before_upsert_result = await self._get_runtime_manager().invoke_hook(
|
||||
@@ -377,14 +377,14 @@ class ExpressionLearner:
|
||||
style=style,
|
||||
)
|
||||
if before_upsert_result.aborted:
|
||||
logger.info(f"{self.session_id} ???????? Hook ??: situation={situation!r}")
|
||||
logger.info(f"{self.session_id} 表达方式写入 Hook 中止: situation={situation!r}")
|
||||
continue
|
||||
|
||||
upsert_kwargs = before_upsert_result.kwargs
|
||||
situation = str(upsert_kwargs.get("situation", situation) or "").strip()
|
||||
style = str(upsert_kwargs.get("style", style) or "").strip()
|
||||
if not situation or not style:
|
||||
logger.info(f"{self.session_id} ???????? Hook ??????")
|
||||
logger.info(f"{self.session_id} 表达方式写入 Hook 中止: situation={situation!r}")
|
||||
continue
|
||||
await self._upsert_expression_to_db(situation, style)
|
||||
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from json_repair import repair_json
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("expression_utils")
|
||||
|
||||
@@ -16,17 +16,7 @@ judge_llm = LLMServiceClient(task_name="utils", request_type="expression_check")
|
||||
|
||||
|
||||
def _normalize_repair_json_result(repaired_result: Any) -> str:
|
||||
"""将 repair_json 的返回值规范化为 JSON 字符串。
|
||||
|
||||
Args:
|
||||
repaired_result: `repair_json` 的返回值,可能是字符串或带附加信息的元组。
|
||||
|
||||
Returns:
|
||||
str: 可供 `json.loads` 继续解析的 JSON 字符串。
|
||||
|
||||
Raises:
|
||||
TypeError: 当返回值无法规范化为字符串时抛出。
|
||||
"""
|
||||
"""将 `repair_json` 的返回结果统一转换为字符串。"""
|
||||
if isinstance(repaired_result, str):
|
||||
return repaired_result
|
||||
if isinstance(repaired_result, tuple) and repaired_result:
|
||||
@@ -37,22 +27,121 @@ def _normalize_repair_json_result(repaired_result: Any) -> str:
|
||||
raise TypeError(f"repair_json 返回了无法处理的结果类型: {type(repaired_result)}")
|
||||
|
||||
|
||||
def _strip_markdown_code_fence(text: str) -> str:
|
||||
"""移除 LLM 可能附带的 Markdown 代码块包裹。"""
|
||||
raw = text.strip()
|
||||
if match := re.search(r"```json\s*(.*?)\s*```", raw, re.DOTALL):
|
||||
return match[1].strip()
|
||||
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
|
||||
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
|
||||
return raw.strip()
|
||||
|
||||
|
||||
def _extract_json_object_candidate(text: str) -> str:
|
||||
"""尽量从文本中提取首个 JSON 对象片段。"""
|
||||
start_index = text.find("{")
|
||||
end_index = text.rfind("}")
|
||||
if start_index != -1 and end_index != -1 and start_index < end_index:
|
||||
return text[start_index : end_index + 1].strip()
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _extract_reason_from_text(text: str) -> Optional[str]:
|
||||
"""从格式不完整的 JSON 文本中兜底提取 reason 字段。"""
|
||||
reason_key_match = re.search(r'["“”]?reason["“”]?\s*:\s*', text, re.IGNORECASE)
|
||||
if reason_key_match is None:
|
||||
return None
|
||||
|
||||
value_text = text[reason_key_match.end() :].strip()
|
||||
if not value_text:
|
||||
return None
|
||||
|
||||
if value_text.endswith("}"):
|
||||
value_text = value_text[:-1].rstrip()
|
||||
if value_text.endswith(","):
|
||||
value_text = value_text[:-1].rstrip()
|
||||
if not value_text:
|
||||
return None
|
||||
|
||||
if value_text[0] in {'"', "'", "“", "”", "‘", "’"}:
|
||||
value_text = value_text[1:]
|
||||
while value_text and value_text[-1] in {'"', "'", "“", "”", "‘", "’"}:
|
||||
value_text = value_text[:-1].rstrip()
|
||||
|
||||
return value_text.strip() or None
|
||||
|
||||
|
||||
def _normalize_reason_text(reason: Any) -> str:
|
||||
"""清理解析后 reason 中残留的包裹引号。"""
|
||||
normalized_reason = str(reason).strip()
|
||||
|
||||
if len(normalized_reason) >= 2 and normalized_reason[0] == normalized_reason[-1]:
|
||||
if normalized_reason[0] in {'"', "'", "“", "”", "‘", "’"}:
|
||||
normalized_reason = normalized_reason[1:-1].strip()
|
||||
|
||||
if normalized_reason.endswith('"') and normalized_reason.count('"') % 2 == 1:
|
||||
normalized_reason = normalized_reason[:-1].rstrip()
|
||||
if normalized_reason.endswith("'") and normalized_reason.count("'") % 2 == 1:
|
||||
normalized_reason = normalized_reason[:-1].rstrip()
|
||||
if normalized_reason.endswith('"') and not normalized_reason.startswith('"'):
|
||||
normalized_reason = normalized_reason[:-1].rstrip()
|
||||
if normalized_reason.endswith("'") and not normalized_reason.startswith("'"):
|
||||
normalized_reason = normalized_reason[:-1].rstrip()
|
||||
|
||||
return normalized_reason
|
||||
|
||||
|
||||
def parse_evaluation_response(response: str) -> Dict[str, Any]:
|
||||
"""解析表达方式评估结果,兼容不完全合法的 JSON。"""
|
||||
raw = _strip_markdown_code_fence(response)
|
||||
if not raw:
|
||||
raise ValueError("LLM 响应为空")
|
||||
|
||||
parse_candidates = [raw]
|
||||
json_candidate = _extract_json_object_candidate(raw)
|
||||
if json_candidate and json_candidate not in parse_candidates:
|
||||
parse_candidates.append(json_candidate)
|
||||
|
||||
for candidate in parse_candidates:
|
||||
parsed = _try_parse(candidate)
|
||||
if isinstance(parsed, dict):
|
||||
if "reason" in parsed:
|
||||
parsed["reason"] = _normalize_reason_text(parsed["reason"])
|
||||
return parsed
|
||||
|
||||
fixed_candidate = fix_chinese_quotes_in_json(candidate)
|
||||
if fixed_candidate != candidate:
|
||||
parsed = _try_parse(fixed_candidate)
|
||||
if isinstance(parsed, dict):
|
||||
if "reason" in parsed:
|
||||
parsed["reason"] = _normalize_reason_text(parsed["reason"])
|
||||
return parsed
|
||||
|
||||
suitable_match = re.search(r'["“”]?suitable["“”]?\s*:\s*(true|false)', raw, re.IGNORECASE)
|
||||
reason = _extract_reason_from_text(json_candidate or raw)
|
||||
if suitable_match is None or reason is None:
|
||||
raise ValueError(f"无法解析 LLM 响应为评估结果 JSON: {response}")
|
||||
|
||||
return {
|
||||
"suitable": suitable_match.group(1).lower() == "true",
|
||||
"reason": _normalize_reason_text(reason),
|
||||
}
|
||||
|
||||
|
||||
async def check_expression_suitability(situation: str, style: str) -> Tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
执行单次 LLM 评估。
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
situation: 情景
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
(suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息
|
||||
"""
|
||||
# 构建评估提示词
|
||||
# 基础评估标准
|
||||
base_criteria = [
|
||||
"表达方式或言语风格是否与使用条件或使用情景匹配",
|
||||
"允许部分语法错误或口头化或缺省出现",
|
||||
"允许部分语法错误或口语化或缺省出现",
|
||||
"表达方式不能太过特指,需要具有泛用性",
|
||||
"一般不涉及具体的人名或名称",
|
||||
]
|
||||
@@ -60,7 +149,6 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool
|
||||
if custom_criteria := global_config.expression.expression_auto_check_custom_criteria:
|
||||
base_criteria.extend(custom_criteria)
|
||||
|
||||
# 构建评估标准列表字符串
|
||||
criteria_list = "\n".join([f"{i + 1}. {criterion}" for i, criterion in enumerate(base_criteria)])
|
||||
|
||||
prompt_template = prompt_manager.get_prompt("expression_evaluation")
|
||||
@@ -81,18 +169,13 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool
|
||||
logger.debug(f"评估结果: {response}")
|
||||
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
response_repaired = _normalize_repair_json_result(repair_json(response))
|
||||
evaluation = json.loads(response_repaired)
|
||||
except Exception as e:
|
||||
raise ValueError(f"无法解析LLM响应为JSON: {response}") from e
|
||||
evaluation = parse_evaluation_response(response)
|
||||
except Exception as e:
|
||||
return False, f"评估表达方式时发生错误: {e}", str(e)
|
||||
|
||||
try:
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
suitable = bool(evaluation.get("suitable", False))
|
||||
reason = _normalize_reason_text(evaluation.get("reason", "未提供理由"))
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
except Exception as e:
|
||||
@@ -100,69 +183,48 @@ async def check_expression_suitability(situation: str, style: str) -> Tuple[bool
|
||||
|
||||
|
||||
def fix_chinese_quotes_in_json(text: str) -> str:
|
||||
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
||||
result = []
|
||||
i = 0
|
||||
"""使用状态机修复 JSON 字符串值中的中文引号。"""
|
||||
result: List[str] = []
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
for char in text:
|
||||
if escape_next:
|
||||
# 当前字符是转义字符后的字符,直接添加
|
||||
result.append(char)
|
||||
escape_next = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if char == "\\":
|
||||
# 转义字符
|
||||
result.append(char)
|
||||
escape_next = True
|
||||
i += 1
|
||||
continue
|
||||
if char == '"' and not escape_next:
|
||||
# 遇到英文引号,切换字符串状态
|
||||
|
||||
if char == '"':
|
||||
in_string = not in_string
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_string and char in ["“", "”"]:
|
||||
result.append('\\"')
|
||||
else:
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
result.append(char)
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
"""
|
||||
解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。
|
||||
|
||||
期望的 JSON 结构:
|
||||
[
|
||||
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
|
||||
{"content": "词条", "source_id": "12"}, // 黑话
|
||||
...
|
||||
]
|
||||
解析 LLM 返回的表达方式总结和黑话 JSON,提取两个列表。
|
||||
|
||||
Returns:
|
||||
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
第一个列表是表达方式 (situation, style, source_id)
|
||||
第二个列表是黑话 (content, source_id)
|
||||
第一个列表是表达方式 (situation, style, source_id)
|
||||
第二个列表是黑话 (content, source_id)
|
||||
"""
|
||||
if not response:
|
||||
return [], []
|
||||
|
||||
raw = response.strip()
|
||||
|
||||
if match := re.search(r"```json\s*(.*?)\s*```", raw, re.DOTALL):
|
||||
raw = match[1].strip()
|
||||
else:
|
||||
# 去掉可能存在的通用 ``` 包裹
|
||||
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
|
||||
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
|
||||
raw = raw.strip()
|
||||
raw = _strip_markdown_code_fence(response)
|
||||
|
||||
parsed = _try_parse(raw)
|
||||
if parsed is None:
|
||||
@@ -180,22 +242,21 @@ def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]]
|
||||
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
|
||||
return [], []
|
||||
|
||||
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
|
||||
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
|
||||
expressions: List[Tuple[str, str, str]] = []
|
||||
jargon_entries: List[Tuple[str, str]] = []
|
||||
|
||||
for item in parsed_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# 检查是否是表达方式条目(有 situation 和 style)
|
||||
situation = str(item.get("situation", "")).strip()
|
||||
style = str(item.get("style", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
|
||||
if situation and style and source_id:
|
||||
# 表达方式条目
|
||||
expressions.append((situation, style, source_id))
|
||||
continue
|
||||
|
||||
content = str(item.get("content", "")).strip()
|
||||
if content and source_id:
|
||||
jargon_entries.append((content, source_id))
|
||||
@@ -204,25 +265,16 @@ def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]]
|
||||
|
||||
|
||||
def is_single_char_jargon(content: str) -> bool:
|
||||
"""
|
||||
判断是否是单字黑话(单个汉字、英文或数字)
|
||||
|
||||
Args:
|
||||
content: 词条内容
|
||||
|
||||
Returns:
|
||||
bool: 如果是单字黑话返回True,否则返回False
|
||||
"""
|
||||
"""判断是否是单字黑话(单个汉字、英文或数字)。"""
|
||||
if not content or len(content) != 1:
|
||||
return False
|
||||
|
||||
char = content[0]
|
||||
# 判断是否是单个汉字、单个英文字母或单个数字
|
||||
return (
|
||||
"\u4e00" <= char <= "\u9fff" # 汉字
|
||||
or "a" <= char <= "z" # 小写字母
|
||||
or "A" <= char <= "Z" # 大写字母
|
||||
or "0" <= char <= "9" # 数字
|
||||
"\u4e00" <= char <= "\u9fff"
|
||||
or "a" <= char <= "z"
|
||||
or "A" <= char <= "Z"
|
||||
or "0" <= char <= "9"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -54,27 +54,30 @@ class RespNotOkException(Exception):
|
||||
return f"未知的异常响应代码:{self.status_code}"
|
||||
|
||||
|
||||
class RespParseException(Exception):
|
||||
"""响应解析错误,常见于响应格式不正确或解析方法不匹配"""
|
||||
class ResponseContextException(Exception):
|
||||
"""携带原始响应上下文的异常基类。"""
|
||||
|
||||
def __init__(self, ext_info: Any, message: str | None = None):
|
||||
default_message: str = "请求失败"
|
||||
|
||||
def __init__(self, ext_info: Any = None, message: str | None = None):
|
||||
super().__init__(message)
|
||||
self.ext_info = ext_info
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return self.message or "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法"
|
||||
return self.message or self.default_message
|
||||
|
||||
|
||||
class EmptyResponseException(Exception):
|
||||
class RespParseException(ResponseContextException):
|
||||
"""响应解析错误,常见于响应格式不正确或解析方法不匹配"""
|
||||
|
||||
default_message = "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法"
|
||||
|
||||
|
||||
class EmptyResponseException(ResponseContextException):
|
||||
"""响应内容为空"""
|
||||
|
||||
def __init__(self, message: str = "响应内容为空,这可能是一个临时性问题"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
default_message = "响应内容为空,这可能是一个临时性问题"
|
||||
|
||||
|
||||
class ModelAttemptFailed(Exception):
|
||||
|
||||
@@ -552,7 +552,7 @@ def _build_stream_api_response(
|
||||
|
||||
_warn_if_max_tokens_truncated(last_response, response.content, response.tool_calls)
|
||||
if not response.content and not response.tool_calls and not response.reasoning_content:
|
||||
raise EmptyResponseException()
|
||||
raise EmptyResponseException(last_response)
|
||||
return response
|
||||
|
||||
|
||||
@@ -627,7 +627,7 @@ def _default_normal_response_parser(
|
||||
usage_record = _extract_usage_record(response)
|
||||
_warn_if_max_tokens_truncated(response, api_response.content, api_response.tool_calls)
|
||||
if not api_response.content and not api_response.tool_calls and not api_response.reasoning_content:
|
||||
raise EmptyResponseException("响应中既无文本内容也无工具调用")
|
||||
raise EmptyResponseException(response, "响应中既无文本内容也无工具调用")
|
||||
return api_response, usage_record
|
||||
|
||||
|
||||
|
||||
@@ -79,6 +79,25 @@ THINK_CONTENT_PATTERN = re.compile(
|
||||
)
|
||||
"""用于解析 `<think>` 推理块的正则表达式。"""
|
||||
|
||||
XML_TOOL_CALL_PATTERN = re.compile(r"<tool_call>\s*(?P<body>.*?)\s*</tool_call>", re.DOTALL | re.IGNORECASE)
|
||||
"""用于兜底解析模型以 XML 文本返回的工具调用。
|
||||
|
||||
这是一个暂时性兼容方案,专门处理“思维链内容里夹带工具调用”的情况;
|
||||
后续如果上游稳定返回标准 tool_calls 字段,这里可能会调整或移除。
|
||||
"""
|
||||
|
||||
XML_FUNCTION_CALL_PATTERN = re.compile(
|
||||
r"<function=(?P<name>[A-Za-z0-9_.-]+)>\s*(?P<arguments>.*?)\s*</function>",
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
"""用于从 XML 风格工具调用块中提取函数名与参数。"""
|
||||
|
||||
XML_PARAMETER_PATTERN = re.compile(
|
||||
r"<parameter=(?P<name>[A-Za-z0-9_.-]+)>\s*(?P<value>.*?)\s*</parameter>",
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
"""用于从 XML 风格工具调用块中提取参数列表。"""
|
||||
|
||||
CHAT_COMPLETIONS_RESERVED_EXTRA_BODY_KEYS = {
|
||||
"max_tokens",
|
||||
"messages",
|
||||
@@ -346,6 +365,32 @@ def _convert_assistant_tool_calls(tool_calls: List[ToolCall]) -> List[ChatComple
|
||||
return converted_tool_calls
|
||||
|
||||
|
||||
def _sanitize_messages_for_toolless_request(messages: List[Message]) -> List[Message]:
|
||||
"""在无工具请求时清洗历史工具调用链,避免兼容接口拒收消息。"""
|
||||
sanitized_messages: List[Message] = []
|
||||
|
||||
for message in messages:
|
||||
if message.role == RoleType.Tool:
|
||||
continue
|
||||
|
||||
if message.role == RoleType.Assistant and message.tool_calls:
|
||||
if not message.parts:
|
||||
continue
|
||||
assistant_message = Message(
|
||||
role=message.role,
|
||||
parts=list(message.parts),
|
||||
tool_call_id=message.tool_call_id,
|
||||
tool_name=message.tool_name,
|
||||
tool_calls=None,
|
||||
)
|
||||
sanitized_messages.append(assistant_message)
|
||||
continue
|
||||
|
||||
sanitized_messages.append(message)
|
||||
|
||||
return sanitized_messages
|
||||
|
||||
|
||||
def _convert_messages(messages: List[Message]) -> List[ChatCompletionMessageParam]:
|
||||
"""将内部消息列表转换为 OpenAI 兼容消息列表。
|
||||
|
||||
@@ -515,6 +560,66 @@ def _extract_reasoning_and_content(
|
||||
return None, match.group("content_only").strip() or None
|
||||
|
||||
|
||||
def _extract_xml_tool_calls(
|
||||
raw_text: str | None,
|
||||
parse_mode: ToolArgumentParseMode,
|
||||
response: Any,
|
||||
) -> Tuple[str | None, List[ToolCall] | None]:
|
||||
"""从 XML 风格文本中兜底提取工具调用。"""
|
||||
if not isinstance(raw_text, str) or not raw_text.strip():
|
||||
return raw_text, None
|
||||
|
||||
tool_calls: List[ToolCall] = []
|
||||
|
||||
def _coerce_xml_parameter_value(raw_value: str) -> Any:
|
||||
normalized_value = raw_value.strip()
|
||||
if not normalized_value:
|
||||
return ""
|
||||
lowered_value = normalized_value.lower()
|
||||
if lowered_value == "true":
|
||||
return True
|
||||
if lowered_value == "false":
|
||||
return False
|
||||
if lowered_value in {"null", "none"}:
|
||||
return None
|
||||
if normalized_value.startswith(("{", "[")):
|
||||
try:
|
||||
return repair_json(normalized_value, return_objects=True, logging=False)
|
||||
except Exception:
|
||||
return normalized_value
|
||||
return normalized_value
|
||||
|
||||
def _parse_xml_parameters(raw_arguments: str) -> Dict[str, Any] | None:
|
||||
parameters = {
|
||||
match.group("name").strip(): _coerce_xml_parameter_value(match.group("value"))
|
||||
for match in XML_PARAMETER_PATTERN.finditer(raw_arguments)
|
||||
}
|
||||
return parameters or None
|
||||
|
||||
def _replace_tool_call(match: re.Match[str]) -> str:
|
||||
body = match.group("body")
|
||||
function_match = XML_FUNCTION_CALL_PATTERN.search(body)
|
||||
if function_match is None:
|
||||
return match.group(0)
|
||||
|
||||
function_name = function_match.group("name").strip()
|
||||
raw_arguments = function_match.group("arguments").strip()
|
||||
arguments = _parse_xml_parameters(raw_arguments)
|
||||
if arguments is None:
|
||||
arguments = _parse_tool_arguments(raw_arguments, parse_mode, response) if raw_arguments else {}
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=f"xml_tool_call_{len(tool_calls) + 1}",
|
||||
func_name=function_name,
|
||||
args=arguments,
|
||||
)
|
||||
)
|
||||
return ""
|
||||
|
||||
cleaned_text = XML_TOOL_CALL_PATTERN.sub(_replace_tool_call, raw_text).strip() or None
|
||||
return cleaned_text, tool_calls or None
|
||||
|
||||
|
||||
def _log_length_truncation(finish_reason: str | None, model_name: str | None) -> None:
|
||||
"""记录因长度截断导致的告警日志。
|
||||
|
||||
@@ -526,6 +631,38 @@ def _log_length_truncation(finish_reason: str | None, model_name: str | None) ->
|
||||
logger.info(f"模型{model_name or ''}因为超过最大 max_token 限制,可能仅输出部分内容,可视情况调整")
|
||||
|
||||
|
||||
def _apply_xml_tool_call_fallback(
|
||||
response: APIResponse,
|
||||
parse_mode: ToolArgumentParseMode,
|
||||
raw_response: Any,
|
||||
) -> None:
|
||||
"""当上游未返回标准 tool_calls 时,尝试从 XML 文本兜底解析。
|
||||
|
||||
这是一个暂时性处理方法,用来兼容思维链中混入工具调用的返回格式,
|
||||
后续可能随着模型或上游接口的规范化而变更。
|
||||
"""
|
||||
if response.tool_calls:
|
||||
return
|
||||
|
||||
reasoning_content, tool_calls = _extract_xml_tool_calls(response.reasoning_content, parse_mode, raw_response)
|
||||
if reasoning_content != response.reasoning_content:
|
||||
response.reasoning_content = reasoning_content
|
||||
if tool_calls:
|
||||
response.tool_calls = tool_calls
|
||||
if not response.content and reasoning_content:
|
||||
response.content = reasoning_content
|
||||
response.reasoning_content = None
|
||||
logger.warning("OpenAI 兼容响应未返回标准 tool_calls,已从 XML 文本兜底解析工具调用")
|
||||
return
|
||||
|
||||
cleaned_content, tool_calls = _extract_xml_tool_calls(response.content, parse_mode, raw_response)
|
||||
if cleaned_content != response.content:
|
||||
response.content = cleaned_content
|
||||
if tool_calls:
|
||||
response.tool_calls = tool_calls
|
||||
logger.warning("OpenAI 兼容响应未返回标准 tool_calls,已从 XML 文本兜底解析工具调用")
|
||||
|
||||
|
||||
def _coerce_openai_argument(value: Any) -> Any | Omit:
|
||||
"""将可选参数转换为 OpenAI SDK 期望的值。
|
||||
|
||||
@@ -561,7 +698,7 @@ def _build_api_status_message(error: APIStatusError) -> str:
|
||||
message_parts.append(str(error.message))
|
||||
response_text = getattr(getattr(error, "response", None), "text", None)
|
||||
if response_text:
|
||||
message_parts.append(str(response_text)[:300])
|
||||
message_parts.append(str(response_text))
|
||||
if message_parts:
|
||||
return " | ".join(message_parts)
|
||||
return f"上游接口返回状态码 {error.status_code}"
|
||||
@@ -722,9 +859,10 @@ class _OpenAIStreamAccumulator:
|
||||
response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments))
|
||||
|
||||
response.raw_data = {"model": self.model_name} if self.model_name else None
|
||||
_apply_xml_tool_call_fallback(response, self.tool_argument_parse_mode, response.raw_data)
|
||||
|
||||
if not response.content and not response.tool_calls:
|
||||
raise EmptyResponseException()
|
||||
raise EmptyResponseException(response.raw_data)
|
||||
|
||||
return response
|
||||
|
||||
@@ -808,7 +946,7 @@ def _default_normal_response_parser(
|
||||
"""
|
||||
choices = getattr(resp, "choices", None)
|
||||
if not choices:
|
||||
raise EmptyResponseException("响应解析失败,choices 为空或缺失")
|
||||
raise EmptyResponseException(resp, "响应解析失败,choices 为空或缺失")
|
||||
|
||||
api_response = APIResponse()
|
||||
message_part = choices[0].message
|
||||
@@ -847,9 +985,10 @@ def _default_normal_response_parser(
|
||||
|
||||
finish_reason = getattr(resp.choices[0], "finish_reason", None)
|
||||
_log_length_truncation(finish_reason, getattr(resp, "model", None))
|
||||
_apply_xml_tool_call_fallback(api_response, tool_argument_parse_mode, resp)
|
||||
|
||||
if not api_response.content and not api_response.tool_calls:
|
||||
raise EmptyResponseException()
|
||||
raise EmptyResponseException(resp)
|
||||
|
||||
return api_response, usage_record
|
||||
|
||||
@@ -965,7 +1104,12 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
||||
model_info = request.model_info
|
||||
|
||||
try:
|
||||
messages_payload: List[ChatCompletionMessageParam] = _convert_messages(request.message_list)
|
||||
request_messages = (
|
||||
list(request.message_list)
|
||||
if request.tool_options
|
||||
else _sanitize_messages_for_toolless_request(request.message_list)
|
||||
)
|
||||
messages_payload: List[ChatCompletionMessageParam] = _convert_messages(request_messages)
|
||||
tools_payload: List[ChatCompletionToolParam] | None = (
|
||||
_convert_tool_options(request.tool_options) if request.tool_options else None
|
||||
)
|
||||
|
||||
@@ -58,6 +58,42 @@ def _json_friendly(value: Any) -> Any:
|
||||
return str(value)
|
||||
|
||||
|
||||
def extract_error_response_body(error: Exception) -> Any | None:
|
||||
"""尽量从异常对象中提取上游返回体,便于排查模型请求失败。"""
|
||||
candidate_errors = [error, getattr(error, "__cause__", None)]
|
||||
|
||||
for candidate in candidate_errors:
|
||||
if candidate is None:
|
||||
continue
|
||||
|
||||
response = getattr(candidate, "response", None)
|
||||
if response is not None:
|
||||
response_json = getattr(response, "json", None)
|
||||
if callable(response_json):
|
||||
try:
|
||||
return _json_friendly(response_json())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
response_text = getattr(response, "text", None)
|
||||
if response_text not in (None, ""):
|
||||
return str(response_text)
|
||||
|
||||
response_content = getattr(response, "content", None)
|
||||
if response_content not in (None, b"", ""):
|
||||
return _json_friendly(response_content)
|
||||
|
||||
response_body = getattr(candidate, "body", None)
|
||||
if response_body not in (None, "", b""):
|
||||
return _json_friendly(response_body)
|
||||
|
||||
ext_info = getattr(candidate, "ext_info", None)
|
||||
if ext_info is not None:
|
||||
return _json_friendly(ext_info)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_filename_component(value: str) -> str:
|
||||
"""将任意字符串转换为适合文件名使用的片段。"""
|
||||
normalized_value = FILENAME_SAFE_PATTERN.sub("-", value.strip())
|
||||
@@ -228,6 +264,7 @@ def serialize_model_info_snapshot(model_info: ModelInfo) -> dict[str, Any]:
|
||||
"model_identifier": model_info.model_identifier,
|
||||
"name": model_info.name,
|
||||
"temperature": model_info.temperature,
|
||||
"visual": model_info.visual,
|
||||
}
|
||||
|
||||
|
||||
@@ -244,6 +281,7 @@ def deserialize_model_info_snapshot(raw_model_info: Any) -> ModelInfo:
|
||||
model_identifier=str(raw_model_info.get("model_identifier") or ""),
|
||||
name=str(raw_model_info.get("name") or ""),
|
||||
temperature=raw_model_info.get("temperature"),
|
||||
visual=bool(raw_model_info.get("visual", False)),
|
||||
)
|
||||
|
||||
|
||||
@@ -386,6 +424,10 @@ def save_failed_request_snapshot(
|
||||
"snapshot_version": SNAPSHOT_VERSION,
|
||||
}
|
||||
|
||||
response_body = extract_error_response_body(error)
|
||||
if response_body is not None:
|
||||
snapshot_payload["error"]["response_body"] = response_body
|
||||
|
||||
snapshot_payload["replay"] = {
|
||||
"command": build_replay_command(snapshot_path),
|
||||
"file_uri": snapshot_path.as_uri(),
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
@@ -397,8 +398,6 @@ class LLMOrchestrator:
|
||||
start_time = time.time()
|
||||
|
||||
tool_built = self._build_tool_options(tools)
|
||||
if self.request_type.startswith("maisaka_"):
|
||||
logger.info(f"LLMOrchestrator[{self.request_type}] 已构建 {len(tool_built or [])} 个内部工具选项")
|
||||
|
||||
execution_result = await self._execute_request(
|
||||
request_type=RequestType.RESPONSE,
|
||||
@@ -912,7 +911,11 @@ class LLMOrchestrator:
|
||||
model_info, api_provider, client = self._select_model(exclude_models=failed_models_this_request)
|
||||
message_list = []
|
||||
if message_factory:
|
||||
message_list = message_factory(client)
|
||||
parameter_count = len(inspect.signature(message_factory).parameters)
|
||||
if parameter_count >= 2:
|
||||
message_list = message_factory(client, model_info)
|
||||
else:
|
||||
message_list = message_factory(client)
|
||||
try:
|
||||
request = self._build_client_request(
|
||||
request_type=request_type,
|
||||
|
||||
@@ -18,6 +18,7 @@ from src.common.message_server.server import Server, get_global_server
|
||||
from src.common.remote import TelemetryHeartBeatTask
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
from src.maisaka.display.stage_status_board import disable_stage_status_board, enable_stage_status_board
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.memory_flow_service import memory_automation_service
|
||||
@@ -65,6 +66,7 @@ class MainSystem:
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化系统组件"""
|
||||
enable_stage_status_board()
|
||||
logger.info(t("startup.waking_up", nickname=global_config.bot.nickname))
|
||||
|
||||
# 其他初始化任务
|
||||
@@ -169,6 +171,7 @@ async def main() -> None:
|
||||
system.schedule_tasks(),
|
||||
)
|
||||
finally:
|
||||
disable_stage_status_board()
|
||||
emoji_manager.shutdown()
|
||||
await memory_automation_service.shutdown()
|
||||
await a_memorix_host_service.stop()
|
||||
|
||||
@@ -10,6 +10,8 @@ from src.llm_models.payload_content.tool_option import ToolDefinitionInput
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
from .continue_tool import get_tool_spec as get_continue_tool_spec
|
||||
from .continue_tool import handle_tool as handle_continue_tool
|
||||
from .finish import get_tool_spec as get_finish_tool_spec
|
||||
from .finish import handle_tool as handle_finish_tool
|
||||
from .no_reply import get_tool_spec as get_no_reply_tool_spec
|
||||
from .no_reply import handle_tool as handle_no_reply_tool
|
||||
from .query_jargon import get_tool_spec as get_query_jargon_tool_spec
|
||||
@@ -22,6 +24,8 @@ from .reply import get_tool_spec as get_reply_tool_spec
|
||||
from .reply import handle_tool as handle_reply_tool
|
||||
from .send_emoji import get_tool_spec as get_send_emoji_tool_spec
|
||||
from .send_emoji import handle_tool as handle_send_emoji_tool
|
||||
from .tool_search import get_tool_spec as get_tool_search_tool_spec
|
||||
from .tool_search import handle_tool as handle_tool_search_tool
|
||||
from .view_complex_message import get_tool_spec as get_view_complex_message_tool_spec
|
||||
from .view_complex_message import handle_tool as handle_view_complex_message_tool
|
||||
from .wait import get_tool_spec as get_wait_tool_spec
|
||||
@@ -44,11 +48,13 @@ def get_action_tool_specs() -> List[ToolSpec]:
|
||||
"""获取 Action Loop 阶段可用的内置工具声明。"""
|
||||
|
||||
return [
|
||||
get_finish_tool_spec(),
|
||||
get_reply_tool_spec(),
|
||||
get_view_complex_message_tool_spec(),
|
||||
get_query_jargon_tool_spec(),
|
||||
get_query_memory_tool_spec(enabled=bool(global_config.memory.enable_memory_query_tool)),
|
||||
get_send_emoji_tool_spec(),
|
||||
get_tool_search_tool_spec(),
|
||||
]
|
||||
|
||||
|
||||
@@ -63,12 +69,14 @@ def get_all_builtin_tool_specs() -> List[ToolSpec]:
|
||||
|
||||
return [
|
||||
*get_timing_tool_specs(),
|
||||
get_finish_tool_spec(),
|
||||
get_reply_tool_spec(),
|
||||
get_view_complex_message_tool_spec(),
|
||||
get_query_jargon_tool_spec(),
|
||||
get_query_memory_tool_spec(enabled=True),
|
||||
get_query_person_info_tool_spec(),
|
||||
get_send_emoji_tool_spec(),
|
||||
get_tool_search_tool_spec(),
|
||||
]
|
||||
|
||||
|
||||
@@ -95,6 +103,7 @@ def build_builtin_tool_handlers(tool_ctx: BuiltinToolRuntimeContext) -> Dict[str
|
||||
|
||||
return {
|
||||
"continue": lambda invocation, context=None: handle_continue_tool(tool_ctx, invocation, context),
|
||||
"finish": lambda invocation, context=None: handle_finish_tool(tool_ctx, invocation, context),
|
||||
"reply": lambda invocation, context=None: handle_reply_tool(tool_ctx, invocation, context),
|
||||
"no_reply": lambda invocation, context=None: handle_no_reply_tool(tool_ctx, invocation, context),
|
||||
"query_jargon": lambda invocation, context=None: handle_query_jargon_tool(tool_ctx, invocation, context),
|
||||
@@ -106,6 +115,7 @@ def build_builtin_tool_handlers(tool_ctx: BuiltinToolRuntimeContext) -> Dict[str
|
||||
),
|
||||
"wait": lambda invocation, context=None: handle_wait_tool(tool_ctx, invocation, context),
|
||||
"send_emoji": lambda invocation, context=None: handle_send_emoji_tool(tool_ctx, invocation, context),
|
||||
"tool_search": lambda invocation, context=None: handle_tool_search_tool(tool_ctx, invocation, context),
|
||||
"view_complex_message": lambda invocation, context=None: handle_view_complex_message_tool(
|
||||
tool_ctx,
|
||||
invocation,
|
||||
|
||||
34
src/maisaka/builtin_tool/finish.py
Normal file
34
src/maisaka/builtin_tool/finish.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""finish 内置工具。"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
|
||||
|
||||
def get_tool_spec() -> ToolSpec:
|
||||
"""获取 finish 工具声明。"""
|
||||
|
||||
return ToolSpec(
|
||||
name="finish",
|
||||
brief_description="结束本轮思考,等待后续新的外部消息再继续。",
|
||||
provider_name="maisaka_builtin",
|
||||
provider_type="builtin",
|
||||
)
|
||||
|
||||
|
||||
async def handle_tool(
|
||||
tool_ctx: BuiltinToolRuntimeContext,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 finish 内置工具。"""
|
||||
|
||||
del context
|
||||
tool_ctx.runtime._enter_stop_state()
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
"当前对话循环已结束本轮思考,等待新的消息到来。",
|
||||
metadata={"pause_execution": True},
|
||||
)
|
||||
@@ -29,6 +29,6 @@ async def handle_tool(
|
||||
tool_ctx.runtime._enter_stop_state()
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
"当前对话循环已暂停,等待新消息到来。",
|
||||
"当前暂时停止思考,等待新消息到来。",
|
||||
metadata={"pause_execution": True},
|
||||
)
|
||||
|
||||
@@ -91,10 +91,6 @@ async def handle_tool(
|
||||
f"未找到要回复的目标消息,msg_id={target_message_id}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{tool_ctx.runtime.log_prefix} 已触发回复工具,"
|
||||
f"目标消息编号={target_message_id} 引用回复={set_quote} 最新思考={latest_thought!r}"
|
||||
)
|
||||
try:
|
||||
replyer = replyer_manager.get_replyer(
|
||||
chat_stream=tool_ctx.runtime.chat_stream,
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
import math
|
||||
from random import sample
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from PIL import ImageDraw, ImageFont
|
||||
@@ -20,12 +20,14 @@ from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||
from src.maisaka.context_messages import (
|
||||
LLMContextMessage,
|
||||
ReferenceMessage,
|
||||
ReferenceMessageType,
|
||||
SessionBackedMessage,
|
||||
)
|
||||
from src.plugin_runtime.hook_payloads import serialize_prompt_messages
|
||||
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
|
||||
@@ -242,34 +244,9 @@ def _build_emoji_candidate_summary(emojis: list[MaiEmoji]) -> str:
|
||||
return "\n".join(summary_lines).strip()
|
||||
|
||||
|
||||
def _build_send_emoji_prompt_preview(
|
||||
*,
|
||||
system_prompt: str,
|
||||
requested_emotion: str,
|
||||
grid_rows: int,
|
||||
grid_columns: int,
|
||||
sampled_emojis: list[MaiEmoji],
|
||||
) -> str:
|
||||
"""构建表情选择子代理的文本预览。"""
|
||||
|
||||
task_text = (
|
||||
"[选择任务]\n"
|
||||
f"requested_emotion: {requested_emotion or '未指定'}\n"
|
||||
f"候选总数: {len(sampled_emojis)}\n"
|
||||
f"拼图布局: {grid_rows}x{grid_columns}\n"
|
||||
"请只输出 JSON。"
|
||||
)
|
||||
candidate_summary = _build_emoji_candidate_summary(sampled_emojis)
|
||||
return (
|
||||
f"[System Prompt]\n{system_prompt}\n\n"
|
||||
f"{task_text}\n\n"
|
||||
f"[候选表情摘要]\n{candidate_summary or '无候选表情'}"
|
||||
).strip()
|
||||
|
||||
|
||||
def _build_send_emoji_monitor_detail(
|
||||
*,
|
||||
prompt_text: str = "",
|
||||
request_messages: Optional[list[dict[str, Any]]] = None,
|
||||
reasoning_text: str = "",
|
||||
output_text: str = "",
|
||||
metrics: Optional[Dict[str, Any]] = None,
|
||||
@@ -278,8 +255,8 @@ def _build_send_emoji_monitor_detail(
|
||||
"""构建 emotion tool 统一监控详情。"""
|
||||
|
||||
detail: Dict[str, Any] = {}
|
||||
if prompt_text.strip():
|
||||
detail["prompt_text"] = prompt_text.strip()
|
||||
if isinstance(request_messages, list) and request_messages:
|
||||
detail["request_messages"] = request_messages
|
||||
if reasoning_text.strip():
|
||||
detail["reasoning_text"] = reasoning_text.strip()
|
||||
if output_text.strip():
|
||||
@@ -387,13 +364,16 @@ async def _select_emoji_with_sub_agent(
|
||||
remaining_uses_value=1,
|
||||
display_prefix="[表情包选择任务]",
|
||||
)
|
||||
prompt_preview = _build_send_emoji_prompt_preview(
|
||||
system_prompt=system_prompt,
|
||||
requested_emotion=requested_emotion,
|
||||
grid_rows=grid_rows,
|
||||
grid_columns=grid_columns,
|
||||
sampled_emojis=sampled_emojis,
|
||||
)
|
||||
request_messages = [
|
||||
MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build(),
|
||||
]
|
||||
prompt_llm_message = prompt_message.to_llm_message()
|
||||
if prompt_llm_message is not None:
|
||||
request_messages.append(prompt_llm_message)
|
||||
candidate_llm_message = candidate_message.to_llm_message()
|
||||
if candidate_llm_message is not None:
|
||||
request_messages.append(candidate_llm_message)
|
||||
serialized_request_messages = serialize_prompt_messages(request_messages)
|
||||
|
||||
selection_started_at = datetime.now()
|
||||
response = await tool_ctx.runtime.run_sub_agent(
|
||||
@@ -421,7 +401,7 @@ async def _select_emoji_with_sub_agent(
|
||||
logger.warning(f"{tool_ctx.runtime.log_prefix} 表情包子代理结果解析失败,将回退到候选首项: {exc}")
|
||||
if selection_metadata is not None:
|
||||
selection_metadata["monitor_detail"] = _build_send_emoji_monitor_detail(
|
||||
prompt_text=prompt_preview,
|
||||
request_messages=serialized_request_messages,
|
||||
output_text=response.content or "",
|
||||
metrics=selection_metrics,
|
||||
extra_sections=[{
|
||||
@@ -435,7 +415,7 @@ async def _select_emoji_with_sub_agent(
|
||||
if selection_metadata is not None:
|
||||
selection_metadata["reason"] = selection.reason.strip()
|
||||
selection_metadata["monitor_detail"] = _build_send_emoji_monitor_detail(
|
||||
prompt_text=prompt_preview,
|
||||
request_messages=serialized_request_messages,
|
||||
reasoning_text=selection.reason,
|
||||
output_text=response.content or "",
|
||||
metrics=selection_metrics,
|
||||
|
||||
106
src/maisaka/builtin_tool/tool_search.py
Normal file
106
src/maisaka/builtin_tool/tool_search.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""tool_search 内置工具。"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import json
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
|
||||
|
||||
def get_tool_spec() -> ToolSpec:
|
||||
"""获取 tool_search 工具声明。"""
|
||||
|
||||
return ToolSpec(
|
||||
name="tool_search",
|
||||
brief_description="在 deferred tools 列表中按名称或关键词搜索工具,并将命中的工具加入后续轮次的可用工具列表。",
|
||||
detailed_description=(
|
||||
"参数说明:\n"
|
||||
"- query:String,必填。工具名、前缀或关键词。\n"
|
||||
"- limit:Integer,可选。最多返回多少个匹配工具,默认为 5。"
|
||||
),
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "要搜索的工具名、前缀或关键词。",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "最多返回多少个匹配工具。",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
provider_name="maisaka_builtin",
|
||||
provider_type="builtin",
|
||||
)
|
||||
|
||||
|
||||
async def handle_tool(
|
||||
tool_ctx: BuiltinToolRuntimeContext,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 tool_search 内置工具。"""
|
||||
|
||||
del context
|
||||
raw_query = invocation.arguments.get("query")
|
||||
if not isinstance(raw_query, str) or not raw_query.strip():
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"tool_search 需要提供非空的 `query` 字符串参数。",
|
||||
)
|
||||
|
||||
raw_limit = invocation.arguments.get("limit", 5)
|
||||
try:
|
||||
limit = max(1, int(raw_limit))
|
||||
except (TypeError, ValueError):
|
||||
limit = 5
|
||||
|
||||
matched_tool_specs = tool_ctx.runtime.search_deferred_tool_specs(raw_query, limit=limit)
|
||||
matched_tool_names = [tool_spec.name for tool_spec in matched_tool_specs]
|
||||
newly_discovered_tool_names = tool_ctx.runtime.discover_deferred_tools(matched_tool_names)
|
||||
|
||||
structured_content: Dict[str, Any] = {
|
||||
"query": raw_query.strip(),
|
||||
"matched_tool_names": matched_tool_names,
|
||||
"newly_discovered_tool_names": newly_discovered_tool_names,
|
||||
}
|
||||
|
||||
if not matched_tool_names:
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
"未找到匹配的 deferred tools,请尝试更完整的工具名、前缀或其他关键词。",
|
||||
structured_content=structured_content,
|
||||
metadata={"record_display_prompt": "tool_search 未找到匹配工具。"},
|
||||
)
|
||||
|
||||
content_lines: List[str] = [
|
||||
f"已找到 {len(matched_tool_names)} 个 deferred tools,它们会在后续轮次中加入可用工具列表:",
|
||||
*[f"- {tool_name}" for tool_name in matched_tool_names],
|
||||
]
|
||||
if newly_discovered_tool_names:
|
||||
content_lines.extend(
|
||||
[
|
||||
"",
|
||||
"本次新发现的工具:",
|
||||
*[f"- {tool_name}" for tool_name in newly_discovered_tool_names],
|
||||
]
|
||||
)
|
||||
else:
|
||||
content_lines.extend(["", "这些工具此前已经发现过,无需重复展开。"])
|
||||
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
"\n".join(content_lines),
|
||||
structured_content=structured_content,
|
||||
metadata={
|
||||
"matched_tool_names": matched_tool_names,
|
||||
"newly_discovered_tool_names": newly_discovered_tool_names,
|
||||
"record_display_prompt": json.dumps(structured_content, ensure_ascii=False),
|
||||
},
|
||||
)
|
||||
@@ -12,8 +12,8 @@ def get_tool_spec() -> ToolSpec:
|
||||
|
||||
return ToolSpec(
|
||||
name="wait",
|
||||
brief_description="暂停当前对话并等待用户新的输入。",
|
||||
detailed_description="参数说明:\n- seconds:integer,必填。等待的秒数。",
|
||||
brief_description="暂停当前对话并固定等待一段时间,期间不因新消息提前恢复。",
|
||||
detailed_description="参数说明:\n- seconds:integer,必填。等待的秒数。等待期间收到的新消息只会暂存,直到超时后再继续处理。",
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -46,6 +46,6 @@ async def handle_tool(
|
||||
tool_ctx.runtime._enter_wait_state(seconds=wait_seconds, tool_call_id=invocation.call_id)
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
f"当前对话循环进入等待状态,最长等待 {wait_seconds} 秒。",
|
||||
f"当前对话循环进入等待状态,将固定等待 {wait_seconds} 秒;期间收到的新消息不会提前打断本次等待。",
|
||||
metadata={"pause_execution": True},
|
||||
)
|
||||
|
||||
@@ -5,20 +5,18 @@ from datetime import datetime
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
from rich.console import RenderableType
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolRegistry, ToolSpec
|
||||
from src.core.tooling import ToolRegistry
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||
from src.llm_models.payload_content.resp_format import RespFormat
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, ToolOption, normalize_tool_options
|
||||
from src.plugin_runtime.hook_payloads import (
|
||||
deserialize_prompt_messages,
|
||||
@@ -32,9 +30,11 @@ from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistr
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from .builtin_tool import get_builtin_tools
|
||||
from .context_messages import AssistantMessage, LLMContextMessage
|
||||
from .context_messages import AssistantMessage, LLMContextMessage, ToolResultMessage
|
||||
from .history_utils import drop_orphan_tool_results
|
||||
from .prompt_cli_renderer import PromptCLIVisualizer
|
||||
from .display.prompt_cli_renderer import PromptCLIVisualizer
|
||||
|
||||
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -54,13 +54,6 @@ class ChatResponse:
|
||||
prompt_section: Optional[RenderableType] = None
|
||||
|
||||
|
||||
class ToolFilterSelection(BaseModel):
|
||||
"""工具筛选响应。"""
|
||||
|
||||
selected_tool_names: list[str] = PydanticField(default_factory=list)
|
||||
"""经过预筛后保留的候选工具名称列表。"""
|
||||
|
||||
|
||||
logger = get_logger("maisaka_chat_loop")
|
||||
|
||||
|
||||
@@ -217,10 +210,6 @@ class MaisakaChatLoopService:
|
||||
else:
|
||||
self._chat_system_prompt = chat_system_prompt
|
||||
self._llm_chat = LLMServiceClient(task_name="planner", request_type="maisaka_planner")
|
||||
self._tool_filter_llm = LLMServiceClient(
|
||||
task_name=global_config.maisaka.tool_filter_task_name,
|
||||
request_type="maisaka_tool_filter",
|
||||
)
|
||||
|
||||
@property
|
||||
def personality_prompt(self) -> str:
|
||||
@@ -303,8 +292,15 @@ class MaisakaChatLoopService:
|
||||
"file_tools_section": tools_section,
|
||||
"group_chat_attention_block": self._build_group_chat_attention_block(),
|
||||
"identity": self._personality_prompt,
|
||||
"time_block": self._build_time_block(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_time_block() -> str:
|
||||
"""构建当前时间提示块。"""
|
||||
|
||||
return f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
def _build_group_chat_attention_block(self) -> str:
|
||||
"""构建当前聊天场景下的额外注意事项块。"""
|
||||
|
||||
@@ -399,6 +395,7 @@ class MaisakaChatLoopService:
|
||||
self,
|
||||
selected_history: List[LLMContextMessage],
|
||||
*,
|
||||
injected_user_messages: Sequence[str] | None = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> List[Message]:
|
||||
"""构造发给大模型的消息列表。
|
||||
@@ -420,254 +417,49 @@ class MaisakaChatLoopService:
|
||||
if llm_message is not None:
|
||||
messages.append(llm_message)
|
||||
|
||||
normalized_injected_messages: List[Message] = []
|
||||
for injected_message in injected_user_messages or []:
|
||||
normalized_message = str(injected_message or "").strip()
|
||||
if not normalized_message:
|
||||
continue
|
||||
normalized_injected_messages.append(
|
||||
MessageBuilder()
|
||||
.set_role(RoleType.User)
|
||||
.add_text_content(normalized_message)
|
||||
.build()
|
||||
)
|
||||
|
||||
if normalized_injected_messages:
|
||||
insertion_index = self._resolve_injected_user_messages_insertion_index(messages)
|
||||
messages[insertion_index:insertion_index] = normalized_injected_messages
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _is_builtin_tool_spec(tool_spec: ToolSpec) -> bool:
|
||||
"""判断一个工具是否属于默认内置工具。
|
||||
def _resolve_injected_user_messages_insertion_index(messages: Sequence[Message]) -> int:
|
||||
"""计算 injected meta user messages 在请求中的插入位置。
|
||||
|
||||
Args:
|
||||
tool_spec: 待判断的工具声明。
|
||||
|
||||
Returns:
|
||||
bool: 是否为默认内置工具。
|
||||
规则与 deferred attachment 更接近:
|
||||
- 从尾部向前寻找最近的 stopping point;
|
||||
- stopping point 为 assistant 消息或 tool 结果消息;
|
||||
- 找到后插入到其后面;
|
||||
- 若不存在 stopping point,则退回到 system 消息之后。
|
||||
"""
|
||||
|
||||
return tool_spec.provider_type == "builtin" or tool_spec.provider_name == "maisaka_builtin"
|
||||
for index in range(len(messages) - 1, -1, -1):
|
||||
message = messages[index]
|
||||
if message.role in {RoleType.Assistant, RoleType.Tool}:
|
||||
return index + 1
|
||||
|
||||
@classmethod
|
||||
def _split_builtin_and_candidate_tools(
|
||||
cls,
|
||||
tool_specs: List[ToolSpec],
|
||||
) -> tuple[List[ToolSpec], List[ToolSpec]]:
|
||||
"""拆分内置工具与可筛选工具列表。
|
||||
|
||||
Args:
|
||||
tool_specs: 当前全部工具声明。
|
||||
|
||||
Returns:
|
||||
tuple[List[ToolSpec], List[ToolSpec]]: `(内置工具, 可筛选工具)`。
|
||||
"""
|
||||
|
||||
builtin_tool_specs: List[ToolSpec] = []
|
||||
candidate_tool_specs: List[ToolSpec] = []
|
||||
for tool_spec in tool_specs:
|
||||
if cls._is_builtin_tool_spec(tool_spec):
|
||||
builtin_tool_specs.append(tool_spec)
|
||||
else:
|
||||
candidate_tool_specs.append(tool_spec)
|
||||
return builtin_tool_specs, candidate_tool_specs
|
||||
|
||||
@staticmethod
|
||||
def _truncate_tool_filter_text(text: str, max_length: int = 180) -> str:
|
||||
"""截断工具筛选阶段展示的文本。
|
||||
|
||||
Args:
|
||||
text: 原始文本。
|
||||
max_length: 最长保留字符数。
|
||||
|
||||
Returns:
|
||||
str: 截断后的文本。
|
||||
"""
|
||||
|
||||
normalized_text = text.strip()
|
||||
if len(normalized_text) <= max_length:
|
||||
return normalized_text
|
||||
return f"{normalized_text[: max_length - 1]}…"
|
||||
|
||||
def _build_tool_filter_prompt(
|
||||
self,
|
||||
selected_history: List[LLMContextMessage],
|
||||
candidate_tool_specs: List[ToolSpec],
|
||||
max_keep: int,
|
||||
) -> str:
|
||||
"""构造小模型工具预筛选提示词。
|
||||
|
||||
Args:
|
||||
selected_history: 已选中的对话上下文。
|
||||
candidate_tool_specs: 非内置候选工具列表。
|
||||
max_keep: 最多保留的候选工具数量。
|
||||
|
||||
Returns:
|
||||
str: 用于工具预筛的小模型提示词。
|
||||
"""
|
||||
|
||||
history_lines: List[str] = []
|
||||
for message in selected_history[-10:]:
|
||||
plain_text = message.processed_plain_text.strip()
|
||||
if not plain_text:
|
||||
continue
|
||||
history_lines.append(
|
||||
f"- {message.role}: {self._truncate_tool_filter_text(plain_text, max_length=200)}"
|
||||
)
|
||||
|
||||
if history_lines:
|
||||
history_section = "\n".join(history_lines)
|
||||
else:
|
||||
history_section = "- 当前没有可用的对话上下文。"
|
||||
|
||||
tool_lines = [
|
||||
f"- {tool_spec.name}: {tool_spec.brief_description.strip() or '无简要描述'}"
|
||||
for tool_spec in candidate_tool_specs
|
||||
]
|
||||
tool_section = "\n".join(tool_lines) if tool_lines else "- 当前没有候选工具。"
|
||||
|
||||
return (
|
||||
"你是 Maisaka 的工具预筛选器。\n"
|
||||
"你的任务是在正式进入 planner 前,根据当前情景从候选工具中挑出最可能马上会用到的工具。\n"
|
||||
"默认内置工具已经自动保留,不在候选列表中,你不需要再次选择它们。\n"
|
||||
"你只能参考工具的简要描述,不要假设未描述的隐藏能力。\n"
|
||||
f"最多保留 {max_keep} 个候选工具;如果都不合适,可以返回空数组。\n"
|
||||
"请严格返回 JSON 对象,格式为:"
|
||||
'{"selected_tool_names":["工具名1","工具名2"]}\n\n'
|
||||
f"【最近对话】\n{history_section}\n\n"
|
||||
f"【候选工具(仅简要描述)】\n{tool_section}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_tool_filter_response(
|
||||
response_text: str,
|
||||
candidate_tool_specs: List[ToolSpec],
|
||||
max_keep: int,
|
||||
) -> List[ToolSpec] | None:
|
||||
"""解析工具预筛选响应。
|
||||
|
||||
Args:
|
||||
response_text: 小模型返回的原始文本。
|
||||
candidate_tool_specs: 非内置候选工具列表。
|
||||
max_keep: 最多保留的候选工具数量。
|
||||
|
||||
Returns:
|
||||
List[ToolSpec] | None: 成功解析时返回筛选后的工具列表;解析失败时返回 ``None``。
|
||||
"""
|
||||
|
||||
normalized_response = response_text.strip()
|
||||
if not normalized_response:
|
||||
return None
|
||||
|
||||
selected_tool_names: List[str]
|
||||
try:
|
||||
selected_tool_names = ToolFilterSelection.model_validate_json(normalized_response).selected_tool_names
|
||||
except Exception:
|
||||
try:
|
||||
parsed_payload = json.loads(normalized_response)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
if isinstance(parsed_payload, dict):
|
||||
raw_tool_names = parsed_payload.get("selected_tool_names", [])
|
||||
elif isinstance(parsed_payload, list):
|
||||
raw_tool_names = parsed_payload
|
||||
else:
|
||||
return None
|
||||
|
||||
if not isinstance(raw_tool_names, list):
|
||||
return None
|
||||
|
||||
selected_tool_names = []
|
||||
for item in raw_tool_names:
|
||||
normalized_name = str(item).strip()
|
||||
if normalized_name:
|
||||
selected_tool_names.append(normalized_name)
|
||||
|
||||
candidate_map = {tool_spec.name: tool_spec for tool_spec in candidate_tool_specs}
|
||||
filtered_tool_specs: List[ToolSpec] = []
|
||||
seen_names: set[str] = set()
|
||||
for tool_name in selected_tool_names:
|
||||
normalized_name = tool_name.strip()
|
||||
if not normalized_name or normalized_name in seen_names:
|
||||
continue
|
||||
tool_spec = candidate_map.get(normalized_name)
|
||||
if tool_spec is None:
|
||||
continue
|
||||
|
||||
seen_names.add(normalized_name)
|
||||
filtered_tool_specs.append(tool_spec)
|
||||
if len(filtered_tool_specs) >= max_keep:
|
||||
break
|
||||
|
||||
return filtered_tool_specs
|
||||
|
||||
async def _filter_tool_specs_for_planner(
|
||||
self,
|
||||
selected_history: List[LLMContextMessage],
|
||||
tool_specs: List[ToolSpec],
|
||||
) -> List[ToolSpec]:
|
||||
"""在将工具交给 planner 前进行快速预筛选。
|
||||
|
||||
Args:
|
||||
selected_history: 已选中的对话上下文。
|
||||
tool_specs: 当前全部可用工具声明。
|
||||
|
||||
Returns:
|
||||
List[ToolSpec]: 最终交给 planner 的工具声明列表。
|
||||
"""
|
||||
|
||||
threshold = max(1, int(global_config.maisaka.tool_filter_threshold))
|
||||
max_keep = max(1, int(global_config.maisaka.tool_filter_max_keep))
|
||||
if len(tool_specs) <= threshold:
|
||||
return tool_specs
|
||||
|
||||
builtin_tool_specs, candidate_tool_specs = self._split_builtin_and_candidate_tools(tool_specs)
|
||||
if not candidate_tool_specs:
|
||||
return tool_specs
|
||||
if len(candidate_tool_specs) <= max_keep:
|
||||
return [*builtin_tool_specs, *candidate_tool_specs]
|
||||
|
||||
filter_prompt = self._build_tool_filter_prompt(selected_history, candidate_tool_specs, max_keep)
|
||||
logger.info(
|
||||
"工具预筛选开始: "
|
||||
f"总工具数={len(tool_specs)} "
|
||||
f"内置工具数={len(builtin_tool_specs)} "
|
||||
f"候选工具数={len(candidate_tool_specs)} "
|
||||
f"最多保留候选数={max_keep}"
|
||||
)
|
||||
|
||||
try:
|
||||
generation_result = await self._tool_filter_llm.generate_response(
|
||||
prompt=filter_prompt,
|
||||
options=LLMGenerationOptions(
|
||||
temperature=0.0,
|
||||
max_tokens=256,
|
||||
response_format=RespFormat(
|
||||
format_type=RespFormatType.JSON_SCHEMA,
|
||||
schema=ToolFilterSelection,
|
||||
),
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"工具预筛选失败,保留全部工具。错误={exc}")
|
||||
return tool_specs
|
||||
|
||||
filtered_candidate_tool_specs = self._parse_tool_filter_response(
|
||||
generation_result.response or "",
|
||||
candidate_tool_specs,
|
||||
max_keep,
|
||||
)
|
||||
if filtered_candidate_tool_specs is None:
|
||||
logger.warning(
|
||||
"工具预筛选返回结果无法解析,保留全部工具。"
|
||||
f" 原始返回={generation_result.response or ''!r}"
|
||||
)
|
||||
return tool_specs
|
||||
|
||||
filtered_tool_specs = [*builtin_tool_specs, *filtered_candidate_tool_specs]
|
||||
if not filtered_tool_specs:
|
||||
logger.warning("工具预筛选得到空结果,保留全部工具以避免主流程失去工具能力。")
|
||||
return tool_specs
|
||||
|
||||
logger.info(
|
||||
"工具预筛选完成: "
|
||||
f"筛选前总数={len(tool_specs)} "
|
||||
f"筛选后总数={len(filtered_tool_specs)} "
|
||||
f"保留候选工具={[tool_spec.name for tool_spec in filtered_candidate_tool_specs]}"
|
||||
)
|
||||
return filtered_tool_specs
|
||||
if messages and messages[0].role == RoleType.System:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
async def chat_loop_step(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
*,
|
||||
injected_user_messages: Sequence[str] | None = None,
|
||||
request_kind: str = "planner",
|
||||
response_format: RespFormat | None = None,
|
||||
tool_definitions: Sequence[ToolDefinitionInput] | None = None,
|
||||
@@ -683,8 +475,14 @@ class MaisakaChatLoopService:
|
||||
|
||||
if not self._prompts_loaded:
|
||||
await self.ensure_chat_prompt_loaded()
|
||||
selected_history, selection_reason = self.select_llm_context_messages(chat_history)
|
||||
built_messages = self._build_request_messages(selected_history)
|
||||
selected_history, selection_reason = self.select_llm_context_messages(
|
||||
chat_history,
|
||||
request_kind=request_kind,
|
||||
)
|
||||
built_messages = self._build_request_messages(
|
||||
selected_history,
|
||||
injected_user_messages=injected_user_messages,
|
||||
)
|
||||
|
||||
def message_factory(_client: BaseClient) -> List[Message]:
|
||||
"""返回当前轮次已经构建好的请求消息。
|
||||
@@ -704,8 +502,7 @@ class MaisakaChatLoopService:
|
||||
all_tools = list(tool_definitions)
|
||||
elif self._tool_registry is not None:
|
||||
tool_specs = await self._tool_registry.list_tools()
|
||||
filtered_tool_specs = await self._filter_tool_specs_for_planner(selected_history, tool_specs)
|
||||
all_tools = [tool_spec.to_llm_definition() for tool_spec in filtered_tool_specs]
|
||||
all_tools = [tool_spec.to_llm_definition() for tool_spec in tool_specs]
|
||||
else:
|
||||
all_tools = [*get_builtin_tools(), *self._extra_tools]
|
||||
|
||||
@@ -740,15 +537,9 @@ class MaisakaChatLoopService:
|
||||
selection_reason=selection_reason,
|
||||
image_display_mode=image_display_mode,
|
||||
folded=global_config.debug.fold_maisaka_thinking,
|
||||
tool_definitions=list(all_tools),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"规划器请求开始: "
|
||||
f"已选上下文消息数={len(selected_history)} "
|
||||
f"大模型消息数={len(built_messages)} "
|
||||
f"工具数={len(all_tools)} "
|
||||
f"启用打断={self._interrupt_flag is not None}"
|
||||
)
|
||||
generation_result = await self._llm_chat.generate_response_with_messages(
|
||||
message_factory=message_factory,
|
||||
options=LLMGenerationOptions(
|
||||
@@ -760,15 +551,6 @@ class MaisakaChatLoopService:
|
||||
),
|
||||
)
|
||||
|
||||
prompt_stats_text = PromptCLIVisualizer.build_prompt_stats_text(
|
||||
selected_history_count=len(selected_history),
|
||||
built_message_count=len(built_messages),
|
||||
prompt_tokens=generation_result.prompt_tokens,
|
||||
completion_tokens=generation_result.completion_tokens,
|
||||
total_tokens=generation_result.total_tokens,
|
||||
)
|
||||
logger.info(f"本轮Prompt统计: {prompt_stats_text}")
|
||||
|
||||
final_response = generation_result.response or ""
|
||||
final_tool_calls = list(generation_result.tool_calls or [])
|
||||
after_response_result = await self._get_runtime_manager().invoke_hook(
|
||||
@@ -822,16 +604,21 @@ class MaisakaChatLoopService:
|
||||
def select_llm_context_messages(
|
||||
chat_history: List[LLMContextMessage],
|
||||
*,
|
||||
request_kind: str = "planner",
|
||||
max_context_size: Optional[int] = None,
|
||||
) -> tuple[List[LLMContextMessage], str]:
|
||||
"""??????? LLM ???????"""
|
||||
"""选择LLM上下文消息"""
|
||||
|
||||
filtered_history = MaisakaChatLoopService._filter_history_for_request_kind(
|
||||
chat_history,
|
||||
request_kind=request_kind,
|
||||
)
|
||||
effective_context_size = max(1, int(max_context_size or global_config.chat.max_context_size))
|
||||
selected_indices: List[int] = []
|
||||
counted_message_count = 0
|
||||
|
||||
for index in range(len(chat_history) - 1, -1, -1):
|
||||
message = chat_history[index]
|
||||
for index in range(len(filtered_history) - 1, -1, -1):
|
||||
message = filtered_history[index]
|
||||
if message.to_llm_message() is None:
|
||||
continue
|
||||
|
||||
@@ -842,10 +629,10 @@ class MaisakaChatLoopService:
|
||||
break
|
||||
|
||||
if not selected_indices:
|
||||
return [], f"???????? {effective_context_size} ? user/assistant??? 0 ??"
|
||||
return [], f"没有选择到上下文消息,实际发送 {effective_context_size} 条 user/assistant 消息"
|
||||
|
||||
selected_indices.reverse()
|
||||
selected_history = [chat_history[index] for index in selected_indices]
|
||||
selected_history = [filtered_history[index] for index in selected_indices]
|
||||
selected_history, hidden_assistant_count = MaisakaChatLoopService._hide_early_assistant_messages(selected_history)
|
||||
selected_history, _ = drop_orphan_tool_results(selected_history)
|
||||
selection_reason = (
|
||||
@@ -860,45 +647,43 @@ class MaisakaChatLoopService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _select_llm_context_messages(chat_history: List[LLMContextMessage]) -> tuple[List[LLMContextMessage], str]:
|
||||
"""选择真正发送给 LLM 的上下文消息。
|
||||
def _filter_history_for_request_kind(
|
||||
selected_history: List[LLMContextMessage],
|
||||
*,
|
||||
request_kind: str,
|
||||
) -> List[LLMContextMessage]:
|
||||
"""按请求类型过滤不应暴露的历史工具链。"""
|
||||
|
||||
Args:
|
||||
chat_history: 当前全部对话历史。
|
||||
if request_kind != "planner":
|
||||
return selected_history
|
||||
|
||||
Returns:
|
||||
tuple[List[LLMContextMessage], str]: `(已选上下文, 选择说明)`。
|
||||
"""
|
||||
|
||||
max_context_size = max(1, int(global_config.chat.max_context_size))
|
||||
selected_indices: List[int] = []
|
||||
counted_message_count = 0
|
||||
|
||||
for index in range(len(chat_history) - 1, -1, -1):
|
||||
message = chat_history[index]
|
||||
if message.to_llm_message() is None:
|
||||
filtered_history: List[LLMContextMessage] = []
|
||||
for message in selected_history:
|
||||
if isinstance(message, ToolResultMessage) and message.tool_name in TIMING_GATE_TOOL_NAMES:
|
||||
continue
|
||||
|
||||
selected_indices.append(index)
|
||||
if message.count_in_context:
|
||||
counted_message_count += 1
|
||||
if counted_message_count >= max_context_size:
|
||||
break
|
||||
if isinstance(message, AssistantMessage) and message.tool_calls:
|
||||
kept_tool_calls = [
|
||||
tool_call
|
||||
for tool_call in message.tool_calls
|
||||
if tool_call.func_name not in TIMING_GATE_TOOL_NAMES
|
||||
]
|
||||
if not kept_tool_calls:
|
||||
continue
|
||||
if len(kept_tool_calls) != len(message.tool_calls):
|
||||
filtered_history.append(
|
||||
AssistantMessage(
|
||||
content=message.content,
|
||||
timestamp=message.timestamp,
|
||||
tool_calls=kept_tool_calls,
|
||||
source_kind=message.source_kind,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if not selected_indices:
|
||||
return [], f"上下文判定:最近 {max_context_size} 条 user/assistant(当前 0 条)"
|
||||
filtered_history.append(message)
|
||||
|
||||
selected_indices.reverse()
|
||||
selected_history = [chat_history[index] for index in selected_indices]
|
||||
selected_history, hidden_assistant_count = MaisakaChatLoopService._hide_early_assistant_messages(selected_history)
|
||||
selected_history, _ = drop_orphan_tool_results(selected_history)
|
||||
return (
|
||||
selected_history,
|
||||
(
|
||||
f"上下文判定:最近 {max_context_size} 条 user/assistant;"
|
||||
f"展示并发送窗口内消息 {len(selected_history)} 条"
|
||||
),
|
||||
)
|
||||
return filtered_history
|
||||
|
||||
@staticmethod
|
||||
def _hide_early_assistant_messages(
|
||||
|
||||
@@ -51,7 +51,9 @@ def _append_emoji_component(builder: MessageBuilder, component: EmojiComponent)
|
||||
if component.content:
|
||||
builder.add_text_content(component.content)
|
||||
return True
|
||||
return False
|
||||
|
||||
builder.add_text_content("[表情包]")
|
||||
return True
|
||||
|
||||
|
||||
def _append_image_component(builder: MessageBuilder, component: ImageComponent) -> bool:
|
||||
@@ -65,7 +67,9 @@ def _append_image_component(builder: MessageBuilder, component: ImageComponent)
|
||||
if component.content:
|
||||
builder.add_text_content(component.content)
|
||||
return True
|
||||
return False
|
||||
|
||||
builder.add_text_content("[图片]")
|
||||
return True
|
||||
|
||||
|
||||
def _append_reply_component(builder: MessageBuilder, component: ReplyComponent) -> bool:
|
||||
|
||||
33
src/maisaka/display/__init__.py
Normal file
33
src/maisaka/display/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Maisaka 展示模块。"""
|
||||
|
||||
from .display_utils import (
|
||||
build_tool_call_summary_lines,
|
||||
format_token_count,
|
||||
format_tool_call_for_display,
|
||||
get_request_panel_style,
|
||||
get_role_badge_label,
|
||||
get_role_badge_style,
|
||||
)
|
||||
from .prompt_cli_renderer import PromptCLIVisualizer
|
||||
from .prompt_preview_logger import PromptPreviewLogger
|
||||
from .stage_status_board import (
|
||||
disable_stage_status_board,
|
||||
enable_stage_status_board,
|
||||
remove_stage_status,
|
||||
update_stage_status,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PromptCLIVisualizer",
|
||||
"PromptPreviewLogger",
|
||||
"build_tool_call_summary_lines",
|
||||
"disable_stage_status_board",
|
||||
"enable_stage_status_board",
|
||||
"format_token_count",
|
||||
"format_tool_call_for_display",
|
||||
"get_request_panel_style",
|
||||
"get_role_badge_label",
|
||||
"get_role_badge_style",
|
||||
"remove_stage_status",
|
||||
"update_stage_status",
|
||||
]
|
||||
@@ -4,14 +4,15 @@ from typing import Any
|
||||
|
||||
|
||||
_REQUEST_PANEL_STYLE_MAP: dict[str, tuple[str, str]] = {
|
||||
"timing_gate": ("\u004d\u0061\u0069\u0053\u0061\u006b\u0061 \u5927\u6a21\u578b\u8bf7\u6c42 - Timing Gate \u5b50\u4ee3\u7406", "bright_magenta"),
|
||||
"replyer": ("\u004d\u0061\u0069\u0053\u0061\u006b\u0061 \u56de\u590d\u5668 Prompt", "bright_yellow"),
|
||||
"planner": ("MaiSaka 大模型请求 - 对话单步", "green"),
|
||||
"timing_gate": ("MaiSaka 大模型请求 - Timing Gate 子代理", "bright_magenta"),
|
||||
"replyer": ("MaiSaka 回复器 Prompt", "bright_yellow"),
|
||||
"emotion": ("MaiSaka Emotion Tool Prompt", "bright_cyan"),
|
||||
"sub_agent": ("\u004d\u0061\u0069\u0053\u0061\u006b\u0061 \u5927\u6a21\u578b\u8bf7\u6c42 - \u5b50\u4ee3\u7406", "bright_blue"),
|
||||
"sub_agent": ("MaiSaka 大模型请求 - 子代理", "bright_blue"),
|
||||
}
|
||||
|
||||
_DEFAULT_REQUEST_PANEL_STYLE: tuple[str, str] = (
|
||||
"\u004d\u0061\u0069\u0053\u0061\u006b\u0061 \u5927\u6a21\u578b\u8bf7\u6c42 - \u5bf9\u8bdd\u5355\u6b65",
|
||||
"MaiSaka 大模型请求 - 对话单步",
|
||||
"cyan",
|
||||
)
|
||||
|
||||
@@ -23,10 +24,10 @@ _ROLE_BADGE_STYLE_MAP: dict[str, str] = {
|
||||
}
|
||||
|
||||
_ROLE_BADGE_LABEL_MAP: dict[str, str] = {
|
||||
"system": "\u7cfb\u7edf",
|
||||
"user": "\u7528\u6237",
|
||||
"assistant": "\u52a9\u624b",
|
||||
"tool": "\u5de5\u5177",
|
||||
"system": "系统",
|
||||
"user": "用户",
|
||||
"assistant": "助手",
|
||||
"tool": "工具",
|
||||
}
|
||||
|
||||
|
||||
@@ -54,7 +55,7 @@ def get_role_badge_style(role: str) -> str:
|
||||
def get_role_badge_label(role: str) -> str:
|
||||
"""返回角色标签对应的展示文案。"""
|
||||
|
||||
return _ROLE_BADGE_LABEL_MAP.get(role, "\u672a\u77e5")
|
||||
return _ROLE_BADGE_LABEL_MAP.get(role, "未知")
|
||||
|
||||
|
||||
def format_tool_call_for_display(tool_call: Any) -> dict[str, Any]:
|
||||
58
src/maisaka/display/preview_path_utils.py
Normal file
58
src/maisaka/display/preview_path_utils.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Maisaka Prompt 预览路径工具。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote
|
||||
|
||||
import re
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve()
|
||||
SAFE_NAME_PATTERN = re.compile(r"[^A-Za-z0-9._-]+")
|
||||
|
||||
|
||||
def normalize_preview_name(value: str) -> str:
|
||||
normalized_value = SAFE_NAME_PATTERN.sub("_", str(value or "").strip()).strip("._")
|
||||
if normalized_value:
|
||||
return normalized_value
|
||||
return "unknown"
|
||||
|
||||
|
||||
def normalize_platform_name(platform: str) -> str:
|
||||
normalized_platform = str(platform or "").strip().lower()
|
||||
platform_aliases = {
|
||||
"telegram": "tg",
|
||||
}
|
||||
return normalize_preview_name(platform_aliases.get(normalized_platform, normalized_platform))
|
||||
|
||||
|
||||
def build_preview_chat_dir_name(chat_id: str) -> str:
|
||||
session = chat_manager.get_session_by_session_id(chat_id)
|
||||
if session is not None:
|
||||
platform = normalize_platform_name(session.platform)
|
||||
if session.is_group_session and session.group_id:
|
||||
return f"{platform}_group_{normalize_preview_name(session.group_id)}"
|
||||
if session.user_id:
|
||||
return f"{platform}_private_{normalize_preview_name(session.user_id)}"
|
||||
|
||||
normalized_chat_id = normalize_preview_name(chat_id)
|
||||
if normalized_chat_id != "unknown":
|
||||
return normalized_chat_id
|
||||
return "unknown_chat"
|
||||
|
||||
|
||||
def build_display_path(file_path: Path) -> str:
|
||||
"""构造用于展示的路径,项目内文件优先显示相对路径。"""
|
||||
resolved_path = file_path.resolve()
|
||||
try:
|
||||
return resolved_path.relative_to(REPO_ROOT).as_posix()
|
||||
except ValueError:
|
||||
return resolved_path.as_posix()
|
||||
|
||||
|
||||
def build_file_uri(file_path: Path) -> str:
|
||||
normalized = file_path.resolve().as_posix()
|
||||
return f"file:///{quote(normalized, safe='/:')}"
|
||||
@@ -7,7 +7,6 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal
|
||||
from urllib.parse import quote
|
||||
|
||||
import hashlib
|
||||
import html
|
||||
@@ -27,10 +26,10 @@ from .display_utils import (
|
||||
get_role_badge_label as get_shared_role_badge_label,
|
||||
get_role_badge_style as get_shared_role_badge_style,
|
||||
)
|
||||
from .preview_path_utils import build_display_path, build_file_uri, REPO_ROOT
|
||||
from .prompt_preview_logger import PromptPreviewLogger
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve()
|
||||
DATA_IMAGE_DIR = PROJECT_ROOT / "data" / "images"
|
||||
DATA_IMAGE_DIR = REPO_ROOT / "data" / "images"
|
||||
|
||||
|
||||
class PromptImageDisplayMode(str, Enum):
|
||||
@@ -115,11 +114,6 @@ class PromptCLIVisualizer:
|
||||
digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest()
|
||||
return root / f"{digest}.{image_format}"
|
||||
|
||||
@staticmethod
|
||||
def _build_file_uri(file_path: Path) -> str:
|
||||
normalized = file_path.resolve().as_posix()
|
||||
return f"file:///{quote(normalized, safe='/:')}"
|
||||
|
||||
@staticmethod
|
||||
def _build_official_image_path(image_format: str, image_base64: str) -> Path | None:
|
||||
normalized_format = PromptCLIVisualizer._normalize_image_format(image_format)
|
||||
@@ -140,7 +134,7 @@ class PromptCLIVisualizer:
|
||||
normalized_format = PromptCLIVisualizer._normalize_image_format(image_format) or "bin"
|
||||
official_path = PromptCLIVisualizer._build_official_image_path(image_format, image_base64)
|
||||
if official_path is not None:
|
||||
return PromptCLIVisualizer._build_file_uri(official_path), official_path
|
||||
return build_file_uri(official_path), official_path
|
||||
|
||||
try:
|
||||
image_bytes = b64decode(image_base64)
|
||||
@@ -153,7 +147,7 @@ class PromptCLIVisualizer:
|
||||
path.write_bytes(image_bytes)
|
||||
except Exception:
|
||||
return None
|
||||
return PromptCLIVisualizer._build_file_uri(path), path
|
||||
return build_file_uri(path), path
|
||||
|
||||
@classmethod
|
||||
def _render_image_item(cls, image_format: str, image_base64: str, settings: PromptImageDisplaySettings) -> Panel:
|
||||
@@ -169,8 +163,9 @@ class PromptCLIVisualizer:
|
||||
path_result = cls._build_image_file_link(image_format, image_base64)
|
||||
if path_result is not None:
|
||||
file_uri, file_path = path_result
|
||||
display_path = build_display_path(file_path)
|
||||
preview_parts: List[RenderableType] = [
|
||||
Text(f"图片格式 image/{normalized_format} {size_text} 路径:{file_path}", style="magenta")
|
||||
Text(f"图片格式 image/{normalized_format} {size_text} 路径:{display_path}", style="magenta")
|
||||
]
|
||||
|
||||
preview_parts.append(Text.from_markup(f"[link={file_uri}]点击打开图片[/link]", style="cyan"))
|
||||
@@ -181,6 +176,16 @@ class PromptCLIVisualizer:
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_image_pair(item: Any) -> tuple[str, str] | None:
|
||||
"""兼容图片片段被序列化为 tuple 或 list 的两种形式。"""
|
||||
|
||||
if isinstance(item, (tuple, list)) and len(item) == 2:
|
||||
image_format, image_base64 = item
|
||||
if isinstance(image_format, str) and isinstance(image_base64, str):
|
||||
return image_format, image_base64
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _render_message_content(cls, content: Any, settings: PromptImageDisplaySettings) -> RenderableType:
|
||||
if isinstance(content, str):
|
||||
@@ -192,11 +197,11 @@ class PromptCLIVisualizer:
|
||||
if isinstance(item, str):
|
||||
parts.append(Text(item))
|
||||
continue
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
image_format, image_base64 = item
|
||||
if isinstance(image_format, str) and isinstance(image_base64, str):
|
||||
parts.append(cls._render_image_item(image_format, image_base64, settings))
|
||||
continue
|
||||
image_pair = cls._extract_image_pair(item)
|
||||
if image_pair is not None:
|
||||
image_format, image_base64 = image_pair
|
||||
parts.append(cls._render_image_item(image_format, image_base64, settings))
|
||||
continue
|
||||
if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str):
|
||||
parts.append(Text(item["text"]))
|
||||
else:
|
||||
@@ -218,8 +223,9 @@ class PromptCLIVisualizer:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
continue
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
image_format, image_base64 = item
|
||||
image_pair = cls._extract_image_pair(item)
|
||||
if image_pair is not None:
|
||||
image_format, image_base64 = image_pair
|
||||
approx_size = max(0, len(str(image_base64)) * 3 // 4)
|
||||
parts.append(f"[图片 image/{image_format} {approx_size} B]")
|
||||
continue
|
||||
@@ -242,6 +248,85 @@ class PromptCLIVisualizer:
|
||||
def format_tool_call_for_display(cls, tool_call: Any) -> Dict[str, Any]:
|
||||
return normalize_tool_call_for_display(tool_call)
|
||||
|
||||
@classmethod
|
||||
def _build_tool_card_title(cls, tool_call: Any) -> str:
|
||||
"""构建 HTML 中工具卡片的折叠标题。"""
|
||||
|
||||
normalized_tool_call = cls.format_tool_call_for_display(tool_call)
|
||||
tool_name = str(normalized_tool_call.get("name") or "").strip()
|
||||
return tool_name or "unknown"
|
||||
|
||||
@classmethod
|
||||
def _build_tool_call_html(cls, tool_call: Any) -> str:
|
||||
"""将单个工具调用渲染为默认折叠的 HTML 卡片。"""
|
||||
|
||||
normalized_tool_call = cls.format_tool_call_for_display(tool_call)
|
||||
tool_name = cls._build_tool_card_title(tool_call)
|
||||
tool_call_id = str(normalized_tool_call.get("id") or "").strip()
|
||||
tool_arguments = normalized_tool_call.get("arguments")
|
||||
|
||||
tool_meta_html = ""
|
||||
if tool_call_id:
|
||||
tool_meta_html = (
|
||||
"<div class='tool-card-meta'>"
|
||||
"<span class='tool-card-meta-label'>调用 ID</span>"
|
||||
f"<code>{html.escape(tool_call_id)}</code>"
|
||||
"</div>"
|
||||
)
|
||||
|
||||
return (
|
||||
"<details class='tool-card tool-call-card'>"
|
||||
"<summary class='tool-card-summary'>"
|
||||
f"<span class='tool-card-name'>{html.escape(tool_name)}</span>"
|
||||
"</summary>"
|
||||
"<div class='tool-card-body'>"
|
||||
f"{tool_meta_html}"
|
||||
f"<pre>{html.escape(json.dumps(tool_arguments, ensure_ascii=False, indent=2, default=str))}</pre>"
|
||||
"</div>"
|
||||
"</details>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_tool_definition_fields(cls, tool_definition: dict[str, Any]) -> tuple[str, str, Any]:
|
||||
"""提取工具定义中的名称、描述和详情内容。"""
|
||||
|
||||
function_info = tool_definition.get("function")
|
||||
if isinstance(function_info, dict):
|
||||
tool_name = str(function_info.get("name") or "").strip() or "unknown"
|
||||
description = str(function_info.get("description") or "").strip()
|
||||
detail_payload = function_info
|
||||
else:
|
||||
tool_name = str(tool_definition.get("name") or "").strip() or "unknown"
|
||||
description = str(tool_definition.get("description") or "").strip()
|
||||
detail_payload = tool_definition
|
||||
return tool_name, description, detail_payload
|
||||
|
||||
@classmethod
|
||||
def _build_tool_definition_html(cls, tool_definition: dict[str, Any]) -> str:
|
||||
"""将单个传入工具定义渲染为默认折叠的 HTML 卡片。"""
|
||||
|
||||
tool_name, description, detail_payload = cls._extract_tool_definition_fields(tool_definition)
|
||||
description_html = ""
|
||||
if description:
|
||||
description_html = (
|
||||
"<div class='tool-card-meta'>"
|
||||
"<span class='tool-card-meta-label'>说明</span>"
|
||||
f"<span>{html.escape(description)}</span>"
|
||||
"</div>"
|
||||
)
|
||||
|
||||
return (
|
||||
"<details class='tool-card tool-definition-card'>"
|
||||
"<summary class='tool-card-summary'>"
|
||||
f"<span class='tool-card-name'>{html.escape(tool_name)}</span>"
|
||||
"</summary>"
|
||||
"<div class='tool-card-body'>"
|
||||
f"{description_html}"
|
||||
f"<pre>{html.escape(json.dumps(detail_payload, ensure_ascii=False, indent=2, default=str))}</pre>"
|
||||
"</div>"
|
||||
"</details>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _render_tool_call_panel(cls, tool_call: Any, index: int, parent_index: int) -> Panel:
|
||||
title = Text.assemble(
|
||||
@@ -291,6 +376,20 @@ class PromptCLIVisualizer:
|
||||
|
||||
return "\n\n" + ("\n\n" + ("=" * 80) + "\n\n").join(sections) if sections else "[空 Prompt]"
|
||||
|
||||
@classmethod
|
||||
def _build_tool_definition_dump_text(cls, tool_definitions: list[dict[str, Any]] | None) -> str:
|
||||
"""构建传入工具定义的文本备份内容。"""
|
||||
|
||||
if not tool_definitions:
|
||||
return ""
|
||||
|
||||
sections: List[str] = ["[tool_definitions]"]
|
||||
for index, tool_definition in enumerate(tool_definitions, start=1):
|
||||
tool_name, _, detail_payload = cls._extract_tool_definition_fields(tool_definition)
|
||||
sections.append(f"[{index}] name={tool_name}")
|
||||
sections.append(json.dumps(detail_payload, ensure_ascii=False, indent=2, default=str))
|
||||
return "\n\n".join(sections).strip()
|
||||
|
||||
@classmethod
|
||||
def _render_message_content_html(cls, content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
@@ -302,8 +401,9 @@ class PromptCLIVisualizer:
|
||||
if isinstance(item, str):
|
||||
parts.append(f"<pre>{html.escape(item)}</pre>")
|
||||
continue
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
image_format, image_base64 = item
|
||||
image_pair = cls._extract_image_pair(item)
|
||||
if image_pair is not None:
|
||||
image_format, image_base64 = image_pair
|
||||
image_html = cls._render_image_item_html(str(image_format), str(image_base64))
|
||||
parts.append(image_html)
|
||||
continue
|
||||
@@ -332,14 +432,44 @@ class PromptCLIVisualizer:
|
||||
)
|
||||
|
||||
file_uri, file_path = path_result
|
||||
display_path = build_display_path(file_path)
|
||||
return (
|
||||
"<div class='image-card'>"
|
||||
f"<div class='image-meta'>图片 image/{html.escape(normalized_format)} {html.escape(size_text)}</div>"
|
||||
f"<div class='image-path'>{html.escape(str(file_path))}</div>"
|
||||
f"<a class='image-preview-link' href='{html.escape(file_uri, quote=True)}'>"
|
||||
f"<img class='image-preview' src='{html.escape(file_uri, quote=True)}' alt='图片预览' />"
|
||||
"</a>"
|
||||
f"<div class='image-path'>{html.escape(display_path)}</div>"
|
||||
f"<a class='image-link' href='{html.escape(file_uri, quote=True)}'>打开图片</a>"
|
||||
"</div>"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_preview_access_body(
|
||||
*,
|
||||
viewer_label: str,
|
||||
viewer_path: Path,
|
||||
viewer_link_text: str,
|
||||
dump_label: str,
|
||||
dump_path: Path,
|
||||
dump_link_text: str,
|
||||
) -> RenderableType:
|
||||
viewer_uri = build_file_uri(viewer_path)
|
||||
dump_uri = build_file_uri(dump_path)
|
||||
viewer_display_path = build_display_path(viewer_path)
|
||||
dump_display_path = build_display_path(dump_path)
|
||||
|
||||
return Group(
|
||||
Text.from_markup(
|
||||
f"[bold green]{viewer_label}:{viewer_display_path}[/bold green] "
|
||||
f"[link={viewer_uri}]{viewer_link_text}[/link]"
|
||||
),
|
||||
Text.from_markup(
|
||||
f"[magenta]{dump_label}:{dump_display_path}[/magenta] "
|
||||
f"[cyan][link={dump_uri}]{dump_link_text}[/link][/cyan]"
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _build_html_role_class(cls, role: str) -> str:
|
||||
return {
|
||||
@@ -356,6 +486,7 @@ class PromptCLIVisualizer:
|
||||
*,
|
||||
request_kind: str,
|
||||
selection_reason: str,
|
||||
tool_definitions: list[dict[str, Any]] | None = None,
|
||||
) -> str:
|
||||
panel_title, _ = cls.get_request_panel_style(request_kind)
|
||||
message_cards: List[str] = []
|
||||
@@ -378,16 +509,12 @@ class PromptCLIVisualizer:
|
||||
tool_panels = ""
|
||||
raw_tool_calls = message.get("tool_calls") or []
|
||||
if isinstance(raw_tool_calls, list) and raw_tool_calls:
|
||||
tool_items = []
|
||||
for tool_call_index, tool_call in enumerate(raw_tool_calls, start=1):
|
||||
normalized_tool_call = cls.format_tool_call_for_display(tool_call)
|
||||
tool_items.append(
|
||||
"<div class='tool-panel'>"
|
||||
f"<div class='tool-panel-title'>工具调用 #{index}.{tool_call_index}</div>"
|
||||
f"<pre>{html.escape(json.dumps(normalized_tool_call, ensure_ascii=False, indent=2, default=str))}</pre>"
|
||||
"</div>"
|
||||
)
|
||||
tool_panels = "".join(tool_items)
|
||||
tool_panels = (
|
||||
"<div class='tool-list'>"
|
||||
"<div class='tool-list-title'>工具调用</div>"
|
||||
f"{''.join(cls._build_tool_call_html(tool_call) for tool_call in raw_tool_calls)}"
|
||||
"</div>"
|
||||
)
|
||||
|
||||
message_cards.append(
|
||||
"<section class='message-card'>"
|
||||
@@ -405,6 +532,21 @@ class PromptCLIVisualizer:
|
||||
if selection_reason.strip():
|
||||
subtitle_html = f"<div class='subtitle'>{html.escape(selection_reason)}</div>"
|
||||
|
||||
tool_definition_section_html = ""
|
||||
if tool_definitions:
|
||||
tool_definition_section_html = (
|
||||
"<section class='message-card tool-definition-section'>"
|
||||
"<div class='message-head'>"
|
||||
"<span class='role-badge tool'>全部工具</span>"
|
||||
f"<span class='message-index'>{len(tool_definitions)} 个</span>"
|
||||
"</div>"
|
||||
"<div class='tool-list'>"
|
||||
"<div class='tool-list-title'>本次送入模型的工具定义</div>"
|
||||
f"{''.join(cls._build_tool_definition_html(tool_definition) for tool_definition in tool_definitions)}"
|
||||
"</div>"
|
||||
"</section>"
|
||||
)
|
||||
|
||||
return f"""<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
@@ -491,7 +633,7 @@ class PromptCLIVisualizer:
|
||||
font-weight: 600;
|
||||
}}
|
||||
.message-content pre,
|
||||
.tool-panel pre {{
|
||||
.tool-card pre {{
|
||||
margin: 0;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
@@ -517,18 +659,81 @@ class PromptCLIVisualizer:
|
||||
border-radius: 8px;
|
||||
padding: 3px 8px;
|
||||
}}
|
||||
.tool-panel {{
|
||||
.tool-list {{
|
||||
margin-top: 14px;
|
||||
}}
|
||||
.tool-list-title {{
|
||||
color: #86198f;
|
||||
font-size: 13px;
|
||||
font-weight: 800;
|
||||
margin-bottom: 10px;
|
||||
}}
|
||||
.tool-card {{
|
||||
margin-top: 12px;
|
||||
background: #fcf4ff;
|
||||
border: 1px solid #f0d7fb;
|
||||
border-radius: 14px;
|
||||
padding: 12px 14px;
|
||||
overflow: hidden;
|
||||
}}
|
||||
.tool-panel-title {{
|
||||
color: #a21caf;
|
||||
.tool-call-card {{
|
||||
border-color: #ff8700;
|
||||
}}
|
||||
.tool-card:first-of-type {{
|
||||
margin-top: 0;
|
||||
}}
|
||||
.tool-card-summary {{
|
||||
list-style: none;
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 12px 14px;
|
||||
color: #86198f;
|
||||
font-size: 13px;
|
||||
font-weight: 800;
|
||||
}}
|
||||
.tool-card-summary::-webkit-details-marker {{
|
||||
display: none;
|
||||
}}
|
||||
.tool-card-summary::after {{
|
||||
content: "展开";
|
||||
color: #a21caf;
|
||||
font-size: 12px;
|
||||
font-weight: 700;
|
||||
margin-bottom: 8px;
|
||||
}}
|
||||
.tool-card[open] .tool-card-summary::after {{
|
||||
content: "收起";
|
||||
}}
|
||||
.tool-card-name {{
|
||||
word-break: break-word;
|
||||
}}
|
||||
.tool-card-body {{
|
||||
border-top: 1px solid #f0d7fb;
|
||||
padding: 12px 14px;
|
||||
background: rgba(255, 255, 255, 0.52);
|
||||
}}
|
||||
.tool-call-card .tool-card-body {{
|
||||
border-top-color: #ff8700;
|
||||
}}
|
||||
.tool-card-meta {{
|
||||
margin-bottom: 10px;
|
||||
color: #a21caf;
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
align-items: center;
|
||||
flex-wrap: wrap;
|
||||
}}
|
||||
.tool-card-meta-label {{
|
||||
font-weight: 700;
|
||||
}}
|
||||
.tool-card-meta code {{
|
||||
background: #faf5ff;
|
||||
border: 1px solid #e9d5ff;
|
||||
border-radius: 8px;
|
||||
padding: 3px 8px;
|
||||
}}
|
||||
.tool-card pre {{
|
||||
color: #3b0764;
|
||||
}}
|
||||
.image-card {{
|
||||
background: #f8fafc;
|
||||
@@ -547,6 +752,22 @@ class PromptCLIVisualizer:
|
||||
font-family: "Cascadia Mono", "JetBrains Mono", "Consolas", monospace;
|
||||
word-break: break-all;
|
||||
}}
|
||||
.image-preview-link {{
|
||||
display: block;
|
||||
margin-top: 10px;
|
||||
}}
|
||||
.image-preview {{
|
||||
display: block;
|
||||
max-width: min(100%, 560px);
|
||||
max-height: 420px;
|
||||
width: auto;
|
||||
height: auto;
|
||||
border-radius: 12px;
|
||||
border: 1px solid #dbe4f0;
|
||||
background: #fff;
|
||||
box-shadow: 0 8px 20px rgba(15, 23, 42, 0.08);
|
||||
object-fit: contain;
|
||||
}}
|
||||
.image-link {{
|
||||
display: inline-block;
|
||||
margin-top: 8px;
|
||||
@@ -564,6 +785,7 @@ class PromptCLIVisualizer:
|
||||
{subtitle_html}
|
||||
</header>
|
||||
{''.join(message_cards)}
|
||||
{tool_definition_section_html}
|
||||
</main>
|
||||
</body>
|
||||
</html>"""
|
||||
@@ -578,6 +800,7 @@ class PromptCLIVisualizer:
|
||||
request_kind: str,
|
||||
selection_reason: str,
|
||||
image_display_mode: Literal["legacy", "path_link"],
|
||||
tool_definitions: list[dict[str, Any]] | None = None,
|
||||
) -> RenderableType:
|
||||
"""构建用于查看完整 prompt 的折叠入口内容。"""
|
||||
|
||||
@@ -603,10 +826,14 @@ class PromptCLIVisualizer:
|
||||
viewer_messages.append(normalized_message)
|
||||
|
||||
prompt_dump_text = cls._build_prompt_dump_text(messages)
|
||||
tool_definition_dump_text = cls._build_tool_definition_dump_text(tool_definitions)
|
||||
if tool_definition_dump_text:
|
||||
prompt_dump_text = f"{prompt_dump_text}\n\n{'=' * 80}\n\n{tool_definition_dump_text}"
|
||||
viewer_html_text = cls._build_prompt_viewer_html(
|
||||
viewer_messages,
|
||||
request_kind=request_kind,
|
||||
selection_reason=selection_reason,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
saved_paths = PromptPreviewLogger.save_preview_files(
|
||||
chat_id,
|
||||
@@ -618,18 +845,13 @@ class PromptCLIVisualizer:
|
||||
)
|
||||
viewer_html_path = saved_paths[".html"]
|
||||
prompt_dump_path = saved_paths[".txt"]
|
||||
viewer_uri = cls._build_file_uri(viewer_html_path)
|
||||
dump_uri = cls._build_file_uri(prompt_dump_path)
|
||||
|
||||
body = Group(
|
||||
Text.from_markup(
|
||||
f"[bold green]富文本预览:{viewer_html_path}[/bold green] "
|
||||
f"[link={viewer_uri}]点击在浏览器打开富文本 Prompt 视图[/link]"
|
||||
),
|
||||
Text.from_markup(
|
||||
f"[magenta]原始文本备份:{prompt_dump_path}[/magenta] "
|
||||
f"[cyan][link={dump_uri}]点击直接打开 Prompt 文本[/link][/cyan]"
|
||||
),
|
||||
body = cls._build_preview_access_body(
|
||||
viewer_label="html预览",
|
||||
viewer_path=viewer_html_path,
|
||||
viewer_link_text="在浏览器打开 Prompt",
|
||||
dump_label="原始文本",
|
||||
dump_path=prompt_dump_path,
|
||||
dump_link_text="点击打开 Prompt 文本",
|
||||
)
|
||||
return body
|
||||
|
||||
@@ -644,6 +866,7 @@ class PromptCLIVisualizer:
|
||||
selection_reason: str,
|
||||
image_display_mode: Literal["legacy", "path_link"],
|
||||
folded: bool,
|
||||
tool_definitions: list[dict[str, Any]] | None = None,
|
||||
) -> Panel:
|
||||
"""构建用于嵌入结果面板中的 Prompt 区块。"""
|
||||
|
||||
@@ -656,6 +879,7 @@ class PromptCLIVisualizer:
|
||||
request_kind=request_kind,
|
||||
selection_reason=selection_reason,
|
||||
image_display_mode=image_display_mode,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
else:
|
||||
ordered_panels = cls.build_prompt_panels(
|
||||
@@ -782,18 +1006,13 @@ class PromptCLIVisualizer:
|
||||
)
|
||||
viewer_html_path = saved_paths[".html"]
|
||||
text_dump_path = saved_paths[".txt"]
|
||||
viewer_uri = cls._build_file_uri(viewer_html_path)
|
||||
dump_uri = cls._build_file_uri(text_dump_path)
|
||||
|
||||
body = Group(
|
||||
Text.from_markup(
|
||||
f"[bold green]富文本预览:{viewer_html_path}[/bold green] "
|
||||
f"[link={viewer_uri}]点击在浏览器打开富文本 Prompt 视图[/link]"
|
||||
),
|
||||
Text.from_markup(
|
||||
f"[magenta]原始文本备份:{text_dump_path}[/magenta] "
|
||||
f"[cyan][link={dump_uri}]点击直接打开 Prompt 文本[/link][/cyan]"
|
||||
),
|
||||
body = cls._build_preview_access_body(
|
||||
viewer_label="富文本预览",
|
||||
viewer_path=viewer_html_path,
|
||||
viewer_link_text="点击在浏览器打开富文本 Prompt 视图",
|
||||
dump_label="原始文本备份",
|
||||
dump_path=text_dump_path,
|
||||
dump_link_text="点击直接打开 Prompt 文本",
|
||||
)
|
||||
return body
|
||||
|
||||
@@ -2,34 +2,29 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from uuid import uuid4
|
||||
|
||||
from src.config.config import global_config
|
||||
from .preview_path_utils import build_preview_chat_dir_name, normalize_preview_name
|
||||
|
||||
|
||||
class PromptPreviewLogger:
|
||||
"""负责保存 Maisaka Prompt 预览文件并控制目录容量。"""
|
||||
|
||||
_BASE_DIR = Path("logs") / "maisaka_prompt"
|
||||
_MAX_PREVIEW_GROUPS_PER_CHAT = 1024
|
||||
_TRIM_COUNT = 100
|
||||
_SAFE_NAME_PATTERN = re.compile(r"[^A-Za-z0-9._-]+")
|
||||
|
||||
@classmethod
|
||||
def _get_max_per_chat(cls) -> int:
|
||||
"""从配置中获取每个聊天流最大保存的预览数量。"""
|
||||
|
||||
return getattr(global_config.chat, "plan_reply_log_max_per_chat", 1000)
|
||||
|
||||
@classmethod
|
||||
def _normalize_chat_id(cls, chat_id: str) -> str:
|
||||
normalized_chat_id = cls._SAFE_NAME_PATTERN.sub("_", str(chat_id or "").strip()).strip("._")
|
||||
if normalized_chat_id:
|
||||
return normalized_chat_id
|
||||
return "unknown_chat"
|
||||
def _build_file_stem(cls, chat_dir: Path) -> str:
|
||||
base_stem = str(int(time.time() * 1000))
|
||||
candidate_stem = base_stem
|
||||
suffix_index = 1
|
||||
while any((chat_dir / f"{candidate_stem}{suffix}").exists() for suffix in (".html", ".txt")):
|
||||
candidate_stem = f"{base_stem}_{suffix_index}"
|
||||
suffix_index += 1
|
||||
return candidate_stem
|
||||
|
||||
@classmethod
|
||||
def save_preview_files(
|
||||
@@ -40,10 +35,10 @@ class PromptPreviewLogger:
|
||||
) -> Dict[str, Path]:
|
||||
"""保存同一份 Prompt 预览的多个文件并执行超量清理。"""
|
||||
|
||||
normalized_category = cls._normalize_chat_id(category)
|
||||
chat_dir = (cls._BASE_DIR / normalized_category / cls._normalize_chat_id(chat_id)).resolve()
|
||||
normalized_category = normalize_preview_name(category)
|
||||
chat_dir = (cls._BASE_DIR / normalized_category / build_preview_chat_dir_name(chat_id)).resolve()
|
||||
chat_dir.mkdir(parents=True, exist_ok=True)
|
||||
stem = f"{int(time.time() * 1000)}_{uuid4().hex[:8]}"
|
||||
stem = cls._build_file_stem(chat_dir)
|
||||
saved_paths: Dict[str, Path] = {}
|
||||
try:
|
||||
for suffix, content in files.items():
|
||||
@@ -65,15 +60,14 @@ class PromptPreviewLogger:
|
||||
continue
|
||||
grouped_files.setdefault(file_path.stem, []).append(file_path)
|
||||
|
||||
max_per_chat = cls._get_max_per_chat()
|
||||
if len(grouped_files) <= max_per_chat:
|
||||
if len(grouped_files) <= cls._MAX_PREVIEW_GROUPS_PER_CHAT:
|
||||
return
|
||||
|
||||
sorted_groups = sorted(
|
||||
grouped_files.items(),
|
||||
key=lambda item: min(path.stat().st_mtime for path in item[1]),
|
||||
)
|
||||
overflow_count = len(grouped_files) - max_per_chat
|
||||
overflow_count = len(grouped_files) - cls._MAX_PREVIEW_GROUPS_PER_CHAT
|
||||
trim_count = min(len(sorted_groups), max(cls._TRIM_COUNT, overflow_count))
|
||||
for _, file_group in sorted_groups[:trim_count]:
|
||||
for old_file in file_group:
|
||||
163
src/maisaka/display/stage_status_board.py
Normal file
163
src/maisaka/display/stage_status_board.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Maisaka 阶段状态看板。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
class MaisakaStageStatusBoard:
|
||||
"""维护 Maisaka 阶段状态,并在独立终端中展示。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._enabled = False
|
||||
self._entries: dict[str, dict[str, Any]] = {}
|
||||
self._viewer_process: Optional[subprocess.Popen[Any]] = None
|
||||
self._state_file = Path("temp") / "maisaka_stage_status.json"
|
||||
self._state_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def enable(self) -> None:
|
||||
"""启用阶段状态看板。"""
|
||||
|
||||
with self._lock:
|
||||
if self._enabled:
|
||||
return
|
||||
self._enabled = True
|
||||
self._write_state_locked()
|
||||
self._ensure_viewer_process_locked()
|
||||
|
||||
def disable(self) -> None:
|
||||
"""禁用阶段状态看板。"""
|
||||
|
||||
with self._lock:
|
||||
self._enabled = False
|
||||
self._entries.clear()
|
||||
self._write_state_locked()
|
||||
process = self._viewer_process
|
||||
self._viewer_process = None
|
||||
|
||||
if process is not None and process.poll() is None:
|
||||
try:
|
||||
process.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def update(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
session_name: str,
|
||||
stage: str,
|
||||
detail: str = "",
|
||||
round_text: str = "",
|
||||
agent_state: str = "",
|
||||
) -> None:
|
||||
"""更新一个会话的阶段状态。"""
|
||||
|
||||
with self._lock:
|
||||
if not self._enabled:
|
||||
return
|
||||
now = time.time()
|
||||
current = self._entries.get(session_id, {})
|
||||
previous_stage = str(current.get("stage") or "").strip()
|
||||
stage_started_at = float(current.get("stage_started_at") or now)
|
||||
if previous_stage != stage:
|
||||
stage_started_at = now
|
||||
self._entries[session_id] = {
|
||||
"session_id": session_id,
|
||||
"session_name": session_name,
|
||||
"stage": stage,
|
||||
"detail": detail,
|
||||
"round_text": round_text,
|
||||
"agent_state": agent_state,
|
||||
"stage_started_at": stage_started_at,
|
||||
"updated_at": now,
|
||||
}
|
||||
self._write_state_locked()
|
||||
|
||||
def remove(self, session_id: str) -> None:
|
||||
"""移除一个会话的阶段状态。"""
|
||||
|
||||
with self._lock:
|
||||
if not self._enabled:
|
||||
return
|
||||
self._entries.pop(session_id, None)
|
||||
self._write_state_locked()
|
||||
|
||||
def _write_state_locked(self) -> None:
|
||||
payload = {
|
||||
"enabled": self._enabled,
|
||||
"host_pid": os.getpid(),
|
||||
"updated_at": time.time(),
|
||||
"entries": list(self._entries.values()),
|
||||
}
|
||||
tmp_file = self._state_file.with_suffix(".tmp")
|
||||
tmp_file.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
tmp_file.replace(self._state_file)
|
||||
|
||||
def _ensure_viewer_process_locked(self) -> None:
|
||||
if not sys.platform.startswith("win"):
|
||||
return
|
||||
if self._viewer_process is not None and self._viewer_process.poll() is None:
|
||||
return
|
||||
creationflags = getattr(subprocess, "CREATE_NEW_CONSOLE", 0)
|
||||
viewer_script = Path(__file__).resolve().with_name("stage_status_viewer.py")
|
||||
self._viewer_process = subprocess.Popen(
|
||||
[
|
||||
sys.executable,
|
||||
str(viewer_script),
|
||||
str(self._state_file.resolve()),
|
||||
],
|
||||
creationflags=creationflags,
|
||||
cwd=str(Path.cwd()),
|
||||
)
|
||||
|
||||
|
||||
_stage_board = MaisakaStageStatusBoard()
|
||||
|
||||
|
||||
def enable_stage_status_board() -> None:
|
||||
"""启用控制台阶段状态看板。"""
|
||||
|
||||
_stage_board.enable()
|
||||
|
||||
|
||||
def disable_stage_status_board() -> None:
|
||||
"""禁用控制台阶段状态看板。"""
|
||||
|
||||
_stage_board.disable()
|
||||
|
||||
|
||||
def update_stage_status(
|
||||
*,
|
||||
session_id: str,
|
||||
session_name: str,
|
||||
stage: str,
|
||||
detail: str = "",
|
||||
round_text: str = "",
|
||||
agent_state: str = "",
|
||||
) -> None:
|
||||
"""更新控制台阶段状态。"""
|
||||
|
||||
_stage_board.update(
|
||||
session_id=session_id,
|
||||
session_name=session_name,
|
||||
stage=stage,
|
||||
detail=detail,
|
||||
round_text=round_text,
|
||||
agent_state=agent_state,
|
||||
)
|
||||
|
||||
|
||||
def remove_stage_status(session_id: str) -> None:
|
||||
"""移除控制台阶段状态。"""
|
||||
|
||||
_stage_board.remove(session_id)
|
||||
93
src/maisaka/display/stage_status_viewer.py
Normal file
93
src/maisaka/display/stage_status_viewer.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Maisaka 阶段状态看板查看器。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
|
||||
def _clear_screen() -> None:
|
||||
os.system("cls" if sys.platform.startswith("win") else "clear")
|
||||
|
||||
|
||||
def _load_state(state_file: Path) -> dict[str, Any]:
|
||||
if not state_file.exists():
|
||||
return {}
|
||||
try:
|
||||
return json.loads(state_file.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _render(state: dict[str, Any]) -> str:
|
||||
entries = state.get("entries")
|
||||
if not isinstance(entries, list):
|
||||
entries = []
|
||||
|
||||
lines = ["Maisaka 阶段看板", "=" * 72, ""]
|
||||
if not entries:
|
||||
lines.append("当前没有活跃会话。")
|
||||
return "\n".join(lines)
|
||||
|
||||
entries = sorted(
|
||||
[entry for entry in entries if isinstance(entry, dict)],
|
||||
key=lambda item: str(item.get("session_name") or item.get("session_id") or ""),
|
||||
)
|
||||
now = time.time()
|
||||
for entry in entries:
|
||||
session_name = str(entry.get("session_name") or entry.get("session_id") or "").strip() or "unknown"
|
||||
session_id = str(entry.get("session_id") or "").strip()
|
||||
stage = str(entry.get("stage") or "").strip() or "未知"
|
||||
detail = str(entry.get("detail") or "").strip() or "-"
|
||||
round_text = str(entry.get("round_text") or "").strip()
|
||||
agent_state = str(entry.get("agent_state") or "").strip() or "-"
|
||||
stage_started_at = float(entry.get("stage_started_at") or now)
|
||||
elapsed = max(0.0, now - stage_started_at)
|
||||
|
||||
lines.append(f"Chat: {session_name}")
|
||||
if session_id and session_id != session_name:
|
||||
lines.append(f"ID: {session_id}")
|
||||
lines.append(f"阶段: {stage}")
|
||||
if round_text:
|
||||
lines.append(f"轮次: {round_text}")
|
||||
lines.append(f"详情: {detail}")
|
||||
lines.append(f"状态: {agent_state}")
|
||||
lines.append(f"阶段耗时: {elapsed:.1f}s")
|
||||
lines.append("-" * 72)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
if len(sys.argv) < 2:
|
||||
return 1
|
||||
|
||||
state_file = Path(sys.argv[1]).resolve()
|
||||
log_file = state_file.with_name("maisaka_stage_status_viewer.log")
|
||||
last_render = ""
|
||||
while True:
|
||||
try:
|
||||
state = _load_state(state_file)
|
||||
if not state.get("enabled", False):
|
||||
return 0
|
||||
|
||||
rendered = _render(state)
|
||||
if rendered != last_render:
|
||||
_clear_screen()
|
||||
print(rendered, flush=True)
|
||||
last_render = rendered
|
||||
time.sleep(0.5)
|
||||
except Exception:
|
||||
log_file.write_text(traceback.format_exc(), encoding="utf-8")
|
||||
time.sleep(3)
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
125
src/maisaka/history_post_processor.py
Normal file
125
src/maisaka/history_post_processor.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Maisaka 历史消息轮次结束后处理。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .context_messages import AssistantMessage, LLMContextMessage, ToolResultMessage
|
||||
from .history_utils import drop_leading_orphan_tool_results, drop_orphan_tool_results
|
||||
|
||||
TIMING_HISTORY_TOOL_NAMES = {"continue", "finish", "no_reply", "wait"}
|
||||
EARLY_TRIM_RATIO = 0.2
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HistoryPostProcessResult:
|
||||
"""历史后处理结果。"""
|
||||
|
||||
history: list[LLMContextMessage]
|
||||
removed_count: int
|
||||
remaining_context_count: int
|
||||
|
||||
|
||||
def process_chat_history_after_cycle(
|
||||
chat_history: list[LLMContextMessage],
|
||||
*,
|
||||
max_context_size: int,
|
||||
) -> HistoryPostProcessResult:
|
||||
"""在每轮结束后统一执行历史裁切与清理。"""
|
||||
|
||||
processed_history = list(chat_history)
|
||||
removed_timing_tool_count = _remove_early_timing_tool_records(processed_history)
|
||||
removed_assistant_thought_count = _remove_early_assistant_thoughts(processed_history)
|
||||
|
||||
processed_history, orphan_removed_count = drop_orphan_tool_results(processed_history)
|
||||
remaining_context_count = sum(1 for message in processed_history if message.count_in_context)
|
||||
removed_overflow_count = 0
|
||||
|
||||
while remaining_context_count > max_context_size and processed_history:
|
||||
removed_message = processed_history.pop(0)
|
||||
removed_overflow_count += 1
|
||||
if removed_message.count_in_context:
|
||||
remaining_context_count -= 1
|
||||
|
||||
processed_history, leading_orphan_removed_count = drop_leading_orphan_tool_results(processed_history)
|
||||
removed_overflow_count += leading_orphan_removed_count
|
||||
remaining_context_count = sum(1 for message in processed_history if message.count_in_context)
|
||||
removed_count = (
|
||||
removed_timing_tool_count
|
||||
+ removed_assistant_thought_count
|
||||
+ orphan_removed_count
|
||||
+ removed_overflow_count
|
||||
)
|
||||
return HistoryPostProcessResult(
|
||||
history=processed_history,
|
||||
removed_count=removed_count,
|
||||
remaining_context_count=remaining_context_count,
|
||||
)
|
||||
|
||||
|
||||
def _remove_early_timing_tool_records(chat_history: list[LLMContextMessage]) -> int:
|
||||
"""移除最早 20% 的门控/结束类工具链记录。"""
|
||||
|
||||
candidate_assistant_indexes = [
|
||||
index
|
||||
for index, message in enumerate(chat_history)
|
||||
if _is_timing_tool_assistant_message(message)
|
||||
]
|
||||
remove_count = int(len(candidate_assistant_indexes) * EARLY_TRIM_RATIO)
|
||||
if remove_count <= 0:
|
||||
return 0
|
||||
|
||||
removed_indexes = set(candidate_assistant_indexes[:remove_count])
|
||||
removed_tool_call_ids = {
|
||||
tool_call.call_id
|
||||
for index in removed_indexes
|
||||
for tool_call in chat_history[index].tool_calls
|
||||
if tool_call.call_id
|
||||
}
|
||||
|
||||
filtered_history: list[LLMContextMessage] = []
|
||||
removed_total = 0
|
||||
for index, message in enumerate(chat_history):
|
||||
if index in removed_indexes:
|
||||
removed_total += 1
|
||||
continue
|
||||
if isinstance(message, ToolResultMessage) and message.tool_call_id in removed_tool_call_ids:
|
||||
removed_total += 1
|
||||
continue
|
||||
filtered_history.append(message)
|
||||
|
||||
chat_history[:] = filtered_history
|
||||
return removed_total
|
||||
|
||||
|
||||
def _remove_early_assistant_thoughts(chat_history: list[LLMContextMessage]) -> int:
|
||||
"""移除最早 20% 的非工具 assistant 思考内容。"""
|
||||
|
||||
candidate_indexes = [
|
||||
index
|
||||
for index, message in enumerate(chat_history)
|
||||
if isinstance(message, AssistantMessage)
|
||||
and not message.tool_calls
|
||||
and message.source_kind != "perception"
|
||||
and bool(message.content.strip())
|
||||
]
|
||||
remove_count = int(len(candidate_indexes) * EARLY_TRIM_RATIO)
|
||||
if remove_count <= 0:
|
||||
return 0
|
||||
|
||||
removed_indexes = set(candidate_indexes[:remove_count])
|
||||
filtered_history: list[LLMContextMessage] = []
|
||||
removed_total = 0
|
||||
for index, message in enumerate(chat_history):
|
||||
if index in removed_indexes:
|
||||
removed_total += 1
|
||||
continue
|
||||
filtered_history.append(message)
|
||||
|
||||
chat_history[:] = filtered_history
|
||||
return removed_total
|
||||
|
||||
|
||||
def _is_timing_tool_assistant_message(message: LLMContextMessage) -> bool:
|
||||
if not isinstance(message, AssistantMessage) or not message.tool_calls:
|
||||
return False
|
||||
|
||||
return all(tool_call.func_name in TIMING_HISTORY_TOOL_NAMES for tool_call in message.tool_calls)
|
||||
@@ -14,7 +14,7 @@ from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.config.config import global_config
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.llm_models.exceptions import ReqAbortException
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
@@ -35,7 +35,8 @@ from .context_messages import (
|
||||
ToolResultMessage,
|
||||
contains_complex_message,
|
||||
)
|
||||
from .history_utils import build_prefixed_message_sequence, build_session_message_visible_text, drop_leading_orphan_tool_results
|
||||
from .history_post_processor import process_chat_history_after_cycle
|
||||
from .history_utils import build_prefixed_message_sequence, build_session_message_visible_text
|
||||
from .monitor_events import (
|
||||
emit_cycle_start,
|
||||
emit_message_ingested,
|
||||
@@ -53,7 +54,7 @@ logger = get_logger("maisaka_reasoning_engine")
|
||||
TIMING_GATE_CONTEXT_LIMIT = 24
|
||||
TIMING_GATE_MAX_TOKENS = 384
|
||||
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
|
||||
ACTION_HIDDEN_TOOL_NAMES = {"continue", "no_reply", "wait"}
|
||||
ACTION_HIDDEN_TOOL_NAMES = {"continue", "no_reply"}
|
||||
ACTION_BUILTIN_TOOL_NAMES = {tool_spec.name for tool_spec in get_action_tool_specs()}
|
||||
|
||||
|
||||
@@ -94,6 +95,7 @@ class MaisakaReasoningEngine:
|
||||
async def _run_interruptible_planner(
|
||||
self,
|
||||
*,
|
||||
injected_user_messages: Optional[list[str]] = None,
|
||||
tool_definitions: Optional[list[dict[str, Any]]] = None,
|
||||
) -> Any:
|
||||
"""运行一轮可被新消息打断的主 planner 请求。"""
|
||||
@@ -105,6 +107,7 @@ class MaisakaReasoningEngine:
|
||||
try:
|
||||
return await self._runtime._chat_loop_service.chat_loop_step(
|
||||
self._runtime._chat_history,
|
||||
injected_user_messages=injected_user_messages,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
except ReqAbortException:
|
||||
@@ -117,36 +120,27 @@ class MaisakaReasoningEngine:
|
||||
)
|
||||
self._runtime._chat_loop_service.set_interrupt_flag(None)
|
||||
|
||||
async def _run_interruptible_sub_agent(
|
||||
async def _run_timing_gate_sub_agent(
|
||||
self,
|
||||
*,
|
||||
context_message_limit: int,
|
||||
system_prompt: str,
|
||||
tool_definitions: list[dict[str, Any]],
|
||||
) -> Any:
|
||||
"""运行一轮可被新消息打断的临时子代理请求。"""
|
||||
"""运行一轮 Timing Gate 子代理请求。
|
||||
|
||||
interrupt_flag = asyncio.Event()
|
||||
interrupted = False
|
||||
self._runtime._bind_planner_interrupt_flag(interrupt_flag)
|
||||
try:
|
||||
return await self._runtime.run_sub_agent(
|
||||
context_message_limit=context_message_limit,
|
||||
system_prompt=system_prompt,
|
||||
request_kind="timing_gate",
|
||||
interrupt_flag=interrupt_flag,
|
||||
max_tokens=TIMING_GATE_MAX_TOKENS,
|
||||
temperature=0.1,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
except ReqAbortException:
|
||||
interrupted = True
|
||||
raise
|
||||
finally:
|
||||
self._runtime._unbind_planner_interrupt_flag(
|
||||
interrupt_flag,
|
||||
interrupted=interrupted,
|
||||
)
|
||||
Timing Gate 阶段不再响应新的 planner 打断,只有主 planner 阶段允许被打断。
|
||||
"""
|
||||
|
||||
return await self._runtime.run_sub_agent(
|
||||
context_message_limit=context_message_limit,
|
||||
system_prompt=system_prompt,
|
||||
request_kind="timing_gate",
|
||||
interrupt_flag=None,
|
||||
max_tokens=TIMING_GATE_MAX_TOKENS,
|
||||
temperature=0.1,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_timing_gate_fallback_prompt() -> str:
|
||||
@@ -174,22 +168,34 @@ class MaisakaReasoningEngine:
|
||||
except Exception:
|
||||
return self._build_timing_gate_fallback_prompt()
|
||||
|
||||
async def _build_action_tool_definitions(self) -> list[dict[str, Any]]:
|
||||
"""构造 Action Loop 阶段可见的工具定义。"""
|
||||
async def _build_action_tool_definitions(self) -> tuple[list[dict[str, Any]], str]:
|
||||
"""构造 Action Loop 阶段可见的工具定义与 deferred tools 提示。"""
|
||||
|
||||
if self._runtime._tool_registry is None:
|
||||
return []
|
||||
self._runtime.update_deferred_tool_specs([])
|
||||
self._runtime.set_current_action_tool_names([])
|
||||
return [], ""
|
||||
|
||||
tool_specs = await self._runtime._tool_registry.list_tools()
|
||||
return [
|
||||
tool_spec.to_llm_definition()
|
||||
for tool_spec in tool_specs
|
||||
if tool_spec.name not in ACTION_HIDDEN_TOOL_NAMES
|
||||
and (
|
||||
tool_spec.provider_name != "maisaka_builtin"
|
||||
or tool_spec.name in ACTION_BUILTIN_TOOL_NAMES
|
||||
)
|
||||
]
|
||||
visible_builtin_tool_specs: list[ToolSpec] = []
|
||||
deferred_tool_specs: list[ToolSpec] = []
|
||||
for tool_spec in tool_specs:
|
||||
if tool_spec.name in ACTION_HIDDEN_TOOL_NAMES:
|
||||
continue
|
||||
if tool_spec.provider_name == "maisaka_builtin":
|
||||
if tool_spec.name in ACTION_BUILTIN_TOOL_NAMES:
|
||||
visible_builtin_tool_specs.append(tool_spec)
|
||||
continue
|
||||
deferred_tool_specs.append(tool_spec)
|
||||
|
||||
self._runtime.update_deferred_tool_specs(deferred_tool_specs)
|
||||
discovered_deferred_tool_specs = self._runtime.get_discovered_deferred_tool_specs()
|
||||
visible_tool_specs = [*visible_builtin_tool_specs, *discovered_deferred_tool_specs]
|
||||
self._runtime.set_current_action_tool_names([tool_spec.name for tool_spec in visible_tool_specs])
|
||||
return (
|
||||
[tool_spec.to_llm_definition() for tool_spec in visible_tool_specs],
|
||||
self._runtime.build_deferred_tools_reminder(),
|
||||
)
|
||||
|
||||
async def _invoke_tool_call(
|
||||
self,
|
||||
@@ -227,18 +233,19 @@ class MaisakaReasoningEngine:
|
||||
async def _run_timing_gate(
|
||||
self,
|
||||
anchor_message: SessionMessage,
|
||||
) -> tuple[Literal["continue", "no_reply", "wait"], Any, list[str]]:
|
||||
) -> tuple[Literal["continue", "no_reply", "wait"], Any, list[str], list[dict[str, Any]]]:
|
||||
"""运行 Timing Gate 子代理并返回控制决策。"""
|
||||
|
||||
if self._runtime._force_continue_until_reply:
|
||||
if self._runtime._force_next_timing_continue:
|
||||
return self._build_forced_continue_timing_result()
|
||||
|
||||
response = await self._run_interruptible_sub_agent(
|
||||
response = await self._run_timing_gate_sub_agent(
|
||||
context_message_limit=TIMING_GATE_CONTEXT_LIMIT,
|
||||
system_prompt=self._build_timing_gate_system_prompt(),
|
||||
tool_definitions=get_timing_tools(),
|
||||
)
|
||||
tool_result_summaries: list[str] = []
|
||||
tool_monitor_results: list[dict[str, Any]] = []
|
||||
selected_tool_call: Optional[ToolCall] = None
|
||||
for tool_call in response.tool_calls:
|
||||
if tool_call.func_name in TIMING_GATE_TOOL_NAMES:
|
||||
@@ -247,11 +254,11 @@ class MaisakaReasoningEngine:
|
||||
|
||||
if selected_tool_call is None:
|
||||
logger.warning(f"{self._runtime.log_prefix} Timing Gate 未返回有效控制工具,默认继续执行 Action Loop")
|
||||
return "continue", response, tool_result_summaries
|
||||
return "continue", response, tool_result_summaries, tool_monitor_results
|
||||
|
||||
append_history = selected_tool_call.func_name != "continue"
|
||||
append_history = False
|
||||
store_record = selected_tool_call.func_name != "continue"
|
||||
_, result, _ = await self._invoke_tool_call(
|
||||
invocation, result, tool_spec = await self._invoke_tool_call(
|
||||
selected_tool_call,
|
||||
response.content or "",
|
||||
anchor_message,
|
||||
@@ -259,19 +266,31 @@ class MaisakaReasoningEngine:
|
||||
store_record=store_record,
|
||||
)
|
||||
tool_result_summaries.append(self._build_tool_result_summary(selected_tool_call, result))
|
||||
tool_monitor_results.append(
|
||||
self._build_tool_monitor_result(
|
||||
selected_tool_call,
|
||||
invocation,
|
||||
result,
|
||||
duration_ms=0.0,
|
||||
tool_spec=tool_spec,
|
||||
)
|
||||
)
|
||||
self._append_timing_gate_execution_result(response, selected_tool_call, result)
|
||||
|
||||
timing_action = str(result.metadata.get("timing_action") or selected_tool_call.func_name).strip()
|
||||
if timing_action not in TIMING_GATE_TOOL_NAMES:
|
||||
logger.warning(
|
||||
f"{self._runtime.log_prefix} Timing Gate 返回未知动作 {timing_action!r},将按 continue 处理"
|
||||
)
|
||||
return "continue", response, tool_result_summaries
|
||||
return timing_action, response, tool_result_summaries
|
||||
return "continue", response, tool_result_summaries, tool_monitor_results
|
||||
return timing_action, response, tool_result_summaries, tool_monitor_results
|
||||
|
||||
def _build_forced_continue_timing_result(self) -> tuple[Literal["continue"], ChatResponse, list[str]]:
|
||||
def _build_forced_continue_timing_result(
|
||||
self,
|
||||
) -> tuple[Literal["continue"], ChatResponse, list[str], list[dict[str, Any]]]:
|
||||
"""构造跳过 Timing Gate 时使用的伪 continue 结果。"""
|
||||
|
||||
reason = self._runtime._build_force_continue_timing_reason()
|
||||
reason = self._runtime._consume_force_next_timing_continue_reason() or "本轮直接跳过 Timing Gate 并视作 continue。"
|
||||
logger.info(f"{self._runtime.log_prefix} {reason}")
|
||||
return (
|
||||
"continue",
|
||||
@@ -296,8 +315,24 @@ class MaisakaReasoningEngine:
|
||||
prompt_section=None,
|
||||
),
|
||||
[f"- continue [强制跳过]: {reason}"],
|
||||
[],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _mark_timing_gate_completed(timing_action: str) -> bool:
|
||||
"""根据门控动作决定下一轮是否还需要重新执行 timing。"""
|
||||
|
||||
return timing_action != "continue"
|
||||
|
||||
@staticmethod
|
||||
def _should_retry_planner_after_interrupt(
|
||||
*,
|
||||
round_index: int,
|
||||
max_internal_rounds: int,
|
||||
has_pending_messages: bool,
|
||||
) -> bool:
|
||||
return has_pending_messages and round_index + 1 < max_internal_rounds
|
||||
|
||||
async def run_loop(self) -> None:
|
||||
"""独立消费消息批次,并执行对应的内部思考轮次。"""
|
||||
try:
|
||||
@@ -314,13 +349,20 @@ class MaisakaReasoningEngine:
|
||||
if self._runtime._has_pending_messages()
|
||||
else []
|
||||
)
|
||||
if not timeout_triggered and not cached_messages and not message_triggered:
|
||||
if not timeout_triggered and not cached_messages:
|
||||
continue
|
||||
|
||||
self._runtime._agent_state = self._runtime._STATE_RUNNING
|
||||
self._runtime._update_stage_status(
|
||||
"消息整理",
|
||||
f"待处理消息 {len(cached_messages)} 条" if cached_messages else "准备复用超时锚点",
|
||||
)
|
||||
if cached_messages:
|
||||
asyncio.create_task(self._runtime._trigger_batch_learning(cached_messages))
|
||||
self._append_wait_interrupted_message_if_needed()
|
||||
if timeout_triggered:
|
||||
self._runtime._chat_history.append(
|
||||
self._build_wait_completed_message(has_new_messages=True)
|
||||
)
|
||||
await self._ingest_messages(cached_messages)
|
||||
anchor_message = cached_messages[-1]
|
||||
else:
|
||||
@@ -332,13 +374,16 @@ class MaisakaReasoningEngine:
|
||||
continue
|
||||
logger.info(f"{self._runtime.log_prefix} 等待超时后开始新一轮思考")
|
||||
if self._runtime._pending_wait_tool_call_id:
|
||||
self._runtime._chat_history.append(self._build_wait_timeout_message())
|
||||
self._trim_chat_history()
|
||||
|
||||
self._runtime._chat_history.append(
|
||||
self._build_wait_completed_message(has_new_messages=False)
|
||||
)
|
||||
try:
|
||||
timing_gate_required = True
|
||||
for round_index in range(self._runtime._max_internal_rounds):
|
||||
cycle_detail = self._start_cycle()
|
||||
round_text = f"第 {round_index + 1}/{self._runtime._max_internal_rounds} 轮"
|
||||
self._runtime._log_cycle_started(cycle_detail, round_index)
|
||||
self._runtime._update_stage_status("启动循环", f"循环 {cycle_detail.cycle_id}", round_text=round_text)
|
||||
await emit_cycle_start(
|
||||
session_id=self._runtime.session_id,
|
||||
cycle_id=cycle_detail.cycle_id,
|
||||
@@ -349,10 +394,14 @@ class MaisakaReasoningEngine:
|
||||
planner_started_at = 0.0
|
||||
planner_duration_ms = 0.0
|
||||
timing_duration_ms = 0.0
|
||||
current_stage_started_at = 0.0
|
||||
timing_action: Optional[str] = None
|
||||
timing_response: Optional[ChatResponse] = None
|
||||
timing_tool_results: Optional[list[str]] = None
|
||||
timing_tool_monitor_results: Optional[list[dict[str, Any]]] = None
|
||||
response: Optional[ChatResponse] = None
|
||||
action_tool_definitions: list[dict[str, Any]] = []
|
||||
planner_extra_lines: list[str] = []
|
||||
tool_result_summaries: list[str] = []
|
||||
tool_monitor_results: list[dict[str, Any]] = []
|
||||
try:
|
||||
@@ -364,30 +413,46 @@ class MaisakaReasoningEngine:
|
||||
f"{self._runtime.log_prefix} 本轮思考前已刷新 {refreshed_message_count} 条视觉占位历史消息"
|
||||
)
|
||||
|
||||
timing_started_at = time.time()
|
||||
timing_action, timing_response, timing_tool_results = await self._run_timing_gate(anchor_message)
|
||||
timing_duration_ms = (time.time() - timing_started_at) * 1000
|
||||
cycle_detail.time_records["timing_gate"] = timing_duration_ms / 1000
|
||||
await emit_timing_gate_result(
|
||||
session_id=self._runtime.session_id,
|
||||
cycle_id=cycle_detail.cycle_id,
|
||||
action=timing_action,
|
||||
content=timing_response.content,
|
||||
tool_calls=timing_response.tool_calls,
|
||||
messages=[],
|
||||
prompt_tokens=timing_response.prompt_tokens,
|
||||
selected_history_count=timing_response.selected_history_count,
|
||||
duration_ms=timing_duration_ms,
|
||||
)
|
||||
if timing_action != "continue":
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} Timing Gate 结束当前回合: "
|
||||
f"回合={round_index + 1} 动作={timing_action}"
|
||||
if timing_gate_required:
|
||||
self._runtime._update_stage_status("Timing Gate", "等待门控决策", round_text=round_text)
|
||||
current_stage_started_at = time.time()
|
||||
timing_started_at = time.time()
|
||||
(
|
||||
timing_action,
|
||||
timing_response,
|
||||
timing_tool_results,
|
||||
timing_tool_monitor_results,
|
||||
) = await self._run_timing_gate(anchor_message)
|
||||
timing_duration_ms = (time.time() - timing_started_at) * 1000
|
||||
cycle_detail.time_records["timing_gate"] = timing_duration_ms / 1000
|
||||
await emit_timing_gate_result(
|
||||
session_id=self._runtime.session_id,
|
||||
cycle_id=cycle_detail.cycle_id,
|
||||
action=timing_action,
|
||||
content=timing_response.content,
|
||||
tool_calls=timing_response.tool_calls,
|
||||
messages=[],
|
||||
prompt_tokens=timing_response.prompt_tokens,
|
||||
selected_history_count=timing_response.selected_history_count,
|
||||
duration_ms=timing_duration_ms,
|
||||
)
|
||||
timing_gate_required = self._mark_timing_gate_completed(timing_action)
|
||||
if timing_action != "continue":
|
||||
logger.debug(
|
||||
f"{self._runtime.log_prefix} Timing Gate 结束当前回合: "
|
||||
f"回合={round_index + 1} 动作={timing_action}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} 跳过 Timing Gate,继续执行 Planner: "
|
||||
f"回合={round_index + 1}"
|
||||
)
|
||||
break
|
||||
|
||||
planner_started_at = time.time()
|
||||
action_tool_definitions = await self._build_action_tool_definitions()
|
||||
current_stage_started_at = planner_started_at
|
||||
self._runtime._update_stage_status("Planner", "组织上下文并请求模型", round_text=round_text)
|
||||
action_tool_definitions, deferred_tools_reminder = await self._build_action_tool_definitions()
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} 规划器开始执行: "
|
||||
f"回合={round_index + 1} "
|
||||
@@ -395,6 +460,7 @@ class MaisakaReasoningEngine:
|
||||
f"开始时间={planner_started_at:.3f}"
|
||||
)
|
||||
response = await self._run_interruptible_planner(
|
||||
injected_user_messages=[deferred_tools_reminder] if deferred_tools_reminder else None,
|
||||
tool_definitions=action_tool_definitions,
|
||||
)
|
||||
planner_duration_ms = (time.time() - planner_started_at) * 1000
|
||||
@@ -406,8 +472,8 @@ class MaisakaReasoningEngine:
|
||||
)
|
||||
reasoning_content = response.content or ""
|
||||
if self._should_replace_reasoning(reasoning_content):
|
||||
response.content = "我应该根据我上面思考的内容进行反思,重新思考我下一步的行动,我需要分析当前场景,对话,以及我可以使用的工具,然后先输出想法再使用工具"
|
||||
response.raw_message.content = "我应该根据我上面思考的内容进行反思,重新思考我下一步的行动,我需要分析当前场景,对话,以及我可以使用的工具,然后先输出想法再使用工具"
|
||||
response.content = "我应该根据我上面思考的内容进行反思,重新思考我下一步的行动,我需要分析当前场景,对话,以及我可以使用的工具,然后直接输出我的想法"
|
||||
response.raw_message.content = "我应该根据我上面思考的内容进行反思,重新思考我下一步的行动,我需要分析当前场景,对话,以及我可以使用的工具,然后直接输出我的想法"
|
||||
logger.info(f"{self._runtime.log_prefix} 当前思考与上一轮过于相似,已替换为重新思考提示")
|
||||
|
||||
self._last_reasoning_content = reasoning_content
|
||||
@@ -428,20 +494,73 @@ class MaisakaReasoningEngine:
|
||||
|
||||
if not response.content:
|
||||
break
|
||||
except ReqAbortException:
|
||||
interrupted_at = time.time()
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} 规划器打断成功: "
|
||||
f"回合={round_index + 1} "
|
||||
f"开始时间={planner_started_at:.3f} "
|
||||
f"打断时间={interrupted_at:.3f} "
|
||||
f"耗时={interrupted_at - planner_started_at:.3f} 秒"
|
||||
except ReqAbortException as exc:
|
||||
self._runtime._update_stage_status(
|
||||
"Planner 已打断",
|
||||
str(exc) or "收到外部中断信号",
|
||||
round_text=round_text,
|
||||
)
|
||||
break
|
||||
interrupted_at = time.time()
|
||||
interrupted_stage_label = "Planner"
|
||||
interrupted_text = "Planner 收到新消息,开始重新决策"
|
||||
interrupted_response = ChatResponse(
|
||||
content=interrupted_text or None,
|
||||
tool_calls=[],
|
||||
request_messages=[],
|
||||
raw_message=AssistantMessage(
|
||||
content=interrupted_text,
|
||||
timestamp=datetime.now(),
|
||||
tool_calls=[],
|
||||
source_kind="perception",
|
||||
),
|
||||
selected_history_count=len(self._runtime._chat_history),
|
||||
tool_count=len(action_tool_definitions),
|
||||
prompt_tokens=0,
|
||||
built_message_count=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
prompt_section=None,
|
||||
)
|
||||
interrupted_extra_lines = [
|
||||
"状态:已被新消息打断",
|
||||
f"打断位置:{interrupted_stage_label} 请求流式响应阶段",
|
||||
f"打断耗时:{interrupted_at - current_stage_started_at:.3f} 秒",
|
||||
]
|
||||
response = interrupted_response
|
||||
planner_extra_lines = interrupted_extra_lines
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} {interrupted_stage_label} 打断成功: "
|
||||
f"回合={round_index + 1} "
|
||||
f"开始时间={current_stage_started_at:.3f} "
|
||||
f"打断时间={interrupted_at:.3f} "
|
||||
f"耗时={interrupted_at - current_stage_started_at:.3f} 秒"
|
||||
)
|
||||
if not self._should_retry_planner_after_interrupt(
|
||||
round_index=round_index,
|
||||
max_internal_rounds=self._runtime._max_internal_rounds,
|
||||
has_pending_messages=self._runtime._has_pending_messages(),
|
||||
):
|
||||
break
|
||||
|
||||
await self._runtime._wait_for_message_quiet_period()
|
||||
self._runtime._message_turn_scheduled = False
|
||||
interrupted_messages = self._runtime._collect_pending_messages()
|
||||
if not interrupted_messages:
|
||||
break
|
||||
|
||||
asyncio.create_task(self._runtime._trigger_batch_learning(interrupted_messages))
|
||||
await self._ingest_messages(interrupted_messages)
|
||||
anchor_message = interrupted_messages[-1]
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} 淇濇寔娲昏穬鐘舵€侊紝璺宠繃 Timing Gate 鐩存帴閲嶈瘯 Planner: "
|
||||
f"鍥炲悎={round_index + 2}"
|
||||
)
|
||||
continue
|
||||
finally:
|
||||
completed_cycle = self._end_cycle(cycle_detail)
|
||||
self._runtime._render_context_usage_panel(
|
||||
cycle_id=cycle_detail.cycle_id,
|
||||
time_records=dict(completed_cycle.time_records),
|
||||
timing_selected_history_count=(
|
||||
timing_response.selected_history_count if timing_response is not None else None
|
||||
),
|
||||
@@ -452,6 +571,7 @@ class MaisakaReasoningEngine:
|
||||
timing_response=timing_response.content or "" if timing_response is not None else "",
|
||||
timing_tool_calls=timing_response.tool_calls if timing_response is not None else None,
|
||||
timing_tool_results=timing_tool_results,
|
||||
timing_tool_detail_results=timing_tool_monitor_results,
|
||||
timing_prompt_section=(
|
||||
timing_response.prompt_section if timing_response is not None else None
|
||||
),
|
||||
@@ -464,6 +584,7 @@ class MaisakaReasoningEngine:
|
||||
planner_tool_results=tool_result_summaries,
|
||||
planner_tool_detail_results=tool_monitor_results,
|
||||
planner_prompt_section=response.prompt_section if response is not None else None,
|
||||
planner_extra_lines=planner_extra_lines,
|
||||
)
|
||||
await emit_planner_finalized(
|
||||
session_id=self._runtime.session_id,
|
||||
@@ -505,6 +626,8 @@ class MaisakaReasoningEngine:
|
||||
finally:
|
||||
if self._runtime._agent_state == self._runtime._STATE_RUNNING:
|
||||
self._runtime._agent_state = self._runtime._STATE_STOP
|
||||
if self._runtime._running:
|
||||
self._runtime._update_stage_status("等待消息", "本轮处理结束")
|
||||
except asyncio.CancelledError:
|
||||
self._runtime._log_internal_loop_cancelled()
|
||||
raise
|
||||
@@ -543,33 +666,22 @@ class MaisakaReasoningEngine:
|
||||
return self._runtime.message_cache[-1]
|
||||
return None
|
||||
|
||||
def _build_wait_timeout_message(self) -> ToolResultMessage:
|
||||
"""构造 wait 超时后的工具结果消息。"""
|
||||
def _build_wait_completed_message(self, *, has_new_messages: bool) -> ToolResultMessage:
|
||||
"""构造 wait 完成后的工具结果消息。"""
|
||||
tool_call_id = self._runtime._pending_wait_tool_call_id or "wait_timeout"
|
||||
self._runtime._pending_wait_tool_call_id = None
|
||||
content = (
|
||||
"等待已结束,期间收到了新的用户输入。请结合这些新消息继续下一轮思考。"
|
||||
if has_new_messages
|
||||
else "等待已超时,期间没有收到新的用户输入。请基于现有上下文继续下一轮思考。"
|
||||
)
|
||||
return ToolResultMessage(
|
||||
content="等待已超时,期间没有收到新的用户输入。请基于现有上下文继续下一轮思考。",
|
||||
content=content,
|
||||
timestamp=datetime.now(),
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name="wait",
|
||||
)
|
||||
|
||||
def _append_wait_interrupted_message_if_needed(self) -> None:
|
||||
"""如果 wait 被新消息打断,则补一条对应的工具结果消息。"""
|
||||
tool_call_id = self._runtime._pending_wait_tool_call_id
|
||||
if not tool_call_id:
|
||||
return
|
||||
|
||||
self._runtime._pending_wait_tool_call_id = None
|
||||
self._runtime._chat_history.append(
|
||||
ToolResultMessage(
|
||||
content="等待过程被新的用户输入打断,已继续处理最新消息。",
|
||||
timestamp=datetime.now(),
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name="wait",
|
||||
)
|
||||
)
|
||||
|
||||
async def _ingest_messages(self, messages: list[SessionMessage]) -> None:
|
||||
"""处理传入消息列表,将其转换为历史消息并加入聊天历史缓存。"""
|
||||
for message in messages:
|
||||
@@ -578,7 +690,6 @@ class MaisakaReasoningEngine:
|
||||
continue
|
||||
|
||||
self._insert_chat_history_message(history_message)
|
||||
self._trim_chat_history()
|
||||
|
||||
# 向监控前端广播新消息注入事件
|
||||
user_info = message.message_info.user_info
|
||||
@@ -628,10 +739,47 @@ class MaisakaReasoningEngine:
|
||||
planner_prefix: str,
|
||||
) -> MessageSequence:
|
||||
message_sequence = build_prefixed_message_sequence(message.raw_message, planner_prefix)
|
||||
if global_config.visual.multimodal_planner:
|
||||
if self._resolve_enable_visual_planner():
|
||||
await self._hydrate_visual_components(message_sequence.components)
|
||||
return message_sequence
|
||||
|
||||
@staticmethod
|
||||
def _resolve_enable_visual_planner() -> bool:
|
||||
planner_mode = global_config.visual.planner_mode
|
||||
planner_task_config = config_manager.get_model_config().model_task_config.planner
|
||||
models_by_name = {model.name: model for model in config_manager.get_model_config().models}
|
||||
|
||||
if planner_mode == "text":
|
||||
return False
|
||||
|
||||
planner_models: list[str] = list(planner_task_config.model_list)
|
||||
missing_models = [model_name for model_name in planner_models if model_name not in models_by_name]
|
||||
non_visual_models = [
|
||||
model_name for model_name in planner_models if model_name in models_by_name and not models_by_name[model_name].visual
|
||||
]
|
||||
|
||||
if planner_mode == "multimodal":
|
||||
if missing_models:
|
||||
raise ValueError(
|
||||
"planner_mode=multimodal,但 planner 任务存在未定义的模型:"
|
||||
f"{', '.join(missing_models)}"
|
||||
)
|
||||
if non_visual_models:
|
||||
raise ValueError(
|
||||
"planner_mode=multimodal,但 planner 任务存在未开启 visual 的模型:"
|
||||
f"{', '.join(non_visual_models)}"
|
||||
)
|
||||
return True
|
||||
|
||||
if missing_models:
|
||||
logger.warning(
|
||||
"planner_mode=auto 时发现 planner 任务存在未定义模型:"
|
||||
f"{', '.join(missing_models)},将退化为纯文本 planner"
|
||||
)
|
||||
return False
|
||||
|
||||
return bool(planner_models) and not non_visual_models
|
||||
|
||||
async def _hydrate_visual_components(self, planner_components: list[object]) -> None:
|
||||
"""在 Maisaka 真正需要图片或表情时,按需回填二进制数据。"""
|
||||
load_tasks: list[asyncio.Task[None]] = []
|
||||
@@ -681,6 +829,7 @@ class MaisakaReasoningEngine:
|
||||
"""结束并记录一轮 Maisaka 思考循环。"""
|
||||
cycle_detail.end_time = time.time()
|
||||
self._runtime.history_loop.append(cycle_detail)
|
||||
self._post_process_chat_history_after_cycle()
|
||||
|
||||
timer_strings = [
|
||||
f"{name}: {duration:.2f}s"
|
||||
@@ -690,26 +839,20 @@ class MaisakaReasoningEngine:
|
||||
self._runtime._log_cycle_completed(cycle_detail, timer_strings)
|
||||
return cycle_detail
|
||||
|
||||
def _trim_chat_history(self) -> None:
|
||||
def _post_process_chat_history_after_cycle(self) -> None:
|
||||
"""裁剪聊天历史,保证用户消息数量不超过配置限制。"""
|
||||
conversation_message_count = sum(1 for message in self._runtime._chat_history if message.count_in_context)
|
||||
if conversation_message_count <= self._runtime._max_context_size:
|
||||
process_result = process_chat_history_after_cycle(
|
||||
self._runtime._chat_history,
|
||||
max_context_size=self._runtime._max_context_size,
|
||||
)
|
||||
if process_result.removed_count <= 0:
|
||||
return
|
||||
|
||||
trimmed_history = list(self._runtime._chat_history)
|
||||
removed_count = 0
|
||||
|
||||
while conversation_message_count > self._runtime._max_context_size and trimmed_history:
|
||||
removed_message = trimmed_history.pop(0)
|
||||
removed_count += 1
|
||||
if removed_message.count_in_context:
|
||||
conversation_message_count -= 1
|
||||
|
||||
trimmed_history, pruned_orphan_count = drop_leading_orphan_tool_results(trimmed_history)
|
||||
removed_count += pruned_orphan_count
|
||||
|
||||
self._runtime._chat_history = trimmed_history
|
||||
self._runtime._log_history_trimmed(removed_count, conversation_message_count)
|
||||
self._runtime._chat_history = process_result.history
|
||||
self._runtime._log_history_trimmed(
|
||||
process_result.removed_count,
|
||||
process_result.remaining_context_count,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _calculate_similarity(text1: str, text2: str) -> float:
|
||||
@@ -934,6 +1077,9 @@ class MaisakaReasoningEngine:
|
||||
if invocation.tool_name == "no_reply":
|
||||
return "你暂停了当前对话循环,等待新的外部消息。"
|
||||
|
||||
if invocation.tool_name == "finish":
|
||||
return "你结束了本轮思考,等待新的外部消息后再继续。"
|
||||
|
||||
if invocation.tool_name == "continue":
|
||||
return "你允许当前对话继续进入下一轮完整思考与工具执行。"
|
||||
|
||||
@@ -1065,6 +1211,24 @@ class MaisakaReasoningEngine:
|
||||
)
|
||||
)
|
||||
|
||||
def _append_timing_gate_execution_result(
|
||||
self,
|
||||
response: ChatResponse,
|
||||
tool_call: ToolCall,
|
||||
result: ToolExecutionResult,
|
||||
) -> None:
|
||||
"""将 Timing Gate 的决策链写入历史,供后续门控复用。"""
|
||||
|
||||
self._runtime._chat_history.append(
|
||||
AssistantMessage(
|
||||
content=response.content or "",
|
||||
timestamp=response.raw_message.timestamp,
|
||||
tool_calls=[tool_call],
|
||||
source_kind="timing_gate",
|
||||
)
|
||||
)
|
||||
self._append_tool_execution_result(tool_call, result)
|
||||
|
||||
def _build_tool_result_summary(self, tool_call: ToolCall, result: ToolExecutionResult) -> str:
|
||||
"""构建用于终端展示的工具结果摘要。"""
|
||||
|
||||
@@ -1084,6 +1248,7 @@ class MaisakaReasoningEngine:
|
||||
invocation: ToolInvocation,
|
||||
result: ToolExecutionResult,
|
||||
duration_ms: float,
|
||||
tool_spec: Optional[ToolSpec] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""构建 planner.finalized 中单个工具的监控结果。"""
|
||||
|
||||
@@ -1092,9 +1257,20 @@ class MaisakaReasoningEngine:
|
||||
if monitor_detail is not None:
|
||||
normalized_detail = self._normalize_tool_record_value(monitor_detail)
|
||||
|
||||
monitor_card = result.metadata.get("monitor_card")
|
||||
normalized_card = None
|
||||
if monitor_card is not None:
|
||||
normalized_card = self._normalize_tool_record_value(monitor_card)
|
||||
|
||||
monitor_sub_cards = result.metadata.get("monitor_sub_cards")
|
||||
normalized_sub_cards = None
|
||||
if monitor_sub_cards is not None:
|
||||
normalized_sub_cards = self._normalize_tool_record_value(monitor_sub_cards)
|
||||
|
||||
return {
|
||||
"tool_call_id": tool_call.call_id,
|
||||
"tool_name": tool_call.func_name,
|
||||
"tool_title": tool_spec.title.strip() if tool_spec is not None and tool_spec.title.strip() else "",
|
||||
"tool_args": self._normalize_tool_record_value(
|
||||
invocation.arguments if isinstance(invocation.arguments, dict) else {}
|
||||
),
|
||||
@@ -1102,6 +1278,8 @@ class MaisakaReasoningEngine:
|
||||
"duration_ms": round(duration_ms, 2),
|
||||
"summary": self._build_tool_result_summary(tool_call, result),
|
||||
"detail": normalized_detail,
|
||||
"card": normalized_card,
|
||||
"sub_cards": normalized_sub_cards,
|
||||
}
|
||||
|
||||
async def _handle_tool_calls(
|
||||
@@ -1137,7 +1315,7 @@ class MaisakaReasoningEngine:
|
||||
self._append_tool_execution_result(tool_call, result)
|
||||
tool_result_summaries.append(self._build_tool_result_summary(tool_call, result))
|
||||
tool_monitor_results.append(
|
||||
self._build_tool_monitor_result(tool_call, invocation, result, duration_ms=0.0)
|
||||
self._build_tool_monitor_result(tool_call, invocation, result, duration_ms=0.0, tool_spec=None)
|
||||
)
|
||||
return False, tool_result_summaries, tool_monitor_results
|
||||
|
||||
@@ -1146,10 +1324,25 @@ class MaisakaReasoningEngine:
|
||||
tool_spec.name: tool_spec
|
||||
for tool_spec in await self._runtime._tool_registry.list_tools()
|
||||
}
|
||||
for tool_call in tool_calls:
|
||||
total_tool_count = len(tool_calls)
|
||||
for tool_index, tool_call in enumerate(tool_calls, start=1):
|
||||
invocation = self._build_tool_invocation(tool_call, latest_thought)
|
||||
self._runtime._update_stage_status(
|
||||
f"工具执行 · {invocation.tool_name}",
|
||||
f"第 {tool_index}/{total_tool_count} 个工具",
|
||||
)
|
||||
tool_started_at = time.time()
|
||||
result = await self._runtime._tool_registry.invoke(invocation, execution_context)
|
||||
if not self._runtime.is_action_tool_currently_available(invocation.tool_name):
|
||||
result = ToolExecutionResult(
|
||||
tool_name=invocation.tool_name,
|
||||
success=False,
|
||||
error_message=(
|
||||
f"工具 {invocation.tool_name} 当前未直接暴露给 planner。"
|
||||
"如果它在 deferred tools 提示中,请先调用 tool_search。"
|
||||
),
|
||||
)
|
||||
else:
|
||||
result = await self._runtime._tool_registry.invoke(invocation, execution_context)
|
||||
tool_duration_ms = (time.time() - tool_started_at) * 1000
|
||||
await self._store_tool_execution_record(
|
||||
invocation,
|
||||
@@ -1159,7 +1352,13 @@ class MaisakaReasoningEngine:
|
||||
self._append_tool_execution_result(tool_call, result)
|
||||
tool_result_summaries.append(self._build_tool_result_summary(tool_call, result))
|
||||
tool_monitor_results.append(
|
||||
self._build_tool_monitor_result(tool_call, invocation, result, tool_duration_ms)
|
||||
self._build_tool_monitor_result(
|
||||
tool_call,
|
||||
invocation,
|
||||
result,
|
||||
tool_duration_ms,
|
||||
tool_spec=tool_spec_map.get(invocation.tool_name),
|
||||
)
|
||||
)
|
||||
|
||||
if not result.success and tool_call.func_name == "reply":
|
||||
|
||||
@@ -21,7 +21,7 @@ from src.common.data_models.mai_message_data_model import GroupInfo, UserInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolRegistry
|
||||
from src.core.tooling import ToolRegistry, ToolSpec
|
||||
from src.learners.expression_learner import ExpressionLearner
|
||||
from src.learners.jargon_miner import JargonMiner
|
||||
from src.llm_models.payload_content.resp_format import RespFormat
|
||||
@@ -30,11 +30,13 @@ from src.mcp_module import MCPManager
|
||||
from src.mcp_module.host_llm_bridge import MCPHostLLMBridge
|
||||
from src.mcp_module.provider import MCPToolProvider
|
||||
from src.plugin_runtime.tool_provider import PluginToolProvider
|
||||
from src.plugin_runtime.hook_payloads import deserialize_prompt_messages
|
||||
|
||||
from .chat_loop_service import ChatResponse, MaisakaChatLoopService
|
||||
from .context_messages import LLMContextMessage
|
||||
from .display_utils import build_tool_call_summary_lines, format_token_count
|
||||
from .prompt_cli_renderer import PromptCLIVisualizer
|
||||
from .display.display_utils import build_tool_call_summary_lines, format_token_count
|
||||
from .display.prompt_cli_renderer import PromptCLIVisualizer
|
||||
from .display.stage_status_board import remove_stage_status, update_stage_status
|
||||
from .reasoning_engine import MaisakaReasoningEngine
|
||||
from .tool_provider import MaisakaBuiltinToolProvider
|
||||
|
||||
@@ -92,14 +94,16 @@ class MaisakaHeartFlowChatting:
|
||||
self._max_internal_rounds = MAX_INTERNAL_ROUNDS
|
||||
self._max_context_size = max(1, int(global_config.chat.max_context_size))
|
||||
self._agent_state: Literal["running", "wait", "stop"] = self._STATE_STOP
|
||||
self._wait_until: Optional[float] = None
|
||||
self._pending_wait_tool_call_id: Optional[str] = None
|
||||
self._force_continue_until_reply = False
|
||||
self._force_continue_trigger_message_id = ""
|
||||
self._force_continue_trigger_reason = ""
|
||||
self._force_next_timing_continue = False
|
||||
self._force_next_timing_message_id = ""
|
||||
self._force_next_timing_reason = ""
|
||||
self._planner_interrupt_flag: Optional[asyncio.Event] = None
|
||||
self._planner_interrupt_requested = False
|
||||
self._planner_interrupt_consecutive_count = 0
|
||||
self._current_action_tool_names: set[str] = set()
|
||||
self.discovered_tool_names: set[str] = set()
|
||||
self.deferred_tool_specs_by_name: dict[str, ToolSpec] = {}
|
||||
self._planner_interrupt_max_consecutive_count = max(
|
||||
0,
|
||||
int(global_config.chat.planner_interrupt_max_consecutive_count),
|
||||
@@ -118,6 +122,18 @@ class MaisakaHeartFlowChatting:
|
||||
self._tool_registry = ToolRegistry()
|
||||
self._register_tool_providers()
|
||||
|
||||
def _update_stage_status(self, stage: str, detail: str = "", *, round_text: str = "") -> None:
|
||||
"""更新当前会话的阶段状态。"""
|
||||
|
||||
update_stage_status(
|
||||
session_id=self.session_id,
|
||||
session_name=self.session_name,
|
||||
stage=stage,
|
||||
detail=detail,
|
||||
round_text=round_text,
|
||||
agent_state=self._agent_state,
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动运行时主循环。"""
|
||||
if self._running:
|
||||
@@ -130,6 +146,7 @@ class MaisakaHeartFlowChatting:
|
||||
self._running = True
|
||||
self._ensure_background_tasks_running()
|
||||
self._schedule_message_turn()
|
||||
self._update_stage_status("空闲", "等待消息触发")
|
||||
logger.info(f"{self.log_prefix} Maisaka 运行时已启动")
|
||||
|
||||
async def stop(self) -> None:
|
||||
@@ -157,6 +174,7 @@ class MaisakaHeartFlowChatting:
|
||||
await self._tool_registry.close()
|
||||
self._mcp_manager = None
|
||||
self._mcp_host_bridge = None
|
||||
remove_stage_status(self.session_id)
|
||||
|
||||
logger.info(f"{self.log_prefix} Maisaka 运行时已停止")
|
||||
|
||||
@@ -175,9 +193,6 @@ class MaisakaHeartFlowChatting:
|
||||
self.message_cache.append(message)
|
||||
self._message_received_at_by_id[message.message_id] = received_at
|
||||
self._source_messages_by_id[message.message_id] = message
|
||||
if self._agent_state == self._STATE_WAIT:
|
||||
self._cancel_wait_timeout_task()
|
||||
self._wait_until = None
|
||||
if self._agent_state == self._STATE_RUNNING:
|
||||
self._message_debounce_required = True
|
||||
if self._agent_state == self._STATE_RUNNING and self._planner_interrupt_flag is not None:
|
||||
@@ -248,7 +263,6 @@ class MaisakaHeartFlowChatting:
|
||||
|
||||
def _record_reply_sent(self) -> None:
|
||||
"""在成功发送 reply 后记录本轮消息回复时长。"""
|
||||
self._clear_force_continue_until_reply()
|
||||
if self._reply_latency_measurement_started_at is None:
|
||||
return
|
||||
|
||||
@@ -308,26 +322,26 @@ class MaisakaHeartFlowChatting:
|
||||
if not message.is_at and not message.is_mentioned:
|
||||
return
|
||||
|
||||
self._arm_force_continue_until_reply(
|
||||
self._arm_force_next_timing_continue(
|
||||
message,
|
||||
is_at=message.is_at,
|
||||
is_mentioned=message.is_mentioned,
|
||||
)
|
||||
|
||||
def _arm_force_continue_until_reply(
|
||||
def _arm_force_next_timing_continue(
|
||||
self,
|
||||
message: SessionMessage,
|
||||
*,
|
||||
is_at: bool,
|
||||
is_mentioned: bool,
|
||||
) -> None:
|
||||
"""在检测到 @ 或提及时,要求后续轮次跳过 Timing Gate 直到成功 reply。"""
|
||||
"""在检测到 @ 或提及时,要求下一次 Timing Gate 直接 continue。"""
|
||||
|
||||
trigger_reason = "@消息" if is_at else "提及消息" if is_mentioned else "触发消息"
|
||||
was_armed = self._force_continue_until_reply
|
||||
self._force_continue_until_reply = True
|
||||
self._force_continue_trigger_message_id = message.message_id
|
||||
self._force_continue_trigger_reason = trigger_reason
|
||||
was_armed = self._force_next_timing_continue
|
||||
self._force_next_timing_continue = True
|
||||
self._force_next_timing_message_id = message.message_id
|
||||
self._force_next_timing_reason = trigger_reason
|
||||
|
||||
if was_armed:
|
||||
logger.info(
|
||||
@@ -337,34 +351,31 @@ class MaisakaHeartFlowChatting:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 检测到{trigger_reason},将跳过 Timing Gate 直到成功发送一条 reply;"
|
||||
f"{self.log_prefix} 检测到{trigger_reason},下一次 Timing Gate 将直接视作 continue;"
|
||||
f"消息编号={message.message_id}"
|
||||
)
|
||||
|
||||
def _clear_force_continue_until_reply(self) -> None:
|
||||
"""在成功发送 reply 后清理强制 continue 状态。"""
|
||||
def _consume_force_next_timing_continue_reason(self) -> str | None:
|
||||
"""消费一次性 Timing Gate continue 状态,并返回原因描述。"""
|
||||
|
||||
if not self._force_continue_until_reply:
|
||||
return
|
||||
if not self._force_next_timing_continue:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 已成功发送 reply,恢复 Timing Gate;"
|
||||
f"触发原因={self._force_continue_trigger_reason or '未知'} "
|
||||
f"触发消息编号={self._force_continue_trigger_message_id or 'unknown'}"
|
||||
)
|
||||
self._force_continue_until_reply = False
|
||||
self._force_continue_trigger_message_id = ""
|
||||
self._force_continue_trigger_reason = ""
|
||||
|
||||
def _build_force_continue_timing_reason(self) -> str:
|
||||
"""返回当前强制跳过 Timing Gate 的原因描述。"""
|
||||
|
||||
trigger_reason = self._force_continue_trigger_reason or "@/提及消息"
|
||||
trigger_message_id = self._force_continue_trigger_message_id or "unknown"
|
||||
return (
|
||||
trigger_reason = self._force_next_timing_reason or "@/提及消息"
|
||||
trigger_message_id = self._force_next_timing_message_id or "unknown"
|
||||
reason = (
|
||||
f"检测到新的{trigger_reason}(消息编号={trigger_message_id}),"
|
||||
"本轮直接跳过 Timing Gate 并视作 continue,直到成功发送一条 reply。"
|
||||
"本轮直接跳过 Timing Gate 并视作 continue。"
|
||||
)
|
||||
logger.info(
|
||||
f"{self.log_prefix} 已结束本次强制 continue,恢复 Timing Gate;"
|
||||
f"触发原因={trigger_reason} "
|
||||
f"触发消息编号={trigger_message_id}"
|
||||
)
|
||||
self._force_next_timing_continue = False
|
||||
self._force_next_timing_message_id = ""
|
||||
self._force_next_timing_reason = ""
|
||||
return reason
|
||||
|
||||
def _bind_planner_interrupt_flag(self, interrupt_flag: asyncio.Event) -> None:
|
||||
"""绑定当前可打断请求使用的中断标记。"""
|
||||
@@ -426,6 +437,7 @@ class MaisakaHeartFlowChatting:
|
||||
|
||||
selected_history, _ = MaisakaChatLoopService.select_llm_context_messages(
|
||||
self._chat_history,
|
||||
request_kind=request_kind,
|
||||
max_context_size=context_message_limit,
|
||||
)
|
||||
sub_agent_history = list(selected_history)
|
||||
@@ -447,11 +459,133 @@ class MaisakaHeartFlowChatting:
|
||||
tool_definitions=[] if tool_definitions is None else tool_definitions,
|
||||
)
|
||||
|
||||
def set_current_action_tool_names(self, tool_names: Sequence[str]) -> None:
|
||||
"""记录当前 Action Loop 已实际暴露给 planner 的工具名集合。"""
|
||||
|
||||
self._current_action_tool_names = {tool_name for tool_name in tool_names if str(tool_name).strip()}
|
||||
|
||||
def is_action_tool_currently_available(self, tool_name: str) -> bool:
|
||||
"""判断指定工具在当前 Action Loop 轮次中是否真实可用。"""
|
||||
|
||||
normalized_name = str(tool_name).strip()
|
||||
return bool(normalized_name) and normalized_name in self._current_action_tool_names
|
||||
|
||||
def update_deferred_tool_specs(self, deferred_tool_specs: Sequence[ToolSpec]) -> None:
|
||||
"""刷新当前会话的 deferred tools 池,并清理失效的已发现工具。"""
|
||||
|
||||
next_specs_by_name: dict[str, ToolSpec] = {}
|
||||
for tool_spec in deferred_tool_specs:
|
||||
normalized_name = tool_spec.name.strip()
|
||||
if not normalized_name:
|
||||
continue
|
||||
next_specs_by_name[normalized_name] = tool_spec
|
||||
|
||||
self.deferred_tool_specs_by_name = next_specs_by_name
|
||||
self.discovered_tool_names.intersection_update(next_specs_by_name.keys())
|
||||
|
||||
def get_discovered_deferred_tool_specs(self) -> list[ToolSpec]:
|
||||
"""返回当前会话中已发现、且仍然有效的 deferred tools。"""
|
||||
|
||||
return [
|
||||
tool_spec
|
||||
for tool_name, tool_spec in self.deferred_tool_specs_by_name.items()
|
||||
if tool_name in self.discovered_tool_names
|
||||
]
|
||||
|
||||
def build_deferred_tools_reminder(self) -> str:
|
||||
"""构造供 planner 使用的 deferred tools 提示消息。"""
|
||||
|
||||
undiscovered_tool_specs = [
|
||||
tool_spec
|
||||
for tool_name, tool_spec in self.deferred_tool_specs_by_name.items()
|
||||
if tool_name not in self.discovered_tool_names
|
||||
]
|
||||
if not undiscovered_tool_specs:
|
||||
return ""
|
||||
|
||||
tool_lines: list[str] = []
|
||||
for index, tool_spec in enumerate(undiscovered_tool_specs, start=1):
|
||||
tool_name = tool_spec.name.strip()
|
||||
tool_description = tool_spec.brief_description.strip()
|
||||
if tool_description:
|
||||
tool_lines.append(f"{index}. {tool_name}: {tool_description}")
|
||||
else:
|
||||
tool_lines.append(f"{index}. {tool_name}")
|
||||
|
||||
reminder_lines = [
|
||||
"<system-reminder>",
|
||||
"以下工具当前未直接暴露给你,但可以通过 tool_search 工具发现并在后续轮次中使用:",
|
||||
*tool_lines,
|
||||
"",
|
||||
"如需其中某个工具,请先调用 tool_search。tool_search 只负责发现工具,不直接执行业务。",
|
||||
"</system-reminder>",
|
||||
]
|
||||
return "\n".join(reminder_lines)
|
||||
|
||||
def search_deferred_tool_specs(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
limit: int,
|
||||
) -> list[ToolSpec]:
|
||||
"""按名称或简要描述搜索 deferred tools。"""
|
||||
|
||||
normalized_query = " ".join(query.lower().split()).strip()
|
||||
if not normalized_query:
|
||||
return []
|
||||
|
||||
scored_matches: list[tuple[int, str, ToolSpec]] = []
|
||||
query_terms = [term for term in normalized_query.replace("_", " ").replace("-", " ").split() if term]
|
||||
for tool_name, tool_spec in self.deferred_tool_specs_by_name.items():
|
||||
lower_name = tool_name.lower()
|
||||
lower_description = tool_spec.brief_description.lower()
|
||||
score = 0
|
||||
|
||||
if normalized_query == lower_name:
|
||||
score += 1000
|
||||
if lower_name.startswith(normalized_query):
|
||||
score += 300
|
||||
if normalized_query in lower_name:
|
||||
score += 200
|
||||
if normalized_query in lower_description:
|
||||
score += 100
|
||||
|
||||
for query_term in query_terms:
|
||||
if query_term in lower_name:
|
||||
score += 25
|
||||
if query_term in lower_description:
|
||||
score += 10
|
||||
|
||||
if score <= 0:
|
||||
continue
|
||||
|
||||
scored_matches.append((score, tool_name, tool_spec))
|
||||
|
||||
scored_matches.sort(key=lambda item: (-item[0], item[1]))
|
||||
return [tool_spec for _, _, tool_spec in scored_matches[: max(1, limit)]]
|
||||
|
||||
def discover_deferred_tools(self, tool_names: Sequence[str]) -> list[str]:
|
||||
"""将指定 deferred tools 标记为已发现,并返回本次新发现的工具名。"""
|
||||
|
||||
newly_discovered_tool_names: list[str] = []
|
||||
for raw_tool_name in tool_names:
|
||||
normalized_name = str(raw_tool_name).strip()
|
||||
if not normalized_name or normalized_name not in self.deferred_tool_specs_by_name:
|
||||
continue
|
||||
if normalized_name in self.discovered_tool_names:
|
||||
continue
|
||||
self.discovered_tool_names.add(normalized_name)
|
||||
newly_discovered_tool_names.append(normalized_name)
|
||||
return newly_discovered_tool_names
|
||||
|
||||
def _has_pending_messages(self) -> bool:
|
||||
return self._last_processed_index < len(self.message_cache)
|
||||
|
||||
def _schedule_message_turn(self) -> None:
|
||||
"""为当前待处理消息安排一次内部 turn。"""
|
||||
if self._agent_state == self._STATE_WAIT:
|
||||
return
|
||||
|
||||
if not self._has_pending_messages() or self._message_turn_scheduled:
|
||||
return
|
||||
|
||||
@@ -531,8 +665,9 @@ class MaisakaHeartFlowChatting:
|
||||
def _enter_wait_state(self, seconds: Optional[float] = None, tool_call_id: Optional[str] = None) -> None:
|
||||
"""切换到等待状态。"""
|
||||
self._agent_state = self._STATE_WAIT
|
||||
self._wait_until = None if seconds is None else time.time() + seconds
|
||||
self._pending_wait_tool_call_id = tool_call_id
|
||||
self._message_turn_scheduled = False
|
||||
self._cancel_deferred_message_turn_task()
|
||||
self._cancel_wait_timeout_task()
|
||||
if seconds is not None:
|
||||
self._wait_timeout_task = asyncio.create_task(
|
||||
@@ -542,7 +677,6 @@ class MaisakaHeartFlowChatting:
|
||||
def _enter_stop_state(self) -> None:
|
||||
"""切换到停止状态。"""
|
||||
self._agent_state = self._STATE_STOP
|
||||
self._wait_until = None
|
||||
self._pending_wait_tool_call_id = None
|
||||
self._cancel_wait_timeout_task()
|
||||
|
||||
@@ -567,7 +701,6 @@ class MaisakaHeartFlowChatting:
|
||||
|
||||
logger.info(f"{self.log_prefix} Maisaka 等待已超时")
|
||||
self._agent_state = self._STATE_RUNNING
|
||||
self._wait_until = None
|
||||
await self._internal_turn_queue.put("timeout")
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
@@ -616,7 +749,7 @@ class MaisakaHeartFlowChatting:
|
||||
return True
|
||||
|
||||
async def _trigger_expression_learning(self, messages: list[SessionMessage]) -> None:
|
||||
"""?????????????????"""
|
||||
"""触发表达方式学习"""
|
||||
pending_count = self._expression_learner.get_pending_count(self.message_cache)
|
||||
if not self._should_trigger_learning(
|
||||
enabled=self._enable_expression_learning,
|
||||
@@ -629,21 +762,21 @@ class MaisakaHeartFlowChatting:
|
||||
|
||||
self._last_expression_extraction_time = time.time()
|
||||
logger.info(
|
||||
f"{self.log_prefix} ??????: "
|
||||
f"??????={len(messages)} ??????={pending_count} "
|
||||
f"?????={len(self.message_cache)} "
|
||||
f"??????={self._enable_jargon_learning}"
|
||||
f"{self.log_prefix} 触发表达方式学习: "
|
||||
f"消息数量={len(messages)} 待处理消息数量={pending_count} "
|
||||
f"缓存总量={len(self.message_cache)} "
|
||||
f"是否启用黑话学习={self._enable_jargon_learning}"
|
||||
)
|
||||
|
||||
try:
|
||||
jargon_miner = self._jargon_miner if self._enable_jargon_learning else None
|
||||
learnt_style = await self._expression_learner.learn(self.message_cache, jargon_miner)
|
||||
if learnt_style:
|
||||
logger.info(f"{self.log_prefix} ???????")
|
||||
logger.info(f"{self.log_prefix} 表达方式学习成功")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} ???????????????")
|
||||
logger.debug(f"{self.log_prefix} 表达方式学习失败")
|
||||
except Exception:
|
||||
logger.exception(f"{self.log_prefix} ??????")
|
||||
logger.exception(f"{self.log_prefix} 表达方式学习异常")
|
||||
|
||||
async def _init_mcp(self) -> None:
|
||||
"""初始化 MCP 工具并注册到统一工具层。"""
|
||||
@@ -655,12 +788,12 @@ class MaisakaHeartFlowChatting:
|
||||
host_callbacks=self._mcp_host_bridge.build_callbacks(),
|
||||
)
|
||||
if self._mcp_manager is None:
|
||||
logger.info(f"{self.log_prefix} MCP 管理器不可用")
|
||||
logger.info(f"{self.log_prefix} Maisaka MCP 管理器不可用")
|
||||
return
|
||||
|
||||
mcp_tool_specs = self._mcp_manager.get_tool_specs()
|
||||
if not mcp_tool_specs:
|
||||
logger.info(f"{self.log_prefix} 没有可供 Maisaka 使用的 MCP 工具")
|
||||
logger.info(f"{self.log_prefix} Maisaka 没有可供使用的 MCP 工具")
|
||||
return
|
||||
|
||||
self._tool_registry.register_provider(MCPToolProvider(self._mcp_manager))
|
||||
@@ -694,6 +827,7 @@ class MaisakaHeartFlowChatting:
|
||||
self,
|
||||
*,
|
||||
cycle_id: Optional[int] = None,
|
||||
time_records: Optional[dict[str, float]] = None,
|
||||
timing_selected_history_count: Optional[int] = None,
|
||||
timing_prompt_tokens: Optional[int] = None,
|
||||
timing_action: str = "",
|
||||
@@ -709,6 +843,7 @@ class MaisakaHeartFlowChatting:
|
||||
planner_tool_results: Optional[list[str]] = None,
|
||||
planner_tool_detail_results: Optional[list[dict[str, Any]]] = None,
|
||||
planner_prompt_section: Optional[RenderableType] = None,
|
||||
planner_extra_lines: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
"""在终端展示当前聊天流本轮 cycle 的最终结果。"""
|
||||
if not global_config.debug.show_maisaka_thinking:
|
||||
@@ -721,6 +856,7 @@ class MaisakaHeartFlowChatting:
|
||||
if cycle_id is not None:
|
||||
body_lines.append(f"循环编号:{cycle_id}")
|
||||
|
||||
panel_subtitle = self._build_cycle_time_records_text(time_records or {})
|
||||
renderables: list[RenderableType] = [Text("\n".join(body_lines))]
|
||||
timing_panel = self._build_cycle_stage_panel(
|
||||
title="Timing Gate",
|
||||
@@ -728,33 +864,49 @@ class MaisakaHeartFlowChatting:
|
||||
selected_history_count=timing_selected_history_count,
|
||||
prompt_tokens=timing_prompt_tokens,
|
||||
response_text=timing_response,
|
||||
tool_calls=timing_tool_calls,
|
||||
tool_results=timing_tool_results,
|
||||
tool_detail_results=timing_tool_detail_results,
|
||||
prompt_section=timing_prompt_section,
|
||||
extra_lines=[f"门控动作:{timing_action}"] if timing_action.strip() else None,
|
||||
)
|
||||
if timing_panel is not None:
|
||||
renderables.append(timing_panel)
|
||||
|
||||
timing_tool_cards = self._build_tool_activity_cards(
|
||||
stage_title="Timing Tool",
|
||||
tool_calls=timing_tool_calls,
|
||||
tool_results=timing_tool_results,
|
||||
tool_detail_results=timing_tool_detail_results,
|
||||
planner_style=False,
|
||||
)
|
||||
if timing_tool_cards:
|
||||
renderables.extend(timing_tool_cards)
|
||||
|
||||
planner_panel = self._build_cycle_stage_panel(
|
||||
title="Planner",
|
||||
border_style="green",
|
||||
selected_history_count=planner_selected_history_count,
|
||||
prompt_tokens=planner_prompt_tokens,
|
||||
response_text=planner_response,
|
||||
tool_calls=planner_tool_calls,
|
||||
tool_results=planner_tool_results,
|
||||
tool_detail_results=planner_tool_detail_results,
|
||||
prompt_section=planner_prompt_section,
|
||||
extra_lines=planner_extra_lines,
|
||||
)
|
||||
if planner_panel is not None:
|
||||
renderables.append(planner_panel)
|
||||
|
||||
planner_tool_cards = self._build_tool_activity_cards(
|
||||
stage_title="Planner Tool",
|
||||
tool_calls=planner_tool_calls,
|
||||
tool_results=planner_tool_results,
|
||||
tool_detail_results=planner_tool_detail_results,
|
||||
planner_style=True,
|
||||
)
|
||||
if planner_tool_cards:
|
||||
renderables.extend(planner_tool_cards)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Group(*renderables),
|
||||
title="MaiSaka 循环",
|
||||
subtitle=panel_subtitle,
|
||||
border_style="bright_blue",
|
||||
padding=(0, 1),
|
||||
)
|
||||
@@ -768,9 +920,6 @@ class MaisakaHeartFlowChatting:
|
||||
selected_history_count: Optional[int],
|
||||
prompt_tokens: Optional[int],
|
||||
response_text: str = "",
|
||||
tool_calls: Optional[list[Any]] = None,
|
||||
tool_results: Optional[list[str]] = None,
|
||||
tool_detail_results: Optional[list[dict[str, Any]]] = None,
|
||||
prompt_section: Optional[RenderableType] = None,
|
||||
extra_lines: Optional[list[str]] = None,
|
||||
) -> Optional[Panel]:
|
||||
@@ -780,9 +929,6 @@ class MaisakaHeartFlowChatting:
|
||||
selected_history_count is not None,
|
||||
prompt_tokens is not None,
|
||||
bool(response_text.strip()),
|
||||
bool(tool_calls),
|
||||
bool(tool_results),
|
||||
bool(tool_detail_results),
|
||||
prompt_section is not None,
|
||||
bool(extra_lines),
|
||||
])
|
||||
@@ -809,40 +955,11 @@ class MaisakaHeartFlowChatting:
|
||||
Panel(
|
||||
Text(normalized_response),
|
||||
title="Maisaka 返回",
|
||||
border_style="green",
|
||||
border_style=border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
normalized_tool_calls = build_tool_call_summary_lines(tool_calls or [])
|
||||
if normalized_tool_calls:
|
||||
renderables.append(
|
||||
Panel(
|
||||
Text("\n".join(normalized_tool_calls)),
|
||||
title="工具调用",
|
||||
border_style="magenta",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
normalized_tool_results = self._filter_redundant_tool_results(
|
||||
tool_results=tool_results or [],
|
||||
tool_detail_results=tool_detail_results or [],
|
||||
)
|
||||
if normalized_tool_results:
|
||||
renderables.append(
|
||||
Panel(
|
||||
Text("\n".join(normalized_tool_results)),
|
||||
title="工具结果",
|
||||
border_style="yellow",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
detail_panels = self._build_tool_detail_panels(tool_detail_results or [])
|
||||
if detail_panels:
|
||||
renderables.extend(detail_panels)
|
||||
|
||||
return Panel(
|
||||
Group(*renderables),
|
||||
title=title,
|
||||
@@ -850,6 +967,75 @@ class MaisakaHeartFlowChatting:
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
def _build_tool_activity_cards(
|
||||
self,
|
||||
*,
|
||||
stage_title: str,
|
||||
tool_calls: Optional[list[Any]] = None,
|
||||
tool_results: Optional[list[str]] = None,
|
||||
tool_detail_results: Optional[list[dict[str, Any]]] = None,
|
||||
planner_style: bool = False,
|
||||
) -> list[RenderableType]:
|
||||
"""构建与阶段同级的工具执行卡片列表。"""
|
||||
|
||||
detail_results = tool_detail_results or []
|
||||
cards = self._build_tool_detail_cards(
|
||||
detail_results,
|
||||
stage_title=stage_title,
|
||||
planner_style=planner_style,
|
||||
)
|
||||
if cards:
|
||||
return cards
|
||||
|
||||
# 兼容旧数据结构:若尚无 detail,则降级为简单文本卡片。
|
||||
fallback_lines = self._filter_redundant_tool_results(
|
||||
tool_results=tool_results or [],
|
||||
tool_detail_results=detail_results,
|
||||
)
|
||||
if not fallback_lines and tool_calls:
|
||||
fallback_lines = build_tool_call_summary_lines(tool_calls)
|
||||
if not fallback_lines:
|
||||
return []
|
||||
|
||||
fallback_border_style = "yellow"
|
||||
return [
|
||||
Panel(
|
||||
Text("\n".join(fallback_lines)),
|
||||
title=stage_title,
|
||||
border_style=fallback_border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _build_cycle_time_records_text(time_records: dict[str, float]) -> str:
|
||||
"""构建循环最外层面板展示的阶段耗时文本。"""
|
||||
|
||||
if not time_records:
|
||||
return "流程耗时:无"
|
||||
|
||||
label_map = {
|
||||
"timing_gate": "Timing Gate",
|
||||
"planner": "Planner",
|
||||
"tool_calls": "工具执行",
|
||||
}
|
||||
ordered_keys = ["timing_gate", "planner", "tool_calls"]
|
||||
|
||||
parts: list[str] = []
|
||||
for key in ordered_keys:
|
||||
duration = time_records.get(key)
|
||||
if isinstance(duration, (int, float)):
|
||||
parts.append(f"{label_map.get(key, key)} {float(duration):.2f} s")
|
||||
|
||||
for key, duration in time_records.items():
|
||||
if key in ordered_keys or not isinstance(duration, (int, float)):
|
||||
continue
|
||||
parts.append(f"{label_map.get(key, key)} {float(duration):.2f} s")
|
||||
|
||||
if not parts:
|
||||
return "流程耗时:无"
|
||||
return "流程耗时:" + " | ".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _filter_redundant_tool_results(
|
||||
*,
|
||||
@@ -941,7 +1127,9 @@ class MaisakaHeartFlowChatting:
|
||||
*,
|
||||
tool_name: str,
|
||||
prompt_text: str,
|
||||
request_messages: Optional[list[Any]] = None,
|
||||
tool_call_id: str,
|
||||
border_style: str = "bright_yellow",
|
||||
) -> Panel:
|
||||
"""将工具 prompt 渲染为可点击查看的预览入口。"""
|
||||
|
||||
@@ -950,6 +1138,26 @@ class MaisakaHeartFlowChatting:
|
||||
if tool_call_id:
|
||||
subtitle += f"\n调用ID: {tool_call_id}"
|
||||
|
||||
if isinstance(request_messages, list) and request_messages:
|
||||
try:
|
||||
normalized_messages = deserialize_prompt_messages(request_messages)
|
||||
except Exception as exc:
|
||||
logger.warning(f"工具 {tool_name} 的 request_messages 无法反序列化,已回退为文本预览: {exc}")
|
||||
else:
|
||||
return Panel(
|
||||
PromptCLIVisualizer.build_prompt_access_panel(
|
||||
normalized_messages,
|
||||
category=labels["prompt_category"],
|
||||
chat_id=self.session_id,
|
||||
request_kind=labels["request_kind"],
|
||||
selection_reason=subtitle,
|
||||
image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy",
|
||||
),
|
||||
title=labels["prompt_title"],
|
||||
border_style=border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
return Panel(
|
||||
PromptCLIVisualizer.build_text_access_panel(
|
||||
prompt_text,
|
||||
@@ -959,116 +1167,235 @@ class MaisakaHeartFlowChatting:
|
||||
subtitle=subtitle,
|
||||
),
|
||||
title=labels["prompt_title"],
|
||||
border_style="bright_yellow",
|
||||
border_style=border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
def _build_tool_detail_panels(self, tool_detail_results: list[dict[str, Any]]) -> list[RenderableType]:
|
||||
"""将 tool monitor detail 渲染为 CLI 详情卡片。"""
|
||||
def _normalize_tool_card_body_lines(self, body: Any) -> list[str]:
|
||||
"""将工具卡片正文规范化为行列表。"""
|
||||
|
||||
if isinstance(body, str):
|
||||
return [line for line in body.splitlines() if line.strip()]
|
||||
if isinstance(body, list):
|
||||
return [
|
||||
str(item).strip()
|
||||
for item in body
|
||||
if str(item).strip()
|
||||
]
|
||||
return []
|
||||
|
||||
def _build_custom_tool_sub_cards(
|
||||
self,
|
||||
sub_cards: Any,
|
||||
*,
|
||||
default_border_style: str,
|
||||
) -> list[RenderableType]:
|
||||
"""构建工具自定义子卡片。"""
|
||||
|
||||
if not isinstance(sub_cards, list):
|
||||
return []
|
||||
|
||||
renderables: list[RenderableType] = []
|
||||
for sub_card in sub_cards:
|
||||
if not isinstance(sub_card, dict):
|
||||
continue
|
||||
title = str(sub_card.get("title") or "").strip() or "附加信息"
|
||||
border_style = str(sub_card.get("border_style") or "").strip() or default_border_style
|
||||
body_lines = self._normalize_tool_card_body_lines(
|
||||
sub_card.get("body_lines", sub_card.get("content", ""))
|
||||
)
|
||||
if not body_lines:
|
||||
continue
|
||||
renderables.append(
|
||||
Panel(
|
||||
Text("\n".join(body_lines)),
|
||||
title=title,
|
||||
border_style=border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
return renderables
|
||||
|
||||
def _build_default_tool_detail_parts(
|
||||
self,
|
||||
*,
|
||||
tool_name: str,
|
||||
tool_call_id: str,
|
||||
tool_args: Any,
|
||||
summary: str,
|
||||
duration_ms: Any,
|
||||
detail: dict[str, Any],
|
||||
planner_style: bool,
|
||||
) -> list[RenderableType]:
|
||||
"""构建工具卡片默认内容块。"""
|
||||
|
||||
argument_border_style = "yellow"
|
||||
metrics_border_style = "bright_yellow"
|
||||
prompt_border_style = "bright_yellow"
|
||||
reasoning_border_style = "yellow"
|
||||
output_border_style = "bright_yellow"
|
||||
extra_info_border_style = "yellow"
|
||||
detail_labels = self._get_tool_detail_labels(tool_name)
|
||||
|
||||
parts: list[RenderableType] = []
|
||||
header_lines: list[str] = []
|
||||
if summary:
|
||||
header_lines.append(summary)
|
||||
if tool_call_id:
|
||||
header_lines.append(f"调用ID:{tool_call_id}")
|
||||
if isinstance(duration_ms, (int, float)):
|
||||
header_lines.append(f"执行耗时:{round(float(duration_ms), 2)} ms")
|
||||
if header_lines:
|
||||
parts.append(Text("\n".join(header_lines)))
|
||||
|
||||
if isinstance(tool_args, dict) and tool_args:
|
||||
parts.append(
|
||||
Panel(
|
||||
Pretty(tool_args, expand_all=True),
|
||||
title="工具参数",
|
||||
border_style=argument_border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
metrics = detail.get("metrics")
|
||||
if isinstance(metrics, dict):
|
||||
metrics_text = self._build_tool_metrics_text(metrics)
|
||||
if metrics_text:
|
||||
parts.append(
|
||||
Panel(
|
||||
Text(metrics_text),
|
||||
title="执行指标",
|
||||
border_style=metrics_border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
prompt_text = str(detail.get("prompt_text") or "").strip()
|
||||
if prompt_text:
|
||||
parts.append(
|
||||
self._build_tool_prompt_access_panel(
|
||||
tool_name=tool_name,
|
||||
prompt_text=prompt_text,
|
||||
request_messages=detail.get("request_messages") if isinstance(detail.get("request_messages"), list) else None,
|
||||
tool_call_id=tool_call_id,
|
||||
border_style=prompt_border_style,
|
||||
)
|
||||
)
|
||||
|
||||
reasoning_text = str(detail.get("reasoning_text") or "").strip()
|
||||
if reasoning_text:
|
||||
parts.append(
|
||||
Panel(
|
||||
Text(reasoning_text),
|
||||
title=detail_labels["reasoning_title"],
|
||||
border_style=reasoning_border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
output_text = str(detail.get("output_text") or "").strip()
|
||||
if output_text:
|
||||
parts.append(
|
||||
Panel(
|
||||
Text(output_text),
|
||||
title=detail_labels["output_title"],
|
||||
border_style=output_border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
extra_sections = detail.get("extra_sections")
|
||||
if isinstance(extra_sections, list):
|
||||
for section in extra_sections:
|
||||
if not isinstance(section, dict):
|
||||
continue
|
||||
section_title = str(section.get("title") or "").strip() or "附加信息"
|
||||
section_content = str(section.get("content") or "").strip()
|
||||
if not section_content:
|
||||
continue
|
||||
parts.append(
|
||||
Panel(
|
||||
Text(section_content),
|
||||
title=section_title,
|
||||
border_style=extra_info_border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
return parts
|
||||
|
||||
def _build_tool_detail_cards(
|
||||
self,
|
||||
tool_detail_results: list[dict[str, Any]],
|
||||
*,
|
||||
stage_title: str,
|
||||
planner_style: bool = False,
|
||||
) -> list[RenderableType]:
|
||||
"""将 tool monitor detail 渲染为与 Planner/Timing 平级的工具卡片。"""
|
||||
|
||||
detail_panel_border_style = "yellow"
|
||||
sub_card_border_style = "bright_yellow"
|
||||
|
||||
panels: list[RenderableType] = []
|
||||
for tool_result in tool_detail_results:
|
||||
detail = tool_result.get("detail")
|
||||
if not isinstance(detail, dict) or not detail:
|
||||
continue
|
||||
|
||||
detail_dict = detail if isinstance(detail, dict) else {}
|
||||
tool_name = str(tool_result.get("tool_name") or "unknown").strip() or "unknown"
|
||||
detail_labels = self._get_tool_detail_labels(tool_name)
|
||||
tool_title = str(tool_result.get("tool_title") or "").strip() or tool_name
|
||||
tool_call_id = str(tool_result.get("tool_call_id") or "").strip()
|
||||
tool_args = tool_result.get("tool_args")
|
||||
summary = str(tool_result.get("summary") or "").strip()
|
||||
duration_ms = tool_result.get("duration_ms")
|
||||
custom_card = tool_result.get("card")
|
||||
|
||||
parts: list[RenderableType] = []
|
||||
header_lines: list[str] = []
|
||||
if summary:
|
||||
header_lines.append(summary)
|
||||
if tool_call_id:
|
||||
header_lines.append(f"调用ID:{tool_call_id}")
|
||||
if isinstance(duration_ms, (int, float)):
|
||||
header_lines.append(f"执行耗时:{round(float(duration_ms), 2)} ms")
|
||||
if header_lines:
|
||||
parts.append(Text("\n".join(header_lines)))
|
||||
|
||||
if isinstance(tool_args, dict) and tool_args:
|
||||
parts.append(
|
||||
Panel(
|
||||
Pretty(tool_args, expand_all=True),
|
||||
title="工具参数",
|
||||
border_style="cyan",
|
||||
padding=(0, 1),
|
||||
)
|
||||
custom_title = ""
|
||||
card_border_style = detail_panel_border_style
|
||||
replace_default_children = False
|
||||
if isinstance(custom_card, dict):
|
||||
custom_title = str(custom_card.get("title") or "").strip()
|
||||
card_border_style = str(custom_card.get("border_style") or "").strip() or detail_panel_border_style
|
||||
replace_default_children = bool(custom_card.get("replace_default_children", False))
|
||||
custom_body_lines = self._normalize_tool_card_body_lines(
|
||||
custom_card.get("body_lines", custom_card.get("content", ""))
|
||||
)
|
||||
if custom_body_lines:
|
||||
parts.append(Text("\n".join(custom_body_lines)))
|
||||
|
||||
metrics = detail.get("metrics")
|
||||
if isinstance(metrics, dict):
|
||||
metrics_text = self._build_tool_metrics_text(metrics)
|
||||
if metrics_text:
|
||||
parts.append(
|
||||
Panel(
|
||||
Text(metrics_text),
|
||||
title="执行指标",
|
||||
border_style="bright_cyan",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
prompt_text = str(detail.get("prompt_text") or "").strip()
|
||||
if prompt_text:
|
||||
parts.append(
|
||||
self._build_tool_prompt_access_panel(
|
||||
if not replace_default_children:
|
||||
parts.extend(
|
||||
self._build_default_tool_detail_parts(
|
||||
tool_name=tool_name,
|
||||
prompt_text=prompt_text,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_args=tool_args,
|
||||
summary=summary,
|
||||
duration_ms=duration_ms,
|
||||
detail=detail_dict,
|
||||
planner_style=planner_style,
|
||||
)
|
||||
)
|
||||
|
||||
reasoning_text = str(detail.get("reasoning_text") or "").strip()
|
||||
if reasoning_text:
|
||||
parts.append(
|
||||
Panel(
|
||||
Text(reasoning_text),
|
||||
title=detail_labels["reasoning_title"],
|
||||
border_style="magenta",
|
||||
padding=(0, 1),
|
||||
if isinstance(custom_card, dict):
|
||||
parts.extend(
|
||||
self._build_custom_tool_sub_cards(
|
||||
custom_card.get("sub_cards"),
|
||||
default_border_style=sub_card_border_style,
|
||||
)
|
||||
)
|
||||
|
||||
output_text = str(detail.get("output_text") or "").strip()
|
||||
if output_text:
|
||||
parts.append(
|
||||
Panel(
|
||||
Text(output_text),
|
||||
title=detail_labels["output_title"],
|
||||
border_style="green",
|
||||
padding=(0, 1),
|
||||
)
|
||||
parts.extend(
|
||||
self._build_custom_tool_sub_cards(
|
||||
tool_result.get("sub_cards"),
|
||||
default_border_style=sub_card_border_style,
|
||||
)
|
||||
|
||||
extra_sections = detail.get("extra_sections")
|
||||
if isinstance(extra_sections, list):
|
||||
for section in extra_sections:
|
||||
if not isinstance(section, dict):
|
||||
continue
|
||||
section_title = str(section.get("title") or "").strip() or "附加信息"
|
||||
section_content = str(section.get("content") or "").strip()
|
||||
if not section_content:
|
||||
continue
|
||||
parts.append(
|
||||
Panel(
|
||||
Text(section_content),
|
||||
title=section_title,
|
||||
border_style="white",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if parts:
|
||||
panels.append(
|
||||
Panel(
|
||||
Group(*parts),
|
||||
title=f"{tool_name} 工具详情",
|
||||
border_style="yellow",
|
||||
title=custom_title or f"{stage_title} · {tool_title}",
|
||||
border_style=card_border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -521,9 +521,7 @@ class MCPHostLLMBridge:
|
||||
tool_definitions.append(
|
||||
{
|
||||
"name": tool_name,
|
||||
"description": "\n\n".join(
|
||||
part for part in [brief_description, detailed_description] if part.strip()
|
||||
).strip(),
|
||||
"description": brief_description,
|
||||
"parameters_schema": parameters_schema or {"type": "object", "properties": {}},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -672,6 +672,32 @@ class ComponentQueryService:
|
||||
collected_specs[entry.name] = self._build_tool_spec(entry) # type: ignore[arg-type]
|
||||
return collected_specs
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_context_payload(context: Optional[ToolExecutionContext]) -> Dict[str, Any]:
|
||||
"""提取插件工具可复用的会话上下文字段。"""
|
||||
|
||||
if context is None:
|
||||
return {}
|
||||
|
||||
payload: Dict[str, Any] = {}
|
||||
stream_id = str(context.stream_id or context.session_id or "").strip()
|
||||
if stream_id:
|
||||
payload["stream_id"] = stream_id
|
||||
payload["chat_id"] = stream_id
|
||||
|
||||
anchor_message = context.metadata.get("anchor_message")
|
||||
message_info = getattr(anchor_message, "message_info", None)
|
||||
group_info = getattr(message_info, "group_info", None)
|
||||
user_info = getattr(message_info, "user_info", None)
|
||||
|
||||
group_id = str(getattr(group_info, "group_id", "") or "").strip()
|
||||
user_id = str(getattr(user_info, "user_id", "") or "").strip()
|
||||
if group_id:
|
||||
payload["group_id"] = group_id
|
||||
if user_id:
|
||||
payload["user_id"] = user_id
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_invocation_payload(
|
||||
entry: "ToolEntry",
|
||||
@@ -690,16 +716,27 @@ class ComponentQueryService:
|
||||
"""
|
||||
|
||||
payload = dict(invocation.arguments)
|
||||
context_payload = ComponentQueryService._build_tool_context_payload(context)
|
||||
if entry.invoke_method == "plugin.invoke_action":
|
||||
stream_id = context.stream_id if context is not None else invocation.stream_id
|
||||
stream_id = str(
|
||||
context_payload.get("stream_id")
|
||||
or (context.stream_id if context is not None else invocation.stream_id)
|
||||
or invocation.stream_id
|
||||
).strip()
|
||||
reasoning = context.reasoning if context is not None else invocation.reasoning
|
||||
payload = {
|
||||
**payload,
|
||||
**{key: value for key, value in context_payload.items() if key not in payload or not payload.get(key)},
|
||||
"stream_id": stream_id,
|
||||
"chat_id": stream_id,
|
||||
"reasoning": reasoning,
|
||||
"action_data": dict(invocation.arguments),
|
||||
}
|
||||
return payload
|
||||
|
||||
for key, value in context_payload.items():
|
||||
if key not in payload or not payload.get(key):
|
||||
payload[key] = value
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -11,8 +11,7 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
|
||||
from src.chat.replyer.replyer_manager import replyer_manager
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
@@ -20,8 +19,8 @@ from src.common.logger import get_logger
|
||||
from src.core.types import ActionInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
from src.common.data_models.planned_action_data_models import PlannedAction
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -38,7 +37,7 @@ def _get_replyer(
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer | PrivateReplyer]:
|
||||
) -> Optional[MaisakaReplyGenerator]:
|
||||
"""获取回复器对象"""
|
||||
if not chat_id and not chat_stream:
|
||||
raise ValueError("chat_stream 和 chat_id 不可均为空")
|
||||
@@ -100,7 +99,7 @@ async def generate_reply(
|
||||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
chosen_actions: Optional[List["ActionPlannerInfo"]] = None,
|
||||
chosen_actions: Optional[List["PlannedAction"]] = None,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
|
||||
@@ -267,6 +267,46 @@ def _parse_data_url_image(image_url: str) -> Tuple[str, str]:
|
||||
return image_format, image_base64
|
||||
|
||||
|
||||
def _append_image_content(message_builder: MessageBuilder, content_item: Any) -> bool:
|
||||
"""向消息构建器追加图片片段。
|
||||
|
||||
兼容两种输入格式:
|
||||
1. 旧序列化格式中的 `(image_format, image_base64)` 元组。
|
||||
2. 标准字典片段中的 Data URL 或 `image_format`/`image_base64` 字段。
|
||||
"""
|
||||
|
||||
if isinstance(content_item, (tuple, list)) and len(content_item) == 2:
|
||||
image_format, image_base64 = content_item
|
||||
if not isinstance(image_format, str) or not isinstance(image_base64, str):
|
||||
raise ValueError("图片元组片段必须包含字符串类型的 image_format 和 image_base64")
|
||||
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
return True
|
||||
|
||||
if not isinstance(content_item, dict):
|
||||
return False
|
||||
|
||||
part_type = str(content_item.get("type", "text")).strip().lower()
|
||||
if part_type not in {"image", "image_url", "input_image"}:
|
||||
return False
|
||||
|
||||
image_url = content_item.get("image_url")
|
||||
if isinstance(image_url, dict):
|
||||
image_url = image_url.get("url")
|
||||
if isinstance(image_url, str):
|
||||
image_format, image_base64 = _parse_data_url_image(image_url)
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
return True
|
||||
|
||||
image_format = content_item.get("image_format")
|
||||
image_base64 = content_item.get("image_base64")
|
||||
if isinstance(image_format, str) and isinstance(image_base64, str):
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
return True
|
||||
|
||||
raise ValueError("图片片段缺少可识别的图片数据")
|
||||
|
||||
|
||||
def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None:
|
||||
"""将原始消息内容追加到内部消息构建器。
|
||||
|
||||
@@ -293,8 +333,10 @@ def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None
|
||||
if isinstance(content_item, str):
|
||||
message_builder.add_text_content(content_item)
|
||||
continue
|
||||
if _append_image_content(message_builder, content_item):
|
||||
continue
|
||||
if not isinstance(content_item, dict):
|
||||
raise ValueError("消息内容列表中仅支持字符串或字典片段")
|
||||
raise ValueError("消息内容列表中仅支持字符串、图片元组或字典片段")
|
||||
|
||||
part_type = str(content_item.get("type", "text")).strip().lower()
|
||||
if part_type == "text":
|
||||
@@ -304,22 +346,6 @@ def _append_content_parts(message_builder: MessageBuilder, content: Any) -> None
|
||||
message_builder.add_text_content(text_content)
|
||||
continue
|
||||
|
||||
if part_type in {"image", "image_url", "input_image"}:
|
||||
image_url = content_item.get("image_url")
|
||||
if isinstance(image_url, dict):
|
||||
image_url = image_url.get("url")
|
||||
if isinstance(image_url, str):
|
||||
image_format, image_base64 = _parse_data_url_image(image_url)
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
continue
|
||||
|
||||
image_format = content_item.get("image_format")
|
||||
image_base64 = content_item.get("image_base64")
|
||||
if isinstance(image_format, str) and isinstance(image_base64, str):
|
||||
message_builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
continue
|
||||
raise ValueError("图片片段缺少可识别的图片数据")
|
||||
|
||||
raise ValueError(f"不支持的消息片段类型: {part_type}")
|
||||
|
||||
|
||||
|
||||
@@ -326,7 +326,7 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
if emoji.is_registered:
|
||||
return EmojiUpdateResponse(success=True, message="??????????", data=emoji_to_response(emoji))
|
||||
return EmojiUpdateResponse(success=True, message="表情包已注册", data=emoji_to_response(emoji))
|
||||
|
||||
emoji.is_registered = True
|
||||
emoji.is_banned = False
|
||||
|
||||
Reference in New Issue
Block a user