feat:优化表情包注册,迁移数据库v3

This commit is contained in:
SengokuCola
2026-04-05 20:12:21 +08:00
parent 80be746be0
commit 526fc9b763
16 changed files with 926 additions and 534 deletions

View File

@@ -1,6 +1,6 @@
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Literal, Optional
import asyncio
import hashlib
@@ -31,12 +31,12 @@ install(extra_lines=3)
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve()
DATA_DIR = PROJECT_ROOT / "data"
EmojiRegisterStatus = Literal["registered", "skipped", "failed"]
EMOJI_DIR = DATA_DIR / "emoji" # 表情包存储目录
EMOJI_REGISTERED_DIR = DATA_DIR / "emoji_registered" # 已注册的表情包注册目录
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
def register_emoji_hook_specs(registry: HookSpecRegistry) -> list[HookSpec]:
"""注册表情包系统内置 Hook 规格。
Args:
@@ -145,7 +145,7 @@ def _get_runtime_manager() -> Any:
return get_plugin_runtime_manager()
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, Any]]:
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[dict[str, Any]]:
"""将表情包对象序列化为 Hook 可传输载荷。
Args:
@@ -163,27 +163,27 @@ def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, A
"file_name": emoji.file_name,
"full_path": str(emoji.full_path),
"description": emoji.description,
"emotions": [str(item).strip() for item in _normalize_emoji_tag_text(emoji.description or emoji.emotion)],
"emotions": [str(item).strip() for item in _normalize_emoji_tag_text(emoji.description)],
"query_count": int(emoji.query_count),
}
def _normalize_emoji_tag_text(raw_values: Any) -> List[str]:
def _normalize_emoji_tag_text(raw_values: Any) -> list[str]:
"""将文本或标签列表转为去重的情绪标签列表。"""
if isinstance(raw_values, str):
if not raw_values:
return []
parts = re.split(r"[,,、;\s]+", raw_values.strip())
parts = re.split(r"[,,、;\r\n\t]+", raw_values.strip())
normalized_tags = [str(part).strip() for part in parts if str(part).strip()]
elif isinstance(raw_values, list):
normalized_tags: List[str] = []
normalized_tags: list[str] = []
for value in raw_values:
normalized_tags.extend(_normalize_emoji_tag_text(value))
else:
return []
deduped_tags: List[str] = []
seen: Set[str] = set()
deduped_tags: list[str] = []
seen: set[str] = set()
for tag in normalized_tags:
normalized_tag = tag.strip()
if not normalized_tag:
@@ -196,15 +196,23 @@ def _normalize_emoji_tag_text(raw_values: Any) -> List[str]:
return deduped_tags
def _get_emoji_emotions(emoji: MaiEmoji) -> List[str]:
"""获取兼容旧数据的表情包情绪标签。"""
return _normalize_emoji_tag_text(emoji.description or emoji.emotion)
def _get_emoji_emotions(emoji: MaiEmoji) -> list[str]:
"""获取表情包情绪标签。"""
return _normalize_emoji_tag_text(emoji.description)
def _ensure_directories() -> None:
"""确保表情包相关目录存在"""
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
EMOJI_REGISTERED_DIR.mkdir(parents=True, exist_ok=True)
def _is_available_emoji_record(record: Images) -> bool:
"""判断数据库记录对应的表情包文件当前是否可用。"""
if record.no_file_flag:
return False
record_path = Path(record.full_path)
return record_path.exists() and record_path.is_file()
# TODO: 修改这个vlm为获取的vlm client暂时使用这个VLM方法
@@ -224,9 +232,9 @@ class EmojiManager:
_ensure_directories()
self._emoji_num: int = 0
self.emojis: List[MaiEmoji] = []
self.emojis: list[MaiEmoji] = []
self._maintenance_wakeup_event: asyncio.Event = asyncio.Event()
self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {}
self._pending_description_tasks: dict[str, asyncio.Task[None]] = {}
self._reload_callback_registered: bool = False
config_manager.register_reload_callback(self.reload_runtime_config)
@@ -254,7 +262,7 @@ class EmojiManager:
emoji_bytes: Optional[bytes] = None,
emoji_hash: Optional[str] = None,
wait_for_build: bool = True,
) -> Optional[Tuple[str, List[str]]]:
) -> Optional[tuple[str, list[str]]]:
"""
根据表情包哈希获取表情包描述和情感列表的封装方法
@@ -275,12 +283,28 @@ class EmojiManager:
emoji_hash = hashlib.sha256(emoji_bytes).hexdigest()
if emoji := self.get_emoji_by_hash(emoji_hash):
emoji_path = Path(emoji.full_path) if emoji.full_path else None
if emoji_bytes and (emoji_path is None or not emoji_path.exists()):
try:
restored_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
emoji.full_path = restored_emoji.full_path
emoji.file_name = restored_emoji.file_name
except Exception as e:
logger.warning(f"表情包缓存命中但本地文件缺失,回填失败: {e}")
return emoji.description, _normalize_emoji_tag_text(emoji.description or "")
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1)
if result := session.exec(statement).first():
cached_description = result.description or result.emotion or ""
record_path = Path(result.full_path) if result.full_path else None
if emoji_bytes and (result.no_file_flag or record_path is None or not record_path.exists()):
try:
restored_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
result.full_path = str(restored_emoji.full_path)
result.no_file_flag = False
except Exception as e:
logger.warning(f"数据库命中表情包记录但本地文件缺失,回填失败: {e}")
cached_description = result.description or ""
cached_emotions = _normalize_emoji_tag_text(cached_description)
return (
cached_description,
@@ -400,7 +424,7 @@ class EmojiManager:
self,
emoji_hash: str,
emoji_bytes: bytes,
) -> Optional[Tuple[str, List[str]]]:
) -> Optional[tuple[str, list[str]]]:
"""构建并缓存表情包描述(返回标签化结果,不再走额外识别流程)。"""
logger.info(f"Start building cached emoji description, hash={emoji_hash}")
new_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
@@ -461,87 +485,70 @@ class EmojiManager:
self._emoji_num = 0
raise e
def register_emoji_to_db(self, emoji: MaiEmoji) -> bool:
def register_emoji_to_db(self, emoji: MaiEmoji) -> EmojiRegisterStatus:
# sourcery skip: extract-method
"""
将表情包注册到数据库中
Args:
emoji (MaiEmoji): 需要注册的表情包对象
Returns:
return (bool): 注册是否成功
"""
"""Register an emoji in the database without moving its source file."""
if not emoji or not isinstance(emoji, MaiEmoji):
logger.error("[注册表情包] 无效的表情包对象")
return False
logger.error("[register_emoji] Invalid emoji object")
return "failed"
if not emoji.full_path.exists():
logger.error(f"[注册表情包] 表情包文件不存在: {emoji.full_path}")
return False
logger.error(f"[register_emoji] Emoji file does not exist: {emoji.full_path}")
return "failed"
target_path = EMOJI_REGISTERED_DIR / emoji.file_name
# 先查库,避免重复记录导致文件被误移动后无法回收
original_path = emoji.full_path
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
existing_record = session.exec(statement).first()
if existing_record and not existing_record.no_file_flag:
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
return False
except Exception as e:
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
return False
try:
emoji.full_path.replace(target_path)
emoji.full_path = target_path
except Exception as e:
logger.error(f"[注册表情包] 移动表情包文件时出错: {e}")
return False
if existing_record:
if existing_record.is_registered and _is_available_emoji_record(existing_record):
logger.info(f"[register_emoji] Emoji already registered, skipping: {emoji.file_hash}")
return "skipped"
normalized_description = str(emoji.description or existing_record.description or "").strip()
normalized_emotions = _normalize_emoji_tag_text(normalized_description)
register_time = existing_record.register_time or datetime.now()
existing_record.full_path = str(emoji.full_path)
existing_record.description = normalized_description
existing_record.query_count = max(int(existing_record.query_count), int(emoji.query_count))
existing_record.last_used_time = emoji.last_used_time or existing_record.last_used_time
existing_record.is_registered = True
existing_record.is_banned = False
existing_record.no_file_flag = False
existing_record.register_time = register_time
session.add(existing_record)
emoji.description = normalized_description
emoji.emotion = normalized_emotions
emoji.query_count = existing_record.query_count
emoji.last_used_time = existing_record.last_used_time
emoji.register_time = register_time
logger.info(
f"[register_emoji] Updated existing record and registered emoji, ID: {existing_record.id}, path: {emoji.full_path}"
)
return "registered"
except Exception as e:
logger.error(f"[register_emoji] Database query failed: {e}")
return "failed"
# 注册到数据库
restore_file = False
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
if existing_record := session.exec(statement).first():
if existing_record.no_file_flag:
existing_record.no_file_flag = False
existing_record.is_banned = False
existing_record.full_path = str(emoji.full_path)
existing_record.description = emoji.description
existing_record.query_count = emoji.query_count
existing_record.last_used_time = emoji.last_used_time
existing_record.register_time = emoji.register_time
session.add(existing_record)
logger.info(
f"[注册表情包] 更新已有记录并注册表情包到数据库, ID: {existing_record.id}, 路径: {emoji.full_path}"
)
else:
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
restore_file = True
return False
else:
image_record = emoji.to_db_instance()
image_record.is_registered = True
image_record.is_banned = False
image_record.register_time = datetime.now()
session.add(image_record)
session.flush()
record_id = image_record.id
logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}")
image_record = emoji.to_db_instance()
image_record.is_registered = True
image_record.is_banned = False
image_record.no_file_flag = False
image_record.register_time = datetime.now()
session.add(image_record)
session.flush()
emoji.register_time = image_record.register_time
record_id = image_record.id
logger.info(f"[register_emoji] Registered emoji to database, ID: {record_id}, path: {emoji.full_path}")
except Exception as e:
logger.error(f"[注册表情包] 注册到数据库时出错: {e}")
restore_file = True
return False
finally:
if restore_file:
try:
emoji.full_path.replace(original_path)
emoji.full_path = original_path
except Exception as e:
logger.error(f"[注册表情包] 回滚文件移动失败: {e}")
return True
logger.error(f"[register_emoji] Failed to write database record: {e}")
return "failed"
return "registered"
def delete_emoji(self, emoji: MaiEmoji, no_desc: bool = False) -> bool:
"""
@@ -636,7 +643,6 @@ class EmojiManager:
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
if image_record := session.exec(statement).first():
image_record.description = emoji.description
image_record.emotion = None
session.add(image_record)
logger.info(f"[更新表情包] 成功更新表情包信息: {emoji.file_hash}")
else:
@@ -794,12 +800,15 @@ class EmojiManager:
logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}")
if self.delete_emoji(emoji_to_delete):
self.emojis.remove(emoji_to_delete)
if self.register_emoji_to_db(new_emoji):
register_status = self.register_emoji_to_db(new_emoji)
if register_status == "registered":
self.emojis.append(new_emoji)
logger.info(f"[注册表情包] 成功替换并注册新表情包: {new_emoji.description}")
logger.info(f"[register_emoji] Replaced old emoji with new emoji: {new_emoji.description}")
return True
if register_status == "skipped":
logger.info(f"[register_emoji] Replacement emoji was already registered: {new_emoji.description}")
else:
logger.error(f"[注册表情包] 注册新表情包失败: {new_emoji.description}")
logger.error(f"[register_emoji] Failed to register replacement emoji: {new_emoji.description}")
else:
logger.error("[错误] 删除表情包失败,无法完成替换")
else:
@@ -808,7 +817,7 @@ class EmojiManager:
logger.error("[决策] 未能解析删除编号")
return False
async def build_emoji_description(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]:
async def build_emoji_description(self, target_emoji: MaiEmoji) -> tuple[bool, MaiEmoji]:
"""
构建表情包描述
@@ -915,16 +924,12 @@ class EmojiManager:
logger.info(f"[构建描述] 成功为表情包构建情绪标签: {target_emoji.description}")
return True, target_emoji
async def build_emoji_emotion(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]:
"""兼容保留:表情包情绪标签已在 build_emoji_description 中一次性构建。"""
return await self.build_emoji_description(target_emoji)
def check_emoji_file_integrity(self) -> None:
"""
检查表情包完整性,删除文件缺失的表情包记录
"""
logger.info("[完整性检查] 开始检查表情包文件完整性...")
to_delete_emojis: list[Tuple[MaiEmoji, bool]] = []
to_delete_emojis: list[tuple[MaiEmoji, bool]] = []
removal_count = 0
for emoji in self.emojis:
if not emoji.full_path.exists():
@@ -945,59 +950,34 @@ class EmojiManager:
logger.info(f"[完整性检查] 表情包文件完整性检查完成,删除了 {removal_count} 条记录")
def remove_untracked_emoji_files(self) -> None:
"""
删除未被数据库记录跟踪的表情包文件
"""
logger.info("[未跟踪表情包清理] 开始清理未被数据库记录跟踪的表情包文件...")
tracked_files = {emoji.full_path.name for emoji in self.emojis}
all_files = set(EMOJI_REGISTERED_DIR.glob("*"))
removal_count = 0
for file_path in all_files:
if file_path.name not in tracked_files:
try:
file_path.unlink()
removal_count += 1
logger.info(f"[未跟踪表情包清理] 删除未跟踪的表情包文件: {file_path.name}")
except Exception as e:
logger.error(f"[未跟踪表情包清理] 删除文件 {file_path.name} 时出错: {e}")
logger.info(f"[未跟踪表情包清理] 未跟踪表情包文件清理完成,删除了 {removal_count} 个文件")
async def periodic_emoji_maintenance(self) -> None:
"""
定期执行表情包维护任务,包括完整性检查和未跟踪文件清理
"""
"""Run emoji maintenance tasks periodically."""
while True:
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
EMOJI_REGISTERED_DIR.mkdir(parents=True, exist_ok=True)
_ensure_directories()
try:
self.check_emoji_file_integrity()
self.remove_untracked_emoji_files()
except Exception as e:
logger.error(f"[定期维护] 执行表情包维护任务时出错: {e}")
logger.error(f"[emoji_maintenance] Maintenance task failed: {e}")
if global_config.emoji.steal_emoji and (
self._emoji_num < global_config.emoji.max_reg_num
or (self._emoji_num > global_config.emoji.max_reg_num and global_config.emoji.do_replace)
):
logger.info("[定期维护] 尝试从表情包盗取目录注册新表情包...")
logger.info("[emoji_maintenance] Scanning data/emoji for new emojis...")
for emoji_file in EMOJI_DIR.iterdir():
if not emoji_file.is_file():
continue
try:
register_success = await self.register_emoji_by_filename(emoji_file)
register_status = await self.register_emoji_by_filename(emoji_file)
except Exception as e:
logger.error(f"[定期维护] 注册表情包 {emoji_file.name} 时发生未处理异常: {e}")
register_success = False
if register_success:
break # 每次只注册一个表情包
try:
emoji_file.unlink()
logger.info(f"[定期维护] 删除无法注册的表情包文件: {emoji_file.name}")
except Exception as e:
logger.error(f"[定期维护] 删除文件 {emoji_file.name} 时出错: {e}")
logger.error(f"[emoji_maintenance] Failed to process {emoji_file.name}: {e}")
register_status = "failed"
if register_status == "registered":
break
if register_status == "skipped":
logger.debug(f"[emoji_maintenance] Emoji already registered, keep file: {emoji_file.name}")
else:
logger.debug(f"[emoji_maintenance] Emoji not registered, keep file: {emoji_file.name}")
wait_seconds = max(global_config.emoji.check_interval * 60, 0)
try:
await asyncio.wait_for(self._maintenance_wakeup_event.wait(), timeout=wait_seconds)
@@ -1006,84 +986,81 @@ class EmojiManager:
finally:
self._maintenance_wakeup_event.clear()
async def register_emoji_by_filename(self, filename: Path | str) -> bool:
"""
根据指定的表情包图片,分析并注册到数据库
Args:
filename (Path | str): 表情包图片的完整文件路径(可能根据文件实际格式修正)
Returns:
return (bool): 注册是否成功
"""
async def register_emoji_by_filename(self, filename: Path | str) -> EmojiRegisterStatus:
"""Register an emoji file from ``data/emoji`` without moving or deleting it."""
file_full_path = Path(filename).absolute().resolve()
if not file_full_path.exists():
logger.error(f"[注册表情包] 表情包文件不存在: {file_full_path}")
return False
logger.error(f"[register_emoji] Emoji file does not exist: {file_full_path}")
return "failed"
try:
target_emoji = MaiEmoji(full_path=file_full_path)
except Exception as e:
logger.error(f"[注册表情包] 创建表情包对象时出错: {e}")
return False
logger.error(f"[register_emoji] Failed to create emoji object: {e}")
return "failed"
calc_success = await target_emoji.calculate_hash_format()
if not calc_success:
logger.error(f"[注册表情包] 计算表情包哈希值和格式失败: {file_full_path}")
return False
file_full_path = target_emoji.full_path # 更新为可能修正后的路径
logger.error(f"[register_emoji] Failed to calculate hash and format: {file_full_path}")
return "failed"
file_full_path = target_emoji.full_path
# 2. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建
existing_record: Optional[Images] = None
try:
with get_db_session_manual() as session:
statement = (
select(Images).filter_by(image_hash=target_emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
)
if image_record := session.exec(statement).first():
if image_record.no_file_flag:
image_record.no_file_flag = False
image_record.is_banned = False
image_record.is_registered = True
image_record.full_path = str(target_emoji.full_path)
session.add(image_record)
session.commit()
logger.info(f"表情包注册成功Hash: {target_emoji.file_hash}")
return True
else:
logger.warning(f"[注册表情包] 数据库中已存在表情包记录,跳过注册: {target_emoji.file_name}")
return False
existing_record = session.exec(statement).first()
except Exception as e:
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
return False
logger.error(f"[register_emoji] Failed to query database: {e}")
return "failed"
if existing_record is not None:
if existing_record.is_registered and _is_available_emoji_record(existing_record):
logger.info(f"[register_emoji] Emoji already registered, skipping: {target_emoji.file_name}")
return "skipped"
cached_description = str(existing_record.description or "").strip()
if cached_description:
normalized_emotions = _normalize_emoji_tag_text(cached_description)
target_emoji.description = ",".join(normalized_emotions)
target_emoji.emotion = normalized_emotions
target_emoji.query_count = existing_record.query_count
target_emoji.last_used_time = existing_record.last_used_time
target_emoji.register_time = existing_record.register_time
# 3. 检查内存缓存是否已经存在
if existing_emoji := self.get_emoji_by_hash(target_emoji.file_hash):
logger.warning(f"[注册表情包] 表情包已存在,跳过注册: {existing_emoji.file_name}")
return False
# 3. 构建描述(包含情绪标签)
desc_success, target_emoji = await self.build_emoji_description(target_emoji)
if not desc_success:
logger.error(f"[注册表情包] 构建表情包描述失败: {file_full_path}")
return False
logger.info(f"[register_emoji] Emoji already loaded in memory, skipping: {existing_emoji.file_name}")
return "skipped"
if not target_emoji.description:
desc_success, target_emoji = await self.build_emoji_description(target_emoji)
if not desc_success:
logger.error(f"[register_emoji] Failed to build emoji description: {file_full_path}")
return "failed"
# 4. 检查容量并决定是否替换或者直接注册
if self._emoji_num >= global_config.emoji.max_reg_num and global_config.emoji.do_replace:
logger.warning(f"[注册表情包] 表情包数量已达上限{global_config.emoji.max_reg_num},尝试替换一个表情包")
logger.warning(
f"[register_emoji] Emoji limit {global_config.emoji.max_reg_num} reached, trying replacement"
)
replaced = await self.replace_an_emoji_by_llm(target_emoji)
if not replaced:
logger.error("[注册表情包] 替换表情包失败,无法注册新表情包")
return False
return True
else:
if self.register_emoji_to_db(target_emoji):
self.emojis.append(target_emoji)
self._emoji_num += 1
logger.info(f"[注册表情包] 成功注册新表情包: {target_emoji.file_name}")
return True
else:
logger.error(f"[注册表情包] 注册表情包到数据库失败: {file_full_path}")
return False
logger.error("[register_emoji] Failed to replace existing emoji")
return "failed"
return "registered"
def _calculate_emotion_similarity_list(self, text_emotion: str) -> List[Tuple[MaiEmoji, float]]:
register_status = self.register_emoji_to_db(target_emoji)
if register_status == "registered":
self.emojis.append(target_emoji)
self._emoji_num = len(self.emojis)
logger.info(f"[register_emoji] Registered new emoji: {target_emoji.file_name}")
elif register_status == "failed":
logger.error(f"[register_emoji] Failed to register emoji in database: {file_full_path}")
else:
logger.info(f"[register_emoji] Emoji already registered, skipping: {target_emoji.file_name}")
return register_status
def _calculate_emotion_similarity_list(self, text_emotion: str) -> list[tuple[MaiEmoji, float]]:
"""
计算文本情感标签与所有表情包情感标签的相似度列表
@@ -1096,7 +1073,7 @@ class EmojiManager:
if not normalized_text_emotion:
return []
similarity_list: List[Tuple[MaiEmoji, float]] = []
similarity_list: list[tuple[MaiEmoji, float]] = []
for emoji in self.emojis:
candidate_emotions = _get_emoji_emotions(emoji)
if not candidate_emotions:

