feat:修复表达方式的学习和使用,用subagent使用表达
1
This commit is contained in:
@@ -2,18 +2,15 @@ import random
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from rich.console import Group, RenderableType
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
from sqlmodel import select
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.cli.console import console
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.common.data_models.reply_generation_data_models import (
|
||||
GenerationMetrics,
|
||||
@@ -22,7 +19,6 @@ from src.common.data_models.reply_generation_data_models import (
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.types import ActionInfo
|
||||
from src.llm_models.payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart
|
||||
@@ -35,6 +31,7 @@ from src.maisaka.context_messages import (
|
||||
SessionBackedMessage,
|
||||
ToolResultMessage,
|
||||
)
|
||||
from .maisaka_expression_selector import maisaka_expression_selector
|
||||
from src.maisaka.message_adapter import clone_message_sequence, parse_speaker_content
|
||||
from src.maisaka.prompt_cli_renderer import PromptCLIVisualizer
|
||||
|
||||
@@ -49,17 +46,8 @@ class MaisakaReplyContext:
|
||||
selected_expression_ids: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ExpressionRecord:
|
||||
"""表达方式的轻量记录。"""
|
||||
|
||||
expression_id: Optional[int]
|
||||
situation: str
|
||||
style: str
|
||||
|
||||
|
||||
class MaisakaReplyGenerator:
|
||||
"""生成 Maisaka 的最终可见回复。"""
|
||||
"""生成 Maisaka 的最终可见回复(多模态管线)。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -75,7 +63,7 @@ class MaisakaReplyGenerator:
|
||||
self._personality_prompt = self._build_personality_prompt()
|
||||
|
||||
def _build_personality_prompt(self) -> str:
|
||||
"""构建 replyer 使用的人设描述。"""
|
||||
"""构建 replyer 使用的人设提示。"""
|
||||
try:
|
||||
bot_name = global_config.bot.nickname
|
||||
alias_names = global_config.bot.alias_names
|
||||
@@ -117,7 +105,6 @@ class MaisakaReplyGenerator:
|
||||
|
||||
@staticmethod
|
||||
def _split_user_message_segments(raw_content: str) -> List[tuple[Optional[str], str]]:
|
||||
"""按说话人拆分用户消息。"""
|
||||
segments: List[tuple[Optional[str], str]] = []
|
||||
current_speaker: Optional[str] = None
|
||||
current_lines: List[str] = []
|
||||
@@ -139,7 +126,6 @@ class MaisakaReplyGenerator:
|
||||
return segments
|
||||
|
||||
def _build_target_message_block(self, reply_message: Optional[SessionMessage]) -> str:
|
||||
"""构建当前需要回复的目标消息摘要。"""
|
||||
if reply_message is None:
|
||||
return ""
|
||||
|
||||
@@ -155,7 +141,7 @@ class MaisakaReplyGenerator:
|
||||
f"- 目标消息ID:{target_message_id}\n"
|
||||
f"- 发送者:{sender_name}\n"
|
||||
f"- 消息内容:{target_content}\n"
|
||||
"- 你这次要回复的就是这条目标消息,请结合整段上下文理解,但不要误把其他历史消息当成当前回复对象。"
|
||||
"- 你这次要回复的就是这条目标消息,请结合整段上下文理解,但不要把其他历史消息当成当前回复对象。"
|
||||
)
|
||||
|
||||
def _build_system_prompt(
|
||||
@@ -164,7 +150,6 @@ class MaisakaReplyGenerator:
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
) -> str:
|
||||
"""构建 Maisaka replyer 使用的系统提示词。"""
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
target_message_block = self._build_target_message_block(reply_message)
|
||||
|
||||
@@ -179,27 +164,25 @@ class MaisakaReplyGenerator:
|
||||
except Exception:
|
||||
system_prompt = "你是一个友好的 AI 助手,请根据聊天记录自然回复。"
|
||||
|
||||
extra_sections: List[str] = []
|
||||
sections: List[str] = []
|
||||
if expression_habits.strip():
|
||||
extra_sections.append(expression_habits.strip())
|
||||
sections.append(expression_habits.strip())
|
||||
if target_message_block:
|
||||
extra_sections.append(target_message_block)
|
||||
sections.append(target_message_block)
|
||||
if reply_reason.strip():
|
||||
extra_sections.append(f"【回复信息参考】\n{reply_reason}")
|
||||
if not extra_sections:
|
||||
sections.append(f"【回复信息参考】\n{reply_reason}")
|
||||
if not sections:
|
||||
return system_prompt
|
||||
return f"{system_prompt}\n\n" + "\n\n".join(extra_sections)
|
||||
return f"{system_prompt}\n\n" + "\n\n".join(sections)
|
||||
|
||||
def _build_reply_instruction(self) -> str:
|
||||
"""构建追加在上下文末尾的回复指令。"""
|
||||
return "请基于以上逐条对话消息,自然地继续回复。直接输出你要说的话,不要额外解释。"
|
||||
return "请基于以上上下文,自然地继续回复。直接输出你要说的话,不需要额外解释。"
|
||||
|
||||
def _build_multimodal_user_message(
|
||||
self,
|
||||
message: SessionBackedMessage,
|
||||
default_user_name: str,
|
||||
) -> Optional[Message]:
|
||||
"""构建保留图片等多模态片段的用户消息。"""
|
||||
speaker_name, _ = parse_speaker_content(message.processed_plain_text.strip())
|
||||
visible_speaker = speaker_name or default_user_name
|
||||
|
||||
@@ -223,7 +206,6 @@ class MaisakaReplyGenerator:
|
||||
return multimodal_message.to_llm_message()
|
||||
|
||||
def _build_history_messages(self, chat_history: List[LLMContextMessage]) -> List[Message]:
|
||||
"""将 replyer 上下文拆成多条 LLM 消息。"""
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
default_user_name = global_config.maisaka.cli_user_name.strip() or "User"
|
||||
messages: List[Message] = []
|
||||
@@ -277,7 +259,6 @@ class MaisakaReplyGenerator:
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
) -> List[Message]:
|
||||
"""构建发给大模型的消息列表。"""
|
||||
messages: List[Message] = []
|
||||
system_prompt = self._build_system_prompt(
|
||||
reply_message=reply_message,
|
||||
@@ -293,7 +274,6 @@ class MaisakaReplyGenerator:
|
||||
|
||||
@staticmethod
|
||||
def _build_request_prompt_preview(messages: List[Message]) -> str:
|
||||
"""将消息列表转为便于调试的文本预览。"""
|
||||
preview_lines: List[str] = []
|
||||
for message in messages:
|
||||
role_name = message.role.value.capitalize()
|
||||
@@ -308,7 +288,6 @@ class MaisakaReplyGenerator:
|
||||
return "\n\n".join(preview_lines)
|
||||
|
||||
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
|
||||
"""解析当前回复使用的会话 ID。"""
|
||||
if stream_id:
|
||||
return stream_id
|
||||
if self.chat_stream is not None:
|
||||
@@ -321,109 +300,29 @@ class MaisakaReplyGenerator:
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
stream_id: Optional[str],
|
||||
sub_agent_runner: Optional[Callable[[str], Awaitable[str]]],
|
||||
) -> MaisakaReplyContext:
|
||||
"""在 replyer 内部构建表达习惯和黑话解释。"""
|
||||
session_id = self._resolve_session_id(stream_id)
|
||||
if not session_id:
|
||||
logger.warning("构建 Maisaka 回复上下文失败:缺少会话标识")
|
||||
return MaisakaReplyContext()
|
||||
|
||||
expression_habits, selected_expression_ids = self._build_expression_habits(
|
||||
if sub_agent_runner is None:
|
||||
logger.info("表达方式选择跳过:缺少子代理执行器")
|
||||
return MaisakaReplyContext()
|
||||
|
||||
selection_result = await maisaka_expression_selector.select_for_reply(
|
||||
session_id=session_id,
|
||||
chat_history=chat_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
sub_agent_runner=sub_agent_runner,
|
||||
)
|
||||
return MaisakaReplyContext(
|
||||
expression_habits=expression_habits,
|
||||
selected_expression_ids=selected_expression_ids,
|
||||
expression_habits=selection_result.expression_habits,
|
||||
selected_expression_ids=selection_result.selected_expression_ids,
|
||||
)
|
||||
|
||||
def _build_expression_habits(
|
||||
self,
|
||||
session_id: str,
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
) -> tuple[str, List[int]]:
|
||||
"""查询并格式化适合当前会话的表达习惯。"""
|
||||
del chat_history
|
||||
del reply_message
|
||||
del reply_reason
|
||||
|
||||
expression_records = self._load_expression_records(session_id)
|
||||
if not expression_records:
|
||||
return "", []
|
||||
|
||||
lines: List[str] = []
|
||||
selected_ids: List[int] = []
|
||||
for expression in expression_records:
|
||||
if expression.expression_id is not None:
|
||||
selected_ids.append(expression.expression_id)
|
||||
lines.append(f"- 当{expression.situation}时,可以自然地用{expression.style}这种表达习惯。")
|
||||
|
||||
block = "【表达习惯参考】\n" + "\n".join(lines)
|
||||
logger.info(
|
||||
f"已构建 Maisaka 表达习惯: 会话标识={session_id} "
|
||||
f"数量={len(selected_ids)} 表达编号={selected_ids!r}"
|
||||
)
|
||||
return block, selected_ids
|
||||
|
||||
def _get_related_session_ids(self, session_id: str) -> List[str]:
|
||||
"""根据表达互通组配置,解析当前会话可共享的会话 ID。"""
|
||||
related_session_ids = {session_id}
|
||||
expression_groups = global_config.expression.expression_groups
|
||||
|
||||
for expression_group in expression_groups:
|
||||
target_items = expression_group.expression_groups
|
||||
group_session_ids: set[str] = set()
|
||||
contains_current_session = False
|
||||
|
||||
for target_item in target_items:
|
||||
platform = target_item.platform.strip()
|
||||
item_id = target_item.item_id.strip()
|
||||
if not platform or not item_id:
|
||||
continue
|
||||
|
||||
rule_type = target_item.rule_type
|
||||
target_session_id = SessionUtils.calculate_session_id(
|
||||
platform,
|
||||
group_id=item_id if rule_type == "group" else None,
|
||||
user_id=None if rule_type == "group" else item_id,
|
||||
)
|
||||
group_session_ids.add(target_session_id)
|
||||
if target_session_id == session_id:
|
||||
contains_current_session = True
|
||||
|
||||
if contains_current_session:
|
||||
related_session_ids.update(group_session_ids)
|
||||
|
||||
return list(related_session_ids)
|
||||
|
||||
def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]:
|
||||
"""提取表达方式静态数据,避免 detached ORM 对象。"""
|
||||
related_session_ids = self._get_related_session_ids(session_id)
|
||||
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
base_query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
||||
scoped_query = base_query.where(
|
||||
(Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
|
||||
|
||||
if global_config.expression.expression_checked_only:
|
||||
scoped_query = scoped_query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
||||
|
||||
expressions = session.exec(scoped_query.limit(5)).all()
|
||||
|
||||
return [
|
||||
_ExpressionRecord(
|
||||
expression_id=expression.id,
|
||||
situation=expression.situation,
|
||||
style=expression.style,
|
||||
)
|
||||
for expression in expressions
|
||||
]
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
extra_info: str = "",
|
||||
@@ -440,8 +339,8 @@ class MaisakaReplyGenerator:
|
||||
chat_history: Optional[List[LLMContextMessage]] = None,
|
||||
expression_habits: str = "",
|
||||
selected_expression_ids: Optional[List[int]] = None,
|
||||
sub_agent_runner: Optional[Callable[[str], Awaitable[str]]] = None,
|
||||
) -> Tuple[bool, ReplyGenerationResult]:
|
||||
"""结合上下文生成 Maisaka 的最终可见回复。"""
|
||||
del available_actions
|
||||
del chosen_actions
|
||||
del extra_info
|
||||
@@ -457,9 +356,8 @@ class MaisakaReplyGenerator:
|
||||
return False, result
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复器开始生成: 会话流标识={stream_id} 回复原因={reply_reason!r} "
|
||||
f"历史消息数={len(chat_history)} 目标消息编号="
|
||||
f"{reply_message.message_id if reply_message else None}"
|
||||
f"Maisaka 回复器开始生成: 流={stream_id} 原因={reply_reason!r} "
|
||||
f"历史条数={len(chat_history)} 目标ID={reply_message.message_id if reply_message else None}"
|
||||
)
|
||||
|
||||
filtered_history = [
|
||||
@@ -468,11 +366,8 @@ class MaisakaReplyGenerator:
|
||||
if not isinstance(message, (ReferenceMessage, ToolResultMessage))
|
||||
]
|
||||
|
||||
logger.debug(f"Maisaka 回复器过滤后历史消息数={len(filtered_history)}")
|
||||
|
||||
# Validate that express_model is properly initialized
|
||||
if self.express_model is None:
|
||||
logger.error("Maisaka 回复器的回复模型未初始化")
|
||||
logger.error("回复模型未初始化")
|
||||
result.error_message = "回复模型尚未初始化"
|
||||
return False, result
|
||||
|
||||
@@ -482,10 +377,11 @@ class MaisakaReplyGenerator:
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
stream_id=stream_id,
|
||||
sub_agent_runner=sub_agent_runner,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
logger.error(f"Maisaka 回复器构建回复上下文失败: {exc}\n{traceback.format_exc()}")
|
||||
logger.error(f"构建回复上下文失败: {exc}\n{traceback.format_exc()}")
|
||||
result.error_message = f"构建回复上下文失败: {exc}"
|
||||
return False, result
|
||||
|
||||
@@ -497,8 +393,7 @@ class MaisakaReplyGenerator:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复上下文构建完成: 会话流标识={stream_id} "
|
||||
f"已选表达编号={result.selected_expression_ids!r}"
|
||||
f"回复上下文完成: 流={stream_id} 已选表达={result.selected_expression_ids!r}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -510,7 +405,7 @@ class MaisakaReplyGenerator:
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
logger.error(f"Maisaka 回复器构建提示词失败: {exc}\n{traceback.format_exc()}")
|
||||
logger.error(f"构建提示词失败: {exc}\n{traceback.format_exc()}")
|
||||
result.error_message = f"构建提示词失败: {exc}"
|
||||
return False, result
|
||||
|
||||
@@ -528,13 +423,15 @@ class MaisakaReplyGenerator:
|
||||
category="replyer",
|
||||
chat_id=preview_chat_id,
|
||||
request_kind="replyer",
|
||||
subtitle=f"会话流标识:{preview_chat_id}",
|
||||
subtitle=f"流ID: {preview_chat_id}",
|
||||
folded=global_config.debug.fold_maisaka_thinking,
|
||||
)
|
||||
|
||||
started_at = time.perf_counter()
|
||||
try:
|
||||
generation_result = await self.express_model.generate_response_with_messages(message_factory=message_factory)
|
||||
generation_result = await self.express_model.generate_response_with_messages(
|
||||
message_factory=message_factory
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Maisaka 回复器调用失败")
|
||||
result.error_message = str(exc)
|
||||
@@ -565,17 +462,15 @@ class MaisakaReplyGenerator:
|
||||
return False, result
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复器生成成功: 回复文本={response_text!r} "
|
||||
f"总耗时毫秒={result.metrics.overall_ms} "
|
||||
f"已选表达编号={result.selected_expression_ids!r}"
|
||||
f"Maisaka 回复器生成成功: 文本={response_text!r} 总耗时ms={result.metrics.overall_ms} 已选表达={result.selected_expression_ids!r}"
|
||||
)
|
||||
if global_config.debug.show_replyer_prompt or global_config.debug.show_replyer_reasoning:
|
||||
summary_lines = [
|
||||
f"会话流标识: {preview_chat_id or 'unknown'}",
|
||||
f"总耗时: {result.metrics.overall_ms} ms",
|
||||
f"流ID: {preview_chat_id or 'unknown'}",
|
||||
f"耗时: {result.metrics.overall_ms} ms",
|
||||
]
|
||||
if result.selected_expression_ids:
|
||||
summary_lines.append(f"表达习惯编号: {result.selected_expression_ids!r}")
|
||||
summary_lines.append(f"表达编号: {result.selected_expression_ids!r}")
|
||||
|
||||
renderables: List[RenderableType] = [Text("\n".join(summary_lines))]
|
||||
if replyer_prompt_section is not None:
|
||||
@@ -584,7 +479,7 @@ class MaisakaReplyGenerator:
|
||||
renderables.append(
|
||||
Panel(
|
||||
Text(result.completion.reasoning_text),
|
||||
title="回复器思考",
|
||||
title="思考内容",
|
||||
border_style="magenta",
|
||||
padding=(0, 1),
|
||||
)
|
||||
@@ -600,7 +495,7 @@ class MaisakaReplyGenerator:
|
||||
console.print(
|
||||
Panel(
|
||||
Group(*renderables),
|
||||
title="MaiSaka 回复器结果",
|
||||
title="MaiSaka 回复器",
|
||||
border_style="bright_yellow",
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user