feat:优化重复思考,表情现在一次选所有且可配置
This commit is contained in:
@@ -17,6 +17,7 @@ from src.common.data_models.reply_generation_data_models import (
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.types import ActionInfo
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
@@ -285,18 +286,52 @@ class MaisakaReplyGenerator:
|
||||
)
|
||||
return block, selected_ids
|
||||
|
||||
def _get_related_session_ids(self, session_id: str) -> List[str]:
|
||||
"""根据表达互通组配置,解析当前会话可共享的会话 ID。"""
|
||||
related_session_ids = {session_id}
|
||||
expression_groups = global_config.expression.expression_groups
|
||||
|
||||
for expression_group in expression_groups:
|
||||
target_items = expression_group.expression_groups
|
||||
group_session_ids: set[str] = set()
|
||||
contains_current_session = False
|
||||
|
||||
for target_item in target_items:
|
||||
platform = target_item.platform.strip()
|
||||
item_id = target_item.item_id.strip()
|
||||
if not platform or not item_id:
|
||||
continue
|
||||
|
||||
rule_type = target_item.rule_type
|
||||
target_session_id = SessionUtils.calculate_session_id(
|
||||
platform,
|
||||
group_id=item_id if rule_type == "group" else None,
|
||||
user_id=None if rule_type == "group" else item_id,
|
||||
)
|
||||
group_session_ids.add(target_session_id)
|
||||
if target_session_id == session_id:
|
||||
contains_current_session = True
|
||||
|
||||
if contains_current_session:
|
||||
related_session_ids.update(group_session_ids)
|
||||
|
||||
return list(related_session_ids)
|
||||
|
||||
def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]:
|
||||
"""提取表达方式静态数据,避免 detached ORM 对象。"""
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
||||
if global_config.expression.expression_checked_only:
|
||||
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
||||
related_session_ids = self._get_related_session_ids(session_id)
|
||||
|
||||
query = query.where(
|
||||
(Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
base_query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
||||
scoped_query = base_query.where(
|
||||
(Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
|
||||
|
||||
expressions = session.exec(query.limit(5)).all()
|
||||
if global_config.expression.expression_checked_only:
|
||||
scoped_query = scoped_query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
||||
|
||||
expressions = session.exec(scoped_query.limit(5)).all()
|
||||
|
||||
return [
|
||||
_ExpressionRecord(
|
||||
expression_id=expression.id,
|
||||
|
||||
@@ -17,6 +17,7 @@ from src.common.data_models.reply_generation_data_models import (
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.types import ActionInfo
|
||||
from src.llm_models.payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart
|
||||
@@ -172,24 +173,15 @@ class MaisakaReplyGenerator:
|
||||
extra_sections.append(expression_habits.strip())
|
||||
if target_message_block:
|
||||
extra_sections.append(target_message_block)
|
||||
if reply_reason.strip():
|
||||
extra_sections.append(f"【回复信息参考】\n{reply_reason}")
|
||||
if not extra_sections:
|
||||
return system_prompt
|
||||
return f"{system_prompt}\n\n" + "\n\n".join(extra_sections)
|
||||
|
||||
def _build_reply_instruction(
|
||||
self,
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
) -> str:
|
||||
def _build_reply_instruction(self) -> str:
|
||||
"""构建追加在上下文末尾的回复指令。"""
|
||||
sections: List[str] = []
|
||||
target_message_block = self._build_target_message_block(reply_message)
|
||||
if target_message_block:
|
||||
sections.append(target_message_block)
|
||||
if reply_reason.strip():
|
||||
sections.append(f"【回复信息参考】\n{reply_reason}")
|
||||
sections.append("请基于以上逐条对话消息,自然地继续回复。直接输出你要说的话,不要额外解释。")
|
||||
return "\n\n".join(sections)
|
||||
return "请基于以上逐条对话消息,自然地继续回复。直接输出你要说的话,不要额外解释。"
|
||||
|
||||
def _build_multimodal_user_message(
|
||||
self,
|
||||
@@ -281,10 +273,7 @@ class MaisakaReplyGenerator:
|
||||
reply_reason=reply_reason,
|
||||
expression_habits=expression_habits,
|
||||
)
|
||||
instruction = self._build_reply_instruction(
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
)
|
||||
instruction = self._build_reply_instruction()
|
||||
|
||||
messages.append(MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build())
|
||||
messages.extend(self._build_history_messages(chat_history))
|
||||
@@ -369,18 +358,52 @@ class MaisakaReplyGenerator:
|
||||
)
|
||||
return block, selected_ids
|
||||
|
||||
def _get_related_session_ids(self, session_id: str) -> List[str]:
|
||||
"""根据表达互通组配置,解析当前会话可共享的会话 ID。"""
|
||||
related_session_ids = {session_id}
|
||||
expression_groups = global_config.expression.expression_groups
|
||||
|
||||
for expression_group in expression_groups:
|
||||
target_items = expression_group.expression_groups
|
||||
group_session_ids: set[str] = set()
|
||||
contains_current_session = False
|
||||
|
||||
for target_item in target_items:
|
||||
platform = target_item.platform.strip()
|
||||
item_id = target_item.item_id.strip()
|
||||
if not platform or not item_id:
|
||||
continue
|
||||
|
||||
rule_type = target_item.rule_type
|
||||
target_session_id = SessionUtils.calculate_session_id(
|
||||
platform,
|
||||
group_id=item_id if rule_type == "group" else None,
|
||||
user_id=None if rule_type == "group" else item_id,
|
||||
)
|
||||
group_session_ids.add(target_session_id)
|
||||
if target_session_id == session_id:
|
||||
contains_current_session = True
|
||||
|
||||
if contains_current_session:
|
||||
related_session_ids.update(group_session_ids)
|
||||
|
||||
return list(related_session_ids)
|
||||
|
||||
def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]:
|
||||
"""提取表达方式静态数据,避免 detached ORM 对象。"""
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
||||
if global_config.expression.expression_checked_only:
|
||||
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
||||
related_session_ids = self._get_related_session_ids(session_id)
|
||||
|
||||
query = query.where(
|
||||
(Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
base_query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
||||
scoped_query = base_query.where(
|
||||
(Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
|
||||
|
||||
expressions = session.exec(query.limit(5)).all()
|
||||
if global_config.expression.expression_checked_only:
|
||||
scoped_query = scoped_query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
||||
|
||||
expressions = session.exec(scoped_query.limit(5)).all()
|
||||
|
||||
return [
|
||||
_ExpressionRecord(
|
||||
expression_id=expression.id,
|
||||
|
||||
@@ -716,17 +716,16 @@ class EmojiConfig(ConfigBase):
|
||||
__ui_label__ = "功能"
|
||||
__ui_icon__ = "puzzle"
|
||||
|
||||
emoji_chance: float = Field(
|
||||
default=0.4,
|
||||
ge=0,
|
||||
le=1,
|
||||
emoji_send_num: int = Field(
|
||||
default=25,
|
||||
ge=1,
|
||||
le=64,
|
||||
json_schema_extra={
|
||||
"x-widget": "slider",
|
||||
"x-icon": "smile",
|
||||
"step": 0.1,
|
||||
"x-widget": "input",
|
||||
"x-icon": "grid",
|
||||
},
|
||||
)
|
||||
"""发送表情包的基础概率"""
|
||||
"""一次从多少个表情包中选择发送,最大为 64"""
|
||||
|
||||
max_reg_num: int = Field(
|
||||
default=100,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
import math
|
||||
from random import sample
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -16,6 +17,7 @@ from src.chat.emoji_system.maisaka_tool import send_emoji_for_maisaka
|
||||
from src.common.data_models.image_data_model import MaiEmoji
|
||||
from src.common.data_models.message_component_data_model import ImageComponent, MessageSequence, TextComponent
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||
from src.maisaka.context_messages import (
|
||||
@@ -31,7 +33,7 @@ logger = get_logger("maisaka_builtin_send_emoji")
|
||||
|
||||
_EMOJI_SUB_AGENT_CONTEXT_LIMIT = 12
|
||||
_EMOJI_SUB_AGENT_MAX_TOKENS = 240
|
||||
_EMOJI_CANDIDATE_COUNT = 25
|
||||
_EMOJI_MAX_CANDIDATE_COUNT = 64
|
||||
_EMOJI_CANDIDATE_TILE_SIZE = 256
|
||||
_EMOJI_SUCCESS_MESSAGE = "表情包发送成功"
|
||||
|
||||
@@ -70,6 +72,36 @@ async def _load_emoji_bytes(emoji: MaiEmoji) -> bytes:
|
||||
return await asyncio.to_thread(emoji.full_path.read_bytes)
|
||||
|
||||
|
||||
def _get_emoji_candidate_count() -> int:
|
||||
"""获取本次表情包候选数量配置。"""
|
||||
|
||||
configured_count = int(getattr(global_config.emoji, "emoji_send_num", 25))
|
||||
return max(1, min(configured_count, _EMOJI_MAX_CANDIDATE_COUNT))
|
||||
|
||||
|
||||
def _calculate_grid_shape(candidate_count: int) -> tuple[int, int]:
|
||||
"""根据候选数量计算尽量接近矩形的拼图行列数。"""
|
||||
|
||||
if candidate_count <= 0:
|
||||
return 1, 1
|
||||
|
||||
best_columns = candidate_count
|
||||
best_rows = 1
|
||||
best_score: tuple[int, int] | None = None
|
||||
|
||||
for columns in range(1, candidate_count + 1):
|
||||
rows = math.ceil(candidate_count / columns)
|
||||
empty_slots = rows * columns - candidate_count
|
||||
aspect_gap = abs(columns - rows)
|
||||
score = (aspect_gap, empty_slots)
|
||||
if best_score is None or score < best_score:
|
||||
best_score = score
|
||||
best_columns = columns
|
||||
best_rows = rows
|
||||
|
||||
return best_rows, best_columns
|
||||
|
||||
|
||||
def _build_placeholder_tile(label: str, tile_size: int) -> PILImage.Image:
|
||||
"""构建图片读取失败时使用的占位图。"""
|
||||
|
||||
@@ -134,12 +166,12 @@ def _build_labeled_tile(image_bytes: bytes, index: int, tile_size: int) -> PILIm
|
||||
|
||||
|
||||
def _merge_emoji_tiles(image_bytes_list: list[bytes]) -> bytes:
|
||||
"""将候选表情图拼接成一张 5x5 网格图片。"""
|
||||
"""将候选表情图拼接成一张尽量接近矩形的网格图片。"""
|
||||
|
||||
tile_size = _EMOJI_CANDIDATE_TILE_SIZE
|
||||
gap = 12
|
||||
grid_columns = 5
|
||||
grid_rows = 5
|
||||
candidate_count = len(image_bytes_list)
|
||||
grid_rows, grid_columns = _calculate_grid_shape(candidate_count)
|
||||
tiles = [
|
||||
_build_labeled_tile(image_bytes=image_bytes, index=index, tile_size=tile_size)
|
||||
for index, image_bytes in enumerate(image_bytes_list, start=1)
|
||||
@@ -161,7 +193,7 @@ def _merge_emoji_tiles(image_bytes_list: list[bytes]) -> bytes:
|
||||
|
||||
|
||||
async def _build_emoji_candidate_message(emojis: list[MaiEmoji]) -> SessionBackedMessage:
|
||||
"""构建供子代理挑选的 5x5 拼图候选消息。"""
|
||||
"""构建供子代理挑选的拼图候选消息。"""
|
||||
|
||||
image_bytes_list = await asyncio.gather(*[_load_emoji_bytes(emoji) for emoji in emojis])
|
||||
merged_image_bytes = await asyncio.to_thread(_merge_emoji_tiles, list(image_bytes_list))
|
||||
@@ -195,15 +227,17 @@ async def _select_emoji_with_sub_agent(
|
||||
if not available_emojis:
|
||||
return None, ""
|
||||
|
||||
total_candidate_count = min(len(available_emojis), _EMOJI_CANDIDATE_COUNT)
|
||||
total_candidate_count = min(len(available_emojis), _get_emoji_candidate_count())
|
||||
sampled_emojis = sample(available_emojis, total_candidate_count)
|
||||
candidate_message = await _build_emoji_candidate_message(sampled_emojis)
|
||||
grid_rows, grid_columns = _calculate_grid_shape(len(sampled_emojis))
|
||||
|
||||
system_prompt = (
|
||||
"你是 Maisaka 的临时表情包选择子代理。\n"
|
||||
"你会收到群聊上下文,以及 1 条额外候选消息,其中包含一张 5x5 的表情包拼图,一共 25 个位置。\n"
|
||||
"每张小图左上角都有一个较大的序号,范围是 1 到 25。\n"
|
||||
"你的任务是根据上下文和当前语气,从这 25 张图里选出最合适的一张表情包。\n"
|
||||
f"你会收到群聊上下文,以及 1 条额外候选消息,其中包含一张 {grid_rows}x{grid_columns} 的表情包拼图,"
|
||||
f"一共 {len(sampled_emojis)} 个位置。\n"
|
||||
f"每张小图左上角都有一个较大的序号,范围是 1 到 {len(sampled_emojis)}。\n"
|
||||
f"你的任务是根据上下文和当前语气,从这 {len(sampled_emojis)} 张图里选出最合适的一张表情包。\n"
|
||||
"如果提供了 requested_emotion,请优先考虑与其接近的候选;如果没有完全匹配,则选择最符合上下文语气的候选。\n"
|
||||
"你必须返回一个 JSON 对象(json object),不要输出任何 JSON 之外的内容。\n"
|
||||
'返回格式固定为:{"emoji_index":1,"reason":"简短理由"}'
|
||||
@@ -212,6 +246,8 @@ async def _select_emoji_with_sub_agent(
|
||||
content=(
|
||||
f"[选择任务]\n"
|
||||
f"requested_emotion: {requested_emotion or '未指定'}\n"
|
||||
f"候选总数: {len(sampled_emojis)}\n"
|
||||
f"拼图布局: {grid_rows}x{grid_columns}\n"
|
||||
"请只输出 JSON。"
|
||||
),
|
||||
timestamp=datetime.now(),
|
||||
|
||||
Reference in New Issue
Block a user