feat:优化表情包注册,迁移数据库v3
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
@@ -31,12 +31,12 @@ install(extra_lines=3)
|
|||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve()
|
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve()
|
||||||
DATA_DIR = PROJECT_ROOT / "data"
|
DATA_DIR = PROJECT_ROOT / "data"
|
||||||
|
EmojiRegisterStatus = Literal["registered", "skipped", "failed"]
|
||||||
EMOJI_DIR = DATA_DIR / "emoji" # 表情包存储目录
|
EMOJI_DIR = DATA_DIR / "emoji" # 表情包存储目录
|
||||||
EMOJI_REGISTERED_DIR = DATA_DIR / "emoji_registered" # 已注册的表情包注册目录
|
|
||||||
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
|
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
|
||||||
|
|
||||||
|
|
||||||
def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
def register_emoji_hook_specs(registry: HookSpecRegistry) -> list[HookSpec]:
|
||||||
"""注册表情包系统内置 Hook 规格。
|
"""注册表情包系统内置 Hook 规格。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -145,7 +145,7 @@ def _get_runtime_manager() -> Any:
|
|||||||
return get_plugin_runtime_manager()
|
return get_plugin_runtime_manager()
|
||||||
|
|
||||||
|
|
||||||
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, Any]]:
|
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[dict[str, Any]]:
|
||||||
"""将表情包对象序列化为 Hook 可传输载荷。
|
"""将表情包对象序列化为 Hook 可传输载荷。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -163,27 +163,27 @@ def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, A
|
|||||||
"file_name": emoji.file_name,
|
"file_name": emoji.file_name,
|
||||||
"full_path": str(emoji.full_path),
|
"full_path": str(emoji.full_path),
|
||||||
"description": emoji.description,
|
"description": emoji.description,
|
||||||
"emotions": [str(item).strip() for item in _normalize_emoji_tag_text(emoji.description or emoji.emotion)],
|
"emotions": [str(item).strip() for item in _normalize_emoji_tag_text(emoji.description)],
|
||||||
"query_count": int(emoji.query_count),
|
"query_count": int(emoji.query_count),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _normalize_emoji_tag_text(raw_values: Any) -> List[str]:
|
def _normalize_emoji_tag_text(raw_values: Any) -> list[str]:
|
||||||
"""将文本或标签列表转为去重的情绪标签列表。"""
|
"""将文本或标签列表转为去重的情绪标签列表。"""
|
||||||
if isinstance(raw_values, str):
|
if isinstance(raw_values, str):
|
||||||
if not raw_values:
|
if not raw_values:
|
||||||
return []
|
return []
|
||||||
parts = re.split(r"[,,、;;\s]+", raw_values.strip())
|
parts = re.split(r"[,,、;;\r\n\t]+", raw_values.strip())
|
||||||
normalized_tags = [str(part).strip() for part in parts if str(part).strip()]
|
normalized_tags = [str(part).strip() for part in parts if str(part).strip()]
|
||||||
elif isinstance(raw_values, list):
|
elif isinstance(raw_values, list):
|
||||||
normalized_tags: List[str] = []
|
normalized_tags: list[str] = []
|
||||||
for value in raw_values:
|
for value in raw_values:
|
||||||
normalized_tags.extend(_normalize_emoji_tag_text(value))
|
normalized_tags.extend(_normalize_emoji_tag_text(value))
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
deduped_tags: List[str] = []
|
deduped_tags: list[str] = []
|
||||||
seen: Set[str] = set()
|
seen: set[str] = set()
|
||||||
for tag in normalized_tags:
|
for tag in normalized_tags:
|
||||||
normalized_tag = tag.strip()
|
normalized_tag = tag.strip()
|
||||||
if not normalized_tag:
|
if not normalized_tag:
|
||||||
@@ -196,15 +196,23 @@ def _normalize_emoji_tag_text(raw_values: Any) -> List[str]:
|
|||||||
return deduped_tags
|
return deduped_tags
|
||||||
|
|
||||||
|
|
||||||
def _get_emoji_emotions(emoji: MaiEmoji) -> List[str]:
|
def _get_emoji_emotions(emoji: MaiEmoji) -> list[str]:
|
||||||
"""获取兼容旧数据的表情包情绪标签。"""
|
"""获取表情包情绪标签。"""
|
||||||
return _normalize_emoji_tag_text(emoji.description or emoji.emotion)
|
return _normalize_emoji_tag_text(emoji.description)
|
||||||
|
|
||||||
|
|
||||||
def _ensure_directories() -> None:
|
def _ensure_directories() -> None:
|
||||||
"""确保表情包相关目录存在"""
|
"""确保表情包相关目录存在"""
|
||||||
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
|
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
EMOJI_REGISTERED_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
|
def _is_available_emoji_record(record: Images) -> bool:
|
||||||
|
"""判断数据库记录对应的表情包文件当前是否可用。"""
|
||||||
|
if record.no_file_flag:
|
||||||
|
return False
|
||||||
|
|
||||||
|
record_path = Path(record.full_path)
|
||||||
|
return record_path.exists() and record_path.is_file()
|
||||||
|
|
||||||
|
|
||||||
# TODO: 修改这个vlm为获取的vlm client,暂时使用这个VLM方法
|
# TODO: 修改这个vlm为获取的vlm client,暂时使用这个VLM方法
|
||||||
@@ -224,9 +232,9 @@ class EmojiManager:
|
|||||||
_ensure_directories()
|
_ensure_directories()
|
||||||
|
|
||||||
self._emoji_num: int = 0
|
self._emoji_num: int = 0
|
||||||
self.emojis: List[MaiEmoji] = []
|
self.emojis: list[MaiEmoji] = []
|
||||||
self._maintenance_wakeup_event: asyncio.Event = asyncio.Event()
|
self._maintenance_wakeup_event: asyncio.Event = asyncio.Event()
|
||||||
self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {}
|
self._pending_description_tasks: dict[str, asyncio.Task[None]] = {}
|
||||||
self._reload_callback_registered: bool = False
|
self._reload_callback_registered: bool = False
|
||||||
|
|
||||||
config_manager.register_reload_callback(self.reload_runtime_config)
|
config_manager.register_reload_callback(self.reload_runtime_config)
|
||||||
@@ -254,7 +262,7 @@ class EmojiManager:
|
|||||||
emoji_bytes: Optional[bytes] = None,
|
emoji_bytes: Optional[bytes] = None,
|
||||||
emoji_hash: Optional[str] = None,
|
emoji_hash: Optional[str] = None,
|
||||||
wait_for_build: bool = True,
|
wait_for_build: bool = True,
|
||||||
) -> Optional[Tuple[str, List[str]]]:
|
) -> Optional[tuple[str, list[str]]]:
|
||||||
"""
|
"""
|
||||||
根据表情包哈希获取表情包描述和情感列表的封装方法
|
根据表情包哈希获取表情包描述和情感列表的封装方法
|
||||||
|
|
||||||
@@ -275,12 +283,28 @@ class EmojiManager:
|
|||||||
emoji_hash = hashlib.sha256(emoji_bytes).hexdigest()
|
emoji_hash = hashlib.sha256(emoji_bytes).hexdigest()
|
||||||
|
|
||||||
if emoji := self.get_emoji_by_hash(emoji_hash):
|
if emoji := self.get_emoji_by_hash(emoji_hash):
|
||||||
|
emoji_path = Path(emoji.full_path) if emoji.full_path else None
|
||||||
|
if emoji_bytes and (emoji_path is None or not emoji_path.exists()):
|
||||||
|
try:
|
||||||
|
restored_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
|
||||||
|
emoji.full_path = restored_emoji.full_path
|
||||||
|
emoji.file_name = restored_emoji.file_name
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"表情包缓存命中但本地文件缺失,回填失败: {e}")
|
||||||
return emoji.description, _normalize_emoji_tag_text(emoji.description or "")
|
return emoji.description, _normalize_emoji_tag_text(emoji.description or "")
|
||||||
try:
|
try:
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1)
|
statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1)
|
||||||
if result := session.exec(statement).first():
|
if result := session.exec(statement).first():
|
||||||
cached_description = result.description or result.emotion or ""
|
record_path = Path(result.full_path) if result.full_path else None
|
||||||
|
if emoji_bytes and (result.no_file_flag or record_path is None or not record_path.exists()):
|
||||||
|
try:
|
||||||
|
restored_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
|
||||||
|
result.full_path = str(restored_emoji.full_path)
|
||||||
|
result.no_file_flag = False
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"数据库命中表情包记录但本地文件缺失,回填失败: {e}")
|
||||||
|
cached_description = result.description or ""
|
||||||
cached_emotions = _normalize_emoji_tag_text(cached_description)
|
cached_emotions = _normalize_emoji_tag_text(cached_description)
|
||||||
return (
|
return (
|
||||||
cached_description,
|
cached_description,
|
||||||
@@ -400,7 +424,7 @@ class EmojiManager:
|
|||||||
self,
|
self,
|
||||||
emoji_hash: str,
|
emoji_hash: str,
|
||||||
emoji_bytes: bytes,
|
emoji_bytes: bytes,
|
||||||
) -> Optional[Tuple[str, List[str]]]:
|
) -> Optional[tuple[str, list[str]]]:
|
||||||
"""构建并缓存表情包描述(返回标签化结果,不再走额外识别流程)。"""
|
"""构建并缓存表情包描述(返回标签化结果,不再走额外识别流程)。"""
|
||||||
logger.info(f"Start building cached emoji description, hash={emoji_hash}")
|
logger.info(f"Start building cached emoji description, hash={emoji_hash}")
|
||||||
new_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
|
new_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
|
||||||
@@ -461,87 +485,70 @@ class EmojiManager:
|
|||||||
self._emoji_num = 0
|
self._emoji_num = 0
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def register_emoji_to_db(self, emoji: MaiEmoji) -> bool:
|
def register_emoji_to_db(self, emoji: MaiEmoji) -> EmojiRegisterStatus:
|
||||||
# sourcery skip: extract-method
|
# sourcery skip: extract-method
|
||||||
"""
|
"""Register an emoji in the database without moving its source file."""
|
||||||
将表情包注册到数据库中
|
|
||||||
|
|
||||||
Args:
|
|
||||||
emoji (MaiEmoji): 需要注册的表情包对象
|
|
||||||
Returns:
|
|
||||||
return (bool): 注册是否成功
|
|
||||||
"""
|
|
||||||
if not emoji or not isinstance(emoji, MaiEmoji):
|
if not emoji or not isinstance(emoji, MaiEmoji):
|
||||||
logger.error("[注册表情包] 无效的表情包对象")
|
logger.error("[register_emoji] Invalid emoji object")
|
||||||
return False
|
return "failed"
|
||||||
if not emoji.full_path.exists():
|
if not emoji.full_path.exists():
|
||||||
logger.error(f"[注册表情包] 表情包文件不存在: {emoji.full_path}")
|
logger.error(f"[register_emoji] Emoji file does not exist: {emoji.full_path}")
|
||||||
return False
|
return "failed"
|
||||||
|
|
||||||
target_path = EMOJI_REGISTERED_DIR / emoji.file_name
|
|
||||||
|
|
||||||
# 先查库,避免重复记录导致文件被误移动后无法回收
|
|
||||||
original_path = emoji.full_path
|
|
||||||
try:
|
try:
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||||
existing_record = session.exec(statement).first()
|
existing_record = session.exec(statement).first()
|
||||||
if existing_record and not existing_record.no_file_flag:
|
if existing_record:
|
||||||
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
|
if existing_record.is_registered and _is_available_emoji_record(existing_record):
|
||||||
return False
|
logger.info(f"[register_emoji] Emoji already registered, skipping: {emoji.file_hash}")
|
||||||
except Exception as e:
|
return "skipped"
|
||||||
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
|
|
||||||
return False
|
normalized_description = str(emoji.description or existing_record.description or "").strip()
|
||||||
try:
|
normalized_emotions = _normalize_emoji_tag_text(normalized_description)
|
||||||
emoji.full_path.replace(target_path)
|
register_time = existing_record.register_time or datetime.now()
|
||||||
emoji.full_path = target_path
|
|
||||||
except Exception as e:
|
existing_record.full_path = str(emoji.full_path)
|
||||||
logger.error(f"[注册表情包] 移动表情包文件时出错: {e}")
|
existing_record.description = normalized_description
|
||||||
return False
|
existing_record.query_count = max(int(existing_record.query_count), int(emoji.query_count))
|
||||||
|
existing_record.last_used_time = emoji.last_used_time or existing_record.last_used_time
|
||||||
|
existing_record.is_registered = True
|
||||||
|
existing_record.is_banned = False
|
||||||
|
existing_record.no_file_flag = False
|
||||||
|
existing_record.register_time = register_time
|
||||||
|
session.add(existing_record)
|
||||||
|
|
||||||
|
emoji.description = normalized_description
|
||||||
|
emoji.emotion = normalized_emotions
|
||||||
|
emoji.query_count = existing_record.query_count
|
||||||
|
emoji.last_used_time = existing_record.last_used_time
|
||||||
|
emoji.register_time = register_time
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[register_emoji] Updated existing record and registered emoji, ID: {existing_record.id}, path: {emoji.full_path}"
|
||||||
|
)
|
||||||
|
return "registered"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[register_emoji] Database query failed: {e}")
|
||||||
|
return "failed"
|
||||||
|
|
||||||
# 注册到数据库
|
|
||||||
restore_file = False
|
|
||||||
try:
|
try:
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
image_record = emoji.to_db_instance()
|
||||||
if existing_record := session.exec(statement).first():
|
image_record.is_registered = True
|
||||||
if existing_record.no_file_flag:
|
image_record.is_banned = False
|
||||||
existing_record.no_file_flag = False
|
image_record.no_file_flag = False
|
||||||
existing_record.is_banned = False
|
image_record.register_time = datetime.now()
|
||||||
existing_record.full_path = str(emoji.full_path)
|
session.add(image_record)
|
||||||
existing_record.description = emoji.description
|
session.flush()
|
||||||
existing_record.query_count = emoji.query_count
|
emoji.register_time = image_record.register_time
|
||||||
existing_record.last_used_time = emoji.last_used_time
|
record_id = image_record.id
|
||||||
existing_record.register_time = emoji.register_time
|
logger.info(f"[register_emoji] Registered emoji to database, ID: {record_id}, path: {emoji.full_path}")
|
||||||
session.add(existing_record)
|
|
||||||
logger.info(
|
|
||||||
f"[注册表情包] 更新已有记录并注册表情包到数据库, ID: {existing_record.id}, 路径: {emoji.full_path}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}")
|
|
||||||
restore_file = True
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
image_record = emoji.to_db_instance()
|
|
||||||
image_record.is_registered = True
|
|
||||||
image_record.is_banned = False
|
|
||||||
image_record.register_time = datetime.now()
|
|
||||||
session.add(image_record)
|
|
||||||
session.flush()
|
|
||||||
record_id = image_record.id
|
|
||||||
logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[注册表情包] 注册到数据库时出错: {e}")
|
logger.error(f"[register_emoji] Failed to write database record: {e}")
|
||||||
restore_file = True
|
return "failed"
|
||||||
return False
|
|
||||||
finally:
|
return "registered"
|
||||||
if restore_file:
|
|
||||||
try:
|
|
||||||
emoji.full_path.replace(original_path)
|
|
||||||
emoji.full_path = original_path
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[注册表情包] 回滚文件移动失败: {e}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def delete_emoji(self, emoji: MaiEmoji, no_desc: bool = False) -> bool:
|
def delete_emoji(self, emoji: MaiEmoji, no_desc: bool = False) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -636,7 +643,6 @@ class EmojiManager:
|
|||||||
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||||
if image_record := session.exec(statement).first():
|
if image_record := session.exec(statement).first():
|
||||||
image_record.description = emoji.description
|
image_record.description = emoji.description
|
||||||
image_record.emotion = None
|
|
||||||
session.add(image_record)
|
session.add(image_record)
|
||||||
logger.info(f"[更新表情包] 成功更新表情包信息: {emoji.file_hash}")
|
logger.info(f"[更新表情包] 成功更新表情包信息: {emoji.file_hash}")
|
||||||
else:
|
else:
|
||||||
@@ -794,12 +800,15 @@ class EmojiManager:
|
|||||||
logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}")
|
logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}")
|
||||||
if self.delete_emoji(emoji_to_delete):
|
if self.delete_emoji(emoji_to_delete):
|
||||||
self.emojis.remove(emoji_to_delete)
|
self.emojis.remove(emoji_to_delete)
|
||||||
if self.register_emoji_to_db(new_emoji):
|
register_status = self.register_emoji_to_db(new_emoji)
|
||||||
|
if register_status == "registered":
|
||||||
self.emojis.append(new_emoji)
|
self.emojis.append(new_emoji)
|
||||||
logger.info(f"[注册表情包] 成功替换并注册新表情包: {new_emoji.description}")
|
logger.info(f"[register_emoji] Replaced old emoji with new emoji: {new_emoji.description}")
|
||||||
return True
|
return True
|
||||||
|
if register_status == "skipped":
|
||||||
|
logger.info(f"[register_emoji] Replacement emoji was already registered: {new_emoji.description}")
|
||||||
else:
|
else:
|
||||||
logger.error(f"[注册表情包] 注册新表情包失败: {new_emoji.description}")
|
logger.error(f"[register_emoji] Failed to register replacement emoji: {new_emoji.description}")
|
||||||
else:
|
else:
|
||||||
logger.error("[错误] 删除表情包失败,无法完成替换")
|
logger.error("[错误] 删除表情包失败,无法完成替换")
|
||||||
else:
|
else:
|
||||||
@@ -808,7 +817,7 @@ class EmojiManager:
|
|||||||
logger.error("[决策] 未能解析删除编号")
|
logger.error("[决策] 未能解析删除编号")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def build_emoji_description(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]:
|
async def build_emoji_description(self, target_emoji: MaiEmoji) -> tuple[bool, MaiEmoji]:
|
||||||
"""
|
"""
|
||||||
构建表情包描述
|
构建表情包描述
|
||||||
|
|
||||||
@@ -915,16 +924,12 @@ class EmojiManager:
|
|||||||
logger.info(f"[构建描述] 成功为表情包构建情绪标签: {target_emoji.description}")
|
logger.info(f"[构建描述] 成功为表情包构建情绪标签: {target_emoji.description}")
|
||||||
return True, target_emoji
|
return True, target_emoji
|
||||||
|
|
||||||
async def build_emoji_emotion(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]:
|
|
||||||
"""兼容保留:表情包情绪标签已在 build_emoji_description 中一次性构建。"""
|
|
||||||
return await self.build_emoji_description(target_emoji)
|
|
||||||
|
|
||||||
def check_emoji_file_integrity(self) -> None:
|
def check_emoji_file_integrity(self) -> None:
|
||||||
"""
|
"""
|
||||||
检查表情包完整性,删除文件缺失的表情包记录
|
检查表情包完整性,删除文件缺失的表情包记录
|
||||||
"""
|
"""
|
||||||
logger.info("[完整性检查] 开始检查表情包文件完整性...")
|
logger.info("[完整性检查] 开始检查表情包文件完整性...")
|
||||||
to_delete_emojis: list[Tuple[MaiEmoji, bool]] = []
|
to_delete_emojis: list[tuple[MaiEmoji, bool]] = []
|
||||||
removal_count = 0
|
removal_count = 0
|
||||||
for emoji in self.emojis:
|
for emoji in self.emojis:
|
||||||
if not emoji.full_path.exists():
|
if not emoji.full_path.exists():
|
||||||
@@ -945,59 +950,34 @@ class EmojiManager:
|
|||||||
|
|
||||||
logger.info(f"[完整性检查] 表情包文件完整性检查完成,删除了 {removal_count} 条记录")
|
logger.info(f"[完整性检查] 表情包文件完整性检查完成,删除了 {removal_count} 条记录")
|
||||||
|
|
||||||
def remove_untracked_emoji_files(self) -> None:
|
|
||||||
"""
|
|
||||||
删除未被数据库记录跟踪的表情包文件
|
|
||||||
"""
|
|
||||||
logger.info("[未跟踪表情包清理] 开始清理未被数据库记录跟踪的表情包文件...")
|
|
||||||
tracked_files = {emoji.full_path.name for emoji in self.emojis}
|
|
||||||
all_files = set(EMOJI_REGISTERED_DIR.glob("*"))
|
|
||||||
removal_count = 0
|
|
||||||
|
|
||||||
for file_path in all_files:
|
|
||||||
if file_path.name not in tracked_files:
|
|
||||||
try:
|
|
||||||
file_path.unlink()
|
|
||||||
removal_count += 1
|
|
||||||
logger.info(f"[未跟踪表情包清理] 删除未跟踪的表情包文件: {file_path.name}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[未跟踪表情包清理] 删除文件 {file_path.name} 时出错: {e}")
|
|
||||||
|
|
||||||
logger.info(f"[未跟踪表情包清理] 未跟踪表情包文件清理完成,删除了 {removal_count} 个文件")
|
|
||||||
|
|
||||||
async def periodic_emoji_maintenance(self) -> None:
|
async def periodic_emoji_maintenance(self) -> None:
|
||||||
"""
|
"""Run emoji maintenance tasks periodically."""
|
||||||
定期执行表情包维护任务,包括完整性检查和未跟踪文件清理
|
|
||||||
"""
|
|
||||||
while True:
|
while True:
|
||||||
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
|
_ensure_directories()
|
||||||
EMOJI_REGISTERED_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
try:
|
try:
|
||||||
self.check_emoji_file_integrity()
|
self.check_emoji_file_integrity()
|
||||||
self.remove_untracked_emoji_files()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[定期维护] 执行表情包维护任务时出错: {e}")
|
logger.error(f"[emoji_maintenance] Maintenance task failed: {e}")
|
||||||
|
|
||||||
if global_config.emoji.steal_emoji and (
|
if global_config.emoji.steal_emoji and (
|
||||||
self._emoji_num < global_config.emoji.max_reg_num
|
self._emoji_num < global_config.emoji.max_reg_num
|
||||||
or (self._emoji_num > global_config.emoji.max_reg_num and global_config.emoji.do_replace)
|
or (self._emoji_num > global_config.emoji.max_reg_num and global_config.emoji.do_replace)
|
||||||
):
|
):
|
||||||
logger.info("[定期维护] 尝试从表情包盗取目录注册新表情包...")
|
logger.info("[emoji_maintenance] Scanning data/emoji for new emojis...")
|
||||||
for emoji_file in EMOJI_DIR.iterdir():
|
for emoji_file in EMOJI_DIR.iterdir():
|
||||||
if not emoji_file.is_file():
|
if not emoji_file.is_file():
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
register_success = await self.register_emoji_by_filename(emoji_file)
|
register_status = await self.register_emoji_by_filename(emoji_file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[定期维护] 注册表情包 {emoji_file.name} 时发生未处理异常: {e}")
|
logger.error(f"[emoji_maintenance] Failed to process {emoji_file.name}: {e}")
|
||||||
register_success = False
|
register_status = "failed"
|
||||||
if register_success:
|
if register_status == "registered":
|
||||||
break # 每次只注册一个表情包
|
break
|
||||||
try:
|
if register_status == "skipped":
|
||||||
emoji_file.unlink()
|
logger.debug(f"[emoji_maintenance] Emoji already registered, keep file: {emoji_file.name}")
|
||||||
logger.info(f"[定期维护] 删除无法注册的表情包文件: {emoji_file.name}")
|
else:
|
||||||
except Exception as e:
|
logger.debug(f"[emoji_maintenance] Emoji not registered, keep file: {emoji_file.name}")
|
||||||
logger.error(f"[定期维护] 删除文件 {emoji_file.name} 时出错: {e}")
|
|
||||||
wait_seconds = max(global_config.emoji.check_interval * 60, 0)
|
wait_seconds = max(global_config.emoji.check_interval * 60, 0)
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self._maintenance_wakeup_event.wait(), timeout=wait_seconds)
|
await asyncio.wait_for(self._maintenance_wakeup_event.wait(), timeout=wait_seconds)
|
||||||
@@ -1006,84 +986,81 @@ class EmojiManager:
|
|||||||
finally:
|
finally:
|
||||||
self._maintenance_wakeup_event.clear()
|
self._maintenance_wakeup_event.clear()
|
||||||
|
|
||||||
async def register_emoji_by_filename(self, filename: Path | str) -> bool:
|
async def register_emoji_by_filename(self, filename: Path | str) -> EmojiRegisterStatus:
|
||||||
"""
|
"""Register an emoji file from ``data/emoji`` without moving or deleting it."""
|
||||||
根据指定的表情包图片,分析并注册到数据库
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename (Path | str): 表情包图片的完整文件路径(可能根据文件实际格式修正)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
return (bool): 注册是否成功
|
|
||||||
"""
|
|
||||||
file_full_path = Path(filename).absolute().resolve()
|
file_full_path = Path(filename).absolute().resolve()
|
||||||
if not file_full_path.exists():
|
if not file_full_path.exists():
|
||||||
logger.error(f"[注册表情包] 表情包文件不存在: {file_full_path}")
|
logger.error(f"[register_emoji] Emoji file does not exist: {file_full_path}")
|
||||||
return False
|
return "failed"
|
||||||
try:
|
try:
|
||||||
target_emoji = MaiEmoji(full_path=file_full_path)
|
target_emoji = MaiEmoji(full_path=file_full_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[注册表情包] 创建表情包对象时出错: {e}")
|
logger.error(f"[register_emoji] Failed to create emoji object: {e}")
|
||||||
return False
|
return "failed"
|
||||||
|
|
||||||
calc_success = await target_emoji.calculate_hash_format()
|
calc_success = await target_emoji.calculate_hash_format()
|
||||||
if not calc_success:
|
if not calc_success:
|
||||||
logger.error(f"[注册表情包] 计算表情包哈希值和格式失败: {file_full_path}")
|
logger.error(f"[register_emoji] Failed to calculate hash and format: {file_full_path}")
|
||||||
return False
|
return "failed"
|
||||||
file_full_path = target_emoji.full_path # 更新为可能修正后的路径
|
file_full_path = target_emoji.full_path
|
||||||
|
|
||||||
# 2. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建
|
existing_record: Optional[Images] = None
|
||||||
try:
|
try:
|
||||||
with get_db_session_manual() as session:
|
with get_db_session_manual() as session:
|
||||||
statement = (
|
statement = (
|
||||||
select(Images).filter_by(image_hash=target_emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
select(Images).filter_by(image_hash=target_emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||||
)
|
)
|
||||||
if image_record := session.exec(statement).first():
|
existing_record = session.exec(statement).first()
|
||||||
if image_record.no_file_flag:
|
|
||||||
image_record.no_file_flag = False
|
|
||||||
image_record.is_banned = False
|
|
||||||
image_record.is_registered = True
|
|
||||||
image_record.full_path = str(target_emoji.full_path)
|
|
||||||
session.add(image_record)
|
|
||||||
session.commit()
|
|
||||||
logger.info(f"表情包注册成功,Hash: {target_emoji.file_hash}")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.warning(f"[注册表情包] 数据库中已存在表情包记录,跳过注册: {target_emoji.file_name}")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[注册表情包] 查询数据库时出错: {e}")
|
logger.error(f"[register_emoji] Failed to query database: {e}")
|
||||||
return False
|
return "failed"
|
||||||
|
|
||||||
|
if existing_record is not None:
|
||||||
|
if existing_record.is_registered and _is_available_emoji_record(existing_record):
|
||||||
|
logger.info(f"[register_emoji] Emoji already registered, skipping: {target_emoji.file_name}")
|
||||||
|
return "skipped"
|
||||||
|
|
||||||
|
cached_description = str(existing_record.description or "").strip()
|
||||||
|
if cached_description:
|
||||||
|
normalized_emotions = _normalize_emoji_tag_text(cached_description)
|
||||||
|
target_emoji.description = ",".join(normalized_emotions)
|
||||||
|
target_emoji.emotion = normalized_emotions
|
||||||
|
target_emoji.query_count = existing_record.query_count
|
||||||
|
target_emoji.last_used_time = existing_record.last_used_time
|
||||||
|
target_emoji.register_time = existing_record.register_time
|
||||||
|
|
||||||
# 3. 检查内存缓存是否已经存在
|
|
||||||
if existing_emoji := self.get_emoji_by_hash(target_emoji.file_hash):
|
if existing_emoji := self.get_emoji_by_hash(target_emoji.file_hash):
|
||||||
logger.warning(f"[注册表情包] 表情包已存在,跳过注册: {existing_emoji.file_name}")
|
logger.info(f"[register_emoji] Emoji already loaded in memory, skipping: {existing_emoji.file_name}")
|
||||||
return False
|
return "skipped"
|
||||||
# 3. 构建描述(包含情绪标签)
|
|
||||||
desc_success, target_emoji = await self.build_emoji_description(target_emoji)
|
if not target_emoji.description:
|
||||||
if not desc_success:
|
desc_success, target_emoji = await self.build_emoji_description(target_emoji)
|
||||||
logger.error(f"[注册表情包] 构建表情包描述失败: {file_full_path}")
|
if not desc_success:
|
||||||
return False
|
logger.error(f"[register_emoji] Failed to build emoji description: {file_full_path}")
|
||||||
|
return "failed"
|
||||||
|
|
||||||
# 4. 检查容量并决定是否替换或者直接注册
|
|
||||||
if self._emoji_num >= global_config.emoji.max_reg_num and global_config.emoji.do_replace:
|
if self._emoji_num >= global_config.emoji.max_reg_num and global_config.emoji.do_replace:
|
||||||
logger.warning(f"[注册表情包] 表情包数量已达上限{global_config.emoji.max_reg_num},尝试替换一个表情包")
|
logger.warning(
|
||||||
|
f"[register_emoji] Emoji limit {global_config.emoji.max_reg_num} reached, trying replacement"
|
||||||
|
)
|
||||||
replaced = await self.replace_an_emoji_by_llm(target_emoji)
|
replaced = await self.replace_an_emoji_by_llm(target_emoji)
|
||||||
if not replaced:
|
if not replaced:
|
||||||
logger.error("[注册表情包] 替换表情包失败,无法注册新表情包")
|
logger.error("[register_emoji] Failed to replace existing emoji")
|
||||||
return False
|
return "failed"
|
||||||
return True
|
return "registered"
|
||||||
else:
|
|
||||||
if self.register_emoji_to_db(target_emoji):
|
|
||||||
self.emojis.append(target_emoji)
|
|
||||||
self._emoji_num += 1
|
|
||||||
logger.info(f"[注册表情包] 成功注册新表情包: {target_emoji.file_name}")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.error(f"[注册表情包] 注册表情包到数据库失败: {file_full_path}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _calculate_emotion_similarity_list(self, text_emotion: str) -> List[Tuple[MaiEmoji, float]]:
|
register_status = self.register_emoji_to_db(target_emoji)
|
||||||
|
if register_status == "registered":
|
||||||
|
self.emojis.append(target_emoji)
|
||||||
|
self._emoji_num = len(self.emojis)
|
||||||
|
logger.info(f"[register_emoji] Registered new emoji: {target_emoji.file_name}")
|
||||||
|
elif register_status == "failed":
|
||||||
|
logger.error(f"[register_emoji] Failed to register emoji in database: {file_full_path}")
|
||||||
|
else:
|
||||||
|
logger.info(f"[register_emoji] Emoji already registered, skipping: {target_emoji.file_name}")
|
||||||
|
return register_status
|
||||||
|
|
||||||
|
def _calculate_emotion_similarity_list(self, text_emotion: str) -> list[tuple[MaiEmoji, float]]:
|
||||||
"""
|
"""
|
||||||
计算文本情感标签与所有表情包情感标签的相似度列表
|
计算文本情感标签与所有表情包情感标签的相似度列表
|
||||||
|
|
||||||
@@ -1096,7 +1073,7 @@ class EmojiManager:
|
|||||||
if not normalized_text_emotion:
|
if not normalized_text_emotion:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
similarity_list: List[Tuple[MaiEmoji, float]] = []
|
similarity_list: list[tuple[MaiEmoji, float]] = []
|
||||||
for emoji in self.emojis:
|
for emoji in self.emojis:
|
||||||
candidate_emotions = _get_emoji_emotions(emoji)
|
candidate_emotions = _get_emoji_emotions(emoji)
|
||||||
if not candidate_emotions:
|
if not candidate_emotions:
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Maisaka 表情工具内置能力。"""
|
"""Maisaka 表情工具内置能力。"""
|
||||||
|
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable, Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Optional, Sequence, TYPE_CHECKING
|
from typing import Any, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
@@ -120,8 +120,6 @@ def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
|
|||||||
"""提取并清洗单个表情的情绪标签。"""
|
"""提取并清洗单个表情的情绪标签。"""
|
||||||
if emoji.description:
|
if emoji.description:
|
||||||
return _normalize_emoji_tag_text(emoji.description)
|
return _normalize_emoji_tag_text(emoji.description)
|
||||||
if emoji.emotion:
|
|
||||||
return _normalize_emoji_tag_text(emoji.emotion)
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ class ImageManager:
|
|||||||
"""初始化图片管理器。"""
|
"""初始化图片管理器。"""
|
||||||
_ensure_image_dir_exists()
|
_ensure_image_dir_exists()
|
||||||
self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {}
|
self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {}
|
||||||
|
self.cleanup_legacy_image_registration_records()
|
||||||
|
|
||||||
logger.info("图片管理器初始化完成")
|
logger.info("图片管理器初始化完成")
|
||||||
|
|
||||||
@@ -50,6 +51,17 @@ class ImageManager:
|
|||||||
statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1)
|
statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1)
|
||||||
return session.exec(statement).first()
|
return session.exec(statement).first()
|
||||||
|
|
||||||
|
def _normalize_image_registration_fields(self, record: Images) -> bool:
|
||||||
|
"""Normalize accidental emoji registration fields on image records."""
|
||||||
|
if record.image_type != ImageType.IMAGE:
|
||||||
|
return False
|
||||||
|
if not record.is_registered and record.register_time is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
record.is_registered = False
|
||||||
|
record.register_time = None
|
||||||
|
return True
|
||||||
|
|
||||||
async def get_image_description(
|
async def get_image_description(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -182,8 +194,9 @@ class ImageManager:
|
|||||||
try:
|
try:
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
record = image.to_db_instance()
|
record = image.to_db_instance()
|
||||||
record.is_registered = True
|
record.is_registered = False
|
||||||
record.register_time = record.last_used_time = datetime.now()
|
record.register_time = None
|
||||||
|
record.last_used_time = datetime.now()
|
||||||
session.add(record)
|
session.add(record)
|
||||||
session.flush() # 确保记录被写入数据库以获取ID
|
session.flush() # 确保记录被写入数据库以获取ID
|
||||||
record_id = record.id
|
record_id = record.id
|
||||||
@@ -209,6 +222,7 @@ class ImageManager:
|
|||||||
if not record:
|
if not record:
|
||||||
logger.error(f"未找到哈希值为 {image.file_hash} 的图片记录,无法更新描述")
|
logger.error(f"未找到哈希值为 {image.file_hash} 的图片记录,无法更新描述")
|
||||||
return False
|
return False
|
||||||
|
self._normalize_image_registration_fields(record)
|
||||||
record.description = image.description
|
record.description = image.description
|
||||||
record.last_used_time = datetime.now()
|
record.last_used_time = datetime.now()
|
||||||
record.vlm_processed = image.vlm_processed
|
record.vlm_processed = image.vlm_processed
|
||||||
@@ -258,6 +272,7 @@ class ImageManager:
|
|||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1)
|
statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1)
|
||||||
if record := session.exec(statement).first():
|
if record := session.exec(statement).first():
|
||||||
|
self._normalize_image_registration_fields(record)
|
||||||
logger.info(f"图片已存在于数据库中,哈希值: {hash_str}")
|
logger.info(f"图片已存在于数据库中,哈希值: {hash_str}")
|
||||||
record.last_used_time = datetime.now()
|
record.last_used_time = datetime.now()
|
||||||
record.query_count += 1
|
record.query_count += 1
|
||||||
@@ -335,6 +350,24 @@ class ImageManager:
|
|||||||
|
|
||||||
logger.info(f"清理完成: {invalid_counter} 条无效描述记录,{null_path_counter} 条文件路径不存在记录")
|
logger.info(f"清理完成: {invalid_counter} 条无效描述记录,{null_path_counter} 条文件路径不存在记录")
|
||||||
|
|
||||||
|
def cleanup_legacy_image_registration_records(self) -> None:
|
||||||
|
"""Clean up legacy image records with mistaken registration fields."""
|
||||||
|
fixed_counter = 0
|
||||||
|
try:
|
||||||
|
with get_db_session() as session:
|
||||||
|
statement = select(Images).filter_by(image_type=ImageType.IMAGE)
|
||||||
|
for record in session.exec(statement).yield_per(100):
|
||||||
|
if not self._normalize_image_registration_fields(record):
|
||||||
|
continue
|
||||||
|
session.add(record)
|
||||||
|
fixed_counter += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to clean image registration state: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if fixed_counter:
|
||||||
|
logger.info(f"Cleaned mistaken registration state on {fixed_counter} image records")
|
||||||
|
|
||||||
async def _generate_image_description(self, image_bytes: bytes, image_format: str) -> str:
|
async def _generate_image_description(self, image_bytes: bytes, image_format: str) -> str:
|
||||||
prompt = global_config.personality.visual_style
|
prompt = global_config.personality.visual_style
|
||||||
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||||
|
|||||||
@@ -1,14 +1,20 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from src.common.database.database_model import ActionRecord
|
from src.common.database.database_model import ToolRecord
|
||||||
|
|
||||||
from . import BaseDatabaseDataModel
|
from . import BaseDatabaseDataModel
|
||||||
|
|
||||||
|
|
||||||
class MaiActionRecord(BaseDatabaseDataModel[ActionRecord]):
|
class MaiActionRecord(BaseDatabaseDataModel[ToolRecord]):
|
||||||
|
"""``action_records`` 的兼容数据模型。
|
||||||
|
|
||||||
|
历史动作记录已统一并入 ``tool_records``,该类仅保留旧命名接口,
|
||||||
|
底层读写对象统一映射为 ``ToolRecord``。
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
action_id: str,
|
action_id: str,
|
||||||
@@ -21,45 +27,39 @@ class MaiActionRecord(BaseDatabaseDataModel[ActionRecord]):
|
|||||||
action_display_prompt: Optional[str] = None,
|
action_display_prompt: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.action_id = action_id
|
self.action_id = action_id
|
||||||
"""动作ID"""
|
|
||||||
self.timestamp = timestamp
|
self.timestamp = timestamp
|
||||||
"""时间戳"""
|
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
"""会话ID"""
|
|
||||||
self.action_name = action_name
|
self.action_name = action_name
|
||||||
"""动作名称"""
|
|
||||||
self.action_reasoning = action_reasoning
|
self.action_reasoning = action_reasoning
|
||||||
"""动作推理过程"""
|
|
||||||
self.action_data = action_data or {}
|
self.action_data = action_data or {}
|
||||||
"""动作数据"""
|
|
||||||
self.action_builtin_prompt = action_builtin_prompt
|
self.action_builtin_prompt = action_builtin_prompt
|
||||||
"""内置动作提示"""
|
|
||||||
self.action_display_prompt = action_display_prompt
|
self.action_display_prompt = action_display_prompt
|
||||||
"""最终输入到 Prompt 的内容"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db_instance(cls, db_record: ActionRecord):
|
def from_db_instance(cls, db_record: ToolRecord):
|
||||||
"""Create a data model object from a database record."""
|
"""从数据库实例创建兼容数据模型对象。"""
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
action_id=db_record.action_id,
|
action_id=db_record.tool_id,
|
||||||
timestamp=db_record.timestamp,
|
timestamp=db_record.timestamp,
|
||||||
session_id=db_record.session_id,
|
session_id=db_record.session_id,
|
||||||
action_name=db_record.action_name,
|
action_name=db_record.tool_name,
|
||||||
action_reasoning=db_record.action_reasoning,
|
action_reasoning=db_record.tool_reasoning,
|
||||||
action_data=json.loads(db_record.action_data) if db_record.action_data else None,
|
action_data=json.loads(db_record.tool_data) if db_record.tool_data else None,
|
||||||
action_builtin_prompt=db_record.action_builtin_prompt,
|
action_builtin_prompt=db_record.tool_builtin_prompt,
|
||||||
action_display_prompt=db_record.action_display_prompt,
|
action_display_prompt=db_record.tool_display_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_db_instance(self):
|
def to_db_instance(self):
|
||||||
"""Convert the data model object back to a database instance."""
|
"""将兼容数据模型对象转换为 ``ToolRecord``。"""
|
||||||
return ActionRecord(
|
|
||||||
action_id=self.action_id,
|
return ToolRecord(
|
||||||
|
tool_id=self.action_id,
|
||||||
timestamp=self.timestamp,
|
timestamp=self.timestamp,
|
||||||
session_id=self.session_id,
|
session_id=self.session_id,
|
||||||
action_name=self.action_name,
|
tool_name=self.action_name,
|
||||||
action_reasoning=self.action_reasoning,
|
tool_reasoning=self.action_reasoning,
|
||||||
action_data=json.dumps(self.action_data) if self.action_data else None,
|
tool_data=json.dumps(self.action_data) if self.action_data else None,
|
||||||
action_builtin_prompt=self.action_builtin_prompt,
|
tool_builtin_prompt=self.action_builtin_prompt,
|
||||||
action_display_prompt=self.action_display_prompt,
|
tool_display_prompt=self.action_display_prompt,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from PIL import Image as PILImage
|
|
||||||
from rich.traceback import install
|
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
import io
|
||||||
import traceback
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
from src.common.database.database_model import Images, ImageType
|
from src.common.database.database_model import Images, ImageType
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
from . import BaseDatabaseDataModel
|
from . import BaseDatabaseDataModel
|
||||||
|
|
||||||
|
|
||||||
@@ -152,7 +153,7 @@ class MaiEmoji(BaseImageDataModel):
|
|||||||
raise ValueError(f"数据库记录 {db_record.image_hash} 标记为文件不存在,无法创建 MaiEmoji 对象")
|
raise ValueError(f"数据库记录 {db_record.image_hash} 标记为文件不存在,无法创建 MaiEmoji 对象")
|
||||||
obj = cls(db_record.full_path)
|
obj = cls(db_record.full_path)
|
||||||
obj.file_hash = db_record.image_hash
|
obj.file_hash = db_record.image_hash
|
||||||
description = db_record.description or db_record.emotion or ""
|
description = db_record.description or ""
|
||||||
obj.description = description
|
obj.description = description
|
||||||
normalized_tags = [
|
normalized_tags = [
|
||||||
str(item).strip()
|
str(item).strip()
|
||||||
@@ -175,7 +176,6 @@ class MaiEmoji(BaseImageDataModel):
|
|||||||
description=self.description,
|
description=self.description,
|
||||||
full_path=str(self.full_path),
|
full_path=str(self.full_path),
|
||||||
image_type=ImageType.EMOJI,
|
image_type=ImageType.EMOJI,
|
||||||
emotion=None,
|
|
||||||
query_count=self.query_count,
|
query_count=self.query_count,
|
||||||
last_used_time=self.last_used_time,
|
last_used_time=self.last_used_time,
|
||||||
register_time=self.register_time,
|
register_time=self.register_time,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
|||||||
from typing import ContextManager, Generator, TYPE_CHECKING
|
from typing import ContextManager, Generator, TYPE_CHECKING
|
||||||
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from sqlalchemy import event, text
|
from sqlalchemy import event
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from sqlmodel import SQLModel, Session, create_engine
|
from sqlmodel import SQLModel, Session, create_engine
|
||||||
@@ -63,41 +63,6 @@ _migration_bootstrapper = create_database_migration_bootstrapper(engine)
|
|||||||
_db_initialized = False
|
_db_initialized = False
|
||||||
|
|
||||||
|
|
||||||
def _migrate_action_records_to_tool_records() -> None:
|
|
||||||
"""将旧的 ``action_records`` 历史数据迁移到 ``tool_records``。"""
|
|
||||||
migration_sql = text(
|
|
||||||
"""
|
|
||||||
INSERT INTO tool_records (
|
|
||||||
tool_id,
|
|
||||||
timestamp,
|
|
||||||
session_id,
|
|
||||||
tool_name,
|
|
||||||
tool_reasoning,
|
|
||||||
tool_data,
|
|
||||||
tool_builtin_prompt,
|
|
||||||
tool_display_prompt
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
action_id,
|
|
||||||
timestamp,
|
|
||||||
session_id,
|
|
||||||
action_name,
|
|
||||||
action_reasoning,
|
|
||||||
action_data,
|
|
||||||
action_builtin_prompt,
|
|
||||||
action_display_prompt
|
|
||||||
FROM action_records
|
|
||||||
WHERE NOT EXISTS (
|
|
||||||
SELECT 1
|
|
||||||
FROM tool_records
|
|
||||||
WHERE tool_records.tool_id = action_records.action_id
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
with engine.begin() as connection:
|
|
||||||
connection.execute(migration_sql)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_database() -> None:
|
def initialize_database() -> None:
|
||||||
"""初始化数据库连接、结构与启动期迁移。
|
"""初始化数据库连接、结构与启动期迁移。
|
||||||
|
|
||||||
@@ -105,8 +70,7 @@ def initialize_database() -> None:
|
|||||||
1. 确保数据库目录存在;
|
1. 确保数据库目录存在;
|
||||||
2. 加载 SQLModel 模型定义;
|
2. 加载 SQLModel 模型定义;
|
||||||
3. 执行已注册的启动期迁移;
|
3. 执行已注册的启动期迁移;
|
||||||
4. 兜底执行 ``create_all`` 确保当前模型定义已建表;
|
4. 兜底执行 ``create_all`` 确保当前模型定义已建表。
|
||||||
5. 执行项目现有的轻量数据补迁移逻辑。
|
|
||||||
"""
|
"""
|
||||||
global _db_initialized
|
global _db_initialized
|
||||||
if _db_initialized:
|
if _db_initialized:
|
||||||
@@ -120,7 +84,6 @@ def initialize_database() -> None:
|
|||||||
f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}"
|
f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}"
|
||||||
)
|
)
|
||||||
SQLModel.metadata.create_all(engine)
|
SQLModel.metadata.create_all(engine)
|
||||||
_migrate_action_records_to_tool_records()
|
|
||||||
_migration_bootstrapper.finalize_database(migration_state)
|
_migration_bootstrapper.finalize_database(migration_state)
|
||||||
_db_initialized = True
|
_db_initialized = True
|
||||||
|
|
||||||
|
|||||||
@@ -94,7 +94,6 @@ class Images(SQLModel, table=True):
|
|||||||
full_path: str = Field(max_length=1024) # 文件的完整路径 (包括文件名)
|
full_path: str = Field(max_length=1024) # 文件的完整路径 (包括文件名)
|
||||||
image_type: ImageType = Field(sa_column=Column(SQLEnum(ImageType)), default=ImageType.EMOJI)
|
image_type: ImageType = Field(sa_column=Column(SQLEnum(ImageType)), default=ImageType.EMOJI)
|
||||||
"""图片类型,例如 'emoji' 或 'image'"""
|
"""图片类型,例如 'emoji' 或 'image'"""
|
||||||
emotion: Optional[str] = Field(default=None, nullable=True) # 表情包的情感标签,逗号分隔
|
|
||||||
|
|
||||||
query_count: int = Field(default=0) # 被查询次数
|
query_count: int = Field(default=0) # 被查询次数
|
||||||
is_registered: bool = Field(default=False) # 是否已经注册
|
is_registered: bool = Field(default=False) # 是否已经注册
|
||||||
@@ -113,27 +112,6 @@ class Images(SQLModel, table=True):
|
|||||||
vlm_processed: bool = Field(default=False) # 是否已经过VLM处理
|
vlm_processed: bool = Field(default=False) # 是否已经过VLM处理
|
||||||
|
|
||||||
|
|
||||||
class ActionRecord(SQLModel, table=True):
|
|
||||||
"""存储动作记录"""
|
|
||||||
|
|
||||||
__tablename__ = "action_records" # type: ignore
|
|
||||||
|
|
||||||
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
|
||||||
|
|
||||||
# 元信息
|
|
||||||
action_id: str = Field(index=True, max_length=255) # 动作ID
|
|
||||||
timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 记录时间戳
|
|
||||||
session_id: str = Field(index=True, max_length=255) # 对应的 ChatSession session_id
|
|
||||||
|
|
||||||
# 调用信息
|
|
||||||
action_name: str = Field(index=True, max_length=255) # 动作名称
|
|
||||||
action_reasoning: Optional[str] = Field(default=None) # 动作推理过程
|
|
||||||
action_data: Optional[str] = Field(default=None) # 动作数据,JSON格式存储
|
|
||||||
|
|
||||||
action_builtin_prompt: Optional[str] = Field(default=None) # 内置动作提示
|
|
||||||
action_display_prompt: Optional[str] = Field(default=None) # 最终输入到Prompt的内容
|
|
||||||
|
|
||||||
|
|
||||||
class ToolRecord(SQLModel, table=True):
|
class ToolRecord(SQLModel, table=True):
|
||||||
"""存储工具调用记录"""
|
"""存储工具调用记录"""
|
||||||
|
|
||||||
@@ -281,28 +259,6 @@ class ChatHistory(SQLModel, table=True):
|
|||||||
summary: str # 概括:对这段话的平文本概括
|
summary: str # 概括:对这段话的平文本概括
|
||||||
|
|
||||||
|
|
||||||
class ThinkingQuestion(SQLModel, table=True):
|
|
||||||
"""存储思考型问题的模型"""
|
|
||||||
|
|
||||||
__tablename__ = "thinking_questions" # type: ignore
|
|
||||||
|
|
||||||
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
|
||||||
|
|
||||||
# 问答对
|
|
||||||
question: str # 问题内容
|
|
||||||
context: Optional[str] = Field(default=None, nullable=True) # 上下文
|
|
||||||
found_answer: bool = Field(default=False) # 是否找到答案
|
|
||||||
answer: Optional[str] = Field(default=None, nullable=True) # 问题答案
|
|
||||||
|
|
||||||
thinking_steps: Optional[str] = Field(default=None, nullable=True) # 思考步骤,JSON格式存储
|
|
||||||
created_timestamp: datetime = Field(
|
|
||||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
|
||||||
) # 创建时间
|
|
||||||
updated_timestamp: datetime = Field(
|
|
||||||
default_factory=datetime.now, sa_column=Column(DateTime, index=True)
|
|
||||||
) # 最后更新时间
|
|
||||||
|
|
||||||
|
|
||||||
class BinaryData(SQLModel, table=True):
|
class BinaryData(SQLModel, table=True):
|
||||||
"""存储二进制数据的模型"""
|
"""存储二进制数据的模型"""
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from .builtin import (
|
|||||||
EMPTY_SCHEMA_VERSION,
|
EMPTY_SCHEMA_VERSION,
|
||||||
LATEST_SCHEMA_VERSION,
|
LATEST_SCHEMA_VERSION,
|
||||||
LEGACY_V1_SCHEMA_VERSION,
|
LEGACY_V1_SCHEMA_VERSION,
|
||||||
|
V2_SCHEMA_VERSION,
|
||||||
build_default_migration_registry,
|
build_default_migration_registry,
|
||||||
build_default_schema_version_resolver,
|
build_default_schema_version_resolver,
|
||||||
)
|
)
|
||||||
@@ -61,6 +62,7 @@ __all__ = [
|
|||||||
"EMPTY_SCHEMA_VERSION",
|
"EMPTY_SCHEMA_VERSION",
|
||||||
"LATEST_SCHEMA_VERSION",
|
"LATEST_SCHEMA_VERSION",
|
||||||
"LEGACY_V1_SCHEMA_VERSION",
|
"LEGACY_V1_SCHEMA_VERSION",
|
||||||
|
"V2_SCHEMA_VERSION",
|
||||||
"MigrationExecutionContext",
|
"MigrationExecutionContext",
|
||||||
"MigrationPlan",
|
"MigrationPlan",
|
||||||
"MigrationPlanner",
|
"MigrationPlanner",
|
||||||
|
|||||||
@@ -6,12 +6,14 @@ from .legacy_v1_to_v2 import migrate_legacy_v1_to_v2
|
|||||||
from .models import DatabaseSchemaSnapshot, MigrationStep
|
from .models import DatabaseSchemaSnapshot, MigrationStep
|
||||||
from .registry import MigrationRegistry
|
from .registry import MigrationRegistry
|
||||||
from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver
|
from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver
|
||||||
from .version_store import SQLiteUserVersionStore
|
|
||||||
from .schema import SQLiteSchemaInspector
|
from .schema import SQLiteSchemaInspector
|
||||||
|
from .v2_to_v3 import migrate_v2_to_v3
|
||||||
|
from .version_store import SQLiteUserVersionStore
|
||||||
|
|
||||||
EMPTY_SCHEMA_VERSION = 0
|
EMPTY_SCHEMA_VERSION = 0
|
||||||
LEGACY_V1_SCHEMA_VERSION = 1
|
LEGACY_V1_SCHEMA_VERSION = 1
|
||||||
LATEST_SCHEMA_VERSION = 2
|
V2_SCHEMA_VERSION = 2
|
||||||
|
LATEST_SCHEMA_VERSION = 3
|
||||||
|
|
||||||
_LEGACY_V1_EXCLUSIVE_TABLES = (
|
_LEGACY_V1_EXCLUSIVE_TABLES = (
|
||||||
"chat_streams",
|
"chat_streams",
|
||||||
@@ -24,6 +26,13 @@ _LEGACY_V1_EXCLUSIVE_TABLES = (
|
|||||||
"messages",
|
"messages",
|
||||||
"thinking_back",
|
"thinking_back",
|
||||||
)
|
)
|
||||||
|
_COMMON_MARKER_TABLES = (
|
||||||
|
"mai_messages",
|
||||||
|
"chat_sessions",
|
||||||
|
"expressions",
|
||||||
|
"jargons",
|
||||||
|
"tool_records",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
|
class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
|
||||||
@@ -36,6 +45,7 @@ class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
|
|||||||
Returns:
|
Returns:
|
||||||
str: 当前探测器名称。
|
str: 当前探测器名称。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return "latest_schema_detector"
|
return "latest_schema_detector"
|
||||||
|
|
||||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||||
@@ -47,18 +57,16 @@ class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[int]: 若识别为最新结构则返回最新版本号,否则返回 ``None``。
|
Optional[int]: 若识别为最新结构则返回最新版本号,否则返回 ``None``。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
|
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
|
||||||
return None
|
return None
|
||||||
|
if not all(snapshot.has_table(table_name) for table_name in _COMMON_MARKER_TABLES):
|
||||||
latest_marker_tables = (
|
return None
|
||||||
"mai_messages",
|
if snapshot.has_table("action_records"):
|
||||||
"chat_sessions",
|
return None
|
||||||
"expressions",
|
if snapshot.has_table("thinking_questions"):
|
||||||
"jargons",
|
return None
|
||||||
"thinking_questions",
|
if snapshot.has_column("images", "emotion"):
|
||||||
"tool_records",
|
|
||||||
)
|
|
||||||
if not all(snapshot.has_table(table_name) for table_name in latest_marker_tables):
|
|
||||||
return None
|
return None
|
||||||
if not snapshot.has_column("images", "image_hash"):
|
if not snapshot.has_column("images", "image_hash"):
|
||||||
return None
|
return None
|
||||||
@@ -66,13 +74,53 @@ class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
|
|||||||
return None
|
return None
|
||||||
if not snapshot.has_column("images", "image_type"):
|
if not snapshot.has_column("images", "image_type"):
|
||||||
return None
|
return None
|
||||||
|
if not snapshot.has_column("chat_history", "session_id"):
|
||||||
|
return None
|
||||||
|
if not snapshot.has_column("person_info", "user_nickname"):
|
||||||
|
return None
|
||||||
|
return LATEST_SCHEMA_VERSION
|
||||||
|
|
||||||
|
|
||||||
|
class V2SchemaVersionDetector(BaseSchemaVersionDetector):
|
||||||
|
"""v2 schema 结构探测器。"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""返回探测器名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 当前探测器名称。
|
||||||
|
"""
|
||||||
|
|
||||||
|
return "v2_schema_detector"
|
||||||
|
|
||||||
|
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||||
|
"""检测数据库是否为 v2 结构。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
snapshot: 当前数据库结构快照。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[int]: 若识别为 v2 结构则返回 ``2``,否则返回 ``None``。
|
||||||
|
"""
|
||||||
|
|
||||||
|
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
|
||||||
|
return None
|
||||||
|
if not all(snapshot.has_table(table_name) for table_name in _COMMON_MARKER_TABLES):
|
||||||
|
return None
|
||||||
|
if not snapshot.has_table("action_records"):
|
||||||
|
return None
|
||||||
|
if not snapshot.has_table("thinking_questions"):
|
||||||
|
return None
|
||||||
|
if not snapshot.has_column("images", "emotion"):
|
||||||
|
return None
|
||||||
if not snapshot.has_column("action_records", "session_id"):
|
if not snapshot.has_column("action_records", "session_id"):
|
||||||
return None
|
return None
|
||||||
if not snapshot.has_column("chat_history", "session_id"):
|
if not snapshot.has_column("chat_history", "session_id"):
|
||||||
return None
|
return None
|
||||||
if not snapshot.has_column("person_info", "user_nickname"):
|
if not snapshot.has_column("person_info", "user_nickname"):
|
||||||
return None
|
return None
|
||||||
return LATEST_SCHEMA_VERSION
|
return V2_SCHEMA_VERSION
|
||||||
|
|
||||||
|
|
||||||
class LegacyV1SchemaDetector(BaseSchemaVersionDetector):
|
class LegacyV1SchemaDetector(BaseSchemaVersionDetector):
|
||||||
@@ -85,6 +133,7 @@ class LegacyV1SchemaDetector(BaseSchemaVersionDetector):
|
|||||||
Returns:
|
Returns:
|
||||||
str: 当前探测器名称。
|
str: 当前探测器名称。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return "legacy_v1_schema_detector"
|
return "legacy_v1_schema_detector"
|
||||||
|
|
||||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||||
@@ -96,6 +145,7 @@ class LegacyV1SchemaDetector(BaseSchemaVersionDetector):
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[int]: 若识别为旧版结构则返回 ``1``,否则返回 ``None``。
|
Optional[int]: 若识别为旧版结构则返回 ``1``,否则返回 ``None``。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
|
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
|
||||||
return LEGACY_V1_SCHEMA_VERSION
|
return LEGACY_V1_SCHEMA_VERSION
|
||||||
|
|
||||||
@@ -121,8 +171,10 @@ def build_default_schema_version_detectors() -> List[BaseSchemaVersionDetector]:
|
|||||||
Returns:
|
Returns:
|
||||||
List[BaseSchemaVersionDetector]: 按优先级排序的探测器列表。
|
List[BaseSchemaVersionDetector]: 按优先级排序的探测器列表。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return [
|
return [
|
||||||
LatestSchemaVersionDetector(),
|
LatestSchemaVersionDetector(),
|
||||||
|
V2SchemaVersionDetector(),
|
||||||
LegacyV1SchemaDetector(),
|
LegacyV1SchemaDetector(),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -133,6 +185,7 @@ def build_default_schema_version_resolver() -> SchemaVersionResolver:
|
|||||||
Returns:
|
Returns:
|
||||||
SchemaVersionResolver: 配置完成的 schema 版本解析器。
|
SchemaVersionResolver: 配置完成的 schema 版本解析器。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return SchemaVersionResolver(
|
return SchemaVersionResolver(
|
||||||
version_store=SQLiteUserVersionStore(),
|
version_store=SQLiteUserVersionStore(),
|
||||||
schema_inspector=SQLiteSchemaInspector(),
|
schema_inspector=SQLiteSchemaInspector(),
|
||||||
@@ -146,14 +199,22 @@ def build_default_migration_registry() -> MigrationRegistry:
|
|||||||
Returns:
|
Returns:
|
||||||
MigrationRegistry: 含默认迁移步骤的注册表实例。
|
MigrationRegistry: 含默认迁移步骤的注册表实例。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return MigrationRegistry(
|
return MigrationRegistry(
|
||||||
steps=[
|
steps=[
|
||||||
MigrationStep(
|
MigrationStep(
|
||||||
version_from=LEGACY_V1_SCHEMA_VERSION,
|
version_from=LEGACY_V1_SCHEMA_VERSION,
|
||||||
version_to=LATEST_SCHEMA_VERSION,
|
version_to=V2_SCHEMA_VERSION,
|
||||||
name="legacy_v1_to_latest_v2",
|
name="legacy_v1_to_v2",
|
||||||
description="将旧版 0.x 数据库整体迁移到当前最新 schema。",
|
description="将旧版 0.x 数据库迁移到 v2 schema。",
|
||||||
handler=migrate_legacy_v1_to_v2,
|
handler=migrate_legacy_v1_to_v2,
|
||||||
)
|
),
|
||||||
|
MigrationStep(
|
||||||
|
version_from=V2_SCHEMA_VERSION,
|
||||||
|
version_to=LATEST_SCHEMA_VERSION,
|
||||||
|
name="v2_to_v3",
|
||||||
|
description="移除废弃表,并将 emoji 标签统一收敛到 description 字段。",
|
||||||
|
handler=migrate_v2_to_v3,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
298
src/common/database/migrations/frozen_v2_schema.py
Normal file
298
src/common/database/migrations/frozen_v2_schema.py
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
"""冻结的 v2 schema 快照。
|
||||||
|
|
||||||
|
该模块只用于 ``legacy_v1_to_v2`` 迁移,避免迁移过程依赖当前运行时代码中的
|
||||||
|
最新 SQLModel 定义,导致历史迁移随着后续 schema 演进而失真。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy.engine import Connection
|
||||||
|
|
||||||
|
_V2_TABLE_STATEMENTS = (
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS action_records (
|
||||||
|
id INTEGER NOT NULL,
|
||||||
|
action_id VARCHAR(255) NOT NULL,
|
||||||
|
timestamp DATETIME,
|
||||||
|
session_id VARCHAR(255) NOT NULL,
|
||||||
|
action_name VARCHAR(255) NOT NULL,
|
||||||
|
action_reasoning VARCHAR,
|
||||||
|
action_data VARCHAR,
|
||||||
|
action_builtin_prompt VARCHAR,
|
||||||
|
action_display_prompt VARCHAR,
|
||||||
|
PRIMARY KEY (id)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS binary_data (
|
||||||
|
id INTEGER NOT NULL,
|
||||||
|
data_hash VARCHAR(255) NOT NULL,
|
||||||
|
full_path VARCHAR(1024) NOT NULL,
|
||||||
|
PRIMARY KEY (id)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS chat_history (
|
||||||
|
id INTEGER NOT NULL,
|
||||||
|
session_id VARCHAR(255) NOT NULL,
|
||||||
|
start_timestamp DATETIME,
|
||||||
|
end_timestamp DATETIME,
|
||||||
|
query_count INTEGER NOT NULL,
|
||||||
|
query_forget_count INTEGER NOT NULL,
|
||||||
|
original_messages VARCHAR NOT NULL,
|
||||||
|
participants VARCHAR NOT NULL,
|
||||||
|
theme VARCHAR NOT NULL,
|
||||||
|
keywords VARCHAR NOT NULL,
|
||||||
|
summary VARCHAR NOT NULL,
|
||||||
|
PRIMARY KEY (id)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS chat_sessions (
|
||||||
|
id INTEGER NOT NULL,
|
||||||
|
session_id VARCHAR(255) NOT NULL,
|
||||||
|
created_timestamp DATETIME,
|
||||||
|
last_active_timestamp DATETIME,
|
||||||
|
user_id VARCHAR(255),
|
||||||
|
group_id VARCHAR(255),
|
||||||
|
platform VARCHAR(100) NOT NULL,
|
||||||
|
PRIMARY KEY (id)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS command_records (
|
||||||
|
id INTEGER NOT NULL,
|
||||||
|
timestamp DATETIME,
|
||||||
|
session_id VARCHAR(255) NOT NULL,
|
||||||
|
command_name VARCHAR(255) NOT NULL,
|
||||||
|
command_data VARCHAR,
|
||||||
|
command_result VARCHAR,
|
||||||
|
PRIMARY KEY (id)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS expressions (
|
||||||
|
id INTEGER NOT NULL,
|
||||||
|
situation VARCHAR(255) NOT NULL,
|
||||||
|
style VARCHAR(255) NOT NULL,
|
||||||
|
content_list VARCHAR NOT NULL,
|
||||||
|
count INTEGER NOT NULL,
|
||||||
|
last_active_time DATETIME,
|
||||||
|
create_time DATETIME,
|
||||||
|
session_id VARCHAR(255),
|
||||||
|
checked BOOLEAN NOT NULL,
|
||||||
|
rejected BOOLEAN NOT NULL,
|
||||||
|
modified_by VARCHAR(4),
|
||||||
|
PRIMARY KEY (id)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS images (
|
||||||
|
id INTEGER NOT NULL,
|
||||||
|
image_hash VARCHAR(255) NOT NULL,
|
||||||
|
description VARCHAR NOT NULL,
|
||||||
|
full_path VARCHAR(1024) NOT NULL,
|
||||||
|
image_type VARCHAR(5),
|
||||||
|
emotion VARCHAR,
|
||||||
|
query_count INTEGER NOT NULL,
|
||||||
|
is_registered BOOLEAN NOT NULL,
|
||||||
|
is_banned BOOLEAN NOT NULL,
|
||||||
|
no_file_flag BOOLEAN NOT NULL,
|
||||||
|
record_time DATETIME,
|
||||||
|
register_time DATETIME,
|
||||||
|
last_used_time DATETIME,
|
||||||
|
vlm_processed BOOLEAN NOT NULL,
|
||||||
|
PRIMARY KEY (id)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS jargons (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
||||||
|
content VARCHAR(255) NOT NULL,
|
||||||
|
raw_content TEXT,
|
||||||
|
meaning TEXT NOT NULL,
|
||||||
|
session_id_dict TEXT NOT NULL,
|
||||||
|
count INTEGER NOT NULL,
|
||||||
|
is_jargon BOOLEAN,
|
||||||
|
is_complete BOOLEAN NOT NULL,
|
||||||
|
is_global BOOLEAN NOT NULL,
|
||||||
|
last_inference_count INTEGER NOT NULL,
|
||||||
|
inference_with_context TEXT,
|
||||||
|
inference_with_content_only TEXT
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS llm_usage (
|
||||||
|
id INTEGER NOT NULL,
|
||||||
|
model_name VARCHAR(255) NOT NULL,
|
||||||
|
model_assign_name VARCHAR(255),
|
||||||
|
model_api_provider_name VARCHAR(255) NOT NULL,
|
||||||
|
endpoint VARCHAR(255),
|
||||||
|
user_type VARCHAR(6),
|
||||||
|
request_type VARCHAR(50) NOT NULL,
|
||||||
|
time_cost FLOAT,
|
||||||
|
timestamp DATETIME,
|
||||||
|
prompt_tokens INTEGER NOT NULL,
|
||||||
|
completion_tokens INTEGER NOT NULL,
|
||||||
|
total_tokens INTEGER NOT NULL,
|
||||||
|
cost FLOAT NOT NULL,
|
||||||
|
PRIMARY KEY (id)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS mai_knowledge (
|
||||||
|
id INTEGER NOT NULL,
|
||||||
|
knowledge_id VARCHAR(255) NOT NULL,
|
||||||
|
category_id VARCHAR(32) NOT NULL,
|
||||||
|
content VARCHAR NOT NULL,
|
||||||
|
normalized_content VARCHAR NOT NULL,
|
||||||
|
metadata_json VARCHAR,
|
||||||
|
created_at DATETIME,
|
||||||
|
PRIMARY KEY (id)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS mai_messages (
|
||||||
|
id INTEGER NOT NULL,
|
||||||
|
message_id VARCHAR(255) NOT NULL,
|
||||||
|
timestamp DATETIME,
|
||||||
|
platform VARCHAR(100) NOT NULL,
|
||||||
|
user_id VARCHAR(255) NOT NULL,
|
||||||
|
user_nickname VARCHAR(255) NOT NULL,
|
||||||
|
user_cardname VARCHAR(255),
|
||||||
|
group_id VARCHAR(255),
|
||||||
|
group_name VARCHAR(255),
|
||||||
|
is_mentioned BOOLEAN NOT NULL,
|
||||||
|
is_at BOOLEAN NOT NULL,
|
||||||
|
session_id VARCHAR(255) NOT NULL,
|
||||||
|
reply_to VARCHAR(255),
|
||||||
|
is_emoji BOOLEAN NOT NULL,
|
||||||
|
is_picture BOOLEAN NOT NULL,
|
||||||
|
is_command BOOLEAN NOT NULL,
|
||||||
|
is_notify BOOLEAN NOT NULL,
|
||||||
|
raw_content BLOB,
|
||||||
|
processed_plain_text VARCHAR,
|
||||||
|
display_message VARCHAR,
|
||||||
|
additional_config VARCHAR,
|
||||||
|
PRIMARY KEY (id)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS online_time (
|
||||||
|
id INTEGER NOT NULL,
|
||||||
|
timestamp DATETIME,
|
||||||
|
duration_minutes INTEGER NOT NULL,
|
||||||
|
start_timestamp DATETIME,
|
||||||
|
end_timestamp DATETIME,
|
||||||
|
PRIMARY KEY (id)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS person_info (
|
||||||
|
id INTEGER NOT NULL,
|
||||||
|
is_known BOOLEAN NOT NULL,
|
||||||
|
person_id VARCHAR(255) NOT NULL,
|
||||||
|
person_name VARCHAR(255),
|
||||||
|
name_reason VARCHAR,
|
||||||
|
platform VARCHAR(100) NOT NULL,
|
||||||
|
user_id VARCHAR(255) NOT NULL,
|
||||||
|
user_nickname VARCHAR(255) NOT NULL,
|
||||||
|
group_cardname VARCHAR,
|
||||||
|
memory_points VARCHAR,
|
||||||
|
know_counts INTEGER NOT NULL,
|
||||||
|
first_known_time DATETIME,
|
||||||
|
last_known_time DATETIME,
|
||||||
|
PRIMARY KEY (id)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS thinking_questions (
|
||||||
|
id INTEGER NOT NULL,
|
||||||
|
question VARCHAR NOT NULL,
|
||||||
|
context VARCHAR,
|
||||||
|
found_answer BOOLEAN NOT NULL,
|
||||||
|
answer VARCHAR,
|
||||||
|
thinking_steps VARCHAR,
|
||||||
|
created_timestamp DATETIME,
|
||||||
|
updated_timestamp DATETIME,
|
||||||
|
PRIMARY KEY (id)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS tool_records (
|
||||||
|
id INTEGER NOT NULL,
|
||||||
|
tool_id VARCHAR(255) NOT NULL,
|
||||||
|
timestamp DATETIME,
|
||||||
|
session_id VARCHAR(255) NOT NULL,
|
||||||
|
tool_name VARCHAR(255) NOT NULL,
|
||||||
|
tool_reasoning VARCHAR,
|
||||||
|
tool_data VARCHAR,
|
||||||
|
tool_builtin_prompt VARCHAR,
|
||||||
|
tool_display_prompt VARCHAR,
|
||||||
|
PRIMARY KEY (id)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
_V2_INDEX_STATEMENTS = (
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_action_records_action_id ON action_records (action_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_action_records_action_name ON action_records (action_name)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_action_records_session_id ON action_records (session_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_action_records_timestamp ON action_records (timestamp)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_binary_data_data_hash ON binary_data (data_hash)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_chat_history_end_timestamp ON chat_history (end_timestamp)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_chat_history_session_id ON chat_history (session_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_chat_history_start_timestamp ON chat_history (start_timestamp)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_chat_sessions_created_timestamp ON chat_sessions (created_timestamp)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_chat_sessions_group_id ON chat_sessions (group_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_chat_sessions_last_active_timestamp ON chat_sessions (last_active_timestamp)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_chat_sessions_platform ON chat_sessions (platform)",
|
||||||
|
"CREATE UNIQUE INDEX IF NOT EXISTS ix_chat_sessions_session_id ON chat_sessions (session_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_chat_sessions_user_id ON chat_sessions (user_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_command_records_command_name ON command_records (command_name)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_command_records_session_id ON command_records (session_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_command_records_timestamp ON command_records (timestamp)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_expressions_last_active_time ON expressions (last_active_time)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_expressions_situation ON expressions (situation)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_expressions_style ON expressions (style)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_images_image_hash ON images (image_hash)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_images_record_time ON images (record_time)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_jargons_content ON jargons (content)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_llm_usage_model_api_provider_name ON llm_usage (model_api_provider_name)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_llm_usage_model_assign_name ON llm_usage (model_assign_name)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_llm_usage_model_name ON llm_usage (model_name)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_llm_usage_timestamp ON llm_usage (timestamp)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_mai_knowledge_category_id ON mai_knowledge (category_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_mai_knowledge_created_at ON mai_knowledge (created_at)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_mai_knowledge_knowledge_id ON mai_knowledge (knowledge_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_mai_knowledge_normalized_content ON mai_knowledge (normalized_content)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_mai_messages_group_id ON mai_messages (group_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_mai_messages_message_id ON mai_messages (message_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_mai_messages_platform ON mai_messages (platform)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_mai_messages_session_id ON mai_messages (session_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_mai_messages_user_id ON mai_messages (user_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_mai_messages_user_nickname ON mai_messages (user_nickname)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_online_time_timestamp ON online_time (timestamp)",
|
||||||
|
"CREATE UNIQUE INDEX IF NOT EXISTS ix_person_info_person_id ON person_info (person_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_person_info_platform ON person_info (platform)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_person_info_user_id ON person_info (user_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_person_info_user_nickname ON person_info (user_nickname)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_thinking_questions_created_timestamp ON thinking_questions (created_timestamp)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_thinking_questions_updated_timestamp ON thinking_questions (updated_timestamp)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_tool_records_session_id ON tool_records (session_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_tool_records_timestamp ON tool_records (timestamp)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_tool_records_tool_id ON tool_records (tool_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS ix_tool_records_tool_name ON tool_records (tool_name)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_frozen_v2_schema(connection: Connection) -> None:
|
||||||
|
"""创建冻结的 v2 schema。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: 当前数据库连接。
|
||||||
|
"""
|
||||||
|
|
||||||
|
for statement in _V2_TABLE_STATEMENTS:
|
||||||
|
connection.exec_driver_sql(statement)
|
||||||
|
|
||||||
|
for statement in _V2_INDEX_STATEMENTS:
|
||||||
|
connection.exec_driver_sql(statement)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
"""旧版 ``0.x`` 数据库升级到最新 schema 的迁移逻辑。"""
|
"""旧版 ``0.x`` 数据库升级到 v2 schema 的迁移逻辑。"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -7,15 +7,16 @@ from dataclasses import dataclass
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, cast
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, cast
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import msgpack
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.engine import Connection
|
from sqlalchemy.engine import Connection
|
||||||
|
|
||||||
import json
|
|
||||||
import msgpack
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
from .exceptions import DatabaseMigrationExecutionError
|
from .exceptions import DatabaseMigrationExecutionError
|
||||||
|
from .frozen_v2_schema import create_frozen_v2_schema
|
||||||
from .models import DatabaseSchemaSnapshot, MigrationExecutionContext
|
from .models import DatabaseSchemaSnapshot, MigrationExecutionContext
|
||||||
from .schema import SQLiteSchemaInspector
|
from .schema import SQLiteSchemaInspector
|
||||||
|
|
||||||
@@ -52,19 +53,15 @@ class LegacyTableData:
|
|||||||
|
|
||||||
|
|
||||||
def migrate_legacy_v1_to_v2(context: MigrationExecutionContext) -> None:
|
def migrate_legacy_v1_to_v2(context: MigrationExecutionContext) -> None:
|
||||||
"""执行旧版 ``0.x`` 数据库到最新 schema 的迁移。
|
"""执行旧版 ``0.x`` 数据库到 v2 schema 的迁移。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context: 当前迁移步骤执行上下文。
|
context: 当前迁移步骤执行上下文。
|
||||||
"""
|
"""
|
||||||
from sqlmodel import SQLModel
|
|
||||||
|
|
||||||
import src.common.database.database_model # noqa: F401
|
|
||||||
|
|
||||||
schema_inspector = SQLiteSchemaInspector()
|
schema_inspector = SQLiteSchemaInspector()
|
||||||
snapshot = schema_inspector.inspect(context.connection)
|
snapshot = schema_inspector.inspect(context.connection)
|
||||||
_rename_legacy_v1_tables(context.connection, snapshot)
|
_rename_legacy_v1_tables(context.connection, snapshot)
|
||||||
SQLModel.metadata.create_all(context.connection)
|
create_frozen_v2_schema(context.connection)
|
||||||
|
|
||||||
table_migration_jobs: List[Tuple[str, Callable[[MigrationExecutionContext], int]]] = [
|
table_migration_jobs: List[Tuple[str, Callable[[MigrationExecutionContext], int]]] = [
|
||||||
("chat_sessions", _migrate_chat_sessions),
|
("chat_sessions", _migrate_chat_sessions),
|
||||||
@@ -794,8 +791,6 @@ def _migrate_images(context: MigrationExecutionContext) -> int:
|
|||||||
if full_path and dedupe_key not in existing_keys:
|
if full_path and dedupe_key not in existing_keys:
|
||||||
migrated_description = _normalize_required_text(row.get("description"))
|
migrated_description = _normalize_required_text(row.get("description"))
|
||||||
migrated_emotion = _normalize_optional_text(row.get("emotion"))
|
migrated_emotion = _normalize_optional_text(row.get("emotion"))
|
||||||
if not migrated_description and migrated_emotion:
|
|
||||||
migrated_description = migrated_emotion
|
|
||||||
connection.execute(
|
connection.execute(
|
||||||
insert_sql,
|
insert_sql,
|
||||||
{
|
{
|
||||||
@@ -803,7 +798,7 @@ def _migrate_images(context: MigrationExecutionContext) -> int:
|
|||||||
"description": migrated_description,
|
"description": migrated_description,
|
||||||
"full_path": full_path,
|
"full_path": full_path,
|
||||||
"image_type": "EMOJI",
|
"image_type": "EMOJI",
|
||||||
"emotion": None,
|
"emotion": migrated_emotion,
|
||||||
"query_count": _normalize_int(row.get("query_count"), default=0),
|
"query_count": _normalize_int(row.get("query_count"), default=0),
|
||||||
"is_registered": _normalize_bool(row.get("is_registered"), default=False),
|
"is_registered": _normalize_bool(row.get("is_registered"), default=False),
|
||||||
"is_banned": _normalize_bool(row.get("is_banned"), default=False),
|
"is_banned": _normalize_bool(row.get("is_banned"), default=False),
|
||||||
|
|||||||
269
src/common/database/migrations/v2_to_v3.py
Normal file
269
src/common/database/migrations/v2_to_v3.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
"""v2 schema 升级到 v3 的迁移逻辑。"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine import Connection
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
from .exceptions import DatabaseMigrationExecutionError
|
||||||
|
from .models import MigrationExecutionContext
|
||||||
|
from .schema import SQLiteSchemaInspector
|
||||||
|
|
||||||
|
logger = get_logger("database_migration")
|
||||||
|
|
||||||
|
_V2_IMAGES_BACKUP_TABLE = "__v2_images_backup"
|
||||||
|
_V3_IMAGES_CREATE_SQL = """
|
||||||
|
CREATE TABLE images (
|
||||||
|
id INTEGER NOT NULL,
|
||||||
|
image_hash VARCHAR(255) NOT NULL,
|
||||||
|
description VARCHAR NOT NULL,
|
||||||
|
full_path VARCHAR(1024) NOT NULL,
|
||||||
|
image_type VARCHAR(5),
|
||||||
|
query_count INTEGER NOT NULL,
|
||||||
|
is_registered BOOLEAN NOT NULL,
|
||||||
|
is_banned BOOLEAN NOT NULL,
|
||||||
|
no_file_flag BOOLEAN NOT NULL,
|
||||||
|
record_time DATETIME,
|
||||||
|
register_time DATETIME,
|
||||||
|
last_used_time DATETIME,
|
||||||
|
vlm_processed BOOLEAN NOT NULL,
|
||||||
|
PRIMARY KEY (id)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
_V3_IMAGES_INDEX_STATEMENTS = (
|
||||||
|
"CREATE INDEX ix_images_image_hash ON images (image_hash)",
|
||||||
|
"CREATE INDEX ix_images_record_time ON images (record_time)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_v2_to_v3(context: MigrationExecutionContext) -> None:
|
||||||
|
"""执行 v2 到 v3 的 schema 迁移。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 当前迁移步骤执行上下文。
|
||||||
|
"""
|
||||||
|
|
||||||
|
connection = context.connection
|
||||||
|
total_records = (
|
||||||
|
_count_table_rows(connection, "action_records")
|
||||||
|
+ _count_table_rows(connection, "thinking_questions")
|
||||||
|
+ _count_table_rows(connection, "images")
|
||||||
|
)
|
||||||
|
context.start_progress(
|
||||||
|
total_tables=3,
|
||||||
|
total_records=total_records,
|
||||||
|
description="v2 -> v3 迁移进度",
|
||||||
|
table_unit_name="表",
|
||||||
|
record_unit_name="记录",
|
||||||
|
)
|
||||||
|
|
||||||
|
migrated_tool_records = _migrate_action_records_to_tool_records(connection)
|
||||||
|
action_record_count = _count_table_rows(connection, "action_records")
|
||||||
|
_drop_table_if_exists(connection, "action_records")
|
||||||
|
context.advance_progress(
|
||||||
|
records=action_record_count,
|
||||||
|
completed_tables=1,
|
||||||
|
item_name="action_records",
|
||||||
|
)
|
||||||
|
|
||||||
|
thinking_question_count = _count_table_rows(connection, "thinking_questions")
|
||||||
|
_drop_table_if_exists(connection, "thinking_questions")
|
||||||
|
context.advance_progress(
|
||||||
|
records=thinking_question_count,
|
||||||
|
completed_tables=1,
|
||||||
|
item_name="thinking_questions",
|
||||||
|
)
|
||||||
|
|
||||||
|
migrated_image_rows = _migrate_images_table_to_v3(connection)
|
||||||
|
context.advance_progress(
|
||||||
|
records=migrated_image_rows,
|
||||||
|
completed_tables=1,
|
||||||
|
item_name="images",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"v2 -> v3 数据库迁移完成: "
|
||||||
|
f"tool_records补迁移={migrated_tool_records},"
|
||||||
|
f"images重建={migrated_image_rows}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_table_rows(connection: Connection, table_name: str) -> int:
|
||||||
|
"""统计表记录数,不存在时返回 0。"""
|
||||||
|
|
||||||
|
schema_inspector = SQLiteSchemaInspector()
|
||||||
|
if not schema_inspector.table_exists(connection, table_name):
|
||||||
|
return 0
|
||||||
|
row = connection.execute(text(f'SELECT COUNT(*) FROM "{table_name}"')).first()
|
||||||
|
return int(row[0]) if row else 0
|
||||||
|
|
||||||
|
|
||||||
|
def _drop_table_if_exists(connection: Connection, table_name: str) -> None:
|
||||||
|
"""删除指定表,不存在时静默跳过。"""
|
||||||
|
|
||||||
|
connection.exec_driver_sql(f'DROP TABLE IF EXISTS "{table_name}"')
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_action_records_to_tool_records(connection: Connection) -> int:
|
||||||
|
"""把 v2 中残留的 ``action_records`` 数据转存到 ``tool_records``。"""
|
||||||
|
|
||||||
|
schema_inspector = SQLiteSchemaInspector()
|
||||||
|
if not schema_inspector.table_exists(connection, "action_records"):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
inserted_count = _count_table_rows(connection, "action_records")
|
||||||
|
connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO tool_records (
|
||||||
|
tool_id,
|
||||||
|
timestamp,
|
||||||
|
session_id,
|
||||||
|
tool_name,
|
||||||
|
tool_reasoning,
|
||||||
|
tool_data,
|
||||||
|
tool_builtin_prompt,
|
||||||
|
tool_display_prompt
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
action_id,
|
||||||
|
timestamp,
|
||||||
|
session_id,
|
||||||
|
action_name,
|
||||||
|
action_reasoning,
|
||||||
|
action_data,
|
||||||
|
action_builtin_prompt,
|
||||||
|
action_display_prompt
|
||||||
|
FROM action_records
|
||||||
|
WHERE NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM tool_records
|
||||||
|
WHERE tool_records.tool_id = action_records.action_id
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return inserted_count
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_images_table_to_v3(connection: Connection) -> int:
|
||||||
|
"""重建 ``images`` 表并移除 ``emotion`` 列。"""
|
||||||
|
|
||||||
|
schema_inspector = SQLiteSchemaInspector()
|
||||||
|
if not schema_inspector.table_exists(connection, "images"):
|
||||||
|
return 0
|
||||||
|
if not schema_inspector.get_table_schema(connection, "images").has_column("emotion"):
|
||||||
|
return _count_table_rows(connection, "images")
|
||||||
|
if schema_inspector.table_exists(connection, _V2_IMAGES_BACKUP_TABLE):
|
||||||
|
raise DatabaseMigrationExecutionError(
|
||||||
|
f"检测到残留备份表 {_V2_IMAGES_BACKUP_TABLE},无法安全执行 v2 -> v3 images 迁移。"
|
||||||
|
)
|
||||||
|
|
||||||
|
connection.exec_driver_sql(f'ALTER TABLE "images" RENAME TO "{_V2_IMAGES_BACKUP_TABLE}"')
|
||||||
|
connection.exec_driver_sql(_V3_IMAGES_CREATE_SQL)
|
||||||
|
|
||||||
|
legacy_rows = connection.execute(
|
||||||
|
text(f'SELECT * FROM "{_V2_IMAGES_BACKUP_TABLE}" ORDER BY id')
|
||||||
|
).mappings().all()
|
||||||
|
insert_sql = text(
|
||||||
|
"""
|
||||||
|
INSERT INTO images (
|
||||||
|
id,
|
||||||
|
image_hash,
|
||||||
|
description,
|
||||||
|
full_path,
|
||||||
|
image_type,
|
||||||
|
query_count,
|
||||||
|
is_registered,
|
||||||
|
is_banned,
|
||||||
|
no_file_flag,
|
||||||
|
record_time,
|
||||||
|
register_time,
|
||||||
|
last_used_time,
|
||||||
|
vlm_processed
|
||||||
|
) VALUES (
|
||||||
|
:id,
|
||||||
|
:image_hash,
|
||||||
|
:description,
|
||||||
|
:full_path,
|
||||||
|
:image_type,
|
||||||
|
:query_count,
|
||||||
|
:is_registered,
|
||||||
|
:is_banned,
|
||||||
|
:no_file_flag,
|
||||||
|
:record_time,
|
||||||
|
:register_time,
|
||||||
|
:last_used_time,
|
||||||
|
:vlm_processed
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
for row in legacy_rows:
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"id": row.get("id"),
|
||||||
|
"image_hash": str(row.get("image_hash") or "").strip(),
|
||||||
|
"description": _migrate_v3_emoji_description(row),
|
||||||
|
"full_path": str(row.get("full_path") or "").strip(),
|
||||||
|
"image_type": row.get("image_type"),
|
||||||
|
"query_count": int(row.get("query_count") or 0),
|
||||||
|
"is_registered": bool(row.get("is_registered")),
|
||||||
|
"is_banned": bool(row.get("is_banned")),
|
||||||
|
"no_file_flag": bool(row.get("no_file_flag")),
|
||||||
|
"record_time": row.get("record_time"),
|
||||||
|
"register_time": row.get("register_time"),
|
||||||
|
"last_used_time": row.get("last_used_time"),
|
||||||
|
"vlm_processed": bool(row.get("vlm_processed")),
|
||||||
|
}
|
||||||
|
connection.execute(insert_sql, payload)
|
||||||
|
|
||||||
|
connection.exec_driver_sql(f'DROP TABLE "{_V2_IMAGES_BACKUP_TABLE}"')
|
||||||
|
for statement in _V3_IMAGES_INDEX_STATEMENTS:
|
||||||
|
connection.exec_driver_sql(statement)
|
||||||
|
return len(legacy_rows)
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_v3_emoji_description(row: Dict[str, Any]) -> str:
|
||||||
|
"""为 v3 统一 emoji 描述字段语义。
|
||||||
|
|
||||||
|
v3 中 `description` 对 emoji 统一承担“标签列表”的职责,因此迁移时:
|
||||||
|
1. 若旧 `emotion` 非空,优先将其规范化后写入 `description`;
|
||||||
|
2. 否则保留并规范化当前 `description`;
|
||||||
|
3. 非 emoji 图片保持原描述不变。
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_type = str(row.get("image_type") or "").strip().upper()
|
||||||
|
current_description = str(row.get("description") or "").strip()
|
||||||
|
current_emotion = str(row.get("emotion") or "").strip()
|
||||||
|
if image_type != "EMOJI":
|
||||||
|
return current_description
|
||||||
|
|
||||||
|
normalized_tags = _normalize_emoji_tag_text(current_emotion or current_description)
|
||||||
|
if normalized_tags:
|
||||||
|
return ",".join(normalized_tags)
|
||||||
|
return current_description
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_emoji_tag_text(raw_value: Any) -> List[str]:
|
||||||
|
"""将 emoji 标签文本转换为去重后的标签列表。"""
|
||||||
|
|
||||||
|
normalized_text = str(raw_value or "").strip()
|
||||||
|
if not normalized_text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
separators = [",", ",", "、", ";", ";", "\n", "\r", "\t"]
|
||||||
|
for separator in separators[1:]:
|
||||||
|
normalized_text = normalized_text.replace(separator, separators[0])
|
||||||
|
|
||||||
|
deduped_tags: List[str] = []
|
||||||
|
seen_tags: set[str] = set()
|
||||||
|
for part in normalized_text.split(separators[0]):
|
||||||
|
normalized_part = part.strip()
|
||||||
|
lowered_part = normalized_part.lower()
|
||||||
|
if not normalized_part or lowered_part in seen_tags:
|
||||||
|
continue
|
||||||
|
seen_tags.add(lowered_part)
|
||||||
|
deduped_tags.append(normalized_part)
|
||||||
|
return deduped_tags
|
||||||
@@ -1,16 +1,14 @@
|
|||||||
import contextlib
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
import contextlib
|
||||||
from typing import List, Dict, Any, Optional, Tuple, Callable
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.prompt.prompt_manager import prompt_manager
|
from src.prompt.prompt_manager import prompt_manager
|
||||||
from src.services import llm_service as llm_api
|
from src.services import llm_service as llm_api
|
||||||
from sqlmodel import select, col
|
|
||||||
from src.common.database.database import get_db_session
|
|
||||||
from src.common.database.database_model import ThinkingQuestion
|
|
||||||
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
|
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
|
||||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
||||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||||
@@ -18,35 +16,6 @@ from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
|
|||||||
|
|
||||||
logger = get_logger("memory_retrieval")
|
logger = get_logger("memory_retrieval")
|
||||||
|
|
||||||
THINKING_BACK_NOT_FOUND_RETENTION_SECONDS = 36000 # 未找到答案记录保留时长
|
|
||||||
THINKING_BACK_CLEANUP_INTERVAL_SECONDS = 3000 # 清理频率
|
|
||||||
_last_not_found_cleanup_ts: float = 0.0
|
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_stale_not_found_thinking_back() -> None:
|
|
||||||
"""定期清理过期的未找到答案记录"""
|
|
||||||
global _last_not_found_cleanup_ts
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
if now - _last_not_found_cleanup_ts < THINKING_BACK_CLEANUP_INTERVAL_SECONDS:
|
|
||||||
return
|
|
||||||
|
|
||||||
threshold_time = now - THINKING_BACK_NOT_FOUND_RETENTION_SECONDS
|
|
||||||
try:
|
|
||||||
with get_db_session() as session:
|
|
||||||
statement = select(ThinkingQuestion).where(
|
|
||||||
col(ThinkingQuestion.found_answer).is_(False)
|
|
||||||
& (ThinkingQuestion.updated_timestamp < datetime.fromtimestamp(threshold_time))
|
|
||||||
)
|
|
||||||
records = session.exec(statement).all()
|
|
||||||
for record in records:
|
|
||||||
session.delete(record)
|
|
||||||
if records:
|
|
||||||
logger.info(f"清理过期的未找到答案thinking_question记录 {len(records)} 条")
|
|
||||||
_last_not_found_cleanup_ts = now
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"清理未找到答案的thinking_back记录失败: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def init_memory_retrieval_sys():
|
def init_memory_retrieval_sys():
|
||||||
"""初始化记忆检索相关工具"""
|
"""初始化记忆检索相关工具"""
|
||||||
@@ -766,141 +735,6 @@ async def _react_agent_solve_question(
|
|||||||
|
|
||||||
return False, "", thinking_steps, is_timeout
|
return False, "", thinking_steps, is_timeout
|
||||||
|
|
||||||
|
|
||||||
def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0) -> str:
|
|
||||||
"""获取最近一段时间内的查询历史(用于避免重复查询)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天ID
|
|
||||||
time_window_seconds: 时间窗口(秒),默认10分钟
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 格式化的查询历史字符串
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
_current_time = time.time()
|
|
||||||
|
|
||||||
with get_db_session() as session:
|
|
||||||
statement = (
|
|
||||||
select(ThinkingQuestion)
|
|
||||||
.where(col(ThinkingQuestion.context) == chat_id)
|
|
||||||
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
|
|
||||||
.limit(5)
|
|
||||||
)
|
|
||||||
records = session.exec(statement).all()
|
|
||||||
|
|
||||||
if not records:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
history_lines = ["最近已查询的问题和结果:"]
|
|
||||||
|
|
||||||
for record in records:
|
|
||||||
status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案"
|
|
||||||
answer_preview = ""
|
|
||||||
# 只有找到答案时才显示答案内容
|
|
||||||
if record.found_answer and record.answer:
|
|
||||||
# 截取答案前100字符
|
|
||||||
answer_preview = record.answer[:100]
|
|
||||||
if len(record.answer) > 100:
|
|
||||||
answer_preview += "..."
|
|
||||||
|
|
||||||
history_lines.extend([f"- 问题:{record.question}", f" 状态:{status}"])
|
|
||||||
if answer_preview:
|
|
||||||
history_lines.append(f" 答案:{answer_preview}")
|
|
||||||
history_lines.append("") # 空行分隔
|
|
||||||
|
|
||||||
return "\n".join(history_lines)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取查询历史失败: {e}")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0) -> List[str]:
|
|
||||||
"""获取最近一段时间内已找到答案的查询记录(用于返回给 replyer)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天ID
|
|
||||||
time_window_seconds: 时间窗口(秒),默认10分钟
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: 格式化的答案列表,每个元素格式为 "问题:xxx\n答案:xxx"
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
_current_time = time.time()
|
|
||||||
|
|
||||||
# 查询最近时间窗口内已找到答案的记录,按更新时间倒序
|
|
||||||
with get_db_session() as session:
|
|
||||||
statement = (
|
|
||||||
select(ThinkingQuestion)
|
|
||||||
.where(col(ThinkingQuestion.context) == chat_id)
|
|
||||||
.where(col(ThinkingQuestion.found_answer))
|
|
||||||
.where(col(ThinkingQuestion.answer).is_not(None))
|
|
||||||
.where(col(ThinkingQuestion.answer) != "")
|
|
||||||
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
|
|
||||||
.limit(3)
|
|
||||||
)
|
|
||||||
records = session.exec(statement).all()
|
|
||||||
|
|
||||||
if not records:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return [f"问题:{record.question}\n答案:{record.answer}" for record in records if record.answer]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取最近已找到答案的记录失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def _store_thinking_back(
|
|
||||||
chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]]
|
|
||||||
) -> None:
|
|
||||||
"""存储或更新思考过程到数据库(如果已存在则更新,否则创建)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天ID
|
|
||||||
question: 问题
|
|
||||||
context: 上下文信息
|
|
||||||
found_answer: 是否找到答案
|
|
||||||
answer: 答案内容
|
|
||||||
thinking_steps: 思考步骤列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
now = time.time()
|
|
||||||
|
|
||||||
# 先查询是否已存在相同chat_id和问题的记录
|
|
||||||
with get_db_session() as session:
|
|
||||||
statement = (
|
|
||||||
select(ThinkingQuestion)
|
|
||||||
.where(col(ThinkingQuestion.context) == chat_id)
|
|
||||||
.where(col(ThinkingQuestion.question) == question)
|
|
||||||
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
|
|
||||||
.limit(1)
|
|
||||||
)
|
|
||||||
if record := session.exec(statement).first():
|
|
||||||
record.context = context
|
|
||||||
record.found_answer = found_answer
|
|
||||||
record.answer = answer
|
|
||||||
record.thinking_steps = json.dumps(thinking_steps, ensure_ascii=False)
|
|
||||||
record.updated_timestamp = datetime.fromtimestamp(now)
|
|
||||||
session.add(record)
|
|
||||||
logger.info(f"已更新思考过程到数据库,问题: {question[:50]}...")
|
|
||||||
return
|
|
||||||
|
|
||||||
new_record = ThinkingQuestion(
|
|
||||||
question=question,
|
|
||||||
context=chat_id,
|
|
||||||
found_answer=found_answer,
|
|
||||||
answer=answer,
|
|
||||||
thinking_steps=json.dumps(thinking_steps, ensure_ascii=False),
|
|
||||||
created_timestamp=datetime.fromtimestamp(now),
|
|
||||||
updated_timestamp=datetime.fromtimestamp(now),
|
|
||||||
)
|
|
||||||
session.add(new_record)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"存储思考过程失败: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def _process_memory_retrieval(
|
async def _process_memory_retrieval(
|
||||||
chat_id: str,
|
chat_id: str,
|
||||||
context: str,
|
context: str,
|
||||||
@@ -920,8 +754,6 @@ async def _process_memory_retrieval(
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 如果找到答案,返回答案内容,否则返回None
|
Optional[str]: 如果找到答案,返回答案内容,否则返回None
|
||||||
"""
|
"""
|
||||||
_cleanup_stale_not_found_thinking_back()
|
|
||||||
|
|
||||||
question_initial_info = initial_info or ""
|
question_initial_info = initial_info or ""
|
||||||
|
|
||||||
# 直接使用ReAct Agent进行记忆检索
|
# 直接使用ReAct Agent进行记忆检索
|
||||||
|
|||||||
@@ -585,10 +585,8 @@ class RuntimeDataCapabilityMixin:
|
|||||||
if not ImageUtils.base64_to_image(emoji_base64, str(temp_file_path)):
|
if not ImageUtils.base64_to_image(emoji_base64, str(temp_file_path)):
|
||||||
return {"success": False, "message": "无法保存图片文件", "description": None, "emotions": None, "replaced": None, "hash": None}
|
return {"success": False, "message": "无法保存图片文件", "description": None, "emotions": None, "replaced": None, "hash": None}
|
||||||
|
|
||||||
register_success = await emoji_manager.register_emoji_by_filename(temp_file_path)
|
register_status = await emoji_manager.register_emoji_by_filename(temp_file_path)
|
||||||
if not register_success:
|
if register_status == "failed":
|
||||||
if temp_file_path.exists():
|
|
||||||
temp_file_path.unlink(missing_ok=True)
|
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
|
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
|
||||||
@@ -597,6 +595,15 @@ class RuntimeDataCapabilityMixin:
|
|||||||
"replaced": None,
|
"replaced": None,
|
||||||
"hash": None,
|
"hash": None,
|
||||||
}
|
}
|
||||||
|
if register_status == "skipped":
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "表情包已注册,已跳过本次注册",
|
||||||
|
"description": None,
|
||||||
|
"emotions": None,
|
||||||
|
"replaced": False,
|
||||||
|
"hash": None,
|
||||||
|
}
|
||||||
|
|
||||||
count_after = len(emoji_manager.emojis)
|
count_after = len(emoji_manager.emojis)
|
||||||
replaced = count_after <= count_before
|
replaced = count_after <= count_before
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ from .schemas import (
|
|||||||
emoji_to_response,
|
emoji_to_response,
|
||||||
)
|
)
|
||||||
from .support import (
|
from .support import (
|
||||||
EMOJI_REGISTERED_DIR,
|
EMOJI_DIR,
|
||||||
THUMBNAIL_CACHE_DIR,
|
THUMBNAIL_CACHE_DIR,
|
||||||
background_generate_thumbnail,
|
background_generate_thumbnail,
|
||||||
cleanup_orphaned_thumbnails,
|
cleanup_orphaned_thumbnails,
|
||||||
@@ -326,7 +326,7 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
|
|||||||
if not emoji:
|
if not emoji:
|
||||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||||
if emoji.is_registered:
|
if emoji.is_registered:
|
||||||
raise HTTPException(status_code=400, detail="该表情包已经注册")
|
return EmojiUpdateResponse(success=True, message="??????????", data=emoji_to_response(emoji))
|
||||||
|
|
||||||
emoji.is_registered = True
|
emoji.is_registered = True
|
||||||
emoji.is_banned = False
|
emoji.is_banned = False
|
||||||
@@ -549,16 +549,16 @@ async def upload_emoji(
|
|||||||
if existing_emoji := session.exec(existing_statement).first():
|
if existing_emoji := session.exec(existing_statement).first():
|
||||||
raise HTTPException(status_code=409, detail=f"已存在相同的表情包 (ID: {existing_emoji.id})")
|
raise HTTPException(status_code=409, detail=f"已存在相同的表情包 (ID: {existing_emoji.id})")
|
||||||
|
|
||||||
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
|
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||||
|
|
||||||
timestamp = int(datetime.now().timestamp())
|
timestamp = int(datetime.now().timestamp())
|
||||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
|
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
|
||||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
full_path = os.path.join(EMOJI_DIR, filename)
|
||||||
|
|
||||||
counter = 1
|
counter = 1
|
||||||
while os.path.exists(full_path):
|
while os.path.exists(full_path):
|
||||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}"
|
filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}"
|
||||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
full_path = os.path.join(EMOJI_DIR, filename)
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
with open(full_path, "wb") as output_file:
|
with open(full_path, "wb") as output_file:
|
||||||
@@ -618,7 +618,7 @@ async def batch_upload_emoji(
|
|||||||
}
|
}
|
||||||
|
|
||||||
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
|
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
|
||||||
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
|
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||||
|
|
||||||
for file in files:
|
for file in files:
|
||||||
try:
|
try:
|
||||||
@@ -665,12 +665,12 @@ async def batch_upload_emoji(
|
|||||||
|
|
||||||
timestamp = int(datetime.now().timestamp())
|
timestamp = int(datetime.now().timestamp())
|
||||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
|
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
|
||||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
full_path = os.path.join(EMOJI_DIR, filename)
|
||||||
|
|
||||||
counter = 1
|
counter = 1
|
||||||
while os.path.exists(full_path):
|
while os.path.exists(full_path):
|
||||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}"
|
filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}"
|
||||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
full_path = os.path.join(EMOJI_DIR, filename)
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
with open(full_path, "wb") as output_file:
|
with open(full_path, "wb") as output_file:
|
||||||
|
|||||||
@@ -16,7 +16,8 @@ logger = get_logger("webui.emoji")
|
|||||||
THUMBNAIL_CACHE_DIR = Path("data/emoji_thumbnails")
|
THUMBNAIL_CACHE_DIR = Path("data/emoji_thumbnails")
|
||||||
THUMBNAIL_SIZE = (200, 200)
|
THUMBNAIL_SIZE = (200, 200)
|
||||||
THUMBNAIL_QUALITY = 80
|
THUMBNAIL_QUALITY = 80
|
||||||
EMOJI_REGISTERED_DIR = os.path.join("data", "emoji_registed")
|
EMOJI_REGISTERED_DIR = os.path.join("data", "emoji")
|
||||||
|
EMOJI_DIR = EMOJI_REGISTERED_DIR
|
||||||
|
|
||||||
_thumbnail_locks: Dict[str, threading.Lock] = {}
|
_thumbnail_locks: Dict[str, threading.Lock] = {}
|
||||||
_locks_lock = threading.Lock()
|
_locks_lock = threading.Lock()
|
||||||
|
|||||||
Reference in New Issue
Block a user