Files
mai-bot/src/maisaka/chat_history_visual_refresher.py

139 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Maisaka 聊天历史视觉占位刷新器。"""
from typing import Awaitable, Callable, Optional
from sqlmodel import select
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.message_component_data_model import EmojiComponent, ForwardNodeComponent, ImageComponent
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
from src.common.logger import get_logger
from .context_messages import LLMContextMessage, SessionBackedMessage
logger = get_logger("maisaka_chat_history_visual_refresher")
BuildHistoryMessage = Callable[[SessionMessage, str], Awaitable[Optional[LLMContextMessage]]]
BuildVisibleText = Callable[[SessionMessage], str]
async def refresh_chat_history_visual_placeholders(
*,
chat_history: list[LLMContextMessage],
build_history_message: BuildHistoryMessage,
build_visible_text: BuildVisibleText,
) -> int:
"""在进入新一轮规划前,尝试用已完成的识图结果刷新历史占位。"""
refreshed_count = 0
for index, history_message in enumerate(chat_history):
if not isinstance(history_message, SessionBackedMessage):
continue
original_message = history_message.original_message
if original_message is None:
continue
visual_components_updated = _refresh_pending_visual_components(original_message.raw_message.components)
if visual_components_updated:
await original_message.process(
enable_heavy_media_analysis=False,
enable_voice_transcription=False,
)
refreshed_visible_text = build_visible_text(original_message)
if not visual_components_updated and refreshed_visible_text == history_message.visible_text:
continue
rebuilt_history_message = await build_history_message(original_message, history_message.source_kind)
if rebuilt_history_message is None:
continue
chat_history[index] = rebuilt_history_message
refreshed_count += 1
return refreshed_count
def _refresh_pending_visual_components(components: list[object]) -> bool:
"""用缓存中的描述更新尚未补全文本的图片与表情组件。"""
refreshed = False
for component in components:
if isinstance(component, ImageComponent):
if _should_refresh_image_component(component):
image_description = _lookup_cached_image_description(component.binary_hash)
if image_description:
component.content = f"[图片:{image_description}]"
refreshed = True
continue
if isinstance(component, EmojiComponent):
if _should_refresh_emoji_component(component):
emoji_description = _lookup_cached_emoji_description(component.binary_hash)
if emoji_description:
component.content = f"[表情包: {emoji_description}]"
refreshed = True
continue
if not isinstance(component, ForwardNodeComponent):
continue
for forward_component in component.forward_components:
if _refresh_pending_visual_components(forward_component.content):
refreshed = True
return refreshed
def _should_refresh_image_component(component: ImageComponent) -> bool:
"""判断图片组件当前是否仍处于待补全文本的占位状态。"""
return not component.content or component.content == "[图片]"
def _should_refresh_emoji_component(component: EmojiComponent) -> bool:
"""判断表情组件当前是否仍处于待补全文本的占位状态。"""
return not component.content or component.content == "[表情包]"
def _lookup_cached_image_description(image_hash: str) -> str:
"""从数据库读取已完成的图片描述,不触发新的识图请求。"""
if not image_hash:
return ""
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1)
if image_record := session.exec(statement).first():
if image_record.no_file_flag:
return ""
if image_record.vlm_processed and image_record.description:
return str(image_record.description).strip()
except Exception as exc:
logger.warning(f"读取图片缓存描述失败image_hash={image_hash}: {exc}")
return ""
def _lookup_cached_emoji_description(emoji_hash: str) -> str:
"""从数据库读取已完成的表情描述,不触发新的识别请求。"""
if not emoji_hash:
return ""
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1)
if image_record := session.exec(statement).first():
if image_record.no_file_flag or not image_record.description:
return ""
return str(image_record.description).strip()
except Exception as exc:
logger.warning(f"读取表情缓存描述失败emoji_hash={emoji_hash}: {exc}")
return ""