重构绝大部分模块以适配新版本的数据库和数据模型,修复缺少依赖问题,更新 pyproject

This commit is contained in:
DrSmoothl
2026-02-13 20:39:11 +08:00
parent c14736ffca
commit 16b16d2ca6
29 changed files with 2459 additions and 1737 deletions

View File

@@ -1,18 +1,21 @@
import base64
from datetime import datetime
from typing import Optional, Tuple
import hashlib
import io
import os
import time
import hashlib
import uuid
import io
import numpy as np
from typing import Optional, Tuple
import numpy as np
from PIL import Image
from rich.traceback import install
from sqlmodel import select, col
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import Images, ImageDescriptions, EmojiDescriptionCache
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
@@ -38,11 +41,7 @@ class ImageManager:
self._initialized = True
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
try:
db.connect(reuse_if_open=True)
db.create_tables([Images, ImageDescriptions, EmojiDescriptionCache], safe=True)
except Exception as e:
logger.error(f"数据库连接或表创建失败: {e}")
get_db_session()
try:
self._cleanup_invalid_descriptions()
@@ -72,10 +71,12 @@ class ImageManager:
Optional[str]: 描述文本如果不存在则返回None
"""
try:
record = ImageDescriptions.get_or_none(
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
)
return record.description if record else None
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type))
)
record = session.exec(statement).first()
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
return None
@@ -90,15 +91,27 @@ class ImageManager:
description_type: 描述类型 ('emoji''image')
"""
try:
current_timestamp = time.time()
defaults = {"description": description, "timestamp": current_timestamp}
desc_obj, created = ImageDescriptions.get_or_create(
image_description_hash=image_hash, type=description_type, defaults=defaults
)
if not created: # 如果记录已存在,则更新
desc_obj.description = description
desc_obj.timestamp = current_timestamp
desc_obj.save()
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type))
)
record = session.exec(statement).first()
if record:
record.description = description
session.add(record)
return
new_record = Images(
image_hash=image_hash,
description=description,
full_path="",
image_type=ImageType(description_type),
query_count=0,
is_registered=False,
is_banned=False,
vlm_processed=True,
)
session.add(new_record)
except Exception as e:
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
@@ -107,20 +120,18 @@ class ImageManager:
"""清理数据库中 description 为空或为 'None' 的记录"""
invalid_values = ["", "None"]
# 清理 Images 表
deleted_images = (
Images.delete().where((Images.description >> None) | (Images.description << invalid_values)).execute()
)
with get_db_session() as session:
statement = (
select(Images)
.where(col(Images.description).is_(None) | col(Images.description).in_(invalid_values))
.limit(1000)
)
records = session.exec(statement).all()
for record in records:
session.delete(record)
# 清理 ImageDescriptions 表
deleted_descriptions = (
ImageDescriptions.delete()
.where((ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values))
.execute()
)
if deleted_images or deleted_descriptions:
logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions}")
if records:
logger.info(f"[清理完成] 删除 Images: {len(records)}")
else:
logger.info("[清理完成] 未发现无效描述记录")
@@ -128,19 +139,15 @@ class ImageManager:
def _cleanup_emoji_from_image_descriptions():
"""清理Images和ImageDescriptions表中type为emoji的记录已迁移到EmojiDescriptionCache"""
try:
# 清理Images表中type为emoji的记录
deleted_images = Images.delete().where(Images.type == "emoji").execute()
with get_db_session() as session:
statement = select(Images).where(col(Images.image_type) == ImageType.EMOJI)
records = session.exec(statement).all()
for record in records:
session.delete(record)
# 清理ImageDescriptions表中type为emoji的记录
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
total_deleted = deleted_images + deleted_descriptions
total_deleted = len(records)
if total_deleted > 0:
logger.info(
f"[清理完成] 从Images表中删除 {deleted_images} 条emoji类型记录, "
f"从ImageDescriptions表中删除 {deleted_descriptions} 条emoji类型记录, "
f"共删除 {total_deleted} 条记录"
)
logger.info(f"[清理完成] 从Images表中删除 {total_deleted} 条emoji类型记录")
else:
logger.info("[清理完成] Images和ImageDescriptions表中未发现emoji类型记录")
except Exception as e:
@@ -148,14 +155,14 @@ class ImageManager:
raise
async def get_emoji_tag(self, image_base64: str) -> str:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
emoji_manager = get_emoji_manager()
emoji_manager = emoji_manager_instance
if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
emoji = await emoji_manager.get_emoji_from_manager(image_hash)
emoji = emoji_manager.get_emoji_by_hash(image_hash)
if not emoji:
return "[表情包:未知]"
emotion_list = emoji.emotion
@@ -175,14 +182,14 @@ class ImageManager:
try:
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
# 确保目录存在
os.makedirs(EMOJI_DIR, exist_ok=True)
# 检查是否已存在该表情包(通过哈希值)
emoji_manager = get_emoji_manager()
existing_emoji = await emoji_manager.get_emoji_from_manager(image_hash)
emoji_manager = emoji_manager_instance
existing_emoji = emoji_manager.get_emoji_by_hash(image_hash)
if existing_emoji:
logger.debug(f"[自动保存] 表情包已注册,跳过保存: {image_hash[:8]}...")
return
@@ -212,14 +219,15 @@ class ImageManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
# 优先使用EmojiManager查询已注册表情包的描述
try:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
emoji_manager = get_emoji_manager()
tags = await emoji_manager.get_emoji_tag_by_hash(image_hash)
emoji_manager = emoji_manager_instance
emoji = emoji_manager.get_emoji_by_hash(image_hash)
tags = emoji.emotion if emoji else None
if tags:
tag_str = ",".join(tags)
logger.info(f"[缓存命中] 使用已注册表情包描述: {tag_str}...")
@@ -227,29 +235,26 @@ class ImageManager:
except Exception as e:
logger.debug(f"查询EmojiManager时出错: {e}")
# 查询EmojiDescriptionCache表的缓存包含描述和情感标签
try:
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
)
cache_record = session.exec(statement).first()
if cache_record:
# 优先使用情感标签,如果没有则使用详细描述
result_text = ""
if cache_record.emotion_tags:
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
)
result_text = f"[表情包:{cache_record.emotion_tags}]"
if cache_record.emotion:
logger.info(f"[缓存命中] 使用Images表中的情感标签: {cache_record.emotion[:50]}...")
result_text = f"[表情包:{cache_record.emotion}]"
elif cache_record.description:
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
)
logger.info(f"[缓存命中] 使用Images表中的描述: {cache_record.description[:50]}...")
result_text = f"[表情包:{cache_record.description}]"
# 即使缓存命中如果启用了steal_emoji也检查是否需要保存文件
if result_text:
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
return result_text
except Exception as e:
logger.debug(f"查询EmojiDescriptionCache时出错: {e}")
logger.debug(f"查询Images缓存时出错: {e}")
# === 二步走识别流程 ===
@@ -309,33 +314,42 @@ class ImageManager:
logger.debug(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
# 再次检查缓存(防止并发情况下其他线程已经保存)
try:
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
if cache_record and cache_record.emotion_tags:
logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion_tags}")
return f"[表情包:{cache_record.emotion_tags}]"
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
)
cache_record = session.exec(statement).first()
if cache_record and cache_record.emotion:
logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion}")
return f"[表情包:{cache_record.emotion}]"
except Exception as e:
logger.debug(f"再次查询EmojiDescriptionCache时出错: {e}")
logger.debug(f"再次查询Images缓存时出错: {e}")
# 保存识别出的详细描述和情感标签到 emoji_description_cache
try:
current_timestamp = time.time()
cache_record, created = EmojiDescriptionCache.get_or_create(
emoji_hash=image_hash,
defaults={
"description": detailed_description,
"emotion_tags": final_emotion,
"timestamp": current_timestamp,
},
)
if not created:
# 更新已有记录
cache_record.description = detailed_description
cache_record.emotion_tags = final_emotion
cache_record.timestamp = current_timestamp
cache_record.save()
logger.info(f"[缓存保存] 表情包描述和情感标签已保存到EmojiDescriptionCache: {image_hash[:8]}...")
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
)
cache_record = session.exec(statement).first()
if cache_record:
cache_record.description = detailed_description
cache_record.emotion = final_emotion
session.add(cache_record)
else:
cache_record = Images(
image_hash=image_hash,
description=detailed_description,
full_path="",
image_type=ImageType.EMOJI,
emotion=final_emotion,
query_count=0,
is_registered=False,
is_banned=False,
vlm_processed=True,
)
session.add(cache_record)
logger.info(f"[缓存保存] 表情包描述和情感标签已保存到Images: {image_hash[:8]}...")
except Exception as e:
logger.error(f"保存表情包描述和情感标签缓存失败: {str(e)}")
@@ -358,14 +372,13 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 优先检查Images表中是否已有完整的描述
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
with get_db_session() as session:
statement = select(Images).where(col(Images.image_hash) == image_hash)
existing_image = session.exec(statement).first()
if existing_image:
# 更新计数
if hasattr(existing_image, "count") and existing_image.count is not None:
existing_image.count += 1
else:
existing_image.count = 1
existing_image.save()
existing_image.query_count += 1
with get_db_session() as session:
session.add(existing_image)
# 如果已有描述,直接返回
if existing_image.description:
@@ -377,7 +390,7 @@ class ImageManager:
return f"[图片:{cached_description}]"
# 调用AI获取描述
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
prompt = global_config.personality.visual_style
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(
@@ -402,26 +415,27 @@ class ImageManager:
# 保存到数据库,补充缺失字段
if existing_image:
existing_image.path = file_path
existing_image.full_path = file_path
existing_image.description = description
existing_image.timestamp = current_timestamp
if not hasattr(existing_image, "image_id") or not existing_image.image_id:
existing_image.image_id = str(uuid.uuid4())
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
existing_image.vlm_processed = True
existing_image.save()
existing_image.record_time = datetime.fromtimestamp(current_timestamp)
existing_image.vlm_processed = True
with get_db_session() as session:
session.add(existing_image)
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
else:
Images.create(
image_id=str(uuid.uuid4()),
emoji_hash=image_hash,
path=file_path,
type="image",
description=description,
timestamp=current_timestamp,
vlm_processed=True,
count=1,
)
with get_db_session() as session:
new_record = Images(
image_hash=image_hash,
description=description,
full_path=file_path,
image_type=ImageType.IMAGE,
query_count=1,
is_registered=False,
is_banned=False,
record_time=datetime.fromtimestamp(current_timestamp),
vlm_processed=True,
)
session.add(new_record)
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
except Exception as e:
logger.error(f"保存图片文件或元数据失败: {str(e)}")
@@ -575,30 +589,17 @@ class ImageManager:
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
# 检查是否缺少必要字段,如果缺少则创建新记录
if (
not hasattr(existing_image, "image_id")
or not existing_image.image_id
or not hasattr(existing_image, "count")
or existing_image.count is None
or not hasattr(existing_image, "vlm_processed")
or existing_image.vlm_processed is None
):
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
if not existing_image.image_id:
existing_image.image_id = str(uuid.uuid4())
if existing_image.count is None:
existing_image.count = 0
if existing_image.vlm_processed is None:
existing_image.vlm_processed = False
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.IMAGE)
)
existing_image = session.exec(statement).first()
if existing_image:
existing_image.query_count += 1
session.add(existing_image)
return str(existing_image.id), f"[picid:{existing_image.id}]"
existing_image.count += 1
existing_image.save()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
else:
# print(f"图片不存在: {image_hash}")
image_id = str(uuid.uuid4())
image_id = str(uuid.uuid4())
# 保存新图片
current_timestamp = time.time()
@@ -612,15 +613,19 @@ class ImageManager:
f.write(image_bytes)
# 保存到数据库
Images.create(
image_id=image_id,
emoji_hash=image_hash,
path=file_path,
type="image",
timestamp=current_timestamp,
vlm_processed=False,
count=1,
)
with get_db_session() as session:
new_record = Images(
image_hash=image_hash,
description="",
full_path=file_path,
image_type=ImageType.IMAGE,
query_count=1,
is_registered=False,
is_banned=False,
record_time=datetime.fromtimestamp(current_timestamp),
vlm_processed=False,
)
session.add(new_record)
# 启动异步VLM处理
await self._process_image_with_vlm(image_id, image_base64)
@@ -647,17 +652,26 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 获取当前图片记录
image = Images.get(Images.image_id == image_id)
with get_db_session() as session:
image = session.get(Images, int(image_id)) if image_id.isdigit() else None
if image is None:
logger.warning(f"未找到图片记录: {image_id}")
return
# 优先检查是否已有其他相同哈希的图片记录包含描述
existing_with_description = Images.get_or_none(
(Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
)
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash)
& (col(Images.description).is_not(None))
& (col(Images.description) != "")
)
existing_with_description = session.exec(statement).first()
if existing_with_description and existing_with_description.id != image.id:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
image.description = existing_with_description.description
image.vlm_processed = True
image.save()
with get_db_session() as session:
session.add(image)
# 同时保存到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, existing_with_description.description, "image")
return
@@ -667,11 +681,12 @@ class ImageManager:
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description
image.vlm_processed = True
image.save()
with get_db_session() as session:
session.add(image)
return
# 获取图片格式
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
# 构建prompt
prompt = global_config.personality.visual_style
@@ -692,7 +707,8 @@ class ImageManager:
# 更新数据库
image.description = description
image.vlm_processed = True
image.save()
with get_db_session() as session:
session.add(image)
# 保存描述到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, description, "image")