feat:优化重复思考,表情现在一次选所有且可配置

This commit is contained in:
SengokuCola
2026-04-04 01:55:05 +08:00
parent 97fb4cb36f
commit eda8ce66f0
4 changed files with 141 additions and 48 deletions

View File

@@ -17,6 +17,7 @@ from src.common.data_models.reply_generation_data_models import (
) )
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt from src.common.prompt_i18n import load_prompt
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config from src.config.config import global_config
from src.core.types import ActionInfo from src.core.types import ActionInfo
from src.services.llm_service import LLMServiceClient from src.services.llm_service import LLMServiceClient
@@ -285,18 +286,52 @@ class MaisakaReplyGenerator:
) )
return block, selected_ids return block, selected_ids
def _get_related_session_ids(self, session_id: str) -> List[str]:
"""根据表达互通组配置,解析当前会话可共享的会话 ID。"""
related_session_ids = {session_id}
expression_groups = global_config.expression.expression_groups
for expression_group in expression_groups:
target_items = expression_group.expression_groups
group_session_ids: set[str] = set()
contains_current_session = False
for target_item in target_items:
platform = target_item.platform.strip()
item_id = target_item.item_id.strip()
if not platform or not item_id:
continue
rule_type = target_item.rule_type
target_session_id = SessionUtils.calculate_session_id(
platform,
group_id=item_id if rule_type == "group" else None,
user_id=None if rule_type == "group" else item_id,
)
group_session_ids.add(target_session_id)
if target_session_id == session_id:
contains_current_session = True
if contains_current_session:
related_session_ids.update(group_session_ids)
return list(related_session_ids)
def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]: def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]:
"""提取表达方式静态数据,避免 detached ORM 对象。""" """提取表达方式静态数据,避免 detached ORM 对象。"""
with get_db_session(auto_commit=False) as session: related_session_ids = self._get_related_session_ids(session_id)
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
if global_config.expression.expression_checked_only:
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
query = query.where( with get_db_session(auto_commit=False) as session:
(Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined] base_query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
scoped_query = base_query.where(
(Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined] ).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
expressions = session.exec(query.limit(5)).all() if global_config.expression.expression_checked_only:
scoped_query = scoped_query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
expressions = session.exec(scoped_query.limit(5)).all()
return [ return [
_ExpressionRecord( _ExpressionRecord(
expression_id=expression.id, expression_id=expression.id,

View File

@@ -17,6 +17,7 @@ from src.common.data_models.reply_generation_data_models import (
) )
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt from src.common.prompt_i18n import load_prompt
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config from src.config.config import global_config
from src.core.types import ActionInfo from src.core.types import ActionInfo
from src.llm_models.payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart from src.llm_models.payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart
@@ -172,24 +173,15 @@ class MaisakaReplyGenerator:
extra_sections.append(expression_habits.strip()) extra_sections.append(expression_habits.strip())
if target_message_block: if target_message_block:
extra_sections.append(target_message_block) extra_sections.append(target_message_block)
if reply_reason.strip():
extra_sections.append(f"【回复信息参考】\n{reply_reason}")
if not extra_sections: if not extra_sections:
return system_prompt return system_prompt
return f"{system_prompt}\n\n" + "\n\n".join(extra_sections) return f"{system_prompt}\n\n" + "\n\n".join(extra_sections)
def _build_reply_instruction( def _build_reply_instruction(self) -> str:
self,
reply_message: Optional[SessionMessage],
reply_reason: str,
) -> str:
"""构建追加在上下文末尾的回复指令。""" """构建追加在上下文末尾的回复指令。"""
sections: List[str] = [] return "请基于以上逐条对话消息,自然地继续回复。直接输出你要说的话,不要额外解释。"
target_message_block = self._build_target_message_block(reply_message)
if target_message_block:
sections.append(target_message_block)
if reply_reason.strip():
sections.append(f"【回复信息参考】\n{reply_reason}")
sections.append("请基于以上逐条对话消息,自然地继续回复。直接输出你要说的话,不要额外解释。")
return "\n\n".join(sections)
def _build_multimodal_user_message( def _build_multimodal_user_message(
self, self,
@@ -281,10 +273,7 @@ class MaisakaReplyGenerator:
reply_reason=reply_reason, reply_reason=reply_reason,
expression_habits=expression_habits, expression_habits=expression_habits,
) )
instruction = self._build_reply_instruction( instruction = self._build_reply_instruction()
reply_message=reply_message,
reply_reason=reply_reason,
)
messages.append(MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build()) messages.append(MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build())
messages.extend(self._build_history_messages(chat_history)) messages.extend(self._build_history_messages(chat_history))
@@ -369,18 +358,52 @@ class MaisakaReplyGenerator:
) )
return block, selected_ids return block, selected_ids
def _get_related_session_ids(self, session_id: str) -> List[str]:
"""根据表达互通组配置,解析当前会话可共享的会话 ID。"""
related_session_ids = {session_id}
expression_groups = global_config.expression.expression_groups
for expression_group in expression_groups:
target_items = expression_group.expression_groups
group_session_ids: set[str] = set()
contains_current_session = False
for target_item in target_items:
platform = target_item.platform.strip()
item_id = target_item.item_id.strip()
if not platform or not item_id:
continue
rule_type = target_item.rule_type
target_session_id = SessionUtils.calculate_session_id(
platform,
group_id=item_id if rule_type == "group" else None,
user_id=None if rule_type == "group" else item_id,
)
group_session_ids.add(target_session_id)
if target_session_id == session_id:
contains_current_session = True
if contains_current_session:
related_session_ids.update(group_session_ids)
return list(related_session_ids)
def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]: def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]:
"""提取表达方式静态数据,避免 detached ORM 对象。""" """提取表达方式静态数据,避免 detached ORM 对象。"""
with get_db_session(auto_commit=False) as session: related_session_ids = self._get_related_session_ids(session_id)
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
if global_config.expression.expression_checked_only:
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
query = query.where( with get_db_session(auto_commit=False) as session:
(Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined] base_query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
scoped_query = base_query.where(
(Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined] ).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
expressions = session.exec(query.limit(5)).all() if global_config.expression.expression_checked_only:
scoped_query = scoped_query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
expressions = session.exec(scoped_query.limit(5)).all()
return [ return [
_ExpressionRecord( _ExpressionRecord(
expression_id=expression.id, expression_id=expression.id,

View File

@@ -716,17 +716,16 @@ class EmojiConfig(ConfigBase):
__ui_label__ = "功能" __ui_label__ = "功能"
__ui_icon__ = "puzzle" __ui_icon__ = "puzzle"
emoji_chance: float = Field( emoji_send_num: int = Field(
default=0.4, default=25,
ge=0, ge=1,
le=1, le=64,
json_schema_extra={ json_schema_extra={
"x-widget": "slider", "x-widget": "input",
"x-icon": "smile", "x-icon": "grid",
"step": 0.1,
}, },
) )
"""发送表情包的基础概率""" """一次从多少个表情包中选择发送,最大为 64"""
max_reg_num: int = Field( max_reg_num: int = Field(
default=100, default=100,

View File

@@ -2,6 +2,7 @@
from datetime import datetime from datetime import datetime
from io import BytesIO from io import BytesIO
import math
from random import sample from random import sample
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
@@ -16,6 +17,7 @@ from src.chat.emoji_system.maisaka_tool import send_emoji_for_maisaka
from src.common.data_models.image_data_model import MaiEmoji from src.common.data_models.image_data_model import MaiEmoji
from src.common.data_models.message_component_data_model import ImageComponent, MessageSequence, TextComponent from src.common.data_models.message_component_data_model import ImageComponent, MessageSequence, TextComponent
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
from src.maisaka.context_messages import ( from src.maisaka.context_messages import (
@@ -31,7 +33,7 @@ logger = get_logger("maisaka_builtin_send_emoji")
_EMOJI_SUB_AGENT_CONTEXT_LIMIT = 12 _EMOJI_SUB_AGENT_CONTEXT_LIMIT = 12
_EMOJI_SUB_AGENT_MAX_TOKENS = 240 _EMOJI_SUB_AGENT_MAX_TOKENS = 240
_EMOJI_CANDIDATE_COUNT = 25 _EMOJI_MAX_CANDIDATE_COUNT = 64
_EMOJI_CANDIDATE_TILE_SIZE = 256 _EMOJI_CANDIDATE_TILE_SIZE = 256
_EMOJI_SUCCESS_MESSAGE = "表情包发送成功" _EMOJI_SUCCESS_MESSAGE = "表情包发送成功"
@@ -70,6 +72,36 @@ async def _load_emoji_bytes(emoji: MaiEmoji) -> bytes:
return await asyncio.to_thread(emoji.full_path.read_bytes) return await asyncio.to_thread(emoji.full_path.read_bytes)
def _get_emoji_candidate_count() -> int:
"""获取本次表情包候选数量配置。"""
configured_count = int(getattr(global_config.emoji, "emoji_send_num", 25))
return max(1, min(configured_count, _EMOJI_MAX_CANDIDATE_COUNT))
def _calculate_grid_shape(candidate_count: int) -> tuple[int, int]:
"""根据候选数量计算尽量接近矩形的拼图行列数。"""
if candidate_count <= 0:
return 1, 1
best_columns = candidate_count
best_rows = 1
best_score: tuple[int, int] | None = None
for columns in range(1, candidate_count + 1):
rows = math.ceil(candidate_count / columns)
empty_slots = rows * columns - candidate_count
aspect_gap = abs(columns - rows)
score = (aspect_gap, empty_slots)
if best_score is None or score < best_score:
best_score = score
best_columns = columns
best_rows = rows
return best_rows, best_columns
def _build_placeholder_tile(label: str, tile_size: int) -> PILImage.Image: def _build_placeholder_tile(label: str, tile_size: int) -> PILImage.Image:
"""构建图片读取失败时使用的占位图。""" """构建图片读取失败时使用的占位图。"""
@@ -134,12 +166,12 @@ def _build_labeled_tile(image_bytes: bytes, index: int, tile_size: int) -> PILIm
def _merge_emoji_tiles(image_bytes_list: list[bytes]) -> bytes: def _merge_emoji_tiles(image_bytes_list: list[bytes]) -> bytes:
"""将候选表情图拼接成一张 5x5 网格图片。""" """将候选表情图拼接成一张尽量接近矩形的网格图片。"""
tile_size = _EMOJI_CANDIDATE_TILE_SIZE tile_size = _EMOJI_CANDIDATE_TILE_SIZE
gap = 12 gap = 12
grid_columns = 5 candidate_count = len(image_bytes_list)
grid_rows = 5 grid_rows, grid_columns = _calculate_grid_shape(candidate_count)
tiles = [ tiles = [
_build_labeled_tile(image_bytes=image_bytes, index=index, tile_size=tile_size) _build_labeled_tile(image_bytes=image_bytes, index=index, tile_size=tile_size)
for index, image_bytes in enumerate(image_bytes_list, start=1) for index, image_bytes in enumerate(image_bytes_list, start=1)
@@ -161,7 +193,7 @@ def _merge_emoji_tiles(image_bytes_list: list[bytes]) -> bytes:
async def _build_emoji_candidate_message(emojis: list[MaiEmoji]) -> SessionBackedMessage: async def _build_emoji_candidate_message(emojis: list[MaiEmoji]) -> SessionBackedMessage:
"""构建供子代理挑选的 5x5 拼图候选消息。""" """构建供子代理挑选的拼图候选消息。"""
image_bytes_list = await asyncio.gather(*[_load_emoji_bytes(emoji) for emoji in emojis]) image_bytes_list = await asyncio.gather(*[_load_emoji_bytes(emoji) for emoji in emojis])
merged_image_bytes = await asyncio.to_thread(_merge_emoji_tiles, list(image_bytes_list)) merged_image_bytes = await asyncio.to_thread(_merge_emoji_tiles, list(image_bytes_list))
@@ -195,15 +227,17 @@ async def _select_emoji_with_sub_agent(
if not available_emojis: if not available_emojis:
return None, "" return None, ""
total_candidate_count = min(len(available_emojis), _EMOJI_CANDIDATE_COUNT) total_candidate_count = min(len(available_emojis), _get_emoji_candidate_count())
sampled_emojis = sample(available_emojis, total_candidate_count) sampled_emojis = sample(available_emojis, total_candidate_count)
candidate_message = await _build_emoji_candidate_message(sampled_emojis) candidate_message = await _build_emoji_candidate_message(sampled_emojis)
grid_rows, grid_columns = _calculate_grid_shape(len(sampled_emojis))
system_prompt = ( system_prompt = (
"你是 Maisaka 的临时表情包选择子代理。\n" "你是 Maisaka 的临时表情包选择子代理。\n"
"你会收到群聊上下文,以及 1 条额外候选消息,其中包含一张 5x5 的表情包拼图,一共 25 个位置。\n" f"你会收到群聊上下文,以及 1 条额外候选消息,其中包含一张 {grid_rows}x{grid_columns} 的表情包拼图,"
"每张小图左上角都有一个较大的序号,范围是 1 到 25\n" f"一共 {len(sampled_emojis)} 个位置\n"
"你的任务是根据上下文和当前语气,从这 25 张图里选出最合适的一张表情包\n" f"每张小图左上角都有一个较大的序号,范围是 1 到 {len(sampled_emojis)}\n"
f"你的任务是根据上下文和当前语气,从这 {len(sampled_emojis)} 张图里选出最合适的一张表情包。\n"
"如果提供了 requested_emotion请优先考虑与其接近的候选如果没有完全匹配则选择最符合上下文语气的候选。\n" "如果提供了 requested_emotion请优先考虑与其接近的候选如果没有完全匹配则选择最符合上下文语气的候选。\n"
"你必须返回一个 JSON 对象json object不要输出任何 JSON 之外的内容。\n" "你必须返回一个 JSON 对象json object不要输出任何 JSON 之外的内容。\n"
'返回格式固定为:{"emoji_index":1,"reason":"简短理由"}' '返回格式固定为:{"emoji_index":1,"reason":"简短理由"}'
@@ -212,6 +246,8 @@ async def _select_emoji_with_sub_agent(
content=( content=(
f"[选择任务]\n" f"[选择任务]\n"
f"requested_emotion: {requested_emotion or '未指定'}\n" f"requested_emotion: {requested_emotion or '未指定'}\n"
f"候选总数: {len(sampled_emojis)}\n"
f"拼图布局: {grid_rows}x{grid_columns}\n"
"请只输出 JSON。" "请只输出 JSON。"
), ),
timestamp=datetime.now(), timestamp=datetime.now(),