炸 service 层 x 2,把能归类为现有重构好的模块的都归类过去

This commit is contained in:
DrSmoothl
2026-03-14 00:33:08 +08:00
parent 43c5b34623
commit 4bc9c5bf7e
12 changed files with 207 additions and 1000 deletions

View File

@@ -28,10 +28,7 @@ from src.services import (
message_service as message_api,
database_service as database_api,
)
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
)
from src.services.message_service import build_readable_messages_with_id, get_messages_before_time_in_chat
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
@@ -275,7 +272,7 @@ class BrainChatting:
# 一次思考迭代Think - Act - Observe
# 获取聊天上下文
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
message_list_before_now = get_messages_before_time_in_chat(
chat_id=self.stream_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),

View File

@@ -14,11 +14,11 @@ from src.common.logger import get_logger
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.prompt.prompt_manager import prompt_manager
from src.chat.utils.chat_message_builder import (
from src.services.message_service import (
build_readable_actions,
get_actions_by_timestamp_with_chat,
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
get_actions_by_timestamp_with_chat,
get_messages_before_time_in_chat,
)
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.planner_actions.action_manager import ActionManager
@@ -163,7 +163,7 @@ class BrainPlanner:
plan_start = time.perf_counter()
# 获取聊天上下文
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
message_list_before_now = get_messages_before_time_in_chat(
chat_id=self.chat_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),

View File

@@ -6,7 +6,7 @@ from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
from src.services.message_service import build_readable_messages, get_messages_before_time_in_chat
from src.core.types import ActionActivationType, ActionInfo
from src.core.announcement_manager import global_announcement_manager
@@ -51,7 +51,7 @@ class ActionModifier:
self.action_manager.restore_actions()
all_actions = self.action_manager.get_using_actions()
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
message_list_before_now_half = get_messages_before_time_in_chat(
chat_id=self.chat_stream.stream_id,
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 10),

View File

@@ -15,17 +15,17 @@ from src.common.logger import get_logger
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.prompt.prompt_manager import prompt_manager
from src.chat.utils.chat_message_builder import (
from src.services.message_service import (
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
replace_user_references,
get_messages_before_time_in_chat,
translate_pid_to_description,
)
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.core.types import ActionActivationType, ActionInfo, ComponentType
from src.core.component_registry import component_registry
from src.services.message_service import translate_pid_to_description
from src.person_info.person_info import Person
if TYPE_CHECKING:
@@ -389,7 +389,7 @@ class ActionPlanner:
plan_start = time.perf_counter()
# 获取聊天上下文
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
message_list_before_now = get_messages_before_time_in_chat(
chat_id=self.chat_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),

View File

@@ -22,7 +22,7 @@ from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.prompt.prompt_manager import prompt_manager
from src.services.message_service import (
build_readable_messages,
get_raw_msg_before_timestamp_with_chat,
get_messages_before_time_in_chat,
replace_user_references,
translate_pid_to_description,
)
@@ -809,14 +809,14 @@ class DefaultReplyer:
# 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target)
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
message_list_before_now_long = get_messages_before_time_in_chat(
chat_id=chat_id,
timestamp=reply_time_point,
limit=global_config.chat.max_context_size * 1,
filter_intercept_message_level=1,
)
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
message_list_before_short = get_messages_before_time_in_chat(
chat_id=chat_id,
timestamp=reply_time_point,
limit=int(global_config.chat.max_context_size * 0.33),
@@ -1022,7 +1022,7 @@ class DefaultReplyer:
# 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target)
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
message_list_before_now_half = get_messages_before_time_in_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 15),

View File

@@ -23,7 +23,7 @@ from src.prompt.prompt_manager import prompt_manager
from src.chat.utils.common_utils import TempMethodsExpression
from src.services.message_service import (
build_readable_messages,
get_raw_msg_before_timestamp_with_chat,
get_messages_before_time_in_chat,
replace_user_references,
translate_pid_to_description,
)
@@ -650,7 +650,7 @@ class PrivateReplyer:
# 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target)
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
message_list_before_now_long = get_messages_before_time_in_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=global_config.chat.max_context_size,
@@ -666,7 +666,7 @@ class PrivateReplyer:
long_time_notice=True,
)
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
message_list_before_short = get_messages_before_time_in_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.33),
@@ -857,7 +857,7 @@ class PrivateReplyer:
# 将[picid:xxx]替换为具体的图片描述
target = self._replace_picids_with_descriptions(target)
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
message_list_before_now_half = get_messages_before_time_in_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 15),

View File

