feat:优化表情包注册,迁移数据库v3

This commit is contained in:
SengokuCola
2026-04-05 20:12:21 +08:00
parent 80be746be0
commit 526fc9b763
16 changed files with 926 additions and 534 deletions

View File

@@ -1,6 +1,6 @@
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Literal, Optional
import asyncio
import hashlib
@@ -31,12 +31,12 @@ install(extra_lines=3)
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve()
DATA_DIR = PROJECT_ROOT / "data"
EmojiRegisterStatus = Literal["registered", "skipped", "failed"]
EMOJI_DIR = DATA_DIR / "emoji" # 表情包存储目录
EMOJI_REGISTERED_DIR = DATA_DIR / "emoji_registered" # 已注册的表情包注册目录
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
def register_emoji_hook_specs(registry: HookSpecRegistry) -> list[HookSpec]:
"""注册表情包系统内置 Hook 规格。
Args:
@@ -145,7 +145,7 @@ def _get_runtime_manager() -> Any:
return get_plugin_runtime_manager()
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, Any]]:
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[dict[str, Any]]:
"""将表情包对象序列化为 Hook 可传输载荷。
Args:
@@ -163,27 +163,27 @@ def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, A
"file_name": emoji.file_name,
"full_path": str(emoji.full_path),
"description": emoji.description,
"emotions": [str(item).strip() for item in _normalize_emoji_tag_text(emoji.description or emoji.emotion)],
"emotions": [str(item).strip() for item in _normalize_emoji_tag_text(emoji.description)],
"query_count": int(emoji.query_count),
}
def _normalize_emoji_tag_text(raw_values: Any) -> List[str]:
def _normalize_emoji_tag_text(raw_values: Any) -> list[str]:
"""将文本或标签列表转为去重的情绪标签列表。"""
if isinstance(raw_values, str):
if not raw_values:
return []
parts = re.split(r"[,,、;\s]+", raw_values.strip())
parts = re.split(r"[,,、;\r\n\t]+", raw_values.strip())
normalized_tags = [str(part).strip() for part in parts if str(part).strip()]
elif isinstance(raw_values, list):
normalized_tags: List[str] = []
normalized_tags: list[str] = []
for value in raw_values:
normalized_tags.extend(_normalize_emoji_tag_text(value))
else:
return []
deduped_tags: List[str] = []
seen: Set[str] = set()
deduped_tags: list[str] = []
seen: set[str] = set()
for tag in normalized_tags:
normalized_tag = tag.strip()
if not normalized_tag:
@@ -196,15 +196,23 @@ def _normalize_emoji_tag_text(raw_values: Any) -> List[str]:
return deduped_tags
def _get_emoji_emotions(emoji: MaiEmoji) -> List[str]:
"""获取兼容旧数据的表情包情绪标签。"""
return _normalize_emoji_tag_text(emoji.description or emoji.emotion)
def _get_emoji_emotions(emoji: MaiEmoji) -> list[str]:
"""获取表情包情绪标签。"""
return _normalize_emoji_tag_text(emoji.description)
def _ensure_directories() -> None:
"""确保表情包相关目录存在"""
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
EMOJI_REGISTERED_DIR.mkdir(parents=True, exist_ok=True)
def _is_available_emoji_record(record: Images) -> bool:
"""判断数据库记录对应的表情包文件当前是否可用。"""
if record.no_file_flag:
return False
record_path = Path(record.full_path)
return record_path.exists() and record_path.is_file()
# TODO: 修改这个vlm为获取的vlm client暂时使用这个VLM方法
@@ -224,9 +232,9 @@ class EmojiManager:
_ensure_directories()
self._emoji_num: int = 0
self.emojis: List[MaiEmoji] = []
self.emojis: list[MaiEmoji] = []
self._maintenance_wakeup_event: asyncio.Event = asyncio.Event()
self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {}
self._pending_description_tasks: dict[str, asyncio.Task[None]] = {}
self._reload_callback_registered: bool = False
config_manager.register_reload_callback(self.reload_runtime_config)
@@ -254,7 +262,7 @@ class EmojiManager:
emoji_bytes: Optional[bytes] = None,
emoji_hash: Optional[str] = None,
wait_for_build: bool = True,
) -> Optional[Tuple[str, List[str]]]:
) -> Optional[tuple[str, list[str]]]:
"""
根据表情包哈希获取表情包描述和情感列表的封装方法
@@ -275,12 +283,28 @@ class EmojiManager:
emoji_hash = hashlib.sha256(emoji_bytes).hexdigest()
if emoji := self.get_emoji_by_hash(emoji_hash):
emoji_path = Path(emoji.full_path) if emoji.full_path else None
if emoji_bytes and (emoji_path is None or not emoji_path.exists()):
try:
restored_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
emoji.full_path = restored_emoji.full_path
emoji.file_name = restored_emoji.file_name
except Exception as e:
logger.warning(f"表情包缓存命中但本地文件缺失,回填失败: {e}")
return emoji.description, _normalize_emoji_tag_text(emoji.description or "")
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1)
if result := session.exec(statement).first():
cached_description = result.description or result.emotion or ""
record_path = Path(result.full_path) if result.full_path else None
if emoji_bytes and (result.no_file_flag or record_path is None or not record_path.exists()):
try:
restored_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
result.full_path = str(restored_emoji.full_path)
result.no_file_flag = False
except Exception as e:
logger.warning(f"数据库命中表情包记录但本地文件缺失,回填失败: {e}")
cached_description = result.description or ""
cached_emotions = _normalize_emoji_tag_text(cached_description)
return (
cached_description,
@@ -400,7 +424,7 @@ class EmojiManager:
self,
emoji_hash: str,
emoji_bytes: bytes,
) -> Optional[Tuple[str, List[str]]]:
) -> Optional[tuple[str, list[str]]]:
"""构建并缓存表情包描述(返回标签化结果,不再走额外识别流程)。"""
logger.info(f"Start building cached emoji description, hash={emoji_hash}")
new_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
@@ -461,87 +485,70 @@ class EmojiManager:
self._emoji_num = 0
raise e
def register_emoji_to_db(self, emoji: MaiEmoji) -> bool:
def register_emoji_to_db(self, emoji: MaiEmoji) -> EmojiRegisterStatus:
# sourcery skip: extract-method
"""
将表情包注册到数据库中
Args:
emoji (MaiEmoji): 需要注册的表情包对象
Returns:
return (bool): 注册是否成功
"""
"""Register an emoji in the database without moving its source file."""
if not emoji or not isinstance(emoji, MaiEmoji):
logger.error("[注册表情包] 无效的表情包对象")
return False
logger.error("[register_emoji] Invalid emoji object")
return "failed"
if not emoji.full_path.exists():
logger.error(f"[注册表情包] 表情包文件不存在: {emoji.full_path}")
return False
logger.error(f"[register_emoji] Emoji file does not exist: {emoji.full_path}")
return "failed"
target_path = EMOJI_REGISTERED_DIR / emoji.file_name
# 先查库,避免重复记录导致文件被误移动后无法回收
original_path = emoji.full_path
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
existing_record = session.exec(statement).first()
if existing_record and not existing_record.no_file_flag:
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
return False
except Exception as e:
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
return False
try:
emoji.full_path.replace(target_path)
emoji.full_path = target_path
except Exception as e:
logger.error(f"[注册表情包] 移动表情包文件时出错: {e}")
return False
if existing_record:
if existing_record.is_registered and _is_available_emoji_record(existing_record):
logger.info(f"[register_emoji] Emoji already registered, skipping: {emoji.file_hash}")
return "skipped"
normalized_description = str(emoji.description or existing_record.description or "").strip()
normalized_emotions = _normalize_emoji_tag_text(normalized_description)
register_time = existing_record.register_time or datetime.now()
existing_record.full_path = str(emoji.full_path)
existing_record.description = normalized_description
existing_record.query_count = max(int(existing_record.query_count), int(emoji.query_count))
existing_record.last_used_time = emoji.last_used_time or existing_record.last_used_time
existing_record.is_registered = True
existing_record.is_banned = False
existing_record.no_file_flag = False
existing_record.register_time = register_time
session.add(existing_record)
emoji.description = normalized_description
emoji.emotion = normalized_emotions
emoji.query_count = existing_record.query_count
emoji.last_used_time = existing_record.last_used_time
emoji.register_time = register_time
logger.info(
f"[register_emoji] Updated existing record and registered emoji, ID: {existing_record.id}, path: {emoji.full_path}"
)
return "registered"
except Exception as e:
logger.error(f"[register_emoji] Database query failed: {e}")
return "failed"
# 注册到数据库
restore_file = False
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
if existing_record := session.exec(statement).first():
if existing_record.no_file_flag:
existing_record.no_file_flag = False
existing_record.is_banned = False
existing_record.full_path = str(emoji.full_path)
existing_record.description = emoji.description
existing_record.query_count = emoji.query_count
existing_record.last_used_time = emoji.last_used_time
existing_record.register_time = emoji.register_time
session.add(existing_record)
logger.info(
f"[注册表情包] 更新已有记录并注册表情包到数据库, ID: {existing_record.id}, 路径: {emoji.full_path}"
)
else:
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
restore_file = True
return False
else:
image_record = emoji.to_db_instance()
image_record.is_registered = True
image_record.is_banned = False
image_record.register_time = datetime.now()
session.add(image_record)
session.flush()
record_id = image_record.id
logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}")
image_record = emoji.to_db_instance()
image_record.is_registered = True
image_record.is_banned = False
image_record.no_file_flag = False
image_record.register_time = datetime.now()
session.add(image_record)
session.flush()
emoji.register_time = image_record.register_time
record_id = image_record.id
logger.info(f"[register_emoji] Registered emoji to database, ID: {record_id}, path: {emoji.full_path}")
except Exception as e:
logger.error(f"[注册表情包] 注册到数据库时出错: {e}")
restore_file = True
return False
finally:
if restore_file:
try:
emoji.full_path.replace(original_path)
emoji.full_path = original_path
except Exception as e:
logger.error(f"[注册表情包] 回滚文件移动失败: {e}")
return True
logger.error(f"[register_emoji] Failed to write database record: {e}")
return "failed"
return "registered"
def delete_emoji(self, emoji: MaiEmoji, no_desc: bool = False) -> bool:
"""
@@ -636,7 +643,6 @@ class EmojiManager:
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
if image_record := session.exec(statement).first():
image_record.description = emoji.description
image_record.emotion = None
session.add(image_record)
logger.info(f"[更新表情包] 成功更新表情包信息: {emoji.file_hash}")
else:
@@ -794,12 +800,15 @@ class EmojiManager:
logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}")
if self.delete_emoji(emoji_to_delete):
self.emojis.remove(emoji_to_delete)
if self.register_emoji_to_db(new_emoji):
register_status = self.register_emoji_to_db(new_emoji)
if register_status == "registered":
self.emojis.append(new_emoji)
logger.info(f"[注册表情包] 成功替换并注册新表情包: {new_emoji.description}")
logger.info(f"[register_emoji] Replaced old emoji with new emoji: {new_emoji.description}")
return True
if register_status == "skipped":
logger.info(f"[register_emoji] Replacement emoji was already registered: {new_emoji.description}")
else:
logger.error(f"[注册表情包] 注册新表情包失败: {new_emoji.description}")
logger.error(f"[register_emoji] Failed to register replacement emoji: {new_emoji.description}")
else:
logger.error("[错误] 删除表情包失败,无法完成替换")
else:
@@ -808,7 +817,7 @@ class EmojiManager:
logger.error("[决策] 未能解析删除编号")
return False
async def build_emoji_description(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]:
async def build_emoji_description(self, target_emoji: MaiEmoji) -> tuple[bool, MaiEmoji]:
"""
构建表情包描述
@@ -915,16 +924,12 @@ class EmojiManager:
logger.info(f"[构建描述] 成功为表情包构建情绪标签: {target_emoji.description}")
return True, target_emoji
async def build_emoji_emotion(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]:
"""兼容保留:表情包情绪标签已在 build_emoji_description 中一次性构建。"""
return await self.build_emoji_description(target_emoji)
def check_emoji_file_integrity(self) -> None:
"""
检查表情包完整性,删除文件缺失的表情包记录
"""
logger.info("[完整性检查] 开始检查表情包文件完整性...")
to_delete_emojis: list[Tuple[MaiEmoji, bool]] = []
to_delete_emojis: list[tuple[MaiEmoji, bool]] = []
removal_count = 0
for emoji in self.emojis:
if not emoji.full_path.exists():
@@ -945,59 +950,34 @@ class EmojiManager:
logger.info(f"[完整性检查] 表情包文件完整性检查完成,删除了 {removal_count} 条记录")
def remove_untracked_emoji_files(self) -> None:
"""
删除未被数据库记录跟踪的表情包文件
"""
logger.info("[未跟踪表情包清理] 开始清理未被数据库记录跟踪的表情包文件...")
tracked_files = {emoji.full_path.name for emoji in self.emojis}
all_files = set(EMOJI_REGISTERED_DIR.glob("*"))
removal_count = 0
for file_path in all_files:
if file_path.name not in tracked_files:
try:
file_path.unlink()
removal_count += 1
logger.info(f"[未跟踪表情包清理] 删除未跟踪的表情包文件: {file_path.name}")
except Exception as e:
logger.error(f"[未跟踪表情包清理] 删除文件 {file_path.name} 时出错: {e}")
logger.info(f"[未跟踪表情包清理] 未跟踪表情包文件清理完成,删除了 {removal_count} 个文件")
async def periodic_emoji_maintenance(self) -> None:
"""
定期执行表情包维护任务,包括完整性检查和未跟踪文件清理
"""
"""Run emoji maintenance tasks periodically."""
while True:
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
EMOJI_REGISTERED_DIR.mkdir(parents=True, exist_ok=True)
_ensure_directories()
try:
self.check_emoji_file_integrity()
self.remove_untracked_emoji_files()
except Exception as e:
logger.error(f"[定期维护] 执行表情包维护任务时出错: {e}")
logger.error(f"[emoji_maintenance] Maintenance task failed: {e}")
if global_config.emoji.steal_emoji and (
self._emoji_num < global_config.emoji.max_reg_num
or (self._emoji_num > global_config.emoji.max_reg_num and global_config.emoji.do_replace)
):
logger.info("[定期维护] 尝试从表情包盗取目录注册新表情包...")
logger.info("[emoji_maintenance] Scanning data/emoji for new emojis...")
for emoji_file in EMOJI_DIR.iterdir():
if not emoji_file.is_file():
continue
try:
register_success = await self.register_emoji_by_filename(emoji_file)
register_status = await self.register_emoji_by_filename(emoji_file)
except Exception as e:
logger.error(f"[定期维护] 注册表情包 {emoji_file.name} 时发生未处理异常: {e}")
register_success = False
if register_success:
break # 每次只注册一个表情包
try:
emoji_file.unlink()
logger.info(f"[定期维护] 删除无法注册的表情包文件: {emoji_file.name}")
except Exception as e:
logger.error(f"[定期维护] 删除文件 {emoji_file.name} 时出错: {e}")
logger.error(f"[emoji_maintenance] Failed to process {emoji_file.name}: {e}")
register_status = "failed"
if register_status == "registered":
break
if register_status == "skipped":
logger.debug(f"[emoji_maintenance] Emoji already registered, keep file: {emoji_file.name}")
else:
logger.debug(f"[emoji_maintenance] Emoji not registered, keep file: {emoji_file.name}")
wait_seconds = max(global_config.emoji.check_interval * 60, 0)
try:
await asyncio.wait_for(self._maintenance_wakeup_event.wait(), timeout=wait_seconds)
@@ -1006,84 +986,81 @@ class EmojiManager:
finally:
self._maintenance_wakeup_event.clear()
async def register_emoji_by_filename(self, filename: Path | str) -> bool:
"""
根据指定的表情包图片,分析并注册到数据库
Args:
filename (Path | str): 表情包图片的完整文件路径(可能根据文件实际格式修正)
Returns:
return (bool): 注册是否成功
"""
async def register_emoji_by_filename(self, filename: Path | str) -> EmojiRegisterStatus:
"""Register an emoji file from ``data/emoji`` without moving or deleting it."""
file_full_path = Path(filename).absolute().resolve()
if not file_full_path.exists():
logger.error(f"[注册表情包] 表情包文件不存在: {file_full_path}")
return False
logger.error(f"[register_emoji] Emoji file does not exist: {file_full_path}")
return "failed"
try:
target_emoji = MaiEmoji(full_path=file_full_path)
except Exception as e:
logger.error(f"[注册表情包] 创建表情包对象时出错: {e}")
return False
logger.error(f"[register_emoji] Failed to create emoji object: {e}")
return "failed"
calc_success = await target_emoji.calculate_hash_format()
if not calc_success:
logger.error(f"[注册表情包] 计算表情包哈希值和格式失败: {file_full_path}")
return False
file_full_path = target_emoji.full_path # 更新为可能修正后的路径
logger.error(f"[register_emoji] Failed to calculate hash and format: {file_full_path}")
return "failed"
file_full_path = target_emoji.full_path
# 2. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建
existing_record: Optional[Images] = None
try:
with get_db_session_manual() as session:
statement = (
select(Images).filter_by(image_hash=target_emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
)
if image_record := session.exec(statement).first():
if image_record.no_file_flag:
image_record.no_file_flag = False
image_record.is_banned = False
image_record.is_registered = True
image_record.full_path = str(target_emoji.full_path)
session.add(image_record)
session.commit()
logger.info(f"表情包注册成功Hash: {target_emoji.file_hash}")
return True
else:
logger.warning(f"[注册表情包] 数据库中已存在表情包记录,跳过注册: {target_emoji.file_name}")
return False
existing_record = session.exec(statement).first()
except Exception as e:
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
return False
logger.error(f"[register_emoji] Failed to query database: {e}")
return "failed"
if existing_record is not None:
if existing_record.is_registered and _is_available_emoji_record(existing_record):
logger.info(f"[register_emoji] Emoji already registered, skipping: {target_emoji.file_name}")
return "skipped"
cached_description = str(existing_record.description or "").strip()
if cached_description:
normalized_emotions = _normalize_emoji_tag_text(cached_description)
target_emoji.description = ",".join(normalized_emotions)
target_emoji.emotion = normalized_emotions
target_emoji.query_count = existing_record.query_count
target_emoji.last_used_time = existing_record.last_used_time
target_emoji.register_time = existing_record.register_time
# 3. 检查内存缓存是否已经存在
if existing_emoji := self.get_emoji_by_hash(target_emoji.file_hash):
logger.warning(f"[注册表情包] 表情包已存在,跳过注册: {existing_emoji.file_name}")
return False
# 3. 构建描述(包含情绪标签)
desc_success, target_emoji = await self.build_emoji_description(target_emoji)
if not desc_success:
logger.error(f"[注册表情包] 构建表情包描述失败: {file_full_path}")
return False
logger.info(f"[register_emoji] Emoji already loaded in memory, skipping: {existing_emoji.file_name}")
return "skipped"
if not target_emoji.description:
desc_success, target_emoji = await self.build_emoji_description(target_emoji)
if not desc_success:
logger.error(f"[register_emoji] Failed to build emoji description: {file_full_path}")
return "failed"
# 4. 检查容量并决定是否替换或者直接注册
if self._emoji_num >= global_config.emoji.max_reg_num and global_config.emoji.do_replace:
logger.warning(f"[注册表情包] 表情包数量已达上限{global_config.emoji.max_reg_num},尝试替换一个表情包")
logger.warning(
f"[register_emoji] Emoji limit {global_config.emoji.max_reg_num} reached, trying replacement"
)
replaced = await self.replace_an_emoji_by_llm(target_emoji)
if not replaced:
logger.error("[注册表情包] 替换表情包失败,无法注册新表情包")
return False
return True
else:
if self.register_emoji_to_db(target_emoji):
self.emojis.append(target_emoji)
self._emoji_num += 1
logger.info(f"[注册表情包] 成功注册新表情包: {target_emoji.file_name}")
return True
else:
logger.error(f"[注册表情包] 注册表情包到数据库失败: {file_full_path}")
return False
logger.error("[register_emoji] Failed to replace existing emoji")
return "failed"
return "registered"
def _calculate_emotion_similarity_list(self, text_emotion: str) -> List[Tuple[MaiEmoji, float]]:
register_status = self.register_emoji_to_db(target_emoji)
if register_status == "registered":
self.emojis.append(target_emoji)
self._emoji_num = len(self.emojis)
logger.info(f"[register_emoji] Registered new emoji: {target_emoji.file_name}")
elif register_status == "failed":
logger.error(f"[register_emoji] Failed to register emoji in database: {file_full_path}")
else:
logger.info(f"[register_emoji] Emoji already registered, skipping: {target_emoji.file_name}")
return register_status
def _calculate_emotion_similarity_list(self, text_emotion: str) -> list[tuple[MaiEmoji, float]]:
"""
计算文本情感标签与所有表情包情感标签的相似度列表
@@ -1096,7 +1073,7 @@ class EmojiManager:
if not normalized_text_emotion:
return []
similarity_list: List[Tuple[MaiEmoji, float]] = []
similarity_list: list[tuple[MaiEmoji, float]] = []
for emoji in self.emojis:
candidate_emotions = _get_emoji_emotions(emoji)
if not candidate_emotions:

