feat:优化表情包注册,迁移数据库v3
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
@@ -31,12 +31,12 @@ install(extra_lines=3)
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve()
|
||||
DATA_DIR = PROJECT_ROOT / "data"
|
||||
EmojiRegisterStatus = Literal["registered", "skipped", "failed"]
|
||||
EMOJI_DIR = DATA_DIR / "emoji" # 表情包存储目录
|
||||
EMOJI_REGISTERED_DIR = DATA_DIR / "emoji_registered" # 已注册的表情包注册目录
|
||||
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
|
||||
|
||||
|
||||
def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
def register_emoji_hook_specs(registry: HookSpecRegistry) -> list[HookSpec]:
|
||||
"""注册表情包系统内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
@@ -145,7 +145,7 @@ def _get_runtime_manager() -> Any:
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
|
||||
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, Any]]:
|
||||
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[dict[str, Any]]:
|
||||
"""将表情包对象序列化为 Hook 可传输载荷。
|
||||
|
||||
Args:
|
||||
@@ -163,27 +163,27 @@ def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, A
|
||||
"file_name": emoji.file_name,
|
||||
"full_path": str(emoji.full_path),
|
||||
"description": emoji.description,
|
||||
"emotions": [str(item).strip() for item in _normalize_emoji_tag_text(emoji.description or emoji.emotion)],
|
||||
"emotions": [str(item).strip() for item in _normalize_emoji_tag_text(emoji.description)],
|
||||
"query_count": int(emoji.query_count),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_emoji_tag_text(raw_values: Any) -> List[str]:
|
||||
def _normalize_emoji_tag_text(raw_values: Any) -> list[str]:
|
||||
"""将文本或标签列表转为去重的情绪标签列表。"""
|
||||
if isinstance(raw_values, str):
|
||||
if not raw_values:
|
||||
return []
|
||||
parts = re.split(r"[,,、;;\s]+", raw_values.strip())
|
||||
parts = re.split(r"[,,、;;\r\n\t]+", raw_values.strip())
|
||||
normalized_tags = [str(part).strip() for part in parts if str(part).strip()]
|
||||
elif isinstance(raw_values, list):
|
||||
normalized_tags: List[str] = []
|
||||
normalized_tags: list[str] = []
|
||||
for value in raw_values:
|
||||
normalized_tags.extend(_normalize_emoji_tag_text(value))
|
||||
else:
|
||||
return []
|
||||
|
||||
deduped_tags: List[str] = []
|
||||
seen: Set[str] = set()
|
||||
deduped_tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for tag in normalized_tags:
|
||||
normalized_tag = tag.strip()
|
||||
if not normalized_tag:
|
||||
@@ -196,15 +196,23 @@ def _normalize_emoji_tag_text(raw_values: Any) -> List[str]:
|
||||
return deduped_tags
|
||||
|
||||
|
||||
def _get_emoji_emotions(emoji: MaiEmoji) -> List[str]:
|
||||
"""获取兼容旧数据的表情包情绪标签。"""
|
||||
return _normalize_emoji_tag_text(emoji.description or emoji.emotion)
|
||||
def _get_emoji_emotions(emoji: MaiEmoji) -> list[str]:
|
||||
"""获取表情包情绪标签。"""
|
||||
return _normalize_emoji_tag_text(emoji.description)
|
||||
|
||||
|
||||
def _ensure_directories() -> None:
|
||||
"""确保表情包相关目录存在"""
|
||||
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
|
||||
EMOJI_REGISTERED_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _is_available_emoji_record(record: Images) -> bool:
|
||||
"""判断数据库记录对应的表情包文件当前是否可用。"""
|
||||
if record.no_file_flag:
|
||||
return False
|
||||
|
||||
record_path = Path(record.full_path)
|
||||
return record_path.exists() and record_path.is_file()
|
||||
|
||||
|
||||
# TODO: 修改这个vlm为获取的vlm client,暂时使用这个VLM方法
|
||||
@@ -224,9 +232,9 @@ class EmojiManager:
|
||||
_ensure_directories()
|
||||
|
||||
self._emoji_num: int = 0
|
||||
self.emojis: List[MaiEmoji] = []
|
||||
self.emojis: list[MaiEmoji] = []
|
||||
self._maintenance_wakeup_event: asyncio.Event = asyncio.Event()
|
||||
self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {}
|
||||
self._pending_description_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._reload_callback_registered: bool = False
|
||||
|
||||
config_manager.register_reload_callback(self.reload_runtime_config)
|
||||
@@ -254,7 +262,7 @@ class EmojiManager:
|
||||
emoji_bytes: Optional[bytes] = None,
|
||||
emoji_hash: Optional[str] = None,
|
||||
wait_for_build: bool = True,
|
||||
) -> Optional[Tuple[str, List[str]]]:
|
||||
) -> Optional[tuple[str, list[str]]]:
|
||||
"""
|
||||
根据表情包哈希获取表情包描述和情感列表的封装方法
|
||||
|
||||
@@ -275,12 +283,28 @@ class EmojiManager:
|
||||
emoji_hash = hashlib.sha256(emoji_bytes).hexdigest()
|
||||
|
||||
if emoji := self.get_emoji_by_hash(emoji_hash):
|
||||
emoji_path = Path(emoji.full_path) if emoji.full_path else None
|
||||
if emoji_bytes and (emoji_path is None or not emoji_path.exists()):
|
||||
try:
|
||||
restored_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
|
||||
emoji.full_path = restored_emoji.full_path
|
||||
emoji.file_name = restored_emoji.file_name
|
||||
except Exception as e:
|
||||
logger.warning(f"表情包缓存命中但本地文件缺失,回填失败: {e}")
|
||||
return emoji.description, _normalize_emoji_tag_text(emoji.description or "")
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
if result := session.exec(statement).first():
|
||||
cached_description = result.description or result.emotion or ""
|
||||
record_path = Path(result.full_path) if result.full_path else None
|
||||
if emoji_bytes and (result.no_file_flag or record_path is None or not record_path.exists()):
|
||||
try:
|
||||
restored_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
|
||||
result.full_path = str(restored_emoji.full_path)
|
||||
result.no_file_flag = False
|
||||
except Exception as e:
|
||||
logger.warning(f"数据库命中表情包记录但本地文件缺失,回填失败: {e}")
|
||||
cached_description = result.description or ""
|
||||
cached_emotions = _normalize_emoji_tag_text(cached_description)
|
||||
return (
|
||||
cached_description,
|
||||
@@ -400,7 +424,7 @@ class EmojiManager:
|
||||
self,
|
||||
emoji_hash: str,
|
||||
emoji_bytes: bytes,
|
||||
) -> Optional[Tuple[str, List[str]]]:
|
||||
) -> Optional[tuple[str, list[str]]]:
|
||||
"""构建并缓存表情包描述(返回标签化结果,不再走额外识别流程)。"""
|
||||
logger.info(f"Start building cached emoji description, hash={emoji_hash}")
|
||||
new_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
|
||||
@@ -461,87 +485,70 @@ class EmojiManager:
|
||||
self._emoji_num = 0
|
||||
raise e
|
||||
|
||||
def register_emoji_to_db(self, emoji: MaiEmoji) -> bool:
|
||||
def register_emoji_to_db(self, emoji: MaiEmoji) -> EmojiRegisterStatus:
|
||||
# sourcery skip: extract-method
|
||||
"""
|
||||
将表情包注册到数据库中
|
||||
|
||||
Args:
|
||||
emoji (MaiEmoji): 需要注册的表情包对象
|
||||
Returns:
|
||||
return (bool): 注册是否成功
|
||||
"""
|
||||
"""Register an emoji in the database without moving its source file."""
|
||||
if not emoji or not isinstance(emoji, MaiEmoji):
|
||||
logger.error("[注册表情包] 无效的表情包对象")
|
||||
return False
|
||||
logger.error("[register_emoji] Invalid emoji object")
|
||||
return "failed"
|
||||
if not emoji.full_path.exists():
|
||||
logger.error(f"[注册表情包] 表情包文件不存在: {emoji.full_path}")
|
||||
return False
|
||||
logger.error(f"[register_emoji] Emoji file does not exist: {emoji.full_path}")
|
||||
return "failed"
|
||||
|
||||
target_path = EMOJI_REGISTERED_DIR / emoji.file_name
|
||||
|
||||
# 先查库,避免重复记录导致文件被误移动后无法回收
|
||||
original_path = emoji.full_path
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
existing_record = session.exec(statement).first()
|
||||
if existing_record and not existing_record.no_file_flag:
|
||||
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
|
||||
return False
|
||||
try:
|
||||
emoji.full_path.replace(target_path)
|
||||
emoji.full_path = target_path
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 移动表情包文件时出错: {e}")
|
||||
return False
|
||||
if existing_record:
|
||||
if existing_record.is_registered and _is_available_emoji_record(existing_record):
|
||||
logger.info(f"[register_emoji] Emoji already registered, skipping: {emoji.file_hash}")
|
||||
return "skipped"
|
||||
|
||||
normalized_description = str(emoji.description or existing_record.description or "").strip()
|
||||
normalized_emotions = _normalize_emoji_tag_text(normalized_description)
|
||||
register_time = existing_record.register_time or datetime.now()
|
||||
|
||||
existing_record.full_path = str(emoji.full_path)
|
||||
existing_record.description = normalized_description
|
||||
existing_record.query_count = max(int(existing_record.query_count), int(emoji.query_count))
|
||||
existing_record.last_used_time = emoji.last_used_time or existing_record.last_used_time
|
||||
existing_record.is_registered = True
|
||||
existing_record.is_banned = False
|
||||
existing_record.no_file_flag = False
|
||||
existing_record.register_time = register_time
|
||||
session.add(existing_record)
|
||||
|
||||
emoji.description = normalized_description
|
||||
emoji.emotion = normalized_emotions
|
||||
emoji.query_count = existing_record.query_count
|
||||
emoji.last_used_time = existing_record.last_used_time
|
||||
emoji.register_time = register_time
|
||||
|
||||
logger.info(
|
||||
f"[register_emoji] Updated existing record and registered emoji, ID: {existing_record.id}, path: {emoji.full_path}"
|
||||
)
|
||||
return "registered"
|
||||
except Exception as e:
|
||||
logger.error(f"[register_emoji] Database query failed: {e}")
|
||||
return "failed"
|
||||
|
||||
# 注册到数据库
|
||||
restore_file = False
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
if existing_record := session.exec(statement).first():
|
||||
if existing_record.no_file_flag:
|
||||
existing_record.no_file_flag = False
|
||||
existing_record.is_banned = False
|
||||
existing_record.full_path = str(emoji.full_path)
|
||||
existing_record.description = emoji.description
|
||||
existing_record.query_count = emoji.query_count
|
||||
existing_record.last_used_time = emoji.last_used_time
|
||||
existing_record.register_time = emoji.register_time
|
||||
session.add(existing_record)
|
||||
logger.info(
|
||||
f"[注册表情包] 更新已有记录并注册表情包到数据库, ID: {existing_record.id}, 路径: {emoji.full_path}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
|
||||
restore_file = True
|
||||
return False
|
||||
else:
|
||||
image_record = emoji.to_db_instance()
|
||||
image_record.is_registered = True
|
||||
image_record.is_banned = False
|
||||
image_record.register_time = datetime.now()
|
||||
session.add(image_record)
|
||||
session.flush()
|
||||
record_id = image_record.id
|
||||
logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}")
|
||||
image_record = emoji.to_db_instance()
|
||||
image_record.is_registered = True
|
||||
image_record.is_banned = False
|
||||
image_record.no_file_flag = False
|
||||
image_record.register_time = datetime.now()
|
||||
session.add(image_record)
|
||||
session.flush()
|
||||
emoji.register_time = image_record.register_time
|
||||
record_id = image_record.id
|
||||
logger.info(f"[register_emoji] Registered emoji to database, ID: {record_id}, path: {emoji.full_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 注册到数据库时出错: {e}")
|
||||
restore_file = True
|
||||
return False
|
||||
finally:
|
||||
if restore_file:
|
||||
try:
|
||||
emoji.full_path.replace(original_path)
|
||||
emoji.full_path = original_path
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 回滚文件移动失败: {e}")
|
||||
return True
|
||||
logger.error(f"[register_emoji] Failed to write database record: {e}")
|
||||
return "failed"
|
||||
|
||||
return "registered"
|
||||
|
||||
def delete_emoji(self, emoji: MaiEmoji, no_desc: bool = False) -> bool:
|
||||
"""
|
||||
@@ -636,7 +643,6 @@ class EmojiManager:
|
||||
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
if image_record := session.exec(statement).first():
|
||||
image_record.description = emoji.description
|
||||
image_record.emotion = None
|
||||
session.add(image_record)
|
||||
logger.info(f"[更新表情包] 成功更新表情包信息: {emoji.file_hash}")
|
||||
else:
|
||||
@@ -794,12 +800,15 @@ class EmojiManager:
|
||||
logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}")
|
||||
if self.delete_emoji(emoji_to_delete):
|
||||
self.emojis.remove(emoji_to_delete)
|
||||
if self.register_emoji_to_db(new_emoji):
|
||||
register_status = self.register_emoji_to_db(new_emoji)
|
||||
if register_status == "registered":
|
||||
self.emojis.append(new_emoji)
|
||||
logger.info(f"[注册表情包] 成功替换并注册新表情包: {new_emoji.description}")
|
||||
logger.info(f"[register_emoji] Replaced old emoji with new emoji: {new_emoji.description}")
|
||||
return True
|
||||
if register_status == "skipped":
|
||||
logger.info(f"[register_emoji] Replacement emoji was already registered: {new_emoji.description}")
|
||||
else:
|
||||
logger.error(f"[注册表情包] 注册新表情包失败: {new_emoji.description}")
|
||||
logger.error(f"[register_emoji] Failed to register replacement emoji: {new_emoji.description}")
|
||||
else:
|
||||
logger.error("[错误] 删除表情包失败,无法完成替换")
|
||||
else:
|
||||
@@ -808,7 +817,7 @@ class EmojiManager:
|
||||
logger.error("[决策] 未能解析删除编号")
|
||||
return False
|
||||
|
||||
async def build_emoji_description(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]:
|
||||
async def build_emoji_description(self, target_emoji: MaiEmoji) -> tuple[bool, MaiEmoji]:
|
||||
"""
|
||||
构建表情包描述
|
||||
|
||||
@@ -915,16 +924,12 @@ class EmojiManager:
|
||||
logger.info(f"[构建描述] 成功为表情包构建情绪标签: {target_emoji.description}")
|
||||
return True, target_emoji
|
||||
|
||||
async def build_emoji_emotion(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]:
|
||||
"""兼容保留:表情包情绪标签已在 build_emoji_description 中一次性构建。"""
|
||||
return await self.build_emoji_description(target_emoji)
|
||||
|
||||
def check_emoji_file_integrity(self) -> None:
|
||||
"""
|
||||
检查表情包完整性,删除文件缺失的表情包记录
|
||||
"""
|
||||
logger.info("[完整性检查] 开始检查表情包文件完整性...")
|
||||
to_delete_emojis: list[Tuple[MaiEmoji, bool]] = []
|
||||
to_delete_emojis: list[tuple[MaiEmoji, bool]] = []
|
||||
removal_count = 0
|
||||
for emoji in self.emojis:
|
||||
if not emoji.full_path.exists():
|
||||
@@ -945,59 +950,34 @@ class EmojiManager:
|
||||
|
||||
logger.info(f"[完整性检查] 表情包文件完整性检查完成,删除了 {removal_count} 条记录")
|
||||
|
||||
def remove_untracked_emoji_files(self) -> None:
|
||||
"""
|
||||
删除未被数据库记录跟踪的表情包文件
|
||||
"""
|
||||
logger.info("[未跟踪表情包清理] 开始清理未被数据库记录跟踪的表情包文件...")
|
||||
tracked_files = {emoji.full_path.name for emoji in self.emojis}
|
||||
all_files = set(EMOJI_REGISTERED_DIR.glob("*"))
|
||||
removal_count = 0
|
||||
|
||||
for file_path in all_files:
|
||||
if file_path.name not in tracked_files:
|
||||
try:
|
||||
file_path.unlink()
|
||||
removal_count += 1
|
||||
logger.info(f"[未跟踪表情包清理] 删除未跟踪的表情包文件: {file_path.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"[未跟踪表情包清理] 删除文件 {file_path.name} 时出错: {e}")
|
||||
|
||||
logger.info(f"[未跟踪表情包清理] 未跟踪表情包文件清理完成,删除了 {removal_count} 个文件")
|
||||
|
||||
async def periodic_emoji_maintenance(self) -> None:
|
||||
"""
|
||||
定期执行表情包维护任务,包括完整性检查和未跟踪文件清理
|
||||
"""
|
||||
"""Run emoji maintenance tasks periodically."""
|
||||
while True:
|
||||
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
|
||||
EMOJI_REGISTERED_DIR.mkdir(parents=True, exist_ok=True)
|
||||
_ensure_directories()
|
||||
try:
|
||||
self.check_emoji_file_integrity()
|
||||
self.remove_untracked_emoji_files()
|
||||
except Exception as e:
|
||||
logger.error(f"[定期维护] 执行表情包维护任务时出错: {e}")
|
||||
logger.error(f"[emoji_maintenance] Maintenance task failed: {e}")
|
||||
|
||||
if global_config.emoji.steal_emoji and (
|
||||
self._emoji_num < global_config.emoji.max_reg_num
|
||||
or (self._emoji_num > global_config.emoji.max_reg_num and global_config.emoji.do_replace)
|
||||
):
|
||||
logger.info("[定期维护] 尝试从表情包盗取目录注册新表情包...")
|
||||
logger.info("[emoji_maintenance] Scanning data/emoji for new emojis...")
|
||||
for emoji_file in EMOJI_DIR.iterdir():
|
||||
if not emoji_file.is_file():
|
||||
continue
|
||||
try:
|
||||
register_success = await self.register_emoji_by_filename(emoji_file)
|
||||
register_status = await self.register_emoji_by_filename(emoji_file)
|
||||
except Exception as e:
|
||||
logger.error(f"[定期维护] 注册表情包 {emoji_file.name} 时发生未处理异常: {e}")
|
||||
register_success = False
|
||||
if register_success:
|
||||
break # 每次只注册一个表情包
|
||||
try:
|
||||
emoji_file.unlink()
|
||||
logger.info(f"[定期维护] 删除无法注册的表情包文件: {emoji_file.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"[定期维护] 删除文件 {emoji_file.name} 时出错: {e}")
|
||||
logger.error(f"[emoji_maintenance] Failed to process {emoji_file.name}: {e}")
|
||||
register_status = "failed"
|
||||
if register_status == "registered":
|
||||
break
|
||||
if register_status == "skipped":
|
||||
logger.debug(f"[emoji_maintenance] Emoji already registered, keep file: {emoji_file.name}")
|
||||
else:
|
||||
logger.debug(f"[emoji_maintenance] Emoji not registered, keep file: {emoji_file.name}")
|
||||
wait_seconds = max(global_config.emoji.check_interval * 60, 0)
|
||||
try:
|
||||
await asyncio.wait_for(self._maintenance_wakeup_event.wait(), timeout=wait_seconds)
|
||||
@@ -1006,84 +986,81 @@ class EmojiManager:
|
||||
finally:
|
||||
self._maintenance_wakeup_event.clear()
|
||||
|
||||
async def register_emoji_by_filename(self, filename: Path | str) -> bool:
|
||||
"""
|
||||
根据指定的表情包图片,分析并注册到数据库
|
||||
|
||||
Args:
|
||||
filename (Path | str): 表情包图片的完整文件路径(可能根据文件实际格式修正)
|
||||
|
||||
Returns:
|
||||
return (bool): 注册是否成功
|
||||
"""
|
||||
async def register_emoji_by_filename(self, filename: Path | str) -> EmojiRegisterStatus:
|
||||
"""Register an emoji file from ``data/emoji`` without moving or deleting it."""
|
||||
file_full_path = Path(filename).absolute().resolve()
|
||||
if not file_full_path.exists():
|
||||
logger.error(f"[注册表情包] 表情包文件不存在: {file_full_path}")
|
||||
return False
|
||||
logger.error(f"[register_emoji] Emoji file does not exist: {file_full_path}")
|
||||
return "failed"
|
||||
try:
|
||||
target_emoji = MaiEmoji(full_path=file_full_path)
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 创建表情包对象时出错: {e}")
|
||||
return False
|
||||
logger.error(f"[register_emoji] Failed to create emoji object: {e}")
|
||||
return "failed"
|
||||
|
||||
calc_success = await target_emoji.calculate_hash_format()
|
||||
if not calc_success:
|
||||
logger.error(f"[注册表情包] 计算表情包哈希值和格式失败: {file_full_path}")
|
||||
return False
|
||||
file_full_path = target_emoji.full_path # 更新为可能修正后的路径
|
||||
logger.error(f"[register_emoji] Failed to calculate hash and format: {file_full_path}")
|
||||
return "failed"
|
||||
file_full_path = target_emoji.full_path
|
||||
|
||||
# 2. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建
|
||||
existing_record: Optional[Images] = None
|
||||
try:
|
||||
with get_db_session_manual() as session:
|
||||
statement = (
|
||||
select(Images).filter_by(image_hash=target_emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
)
|
||||
if image_record := session.exec(statement).first():
|
||||
if image_record.no_file_flag:
|
||||
image_record.no_file_flag = False
|
||||
image_record.is_banned = False
|
||||
image_record.is_registered = True
|
||||
image_record.full_path = str(target_emoji.full_path)
|
||||
session.add(image_record)
|
||||
session.commit()
|
||||
logger.info(f"表情包注册成功,Hash: {target_emoji.file_hash}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"[注册表情包] 数据库中已存在表情包记录,跳过注册: {target_emoji.file_name}")
|
||||
return False
|
||||
existing_record = session.exec(statement).first()
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
|
||||
return False
|
||||
logger.error(f"[register_emoji] Failed to query database: {e}")
|
||||
return "failed"
|
||||
|
||||
if existing_record is not None:
|
||||
if existing_record.is_registered and _is_available_emoji_record(existing_record):
|
||||
logger.info(f"[register_emoji] Emoji already registered, skipping: {target_emoji.file_name}")
|
||||
return "skipped"
|
||||
|
||||
cached_description = str(existing_record.description or "").strip()
|
||||
if cached_description:
|
||||
normalized_emotions = _normalize_emoji_tag_text(cached_description)
|
||||
target_emoji.description = ",".join(normalized_emotions)
|
||||
target_emoji.emotion = normalized_emotions
|
||||
target_emoji.query_count = existing_record.query_count
|
||||
target_emoji.last_used_time = existing_record.last_used_time
|
||||
target_emoji.register_time = existing_record.register_time
|
||||
|
||||
# 3. 检查内存缓存是否已经存在
|
||||
if existing_emoji := self.get_emoji_by_hash(target_emoji.file_hash):
|
||||
logger.warning(f"[注册表情包] 表情包已存在,跳过注册: {existing_emoji.file_name}")
|
||||
return False
|
||||
# 3. 构建描述(包含情绪标签)
|
||||
desc_success, target_emoji = await self.build_emoji_description(target_emoji)
|
||||
if not desc_success:
|
||||
logger.error(f"[注册表情包] 构建表情包描述失败: {file_full_path}")
|
||||
return False
|
||||
logger.info(f"[register_emoji] Emoji already loaded in memory, skipping: {existing_emoji.file_name}")
|
||||
return "skipped"
|
||||
|
||||
if not target_emoji.description:
|
||||
desc_success, target_emoji = await self.build_emoji_description(target_emoji)
|
||||
if not desc_success:
|
||||
logger.error(f"[register_emoji] Failed to build emoji description: {file_full_path}")
|
||||
return "failed"
|
||||
|
||||
# 4. 检查容量并决定是否替换或者直接注册
|
||||
if self._emoji_num >= global_config.emoji.max_reg_num and global_config.emoji.do_replace:
|
||||
logger.warning(f"[注册表情包] 表情包数量已达上限{global_config.emoji.max_reg_num},尝试替换一个表情包")
|
||||
logger.warning(
|
||||
f"[register_emoji] Emoji limit {global_config.emoji.max_reg_num} reached, trying replacement"
|
||||
)
|
||||
replaced = await self.replace_an_emoji_by_llm(target_emoji)
|
||||
if not replaced:
|
||||
logger.error("[注册表情包] 替换表情包失败,无法注册新表情包")
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
if self.register_emoji_to_db(target_emoji):
|
||||
self.emojis.append(target_emoji)
|
||||
self._emoji_num += 1
|
||||
logger.info(f"[注册表情包] 成功注册新表情包: {target_emoji.file_name}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[注册表情包] 注册表情包到数据库失败: {file_full_path}")
|
||||
return False
|
||||
logger.error("[register_emoji] Failed to replace existing emoji")
|
||||
return "failed"
|
||||
return "registered"
|
||||
|
||||
def _calculate_emotion_similarity_list(self, text_emotion: str) -> List[Tuple[MaiEmoji, float]]:
|
||||
register_status = self.register_emoji_to_db(target_emoji)
|
||||
if register_status == "registered":
|
||||
self.emojis.append(target_emoji)
|
||||
self._emoji_num = len(self.emojis)
|
||||
logger.info(f"[register_emoji] Registered new emoji: {target_emoji.file_name}")
|
||||
elif register_status == "failed":
|
||||
logger.error(f"[register_emoji] Failed to register emoji in database: {file_full_path}")
|
||||
else:
|
||||
logger.info(f"[register_emoji] Emoji already registered, skipping: {target_emoji.file_name}")
|
||||
return register_status
|
||||
|
||||
def _calculate_emotion_similarity_list(self, text_emotion: str) -> list[tuple[MaiEmoji, float]]:
|
||||
"""
|
||||
计算文本情感标签与所有表情包情感标签的相似度列表
|
||||
|
||||
@@ -1096,7 +1073,7 @@ class EmojiManager:
|
||||
if not normalized_text_emotion:
|
||||
return []
|
||||
|
||||
similarity_list: List[Tuple[MaiEmoji, float]] = []
|
||||
similarity_list: list[tuple[MaiEmoji, float]] = []
|
||||
for emoji in self.emojis:
|
||||
candidate_emotions = _get_emoji_emotions(emoji)
|
||||
if not candidate_emotions:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Maisaka 表情工具内置能力。"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Sequence, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import random
|
||||
|
||||
@@ -120,8 +120,6 @@ def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
|
||||
"""提取并清洗单个表情的情绪标签。"""
|
||||
if emoji.description:
|
||||
return _normalize_emoji_tag_text(emoji.description)
|
||||
if emoji.emotion:
|
||||
return _normalize_emoji_tag_text(emoji.emotion)
|
||||
return []
|
||||
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ class ImageManager:
|
||||
"""初始化图片管理器。"""
|
||||
_ensure_image_dir_exists()
|
||||
self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {}
|
||||
self.cleanup_legacy_image_registration_records()
|
||||
|
||||
logger.info("图片管理器初始化完成")
|
||||
|
||||
@@ -50,6 +51,17 @@ class ImageManager:
|
||||
statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1)
|
||||
return session.exec(statement).first()
|
||||
|
||||
def _normalize_image_registration_fields(self, record: Images) -> bool:
|
||||
"""Normalize accidental emoji registration fields on image records."""
|
||||
if record.image_type != ImageType.IMAGE:
|
||||
return False
|
||||
if not record.is_registered and record.register_time is None:
|
||||
return False
|
||||
|
||||
record.is_registered = False
|
||||
record.register_time = None
|
||||
return True
|
||||
|
||||
async def get_image_description(
|
||||
self,
|
||||
*,
|
||||
@@ -182,8 +194,9 @@ class ImageManager:
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
record = image.to_db_instance()
|
||||
record.is_registered = True
|
||||
record.register_time = record.last_used_time = datetime.now()
|
||||
record.is_registered = False
|
||||
record.register_time = None
|
||||
record.last_used_time = datetime.now()
|
||||
session.add(record)
|
||||
session.flush() # 确保记录被写入数据库以获取ID
|
||||
record_id = record.id
|
||||
@@ -209,6 +222,7 @@ class ImageManager:
|
||||
if not record:
|
||||
logger.error(f"未找到哈希值为 {image.file_hash} 的图片记录,无法更新描述")
|
||||
return False
|
||||
self._normalize_image_registration_fields(record)
|
||||
record.description = image.description
|
||||
record.last_used_time = datetime.now()
|
||||
record.vlm_processed = image.vlm_processed
|
||||
@@ -258,6 +272,7 @@ class ImageManager:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1)
|
||||
if record := session.exec(statement).first():
|
||||
self._normalize_image_registration_fields(record)
|
||||
logger.info(f"图片已存在于数据库中,哈希值: {hash_str}")
|
||||
record.last_used_time = datetime.now()
|
||||
record.query_count += 1
|
||||
@@ -335,6 +350,24 @@ class ImageManager:
|
||||
|
||||
logger.info(f"清理完成: {invalid_counter} 条无效描述记录,{null_path_counter} 条文件路径不存在记录")
|
||||
|
||||
def cleanup_legacy_image_registration_records(self) -> None:
|
||||
"""Clean up legacy image records with mistaken registration fields."""
|
||||
fixed_counter = 0
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_type=ImageType.IMAGE)
|
||||
for record in session.exec(statement).yield_per(100):
|
||||
if not self._normalize_image_registration_fields(record):
|
||||
continue
|
||||
session.add(record)
|
||||
fixed_counter += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clean image registration state: {e}")
|
||||
return
|
||||
|
||||
if fixed_counter:
|
||||
logger.info(f"Cleaned mistaken registration state on {fixed_counter} image records")
|
||||
|
||||
async def _generate_image_description(self, image_bytes: bytes, image_format: str) -> str:
|
||||
prompt = global_config.personality.visual_style
|
||||
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
Reference in New Issue
Block a user