重构绝大部分模块以适配新版本的数据库和数据模型,修复缺少依赖问题,更新 pyproject

This commit is contained in:
DrSmoothl
2026-02-13 20:39:11 +08:00
parent c14736ffca
commit 16b16d2ca6
29 changed files with 2459 additions and 1737 deletions

View File

@@ -3,11 +3,11 @@ MaiBot模块系统
包含聊天、情绪、记忆、日程等功能模块
"""
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.emoji_system.emoji_manager import get_emoji_manager
# 导出主要组件供外部使用
__all__ = [
"get_chat_manager",
"get_emoji_manager",
"emoji_manager",
]

View File

@@ -1,8 +1,11 @@
import time
import asyncio
import traceback
from datetime import datetime
from typing import Optional, Dict, Any, List
from src.common.logger import get_logger
from sqlmodel import select, col
from src.common.database.database import get_db_session
from src.common.database.database_model import Messages
from maim_message import UserInfo
from src.config.config import global_config
@@ -16,17 +19,18 @@ logger = get_logger("chat_observer")
def _message_to_dict(message: Messages) -> Dict[str, Any]:
"""Convert Peewee Message model to dict for PFC compatibility
Args:
message: Peewee Messages model instance
Returns:
Dict[str, Any]: Message dictionary
"""
message_timestamp = message.timestamp.timestamp() if isinstance(message.timestamp, datetime) else message.timestamp
return {
"message_id": message.message_id,
"time": message.time,
"chat_id": message.chat_id,
"time": message_timestamp,
"chat_id": message.session_id,
"user_id": message.user_id,
"user_nickname": message.user_nickname,
"processed_plain_text": message.processed_plain_text,
@@ -37,7 +41,7 @@ def _message_to_dict(message: Messages) -> Dict[str, Any]:
"user_info": {
"user_id": message.user_id,
"user_nickname": message.user_nickname,
}
},
}
@@ -109,10 +113,13 @@ class ChatObserver:
"""
logger.debug(f"[私聊][{self.private_name}]检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
new_message_exists = Messages.select().where(
(Messages.chat_id == self.stream_id) &
(Messages.time > self.last_check_time)
).exists()
last_check_time = self.last_check_time or 0.0
last_check_dt = datetime.fromtimestamp(last_check_time)
with get_db_session() as session:
statement = select(Messages).where(
(col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) > last_check_dt)
)
new_message_exists = session.exec(statement).first() is not None
if new_message_exists:
logger.debug(f"[私聊][{self.private_name}]发现新消息")
@@ -183,20 +190,21 @@ class ChatObserver:
)
return has_new
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
"""获取新消息
Returns:
List[Dict[str, Any]]: 新消息列表
"""
query = Messages.select().where(
(Messages.chat_id == self.stream_id) &
(Messages.time > self.last_message_time)
).order_by(Messages.time.asc())
new_messages = [_message_to_dict(msg) for msg in query]
last_message_time = self.last_message_time or 0.0
last_message_dt = datetime.fromtimestamp(last_message_time)
with get_db_session() as session:
statement = (
select(Messages)
.where((col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) > last_message_dt))
.order_by(col(Messages.timestamp))
)
new_messages = [_message_to_dict(msg) for msg in session.exec(statement).all()]
if new_messages:
self.last_message_read = new_messages[-1]
@@ -215,13 +223,16 @@ class ChatObserver:
Returns:
List[Dict[str, Any]]: 最多5条消息
"""
query = Messages.select().where(
(Messages.chat_id == self.stream_id) &
(Messages.time < time_point)
).order_by(Messages.time.desc()).limit(5)
messages = list(query)
messages.reverse() # 需要按时间正序排列
time_point_dt = datetime.fromtimestamp(time_point)
with get_db_session() as session:
statement = (
select(Messages)
.where((col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) < time_point_dt))
.order_by(col(Messages.timestamp))
.limit(5)
)
messages = list(session.exec(statement).all())
messages.reverse()
new_messages = [_message_to_dict(msg) for msg in messages]
if new_messages:

View File

@@ -10,7 +10,9 @@ from src.chat.utils.utils import is_mentioned_bot_in_message
from src.chat.utils.chat_message_builder import replace_user_references
from src.common.logger import get_logger
from src.person_info.person_info import Person
from src.common.database.database_model import Images
from sqlmodel import select, col
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
if TYPE_CHECKING:
pass
@@ -47,6 +49,12 @@ class HeartFCMessageReceiver:
# 1. 消息解析与初始化
userinfo = message.message_info.user_info
chat = message.chat_stream
if userinfo is None or message.message_info.platform is None:
raise ValueError("message userinfo or platform is missing")
if userinfo.user_id is None or userinfo.user_nickname is None:
raise ValueError("message userinfo id or nickname is missing")
user_id = userinfo.user_id
nickname = userinfo.user_nickname
# 2. 计算at信息
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
@@ -70,7 +78,15 @@ class HeartFCMessageReceiver:
processed_text = message.processed_plain_text
if picid_list:
for picid in picid_list:
image = Images.get_or_none(Images.image_id == picid)
with get_db_session() as session:
statement = (
select(Images).where(
(col(Images.id) == int(picid)) & (col(Images.image_type) == ImageType.IMAGE)
)
if picid.isdigit()
else None
)
image = session.exec(statement).first() if statement is not None else None
if image and image.description:
# 将[picid:xxxx]替换成图片描述
processed_text = processed_text.replace(f"[picid:{picid}]", f"[图片:{image.description}]")
@@ -80,26 +96,24 @@ class HeartFCMessageReceiver:
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
processed_plain_text = replace_user_references(
processed_text,
message.message_info.platform, # type: ignore
replace_bot_name=True,
processed_text, message.message_info.platform, replace_bot_name=True
)
# if not processed_plain_text:
# print(message)
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}")
# 如果是群聊,获取群号和群昵称
group_id = None
group_nick_name = None
if chat.group_info:
group_id = chat.group_info.group_id # type: ignore
group_nick_name = userinfo.user_cardname # type: ignore
group_id = chat.group_info.group_id
group_nick_name = userinfo.user_cardname
_ = Person.register_person(
platform=message.message_info.platform, # type: ignore
user_id=message.message_info.user_info.user_id, # type: ignore
nickname=userinfo.user_nickname, # type: ignore
platform=message.message_info.platform,
user_id=user_id,
nickname=nickname,
group_id=group_id,
group_nick_name=group_nick_name,
)

View File

@@ -1,10 +1,10 @@
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.storage import MessageStorage
__all__ = [
"get_emoji_manager",
"get_chat_manager",
"MessageStorage",
"emoji_manager",
]

View File

@@ -2,13 +2,15 @@ import asyncio
import hashlib
import time
import copy
from datetime import datetime
from typing import Dict, Optional, TYPE_CHECKING
from rich.traceback import install
from maim_message import GroupInfo, UserInfo
from sqlmodel import select, col
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import ChatStreams # 新增导入
from src.common.database.database import get_db_session
from src.common.database.database_model import ChatSession
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
@@ -76,7 +78,7 @@ class ChatStream:
self.create_time = data.get("create_time", time.time()) if data else time.time()
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
self.saved = False
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
self.context: Optional[ChatMessageContext] = None
def to_dict(self) -> dict:
"""转换为字典格式"""
@@ -95,10 +97,13 @@ class ChatStream:
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
if user_info is None:
raise ValueError("user_info is required to build ChatStream")
return cls(
stream_id=data["stream_id"],
platform=data["platform"],
user_info=user_info, # type: ignore
user_info=user_info,
group_info=group_info,
data=data,
)
@@ -128,12 +133,7 @@ class ChatManager:
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
try:
db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在
db.create_tables([ChatStreams], safe=True)
except Exception as e:
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
get_db_session()
self._initialized = True
# 在事件循环中启动初始化
@@ -161,8 +161,13 @@ class ChatManager:
def register_message(self, message: "MessageRecv"):
"""注册消息到聊天流"""
platform = message.message_info.platform or ""
if not platform:
raise ValueError("platform is required for ChatStream")
if message.message_info.user_info is None and message.message_info.group_info is None:
raise ValueError("user_info or group_info is required for ChatStream")
stream_id = self._generate_stream_id(
message.message_info.platform, # type: ignore
platform,
message.message_info.user_info,
message.message_info.group_info,
)
@@ -176,12 +181,18 @@ class ChatManager:
"""生成聊天流唯一ID"""
if not user_info and not group_info:
raise ValueError("用户信息或群组信息必须提供")
if group_info is None and user_info is None:
raise ValueError("用户信息或群组信息必须提供")
if group_info:
# 组合关键信息
components = [platform, str(group_info.group_id)]
else:
components = [platform, str(user_info.user_id), "private"] # type: ignore
if user_info is None:
raise ValueError("用户信息或群组信息必须提供")
if user_info.user_id is None:
raise ValueError("user_id is required for private stream")
components = [platform, str(user_info.user_id), "private"]
# 使用MD5生成唯一ID
key = "_".join(components)
@@ -231,33 +242,35 @@ class ChatManager:
# 检查数据库中是否存在
def _db_find_stream_sync(s_id: str):
return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
with get_db_session() as session:
statement = select(ChatSession).where(col(ChatSession.session_id) == s_id)
return session.exec(statement).first()
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
if model_instance:
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
user_info_data = {
"platform": model_instance.user_platform,
"platform": model_instance.platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
"user_nickname": "",
"user_cardname": "",
}
group_info_data = None
if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
if model_instance.group_id:
group_info_data = {
"platform": model_instance.group_platform,
"platform": model_instance.platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
"group_name": "",
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"stream_id": model_instance.session_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
"create_time": model_instance.created_timestamp.timestamp(),
"last_active_time": model_instance.last_active_timestamp.timestamp(),
}
stream = ChatStream.from_dict(data_for_from_dict)
# 更新用户信息和群组信息
@@ -329,20 +342,26 @@ class ChatManager:
user_info_d = s_data_dict.get("user_info")
group_info_d = s_data_dict.get("group_info")
fields_to_save = {
"platform": s_data_dict["platform"],
"create_time": s_data_dict["create_time"],
"last_active_time": s_data_dict["last_active_time"],
"user_platform": user_info_d["platform"] if user_info_d else "",
"user_id": user_info_d["user_id"] if user_info_d else "",
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
"group_platform": group_info_d["platform"] if group_info_d else "",
"group_id": group_info_d["group_id"] if group_info_d else "",
"group_name": group_info_d["group_name"] if group_info_d else "",
}
with get_db_session() as session:
statement = select(ChatSession).where(col(ChatSession.session_id) == s_data_dict["stream_id"])
record = session.exec(statement).first()
if record is None:
record = ChatSession(
session_id=s_data_dict["stream_id"],
platform=s_data_dict["platform"],
user_id=user_info_d["user_id"] if user_info_d else None,
group_id=group_info_d["group_id"] if group_info_d else None,
created_timestamp=datetime.fromtimestamp(s_data_dict["create_time"]),
last_active_timestamp=datetime.fromtimestamp(s_data_dict["last_active_time"]),
)
session.add(record)
return
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
record.platform = s_data_dict["platform"]
record.user_id = user_info_d["user_id"] if user_info_d else None
record.group_id = group_info_d["group_id"] if group_info_d else None
record.created_timestamp = datetime.fromtimestamp(s_data_dict["create_time"])
record.last_active_timestamp = datetime.fromtimestamp(s_data_dict["last_active_time"])
try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
@@ -361,30 +380,32 @@ class ChatManager:
def _db_load_all_streams_sync():
loaded_streams_data = []
for model_instance in ChatStreams.select():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id:
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
with get_db_session() as session:
statement = select(ChatSession)
for model_instance in session.exec(statement).all():
user_info_data = {
"platform": model_instance.platform,
"user_id": model_instance.user_id or "",
"user_nickname": "",
"user_cardname": "",
}
group_info_data = None
if model_instance.group_id:
group_info_data = {
"platform": model_instance.platform,
"group_id": model_instance.group_id,
"group_name": "",
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
}
loaded_streams_data.append(data_for_from_dict)
data_for_from_dict = {
"stream_id": model_instance.session_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.created_timestamp.timestamp(),
"last_active_time": model_instance.last_active_timestamp.timestamp(),
}
loaded_streams_data.append(data_for_from_dict)
return loaded_streams_data
try:

View File

@@ -1,36 +1,74 @@
import re
import json
import traceback
from typing import Union
from datetime import datetime
from collections.abc import Mapping
from typing import cast
from src.common.database.database_model import Messages, Images
import json
import re
import traceback
from sqlmodel import col, select
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType, Messages
from src.common.logger import get_logger
from src.common.data_models.message_component_model import MessageSequence, TextComponent
from src.common.utils.utils_message import MessageUtils
from .chat_stream import ChatStream
from .message import MessageSending, MessageRecv
from .message import MessageRecv, MessageSending
logger = get_logger("message_storage")
class MessageStorage:
@staticmethod
def _serialize_keywords(keywords) -> str:
def _coerce_str_list(value: object) -> list[str]:
if isinstance(value, list):
return [str(item) for item in value]
if isinstance(value, tuple):
return [str(item) for item in value]
if isinstance(value, set):
return [str(item) for item in value]
if isinstance(value, str):
return [value]
return []
@staticmethod
def _get_str(mapping: Mapping[str, object], key: str, default: str = "") -> str:
value = mapping.get(key)
if value is None:
return default
return str(value)
@staticmethod
def _get_optional_str(mapping: Mapping[str, object], key: str) -> str | None:
value = mapping.get(key)
if value is None:
return None
return str(value)
@staticmethod
def _serialize_keywords(keywords: list[str] | None) -> str:
"""将关键词列表序列化为JSON字符串"""
if isinstance(keywords, list):
return json.dumps(keywords, ensure_ascii=False)
return "[]"
@staticmethod
def _deserialize_keywords(keywords_str: str) -> list:
def _deserialize_keywords(keywords_str: str) -> list[str]:
"""将JSON字符串反序列化为关键词列表"""
if not keywords_str:
return []
try:
return json.loads(keywords_str)
parsed = cast(object, json.loads(keywords_str))
except (json.JSONDecodeError, TypeError):
return []
if isinstance(parsed, list):
return [str(item) for item in parsed]
if isinstance(parsed, str):
return [parsed]
return []
@staticmethod
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 通知消息不存储
@@ -66,7 +104,7 @@ class MessageStorage:
priority_mode = ""
priority_info = {}
is_emoji = False
is_picid = False
is_picture = False
is_notify = False
is_command = False
key_words = ""
@@ -83,66 +121,73 @@ class MessageStorage:
priority_mode = message.priority_mode
priority_info = message.priority_info
is_emoji = message.is_emoji
is_picid = message.is_picid
is_picture = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
intercept_message_level = getattr(message, "intercept_message_level", 0)
# 序列化关键词列表为JSON字符串
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
key_words = MessageStorage._serialize_keywords(MessageStorage._coerce_str_list(message.key_words))
key_words_lite = MessageStorage._serialize_keywords(
MessageStorage._coerce_str_list(message.key_words_lite)
)
selected_expressions = ""
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
chat_info_dict = cast(dict[str, object], chat_stream.to_dict())
if message.message_info.user_info is None:
raise ValueError("message.user_info is required")
user_info_dict = cast(dict[str, object], message.message_info.user_info.to_dict())
# message_id 现在是 TextField直接使用字符串值
msg_id = message.message_info.message_id
msg_id = message.message_info.message_id or ""
# 安全地获取 group_info, 如果为 None 则视为空字典
group_info_from_chat = chat_info_dict.get("group_info") or {}
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
user_info_from_chat = chat_info_dict.get("user_info") or {}
group_info_from_chat = cast(dict[str, object], chat_info_dict.get("group_info") or {})
Messages.create(
message_id=msg_id,
time=float(message.message_info.time), # type: ignore
chat_id=chat_stream.stream_id,
# Flattened chat_info
additional_config: dict[str, object] = dict(message.message_info.additional_config or {})
additional_config.update(
{
"interest_value": interest_value,
"priority_mode": priority_mode,
"priority_info": priority_info,
"reply_probability_boost": reply_probability_boost,
"intercept_message_level": intercept_message_level,
"key_words": key_words,
"key_words_lite": key_words_lite,
"selected_expressions": selected_expressions,
"is_picid": is_picture,
}
)
processed_text_for_raw = filtered_processed_plain_text or filtered_display_message or ""
raw_sequence = MessageSequence([TextComponent(processed_text_for_raw)] if processed_text_for_raw else [])
raw_content = MessageUtils.from_MaiSeq_to_db_record_msg(raw_sequence)
timestamp_value = message.message_info.time
if timestamp_value is None:
raise ValueError("message.message_info.time is required")
db_message = Messages(
message_id=str(msg_id),
timestamp=datetime.fromtimestamp(float(timestamp_value)),
platform=MessageStorage._get_str(chat_info_dict, "platform"),
user_id=MessageStorage._get_str(user_info_dict, "user_id"),
user_nickname=MessageStorage._get_str(user_info_dict, "user_nickname"),
user_cardname=MessageStorage._get_optional_str(user_info_dict, "user_cardname"),
group_id=MessageStorage._get_optional_str(group_info_from_chat, "group_id"),
group_name=MessageStorage._get_optional_str(group_info_from_chat, "group_name"),
is_mentioned=bool(is_mentioned),
is_at=bool(is_at),
session_id=chat_stream.stream_id,
reply_to=reply_to,
is_mentioned=is_mentioned,
is_at=is_at,
reply_probability_boost=reply_probability_boost,
chat_info_stream_id=chat_info_dict.get("stream_id"),
chat_info_platform=chat_info_dict.get("platform"),
chat_info_user_platform=user_info_from_chat.get("platform"),
chat_info_user_id=user_info_from_chat.get("user_id"),
chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
chat_info_group_platform=group_info_from_chat.get("platform"),
chat_info_group_id=group_info_from_chat.get("group_id"),
chat_info_group_name=group_info_from_chat.get("group_name"),
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
# Flattened user_info (message sender)
user_platform=user_info_dict.get("platform"),
user_id=user_info_dict.get("user_id"),
user_nickname=user_info_dict.get("user_nickname"),
user_cardname=user_info_dict.get("user_cardname"),
# Text content
is_emoji=is_emoji,
is_picture=is_picture,
is_command=is_command,
is_notify=is_notify,
raw_content=raw_content,
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
interest_value=interest_value,
priority_mode=priority_mode,
priority_info=priority_info,
is_emoji=is_emoji,
is_picid=is_picid,
is_notify=is_notify,
is_command=is_command,
intercept_message_level=intercept_message_level,
key_words=key_words,
key_words_lite=key_words_lite,
selected_expressions=selected_expressions,
additional_config=json.dumps(additional_config, ensure_ascii=False),
)
with get_db_session() as session:
session.add(db_message)
except Exception:
logger.exception("存储消息失败")
logger.error(f"消息:{message}")
@@ -156,16 +201,21 @@ class MessageStorage:
if not qq_message_id:
logger.info("消息不存在message_id无法更新")
return False
if matched_message := (
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
):
# 更新找到的消息记录
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
return True
else:
logger.debug("未找到匹配的消息")
return False
with get_db_session() as session:
statement = (
select(Messages)
.where(col(Messages.message_id) == mmc_message_id)
.order_by(col(Messages.timestamp).desc())
.limit(1)
)
matched_message = session.exec(statement).first()
if matched_message:
matched_message.message_id = qq_message_id
session.add(matched_message)
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
return True
logger.debug("未找到匹配的消息")
return False
except Exception as e:
logger.error(f"更新消息ID失败: {e}")
@@ -182,13 +232,18 @@ class MessageStorage:
logger.debug("文本中没有图片标记,直接返回原文本")
return text
def replace_match(match):
def replace_match(match: re.Match[str]) -> str:
description = match.group(1).strip()
try:
image_record = (
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
)
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
with get_db_session() as session:
statement = (
select(Images)
.where((col(Images.description) == description) & (col(Images.image_type) == ImageType.IMAGE))
.order_by(col(Images.record_time).desc())
.limit(1)
)
image_record = session.exec(statement).first()
return f"[picid:{image_record.id}]" if image_record else match.group(0)
except Exception:
return match.group(0)

View File

@@ -1,17 +1,17 @@
import time
import random
import re
from datetime import datetime
from typing import List, Dict, Any, Tuple, Optional, Callable
from rich.traceback import install
from sqlmodel import select, col
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.message_repository import find_messages, count_messages
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseActionRecords
from src.common.data_models.message_data_model import MessageAndActionModel
from src.common.database.database_model import ActionRecords
from src.common.database.database_model import Images
from src.common.database.database import get_db_session
from src.common.database.database_model import ActionRecord, Images
from src.person_info.person_info import Person, get_person_id
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids, is_bot_self
@@ -198,37 +198,38 @@ def get_actions_by_timestamp_with_chat(
limit_mode: str = "latest",
) -> List[DatabaseActionRecords]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
& (ActionRecords.time > timestamp_start) # type: ignore
& (ActionRecords.time < timestamp_end) # type: ignore
)
with get_db_session() as session:
statement = (
select(ActionRecord)
.where((col(ActionRecord.session_id) == chat_id))
.where(col(ActionRecord.timestamp) > datetime.fromtimestamp(timestamp_start))
.where(col(ActionRecord.timestamp) < datetime.fromtimestamp(timestamp_end))
)
if limit > 0:
if limit_mode == "latest":
query = query.order_by(ActionRecords.time.desc()).limit(limit)
# 获取后需要反转列表,以保持最终输出为时间升序
actions = list(query)
actions.reverse()
else: # earliest
query = query.order_by(ActionRecords.time.asc()).limit(limit)
else:
query = query.order_by(ActionRecords.time.asc())
actions = list(query)
if limit > 0:
if limit_mode == "latest":
statement = statement.order_by(col(ActionRecord.timestamp).desc()).limit(limit)
actions = list(session.exec(statement).all())
actions = list(reversed(actions))
else:
statement = statement.order_by(col(ActionRecord.timestamp)).limit(limit)
actions = list(session.exec(statement).all())
else:
statement = statement.order_by(col(ActionRecord.timestamp))
actions = session.exec(statement).all()
return [
DatabaseActionRecords(
action_id=action.action_id,
time=action.time,
time=action.timestamp.timestamp(),
action_name=action.action_name,
action_data=action.action_data,
action_done=action.action_done,
action_build_into_prompt=action.action_build_into_prompt,
action_prompt_display=action.action_prompt_display,
chat_id=action.chat_id,
chat_info_stream_id=action.chat_info_stream_id,
chat_info_platform=action.chat_info_platform,
action_reasoning=action.action_reasoning,
action_data=action.action_data or "{}",
action_done=True,
action_build_into_prompt=bool(action.action_display_prompt),
action_prompt_display=action.action_display_prompt or "",
chat_id=action.session_id,
chat_info_stream_id=action.session_id,
chat_info_platform=global_config.bot.platform,
action_reasoning=action.action_reasoning or "",
)
for action in actions
]
@@ -238,25 +239,27 @@ def get_actions_by_timestamp_with_chat_inclusive(
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
& (ActionRecords.time >= timestamp_start) # type: ignore
& (ActionRecords.time <= timestamp_end) # type: ignore
)
with get_db_session() as session:
statement = (
select(ActionRecord)
.where((col(ActionRecord.session_id) == chat_id))
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(timestamp_start))
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(timestamp_end))
)
if limit > 0:
if limit_mode == "latest":
query = query.order_by(ActionRecords.time.desc()).limit(limit)
# 获取后需要反转列表,以保持最终输出为时间升序
actions = list(query)
return [action.__data__ for action in reversed(actions)]
else: # earliest
query = query.order_by(ActionRecords.time.asc()).limit(limit)
else:
query = query.order_by(ActionRecords.time.asc())
if limit > 0:
if limit_mode == "latest":
statement = statement.order_by(col(ActionRecord.timestamp).desc()).limit(limit)
actions = list(session.exec(statement).all())
actions = list(reversed(actions))
else:
statement = statement.order_by(col(ActionRecord.timestamp)).limit(limit)
actions = list(session.exec(statement).all())
else:
statement = statement.order_by(col(ActionRecord.timestamp))
actions = session.exec(statement).all()
actions = list(query)
return [action.__data__ for action in actions]
return [action.model_dump() for action in actions]
def get_raw_msg_by_timestamp_random(
@@ -278,7 +281,7 @@ def get_raw_msg_by_timestamp_random(
def get_raw_msg_by_timestamp_with_users(
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
timestamp_start: float, timestamp_end: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[DatabaseMessages]:
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
@@ -316,7 +319,7 @@ def get_raw_msg_before_timestamp_with_chat(
def get_raw_msg_before_timestamp_with_users(
timestamp: float, person_ids: list, limit: int = 0
timestamp: float, person_ids: List[str], limit: int = 0
) -> List[DatabaseMessages]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
@@ -344,7 +347,7 @@ def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp
def num_new_messages_since_with_users(
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: List[str]
) -> int:
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
if not person_ids: # 保持空列表检查
@@ -358,7 +361,7 @@ def num_new_messages_since_with_users(
def _build_readable_messages_internal(
messages: List[MessageAndActionModel],
messages: List[DatabaseMessages],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
truncate: bool = False,
@@ -413,7 +416,7 @@ def _build_readable_messages_internal(
# 匹配 [picid:xxxxx] 格式
pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(match: re.Match) -> str:
def replace_pic_id(match: re.Match[str]) -> str:
nonlocal current_pic_counter
nonlocal pic_counter
pic_id = match.group(1)
@@ -421,7 +424,8 @@ def _build_readable_messages_internal(
if pic_id not in pic_description_cache:
description = "内容正在阅读,请稍等"
try:
image = Images.get_or_none(Images.image_id == pic_id)
with get_db_session() as session:
image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None
if image and image.description:
description = image.description
except Exception:
@@ -438,16 +442,11 @@ def _build_readable_messages_internal(
# 1: 获取发送者信息并提取消息组件
for message in messages:
if message.is_action_record:
# 对于动作记录也处理图片ID
content = process_pic_ids(message.display_message)
detailed_messages_raw.append((message.time, message.user_nickname, content, True))
continue
platform = message.user_platform
user_id = message.user_id
user_nickname = message.user_nickname
user_cardname = message.user_cardname
user_info = message.user_info
platform = user_info.platform
user_id = user_info.user_id
user_nickname = user_info.user_nickname
user_cardname = user_info.user_cardname
timestamp = message.time
content = message.display_message or message.processed_plain_text or ""
@@ -525,12 +524,12 @@ def _build_readable_messages_internal(
if long_time_notice and prev_timestamp is not None:
time_diff = timestamp - prev_timestamp
time_diff_hours = time_diff / 3600
# 检查是否跨天
prev_date = time.strftime("%Y-%m-%d", time.localtime(prev_timestamp))
current_date = time.strftime("%Y-%m-%d", time.localtime(timestamp))
is_cross_day = prev_date != current_date
# 如果间隔大于8小时或跨天插入提示
if time_diff_hours > 8 or is_cross_day:
# 格式化日期为中文格式xxxx年xx月xx日去掉前导零
@@ -542,20 +541,15 @@ def _build_readable_messages_internal(
hours_str = f"{int(time_diff_hours)}h"
notice = f"以下聊天开始时间:{date_str}。距离上一条消息过去了{hours_str}\n"
output_lines.append(notice)
readable_time = translate_timestamp_to_human_readable(timestamp, mode=timestamp_mode)
# 查找消息id如果有并构建id_prefix
message_id = timestamp_to_id_mapping.get(timestamp, "")
id_prefix = f"[{message_id}]" if message_id else ""
if is_action:
# 对于动作记录,使用特殊格式
output_lines.append(f"{id_prefix}{readable_time}, {content}")
else:
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
output_lines.append("\n") # 在每个消息块后添加换行,保持可读性
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
output_lines.append("\n")
prev_timestamp = timestamp
formatted_string = "".join(output_lines).strip()
@@ -592,7 +586,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# 从数据库中获取图片描述
description = "内容正在阅读,请稍等"
try:
image = Images.get_or_none(Images.image_id == pic_id)
with get_db_session() as session:
image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None
if image and image.description:
description = image.description
except Exception:
@@ -663,7 +658,7 @@ async def build_readable_messages_with_list(
允许通过参数控制格式化行为。
"""
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
[MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages],
messages,
replace_bot_name,
timestamp_mode,
truncate,
@@ -754,7 +749,7 @@ def build_readable_messages(
filtered_messages = []
for msg in messages:
# 获取消息内容
content = msg.processed_plain_text
content = msg.processed_plain_text or ""
# 移除表情包
emoji_pattern = r"\[表情包:[^\]]+\]"
content = re.sub(emoji_pattern, "", content)
@@ -765,17 +760,14 @@ def build_readable_messages(
messages = filtered_messages
copy_messages: List[MessageAndActionModel] = []
copy_messages: List[DatabaseMessages] = []
for msg in messages:
if remove_emoji_stickers:
# 创建 MessageAndActionModel 但移除表情包
model = MessageAndActionModel.from_DatabaseMessages(msg)
# 移除表情包
if model.processed_plain_text:
model.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", model.processed_plain_text)
copy_messages.append(model)
msg.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", msg.processed_plain_text or "")
copy_messages.append(msg)
else:
copy_messages.append(MessageAndActionModel.from_DatabaseMessages(msg))
copy_messages.append(msg)
if show_actions and copy_messages:
# 获取所有消息的时间范围
@@ -786,40 +778,45 @@ def build_readable_messages(
chat_id = messages[0].chat_id if messages else None
# 获取这个时间范围内的动作记录并匹配chat_id
actions_in_range = (
ActionRecords.select()
.where(
(ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)
)
.order_by(ActionRecords.time)
)
with get_db_session() as session:
actions_in_range = session.exec(
select(ActionRecord)
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(min_time))
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(max_time))
.where(col(ActionRecord.session_id) == chat_id)
.order_by(col(ActionRecord.timestamp))
).all()
# 获取最新消息之后的第一个动作记录
action_after_latest = (
ActionRecords.select()
.where((ActionRecords.time > max_time) & (ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
)
with get_db_session() as session:
action_after_latest = session.exec(
select(ActionRecord)
.where(col(ActionRecord.timestamp) > datetime.fromtimestamp(max_time))
.where(col(ActionRecord.session_id) == chat_id)
.order_by(col(ActionRecord.timestamp))
.limit(1)
).all()
# 合并两部分动作记录
actions: List[ActionRecords] = list(actions_in_range) + list(action_after_latest)
actions: List[ActionRecord] = list(actions_in_range) + list(action_after_latest)
# 将动作记录转换为消息格式
for action in actions:
# 只有当build_into_prompt为True时才添加动作记录
if action.action_build_into_prompt:
action_msg = MessageAndActionModel(
time=float(action.time), # type: ignore
user_id=global_config.bot.qq_account, # 使用机器人的QQ账号
user_platform=global_config.bot.platform, # 使用机器人的平台
user_nickname=global_config.bot.nickname, # 使用机器人的用户名
user_cardname="", # 机器人没有群名片
processed_plain_text=f"{action.action_prompt_display}",
display_message=f"{action.action_prompt_display}",
chat_info_platform=str(action.chat_info_platform),
is_action_record=True, # 添加标识字段
action_name=str(action.action_name), # 保存动作名称
action_display_prompt = action.action_display_prompt or ""
if action_display_prompt:
action_msg = DatabaseMessages(
message_id=f"action_{action.action_id}",
time=float(action.timestamp.timestamp()),
chat_id=chat_id or "",
processed_plain_text=action_display_prompt,
display_message=action_display_prompt,
user_platform=global_config.bot.platform,
user_id=str(global_config.bot.qq_account),
user_nickname=global_config.bot.nickname,
user_cardname="",
chat_info_platform=str(global_config.bot.platform),
chat_info_stream_id=chat_id or "",
)
copy_messages.append(action_msg)
@@ -1026,17 +1023,13 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
person_ids_set = set() # 使用集合来自动去重
for msg in messages:
platform: str = msg.get("user_platform") # type: ignore
user_id: str = msg.get("user_id") # type: ignore
platform = msg.get("user_platform") or ""
user_id = msg.get("user_id") or ""
# 检查必要信息是否存在 且 不是机器人自己
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
continue
# 添加空值检查,防止 platform 为 None 时出错
if platform is None:
platform = "unknown"
if person_id := get_person_id(platform, user_id):
person_ids_set.add(person_id)

File diff suppressed because it is too large Load Diff

View File

@@ -1,18 +1,21 @@
import base64
from datetime import datetime
from typing import Optional, Tuple
import hashlib
import io
import os
import time
import hashlib
import uuid
import io
import numpy as np
from typing import Optional, Tuple
import numpy as np
from PIL import Image
from rich.traceback import install
from sqlmodel import select, col
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import Images, ImageDescriptions, EmojiDescriptionCache
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
@@ -38,11 +41,7 @@ class ImageManager:
self._initialized = True
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
try:
db.connect(reuse_if_open=True)
db.create_tables([Images, ImageDescriptions, EmojiDescriptionCache], safe=True)
except Exception as e:
logger.error(f"数据库连接或表创建失败: {e}")
get_db_session()
try:
self._cleanup_invalid_descriptions()
@@ -72,10 +71,12 @@ class ImageManager:
Optional[str]: 描述文本如果不存在则返回None
"""
try:
record = ImageDescriptions.get_or_none(
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
)
return record.description if record else None
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type))
)
record = session.exec(statement).first()
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
return None
@@ -90,15 +91,27 @@ class ImageManager:
description_type: 描述类型 ('emoji''image')
"""
try:
current_timestamp = time.time()
defaults = {"description": description, "timestamp": current_timestamp}
desc_obj, created = ImageDescriptions.get_or_create(
image_description_hash=image_hash, type=description_type, defaults=defaults
)
if not created: # 如果记录已存在,则更新
desc_obj.description = description
desc_obj.timestamp = current_timestamp
desc_obj.save()
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type))
)
record = session.exec(statement).first()
if record:
record.description = description
session.add(record)
return
new_record = Images(
image_hash=image_hash,
description=description,
full_path="",
image_type=ImageType(description_type),
query_count=0,
is_registered=False,
is_banned=False,
vlm_processed=True,
)
session.add(new_record)
except Exception as e:
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
@@ -107,20 +120,18 @@ class ImageManager:
"""清理数据库中 description 为空或为 'None' 的记录"""
invalid_values = ["", "None"]
# 清理 Images 表
deleted_images = (
Images.delete().where((Images.description >> None) | (Images.description << invalid_values)).execute()
)
with get_db_session() as session:
statement = (
select(Images)
.where(col(Images.description).is_(None) | col(Images.description).in_(invalid_values))
.limit(1000)
)
records = session.exec(statement).all()
for record in records:
session.delete(record)
# 清理 ImageDescriptions 表
deleted_descriptions = (
ImageDescriptions.delete()
.where((ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values))
.execute()
)
if deleted_images or deleted_descriptions:
logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions}")
if records:
logger.info(f"[清理完成] 删除 Images: {len(records)}")
else:
logger.info("[清理完成] 未发现无效描述记录")
@@ -128,19 +139,15 @@ class ImageManager:
def _cleanup_emoji_from_image_descriptions():
"""清理Images和ImageDescriptions表中type为emoji的记录已迁移到EmojiDescriptionCache"""
try:
# 清理Images表中type为emoji的记录
deleted_images = Images.delete().where(Images.type == "emoji").execute()
with get_db_session() as session:
statement = select(Images).where(col(Images.image_type) == ImageType.EMOJI)
records = session.exec(statement).all()
for record in records:
session.delete(record)
# 清理ImageDescriptions表中type为emoji的记录
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
total_deleted = deleted_images + deleted_descriptions
total_deleted = len(records)
if total_deleted > 0:
logger.info(
f"[清理完成] 从Images表中删除 {deleted_images} 条emoji类型记录, "
f"从ImageDescriptions表中删除 {deleted_descriptions} 条emoji类型记录, "
f"共删除 {total_deleted} 条记录"
)
logger.info(f"[清理完成] 从Images表中删除 {total_deleted} 条emoji类型记录")
else:
logger.info("[清理完成] Images和ImageDescriptions表中未发现emoji类型记录")
except Exception as e:
@@ -148,14 +155,14 @@ class ImageManager:
raise
async def get_emoji_tag(self, image_base64: str) -> str:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
emoji_manager = get_emoji_manager()
emoji_manager = emoji_manager_instance
if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
emoji = await emoji_manager.get_emoji_from_manager(image_hash)
emoji = emoji_manager.get_emoji_by_hash(image_hash)
if not emoji:
return "[表情包:未知]"
emotion_list = emoji.emotion
@@ -175,14 +182,14 @@ class ImageManager:
try:
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
# 确保目录存在
os.makedirs(EMOJI_DIR, exist_ok=True)
# 检查是否已存在该表情包(通过哈希值)
emoji_manager = get_emoji_manager()
existing_emoji = await emoji_manager.get_emoji_from_manager(image_hash)
emoji_manager = emoji_manager_instance
existing_emoji = emoji_manager.get_emoji_by_hash(image_hash)
if existing_emoji:
logger.debug(f"[自动保存] 表情包已注册,跳过保存: {image_hash[:8]}...")
return
@@ -212,14 +219,15 @@ class ImageManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
# 优先使用EmojiManager查询已注册表情包的描述
try:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
emoji_manager = get_emoji_manager()
tags = await emoji_manager.get_emoji_tag_by_hash(image_hash)
emoji_manager = emoji_manager_instance
emoji = emoji_manager.get_emoji_by_hash(image_hash)
tags = emoji.emotion if emoji else None
if tags:
tag_str = ",".join(tags)
logger.info(f"[缓存命中] 使用已注册表情包描述: {tag_str}...")
@@ -227,29 +235,26 @@ class ImageManager:
except Exception as e:
logger.debug(f"查询EmojiManager时出错: {e}")
# 查询EmojiDescriptionCache表的缓存包含描述和情感标签
try:
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
)
cache_record = session.exec(statement).first()
if cache_record:
# 优先使用情感标签,如果没有则使用详细描述
result_text = ""
if cache_record.emotion_tags:
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
)
result_text = f"[表情包:{cache_record.emotion_tags}]"
if cache_record.emotion:
logger.info(f"[缓存命中] 使用Images表中的情感标签: {cache_record.emotion[:50]}...")
result_text = f"[表情包:{cache_record.emotion}]"
elif cache_record.description:
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
)
logger.info(f"[缓存命中] 使用Images表中的描述: {cache_record.description[:50]}...")
result_text = f"[表情包:{cache_record.description}]"
# 即使缓存命中如果启用了steal_emoji也检查是否需要保存文件
if result_text:
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
return result_text
except Exception as e:
logger.debug(f"查询EmojiDescriptionCache时出错: {e}")
logger.debug(f"查询Images缓存时出错: {e}")
# === 二步走识别流程 ===
@@ -309,33 +314,42 @@ class ImageManager:
logger.debug(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
# 再次检查缓存(防止并发情况下其他线程已经保存)
try:
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
if cache_record and cache_record.emotion_tags:
logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion_tags}")
return f"[表情包:{cache_record.emotion_tags}]"
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
)
cache_record = session.exec(statement).first()
if cache_record and cache_record.emotion:
logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion}")
return f"[表情包:{cache_record.emotion}]"
except Exception as e:
logger.debug(f"再次查询EmojiDescriptionCache时出错: {e}")
logger.debug(f"再次查询Images缓存时出错: {e}")
# 保存识别出的详细描述和情感标签到 emoji_description_cache
try:
current_timestamp = time.time()
cache_record, created = EmojiDescriptionCache.get_or_create(
emoji_hash=image_hash,
defaults={
"description": detailed_description,
"emotion_tags": final_emotion,
"timestamp": current_timestamp,
},
)
if not created:
# 更新已有记录
cache_record.description = detailed_description
cache_record.emotion_tags = final_emotion
cache_record.timestamp = current_timestamp
cache_record.save()
logger.info(f"[缓存保存] 表情包描述和情感标签已保存到EmojiDescriptionCache: {image_hash[:8]}...")
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
)
cache_record = session.exec(statement).first()
if cache_record:
cache_record.description = detailed_description
cache_record.emotion = final_emotion
session.add(cache_record)
else:
cache_record = Images(
image_hash=image_hash,
description=detailed_description,
full_path="",
image_type=ImageType.EMOJI,
emotion=final_emotion,
query_count=0,
is_registered=False,
is_banned=False,
vlm_processed=True,
)
session.add(cache_record)
logger.info(f"[缓存保存] 表情包描述和情感标签已保存到Images: {image_hash[:8]}...")
except Exception as e:
logger.error(f"保存表情包描述和情感标签缓存失败: {str(e)}")
@@ -358,14 +372,13 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 优先检查Images表中是否已有完整的描述
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
with get_db_session() as session:
statement = select(Images).where(col(Images.image_hash) == image_hash)
existing_image = session.exec(statement).first()
if existing_image:
# 更新计数
if hasattr(existing_image, "count") and existing_image.count is not None:
existing_image.count += 1
else:
existing_image.count = 1
existing_image.save()
existing_image.query_count += 1
with get_db_session() as session:
session.add(existing_image)
# 如果已有描述,直接返回
if existing_image.description:
@@ -377,7 +390,7 @@ class ImageManager:
return f"[图片:{cached_description}]"
# 调用AI获取描述
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
prompt = global_config.personality.visual_style
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(
@@ -402,26 +415,27 @@ class ImageManager:
# 保存到数据库,补充缺失字段
if existing_image:
existing_image.path = file_path
existing_image.full_path = file_path
existing_image.description = description
existing_image.timestamp = current_timestamp
if not hasattr(existing_image, "image_id") or not existing_image.image_id:
existing_image.image_id = str(uuid.uuid4())
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
existing_image.vlm_processed = True
existing_image.save()
existing_image.record_time = datetime.fromtimestamp(current_timestamp)
existing_image.vlm_processed = True
with get_db_session() as session:
session.add(existing_image)
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
else:
Images.create(
image_id=str(uuid.uuid4()),
emoji_hash=image_hash,
path=file_path,
type="image",
description=description,
timestamp=current_timestamp,
vlm_processed=True,
count=1,
)
with get_db_session() as session:
new_record = Images(
image_hash=image_hash,
description=description,
full_path=file_path,
image_type=ImageType.IMAGE,
query_count=1,
is_registered=False,
is_banned=False,
record_time=datetime.fromtimestamp(current_timestamp),
vlm_processed=True,
)
session.add(new_record)
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
except Exception as e:
logger.error(f"保存图片文件或元数据失败: {str(e)}")
@@ -575,30 +589,17 @@ class ImageManager:
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
# 检查是否缺少必要字段,如果缺少则创建新记录
if (
not hasattr(existing_image, "image_id")
or not existing_image.image_id
or not hasattr(existing_image, "count")
or existing_image.count is None
or not hasattr(existing_image, "vlm_processed")
or existing_image.vlm_processed is None
):
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
if not existing_image.image_id:
existing_image.image_id = str(uuid.uuid4())
if existing_image.count is None:
existing_image.count = 0
if existing_image.vlm_processed is None:
existing_image.vlm_processed = False
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.IMAGE)
)
existing_image = session.exec(statement).first()
if existing_image:
existing_image.query_count += 1
session.add(existing_image)
return str(existing_image.id), f"[picid:{existing_image.id}]"
existing_image.count += 1
existing_image.save()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
else:
# print(f"图片不存在: {image_hash}")
image_id = str(uuid.uuid4())
image_id = str(uuid.uuid4())
# 保存新图片
current_timestamp = time.time()
@@ -612,15 +613,19 @@ class ImageManager:
f.write(image_bytes)
# 保存到数据库
Images.create(
image_id=image_id,
emoji_hash=image_hash,
path=file_path,
type="image",
timestamp=current_timestamp,
vlm_processed=False,
count=1,
)
with get_db_session() as session:
new_record = Images(
image_hash=image_hash,
description="",
full_path=file_path,
image_type=ImageType.IMAGE,
query_count=1,
is_registered=False,
is_banned=False,
record_time=datetime.fromtimestamp(current_timestamp),
vlm_processed=False,
)
session.add(new_record)
# 启动异步VLM处理
await self._process_image_with_vlm(image_id, image_base64)
@@ -647,17 +652,26 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 获取当前图片记录
image = Images.get(Images.image_id == image_id)
with get_db_session() as session:
image = session.get(Images, int(image_id)) if image_id.isdigit() else None
if image is None:
logger.warning(f"未找到图片记录: {image_id}")
return
# 优先检查是否已有其他相同哈希的图片记录包含描述
existing_with_description = Images.get_or_none(
(Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
)
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash)
& (col(Images.description).is_not(None))
& (col(Images.description) != "")
)
existing_with_description = session.exec(statement).first()
if existing_with_description and existing_with_description.id != image.id:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
image.description = existing_with_description.description
image.vlm_processed = True
image.save()
with get_db_session() as session:
session.add(image)
# 同时保存到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, existing_with_description.description, "image")
return
@@ -667,11 +681,12 @@ class ImageManager:
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description
image.vlm_processed = True
image.save()
with get_db_session() as session:
session.add(image)
return
# 获取图片格式
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
# 构建prompt
prompt = global_config.personality.visual_style
@@ -692,7 +707,8 @@ class ImageManager:
# 更新数据库
image.description = description
image.vlm_processed = True
image.save()
with get_db_session() as session:
session.add(image)
# 保存描述到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, description, "image")