View File

@@ -1,8 +1,8 @@
"""Maisaka 表情工具内置能力。"""
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass, field
from typing import Any, Optional, Sequence, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
import random
@@ -120,8 +120,6 @@ def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
"""提取并清洗单个表情的情绪标签。"""
if emoji.description:
return _normalize_emoji_tag_text(emoji.description)
if emoji.emotion:
return _normalize_emoji_tag_text(emoji.emotion)
return []

View File

@@ -41,6 +41,7 @@ class ImageManager:
"""初始化图片管理器。"""
_ensure_image_dir_exists()
self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {}
self.cleanup_legacy_image_registration_records()
logger.info("图片管理器初始化完成")
@@ -50,6 +51,17 @@ class ImageManager:
statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1)
return session.exec(statement).first()
def _normalize_image_registration_fields(self, record: Images) -> bool:
"""Normalize accidental emoji registration fields on image records."""
if record.image_type != ImageType.IMAGE:
return False
if not record.is_registered and record.register_time is None:
return False
record.is_registered = False
record.register_time = None
return True
async def get_image_description(
self,
*,
@@ -182,8 +194,9 @@ class ImageManager:
try:
with get_db_session() as session:
record = image.to_db_instance()
record.is_registered = True
record.register_time = record.last_used_time = datetime.now()
record.is_registered = False
record.register_time = None
record.last_used_time = datetime.now()
session.add(record)
session.flush() # 确保记录被写入数据库以获取ID
record_id = record.id
@@ -209,6 +222,7 @@ class ImageManager:
if not record:
logger.error(f"未找到哈希值为 {image.file_hash} 的图片记录,无法更新描述")
return False
self._normalize_image_registration_fields(record)
record.description = image.description
record.last_used_time = datetime.now()
record.vlm_processed = image.vlm_processed
@@ -258,6 +272,7 @@ class ImageManager:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1)
if record := session.exec(statement).first():
self._normalize_image_registration_fields(record)
logger.info(f"图片已存在于数据库中,哈希值: {hash_str}")
record.last_used_time = datetime.now()
record.query_count += 1
@@ -335,6 +350,24 @@ class ImageManager:
logger.info(f"清理完成: {invalid_counter} 条无效描述记录,{null_path_counter} 条文件路径不存在记录")
def cleanup_legacy_image_registration_records(self) -> None:
"""Clean up legacy image records with mistaken registration fields."""
fixed_counter = 0
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_type=ImageType.IMAGE)
for record in session.exec(statement).yield_per(100):
if not self._normalize_image_registration_fields(record):
continue
session.add(record)
fixed_counter += 1
except Exception as e:
logger.error(f"Failed to clean image registration state: {e}")
return
if fixed_counter:
logger.info(f"Cleaned mistaken registration state on {fixed_counter} image records")
async def _generate_image_description(self, image_bytes: bytes, image_format: str) -> str:
prompt = global_config.personality.visual_style
image_base64 = base64.b64encode(image_bytes).decode("utf-8")

