diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py index 28054717..35c1e513 100644 --- a/src/chat/brain_chat/brain_chat.py +++ b/src/chat/brain_chat/brain_chat.py @@ -28,10 +28,7 @@ from src.services import ( message_service as message_api, database_service as database_api, ) -from src.chat.utils.chat_message_builder import ( - build_readable_messages_with_id, - get_raw_msg_before_timestamp_with_chat, -) +from src.services.message_service import build_readable_messages_with_id, get_messages_before_time_in_chat if TYPE_CHECKING: from src.common.data_models.database_data_model import DatabaseMessages @@ -275,7 +272,7 @@ class BrainChatting: # 一次思考迭代:Think - Act - Observe # 获取聊天上下文 - message_list_before_now = get_raw_msg_before_timestamp_with_chat( + message_list_before_now = get_messages_before_time_in_chat( chat_id=self.stream_id, timestamp=time.time(), limit=int(global_config.chat.max_context_size * 0.6), diff --git a/src/chat/brain_chat/brain_planner.py b/src/chat/brain_chat/brain_planner.py index 281830b0..2aa9b293 100644 --- a/src/chat/brain_chat/brain_planner.py +++ b/src/chat/brain_chat/brain_planner.py @@ -14,11 +14,11 @@ from src.common.logger import get_logger from src.chat.logger.plan_reply_logger import PlanReplyLogger from src.common.data_models.info_data_model import ActionPlannerInfo from src.prompt.prompt_manager import prompt_manager -from src.chat.utils.chat_message_builder import ( +from src.services.message_service import ( build_readable_actions, - get_actions_by_timestamp_with_chat, build_readable_messages_with_id, - get_raw_msg_before_timestamp_with_chat, + get_actions_by_timestamp_with_chat, + get_messages_before_time_in_chat, ) from src.chat.utils.utils import get_chat_type_and_target_info from src.chat.planner_actions.action_manager import ActionManager @@ -163,7 +163,7 @@ class BrainPlanner: plan_start = time.perf_counter() # 获取聊天上下文 - message_list_before_now = get_raw_msg_before_timestamp_with_chat( + message_list_before_now = get_messages_before_time_in_chat( chat_id=self.chat_id, timestamp=time.time(), limit=int(global_config.chat.max_context_size * 0.6), diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index fa8635ee..c7d64c1f 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -6,7 +6,7 @@ from src.common.logger import get_logger from src.config.config import global_config from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager from src.chat.planner_actions.action_manager import ActionManager -from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages +from src.services.message_service import build_readable_messages, get_messages_before_time_in_chat from src.core.types import ActionActivationType, ActionInfo from src.core.announcement_manager import global_announcement_manager @@ -51,7 +51,7 @@ class ActionModifier: self.action_manager.restore_actions() all_actions = self.action_manager.get_using_actions() - message_list_before_now_half = get_raw_msg_before_timestamp_with_chat( + message_list_before_now_half = get_messages_before_time_in_chat( chat_id=self.chat_stream.stream_id, timestamp=time.time(), limit=min(int(global_config.chat.max_context_size * 0.33), 10), diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index e7b013e6..b6b68fa3 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -15,17 +15,17 @@ from src.common.logger import get_logger from src.chat.logger.plan_reply_logger import PlanReplyLogger from src.common.data_models.info_data_model import ActionPlannerInfo from src.prompt.prompt_manager import prompt_manager -from src.chat.utils.chat_message_builder import ( +from src.services.message_service import ( build_readable_messages_with_id, - get_raw_msg_before_timestamp_with_chat, replace_user_references, + get_messages_before_time_in_chat, + translate_pid_to_description, ) from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self from src.chat.planner_actions.action_manager import ActionManager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.core.types import ActionActivationType, ActionInfo, ComponentType from src.core.component_registry import component_registry -from src.services.message_service import translate_pid_to_description from src.person_info.person_info import Person if TYPE_CHECKING: @@ -389,7 +389,7 @@ class ActionPlanner: plan_start = time.perf_counter() # 获取聊天上下文 - message_list_before_now = get_raw_msg_before_timestamp_with_chat( + message_list_before_now = get_messages_before_time_in_chat( chat_id=self.chat_id, timestamp=time.time(), limit=int(global_config.chat.max_context_size * 0.6), diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 4bfc78fc..5982e0c3 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -22,7 +22,7 @@ from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self from src.prompt.prompt_manager import prompt_manager from src.services.message_service import ( build_readable_messages, - get_raw_msg_before_timestamp_with_chat, + get_messages_before_time_in_chat, replace_user_references, translate_pid_to_description, ) @@ -809,14 +809,14 @@ class DefaultReplyer: # 将[picid:xxx]替换为具体的图片描述 target = self._replace_picids_with_descriptions(target) - message_list_before_now_long = get_raw_msg_before_timestamp_with_chat( + message_list_before_now_long = get_messages_before_time_in_chat( chat_id=chat_id, timestamp=reply_time_point, limit=global_config.chat.max_context_size * 1, filter_intercept_message_level=1, ) - message_list_before_short = get_raw_msg_before_timestamp_with_chat( + message_list_before_short = get_messages_before_time_in_chat( chat_id=chat_id, timestamp=reply_time_point, limit=int(global_config.chat.max_context_size * 0.33), @@ -1022,7 +1022,7 @@ class DefaultReplyer: # 将[picid:xxx]替换为具体的图片描述 target = self._replace_picids_with_descriptions(target) - message_list_before_now_half = get_raw_msg_before_timestamp_with_chat( + message_list_before_now_half = get_messages_before_time_in_chat( chat_id=chat_id, timestamp=time.time(), limit=min(int(global_config.chat.max_context_size * 0.33), 15), diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index 13d73018..cadb734d 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -23,7 +23,7 @@ from src.prompt.prompt_manager import prompt_manager from src.chat.utils.common_utils import TempMethodsExpression from src.services.message_service import ( build_readable_messages, - get_raw_msg_before_timestamp_with_chat, + get_messages_before_time_in_chat, replace_user_references, translate_pid_to_description, ) @@ -650,7 +650,7 @@ class PrivateReplyer: # 将[picid:xxx]替换为具体的图片描述 target = self._replace_picids_with_descriptions(target) - message_list_before_now_long = get_raw_msg_before_timestamp_with_chat( + message_list_before_now_long = get_messages_before_time_in_chat( chat_id=chat_id, timestamp=time.time(), limit=global_config.chat.max_context_size, @@ -666,7 +666,7 @@ class PrivateReplyer: long_time_notice=True, ) - message_list_before_short = get_raw_msg_before_timestamp_with_chat( + message_list_before_short = get_messages_before_time_in_chat( chat_id=chat_id, timestamp=time.time(), limit=int(global_config.chat.max_context_size * 0.33), @@ -857,7 +857,7 @@ class PrivateReplyer: # 将[picid:xxx]替换为具体的图片描述 target = self._replace_picids_with_descriptions(target) - message_list_before_now_half = get_raw_msg_before_timestamp_with_chat( + message_list_before_now_half = get_messages_before_time_in_chat( chat_id=chat_id, timestamp=time.time(), limit=min(int(global_config.chat.max_context_size * 0.33), 15), diff --git a/src/plugin_runtime/capabilities/data.py b/src/plugin_runtime/capabilities/data.py index e5b183cd..a4f34b00 100644 --- a/src/plugin_runtime/capabilities/data.py +++ b/src/plugin_runtime/capabilities/data.py @@ -1,12 +1,38 @@ +import random +from pathlib import Path from typing import Any, Dict, List, Optional +import time + from src.chat.message_receive.chat_manager import BotChatSession, chat_manager +from src.common.data_models.image_data_model import MaiEmoji from src.common.logger import get_logger +from src.common.utils.utils_image import ImageUtils logger = get_logger("plugin_runtime.integration") class RuntimeDataCapabilityMixin: + @staticmethod + def _serialize_emoji_payload(emoji: MaiEmoji) -> Optional[Dict[str, str]]: + emoji_base64 = ImageUtils.image_path_to_base64(str(emoji.full_path)) + if not emoji_base64: + return None + + matched_emotion = emoji.emotion[0] if emoji.emotion else "" + return { + "base64": emoji_base64, + "description": emoji.description, + "emotion": matched_emotion, + } + + @staticmethod + def _build_emoji_temp_path() -> Path: + from src.chat.emoji_system.emoji_manager import EMOJI_DIR + + EMOJI_DIR.mkdir(parents=True, exist_ok=True) + return EMOJI_DIR / f"emoji_cap_{int(time.time() * 1000000)}.png" + async def _cap_database_query(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: from src.services import database_service as database_api @@ -338,7 +364,7 @@ class RuntimeDataCapabilityMixin: limit=args.get("limit", 0), ) - readable = message_api.build_readable_messages_to_str( + readable = message_api.build_readable_messages( messages=messages, replace_bot_name=args.get("replace_bot_name", True), timestamp_mode=args.get("timestamp_mode", "relative"), @@ -397,101 +423,173 @@ class RuntimeDataCapabilityMixin: return {"success": False, "error": str(e)} async def _cap_emoji_get_by_description(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - from src.services import emoji_service as emoji_api + from src.chat.emoji_system.emoji_manager import emoji_manager description: str = args.get("description", "") if not description: return {"success": False, "error": "缺少必要参数 description"} try: - result = await emoji_api.get_by_description(description=description) - if result is None: + emoji = await emoji_manager.get_emoji_for_emotion(description) + if emoji is None: + return {"success": True, "emoji": None} + serialized = self._serialize_emoji_payload(emoji) + if serialized is None: return {"success": True, "emoji": None} - emoji_base64, emoji_desc, matched_emotion = result return { "success": True, - "emoji": { - "base64": emoji_base64, - "description": emoji_desc, - "emotion": matched_emotion, - }, + "emoji": serialized, } except Exception as e: logger.error(f"[cap.emoji.get_by_description] 执行失败: {e}", exc_info=True) return {"success": False, "error": str(e)} async def _cap_emoji_get_random(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - from src.services import emoji_service as emoji_api + from src.chat.emoji_system.emoji_manager import emoji_manager count: int = args.get("count", 1) try: - results = await emoji_api.get_random(count=count) - emojis = [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results] + if count < 0: + return {"success": False, "error": "count 不能为负数"} + + emojis_source = list(emoji_manager.emojis) + if count == 0 or not emojis_source: + return {"success": True, "emojis": []} + + selected = random.sample(emojis_source, min(count, len(emojis_source))) + emojis: List[Dict[str, str]] = [] + for emoji in selected: + emoji_manager.update_emoji_usage(emoji) + serialized = self._serialize_emoji_payload(emoji) + if serialized is not None: + if not serialized["emotion"]: + serialized["emotion"] = "随机表情" + emojis.append(serialized) return {"success": True, "emojis": emojis} except Exception as e: logger.error(f"[cap.emoji.get_random] 执行失败: {e}", exc_info=True) return {"success": False, "error": str(e)} async def _cap_emoji_get_count(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - from src.services import emoji_service as emoji_api - try: - return {"success": True, "count": emoji_api.get_count()} + from src.chat.emoji_system.emoji_manager import emoji_manager + + return {"success": True, "count": len(emoji_manager.emojis)} except Exception as e: logger.error(f"[cap.emoji.get_count] 执行失败: {e}", exc_info=True) return {"success": False, "error": str(e)} async def _cap_emoji_get_emotions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - from src.services import emoji_service as emoji_api - try: - return {"success": True, "emotions": emoji_api.get_emotions()} + from src.chat.emoji_system.emoji_manager import emoji_manager + + emotions = sorted({emotion for emoji in emoji_manager.emojis for emotion in emoji.emotion}) + return {"success": True, "emotions": emotions} except Exception as e: logger.error(f"[cap.emoji.get_emotions] 执行失败: {e}", exc_info=True) return {"success": False, "error": str(e)} async def _cap_emoji_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - from src.services import emoji_service as emoji_api - try: - results = await emoji_api.get_all() - emojis = [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results] if results else [] + from src.chat.emoji_system.emoji_manager import emoji_manager + + emojis = [] + for emoji in emoji_manager.emojis: + serialized = self._serialize_emoji_payload(emoji) + if serialized is not None: + if not serialized["emotion"]: + serialized["emotion"] = "随机表情" + emojis.append(serialized) return {"success": True, "emojis": emojis} except Exception as e: logger.error(f"[cap.emoji.get_all] 执行失败: {e}", exc_info=True) return {"success": False, "error": str(e)} async def _cap_emoji_get_info(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - from src.services import emoji_service as emoji_api - try: - return {"success": True, "info": emoji_api.get_info()} + from src.chat.emoji_system.emoji_manager import emoji_manager + from src.config.config import global_config + + current_count = len(emoji_manager.emojis) + return { + "success": True, + "info": { + "current_count": current_count, + "max_count": global_config.emoji.max_reg_num, + "available_emojis": current_count, + }, + } except Exception as e: logger.error(f"[cap.emoji.get_info] 执行失败: {e}", exc_info=True) return {"success": False, "error": str(e)} async def _cap_emoji_register(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - from src.services import emoji_service as emoji_api + from src.chat.emoji_system.emoji_manager import emoji_manager emoji_base64: str = args.get("emoji_base64", "") if not emoji_base64: return {"success": False, "error": "缺少必要参数 emoji_base64"} try: - return await emoji_api.register_emoji(emoji_base64) + count_before = len(emoji_manager.emojis) + temp_file_path = self._build_emoji_temp_path() + if not ImageUtils.base64_to_image(emoji_base64, str(temp_file_path)): + return {"success": False, "message": "无法保存图片文件", "description": None, "emotions": None, "replaced": None, "hash": None} + + register_success = await emoji_manager.register_emoji_by_filename(temp_file_path) + if not register_success: + if temp_file_path.exists(): + temp_file_path.unlink(missing_ok=True) + return { + "success": False, + "message": "表情包注册失败,可能因为重复、格式不支持或审核未通过", + "description": None, + "emotions": None, + "replaced": None, + "hash": None, + } + + count_after = len(emoji_manager.emojis) + replaced = count_after <= count_before + new_emoji = next( + ( + item + for item in reversed(emoji_manager.emojis) + if temp_file_path.name == item.file_name or temp_file_path.name in str(item.full_path) + ), + None, + ) + return { + "success": True, + "message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}", + "description": None if new_emoji is None else new_emoji.description, + "emotions": None if new_emoji is None else new_emoji.emotion, + "replaced": replaced, + "hash": None if new_emoji is None else new_emoji.file_hash, + } except Exception as e: logger.error(f"[cap.emoji.register] 执行失败: {e}", exc_info=True) return {"success": False, "error": str(e)} async def _cap_emoji_delete(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: - from src.services import emoji_service as emoji_api + from src.chat.emoji_system.emoji_manager import emoji_manager emoji_hash: str = args.get("emoji_hash", "") if not emoji_hash: return {"success": False, "error": "缺少必要参数 emoji_hash"} try: - return await emoji_api.delete_emoji(emoji_hash) + emoji = emoji_manager.get_emoji_by_hash(emoji_hash) + if emoji is None: + return {"success": False, "message": f"未找到表情包: {emoji_hash}", "hash": emoji_hash} + + success = emoji_manager.delete_emoji(emoji, not bool(emoji.description and emoji.description.strip())) + if not success: + return {"success": False, "message": f"删除表情包失败: {emoji_hash}", "hash": emoji_hash} + + emoji_manager.emojis = [item for item in emoji_manager.emojis if item.file_hash != emoji_hash] + emoji_manager._emoji_num = len(emoji_manager.emojis) + return {"success": True, "message": f"成功删除表情包: {emoji_hash}", "hash": emoji_hash} except Exception as e: logger.error(f"[cap.emoji.delete] 执行失败: {e}", exc_info=True) return {"success": False, "error": str(e)} diff --git a/src/services/chat_service.py b/src/services/chat_service.py deleted file mode 100644 index effb0b93..00000000 --- a/src/services/chat_service.py +++ /dev/null @@ -1,171 +0,0 @@ -""" -聊天服务模块 - -提供聊天信息查询和管理的核心功能。 -""" - -from typing import Any, Callable, Dict, List, Optional -from enum import Enum - -from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager -from src.common.logger import get_logger - -logger = get_logger("chat_service") - - -class SpecialTypes(Enum): - """特殊枚举类型""" - - ALL_PLATFORMS = "all_platforms" - - -class ChatManager: - """聊天管理器 - 负责聊天信息的查询和管理""" - - @staticmethod - def _validate_platform(platform: Optional[str] | SpecialTypes) -> None: - if not isinstance(platform, (str, SpecialTypes)): - raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") - - @staticmethod - def _match_platform(chat_stream: BotChatSession, platform: Optional[str] | SpecialTypes) -> bool: - return platform == SpecialTypes.ALL_PLATFORMS or chat_stream.platform == platform - - @staticmethod - def _get_streams( - platform: Optional[str] | SpecialTypes = "qq", is_group_session: Optional[bool] = None - ) -> List[BotChatSession]: - ChatManager._validate_platform(platform) - - try: - streams = [ - stream - for stream in _chat_manager.sessions.values() - if ChatManager._match_platform(stream, platform) - and (is_group_session is None or stream.is_group_session == is_group_session) - ] - return streams - except Exception as e: - logger.error(f"[ChatService] 获取聊天流失败: {e}") - return [] - - @staticmethod - def _find_stream( - predicate: Callable[[BotChatSession], bool], - platform: Optional[str] | SpecialTypes = "qq", - ) -> Optional[BotChatSession]: - for stream in ChatManager._get_streams(platform=platform): - if predicate(stream): - return stream - return None - - @staticmethod - def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: - streams = ChatManager._get_streams(platform=platform) - logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的聊天流") - return streams - - @staticmethod - def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: - streams = ChatManager._get_streams(platform=platform, is_group_session=True) - logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的群聊流") - return streams - - @staticmethod - def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: - streams = ChatManager._get_streams(platform=platform, is_group_session=False) - logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的私聊流") - return streams - - @staticmethod - def get_group_stream_by_group_id( - group_id: str, platform: Optional[str] | SpecialTypes = "qq" - ) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast - if not isinstance(group_id, str): - raise TypeError("group_id 必须是字符串类型") - ChatManager._validate_platform(platform) - if not group_id: - raise ValueError("group_id 不能为空") - try: - stream = ChatManager._find_stream( - lambda item: item.is_group_session and str(item.group_id) == str(group_id), - platform=platform, - ) - if stream is not None: - logger.debug(f"[ChatService] 找到群ID {group_id} 的聊天流") - return stream - logger.warning(f"[ChatService] 未找到群ID {group_id} 的聊天流") - except Exception as e: - logger.error(f"[ChatService] 查找群聊流失败: {e}") - return None - - @staticmethod - def get_private_stream_by_user_id( - user_id: str, platform: Optional[str] | SpecialTypes = "qq" - ) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast - if not isinstance(user_id, str): - raise TypeError("user_id 必须是字符串类型") - ChatManager._validate_platform(platform) - if not user_id: - raise ValueError("user_id 不能为空") - try: - stream = ChatManager._find_stream( - lambda item: (not item.is_group_session) and str(item.user_id) == str(user_id), - platform=platform, - ) - if stream is not None: - logger.debug(f"[ChatService] 找到用户ID {user_id} 的私聊流") - return stream - logger.warning(f"[ChatService] 未找到用户ID {user_id} 的私聊流") - except Exception as e: - logger.error(f"[ChatService] 查找私聊流失败: {e}") - return None - - @staticmethod - def get_stream_type(chat_stream: BotChatSession) -> str: - if not isinstance(chat_stream, BotChatSession): - raise TypeError("chat_stream 必须是 BotChatSession 类型") - if not chat_stream: - raise ValueError("chat_stream 不能为 None") - - return "group" if chat_stream.is_group_session else "private" - - @staticmethod - def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]: - if not chat_stream: - raise ValueError("chat_stream 不能为 None") - if not isinstance(chat_stream, BotChatSession): - raise TypeError("chat_stream 必须是 BotChatSession 类型") - - try: - info: Dict[str, Any] = { - "session_id": chat_stream.session_id, - "platform": chat_stream.platform, - "type": ChatManager.get_stream_type(chat_stream), - } - - if chat_stream.is_group_session: - info["group_id"] = chat_stream.group_id - if ( - chat_stream.context - and chat_stream.context.message - and chat_stream.context.message.message_info.group_info - ): - info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊" - else: - info["group_name"] = "未知群聊" - else: - info["user_id"] = chat_stream.user_id - if ( - chat_stream.context - and chat_stream.context.message - and chat_stream.context.message.message_info.user_info - ): - info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname - else: - info["user_name"] = "未知用户" - - return info - except Exception as e: - logger.error(f"[ChatService] 获取聊天流信息失败: {e}") - return {} diff --git a/src/services/database_service.py b/src/services/database_service.py index 8d9192d2..5b8b716f 100644 --- a/src/services/database_service.py +++ b/src/services/database_service.py @@ -9,6 +9,7 @@ from typing import Any, Optional from sqlalchemy import delete, func, select from sqlmodel import SQLModel +from src.chat.message_receive.chat_manager import BotChatSession from src.common.database.database import get_db_session from src.common.database.database_model import ActionRecord from src.common.logger import get_logger @@ -23,12 +24,10 @@ def _to_dict(record: Any) -> dict[str, Any]: return record if hasattr(record, "model_dump"): return record.model_dump() - if hasattr(record, "__dict__"): - return dict(record.__dict__) - return {} + return dict(record.__dict__) if hasattr(record, "__dict__") else {} -def _get_model_field(model_class: type[SQLModel], field_name: str): +def _get_model_field(model_class: type[SQLModel], field_name: str) -> Any: field = getattr(model_class, field_name, None) if field is None: raise ValueError(f"{model_class.__name__} 不存在字段 {field_name}") @@ -41,7 +40,7 @@ def _build_filters(model_class: type[SQLModel], filters: Optional[dict[str, Any] return [_get_model_field(model_class, field_name) == value for field_name, value in filters.items()] -def _apply_order_by(statement, model_class: type[SQLModel], order_by: Optional[str | list[str]] = None): +def _apply_order_by(statement: Any, model_class: type[SQLModel], order_by: Optional[str | list[str]] = None) -> Any: if not order_by: return statement @@ -60,7 +59,7 @@ async def db_save( data: dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None, -): +) -> Optional[dict[str, Any]]: try: with get_db_session() as session: record = None @@ -91,12 +90,11 @@ async def db_get( limit: Optional[int] = None, order_by: Optional[str | list[str]] = None, single_result: bool = False, -): +) -> Optional[dict[str, Any]] | list[dict[str, Any]]: try: with get_db_session(auto_commit=False) as session: statement = select(model_class) - conditions = _build_filters(model_class, filters) - if conditions: + if conditions := _build_filters(model_class, filters): statement = statement.where(*conditions) statement = _apply_order_by(statement, model_class, order_by) if limit: @@ -116,8 +114,7 @@ async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters: try: with get_db_session() as session: statement = select(model_class) - conditions = _build_filters(model_class, filters) - if conditions: + if conditions := _build_filters(model_class, filters): statement = statement.where(*conditions) records = session.exec(statement).all() for record in records: @@ -136,8 +133,7 @@ async def db_delete(model_class: type[SQLModel], filters: Optional[dict[str, Any try: with get_db_session() as session: statement = delete(model_class) - conditions = _build_filters(model_class, filters) - if conditions: + if conditions := _build_filters(model_class, filters): statement = statement.where(*conditions) result = session.exec(statement) return result.rowcount or 0 @@ -151,8 +147,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any] try: with get_db_session(auto_commit=False) as session: statement = select(func.count()).select_from(model_class) - conditions = _build_filters(model_class, filters) - if conditions: + if conditions := _build_filters(model_class, filters): statement = statement.where(*conditions) result = session.exec(statement).one() return int(result or 0) @@ -163,18 +158,15 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any] async def store_action_info( - chat_stream=None, + chat_stream: BotChatSession, builtin_prompt: Optional[str] = None, display_prompt: str = "", thinking_id: str = "", - action_data: Optional[dict] = None, + action_data: Optional[dict[str, Any]] = None, action_name: str = "", action_reasoning: str = "", -): +) -> Optional[dict[str, Any]]: try: - if chat_stream is None: - raise ValueError("store_action_info 需要 chat_stream") - record_data = { "action_id": thinking_id or str(int(time.time() * 1000000)), "timestamp": datetime.now(), diff --git a/src/services/emoji_service.py b/src/services/emoji_service.py deleted file mode 100644 index f6d14348..00000000 --- a/src/services/emoji_service.py +++ /dev/null @@ -1,406 +0,0 @@ -""" -表情服务模块 - -提供表情包相关的核心功能。 -""" - -import base64 -import os -import random -import uuid - -from typing import Any, Dict, List, Optional, Tuple - -from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR -from src.common.logger import get_logger -from src.common.utils.utils_image import ImageUtils -from src.config.config import global_config - -logger = get_logger("emoji_service") - - -# ============================================================================= -# 表情包获取函数 -# ============================================================================= - - -async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]: - """根据描述选择表情包""" - if not description: - raise ValueError("描述不能为空") - if not isinstance(description, str): - raise TypeError("描述必须是字符串类型") - try: - logger.debug(f"[EmojiService] 根据描述获取表情包: {description}") - - emoji_obj = await emoji_manager.get_emoji_for_emotion(description) - - if not emoji_obj: - logger.warning(f"[EmojiService] 未找到匹配描述 '{description}' 的表情包") - return None - - emoji_path = str(emoji_obj.full_path) - emoji_description = emoji_obj.description - matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "" - emoji_base64 = ImageUtils.image_path_to_base64(emoji_path) - - if not emoji_base64: - logger.error(f"[EmojiService] 无法将表情包文件转换为base64: {emoji_path}") - return None - - logger.debug(f"[EmojiService] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}") - return emoji_base64, emoji_description, matched_emotion - - except Exception as e: - logger.error(f"[EmojiService] 获取表情包失败: {e}") - return None - - -async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]: - """随机获取指定数量的表情包""" - if not isinstance(count, int): - raise TypeError("count 必须是整数类型") - if count < 0: - raise ValueError("count 不能为负数") - if count == 0: - logger.warning("[EmojiService] count 为0,返回空列表") - return [] - - try: - all_emojis = emoji_manager.emojis - - if not all_emojis: - logger.warning("[EmojiService] 没有可用的表情包") - return [] - - valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted] - if not valid_emojis: - logger.warning("[EmojiService] 没有有效的表情包") - return [] - - if len(valid_emojis) < count: - logger.debug( - f"[EmojiService] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包" - ) - count = len(valid_emojis) - - selected_emojis = random.sample(valid_emojis, count) - - results = [] - for selected_emoji in selected_emojis: - emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path)) - - if not emoji_base64: - logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}") - continue - - matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情" - - emoji_manager.update_emoji_usage(selected_emoji) - results.append((emoji_base64, selected_emoji.description, matched_emotion)) - - if not results and count > 0: - logger.warning("[EmojiService] 随机获取表情包失败,没有一个可以成功处理") - return [] - - logger.debug(f"[EmojiService] 成功获取 {len(results)} 个随机表情包") - return results - - except Exception as e: - logger.error(f"[EmojiService] 获取随机表情包失败: {e}") - return [] - - -async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]: - """根据情感标签获取表情包""" - if not emotion: - raise ValueError("情感标签不能为空") - if not isinstance(emotion, str): - raise TypeError("情感标签必须是字符串类型") - try: - logger.info(f"[EmojiService] 根据情感获取表情包: {emotion}") - - all_emojis = emoji_manager.emojis - - matching_emojis = [] - matching_emojis.extend( - emoji_obj - for emoji_obj in all_emojis - if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion] - ) - if not matching_emojis: - logger.warning(f"[EmojiService] 未找到匹配情感 '{emotion}' 的表情包") - return None - - selected_emoji = random.choice(matching_emojis) - emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path) - - if not emoji_base64: - logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}") - return None - - emoji_manager.update_emoji_usage(selected_emoji) - - logger.info(f"[EmojiService] 成功获取情感表情包: {selected_emoji.description}") - return emoji_base64, selected_emoji.description, emotion - - except Exception as e: - logger.error(f"[EmojiService] 根据情感获取表情包失败: {e}") - return None - - -# ============================================================================= -# 表情包信息查询函数 -# ============================================================================= - - -def get_count() -> int: - try: - return len(emoji_manager.emojis) - except Exception as e: - logger.error(f"[EmojiService] 获取表情包数量失败: {e}") - return 0 - - -def get_info(): - try: - return { - "current_count": len(emoji_manager.emojis), - "max_count": global_config.emoji.max_reg_num, - "available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]), - } - except Exception as e: - logger.error(f"[EmojiService] 获取表情包信息失败: {e}") - return {"current_count": 0, "max_count": 0, "available_emojis": 0} - - -def get_emotions() -> List[str]: - try: - emotions = set() - - for emoji_obj in emoji_manager.emojis: - if not emoji_obj.is_deleted and emoji_obj.emotion: - emotions.update(emoji_obj.emotion) - - return sorted(list(emotions)) - except Exception as e: - logger.error(f"[EmojiService] 获取情感标签失败: {e}") - return [] - - -async def get_all() -> List[Tuple[str, str, str]]: - try: - all_emojis = emoji_manager.emojis - - if not all_emojis: - logger.warning("[EmojiService] 没有可用的表情包") - return [] - - results = [] - for emoji_obj in all_emojis: - if emoji_obj.is_deleted: - continue - - emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path)) - - if not emoji_base64: - logger.error(f"[EmojiService] 无法转换表情包为base64: {emoji_obj.full_path}") - continue - - matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "随机表情" - results.append((emoji_base64, emoji_obj.description, matched_emotion)) - - logger.debug(f"[EmojiService] 成功获取 {len(results)} 个表情包") - return results - - except Exception as e: - logger.error(f"[EmojiService] 获取所有表情包失败: {e}") - return [] - - -def get_descriptions() -> List[str]: - try: - descriptions = [] - - descriptions.extend( - emoji_obj.description - for emoji_obj in emoji_manager.emojis - if not emoji_obj.is_deleted and emoji_obj.description - ) - return descriptions - except Exception as e: - logger.error(f"[EmojiService] 获取表情包描述失败: {e}") - return [] - - -# ============================================================================= -# 表情包注册函数 -# ============================================================================= - - -async def register_emoji(image_base64: str, filename: Optional[str] = None) -> Dict[str, Any]: - """注册新的表情包""" - if not image_base64: - raise ValueError("图片base64编码不能为空") - if not isinstance(image_base64, str): - raise TypeError("image_base64必须是字符串类型") - if filename is not None and not isinstance(filename, str): - raise TypeError("filename必须是字符串类型或None") - - try: - logger.info(f"[EmojiService] 开始注册表情包,文件名: {filename or '自动生成'}") - - count_before = len(emoji_manager.emojis) - max_count = global_config.emoji.max_reg_num - - can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace) - - if not can_register: - return { - "success": False, - "message": f"表情包数量已达上限({count_before}/{max_count})且未启用替换功能", - "description": None, - "emotions": None, - "replaced": None, - "hash": None, - } - - os.makedirs(EMOJI_DIR, exist_ok=True) - - if not filename: - import time as _time - - timestamp = int(_time.time()) - microseconds = int(_time.time() * 1000000) % 1000000 - - random_bytes = random.getrandbits(72).to_bytes(9, "big") - short_id = base64.b64encode(random_bytes).decode("ascii")[:12].rstrip("=") - short_id = short_id.replace("/", "_").replace("+", "-") - filename = f"emoji_{timestamp}_{microseconds}_{short_id}" - - if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")): - filename = f"{filename}.png" - - temp_file_path = os.path.join(EMOJI_DIR, filename) - attempts = 0 - max_attempts = 10 - while os.path.exists(temp_file_path) and attempts < max_attempts: - random_bytes = random.getrandbits(48).to_bytes(6, "big") - short_id = base64.b64encode(random_bytes).decode("ascii")[:8].rstrip("=") - short_id = short_id.replace("/", "_").replace("+", "-") - - name_part, ext = os.path.splitext(filename) - base_name = name_part.rsplit("_", 1)[0] - filename = f"{base_name}_{short_id}{ext}" - temp_file_path = os.path.join(EMOJI_DIR, filename) - attempts += 1 - - if os.path.exists(temp_file_path): - uuid_short = str(uuid.uuid4())[:8] - name_part, ext = os.path.splitext(filename) - base_name = name_part.rsplit("_", 1)[0] - filename = f"{base_name}_{uuid_short}{ext}" - temp_file_path = os.path.join(EMOJI_DIR, filename) - - counter = 1 - original_filename = filename - while os.path.exists(temp_file_path): - name_part, ext = os.path.splitext(original_filename) - filename = f"{name_part}_{counter}{ext}" - temp_file_path = os.path.join(EMOJI_DIR, filename) - counter += 1 - - if counter > 100: - logger.error(f"[EmojiService] 无法生成唯一文件名,尝试次数过多: {original_filename}") - return { - "success": False, - "message": "无法生成唯一文件名,请稍后重试", - "description": None, - "emotions": None, - "replaced": None, - "hash": None, - } - - try: - if not ImageUtils.base64_to_image(image_base64, temp_file_path): - logger.error(f"[EmojiService] 无法保存base64图片到文件: {temp_file_path}") - return { - "success": False, - "message": "无法保存图片文件", - "description": None, - "emotions": None, - "replaced": None, - "hash": None, - } - - logger.debug(f"[EmojiService] 图片已保存到临时文件: {temp_file_path}") - - except Exception as save_error: - logger.error(f"[EmojiService] 保存图片文件失败: {save_error}") - return { - "success": False, - "message": f"保存图片文件失败: {str(save_error)}", - "description": None, - "emotions": None, - "replaced": None, - "hash": None, - } - - register_success = await emoji_manager.register_emoji_by_filename(filename) - - if not register_success and os.path.exists(temp_file_path): - try: - os.remove(temp_file_path) - logger.debug(f"[EmojiService] 已清理临时文件: {temp_file_path}") - except Exception as cleanup_error: - logger.warning(f"[EmojiService] 清理临时文件失败: {cleanup_error}") - - if register_success: - count_after = len(emoji_manager.emojis) - replaced = count_after <= count_before - - new_emoji_info = None - if count_after > count_before or replaced: - try: - for emoji_obj in reversed(emoji_manager.emojis): - if not emoji_obj.is_deleted and ( - emoji_obj.file_name == filename - or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path)) - ): - new_emoji_info = emoji_obj - break - except Exception as find_error: - logger.warning(f"[EmojiService] 查找新注册表情包信息失败: {find_error}") - - description = new_emoji_info.description if new_emoji_info else None - emotions = new_emoji_info.emotion if new_emoji_info else None - emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None - - return { - "success": True, - "message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}", - "description": description, - "emotions": emotions, - "replaced": replaced, - "hash": emoji_hash, - } - else: - return { - "success": False, - "message": "表情包注册失败,可能因为重复、格式不支持或审核未通过", - "description": None, - "emotions": None, - "replaced": None, - "hash": None, - } - - except Exception as e: - logger.error(f"[EmojiService] 注册表情包时发生异常: {e}") - return { - "success": False, - "message": f"注册过程中发生错误: {str(e)}", - "description": None, - "emotions": None, - "replaced": None, - "hash": None, - } diff --git a/src/services/generator_service.py b/src/services/generator_service.py index 3587a244..278fc3f1 100644 --- a/src/services/generator_service.py +++ b/src/services/generator_service.py @@ -35,7 +35,7 @@ logger = get_logger("generator_service") # ============================================================================= -def get_replyer( +def _get_replyer( chat_stream: Optional[BotChatSession] = None, chat_id: Optional[str] = None, request_type: str = "replyer", @@ -58,6 +58,35 @@ def get_replyer( return None +def _extract_unknown_words(action_data: Optional[Dict[str, Any]]) -> Optional[List[str]]: + if not action_data: + return None + + unknown_words = action_data.get("unknown_words") + if not isinstance(unknown_words, list): + return None + + cleaned_words: List[str] = [] + for item in unknown_words: + if isinstance(item, str) and (cleaned_item := item.strip()): + cleaned_words.append(cleaned_item) + + return cleaned_words or None + + +def _build_message_sequence( + content: Optional[str], + *, + enable_splitter: bool, + enable_chinese_typo: bool, +) -> tuple[Optional[MessageSequence], List[str]]: + if not content: + return None, [] + + processed_output = process_llm_response(content, enable_splitter, enable_chinese_typo) + return MessageSequence(components=[TextComponent(text) for text in processed_output]), processed_output + + # ============================================================================= # 回复生成函数 # ============================================================================= @@ -87,7 +116,7 @@ async def generate_reply( reply_time_point = time.time() logger.debug("[GeneratorService] 开始生成回复") - replyer = get_replyer(chat_stream, chat_id, request_type=request_type) + replyer = _get_replyer(chat_stream, chat_id, request_type=request_type) if not replyer: logger.error("[GeneratorService] 无法获取回复器") return False, None @@ -98,16 +127,7 @@ async def generate_reply( if not reply_reason: reply_reason = action_data.get("reason", "") if unknown_words is None: - uw = action_data.get("unknown_words") - if isinstance(uw, list): - cleaned: List[str] = [] - for item in uw: - if isinstance(item, str): - s = item.strip() - if s: - cleaned.append(s) - if cleaned: - unknown_words = cleaned + unknown_words = _extract_unknown_words(action_data) success, llm_response = await replyer.generate_reply_with_context( extra_info=extra_info, @@ -126,13 +146,12 @@ async def generate_reply( if not success: logger.warning("[GeneratorService] 回复生成失败") return False, None - reply_set: Optional[MessageSequence] = None - if content := llm_response.content: - processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo) - llm_response.processed_output = processed_response - reply_set = MessageSequence(components=[]) - for text in processed_response: - reply_set.components.append(TextComponent(text)) + reply_set, processed_output = _build_message_sequence( + llm_response.content, + enable_splitter=enable_splitter, + enable_chinese_typo=enable_chinese_typo, + ) + llm_response.processed_output = processed_output llm_response.reply_set = reply_set logger.debug( f"[GeneratorService] 回复生成成功,生成了 {len(reply_set.components) if reply_set else 0} 个回复项" @@ -181,7 +200,7 @@ async def rewrite_reply( ) -> Tuple[bool, Optional["LLMGenerationDataModel"]]: """重写回复""" try: - replyer = get_replyer(chat_stream, chat_id, request_type=request_type) + replyer = _get_replyer(chat_stream, chat_id, request_type=request_type) if not replyer: logger.error("[GeneratorService] 无法获取回复器") return False, None @@ -198,9 +217,13 @@ async def rewrite_reply( reason=reason, reply_to=reply_to, ) - reply_set: Optional[MessageSequence] = None - if success and llm_response and (content := llm_response.content): - reply_set = process_human_text(content, enable_splitter, enable_chinese_typo) + reply_set, processed_output = _build_message_sequence( + llm_response.content if success and llm_response else None, + enable_splitter=enable_splitter, + enable_chinese_typo=enable_chinese_typo, + ) + if llm_response is not None: + llm_response.processed_output = processed_output llm_response.reply_set = reply_set if success: logger.info( @@ -219,44 +242,3 @@ async def rewrite_reply( return False, None -def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[MessageSequence]: - """将文本处理为更拟人化的文本""" - if not isinstance(content, str): - raise ValueError("content 必须是字符串类型") - try: - reply_set = MessageSequence(components=[]) - processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo) - - for text in processed_response: - reply_set.components.append(TextComponent(text)) - - return reply_set - - except Exception as e: - logger.error(f"[GeneratorService] 处理人形文本时出错: {e}") - return None - - -async def generate_response_custom( - chat_stream: Optional[BotChatSession] = None, - chat_id: Optional[str] = None, - request_type: str = "generator_api", - prompt: str = "", -) -> Optional[str]: - replyer = get_replyer(chat_stream, chat_id, request_type=request_type) - if not replyer: - logger.error("[GeneratorService] 无法获取回复器") - return None - - try: - logger.debug("[GeneratorService] 开始生成自定义回复") - response, _, _, _ = await replyer.llm_generate_content(prompt) - if response: - logger.debug("[GeneratorService] 自定义回复生成成功") - return response - else: - logger.warning("[GeneratorService] 自定义回复生成失败") - return None - except Exception as e: - logger.error(f"[GeneratorService] 生成自定义回复时出错: {e}") - return None diff --git a/src/services/message_service.py b/src/services/message_service.py index 3a19431a..966d8709 100644 --- a/src/services/message_service.py +++ b/src/services/message_service.py @@ -14,7 +14,6 @@ from src.common.database.database_model import ActionRecord, Images, ImageType from src.common.message_repository import count_messages, find_messages from src.common.utils.math_utils import translate_timestamp_to_human_readable from src.common.utils.utils_action import ActionUtils -from src.chat.utils.utils import is_bot_self from src.config.config import global_config @@ -113,103 +112,6 @@ def get_messages_by_time_in_chat( return _normalize_messages(messages) -def get_messages_by_time_in_chat_inclusive( - chat_id: str, - start_time: float, - end_time: float, - limit: int = 0, - limit_mode: str = "latest", - filter_mai: bool = False, - filter_command: bool = False, - filter_intercept_message_level: Optional[int] = None, -) -> List[SessionMessage]: - if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): - raise ValueError("start_time 和 end_time 必须是数字类型") - if limit < 0: - raise ValueError("limit 不能为负数") - if not chat_id: - raise ValueError("chat_id 不能为空") - if not isinstance(chat_id, str): - raise ValueError("chat_id 必须是字符串类型") - messages = find_messages( - message_filter={ - "chat_id": chat_id, - "time": { - "$gte": start_time, - "$lte": end_time, - }, - }, - limit=limit, - limit_mode=limit_mode, - filter_bot=filter_mai, - filter_command=filter_command, - filter_intercept_message_level=filter_intercept_message_level, - ) - return _normalize_messages(messages) - - -def get_messages_by_time_in_chat_for_users( - chat_id: str, - start_time: float, - end_time: float, - person_ids: List[str], - limit: int = 0, - limit_mode: str = "latest", -) -> List[SessionMessage]: - if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): - raise ValueError("start_time 和 end_time 必须是数字类型") - if limit < 0: - raise ValueError("limit 不能为负数") - if not chat_id: - raise ValueError("chat_id 不能为空") - if not isinstance(chat_id, str): - raise ValueError("chat_id 必须是字符串类型") - messages = find_messages( - message_filter={ - "chat_id": chat_id, - "time": { - "$gte": start_time, - "$lte": end_time, - }, - "user_id": {"$in": person_ids}, - }, - limit=limit, - limit_mode=limit_mode, - ) - return _normalize_messages(messages) - - -def get_random_chat_messages( - start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False -) -> List[SessionMessage]: - if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): - raise ValueError("start_time 和 end_time 必须是数字类型") - if limit < 0: - raise ValueError("limit 不能为负数") - return get_messages_by_time(start_time, end_time, limit, limit_mode, filter_mai) - - -def get_messages_by_time_for_users( - start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest" -) -> List[SessionMessage]: - if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): - raise ValueError("start_time 和 end_time 必须是数字类型") - if limit < 0: - raise ValueError("limit 不能为负数") - messages = find_messages( - message_filter={ - "time": { - "$gte": start_time, - "$lte": end_time, - }, - "user_id": {"$in": person_ids}, - }, - limit=limit, - limit_mode=limit_mode, - ) - return _normalize_messages(messages) - - def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[SessionMessage]: if not isinstance(timestamp, (int, float)): raise ValueError("timestamp 必须是数字类型") @@ -252,24 +154,6 @@ def get_messages_before_time_in_chat( return _normalize_messages(messages) -def get_messages_before_time_for_users( - timestamp: float, person_ids: List[str], limit: int = 0 -) -> List[SessionMessage]: - if not isinstance(timestamp, (int, float)): - raise ValueError("timestamp 必须是数字类型") - if limit < 0: - raise ValueError("limit 不能为负数") - messages = find_messages( - message_filter={ - "time": {"$lt": timestamp}, - "user_id": {"$in": person_ids}, - }, - limit=limit, - limit_mode="latest", - ) - return _normalize_messages(messages) - - def get_recent_messages( chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False ) -> List[SessionMessage]: @@ -307,22 +191,6 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional return count_messages(message_filter) -def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int: - if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): - raise ValueError("start_time 和 end_time 必须是数字类型") - if not chat_id: - raise ValueError("chat_id 不能为空") - if not isinstance(chat_id, str): - raise ValueError("chat_id 必须是字符串类型") - return count_messages( - { - "chat_id": chat_id, - "time": {"$gt": start_time, "$lte": end_time}, - "user_id": {"$in": person_ids}, - } - ) - - # ============================================================================= # 消息格式化函数 # ============================================================================= @@ -365,17 +233,6 @@ def build_readable_messages( return "\n".join(lines) -def build_readable_messages_to_str( - messages: List[SessionMessage], - replace_bot_name: bool = True, - timestamp_mode: str = "relative", - read_mark: float = 0.0, - truncate: bool = False, - show_actions: bool = False, -) -> str: - return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions) - - def build_readable_messages_with_id( messages: List[SessionMessage], replace_bot_name: bool = True, @@ -415,148 +272,6 @@ def build_readable_messages_with_id( return "\n".join(lines), message_id_list -async def build_readable_messages_with_details( - messages: List[SessionMessage], - replace_bot_name: bool = True, - timestamp_mode: str = "relative", - truncate: bool = False, -) -> Tuple[str, List[Tuple[float, str, str]]]: - normalized_messages = _normalize_messages(messages) - message_list = [ - ( - message.timestamp.timestamp(), - message.message_info.user_info.user_id, - message.processed_plain_text or "", - ) - for message in normalized_messages - ] - return build_readable_messages(normalized_messages, replace_bot_name, timestamp_mode, truncate=truncate), message_list - - -async def get_person_ids_from_messages(messages: List[Any]) -> List[str]: - person_ids: List[str] = [] - for message in messages: - if isinstance(message, SessionMessage): - person_ids.append(message.message_info.user_info.user_id) - elif isinstance(message, dict) and (user_id := message.get("user_id")): - person_ids.append(str(user_id)) - return person_ids - - -# ============================================================================= -# 消息过滤函数 -# ============================================================================= - - -def filter_mai_messages(messages: List[SessionMessage]) -> List[SessionMessage]: - """从消息列表中移除麦麦的消息""" - return [ - msg - for msg in messages - if not is_bot_self(msg.platform, msg.message_info.user_info.user_id) - ] - - -def get_raw_msg_by_timestamp( - timestamp_start: float, - timestamp_end: float, - limit: int = 0, - limit_mode: str = "latest", -) -> List[SessionMessage]: - return get_messages_by_time(timestamp_start, timestamp_end, limit, limit_mode) - - -def get_raw_msg_by_timestamp_with_chat( - chat_id: str, - timestamp_start: float, - timestamp_end: float, - limit: int = 0, - limit_mode: str = "latest", - filter_bot: bool = False, - filter_command: bool = False, - filter_intercept_message_level: Optional[int] = None, -) -> List[SessionMessage]: - return get_messages_by_time_in_chat( - chat_id, - timestamp_start, - timestamp_end, - limit, - limit_mode, - filter_bot, - filter_command, - filter_intercept_message_level, - ) - - -def get_raw_msg_by_timestamp_with_chat_inclusive( - chat_id: str, - timestamp_start: float, - timestamp_end: float, - limit: int = 0, - limit_mode: str = "latest", - filter_bot: bool = False, - filter_command: bool = False, - filter_intercept_message_level: Optional[int] = None, -) -> List[SessionMessage]: - return get_messages_by_time_in_chat_inclusive( - chat_id, - timestamp_start, - timestamp_end, - limit, - limit_mode, - filter_bot, - filter_command, - filter_intercept_message_level, - ) - - -def get_raw_msg_by_timestamp_with_chat_users( - chat_id: str, - timestamp_start: float, - timestamp_end: float, - person_ids: List[str], - limit: int = 0, - limit_mode: str = "latest", -) -> List[SessionMessage]: - return get_messages_by_time_in_chat_for_users(chat_id, timestamp_start, timestamp_end, person_ids, limit, limit_mode) - - -def get_raw_msg_by_timestamp_with_users( - timestamp_start: float, - timestamp_end: float, - person_ids: List[str], - limit: int = 0, - limit_mode: str = "latest", -) -> List[SessionMessage]: - return get_messages_by_time_for_users(timestamp_start, timestamp_end, person_ids, limit, limit_mode) - - -def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[SessionMessage]: - return get_messages_before_time(timestamp, limit) - - -def get_raw_msg_before_timestamp_with_chat( - chat_id: str, - timestamp: float, - limit: int = 0, - filter_intercept_message_level: Optional[int] = None, -) -> List[SessionMessage]: - return get_messages_before_time_in_chat(chat_id, timestamp, limit, False, filter_intercept_message_level) - - -def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[SessionMessage]: - return get_messages_before_time_for_users(timestamp, person_ids, limit) - - -def get_raw_msg_by_timestamp_random( - timestamp_start: float, - timestamp_end: float, - limit: int = 0, - limit_mode: str = "latest", -) -> List[SessionMessage]: - return get_random_chat_messages(timestamp_start, timestamp_end, limit, limit_mode) - - def get_actions_by_timestamp_with_chat(chat_id: str, timestamp_start: float, timestamp_end: float) -> List[MaiActionRecord]: with get_db_session() as session: statement = (