重构绝大部分模块以适配新版本的数据库和数据模型,修复缺少依赖问题,更新 pyproject

This commit is contained in:
DrSmoothl
2026-02-13 20:39:11 +08:00
parent c14736ffca
commit 16b16d2ca6
29 changed files with 2459 additions and 1737 deletions

View File

@@ -1,17 +1,17 @@
import time
import random
import re
from datetime import datetime
from typing import List, Dict, Any, Tuple, Optional, Callable
from rich.traceback import install
from sqlmodel import select, col
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.message_repository import find_messages, count_messages
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseActionRecords
from src.common.data_models.message_data_model import MessageAndActionModel
from src.common.database.database_model import ActionRecords
from src.common.database.database_model import Images
from src.common.database.database import get_db_session
from src.common.database.database_model import ActionRecord, Images
from src.person_info.person_info import Person, get_person_id
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids, is_bot_self
@@ -198,37 +198,38 @@ def get_actions_by_timestamp_with_chat(
limit_mode: str = "latest",
) -> List[DatabaseActionRecords]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
& (ActionRecords.time > timestamp_start) # type: ignore
& (ActionRecords.time < timestamp_end) # type: ignore
)
with get_db_session() as session:
statement = (
select(ActionRecord)
.where((col(ActionRecord.session_id) == chat_id))
.where(col(ActionRecord.timestamp) > datetime.fromtimestamp(timestamp_start))
.where(col(ActionRecord.timestamp) < datetime.fromtimestamp(timestamp_end))
)
if limit > 0:
if limit_mode == "latest":
query = query.order_by(ActionRecords.time.desc()).limit(limit)
# 获取后需要反转列表,以保持最终输出为时间升序
actions = list(query)
actions.reverse()
else: # earliest
query = query.order_by(ActionRecords.time.asc()).limit(limit)
else:
query = query.order_by(ActionRecords.time.asc())
actions = list(query)
if limit > 0:
if limit_mode == "latest":
statement = statement.order_by(col(ActionRecord.timestamp).desc()).limit(limit)
actions = list(session.exec(statement).all())
actions = list(reversed(actions))
else:
statement = statement.order_by(col(ActionRecord.timestamp)).limit(limit)
actions = list(session.exec(statement).all())
else:
statement = statement.order_by(col(ActionRecord.timestamp))
actions = session.exec(statement).all()
return [
DatabaseActionRecords(
action_id=action.action_id,
time=action.time,
time=action.timestamp.timestamp(),
action_name=action.action_name,
action_data=action.action_data,
action_done=action.action_done,
action_build_into_prompt=action.action_build_into_prompt,
action_prompt_display=action.action_prompt_display,
chat_id=action.chat_id,
chat_info_stream_id=action.chat_info_stream_id,
chat_info_platform=action.chat_info_platform,
action_reasoning=action.action_reasoning,
action_data=action.action_data or "{}",
action_done=True,
action_build_into_prompt=bool(action.action_display_prompt),
action_prompt_display=action.action_display_prompt or "",
chat_id=action.session_id,
chat_info_stream_id=action.session_id,
chat_info_platform=global_config.bot.platform,
action_reasoning=action.action_reasoning or "",
)
for action in actions
]
@@ -238,25 +239,27 @@ def get_actions_by_timestamp_with_chat_inclusive(
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
& (ActionRecords.time >= timestamp_start) # type: ignore
& (ActionRecords.time <= timestamp_end) # type: ignore
)
with get_db_session() as session:
statement = (
select(ActionRecord)
.where((col(ActionRecord.session_id) == chat_id))
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(timestamp_start))
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(timestamp_end))
)
if limit > 0:
if limit_mode == "latest":
query = query.order_by(ActionRecords.time.desc()).limit(limit)
# 获取后需要反转列表,以保持最终输出为时间升序
actions = list(query)
return [action.__data__ for action in reversed(actions)]
else: # earliest
query = query.order_by(ActionRecords.time.asc()).limit(limit)
else:
query = query.order_by(ActionRecords.time.asc())
if limit > 0:
if limit_mode == "latest":
statement = statement.order_by(col(ActionRecord.timestamp).desc()).limit(limit)
actions = list(session.exec(statement).all())
actions = list(reversed(actions))
else:
statement = statement.order_by(col(ActionRecord.timestamp)).limit(limit)
actions = list(session.exec(statement).all())
else:
statement = statement.order_by(col(ActionRecord.timestamp))
actions = session.exec(statement).all()
actions = list(query)
return [action.__data__ for action in actions]
return [action.model_dump() for action in actions]
def get_raw_msg_by_timestamp_random(
@@ -278,7 +281,7 @@ def get_raw_msg_by_timestamp_random(
def get_raw_msg_by_timestamp_with_users(
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
timestamp_start: float, timestamp_end: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[DatabaseMessages]:
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
@@ -316,7 +319,7 @@ def get_raw_msg_before_timestamp_with_chat(
def get_raw_msg_before_timestamp_with_users(
timestamp: float, person_ids: list, limit: int = 0
timestamp: float, person_ids: List[str], limit: int = 0
) -> List[DatabaseMessages]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
@@ -344,7 +347,7 @@ def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp
def num_new_messages_since_with_users(
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: List[str]
) -> int:
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
if not person_ids: # 保持空列表检查
@@ -358,7 +361,7 @@ def num_new_messages_since_with_users(
def _build_readable_messages_internal(
messages: List[MessageAndActionModel],
messages: List[DatabaseMessages],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
truncate: bool = False,
@@ -413,7 +416,7 @@ def _build_readable_messages_internal(
# 匹配 [picid:xxxxx] 格式
pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(match: re.Match) -> str:
def replace_pic_id(match: re.Match[str]) -> str:
nonlocal current_pic_counter
nonlocal pic_counter
pic_id = match.group(1)
@@ -421,7 +424,8 @@ def _build_readable_messages_internal(
if pic_id not in pic_description_cache:
description = "内容正在阅读,请稍等"
try:
image = Images.get_or_none(Images.image_id == pic_id)
with get_db_session() as session:
image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None
if image and image.description:
description = image.description
except Exception:
@@ -438,16 +442,11 @@ def _build_readable_messages_internal(
# 1: 获取发送者信息并提取消息组件
for message in messages:
if message.is_action_record:
# 对于动作记录也处理图片ID
content = process_pic_ids(message.display_message)
detailed_messages_raw.append((message.time, message.user_nickname, content, True))
continue
platform = message.user_platform
user_id = message.user_id
user_nickname = message.user_nickname
user_cardname = message.user_cardname
user_info = message.user_info
platform = user_info.platform
user_id = user_info.user_id
user_nickname = user_info.user_nickname
user_cardname = user_info.user_cardname
timestamp = message.time
content = message.display_message or message.processed_plain_text or ""
@@ -525,12 +524,12 @@ def _build_readable_messages_internal(
if long_time_notice and prev_timestamp is not None:
time_diff = timestamp - prev_timestamp
time_diff_hours = time_diff / 3600
# 检查是否跨天
prev_date = time.strftime("%Y-%m-%d", time.localtime(prev_timestamp))
current_date = time.strftime("%Y-%m-%d", time.localtime(timestamp))
is_cross_day = prev_date != current_date
# 如果间隔大于8小时或跨天插入提示
if time_diff_hours > 8 or is_cross_day:
# 格式化日期为中文格式xxxx年xx月xx日去掉前导零
@@ -542,20 +541,15 @@ def _build_readable_messages_internal(
hours_str = f"{int(time_diff_hours)}h"
notice = f"以下聊天开始时间:{date_str}。距离上一条消息过去了{hours_str}\n"
output_lines.append(notice)
readable_time = translate_timestamp_to_human_readable(timestamp, mode=timestamp_mode)
# 查找消息id如果有并构建id_prefix
message_id = timestamp_to_id_mapping.get(timestamp, "")
id_prefix = f"[{message_id}]" if message_id else ""
if is_action:
# 对于动作记录,使用特殊格式
output_lines.append(f"{id_prefix}{readable_time}, {content}")
else:
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
output_lines.append("\n") # 在每个消息块后添加换行,保持可读性
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
output_lines.append("\n")
prev_timestamp = timestamp
formatted_string = "".join(output_lines).strip()
@@ -592,7 +586,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# 从数据库中获取图片描述
description = "内容正在阅读,请稍等"
try:
image = Images.get_or_none(Images.image_id == pic_id)
with get_db_session() as session:
image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None
if image and image.description:
description = image.description
except Exception:
@@ -663,7 +658,7 @@ async def build_readable_messages_with_list(
允许通过参数控制格式化行为。
"""
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
[MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages],
messages,
replace_bot_name,
timestamp_mode,
truncate,
@@ -754,7 +749,7 @@ def build_readable_messages(
filtered_messages = []
for msg in messages:
# 获取消息内容
content = msg.processed_plain_text
content = msg.processed_plain_text or ""
# 移除表情包
emoji_pattern = r"\[表情包:[^\]]+\]"
content = re.sub(emoji_pattern, "", content)
@@ -765,17 +760,14 @@ def build_readable_messages(
messages = filtered_messages
copy_messages: List[MessageAndActionModel] = []
copy_messages: List[DatabaseMessages] = []
for msg in messages:
if remove_emoji_stickers:
# 创建 MessageAndActionModel 但移除表情包
model = MessageAndActionModel.from_DatabaseMessages(msg)
# 移除表情包
if model.processed_plain_text:
model.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", model.processed_plain_text)
copy_messages.append(model)
msg.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", msg.processed_plain_text or "")
copy_messages.append(msg)
else:
copy_messages.append(MessageAndActionModel.from_DatabaseMessages(msg))
copy_messages.append(msg)
if show_actions and copy_messages:
# 获取所有消息的时间范围
@@ -786,40 +778,45 @@ def build_readable_messages(
chat_id = messages[0].chat_id if messages else None
# 获取这个时间范围内的动作记录并匹配chat_id
actions_in_range = (
ActionRecords.select()
.where(
(ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)
)
.order_by(ActionRecords.time)
)
with get_db_session() as session:
actions_in_range = session.exec(
select(ActionRecord)
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(min_time))
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(max_time))
.where(col(ActionRecord.session_id) == chat_id)
.order_by(col(ActionRecord.timestamp))
).all()
# 获取最新消息之后的第一个动作记录
action_after_latest = (
ActionRecords.select()
.where((ActionRecords.time > max_time) & (ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
)
with get_db_session() as session:
action_after_latest = session.exec(
select(ActionRecord)
.where(col(ActionRecord.timestamp) > datetime.fromtimestamp(max_time))
.where(col(ActionRecord.session_id) == chat_id)
.order_by(col(ActionRecord.timestamp))
.limit(1)
).all()
# 合并两部分动作记录
actions: List[ActionRecords] = list(actions_in_range) + list(action_after_latest)
actions: List[ActionRecord] = list(actions_in_range) + list(action_after_latest)
# 将动作记录转换为消息格式
for action in actions:
# 只有当build_into_prompt为True时才添加动作记录
if action.action_build_into_prompt:
action_msg = MessageAndActionModel(
time=float(action.time), # type: ignore
user_id=global_config.bot.qq_account, # 使用机器人的QQ账号
user_platform=global_config.bot.platform, # 使用机器人的平台
user_nickname=global_config.bot.nickname, # 使用机器人的用户名
user_cardname="", # 机器人没有群名片
processed_plain_text=f"{action.action_prompt_display}",
display_message=f"{action.action_prompt_display}",
chat_info_platform=str(action.chat_info_platform),
is_action_record=True, # 添加标识字段
action_name=str(action.action_name), # 保存动作名称
action_display_prompt = action.action_display_prompt or ""
if action_display_prompt:
action_msg = DatabaseMessages(
message_id=f"action_{action.action_id}",
time=float(action.timestamp.timestamp()),
chat_id=chat_id or "",
processed_plain_text=action_display_prompt,
display_message=action_display_prompt,
user_platform=global_config.bot.platform,
user_id=str(global_config.bot.qq_account),
user_nickname=global_config.bot.nickname,
user_cardname="",
chat_info_platform=str(global_config.bot.platform),
chat_info_stream_id=chat_id or "",
)
copy_messages.append(action_msg)
@@ -1026,17 +1023,13 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
person_ids_set = set() # 使用集合来自动去重
for msg in messages:
platform: str = msg.get("user_platform") # type: ignore
user_id: str = msg.get("user_id") # type: ignore
platform = msg.get("user_platform") or ""
user_id = msg.get("user_id") or ""
# 检查必要信息是否存在 且 不是机器人自己
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
continue
# 添加空值检查,防止 platform 为 None 时出错
if platform is None:
platform = "unknown"
if person_id := get_person_id(platform, user_id):
person_ids_set.add(person_id)

File diff suppressed because it is too large Load Diff

View File

@@ -1,18 +1,21 @@
import base64
from datetime import datetime
from typing import Optional, Tuple
import hashlib
import io
import os
import time
import hashlib
import uuid
import io
import numpy as np
from typing import Optional, Tuple
import numpy as np
from PIL import Image
from rich.traceback import install
from sqlmodel import select, col
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import Images, ImageDescriptions, EmojiDescriptionCache
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
@@ -38,11 +41,7 @@ class ImageManager:
self._initialized = True
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
try:
db.connect(reuse_if_open=True)
db.create_tables([Images, ImageDescriptions, EmojiDescriptionCache], safe=True)
except Exception as e:
logger.error(f"数据库连接或表创建失败: {e}")
get_db_session()
try:
self._cleanup_invalid_descriptions()
@@ -72,10 +71,12 @@ class ImageManager:
Optional[str]: 描述文本如果不存在则返回None
"""
try:
record = ImageDescriptions.get_or_none(
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
)
return record.description if record else None
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type))
)
record = session.exec(statement).first()
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
return None
@@ -90,15 +91,27 @@ class ImageManager:
description_type: 描述类型 ('emoji''image')
"""
try:
current_timestamp = time.time()
defaults = {"description": description, "timestamp": current_timestamp}
desc_obj, created = ImageDescriptions.get_or_create(
image_description_hash=image_hash, type=description_type, defaults=defaults
)
if not created: # 如果记录已存在,则更新
desc_obj.description = description
desc_obj.timestamp = current_timestamp
desc_obj.save()
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type))
)
record = session.exec(statement).first()
if record:
record.description = description
session.add(record)
return
new_record = Images(
image_hash=image_hash,
description=description,
full_path="",
image_type=ImageType(description_type),
query_count=0,
is_registered=False,
is_banned=False,
vlm_processed=True,
)
session.add(new_record)
except Exception as e:
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
@@ -107,20 +120,18 @@ class ImageManager:
"""清理数据库中 description 为空或为 'None' 的记录"""
invalid_values = ["", "None"]
# 清理 Images 表
deleted_images = (
Images.delete().where((Images.description >> None) | (Images.description << invalid_values)).execute()
)
with get_db_session() as session:
statement = (
select(Images)
.where(col(Images.description).is_(None) | col(Images.description).in_(invalid_values))
.limit(1000)
)
records = session.exec(statement).all()
for record in records:
session.delete(record)
# 清理 ImageDescriptions 表
deleted_descriptions = (
ImageDescriptions.delete()
.where((ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values))
.execute()
)
if deleted_images or deleted_descriptions:
logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions}")
if records:
logger.info(f"[清理完成] 删除 Images: {len(records)}")
else:
logger.info("[清理完成] 未发现无效描述记录")
@@ -128,19 +139,15 @@ class ImageManager:
def _cleanup_emoji_from_image_descriptions():
"""清理Images和ImageDescriptions表中type为emoji的记录已迁移到EmojiDescriptionCache"""
try:
# 清理Images表中type为emoji的记录
deleted_images = Images.delete().where(Images.type == "emoji").execute()
with get_db_session() as session:
statement = select(Images).where(col(Images.image_type) == ImageType.EMOJI)
records = session.exec(statement).all()
for record in records:
session.delete(record)
# 清理ImageDescriptions表中type为emoji的记录
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
total_deleted = deleted_images + deleted_descriptions
total_deleted = len(records)
if total_deleted > 0:
logger.info(
f"[清理完成] 从Images表中删除 {deleted_images} 条emoji类型记录, "
f"从ImageDescriptions表中删除 {deleted_descriptions} 条emoji类型记录, "
f"共删除 {total_deleted} 条记录"
)
logger.info(f"[清理完成] 从Images表中删除 {total_deleted} 条emoji类型记录")
else:
logger.info("[清理完成] Images和ImageDescriptions表中未发现emoji类型记录")
except Exception as e:
@@ -148,14 +155,14 @@ class ImageManager:
raise
async def get_emoji_tag(self, image_base64: str) -> str:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
emoji_manager = get_emoji_manager()
emoji_manager = emoji_manager_instance
if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
emoji = await emoji_manager.get_emoji_from_manager(image_hash)
emoji = emoji_manager.get_emoji_by_hash(image_hash)
if not emoji:
return "[表情包:未知]"
emotion_list = emoji.emotion
@@ -175,14 +182,14 @@ class ImageManager:
try:
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
# 确保目录存在
os.makedirs(EMOJI_DIR, exist_ok=True)
# 检查是否已存在该表情包(通过哈希值)
emoji_manager = get_emoji_manager()
existing_emoji = await emoji_manager.get_emoji_from_manager(image_hash)
emoji_manager = emoji_manager_instance
existing_emoji = emoji_manager.get_emoji_by_hash(image_hash)
if existing_emoji:
logger.debug(f"[自动保存] 表情包已注册,跳过保存: {image_hash[:8]}...")
return
@@ -212,14 +219,15 @@ class ImageManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
# 优先使用EmojiManager查询已注册表情包的描述
try:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
emoji_manager = get_emoji_manager()
tags = await emoji_manager.get_emoji_tag_by_hash(image_hash)
emoji_manager = emoji_manager_instance
emoji = emoji_manager.get_emoji_by_hash(image_hash)
tags = emoji.emotion if emoji else None
if tags:
tag_str = ",".join(tags)
logger.info(f"[缓存命中] 使用已注册表情包描述: {tag_str}...")
@@ -227,29 +235,26 @@ class ImageManager:
except Exception as e:
logger.debug(f"查询EmojiManager时出错: {e}")
# 查询EmojiDescriptionCache表的缓存包含描述和情感标签
try:
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
)
cache_record = session.exec(statement).first()
if cache_record:
# 优先使用情感标签,如果没有则使用详细描述
result_text = ""
if cache_record.emotion_tags:
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
)
result_text = f"[表情包:{cache_record.emotion_tags}]"
if cache_record.emotion:
logger.info(f"[缓存命中] 使用Images表中的情感标签: {cache_record.emotion[:50]}...")
result_text = f"[表情包:{cache_record.emotion}]"
elif cache_record.description:
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
)
logger.info(f"[缓存命中] 使用Images表中的描述: {cache_record.description[:50]}...")
result_text = f"[表情包:{cache_record.description}]"
# 即使缓存命中如果启用了steal_emoji也检查是否需要保存文件
if result_text:
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
return result_text
except Exception as e:
logger.debug(f"查询EmojiDescriptionCache时出错: {e}")
logger.debug(f"查询Images缓存时出错: {e}")
# === 二步走识别流程 ===
@@ -309,33 +314,42 @@ class ImageManager:
logger.debug(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
# 再次检查缓存(防止并发情况下其他线程已经保存)
try:
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
if cache_record and cache_record.emotion_tags:
logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion_tags}")
return f"[表情包:{cache_record.emotion_tags}]"
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
)
cache_record = session.exec(statement).first()
if cache_record and cache_record.emotion:
logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion}")
return f"[表情包:{cache_record.emotion}]"
except Exception as e:
logger.debug(f"再次查询EmojiDescriptionCache时出错: {e}")
logger.debug(f"再次查询Images缓存时出错: {e}")
# 保存识别出的详细描述和情感标签到 emoji_description_cache
try:
current_timestamp = time.time()
cache_record, created = EmojiDescriptionCache.get_or_create(
emoji_hash=image_hash,
defaults={
"description": detailed_description,
"emotion_tags": final_emotion,
"timestamp": current_timestamp,
},
)
if not created:
# 更新已有记录
cache_record.description = detailed_description
cache_record.emotion_tags = final_emotion
cache_record.timestamp = current_timestamp
cache_record.save()
logger.info(f"[缓存保存] 表情包描述和情感标签已保存到EmojiDescriptionCache: {image_hash[:8]}...")
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
)
cache_record = session.exec(statement).first()
if cache_record:
cache_record.description = detailed_description
cache_record.emotion = final_emotion
session.add(cache_record)
else:
cache_record = Images(
image_hash=image_hash,
description=detailed_description,
full_path="",
image_type=ImageType.EMOJI,
emotion=final_emotion,
query_count=0,
is_registered=False,
is_banned=False,
vlm_processed=True,
)
session.add(cache_record)
logger.info(f"[缓存保存] 表情包描述和情感标签已保存到Images: {image_hash[:8]}...")
except Exception as e:
logger.error(f"保存表情包描述和情感标签缓存失败: {str(e)}")
@@ -358,14 +372,13 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 优先检查Images表中是否已有完整的描述
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
with get_db_session() as session:
statement = select(Images).where(col(Images.image_hash) == image_hash)
existing_image = session.exec(statement).first()
if existing_image:
# 更新计数
if hasattr(existing_image, "count") and existing_image.count is not None:
existing_image.count += 1
else:
existing_image.count = 1
existing_image.save()
existing_image.query_count += 1
with get_db_session() as session:
session.add(existing_image)
# 如果已有描述,直接返回
if existing_image.description:
@@ -377,7 +390,7 @@ class ImageManager:
return f"[图片:{cached_description}]"
# 调用AI获取描述
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
prompt = global_config.personality.visual_style
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(
@@ -402,26 +415,27 @@ class ImageManager:
# 保存到数据库,补充缺失字段
if existing_image:
existing_image.path = file_path
existing_image.full_path = file_path
existing_image.description = description
existing_image.timestamp = current_timestamp
if not hasattr(existing_image, "image_id") or not existing_image.image_id:
existing_image.image_id = str(uuid.uuid4())
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
existing_image.vlm_processed = True
existing_image.save()
existing_image.record_time = datetime.fromtimestamp(current_timestamp)
existing_image.vlm_processed = True
with get_db_session() as session:
session.add(existing_image)
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
else:
Images.create(
image_id=str(uuid.uuid4()),
emoji_hash=image_hash,
path=file_path,
type="image",
description=description,
timestamp=current_timestamp,
vlm_processed=True,
count=1,
)
with get_db_session() as session:
new_record = Images(
image_hash=image_hash,
description=description,
full_path=file_path,
image_type=ImageType.IMAGE,
query_count=1,
is_registered=False,
is_banned=False,
record_time=datetime.fromtimestamp(current_timestamp),
vlm_processed=True,
)
session.add(new_record)
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
except Exception as e:
logger.error(f"保存图片文件或元数据失败: {str(e)}")
@@ -575,30 +589,17 @@ class ImageManager:
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
# 检查是否缺少必要字段,如果缺少则创建新记录
if (
not hasattr(existing_image, "image_id")
or not existing_image.image_id
or not hasattr(existing_image, "count")
or existing_image.count is None
or not hasattr(existing_image, "vlm_processed")
or existing_image.vlm_processed is None
):
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
if not existing_image.image_id:
existing_image.image_id = str(uuid.uuid4())
if existing_image.count is None:
existing_image.count = 0
if existing_image.vlm_processed is None:
existing_image.vlm_processed = False
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.IMAGE)
)
existing_image = session.exec(statement).first()
if existing_image:
existing_image.query_count += 1
session.add(existing_image)
return str(existing_image.id), f"[picid:{existing_image.id}]"
existing_image.count += 1
existing_image.save()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
else:
# print(f"图片不存在: {image_hash}")
image_id = str(uuid.uuid4())
image_id = str(uuid.uuid4())
# 保存新图片
current_timestamp = time.time()
@@ -612,15 +613,19 @@ class ImageManager:
f.write(image_bytes)
# 保存到数据库
Images.create(
image_id=image_id,
emoji_hash=image_hash,
path=file_path,
type="image",
timestamp=current_timestamp,
vlm_processed=False,
count=1,
)
with get_db_session() as session:
new_record = Images(
image_hash=image_hash,
description="",
full_path=file_path,
image_type=ImageType.IMAGE,
query_count=1,
is_registered=False,
is_banned=False,
record_time=datetime.fromtimestamp(current_timestamp),
vlm_processed=False,
)
session.add(new_record)
# 启动异步VLM处理
await self._process_image_with_vlm(image_id, image_base64)
@@ -647,17 +652,26 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 获取当前图片记录
image = Images.get(Images.image_id == image_id)
with get_db_session() as session:
image = session.get(Images, int(image_id)) if image_id.isdigit() else None
if image is None:
logger.warning(f"未找到图片记录: {image_id}")
return
# 优先检查是否已有其他相同哈希的图片记录包含描述
existing_with_description = Images.get_or_none(
(Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
)
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash)
& (col(Images.description).is_not(None))
& (col(Images.description) != "")
)
existing_with_description = session.exec(statement).first()
if existing_with_description and existing_with_description.id != image.id:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
image.description = existing_with_description.description
image.vlm_processed = True
image.save()
with get_db_session() as session:
session.add(image)
# 同时保存到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, existing_with_description.description, "image")
return
@@ -667,11 +681,12 @@ class ImageManager:
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description
image.vlm_processed = True
image.save()
with get_db_session() as session:
session.add(image)
return
# 获取图片格式
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
# 构建prompt
prompt = global_config.personality.visual_style
@@ -692,7 +707,8 @@ class ImageManager:
# 更新数据库
image.description = description
image.vlm_processed = True
image.save()
with get_db_session() as session:
session.add(image)
# 保存描述到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, description, "image")