View File

@@ -1,8 +1,8 @@
"""Maisaka 表情工具内置能力。"""
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass, field
from typing import Any, Optional, Sequence, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
import random
@@ -120,8 +120,6 @@ def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
"""提取并清洗单个表情的情绪标签。"""
if emoji.description:
return _normalize_emoji_tag_text(emoji.description)
if emoji.emotion:
return _normalize_emoji_tag_text(emoji.emotion)
return []

View File

@@ -41,6 +41,7 @@ class ImageManager:
"""初始化图片管理器。"""
_ensure_image_dir_exists()
self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {}
self.cleanup_legacy_image_registration_records()
logger.info("图片管理器初始化完成")
@@ -50,6 +51,17 @@ class ImageManager:
statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1)
return session.exec(statement).first()
def _normalize_image_registration_fields(self, record: Images) -> bool:
"""Normalize accidental emoji registration fields on image records."""
if record.image_type != ImageType.IMAGE:
return False
if not record.is_registered and record.register_time is None:
return False
record.is_registered = False
record.register_time = None
return True
async def get_image_description(
self,
*,
@@ -182,8 +194,9 @@ class ImageManager:
try:
with get_db_session() as session:
record = image.to_db_instance()
record.is_registered = True
record.register_time = record.last_used_time = datetime.now()
record.is_registered = False
record.register_time = None
record.last_used_time = datetime.now()
session.add(record)
session.flush() # 确保记录被写入数据库以获取ID
record_id = record.id
@@ -209,6 +222,7 @@ class ImageManager:
if not record:
logger.error(f"未找到哈希值为 {image.file_hash} 的图片记录,无法更新描述")
return False
self._normalize_image_registration_fields(record)
record.description = image.description
record.last_used_time = datetime.now()
record.vlm_processed = image.vlm_processed
@@ -258,6 +272,7 @@ class ImageManager:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1)
if record := session.exec(statement).first():
self._normalize_image_registration_fields(record)
logger.info(f"图片已存在于数据库中,哈希值: {hash_str}")
record.last_used_time = datetime.now()
record.query_count += 1
@@ -335,6 +350,24 @@ class ImageManager:
logger.info(f"清理完成: {invalid_counter} 条无效描述记录,{null_path_counter} 条文件路径不存在记录")
def cleanup_legacy_image_registration_records(self) -> None:
"""Clean up legacy image records with mistaken registration fields."""
fixed_counter = 0
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_type=ImageType.IMAGE)
for record in session.exec(statement).yield_per(100):
if not self._normalize_image_registration_fields(record):
continue
session.add(record)
fixed_counter += 1
except Exception as e:
logger.error(f"Failed to clean image registration state: {e}")
return
if fixed_counter:
logger.info(f"Cleaned mistaken registration state on {fixed_counter} image records")
async def _generate_image_description(self, image_bytes: bytes, image_format: str) -> str:
prompt = global_config.personality.visual_style
image_base64 = base64.b64encode(image_bytes).decode("utf-8")