View File

@@ -1,14 +1,20 @@
from datetime import datetime
from typing import Optional, Dict
from typing import Dict, Optional
import json
from src.common.database.database_model import ActionRecord
from src.common.database.database_model import ToolRecord
from . import BaseDatabaseDataModel
class MaiActionRecord(BaseDatabaseDataModel[ActionRecord]):
class MaiActionRecord(BaseDatabaseDataModel[ToolRecord]):
"""``action_records`` 的兼容数据模型。
历史动作记录已统一并入 ``tool_records``,该类仅保留旧命名接口,
底层读写对象统一映射为 ``ToolRecord``。
"""
def __init__(
self,
action_id: str,
@@ -21,45 +27,39 @@ class MaiActionRecord(BaseDatabaseDataModel[ActionRecord]):
action_display_prompt: Optional[str] = None,
):
self.action_id = action_id
"""动作ID"""
self.timestamp = timestamp
"""时间戳"""
self.session_id = session_id
"""会话ID"""
self.action_name = action_name
"""动作名称"""
self.action_reasoning = action_reasoning
"""动作推理过程"""
self.action_data = action_data or {}
"""动作数据"""
self.action_builtin_prompt = action_builtin_prompt
"""内置动作提示"""
self.action_display_prompt = action_display_prompt
"""最终输入到 Prompt 的内容"""
@classmethod
def from_db_instance(cls, db_record: ActionRecord):
"""Create a data model object from a database record."""
def from_db_instance(cls, db_record: ToolRecord):
"""从数据库实例创建兼容数据模型对象。"""
return cls(
action_id=db_record.action_id,
action_id=db_record.tool_id,
timestamp=db_record.timestamp,
session_id=db_record.session_id,
action_name=db_record.action_name,
action_reasoning=db_record.action_reasoning,
action_data=json.loads(db_record.action_data) if db_record.action_data else None,
action_builtin_prompt=db_record.action_builtin_prompt,
action_display_prompt=db_record.action_display_prompt,
action_name=db_record.tool_name,
action_reasoning=db_record.tool_reasoning,
action_data=json.loads(db_record.tool_data) if db_record.tool_data else None,
action_builtin_prompt=db_record.tool_builtin_prompt,
action_display_prompt=db_record.tool_display_prompt,
)
def to_db_instance(self):
"""Convert the data model object back to a database instance."""
return ActionRecord(
action_id=self.action_id,
"""将兼容数据模型对象转换为 ``ToolRecord``。"""
return ToolRecord(
tool_id=self.action_id,
timestamp=self.timestamp,
session_id=self.session_id,
action_name=self.action_name,
action_reasoning=self.action_reasoning,
action_data=json.dumps(self.action_data) if self.action_data else None,
action_builtin_prompt=self.action_builtin_prompt,
action_display_prompt=self.action_display_prompt,
tool_name=self.action_name,
tool_reasoning=self.action_reasoning,
tool_data=json.dumps(self.action_data) if self.action_data else None,
tool_builtin_prompt=self.action_builtin_prompt,
tool_display_prompt=self.action_display_prompt,
)

View File

@@ -1,16 +1,17 @@
from datetime import datetime
from pathlib import Path
from PIL import Image as PILImage
from rich.traceback import install
from typing import Optional, List
import asyncio
import hashlib
import io
import traceback
from datetime import datetime
from pathlib import Path
from typing import List, Optional
from PIL import Image as PILImage
from rich.traceback import install
from src.common.database.database_model import Images, ImageType
from src.common.logger import get_logger
from . import BaseDatabaseDataModel
@@ -152,7 +153,7 @@ class MaiEmoji(BaseImageDataModel):
raise ValueError(f"数据库记录 {db_record.image_hash} 标记为文件不存在,无法创建 MaiEmoji 对象")
obj = cls(db_record.full_path)
obj.file_hash = db_record.image_hash
description = db_record.description or db_record.emotion or ""
description = db_record.description or ""
obj.description = description
normalized_tags = [
str(item).strip()
@@ -175,7 +176,6 @@ class MaiEmoji(BaseImageDataModel):
description=self.description,
full_path=str(self.full_path),
image_type=ImageType.EMOJI,
emotion=None,
query_count=self.query_count,
last_used_time=self.last_used_time,
register_time=self.register_time,

View File

@@ -3,7 +3,7 @@ from pathlib import Path
from typing import ContextManager, Generator, TYPE_CHECKING
from rich.traceback import install
from sqlalchemy import event, text
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel, Session, create_engine
@@ -63,41 +63,6 @@ _migration_bootstrapper = create_database_migration_bootstrapper(engine)
_db_initialized = False
def _migrate_action_records_to_tool_records() -> None:
"""将旧的 ``action_records`` 历史数据迁移到 ``tool_records``。"""
migration_sql = text(
"""
INSERT INTO tool_records (
tool_id,
timestamp,
session_id,
tool_name,
tool_reasoning,
tool_data,
tool_builtin_prompt,
tool_display_prompt
)
SELECT
action_id,
timestamp,
session_id,
action_name,
action_reasoning,
action_data,
action_builtin_prompt,
action_display_prompt
FROM action_records
WHERE NOT EXISTS (
SELECT 1
FROM tool_records
WHERE tool_records.tool_id = action_records.action_id
)
"""
)
with engine.begin() as connection:
connection.execute(migration_sql)
def initialize_database() -> None:
"""初始化数据库连接、结构与启动期迁移。
@@ -105,8 +70,7 @@ def initialize_database() -> None:
1. 确保数据库目录存在;
2. 加载 SQLModel 模型定义;
3. 执行已注册的启动期迁移;
4. 兜底执行 ``create_all`` 确保当前模型定义已建表
5. 执行项目现有的轻量数据补迁移逻辑。
4. 兜底执行 ``create_all`` 确保当前模型定义已建表
"""
global _db_initialized
if _db_initialized:
@@ -120,7 +84,6 @@ def initialize_database() -> None:
f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}"
)
SQLModel.metadata.create_all(engine)
_migrate_action_records_to_tool_records()
_migration_bootstrapper.finalize_database(migration_state)
_db_initialized = True

