feat:优化表情包注册,迁移数据库v3
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
@@ -31,12 +31,12 @@ install(extra_lines=3)
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve()
|
||||
DATA_DIR = PROJECT_ROOT / "data"
|
||||
EmojiRegisterStatus = Literal["registered", "skipped", "failed"]
|
||||
EMOJI_DIR = DATA_DIR / "emoji" # 表情包存储目录
|
||||
EMOJI_REGISTERED_DIR = DATA_DIR / "emoji_registered" # 已注册的表情包注册目录
|
||||
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
|
||||
|
||||
|
||||
def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
def register_emoji_hook_specs(registry: HookSpecRegistry) -> list[HookSpec]:
|
||||
"""注册表情包系统内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
@@ -145,7 +145,7 @@ def _get_runtime_manager() -> Any:
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
|
||||
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, Any]]:
|
||||
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[dict[str, Any]]:
|
||||
"""将表情包对象序列化为 Hook 可传输载荷。
|
||||
|
||||
Args:
|
||||
@@ -163,27 +163,27 @@ def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, A
|
||||
"file_name": emoji.file_name,
|
||||
"full_path": str(emoji.full_path),
|
||||
"description": emoji.description,
|
||||
"emotions": [str(item).strip() for item in _normalize_emoji_tag_text(emoji.description or emoji.emotion)],
|
||||
"emotions": [str(item).strip() for item in _normalize_emoji_tag_text(emoji.description)],
|
||||
"query_count": int(emoji.query_count),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_emoji_tag_text(raw_values: Any) -> List[str]:
|
||||
def _normalize_emoji_tag_text(raw_values: Any) -> list[str]:
|
||||
"""将文本或标签列表转为去重的情绪标签列表。"""
|
||||
if isinstance(raw_values, str):
|
||||
if not raw_values:
|
||||
return []
|
||||
parts = re.split(r"[,,、;;\s]+", raw_values.strip())
|
||||
parts = re.split(r"[,,、;;\r\n\t]+", raw_values.strip())
|
||||
normalized_tags = [str(part).strip() for part in parts if str(part).strip()]
|
||||
elif isinstance(raw_values, list):
|
||||
normalized_tags: List[str] = []
|
||||
normalized_tags: list[str] = []
|
||||
for value in raw_values:
|
||||
normalized_tags.extend(_normalize_emoji_tag_text(value))
|
||||
else:
|
||||
return []
|
||||
|
||||
deduped_tags: List[str] = []
|
||||
seen: Set[str] = set()
|
||||
deduped_tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for tag in normalized_tags:
|
||||
normalized_tag = tag.strip()
|
||||
if not normalized_tag:
|
||||
@@ -196,15 +196,23 @@ def _normalize_emoji_tag_text(raw_values: Any) -> List[str]:
|
||||
return deduped_tags
|
||||
|
||||
|
||||
def _get_emoji_emotions(emoji: MaiEmoji) -> List[str]:
|
||||
"""获取兼容旧数据的表情包情绪标签。"""
|
||||
return _normalize_emoji_tag_text(emoji.description or emoji.emotion)
|
||||
def _get_emoji_emotions(emoji: MaiEmoji) -> list[str]:
|
||||
"""获取表情包情绪标签。"""
|
||||
return _normalize_emoji_tag_text(emoji.description)
|
||||
|
||||
|
||||
def _ensure_directories() -> None:
|
||||
"""确保表情包相关目录存在"""
|
||||
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
|
||||
EMOJI_REGISTERED_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _is_available_emoji_record(record: Images) -> bool:
|
||||
"""判断数据库记录对应的表情包文件当前是否可用。"""
|
||||
if record.no_file_flag:
|
||||
return False
|
||||
|
||||
record_path = Path(record.full_path)
|
||||
return record_path.exists() and record_path.is_file()
|
||||
|
||||
|
||||
# TODO: 修改这个vlm为获取的vlm client,暂时使用这个VLM方法
|
||||
@@ -224,9 +232,9 @@ class EmojiManager:
|
||||
_ensure_directories()
|
||||
|
||||
self._emoji_num: int = 0
|
||||
self.emojis: List[MaiEmoji] = []
|
||||
self.emojis: list[MaiEmoji] = []
|
||||
self._maintenance_wakeup_event: asyncio.Event = asyncio.Event()
|
||||
self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {}
|
||||
self._pending_description_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._reload_callback_registered: bool = False
|
||||
|
||||
config_manager.register_reload_callback(self.reload_runtime_config)
|
||||
@@ -254,7 +262,7 @@ class EmojiManager:
|
||||
emoji_bytes: Optional[bytes] = None,
|
||||
emoji_hash: Optional[str] = None,
|
||||
wait_for_build: bool = True,
|
||||
) -> Optional[Tuple[str, List[str]]]:
|
||||
) -> Optional[tuple[str, list[str]]]:
|
||||
"""
|
||||
根据表情包哈希获取表情包描述和情感列表的封装方法
|
||||
|
||||
@@ -275,12 +283,28 @@ class EmojiManager:
|
||||
emoji_hash = hashlib.sha256(emoji_bytes).hexdigest()
|
||||
|
||||
if emoji := self.get_emoji_by_hash(emoji_hash):
|
||||
emoji_path = Path(emoji.full_path) if emoji.full_path else None
|
||||
if emoji_bytes and (emoji_path is None or not emoji_path.exists()):
|
||||
try:
|
||||
restored_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
|
||||
emoji.full_path = restored_emoji.full_path
|
||||
emoji.file_name = restored_emoji.file_name
|
||||
except Exception as e:
|
||||
logger.warning(f"表情包缓存命中但本地文件缺失,回填失败: {e}")
|
||||
return emoji.description, _normalize_emoji_tag_text(emoji.description or "")
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
if result := session.exec(statement).first():
|
||||
cached_description = result.description or result.emotion or ""
|
||||
record_path = Path(result.full_path) if result.full_path else None
|
||||
if emoji_bytes and (result.no_file_flag or record_path is None or not record_path.exists()):
|
||||
try:
|
||||
restored_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
|
||||
result.full_path = str(restored_emoji.full_path)
|
||||
result.no_file_flag = False
|
||||
except Exception as e:
|
||||
logger.warning(f"数据库命中表情包记录但本地文件缺失,回填失败: {e}")
|
||||
cached_description = result.description or ""
|
||||
cached_emotions = _normalize_emoji_tag_text(cached_description)
|
||||
return (
|
||||
cached_description,
|
||||
@@ -400,7 +424,7 @@ class EmojiManager:
|
||||
self,
|
||||
emoji_hash: str,
|
||||
emoji_bytes: bytes,
|
||||
) -> Optional[Tuple[str, List[str]]]:
|
||||
) -> Optional[tuple[str, list[str]]]:
|
||||
"""构建并缓存表情包描述(返回标签化结果,不再走额外识别流程)。"""
|
||||
logger.info(f"Start building cached emoji description, hash={emoji_hash}")
|
||||
new_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
|
||||
@@ -461,87 +485,70 @@ class EmojiManager:
|
||||
self._emoji_num = 0
|
||||
raise e
|
||||
|
||||
def register_emoji_to_db(self, emoji: MaiEmoji) -> bool:
|
||||
def register_emoji_to_db(self, emoji: MaiEmoji) -> EmojiRegisterStatus:
|
||||
# sourcery skip: extract-method
|
||||
"""
|
||||
将表情包注册到数据库中
|
||||
|
||||
Args:
|
||||
emoji (MaiEmoji): 需要注册的表情包对象
|
||||
Returns:
|
||||
return (bool): 注册是否成功
|
||||
"""
|
||||
"""Register an emoji in the database without moving its source file."""
|
||||
if not emoji or not isinstance(emoji, MaiEmoji):
|
||||
logger.error("[注册表情包] 无效的表情包对象")
|
||||
return False
|
||||
logger.error("[register_emoji] Invalid emoji object")
|
||||
return "failed"
|
||||
if not emoji.full_path.exists():
|
||||
logger.error(f"[注册表情包] 表情包文件不存在: {emoji.full_path}")
|
||||
return False
|
||||
logger.error(f"[register_emoji] Emoji file does not exist: {emoji.full_path}")
|
||||
return "failed"
|
||||
|
||||
target_path = EMOJI_REGISTERED_DIR / emoji.file_name
|
||||
|
||||
# 先查库,避免重复记录导致文件被误移动后无法回收
|
||||
original_path = emoji.full_path
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
existing_record = session.exec(statement).first()
|
||||
if existing_record and not existing_record.no_file_flag:
|
||||
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
|
||||
return False
|
||||
try:
|
||||
emoji.full_path.replace(target_path)
|
||||
emoji.full_path = target_path
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 移动表情包文件时出错: {e}")
|
||||
return False
|
||||
if existing_record:
|
||||
if existing_record.is_registered and _is_available_emoji_record(existing_record):
|
||||
logger.info(f"[register_emoji] Emoji already registered, skipping: {emoji.file_hash}")
|
||||
return "skipped"
|
||||
|
||||
normalized_description = str(emoji.description or existing_record.description or "").strip()
|
||||
normalized_emotions = _normalize_emoji_tag_text(normalized_description)
|
||||
register_time = existing_record.register_time or datetime.now()
|
||||
|
||||
existing_record.full_path = str(emoji.full_path)
|
||||
existing_record.description = normalized_description
|
||||
existing_record.query_count = max(int(existing_record.query_count), int(emoji.query_count))
|
||||
existing_record.last_used_time = emoji.last_used_time or existing_record.last_used_time
|
||||
existing_record.is_registered = True
|
||||
existing_record.is_banned = False
|
||||
existing_record.no_file_flag = False
|
||||
existing_record.register_time = register_time
|
||||
session.add(existing_record)
|
||||
|
||||
emoji.description = normalized_description
|
||||
emoji.emotion = normalized_emotions
|
||||
emoji.query_count = existing_record.query_count
|
||||
emoji.last_used_time = existing_record.last_used_time
|
||||
emoji.register_time = register_time
|
||||
|
||||
logger.info(
|
||||
f"[register_emoji] Updated existing record and registered emoji, ID: {existing_record.id}, path: {emoji.full_path}"
|
||||
)
|
||||
return "registered"
|
||||
except Exception as e:
|
||||
logger.error(f"[register_emoji] Database query failed: {e}")
|
||||
return "failed"
|
||||
|
||||
# 注册到数据库
|
||||
restore_file = False
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
if existing_record := session.exec(statement).first():
|
||||
if existing_record.no_file_flag:
|
||||
existing_record.no_file_flag = False
|
||||
existing_record.is_banned = False
|
||||
existing_record.full_path = str(emoji.full_path)
|
||||
existing_record.description = emoji.description
|
||||
existing_record.query_count = emoji.query_count
|
||||
existing_record.last_used_time = emoji.last_used_time
|
||||
existing_record.register_time = emoji.register_time
|
||||
session.add(existing_record)
|
||||
logger.info(
|
||||
f"[注册表情包] 更新已有记录并注册表情包到数据库, ID: {existing_record.id}, 路径: {emoji.full_path}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
|
||||
restore_file = True
|
||||
return False
|
||||
else:
|
||||
image_record = emoji.to_db_instance()
|
||||
image_record.is_registered = True
|
||||
image_record.is_banned = False
|
||||
image_record.register_time = datetime.now()
|
||||
session.add(image_record)
|
||||
session.flush()
|
||||
record_id = image_record.id
|
||||
logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}")
|
||||
image_record = emoji.to_db_instance()
|
||||
image_record.is_registered = True
|
||||
image_record.is_banned = False
|
||||
image_record.no_file_flag = False
|
||||
image_record.register_time = datetime.now()
|
||||
session.add(image_record)
|
||||
session.flush()
|
||||
emoji.register_time = image_record.register_time
|
||||
record_id = image_record.id
|
||||
logger.info(f"[register_emoji] Registered emoji to database, ID: {record_id}, path: {emoji.full_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 注册到数据库时出错: {e}")
|
||||
restore_file = True
|
||||
return False
|
||||
finally:
|
||||
if restore_file:
|
||||
try:
|
||||
emoji.full_path.replace(original_path)
|
||||
emoji.full_path = original_path
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 回滚文件移动失败: {e}")
|
||||
return True
|
||||
logger.error(f"[register_emoji] Failed to write database record: {e}")
|
||||
return "failed"
|
||||
|
||||
return "registered"
|
||||
|
||||
def delete_emoji(self, emoji: MaiEmoji, no_desc: bool = False) -> bool:
|
||||
"""
|
||||
@@ -636,7 +643,6 @@ class EmojiManager:
|
||||
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
if image_record := session.exec(statement).first():
|
||||
image_record.description = emoji.description
|
||||
image_record.emotion = None
|
||||
session.add(image_record)
|
||||
logger.info(f"[更新表情包] 成功更新表情包信息: {emoji.file_hash}")
|
||||
else:
|
||||
@@ -794,12 +800,15 @@ class EmojiManager:
|
||||
logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}")
|
||||
if self.delete_emoji(emoji_to_delete):
|
||||
self.emojis.remove(emoji_to_delete)
|
||||
if self.register_emoji_to_db(new_emoji):
|
||||
register_status = self.register_emoji_to_db(new_emoji)
|
||||
if register_status == "registered":
|
||||
self.emojis.append(new_emoji)
|
||||
logger.info(f"[注册表情包] 成功替换并注册新表情包: {new_emoji.description}")
|
||||
logger.info(f"[register_emoji] Replaced old emoji with new emoji: {new_emoji.description}")
|
||||
return True
|
||||
if register_status == "skipped":
|
||||
logger.info(f"[register_emoji] Replacement emoji was already registered: {new_emoji.description}")
|
||||
else:
|
||||
logger.error(f"[注册表情包] 注册新表情包失败: {new_emoji.description}")
|
||||
logger.error(f"[register_emoji] Failed to register replacement emoji: {new_emoji.description}")
|
||||
else:
|
||||
logger.error("[错误] 删除表情包失败,无法完成替换")
|
||||
else:
|
||||
@@ -808,7 +817,7 @@ class EmojiManager:
|
||||
logger.error("[决策] 未能解析删除编号")
|
||||
return False
|
||||
|
||||
async def build_emoji_description(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]:
|
||||
async def build_emoji_description(self, target_emoji: MaiEmoji) -> tuple[bool, MaiEmoji]:
|
||||
"""
|
||||
构建表情包描述
|
||||
|
||||
@@ -915,16 +924,12 @@ class EmojiManager:
|
||||
logger.info(f"[构建描述] 成功为表情包构建情绪标签: {target_emoji.description}")
|
||||
return True, target_emoji
|
||||
|
||||
async def build_emoji_emotion(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]:
|
||||
"""兼容保留:表情包情绪标签已在 build_emoji_description 中一次性构建。"""
|
||||
return await self.build_emoji_description(target_emoji)
|
||||
|
||||
def check_emoji_file_integrity(self) -> None:
|
||||
"""
|
||||
检查表情包完整性,删除文件缺失的表情包记录
|
||||
"""
|
||||
logger.info("[完整性检查] 开始检查表情包文件完整性...")
|
||||
to_delete_emojis: list[Tuple[MaiEmoji, bool]] = []
|
||||
to_delete_emojis: list[tuple[MaiEmoji, bool]] = []
|
||||
removal_count = 0
|
||||
for emoji in self.emojis:
|
||||
if not emoji.full_path.exists():
|
||||
@@ -945,59 +950,34 @@ class EmojiManager:
|
||||
|
||||
logger.info(f"[完整性检查] 表情包文件完整性检查完成,删除了 {removal_count} 条记录")
|
||||
|
||||
def remove_untracked_emoji_files(self) -> None:
|
||||
"""
|
||||
删除未被数据库记录跟踪的表情包文件
|
||||
"""
|
||||
logger.info("[未跟踪表情包清理] 开始清理未被数据库记录跟踪的表情包文件...")
|
||||
tracked_files = {emoji.full_path.name for emoji in self.emojis}
|
||||
all_files = set(EMOJI_REGISTERED_DIR.glob("*"))
|
||||
removal_count = 0
|
||||
|
||||
for file_path in all_files:
|
||||
if file_path.name not in tracked_files:
|
||||
try:
|
||||
file_path.unlink()
|
||||
removal_count += 1
|
||||
logger.info(f"[未跟踪表情包清理] 删除未跟踪的表情包文件: {file_path.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"[未跟踪表情包清理] 删除文件 {file_path.name} 时出错: {e}")
|
||||
|
||||
logger.info(f"[未跟踪表情包清理] 未跟踪表情包文件清理完成,删除了 {removal_count} 个文件")
|
||||
|
||||
async def periodic_emoji_maintenance(self) -> None:
|
||||
"""
|
||||
定期执行表情包维护任务,包括完整性检查和未跟踪文件清理
|
||||
"""
|
||||
"""Run emoji maintenance tasks periodically."""
|
||||
while True:
|
||||
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
|
||||
EMOJI_REGISTERED_DIR.mkdir(parents=True, exist_ok=True)
|
||||
_ensure_directories()
|
||||
try:
|
||||
self.check_emoji_file_integrity()
|
||||
self.remove_untracked_emoji_files()
|
||||
except Exception as e:
|
||||
logger.error(f"[定期维护] 执行表情包维护任务时出错: {e}")
|
||||
logger.error(f"[emoji_maintenance] Maintenance task failed: {e}")
|
||||
|
||||
if global_config.emoji.steal_emoji and (
|
||||
self._emoji_num < global_config.emoji.max_reg_num
|
||||
or (self._emoji_num > global_config.emoji.max_reg_num and global_config.emoji.do_replace)
|
||||
):
|
||||
logger.info("[定期维护] 尝试从表情包盗取目录注册新表情包...")
|
||||
logger.info("[emoji_maintenance] Scanning data/emoji for new emojis...")
|
||||
for emoji_file in EMOJI_DIR.iterdir():
|
||||
if not emoji_file.is_file():
|
||||
continue
|
||||
try:
|
||||
register_success = await self.register_emoji_by_filename(emoji_file)
|
||||
register_status = await self.register_emoji_by_filename(emoji_file)
|
||||
except Exception as e:
|
||||
logger.error(f"[定期维护] 注册表情包 {emoji_file.name} 时发生未处理异常: {e}")
|
||||
register_success = False
|
||||
if register_success:
|
||||
break # 每次只注册一个表情包
|
||||
try:
|
||||
emoji_file.unlink()
|
||||
logger.info(f"[定期维护] 删除无法注册的表情包文件: {emoji_file.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"[定期维护] 删除文件 {emoji_file.name} 时出错: {e}")
|
||||
logger.error(f"[emoji_maintenance] Failed to process {emoji_file.name}: {e}")
|
||||
register_status = "failed"
|
||||
if register_status == "registered":
|
||||
break
|
||||
if register_status == "skipped":
|
||||
logger.debug(f"[emoji_maintenance] Emoji already registered, keep file: {emoji_file.name}")
|
||||
else:
|
||||
logger.debug(f"[emoji_maintenance] Emoji not registered, keep file: {emoji_file.name}")
|
||||
wait_seconds = max(global_config.emoji.check_interval * 60, 0)
|
||||
try:
|
||||
await asyncio.wait_for(self._maintenance_wakeup_event.wait(), timeout=wait_seconds)
|
||||
@@ -1006,84 +986,81 @@ class EmojiManager:
|
||||
finally:
|
||||
self._maintenance_wakeup_event.clear()
|
||||
|
||||
async def register_emoji_by_filename(self, filename: Path | str) -> bool:
|
||||
"""
|
||||
根据指定的表情包图片,分析并注册到数据库
|
||||
|
||||
Args:
|
||||
filename (Path | str): 表情包图片的完整文件路径(可能根据文件实际格式修正)
|
||||
|
||||
Returns:
|
||||
return (bool): 注册是否成功
|
||||
"""
|
||||
async def register_emoji_by_filename(self, filename: Path | str) -> EmojiRegisterStatus:
|
||||
"""Register an emoji file from ``data/emoji`` without moving or deleting it."""
|
||||
file_full_path = Path(filename).absolute().resolve()
|
||||
if not file_full_path.exists():
|
||||
logger.error(f"[注册表情包] 表情包文件不存在: {file_full_path}")
|
||||
return False
|
||||
logger.error(f"[register_emoji] Emoji file does not exist: {file_full_path}")
|
||||
return "failed"
|
||||
try:
|
||||
target_emoji = MaiEmoji(full_path=file_full_path)
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 创建表情包对象时出错: {e}")
|
||||
return False
|
||||
logger.error(f"[register_emoji] Failed to create emoji object: {e}")
|
||||
return "failed"
|
||||
|
||||
calc_success = await target_emoji.calculate_hash_format()
|
||||
if not calc_success:
|
||||
logger.error(f"[注册表情包] 计算表情包哈希值和格式失败: {file_full_path}")
|
||||
return False
|
||||
file_full_path = target_emoji.full_path # 更新为可能修正后的路径
|
||||
logger.error(f"[register_emoji] Failed to calculate hash and format: {file_full_path}")
|
||||
return "failed"
|
||||
file_full_path = target_emoji.full_path
|
||||
|
||||
# 2. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建
|
||||
existing_record: Optional[Images] = None
|
||||
try:
|
||||
with get_db_session_manual() as session:
|
||||
statement = (
|
||||
select(Images).filter_by(image_hash=target_emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
)
|
||||
if image_record := session.exec(statement).first():
|
||||
if image_record.no_file_flag:
|
||||
image_record.no_file_flag = False
|
||||
image_record.is_banned = False
|
||||
image_record.is_registered = True
|
||||
image_record.full_path = str(target_emoji.full_path)
|
||||
session.add(image_record)
|
||||
session.commit()
|
||||
logger.info(f"表情包注册成功,Hash: {target_emoji.file_hash}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"[注册表情包] 数据库中已存在表情包记录,跳过注册: {target_emoji.file_name}")
|
||||
return False
|
||||
existing_record = session.exec(statement).first()
|
||||
except Exception as e:
|
||||
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
|
||||
return False
|
||||
logger.error(f"[register_emoji] Failed to query database: {e}")
|
||||
return "failed"
|
||||
|
||||
if existing_record is not None:
|
||||
if existing_record.is_registered and _is_available_emoji_record(existing_record):
|
||||
logger.info(f"[register_emoji] Emoji already registered, skipping: {target_emoji.file_name}")
|
||||
return "skipped"
|
||||
|
||||
cached_description = str(existing_record.description or "").strip()
|
||||
if cached_description:
|
||||
normalized_emotions = _normalize_emoji_tag_text(cached_description)
|
||||
target_emoji.description = ",".join(normalized_emotions)
|
||||
target_emoji.emotion = normalized_emotions
|
||||
target_emoji.query_count = existing_record.query_count
|
||||
target_emoji.last_used_time = existing_record.last_used_time
|
||||
target_emoji.register_time = existing_record.register_time
|
||||
|
||||
# 3. 检查内存缓存是否已经存在
|
||||
if existing_emoji := self.get_emoji_by_hash(target_emoji.file_hash):
|
||||
logger.warning(f"[注册表情包] 表情包已存在,跳过注册: {existing_emoji.file_name}")
|
||||
return False
|
||||
# 3. 构建描述(包含情绪标签)
|
||||
desc_success, target_emoji = await self.build_emoji_description(target_emoji)
|
||||
if not desc_success:
|
||||
logger.error(f"[注册表情包] 构建表情包描述失败: {file_full_path}")
|
||||
return False
|
||||
logger.info(f"[register_emoji] Emoji already loaded in memory, skipping: {existing_emoji.file_name}")
|
||||
return "skipped"
|
||||
|
||||
if not target_emoji.description:
|
||||
desc_success, target_emoji = await self.build_emoji_description(target_emoji)
|
||||
if not desc_success:
|
||||
logger.error(f"[register_emoji] Failed to build emoji description: {file_full_path}")
|
||||
return "failed"
|
||||
|
||||
# 4. 检查容量并决定是否替换或者直接注册
|
||||
if self._emoji_num >= global_config.emoji.max_reg_num and global_config.emoji.do_replace:
|
||||
logger.warning(f"[注册表情包] 表情包数量已达上限{global_config.emoji.max_reg_num},尝试替换一个表情包")
|
||||
logger.warning(
|
||||
f"[register_emoji] Emoji limit {global_config.emoji.max_reg_num} reached, trying replacement"
|
||||
)
|
||||
replaced = await self.replace_an_emoji_by_llm(target_emoji)
|
||||
if not replaced:
|
||||
logger.error("[注册表情包] 替换表情包失败,无法注册新表情包")
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
if self.register_emoji_to_db(target_emoji):
|
||||
self.emojis.append(target_emoji)
|
||||
self._emoji_num += 1
|
||||
logger.info(f"[注册表情包] 成功注册新表情包: {target_emoji.file_name}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[注册表情包] 注册表情包到数据库失败: {file_full_path}")
|
||||
return False
|
||||
logger.error("[register_emoji] Failed to replace existing emoji")
|
||||
return "failed"
|
||||
return "registered"
|
||||
|
||||
def _calculate_emotion_similarity_list(self, text_emotion: str) -> List[Tuple[MaiEmoji, float]]:
|
||||
register_status = self.register_emoji_to_db(target_emoji)
|
||||
if register_status == "registered":
|
||||
self.emojis.append(target_emoji)
|
||||
self._emoji_num = len(self.emojis)
|
||||
logger.info(f"[register_emoji] Registered new emoji: {target_emoji.file_name}")
|
||||
elif register_status == "failed":
|
||||
logger.error(f"[register_emoji] Failed to register emoji in database: {file_full_path}")
|
||||
else:
|
||||
logger.info(f"[register_emoji] Emoji already registered, skipping: {target_emoji.file_name}")
|
||||
return register_status
|
||||
|
||||
def _calculate_emotion_similarity_list(self, text_emotion: str) -> list[tuple[MaiEmoji, float]]:
|
||||
"""
|
||||
计算文本情感标签与所有表情包情感标签的相似度列表
|
||||
|
||||
@@ -1096,7 +1073,7 @@ class EmojiManager:
|
||||
if not normalized_text_emotion:
|
||||
return []
|
||||
|
||||
similarity_list: List[Tuple[MaiEmoji, float]] = []
|
||||
similarity_list: list[tuple[MaiEmoji, float]] = []
|
||||
for emoji in self.emojis:
|
||||
candidate_emotions = _get_emoji_emotions(emoji)
|
||||
if not candidate_emotions:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Maisaka 表情工具内置能力。"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Sequence, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import random
|
||||
|
||||
@@ -120,8 +120,6 @@ def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
|
||||
"""提取并清洗单个表情的情绪标签。"""
|
||||
if emoji.description:
|
||||
return _normalize_emoji_tag_text(emoji.description)
|
||||
if emoji.emotion:
|
||||
return _normalize_emoji_tag_text(emoji.emotion)
|
||||
return []
|
||||
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ class ImageManager:
|
||||
"""初始化图片管理器。"""
|
||||
_ensure_image_dir_exists()
|
||||
self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {}
|
||||
self.cleanup_legacy_image_registration_records()
|
||||
|
||||
logger.info("图片管理器初始化完成")
|
||||
|
||||
@@ -50,6 +51,17 @@ class ImageManager:
|
||||
statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1)
|
||||
return session.exec(statement).first()
|
||||
|
||||
def _normalize_image_registration_fields(self, record: Images) -> bool:
|
||||
"""Normalize accidental emoji registration fields on image records."""
|
||||
if record.image_type != ImageType.IMAGE:
|
||||
return False
|
||||
if not record.is_registered and record.register_time is None:
|
||||
return False
|
||||
|
||||
record.is_registered = False
|
||||
record.register_time = None
|
||||
return True
|
||||
|
||||
async def get_image_description(
|
||||
self,
|
||||
*,
|
||||
@@ -182,8 +194,9 @@ class ImageManager:
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
record = image.to_db_instance()
|
||||
record.is_registered = True
|
||||
record.register_time = record.last_used_time = datetime.now()
|
||||
record.is_registered = False
|
||||
record.register_time = None
|
||||
record.last_used_time = datetime.now()
|
||||
session.add(record)
|
||||
session.flush() # 确保记录被写入数据库以获取ID
|
||||
record_id = record.id
|
||||
@@ -209,6 +222,7 @@ class ImageManager:
|
||||
if not record:
|
||||
logger.error(f"未找到哈希值为 {image.file_hash} 的图片记录,无法更新描述")
|
||||
return False
|
||||
self._normalize_image_registration_fields(record)
|
||||
record.description = image.description
|
||||
record.last_used_time = datetime.now()
|
||||
record.vlm_processed = image.vlm_processed
|
||||
@@ -258,6 +272,7 @@ class ImageManager:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1)
|
||||
if record := session.exec(statement).first():
|
||||
self._normalize_image_registration_fields(record)
|
||||
logger.info(f"图片已存在于数据库中,哈希值: {hash_str}")
|
||||
record.last_used_time = datetime.now()
|
||||
record.query_count += 1
|
||||
@@ -335,6 +350,24 @@ class ImageManager:
|
||||
|
||||
logger.info(f"清理完成: {invalid_counter} 条无效描述记录,{null_path_counter} 条文件路径不存在记录")
|
||||
|
||||
def cleanup_legacy_image_registration_records(self) -> None:
|
||||
"""Clean up legacy image records with mistaken registration fields."""
|
||||
fixed_counter = 0
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_type=ImageType.IMAGE)
|
||||
for record in session.exec(statement).yield_per(100):
|
||||
if not self._normalize_image_registration_fields(record):
|
||||
continue
|
||||
session.add(record)
|
||||
fixed_counter += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clean image registration state: {e}")
|
||||
return
|
||||
|
||||
if fixed_counter:
|
||||
logger.info(f"Cleaned mistaken registration state on {fixed_counter} image records")
|
||||
|
||||
async def _generate_image_description(self, image_bytes: bytes, image_format: str) -> str:
|
||||
prompt = global_config.personality.visual_style
|
||||
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import json
|
||||
|
||||
from src.common.database.database_model import ActionRecord
|
||||
from src.common.database.database_model import ToolRecord
|
||||
|
||||
from . import BaseDatabaseDataModel
|
||||
|
||||
|
||||
class MaiActionRecord(BaseDatabaseDataModel[ActionRecord]):
|
||||
class MaiActionRecord(BaseDatabaseDataModel[ToolRecord]):
|
||||
"""``action_records`` 的兼容数据模型。
|
||||
|
||||
历史动作记录已统一并入 ``tool_records``,该类仅保留旧命名接口,
|
||||
底层读写对象统一映射为 ``ToolRecord``。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_id: str,
|
||||
@@ -21,45 +27,39 @@ class MaiActionRecord(BaseDatabaseDataModel[ActionRecord]):
|
||||
action_display_prompt: Optional[str] = None,
|
||||
):
|
||||
self.action_id = action_id
|
||||
"""动作ID"""
|
||||
self.timestamp = timestamp
|
||||
"""时间戳"""
|
||||
self.session_id = session_id
|
||||
"""会话ID"""
|
||||
self.action_name = action_name
|
||||
"""动作名称"""
|
||||
self.action_reasoning = action_reasoning
|
||||
"""动作推理过程"""
|
||||
self.action_data = action_data or {}
|
||||
"""动作数据"""
|
||||
self.action_builtin_prompt = action_builtin_prompt
|
||||
"""内置动作提示"""
|
||||
self.action_display_prompt = action_display_prompt
|
||||
"""最终输入到 Prompt 的内容"""
|
||||
|
||||
@classmethod
|
||||
def from_db_instance(cls, db_record: ActionRecord):
|
||||
"""Create a data model object from a database record."""
|
||||
def from_db_instance(cls, db_record: ToolRecord):
|
||||
"""从数据库实例创建兼容数据模型对象。"""
|
||||
|
||||
return cls(
|
||||
action_id=db_record.action_id,
|
||||
action_id=db_record.tool_id,
|
||||
timestamp=db_record.timestamp,
|
||||
session_id=db_record.session_id,
|
||||
action_name=db_record.action_name,
|
||||
action_reasoning=db_record.action_reasoning,
|
||||
action_data=json.loads(db_record.action_data) if db_record.action_data else None,
|
||||
action_builtin_prompt=db_record.action_builtin_prompt,
|
||||
action_display_prompt=db_record.action_display_prompt,
|
||||
action_name=db_record.tool_name,
|
||||
action_reasoning=db_record.tool_reasoning,
|
||||
action_data=json.loads(db_record.tool_data) if db_record.tool_data else None,
|
||||
action_builtin_prompt=db_record.tool_builtin_prompt,
|
||||
action_display_prompt=db_record.tool_display_prompt,
|
||||
)
|
||||
|
||||
def to_db_instance(self):
|
||||
"""Convert the data model object back to a database instance."""
|
||||
return ActionRecord(
|
||||
action_id=self.action_id,
|
||||
"""将兼容数据模型对象转换为 ``ToolRecord``。"""
|
||||
|
||||
return ToolRecord(
|
||||
tool_id=self.action_id,
|
||||
timestamp=self.timestamp,
|
||||
session_id=self.session_id,
|
||||
action_name=self.action_name,
|
||||
action_reasoning=self.action_reasoning,
|
||||
action_data=json.dumps(self.action_data) if self.action_data else None,
|
||||
action_builtin_prompt=self.action_builtin_prompt,
|
||||
action_display_prompt=self.action_display_prompt,
|
||||
tool_name=self.action_name,
|
||||
tool_reasoning=self.action_reasoning,
|
||||
tool_data=json.dumps(self.action_data) if self.action_data else None,
|
||||
tool_builtin_prompt=self.action_builtin_prompt,
|
||||
tool_display_prompt=self.action_display_prompt,
|
||||
)
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from PIL import Image as PILImage
|
||||
from rich.traceback import install
|
||||
from typing import Optional, List
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import io
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from . import BaseDatabaseDataModel
|
||||
|
||||
|
||||
@@ -152,7 +153,7 @@ class MaiEmoji(BaseImageDataModel):
|
||||
raise ValueError(f"数据库记录 {db_record.image_hash} 标记为文件不存在,无法创建 MaiEmoji 对象")
|
||||
obj = cls(db_record.full_path)
|
||||
obj.file_hash = db_record.image_hash
|
||||
description = db_record.description or db_record.emotion or ""
|
||||
description = db_record.description or ""
|
||||
obj.description = description
|
||||
normalized_tags = [
|
||||
str(item).strip()
|
||||
@@ -175,7 +176,6 @@ class MaiEmoji(BaseImageDataModel):
|
||||
description=self.description,
|
||||
full_path=str(self.full_path),
|
||||
image_type=ImageType.EMOJI,
|
||||
emotion=None,
|
||||
query_count=self.query_count,
|
||||
last_used_time=self.last_used_time,
|
||||
register_time=self.register_time,
|
||||
|
||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
from typing import ContextManager, Generator, TYPE_CHECKING
|
||||
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import event, text
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import SQLModel, Session, create_engine
|
||||
@@ -63,41 +63,6 @@ _migration_bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
_db_initialized = False
|
||||
|
||||
|
||||
def _migrate_action_records_to_tool_records() -> None:
|
||||
"""将旧的 ``action_records`` 历史数据迁移到 ``tool_records``。"""
|
||||
migration_sql = text(
|
||||
"""
|
||||
INSERT INTO tool_records (
|
||||
tool_id,
|
||||
timestamp,
|
||||
session_id,
|
||||
tool_name,
|
||||
tool_reasoning,
|
||||
tool_data,
|
||||
tool_builtin_prompt,
|
||||
tool_display_prompt
|
||||
)
|
||||
SELECT
|
||||
action_id,
|
||||
timestamp,
|
||||
session_id,
|
||||
action_name,
|
||||
action_reasoning,
|
||||
action_data,
|
||||
action_builtin_prompt,
|
||||
action_display_prompt
|
||||
FROM action_records
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM tool_records
|
||||
WHERE tool_records.tool_id = action_records.action_id
|
||||
)
|
||||
"""
|
||||
)
|
||||
with engine.begin() as connection:
|
||||
connection.execute(migration_sql)
|
||||
|
||||
|
||||
def initialize_database() -> None:
|
||||
"""初始化数据库连接、结构与启动期迁移。
|
||||
|
||||
@@ -105,8 +70,7 @@ def initialize_database() -> None:
|
||||
1. 确保数据库目录存在;
|
||||
2. 加载 SQLModel 模型定义;
|
||||
3. 执行已注册的启动期迁移;
|
||||
4. 兜底执行 ``create_all`` 确保当前模型定义已建表;
|
||||
5. 执行项目现有的轻量数据补迁移逻辑。
|
||||
4. 兜底执行 ``create_all`` 确保当前模型定义已建表。
|
||||
"""
|
||||
global _db_initialized
|
||||
if _db_initialized:
|
||||
@@ -120,7 +84,6 @@ def initialize_database() -> None:
|
||||
f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}"
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
_migrate_action_records_to_tool_records()
|
||||
_migration_bootstrapper.finalize_database(migration_state)
|
||||
_db_initialized = True
|
||||
|
||||
|
||||
@@ -94,7 +94,6 @@ class Images(SQLModel, table=True):
|
||||
full_path: str = Field(max_length=1024) # 文件的完整路径 (包括文件名)
|
||||
image_type: ImageType = Field(sa_column=Column(SQLEnum(ImageType)), default=ImageType.EMOJI)
|
||||
"""图片类型,例如 'emoji' 或 'image'"""
|
||||
emotion: Optional[str] = Field(default=None, nullable=True) # 表情包的情感标签,逗号分隔
|
||||
|
||||
query_count: int = Field(default=0) # 被查询次数
|
||||
is_registered: bool = Field(default=False) # 是否已经注册
|
||||
@@ -113,27 +112,6 @@ class Images(SQLModel, table=True):
|
||||
vlm_processed: bool = Field(default=False) # 是否已经过VLM处理
|
||||
|
||||
|
||||
class ActionRecord(SQLModel, table=True):
|
||||
"""存储动作记录"""
|
||||
|
||||
__tablename__ = "action_records" # type: ignore
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||
|
||||
# 元信息
|
||||
action_id: str = Field(index=True, max_length=255) # 动作ID
|
||||
timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 记录时间戳
|
||||
session_id: str = Field(index=True, max_length=255) # 对应的 ChatSession session_id
|
||||
|
||||
# 调用信息
|
||||
action_name: str = Field(index=True, max_length=255) # 动作名称
|
||||
action_reasoning: Optional[str] = Field(default=None) # 动作推理过程
|
||||
action_data: Optional[str] = Field(default=None) # 动作数据,JSON格式存储
|
||||
|
||||
action_builtin_prompt: Optional[str] = Field(default=None) # 内置动作提示
|
||||
action_display_prompt: Optional[str] = Field(default=None) # 最终输入到Prompt的内容
|
||||
|
||||
|
||||
class ToolRecord(SQLModel, table=True):
|
||||
"""存储工具调用记录"""
|
||||
|
||||
@@ -281,28 +259,6 @@ class ChatHistory(SQLModel, table=True):
|
||||
summary: str # 概括:对这段话的平文本概括
|
||||
|
||||
|
||||
class ThinkingQuestion(SQLModel, table=True):
|
||||
"""存储思考型问题的模型"""
|
||||
|
||||
__tablename__ = "thinking_questions" # type: ignore
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||
|
||||
# 问答对
|
||||
question: str # 问题内容
|
||||
context: Optional[str] = Field(default=None, nullable=True) # 上下文
|
||||
found_answer: bool = Field(default=False) # 是否找到答案
|
||||
answer: Optional[str] = Field(default=None, nullable=True) # 问题答案
|
||||
|
||||
thinking_steps: Optional[str] = Field(default=None, nullable=True) # 思考步骤,JSON格式存储
|
||||
created_timestamp: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||
) # 创建时间
|
||||
updated_timestamp: datetime = Field(
|
||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
||||
) # 最后更新时间
|
||||
|
||||
|
||||
class BinaryData(SQLModel, table=True):
|
||||
"""存储二进制数据的模型"""
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from .builtin import (
|
||||
EMPTY_SCHEMA_VERSION,
|
||||
LATEST_SCHEMA_VERSION,
|
||||
LEGACY_V1_SCHEMA_VERSION,
|
||||
V2_SCHEMA_VERSION,
|
||||
build_default_migration_registry,
|
||||
build_default_schema_version_resolver,
|
||||
)
|
||||
@@ -61,6 +62,7 @@ __all__ = [
|
||||
"EMPTY_SCHEMA_VERSION",
|
||||
"LATEST_SCHEMA_VERSION",
|
||||
"LEGACY_V1_SCHEMA_VERSION",
|
||||
"V2_SCHEMA_VERSION",
|
||||
"MigrationExecutionContext",
|
||||
"MigrationPlan",
|
||||
"MigrationPlanner",
|
||||
|
||||
@@ -6,12 +6,14 @@ from .legacy_v1_to_v2 import migrate_legacy_v1_to_v2
|
||||
from .models import DatabaseSchemaSnapshot, MigrationStep
|
||||
from .registry import MigrationRegistry
|
||||
from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver
|
||||
from .version_store import SQLiteUserVersionStore
|
||||
from .schema import SQLiteSchemaInspector
|
||||
from .v2_to_v3 import migrate_v2_to_v3
|
||||
from .version_store import SQLiteUserVersionStore
|
||||
|
||||
EMPTY_SCHEMA_VERSION = 0
|
||||
LEGACY_V1_SCHEMA_VERSION = 1
|
||||
LATEST_SCHEMA_VERSION = 2
|
||||
V2_SCHEMA_VERSION = 2
|
||||
LATEST_SCHEMA_VERSION = 3
|
||||
|
||||
_LEGACY_V1_EXCLUSIVE_TABLES = (
|
||||
"chat_streams",
|
||||
@@ -24,6 +26,13 @@ _LEGACY_V1_EXCLUSIVE_TABLES = (
|
||||
"messages",
|
||||
"thinking_back",
|
||||
)
|
||||
_COMMON_MARKER_TABLES = (
|
||||
"mai_messages",
|
||||
"chat_sessions",
|
||||
"expressions",
|
||||
"jargons",
|
||||
"tool_records",
|
||||
)
|
||||
|
||||
|
||||
class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
|
||||
@@ -36,6 +45,7 @@ class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
|
||||
Returns:
|
||||
str: 当前探测器名称。
|
||||
"""
|
||||
|
||||
return "latest_schema_detector"
|
||||
|
||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||
@@ -47,18 +57,16 @@ class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
|
||||
Returns:
|
||||
Optional[int]: 若识别为最新结构则返回最新版本号,否则返回 ``None``。
|
||||
"""
|
||||
|
||||
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
|
||||
return None
|
||||
|
||||
latest_marker_tables = (
|
||||
"mai_messages",
|
||||
"chat_sessions",
|
||||
"expressions",
|
||||
"jargons",
|
||||
"thinking_questions",
|
||||
"tool_records",
|
||||
)
|
||||
if not all(snapshot.has_table(table_name) for table_name in latest_marker_tables):
|
||||
if not all(snapshot.has_table(table_name) for table_name in _COMMON_MARKER_TABLES):
|
||||
return None
|
||||
if snapshot.has_table("action_records"):
|
||||
return None
|
||||
if snapshot.has_table("thinking_questions"):
|
||||
return None
|
||||
if snapshot.has_column("images", "emotion"):
|
||||
return None
|
||||
if not snapshot.has_column("images", "image_hash"):
|
||||
return None
|
||||
@@ -66,13 +74,53 @@ class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
|
||||
return None
|
||||
if not snapshot.has_column("images", "image_type"):
|
||||
return None
|
||||
if not snapshot.has_column("chat_history", "session_id"):
|
||||
return None
|
||||
if not snapshot.has_column("person_info", "user_nickname"):
|
||||
return None
|
||||
return LATEST_SCHEMA_VERSION
|
||||
|
||||
|
||||
class V2SchemaVersionDetector(BaseSchemaVersionDetector):
|
||||
"""v2 schema 结构探测器。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""返回探测器名称。
|
||||
|
||||
Returns:
|
||||
str: 当前探测器名称。
|
||||
"""
|
||||
|
||||
return "v2_schema_detector"
|
||||
|
||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||
"""检测数据库是否为 v2 结构。
|
||||
|
||||
Args:
|
||||
snapshot: 当前数据库结构快照。
|
||||
|
||||
Returns:
|
||||
Optional[int]: 若识别为 v2 结构则返回 ``2``,否则返回 ``None``。
|
||||
"""
|
||||
|
||||
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
|
||||
return None
|
||||
if not all(snapshot.has_table(table_name) for table_name in _COMMON_MARKER_TABLES):
|
||||
return None
|
||||
if not snapshot.has_table("action_records"):
|
||||
return None
|
||||
if not snapshot.has_table("thinking_questions"):
|
||||
return None
|
||||
if not snapshot.has_column("images", "emotion"):
|
||||
return None
|
||||
if not snapshot.has_column("action_records", "session_id"):
|
||||
return None
|
||||
if not snapshot.has_column("chat_history", "session_id"):
|
||||
return None
|
||||
if not snapshot.has_column("person_info", "user_nickname"):
|
||||
return None
|
||||
return LATEST_SCHEMA_VERSION
|
||||
return V2_SCHEMA_VERSION
|
||||
|
||||
|
||||
class LegacyV1SchemaDetector(BaseSchemaVersionDetector):
|
||||
@@ -85,6 +133,7 @@ class LegacyV1SchemaDetector(BaseSchemaVersionDetector):
|
||||
Returns:
|
||||
str: 当前探测器名称。
|
||||
"""
|
||||
|
||||
return "legacy_v1_schema_detector"
|
||||
|
||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||
@@ -96,6 +145,7 @@ class LegacyV1SchemaDetector(BaseSchemaVersionDetector):
|
||||
Returns:
|
||||
Optional[int]: 若识别为旧版结构则返回 ``1``,否则返回 ``None``。
|
||||
"""
|
||||
|
||||
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
|
||||
return LEGACY_V1_SCHEMA_VERSION
|
||||
|
||||
@@ -121,8 +171,10 @@ def build_default_schema_version_detectors() -> List[BaseSchemaVersionDetector]:
|
||||
Returns:
|
||||
List[BaseSchemaVersionDetector]: 按优先级排序的探测器列表。
|
||||
"""
|
||||
|
||||
return [
|
||||
LatestSchemaVersionDetector(),
|
||||
V2SchemaVersionDetector(),
|
||||
LegacyV1SchemaDetector(),
|
||||
]
|
||||
|
||||
@@ -133,6 +185,7 @@ def build_default_schema_version_resolver() -> SchemaVersionResolver:
|
||||
Returns:
|
||||
SchemaVersionResolver: 配置完成的 schema 版本解析器。
|
||||
"""
|
||||
|
||||
return SchemaVersionResolver(
|
||||
version_store=SQLiteUserVersionStore(),
|
||||
schema_inspector=SQLiteSchemaInspector(),
|
||||
@@ -146,14 +199,22 @@ def build_default_migration_registry() -> MigrationRegistry:
|
||||
Returns:
|
||||
MigrationRegistry: 含默认迁移步骤的注册表实例。
|
||||
"""
|
||||
|
||||
return MigrationRegistry(
|
||||
steps=[
|
||||
MigrationStep(
|
||||
version_from=LEGACY_V1_SCHEMA_VERSION,
|
||||
version_to=LATEST_SCHEMA_VERSION,
|
||||
name="legacy_v1_to_latest_v2",
|
||||
description="将旧版 0.x 数据库整体迁移到当前最新 schema。",
|
||||
version_to=V2_SCHEMA_VERSION,
|
||||
name="legacy_v1_to_v2",
|
||||
description="将旧版 0.x 数据库迁移到 v2 schema。",
|
||||
handler=migrate_legacy_v1_to_v2,
|
||||
)
|
||||
),
|
||||
MigrationStep(
|
||||
version_from=V2_SCHEMA_VERSION,
|
||||
version_to=LATEST_SCHEMA_VERSION,
|
||||
name="v2_to_v3",
|
||||
description="移除废弃表,并将 emoji 标签统一收敛到 description 字段。",
|
||||
handler=migrate_v2_to_v3,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
298
src/common/database/migrations/frozen_v2_schema.py
Normal file
298
src/common/database/migrations/frozen_v2_schema.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""冻结的 v2 schema 快照。
|
||||
|
||||
该模块只用于 ``legacy_v1_to_v2`` 迁移,避免迁移过程依赖当前运行时代码中的
|
||||
最新 SQLModel 定义,导致历史迁移随着后续 schema 演进而失真。
|
||||
"""
|
||||
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
_V2_TABLE_STATEMENTS = (
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS action_records (
|
||||
id INTEGER NOT NULL,
|
||||
action_id VARCHAR(255) NOT NULL,
|
||||
timestamp DATETIME,
|
||||
session_id VARCHAR(255) NOT NULL,
|
||||
action_name VARCHAR(255) NOT NULL,
|
||||
action_reasoning VARCHAR,
|
||||
action_data VARCHAR,
|
||||
action_builtin_prompt VARCHAR,
|
||||
action_display_prompt VARCHAR,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS binary_data (
|
||||
id INTEGER NOT NULL,
|
||||
data_hash VARCHAR(255) NOT NULL,
|
||||
full_path VARCHAR(1024) NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS chat_history (
|
||||
id INTEGER NOT NULL,
|
||||
session_id VARCHAR(255) NOT NULL,
|
||||
start_timestamp DATETIME,
|
||||
end_timestamp DATETIME,
|
||||
query_count INTEGER NOT NULL,
|
||||
query_forget_count INTEGER NOT NULL,
|
||||
original_messages VARCHAR NOT NULL,
|
||||
participants VARCHAR NOT NULL,
|
||||
theme VARCHAR NOT NULL,
|
||||
keywords VARCHAR NOT NULL,
|
||||
summary VARCHAR NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS chat_sessions (
|
||||
id INTEGER NOT NULL,
|
||||
session_id VARCHAR(255) NOT NULL,
|
||||
created_timestamp DATETIME,
|
||||
last_active_timestamp DATETIME,
|
||||
user_id VARCHAR(255),
|
||||
group_id VARCHAR(255),
|
||||
platform VARCHAR(100) NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS command_records (
|
||||
id INTEGER NOT NULL,
|
||||
timestamp DATETIME,
|
||||
session_id VARCHAR(255) NOT NULL,
|
||||
command_name VARCHAR(255) NOT NULL,
|
||||
command_data VARCHAR,
|
||||
command_result VARCHAR,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS expressions (
|
||||
id INTEGER NOT NULL,
|
||||
situation VARCHAR(255) NOT NULL,
|
||||
style VARCHAR(255) NOT NULL,
|
||||
content_list VARCHAR NOT NULL,
|
||||
count INTEGER NOT NULL,
|
||||
last_active_time DATETIME,
|
||||
create_time DATETIME,
|
||||
session_id VARCHAR(255),
|
||||
checked BOOLEAN NOT NULL,
|
||||
rejected BOOLEAN NOT NULL,
|
||||
modified_by VARCHAR(4),
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
id INTEGER NOT NULL,
|
||||
image_hash VARCHAR(255) NOT NULL,
|
||||
description VARCHAR NOT NULL,
|
||||
full_path VARCHAR(1024) NOT NULL,
|
||||
image_type VARCHAR(5),
|
||||
emotion VARCHAR,
|
||||
query_count INTEGER NOT NULL,
|
||||
is_registered BOOLEAN NOT NULL,
|
||||
is_banned BOOLEAN NOT NULL,
|
||||
no_file_flag BOOLEAN NOT NULL,
|
||||
record_time DATETIME,
|
||||
register_time DATETIME,
|
||||
last_used_time DATETIME,
|
||||
vlm_processed BOOLEAN NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS jargons (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
||||
content VARCHAR(255) NOT NULL,
|
||||
raw_content TEXT,
|
||||
meaning TEXT NOT NULL,
|
||||
session_id_dict TEXT NOT NULL,
|
||||
count INTEGER NOT NULL,
|
||||
is_jargon BOOLEAN,
|
||||
is_complete BOOLEAN NOT NULL,
|
||||
is_global BOOLEAN NOT NULL,
|
||||
last_inference_count INTEGER NOT NULL,
|
||||
inference_with_context TEXT,
|
||||
inference_with_content_only TEXT
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS llm_usage (
|
||||
id INTEGER NOT NULL,
|
||||
model_name VARCHAR(255) NOT NULL,
|
||||
model_assign_name VARCHAR(255),
|
||||
model_api_provider_name VARCHAR(255) NOT NULL,
|
||||
endpoint VARCHAR(255),
|
||||
user_type VARCHAR(6),
|
||||
request_type VARCHAR(50) NOT NULL,
|
||||
time_cost FLOAT,
|
||||
timestamp DATETIME,
|
||||
prompt_tokens INTEGER NOT NULL,
|
||||
completion_tokens INTEGER NOT NULL,
|
||||
total_tokens INTEGER NOT NULL,
|
||||
cost FLOAT NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS mai_knowledge (
|
||||
id INTEGER NOT NULL,
|
||||
knowledge_id VARCHAR(255) NOT NULL,
|
||||
category_id VARCHAR(32) NOT NULL,
|
||||
content VARCHAR NOT NULL,
|
||||
normalized_content VARCHAR NOT NULL,
|
||||
metadata_json VARCHAR,
|
||||
created_at DATETIME,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS mai_messages (
|
||||
id INTEGER NOT NULL,
|
||||
message_id VARCHAR(255) NOT NULL,
|
||||
timestamp DATETIME,
|
||||
platform VARCHAR(100) NOT NULL,
|
||||
user_id VARCHAR(255) NOT NULL,
|
||||
user_nickname VARCHAR(255) NOT NULL,
|
||||
user_cardname VARCHAR(255),
|
||||
group_id VARCHAR(255),
|
||||
group_name VARCHAR(255),
|
||||
is_mentioned BOOLEAN NOT NULL,
|
||||
is_at BOOLEAN NOT NULL,
|
||||
session_id VARCHAR(255) NOT NULL,
|
||||
reply_to VARCHAR(255),
|
||||
is_emoji BOOLEAN NOT NULL,
|
||||
is_picture BOOLEAN NOT NULL,
|
||||
is_command BOOLEAN NOT NULL,
|
||||
is_notify BOOLEAN NOT NULL,
|
||||
raw_content BLOB,
|
||||
processed_plain_text VARCHAR,
|
||||
display_message VARCHAR,
|
||||
additional_config VARCHAR,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS online_time (
|
||||
id INTEGER NOT NULL,
|
||||
timestamp DATETIME,
|
||||
duration_minutes INTEGER NOT NULL,
|
||||
start_timestamp DATETIME,
|
||||
end_timestamp DATETIME,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS person_info (
|
||||
id INTEGER NOT NULL,
|
||||
is_known BOOLEAN NOT NULL,
|
||||
person_id VARCHAR(255) NOT NULL,
|
||||
person_name VARCHAR(255),
|
||||
name_reason VARCHAR,
|
||||
platform VARCHAR(100) NOT NULL,
|
||||
user_id VARCHAR(255) NOT NULL,
|
||||
user_nickname VARCHAR(255) NOT NULL,
|
||||
group_cardname VARCHAR,
|
||||
memory_points VARCHAR,
|
||||
know_counts INTEGER NOT NULL,
|
||||
first_known_time DATETIME,
|
||||
last_known_time DATETIME,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS thinking_questions (
|
||||
id INTEGER NOT NULL,
|
||||
question VARCHAR NOT NULL,
|
||||
context VARCHAR,
|
||||
found_answer BOOLEAN NOT NULL,
|
||||
answer VARCHAR,
|
||||
thinking_steps VARCHAR,
|
||||
created_timestamp DATETIME,
|
||||
updated_timestamp DATETIME,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS tool_records (
|
||||
id INTEGER NOT NULL,
|
||||
tool_id VARCHAR(255) NOT NULL,
|
||||
timestamp DATETIME,
|
||||
session_id VARCHAR(255) NOT NULL,
|
||||
tool_name VARCHAR(255) NOT NULL,
|
||||
tool_reasoning VARCHAR,
|
||||
tool_data VARCHAR,
|
||||
tool_builtin_prompt VARCHAR,
|
||||
tool_display_prompt VARCHAR,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
_V2_INDEX_STATEMENTS = (
|
||||
"CREATE INDEX IF NOT EXISTS ix_action_records_action_id ON action_records (action_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_action_records_action_name ON action_records (action_name)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_action_records_session_id ON action_records (session_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_action_records_timestamp ON action_records (timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_binary_data_data_hash ON binary_data (data_hash)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_chat_history_end_timestamp ON chat_history (end_timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_chat_history_session_id ON chat_history (session_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_chat_history_start_timestamp ON chat_history (start_timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_chat_sessions_created_timestamp ON chat_sessions (created_timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_chat_sessions_group_id ON chat_sessions (group_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_chat_sessions_last_active_timestamp ON chat_sessions (last_active_timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_chat_sessions_platform ON chat_sessions (platform)",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS ix_chat_sessions_session_id ON chat_sessions (session_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_chat_sessions_user_id ON chat_sessions (user_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_command_records_command_name ON command_records (command_name)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_command_records_session_id ON command_records (session_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_command_records_timestamp ON command_records (timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_expressions_last_active_time ON expressions (last_active_time)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_expressions_situation ON expressions (situation)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_expressions_style ON expressions (style)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_images_image_hash ON images (image_hash)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_images_record_time ON images (record_time)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_jargons_content ON jargons (content)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_llm_usage_model_api_provider_name ON llm_usage (model_api_provider_name)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_llm_usage_model_assign_name ON llm_usage (model_assign_name)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_llm_usage_model_name ON llm_usage (model_name)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_llm_usage_timestamp ON llm_usage (timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_mai_knowledge_category_id ON mai_knowledge (category_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_mai_knowledge_created_at ON mai_knowledge (created_at)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_mai_knowledge_knowledge_id ON mai_knowledge (knowledge_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_mai_knowledge_normalized_content ON mai_knowledge (normalized_content)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_mai_messages_group_id ON mai_messages (group_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_mai_messages_message_id ON mai_messages (message_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_mai_messages_platform ON mai_messages (platform)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_mai_messages_session_id ON mai_messages (session_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_mai_messages_user_id ON mai_messages (user_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_mai_messages_user_nickname ON mai_messages (user_nickname)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_online_time_timestamp ON online_time (timestamp)",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS ix_person_info_person_id ON person_info (person_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_person_info_platform ON person_info (platform)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_person_info_user_id ON person_info (user_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_person_info_user_nickname ON person_info (user_nickname)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_thinking_questions_created_timestamp ON thinking_questions (created_timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_thinking_questions_updated_timestamp ON thinking_questions (updated_timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_tool_records_session_id ON tool_records (session_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_tool_records_timestamp ON tool_records (timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_tool_records_tool_id ON tool_records (tool_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_tool_records_tool_name ON tool_records (tool_name)",
|
||||
)
|
||||
|
||||
|
||||
def create_frozen_v2_schema(connection: Connection) -> None:
|
||||
"""创建冻结的 v2 schema。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
"""
|
||||
|
||||
for statement in _V2_TABLE_STATEMENTS:
|
||||
connection.exec_driver_sql(statement)
|
||||
|
||||
for statement in _V2_INDEX_STATEMENTS:
|
||||
connection.exec_driver_sql(statement)
|
||||
@@ -1,4 +1,4 @@
|
||||
"""旧版 ``0.x`` 数据库升级到最新 schema 的迁移逻辑。"""
|
||||
"""旧版 ``0.x`` 数据库升级到 v2 schema 的迁移逻辑。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -7,15 +7,16 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, cast
|
||||
|
||||
import json
|
||||
|
||||
import msgpack
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
import json
|
||||
import msgpack
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .exceptions import DatabaseMigrationExecutionError
|
||||
from .frozen_v2_schema import create_frozen_v2_schema
|
||||
from .models import DatabaseSchemaSnapshot, MigrationExecutionContext
|
||||
from .schema import SQLiteSchemaInspector
|
||||
|
||||
@@ -52,19 +53,15 @@ class LegacyTableData:
|
||||
|
||||
|
||||
def migrate_legacy_v1_to_v2(context: MigrationExecutionContext) -> None:
|
||||
"""执行旧版 ``0.x`` 数据库到最新 schema 的迁移。
|
||||
"""执行旧版 ``0.x`` 数据库到 v2 schema 的迁移。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
import src.common.database.database_model # noqa: F401
|
||||
|
||||
schema_inspector = SQLiteSchemaInspector()
|
||||
snapshot = schema_inspector.inspect(context.connection)
|
||||
_rename_legacy_v1_tables(context.connection, snapshot)
|
||||
SQLModel.metadata.create_all(context.connection)
|
||||
create_frozen_v2_schema(context.connection)
|
||||
|
||||
table_migration_jobs: List[Tuple[str, Callable[[MigrationExecutionContext], int]]] = [
|
||||
("chat_sessions", _migrate_chat_sessions),
|
||||
@@ -794,8 +791,6 @@ def _migrate_images(context: MigrationExecutionContext) -> int:
|
||||
if full_path and dedupe_key not in existing_keys:
|
||||
migrated_description = _normalize_required_text(row.get("description"))
|
||||
migrated_emotion = _normalize_optional_text(row.get("emotion"))
|
||||
if not migrated_description and migrated_emotion:
|
||||
migrated_description = migrated_emotion
|
||||
connection.execute(
|
||||
insert_sql,
|
||||
{
|
||||
@@ -803,7 +798,7 @@ def _migrate_images(context: MigrationExecutionContext) -> int:
|
||||
"description": migrated_description,
|
||||
"full_path": full_path,
|
||||
"image_type": "EMOJI",
|
||||
"emotion": None,
|
||||
"emotion": migrated_emotion,
|
||||
"query_count": _normalize_int(row.get("query_count"), default=0),
|
||||
"is_registered": _normalize_bool(row.get("is_registered"), default=False),
|
||||
"is_banned": _normalize_bool(row.get("is_banned"), default=False),
|
||||
|
||||
269
src/common/database/migrations/v2_to_v3.py
Normal file
269
src/common/database/migrations/v2_to_v3.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""v2 schema 升级到 v3 的迁移逻辑。"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .exceptions import DatabaseMigrationExecutionError
|
||||
from .models import MigrationExecutionContext
|
||||
from .schema import SQLiteSchemaInspector
|
||||
|
||||
logger = get_logger("database_migration")
|
||||
|
||||
_V2_IMAGES_BACKUP_TABLE = "__v2_images_backup"
|
||||
_V3_IMAGES_CREATE_SQL = """
|
||||
CREATE TABLE images (
|
||||
id INTEGER NOT NULL,
|
||||
image_hash VARCHAR(255) NOT NULL,
|
||||
description VARCHAR NOT NULL,
|
||||
full_path VARCHAR(1024) NOT NULL,
|
||||
image_type VARCHAR(5),
|
||||
query_count INTEGER NOT NULL,
|
||||
is_registered BOOLEAN NOT NULL,
|
||||
is_banned BOOLEAN NOT NULL,
|
||||
no_file_flag BOOLEAN NOT NULL,
|
||||
record_time DATETIME,
|
||||
register_time DATETIME,
|
||||
last_used_time DATETIME,
|
||||
vlm_processed BOOLEAN NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
"""
|
||||
_V3_IMAGES_INDEX_STATEMENTS = (
|
||||
"CREATE INDEX ix_images_image_hash ON images (image_hash)",
|
||||
"CREATE INDEX ix_images_record_time ON images (record_time)",
|
||||
)
|
||||
|
||||
|
||||
def migrate_v2_to_v3(context: MigrationExecutionContext) -> None:
|
||||
"""执行 v2 到 v3 的 schema 迁移。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
|
||||
connection = context.connection
|
||||
total_records = (
|
||||
_count_table_rows(connection, "action_records")
|
||||
+ _count_table_rows(connection, "thinking_questions")
|
||||
+ _count_table_rows(connection, "images")
|
||||
)
|
||||
context.start_progress(
|
||||
total_tables=3,
|
||||
total_records=total_records,
|
||||
description="v2 -> v3 迁移进度",
|
||||
table_unit_name="表",
|
||||
record_unit_name="记录",
|
||||
)
|
||||
|
||||
migrated_tool_records = _migrate_action_records_to_tool_records(connection)
|
||||
action_record_count = _count_table_rows(connection, "action_records")
|
||||
_drop_table_if_exists(connection, "action_records")
|
||||
context.advance_progress(
|
||||
records=action_record_count,
|
||||
completed_tables=1,
|
||||
item_name="action_records",
|
||||
)
|
||||
|
||||
thinking_question_count = _count_table_rows(connection, "thinking_questions")
|
||||
_drop_table_if_exists(connection, "thinking_questions")
|
||||
context.advance_progress(
|
||||
records=thinking_question_count,
|
||||
completed_tables=1,
|
||||
item_name="thinking_questions",
|
||||
)
|
||||
|
||||
migrated_image_rows = _migrate_images_table_to_v3(connection)
|
||||
context.advance_progress(
|
||||
records=migrated_image_rows,
|
||||
completed_tables=1,
|
||||
item_name="images",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"v2 -> v3 数据库迁移完成: "
|
||||
f"tool_records补迁移={migrated_tool_records},"
|
||||
f"images重建={migrated_image_rows}"
|
||||
)
|
||||
|
||||
|
||||
def _count_table_rows(connection: Connection, table_name: str) -> int:
|
||||
"""统计表记录数,不存在时返回 0。"""
|
||||
|
||||
schema_inspector = SQLiteSchemaInspector()
|
||||
if not schema_inspector.table_exists(connection, table_name):
|
||||
return 0
|
||||
row = connection.execute(text(f'SELECT COUNT(*) FROM "{table_name}"')).first()
|
||||
return int(row[0]) if row else 0
|
||||
|
||||
|
||||
def _drop_table_if_exists(connection: Connection, table_name: str) -> None:
|
||||
"""删除指定表,不存在时静默跳过。"""
|
||||
|
||||
connection.exec_driver_sql(f'DROP TABLE IF EXISTS "{table_name}"')
|
||||
|
||||
|
||||
def _migrate_action_records_to_tool_records(connection: Connection) -> int:
|
||||
"""把 v2 中残留的 ``action_records`` 数据转存到 ``tool_records``。"""
|
||||
|
||||
schema_inspector = SQLiteSchemaInspector()
|
||||
if not schema_inspector.table_exists(connection, "action_records"):
|
||||
return 0
|
||||
|
||||
inserted_count = _count_table_rows(connection, "action_records")
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO tool_records (
|
||||
tool_id,
|
||||
timestamp,
|
||||
session_id,
|
||||
tool_name,
|
||||
tool_reasoning,
|
||||
tool_data,
|
||||
tool_builtin_prompt,
|
||||
tool_display_prompt
|
||||
)
|
||||
SELECT
|
||||
action_id,
|
||||
timestamp,
|
||||
session_id,
|
||||
action_name,
|
||||
action_reasoning,
|
||||
action_data,
|
||||
action_builtin_prompt,
|
||||
action_display_prompt
|
||||
FROM action_records
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM tool_records
|
||||
WHERE tool_records.tool_id = action_records.action_id
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
return inserted_count
|
||||
|
||||
|
||||
def _migrate_images_table_to_v3(connection: Connection) -> int:
|
||||
"""重建 ``images`` 表并移除 ``emotion`` 列。"""
|
||||
|
||||
schema_inspector = SQLiteSchemaInspector()
|
||||
if not schema_inspector.table_exists(connection, "images"):
|
||||
return 0
|
||||
if not schema_inspector.get_table_schema(connection, "images").has_column("emotion"):
|
||||
return _count_table_rows(connection, "images")
|
||||
if schema_inspector.table_exists(connection, _V2_IMAGES_BACKUP_TABLE):
|
||||
raise DatabaseMigrationExecutionError(
|
||||
f"检测到残留备份表 {_V2_IMAGES_BACKUP_TABLE},无法安全执行 v2 -> v3 images 迁移。"
|
||||
)
|
||||
|
||||
connection.exec_driver_sql(f'ALTER TABLE "images" RENAME TO "{_V2_IMAGES_BACKUP_TABLE}"')
|
||||
connection.exec_driver_sql(_V3_IMAGES_CREATE_SQL)
|
||||
|
||||
legacy_rows = connection.execute(
|
||||
text(f'SELECT * FROM "{_V2_IMAGES_BACKUP_TABLE}" ORDER BY id')
|
||||
).mappings().all()
|
||||
insert_sql = text(
|
||||
"""
|
||||
INSERT INTO images (
|
||||
id,
|
||||
image_hash,
|
||||
description,
|
||||
full_path,
|
||||
image_type,
|
||||
query_count,
|
||||
is_registered,
|
||||
is_banned,
|
||||
no_file_flag,
|
||||
record_time,
|
||||
register_time,
|
||||
last_used_time,
|
||||
vlm_processed
|
||||
) VALUES (
|
||||
:id,
|
||||
:image_hash,
|
||||
:description,
|
||||
:full_path,
|
||||
:image_type,
|
||||
:query_count,
|
||||
:is_registered,
|
||||
:is_banned,
|
||||
:no_file_flag,
|
||||
:record_time,
|
||||
:register_time,
|
||||
:last_used_time,
|
||||
:vlm_processed
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
for row in legacy_rows:
|
||||
payload: Dict[str, Any] = {
|
||||
"id": row.get("id"),
|
||||
"image_hash": str(row.get("image_hash") or "").strip(),
|
||||
"description": _migrate_v3_emoji_description(row),
|
||||
"full_path": str(row.get("full_path") or "").strip(),
|
||||
"image_type": row.get("image_type"),
|
||||
"query_count": int(row.get("query_count") or 0),
|
||||
"is_registered": bool(row.get("is_registered")),
|
||||
"is_banned": bool(row.get("is_banned")),
|
||||
"no_file_flag": bool(row.get("no_file_flag")),
|
||||
"record_time": row.get("record_time"),
|
||||
"register_time": row.get("register_time"),
|
||||
"last_used_time": row.get("last_used_time"),
|
||||
"vlm_processed": bool(row.get("vlm_processed")),
|
||||
}
|
||||
connection.execute(insert_sql, payload)
|
||||
|
||||
connection.exec_driver_sql(f'DROP TABLE "{_V2_IMAGES_BACKUP_TABLE}"')
|
||||
for statement in _V3_IMAGES_INDEX_STATEMENTS:
|
||||
connection.exec_driver_sql(statement)
|
||||
return len(legacy_rows)
|
||||
|
||||
|
||||
def _migrate_v3_emoji_description(row: Dict[str, Any]) -> str:
|
||||
"""为 v3 统一 emoji 描述字段语义。
|
||||
|
||||
v3 中 `description` 对 emoji 统一承担“标签列表”的职责,因此迁移时:
|
||||
1. 若旧 `emotion` 非空,优先将其规范化后写入 `description`;
|
||||
2. 否则保留并规范化当前 `description`;
|
||||
3. 非 emoji 图片保持原描述不变。
|
||||
"""
|
||||
|
||||
image_type = str(row.get("image_type") or "").strip().upper()
|
||||
current_description = str(row.get("description") or "").strip()
|
||||
current_emotion = str(row.get("emotion") or "").strip()
|
||||
if image_type != "EMOJI":
|
||||
return current_description
|
||||
|
||||
normalized_tags = _normalize_emoji_tag_text(current_emotion or current_description)
|
||||
if normalized_tags:
|
||||
return ",".join(normalized_tags)
|
||||
return current_description
|
||||
|
||||
|
||||
def _normalize_emoji_tag_text(raw_value: Any) -> List[str]:
|
||||
"""将 emoji 标签文本转换为去重后的标签列表。"""
|
||||
|
||||
normalized_text = str(raw_value or "").strip()
|
||||
if not normalized_text:
|
||||
return []
|
||||
|
||||
separators = [",", ",", "、", ";", ";", "\n", "\r", "\t"]
|
||||
for separator in separators[1:]:
|
||||
normalized_text = normalized_text.replace(separator, separators[0])
|
||||
|
||||
deduped_tags: List[str] = []
|
||||
seen_tags: set[str] = set()
|
||||
for part in normalized_text.split(separators[0]):
|
||||
normalized_part = part.strip()
|
||||
lowered_part = normalized_part.lower()
|
||||
if not normalized_part or lowered_part in seen_tags:
|
||||
continue
|
||||
seen_tags.add(lowered_part)
|
||||
deduped_tags.append(normalized_part)
|
||||
return deduped_tags
|
||||
@@ -1,16 +1,14 @@
|
||||
import contextlib
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, Tuple, Callable
|
||||
import contextlib
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services import llm_service as llm_api
|
||||
from sqlmodel import select, col
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ThinkingQuestion
|
||||
|
||||
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
@@ -18,35 +16,6 @@ from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
|
||||
|
||||
logger = get_logger("memory_retrieval")
|
||||
|
||||
THINKING_BACK_NOT_FOUND_RETENTION_SECONDS = 36000 # 未找到答案记录保留时长
|
||||
THINKING_BACK_CLEANUP_INTERVAL_SECONDS = 3000 # 清理频率
|
||||
_last_not_found_cleanup_ts: float = 0.0
|
||||
|
||||
|
||||
def _cleanup_stale_not_found_thinking_back() -> None:
|
||||
"""定期清理过期的未找到答案记录"""
|
||||
global _last_not_found_cleanup_ts
|
||||
|
||||
now = time.time()
|
||||
if now - _last_not_found_cleanup_ts < THINKING_BACK_CLEANUP_INTERVAL_SECONDS:
|
||||
return
|
||||
|
||||
threshold_time = now - THINKING_BACK_NOT_FOUND_RETENTION_SECONDS
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(ThinkingQuestion).where(
|
||||
col(ThinkingQuestion.found_answer).is_(False)
|
||||
& (ThinkingQuestion.updated_timestamp < datetime.fromtimestamp(threshold_time))
|
||||
)
|
||||
records = session.exec(statement).all()
|
||||
for record in records:
|
||||
session.delete(record)
|
||||
if records:
|
||||
logger.info(f"清理过期的未找到答案thinking_question记录 {len(records)} 条")
|
||||
_last_not_found_cleanup_ts = now
|
||||
except Exception as e:
|
||||
logger.error(f"清理未找到答案的thinking_back记录失败: {e}")
|
||||
|
||||
|
||||
def init_memory_retrieval_sys():
|
||||
"""初始化记忆检索相关工具"""
|
||||
@@ -766,141 +735,6 @@ async def _react_agent_solve_question(
|
||||
|
||||
return False, "", thinking_steps, is_timeout
|
||||
|
||||
|
||||
def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0) -> str:
|
||||
"""获取最近一段时间内的查询历史(用于避免重复查询)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
time_window_seconds: 时间窗口(秒),默认10分钟
|
||||
|
||||
Returns:
|
||||
str: 格式化的查询历史字符串
|
||||
"""
|
||||
try:
|
||||
_current_time = time.time()
|
||||
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(ThinkingQuestion)
|
||||
.where(col(ThinkingQuestion.context) == chat_id)
|
||||
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
|
||||
.limit(5)
|
||||
)
|
||||
records = session.exec(statement).all()
|
||||
|
||||
if not records:
|
||||
return ""
|
||||
|
||||
history_lines = ["最近已查询的问题和结果:"]
|
||||
|
||||
for record in records:
|
||||
status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案"
|
||||
answer_preview = ""
|
||||
# 只有找到答案时才显示答案内容
|
||||
if record.found_answer and record.answer:
|
||||
# 截取答案前100字符
|
||||
answer_preview = record.answer[:100]
|
||||
if len(record.answer) > 100:
|
||||
answer_preview += "..."
|
||||
|
||||
history_lines.extend([f"- 问题:{record.question}", f" 状态:{status}"])
|
||||
if answer_preview:
|
||||
history_lines.append(f" 答案:{answer_preview}")
|
||||
history_lines.append("") # 空行分隔
|
||||
|
||||
return "\n".join(history_lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取查询历史失败: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0) -> List[str]:
|
||||
"""获取最近一段时间内已找到答案的查询记录(用于返回给 replyer)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
time_window_seconds: 时间窗口(秒),默认10分钟
|
||||
|
||||
Returns:
|
||||
List[str]: 格式化的答案列表,每个元素格式为 "问题:xxx\n答案:xxx"
|
||||
"""
|
||||
try:
|
||||
_current_time = time.time()
|
||||
|
||||
# 查询最近时间窗口内已找到答案的记录,按更新时间倒序
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(ThinkingQuestion)
|
||||
.where(col(ThinkingQuestion.context) == chat_id)
|
||||
.where(col(ThinkingQuestion.found_answer))
|
||||
.where(col(ThinkingQuestion.answer).is_not(None))
|
||||
.where(col(ThinkingQuestion.answer) != "")
|
||||
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
|
||||
.limit(3)
|
||||
)
|
||||
records = session.exec(statement).all()
|
||||
|
||||
if not records:
|
||||
return []
|
||||
|
||||
return [f"问题:{record.question}\n答案:{record.answer}" for record in records if record.answer]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取最近已找到答案的记录失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _store_thinking_back(
|
||||
chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""存储或更新思考过程到数据库(如果已存在则更新,否则创建)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
question: 问题
|
||||
context: 上下文信息
|
||||
found_answer: 是否找到答案
|
||||
answer: 答案内容
|
||||
thinking_steps: 思考步骤列表
|
||||
"""
|
||||
try:
|
||||
now = time.time()
|
||||
|
||||
# 先查询是否已存在相同chat_id和问题的记录
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(ThinkingQuestion)
|
||||
.where(col(ThinkingQuestion.context) == chat_id)
|
||||
.where(col(ThinkingQuestion.question) == question)
|
||||
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
|
||||
.limit(1)
|
||||
)
|
||||
if record := session.exec(statement).first():
|
||||
record.context = context
|
||||
record.found_answer = found_answer
|
||||
record.answer = answer
|
||||
record.thinking_steps = json.dumps(thinking_steps, ensure_ascii=False)
|
||||
record.updated_timestamp = datetime.fromtimestamp(now)
|
||||
session.add(record)
|
||||
logger.info(f"已更新思考过程到数据库,问题: {question[:50]}...")
|
||||
return
|
||||
|
||||
new_record = ThinkingQuestion(
|
||||
question=question,
|
||||
context=chat_id,
|
||||
found_answer=found_answer,
|
||||
answer=answer,
|
||||
thinking_steps=json.dumps(thinking_steps, ensure_ascii=False),
|
||||
created_timestamp=datetime.fromtimestamp(now),
|
||||
updated_timestamp=datetime.fromtimestamp(now),
|
||||
)
|
||||
session.add(new_record)
|
||||
except Exception as e:
|
||||
logger.error(f"存储思考过程失败: {e}")
|
||||
|
||||
|
||||
async def _process_memory_retrieval(
|
||||
chat_id: str,
|
||||
context: str,
|
||||
@@ -920,8 +754,6 @@ async def _process_memory_retrieval(
|
||||
Returns:
|
||||
Optional[str]: 如果找到答案,返回答案内容,否则返回None
|
||||
"""
|
||||
_cleanup_stale_not_found_thinking_back()
|
||||
|
||||
question_initial_info = initial_info or ""
|
||||
|
||||
# 直接使用ReAct Agent进行记忆检索
|
||||
|
||||
@@ -585,10 +585,8 @@ class RuntimeDataCapabilityMixin:
|
||||
if not ImageUtils.base64_to_image(emoji_base64, str(temp_file_path)):
|
||||
return {"success": False, "message": "无法保存图片文件", "description": None, "emotions": None, "replaced": None, "hash": None}
|
||||
|
||||
register_success = await emoji_manager.register_emoji_by_filename(temp_file_path)
|
||||
if not register_success:
|
||||
if temp_file_path.exists():
|
||||
temp_file_path.unlink(missing_ok=True)
|
||||
register_status = await emoji_manager.register_emoji_by_filename(temp_file_path)
|
||||
if register_status == "failed":
|
||||
return {
|
||||
"success": False,
|
||||
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
|
||||
@@ -597,6 +595,15 @@ class RuntimeDataCapabilityMixin:
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
if register_status == "skipped":
|
||||
return {
|
||||
"success": True,
|
||||
"message": "表情包已注册,已跳过本次注册",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": False,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
count_after = len(emoji_manager.emojis)
|
||||
replaced = count_after <= count_before
|
||||
|
||||
@@ -40,7 +40,7 @@ from .schemas import (
|
||||
emoji_to_response,
|
||||
)
|
||||
from .support import (
|
||||
EMOJI_REGISTERED_DIR,
|
||||
EMOJI_DIR,
|
||||
THUMBNAIL_CACHE_DIR,
|
||||
background_generate_thumbnail,
|
||||
cleanup_orphaned_thumbnails,
|
||||
@@ -326,7 +326,7 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
if emoji.is_registered:
|
||||
raise HTTPException(status_code=400, detail="该表情包已经注册")
|
||||
return EmojiUpdateResponse(success=True, message="??????????", data=emoji_to_response(emoji))
|
||||
|
||||
emoji.is_registered = True
|
||||
emoji.is_banned = False
|
||||
@@ -549,16 +549,16 @@ async def upload_emoji(
|
||||
if existing_emoji := session.exec(existing_statement).first():
|
||||
raise HTTPException(status_code=409, detail=f"已存在相同的表情包 (ID: {existing_emoji.id})")
|
||||
|
||||
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
|
||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||
|
||||
timestamp = int(datetime.now().timestamp())
|
||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
|
||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||
full_path = os.path.join(EMOJI_DIR, filename)
|
||||
|
||||
counter = 1
|
||||
while os.path.exists(full_path):
|
||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}"
|
||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||
full_path = os.path.join(EMOJI_DIR, filename)
|
||||
counter += 1
|
||||
|
||||
with open(full_path, "wb") as output_file:
|
||||
@@ -618,7 +618,7 @@ async def batch_upload_emoji(
|
||||
}
|
||||
|
||||
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
|
||||
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
|
||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||
|
||||
for file in files:
|
||||
try:
|
||||
@@ -665,12 +665,12 @@ async def batch_upload_emoji(
|
||||
|
||||
timestamp = int(datetime.now().timestamp())
|
||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
|
||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||
full_path = os.path.join(EMOJI_DIR, filename)
|
||||
|
||||
counter = 1
|
||||
while os.path.exists(full_path):
|
||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}"
|
||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||
full_path = os.path.join(EMOJI_DIR, filename)
|
||||
counter += 1
|
||||
|
||||
with open(full_path, "wb") as output_file:
|
||||
|
||||
@@ -16,7 +16,8 @@ logger = get_logger("webui.emoji")
|
||||
THUMBNAIL_CACHE_DIR = Path("data/emoji_thumbnails")
|
||||
THUMBNAIL_SIZE = (200, 200)
|
||||
THUMBNAIL_QUALITY = 80
|
||||
EMOJI_REGISTERED_DIR = os.path.join("data", "emoji_registed")
|
||||
EMOJI_REGISTERED_DIR = os.path.join("data", "emoji")
|
||||
EMOJI_DIR = EMOJI_REGISTERED_DIR
|
||||
|
||||
_thumbnail_locks: Dict[str, threading.Lock] = {}
|
||||
_locks_lock = threading.Lock()
|
||||
|
||||
Reference in New Issue
Block a user