From 526fc9b7635a6a06851c92fddf6a12feb7d18c9a Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sun, 5 Apr 2026 20:12:21 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E4=BC=98=E5=8C=96=E8=A1=A8=E6=83=85?= =?UTF-8?q?=E5=8C=85=E6=B3=A8=E5=86=8C=EF=BC=8C=E8=BF=81=E7=A7=BB=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93v3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/emoji_system/emoji_manager.py | 363 ++++++++---------- src/chat/emoji_system/maisaka_tool.py | 6 +- src/chat/image_system/image_manager.py | 37 +- .../data_models/action_record_data_model.py | 54 +-- src/common/data_models/image_data_model.py | 16 +- src/common/database/database.py | 41 +- src/common/database/database_model.py | 44 --- src/common/database/migrations/__init__.py | 2 + src/common/database/migrations/builtin.py | 95 ++++- .../database/migrations/frozen_v2_schema.py | 298 ++++++++++++++ .../database/migrations/legacy_v1_to_v2.py | 21 +- src/common/database/migrations/v2_to_v3.py | 269 +++++++++++++ src/memory_system/memory_retrieval.py | 180 +-------- src/plugin_runtime/capabilities/data.py | 15 +- src/webui/routers/emoji/routes.py | 16 +- src/webui/routers/emoji/support.py | 3 +- 16 files changed, 926 insertions(+), 534 deletions(-) create mode 100644 src/common/database/migrations/frozen_v2_schema.py create mode 100644 src/common/database/migrations/v2_to_v3.py diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index d88ad467..b03b3505 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -1,6 +1,6 @@ from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Literal, Optional import asyncio import hashlib @@ -31,12 +31,12 @@ install(extra_lines=3) PROJECT_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve() DATA_DIR = PROJECT_ROOT / "data" +EmojiRegisterStatus = Literal["registered", "skipped", "failed"] EMOJI_DIR = DATA_DIR / "emoji" # 表情包存储目录 -EMOJI_REGISTERED_DIR = DATA_DIR / "emoji_registered" # 已注册的表情包注册目录 MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中 -def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]: +def register_emoji_hook_specs(registry: HookSpecRegistry) -> list[HookSpec]: """注册表情包系统内置 Hook 规格。 Args: @@ -145,7 +145,7 @@ def _get_runtime_manager() -> Any: return get_plugin_runtime_manager() -def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, Any]]: +def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[dict[str, Any]]: """将表情包对象序列化为 Hook 可传输载荷。 Args: @@ -163,27 +163,27 @@ def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, A "file_name": emoji.file_name, "full_path": str(emoji.full_path), "description": emoji.description, - "emotions": [str(item).strip() for item in _normalize_emoji_tag_text(emoji.description or emoji.emotion)], + "emotions": [str(item).strip() for item in _normalize_emoji_tag_text(emoji.description)], "query_count": int(emoji.query_count), } -def _normalize_emoji_tag_text(raw_values: Any) -> List[str]: +def _normalize_emoji_tag_text(raw_values: Any) -> list[str]: """将文本或标签列表转为去重的情绪标签列表。""" if isinstance(raw_values, str): if not raw_values: return [] - parts = re.split(r"[,,、;;\s]+", raw_values.strip()) + parts = re.split(r"[,,、;;\r\n\t]+", raw_values.strip()) normalized_tags = [str(part).strip() for part in parts if str(part).strip()] elif isinstance(raw_values, list): - normalized_tags: List[str] = [] + normalized_tags: list[str] = [] for value in raw_values: normalized_tags.extend(_normalize_emoji_tag_text(value)) else: return [] - deduped_tags: List[str] = [] - seen: Set[str] = set() + deduped_tags: list[str] = [] + seen: set[str] = set() for tag in normalized_tags: normalized_tag = tag.strip() if not normalized_tag: @@ -196,15 +196,23 @@ def _normalize_emoji_tag_text(raw_values: Any) -> List[str]: return deduped_tags -def _get_emoji_emotions(emoji: MaiEmoji) -> List[str]: - """获取兼容旧数据的表情包情绪标签。""" - return _normalize_emoji_tag_text(emoji.description or emoji.emotion) +def _get_emoji_emotions(emoji: MaiEmoji) -> list[str]: + """获取表情包情绪标签。""" + return _normalize_emoji_tag_text(emoji.description) def _ensure_directories() -> None: """确保表情包相关目录存在""" EMOJI_DIR.mkdir(parents=True, exist_ok=True) - EMOJI_REGISTERED_DIR.mkdir(parents=True, exist_ok=True) + + +def _is_available_emoji_record(record: Images) -> bool: + """判断数据库记录对应的表情包文件当前是否可用。""" + if record.no_file_flag: + return False + + record_path = Path(record.full_path) + return record_path.exists() and record_path.is_file() # TODO: 修改这个vlm为获取的vlm client,暂时使用这个VLM方法 @@ -224,9 +232,9 @@ class EmojiManager: _ensure_directories() self._emoji_num: int = 0 - self.emojis: List[MaiEmoji] = [] + self.emojis: list[MaiEmoji] = [] self._maintenance_wakeup_event: asyncio.Event = asyncio.Event() - self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {} + self._pending_description_tasks: dict[str, asyncio.Task[None]] = {} self._reload_callback_registered: bool = False config_manager.register_reload_callback(self.reload_runtime_config) @@ -254,7 +262,7 @@ class EmojiManager: emoji_bytes: Optional[bytes] = None, emoji_hash: Optional[str] = None, wait_for_build: bool = True, - ) -> Optional[Tuple[str, List[str]]]: + ) -> Optional[tuple[str, list[str]]]: """ 根据表情包哈希获取表情包描述和情感列表的封装方法 @@ -275,12 +283,28 @@ class EmojiManager: emoji_hash = hashlib.sha256(emoji_bytes).hexdigest() if emoji := self.get_emoji_by_hash(emoji_hash): + emoji_path = Path(emoji.full_path) if emoji.full_path else None + if emoji_bytes and (emoji_path is None or not emoji_path.exists()): + try: + restored_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash) + emoji.full_path = restored_emoji.full_path + emoji.file_name = restored_emoji.file_name + except Exception as e: + logger.warning(f"表情包缓存命中但本地文件缺失,回填失败: {e}") return emoji.description, _normalize_emoji_tag_text(emoji.description or "") try: with get_db_session() as session: statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1) if result := session.exec(statement).first(): - cached_description = result.description or result.emotion or "" + record_path = Path(result.full_path) if result.full_path else None + if emoji_bytes and (result.no_file_flag or record_path is None or not record_path.exists()): + try: + restored_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash) + result.full_path = str(restored_emoji.full_path) + result.no_file_flag = False + except Exception as e: + logger.warning(f"数据库命中表情包记录但本地文件缺失,回填失败: {e}") + cached_description = result.description or "" cached_emotions = _normalize_emoji_tag_text(cached_description) return ( cached_description, @@ -400,7 +424,7 @@ class EmojiManager: self, emoji_hash: str, emoji_bytes: bytes, - ) -> Optional[Tuple[str, List[str]]]: + ) -> Optional[tuple[str, list[str]]]: """构建并缓存表情包描述(返回标签化结果,不再走额外识别流程)。""" logger.info(f"Start building cached emoji description, hash={emoji_hash}") new_emoji = await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash) @@ -461,87 +485,70 @@ class EmojiManager: self._emoji_num = 0 raise e - def register_emoji_to_db(self, emoji: MaiEmoji) -> bool: + def register_emoji_to_db(self, emoji: MaiEmoji) -> EmojiRegisterStatus: # sourcery skip: extract-method - """ - 将表情包注册到数据库中 - - Args: - emoji (MaiEmoji): 需要注册的表情包对象 - Returns: - return (bool): 注册是否成功 - """ + """Register an emoji in the database without moving its source file.""" if not emoji or not isinstance(emoji, MaiEmoji): - logger.error("[注册表情包] 无效的表情包对象") - return False + logger.error("[register_emoji] Invalid emoji object") + return "failed" if not emoji.full_path.exists(): - logger.error(f"[注册表情包] 表情包文件不存在: {emoji.full_path}") - return False + logger.error(f"[register_emoji] Emoji file does not exist: {emoji.full_path}") + return "failed" - target_path = EMOJI_REGISTERED_DIR / emoji.file_name - - # 先查库,避免重复记录导致文件被误移动后无法回收 - original_path = emoji.full_path try: with get_db_session() as session: statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1) existing_record = session.exec(statement).first() - if existing_record and not existing_record.no_file_flag: - logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}") - return False - except Exception as e: - logger.error(f"[注册表情包] 查询数据库时出错: {e}") - return False - try: - emoji.full_path.replace(target_path) - emoji.full_path = target_path - except Exception as e: - logger.error(f"[注册表情包] 移动表情包文件时出错: {e}") - return False + if existing_record: + if existing_record.is_registered and _is_available_emoji_record(existing_record): + logger.info(f"[register_emoji] Emoji already registered, skipping: {emoji.file_hash}") + return "skipped" + + normalized_description = str(emoji.description or existing_record.description or "").strip() + normalized_emotions = _normalize_emoji_tag_text(normalized_description) + register_time = existing_record.register_time or datetime.now() + + existing_record.full_path = str(emoji.full_path) + existing_record.description = normalized_description + existing_record.query_count = max(int(existing_record.query_count), int(emoji.query_count)) + existing_record.last_used_time = emoji.last_used_time or existing_record.last_used_time + existing_record.is_registered = True + existing_record.is_banned = False + existing_record.no_file_flag = False + existing_record.register_time = register_time + session.add(existing_record) + + emoji.description = normalized_description + emoji.emotion = normalized_emotions + emoji.query_count = existing_record.query_count + emoji.last_used_time = existing_record.last_used_time + emoji.register_time = register_time + + logger.info( + f"[register_emoji] Updated existing record and registered emoji, ID: {existing_record.id}, path: {emoji.full_path}" + ) + return "registered" + except Exception as e: + logger.error(f"[register_emoji] Database query failed: {e}") + return "failed" - # 注册到数据库 - restore_file = False try: with get_db_session() as session: - statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1) - if existing_record := session.exec(statement).first(): - if existing_record.no_file_flag: - existing_record.no_file_flag = False - existing_record.is_banned = False - existing_record.full_path = str(emoji.full_path) - existing_record.description = emoji.description - existing_record.query_count = emoji.query_count - existing_record.last_used_time = emoji.last_used_time - existing_record.register_time = emoji.register_time - session.add(existing_record) - logger.info( - f"[注册表情包] 更新已有记录并注册表情包到数据库, ID: {existing_record.id}, 路径: {emoji.full_path}" - ) - else: - logger.warning(f"[注册表情包] 数据库中已存在表情包记录: {emoji.file_hash}") - restore_file = True - return False - else: - image_record = emoji.to_db_instance() - image_record.is_registered = True - image_record.is_banned = False - image_record.register_time = datetime.now() - session.add(image_record) - session.flush() - record_id = image_record.id - logger.info(f"[注册表情包] 成功注册表情包到数据库, ID: {record_id}, 路径: {emoji.full_path}") + image_record = emoji.to_db_instance() + image_record.is_registered = True + image_record.is_banned = False + image_record.no_file_flag = False + image_record.register_time = datetime.now() + session.add(image_record) + session.flush() + emoji.register_time = image_record.register_time + record_id = image_record.id + logger.info(f"[register_emoji] Registered emoji to database, ID: {record_id}, path: {emoji.full_path}") except Exception as e: - logger.error(f"[注册表情包] 注册到数据库时出错: {e}") - restore_file = True - return False - finally: - if restore_file: - try: - emoji.full_path.replace(original_path) - emoji.full_path = original_path - except Exception as e: - logger.error(f"[注册表情包] 回滚文件移动失败: {e}") - return True + logger.error(f"[register_emoji] Failed to write database record: {e}") + return "failed" + + return "registered" def delete_emoji(self, emoji: MaiEmoji, no_desc: bool = False) -> bool: """ @@ -636,7 +643,6 @@ class EmojiManager: statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1) if image_record := session.exec(statement).first(): image_record.description = emoji.description - image_record.emotion = None session.add(image_record) logger.info(f"[更新表情包] 成功更新表情包信息: {emoji.file_hash}") else: @@ -794,12 +800,15 @@ class EmojiManager: logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}") if self.delete_emoji(emoji_to_delete): self.emojis.remove(emoji_to_delete) - if self.register_emoji_to_db(new_emoji): + register_status = self.register_emoji_to_db(new_emoji) + if register_status == "registered": self.emojis.append(new_emoji) - logger.info(f"[注册表情包] 成功替换并注册新表情包: {new_emoji.description}") + logger.info(f"[register_emoji] Replaced old emoji with new emoji: {new_emoji.description}") return True + if register_status == "skipped": + logger.info(f"[register_emoji] Replacement emoji was already registered: {new_emoji.description}") else: - logger.error(f"[注册表情包] 注册新表情包失败: {new_emoji.description}") + logger.error(f"[register_emoji] Failed to register replacement emoji: {new_emoji.description}") else: logger.error("[错误] 删除表情包失败,无法完成替换") else: @@ -808,7 +817,7 @@ class EmojiManager: logger.error("[决策] 未能解析删除编号") return False - async def build_emoji_description(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]: + async def build_emoji_description(self, target_emoji: MaiEmoji) -> tuple[bool, MaiEmoji]: """ 构建表情包描述 @@ -915,16 +924,12 @@ class EmojiManager: logger.info(f"[构建描述] 成功为表情包构建情绪标签: {target_emoji.description}") return True, target_emoji - async def build_emoji_emotion(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]: - """兼容保留:表情包情绪标签已在 build_emoji_description 中一次性构建。""" - return await self.build_emoji_description(target_emoji) - def check_emoji_file_integrity(self) -> None: """ 检查表情包完整性,删除文件缺失的表情包记录 """ logger.info("[完整性检查] 开始检查表情包文件完整性...") - to_delete_emojis: list[Tuple[MaiEmoji, bool]] = [] + to_delete_emojis: list[tuple[MaiEmoji, bool]] = [] removal_count = 0 for emoji in self.emojis: if not emoji.full_path.exists(): @@ -945,59 +950,34 @@ class EmojiManager: logger.info(f"[完整性检查] 表情包文件完整性检查完成,删除了 {removal_count} 条记录") - def remove_untracked_emoji_files(self) -> None: - """ - 删除未被数据库记录跟踪的表情包文件 - """ - logger.info("[未跟踪表情包清理] 开始清理未被数据库记录跟踪的表情包文件...") - tracked_files = {emoji.full_path.name for emoji in self.emojis} - all_files = set(EMOJI_REGISTERED_DIR.glob("*")) - removal_count = 0 - - for file_path in all_files: - if file_path.name not in tracked_files: - try: - file_path.unlink() - removal_count += 1 - logger.info(f"[未跟踪表情包清理] 删除未跟踪的表情包文件: {file_path.name}") - except Exception as e: - logger.error(f"[未跟踪表情包清理] 删除文件 {file_path.name} 时出错: {e}") - - logger.info(f"[未跟踪表情包清理] 未跟踪表情包文件清理完成,删除了 {removal_count} 个文件") - async def periodic_emoji_maintenance(self) -> None: - """ - 定期执行表情包维护任务,包括完整性检查和未跟踪文件清理 - """ + """Run emoji maintenance tasks periodically.""" while True: - EMOJI_DIR.mkdir(parents=True, exist_ok=True) - EMOJI_REGISTERED_DIR.mkdir(parents=True, exist_ok=True) + _ensure_directories() try: self.check_emoji_file_integrity() - self.remove_untracked_emoji_files() except Exception as e: - logger.error(f"[定期维护] 执行表情包维护任务时出错: {e}") + logger.error(f"[emoji_maintenance] Maintenance task failed: {e}") if global_config.emoji.steal_emoji and ( self._emoji_num < global_config.emoji.max_reg_num or (self._emoji_num > global_config.emoji.max_reg_num and global_config.emoji.do_replace) ): - logger.info("[定期维护] 尝试从表情包盗取目录注册新表情包...") + logger.info("[emoji_maintenance] Scanning data/emoji for new emojis...") for emoji_file in EMOJI_DIR.iterdir(): if not emoji_file.is_file(): continue try: - register_success = await self.register_emoji_by_filename(emoji_file) + register_status = await self.register_emoji_by_filename(emoji_file) except Exception as e: - logger.error(f"[定期维护] 注册表情包 {emoji_file.name} 时发生未处理异常: {e}") - register_success = False - if register_success: - break # 每次只注册一个表情包 - try: - emoji_file.unlink() - logger.info(f"[定期维护] 删除无法注册的表情包文件: {emoji_file.name}") - except Exception as e: - logger.error(f"[定期维护] 删除文件 {emoji_file.name} 时出错: {e}") + logger.error(f"[emoji_maintenance] Failed to process {emoji_file.name}: {e}") + register_status = "failed" + if register_status == "registered": + break + if register_status == "skipped": + logger.debug(f"[emoji_maintenance] Emoji already registered, keep file: {emoji_file.name}") + else: + logger.debug(f"[emoji_maintenance] Emoji not registered, keep file: {emoji_file.name}") wait_seconds = max(global_config.emoji.check_interval * 60, 0) try: await asyncio.wait_for(self._maintenance_wakeup_event.wait(), timeout=wait_seconds) @@ -1006,84 +986,81 @@ class EmojiManager: finally: self._maintenance_wakeup_event.clear() - async def register_emoji_by_filename(self, filename: Path | str) -> bool: - """ - 根据指定的表情包图片,分析并注册到数据库 - - Args: - filename (Path | str): 表情包图片的完整文件路径(可能根据文件实际格式修正) - - Returns: - return (bool): 注册是否成功 - """ + async def register_emoji_by_filename(self, filename: Path | str) -> EmojiRegisterStatus: + """Register an emoji file from ``data/emoji`` without moving or deleting it.""" file_full_path = Path(filename).absolute().resolve() if not file_full_path.exists(): - logger.error(f"[注册表情包] 表情包文件不存在: {file_full_path}") - return False + logger.error(f"[register_emoji] Emoji file does not exist: {file_full_path}") + return "failed" try: target_emoji = MaiEmoji(full_path=file_full_path) except Exception as e: - logger.error(f"[注册表情包] 创建表情包对象时出错: {e}") - return False + logger.error(f"[register_emoji] Failed to create emoji object: {e}") + return "failed" calc_success = await target_emoji.calculate_hash_format() if not calc_success: - logger.error(f"[注册表情包] 计算表情包哈希值和格式失败: {file_full_path}") - return False - file_full_path = target_emoji.full_path # 更新为可能修正后的路径 + logger.error(f"[register_emoji] Failed to calculate hash and format: {file_full_path}") + return "failed" + file_full_path = target_emoji.full_path - # 2. 先验证数据库中是否已经存在相同哈希的表情包,避免重复构建 + existing_record: Optional[Images] = None try: with get_db_session_manual() as session: statement = ( select(Images).filter_by(image_hash=target_emoji.file_hash, image_type=ImageType.EMOJI).limit(1) ) - if image_record := session.exec(statement).first(): - if image_record.no_file_flag: - image_record.no_file_flag = False - image_record.is_banned = False - image_record.is_registered = True - image_record.full_path = str(target_emoji.full_path) - session.add(image_record) - session.commit() - logger.info(f"表情包注册成功,Hash: {target_emoji.file_hash}") - return True - else: - logger.warning(f"[注册表情包] 数据库中已存在表情包记录,跳过注册: {target_emoji.file_name}") - return False + existing_record = session.exec(statement).first() except Exception as e: - logger.error(f"[注册表情包] 查询数据库时出错: {e}") - return False + logger.error(f"[register_emoji] Failed to query database: {e}") + return "failed" + + if existing_record is not None: + if existing_record.is_registered and _is_available_emoji_record(existing_record): + logger.info(f"[register_emoji] Emoji already registered, skipping: {target_emoji.file_name}") + return "skipped" + + cached_description = str(existing_record.description or "").strip() + if cached_description: + normalized_emotions = _normalize_emoji_tag_text(cached_description) + target_emoji.description = ",".join(normalized_emotions) + target_emoji.emotion = normalized_emotions + target_emoji.query_count = existing_record.query_count + target_emoji.last_used_time = existing_record.last_used_time + target_emoji.register_time = existing_record.register_time - # 3. 检查内存缓存是否已经存在 if existing_emoji := self.get_emoji_by_hash(target_emoji.file_hash): - logger.warning(f"[注册表情包] 表情包已存在,跳过注册: {existing_emoji.file_name}") - return False - # 3. 构建描述(包含情绪标签) - desc_success, target_emoji = await self.build_emoji_description(target_emoji) - if not desc_success: - logger.error(f"[注册表情包] 构建表情包描述失败: {file_full_path}") - return False + logger.info(f"[register_emoji] Emoji already loaded in memory, skipping: {existing_emoji.file_name}") + return "skipped" + + if not target_emoji.description: + desc_success, target_emoji = await self.build_emoji_description(target_emoji) + if not desc_success: + logger.error(f"[register_emoji] Failed to build emoji description: {file_full_path}") + return "failed" - # 4. 检查容量并决定是否替换或者直接注册 if self._emoji_num >= global_config.emoji.max_reg_num and global_config.emoji.do_replace: - logger.warning(f"[注册表情包] 表情包数量已达上限{global_config.emoji.max_reg_num},尝试替换一个表情包") + logger.warning( + f"[register_emoji] Emoji limit {global_config.emoji.max_reg_num} reached, trying replacement" + ) replaced = await self.replace_an_emoji_by_llm(target_emoji) if not replaced: - logger.error("[注册表情包] 替换表情包失败,无法注册新表情包") - return False - return True - else: - if self.register_emoji_to_db(target_emoji): - self.emojis.append(target_emoji) - self._emoji_num += 1 - logger.info(f"[注册表情包] 成功注册新表情包: {target_emoji.file_name}") - return True - else: - logger.error(f"[注册表情包] 注册表情包到数据库失败: {file_full_path}") - return False + logger.error("[register_emoji] Failed to replace existing emoji") + return "failed" + return "registered" - def _calculate_emotion_similarity_list(self, text_emotion: str) -> List[Tuple[MaiEmoji, float]]: + register_status = self.register_emoji_to_db(target_emoji) + if register_status == "registered": + self.emojis.append(target_emoji) + self._emoji_num = len(self.emojis) + logger.info(f"[register_emoji] Registered new emoji: {target_emoji.file_name}") + elif register_status == "failed": + logger.error(f"[register_emoji] Failed to register emoji in database: {file_full_path}") + else: + logger.info(f"[register_emoji] Emoji already registered, skipping: {target_emoji.file_name}") + return register_status + + def _calculate_emotion_similarity_list(self, text_emotion: str) -> list[tuple[MaiEmoji, float]]: """ 计算文本情感标签与所有表情包情感标签的相似度列表 @@ -1096,7 +1073,7 @@ class EmojiManager: if not normalized_text_emotion: return [] - similarity_list: List[Tuple[MaiEmoji, float]] = [] + similarity_list: list[tuple[MaiEmoji, float]] = [] for emoji in self.emojis: candidate_emotions = _get_emoji_emotions(emoji) if not candidate_emotions: diff --git a/src/chat/emoji_system/maisaka_tool.py b/src/chat/emoji_system/maisaka_tool.py index 50000942..bbdc072d 100644 --- a/src/chat/emoji_system/maisaka_tool.py +++ b/src/chat/emoji_system/maisaka_tool.py @@ -1,8 +1,8 @@ """Maisaka 表情工具内置能力。""" -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass, field -from typing import Any, Optional, Sequence, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import random @@ -120,8 +120,6 @@ def _normalize_emotions(emoji: MaiEmoji) -> list[str]: """提取并清洗单个表情的情绪标签。""" if emoji.description: return _normalize_emoji_tag_text(emoji.description) - if emoji.emotion: - return _normalize_emoji_tag_text(emoji.emotion) return [] diff --git a/src/chat/image_system/image_manager.py b/src/chat/image_system/image_manager.py index c15f647f..81b95b38 100644 --- a/src/chat/image_system/image_manager.py +++ b/src/chat/image_system/image_manager.py @@ -41,6 +41,7 @@ class ImageManager: """初始化图片管理器。""" _ensure_image_dir_exists() self._pending_description_tasks: Dict[str, asyncio.Task[None]] = {} + self.cleanup_legacy_image_registration_records() logger.info("图片管理器初始化完成") @@ -50,6 +51,17 @@ class ImageManager: statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1) return session.exec(statement).first() + def _normalize_image_registration_fields(self, record: Images) -> bool: + """Normalize accidental emoji registration fields on image records.""" + if record.image_type != ImageType.IMAGE: + return False + if not record.is_registered and record.register_time is None: + return False + + record.is_registered = False + record.register_time = None + return True + async def get_image_description( self, *, @@ -182,8 +194,9 @@ class ImageManager: try: with get_db_session() as session: record = image.to_db_instance() - record.is_registered = True - record.register_time = record.last_used_time = datetime.now() + record.is_registered = False + record.register_time = None + record.last_used_time = datetime.now() session.add(record) session.flush() # 确保记录被写入数据库以获取ID record_id = record.id @@ -209,6 +222,7 @@ class ImageManager: if not record: logger.error(f"未找到哈希值为 {image.file_hash} 的图片记录,无法更新描述") return False + self._normalize_image_registration_fields(record) record.description = image.description record.last_used_time = datetime.now() record.vlm_processed = image.vlm_processed @@ -258,6 +272,7 @@ class ImageManager: with get_db_session() as session: statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1) if record := session.exec(statement).first(): + self._normalize_image_registration_fields(record) logger.info(f"图片已存在于数据库中,哈希值: {hash_str}") record.last_used_time = datetime.now() record.query_count += 1 @@ -335,6 +350,24 @@ class ImageManager: logger.info(f"清理完成: {invalid_counter} 条无效描述记录,{null_path_counter} 条文件路径不存在记录") + def cleanup_legacy_image_registration_records(self) -> None: + """Clean up legacy image records with mistaken registration fields.""" + fixed_counter = 0 + try: + with get_db_session() as session: + statement = select(Images).filter_by(image_type=ImageType.IMAGE) + for record in session.exec(statement).yield_per(100): + if not self._normalize_image_registration_fields(record): + continue + session.add(record) + fixed_counter += 1 + except Exception as e: + logger.error(f"Failed to clean image registration state: {e}") + return + + if fixed_counter: + logger.info(f"Cleaned mistaken registration state on {fixed_counter} image records") + async def _generate_image_description(self, image_bytes: bytes, image_format: str) -> str: prompt = global_config.personality.visual_style image_base64 = base64.b64encode(image_bytes).decode("utf-8") diff --git a/src/common/data_models/action_record_data_model.py b/src/common/data_models/action_record_data_model.py index be6529b3..f757b1d6 100644 --- a/src/common/data_models/action_record_data_model.py +++ b/src/common/data_models/action_record_data_model.py @@ -1,14 +1,20 @@ from datetime import datetime -from typing import Optional, Dict +from typing import Dict, Optional import json -from src.common.database.database_model import ActionRecord +from src.common.database.database_model import ToolRecord from . import BaseDatabaseDataModel -class MaiActionRecord(BaseDatabaseDataModel[ActionRecord]): +class MaiActionRecord(BaseDatabaseDataModel[ToolRecord]): + """``action_records`` 的兼容数据模型。 + + 历史动作记录已统一并入 ``tool_records``,该类仅保留旧命名接口, + 底层读写对象统一映射为 ``ToolRecord``。 + """ + def __init__( self, action_id: str, @@ -21,45 +27,39 @@ class MaiActionRecord(BaseDatabaseDataModel[ActionRecord]): action_display_prompt: Optional[str] = None, ): self.action_id = action_id - """动作ID""" self.timestamp = timestamp - """时间戳""" self.session_id = session_id - """会话ID""" self.action_name = action_name - """动作名称""" self.action_reasoning = action_reasoning - """动作推理过程""" self.action_data = action_data or {} - """动作数据""" self.action_builtin_prompt = action_builtin_prompt - """内置动作提示""" self.action_display_prompt = action_display_prompt - """最终输入到 Prompt 的内容""" @classmethod - def from_db_instance(cls, db_record: ActionRecord): - """Create a data model object from a database record.""" + def from_db_instance(cls, db_record: ToolRecord): + """从数据库实例创建兼容数据模型对象。""" + return cls( - action_id=db_record.action_id, + action_id=db_record.tool_id, timestamp=db_record.timestamp, session_id=db_record.session_id, - action_name=db_record.action_name, - action_reasoning=db_record.action_reasoning, - action_data=json.loads(db_record.action_data) if db_record.action_data else None, - action_builtin_prompt=db_record.action_builtin_prompt, - action_display_prompt=db_record.action_display_prompt, + action_name=db_record.tool_name, + action_reasoning=db_record.tool_reasoning, + action_data=json.loads(db_record.tool_data) if db_record.tool_data else None, + action_builtin_prompt=db_record.tool_builtin_prompt, + action_display_prompt=db_record.tool_display_prompt, ) def to_db_instance(self): - """Convert the data model object back to a database instance.""" - return ActionRecord( - action_id=self.action_id, + """将兼容数据模型对象转换为 ``ToolRecord``。""" + + return ToolRecord( + tool_id=self.action_id, timestamp=self.timestamp, session_id=self.session_id, - action_name=self.action_name, - action_reasoning=self.action_reasoning, - action_data=json.dumps(self.action_data) if self.action_data else None, - action_builtin_prompt=self.action_builtin_prompt, - action_display_prompt=self.action_display_prompt, + tool_name=self.action_name, + tool_reasoning=self.action_reasoning, + tool_data=json.dumps(self.action_data) if self.action_data else None, + tool_builtin_prompt=self.action_builtin_prompt, + tool_display_prompt=self.action_display_prompt, ) diff --git a/src/common/data_models/image_data_model.py b/src/common/data_models/image_data_model.py index 45caa97a..0d9a9923 100644 --- a/src/common/data_models/image_data_model.py +++ b/src/common/data_models/image_data_model.py @@ -1,16 +1,17 @@ -from datetime import datetime -from pathlib import Path -from PIL import Image as PILImage -from rich.traceback import install -from typing import Optional, List - import asyncio import hashlib import io import traceback +from datetime import datetime +from pathlib import Path +from typing import List, Optional + +from PIL import Image as PILImage +from rich.traceback import install from src.common.database.database_model import Images, ImageType from src.common.logger import get_logger + from . import BaseDatabaseDataModel @@ -152,7 +153,7 @@ class MaiEmoji(BaseImageDataModel): raise ValueError(f"数据库记录 {db_record.image_hash} 标记为文件不存在,无法创建 MaiEmoji 对象") obj = cls(db_record.full_path) obj.file_hash = db_record.image_hash - description = db_record.description or db_record.emotion or "" + description = db_record.description or "" obj.description = description normalized_tags = [ str(item).strip() @@ -175,7 +176,6 @@ class MaiEmoji(BaseImageDataModel): description=self.description, full_path=str(self.full_path), image_type=ImageType.EMOJI, - emotion=None, query_count=self.query_count, last_used_time=self.last_used_time, register_time=self.register_time, diff --git a/src/common/database/database.py b/src/common/database/database.py index 2b22475a..04d0a5d0 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import ContextManager, Generator, TYPE_CHECKING from rich.traceback import install -from sqlalchemy import event, text +from sqlalchemy import event from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, Session, create_engine @@ -63,41 +63,6 @@ _migration_bootstrapper = create_database_migration_bootstrapper(engine) _db_initialized = False -def _migrate_action_records_to_tool_records() -> None: - """将旧的 ``action_records`` 历史数据迁移到 ``tool_records``。""" - migration_sql = text( - """ - INSERT INTO tool_records ( - tool_id, - timestamp, - session_id, - tool_name, - tool_reasoning, - tool_data, - tool_builtin_prompt, - tool_display_prompt - ) - SELECT - action_id, - timestamp, - session_id, - action_name, - action_reasoning, - action_data, - action_builtin_prompt, - action_display_prompt - FROM action_records - WHERE NOT EXISTS ( - SELECT 1 - FROM tool_records - WHERE tool_records.tool_id = action_records.action_id - ) - """ - ) - with engine.begin() as connection: - connection.execute(migration_sql) - - def initialize_database() -> None: """初始化数据库连接、结构与启动期迁移。 @@ -105,8 +70,7 @@ def initialize_database() -> None: 1. 确保数据库目录存在; 2. 加载 SQLModel 模型定义; 3. 执行已注册的启动期迁移; - 4. 兜底执行 ``create_all`` 确保当前模型定义已建表; - 5. 执行项目现有的轻量数据补迁移逻辑。 + 4. 兜底执行 ``create_all`` 确保当前模型定义已建表。 """ global _db_initialized if _db_initialized: @@ -120,7 +84,6 @@ def initialize_database() -> None: f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}" ) SQLModel.metadata.create_all(engine) - _migrate_action_records_to_tool_records() _migration_bootstrapper.finalize_database(migration_state) _db_initialized = True diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 9874af67..4164429f 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -94,7 +94,6 @@ class Images(SQLModel, table=True): full_path: str = Field(max_length=1024) # 文件的完整路径 (包括文件名) image_type: ImageType = Field(sa_column=Column(SQLEnum(ImageType)), default=ImageType.EMOJI) """图片类型,例如 'emoji' 或 'image'""" - emotion: Optional[str] = Field(default=None, nullable=True) # 表情包的情感标签,逗号分隔 query_count: int = Field(default=0) # 被查询次数 is_registered: bool = Field(default=False) # 是否已经注册 @@ -113,27 +112,6 @@ class Images(SQLModel, table=True): vlm_processed: bool = Field(default=False) # 是否已经过VLM处理 -class ActionRecord(SQLModel, table=True): - """存储动作记录""" - - __tablename__ = "action_records" # type: ignore - - id: Optional[int] = Field(default=None, primary_key=True) # 自增主键 - - # 元信息 - action_id: str = Field(index=True, max_length=255) # 动作ID - timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 记录时间戳 - session_id: str = Field(index=True, max_length=255) # 对应的 ChatSession session_id - - # 调用信息 - action_name: str = Field(index=True, max_length=255) # 动作名称 - action_reasoning: Optional[str] = Field(default=None) # 动作推理过程 - action_data: Optional[str] = Field(default=None) # 动作数据,JSON格式存储 - - action_builtin_prompt: Optional[str] = Field(default=None) # 内置动作提示 - action_display_prompt: Optional[str] = Field(default=None) # 最终输入到Prompt的内容 - - class ToolRecord(SQLModel, table=True): """存储工具调用记录""" @@ -281,28 +259,6 @@ class ChatHistory(SQLModel, table=True): summary: str # 概括:对这段话的平文本概括 -class ThinkingQuestion(SQLModel, table=True): - """存储思考型问题的模型""" - - __tablename__ = "thinking_questions" # type: ignore - - id: Optional[int] = Field(default=None, primary_key=True) # 自增主键 - - # 问答对 - question: str # 问题内容 - context: Optional[str] = Field(default=None, nullable=True) # 上下文 - found_answer: bool = Field(default=False) # 是否找到答案 - answer: Optional[str] = Field(default=None, nullable=True) # 问题答案 - - thinking_steps: Optional[str] = Field(default=None, nullable=True) # 思考步骤,JSON格式存储 - created_timestamp: datetime = Field( - default_factory=datetime.now, sa_column=Column(DateTime, index=True) - ) # 创建时间 - updated_timestamp: datetime = Field( - default_factory=datetime.now, sa_column=Column(DateTime, index=True) - ) # 最后更新时间 - - class BinaryData(SQLModel, table=True): """存储二进制数据的模型""" diff --git a/src/common/database/migrations/__init__.py b/src/common/database/migrations/__init__.py index e9a69bd1..34e59031 100644 --- a/src/common/database/migrations/__init__.py +++ b/src/common/database/migrations/__init__.py @@ -5,6 +5,7 @@ from .builtin import ( EMPTY_SCHEMA_VERSION, LATEST_SCHEMA_VERSION, LEGACY_V1_SCHEMA_VERSION, + V2_SCHEMA_VERSION, build_default_migration_registry, build_default_schema_version_resolver, ) @@ -61,6 +62,7 @@ __all__ = [ "EMPTY_SCHEMA_VERSION", "LATEST_SCHEMA_VERSION", "LEGACY_V1_SCHEMA_VERSION", + "V2_SCHEMA_VERSION", "MigrationExecutionContext", "MigrationPlan", "MigrationPlanner", diff --git a/src/common/database/migrations/builtin.py b/src/common/database/migrations/builtin.py index 5b16780b..501b3e3c 100644 --- a/src/common/database/migrations/builtin.py +++ b/src/common/database/migrations/builtin.py @@ -6,12 +6,14 @@ from .legacy_v1_to_v2 import migrate_legacy_v1_to_v2 from .models import DatabaseSchemaSnapshot, MigrationStep from .registry import MigrationRegistry from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver -from .version_store import SQLiteUserVersionStore from .schema import SQLiteSchemaInspector +from .v2_to_v3 import migrate_v2_to_v3 +from .version_store import SQLiteUserVersionStore EMPTY_SCHEMA_VERSION = 0 LEGACY_V1_SCHEMA_VERSION = 1 -LATEST_SCHEMA_VERSION = 2 +V2_SCHEMA_VERSION = 2 +LATEST_SCHEMA_VERSION = 3 _LEGACY_V1_EXCLUSIVE_TABLES = ( "chat_streams", @@ -24,6 +26,13 @@ _LEGACY_V1_EXCLUSIVE_TABLES = ( "messages", "thinking_back", ) +_COMMON_MARKER_TABLES = ( + "mai_messages", + "chat_sessions", + "expressions", + "jargons", + "tool_records", +) class LatestSchemaVersionDetector(BaseSchemaVersionDetector): @@ -36,6 +45,7 @@ class LatestSchemaVersionDetector(BaseSchemaVersionDetector): Returns: str: 当前探测器名称。 """ + return "latest_schema_detector" def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]: @@ -47,18 +57,16 @@ class LatestSchemaVersionDetector(BaseSchemaVersionDetector): Returns: Optional[int]: 若识别为最新结构则返回最新版本号,否则返回 ``None``。 """ + if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES): return None - - latest_marker_tables = ( - "mai_messages", - "chat_sessions", - "expressions", - "jargons", - "thinking_questions", - "tool_records", - ) - if not all(snapshot.has_table(table_name) for table_name in latest_marker_tables): + if not all(snapshot.has_table(table_name) for table_name in _COMMON_MARKER_TABLES): + return None + if snapshot.has_table("action_records"): + return None + if snapshot.has_table("thinking_questions"): + return None + if snapshot.has_column("images", "emotion"): return None if not snapshot.has_column("images", "image_hash"): return None @@ -66,13 +74,53 @@ class LatestSchemaVersionDetector(BaseSchemaVersionDetector): return None if not snapshot.has_column("images", "image_type"): return None + if not snapshot.has_column("chat_history", "session_id"): + return None + if not snapshot.has_column("person_info", "user_nickname"): + return None + return LATEST_SCHEMA_VERSION + + +class V2SchemaVersionDetector(BaseSchemaVersionDetector): + """v2 schema 结构探测器。""" + + @property + def name(self) -> str: + """返回探测器名称。 + + Returns: + str: 当前探测器名称。 + """ + + return "v2_schema_detector" + + def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]: + """检测数据库是否为 v2 结构。 + + Args: + snapshot: 当前数据库结构快照。 + + Returns: + Optional[int]: 若识别为 v2 结构则返回 ``2``,否则返回 ``None``。 + """ + + if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES): + return None + if not all(snapshot.has_table(table_name) for table_name in _COMMON_MARKER_TABLES): + return None + if not snapshot.has_table("action_records"): + return None + if not snapshot.has_table("thinking_questions"): + return None + if not snapshot.has_column("images", "emotion"): + return None if not snapshot.has_column("action_records", "session_id"): return None if not snapshot.has_column("chat_history", "session_id"): return None if not snapshot.has_column("person_info", "user_nickname"): return None - return LATEST_SCHEMA_VERSION + return V2_SCHEMA_VERSION class LegacyV1SchemaDetector(BaseSchemaVersionDetector): @@ -85,6 +133,7 @@ class LegacyV1SchemaDetector(BaseSchemaVersionDetector): Returns: str: 当前探测器名称。 """ + return "legacy_v1_schema_detector" def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]: @@ -96,6 +145,7 @@ class LegacyV1SchemaDetector(BaseSchemaVersionDetector): Returns: Optional[int]: 若识别为旧版结构则返回 ``1``,否则返回 ``None``。 """ + if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES): return LEGACY_V1_SCHEMA_VERSION @@ -121,8 +171,10 @@ def build_default_schema_version_detectors() -> List[BaseSchemaVersionDetector]: Returns: List[BaseSchemaVersionDetector]: 按优先级排序的探测器列表。 """ + return [ LatestSchemaVersionDetector(), + V2SchemaVersionDetector(), LegacyV1SchemaDetector(), ] @@ -133,6 +185,7 @@ def build_default_schema_version_resolver() -> SchemaVersionResolver: Returns: SchemaVersionResolver: 配置完成的 schema 版本解析器。 """ + return SchemaVersionResolver( version_store=SQLiteUserVersionStore(), schema_inspector=SQLiteSchemaInspector(), @@ -146,14 +199,22 @@ def build_default_migration_registry() -> MigrationRegistry: Returns: MigrationRegistry: 含默认迁移步骤的注册表实例。 """ + return MigrationRegistry( steps=[ MigrationStep( version_from=LEGACY_V1_SCHEMA_VERSION, - version_to=LATEST_SCHEMA_VERSION, - name="legacy_v1_to_latest_v2", - description="将旧版 0.x 数据库整体迁移到当前最新 schema。", + version_to=V2_SCHEMA_VERSION, + name="legacy_v1_to_v2", + description="将旧版 0.x 数据库迁移到 v2 schema。", handler=migrate_legacy_v1_to_v2, - ) + ), + MigrationStep( + version_from=V2_SCHEMA_VERSION, + version_to=LATEST_SCHEMA_VERSION, + name="v2_to_v3", + description="移除废弃表,并将 emoji 标签统一收敛到 description 字段。", + handler=migrate_v2_to_v3, + ), ] ) diff --git a/src/common/database/migrations/frozen_v2_schema.py b/src/common/database/migrations/frozen_v2_schema.py new file mode 100644 index 00000000..f427ace5 --- /dev/null +++ b/src/common/database/migrations/frozen_v2_schema.py @@ -0,0 +1,298 @@ +"""冻结的 v2 schema 快照。 + +该模块只用于 ``legacy_v1_to_v2`` 迁移,避免迁移过程依赖当前运行时代码中的 +最新 SQLModel 定义,导致历史迁移随着后续 schema 演进而失真。 +""" + +from sqlalchemy.engine import Connection + +_V2_TABLE_STATEMENTS = ( + """ + CREATE TABLE IF NOT EXISTS action_records ( + id INTEGER NOT NULL, + action_id VARCHAR(255) NOT NULL, + timestamp DATETIME, + session_id VARCHAR(255) NOT NULL, + action_name VARCHAR(255) NOT NULL, + action_reasoning VARCHAR, + action_data VARCHAR, + action_builtin_prompt VARCHAR, + action_display_prompt VARCHAR, + PRIMARY KEY (id) + ) + """, + """ + CREATE TABLE IF NOT EXISTS binary_data ( + id INTEGER NOT NULL, + data_hash VARCHAR(255) NOT NULL, + full_path VARCHAR(1024) NOT NULL, + PRIMARY KEY (id) + ) + """, + """ + CREATE TABLE IF NOT EXISTS chat_history ( + id INTEGER NOT NULL, + session_id VARCHAR(255) NOT NULL, + start_timestamp DATETIME, + end_timestamp DATETIME, + query_count INTEGER NOT NULL, + query_forget_count INTEGER NOT NULL, + original_messages VARCHAR NOT NULL, + participants VARCHAR NOT NULL, + theme VARCHAR NOT NULL, + keywords VARCHAR NOT NULL, + summary VARCHAR NOT NULL, + PRIMARY KEY (id) + ) + """, + """ + CREATE TABLE IF NOT EXISTS chat_sessions ( + id INTEGER NOT NULL, + session_id VARCHAR(255) NOT NULL, + created_timestamp DATETIME, + last_active_timestamp DATETIME, + user_id VARCHAR(255), + group_id VARCHAR(255), + platform VARCHAR(100) NOT NULL, + PRIMARY KEY (id) + ) + """, + """ + CREATE TABLE IF NOT EXISTS command_records ( + id INTEGER NOT NULL, + timestamp DATETIME, + session_id VARCHAR(255) NOT NULL, + command_name VARCHAR(255) NOT NULL, + command_data VARCHAR, + command_result VARCHAR, + PRIMARY KEY (id) + ) + """, + """ + CREATE TABLE IF NOT EXISTS expressions ( + id INTEGER NOT NULL, + situation VARCHAR(255) NOT NULL, + style VARCHAR(255) NOT NULL, + content_list VARCHAR NOT NULL, + count INTEGER NOT NULL, + last_active_time DATETIME, + create_time DATETIME, + session_id VARCHAR(255), + checked BOOLEAN NOT NULL, + rejected BOOLEAN NOT NULL, + modified_by VARCHAR(4), + PRIMARY KEY (id) + ) + """, + """ + CREATE TABLE IF NOT EXISTS images ( + id INTEGER NOT NULL, + image_hash VARCHAR(255) NOT NULL, + description VARCHAR NOT NULL, + full_path VARCHAR(1024) NOT NULL, + image_type VARCHAR(5), + emotion VARCHAR, + query_count INTEGER NOT NULL, + is_registered BOOLEAN NOT NULL, + is_banned BOOLEAN NOT NULL, + no_file_flag BOOLEAN NOT NULL, + record_time DATETIME, + register_time DATETIME, + last_used_time DATETIME, + vlm_processed BOOLEAN NOT NULL, + PRIMARY KEY (id) + ) + """, + """ + CREATE TABLE IF NOT EXISTS jargons ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + content VARCHAR(255) NOT NULL, + raw_content TEXT, + meaning TEXT NOT NULL, + session_id_dict TEXT NOT NULL, + count INTEGER NOT NULL, + is_jargon BOOLEAN, + is_complete BOOLEAN NOT NULL, + is_global BOOLEAN NOT NULL, + last_inference_count INTEGER NOT NULL, + inference_with_context TEXT, + inference_with_content_only TEXT + ) + """, + """ + CREATE TABLE IF NOT EXISTS llm_usage ( + id INTEGER NOT NULL, + model_name VARCHAR(255) NOT NULL, + model_assign_name VARCHAR(255), + model_api_provider_name VARCHAR(255) NOT NULL, + endpoint VARCHAR(255), + user_type VARCHAR(6), + request_type VARCHAR(50) NOT NULL, + time_cost FLOAT, + timestamp DATETIME, + prompt_tokens INTEGER NOT NULL, + completion_tokens INTEGER NOT NULL, + total_tokens INTEGER NOT NULL, + cost FLOAT NOT NULL, + PRIMARY KEY (id) + ) + """, + """ + CREATE TABLE IF NOT EXISTS mai_knowledge ( + id INTEGER NOT NULL, + knowledge_id VARCHAR(255) NOT NULL, + category_id VARCHAR(32) NOT NULL, + content VARCHAR NOT NULL, + normalized_content VARCHAR NOT NULL, + metadata_json VARCHAR, + created_at DATETIME, + PRIMARY KEY (id) + ) + """, + """ + CREATE TABLE IF NOT EXISTS mai_messages ( + id INTEGER NOT NULL, + message_id VARCHAR(255) NOT NULL, + timestamp DATETIME, + platform VARCHAR(100) NOT NULL, + user_id VARCHAR(255) NOT NULL, + user_nickname VARCHAR(255) NOT NULL, + user_cardname VARCHAR(255), + group_id VARCHAR(255), + group_name VARCHAR(255), + is_mentioned BOOLEAN NOT NULL, + is_at BOOLEAN NOT NULL, + session_id VARCHAR(255) NOT NULL, + reply_to VARCHAR(255), + is_emoji BOOLEAN NOT NULL, + is_picture BOOLEAN NOT NULL, + is_command BOOLEAN NOT NULL, + is_notify BOOLEAN NOT NULL, + raw_content BLOB, + processed_plain_text VARCHAR, + display_message VARCHAR, + additional_config VARCHAR, + PRIMARY KEY (id) + ) + """, + """ + CREATE TABLE IF NOT EXISTS online_time ( + id INTEGER NOT NULL, + timestamp DATETIME, + duration_minutes INTEGER NOT NULL, + start_timestamp DATETIME, + end_timestamp DATETIME, + PRIMARY KEY (id) + ) + """, + """ + CREATE TABLE IF NOT EXISTS person_info ( + id INTEGER NOT NULL, + is_known BOOLEAN NOT NULL, + person_id VARCHAR(255) NOT NULL, + person_name VARCHAR(255), + name_reason VARCHAR, + platform VARCHAR(100) NOT NULL, + user_id VARCHAR(255) NOT NULL, + user_nickname VARCHAR(255) NOT NULL, + group_cardname VARCHAR, + memory_points VARCHAR, + know_counts INTEGER NOT NULL, + first_known_time DATETIME, + last_known_time DATETIME, + PRIMARY KEY (id) + ) + """, + """ + CREATE TABLE IF NOT EXISTS thinking_questions ( + id INTEGER NOT NULL, + question VARCHAR NOT NULL, + context VARCHAR, + found_answer BOOLEAN NOT NULL, + answer VARCHAR, + thinking_steps VARCHAR, + created_timestamp DATETIME, + updated_timestamp DATETIME, + PRIMARY KEY (id) + ) + """, + """ + CREATE TABLE IF NOT EXISTS tool_records ( + id INTEGER NOT NULL, + tool_id VARCHAR(255) NOT NULL, + timestamp DATETIME, + session_id VARCHAR(255) NOT NULL, + tool_name VARCHAR(255) NOT NULL, + tool_reasoning VARCHAR, + tool_data VARCHAR, + tool_builtin_prompt VARCHAR, + tool_display_prompt VARCHAR, + PRIMARY KEY (id) + ) + """, +) + +_V2_INDEX_STATEMENTS = ( + "CREATE INDEX IF NOT EXISTS ix_action_records_action_id ON action_records (action_id)", + "CREATE INDEX IF NOT EXISTS ix_action_records_action_name ON action_records (action_name)", + "CREATE INDEX IF NOT EXISTS ix_action_records_session_id ON action_records (session_id)", + "CREATE INDEX IF NOT EXISTS ix_action_records_timestamp ON action_records (timestamp)", + "CREATE INDEX IF NOT EXISTS ix_binary_data_data_hash ON binary_data (data_hash)", + "CREATE INDEX IF NOT EXISTS ix_chat_history_end_timestamp ON chat_history (end_timestamp)", + "CREATE INDEX IF NOT EXISTS ix_chat_history_session_id ON chat_history (session_id)", + "CREATE INDEX IF NOT EXISTS ix_chat_history_start_timestamp ON chat_history (start_timestamp)", + "CREATE INDEX IF NOT EXISTS ix_chat_sessions_created_timestamp ON chat_sessions (created_timestamp)", + "CREATE INDEX IF NOT EXISTS ix_chat_sessions_group_id ON chat_sessions (group_id)", + "CREATE INDEX IF NOT EXISTS ix_chat_sessions_last_active_timestamp ON chat_sessions (last_active_timestamp)", + "CREATE INDEX IF NOT EXISTS ix_chat_sessions_platform ON chat_sessions (platform)", + "CREATE UNIQUE INDEX IF NOT EXISTS ix_chat_sessions_session_id ON chat_sessions (session_id)", + "CREATE INDEX IF NOT EXISTS ix_chat_sessions_user_id ON chat_sessions (user_id)", + "CREATE INDEX IF NOT EXISTS ix_command_records_command_name ON command_records (command_name)", + "CREATE INDEX IF NOT EXISTS ix_command_records_session_id ON command_records (session_id)", + "CREATE INDEX IF NOT EXISTS ix_command_records_timestamp ON command_records (timestamp)", + "CREATE INDEX IF NOT EXISTS ix_expressions_last_active_time ON expressions (last_active_time)", + "CREATE INDEX IF NOT EXISTS ix_expressions_situation ON expressions (situation)", + "CREATE INDEX IF NOT EXISTS ix_expressions_style ON expressions (style)", + "CREATE INDEX IF NOT EXISTS ix_images_image_hash ON images (image_hash)", + "CREATE INDEX IF NOT EXISTS ix_images_record_time ON images (record_time)", + "CREATE INDEX IF NOT EXISTS ix_jargons_content ON jargons (content)", + "CREATE INDEX IF NOT EXISTS ix_llm_usage_model_api_provider_name ON llm_usage (model_api_provider_name)", + "CREATE INDEX IF NOT EXISTS ix_llm_usage_model_assign_name ON llm_usage (model_assign_name)", + "CREATE INDEX IF NOT EXISTS ix_llm_usage_model_name ON llm_usage (model_name)", + "CREATE INDEX IF NOT EXISTS ix_llm_usage_timestamp ON llm_usage (timestamp)", + "CREATE INDEX IF NOT EXISTS ix_mai_knowledge_category_id ON mai_knowledge (category_id)", + "CREATE INDEX IF NOT EXISTS ix_mai_knowledge_created_at ON mai_knowledge (created_at)", + "CREATE INDEX IF NOT EXISTS ix_mai_knowledge_knowledge_id ON mai_knowledge (knowledge_id)", + "CREATE INDEX IF NOT EXISTS ix_mai_knowledge_normalized_content ON mai_knowledge (normalized_content)", + "CREATE INDEX IF NOT EXISTS ix_mai_messages_group_id ON mai_messages (group_id)", + "CREATE INDEX IF NOT EXISTS ix_mai_messages_message_id ON mai_messages (message_id)", + "CREATE INDEX IF NOT EXISTS ix_mai_messages_platform ON mai_messages (platform)", + "CREATE INDEX IF NOT EXISTS ix_mai_messages_session_id ON mai_messages (session_id)", + "CREATE INDEX IF NOT EXISTS ix_mai_messages_user_id ON mai_messages (user_id)", + "CREATE INDEX IF NOT EXISTS ix_mai_messages_user_nickname ON mai_messages (user_nickname)", + "CREATE INDEX IF NOT EXISTS ix_online_time_timestamp ON online_time (timestamp)", + "CREATE UNIQUE INDEX IF NOT EXISTS ix_person_info_person_id ON person_info (person_id)", + "CREATE INDEX IF NOT EXISTS ix_person_info_platform ON person_info (platform)", + "CREATE INDEX IF NOT EXISTS ix_person_info_user_id ON person_info (user_id)", + "CREATE INDEX IF NOT EXISTS ix_person_info_user_nickname ON person_info (user_nickname)", + "CREATE INDEX IF NOT EXISTS ix_thinking_questions_created_timestamp ON thinking_questions (created_timestamp)", + "CREATE INDEX IF NOT EXISTS ix_thinking_questions_updated_timestamp ON thinking_questions (updated_timestamp)", + "CREATE INDEX IF NOT EXISTS ix_tool_records_session_id ON tool_records (session_id)", + "CREATE INDEX IF NOT EXISTS ix_tool_records_timestamp ON tool_records (timestamp)", + "CREATE INDEX IF NOT EXISTS ix_tool_records_tool_id ON tool_records (tool_id)", + "CREATE INDEX IF NOT EXISTS ix_tool_records_tool_name ON tool_records (tool_name)", +) + + +def create_frozen_v2_schema(connection: Connection) -> None: + """创建冻结的 v2 schema。 + + Args: + connection: 当前数据库连接。 + """ + + for statement in _V2_TABLE_STATEMENTS: + connection.exec_driver_sql(statement) + + for statement in _V2_INDEX_STATEMENTS: + connection.exec_driver_sql(statement) diff --git a/src/common/database/migrations/legacy_v1_to_v2.py b/src/common/database/migrations/legacy_v1_to_v2.py index 1268b21c..c7cc9cb8 100644 --- a/src/common/database/migrations/legacy_v1_to_v2.py +++ b/src/common/database/migrations/legacy_v1_to_v2.py @@ -1,4 +1,4 @@ -"""旧版 ``0.x`` 数据库升级到最新 schema 的迁移逻辑。""" +"""旧版 ``0.x`` 数据库升级到 v2 schema 的迁移逻辑。""" from __future__ import annotations @@ -7,15 +7,16 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Set, Tuple, cast +import json + +import msgpack from sqlalchemy import text from sqlalchemy.engine import Connection -import json -import msgpack - from src.common.logger import get_logger from .exceptions import DatabaseMigrationExecutionError +from .frozen_v2_schema import create_frozen_v2_schema from .models import DatabaseSchemaSnapshot, MigrationExecutionContext from .schema import SQLiteSchemaInspector @@ -52,19 +53,15 @@ class LegacyTableData: def migrate_legacy_v1_to_v2(context: MigrationExecutionContext) -> None: - """执行旧版 ``0.x`` 数据库到最新 schema 的迁移。 + """执行旧版 ``0.x`` 数据库到 v2 schema 的迁移。 Args: context: 当前迁移步骤执行上下文。 """ - from sqlmodel import SQLModel - - import src.common.database.database_model # noqa: F401 - schema_inspector = SQLiteSchemaInspector() snapshot = schema_inspector.inspect(context.connection) _rename_legacy_v1_tables(context.connection, snapshot) - SQLModel.metadata.create_all(context.connection) + create_frozen_v2_schema(context.connection) table_migration_jobs: List[Tuple[str, Callable[[MigrationExecutionContext], int]]] = [ ("chat_sessions", _migrate_chat_sessions), @@ -794,8 +791,6 @@ def _migrate_images(context: MigrationExecutionContext) -> int: if full_path and dedupe_key not in existing_keys: migrated_description = _normalize_required_text(row.get("description")) migrated_emotion = _normalize_optional_text(row.get("emotion")) - if not migrated_description and migrated_emotion: - migrated_description = migrated_emotion connection.execute( insert_sql, { @@ -803,7 +798,7 @@ def _migrate_images(context: MigrationExecutionContext) -> int: "description": migrated_description, "full_path": full_path, "image_type": "EMOJI", - "emotion": None, + "emotion": migrated_emotion, "query_count": _normalize_int(row.get("query_count"), default=0), "is_registered": _normalize_bool(row.get("is_registered"), default=False), "is_banned": _normalize_bool(row.get("is_banned"), default=False), diff --git a/src/common/database/migrations/v2_to_v3.py b/src/common/database/migrations/v2_to_v3.py new file mode 100644 index 00000000..808367b6 --- /dev/null +++ b/src/common/database/migrations/v2_to_v3.py @@ -0,0 +1,269 @@ +"""v2 schema 升级到 v3 的迁移逻辑。""" + +from typing import Any, Dict, List + +from sqlalchemy import text +from sqlalchemy.engine import Connection + +from src.common.logger import get_logger + +from .exceptions import DatabaseMigrationExecutionError +from .models import MigrationExecutionContext +from .schema import SQLiteSchemaInspector + +logger = get_logger("database_migration") + +_V2_IMAGES_BACKUP_TABLE = "__v2_images_backup" +_V3_IMAGES_CREATE_SQL = """ +CREATE TABLE images ( + id INTEGER NOT NULL, + image_hash VARCHAR(255) NOT NULL, + description VARCHAR NOT NULL, + full_path VARCHAR(1024) NOT NULL, + image_type VARCHAR(5), + query_count INTEGER NOT NULL, + is_registered BOOLEAN NOT NULL, + is_banned BOOLEAN NOT NULL, + no_file_flag BOOLEAN NOT NULL, + record_time DATETIME, + register_time DATETIME, + last_used_time DATETIME, + vlm_processed BOOLEAN NOT NULL, + PRIMARY KEY (id) +) +""" +_V3_IMAGES_INDEX_STATEMENTS = ( + "CREATE INDEX ix_images_image_hash ON images (image_hash)", + "CREATE INDEX ix_images_record_time ON images (record_time)", +) + + +def migrate_v2_to_v3(context: MigrationExecutionContext) -> None: + """执行 v2 到 v3 的 schema 迁移。 + + Args: + context: 当前迁移步骤执行上下文。 + """ + + connection = context.connection + total_records = ( + _count_table_rows(connection, "action_records") + + _count_table_rows(connection, "thinking_questions") + + _count_table_rows(connection, "images") + ) + context.start_progress( + total_tables=3, + total_records=total_records, + description="v2 -> v3 迁移进度", + table_unit_name="表", + record_unit_name="记录", + ) + + migrated_tool_records = _migrate_action_records_to_tool_records(connection) + action_record_count = _count_table_rows(connection, "action_records") + _drop_table_if_exists(connection, "action_records") + context.advance_progress( + records=action_record_count, + completed_tables=1, + item_name="action_records", + ) + + thinking_question_count = _count_table_rows(connection, "thinking_questions") + _drop_table_if_exists(connection, "thinking_questions") + context.advance_progress( + records=thinking_question_count, + completed_tables=1, + item_name="thinking_questions", + ) + + migrated_image_rows = _migrate_images_table_to_v3(connection) + context.advance_progress( + records=migrated_image_rows, + completed_tables=1, + item_name="images", + ) + + logger.info( + "v2 -> v3 数据库迁移完成: " + f"tool_records补迁移={migrated_tool_records}," + f"images重建={migrated_image_rows}" + ) + + +def _count_table_rows(connection: Connection, table_name: str) -> int: + """统计表记录数,不存在时返回 0。""" + + schema_inspector = SQLiteSchemaInspector() + if not schema_inspector.table_exists(connection, table_name): + return 0 + row = connection.execute(text(f'SELECT COUNT(*) FROM "{table_name}"')).first() + return int(row[0]) if row else 0 + + +def _drop_table_if_exists(connection: Connection, table_name: str) -> None: + """删除指定表,不存在时静默跳过。""" + + connection.exec_driver_sql(f'DROP TABLE IF EXISTS "{table_name}"') + + +def _migrate_action_records_to_tool_records(connection: Connection) -> int: + """把 v2 中残留的 ``action_records`` 数据转存到 ``tool_records``。""" + + schema_inspector = SQLiteSchemaInspector() + if not schema_inspector.table_exists(connection, "action_records"): + return 0 + + inserted_count = _count_table_rows(connection, "action_records") + connection.execute( + text( + """ + INSERT INTO tool_records ( + tool_id, + timestamp, + session_id, + tool_name, + tool_reasoning, + tool_data, + tool_builtin_prompt, + tool_display_prompt + ) + SELECT + action_id, + timestamp, + session_id, + action_name, + action_reasoning, + action_data, + action_builtin_prompt, + action_display_prompt + FROM action_records + WHERE NOT EXISTS ( + SELECT 1 + FROM tool_records + WHERE tool_records.tool_id = action_records.action_id + ) + """ + ) + ) + return inserted_count + + +def _migrate_images_table_to_v3(connection: Connection) -> int: + """重建 ``images`` 表并移除 ``emotion`` 列。""" + + schema_inspector = SQLiteSchemaInspector() + if not schema_inspector.table_exists(connection, "images"): + return 0 + if not schema_inspector.get_table_schema(connection, "images").has_column("emotion"): + return _count_table_rows(connection, "images") + if schema_inspector.table_exists(connection, _V2_IMAGES_BACKUP_TABLE): + raise DatabaseMigrationExecutionError( + f"检测到残留备份表 {_V2_IMAGES_BACKUP_TABLE},无法安全执行 v2 -> v3 images 迁移。" + ) + + connection.exec_driver_sql(f'ALTER TABLE "images" RENAME TO "{_V2_IMAGES_BACKUP_TABLE}"') + connection.exec_driver_sql(_V3_IMAGES_CREATE_SQL) + + legacy_rows = connection.execute( + text(f'SELECT * FROM "{_V2_IMAGES_BACKUP_TABLE}" ORDER BY id') + ).mappings().all() + insert_sql = text( + """ + INSERT INTO images ( + id, + image_hash, + description, + full_path, + image_type, + query_count, + is_registered, + is_banned, + no_file_flag, + record_time, + register_time, + last_used_time, + vlm_processed + ) VALUES ( + :id, + :image_hash, + :description, + :full_path, + :image_type, + :query_count, + :is_registered, + :is_banned, + :no_file_flag, + :record_time, + :register_time, + :last_used_time, + :vlm_processed + ) + """ + ) + + for row in legacy_rows: + payload: Dict[str, Any] = { + "id": row.get("id"), + "image_hash": str(row.get("image_hash") or "").strip(), + "description": _migrate_v3_emoji_description(row), + "full_path": str(row.get("full_path") or "").strip(), + "image_type": row.get("image_type"), + "query_count": int(row.get("query_count") or 0), + "is_registered": bool(row.get("is_registered")), + "is_banned": bool(row.get("is_banned")), + "no_file_flag": bool(row.get("no_file_flag")), + "record_time": row.get("record_time"), + "register_time": row.get("register_time"), + "last_used_time": row.get("last_used_time"), + "vlm_processed": bool(row.get("vlm_processed")), + } + connection.execute(insert_sql, payload) + + connection.exec_driver_sql(f'DROP TABLE "{_V2_IMAGES_BACKUP_TABLE}"') + for statement in _V3_IMAGES_INDEX_STATEMENTS: + connection.exec_driver_sql(statement) + return len(legacy_rows) + + +def _migrate_v3_emoji_description(row: Dict[str, Any]) -> str: + """为 v3 统一 emoji 描述字段语义。 + + v3 中 `description` 对 emoji 统一承担“标签列表”的职责,因此迁移时: + 1. 若旧 `emotion` 非空,优先将其规范化后写入 `description`; + 2. 否则保留并规范化当前 `description`; + 3. 非 emoji 图片保持原描述不变。 + """ + + image_type = str(row.get("image_type") or "").strip().upper() + current_description = str(row.get("description") or "").strip() + current_emotion = str(row.get("emotion") or "").strip() + if image_type != "EMOJI": + return current_description + + normalized_tags = _normalize_emoji_tag_text(current_emotion or current_description) + if normalized_tags: + return ",".join(normalized_tags) + return current_description + + +def _normalize_emoji_tag_text(raw_value: Any) -> List[str]: + """将 emoji 标签文本转换为去重后的标签列表。""" + + normalized_text = str(raw_value or "").strip() + if not normalized_text: + return [] + + separators = [",", ",", "、", ";", ";", "\n", "\r", "\t"] + for separator in separators[1:]: + normalized_text = normalized_text.replace(separator, separators[0]) + + deduped_tags: List[str] = [] + seen_tags: set[str] = set() + for part in normalized_text.split(separators[0]): + normalized_part = part.strip() + lowered_part = normalized_part.lower() + if not normalized_part or lowered_part in seen_tags: + continue + seen_tags.add(lowered_part) + deduped_tags.append(normalized_part) + return deduped_tags diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index 2554ebd1..9bbc0267 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -1,16 +1,14 @@ -import contextlib -import time -import json import asyncio -from datetime import datetime -from typing import List, Dict, Any, Optional, Tuple, Callable +import contextlib +import json +import time +from typing import Any, Callable, Dict, List, Optional, Tuple + from src.common.logger import get_logger from src.config.config import global_config from src.prompt.prompt_manager import prompt_manager from src.services import llm_service as llm_api -from sqlmodel import select, col -from src.common.database.database import get_db_session -from src.common.database.database_model import ThinkingQuestion + from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message from src.chat.message_receive.chat_manager import chat_manager as _chat_manager @@ -18,35 +16,6 @@ from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon logger = get_logger("memory_retrieval") -THINKING_BACK_NOT_FOUND_RETENTION_SECONDS = 36000 # 未找到答案记录保留时长 -THINKING_BACK_CLEANUP_INTERVAL_SECONDS = 3000 # 清理频率 -_last_not_found_cleanup_ts: float = 0.0 - - -def _cleanup_stale_not_found_thinking_back() -> None: - """定期清理过期的未找到答案记录""" - global _last_not_found_cleanup_ts - - now = time.time() - if now - _last_not_found_cleanup_ts < THINKING_BACK_CLEANUP_INTERVAL_SECONDS: - return - - threshold_time = now - THINKING_BACK_NOT_FOUND_RETENTION_SECONDS - try: - with get_db_session() as session: - statement = select(ThinkingQuestion).where( - col(ThinkingQuestion.found_answer).is_(False) - & (ThinkingQuestion.updated_timestamp < datetime.fromtimestamp(threshold_time)) - ) - records = session.exec(statement).all() - for record in records: - session.delete(record) - if records: - logger.info(f"清理过期的未找到答案thinking_question记录 {len(records)} 条") - _last_not_found_cleanup_ts = now - except Exception as e: - logger.error(f"清理未找到答案的thinking_back记录失败: {e}") - def init_memory_retrieval_sys(): """初始化记忆检索相关工具""" @@ -766,141 +735,6 @@ async def _react_agent_solve_question( return False, "", thinking_steps, is_timeout - -def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0) -> str: - """获取最近一段时间内的查询历史(用于避免重复查询) - - Args: - chat_id: 聊天ID - time_window_seconds: 时间窗口(秒),默认10分钟 - - Returns: - str: 格式化的查询历史字符串 - """ - try: - _current_time = time.time() - - with get_db_session() as session: - statement = ( - select(ThinkingQuestion) - .where(col(ThinkingQuestion.context) == chat_id) - .order_by(col(ThinkingQuestion.updated_timestamp).desc()) - .limit(5) - ) - records = session.exec(statement).all() - - if not records: - return "" - - history_lines = ["最近已查询的问题和结果:"] - - for record in records: - status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案" - answer_preview = "" - # 只有找到答案时才显示答案内容 - if record.found_answer and record.answer: - # 截取答案前100字符 - answer_preview = record.answer[:100] - if len(record.answer) > 100: - answer_preview += "..." - - history_lines.extend([f"- 问题:{record.question}", f" 状态:{status}"]) - if answer_preview: - history_lines.append(f" 答案:{answer_preview}") - history_lines.append("") # 空行分隔 - - return "\n".join(history_lines) - - except Exception as e: - logger.error(f"获取查询历史失败: {e}") - return "" - - -def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0) -> List[str]: - """获取最近一段时间内已找到答案的查询记录(用于返回给 replyer) - - Args: - chat_id: 聊天ID - time_window_seconds: 时间窗口(秒),默认10分钟 - - Returns: - List[str]: 格式化的答案列表,每个元素格式为 "问题:xxx\n答案:xxx" - """ - try: - _current_time = time.time() - - # 查询最近时间窗口内已找到答案的记录,按更新时间倒序 - with get_db_session() as session: - statement = ( - select(ThinkingQuestion) - .where(col(ThinkingQuestion.context) == chat_id) - .where(col(ThinkingQuestion.found_answer)) - .where(col(ThinkingQuestion.answer).is_not(None)) - .where(col(ThinkingQuestion.answer) != "") - .order_by(col(ThinkingQuestion.updated_timestamp).desc()) - .limit(3) - ) - records = session.exec(statement).all() - - if not records: - return [] - - return [f"问题:{record.question}\n答案:{record.answer}" for record in records if record.answer] - - except Exception as e: - logger.error(f"获取最近已找到答案的记录失败: {e}") - return [] - - -def _store_thinking_back( - chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]] -) -> None: - """存储或更新思考过程到数据库(如果已存在则更新,否则创建) - - Args: - chat_id: 聊天ID - question: 问题 - context: 上下文信息 - found_answer: 是否找到答案 - answer: 答案内容 - thinking_steps: 思考步骤列表 - """ - try: - now = time.time() - - # 先查询是否已存在相同chat_id和问题的记录 - with get_db_session() as session: - statement = ( - select(ThinkingQuestion) - .where(col(ThinkingQuestion.context) == chat_id) - .where(col(ThinkingQuestion.question) == question) - .order_by(col(ThinkingQuestion.updated_timestamp).desc()) - .limit(1) - ) - if record := session.exec(statement).first(): - record.context = context - record.found_answer = found_answer - record.answer = answer - record.thinking_steps = json.dumps(thinking_steps, ensure_ascii=False) - record.updated_timestamp = datetime.fromtimestamp(now) - session.add(record) - logger.info(f"已更新思考过程到数据库,问题: {question[:50]}...") - return - - new_record = ThinkingQuestion( - question=question, - context=chat_id, - found_answer=found_answer, - answer=answer, - thinking_steps=json.dumps(thinking_steps, ensure_ascii=False), - created_timestamp=datetime.fromtimestamp(now), - updated_timestamp=datetime.fromtimestamp(now), - ) - session.add(new_record) - except Exception as e: - logger.error(f"存储思考过程失败: {e}") - - async def _process_memory_retrieval( chat_id: str, context: str, @@ -920,8 +754,6 @@ async def _process_memory_retrieval( Returns: Optional[str]: 如果找到答案,返回答案内容,否则返回None """ - _cleanup_stale_not_found_thinking_back() - question_initial_info = initial_info or "" # 直接使用ReAct Agent进行记忆检索 diff --git a/src/plugin_runtime/capabilities/data.py b/src/plugin_runtime/capabilities/data.py index 440677a2..f3fb27b2 100644 --- a/src/plugin_runtime/capabilities/data.py +++ b/src/plugin_runtime/capabilities/data.py @@ -585,10 +585,8 @@ class RuntimeDataCapabilityMixin: if not ImageUtils.base64_to_image(emoji_base64, str(temp_file_path)): return {"success": False, "message": "无法保存图片文件", "description": None, "emotions": None, "replaced": None, "hash": None} - register_success = await emoji_manager.register_emoji_by_filename(temp_file_path) - if not register_success: - if temp_file_path.exists(): - temp_file_path.unlink(missing_ok=True) + register_status = await emoji_manager.register_emoji_by_filename(temp_file_path) + if register_status == "failed": return { "success": False, "message": "表情包注册失败,可能因为重复、格式不支持或审核未通过", @@ -597,6 +595,15 @@ class RuntimeDataCapabilityMixin: "replaced": None, "hash": None, } + if register_status == "skipped": + return { + "success": True, + "message": "表情包已注册,已跳过本次注册", + "description": None, + "emotions": None, + "replaced": False, + "hash": None, + } count_after = len(emoji_manager.emojis) replaced = count_after <= count_before diff --git a/src/webui/routers/emoji/routes.py b/src/webui/routers/emoji/routes.py index 0b836786..7243bc31 100644 --- a/src/webui/routers/emoji/routes.py +++ b/src/webui/routers/emoji/routes.py @@ -40,7 +40,7 @@ from .schemas import ( emoji_to_response, ) from .support import ( - EMOJI_REGISTERED_DIR, + EMOJI_DIR, THUMBNAIL_CACHE_DIR, background_generate_thumbnail, cleanup_orphaned_thumbnails, @@ -326,7 +326,7 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N if not emoji: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") if emoji.is_registered: - raise HTTPException(status_code=400, detail="该表情包已经注册") + return EmojiUpdateResponse(success=True, message="??????????", data=emoji_to_response(emoji)) emoji.is_registered = True emoji.is_banned = False @@ -549,16 +549,16 @@ async def upload_emoji( if existing_emoji := session.exec(existing_statement).first(): raise HTTPException(status_code=409, detail=f"已存在相同的表情包 (ID: {existing_emoji.id})") - os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True) + os.makedirs(EMOJI_DIR, exist_ok=True) timestamp = int(datetime.now().timestamp()) filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}" - full_path = os.path.join(EMOJI_REGISTERED_DIR, filename) + full_path = os.path.join(EMOJI_DIR, filename) counter = 1 while os.path.exists(full_path): filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}" - full_path = os.path.join(EMOJI_REGISTERED_DIR, filename) + full_path = os.path.join(EMOJI_DIR, filename) counter += 1 with open(full_path, "wb") as output_file: @@ -618,7 +618,7 @@ async def batch_upload_emoji( } allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"] - os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True) + os.makedirs(EMOJI_DIR, exist_ok=True) for file in files: try: @@ -665,12 +665,12 @@ async def batch_upload_emoji( timestamp = int(datetime.now().timestamp()) filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}" - full_path = os.path.join(EMOJI_REGISTERED_DIR, filename) + full_path = os.path.join(EMOJI_DIR, filename) counter = 1 while os.path.exists(full_path): filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}" - full_path = os.path.join(EMOJI_REGISTERED_DIR, filename) + full_path = os.path.join(EMOJI_DIR, filename) counter += 1 with open(full_path, "wb") as output_file: diff --git a/src/webui/routers/emoji/support.py b/src/webui/routers/emoji/support.py index 41b0be97..4b715997 100644 --- a/src/webui/routers/emoji/support.py +++ b/src/webui/routers/emoji/support.py @@ -16,7 +16,8 @@ logger = get_logger("webui.emoji") THUMBNAIL_CACHE_DIR = Path("data/emoji_thumbnails") THUMBNAIL_SIZE = (200, 200) THUMBNAIL_QUALITY = 80 -EMOJI_REGISTERED_DIR = os.path.join("data", "emoji_registed") +EMOJI_REGISTERED_DIR = os.path.join("data", "emoji") +EMOJI_DIR = EMOJI_REGISTERED_DIR _thumbnail_locks: Dict[str, threading.Lock] = {} _locks_lock = threading.Lock()