View File

@@ -94,7 +94,6 @@ class Images(SQLModel, table=True):
full_path: str = Field(max_length=1024) # 文件的完整路径 (包括文件名)
image_type: ImageType = Field(sa_column=Column(SQLEnum(ImageType)), default=ImageType.EMOJI)
"""图片类型,例如 'emoji''image'"""
emotion: Optional[str] = Field(default=None, nullable=True) # 表情包的情感标签,逗号分隔
query_count: int = Field(default=0) # 被查询次数
is_registered: bool = Field(default=False) # 是否已经注册
@@ -113,27 +112,6 @@ class Images(SQLModel, table=True):
vlm_processed: bool = Field(default=False) # 是否已经过VLM处理
class ActionRecord(SQLModel, table=True):
"""存储动作记录"""
__tablename__ = "action_records" # type: ignore
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
# 元信息
action_id: str = Field(index=True, max_length=255) # 动作ID
timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 记录时间戳
session_id: str = Field(index=True, max_length=255) # 对应的 ChatSession session_id
# 调用信息
action_name: str = Field(index=True, max_length=255) # 动作名称
action_reasoning: Optional[str] = Field(default=None) # 动作推理过程
action_data: Optional[str] = Field(default=None) # 动作数据JSON格式存储
action_builtin_prompt: Optional[str] = Field(default=None) # 内置动作提示
action_display_prompt: Optional[str] = Field(default=None) # 最终输入到Prompt的内容
class ToolRecord(SQLModel, table=True):
"""存储工具调用记录"""
@@ -281,28 +259,6 @@ class ChatHistory(SQLModel, table=True):
summary: str # 概括:对这段话的平文本概括
class ThinkingQuestion(SQLModel, table=True):
"""存储思考型问题的模型"""
__tablename__ = "thinking_questions" # type: ignore
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
# 问答对
question: str # 问题内容
context: Optional[str] = Field(default=None, nullable=True) # 上下文
found_answer: bool = Field(default=False) # 是否找到答案
answer: Optional[str] = Field(default=None, nullable=True) # 问题答案
thinking_steps: Optional[str] = Field(default=None, nullable=True) # 思考步骤JSON格式存储
created_timestamp: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
) # 创建时间
updated_timestamp: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
) # 最后更新时间
class BinaryData(SQLModel, table=True):
"""存储二进制数据的模型"""

View File

