From eda8ce66f0264645d8198202f906b84147ab4418 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sat, 4 Apr 2026 01:55:05 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E4=BC=98=E5=8C=96=E9=87=8D?= =?UTF-8?q?=E5=A4=8D=E6=80=9D=E8=80=83=EF=BC=8C=E8=A1=A8=E6=83=85=E7=8E=B0?= =?UTF-8?q?=E5=9C=A8=E4=B8=80=E6=AC=A1=E9=80=89=E6=89=80=E6=9C=89=E4=B8=94?= =?UTF-8?q?=E5=8F=AF=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/replyer/maisaka_generator.py | 49 ++++++++++++-- src/chat/replyer/maisaka_generator_multi.py | 71 ++++++++++++++------- src/config/official_configs.py | 15 ++--- src/maisaka/builtin_tool/send_emoji.py | 54 +++++++++++++--- 4 files changed, 141 insertions(+), 48 deletions(-) diff --git a/src/chat/replyer/maisaka_generator.py b/src/chat/replyer/maisaka_generator.py index 4d177ba5..05df760a 100644 --- a/src/chat/replyer/maisaka_generator.py +++ b/src/chat/replyer/maisaka_generator.py @@ -17,6 +17,7 @@ from src.common.data_models.reply_generation_data_models import ( ) from src.common.logger import get_logger from src.common.prompt_i18n import load_prompt +from src.common.utils.utils_session import SessionUtils from src.config.config import global_config from src.core.types import ActionInfo from src.services.llm_service import LLMServiceClient @@ -285,18 +286,52 @@ class MaisakaReplyGenerator: ) return block, selected_ids + def _get_related_session_ids(self, session_id: str) -> List[str]: + """根据表达互通组配置,解析当前会话可共享的会话 ID。""" + related_session_ids = {session_id} + expression_groups = global_config.expression.expression_groups + + for expression_group in expression_groups: + target_items = expression_group.expression_groups + group_session_ids: set[str] = set() + contains_current_session = False + + for target_item in target_items: + platform = target_item.platform.strip() + item_id = target_item.item_id.strip() + if not platform or not item_id: + continue + + rule_type = target_item.rule_type + target_session_id = SessionUtils.calculate_session_id( + platform, + group_id=item_id if rule_type == "group" else None, + user_id=None if rule_type == "group" else item_id, + ) + group_session_ids.add(target_session_id) + if target_session_id == session_id: + contains_current_session = True + + if contains_current_session: + related_session_ids.update(group_session_ids) + + return list(related_session_ids) + def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]: """提取表达方式静态数据,避免 detached ORM 对象。""" - with get_db_session(auto_commit=False) as session: - query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined] - if global_config.expression.expression_checked_only: - query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined] + related_session_ids = self._get_related_session_ids(session_id) - query = query.where( - (Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined] + with get_db_session(auto_commit=False) as session: + base_query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined] + scoped_query = base_query.where( + (Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined] ).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined] - expressions = session.exec(query.limit(5)).all() + if global_config.expression.expression_checked_only: + scoped_query = scoped_query.where(Expression.checked.is_(True)) # type: ignore[attr-defined] + + expressions = session.exec(scoped_query.limit(5)).all() + return [ _ExpressionRecord( expression_id=expression.id, diff --git a/src/chat/replyer/maisaka_generator_multi.py b/src/chat/replyer/maisaka_generator_multi.py index 1db6c555..a3befb72 100644 --- a/src/chat/replyer/maisaka_generator_multi.py +++ b/src/chat/replyer/maisaka_generator_multi.py @@ -17,6 +17,7 @@ from src.common.data_models.reply_generation_data_models import ( ) from src.common.logger import get_logger from src.common.prompt_i18n import load_prompt +from src.common.utils.utils_session import SessionUtils from src.config.config import global_config from src.core.types import ActionInfo from src.llm_models.payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart @@ -172,24 +173,15 @@ class MaisakaReplyGenerator: extra_sections.append(expression_habits.strip()) if target_message_block: extra_sections.append(target_message_block) + if reply_reason.strip(): + extra_sections.append(f"【回复信息参考】\n{reply_reason}") if not extra_sections: return system_prompt return f"{system_prompt}\n\n" + "\n\n".join(extra_sections) - def _build_reply_instruction( - self, - reply_message: Optional[SessionMessage], - reply_reason: str, - ) -> str: + def _build_reply_instruction(self) -> str: """构建追加在上下文末尾的回复指令。""" - sections: List[str] = [] - target_message_block = self._build_target_message_block(reply_message) - if target_message_block: - sections.append(target_message_block) - if reply_reason.strip(): - sections.append(f"【回复信息参考】\n{reply_reason}") - sections.append("请基于以上逐条对话消息,自然地继续回复。直接输出你要说的话,不要额外解释。") - return "\n\n".join(sections) + return "请基于以上逐条对话消息,自然地继续回复。直接输出你要说的话,不要额外解释。" def _build_multimodal_user_message( self, @@ -281,10 +273,7 @@ class MaisakaReplyGenerator: reply_reason=reply_reason, expression_habits=expression_habits, ) - instruction = self._build_reply_instruction( - reply_message=reply_message, - reply_reason=reply_reason, - ) + instruction = self._build_reply_instruction() messages.append(MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build()) messages.extend(self._build_history_messages(chat_history)) @@ -369,18 +358,52 @@ class MaisakaReplyGenerator: ) return block, selected_ids + def _get_related_session_ids(self, session_id: str) -> List[str]: + """根据表达互通组配置,解析当前会话可共享的会话 ID。""" + related_session_ids = {session_id} + expression_groups = global_config.expression.expression_groups + + for expression_group in expression_groups: + target_items = expression_group.expression_groups + group_session_ids: set[str] = set() + contains_current_session = False + + for target_item in target_items: + platform = target_item.platform.strip() + item_id = target_item.item_id.strip() + if not platform or not item_id: + continue + + rule_type = target_item.rule_type + target_session_id = SessionUtils.calculate_session_id( + platform, + group_id=item_id if rule_type == "group" else None, + user_id=None if rule_type == "group" else item_id, + ) + group_session_ids.add(target_session_id) + if target_session_id == session_id: + contains_current_session = True + + if contains_current_session: + related_session_ids.update(group_session_ids) + + return list(related_session_ids) + def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]: """提取表达方式静态数据,避免 detached ORM 对象。""" - with get_db_session(auto_commit=False) as session: - query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined] - if global_config.expression.expression_checked_only: - query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined] + related_session_ids = self._get_related_session_ids(session_id) - query = query.where( - (Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined] + with get_db_session(auto_commit=False) as session: + base_query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined] + scoped_query = base_query.where( + (Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined] ).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined] - expressions = session.exec(query.limit(5)).all() + if global_config.expression.expression_checked_only: + scoped_query = scoped_query.where(Expression.checked.is_(True)) # type: ignore[attr-defined] + + expressions = session.exec(scoped_query.limit(5)).all() + return [ _ExpressionRecord( expression_id=expression.id, diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 2b3970f0..1b3944ef 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -716,17 +716,16 @@ class EmojiConfig(ConfigBase): __ui_label__ = "功能" __ui_icon__ = "puzzle" - emoji_chance: float = Field( - default=0.4, - ge=0, - le=1, + emoji_send_num: int = Field( + default=25, + ge=1, + le=64, json_schema_extra={ - "x-widget": "slider", - "x-icon": "smile", - "step": 0.1, + "x-widget": "input", + "x-icon": "grid", }, ) - """发送表情包的基础概率""" + """一次从多少个表情包中选择发送,最大为 64""" max_reg_num: int = Field( default=100, diff --git a/src/maisaka/builtin_tool/send_emoji.py b/src/maisaka/builtin_tool/send_emoji.py index dccf7b3d..4e014202 100644 --- a/src/maisaka/builtin_tool/send_emoji.py +++ b/src/maisaka/builtin_tool/send_emoji.py @@ -2,6 +2,7 @@ from datetime import datetime from io import BytesIO +import math from random import sample from typing import Any, Dict, Optional @@ -16,6 +17,7 @@ from src.chat.emoji_system.maisaka_tool import send_emoji_for_maisaka from src.common.data_models.image_data_model import MaiEmoji from src.common.data_models.message_component_data_model import ImageComponent, MessageSequence, TextComponent from src.common.logger import get_logger +from src.config.config import global_config from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType from src.maisaka.context_messages import ( @@ -31,7 +33,7 @@ logger = get_logger("maisaka_builtin_send_emoji") _EMOJI_SUB_AGENT_CONTEXT_LIMIT = 12 _EMOJI_SUB_AGENT_MAX_TOKENS = 240 -_EMOJI_CANDIDATE_COUNT = 25 +_EMOJI_MAX_CANDIDATE_COUNT = 64 _EMOJI_CANDIDATE_TILE_SIZE = 256 _EMOJI_SUCCESS_MESSAGE = "表情包发送成功" @@ -70,6 +72,36 @@ async def _load_emoji_bytes(emoji: MaiEmoji) -> bytes: return await asyncio.to_thread(emoji.full_path.read_bytes) +def _get_emoji_candidate_count() -> int: + """获取本次表情包候选数量配置。""" + + configured_count = int(getattr(global_config.emoji, "emoji_send_num", 25)) + return max(1, min(configured_count, _EMOJI_MAX_CANDIDATE_COUNT)) + + +def _calculate_grid_shape(candidate_count: int) -> tuple[int, int]: + """根据候选数量计算尽量接近矩形的拼图行列数。""" + + if candidate_count <= 0: + return 1, 1 + + best_columns = candidate_count + best_rows = 1 + best_score: tuple[int, int] | None = None + + for columns in range(1, candidate_count + 1): + rows = math.ceil(candidate_count / columns) + empty_slots = rows * columns - candidate_count + aspect_gap = abs(columns - rows) + score = (aspect_gap, empty_slots) + if best_score is None or score < best_score: + best_score = score + best_columns = columns + best_rows = rows + + return best_rows, best_columns + + def _build_placeholder_tile(label: str, tile_size: int) -> PILImage.Image: """构建图片读取失败时使用的占位图。""" @@ -134,12 +166,12 @@ def _build_labeled_tile(image_bytes: bytes, index: int, tile_size: int) -> PILIm def _merge_emoji_tiles(image_bytes_list: list[bytes]) -> bytes: - """将候选表情图拼接成一张 5x5 网格图片。""" + """将候选表情图拼接成一张尽量接近矩形的网格图片。""" tile_size = _EMOJI_CANDIDATE_TILE_SIZE gap = 12 - grid_columns = 5 - grid_rows = 5 + candidate_count = len(image_bytes_list) + grid_rows, grid_columns = _calculate_grid_shape(candidate_count) tiles = [ _build_labeled_tile(image_bytes=image_bytes, index=index, tile_size=tile_size) for index, image_bytes in enumerate(image_bytes_list, start=1) @@ -161,7 +193,7 @@ def _merge_emoji_tiles(image_bytes_list: list[bytes]) -> bytes: async def _build_emoji_candidate_message(emojis: list[MaiEmoji]) -> SessionBackedMessage: - """构建供子代理挑选的 5x5 拼图候选消息。""" + """构建供子代理挑选的拼图候选消息。""" image_bytes_list = await asyncio.gather(*[_load_emoji_bytes(emoji) for emoji in emojis]) merged_image_bytes = await asyncio.to_thread(_merge_emoji_tiles, list(image_bytes_list)) @@ -195,15 +227,17 @@ async def _select_emoji_with_sub_agent( if not available_emojis: return None, "" - total_candidate_count = min(len(available_emojis), _EMOJI_CANDIDATE_COUNT) + total_candidate_count = min(len(available_emojis), _get_emoji_candidate_count()) sampled_emojis = sample(available_emojis, total_candidate_count) candidate_message = await _build_emoji_candidate_message(sampled_emojis) + grid_rows, grid_columns = _calculate_grid_shape(len(sampled_emojis)) system_prompt = ( "你是 Maisaka 的临时表情包选择子代理。\n" - "你会收到群聊上下文,以及 1 条额外候选消息,其中包含一张 5x5 的表情包拼图,一共 25 个位置。\n" - "每张小图左上角都有一个较大的序号,范围是 1 到 25。\n" - "你的任务是根据上下文和当前语气,从这 25 张图里选出最合适的一张表情包。\n" + f"你会收到群聊上下文,以及 1 条额外候选消息,其中包含一张 {grid_rows}x{grid_columns} 的表情包拼图," + f"一共 {len(sampled_emojis)} 个位置。\n" + f"每张小图左上角都有一个较大的序号,范围是 1 到 {len(sampled_emojis)}。\n" + f"你的任务是根据上下文和当前语气,从这 {len(sampled_emojis)} 张图里选出最合适的一张表情包。\n" "如果提供了 requested_emotion,请优先考虑与其接近的候选;如果没有完全匹配,则选择最符合上下文语气的候选。\n" "你必须返回一个 JSON 对象(json object),不要输出任何 JSON 之外的内容。\n" '返回格式固定为:{"emoji_index":1,"reason":"简短理由"}' @@ -212,6 +246,8 @@ async def _select_emoji_with_sub_agent( content=( f"[选择任务]\n" f"requested_emotion: {requested_emotion or '未指定'}\n" + f"候选总数: {len(sampled_emojis)}\n" + f"拼图布局: {grid_rows}x{grid_columns}\n" "请只输出 JSON。" ), timestamp=datetime.now(),