@@ -1,12 +1,38 @@
import random
from pathlib import Path
from typing import Any, Dict, List, Optional
import time
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager
from src.common.data_models.image_data_model import MaiEmoji
from src.common.logger import get_logger
from src.common.utils.utils_image import ImageUtils
logger = get_logger("plugin_runtime.integration")
class RuntimeDataCapabilityMixin:
@staticmethod
def _serialize_emoji_payload(emoji: MaiEmoji) -> Optional[Dict[str, str]]:
emoji_base64 = ImageUtils.image_path_to_base64(str(emoji.full_path))
if not emoji_base64:
return None
matched_emotion = emoji.emotion[0] if emoji.emotion else ""
return {
"base64": emoji_base64,
"description": emoji.description,
"emotion": matched_emotion,
}
@staticmethod
def _build_emoji_temp_path() -> Path:
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
return EMOJI_DIR / f"emoji_cap_{int(time.time() * 1000000)}.png"
async def _cap_database_query(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import database_service as database_api
@@ -338,7 +364,7 @@ class RuntimeDataCapabilityMixin:
limit=args.get("limit", 0),
)
readable = message_api.build_readable_messages_to_str(
readable = message_api.build_readable_messages(
messages=messages,
replace_bot_name=args.get("replace_bot_name", True),
timestamp_mode=args.get("timestamp_mode", "relative"),
@@ -397,101 +423,173 @@ class RuntimeDataCapabilityMixin:
return {"success": False, "error": str(e)}
async def _cap_emoji_get_by_description(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import emoji_service as emoji_api
from src.chat.emoji_system.emoji_manager import emoji_manager
description: str = args.get("description", "")
if not description:
return {"success": False, "error": "缺少必要参数 description"}
try:
result = await emoji_api.get_by_description(description=description)
if result is None:
emoji = await emoji_manager.get_emoji_for_emotion(description)
if emoji is None:
return {"success": True, "emoji": None}
serialized = self._serialize_emoji_payload(emoji)
if serialized is None:
return {"success": True, "emoji": None}
emoji_base64, emoji_desc, matched_emotion = result
return {
"success": True,
"emoji": {
"base64": emoji_base64,
"description": emoji_desc,
"emotion": matched_emotion,
},
"emoji": serialized,
}
except Exception as e:
logger.error(f"[cap.emoji.get_by_description] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_emoji_get_random(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import emoji_service as emoji_api
from src.chat.emoji_system.emoji_manager import emoji_manager
count: int = args.get("count", 1)
try:
results = await emoji_api.get_random(count=count)
emojis = [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results]
if count < 0:
return {"success": False, "error": "count 不能为负数"}
emojis_source = list(emoji_manager.emojis)
if count == 0 or not emojis_source:
return {"success": True, "emojis": []}
selected = random.sample(emojis_source, min(count, len(emojis_source)))
emojis: List[Dict[str, str]] = []
for emoji in selected:
emoji_manager.update_emoji_usage(emoji)
serialized = self._serialize_emoji_payload(emoji)
if serialized is not None:
if not serialized["emotion"]:
serialized["emotion"] = "随机表情"
emojis.append(serialized)
return {"success": True, "emojis": emojis}
except Exception as e:
logger.error(f"[cap.emoji.get_random] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_emoji_get_count(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import emoji_service as emoji_api
try:
return {"success": True, "count": emoji_api.get_count()}
from src.chat.emoji_system.emoji_manager import emoji_manager
return {"success": True, "count": len(emoji_manager.emojis)}
except Exception as e:
logger.error(f"[cap.emoji.get_count] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_emoji_get_emotions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import emoji_service as emoji_api
try:
return {"success": True, "emotions": emoji_api.get_emotions()}
from src.chat.emoji_system.emoji_manager import emoji_manager
emotions = sorted({emotion for emoji in emoji_manager.emojis for emotion in emoji.emotion})
return {"success": True, "emotions": emotions}
except Exception as e:
logger.error(f"[cap.emoji.get_emotions] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_emoji_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import emoji_service as emoji_api
try:
results = await emoji_api.get_all()
emojis = [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results] if results else []
from src.chat.emoji_system.emoji_manager import emoji_manager
emojis = []
for emoji in emoji_manager.emojis:
serialized = self._serialize_emoji_payload(emoji)
if serialized is not None:
if not serialized["emotion"]:
serialized["emotion"] = "随机表情"
emojis.append(serialized)
return {"success": True, "emojis": emojis}
except Exception as e:
logger.error(f"[cap.emoji.get_all] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_emoji_get_info(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import emoji_service as emoji_api
try:
return {"success": True, "info": emoji_api.get_info()}
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.config.config import global_config
current_count = len(emoji_manager.emojis)
return {
"success": True,
"info": {
"current_count": current_count,
"max_count": global_config.emoji.max_reg_num,
"available_emojis": current_count,
},
}
except Exception as e:
logger.error(f"[cap.emoji.get_info] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_emoji_register(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import emoji_service as emoji_api
from src.chat.emoji_system.emoji_manager import emoji_manager
emoji_base64: str = args.get("emoji_base64", "")
if not emoji_base64:
return {"success": False, "error": "缺少必要参数 emoji_base64"}
try:
return await emoji_api.register_emoji(emoji_base64)
count_before = len(emoji_manager.emojis)
temp_file_path = self._build_emoji_temp_path()
if not ImageUtils.base64_to_image(emoji_base64, str(temp_file_path)):
return {"success": False, "message": "无法保存图片文件", "description": None, "emotions": None, "replaced": None, "hash": None}
register_success = await emoji_manager.register_emoji_by_filename(temp_file_path)
if not register_success:
if temp_file_path.exists():
temp_file_path.unlink(missing_ok=True)
return {
"success": False,
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
count_after = len(emoji_manager.emojis)
replaced = count_after <= count_before
new_emoji = next(
(
item
for item in reversed(emoji_manager.emojis)
if temp_file_path.name == item.file_name or temp_file_path.name in str(item.full_path)
),
None,
)
return {
"success": True,
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
"description": None if new_emoji is None else new_emoji.description,
"emotions": None if new_emoji is None else new_emoji.emotion,
"replaced": replaced,
"hash": None if new_emoji is None else new_emoji.file_hash,
}
except Exception as e:
logger.error(f"[cap.emoji.register] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def _cap_emoji_delete(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.services import emoji_service as emoji_api
from src.chat.emoji_system.emoji_manager import emoji_manager
emoji_hash: str = args.get("emoji_hash", "")
if not emoji_hash:
return {"success": False, "error": "缺少必要参数 emoji_hash"}
try:
return await emoji_api.delete_emoji(emoji_hash)
emoji = emoji_manager.get_emoji_by_hash(emoji_hash)
if emoji is None:
return {"success": False, "message": f"未找到表情包: {emoji_hash}", "hash": emoji_hash}
success = emoji_manager.delete_emoji(emoji, not bool(emoji.description and emoji.description.strip()))
if not success:
return {"success": False, "message": f"删除表情包失败: {emoji_hash}", "hash": emoji_hash}
emoji_manager.emojis = [item for item in emoji_manager.emojis if item.file_hash != emoji_hash]
emoji_manager._emoji_num = len(emoji_manager.emojis)
return {"success": True, "message": f"成功删除表情包: {emoji_hash}", "hash": emoji_hash}
except Exception as e:
logger.error(f"[cap.emoji.delete] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}

View File

@@ -1,171 +0,0 @@
"""
聊天服务模块
提供聊天信息查询和管理的核心功能。
"""
from typing import Any, Callable, Dict, List, Optional
from enum import Enum
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from src.common.logger import get_logger
logger = get_logger("chat_service")
class SpecialTypes(Enum):
"""特殊枚举类型"""
ALL_PLATFORMS = "all_platforms"
class ChatManager:
"""聊天管理器 - 负责聊天信息的查询和管理"""
@staticmethod
def _validate_platform(platform: Optional[str] | SpecialTypes) -> None:
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
@staticmethod
def _match_platform(chat_stream: BotChatSession, platform: Optional[str] | SpecialTypes) -> bool:
return platform == SpecialTypes.ALL_PLATFORMS or chat_stream.platform == platform
@staticmethod
def _get_streams(
platform: Optional[str] | SpecialTypes = "qq", is_group_session: Optional[bool] = None
) -> List[BotChatSession]:
ChatManager._validate_platform(platform)
try:
streams = [
stream
for stream in _chat_manager.sessions.values()
if ChatManager._match_platform(stream, platform)
and (is_group_session is None or stream.is_group_session == is_group_session)
]
return streams
except Exception as e:
logger.error(f"[ChatService] 获取聊天流失败: {e}")
return []
@staticmethod
def _find_stream(
predicate: Callable[[BotChatSession], bool],
platform: Optional[str] | SpecialTypes = "qq",
) -> Optional[BotChatSession]:
for stream in ChatManager._get_streams(platform=platform):
if predicate(stream):
return stream
return None
@staticmethod
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
streams = ChatManager._get_streams(platform=platform)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的聊天流")
return streams
@staticmethod
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
streams = ChatManager._get_streams(platform=platform, is_group_session=True)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的群聊流")
return streams
@staticmethod
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
streams = ChatManager._get_streams(platform=platform, is_group_session=False)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的私聊流")
return streams
@staticmethod
def get_group_stream_by_group_id(
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
if not isinstance(group_id, str):
raise TypeError("group_id 必须是字符串类型")
ChatManager._validate_platform(platform)
if not group_id:
raise ValueError("group_id 不能为空")
try:
stream = ChatManager._find_stream(
lambda item: item.is_group_session and str(item.group_id) == str(group_id),
platform=platform,
)
if stream is not None:
logger.debug(f"[ChatService] 找到群ID {group_id} 的聊天流")
return stream
logger.warning(f"[ChatService] 未找到群ID {group_id} 的聊天流")
except Exception as e:
logger.error(f"[ChatService] 查找群聊流失败: {e}")
return None
@staticmethod
def get_private_stream_by_user_id(
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
if not isinstance(user_id, str):
raise TypeError("user_id 必须是字符串类型")
ChatManager._validate_platform(platform)
if not user_id:
raise ValueError("user_id 不能为空")
try:
stream = ChatManager._find_stream(
lambda item: (not item.is_group_session) and str(item.user_id) == str(user_id),
platform=platform,
)
if stream is not None:
logger.debug(f"[ChatService] 找到用户ID {user_id} 的私聊流")
return stream
logger.warning(f"[ChatService] 未找到用户ID {user_id} 的私聊流")
except Exception as e:
logger.error(f"[ChatService] 查找私聊流失败: {e}")
return None
@staticmethod
def get_stream_type(chat_stream: BotChatSession) -> str:
if not isinstance(chat_stream, BotChatSession):
raise TypeError("chat_stream 必须是 BotChatSession 类型")
if not chat_stream:
raise ValueError("chat_stream 不能为 None")
return "group" if chat_stream.is_group_session else "private"
@staticmethod
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
if not chat_stream:
raise ValueError("chat_stream 不能为 None")
if not isinstance(chat_stream, BotChatSession):
raise TypeError("chat_stream 必须是 BotChatSession 类型")
try:
info: Dict[str, Any] = {
"session_id": chat_stream.session_id,
"platform": chat_stream.platform,
"type": ChatManager.get_stream_type(chat_stream),
}
if chat_stream.is_group_session:
info["group_id"] = chat_stream.group_id
if (
chat_stream.context
and chat_stream.context.message
and chat_stream.context.message.message_info.group_info
):
info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
else:
info["group_name"] = "未知群聊"
else:
info["user_id"] = chat_stream.user_id
if (
chat_stream.context
and chat_stream.context.message
and chat_stream.context.message.message_info.user_info
):
info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
else:
info["user_name"] = "未知用户"
return info
except Exception as e:
logger.error(f"[ChatService] 获取聊天流信息失败: {e}")
return {}

View File

@@ -9,6 +9,7 @@ from typing import Any, Optional
from sqlalchemy import delete, func, select
from sqlmodel import SQLModel
from src.chat.message_receive.chat_manager import BotChatSession
from src.common.database.database import get_db_session
from src.common.database.database_model import ActionRecord
from src.common.logger import get_logger
@@ -23,12 +24,10 @@ def _to_dict(record: Any) -> dict[str, Any]:
return record
if hasattr(record, "model_dump"):
return record.model_dump()
if hasattr(record, "__dict__"):
return dict(record.__dict__)
return {}
return dict(record.__dict__) if hasattr(record, "__dict__") else {}
def _get_model_field(model_class: type[SQLModel], field_name: str):
def _get_model_field(model_class: type[SQLModel], field_name: str) -> Any:
field = getattr(model_class, field_name, None)
if field is None:
raise ValueError(f"{model_class.__name__} 不存在字段 {field_name}")
@@ -41,7 +40,7 @@ def _build_filters(model_class: type[SQLModel], filters: Optional[dict[str, Any]
return [_get_model_field(model_class, field_name) == value for field_name, value in filters.items()]
def _apply_order_by(statement, model_class: type[SQLModel], order_by: Optional[str | list[str]] = None):
def _apply_order_by(statement: Any, model_class: type[SQLModel], order_by: Optional[str | list[str]] = None) -> Any:
if not order_by:
return statement
@@ -60,7 +59,7 @@ async def db_save(
data: dict[str, Any],
key_field: Optional[str] = None,
key_value: Optional[Any] = None,
):
) -> Optional[dict[str, Any]]:
try:
with get_db_session() as session:
record = None
@@ -91,12 +90,11 @@ async def db_get(
limit: Optional[int] = None,
order_by: Optional[str | list[str]] = None,
single_result: bool = False,
):
) -> Optional[dict[str, Any]] | list[dict[str, Any]]:
try:
with get_db_session(auto_commit=False) as session:
statement = select(model_class)
conditions = _build_filters(model_class, filters)
if conditions:
if conditions := _build_filters(model_class, filters):
statement = statement.where(*conditions)
statement = _apply_order_by(statement, model_class, order_by)
if limit:
@@ -116,8 +114,7 @@ async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters:
try:
with get_db_session() as session:
statement = select(model_class)
conditions = _build_filters(model_class, filters)
if conditions:
if conditions := _build_filters(model_class, filters):
statement = statement.where(*conditions)
records = session.exec(statement).all()
for record in records:
@@ -136,8 +133,7 @@ async def db_delete(model_class: type[SQLModel], filters: Optional[dict[str, Any
try:
with get_db_session() as session:
statement = delete(model_class)
conditions = _build_filters(model_class, filters)
if conditions:
if conditions := _build_filters(model_class, filters):
statement = statement.where(*conditions)
result = session.exec(statement)
return result.rowcount or 0
@@ -151,8 +147,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
try:
with get_db_session(auto_commit=False) as session:
statement = select(func.count()).select_from(model_class)
conditions = _build_filters(model_class, filters)
if conditions:
if conditions := _build_filters(model_class, filters):
statement = statement.where(*conditions)
result = session.exec(statement).one()
return int(result or 0)
@@ -163,18 +158,15 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
async def store_action_info(
chat_stream=None,
chat_stream: BotChatSession,
builtin_prompt: Optional[str] = None,
display_prompt: str = "",
thinking_id: str = "",
action_data: Optional[dict] = None,
action_data: Optional[dict[str, Any]] = None,
action_name: str = "",
action_reasoning: str = "",
):
) -> Optional[dict[str, Any]]:
try:
if chat_stream is None:
raise ValueError("store_action_info 需要 chat_stream")
record_data = {
"action_id": thinking_id or str(int(time.time() * 1000000)),
"timestamp": datetime.now(),

View File

@@ -1,406 +0,0 @@
"""
表情服务模块
提供表情包相关的核心功能。
"""
import base64
import os
import random
import uuid
from typing import Any, Dict, List, Optional, Tuple
from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR
from src.common.logger import get_logger
from src.common.utils.utils_image import ImageUtils
from src.config.config import global_config
logger = get_logger("emoji_service")
# =============================================================================
# 表情包获取函数
# =============================================================================
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
"""根据描述选择表情包"""
if not description:
raise ValueError("描述不能为空")
if not isinstance(description, str):
raise TypeError("描述必须是字符串类型")
try:
logger.debug(f"[EmojiService] 根据描述获取表情包: {description}")
emoji_obj = await emoji_manager.get_emoji_for_emotion(description)
if not emoji_obj:
logger.warning(f"[EmojiService] 未找到匹配描述 '{description}' 的表情包")
return None
emoji_path = str(emoji_obj.full_path)
emoji_description = emoji_obj.description
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else ""
emoji_base64 = ImageUtils.image_path_to_base64(emoji_path)
if not emoji_base64:
logger.error(f"[EmojiService] 无法将表情包文件转换为base64: {emoji_path}")
return None
logger.debug(f"[EmojiService] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}")
return emoji_base64, emoji_description, matched_emotion
except Exception as e:
logger.error(f"[EmojiService] 获取表情包失败: {e}")
return None
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
"""随机获取指定数量的表情包"""
if not isinstance(count, int):
raise TypeError("count 必须是整数类型")
if count < 0:
raise ValueError("count 不能为负数")
if count == 0:
logger.warning("[EmojiService] count 为0返回空列表")
return []
try:
all_emojis = emoji_manager.emojis
if not all_emojis:
logger.warning("[EmojiService] 没有可用的表情包")
return []
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
if not valid_emojis:
logger.warning("[EmojiService] 没有有效的表情包")
return []
if len(valid_emojis) < count:
logger.debug(
f"[EmojiService] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包"
)
count = len(valid_emojis)
selected_emojis = random.sample(valid_emojis, count)
results = []
for selected_emoji in selected_emojis:
emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
if not emoji_base64:
logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}")
continue
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
emoji_manager.update_emoji_usage(selected_emoji)
results.append((emoji_base64, selected_emoji.description, matched_emotion))
if not results and count > 0:
logger.warning("[EmojiService] 随机获取表情包失败,没有一个可以成功处理")
return []
logger.debug(f"[EmojiService] 成功获取 {len(results)} 个随机表情包")
return results
except Exception as e:
logger.error(f"[EmojiService] 获取随机表情包失败: {e}")
return []
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
"""根据情感标签获取表情包"""
if not emotion:
raise ValueError("情感标签不能为空")
if not isinstance(emotion, str):
raise TypeError("情感标签必须是字符串类型")
try:
logger.info(f"[EmojiService] 根据情感获取表情包: {emotion}")
all_emojis = emoji_manager.emojis
matching_emojis = []
matching_emojis.extend(
emoji_obj
for emoji_obj in all_emojis
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
)
if not matching_emojis:
logger.warning(f"[EmojiService] 未找到匹配情感 '{emotion}' 的表情包")
return None
selected_emoji = random.choice(matching_emojis)
emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path)
if not emoji_base64:
logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}")
return None
emoji_manager.update_emoji_usage(selected_emoji)
logger.info(f"[EmojiService] 成功获取情感表情包: {selected_emoji.description}")
return emoji_base64, selected_emoji.description, emotion
except Exception as e:
logger.error(f"[EmojiService] 根据情感获取表情包失败: {e}")
return None
# =============================================================================
# 表情包信息查询函数
# =============================================================================
def get_count() -> int:
try:
return len(emoji_manager.emojis)
except Exception as e:
logger.error(f"[EmojiService] 获取表情包数量失败: {e}")
return 0
def get_info():
try:
return {
"current_count": len(emoji_manager.emojis),
"max_count": global_config.emoji.max_reg_num,
"available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]),
}
except Exception as e:
logger.error(f"[EmojiService] 获取表情包信息失败: {e}")
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
def get_emotions() -> List[str]:
try:
emotions = set()
for emoji_obj in emoji_manager.emojis:
if not emoji_obj.is_deleted and emoji_obj.emotion:
emotions.update(emoji_obj.emotion)
return sorted(list(emotions))
except Exception as e:
logger.error(f"[EmojiService] 获取情感标签失败: {e}")
return []
async def get_all() -> List[Tuple[str, str, str]]:
try:
all_emojis = emoji_manager.emojis
if not all_emojis:
logger.warning("[EmojiService] 没有可用的表情包")
return []
results = []
for emoji_obj in all_emojis:
if emoji_obj.is_deleted:
continue
emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path))
if not emoji_base64:
logger.error(f"[EmojiService] 无法转换表情包为base64: {emoji_obj.full_path}")
continue
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "随机表情"
results.append((emoji_base64, emoji_obj.description, matched_emotion))
logger.debug(f"[EmojiService] 成功获取 {len(results)} 个表情包")
return results
except Exception as e:
logger.error(f"[EmojiService] 获取所有表情包失败: {e}")
return []
def get_descriptions() -> List[str]:
try:
descriptions = []
descriptions.extend(
emoji_obj.description
for emoji_obj in emoji_manager.emojis
if not emoji_obj.is_deleted and emoji_obj.description
)
return descriptions
except Exception as e:
logger.error(f"[EmojiService] 获取表情包描述失败: {e}")
return []
# =============================================================================
# 表情包注册函数
# =============================================================================
async def register_emoji(image_base64: str, filename: Optional[str] = None) -> Dict[str, Any]:
"""注册新的表情包"""
if not image_base64:
raise ValueError("图片base64编码不能为空")
if not isinstance(image_base64, str):
raise TypeError("image_base64必须是字符串类型")
if filename is not None and not isinstance(filename, str):
raise TypeError("filename必须是字符串类型或None")
try:
logger.info(f"[EmojiService] 开始注册表情包,文件名: {filename or '自动生成'}")
count_before = len(emoji_manager.emojis)
max_count = global_config.emoji.max_reg_num
can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace)
if not can_register:
return {
"success": False,
"message": f"表情包数量已达上限({count_before}/{max_count})且未启用替换功能",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
os.makedirs(EMOJI_DIR, exist_ok=True)
if not filename:
import time as _time
timestamp = int(_time.time())
microseconds = int(_time.time() * 1000000) % 1000000
random_bytes = random.getrandbits(72).to_bytes(9, "big")
short_id = base64.b64encode(random_bytes).decode("ascii")[:12].rstrip("=")
short_id = short_id.replace("/", "_").replace("+", "-")
filename = f"emoji_{timestamp}_{microseconds}_{short_id}"
if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
filename = f"{filename}.png"
temp_file_path = os.path.join(EMOJI_DIR, filename)
attempts = 0
max_attempts = 10
while os.path.exists(temp_file_path) and attempts < max_attempts:
random_bytes = random.getrandbits(48).to_bytes(6, "big")
short_id = base64.b64encode(random_bytes).decode("ascii")[:8].rstrip("=")
short_id = short_id.replace("/", "_").replace("+", "-")
name_part, ext = os.path.splitext(filename)
base_name = name_part.rsplit("_", 1)[0]
filename = f"{base_name}_{short_id}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
attempts += 1
if os.path.exists(temp_file_path):
uuid_short = str(uuid.uuid4())[:8]
name_part, ext = os.path.splitext(filename)
base_name = name_part.rsplit("_", 1)[0]
filename = f"{base_name}_{uuid_short}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
counter = 1
original_filename = filename
while os.path.exists(temp_file_path):
name_part, ext = os.path.splitext(original_filename)
filename = f"{name_part}_{counter}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
counter += 1
if counter > 100:
logger.error(f"[EmojiService] 无法生成唯一文件名,尝试次数过多: {original_filename}")
return {
"success": False,
"message": "无法生成唯一文件名,请稍后重试",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
try:
if not ImageUtils.base64_to_image(image_base64, temp_file_path):
logger.error(f"[EmojiService] 无法保存base64图片到文件: {temp_file_path}")
return {
"success": False,
"message": "无法保存图片文件",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
logger.debug(f"[EmojiService] 图片已保存到临时文件: {temp_file_path}")
except Exception as save_error:
logger.error(f"[EmojiService] 保存图片文件失败: {save_error}")
return {
"success": False,
"message": f"保存图片文件失败: {str(save_error)}",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
register_success = await emoji_manager.register_emoji_by_filename(filename)
if not register_success and os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
logger.debug(f"[EmojiService] 已清理临时文件: {temp_file_path}")
except Exception as cleanup_error:
logger.warning(f"[EmojiService] 清理临时文件失败: {cleanup_error}")
if register_success:
count_after = len(emoji_manager.emojis)
replaced = count_after <= count_before
new_emoji_info = None
if count_after > count_before or replaced:
try:
for emoji_obj in reversed(emoji_manager.emojis):
if not emoji_obj.is_deleted and (
emoji_obj.file_name == filename
or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path))
):
new_emoji_info = emoji_obj
break
except Exception as find_error:
logger.warning(f"[EmojiService] 查找新注册表情包信息失败: {find_error}")
description = new_emoji_info.description if new_emoji_info else None
emotions = new_emoji_info.emotion if new_emoji_info else None
emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None
return {
"success": True,
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
"description": description,
"emotions": emotions,
"replaced": replaced,
"hash": emoji_hash,
}
else:
return {
"success": False,
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
except Exception as e:
logger.error(f"[EmojiService] 注册表情包时发生异常: {e}")
return {
"success": False,
"message": f"注册过程中发生错误: {str(e)}",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}

View File

@@ -35,7 +35,7 @@ logger = get_logger("generator_service")
# =============================================================================
def get_replyer(
def _get_replyer(
chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None,
request_type: str = "replyer",
@@ -58,6 +58,35 @@ def get_replyer(
return None
def _extract_unknown_words(action_data: Optional[Dict[str, Any]]) -> Optional[List[str]]:
if not action_data:
return None
unknown_words = action_data.get("unknown_words")
if not isinstance(unknown_words, list):
return None
cleaned_words: List[str] = []
for item in unknown_words:
if isinstance(item, str) and (cleaned_item := item.strip()):
cleaned_words.append(cleaned_item)
return cleaned_words or None
def _build_message_sequence(
content: Optional[str],
*,
enable_splitter: bool,
enable_chinese_typo: bool,
) -> tuple[Optional[MessageSequence], List[str]]:
if not content:
return None, []
processed_output = process_llm_response(content, enable_splitter, enable_chinese_typo)
return MessageSequence(components=[TextComponent(text) for text in processed_output]), processed_output
# =============================================================================
# 回复生成函数
# =============================================================================
@@ -87,7 +116,7 @@ async def generate_reply(
reply_time_point = time.time()
logger.debug("[GeneratorService] 开始生成回复")
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
replyer = _get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorService] 无法获取回复器")
return False, None
@@ -98,16 +127,7 @@ async def generate_reply(
if not reply_reason:
reply_reason = action_data.get("reason", "")
if unknown_words is None:
uw = action_data.get("unknown_words")
if isinstance(uw, list):
cleaned: List[str] = []
for item in uw:
if isinstance(item, str):
s = item.strip()
if s:
cleaned.append(s)
if cleaned:
unknown_words = cleaned
unknown_words = _extract_unknown_words(action_data)
success, llm_response = await replyer.generate_reply_with_context(
extra_info=extra_info,
@@ -126,13 +146,12 @@ async def generate_reply(
if not success:
logger.warning("[GeneratorService] 回复生成失败")
return False, None
reply_set: Optional[MessageSequence] = None
if content := llm_response.content:
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
llm_response.processed_output = processed_response
reply_set = MessageSequence(components=[])
for text in processed_response:
reply_set.components.append(TextComponent(text))
reply_set, processed_output = _build_message_sequence(
llm_response.content,
enable_splitter=enable_splitter,
enable_chinese_typo=enable_chinese_typo,
)
llm_response.processed_output = processed_output
llm_response.reply_set = reply_set
logger.debug(
f"[GeneratorService] 回复生成成功,生成了 {len(reply_set.components) if reply_set else 0} 个回复项"
@@ -181,7 +200,7 @@ async def rewrite_reply(
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
"""重写回复"""
try:
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
replyer = _get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorService] 无法获取回复器")
return False, None
@@ -198,9 +217,13 @@ async def rewrite_reply(
reason=reason,
reply_to=reply_to,
)
reply_set: Optional[MessageSequence] = None
if success and llm_response and (content := llm_response.content):
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
reply_set, processed_output = _build_message_sequence(
llm_response.content if success and llm_response else None,
enable_splitter=enable_splitter,
enable_chinese_typo=enable_chinese_typo,
)
if llm_response is not None:
llm_response.processed_output = processed_output
llm_response.reply_set = reply_set
if success:
logger.info(
@@ -219,44 +242,3 @@ async def rewrite_reply(
return False, None
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[MessageSequence]:
"""将文本处理为更拟人化的文本"""
if not isinstance(content, str):
raise ValueError("content 必须是字符串类型")
try:
reply_set = MessageSequence(components=[])
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
for text in processed_response:
reply_set.components.append(TextComponent(text))
return reply_set
except Exception as e:
logger.error(f"[GeneratorService] 处理人形文本时出错: {e}")
return None
async def generate_response_custom(
chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None,
request_type: str = "generator_api",
prompt: str = "",
) -> Optional[str]:
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorService] 无法获取回复器")
return None
try:
logger.debug("[GeneratorService] 开始生成自定义回复")
response, _, _, _ = await replyer.llm_generate_content(prompt)
if response:
logger.debug("[GeneratorService] 自定义回复生成成功")
return response
else:
logger.warning("[GeneratorService] 自定义回复生成失败")
return None
except Exception as e:
logger.error(f"[GeneratorService] 生成自定义回复时出错: {e}")
return None

View File

@@ -14,7 +14,6 @@ from src.common.database.database_model import ActionRecord, Images, ImageType
from src.common.message_repository import count_messages, find_messages
from src.common.utils.math_utils import translate_timestamp_to_human_readable
from src.common.utils.utils_action import ActionUtils
from src.chat.utils.utils import is_bot_self
from src.config.config import global_config
@@ -113,103 +112,6 @@ def get_messages_by_time_in_chat(
return _normalize_messages(messages)
def get_messages_by_time_in_chat_inclusive(
chat_id: str,
start_time: float,
end_time: float,
limit: int = 0,
limit_mode: str = "latest",
filter_mai: bool = False,
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
messages = find_messages(
message_filter={
"chat_id": chat_id,
"time": {
"$gte": start_time,
"$lte": end_time,
},
},
limit=limit,
limit_mode=limit_mode,
filter_bot=filter_mai,
filter_command=filter_command,
filter_intercept_message_level=filter_intercept_message_level,
)
return _normalize_messages(messages)
def get_messages_by_time_in_chat_for_users(
chat_id: str,
start_time: float,
end_time: float,
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
messages = find_messages(
message_filter={
"chat_id": chat_id,
"time": {
"$gte": start_time,
"$lte": end_time,
},
"user_id": {"$in": person_ids},
},
limit=limit,
limit_mode=limit_mode,
)
return _normalize_messages(messages)
def get_random_chat_messages(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
return get_messages_by_time(start_time, end_time, limit, limit_mode, filter_mai)
def get_messages_by_time_for_users(
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
messages = find_messages(
message_filter={
"time": {
"$gte": start_time,
"$lte": end_time,
},
"user_id": {"$in": person_ids},
},
limit=limit,
limit_mode=limit_mode,
)
return _normalize_messages(messages)
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[SessionMessage]:
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
@@ -252,24 +154,6 @@ def get_messages_before_time_in_chat(
return _normalize_messages(messages)
def get_messages_before_time_for_users(
timestamp: float, person_ids: List[str], limit: int = 0
) -> List[SessionMessage]:
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
messages = find_messages(
message_filter={
"time": {"$lt": timestamp},
"user_id": {"$in": person_ids},
},
limit=limit,
limit_mode="latest",
)
return _normalize_messages(messages)
def get_recent_messages(
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
) -> List[SessionMessage]:
@@ -307,22 +191,6 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
return count_messages(message_filter)
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return count_messages(
{
"chat_id": chat_id,
"time": {"$gt": start_time, "$lte": end_time},
"user_id": {"$in": person_ids},
}
)
# =============================================================================
# 消息格式化函数
# =============================================================================
@@ -365,17 +233,6 @@ def build_readable_messages(
return "\n".join(lines)
def build_readable_messages_to_str(
messages: List[SessionMessage],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
) -> str:
return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions)
def build_readable_messages_with_id(
messages: List[SessionMessage],
replace_bot_name: bool = True,
@@ -415,148 +272,6 @@ def build_readable_messages_with_id(
return "\n".join(lines), message_id_list
async def build_readable_messages_with_details(
messages: List[SessionMessage],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
normalized_messages = _normalize_messages(messages)
message_list = [
(
message.timestamp.timestamp(),
message.message_info.user_info.user_id,
message.processed_plain_text or "",
)
for message in normalized_messages
]
return build_readable_messages(normalized_messages, replace_bot_name, timestamp_mode, truncate=truncate), message_list
async def get_person_ids_from_messages(messages: List[Any]) -> List[str]:
person_ids: List[str] = []
for message in messages:
if isinstance(message, SessionMessage):
person_ids.append(message.message_info.user_info.user_id)
elif isinstance(message, dict) and (user_id := message.get("user_id")):
person_ids.append(str(user_id))
return person_ids
# =============================================================================
# 消息过滤函数
# =============================================================================
def filter_mai_messages(messages: List[SessionMessage]) -> List[SessionMessage]:
"""从消息列表中移除麦麦的消息"""
return [
msg
for msg in messages
if not is_bot_self(msg.platform, msg.message_info.user_info.user_id)
]
def get_raw_msg_by_timestamp(
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
return get_messages_by_time(timestamp_start, timestamp_end, limit, limit_mode)
def get_raw_msg_by_timestamp_with_chat(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
filter_bot: bool = False,
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[SessionMessage]:
return get_messages_by_time_in_chat(
chat_id,
timestamp_start,
timestamp_end,
limit,
limit_mode,
filter_bot,
filter_command,
filter_intercept_message_level,
)
def get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
filter_bot: bool = False,
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[SessionMessage]:
return get_messages_by_time_in_chat_inclusive(
chat_id,
timestamp_start,
timestamp_end,
limit,
limit_mode,
filter_bot,
filter_command,
filter_intercept_message_level,
)
def get_raw_msg_by_timestamp_with_chat_users(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
return get_messages_by_time_in_chat_for_users(chat_id, timestamp_start, timestamp_end, person_ids, limit, limit_mode)
def get_raw_msg_by_timestamp_with_users(
timestamp_start: float,
timestamp_end: float,
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
return get_messages_by_time_for_users(timestamp_start, timestamp_end, person_ids, limit, limit_mode)
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[SessionMessage]:
return get_messages_before_time(timestamp, limit)
def get_raw_msg_before_timestamp_with_chat(
chat_id: str,
timestamp: float,
limit: int = 0,
filter_intercept_message_level: Optional[int] = None,
) -> List[SessionMessage]:
return get_messages_before_time_in_chat(chat_id, timestamp, limit, False, filter_intercept_message_level)
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[SessionMessage]:
return get_messages_before_time_for_users(timestamp, person_ids, limit)
def get_raw_msg_by_timestamp_random(
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
return get_random_chat_messages(timestamp_start, timestamp_end, limit, limit_mode)
def get_actions_by_timestamp_with_chat(chat_id: str, timestamp_start: float, timestamp_end: float) -> List[MaiActionRecord]:
with get_db_session() as session:
statement = (