重构绝大部分模块以适配新版本的数据库和数据模型,修复缺少依赖问题,更新 pyproject
This commit is contained in:
@@ -3,11 +3,11 @@ MaiBot模块系统
|
||||
包含聊天、情绪、记忆、日程等功能模块
|
||||
"""
|
||||
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
|
||||
# 导出主要组件供外部使用
|
||||
__all__ = [
|
||||
"get_chat_manager",
|
||||
"get_emoji_manager",
|
||||
"emoji_manager",
|
||||
]
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import time
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from src.common.logger import get_logger
|
||||
from sqlmodel import select, col
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages
|
||||
from maim_message import UserInfo
|
||||
from src.config.config import global_config
|
||||
@@ -16,17 +19,18 @@ logger = get_logger("chat_observer")
|
||||
|
||||
def _message_to_dict(message: Messages) -> Dict[str, Any]:
|
||||
"""Convert Peewee Message model to dict for PFC compatibility
|
||||
|
||||
|
||||
Args:
|
||||
message: Peewee Messages model instance
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Message dictionary
|
||||
"""
|
||||
message_timestamp = message.timestamp.timestamp() if isinstance(message.timestamp, datetime) else message.timestamp
|
||||
return {
|
||||
"message_id": message.message_id,
|
||||
"time": message.time,
|
||||
"chat_id": message.chat_id,
|
||||
"time": message_timestamp,
|
||||
"chat_id": message.session_id,
|
||||
"user_id": message.user_id,
|
||||
"user_nickname": message.user_nickname,
|
||||
"processed_plain_text": message.processed_plain_text,
|
||||
@@ -37,7 +41,7 @@ def _message_to_dict(message: Messages) -> Dict[str, Any]:
|
||||
"user_info": {
|
||||
"user_id": message.user_id,
|
||||
"user_nickname": message.user_nickname,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -109,10 +113,13 @@ class ChatObserver:
|
||||
"""
|
||||
logger.debug(f"[私聊][{self.private_name}]检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
|
||||
|
||||
new_message_exists = Messages.select().where(
|
||||
(Messages.chat_id == self.stream_id) &
|
||||
(Messages.time > self.last_check_time)
|
||||
).exists()
|
||||
last_check_time = self.last_check_time or 0.0
|
||||
last_check_dt = datetime.fromtimestamp(last_check_time)
|
||||
with get_db_session() as session:
|
||||
statement = select(Messages).where(
|
||||
(col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) > last_check_dt)
|
||||
)
|
||||
new_message_exists = session.exec(statement).first() is not None
|
||||
|
||||
if new_message_exists:
|
||||
logger.debug(f"[私聊][{self.private_name}]发现新消息")
|
||||
@@ -183,20 +190,21 @@ class ChatObserver:
|
||||
)
|
||||
return has_new
|
||||
|
||||
|
||||
|
||||
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
|
||||
"""获取新消息
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 新消息列表
|
||||
"""
|
||||
query = Messages.select().where(
|
||||
(Messages.chat_id == self.stream_id) &
|
||||
(Messages.time > self.last_message_time)
|
||||
).order_by(Messages.time.asc())
|
||||
|
||||
new_messages = [_message_to_dict(msg) for msg in query]
|
||||
last_message_time = self.last_message_time or 0.0
|
||||
last_message_dt = datetime.fromtimestamp(last_message_time)
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Messages)
|
||||
.where((col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) > last_message_dt))
|
||||
.order_by(col(Messages.timestamp))
|
||||
)
|
||||
new_messages = [_message_to_dict(msg) for msg in session.exec(statement).all()]
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]
|
||||
@@ -215,13 +223,16 @@ class ChatObserver:
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 最多5条消息
|
||||
"""
|
||||
query = Messages.select().where(
|
||||
(Messages.chat_id == self.stream_id) &
|
||||
(Messages.time < time_point)
|
||||
).order_by(Messages.time.desc()).limit(5)
|
||||
|
||||
messages = list(query)
|
||||
messages.reverse() # 需要按时间正序排列
|
||||
time_point_dt = datetime.fromtimestamp(time_point)
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Messages)
|
||||
.where((col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) < time_point_dt))
|
||||
.order_by(col(Messages.timestamp))
|
||||
.limit(5)
|
||||
)
|
||||
messages = list(session.exec(statement).all())
|
||||
messages.reverse()
|
||||
new_messages = [_message_to_dict(msg) for msg in messages]
|
||||
|
||||
if new_messages:
|
||||
|
||||
@@ -10,7 +10,9 @@ from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.chat.utils.chat_message_builder import replace_user_references
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.person_info import Person
|
||||
from src.common.database.database_model import Images
|
||||
from sqlmodel import select, col
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@@ -47,6 +49,12 @@ class HeartFCMessageReceiver:
|
||||
# 1. 消息解析与初始化
|
||||
userinfo = message.message_info.user_info
|
||||
chat = message.chat_stream
|
||||
if userinfo is None or message.message_info.platform is None:
|
||||
raise ValueError("message userinfo or platform is missing")
|
||||
if userinfo.user_id is None or userinfo.user_nickname is None:
|
||||
raise ValueError("message userinfo id or nickname is missing")
|
||||
user_id = userinfo.user_id
|
||||
nickname = userinfo.user_nickname
|
||||
|
||||
# 2. 计算at信息
|
||||
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
|
||||
@@ -70,7 +78,15 @@ class HeartFCMessageReceiver:
|
||||
processed_text = message.processed_plain_text
|
||||
if picid_list:
|
||||
for picid in picid_list:
|
||||
image = Images.get_or_none(Images.image_id == picid)
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Images).where(
|
||||
(col(Images.id) == int(picid)) & (col(Images.image_type) == ImageType.IMAGE)
|
||||
)
|
||||
if picid.isdigit()
|
||||
else None
|
||||
)
|
||||
image = session.exec(statement).first() if statement is not None else None
|
||||
if image and image.description:
|
||||
# 将[picid:xxxx]替换成图片描述
|
||||
processed_text = processed_text.replace(f"[picid:{picid}]", f"[图片:{image.description}]")
|
||||
@@ -80,26 +96,24 @@ class HeartFCMessageReceiver:
|
||||
|
||||
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
|
||||
processed_plain_text = replace_user_references(
|
||||
processed_text,
|
||||
message.message_info.platform, # type: ignore
|
||||
replace_bot_name=True,
|
||||
processed_text, message.message_info.platform, replace_bot_name=True
|
||||
)
|
||||
# if not processed_plain_text:
|
||||
# print(message)
|
||||
|
||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore
|
||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}")
|
||||
|
||||
# 如果是群聊,获取群号和群昵称
|
||||
group_id = None
|
||||
group_nick_name = None
|
||||
if chat.group_info:
|
||||
group_id = chat.group_info.group_id # type: ignore
|
||||
group_nick_name = userinfo.user_cardname # type: ignore
|
||||
group_id = chat.group_info.group_id
|
||||
group_nick_name = userinfo.user_cardname
|
||||
|
||||
_ = Person.register_person(
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
user_id=message.message_info.user_info.user_id, # type: ignore
|
||||
nickname=userinfo.user_nickname, # type: ignore
|
||||
platform=message.message_info.platform,
|
||||
user_id=user_id,
|
||||
nickname=nickname,
|
||||
group_id=group_id,
|
||||
group_nick_name=group_nick_name,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_emoji_manager",
|
||||
"get_chat_manager",
|
||||
"MessageStorage",
|
||||
"emoji_manager",
|
||||
]
|
||||
|
||||
@@ -2,13 +2,15 @@ import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
|
||||
from sqlmodel import select, col
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import ChatStreams # 新增导入
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ChatSession
|
||||
|
||||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||||
if TYPE_CHECKING:
|
||||
@@ -76,7 +78,7 @@ class ChatStream:
|
||||
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
||||
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
||||
self.saved = False
|
||||
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
|
||||
self.context: Optional[ChatMessageContext] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
@@ -95,10 +97,13 @@ class ChatStream:
|
||||
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
|
||||
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
|
||||
|
||||
if user_info is None:
|
||||
raise ValueError("user_info is required to build ChatStream")
|
||||
|
||||
return cls(
|
||||
stream_id=data["stream_id"],
|
||||
platform=data["platform"],
|
||||
user_info=user_info, # type: ignore
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
data=data,
|
||||
)
|
||||
@@ -128,12 +133,7 @@ class ChatManager:
|
||||
if not self._initialized:
|
||||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
# 确保 ChatStreams 表存在
|
||||
db.create_tables([ChatStreams], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
|
||||
get_db_session()
|
||||
|
||||
self._initialized = True
|
||||
# 在事件循环中启动初始化
|
||||
@@ -161,8 +161,13 @@ class ChatManager:
|
||||
|
||||
def register_message(self, message: "MessageRecv"):
|
||||
"""注册消息到聊天流"""
|
||||
platform = message.message_info.platform or ""
|
||||
if not platform:
|
||||
raise ValueError("platform is required for ChatStream")
|
||||
if message.message_info.user_info is None and message.message_info.group_info is None:
|
||||
raise ValueError("user_info or group_info is required for ChatStream")
|
||||
stream_id = self._generate_stream_id(
|
||||
message.message_info.platform, # type: ignore
|
||||
platform,
|
||||
message.message_info.user_info,
|
||||
message.message_info.group_info,
|
||||
)
|
||||
@@ -176,12 +181,18 @@ class ChatManager:
|
||||
"""生成聊天流唯一ID"""
|
||||
if not user_info and not group_info:
|
||||
raise ValueError("用户信息或群组信息必须提供")
|
||||
if group_info is None and user_info is None:
|
||||
raise ValueError("用户信息或群组信息必须提供")
|
||||
|
||||
if group_info:
|
||||
# 组合关键信息
|
||||
components = [platform, str(group_info.group_id)]
|
||||
else:
|
||||
components = [platform, str(user_info.user_id), "private"] # type: ignore
|
||||
if user_info is None:
|
||||
raise ValueError("用户信息或群组信息必须提供")
|
||||
if user_info.user_id is None:
|
||||
raise ValueError("user_id is required for private stream")
|
||||
components = [platform, str(user_info.user_id), "private"]
|
||||
|
||||
# 使用MD5生成唯一ID
|
||||
key = "_".join(components)
|
||||
@@ -231,33 +242,35 @@ class ChatManager:
|
||||
|
||||
# 检查数据库中是否存在
|
||||
def _db_find_stream_sync(s_id: str):
|
||||
return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
|
||||
with get_db_session() as session:
|
||||
statement = select(ChatSession).where(col(ChatSession.session_id) == s_id)
|
||||
return session.exec(statement).first()
|
||||
|
||||
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
|
||||
|
||||
if model_instance:
|
||||
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"platform": model_instance.platform,
|
||||
"user_id": model_instance.user_id,
|
||||
"user_nickname": model_instance.user_nickname,
|
||||
"user_cardname": model_instance.user_cardname or "",
|
||||
"user_nickname": "",
|
||||
"user_cardname": "",
|
||||
}
|
||||
group_info_data = None
|
||||
if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
|
||||
if model_instance.group_id:
|
||||
group_info_data = {
|
||||
"platform": model_instance.group_platform,
|
||||
"platform": model_instance.platform,
|
||||
"group_id": model_instance.group_id,
|
||||
"group_name": model_instance.group_name,
|
||||
"group_name": "",
|
||||
}
|
||||
|
||||
data_for_from_dict = {
|
||||
"stream_id": model_instance.stream_id,
|
||||
"stream_id": model_instance.session_id,
|
||||
"platform": model_instance.platform,
|
||||
"user_info": user_info_data,
|
||||
"group_info": group_info_data,
|
||||
"create_time": model_instance.create_time,
|
||||
"last_active_time": model_instance.last_active_time,
|
||||
"create_time": model_instance.created_timestamp.timestamp(),
|
||||
"last_active_time": model_instance.last_active_timestamp.timestamp(),
|
||||
}
|
||||
stream = ChatStream.from_dict(data_for_from_dict)
|
||||
# 更新用户信息和群组信息
|
||||
@@ -329,20 +342,26 @@ class ChatManager:
|
||||
user_info_d = s_data_dict.get("user_info")
|
||||
group_info_d = s_data_dict.get("group_info")
|
||||
|
||||
fields_to_save = {
|
||||
"platform": s_data_dict["platform"],
|
||||
"create_time": s_data_dict["create_time"],
|
||||
"last_active_time": s_data_dict["last_active_time"],
|
||||
"user_platform": user_info_d["platform"] if user_info_d else "",
|
||||
"user_id": user_info_d["user_id"] if user_info_d else "",
|
||||
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
||||
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
||||
"group_platform": group_info_d["platform"] if group_info_d else "",
|
||||
"group_id": group_info_d["group_id"] if group_info_d else "",
|
||||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||
}
|
||||
with get_db_session() as session:
|
||||
statement = select(ChatSession).where(col(ChatSession.session_id) == s_data_dict["stream_id"])
|
||||
record = session.exec(statement).first()
|
||||
if record is None:
|
||||
record = ChatSession(
|
||||
session_id=s_data_dict["stream_id"],
|
||||
platform=s_data_dict["platform"],
|
||||
user_id=user_info_d["user_id"] if user_info_d else None,
|
||||
group_id=group_info_d["group_id"] if group_info_d else None,
|
||||
created_timestamp=datetime.fromtimestamp(s_data_dict["create_time"]),
|
||||
last_active_timestamp=datetime.fromtimestamp(s_data_dict["last_active_time"]),
|
||||
)
|
||||
session.add(record)
|
||||
return
|
||||
|
||||
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
|
||||
record.platform = s_data_dict["platform"]
|
||||
record.user_id = user_info_d["user_id"] if user_info_d else None
|
||||
record.group_id = group_info_d["group_id"] if group_info_d else None
|
||||
record.created_timestamp = datetime.fromtimestamp(s_data_dict["create_time"])
|
||||
record.last_active_timestamp = datetime.fromtimestamp(s_data_dict["last_active_time"])
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
|
||||
@@ -361,30 +380,32 @@ class ChatManager:
|
||||
|
||||
def _db_load_all_streams_sync():
|
||||
loaded_streams_data = []
|
||||
for model_instance in ChatStreams.select():
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
"user_nickname": model_instance.user_nickname,
|
||||
"user_cardname": model_instance.user_cardname or "",
|
||||
}
|
||||
group_info_data = None
|
||||
if model_instance.group_id:
|
||||
group_info_data = {
|
||||
"platform": model_instance.group_platform,
|
||||
"group_id": model_instance.group_id,
|
||||
"group_name": model_instance.group_name,
|
||||
with get_db_session() as session:
|
||||
statement = select(ChatSession)
|
||||
for model_instance in session.exec(statement).all():
|
||||
user_info_data = {
|
||||
"platform": model_instance.platform,
|
||||
"user_id": model_instance.user_id or "",
|
||||
"user_nickname": "",
|
||||
"user_cardname": "",
|
||||
}
|
||||
group_info_data = None
|
||||
if model_instance.group_id:
|
||||
group_info_data = {
|
||||
"platform": model_instance.platform,
|
||||
"group_id": model_instance.group_id,
|
||||
"group_name": "",
|
||||
}
|
||||
|
||||
data_for_from_dict = {
|
||||
"stream_id": model_instance.stream_id,
|
||||
"platform": model_instance.platform,
|
||||
"user_info": user_info_data,
|
||||
"group_info": group_info_data,
|
||||
"create_time": model_instance.create_time,
|
||||
"last_active_time": model_instance.last_active_time,
|
||||
}
|
||||
loaded_streams_data.append(data_for_from_dict)
|
||||
data_for_from_dict = {
|
||||
"stream_id": model_instance.session_id,
|
||||
"platform": model_instance.platform,
|
||||
"user_info": user_info_data,
|
||||
"group_info": group_info_data,
|
||||
"create_time": model_instance.created_timestamp.timestamp(),
|
||||
"last_active_time": model_instance.last_active_timestamp.timestamp(),
|
||||
}
|
||||
loaded_streams_data.append(data_for_from_dict)
|
||||
return loaded_streams_data
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,36 +1,74 @@
|
||||
import re
|
||||
import json
|
||||
import traceback
|
||||
from typing import Union
|
||||
from datetime import datetime
|
||||
from collections.abc import Mapping
|
||||
from typing import cast
|
||||
|
||||
from src.common.database.database_model import Messages, Images
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
|
||||
from sqlmodel import col, select
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType, Messages
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.message_component_model import MessageSequence, TextComponent
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
from .chat_stream import ChatStream
|
||||
from .message import MessageSending, MessageRecv
|
||||
from .message import MessageRecv, MessageSending
|
||||
|
||||
logger = get_logger("message_storage")
|
||||
|
||||
|
||||
class MessageStorage:
|
||||
@staticmethod
|
||||
def _serialize_keywords(keywords) -> str:
|
||||
def _coerce_str_list(value: object) -> list[str]:
|
||||
if isinstance(value, list):
|
||||
return [str(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return [str(item) for item in value]
|
||||
if isinstance(value, set):
|
||||
return [str(item) for item in value]
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _get_str(mapping: Mapping[str, object], key: str, default: str = "") -> str:
|
||||
value = mapping.get(key)
|
||||
if value is None:
|
||||
return default
|
||||
return str(value)
|
||||
|
||||
@staticmethod
|
||||
def _get_optional_str(mapping: Mapping[str, object], key: str) -> str | None:
|
||||
value = mapping.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_keywords(keywords: list[str] | None) -> str:
|
||||
"""将关键词列表序列化为JSON字符串"""
|
||||
if isinstance(keywords, list):
|
||||
return json.dumps(keywords, ensure_ascii=False)
|
||||
return "[]"
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_keywords(keywords_str: str) -> list:
|
||||
def _deserialize_keywords(keywords_str: str) -> list[str]:
|
||||
"""将JSON字符串反序列化为关键词列表"""
|
||||
if not keywords_str:
|
||||
return []
|
||||
try:
|
||||
return json.loads(keywords_str)
|
||||
parsed = cast(object, json.loads(keywords_str))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return []
|
||||
if isinstance(parsed, list):
|
||||
return [str(item) for item in parsed]
|
||||
if isinstance(parsed, str):
|
||||
return [parsed]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
||||
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
|
||||
"""存储消息到数据库"""
|
||||
try:
|
||||
# 通知消息不存储
|
||||
@@ -66,7 +104,7 @@ class MessageStorage:
|
||||
priority_mode = ""
|
||||
priority_info = {}
|
||||
is_emoji = False
|
||||
is_picid = False
|
||||
is_picture = False
|
||||
is_notify = False
|
||||
is_command = False
|
||||
key_words = ""
|
||||
@@ -83,66 +121,73 @@ class MessageStorage:
|
||||
priority_mode = message.priority_mode
|
||||
priority_info = message.priority_info
|
||||
is_emoji = message.is_emoji
|
||||
is_picid = message.is_picid
|
||||
is_picture = message.is_picid
|
||||
is_notify = message.is_notify
|
||||
is_command = message.is_command
|
||||
intercept_message_level = getattr(message, "intercept_message_level", 0)
|
||||
# 序列化关键词列表为JSON字符串
|
||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||
key_words = MessageStorage._serialize_keywords(MessageStorage._coerce_str_list(message.key_words))
|
||||
key_words_lite = MessageStorage._serialize_keywords(
|
||||
MessageStorage._coerce_str_list(message.key_words_lite)
|
||||
)
|
||||
selected_expressions = ""
|
||||
|
||||
chat_info_dict = chat_stream.to_dict()
|
||||
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
||||
chat_info_dict = cast(dict[str, object], chat_stream.to_dict())
|
||||
if message.message_info.user_info is None:
|
||||
raise ValueError("message.user_info is required")
|
||||
user_info_dict = cast(dict[str, object], message.message_info.user_info.to_dict())
|
||||
|
||||
# message_id 现在是 TextField,直接使用字符串值
|
||||
msg_id = message.message_info.message_id
|
||||
msg_id = message.message_info.message_id or ""
|
||||
|
||||
# 安全地获取 group_info, 如果为 None 则视为空字典
|
||||
group_info_from_chat = chat_info_dict.get("group_info") or {}
|
||||
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
|
||||
user_info_from_chat = chat_info_dict.get("user_info") or {}
|
||||
group_info_from_chat = cast(dict[str, object], chat_info_dict.get("group_info") or {})
|
||||
|
||||
Messages.create(
|
||||
message_id=msg_id,
|
||||
time=float(message.message_info.time), # type: ignore
|
||||
chat_id=chat_stream.stream_id,
|
||||
# Flattened chat_info
|
||||
additional_config: dict[str, object] = dict(message.message_info.additional_config or {})
|
||||
additional_config.update(
|
||||
{
|
||||
"interest_value": interest_value,
|
||||
"priority_mode": priority_mode,
|
||||
"priority_info": priority_info,
|
||||
"reply_probability_boost": reply_probability_boost,
|
||||
"intercept_message_level": intercept_message_level,
|
||||
"key_words": key_words,
|
||||
"key_words_lite": key_words_lite,
|
||||
"selected_expressions": selected_expressions,
|
||||
"is_picid": is_picture,
|
||||
}
|
||||
)
|
||||
processed_text_for_raw = filtered_processed_plain_text or filtered_display_message or ""
|
||||
raw_sequence = MessageSequence([TextComponent(processed_text_for_raw)] if processed_text_for_raw else [])
|
||||
raw_content = MessageUtils.from_MaiSeq_to_db_record_msg(raw_sequence)
|
||||
|
||||
timestamp_value = message.message_info.time
|
||||
if timestamp_value is None:
|
||||
raise ValueError("message.message_info.time is required")
|
||||
db_message = Messages(
|
||||
message_id=str(msg_id),
|
||||
timestamp=datetime.fromtimestamp(float(timestamp_value)),
|
||||
platform=MessageStorage._get_str(chat_info_dict, "platform"),
|
||||
user_id=MessageStorage._get_str(user_info_dict, "user_id"),
|
||||
user_nickname=MessageStorage._get_str(user_info_dict, "user_nickname"),
|
||||
user_cardname=MessageStorage._get_optional_str(user_info_dict, "user_cardname"),
|
||||
group_id=MessageStorage._get_optional_str(group_info_from_chat, "group_id"),
|
||||
group_name=MessageStorage._get_optional_str(group_info_from_chat, "group_name"),
|
||||
is_mentioned=bool(is_mentioned),
|
||||
is_at=bool(is_at),
|
||||
session_id=chat_stream.stream_id,
|
||||
reply_to=reply_to,
|
||||
is_mentioned=is_mentioned,
|
||||
is_at=is_at,
|
||||
reply_probability_boost=reply_probability_boost,
|
||||
chat_info_stream_id=chat_info_dict.get("stream_id"),
|
||||
chat_info_platform=chat_info_dict.get("platform"),
|
||||
chat_info_user_platform=user_info_from_chat.get("platform"),
|
||||
chat_info_user_id=user_info_from_chat.get("user_id"),
|
||||
chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
|
||||
chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
|
||||
chat_info_group_platform=group_info_from_chat.get("platform"),
|
||||
chat_info_group_id=group_info_from_chat.get("group_id"),
|
||||
chat_info_group_name=group_info_from_chat.get("group_name"),
|
||||
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
|
||||
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
|
||||
# Flattened user_info (message sender)
|
||||
user_platform=user_info_dict.get("platform"),
|
||||
user_id=user_info_dict.get("user_id"),
|
||||
user_nickname=user_info_dict.get("user_nickname"),
|
||||
user_cardname=user_info_dict.get("user_cardname"),
|
||||
# Text content
|
||||
is_emoji=is_emoji,
|
||||
is_picture=is_picture,
|
||||
is_command=is_command,
|
||||
is_notify=is_notify,
|
||||
raw_content=raw_content,
|
||||
processed_plain_text=filtered_processed_plain_text,
|
||||
display_message=filtered_display_message,
|
||||
interest_value=interest_value,
|
||||
priority_mode=priority_mode,
|
||||
priority_info=priority_info,
|
||||
is_emoji=is_emoji,
|
||||
is_picid=is_picid,
|
||||
is_notify=is_notify,
|
||||
is_command=is_command,
|
||||
intercept_message_level=intercept_message_level,
|
||||
key_words=key_words,
|
||||
key_words_lite=key_words_lite,
|
||||
selected_expressions=selected_expressions,
|
||||
additional_config=json.dumps(additional_config, ensure_ascii=False),
|
||||
)
|
||||
with get_db_session() as session:
|
||||
session.add(db_message)
|
||||
except Exception:
|
||||
logger.exception("存储消息失败")
|
||||
logger.error(f"消息:{message}")
|
||||
@@ -156,16 +201,21 @@ class MessageStorage:
|
||||
if not qq_message_id:
|
||||
logger.info("消息不存在message_id,无法更新")
|
||||
return False
|
||||
if matched_message := (
|
||||
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
|
||||
):
|
||||
# 更新找到的消息记录
|
||||
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
|
||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||
return True
|
||||
else:
|
||||
logger.debug("未找到匹配的消息")
|
||||
return False
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Messages)
|
||||
.where(col(Messages.message_id) == mmc_message_id)
|
||||
.order_by(col(Messages.timestamp).desc())
|
||||
.limit(1)
|
||||
)
|
||||
matched_message = session.exec(statement).first()
|
||||
if matched_message:
|
||||
matched_message.message_id = qq_message_id
|
||||
session.add(matched_message)
|
||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||
return True
|
||||
logger.debug("未找到匹配的消息")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息ID失败: {e}")
|
||||
@@ -182,13 +232,18 @@ class MessageStorage:
|
||||
logger.debug("文本中没有图片标记,直接返回原文本")
|
||||
return text
|
||||
|
||||
def replace_match(match):
|
||||
def replace_match(match: re.Match[str]) -> str:
|
||||
description = match.group(1).strip()
|
||||
try:
|
||||
image_record = (
|
||||
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
|
||||
)
|
||||
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Images)
|
||||
.where((col(Images.description) == description) & (col(Images.image_type) == ImageType.IMAGE))
|
||||
.order_by(col(Images.record_time).desc())
|
||||
.limit(1)
|
||||
)
|
||||
image_record = session.exec(statement).first()
|
||||
return f"[picid:{image_record.id}]" if image_record else match.group(0)
|
||||
except Exception:
|
||||
return match.group(0)
|
||||
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import time
|
||||
import random
|
||||
import re
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable
|
||||
from rich.traceback import install
|
||||
|
||||
from sqlmodel import select, col
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseActionRecords
|
||||
from src.common.data_models.message_data_model import MessageAndActionModel
|
||||
from src.common.database.database_model import ActionRecords
|
||||
from src.common.database.database_model import Images
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ActionRecord, Images
|
||||
from src.person_info.person_info import Person, get_person_id
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids, is_bot_self
|
||||
|
||||
@@ -198,37 +198,38 @@ def get_actions_by_timestamp_with_chat(
|
||||
limit_mode: str = "latest",
|
||||
) -> List[DatabaseActionRecords]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
|
||||
query = ActionRecords.select().where(
|
||||
(ActionRecords.chat_id == chat_id)
|
||||
& (ActionRecords.time > timestamp_start) # type: ignore
|
||||
& (ActionRecords.time < timestamp_end) # type: ignore
|
||||
)
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(ActionRecord)
|
||||
.where((col(ActionRecord.session_id) == chat_id))
|
||||
.where(col(ActionRecord.timestamp) > datetime.fromtimestamp(timestamp_start))
|
||||
.where(col(ActionRecord.timestamp) < datetime.fromtimestamp(timestamp_end))
|
||||
)
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = query.order_by(ActionRecords.time.desc()).limit(limit)
|
||||
# 获取后需要反转列表,以保持最终输出为时间升序
|
||||
actions = list(query)
|
||||
actions.reverse()
|
||||
else: # earliest
|
||||
query = query.order_by(ActionRecords.time.asc()).limit(limit)
|
||||
else:
|
||||
query = query.order_by(ActionRecords.time.asc())
|
||||
|
||||
actions = list(query)
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
statement = statement.order_by(col(ActionRecord.timestamp).desc()).limit(limit)
|
||||
actions = list(session.exec(statement).all())
|
||||
actions = list(reversed(actions))
|
||||
else:
|
||||
statement = statement.order_by(col(ActionRecord.timestamp)).limit(limit)
|
||||
actions = list(session.exec(statement).all())
|
||||
else:
|
||||
statement = statement.order_by(col(ActionRecord.timestamp))
|
||||
actions = session.exec(statement).all()
|
||||
return [
|
||||
DatabaseActionRecords(
|
||||
action_id=action.action_id,
|
||||
time=action.time,
|
||||
time=action.timestamp.timestamp(),
|
||||
action_name=action.action_name,
|
||||
action_data=action.action_data,
|
||||
action_done=action.action_done,
|
||||
action_build_into_prompt=action.action_build_into_prompt,
|
||||
action_prompt_display=action.action_prompt_display,
|
||||
chat_id=action.chat_id,
|
||||
chat_info_stream_id=action.chat_info_stream_id,
|
||||
chat_info_platform=action.chat_info_platform,
|
||||
action_reasoning=action.action_reasoning,
|
||||
action_data=action.action_data or "{}",
|
||||
action_done=True,
|
||||
action_build_into_prompt=bool(action.action_display_prompt),
|
||||
action_prompt_display=action.action_display_prompt or "",
|
||||
chat_id=action.session_id,
|
||||
chat_info_stream_id=action.session_id,
|
||||
chat_info_platform=global_config.bot.platform,
|
||||
action_reasoning=action.action_reasoning or "",
|
||||
)
|
||||
for action in actions
|
||||
]
|
||||
@@ -238,25 +239,27 @@ def get_actions_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
||||
query = ActionRecords.select().where(
|
||||
(ActionRecords.chat_id == chat_id)
|
||||
& (ActionRecords.time >= timestamp_start) # type: ignore
|
||||
& (ActionRecords.time <= timestamp_end) # type: ignore
|
||||
)
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(ActionRecord)
|
||||
.where((col(ActionRecord.session_id) == chat_id))
|
||||
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(timestamp_start))
|
||||
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(timestamp_end))
|
||||
)
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = query.order_by(ActionRecords.time.desc()).limit(limit)
|
||||
# 获取后需要反转列表,以保持最终输出为时间升序
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in reversed(actions)]
|
||||
else: # earliest
|
||||
query = query.order_by(ActionRecords.time.asc()).limit(limit)
|
||||
else:
|
||||
query = query.order_by(ActionRecords.time.asc())
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
statement = statement.order_by(col(ActionRecord.timestamp).desc()).limit(limit)
|
||||
actions = list(session.exec(statement).all())
|
||||
actions = list(reversed(actions))
|
||||
else:
|
||||
statement = statement.order_by(col(ActionRecord.timestamp)).limit(limit)
|
||||
actions = list(session.exec(statement).all())
|
||||
else:
|
||||
statement = statement.order_by(col(ActionRecord.timestamp))
|
||||
actions = session.exec(statement).all()
|
||||
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in actions]
|
||||
return [action.model_dump() for action in actions]
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_random(
|
||||
@@ -278,7 +281,7 @@ def get_raw_msg_by_timestamp_random(
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_with_users(
|
||||
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
|
||||
timestamp_start: float, timestamp_end: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[DatabaseMessages]:
|
||||
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
@@ -316,7 +319,7 @@ def get_raw_msg_before_timestamp_with_chat(
|
||||
|
||||
|
||||
def get_raw_msg_before_timestamp_with_users(
|
||||
timestamp: float, person_ids: list, limit: int = 0
|
||||
timestamp: float, person_ids: List[str], limit: int = 0
|
||||
) -> List[DatabaseMessages]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
@@ -344,7 +347,7 @@ def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp
|
||||
|
||||
|
||||
def num_new_messages_since_with_users(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: List[str]
|
||||
) -> int:
|
||||
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
|
||||
if not person_ids: # 保持空列表检查
|
||||
@@ -358,7 +361,7 @@ def num_new_messages_since_with_users(
|
||||
|
||||
|
||||
def _build_readable_messages_internal(
|
||||
messages: List[MessageAndActionModel],
|
||||
messages: List[DatabaseMessages],
|
||||
replace_bot_name: bool = True,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
@@ -413,7 +416,7 @@ def _build_readable_messages_internal(
|
||||
# 匹配 [picid:xxxxx] 格式
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
|
||||
def replace_pic_id(match: re.Match) -> str:
|
||||
def replace_pic_id(match: re.Match[str]) -> str:
|
||||
nonlocal current_pic_counter
|
||||
nonlocal pic_counter
|
||||
pic_id = match.group(1)
|
||||
@@ -421,7 +424,8 @@ def _build_readable_messages_internal(
|
||||
if pic_id not in pic_description_cache:
|
||||
description = "内容正在阅读,请稍等"
|
||||
try:
|
||||
image = Images.get_or_none(Images.image_id == pic_id)
|
||||
with get_db_session() as session:
|
||||
image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None
|
||||
if image and image.description:
|
||||
description = image.description
|
||||
except Exception:
|
||||
@@ -438,16 +442,11 @@ def _build_readable_messages_internal(
|
||||
|
||||
# 1: 获取发送者信息并提取消息组件
|
||||
for message in messages:
|
||||
if message.is_action_record:
|
||||
# 对于动作记录,也处理图片ID
|
||||
content = process_pic_ids(message.display_message)
|
||||
detailed_messages_raw.append((message.time, message.user_nickname, content, True))
|
||||
continue
|
||||
|
||||
platform = message.user_platform
|
||||
user_id = message.user_id
|
||||
user_nickname = message.user_nickname
|
||||
user_cardname = message.user_cardname
|
||||
user_info = message.user_info
|
||||
platform = user_info.platform
|
||||
user_id = user_info.user_id
|
||||
user_nickname = user_info.user_nickname
|
||||
user_cardname = user_info.user_cardname
|
||||
|
||||
timestamp = message.time
|
||||
content = message.display_message or message.processed_plain_text or ""
|
||||
@@ -525,12 +524,12 @@ def _build_readable_messages_internal(
|
||||
if long_time_notice and prev_timestamp is not None:
|
||||
time_diff = timestamp - prev_timestamp
|
||||
time_diff_hours = time_diff / 3600
|
||||
|
||||
|
||||
# 检查是否跨天
|
||||
prev_date = time.strftime("%Y-%m-%d", time.localtime(prev_timestamp))
|
||||
current_date = time.strftime("%Y-%m-%d", time.localtime(timestamp))
|
||||
is_cross_day = prev_date != current_date
|
||||
|
||||
|
||||
# 如果间隔大于8小时或跨天,插入提示
|
||||
if time_diff_hours > 8 or is_cross_day:
|
||||
# 格式化日期为中文格式:xxxx年xx月xx日(去掉前导零)
|
||||
@@ -542,20 +541,15 @@ def _build_readable_messages_internal(
|
||||
hours_str = f"{int(time_diff_hours)}h"
|
||||
notice = f"以下聊天开始时间:{date_str}。距离上一条消息过去了{hours_str}\n"
|
||||
output_lines.append(notice)
|
||||
|
||||
|
||||
readable_time = translate_timestamp_to_human_readable(timestamp, mode=timestamp_mode)
|
||||
|
||||
# 查找消息id(如果有)并构建id_prefix
|
||||
message_id = timestamp_to_id_mapping.get(timestamp, "")
|
||||
id_prefix = f"[{message_id}]" if message_id else ""
|
||||
|
||||
if is_action:
|
||||
# 对于动作记录,使用特殊格式
|
||||
output_lines.append(f"{id_prefix}{readable_time}, {content}")
|
||||
else:
|
||||
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
|
||||
output_lines.append("\n") # 在每个消息块后添加换行,保持可读性
|
||||
|
||||
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
|
||||
output_lines.append("\n")
|
||||
prev_timestamp = timestamp
|
||||
|
||||
formatted_string = "".join(output_lines).strip()
|
||||
@@ -592,7 +586,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
# 从数据库中获取图片描述
|
||||
description = "内容正在阅读,请稍等"
|
||||
try:
|
||||
image = Images.get_or_none(Images.image_id == pic_id)
|
||||
with get_db_session() as session:
|
||||
image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None
|
||||
if image and image.description:
|
||||
description = image.description
|
||||
except Exception:
|
||||
@@ -663,7 +658,7 @@ async def build_readable_messages_with_list(
|
||||
允许通过参数控制格式化行为。
|
||||
"""
|
||||
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
[MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages],
|
||||
messages,
|
||||
replace_bot_name,
|
||||
timestamp_mode,
|
||||
truncate,
|
||||
@@ -754,7 +749,7 @@ def build_readable_messages(
|
||||
filtered_messages = []
|
||||
for msg in messages:
|
||||
# 获取消息内容
|
||||
content = msg.processed_plain_text
|
||||
content = msg.processed_plain_text or ""
|
||||
# 移除表情包
|
||||
emoji_pattern = r"\[表情包:[^\]]+\]"
|
||||
content = re.sub(emoji_pattern, "", content)
|
||||
@@ -765,17 +760,14 @@ def build_readable_messages(
|
||||
|
||||
messages = filtered_messages
|
||||
|
||||
copy_messages: List[MessageAndActionModel] = []
|
||||
copy_messages: List[DatabaseMessages] = []
|
||||
for msg in messages:
|
||||
if remove_emoji_stickers:
|
||||
# 创建 MessageAndActionModel 但移除表情包
|
||||
model = MessageAndActionModel.from_DatabaseMessages(msg)
|
||||
# 移除表情包
|
||||
if model.processed_plain_text:
|
||||
model.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", model.processed_plain_text)
|
||||
copy_messages.append(model)
|
||||
msg.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", msg.processed_plain_text or "")
|
||||
copy_messages.append(msg)
|
||||
else:
|
||||
copy_messages.append(MessageAndActionModel.from_DatabaseMessages(msg))
|
||||
copy_messages.append(msg)
|
||||
|
||||
if show_actions and copy_messages:
|
||||
# 获取所有消息的时间范围
|
||||
@@ -786,40 +778,45 @@ def build_readable_messages(
|
||||
chat_id = messages[0].chat_id if messages else None
|
||||
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
actions_in_range = (
|
||||
ActionRecords.select()
|
||||
.where(
|
||||
(ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)
|
||||
)
|
||||
.order_by(ActionRecords.time)
|
||||
)
|
||||
with get_db_session() as session:
|
||||
actions_in_range = session.exec(
|
||||
select(ActionRecord)
|
||||
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(min_time))
|
||||
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(max_time))
|
||||
.where(col(ActionRecord.session_id) == chat_id)
|
||||
.order_by(col(ActionRecord.timestamp))
|
||||
).all()
|
||||
|
||||
# 获取最新消息之后的第一个动作记录
|
||||
action_after_latest = (
|
||||
ActionRecords.select()
|
||||
.where((ActionRecords.time > max_time) & (ActionRecords.chat_id == chat_id))
|
||||
.order_by(ActionRecords.time)
|
||||
.limit(1)
|
||||
)
|
||||
with get_db_session() as session:
|
||||
action_after_latest = session.exec(
|
||||
select(ActionRecord)
|
||||
.where(col(ActionRecord.timestamp) > datetime.fromtimestamp(max_time))
|
||||
.where(col(ActionRecord.session_id) == chat_id)
|
||||
.order_by(col(ActionRecord.timestamp))
|
||||
.limit(1)
|
||||
).all()
|
||||
|
||||
# 合并两部分动作记录
|
||||
actions: List[ActionRecords] = list(actions_in_range) + list(action_after_latest)
|
||||
actions: List[ActionRecord] = list(actions_in_range) + list(action_after_latest)
|
||||
|
||||
# 将动作记录转换为消息格式
|
||||
for action in actions:
|
||||
# 只有当build_into_prompt为True时才添加动作记录
|
||||
if action.action_build_into_prompt:
|
||||
action_msg = MessageAndActionModel(
|
||||
time=float(action.time), # type: ignore
|
||||
user_id=global_config.bot.qq_account, # 使用机器人的QQ账号
|
||||
user_platform=global_config.bot.platform, # 使用机器人的平台
|
||||
user_nickname=global_config.bot.nickname, # 使用机器人的用户名
|
||||
user_cardname="", # 机器人没有群名片
|
||||
processed_plain_text=f"{action.action_prompt_display}",
|
||||
display_message=f"{action.action_prompt_display}",
|
||||
chat_info_platform=str(action.chat_info_platform),
|
||||
is_action_record=True, # 添加标识字段
|
||||
action_name=str(action.action_name), # 保存动作名称
|
||||
action_display_prompt = action.action_display_prompt or ""
|
||||
if action_display_prompt:
|
||||
action_msg = DatabaseMessages(
|
||||
message_id=f"action_{action.action_id}",
|
||||
time=float(action.timestamp.timestamp()),
|
||||
chat_id=chat_id or "",
|
||||
processed_plain_text=action_display_prompt,
|
||||
display_message=action_display_prompt,
|
||||
user_platform=global_config.bot.platform,
|
||||
user_id=str(global_config.bot.qq_account),
|
||||
user_nickname=global_config.bot.nickname,
|
||||
user_cardname="",
|
||||
chat_info_platform=str(global_config.bot.platform),
|
||||
chat_info_stream_id=chat_id or "",
|
||||
)
|
||||
copy_messages.append(action_msg)
|
||||
|
||||
@@ -1026,17 +1023,13 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
person_ids_set = set() # 使用集合来自动去重
|
||||
|
||||
for msg in messages:
|
||||
platform: str = msg.get("user_platform") # type: ignore
|
||||
user_id: str = msg.get("user_id") # type: ignore
|
||||
platform = msg.get("user_platform") or ""
|
||||
user_id = msg.get("user_id") or ""
|
||||
|
||||
# 检查必要信息是否存在 且 不是机器人自己
|
||||
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
||||
continue
|
||||
|
||||
# 添加空值检查,防止 platform 为 None 时出错
|
||||
if platform is None:
|
||||
platform = "unknown"
|
||||
|
||||
if person_id := get_person_id(platform, user_id):
|
||||
person_ids_set.add(person_id)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,18 +1,21 @@
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
import hashlib
|
||||
import uuid
|
||||
import io
|
||||
import numpy as np
|
||||
|
||||
from typing import Optional, Tuple
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from rich.traceback import install
|
||||
|
||||
from sqlmodel import select, col
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import Images, ImageDescriptions, EmojiDescriptionCache
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
@@ -38,11 +41,7 @@ class ImageManager:
|
||||
self._initialized = True
|
||||
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
|
||||
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
db.create_tables([Images, ImageDescriptions, EmojiDescriptionCache], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或表创建失败: {e}")
|
||||
get_db_session()
|
||||
|
||||
try:
|
||||
self._cleanup_invalid_descriptions()
|
||||
@@ -72,10 +71,12 @@ class ImageManager:
|
||||
Optional[str]: 描述文本,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
record = ImageDescriptions.get_or_none(
|
||||
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
|
||||
)
|
||||
return record.description if record else None
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type))
|
||||
)
|
||||
record = session.exec(statement).first()
|
||||
return record.description if record else None
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
|
||||
return None
|
||||
@@ -90,15 +91,27 @@ class ImageManager:
|
||||
description_type: 描述类型 ('emoji' 或 'image')
|
||||
"""
|
||||
try:
|
||||
current_timestamp = time.time()
|
||||
defaults = {"description": description, "timestamp": current_timestamp}
|
||||
desc_obj, created = ImageDescriptions.get_or_create(
|
||||
image_description_hash=image_hash, type=description_type, defaults=defaults
|
||||
)
|
||||
if not created: # 如果记录已存在,则更新
|
||||
desc_obj.description = description
|
||||
desc_obj.timestamp = current_timestamp
|
||||
desc_obj.save()
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type))
|
||||
)
|
||||
record = session.exec(statement).first()
|
||||
if record:
|
||||
record.description = description
|
||||
session.add(record)
|
||||
return
|
||||
|
||||
new_record = Images(
|
||||
image_hash=image_hash,
|
||||
description=description,
|
||||
full_path="",
|
||||
image_type=ImageType(description_type),
|
||||
query_count=0,
|
||||
is_registered=False,
|
||||
is_banned=False,
|
||||
vlm_processed=True,
|
||||
)
|
||||
session.add(new_record)
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||
|
||||
@@ -107,20 +120,18 @@ class ImageManager:
|
||||
"""清理数据库中 description 为空或为 'None' 的记录"""
|
||||
invalid_values = ["", "None"]
|
||||
|
||||
# 清理 Images 表
|
||||
deleted_images = (
|
||||
Images.delete().where((Images.description >> None) | (Images.description << invalid_values)).execute()
|
||||
)
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Images)
|
||||
.where(col(Images.description).is_(None) | col(Images.description).in_(invalid_values))
|
||||
.limit(1000)
|
||||
)
|
||||
records = session.exec(statement).all()
|
||||
for record in records:
|
||||
session.delete(record)
|
||||
|
||||
# 清理 ImageDescriptions 表
|
||||
deleted_descriptions = (
|
||||
ImageDescriptions.delete()
|
||||
.where((ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values))
|
||||
.execute()
|
||||
)
|
||||
|
||||
if deleted_images or deleted_descriptions:
|
||||
logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions} 条")
|
||||
if records:
|
||||
logger.info(f"[清理完成] 删除 Images: {len(records)} 条")
|
||||
else:
|
||||
logger.info("[清理完成] 未发现无效描述记录")
|
||||
|
||||
@@ -128,19 +139,15 @@ class ImageManager:
|
||||
def _cleanup_emoji_from_image_descriptions():
|
||||
"""清理Images和ImageDescriptions表中type为emoji的记录(已迁移到EmojiDescriptionCache)"""
|
||||
try:
|
||||
# 清理Images表中type为emoji的记录
|
||||
deleted_images = Images.delete().where(Images.type == "emoji").execute()
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(col(Images.image_type) == ImageType.EMOJI)
|
||||
records = session.exec(statement).all()
|
||||
for record in records:
|
||||
session.delete(record)
|
||||
|
||||
# 清理ImageDescriptions表中type为emoji的记录
|
||||
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
||||
|
||||
total_deleted = deleted_images + deleted_descriptions
|
||||
total_deleted = len(records)
|
||||
if total_deleted > 0:
|
||||
logger.info(
|
||||
f"[清理完成] 从Images表中删除 {deleted_images} 条emoji类型记录, "
|
||||
f"从ImageDescriptions表中删除 {deleted_descriptions} 条emoji类型记录, "
|
||||
f"共删除 {total_deleted} 条记录"
|
||||
)
|
||||
logger.info(f"[清理完成] 从Images表中删除 {total_deleted} 条emoji类型记录")
|
||||
else:
|
||||
logger.info("[清理完成] Images和ImageDescriptions表中未发现emoji类型记录")
|
||||
except Exception as e:
|
||||
@@ -148,14 +155,14 @@ class ImageManager:
|
||||
raise
|
||||
|
||||
async def get_emoji_tag(self, image_base64: str) -> str:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
emoji_manager = emoji_manager_instance
|
||||
if isinstance(image_base64, str):
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
emoji = await emoji_manager.get_emoji_from_manager(image_hash)
|
||||
emoji = emoji_manager.get_emoji_by_hash(image_hash)
|
||||
if not emoji:
|
||||
return "[表情包:未知]"
|
||||
emotion_list = emoji.emotion
|
||||
@@ -175,14 +182,14 @@ class ImageManager:
|
||||
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||
|
||||
# 检查是否已存在该表情包(通过哈希值)
|
||||
emoji_manager = get_emoji_manager()
|
||||
existing_emoji = await emoji_manager.get_emoji_from_manager(image_hash)
|
||||
emoji_manager = emoji_manager_instance
|
||||
existing_emoji = emoji_manager.get_emoji_by_hash(image_hash)
|
||||
if existing_emoji:
|
||||
logger.debug(f"[自动保存] 表情包已注册,跳过保存: {image_hash[:8]}...")
|
||||
return
|
||||
@@ -212,14 +219,15 @@ class ImageManager:
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
|
||||
|
||||
# 优先使用EmojiManager查询已注册表情包的描述
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
tags = await emoji_manager.get_emoji_tag_by_hash(image_hash)
|
||||
emoji_manager = emoji_manager_instance
|
||||
emoji = emoji_manager.get_emoji_by_hash(image_hash)
|
||||
tags = emoji.emotion if emoji else None
|
||||
if tags:
|
||||
tag_str = ",".join(tags)
|
||||
logger.info(f"[缓存命中] 使用已注册表情包描述: {tag_str}...")
|
||||
@@ -227,29 +235,26 @@ class ImageManager:
|
||||
except Exception as e:
|
||||
logger.debug(f"查询EmojiManager时出错: {e}")
|
||||
|
||||
# 查询EmojiDescriptionCache表的缓存(包含描述和情感标签)
|
||||
try:
|
||||
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
|
||||
)
|
||||
cache_record = session.exec(statement).first()
|
||||
if cache_record:
|
||||
# 优先使用情感标签,如果没有则使用详细描述
|
||||
result_text = ""
|
||||
if cache_record.emotion_tags:
|
||||
logger.info(
|
||||
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
|
||||
)
|
||||
result_text = f"[表情包:{cache_record.emotion_tags}]"
|
||||
if cache_record.emotion:
|
||||
logger.info(f"[缓存命中] 使用Images表中的情感标签: {cache_record.emotion[:50]}...")
|
||||
result_text = f"[表情包:{cache_record.emotion}]"
|
||||
elif cache_record.description:
|
||||
logger.info(
|
||||
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
|
||||
)
|
||||
logger.info(f"[缓存命中] 使用Images表中的描述: {cache_record.description[:50]}...")
|
||||
result_text = f"[表情包:{cache_record.description}]"
|
||||
|
||||
# 即使缓存命中,如果启用了steal_emoji,也检查是否需要保存文件
|
||||
if result_text:
|
||||
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
|
||||
return result_text
|
||||
except Exception as e:
|
||||
logger.debug(f"查询EmojiDescriptionCache时出错: {e}")
|
||||
logger.debug(f"查询Images缓存时出错: {e}")
|
||||
|
||||
# === 二步走识别流程 ===
|
||||
|
||||
@@ -309,33 +314,42 @@ class ImageManager:
|
||||
|
||||
logger.debug(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
|
||||
|
||||
# 再次检查缓存(防止并发情况下其他线程已经保存)
|
||||
try:
|
||||
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
|
||||
if cache_record and cache_record.emotion_tags:
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion_tags}")
|
||||
return f"[表情包:{cache_record.emotion_tags}]"
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
|
||||
)
|
||||
cache_record = session.exec(statement).first()
|
||||
if cache_record and cache_record.emotion:
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion}")
|
||||
return f"[表情包:{cache_record.emotion}]"
|
||||
except Exception as e:
|
||||
logger.debug(f"再次查询EmojiDescriptionCache时出错: {e}")
|
||||
logger.debug(f"再次查询Images缓存时出错: {e}")
|
||||
|
||||
# 保存识别出的详细描述和情感标签到 emoji_description_cache
|
||||
try:
|
||||
current_timestamp = time.time()
|
||||
cache_record, created = EmojiDescriptionCache.get_or_create(
|
||||
emoji_hash=image_hash,
|
||||
defaults={
|
||||
"description": detailed_description,
|
||||
"emotion_tags": final_emotion,
|
||||
"timestamp": current_timestamp,
|
||||
},
|
||||
)
|
||||
if not created:
|
||||
# 更新已有记录
|
||||
cache_record.description = detailed_description
|
||||
cache_record.emotion_tags = final_emotion
|
||||
cache_record.timestamp = current_timestamp
|
||||
cache_record.save()
|
||||
logger.info(f"[缓存保存] 表情包描述和情感标签已保存到EmojiDescriptionCache: {image_hash[:8]}...")
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
|
||||
)
|
||||
cache_record = session.exec(statement).first()
|
||||
if cache_record:
|
||||
cache_record.description = detailed_description
|
||||
cache_record.emotion = final_emotion
|
||||
session.add(cache_record)
|
||||
else:
|
||||
cache_record = Images(
|
||||
image_hash=image_hash,
|
||||
description=detailed_description,
|
||||
full_path="",
|
||||
image_type=ImageType.EMOJI,
|
||||
emotion=final_emotion,
|
||||
query_count=0,
|
||||
is_registered=False,
|
||||
is_banned=False,
|
||||
vlm_processed=True,
|
||||
)
|
||||
session.add(cache_record)
|
||||
logger.info(f"[缓存保存] 表情包描述和情感标签已保存到Images: {image_hash[:8]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包描述和情感标签缓存失败: {str(e)}")
|
||||
|
||||
@@ -358,14 +372,13 @@ class ImageManager:
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 优先检查Images表中是否已有完整的描述
|
||||
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(col(Images.image_hash) == image_hash)
|
||||
existing_image = session.exec(statement).first()
|
||||
if existing_image:
|
||||
# 更新计数
|
||||
if hasattr(existing_image, "count") and existing_image.count is not None:
|
||||
existing_image.count += 1
|
||||
else:
|
||||
existing_image.count = 1
|
||||
existing_image.save()
|
||||
existing_image.query_count += 1
|
||||
with get_db_session() as session:
|
||||
session.add(existing_image)
|
||||
|
||||
# 如果已有描述,直接返回
|
||||
if existing_image.description:
|
||||
@@ -377,7 +390,7 @@ class ImageManager:
|
||||
return f"[图片:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
|
||||
prompt = global_config.personality.visual_style
|
||||
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
@@ -402,26 +415,27 @@ class ImageManager:
|
||||
|
||||
# 保存到数据库,补充缺失字段
|
||||
if existing_image:
|
||||
existing_image.path = file_path
|
||||
existing_image.full_path = file_path
|
||||
existing_image.description = description
|
||||
existing_image.timestamp = current_timestamp
|
||||
if not hasattr(existing_image, "image_id") or not existing_image.image_id:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = True
|
||||
existing_image.save()
|
||||
existing_image.record_time = datetime.fromtimestamp(current_timestamp)
|
||||
existing_image.vlm_processed = True
|
||||
with get_db_session() as session:
|
||||
session.add(existing_image)
|
||||
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
|
||||
else:
|
||||
Images.create(
|
||||
image_id=str(uuid.uuid4()),
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="image",
|
||||
description=description,
|
||||
timestamp=current_timestamp,
|
||||
vlm_processed=True,
|
||||
count=1,
|
||||
)
|
||||
with get_db_session() as session:
|
||||
new_record = Images(
|
||||
image_hash=image_hash,
|
||||
description=description,
|
||||
full_path=file_path,
|
||||
image_type=ImageType.IMAGE,
|
||||
query_count=1,
|
||||
is_registered=False,
|
||||
is_banned=False,
|
||||
record_time=datetime.fromtimestamp(current_timestamp),
|
||||
vlm_processed=True,
|
||||
)
|
||||
session.add(new_record)
|
||||
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
||||
@@ -575,30 +589,17 @@ class ImageManager:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
|
||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
||||
if (
|
||||
not hasattr(existing_image, "image_id")
|
||||
or not existing_image.image_id
|
||||
or not hasattr(existing_image, "count")
|
||||
or existing_image.count is None
|
||||
or not hasattr(existing_image, "vlm_processed")
|
||||
or existing_image.vlm_processed is None
|
||||
):
|
||||
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
|
||||
if not existing_image.image_id:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if existing_image.count is None:
|
||||
existing_image.count = 0
|
||||
if existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = False
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.IMAGE)
|
||||
)
|
||||
existing_image = session.exec(statement).first()
|
||||
if existing_image:
|
||||
existing_image.query_count += 1
|
||||
session.add(existing_image)
|
||||
return str(existing_image.id), f"[picid:{existing_image.id}]"
|
||||
|
||||
existing_image.count += 1
|
||||
existing_image.save()
|
||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||
else:
|
||||
# print(f"图片不存在: {image_hash}")
|
||||
image_id = str(uuid.uuid4())
|
||||
image_id = str(uuid.uuid4())
|
||||
|
||||
# 保存新图片
|
||||
current_timestamp = time.time()
|
||||
@@ -612,15 +613,19 @@ class ImageManager:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
Images.create(
|
||||
image_id=image_id,
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="image",
|
||||
timestamp=current_timestamp,
|
||||
vlm_processed=False,
|
||||
count=1,
|
||||
)
|
||||
with get_db_session() as session:
|
||||
new_record = Images(
|
||||
image_hash=image_hash,
|
||||
description="",
|
||||
full_path=file_path,
|
||||
image_type=ImageType.IMAGE,
|
||||
query_count=1,
|
||||
is_registered=False,
|
||||
is_banned=False,
|
||||
record_time=datetime.fromtimestamp(current_timestamp),
|
||||
vlm_processed=False,
|
||||
)
|
||||
session.add(new_record)
|
||||
|
||||
# 启动异步VLM处理
|
||||
await self._process_image_with_vlm(image_id, image_base64)
|
||||
@@ -647,17 +652,26 @@ class ImageManager:
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 获取当前图片记录
|
||||
image = Images.get(Images.image_id == image_id)
|
||||
with get_db_session() as session:
|
||||
image = session.get(Images, int(image_id)) if image_id.isdigit() else None
|
||||
if image is None:
|
||||
logger.warning(f"未找到图片记录: {image_id}")
|
||||
return
|
||||
|
||||
# 优先检查是否已有其他相同哈希的图片记录包含描述
|
||||
existing_with_description = Images.get_or_none(
|
||||
(Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
|
||||
)
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
(col(Images.image_hash) == image_hash)
|
||||
& (col(Images.description).is_not(None))
|
||||
& (col(Images.description) != "")
|
||||
)
|
||||
existing_with_description = session.exec(statement).first()
|
||||
if existing_with_description and existing_with_description.id != image.id:
|
||||
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
|
||||
image.description = existing_with_description.description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
with get_db_session() as session:
|
||||
session.add(image)
|
||||
# 同时保存到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, existing_with_description.description, "image")
|
||||
return
|
||||
@@ -667,11 +681,12 @@ class ImageManager:
|
||||
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
|
||||
image.description = cached_description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
with get_db_session() as session:
|
||||
session.add(image)
|
||||
return
|
||||
|
||||
# 获取图片格式
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
|
||||
|
||||
# 构建prompt
|
||||
prompt = global_config.personality.visual_style
|
||||
@@ -692,7 +707,8 @@ class ImageManager:
|
||||
# 更新数据库
|
||||
image.description = description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
with get_db_session() as session:
|
||||
session.add(image)
|
||||
|
||||
# 保存描述到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
Reference in New Issue
Block a user