@@ -5,6 +5,7 @@ from .builtin import (
EMPTY_SCHEMA_VERSION,
LATEST_SCHEMA_VERSION,
LEGACY_V1_SCHEMA_VERSION,
V2_SCHEMA_VERSION,
build_default_migration_registry,
build_default_schema_version_resolver,
)
@@ -61,6 +62,7 @@ __all__ = [
"EMPTY_SCHEMA_VERSION",
"LATEST_SCHEMA_VERSION",
"LEGACY_V1_SCHEMA_VERSION",
"V2_SCHEMA_VERSION",
"MigrationExecutionContext",
"MigrationPlan",
"MigrationPlanner",

View File

@@ -6,12 +6,14 @@ from .legacy_v1_to_v2 import migrate_legacy_v1_to_v2
from .models import DatabaseSchemaSnapshot, MigrationStep
from .registry import MigrationRegistry
from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver
from .version_store import SQLiteUserVersionStore
from .schema import SQLiteSchemaInspector
from .v2_to_v3 import migrate_v2_to_v3
from .version_store import SQLiteUserVersionStore
EMPTY_SCHEMA_VERSION = 0
LEGACY_V1_SCHEMA_VERSION = 1
LATEST_SCHEMA_VERSION = 2
V2_SCHEMA_VERSION = 2
LATEST_SCHEMA_VERSION = 3
_LEGACY_V1_EXCLUSIVE_TABLES = (
"chat_streams",
@@ -24,6 +26,13 @@ _LEGACY_V1_EXCLUSIVE_TABLES = (
"messages",
"thinking_back",
)
_COMMON_MARKER_TABLES = (
"mai_messages",
"chat_sessions",
"expressions",
"jargons",
"tool_records",
)
class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
@@ -36,6 +45,7 @@ class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
Returns:
str: 当前探测器名称。
"""
return "latest_schema_detector"
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
@@ -47,18 +57,16 @@ class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
Returns:
Optional[int]: 若识别为最新结构则返回最新版本号,否则返回 ``None``。
"""
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
return None
latest_marker_tables = (
"mai_messages",
"chat_sessions",
"expressions",
"jargons",
"thinking_questions",
"tool_records",
)
if not all(snapshot.has_table(table_name) for table_name in latest_marker_tables):
if not all(snapshot.has_table(table_name) for table_name in _COMMON_MARKER_TABLES):
return None
if snapshot.has_table("action_records"):
return None
if snapshot.has_table("thinking_questions"):
return None
if snapshot.has_column("images", "emotion"):
return None
if not snapshot.has_column("images", "image_hash"):
return None
@@ -66,13 +74,53 @@ class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
return None
if not snapshot.has_column("images", "image_type"):
return None
if not snapshot.has_column("chat_history", "session_id"):
return None
if not snapshot.has_column("person_info", "user_nickname"):
return None
return LATEST_SCHEMA_VERSION
class V2SchemaVersionDetector(BaseSchemaVersionDetector):
"""v2 schema 结构探测器。"""
@property
def name(self) -> str:
"""返回探测器名称。
Returns:
str: 当前探测器名称。
"""
return "v2_schema_detector"
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
"""检测数据库是否为 v2 结构。
Args:
snapshot: 当前数据库结构快照。
Returns:
Optional[int]: 若识别为 v2 结构则返回 ``2``,否则返回 ``None``。
"""
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
return None
if not all(snapshot.has_table(table_name) for table_name in _COMMON_MARKER_TABLES):
return None
if not snapshot.has_table("action_records"):
return None
if not snapshot.has_table("thinking_questions"):
return None
if not snapshot.has_column("images", "emotion"):
return None
if not snapshot.has_column("action_records", "session_id"):
return None
if not snapshot.has_column("chat_history", "session_id"):
return None
if not snapshot.has_column("person_info", "user_nickname"):
return None
return LATEST_SCHEMA_VERSION
return V2_SCHEMA_VERSION
class LegacyV1SchemaDetector(BaseSchemaVersionDetector):
@@ -85,6 +133,7 @@ class LegacyV1SchemaDetector(BaseSchemaVersionDetector):
Returns:
str: 当前探测器名称。
"""
return "legacy_v1_schema_detector"
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
@@ -96,6 +145,7 @@ class LegacyV1SchemaDetector(BaseSchemaVersionDetector):
Returns:
Optional[int]: 若识别为旧版结构则返回 ``1``,否则返回 ``None``。
"""
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
return LEGACY_V1_SCHEMA_VERSION
@@ -121,8 +171,10 @@ def build_default_schema_version_detectors() -> List[BaseSchemaVersionDetector]:
Returns:
List[BaseSchemaVersionDetector]: 按优先级排序的探测器列表。
"""
return [
LatestSchemaVersionDetector(),
V2SchemaVersionDetector(),
LegacyV1SchemaDetector(),
]
@@ -133,6 +185,7 @@ def build_default_schema_version_resolver() -> SchemaVersionResolver:
Returns:
SchemaVersionResolver: 配置完成的 schema 版本解析器。
"""
return SchemaVersionResolver(
version_store=SQLiteUserVersionStore(),
schema_inspector=SQLiteSchemaInspector(),
@@ -146,14 +199,22 @@ def build_default_migration_registry() -> MigrationRegistry:
Returns:
MigrationRegistry: 含默认迁移步骤的注册表实例。
"""
return MigrationRegistry(
steps=[
MigrationStep(
version_from=LEGACY_V1_SCHEMA_VERSION,
version_to=LATEST_SCHEMA_VERSION,
name="legacy_v1_to_latest_v2",
description="将旧版 0.x 数据库整体迁移到当前最新 schema。",
version_to=V2_SCHEMA_VERSION,
name="legacy_v1_to_v2",
description="将旧版 0.x 数据库迁移到 v2 schema。",
handler=migrate_legacy_v1_to_v2,
)
),
MigrationStep(
version_from=V2_SCHEMA_VERSION,
version_to=LATEST_SCHEMA_VERSION,
name="v2_to_v3",
description="移除废弃表,并将 emoji 标签统一收敛到 description 字段。",
handler=migrate_v2_to_v3,
),
]
)

View File

@@ -0,0 +1,298 @@
"""冻结的 v2 schema 快照。
该模块只用于 ``legacy_v1_to_v2`` 迁移,避免迁移过程依赖当前运行时代码中的
最新 SQLModel 定义,导致历史迁移随着后续 schema 演进而失真。
"""
from sqlalchemy.engine import Connection
_V2_TABLE_STATEMENTS = (
"""
CREATE TABLE IF NOT EXISTS action_records (
id INTEGER NOT NULL,
action_id VARCHAR(255) NOT NULL,
timestamp DATETIME,
session_id VARCHAR(255) NOT NULL,
action_name VARCHAR(255) NOT NULL,
action_reasoning VARCHAR,
action_data VARCHAR,
action_builtin_prompt VARCHAR,
action_display_prompt VARCHAR,
PRIMARY KEY (id)
)
""",
"""
CREATE TABLE IF NOT EXISTS binary_data (
id INTEGER NOT NULL,
data_hash VARCHAR(255) NOT NULL,
full_path VARCHAR(1024) NOT NULL,
PRIMARY KEY (id)
)
""",
"""
CREATE TABLE IF NOT EXISTS chat_history (
id INTEGER NOT NULL,
session_id VARCHAR(255) NOT NULL,
start_timestamp DATETIME,
end_timestamp DATETIME,
query_count INTEGER NOT NULL,
query_forget_count INTEGER NOT NULL,
original_messages VARCHAR NOT NULL,
participants VARCHAR NOT NULL,
theme VARCHAR NOT NULL,
keywords VARCHAR NOT NULL,
summary VARCHAR NOT NULL,
PRIMARY KEY (id)
)
""",
"""
CREATE TABLE IF NOT EXISTS chat_sessions (
id INTEGER NOT NULL,
session_id VARCHAR(255) NOT NULL,
created_timestamp DATETIME,
last_active_timestamp DATETIME,
user_id VARCHAR(255),
group_id VARCHAR(255),
platform VARCHAR(100) NOT NULL,
PRIMARY KEY (id)
)
""",
"""
CREATE TABLE IF NOT EXISTS command_records (
id INTEGER NOT NULL,
timestamp DATETIME,
session_id VARCHAR(255) NOT NULL,
command_name VARCHAR(255) NOT NULL,
command_data VARCHAR,
command_result VARCHAR,
PRIMARY KEY (id)
)
""",
"""
CREATE TABLE IF NOT EXISTS expressions (
id INTEGER NOT NULL,
situation VARCHAR(255) NOT NULL,
style VARCHAR(255) NOT NULL,
content_list VARCHAR NOT NULL,
count INTEGER NOT NULL,
last_active_time DATETIME,
create_time DATETIME,
session_id VARCHAR(255),
checked BOOLEAN NOT NULL,
rejected BOOLEAN NOT NULL,
modified_by VARCHAR(4),
PRIMARY KEY (id)
)
""",
"""
CREATE TABLE IF NOT EXISTS images (
id INTEGER NOT NULL,
image_hash VARCHAR(255) NOT NULL,
description VARCHAR NOT NULL,
full_path VARCHAR(1024) NOT NULL,
image_type VARCHAR(5),
emotion VARCHAR,
query_count INTEGER NOT NULL,
is_registered BOOLEAN NOT NULL,
is_banned BOOLEAN NOT NULL,
no_file_flag BOOLEAN NOT NULL,
record_time DATETIME,
register_time DATETIME,
last_used_time DATETIME,
vlm_processed BOOLEAN NOT NULL,
PRIMARY KEY (id)
)
""",
"""
CREATE TABLE IF NOT EXISTS jargons (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
content VARCHAR(255) NOT NULL,
raw_content TEXT,
meaning TEXT NOT NULL,
session_id_dict TEXT NOT NULL,
count INTEGER NOT NULL,
is_jargon BOOLEAN,
is_complete BOOLEAN NOT NULL,
is_global BOOLEAN NOT NULL,
last_inference_count INTEGER NOT NULL,
inference_with_context TEXT,
inference_with_content_only TEXT
)
""",
"""
CREATE TABLE IF NOT EXISTS llm_usage (
id INTEGER NOT NULL,
model_name VARCHAR(255) NOT NULL,
model_assign_name VARCHAR(255),
model_api_provider_name VARCHAR(255) NOT NULL,
endpoint VARCHAR(255),
user_type VARCHAR(6),
request_type VARCHAR(50) NOT NULL,
time_cost FLOAT,
timestamp DATETIME,
prompt_tokens INTEGER NOT NULL,
completion_tokens INTEGER NOT NULL,
total_tokens INTEGER NOT NULL,
cost FLOAT NOT NULL,
PRIMARY KEY (id)
)
""",
"""
CREATE TABLE IF NOT EXISTS mai_knowledge (
id INTEGER NOT NULL,
knowledge_id VARCHAR(255) NOT NULL,
category_id VARCHAR(32) NOT NULL,
content VARCHAR NOT NULL,
normalized_content VARCHAR NOT NULL,
metadata_json VARCHAR,
created_at DATETIME,
PRIMARY KEY (id)
)
""",
"""
CREATE TABLE IF NOT EXISTS mai_messages (
id INTEGER NOT NULL,
message_id VARCHAR(255) NOT NULL,
timestamp DATETIME,
platform VARCHAR(100) NOT NULL,
user_id VARCHAR(255) NOT NULL,
user_nickname VARCHAR(255) NOT NULL,
user_cardname VARCHAR(255),
group_id VARCHAR(255),
group_name VARCHAR(255),
is_mentioned BOOLEAN NOT NULL,
is_at BOOLEAN NOT NULL,
session_id VARCHAR(255) NOT NULL,
reply_to VARCHAR(255),
is_emoji BOOLEAN NOT NULL,
is_picture BOOLEAN NOT NULL,
is_command BOOLEAN NOT NULL,
is_notify BOOLEAN NOT NULL,
raw_content BLOB,
processed_plain_text VARCHAR,
display_message VARCHAR,
additional_config VARCHAR,
PRIMARY KEY (id)
)
""",
"""
CREATE TABLE IF NOT EXISTS online_time (
id INTEGER NOT NULL,
timestamp DATETIME,
duration_minutes INTEGER NOT NULL,
start_timestamp DATETIME,
end_timestamp DATETIME,
PRIMARY KEY (id)
)
""",
"""
CREATE TABLE IF NOT EXISTS person_info (
id INTEGER NOT NULL,
is_known BOOLEAN NOT NULL,
person_id VARCHAR(255) NOT NULL,
person_name VARCHAR(255),
name_reason VARCHAR,
platform VARCHAR(100) NOT NULL,
user_id VARCHAR(255) NOT NULL,
user_nickname VARCHAR(255) NOT NULL,
group_cardname VARCHAR,
memory_points VARCHAR,
know_counts INTEGER NOT NULL,
first_known_time DATETIME,
last_known_time DATETIME,
PRIMARY KEY (id)
)
""",
"""
CREATE TABLE IF NOT EXISTS thinking_questions (
id INTEGER NOT NULL,
question VARCHAR NOT NULL,
context VARCHAR,
found_answer BOOLEAN NOT NULL,
answer VARCHAR,
thinking_steps VARCHAR,
created_timestamp DATETIME,
updated_timestamp DATETIME,
PRIMARY KEY (id)
)
""",
"""
CREATE TABLE IF NOT EXISTS tool_records (
id INTEGER NOT NULL,
tool_id VARCHAR(255) NOT NULL,
timestamp DATETIME,
session_id VARCHAR(255) NOT NULL,
tool_name VARCHAR(255) NOT NULL,
tool_reasoning VARCHAR,
tool_data VARCHAR,
tool_builtin_prompt VARCHAR,
tool_display_prompt VARCHAR,
PRIMARY KEY (id)
)
""",
)
_V2_INDEX_STATEMENTS = (
"CREATE INDEX IF NOT EXISTS ix_action_records_action_id ON action_records (action_id)",
"CREATE INDEX IF NOT EXISTS ix_action_records_action_name ON action_records (action_name)",
"CREATE INDEX IF NOT EXISTS ix_action_records_session_id ON action_records (session_id)",
"CREATE INDEX IF NOT EXISTS ix_action_records_timestamp ON action_records (timestamp)",
"CREATE INDEX IF NOT EXISTS ix_binary_data_data_hash ON binary_data (data_hash)",
"CREATE INDEX IF NOT EXISTS ix_chat_history_end_timestamp ON chat_history (end_timestamp)",
"CREATE INDEX IF NOT EXISTS ix_chat_history_session_id ON chat_history (session_id)",
"CREATE INDEX IF NOT EXISTS ix_chat_history_start_timestamp ON chat_history (start_timestamp)",
"CREATE INDEX IF NOT EXISTS ix_chat_sessions_created_timestamp ON chat_sessions (created_timestamp)",
"CREATE INDEX IF NOT EXISTS ix_chat_sessions_group_id ON chat_sessions (group_id)",
"CREATE INDEX IF NOT EXISTS ix_chat_sessions_last_active_timestamp ON chat_sessions (last_active_timestamp)",
"CREATE INDEX IF NOT EXISTS ix_chat_sessions_platform ON chat_sessions (platform)",
"CREATE UNIQUE INDEX IF NOT EXISTS ix_chat_sessions_session_id ON chat_sessions (session_id)",
"CREATE INDEX IF NOT EXISTS ix_chat_sessions_user_id ON chat_sessions (user_id)",
"CREATE INDEX IF NOT EXISTS ix_command_records_command_name ON command_records (command_name)",
"CREATE INDEX IF NOT EXISTS ix_command_records_session_id ON command_records (session_id)",
"CREATE INDEX IF NOT EXISTS ix_command_records_timestamp ON command_records (timestamp)",
"CREATE INDEX IF NOT EXISTS ix_expressions_last_active_time ON expressions (last_active_time)",
"CREATE INDEX IF NOT EXISTS ix_expressions_situation ON expressions (situation)",
"CREATE INDEX IF NOT EXISTS ix_expressions_style ON expressions (style)",
"CREATE INDEX IF NOT EXISTS ix_images_image_hash ON images (image_hash)",
"CREATE INDEX IF NOT EXISTS ix_images_record_time ON images (record_time)",
"CREATE INDEX IF NOT EXISTS ix_jargons_content ON jargons (content)",
"CREATE INDEX IF NOT EXISTS ix_llm_usage_model_api_provider_name ON llm_usage (model_api_provider_name)",
"CREATE INDEX IF NOT EXISTS ix_llm_usage_model_assign_name ON llm_usage (model_assign_name)",
"CREATE INDEX IF NOT EXISTS ix_llm_usage_model_name ON llm_usage (model_name)",
"CREATE INDEX IF NOT EXISTS ix_llm_usage_timestamp ON llm_usage (timestamp)",
"CREATE INDEX IF NOT EXISTS ix_mai_knowledge_category_id ON mai_knowledge (category_id)",
"CREATE INDEX IF NOT EXISTS ix_mai_knowledge_created_at ON mai_knowledge (created_at)",
"CREATE INDEX IF NOT EXISTS ix_mai_knowledge_knowledge_id ON mai_knowledge (knowledge_id)",
"CREATE INDEX IF NOT EXISTS ix_mai_knowledge_normalized_content ON mai_knowledge (normalized_content)",
"CREATE INDEX IF NOT EXISTS ix_mai_messages_group_id ON mai_messages (group_id)",
"CREATE INDEX IF NOT EXISTS ix_mai_messages_message_id ON mai_messages (message_id)",
"CREATE INDEX IF NOT EXISTS ix_mai_messages_platform ON mai_messages (platform)",
"CREATE INDEX IF NOT EXISTS ix_mai_messages_session_id ON mai_messages (session_id)",
"CREATE INDEX IF NOT EXISTS ix_mai_messages_user_id ON mai_messages (user_id)",
"CREATE INDEX IF NOT EXISTS ix_mai_messages_user_nickname ON mai_messages (user_nickname)",
"CREATE INDEX IF NOT EXISTS ix_online_time_timestamp ON online_time (timestamp)",
"CREATE UNIQUE INDEX IF NOT EXISTS ix_person_info_person_id ON person_info (person_id)",
"CREATE INDEX IF NOT EXISTS ix_person_info_platform ON person_info (platform)",
"CREATE INDEX IF NOT EXISTS ix_person_info_user_id ON person_info (user_id)",
"CREATE INDEX IF NOT EXISTS ix_person_info_user_nickname ON person_info (user_nickname)",
"CREATE INDEX IF NOT EXISTS ix_thinking_questions_created_timestamp ON thinking_questions (created_timestamp)",
"CREATE INDEX IF NOT EXISTS ix_thinking_questions_updated_timestamp ON thinking_questions (updated_timestamp)",
"CREATE INDEX IF NOT EXISTS ix_tool_records_session_id ON tool_records (session_id)",
"CREATE INDEX IF NOT EXISTS ix_tool_records_timestamp ON tool_records (timestamp)",
"CREATE INDEX IF NOT EXISTS ix_tool_records_tool_id ON tool_records (tool_id)",
"CREATE INDEX IF NOT EXISTS ix_tool_records_tool_name ON tool_records (tool_name)",
)
def create_frozen_v2_schema(connection: Connection) -> None:
"""创建冻结的 v2 schema。
Args:
connection: 当前数据库连接。
"""
for statement in _V2_TABLE_STATEMENTS:
connection.exec_driver_sql(statement)
for statement in _V2_INDEX_STATEMENTS:
connection.exec_driver_sql(statement)

View File

@@ -1,4 +1,4 @@
"""旧版 ``0.x`` 数据库升级到最新 schema 的迁移逻辑。"""
"""旧版 ``0.x`` 数据库升级到 v2 schema 的迁移逻辑。"""
from __future__ import annotations
@@ -7,15 +7,16 @@ from dataclasses import dataclass
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, cast
import json
import msgpack
from sqlalchemy import text
from sqlalchemy.engine import Connection
import json
import msgpack
from src.common.logger import get_logger
from .exceptions import DatabaseMigrationExecutionError
from .frozen_v2_schema import create_frozen_v2_schema
from .models import DatabaseSchemaSnapshot, MigrationExecutionContext
from .schema import SQLiteSchemaInspector
@@ -52,19 +53,15 @@ class LegacyTableData:
def migrate_legacy_v1_to_v2(context: MigrationExecutionContext) -> None:
"""执行旧版 ``0.x`` 数据库到最新 schema 的迁移。
"""执行旧版 ``0.x`` 数据库到 v2 schema 的迁移。
Args:
context: 当前迁移步骤执行上下文。
"""
from sqlmodel import SQLModel
import src.common.database.database_model # noqa: F401
schema_inspector = SQLiteSchemaInspector()
snapshot = schema_inspector.inspect(context.connection)
_rename_legacy_v1_tables(context.connection, snapshot)
SQLModel.metadata.create_all(context.connection)
create_frozen_v2_schema(context.connection)
table_migration_jobs: List[Tuple[str, Callable[[MigrationExecutionContext], int]]] = [
("chat_sessions", _migrate_chat_sessions),
@@ -794,8 +791,6 @@ def _migrate_images(context: MigrationExecutionContext) -> int:
if full_path and dedupe_key not in existing_keys:
migrated_description = _normalize_required_text(row.get("description"))
migrated_emotion = _normalize_optional_text(row.get("emotion"))
if not migrated_description and migrated_emotion:
migrated_description = migrated_emotion
connection.execute(
insert_sql,
{
@@ -803,7 +798,7 @@ def _migrate_images(context: MigrationExecutionContext) -> int:
"description": migrated_description,
"full_path": full_path,
"image_type": "EMOJI",
"emotion": None,
"emotion": migrated_emotion,
"query_count": _normalize_int(row.get("query_count"), default=0),
"is_registered": _normalize_bool(row.get("is_registered"), default=False),
"is_banned": _normalize_bool(row.get("is_banned"), default=False),

View File

@@ -0,0 +1,269 @@
"""v2 schema 升级到 v3 的迁移逻辑。"""
from typing import Any, Dict, List
from sqlalchemy import text
from sqlalchemy.engine import Connection
from src.common.logger import get_logger
from .exceptions import DatabaseMigrationExecutionError
from .models import MigrationExecutionContext
from .schema import SQLiteSchemaInspector
logger = get_logger("database_migration")
_V2_IMAGES_BACKUP_TABLE = "__v2_images_backup"
_V3_IMAGES_CREATE_SQL = """
CREATE TABLE images (
id INTEGER NOT NULL,
image_hash VARCHAR(255) NOT NULL,
description VARCHAR NOT NULL,
full_path VARCHAR(1024) NOT NULL,
image_type VARCHAR(5),
query_count INTEGER NOT NULL,
is_registered BOOLEAN NOT NULL,
is_banned BOOLEAN NOT NULL,
no_file_flag BOOLEAN NOT NULL,
record_time DATETIME,
register_time DATETIME,
last_used_time DATETIME,
vlm_processed BOOLEAN NOT NULL,
PRIMARY KEY (id)
)
"""
_V3_IMAGES_INDEX_STATEMENTS = (
"CREATE INDEX ix_images_image_hash ON images (image_hash)",
"CREATE INDEX ix_images_record_time ON images (record_time)",
)
def migrate_v2_to_v3(context: MigrationExecutionContext) -> None:
"""执行 v2 到 v3 的 schema 迁移。
Args:
context: 当前迁移步骤执行上下文。
"""
connection = context.connection
total_records = (
_count_table_rows(connection, "action_records")
+ _count_table_rows(connection, "thinking_questions")
+ _count_table_rows(connection, "images")
)
context.start_progress(
total_tables=3,
total_records=total_records,
description="v2 -> v3 迁移进度",
table_unit_name="",
record_unit_name="记录",
)
migrated_tool_records = _migrate_action_records_to_tool_records(connection)
action_record_count = _count_table_rows(connection, "action_records")
_drop_table_if_exists(connection, "action_records")
context.advance_progress(
records=action_record_count,
completed_tables=1,
item_name="action_records",
)
thinking_question_count = _count_table_rows(connection, "thinking_questions")
_drop_table_if_exists(connection, "thinking_questions")
context.advance_progress(
records=thinking_question_count,
completed_tables=1,
item_name="thinking_questions",
)
migrated_image_rows = _migrate_images_table_to_v3(connection)
context.advance_progress(
records=migrated_image_rows,
completed_tables=1,
item_name="images",
)
logger.info(
"v2 -> v3 数据库迁移完成: "
f"tool_records补迁移={migrated_tool_records}"
f"images重建={migrated_image_rows}"
)
def _count_table_rows(connection: Connection, table_name: str) -> int:
"""统计表记录数,不存在时返回 0。"""
schema_inspector = SQLiteSchemaInspector()
if not schema_inspector.table_exists(connection, table_name):
return 0
row = connection.execute(text(f'SELECT COUNT(*) FROM "{table_name}"')).first()
return int(row[0]) if row else 0
def _drop_table_if_exists(connection: Connection, table_name: str) -> None:
"""删除指定表,不存在时静默跳过。"""
connection.exec_driver_sql(f'DROP TABLE IF EXISTS "{table_name}"')
def _migrate_action_records_to_tool_records(connection: Connection) -> int:
"""把 v2 中残留的 ``action_records`` 数据转存到 ``tool_records``。"""
schema_inspector = SQLiteSchemaInspector()
if not schema_inspector.table_exists(connection, "action_records"):
return 0
inserted_count = _count_table_rows(connection, "action_records")
connection.execute(
text(
"""
INSERT INTO tool_records (
tool_id,
timestamp,
session_id,
tool_name,
tool_reasoning,
tool_data,
tool_builtin_prompt,
tool_display_prompt
)
SELECT
action_id,
timestamp,
session_id,
action_name,
action_reasoning,
action_data,
action_builtin_prompt,
action_display_prompt
FROM action_records
WHERE NOT EXISTS (
SELECT 1
FROM tool_records
WHERE tool_records.tool_id = action_records.action_id
)
"""
)
)
return inserted_count
def _migrate_images_table_to_v3(connection: Connection) -> int:
"""重建 ``images`` 表并移除 ``emotion`` 列。"""
schema_inspector = SQLiteSchemaInspector()
if not schema_inspector.table_exists(connection, "images"):
return 0
if not schema_inspector.get_table_schema(connection, "images").has_column("emotion"):
return _count_table_rows(connection, "images")
if schema_inspector.table_exists(connection, _V2_IMAGES_BACKUP_TABLE):
raise DatabaseMigrationExecutionError(
f"检测到残留备份表 {_V2_IMAGES_BACKUP_TABLE},无法安全执行 v2 -> v3 images 迁移。"
)
connection.exec_driver_sql(f'ALTER TABLE "images" RENAME TO "{_V2_IMAGES_BACKUP_TABLE}"')
connection.exec_driver_sql(_V3_IMAGES_CREATE_SQL)
legacy_rows = connection.execute(
text(f'SELECT * FROM "{_V2_IMAGES_BACKUP_TABLE}" ORDER BY id')
).mappings().all()
insert_sql = text(
"""
INSERT INTO images (
id,
image_hash,
description,
full_path,
image_type,
query_count,
is_registered,
is_banned,
no_file_flag,
record_time,
register_time,
last_used_time,
vlm_processed
) VALUES (
:id,
:image_hash,
:description,
:full_path,
:image_type,
:query_count,
:is_registered,
:is_banned,
:no_file_flag,
:record_time,
:register_time,
:last_used_time,
:vlm_processed
)
"""
)
for row in legacy_rows:
payload: Dict[str, Any] = {
"id": row.get("id"),
"image_hash": str(row.get("image_hash") or "").strip(),
"description": _migrate_v3_emoji_description(row),
"full_path": str(row.get("full_path") or "").strip(),
"image_type": row.get("image_type"),
"query_count": int(row.get("query_count") or 0),
"is_registered": bool(row.get("is_registered")),
"is_banned": bool(row.get("is_banned")),
"no_file_flag": bool(row.get("no_file_flag")),
"record_time": row.get("record_time"),
"register_time": row.get("register_time"),
"last_used_time": row.get("last_used_time"),
"vlm_processed": bool(row.get("vlm_processed")),
}
connection.execute(insert_sql, payload)
connection.exec_driver_sql(f'DROP TABLE "{_V2_IMAGES_BACKUP_TABLE}"')
for statement in _V3_IMAGES_INDEX_STATEMENTS:
connection.exec_driver_sql(statement)
return len(legacy_rows)
def _migrate_v3_emoji_description(row: Dict[str, Any]) -> str:
"""为 v3 统一 emoji 描述字段语义。
v3 中 `description` 对 emoji 统一承担“标签列表”的职责,因此迁移时:
1. 若旧 `emotion` 非空,优先将其规范化后写入 `description`
2. 否则保留并规范化当前 `description`
3. 非 emoji 图片保持原描述不变。
"""
image_type = str(row.get("image_type") or "").strip().upper()
current_description = str(row.get("description") or "").strip()
current_emotion = str(row.get("emotion") or "").strip()
if image_type != "EMOJI":
return current_description
normalized_tags = _normalize_emoji_tag_text(current_emotion or current_description)
if normalized_tags:
return ",".join(normalized_tags)
return current_description
def _normalize_emoji_tag_text(raw_value: Any) -> List[str]:
"""将 emoji 标签文本转换为去重后的标签列表。"""
normalized_text = str(raw_value or "").strip()
if not normalized_text:
return []
separators = [",", "", "", ";", "", "\n", "\r", "\t"]
for separator in separators[1:]:
normalized_text = normalized_text.replace(separator, separators[0])
deduped_tags: List[str] = []
seen_tags: set[str] = set()
for part in normalized_text.split(separators[0]):
normalized_part = part.strip()
lowered_part = normalized_part.lower()
if not normalized_part or lowered_part in seen_tags:
continue
seen_tags.add(lowered_part)
deduped_tags.append(normalized_part)
return deduped_tags

View File

@@ -1,16 +1,14 @@
import contextlib
import time
import json
import asyncio
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple, Callable
import contextlib
import json
import time
from typing import Any, Callable, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from src.config.config import global_config
from src.prompt.prompt_manager import prompt_manager
from src.services import llm_service as llm_api
from sqlmodel import select, col
from src.common.database.database import get_db_session
from src.common.database.database_model import ThinkingQuestion
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
@@ -18,35 +16,6 @@ from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
logger = get_logger("memory_retrieval")
THINKING_BACK_NOT_FOUND_RETENTION_SECONDS = 36000 # 未找到答案记录保留时长
THINKING_BACK_CLEANUP_INTERVAL_SECONDS = 3000 # 清理频率
_last_not_found_cleanup_ts: float = 0.0
def _cleanup_stale_not_found_thinking_back() -> None:
"""定期清理过期的未找到答案记录"""
global _last_not_found_cleanup_ts
now = time.time()
if now - _last_not_found_cleanup_ts < THINKING_BACK_CLEANUP_INTERVAL_SECONDS:
return
threshold_time = now - THINKING_BACK_NOT_FOUND_RETENTION_SECONDS
try:
with get_db_session() as session:
statement = select(ThinkingQuestion).where(
col(ThinkingQuestion.found_answer).is_(False)
& (ThinkingQuestion.updated_timestamp < datetime.fromtimestamp(threshold_time))
)
records = session.exec(statement).all()
for record in records:
session.delete(record)
if records:
logger.info(f"清理过期的未找到答案thinking_question记录 {len(records)}")
_last_not_found_cleanup_ts = now
except Exception as e:
logger.error(f"清理未找到答案的thinking_back记录失败: {e}")
def init_memory_retrieval_sys():
"""初始化记忆检索相关工具"""
@@ -766,141 +735,6 @@ async def _react_agent_solve_question(
return False, "", thinking_steps, is_timeout
def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0) -> str:
"""获取最近一段时间内的查询历史(用于避免重复查询)
Args:
chat_id: 聊天ID
time_window_seconds: 时间窗口默认10分钟
Returns:
str: 格式化的查询历史字符串
"""
try:
_current_time = time.time()
with get_db_session() as session:
statement = (
select(ThinkingQuestion)
.where(col(ThinkingQuestion.context) == chat_id)
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
.limit(5)
)
records = session.exec(statement).all()
if not records:
return ""
history_lines = ["最近已查询的问题和结果:"]
for record in records:
status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案"
answer_preview = ""
# 只有找到答案时才显示答案内容
if record.found_answer and record.answer:
# 截取答案前100字符
answer_preview = record.answer[:100]
if len(record.answer) > 100:
answer_preview += "..."
history_lines.extend([f"- 问题:{record.question}", f" 状态:{status}"])
if answer_preview:
history_lines.append(f" 答案:{answer_preview}")
history_lines.append("") # 空行分隔
return "\n".join(history_lines)
except Exception as e:
logger.error(f"获取查询历史失败: {e}")
return ""
def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0) -> List[str]:
"""获取最近一段时间内已找到答案的查询记录(用于返回给 replyer
Args:
chat_id: 聊天ID
time_window_seconds: 时间窗口默认10分钟
Returns:
List[str]: 格式化的答案列表,每个元素格式为 "问题xxx\n答案xxx"
"""
try:
_current_time = time.time()
# 查询最近时间窗口内已找到答案的记录,按更新时间倒序
with get_db_session() as session:
statement = (
select(ThinkingQuestion)
.where(col(ThinkingQuestion.context) == chat_id)
.where(col(ThinkingQuestion.found_answer))
.where(col(ThinkingQuestion.answer).is_not(None))
.where(col(ThinkingQuestion.answer) != "")
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
.limit(3)
)
records = session.exec(statement).all()
if not records:
return []
return [f"问题:{record.question}\n答案:{record.answer}" for record in records if record.answer]
except Exception as e:
logger.error(f"获取最近已找到答案的记录失败: {e}")
return []
def _store_thinking_back(
chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]]
) -> None:
"""存储或更新思考过程到数据库(如果已存在则更新,否则创建)
Args:
chat_id: 聊天ID
question: 问题
context: 上下文信息
found_answer: 是否找到答案
answer: 答案内容
thinking_steps: 思考步骤列表
"""
try:
now = time.time()
# 先查询是否已存在相同chat_id和问题的记录
with get_db_session() as session:
statement = (
select(ThinkingQuestion)
.where(col(ThinkingQuestion.context) == chat_id)
.where(col(ThinkingQuestion.question) == question)
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
.limit(1)
)
if record := session.exec(statement).first():
record.context = context
record.found_answer = found_answer
record.answer = answer
record.thinking_steps = json.dumps(thinking_steps, ensure_ascii=False)
record.updated_timestamp = datetime.fromtimestamp(now)
session.add(record)
logger.info(f"已更新思考过程到数据库,问题: {question[:50]}...")
return
new_record = ThinkingQuestion(
question=question,
context=chat_id,
found_answer=found_answer,
answer=answer,
thinking_steps=json.dumps(thinking_steps, ensure_ascii=False),
created_timestamp=datetime.fromtimestamp(now),
updated_timestamp=datetime.fromtimestamp(now),
)
session.add(new_record)
except Exception as e:
logger.error(f"存储思考过程失败: {e}")
async def _process_memory_retrieval(
chat_id: str,
context: str,
@@ -920,8 +754,6 @@ async def _process_memory_retrieval(
Returns:
Optional[str]: 如果找到答案返回答案内容否则返回None
"""
_cleanup_stale_not_found_thinking_back()
question_initial_info = initial_info or ""
# 直接使用ReAct Agent进行记忆检索

View File

@@ -585,10 +585,8 @@ class RuntimeDataCapabilityMixin:
if not ImageUtils.base64_to_image(emoji_base64, str(temp_file_path)):
return {"success": False, "message": "无法保存图片文件", "description": None, "emotions": None, "replaced": None, "hash": None}
register_success = await emoji_manager.register_emoji_by_filename(temp_file_path)
if not register_success:
if temp_file_path.exists():
temp_file_path.unlink(missing_ok=True)
register_status = await emoji_manager.register_emoji_by_filename(temp_file_path)
if register_status == "failed":
return {
"success": False,
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
@@ -597,6 +595,15 @@ class RuntimeDataCapabilityMixin:
"replaced": None,
"hash": None,
}
if register_status == "skipped":
return {
"success": True,
"message": "表情包已注册,已跳过本次注册",
"description": None,
"emotions": None,
"replaced": False,
"hash": None,
}
count_after = len(emoji_manager.emojis)
replaced = count_after <= count_before

View File

@@ -40,7 +40,7 @@ from .schemas import (
emoji_to_response,
)
from .support import (
EMOJI_REGISTERED_DIR,
EMOJI_DIR,
THUMBNAIL_CACHE_DIR,
background_generate_thumbnail,
cleanup_orphaned_thumbnails,
@@ -326,7 +326,7 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
if emoji.is_registered:
raise HTTPException(status_code=400, detail="该表情包已经注册")
return EmojiUpdateResponse(success=True, message="??????????", data=emoji_to_response(emoji))
emoji.is_registered = True
emoji.is_banned = False
@@ -549,16 +549,16 @@ async def upload_emoji(
if existing_emoji := session.exec(existing_statement).first():
raise HTTPException(status_code=409, detail=f"已存在相同的表情包 (ID: {existing_emoji.id})")
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
os.makedirs(EMOJI_DIR, exist_ok=True)
timestamp = int(datetime.now().timestamp())
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
full_path = os.path.join(EMOJI_DIR, filename)
counter = 1
while os.path.exists(full_path):
filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}"
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
full_path = os.path.join(EMOJI_DIR, filename)
counter += 1
with open(full_path, "wb") as output_file:
@@ -618,7 +618,7 @@ async def batch_upload_emoji(
}
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
os.makedirs(EMOJI_DIR, exist_ok=True)
for file in files:
try:
@@ -665,12 +665,12 @@ async def batch_upload_emoji(
timestamp = int(datetime.now().timestamp())
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
full_path = os.path.join(EMOJI_DIR, filename)
counter = 1
while os.path.exists(full_path):
filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}"
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
full_path = os.path.join(EMOJI_DIR, filename)
counter += 1
with open(full_path, "wb") as output_file:

View File

@@ -16,7 +16,8 @@ logger = get_logger("webui.emoji")
THUMBNAIL_CACHE_DIR = Path("data/emoji_thumbnails")
THUMBNAIL_SIZE = (200, 200)
THUMBNAIL_QUALITY = 80
EMOJI_REGISTERED_DIR = os.path.join("data", "emoji_registed")
EMOJI_REGISTERED_DIR = os.path.join("data", "emoji")
EMOJI_DIR = EMOJI_REGISTERED_DIR
_thumbnail_locks: Dict[str, threading.Lock] = {}
_locks_lock = threading.Lock()