炸 service 层 x 2,把能归类为现有重构好的模块的都归类过去
This commit is contained in:
@@ -1,171 +0,0 @@
|
||||
"""
|
||||
聊天服务模块
|
||||
|
||||
提供聊天信息查询和管理的核心功能。
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from enum import Enum
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("chat_service")
|
||||
|
||||
|
||||
class SpecialTypes(Enum):
|
||||
"""特殊枚举类型"""
|
||||
|
||||
ALL_PLATFORMS = "all_platforms"
|
||||
|
||||
|
||||
class ChatManager:
|
||||
"""聊天管理器 - 负责聊天信息的查询和管理"""
|
||||
|
||||
@staticmethod
|
||||
def _validate_platform(platform: Optional[str] | SpecialTypes) -> None:
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
|
||||
@staticmethod
|
||||
def _match_platform(chat_stream: BotChatSession, platform: Optional[str] | SpecialTypes) -> bool:
|
||||
return platform == SpecialTypes.ALL_PLATFORMS or chat_stream.platform == platform
|
||||
|
||||
@staticmethod
|
||||
def _get_streams(
|
||||
platform: Optional[str] | SpecialTypes = "qq", is_group_session: Optional[bool] = None
|
||||
) -> List[BotChatSession]:
|
||||
ChatManager._validate_platform(platform)
|
||||
|
||||
try:
|
||||
streams = [
|
||||
stream
|
||||
for stream in _chat_manager.sessions.values()
|
||||
if ChatManager._match_platform(stream, platform)
|
||||
and (is_group_session is None or stream.is_group_session == is_group_session)
|
||||
]
|
||||
return streams
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatService] 获取聊天流失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _find_stream(
|
||||
predicate: Callable[[BotChatSession], bool],
|
||||
platform: Optional[str] | SpecialTypes = "qq",
|
||||
) -> Optional[BotChatSession]:
|
||||
for stream in ChatManager._get_streams(platform=platform):
|
||||
if predicate(stream):
|
||||
return stream
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
||||
streams = ChatManager._get_streams(platform=platform)
|
||||
logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的聊天流")
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
||||
streams = ChatManager._get_streams(platform=platform, is_group_session=True)
|
||||
logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的群聊流")
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
||||
streams = ChatManager._get_streams(platform=platform, is_group_session=False)
|
||||
logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的私聊流")
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_group_stream_by_group_id(
|
||||
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
|
||||
if not isinstance(group_id, str):
|
||||
raise TypeError("group_id 必须是字符串类型")
|
||||
ChatManager._validate_platform(platform)
|
||||
if not group_id:
|
||||
raise ValueError("group_id 不能为空")
|
||||
try:
|
||||
stream = ChatManager._find_stream(
|
||||
lambda item: item.is_group_session and str(item.group_id) == str(group_id),
|
||||
platform=platform,
|
||||
)
|
||||
if stream is not None:
|
||||
logger.debug(f"[ChatService] 找到群ID {group_id} 的聊天流")
|
||||
return stream
|
||||
logger.warning(f"[ChatService] 未找到群ID {group_id} 的聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatService] 查找群聊流失败: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_private_stream_by_user_id(
|
||||
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
|
||||
if not isinstance(user_id, str):
|
||||
raise TypeError("user_id 必须是字符串类型")
|
||||
ChatManager._validate_platform(platform)
|
||||
if not user_id:
|
||||
raise ValueError("user_id 不能为空")
|
||||
try:
|
||||
stream = ChatManager._find_stream(
|
||||
lambda item: (not item.is_group_session) and str(item.user_id) == str(user_id),
|
||||
platform=platform,
|
||||
)
|
||||
if stream is not None:
|
||||
logger.debug(f"[ChatService] 找到用户ID {user_id} 的私聊流")
|
||||
return stream
|
||||
logger.warning(f"[ChatService] 未找到用户ID {user_id} 的私聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatService] 查找私聊流失败: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_stream_type(chat_stream: BotChatSession) -> str:
|
||||
if not isinstance(chat_stream, BotChatSession):
|
||||
raise TypeError("chat_stream 必须是 BotChatSession 类型")
|
||||
if not chat_stream:
|
||||
raise ValueError("chat_stream 不能为 None")
|
||||
|
||||
return "group" if chat_stream.is_group_session else "private"
|
||||
|
||||
@staticmethod
|
||||
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
|
||||
if not chat_stream:
|
||||
raise ValueError("chat_stream 不能为 None")
|
||||
if not isinstance(chat_stream, BotChatSession):
|
||||
raise TypeError("chat_stream 必须是 BotChatSession 类型")
|
||||
|
||||
try:
|
||||
info: Dict[str, Any] = {
|
||||
"session_id": chat_stream.session_id,
|
||||
"platform": chat_stream.platform,
|
||||
"type": ChatManager.get_stream_type(chat_stream),
|
||||
}
|
||||
|
||||
if chat_stream.is_group_session:
|
||||
info["group_id"] = chat_stream.group_id
|
||||
if (
|
||||
chat_stream.context
|
||||
and chat_stream.context.message
|
||||
and chat_stream.context.message.message_info.group_info
|
||||
):
|
||||
info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
|
||||
else:
|
||||
info["group_name"] = "未知群聊"
|
||||
else:
|
||||
info["user_id"] = chat_stream.user_id
|
||||
if (
|
||||
chat_stream.context
|
||||
and chat_stream.context.message
|
||||
and chat_stream.context.message.message_info.user_info
|
||||
):
|
||||
info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
|
||||
else:
|
||||
info["user_name"] = "未知用户"
|
||||
|
||||
return info
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatService] 获取聊天流信息失败: {e}")
|
||||
return {}
|
||||
@@ -9,6 +9,7 @@ from typing import Any, Optional
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ActionRecord
|
||||
from src.common.logger import get_logger
|
||||
@@ -23,12 +24,10 @@ def _to_dict(record: Any) -> dict[str, Any]:
|
||||
return record
|
||||
if hasattr(record, "model_dump"):
|
||||
return record.model_dump()
|
||||
if hasattr(record, "__dict__"):
|
||||
return dict(record.__dict__)
|
||||
return {}
|
||||
return dict(record.__dict__) if hasattr(record, "__dict__") else {}
|
||||
|
||||
|
||||
def _get_model_field(model_class: type[SQLModel], field_name: str):
|
||||
def _get_model_field(model_class: type[SQLModel], field_name: str) -> Any:
|
||||
field = getattr(model_class, field_name, None)
|
||||
if field is None:
|
||||
raise ValueError(f"{model_class.__name__} 不存在字段 {field_name}")
|
||||
@@ -41,7 +40,7 @@ def _build_filters(model_class: type[SQLModel], filters: Optional[dict[str, Any]
|
||||
return [_get_model_field(model_class, field_name) == value for field_name, value in filters.items()]
|
||||
|
||||
|
||||
def _apply_order_by(statement, model_class: type[SQLModel], order_by: Optional[str | list[str]] = None):
|
||||
def _apply_order_by(statement: Any, model_class: type[SQLModel], order_by: Optional[str | list[str]] = None) -> Any:
|
||||
if not order_by:
|
||||
return statement
|
||||
|
||||
@@ -60,7 +59,7 @@ async def db_save(
|
||||
data: dict[str, Any],
|
||||
key_field: Optional[str] = None,
|
||||
key_value: Optional[Any] = None,
|
||||
):
|
||||
) -> Optional[dict[str, Any]]:
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
record = None
|
||||
@@ -91,12 +90,11 @@ async def db_get(
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[str | list[str]] = None,
|
||||
single_result: bool = False,
|
||||
):
|
||||
) -> Optional[dict[str, Any]] | list[dict[str, Any]]:
|
||||
try:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(model_class)
|
||||
conditions = _build_filters(model_class, filters)
|
||||
if conditions:
|
||||
if conditions := _build_filters(model_class, filters):
|
||||
statement = statement.where(*conditions)
|
||||
statement = _apply_order_by(statement, model_class, order_by)
|
||||
if limit:
|
||||
@@ -116,8 +114,7 @@ async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters:
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(model_class)
|
||||
conditions = _build_filters(model_class, filters)
|
||||
if conditions:
|
||||
if conditions := _build_filters(model_class, filters):
|
||||
statement = statement.where(*conditions)
|
||||
records = session.exec(statement).all()
|
||||
for record in records:
|
||||
@@ -136,8 +133,7 @@ async def db_delete(model_class: type[SQLModel], filters: Optional[dict[str, Any
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = delete(model_class)
|
||||
conditions = _build_filters(model_class, filters)
|
||||
if conditions:
|
||||
if conditions := _build_filters(model_class, filters):
|
||||
statement = statement.where(*conditions)
|
||||
result = session.exec(statement)
|
||||
return result.rowcount or 0
|
||||
@@ -151,8 +147,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
|
||||
try:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(func.count()).select_from(model_class)
|
||||
conditions = _build_filters(model_class, filters)
|
||||
if conditions:
|
||||
if conditions := _build_filters(model_class, filters):
|
||||
statement = statement.where(*conditions)
|
||||
result = session.exec(statement).one()
|
||||
return int(result or 0)
|
||||
@@ -163,18 +158,15 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
|
||||
|
||||
|
||||
async def store_action_info(
|
||||
chat_stream=None,
|
||||
chat_stream: BotChatSession,
|
||||
builtin_prompt: Optional[str] = None,
|
||||
display_prompt: str = "",
|
||||
thinking_id: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
action_data: Optional[dict[str, Any]] = None,
|
||||
action_name: str = "",
|
||||
action_reasoning: str = "",
|
||||
):
|
||||
) -> Optional[dict[str, Any]]:
|
||||
try:
|
||||
if chat_stream is None:
|
||||
raise ValueError("store_action_info 需要 chat_stream")
|
||||
|
||||
record_data = {
|
||||
"action_id": thinking_id or str(int(time.time() * 1000000)),
|
||||
"timestamp": datetime.now(),
|
||||
|
||||
@@ -1,406 +0,0 @@
|
||||
"""
|
||||
表情服务模块
|
||||
|
||||
提供表情包相关的核心功能。
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("emoji_service")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 表情包获取函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
|
||||
"""根据描述选择表情包"""
|
||||
if not description:
|
||||
raise ValueError("描述不能为空")
|
||||
if not isinstance(description, str):
|
||||
raise TypeError("描述必须是字符串类型")
|
||||
try:
|
||||
logger.debug(f"[EmojiService] 根据描述获取表情包: {description}")
|
||||
|
||||
emoji_obj = await emoji_manager.get_emoji_for_emotion(description)
|
||||
|
||||
if not emoji_obj:
|
||||
logger.warning(f"[EmojiService] 未找到匹配描述 '{description}' 的表情包")
|
||||
return None
|
||||
|
||||
emoji_path = str(emoji_obj.full_path)
|
||||
emoji_description = emoji_obj.description
|
||||
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else ""
|
||||
emoji_base64 = ImageUtils.image_path_to_base64(emoji_path)
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiService] 无法将表情包文件转换为base64: {emoji_path}")
|
||||
return None
|
||||
|
||||
logger.debug(f"[EmojiService] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}")
|
||||
return emoji_base64, emoji_description, matched_emotion
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 获取表情包失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
||||
"""随机获取指定数量的表情包"""
|
||||
if not isinstance(count, int):
|
||||
raise TypeError("count 必须是整数类型")
|
||||
if count < 0:
|
||||
raise ValueError("count 不能为负数")
|
||||
if count == 0:
|
||||
logger.warning("[EmojiService] count 为0,返回空列表")
|
||||
return []
|
||||
|
||||
try:
|
||||
all_emojis = emoji_manager.emojis
|
||||
|
||||
if not all_emojis:
|
||||
logger.warning("[EmojiService] 没有可用的表情包")
|
||||
return []
|
||||
|
||||
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
|
||||
if not valid_emojis:
|
||||
logger.warning("[EmojiService] 没有有效的表情包")
|
||||
return []
|
||||
|
||||
if len(valid_emojis) < count:
|
||||
logger.debug(
|
||||
f"[EmojiService] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包"
|
||||
)
|
||||
count = len(valid_emojis)
|
||||
|
||||
selected_emojis = random.sample(valid_emojis, count)
|
||||
|
||||
results = []
|
||||
for selected_emoji in selected_emojis:
|
||||
emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}")
|
||||
continue
|
||||
|
||||
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
|
||||
|
||||
emoji_manager.update_emoji_usage(selected_emoji)
|
||||
results.append((emoji_base64, selected_emoji.description, matched_emotion))
|
||||
|
||||
if not results and count > 0:
|
||||
logger.warning("[EmojiService] 随机获取表情包失败,没有一个可以成功处理")
|
||||
return []
|
||||
|
||||
logger.debug(f"[EmojiService] 成功获取 {len(results)} 个随机表情包")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 获取随机表情包失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||
"""根据情感标签获取表情包"""
|
||||
if not emotion:
|
||||
raise ValueError("情感标签不能为空")
|
||||
if not isinstance(emotion, str):
|
||||
raise TypeError("情感标签必须是字符串类型")
|
||||
try:
|
||||
logger.info(f"[EmojiService] 根据情感获取表情包: {emotion}")
|
||||
|
||||
all_emojis = emoji_manager.emojis
|
||||
|
||||
matching_emojis = []
|
||||
matching_emojis.extend(
|
||||
emoji_obj
|
||||
for emoji_obj in all_emojis
|
||||
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
|
||||
)
|
||||
if not matching_emojis:
|
||||
logger.warning(f"[EmojiService] 未找到匹配情感 '{emotion}' 的表情包")
|
||||
return None
|
||||
|
||||
selected_emoji = random.choice(matching_emojis)
|
||||
emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path)
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}")
|
||||
return None
|
||||
|
||||
emoji_manager.update_emoji_usage(selected_emoji)
|
||||
|
||||
logger.info(f"[EmojiService] 成功获取情感表情包: {selected_emoji.description}")
|
||||
return emoji_base64, selected_emoji.description, emotion
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 根据情感获取表情包失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 表情包信息查询函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_count() -> int:
|
||||
try:
|
||||
return len(emoji_manager.emojis)
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 获取表情包数量失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def get_info():
|
||||
try:
|
||||
return {
|
||||
"current_count": len(emoji_manager.emojis),
|
||||
"max_count": global_config.emoji.max_reg_num,
|
||||
"available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 获取表情包信息失败: {e}")
|
||||
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
|
||||
|
||||
|
||||
def get_emotions() -> List[str]:
|
||||
try:
|
||||
emotions = set()
|
||||
|
||||
for emoji_obj in emoji_manager.emojis:
|
||||
if not emoji_obj.is_deleted and emoji_obj.emotion:
|
||||
emotions.update(emoji_obj.emotion)
|
||||
|
||||
return sorted(list(emotions))
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 获取情感标签失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_all() -> List[Tuple[str, str, str]]:
|
||||
try:
|
||||
all_emojis = emoji_manager.emojis
|
||||
|
||||
if not all_emojis:
|
||||
logger.warning("[EmojiService] 没有可用的表情包")
|
||||
return []
|
||||
|
||||
results = []
|
||||
for emoji_obj in all_emojis:
|
||||
if emoji_obj.is_deleted:
|
||||
continue
|
||||
|
||||
emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path))
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiService] 无法转换表情包为base64: {emoji_obj.full_path}")
|
||||
continue
|
||||
|
||||
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "随机表情"
|
||||
results.append((emoji_base64, emoji_obj.description, matched_emotion))
|
||||
|
||||
logger.debug(f"[EmojiService] 成功获取 {len(results)} 个表情包")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 获取所有表情包失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_descriptions() -> List[str]:
|
||||
try:
|
||||
descriptions = []
|
||||
|
||||
descriptions.extend(
|
||||
emoji_obj.description
|
||||
for emoji_obj in emoji_manager.emojis
|
||||
if not emoji_obj.is_deleted and emoji_obj.description
|
||||
)
|
||||
return descriptions
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 获取表情包描述失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 表情包注册函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def register_emoji(image_base64: str, filename: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""注册新的表情包"""
|
||||
if not image_base64:
|
||||
raise ValueError("图片base64编码不能为空")
|
||||
if not isinstance(image_base64, str):
|
||||
raise TypeError("image_base64必须是字符串类型")
|
||||
if filename is not None and not isinstance(filename, str):
|
||||
raise TypeError("filename必须是字符串类型或None")
|
||||
|
||||
try:
|
||||
logger.info(f"[EmojiService] 开始注册表情包,文件名: {filename or '自动生成'}")
|
||||
|
||||
count_before = len(emoji_manager.emojis)
|
||||
max_count = global_config.emoji.max_reg_num
|
||||
|
||||
can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace)
|
||||
|
||||
if not can_register:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"表情包数量已达上限({count_before}/{max_count})且未启用替换功能",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||
|
||||
if not filename:
|
||||
import time as _time
|
||||
|
||||
timestamp = int(_time.time())
|
||||
microseconds = int(_time.time() * 1000000) % 1000000
|
||||
|
||||
random_bytes = random.getrandbits(72).to_bytes(9, "big")
|
||||
short_id = base64.b64encode(random_bytes).decode("ascii")[:12].rstrip("=")
|
||||
short_id = short_id.replace("/", "_").replace("+", "-")
|
||||
filename = f"emoji_{timestamp}_{microseconds}_{short_id}"
|
||||
|
||||
if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
|
||||
filename = f"{filename}.png"
|
||||
|
||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
||||
attempts = 0
|
||||
max_attempts = 10
|
||||
while os.path.exists(temp_file_path) and attempts < max_attempts:
|
||||
random_bytes = random.getrandbits(48).to_bytes(6, "big")
|
||||
short_id = base64.b64encode(random_bytes).decode("ascii")[:8].rstrip("=")
|
||||
short_id = short_id.replace("/", "_").replace("+", "-")
|
||||
|
||||
name_part, ext = os.path.splitext(filename)
|
||||
base_name = name_part.rsplit("_", 1)[0]
|
||||
filename = f"{base_name}_{short_id}{ext}"
|
||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
||||
attempts += 1
|
||||
|
||||
if os.path.exists(temp_file_path):
|
||||
uuid_short = str(uuid.uuid4())[:8]
|
||||
name_part, ext = os.path.splitext(filename)
|
||||
base_name = name_part.rsplit("_", 1)[0]
|
||||
filename = f"{base_name}_{uuid_short}{ext}"
|
||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
||||
|
||||
counter = 1
|
||||
original_filename = filename
|
||||
while os.path.exists(temp_file_path):
|
||||
name_part, ext = os.path.splitext(original_filename)
|
||||
filename = f"{name_part}_{counter}{ext}"
|
||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
||||
counter += 1
|
||||
|
||||
if counter > 100:
|
||||
logger.error(f"[EmojiService] 无法生成唯一文件名,尝试次数过多: {original_filename}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "无法生成唯一文件名,请稍后重试",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
try:
|
||||
if not ImageUtils.base64_to_image(image_base64, temp_file_path):
|
||||
logger.error(f"[EmojiService] 无法保存base64图片到文件: {temp_file_path}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "无法保存图片文件",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
logger.debug(f"[EmojiService] 图片已保存到临时文件: {temp_file_path}")
|
||||
|
||||
except Exception as save_error:
|
||||
logger.error(f"[EmojiService] 保存图片文件失败: {save_error}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"保存图片文件失败: {str(save_error)}",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
register_success = await emoji_manager.register_emoji_by_filename(filename)
|
||||
|
||||
if not register_success and os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
logger.debug(f"[EmojiService] 已清理临时文件: {temp_file_path}")
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"[EmojiService] 清理临时文件失败: {cleanup_error}")
|
||||
|
||||
if register_success:
|
||||
count_after = len(emoji_manager.emojis)
|
||||
replaced = count_after <= count_before
|
||||
|
||||
new_emoji_info = None
|
||||
if count_after > count_before or replaced:
|
||||
try:
|
||||
for emoji_obj in reversed(emoji_manager.emojis):
|
||||
if not emoji_obj.is_deleted and (
|
||||
emoji_obj.file_name == filename
|
||||
or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path))
|
||||
):
|
||||
new_emoji_info = emoji_obj
|
||||
break
|
||||
except Exception as find_error:
|
||||
logger.warning(f"[EmojiService] 查找新注册表情包信息失败: {find_error}")
|
||||
|
||||
description = new_emoji_info.description if new_emoji_info else None
|
||||
emotions = new_emoji_info.emotion if new_emoji_info else None
|
||||
emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
|
||||
"description": description,
|
||||
"emotions": emotions,
|
||||
"replaced": replaced,
|
||||
"hash": emoji_hash,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 注册表情包时发生异常: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"注册过程中发生错误: {str(e)}",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
@@ -35,7 +35,7 @@ logger = get_logger("generator_service")
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_replyer(
|
||||
def _get_replyer(
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
request_type: str = "replyer",
|
||||
@@ -58,6 +58,35 @@ def get_replyer(
|
||||
return None
|
||||
|
||||
|
||||
def _extract_unknown_words(action_data: Optional[Dict[str, Any]]) -> Optional[List[str]]:
|
||||
if not action_data:
|
||||
return None
|
||||
|
||||
unknown_words = action_data.get("unknown_words")
|
||||
if not isinstance(unknown_words, list):
|
||||
return None
|
||||
|
||||
cleaned_words: List[str] = []
|
||||
for item in unknown_words:
|
||||
if isinstance(item, str) and (cleaned_item := item.strip()):
|
||||
cleaned_words.append(cleaned_item)
|
||||
|
||||
return cleaned_words or None
|
||||
|
||||
|
||||
def _build_message_sequence(
|
||||
content: Optional[str],
|
||||
*,
|
||||
enable_splitter: bool,
|
||||
enable_chinese_typo: bool,
|
||||
) -> tuple[Optional[MessageSequence], List[str]]:
|
||||
if not content:
|
||||
return None, []
|
||||
|
||||
processed_output = process_llm_response(content, enable_splitter, enable_chinese_typo)
|
||||
return MessageSequence(components=[TextComponent(text) for text in processed_output]), processed_output
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 回复生成函数
|
||||
# =============================================================================
|
||||
@@ -87,7 +116,7 @@ async def generate_reply(
|
||||
reply_time_point = time.time()
|
||||
|
||||
logger.debug("[GeneratorService] 开始生成回复")
|
||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||
replyer = _get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorService] 无法获取回复器")
|
||||
return False, None
|
||||
@@ -98,16 +127,7 @@ async def generate_reply(
|
||||
if not reply_reason:
|
||||
reply_reason = action_data.get("reason", "")
|
||||
if unknown_words is None:
|
||||
uw = action_data.get("unknown_words")
|
||||
if isinstance(uw, list):
|
||||
cleaned: List[str] = []
|
||||
for item in uw:
|
||||
if isinstance(item, str):
|
||||
s = item.strip()
|
||||
if s:
|
||||
cleaned.append(s)
|
||||
if cleaned:
|
||||
unknown_words = cleaned
|
||||
unknown_words = _extract_unknown_words(action_data)
|
||||
|
||||
success, llm_response = await replyer.generate_reply_with_context(
|
||||
extra_info=extra_info,
|
||||
@@ -126,13 +146,12 @@ async def generate_reply(
|
||||
if not success:
|
||||
logger.warning("[GeneratorService] 回复生成失败")
|
||||
return False, None
|
||||
reply_set: Optional[MessageSequence] = None
|
||||
if content := llm_response.content:
|
||||
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
|
||||
llm_response.processed_output = processed_response
|
||||
reply_set = MessageSequence(components=[])
|
||||
for text in processed_response:
|
||||
reply_set.components.append(TextComponent(text))
|
||||
reply_set, processed_output = _build_message_sequence(
|
||||
llm_response.content,
|
||||
enable_splitter=enable_splitter,
|
||||
enable_chinese_typo=enable_chinese_typo,
|
||||
)
|
||||
llm_response.processed_output = processed_output
|
||||
llm_response.reply_set = reply_set
|
||||
logger.debug(
|
||||
f"[GeneratorService] 回复生成成功,生成了 {len(reply_set.components) if reply_set else 0} 个回复项"
|
||||
@@ -181,7 +200,7 @@ async def rewrite_reply(
|
||||
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
|
||||
"""重写回复"""
|
||||
try:
|
||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||
replyer = _get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorService] 无法获取回复器")
|
||||
return False, None
|
||||
@@ -198,9 +217,13 @@ async def rewrite_reply(
|
||||
reason=reason,
|
||||
reply_to=reply_to,
|
||||
)
|
||||
reply_set: Optional[MessageSequence] = None
|
||||
if success and llm_response and (content := llm_response.content):
|
||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||
reply_set, processed_output = _build_message_sequence(
|
||||
llm_response.content if success and llm_response else None,
|
||||
enable_splitter=enable_splitter,
|
||||
enable_chinese_typo=enable_chinese_typo,
|
||||
)
|
||||
if llm_response is not None:
|
||||
llm_response.processed_output = processed_output
|
||||
llm_response.reply_set = reply_set
|
||||
if success:
|
||||
logger.info(
|
||||
@@ -219,44 +242,3 @@ async def rewrite_reply(
|
||||
return False, None
|
||||
|
||||
|
||||
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[MessageSequence]:
|
||||
"""将文本处理为更拟人化的文本"""
|
||||
if not isinstance(content, str):
|
||||
raise ValueError("content 必须是字符串类型")
|
||||
try:
|
||||
reply_set = MessageSequence(components=[])
|
||||
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
|
||||
|
||||
for text in processed_response:
|
||||
reply_set.components.append(TextComponent(text))
|
||||
|
||||
return reply_set
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorService] 处理人形文本时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def generate_response_custom(
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
request_type: str = "generator_api",
|
||||
prompt: str = "",
|
||||
) -> Optional[str]:
|
||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorService] 无法获取回复器")
|
||||
return None
|
||||
|
||||
try:
|
||||
logger.debug("[GeneratorService] 开始生成自定义回复")
|
||||
response, _, _, _ = await replyer.llm_generate_content(prompt)
|
||||
if response:
|
||||
logger.debug("[GeneratorService] 自定义回复生成成功")
|
||||
return response
|
||||
else:
|
||||
logger.warning("[GeneratorService] 自定义回复生成失败")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorService] 生成自定义回复时出错: {e}")
|
||||
return None
|
||||
|
||||
@@ -14,7 +14,6 @@ from src.common.database.database_model import ActionRecord, Images, ImageType
|
||||
from src.common.message_repository import count_messages, find_messages
|
||||
from src.common.utils.math_utils import translate_timestamp_to_human_readable
|
||||
from src.common.utils.utils_action import ActionUtils
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
@@ -113,103 +112,6 @@ def get_messages_by_time_in_chat(
|
||||
return _normalize_messages(messages)
|
||||
|
||||
|
||||
def get_messages_by_time_in_chat_inclusive(
|
||||
chat_id: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_mai: bool = False,
|
||||
filter_command: bool = False,
|
||||
filter_intercept_message_level: Optional[int] = None,
|
||||
) -> List[SessionMessage]:
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
messages = find_messages(
|
||||
message_filter={
|
||||
"chat_id": chat_id,
|
||||
"time": {
|
||||
"$gte": start_time,
|
||||
"$lte": end_time,
|
||||
},
|
||||
},
|
||||
limit=limit,
|
||||
limit_mode=limit_mode,
|
||||
filter_bot=filter_mai,
|
||||
filter_command=filter_command,
|
||||
filter_intercept_message_level=filter_intercept_message_level,
|
||||
)
|
||||
return _normalize_messages(messages)
|
||||
|
||||
|
||||
def get_messages_by_time_in_chat_for_users(
|
||||
chat_id: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
person_ids: List[str],
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[SessionMessage]:
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
messages = find_messages(
|
||||
message_filter={
|
||||
"chat_id": chat_id,
|
||||
"time": {
|
||||
"$gte": start_time,
|
||||
"$lte": end_time,
|
||||
},
|
||||
"user_id": {"$in": person_ids},
|
||||
},
|
||||
limit=limit,
|
||||
limit_mode=limit_mode,
|
||||
)
|
||||
return _normalize_messages(messages)
|
||||
|
||||
|
||||
def get_random_chat_messages(
|
||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[SessionMessage]:
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
return get_messages_by_time(start_time, end_time, limit, limit_mode, filter_mai)
|
||||
|
||||
|
||||
def get_messages_by_time_for_users(
|
||||
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[SessionMessage]:
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
messages = find_messages(
|
||||
message_filter={
|
||||
"time": {
|
||||
"$gte": start_time,
|
||||
"$lte": end_time,
|
||||
},
|
||||
"user_id": {"$in": person_ids},
|
||||
},
|
||||
limit=limit,
|
||||
limit_mode=limit_mode,
|
||||
)
|
||||
return _normalize_messages(messages)
|
||||
|
||||
|
||||
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[SessionMessage]:
|
||||
if not isinstance(timestamp, (int, float)):
|
||||
raise ValueError("timestamp 必须是数字类型")
|
||||
@@ -252,24 +154,6 @@ def get_messages_before_time_in_chat(
|
||||
return _normalize_messages(messages)
|
||||
|
||||
|
||||
def get_messages_before_time_for_users(
|
||||
timestamp: float, person_ids: List[str], limit: int = 0
|
||||
) -> List[SessionMessage]:
|
||||
if not isinstance(timestamp, (int, float)):
|
||||
raise ValueError("timestamp 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
messages = find_messages(
|
||||
message_filter={
|
||||
"time": {"$lt": timestamp},
|
||||
"user_id": {"$in": person_ids},
|
||||
},
|
||||
limit=limit,
|
||||
limit_mode="latest",
|
||||
)
|
||||
return _normalize_messages(messages)
|
||||
|
||||
|
||||
def get_recent_messages(
|
||||
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[SessionMessage]:
|
||||
@@ -307,22 +191,6 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
|
||||
return count_messages(message_filter)
|
||||
|
||||
|
||||
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
return count_messages(
|
||||
{
|
||||
"chat_id": chat_id,
|
||||
"time": {"$gt": start_time, "$lte": end_time},
|
||||
"user_id": {"$in": person_ids},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 消息格式化函数
|
||||
# =============================================================================
|
||||
@@ -365,17 +233,6 @@ def build_readable_messages(
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def build_readable_messages_to_str(
|
||||
messages: List[SessionMessage],
|
||||
replace_bot_name: bool = True,
|
||||
timestamp_mode: str = "relative",
|
||||
read_mark: float = 0.0,
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
) -> str:
|
||||
return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions)
|
||||
|
||||
|
||||
def build_readable_messages_with_id(
|
||||
messages: List[SessionMessage],
|
||||
replace_bot_name: bool = True,
|
||||
@@ -415,148 +272,6 @@ def build_readable_messages_with_id(
|
||||
return "\n".join(lines), message_id_list
|
||||
|
||||
|
||||
async def build_readable_messages_with_details(
|
||||
messages: List[SessionMessage],
|
||||
replace_bot_name: bool = True,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||
normalized_messages = _normalize_messages(messages)
|
||||
message_list = [
|
||||
(
|
||||
message.timestamp.timestamp(),
|
||||
message.message_info.user_info.user_id,
|
||||
message.processed_plain_text or "",
|
||||
)
|
||||
for message in normalized_messages
|
||||
]
|
||||
return build_readable_messages(normalized_messages, replace_bot_name, timestamp_mode, truncate=truncate), message_list
|
||||
|
||||
|
||||
async def get_person_ids_from_messages(messages: List[Any]) -> List[str]:
|
||||
person_ids: List[str] = []
|
||||
for message in messages:
|
||||
if isinstance(message, SessionMessage):
|
||||
person_ids.append(message.message_info.user_info.user_id)
|
||||
elif isinstance(message, dict) and (user_id := message.get("user_id")):
|
||||
person_ids.append(str(user_id))
|
||||
return person_ids
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 消息过滤函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def filter_mai_messages(messages: List[SessionMessage]) -> List[SessionMessage]:
|
||||
"""从消息列表中移除麦麦的消息"""
|
||||
return [
|
||||
msg
|
||||
for msg in messages
|
||||
if not is_bot_self(msg.platform, msg.message_info.user_info.user_id)
|
||||
]
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp(
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[SessionMessage]:
|
||||
return get_messages_by_time(timestamp_start, timestamp_end, limit, limit_mode)
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id: str,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_bot: bool = False,
|
||||
filter_command: bool = False,
|
||||
filter_intercept_message_level: Optional[int] = None,
|
||||
) -> List[SessionMessage]:
|
||||
return get_messages_by_time_in_chat(
|
||||
chat_id,
|
||||
timestamp_start,
|
||||
timestamp_end,
|
||||
limit,
|
||||
limit_mode,
|
||||
filter_bot,
|
||||
filter_command,
|
||||
filter_intercept_message_level,
|
||||
)
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_bot: bool = False,
|
||||
filter_command: bool = False,
|
||||
filter_intercept_message_level: Optional[int] = None,
|
||||
) -> List[SessionMessage]:
|
||||
return get_messages_by_time_in_chat_inclusive(
|
||||
chat_id,
|
||||
timestamp_start,
|
||||
timestamp_end,
|
||||
limit,
|
||||
limit_mode,
|
||||
filter_bot,
|
||||
filter_command,
|
||||
filter_intercept_message_level,
|
||||
)
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_with_chat_users(
|
||||
chat_id: str,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
person_ids: List[str],
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[SessionMessage]:
|
||||
return get_messages_by_time_in_chat_for_users(chat_id, timestamp_start, timestamp_end, person_ids, limit, limit_mode)
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_with_users(
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
person_ids: List[str],
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[SessionMessage]:
|
||||
return get_messages_by_time_for_users(timestamp_start, timestamp_end, person_ids, limit, limit_mode)
|
||||
|
||||
|
||||
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[SessionMessage]:
|
||||
return get_messages_before_time(timestamp, limit)
|
||||
|
||||
|
||||
def get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id: str,
|
||||
timestamp: float,
|
||||
limit: int = 0,
|
||||
filter_intercept_message_level: Optional[int] = None,
|
||||
) -> List[SessionMessage]:
|
||||
return get_messages_before_time_in_chat(chat_id, timestamp, limit, False, filter_intercept_message_level)
|
||||
|
||||
|
||||
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[SessionMessage]:
|
||||
return get_messages_before_time_for_users(timestamp, person_ids, limit)
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_random(
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[SessionMessage]:
|
||||
return get_random_chat_messages(timestamp_start, timestamp_end, limit, limit_mode)
|
||||
|
||||
|
||||
def get_actions_by_timestamp_with_chat(chat_id: str, timestamp_start: float, timestamp_end: float) -> List[MaiActionRecord]:
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
|
||||
Reference in New Issue
Block a user