Merge branch 'Mai-with-u:r-dev' into r-dev
This commit is contained in:
@@ -17,6 +17,7 @@ from src.common.data_models.reply_generation_data_models import (
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.types import ActionInfo
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
@@ -162,15 +163,37 @@ class MaisakaReplyGenerator:
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _build_target_message_block(self, reply_message: Optional[SessionMessage]) -> str:
|
||||
"""构建当前需要回复的目标消息摘要。"""
|
||||
if reply_message is None:
|
||||
return ""
|
||||
|
||||
user_info = reply_message.message_info.user_info
|
||||
sender_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id
|
||||
target_message_id = reply_message.message_id.strip() if reply_message.message_id else "未知"
|
||||
target_content = self._normalize_content((reply_message.processed_plain_text or "").strip(), limit=300)
|
||||
if not target_content:
|
||||
target_content = "[无可见文本内容]"
|
||||
|
||||
return (
|
||||
"【本次回复目标】\n"
|
||||
f"- 目标消息ID:{target_message_id}\n"
|
||||
f"- 发送者:{sender_name}\n"
|
||||
f"- 消息内容:{target_content}\n"
|
||||
"- 你这次要回复的就是这条目标消息,请结合整段上下文理解,但不要误把其他历史消息当成当前回复对象。"
|
||||
)
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
) -> str:
|
||||
"""构建 Maisaka replyer 提示词。"""
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
formatted_history = self._format_chat_history(chat_history)
|
||||
target_message_block = self._build_target_message_block(reply_message)
|
||||
|
||||
try:
|
||||
system_prompt = load_prompt(
|
||||
@@ -191,6 +214,8 @@ class MaisakaReplyGenerator:
|
||||
f"当前时间:{current_time}",
|
||||
f"【聊天记录】\n{formatted_history}",
|
||||
]
|
||||
if target_message_block:
|
||||
user_sections.append(target_message_block)
|
||||
if extra_sections:
|
||||
user_sections.append("\n\n".join(extra_sections))
|
||||
user_sections.append(f"【回复信息参考】\n{reply_reason}")
|
||||
@@ -261,18 +286,52 @@ class MaisakaReplyGenerator:
|
||||
)
|
||||
return block, selected_ids
|
||||
|
||||
def _get_related_session_ids(self, session_id: str) -> List[str]:
|
||||
"""根据表达互通组配置,解析当前会话可共享的会话 ID。"""
|
||||
related_session_ids = {session_id}
|
||||
expression_groups = global_config.expression.expression_groups
|
||||
|
||||
for expression_group in expression_groups:
|
||||
target_items = expression_group.expression_groups
|
||||
group_session_ids: set[str] = set()
|
||||
contains_current_session = False
|
||||
|
||||
for target_item in target_items:
|
||||
platform = target_item.platform.strip()
|
||||
item_id = target_item.item_id.strip()
|
||||
if not platform or not item_id:
|
||||
continue
|
||||
|
||||
rule_type = target_item.rule_type
|
||||
target_session_id = SessionUtils.calculate_session_id(
|
||||
platform,
|
||||
group_id=item_id if rule_type == "group" else None,
|
||||
user_id=None if rule_type == "group" else item_id,
|
||||
)
|
||||
group_session_ids.add(target_session_id)
|
||||
if target_session_id == session_id:
|
||||
contains_current_session = True
|
||||
|
||||
if contains_current_session:
|
||||
related_session_ids.update(group_session_ids)
|
||||
|
||||
return list(related_session_ids)
|
||||
|
||||
def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]:
|
||||
"""提取表达方式静态数据,避免 detached ORM 对象。"""
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
||||
if global_config.expression.expression_checked_only:
|
||||
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
||||
related_session_ids = self._get_related_session_ids(session_id)
|
||||
|
||||
query = query.where(
|
||||
(Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
base_query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
||||
scoped_query = base_query.where(
|
||||
(Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
|
||||
|
||||
expressions = session.exec(query.limit(5)).all()
|
||||
if global_config.expression.expression_checked_only:
|
||||
scoped_query = scoped_query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
||||
|
||||
expressions = session.exec(scoped_query.limit(5)).all()
|
||||
|
||||
return [
|
||||
_ExpressionRecord(
|
||||
expression_id=expression.id,
|
||||
@@ -362,6 +421,7 @@ class MaisakaReplyGenerator:
|
||||
try:
|
||||
prompt = self._build_prompt(
|
||||
chat_history=filtered_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
expression_habits=merged_expression_habits,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import random
|
||||
import time
|
||||
|
||||
@@ -9,6 +8,7 @@ from sqlmodel import select
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.data_models.reply_generation_data_models import (
|
||||
GenerationMetrics,
|
||||
@@ -17,14 +17,15 @@ from src.common.data_models.reply_generation_data_models import (
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.types import ActionInfo
|
||||
from src.llm_models.payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.maisaka.context_messages import AssistantMessage, LLMContextMessage, ReferenceMessage, SessionBackedMessage, ToolResultMessage
|
||||
from src.maisaka.message_adapter import parse_speaker_content
|
||||
from src.maisaka.message_adapter import clone_message_sequence, parse_speaker_content
|
||||
|
||||
logger = get_logger("replyer")
|
||||
|
||||
@@ -126,13 +127,35 @@ class MaisakaReplyGenerator:
|
||||
|
||||
return segments
|
||||
|
||||
def _build_target_message_block(self, reply_message: Optional[SessionMessage]) -> str:
|
||||
"""构建当前需要回复的目标消息摘要。"""
|
||||
if reply_message is None:
|
||||
return ""
|
||||
|
||||
user_info = reply_message.message_info.user_info
|
||||
sender_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id
|
||||
target_message_id = reply_message.message_id.strip() if reply_message.message_id else "未知"
|
||||
target_content = self._normalize_content((reply_message.processed_plain_text or "").strip(), limit=300)
|
||||
if not target_content:
|
||||
target_content = "[无可见文本内容]"
|
||||
|
||||
return (
|
||||
"【本次回复目标】\n"
|
||||
f"- 目标消息ID:{target_message_id}\n"
|
||||
f"- 发送者:{sender_name}\n"
|
||||
f"- 消息内容:{target_content}\n"
|
||||
"- 你这次要回复的就是这条目标消息,请结合整段上下文理解,但不要误把其他历史消息当成当前回复对象。"
|
||||
)
|
||||
|
||||
def _build_system_prompt(
|
||||
self,
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
) -> str:
|
||||
"""构建 Maisaka replyer 使用的系统提示词。"""
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
target_message_block = self._build_target_message_block(reply_message)
|
||||
|
||||
try:
|
||||
system_prompt = load_prompt(
|
||||
@@ -148,9 +171,10 @@ class MaisakaReplyGenerator:
|
||||
extra_sections: List[str] = []
|
||||
if expression_habits.strip():
|
||||
extra_sections.append(expression_habits.strip())
|
||||
if target_message_block:
|
||||
extra_sections.append(target_message_block)
|
||||
if reply_reason.strip():
|
||||
extra_sections.append(f"【回复信息参考】\n{reply_reason}")
|
||||
|
||||
if not extra_sections:
|
||||
return system_prompt
|
||||
return f"{system_prompt}\n\n" + "\n\n".join(extra_sections)
|
||||
@@ -159,6 +183,34 @@ class MaisakaReplyGenerator:
|
||||
"""构建追加在上下文末尾的回复指令。"""
|
||||
return "请基于以上逐条对话消息,自然地继续回复。直接输出你要说的话,不要额外解释。"
|
||||
|
||||
def _build_multimodal_user_message(
|
||||
self,
|
||||
message: SessionBackedMessage,
|
||||
default_user_name: str,
|
||||
) -> Optional[Message]:
|
||||
"""构建保留图片等多模态片段的用户消息。"""
|
||||
speaker_name, _ = parse_speaker_content(message.processed_plain_text.strip())
|
||||
visible_speaker = speaker_name or default_user_name
|
||||
|
||||
raw_message = clone_message_sequence(message.raw_message)
|
||||
if not raw_message.components:
|
||||
raw_message = MessageSequence([TextComponent(f"[{visible_speaker}]")])
|
||||
elif isinstance(raw_message.components[0], TextComponent):
|
||||
first_text = raw_message.components[0].text or ""
|
||||
raw_message.components[0] = TextComponent(f"[{visible_speaker}]{first_text}")
|
||||
else:
|
||||
raw_message.components.insert(0, TextComponent(f"[{visible_speaker}]"))
|
||||
|
||||
multimodal_message = SessionBackedMessage(
|
||||
raw_message=raw_message,
|
||||
visible_text=f"[{visible_speaker}]{message.processed_plain_text}",
|
||||
timestamp=message.timestamp,
|
||||
message_id=message.message_id,
|
||||
original_message=message.original_message,
|
||||
source_kind=message.source_kind,
|
||||
)
|
||||
return multimodal_message.to_llm_message()
|
||||
|
||||
def _build_history_messages(self, chat_history: List[LLMContextMessage]) -> List[Message]:
|
||||
"""将 replyer 上下文拆成多条 LLM 消息。"""
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
@@ -177,6 +229,11 @@ class MaisakaReplyGenerator:
|
||||
)
|
||||
continue
|
||||
|
||||
multimodal_message = self._build_multimodal_user_message(message, default_user_name)
|
||||
if multimodal_message is not None:
|
||||
messages.append(multimodal_message)
|
||||
continue
|
||||
|
||||
for speaker_name, content_body in self._split_user_message_segments(message.processed_plain_text):
|
||||
content = self._normalize_content(content_body)
|
||||
if not content:
|
||||
@@ -205,12 +262,14 @@ class MaisakaReplyGenerator:
|
||||
def _build_request_messages(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
) -> List[Message]:
|
||||
"""构建发给大模型的消息列表。"""
|
||||
messages: List[Message] = []
|
||||
system_prompt = self._build_system_prompt(
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
expression_habits=expression_habits,
|
||||
)
|
||||
@@ -227,7 +286,14 @@ class MaisakaReplyGenerator:
|
||||
preview_lines: List[str] = []
|
||||
for message in messages:
|
||||
role_name = message.role.value.capitalize()
|
||||
preview_lines.append(f"{role_name}: {message.get_text_content()}")
|
||||
part_previews: List[str] = []
|
||||
for part in message.parts:
|
||||
if isinstance(part, TextMessagePart):
|
||||
part_previews.append(part.text)
|
||||
continue
|
||||
if isinstance(part, ImageMessagePart):
|
||||
part_previews.append(f"[图片:{part.normalized_image_format}]")
|
||||
preview_lines.append(f"{role_name}: {''.join(part_previews)}")
|
||||
return "\n\n".join(preview_lines)
|
||||
|
||||
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
|
||||
@@ -292,18 +358,52 @@ class MaisakaReplyGenerator:
|
||||
)
|
||||
return block, selected_ids
|
||||
|
||||
def _get_related_session_ids(self, session_id: str) -> List[str]:
|
||||
"""根据表达互通组配置,解析当前会话可共享的会话 ID。"""
|
||||
related_session_ids = {session_id}
|
||||
expression_groups = global_config.expression.expression_groups
|
||||
|
||||
for expression_group in expression_groups:
|
||||
target_items = expression_group.expression_groups
|
||||
group_session_ids: set[str] = set()
|
||||
contains_current_session = False
|
||||
|
||||
for target_item in target_items:
|
||||
platform = target_item.platform.strip()
|
||||
item_id = target_item.item_id.strip()
|
||||
if not platform or not item_id:
|
||||
continue
|
||||
|
||||
rule_type = target_item.rule_type
|
||||
target_session_id = SessionUtils.calculate_session_id(
|
||||
platform,
|
||||
group_id=item_id if rule_type == "group" else None,
|
||||
user_id=None if rule_type == "group" else item_id,
|
||||
)
|
||||
group_session_ids.add(target_session_id)
|
||||
if target_session_id == session_id:
|
||||
contains_current_session = True
|
||||
|
||||
if contains_current_session:
|
||||
related_session_ids.update(group_session_ids)
|
||||
|
||||
return list(related_session_ids)
|
||||
|
||||
def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]:
|
||||
"""提取表达方式静态数据,避免 detached ORM 对象。"""
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
||||
if global_config.expression.expression_checked_only:
|
||||
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
||||
related_session_ids = self._get_related_session_ids(session_id)
|
||||
|
||||
query = query.where(
|
||||
(Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
base_query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
||||
scoped_query = base_query.where(
|
||||
(Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
|
||||
|
||||
expressions = session.exec(query.limit(5)).all()
|
||||
if global_config.expression.expression_checked_only:
|
||||
scoped_query = scoped_query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
||||
|
||||
expressions = session.exec(scoped_query.limit(5)).all()
|
||||
|
||||
return [
|
||||
_ExpressionRecord(
|
||||
expression_id=expression.id,
|
||||
@@ -393,6 +493,7 @@ class MaisakaReplyGenerator:
|
||||
try:
|
||||
request_messages = self._build_request_messages(
|
||||
chat_history=filtered_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
expression_habits=merged_expression_habits,
|
||||
)
|
||||
|
||||
@@ -5,8 +5,8 @@ from src.config.config import global_config
|
||||
|
||||
def get_maisaka_replyer_class() -> Type[object]:
|
||||
"""根据配置返回 Maisaka replyer 类。"""
|
||||
generator_type = global_config.maisaka.replyer_generator_type
|
||||
if generator_type == "multi":
|
||||
generator_type = get_maisaka_replyer_generator_type()
|
||||
if generator_type == "multimodal":
|
||||
from .maisaka_generator_multi import MaisakaReplyGenerator
|
||||
|
||||
return MaisakaReplyGenerator
|
||||
@@ -18,4 +18,4 @@ def get_maisaka_replyer_class() -> Type[object]:
|
||||
|
||||
def get_maisaka_replyer_generator_type() -> str:
|
||||
"""返回当前配置的 Maisaka replyer 生成器类型。"""
|
||||
return global_config.maisaka.replyer_generator_type
|
||||
return global_config.chat.replyer_generator_type
|
||||
|
||||
@@ -16,8 +16,6 @@ class ExampleConfig(ConfigBase):
|
||||
\"""This is an example field\"""
|
||||
- 注释前面增加_warp_标记可以实现配置文件中注释在配置项前面单独一行显示
|
||||
"""
|
||||
|
||||
|
||||
class BotConfig(ConfigBase):
|
||||
"""机器人配置类"""
|
||||
|
||||
@@ -283,7 +281,7 @@ class ChatConfig(ConfigBase):
|
||||
},
|
||||
)
|
||||
|
||||
direct_image_input: bool = Field(
|
||||
multimodal_planner: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
@@ -292,14 +290,14 @@ class ChatConfig(ConfigBase):
|
||||
)
|
||||
"""是否直接输入图片"""
|
||||
|
||||
replyer_generator_type: Literal["legacy", "multi"] = Field(
|
||||
replyer_generator_type: Literal["legacy", "multimodal"] = Field(
|
||||
default="legacy",
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "git-branch",
|
||||
},
|
||||
)
|
||||
"""Maisaka replyer 生成器类型:legacy(旧版单 prompt)/ multi(多消息版)"""
|
||||
"""Maisaka replyer 生成器类型:legacy(旧版单 prompt)/ multimodal(多模态版,适合主循环直接展示图片)"""
|
||||
|
||||
enable_talk_value_rules: bool = Field(
|
||||
default=True,
|
||||
@@ -718,17 +716,16 @@ class EmojiConfig(ConfigBase):
|
||||
__ui_label__ = "功能"
|
||||
__ui_icon__ = "puzzle"
|
||||
|
||||
emoji_chance: float = Field(
|
||||
default=0.4,
|
||||
ge=0,
|
||||
le=1,
|
||||
emoji_send_num: int = Field(
|
||||
default=25,
|
||||
ge=1,
|
||||
le=64,
|
||||
json_schema_extra={
|
||||
"x-widget": "slider",
|
||||
"x-icon": "smile",
|
||||
"step": 0.1,
|
||||
"x-widget": "input",
|
||||
"x-icon": "grid",
|
||||
},
|
||||
)
|
||||
"""发送表情包的基础概率"""
|
||||
"""一次从多少个表情包中选择发送,最大为 64"""
|
||||
|
||||
max_reg_num: int = Field(
|
||||
default=100,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""reply 内置工具。"""
|
||||
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
from src.chat.replyer.replyer_manager import replyer_manager
|
||||
@@ -82,6 +83,7 @@ async def handle_tool(
|
||||
logger.exception(
|
||||
f"{tool_ctx.runtime.log_prefix} 获取回复生成器时发生异常: 目标消息编号={target_message_id}"
|
||||
)
|
||||
logger.info(traceback.format_exc())
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"获取 Maisaka 回复生成器时发生异常。",
|
||||
|
||||
@@ -1,22 +1,31 @@
|
||||
"""send_emoji 内置工具。"""
|
||||
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
import math
|
||||
from random import sample
|
||||
from secrets import token_hex
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import asyncio
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from PIL import ImageDraw, ImageFont
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.emoji_system.maisaka_tool import send_emoji_for_maisaka
|
||||
from src.common.data_models.message_component_data_model import ImageComponent, MessageSequence, TextComponent
|
||||
from src.common.data_models.image_data_model import MaiEmoji
|
||||
from src.common.data_models.message_component_data_model import ImageComponent, MessageSequence, TextComponent
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||
from src.maisaka.context_messages import LLMContextMessage, ReferenceMessage, ReferenceMessageType, SessionBackedMessage
|
||||
from src.maisaka.context_messages import (
|
||||
LLMContextMessage,
|
||||
ReferenceMessage,
|
||||
ReferenceMessageType,
|
||||
SessionBackedMessage,
|
||||
)
|
||||
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
|
||||
@@ -24,16 +33,16 @@ logger = get_logger("maisaka_builtin_send_emoji")
|
||||
|
||||
_EMOJI_SUB_AGENT_CONTEXT_LIMIT = 12
|
||||
_EMOJI_SUB_AGENT_MAX_TOKENS = 240
|
||||
_EMOJI_SUB_AGENT_SAMPLE_SIZE = 20
|
||||
_EMOJI_SUCCESS_MESSAGE = "???????"
|
||||
_EMOJI_MAX_CANDIDATE_COUNT = 64
|
||||
_EMOJI_CANDIDATE_TILE_SIZE = 256
|
||||
_EMOJI_SUCCESS_MESSAGE = "表情包发送成功"
|
||||
|
||||
|
||||
class EmojiSelectionResult(BaseModel):
|
||||
"""表情包子代理的结构化选择结果。"""
|
||||
|
||||
emoji_id: str = PydanticField(default="", description="选中的候选表情包 ID。")
|
||||
matched_emotion: str = PydanticField(default="", description="本次命中的情绪标签,可为空。")
|
||||
reason: str = PydanticField(default="", description="简短选择理由。")
|
||||
emoji_index: int = PydanticField(default=1, description="选中的表情包序号,从 1 开始计数。")
|
||||
reason: str = PydanticField(default="", description="选择这张表情包的简短理由。")
|
||||
|
||||
|
||||
def get_tool_spec() -> ToolSpec:
|
||||
@@ -57,19 +66,146 @@ def get_tool_spec() -> ToolSpec:
|
||||
)
|
||||
|
||||
|
||||
async def _build_emoji_candidate_message(emoji: MaiEmoji, candidate_id: str) -> SessionBackedMessage:
|
||||
"""构建供子代理挑选的图片候选消息。"""
|
||||
async def _load_emoji_bytes(emoji: MaiEmoji) -> bytes:
|
||||
"""读取单个表情包图片字节。"""
|
||||
|
||||
image_bytes = await asyncio.to_thread(emoji.full_path.read_bytes)
|
||||
return await asyncio.to_thread(emoji.full_path.read_bytes)
|
||||
|
||||
|
||||
def _get_emoji_candidate_count() -> int:
|
||||
"""获取本次表情包候选数量配置。"""
|
||||
|
||||
configured_count = int(getattr(global_config.emoji, "emoji_send_num", 25))
|
||||
return max(1, min(configured_count, _EMOJI_MAX_CANDIDATE_COUNT))
|
||||
|
||||
|
||||
def _calculate_grid_shape(candidate_count: int) -> tuple[int, int]:
|
||||
"""根据候选数量计算尽量接近矩形的拼图行列数。"""
|
||||
|
||||
if candidate_count <= 0:
|
||||
return 1, 1
|
||||
|
||||
best_columns = candidate_count
|
||||
best_rows = 1
|
||||
best_score: tuple[int, int] | None = None
|
||||
|
||||
for columns in range(1, candidate_count + 1):
|
||||
rows = math.ceil(candidate_count / columns)
|
||||
empty_slots = rows * columns - candidate_count
|
||||
aspect_gap = abs(columns - rows)
|
||||
score = (aspect_gap, empty_slots)
|
||||
if best_score is None or score < best_score:
|
||||
best_score = score
|
||||
best_columns = columns
|
||||
best_rows = rows
|
||||
|
||||
return best_rows, best_columns
|
||||
|
||||
|
||||
def _build_placeholder_tile(label: str, tile_size: int) -> PILImage.Image:
|
||||
"""构建图片读取失败时使用的占位图。"""
|
||||
|
||||
tile = PILImage.new("RGB", (tile_size, tile_size), color=(245, 245, 245))
|
||||
draw = ImageDraw.Draw(tile)
|
||||
font = ImageFont.load_default()
|
||||
text_bbox = draw.textbbox((0, 0), label, font=font)
|
||||
text_width = text_bbox[2] - text_bbox[0]
|
||||
text_height = text_bbox[3] - text_bbox[1]
|
||||
draw.text(
|
||||
((tile_size - text_width) / 2, (tile_size - text_height) / 2),
|
||||
label,
|
||||
fill=(80, 80, 80),
|
||||
font=font,
|
||||
)
|
||||
return tile
|
||||
|
||||
|
||||
def _build_labeled_tile(image_bytes: bytes, index: int, tile_size: int) -> PILImage.Image:
|
||||
"""构建带序号角标的候选图片块。"""
|
||||
|
||||
try:
|
||||
with PILImage.open(BytesIO(image_bytes)) as raw_image:
|
||||
image = raw_image.convert("RGBA")
|
||||
except Exception:
|
||||
return _build_placeholder_tile(str(index), tile_size)
|
||||
|
||||
image.thumbnail((tile_size, tile_size))
|
||||
tile = PILImage.new("RGBA", (tile_size, tile_size), color=(255, 255, 255, 255))
|
||||
offset_x = (tile_size - image.width) // 2
|
||||
offset_y = (tile_size - image.height) // 2
|
||||
tile.paste(image, (offset_x, offset_y), image)
|
||||
|
||||
draw = ImageDraw.Draw(tile)
|
||||
font = ImageFont.load_default()
|
||||
badge_size = 56
|
||||
badge_margin = 14
|
||||
draw.rounded_rectangle(
|
||||
(
|
||||
badge_margin,
|
||||
badge_margin,
|
||||
badge_margin + badge_size,
|
||||
badge_margin + badge_size,
|
||||
),
|
||||
radius=8,
|
||||
fill=(0, 0, 0, 180),
|
||||
)
|
||||
label = str(index)
|
||||
text_bbox = draw.textbbox((0, 0), label, font=font)
|
||||
text_width = text_bbox[2] - text_bbox[0]
|
||||
text_height = text_bbox[3] - text_bbox[1]
|
||||
draw.text(
|
||||
(
|
||||
badge_margin + (badge_size - text_width) / 2,
|
||||
badge_margin + (badge_size - text_height) / 2 - 1,
|
||||
),
|
||||
label,
|
||||
fill=(255, 255, 255, 255),
|
||||
font=font,
|
||||
)
|
||||
return tile
|
||||
|
||||
|
||||
def _merge_emoji_tiles(image_bytes_list: list[bytes]) -> bytes:
|
||||
"""将候选表情图拼接成一张尽量接近矩形的网格图片。"""
|
||||
|
||||
tile_size = _EMOJI_CANDIDATE_TILE_SIZE
|
||||
gap = 12
|
||||
candidate_count = len(image_bytes_list)
|
||||
grid_rows, grid_columns = _calculate_grid_shape(candidate_count)
|
||||
tiles = [
|
||||
_build_labeled_tile(image_bytes=image_bytes, index=index, tile_size=tile_size)
|
||||
for index, image_bytes in enumerate(image_bytes_list, start=1)
|
||||
]
|
||||
canvas_width = tile_size * grid_columns + gap * (grid_columns - 1)
|
||||
canvas_height = tile_size * grid_rows + gap * (grid_rows - 1)
|
||||
canvas = PILImage.new("RGBA", (canvas_width, canvas_height), color=(255, 255, 255, 255))
|
||||
|
||||
for index, tile in enumerate(tiles):
|
||||
row = index // grid_columns
|
||||
column = index % grid_columns
|
||||
offset_x = column * (tile_size + gap)
|
||||
offset_y = row * (tile_size + gap)
|
||||
canvas.paste(tile, (offset_x, offset_y), tile)
|
||||
|
||||
output = BytesIO()
|
||||
canvas.convert("RGB").save(output, format="PNG")
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
async def _build_emoji_candidate_message(emojis: list[MaiEmoji]) -> SessionBackedMessage:
|
||||
"""构建供子代理挑选的拼图候选消息。"""
|
||||
|
||||
image_bytes_list = await asyncio.gather(*[_load_emoji_bytes(emoji) for emoji in emojis])
|
||||
merged_image_bytes = await asyncio.to_thread(_merge_emoji_tiles, list(image_bytes_list))
|
||||
raw_message = MessageSequence(
|
||||
[
|
||||
TextComponent(f"ID: {candidate_id}"),
|
||||
ImageComponent(binary_hash=str(emoji.file_hash or ""), binary_data=image_bytes),
|
||||
TextComponent("请从这张 5x5 拼图中选择一个序号。"),
|
||||
ImageComponent(binary_hash="", binary_data=merged_image_bytes),
|
||||
]
|
||||
)
|
||||
return SessionBackedMessage(
|
||||
raw_message=raw_message,
|
||||
visible_text=f"ID: {candidate_id}",
|
||||
visible_text="[表情包拼图候选]",
|
||||
timestamp=datetime.now(),
|
||||
source_kind="emoji_candidate",
|
||||
)
|
||||
@@ -81,42 +217,38 @@ async def _select_emoji_with_sub_agent(
|
||||
reasoning: str,
|
||||
context_texts: list[str],
|
||||
sample_size: int,
|
||||
selection_metadata: Optional[Dict[str, str]] = None,
|
||||
) -> tuple[MaiEmoji | None, str]:
|
||||
"""通过临时子代理从候选表情包中选出一个结果。"""
|
||||
|
||||
del reasoning, context_texts, sample_size
|
||||
|
||||
available_emojis = list(emoji_manager.emojis)
|
||||
if not available_emojis:
|
||||
return None, ""
|
||||
|
||||
effective_sample_size = min(max(sample_size, 1), _EMOJI_SUB_AGENT_SAMPLE_SIZE, len(available_emojis))
|
||||
sampled_emojis = sample(available_emojis, effective_sample_size)
|
||||
total_candidate_count = min(len(available_emojis), _get_emoji_candidate_count())
|
||||
sampled_emojis = sample(available_emojis, total_candidate_count)
|
||||
candidate_message = await _build_emoji_candidate_message(sampled_emojis)
|
||||
grid_rows, grid_columns = _calculate_grid_shape(len(sampled_emojis))
|
||||
|
||||
candidate_map: dict[str, MaiEmoji] = {}
|
||||
candidate_messages: list[LLMContextMessage] = []
|
||||
for emoji in sampled_emojis:
|
||||
candidate_id = token_hex(4)
|
||||
while candidate_id in candidate_map:
|
||||
candidate_id = token_hex(4)
|
||||
candidate_map[candidate_id] = emoji
|
||||
candidate_messages.append(await _build_emoji_candidate_message(emoji, candidate_id))
|
||||
|
||||
context_text = "\n".join(context_texts[-5:]) if context_texts else "(暂无额外上下文)"
|
||||
system_prompt = (
|
||||
"你是 Maisaka 的临时表情包选择子代理。\n"
|
||||
"你会收到一段群聊上下文,以及若干条候选表情包消息。每条候选消息里都有一个临时 ID。\n"
|
||||
"你的任务是根据上下文、当前语气和发送意图,从候选里选出最合适的一个表情包。\n"
|
||||
"必须只从候选消息中选择,不能编造新的 ID。\n"
|
||||
f"你会收到群聊上下文,以及 1 条额外候选消息,其中包含一张 {grid_rows}x{grid_columns} 的表情包拼图,"
|
||||
f"一共 {len(sampled_emojis)} 个位置。\n"
|
||||
f"每张小图左上角都有一个较大的序号,范围是 1 到 {len(sampled_emojis)}。\n"
|
||||
f"你的任务是根据上下文和当前语气,从这 {len(sampled_emojis)} 张图里选出最合适的一张表情包。\n"
|
||||
"如果提供了 requested_emotion,请优先考虑与其接近的候选;如果没有完全匹配,则选择最符合上下文语气的候选。\n"
|
||||
"你必须返回一个 JSON 对象(json object),不要输出任何 JSON 之外的内容。\n"
|
||||
'返回格式固定为:{"emoji_id":"候选ID","matched_emotion":"情绪标签","reason":"简短理由"}'
|
||||
'返回格式固定为:{"emoji_index":1,"reason":"简短理由"}'
|
||||
)
|
||||
prompt_message = ReferenceMessage(
|
||||
content=(
|
||||
f"[选择任务]\n"
|
||||
f"requested_emotion: {requested_emotion or '未指定'}\n"
|
||||
f"reasoning: {reasoning or '辅助表达当前语气和情绪'}\n"
|
||||
f"recent_context:\n{context_text}\n"
|
||||
'请只输出 JSON。'
|
||||
f"候选总数: {len(sampled_emojis)}\n"
|
||||
f"拼图布局: {grid_rows}x{grid_columns}\n"
|
||||
"请只输出 JSON。"
|
||||
),
|
||||
timestamp=datetime.now(),
|
||||
reference_type=ReferenceMessageType.TOOL_HINT,
|
||||
@@ -127,7 +259,7 @@ async def _select_emoji_with_sub_agent(
|
||||
response = await tool_ctx.runtime.run_sub_agent(
|
||||
context_message_limit=_EMOJI_SUB_AGENT_CONTEXT_LIMIT,
|
||||
system_prompt=system_prompt,
|
||||
extra_messages=[prompt_message, *candidate_messages],
|
||||
extra_messages=[prompt_message, candidate_message],
|
||||
max_tokens=_EMOJI_SUB_AGENT_MAX_TOKENS,
|
||||
response_format=RespFormat(
|
||||
format_type=RespFormatType.JSON_SCHEMA,
|
||||
@@ -140,20 +272,19 @@ async def _select_emoji_with_sub_agent(
|
||||
except Exception as exc:
|
||||
logger.warning(f"{tool_ctx.runtime.log_prefix} 表情包子代理结果解析失败,将回退到候选首项: {exc}")
|
||||
fallback_emoji = sampled_emojis[0] if sampled_emojis else None
|
||||
return fallback_emoji, requested_emotion
|
||||
return fallback_emoji, ""
|
||||
|
||||
selected_emoji = candidate_map.get(selection.emoji_id.strip())
|
||||
if selected_emoji is None:
|
||||
if selection_metadata is not None:
|
||||
selection_metadata["reason"] = selection.reason.strip()
|
||||
|
||||
emoji_index = int(selection.emoji_index)
|
||||
if emoji_index < 1 or emoji_index > len(sampled_emojis):
|
||||
logger.warning(
|
||||
f"{tool_ctx.runtime.log_prefix} 表情包子代理返回了无效 ID: {selection.emoji_id!r},将回退到候选首项"
|
||||
f"{tool_ctx.runtime.log_prefix} 表情包子代理返回了无效序号: {emoji_index!r},将回退到第 1 张"
|
||||
)
|
||||
fallback_emoji = sampled_emojis[0] if sampled_emojis else None
|
||||
return fallback_emoji, requested_emotion
|
||||
emoji_index = 1
|
||||
|
||||
matched_emotion = selection.matched_emotion.strip()
|
||||
if not matched_emotion:
|
||||
matched_emotion = requested_emotion.strip()
|
||||
return selected_emoji, matched_emotion
|
||||
return sampled_emojis[emoji_index - 1], ""
|
||||
|
||||
|
||||
async def handle_tool(
|
||||
@@ -177,7 +308,9 @@ async def handle_tool(
|
||||
"emotion": [],
|
||||
"requested_emotion": emotion,
|
||||
"matched_emotion": "",
|
||||
"reason": "",
|
||||
}
|
||||
selection_metadata: Dict[str, str] = {"reason": ""}
|
||||
|
||||
logger.info(f"{tool_ctx.runtime.log_prefix} 触发表情包发送工具,请求情绪={emotion!r}")
|
||||
|
||||
@@ -193,6 +326,7 @@ async def handle_tool(
|
||||
reasoning,
|
||||
list(context_texts or []),
|
||||
sample_size,
|
||||
selection_metadata,
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
@@ -206,10 +340,11 @@ async def handle_tool(
|
||||
|
||||
if send_result.success:
|
||||
structured_result["message"] = _EMOJI_SUCCESS_MESSAGE
|
||||
structured_result["reason"] = selection_metadata["reason"]
|
||||
logger.info(
|
||||
f"{tool_ctx.runtime.log_prefix} ??????? "
|
||||
f"??={send_result.description!r} ????={send_result.emotions} "
|
||||
f"????={emotion!r} ????={send_result.matched_emotion!r}"
|
||||
f"{tool_ctx.runtime.log_prefix} 表情包发送成功 "
|
||||
f"描述={send_result.description!r} 情绪标签={send_result.emotions} "
|
||||
f"请求情绪={emotion!r} 命中情绪={send_result.matched_emotion!r}"
|
||||
)
|
||||
tool_ctx.append_sent_emoji_to_chat_history(
|
||||
emoji_base64=send_result.emoji_base64,
|
||||
@@ -218,7 +353,7 @@ async def handle_tool(
|
||||
structured_result["success"] = True
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
_EMOJI_SUCCESS_MESSAGE,
|
||||
selection_metadata["reason"] or _EMOJI_SUCCESS_MESSAGE,
|
||||
structured_content=structured_result,
|
||||
)
|
||||
|
||||
|
||||
@@ -51,6 +51,18 @@ class MaisakaReasoningEngine:
|
||||
self._runtime = runtime
|
||||
self._last_reasoning_content: str = ""
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@property
|
||||
def last_reasoning_content(self) -> str:
|
||||
"""返回最近一轮思考文本。"""
|
||||
@@ -122,8 +134,8 @@ class MaisakaReasoningEngine:
|
||||
|
||||
reasoning_content = response.content or ""
|
||||
if self._should_replace_reasoning(reasoning_content):
|
||||
response.content = "让我根据新情况重新思考:"
|
||||
response.raw_message.content = "让我根据新情况重新思考:"
|
||||
response.content = "我应该根据我上面思考的内容进行反思,重新思考我下一步的行动,我需要分析当前场景,对话,以及我可以使用的工具,然后先输出想法再使用工具"
|
||||
response.raw_message.content = "我应该根据我上面思考的内容进行反思,重新思考我下一步的行动,我需要分析当前场景,对话,以及我可以使用的工具,然后先输出想法再使用工具"
|
||||
logger.info(f"{self._runtime.log_prefix} 当前思考与上一轮过于相似,已替换为重新思考提示")
|
||||
|
||||
self._last_reasoning_content = reasoning_content
|
||||
@@ -266,7 +278,7 @@ class MaisakaReasoningEngine:
|
||||
source_sequence = message.raw_message
|
||||
|
||||
planner_components = clone_message_sequence(source_sequence).components
|
||||
if global_config.chat.direct_image_input:
|
||||
if global_config.chat.multimodal_planner:
|
||||
await self._hydrate_visual_components(planner_components)
|
||||
if planner_components and isinstance(planner_components[0], TextComponent):
|
||||
planner_components[0].text = planner_prefix + planner_components[0].text
|
||||
|
||||
@@ -235,8 +235,18 @@ class PluginMessageUtils:
|
||||
if isinstance(raw_forward_nodes, list):
|
||||
for node in raw_forward_nodes:
|
||||
if not isinstance(node, dict):
|
||||
logger.info(f"解析转发节点时跳过非字典节点: {node!r}")
|
||||
continue
|
||||
raw_content = node.get("content", [])
|
||||
logger.info(
|
||||
"开始解析转发节点: "
|
||||
f"message_id={node.get('message_id')!r} "
|
||||
f"user_id={node.get('user_id')!r} "
|
||||
f"user_nickname={node.get('user_nickname')!r} "
|
||||
f"user_cardname={node.get('user_cardname')!r} "
|
||||
f"raw_content_type={type(raw_content).__name__} "
|
||||
f"raw_content={raw_content!r}"
|
||||
)
|
||||
node_components: List[StandardMessageComponents] = []
|
||||
if isinstance(raw_content, list):
|
||||
node_components = [
|
||||
@@ -244,7 +254,17 @@ class PluginMessageUtils:
|
||||
for content in raw_content
|
||||
if isinstance(content, dict)
|
||||
]
|
||||
logger.info(
|
||||
"转发节点解析结果: "
|
||||
f"message_id={node.get('message_id')!r} "
|
||||
f"component_types={[component.__class__.__name__ for component in node_components]!r} "
|
||||
f"component_values={[getattr(component, 'text', None) for component in node_components]!r}"
|
||||
)
|
||||
if not node_components:
|
||||
logger.warning(
|
||||
"转发节点内容为空,使用占位文本回退: "
|
||||
f"message_id={node.get('message_id')!r} raw_content={raw_content!r}"
|
||||
)
|
||||
node_components = [TextComponent(text="[empty forward node]")]
|
||||
forward_nodes.append(
|
||||
ForwardComponent(
|
||||
|
||||
Reference in New Issue
Block a user