重构整个插件系统,尝试恢复可启动性,新增插件系统maibot-plugin-sdk依赖
This commit is contained in:
7
src/services/__init__.py
Normal file
7
src/services/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
核心服务层
|
||||
|
||||
提供与具体插件系统无关的核心业务服务。
|
||||
内部模块(chat、dream、memory 等)应直接使用此层,
|
||||
而 plugin_system.apis 仅作为面向插件的薄包装。
|
||||
"""
|
||||
159
src/services/chat_service.py
Normal file
159
src/services/chat_service.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
聊天服务模块
|
||||
|
||||
提供聊天信息查询和管理的核心功能。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from enum import Enum
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("chat_service")
|
||||
|
||||
|
||||
class SpecialTypes(Enum):
|
||||
"""特殊枚举类型"""
|
||||
|
||||
ALL_PLATFORMS = "all_platforms"
|
||||
|
||||
|
||||
class ChatManager:
|
||||
"""聊天管理器 - 负责聊天信息的查询和管理"""
|
||||
|
||||
@staticmethod
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in _chat_manager.sessions.items():
|
||||
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatService] 获取聊天流失败: {e}")
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in _chat_manager.sessions.items():
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的群聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatService] 获取群聊流失败: {e}")
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in _chat_manager.sessions.items():
|
||||
if (
|
||||
platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
|
||||
) and not stream.is_group_session:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatService] 获取到 {len(streams)} 个 {platform} 平台的私聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatService] 获取私聊流失败: {e}")
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_group_stream_by_group_id(
|
||||
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
|
||||
if not isinstance(group_id, str):
|
||||
raise TypeError("group_id 必须是字符串类型")
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
if not group_id:
|
||||
raise ValueError("group_id 不能为空")
|
||||
try:
|
||||
for _, stream in _chat_manager.sessions.items():
|
||||
if (
|
||||
stream.is_group_session
|
||||
and str(stream.group_id) == str(group_id)
|
||||
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
|
||||
):
|
||||
logger.debug(f"[ChatService] 找到群ID {group_id} 的聊天流")
|
||||
return stream
|
||||
logger.warning(f"[ChatService] 未找到群ID {group_id} 的聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatService] 查找群聊流失败: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_private_stream_by_user_id(
|
||||
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
|
||||
if not isinstance(user_id, str):
|
||||
raise TypeError("user_id 必须是字符串类型")
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
if not user_id:
|
||||
raise ValueError("user_id 不能为空")
|
||||
try:
|
||||
for _, stream in _chat_manager.sessions.items():
|
||||
if (
|
||||
not stream.is_group_session
|
||||
and str(stream.user_id) == str(user_id)
|
||||
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
|
||||
):
|
||||
logger.debug(f"[ChatService] 找到用户ID {user_id} 的私聊流")
|
||||
return stream
|
||||
logger.warning(f"[ChatService] 未找到用户ID {user_id} 的私聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatService] 查找私聊流失败: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_stream_type(chat_stream: BotChatSession) -> str:
|
||||
if not isinstance(chat_stream, BotChatSession):
|
||||
raise TypeError("chat_stream 必须是 BotChatSession 类型")
|
||||
if not chat_stream:
|
||||
raise ValueError("chat_stream 不能为 None")
|
||||
|
||||
return "group" if chat_stream.is_group_session else "private"
|
||||
|
||||
@staticmethod
|
||||
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
|
||||
if not chat_stream:
|
||||
raise ValueError("chat_stream 不能为 None")
|
||||
if not isinstance(chat_stream, BotChatSession):
|
||||
raise TypeError("chat_stream 必须是 BotChatSession 类型")
|
||||
|
||||
try:
|
||||
info: Dict[str, Any] = {
|
||||
"session_id": chat_stream.session_id,
|
||||
"platform": chat_stream.platform,
|
||||
"type": ChatManager.get_stream_type(chat_stream),
|
||||
}
|
||||
|
||||
if chat_stream.is_group_session:
|
||||
info["group_id"] = chat_stream.group_id
|
||||
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info:
|
||||
info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
|
||||
else:
|
||||
info["group_name"] = "未知群聊"
|
||||
else:
|
||||
info["user_id"] = chat_stream.user_id
|
||||
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info:
|
||||
info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
|
||||
else:
|
||||
info["user_name"] = "未知用户"
|
||||
|
||||
return info
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatService] 获取聊天流信息失败: {e}")
|
||||
return {}
|
||||
66
src/services/config_service.py
Normal file
66
src/services/config_service.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""配置服务模块
|
||||
|
||||
提供配置读取的核心功能。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("config_service")
|
||||
|
||||
|
||||
def get_global_config(key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
安全地从全局配置中获取一个值。
|
||||
|
||||
Args:
|
||||
key: 命名空间式配置键名,使用嵌套访问,如 "section.subsection.key",大小写敏感
|
||||
default: 如果配置不存在时返回的默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
keys = key.split(".")
|
||||
current = global_config
|
||||
|
||||
try:
|
||||
for k in keys:
|
||||
if hasattr(current, k):
|
||||
current = getattr(current, k)
|
||||
else:
|
||||
raise KeyError(f"配置中不存在子空间或键 '{k}'")
|
||||
return current
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConfigService] 获取全局配置 {key} 失败: {e}")
|
||||
return default
|
||||
|
||||
|
||||
def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
从插件配置中获取值,支持嵌套键访问
|
||||
|
||||
Args:
|
||||
plugin_config: 插件配置字典
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key",大小写敏感
|
||||
default: 如果配置不存在时返回的默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
keys = key.split(".")
|
||||
current = plugin_config
|
||||
|
||||
try:
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
elif hasattr(current, k):
|
||||
current = getattr(current, k)
|
||||
else:
|
||||
raise KeyError(f"配置中不存在子空间或键 '{k}'")
|
||||
return current
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConfigService] 获取插件配置 {key} 失败: {e}")
|
||||
return default
|
||||
173
src/services/database_service.py
Normal file
173
src/services/database_service.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""数据库服务模块
|
||||
|
||||
提供数据库操作相关的核心功能。
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("database_service")
|
||||
|
||||
|
||||
def _to_dict(record: Any) -> dict[str, Any]:
|
||||
if record is None:
|
||||
return {}
|
||||
if isinstance(record, dict):
|
||||
return record
|
||||
if hasattr(record, "model_dump"):
|
||||
return record.model_dump()
|
||||
if hasattr(record, "__dict__"):
|
||||
return dict(record.__dict__)
|
||||
return {}
|
||||
|
||||
|
||||
async def db_query(
|
||||
model_class,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
query_type: str = "get",
|
||||
filters: Optional[dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[list[str]] = None,
|
||||
single_result: bool = False,
|
||||
):
|
||||
try:
|
||||
if query_type not in ["get", "create", "update", "delete", "count"]:
|
||||
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
|
||||
|
||||
if query_type == "get":
|
||||
query = model_class.select()
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
if order_by:
|
||||
query = query.order_by(*order_by)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
results = list(query.dicts())
|
||||
if single_result:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
if query_type == "create":
|
||||
if not data:
|
||||
raise ValueError("创建记录需要提供data参数")
|
||||
record = model_class.create(**data)
|
||||
return _to_dict(record)
|
||||
|
||||
query = model_class.select()
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
if query_type == "update":
|
||||
if not data:
|
||||
raise ValueError("更新记录需要提供data参数")
|
||||
return query.model_class.update(**data).where(*query.stmt._where_criteria).execute()
|
||||
|
||||
if query_type == "delete":
|
||||
return model_class.delete().where(*query.stmt._where_criteria).execute()
|
||||
|
||||
return query.count()
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseService] 数据库操作出错: {e}")
|
||||
traceback.print_exc()
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
return None
|
||||
|
||||
|
||||
async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None):
|
||||
try:
|
||||
if key_field and key_value is not None:
|
||||
record = model_class.get_or_none(getattr(model_class, key_field) == key_value)
|
||||
if record is not None:
|
||||
for field, value in data.items():
|
||||
setattr(record, field, value)
|
||||
record.save()
|
||||
return _to_dict(record)
|
||||
|
||||
new_record = model_class.create(**data)
|
||||
return _to_dict(new_record)
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseService] 保存数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
async def db_get(
|
||||
model_class,
|
||||
filters: Optional[dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[str] = None,
|
||||
single_result: bool = False,
|
||||
):
|
||||
try:
|
||||
query = model_class.select()
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
if order_by:
|
||||
query = query.order_by(order_by)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
results = list(query.dicts())
|
||||
if single_result:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseService] 获取数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None if single_result else []
|
||||
|
||||
|
||||
async def store_action_info(
|
||||
chat_stream=None,
|
||||
action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
thinking_id: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
action_name: str = "",
|
||||
action_reasoning: str = "",
|
||||
):
|
||||
try:
|
||||
from src.common.database.database_model import ActionRecords
|
||||
|
||||
record_data = {
|
||||
"action_id": thinking_id or str(int(time.time() * 1000000)),
|
||||
"time": time.time(),
|
||||
"action_name": action_name,
|
||||
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
|
||||
"action_done": action_done,
|
||||
"action_reasoning": action_reasoning,
|
||||
"action_build_into_prompt": action_build_into_prompt,
|
||||
"action_prompt_display": action_prompt_display,
|
||||
}
|
||||
|
||||
if chat_stream:
|
||||
record_data.update(
|
||||
{
|
||||
"chat_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_platform": getattr(chat_stream, "platform", ""),
|
||||
}
|
||||
)
|
||||
else:
|
||||
record_data.update({"chat_id": "", "chat_info_stream_id": "", "chat_info_platform": ""})
|
||||
|
||||
saved_record = await db_save(
|
||||
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
|
||||
)
|
||||
if saved_record:
|
||||
logger.debug(f"[DatabaseService] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
|
||||
else:
|
||||
logger.error(f"[DatabaseService] 存储动作信息失败: {action_name}")
|
||||
return saved_record
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseService] 存储动作信息时发生错误: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
406
src/services/emoji_service.py
Normal file
406
src/services/emoji_service.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""
|
||||
表情服务模块
|
||||
|
||||
提供表情包相关的核心功能。
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("emoji_service")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 表情包获取函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
|
||||
"""根据描述选择表情包"""
|
||||
if not description:
|
||||
raise ValueError("描述不能为空")
|
||||
if not isinstance(description, str):
|
||||
raise TypeError("描述必须是字符串类型")
|
||||
try:
|
||||
logger.debug(f"[EmojiService] 根据描述获取表情包: {description}")
|
||||
|
||||
emoji_obj = await emoji_manager.get_emoji_for_emotion(description)
|
||||
|
||||
if not emoji_obj:
|
||||
logger.warning(f"[EmojiService] 未找到匹配描述 '{description}' 的表情包")
|
||||
return None
|
||||
|
||||
emoji_path = str(emoji_obj.full_path)
|
||||
emoji_description = emoji_obj.description
|
||||
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else ""
|
||||
emoji_base64 = ImageUtils.image_path_to_base64(emoji_path)
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiService] 无法将表情包文件转换为base64: {emoji_path}")
|
||||
return None
|
||||
|
||||
logger.debug(f"[EmojiService] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}")
|
||||
return emoji_base64, emoji_description, matched_emotion
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 获取表情包失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
||||
"""随机获取指定数量的表情包"""
|
||||
if not isinstance(count, int):
|
||||
raise TypeError("count 必须是整数类型")
|
||||
if count < 0:
|
||||
raise ValueError("count 不能为负数")
|
||||
if count == 0:
|
||||
logger.warning("[EmojiService] count 为0,返回空列表")
|
||||
return []
|
||||
|
||||
try:
|
||||
all_emojis = emoji_manager.emojis
|
||||
|
||||
if not all_emojis:
|
||||
logger.warning("[EmojiService] 没有可用的表情包")
|
||||
return []
|
||||
|
||||
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
|
||||
if not valid_emojis:
|
||||
logger.warning("[EmojiService] 没有有效的表情包")
|
||||
return []
|
||||
|
||||
if len(valid_emojis) < count:
|
||||
logger.debug(
|
||||
f"[EmojiService] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包"
|
||||
)
|
||||
count = len(valid_emojis)
|
||||
|
||||
selected_emojis = random.sample(valid_emojis, count)
|
||||
|
||||
results = []
|
||||
for selected_emoji in selected_emojis:
|
||||
emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}")
|
||||
continue
|
||||
|
||||
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
|
||||
|
||||
emoji_manager.update_emoji_usage(selected_emoji)
|
||||
results.append((emoji_base64, selected_emoji.description, matched_emotion))
|
||||
|
||||
if not results and count > 0:
|
||||
logger.warning("[EmojiService] 随机获取表情包失败,没有一个可以成功处理")
|
||||
return []
|
||||
|
||||
logger.debug(f"[EmojiService] 成功获取 {len(results)} 个随机表情包")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 获取随机表情包失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||
"""根据情感标签获取表情包"""
|
||||
if not emotion:
|
||||
raise ValueError("情感标签不能为空")
|
||||
if not isinstance(emotion, str):
|
||||
raise TypeError("情感标签必须是字符串类型")
|
||||
try:
|
||||
logger.info(f"[EmojiService] 根据情感获取表情包: {emotion}")
|
||||
|
||||
all_emojis = emoji_manager.emojis
|
||||
|
||||
matching_emojis = []
|
||||
matching_emojis.extend(
|
||||
emoji_obj
|
||||
for emoji_obj in all_emojis
|
||||
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
|
||||
)
|
||||
if not matching_emojis:
|
||||
logger.warning(f"[EmojiService] 未找到匹配情感 '{emotion}' 的表情包")
|
||||
return None
|
||||
|
||||
selected_emoji = random.choice(matching_emojis)
|
||||
emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path)
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}")
|
||||
return None
|
||||
|
||||
emoji_manager.update_emoji_usage(selected_emoji)
|
||||
|
||||
logger.info(f"[EmojiService] 成功获取情感表情包: {selected_emoji.description}")
|
||||
return emoji_base64, selected_emoji.description, emotion
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 根据情感获取表情包失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 表情包信息查询函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_count() -> int:
|
||||
try:
|
||||
return len(emoji_manager.emojis)
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 获取表情包数量失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def get_info():
|
||||
try:
|
||||
return {
|
||||
"current_count": len(emoji_manager.emojis),
|
||||
"max_count": global_config.emoji.max_reg_num,
|
||||
"available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 获取表情包信息失败: {e}")
|
||||
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
|
||||
|
||||
|
||||
def get_emotions() -> List[str]:
|
||||
try:
|
||||
emotions = set()
|
||||
|
||||
for emoji_obj in emoji_manager.emojis:
|
||||
if not emoji_obj.is_deleted and emoji_obj.emotion:
|
||||
emotions.update(emoji_obj.emotion)
|
||||
|
||||
return sorted(list(emotions))
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 获取情感标签失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_all() -> List[Tuple[str, str, str]]:
|
||||
try:
|
||||
all_emojis = emoji_manager.emojis
|
||||
|
||||
if not all_emojis:
|
||||
logger.warning("[EmojiService] 没有可用的表情包")
|
||||
return []
|
||||
|
||||
results = []
|
||||
for emoji_obj in all_emojis:
|
||||
if emoji_obj.is_deleted:
|
||||
continue
|
||||
|
||||
emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path))
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiService] 无法转换表情包为base64: {emoji_obj.full_path}")
|
||||
continue
|
||||
|
||||
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "随机表情"
|
||||
results.append((emoji_base64, emoji_obj.description, matched_emotion))
|
||||
|
||||
logger.debug(f"[EmojiService] 成功获取 {len(results)} 个表情包")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 获取所有表情包失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_descriptions() -> List[str]:
|
||||
try:
|
||||
descriptions = []
|
||||
|
||||
descriptions.extend(
|
||||
emoji_obj.description
|
||||
for emoji_obj in emoji_manager.emojis
|
||||
if not emoji_obj.is_deleted and emoji_obj.description
|
||||
)
|
||||
return descriptions
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 获取表情包描述失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 表情包注册函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def register_emoji(image_base64: str, filename: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""注册新的表情包"""
|
||||
if not image_base64:
|
||||
raise ValueError("图片base64编码不能为空")
|
||||
if not isinstance(image_base64, str):
|
||||
raise TypeError("image_base64必须是字符串类型")
|
||||
if filename is not None and not isinstance(filename, str):
|
||||
raise TypeError("filename必须是字符串类型或None")
|
||||
|
||||
try:
|
||||
logger.info(f"[EmojiService] 开始注册表情包,文件名: {filename or '自动生成'}")
|
||||
|
||||
count_before = len(emoji_manager.emojis)
|
||||
max_count = global_config.emoji.max_reg_num
|
||||
|
||||
can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace)
|
||||
|
||||
if not can_register:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"表情包数量已达上限({count_before}/{max_count})且未启用替换功能",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||
|
||||
if not filename:
|
||||
import time as _time
|
||||
|
||||
timestamp = int(_time.time())
|
||||
microseconds = int(_time.time() * 1000000) % 1000000
|
||||
|
||||
random_bytes = random.getrandbits(72).to_bytes(9, "big")
|
||||
short_id = base64.b64encode(random_bytes).decode("ascii")[:12].rstrip("=")
|
||||
short_id = short_id.replace("/", "_").replace("+", "-")
|
||||
filename = f"emoji_{timestamp}_{microseconds}_{short_id}"
|
||||
|
||||
if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
|
||||
filename = f"{filename}.png"
|
||||
|
||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
||||
attempts = 0
|
||||
max_attempts = 10
|
||||
while os.path.exists(temp_file_path) and attempts < max_attempts:
|
||||
random_bytes = random.getrandbits(48).to_bytes(6, "big")
|
||||
short_id = base64.b64encode(random_bytes).decode("ascii")[:8].rstrip("=")
|
||||
short_id = short_id.replace("/", "_").replace("+", "-")
|
||||
|
||||
name_part, ext = os.path.splitext(filename)
|
||||
base_name = name_part.rsplit("_", 1)[0]
|
||||
filename = f"{base_name}_{short_id}{ext}"
|
||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
||||
attempts += 1
|
||||
|
||||
if os.path.exists(temp_file_path):
|
||||
uuid_short = str(uuid.uuid4())[:8]
|
||||
name_part, ext = os.path.splitext(filename)
|
||||
base_name = name_part.rsplit("_", 1)[0]
|
||||
filename = f"{base_name}_{uuid_short}{ext}"
|
||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
||||
|
||||
counter = 1
|
||||
original_filename = filename
|
||||
while os.path.exists(temp_file_path):
|
||||
name_part, ext = os.path.splitext(original_filename)
|
||||
filename = f"{name_part}_{counter}{ext}"
|
||||
temp_file_path = os.path.join(EMOJI_DIR, filename)
|
||||
counter += 1
|
||||
|
||||
if counter > 100:
|
||||
logger.error(f"[EmojiService] 无法生成唯一文件名,尝试次数过多: {original_filename}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "无法生成唯一文件名,请稍后重试",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
try:
|
||||
if not ImageUtils.base64_to_image(image_base64, temp_file_path):
|
||||
logger.error(f"[EmojiService] 无法保存base64图片到文件: {temp_file_path}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "无法保存图片文件",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
logger.debug(f"[EmojiService] 图片已保存到临时文件: {temp_file_path}")
|
||||
|
||||
except Exception as save_error:
|
||||
logger.error(f"[EmojiService] 保存图片文件失败: {save_error}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"保存图片文件失败: {str(save_error)}",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
register_success = await emoji_manager.register_emoji_by_filename(filename)
|
||||
|
||||
if not register_success and os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
logger.debug(f"[EmojiService] 已清理临时文件: {temp_file_path}")
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"[EmojiService] 清理临时文件失败: {cleanup_error}")
|
||||
|
||||
if register_success:
|
||||
count_after = len(emoji_manager.emojis)
|
||||
replaced = count_after <= count_before
|
||||
|
||||
new_emoji_info = None
|
||||
if count_after > count_before or replaced:
|
||||
try:
|
||||
for emoji_obj in reversed(emoji_manager.emojis):
|
||||
if not emoji_obj.is_deleted and (
|
||||
emoji_obj.file_name == filename
|
||||
or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path))
|
||||
):
|
||||
new_emoji_info = emoji_obj
|
||||
break
|
||||
except Exception as find_error:
|
||||
logger.warning(f"[EmojiService] 查找新注册表情包信息失败: {find_error}")
|
||||
|
||||
description = new_emoji_info.description if new_emoji_info else None
|
||||
emotions = new_emoji_info.emotion if new_emoji_info else None
|
||||
emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
|
||||
"description": description,
|
||||
"emotions": emotions,
|
||||
"replaced": replaced,
|
||||
"hash": emoji_hash,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiService] 注册表情包时发生异常: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"注册过程中发生错误: {str(e)}",
|
||||
"description": None,
|
||||
"emotions": None,
|
||||
"replaced": None,
|
||||
"hash": None,
|
||||
}
|
||||
21
src/services/frequency_service.py
Normal file
21
src/services/frequency_service.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""频率控制服务模块
|
||||
|
||||
提供聊天频率控制的核心功能。
|
||||
"""
|
||||
|
||||
from src.chat.heart_flow.frequency_control import frequency_control_manager
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
def get_current_talk_value(chat_id: str) -> float:
|
||||
return frequency_control_manager.get_or_create_frequency_control(
|
||||
chat_id
|
||||
).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id)
|
||||
|
||||
|
||||
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:
|
||||
frequency_control_manager.get_or_create_frequency_control(chat_id).set_talk_frequency_adjust(talk_frequency_adjust)
|
||||
|
||||
|
||||
def get_talk_frequency_adjust(chat_id: str) -> float:
|
||||
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust()
|
||||
256
src/services/generator_service.py
Normal file
256
src/services/generator_service.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
回复器服务模块
|
||||
|
||||
提供回复器相关的核心功能。
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
from src.chat.replyer.replyer_manager import replyer_manager
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.common.data_models.message_data_model import ReplySetModel
|
||||
from src.common.logger import get_logger
|
||||
from src.core.types import ActionInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("generator_service")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 回复器获取函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_replyer(
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer | PrivateReplyer]:
|
||||
"""获取回复器对象"""
|
||||
if not chat_id and not chat_stream:
|
||||
raise ValueError("chat_stream 和 chat_id 不可均为空")
|
||||
try:
|
||||
logger.debug(f"[GeneratorService] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
||||
return replyer_manager.get_replyer(
|
||||
chat_stream=chat_stream,
|
||||
chat_id=chat_id,
|
||||
request_type=request_type,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorService] 获取回复器时发生意外错误: {e}", exc_info=True)
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 回复生成函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def generate_reply(
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
action_data: Optional[Dict[str, Any]] = None,
|
||||
reply_message: Optional["DatabaseMessages"] = None,
|
||||
think_level: int = 1,
|
||||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
chosen_actions: Optional[List["ActionPlannerInfo"]] = None,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
enable_tool: bool = False,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
request_type: str = "generator_api",
|
||||
from_plugin: bool = True,
|
||||
reply_time_point: Optional[float] = None,
|
||||
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
|
||||
"""生成回复"""
|
||||
try:
|
||||
if reply_time_point is None:
|
||||
reply_time_point = time.time()
|
||||
|
||||
logger.debug("[GeneratorService] 开始生成回复")
|
||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorService] 无法获取回复器")
|
||||
return False, None
|
||||
|
||||
if action_data:
|
||||
if not extra_info:
|
||||
extra_info = action_data.get("extra_info", "")
|
||||
if not reply_reason:
|
||||
reply_reason = action_data.get("reason", "")
|
||||
if unknown_words is None:
|
||||
uw = action_data.get("unknown_words")
|
||||
if isinstance(uw, list):
|
||||
cleaned: List[str] = []
|
||||
for item in uw:
|
||||
if isinstance(item, str):
|
||||
s = item.strip()
|
||||
if s:
|
||||
cleaned.append(s)
|
||||
if cleaned:
|
||||
unknown_words = cleaned
|
||||
|
||||
success, llm_response = await replyer.generate_reply_with_context(
|
||||
extra_info=extra_info,
|
||||
available_actions=available_actions,
|
||||
chosen_actions=chosen_actions,
|
||||
enable_tool=enable_tool,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
unknown_words=unknown_words,
|
||||
think_level=think_level,
|
||||
from_plugin=from_plugin,
|
||||
stream_id=chat_stream.session_id if chat_stream else chat_id,
|
||||
reply_time_point=reply_time_point,
|
||||
log_reply=False,
|
||||
)
|
||||
if not success:
|
||||
logger.warning("[GeneratorService] 回复生成失败")
|
||||
return False, None
|
||||
reply_set: Optional[ReplySetModel] = None
|
||||
if content := llm_response.content:
|
||||
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
|
||||
llm_response.processed_output = processed_response
|
||||
reply_set = ReplySetModel()
|
||||
for text in processed_response:
|
||||
reply_set.add_text_content(text)
|
||||
llm_response.reply_set = reply_set
|
||||
logger.debug(f"[GeneratorService] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
|
||||
|
||||
try:
|
||||
PlanReplyLogger.log_reply(
|
||||
chat_id=chat_stream.session_id if chat_stream else (chat_id or ""),
|
||||
prompt=llm_response.prompt or "",
|
||||
output=llm_response.content,
|
||||
processed_output=llm_response.processed_output,
|
||||
model=llm_response.model,
|
||||
timing=llm_response.timing,
|
||||
reasoning=llm_response.reasoning,
|
||||
think_level=think_level,
|
||||
success=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("[GeneratorService] 记录reply日志失败")
|
||||
|
||||
return success, llm_response
|
||||
|
||||
except ValueError as ve:
|
||||
raise ve
|
||||
|
||||
except UserWarning as uw:
|
||||
logger.warning(f"[GeneratorService] 中断了生成: {uw}")
|
||||
return False, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorService] 生成回复时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False, None
|
||||
|
||||
|
||||
async def rewrite_reply(
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
reply_data: Optional[Dict[str, Any]] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
raw_reply: str = "",
|
||||
reason: str = "",
|
||||
reply_to: str = "",
|
||||
request_type: str = "generator_api",
|
||||
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
|
||||
"""重写回复"""
|
||||
try:
|
||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorService] 无法获取回复器")
|
||||
return False, None
|
||||
|
||||
logger.info("[GeneratorService] 开始重写回复")
|
||||
|
||||
if reply_data:
|
||||
raw_reply = raw_reply or reply_data.get("raw_reply", "")
|
||||
reason = reason or reply_data.get("reason", "")
|
||||
reply_to = reply_to or reply_data.get("reply_to", "")
|
||||
|
||||
success, llm_response = await replyer.rewrite_reply_with_context(
|
||||
raw_reply=raw_reply,
|
||||
reason=reason,
|
||||
reply_to=reply_to,
|
||||
)
|
||||
reply_set: Optional[ReplySetModel] = None
|
||||
if success and llm_response and (content := llm_response.content):
|
||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||
llm_response.reply_set = reply_set
|
||||
if success:
|
||||
logger.info(f"[GeneratorService] 重写回复成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
|
||||
else:
|
||||
logger.warning("[GeneratorService] 重写回复失败")
|
||||
|
||||
return success, llm_response
|
||||
|
||||
except ValueError as ve:
|
||||
raise ve
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorService] 重写回复时出错: {e}")
|
||||
return False, None
|
||||
|
||||
|
||||
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[ReplySetModel]:
|
||||
"""将文本处理为更拟人化的文本"""
|
||||
if not isinstance(content, str):
|
||||
raise ValueError("content 必须是字符串类型")
|
||||
try:
|
||||
reply_set = ReplySetModel()
|
||||
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
|
||||
|
||||
for text in processed_response:
|
||||
reply_set.add_text_content(text)
|
||||
|
||||
return reply_set
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorService] 处理人形文本时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def generate_response_custom(
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
request_type: str = "generator_api",
|
||||
prompt: str = "",
|
||||
) -> Optional[str]:
|
||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorService] 无法获取回复器")
|
||||
return None
|
||||
|
||||
try:
|
||||
logger.debug("[GeneratorService] 开始生成自定义回复")
|
||||
response, _, _, _ = await replyer.llm_generate_content(prompt)
|
||||
if response:
|
||||
logger.debug("[GeneratorService] 自定义回复生成成功")
|
||||
return response
|
||||
else:
|
||||
logger.warning("[GeneratorService] 自定义回复生成失败")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorService] 生成自定义回复时出错: {e}")
|
||||
return None
|
||||
155
src/services/llm_service.py
Normal file
155
src/services/llm_service.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""LLM 服务模块
|
||||
|
||||
提供与 LLM 模型交互的核心功能。
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import TaskConfig
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("llm_service")
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, TaskConfig]:
|
||||
"""获取所有可用的模型配置
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 模型配置字典,key为模型名称,value为模型配置
|
||||
"""
|
||||
try:
|
||||
models = config_manager.get_model_config().model_task_config
|
||||
attrs = dir(models)
|
||||
rets: Dict[str, TaskConfig] = {}
|
||||
for attr in attrs:
|
||||
if not attr.startswith("__"):
|
||||
try:
|
||||
value = getattr(models, attr)
|
||||
if not callable(value) and isinstance(value, TaskConfig):
|
||||
rets[attr] = value
|
||||
except Exception as e:
|
||||
logger.debug(f"[LLMService] 获取属性 {attr} 失败: {e}")
|
||||
continue
|
||||
return rets
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMService] 获取可用模型失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
async def generate_with_model(
|
||||
prompt: str,
|
||||
model_config: TaskConfig,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str]:
|
||||
"""使用指定模型生成内容
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_config: 模型配置(从 get_available_models 获取的模型配置)
|
||||
request_type: 请求类型标识
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"[LLMService] 完整提示词: {prompt}")
|
||||
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||
|
||||
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(
|
||||
prompt, temperature=temperature, max_tokens=max_tokens
|
||||
)
|
||||
return True, response, reasoning_content, model_name
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMService] {error_msg}")
|
||||
return False, error_msg, "", ""
|
||||
|
||||
|
||||
async def generate_with_model_with_tools(
|
||||
prompt: str,
|
||||
model_config: TaskConfig,
|
||||
tool_options: List[Dict[str, Any]] | None = None,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
|
||||
"""使用指定模型和工具生成内容
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_config: 模型配置(从 get_available_models 获取的模型配置)
|
||||
tool_options: 工具选项列表
|
||||
request_type: 请求类型标识
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表)
|
||||
"""
|
||||
try:
|
||||
model_name_list = model_config.model_list
|
||||
logger.info(f"使用模型{model_name_list}生成内容")
|
||||
logger.debug(f"完整提示词: {prompt}")
|
||||
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||
|
||||
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
|
||||
prompt, tools=tool_options, temperature=temperature, max_tokens=max_tokens
|
||||
)
|
||||
return True, response, reasoning_content, model_name, tool_call
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMService] {error_msg}")
|
||||
return False, error_msg, "", "", None
|
||||
|
||||
|
||||
async def generate_with_model_with_tools_by_message_factory(
|
||||
message_factory: Callable[[BaseClient], List[Message]],
|
||||
model_config: TaskConfig,
|
||||
tool_options: List[Dict[str, Any]] | None = None,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
|
||||
"""使用指定模型和工具生成内容(通过消息工厂构建消息列表)
|
||||
|
||||
Args:
|
||||
message_factory: 消息工厂函数
|
||||
model_config: 模型配置
|
||||
tool_options: 工具选项列表
|
||||
request_type: 请求类型标识
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表)
|
||||
"""
|
||||
try:
|
||||
model_name_list = model_config.model_list
|
||||
logger.info(f"使用模型 {model_name_list} 生成内容")
|
||||
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||
|
||||
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_with_message_async(
|
||||
message_factory=message_factory,
|
||||
tools=tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return True, response, reasoning_content, model_name, tool_call
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMService] {error_msg}")
|
||||
return False, error_msg, "", "", None
|
||||
296
src/services/message_service.py
Normal file
296
src/services/message_service.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
消息服务模块
|
||||
|
||||
提供消息查询和构建成字符串的核心功能。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
build_readable_messages_with_list,
|
||||
get_person_id_list,
|
||||
get_raw_msg_before_timestamp,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
get_raw_msg_before_timestamp_with_users,
|
||||
get_raw_msg_by_timestamp,
|
||||
get_raw_msg_by_timestamp_random,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
get_raw_msg_by_timestamp_with_chat_users,
|
||||
get_raw_msg_by_timestamp_with_users,
|
||||
num_new_messages_since,
|
||||
num_new_messages_since_with_users,
|
||||
)
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 消息查询函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_messages_by_time(
|
||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[DatabaseMessages]:
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode))
|
||||
return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)
|
||||
|
||||
|
||||
def get_messages_by_time_in_chat(
|
||||
chat_id: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_mai: bool = False,
|
||||
filter_command: bool = False,
|
||||
filter_intercept_message_level: Optional[int] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
return get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp_start=start_time,
|
||||
timestamp_end=end_time,
|
||||
limit=limit,
|
||||
limit_mode=limit_mode,
|
||||
filter_bot=filter_mai,
|
||||
filter_command=filter_command,
|
||||
filter_intercept_message_level=filter_intercept_message_level,
|
||||
)
|
||||
|
||||
|
||||
def get_messages_by_time_in_chat_inclusive(
|
||||
chat_id: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_mai: bool = False,
|
||||
filter_command: bool = False,
|
||||
filter_intercept_message_level: Optional[int] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=chat_id,
|
||||
timestamp_start=start_time,
|
||||
timestamp_end=end_time,
|
||||
limit=limit,
|
||||
limit_mode=limit_mode,
|
||||
filter_bot=filter_mai,
|
||||
filter_command=filter_command,
|
||||
filter_intercept_message_level=filter_intercept_message_level,
|
||||
)
|
||||
if filter_mai:
|
||||
return filter_mai_messages(messages)
|
||||
return messages
|
||||
|
||||
|
||||
def get_messages_by_time_in_chat_for_users(
|
||||
chat_id: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
person_ids: List[str],
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[DatabaseMessages]:
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
return get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode)
|
||||
|
||||
|
||||
def get_random_chat_messages(
|
||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[DatabaseMessages]:
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode))
|
||||
return get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)
|
||||
|
||||
|
||||
def get_messages_by_time_for_users(
|
||||
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[DatabaseMessages]:
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
|
||||
|
||||
|
||||
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[DatabaseMessages]:
|
||||
if not isinstance(timestamp, (int, float)):
|
||||
raise ValueError("timestamp 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_before_timestamp(timestamp, limit))
|
||||
return get_raw_msg_before_timestamp(timestamp, limit)
|
||||
|
||||
|
||||
def get_messages_before_time_in_chat(
|
||||
chat_id: str,
|
||||
timestamp: float,
|
||||
limit: int = 0,
|
||||
filter_mai: bool = False,
|
||||
filter_intercept_message_level: Optional[int] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
if not isinstance(timestamp, (int, float)):
|
||||
raise ValueError("timestamp 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
messages = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=timestamp,
|
||||
limit=limit,
|
||||
filter_intercept_message_level=filter_intercept_message_level,
|
||||
)
|
||||
if filter_mai:
|
||||
return filter_mai_messages(messages)
|
||||
return messages
|
||||
|
||||
|
||||
def get_messages_before_time_for_users(
|
||||
timestamp: float, person_ids: List[str], limit: int = 0
|
||||
) -> List[DatabaseMessages]:
|
||||
if not isinstance(timestamp, (int, float)):
|
||||
raise ValueError("timestamp 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
return get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit)
|
||||
|
||||
|
||||
def get_recent_messages(
|
||||
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[DatabaseMessages]:
|
||||
if not isinstance(hours, (int, float)) or hours < 0:
|
||||
raise ValueError("hours 不能是负数")
|
||||
if not isinstance(limit, int) or limit < 0:
|
||||
raise ValueError("limit 必须是非负整数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
now = time.time()
|
||||
start_time = now - hours * 3600
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode))
|
||||
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 消息计数函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int:
|
||||
if not isinstance(start_time, (int, float)):
|
||||
raise ValueError("start_time 必须是数字类型")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
return num_new_messages_since(chat_id, start_time, end_time)
|
||||
|
||||
|
||||
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
return num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 消息格式化函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def build_readable_messages_to_str(
|
||||
messages: List[DatabaseMessages],
|
||||
replace_bot_name: bool = True,
|
||||
timestamp_mode: str = "relative",
|
||||
read_mark: float = 0.0,
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
) -> str:
|
||||
return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions)
|
||||
|
||||
|
||||
async def build_readable_messages_with_details(
|
||||
messages: List[DatabaseMessages],
|
||||
replace_bot_name: bool = True,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||
return await build_readable_messages_with_list(messages, replace_bot_name, timestamp_mode, truncate)
|
||||
|
||||
|
||||
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
return await get_person_id_list(messages)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 消息过滤函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
|
||||
"""从消息列表中移除麦麦的消息"""
|
||||
return [msg for msg in messages if not is_bot_self(msg.user_info.platform, msg.user_info.user_id)]
|
||||
|
||||
|
||||
def translate_pid_to_description(pid: str) -> str:
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Images).where((col(Images.id) == int(pid)) & (col(Images.image_type) == ImageType.IMAGE))
|
||||
if pid.isdigit()
|
||||
else None
|
||||
)
|
||||
image = session.exec(statement).first() if statement is not None else None
|
||||
description = ""
|
||||
if image and image.description and image.description.strip():
|
||||
description = image.description.strip()
|
||||
else:
|
||||
description = "[图片]"
|
||||
return description
|
||||
65
src/services/person_service.py
Normal file
65
src/services/person_service.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""个人信息服务模块
|
||||
|
||||
提供个人信息查询的核心功能。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.person_info import Person
|
||||
|
||||
logger = get_logger("person_service")
|
||||
|
||||
|
||||
def get_person_id(platform: str, user_id: int | str) -> str:
|
||||
"""根据平台和用户ID获取person_id
|
||||
|
||||
Args:
|
||||
platform: 平台名称,如 "qq", "telegram" 等
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
str: 唯一的person_id(MD5哈希值)
|
||||
"""
|
||||
try:
|
||||
return Person(platform=platform, user_id=str(user_id)).person_id
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonService] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}")
|
||||
return ""
|
||||
|
||||
|
||||
async def get_person_value(person_id: str, field_name: str, default: Any = None) -> Any:
|
||||
"""根据person_id和字段名获取某个值
|
||||
|
||||
Args:
|
||||
person_id: 用户的唯一标识ID
|
||||
field_name: 要获取的字段名,如 "nickname", "impression" 等
|
||||
default: 当字段不存在或获取失败时返回的默认值
|
||||
|
||||
Returns:
|
||||
Any: 字段值或默认值
|
||||
"""
|
||||
try:
|
||||
person = Person(person_id=person_id)
|
||||
value = getattr(person, field_name)
|
||||
return value if value is not None else default
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonService] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}")
|
||||
return default
|
||||
|
||||
|
||||
def get_person_id_by_name(person_name: str) -> str:
|
||||
"""根据用户名获取person_id
|
||||
|
||||
Args:
|
||||
person_name: 用户名
|
||||
|
||||
Returns:
|
||||
str: person_id,如果未找到返回空字符串
|
||||
"""
|
||||
try:
|
||||
person = Person(person_name=person_name)
|
||||
return person.person_id
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonService] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
|
||||
return ""
|
||||
347
src/services/send_service.py
Normal file
347
src/services/send_service.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
发送服务模块
|
||||
|
||||
提供发送各种类型消息的核心功能。
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
from typing import Optional, Union, Dict, List, TYPE_CHECKING, Tuple
|
||||
|
||||
from maim_message import MessageBase, BaseMessageInfo, Seg
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||
from src.common.data_models.message_data_model import ReplyContentType
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_data_model import ForwardNode, ReplyContent, ReplySetModel
|
||||
|
||||
logger = get_logger("send_service")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 内部实现函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def _send_to_target(
|
||||
message_segment: Seg,
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
set_reply: bool = False,
|
||||
reply_message: Optional["DatabaseMessages"] = None,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
) -> bool:
|
||||
"""向指定目标发送消息的内部实现"""
|
||||
try:
|
||||
if set_reply and not reply_message:
|
||||
logger.warning("[SendService] 使用引用回复,但未提供回复消息")
|
||||
return False
|
||||
|
||||
if show_log:
|
||||
logger.debug(f"[SendService] 发送{message_segment.type}消息到 {stream_id}")
|
||||
|
||||
target_stream = _chat_manager.get_session_by_session_id(stream_id)
|
||||
if not target_stream:
|
||||
logger.error(f"[SendService] 未找到聊天流: {stream_id}")
|
||||
return False
|
||||
|
||||
message_sender = UniversalMessageSender()
|
||||
|
||||
current_time = time.time()
|
||||
message_id = f"send_api_{int(current_time * 1000)}"
|
||||
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
)
|
||||
|
||||
reply_to_platform_id = ""
|
||||
anchor_message: Optional[MaiMessage] = None
|
||||
if reply_message:
|
||||
anchor_message = db_message_to_mai_message(reply_message)
|
||||
if anchor_message:
|
||||
logger.debug(f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}")
|
||||
reply_to_platform_id = (
|
||||
f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}"
|
||||
)
|
||||
|
||||
sender_info = None
|
||||
if target_stream.context and target_stream.context.message:
|
||||
sender_info = target_stream.context.message.message_info.user_info
|
||||
|
||||
bot_message = MessageSending(
|
||||
message_id=message_id,
|
||||
session=target_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=sender_info,
|
||||
message_segment=message_segment,
|
||||
display_message=display_message,
|
||||
reply=anchor_message,
|
||||
is_head=True,
|
||||
is_emoji=(message_segment.type == "emoji"),
|
||||
thinking_start_time=current_time,
|
||||
reply_to=reply_to_platform_id,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
|
||||
sent_msg = await message_sender.send_message(
|
||||
bot_message,
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
|
||||
if sent_msg:
|
||||
logger.debug(f"[SendService] 成功发送消息到 {stream_id}")
|
||||
return True
|
||||
else:
|
||||
logger.error("[SendService] 发送消息失败")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SendService] 发送消息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMessage]:
|
||||
"""将数据库消息重建为 MaiMessage 对象,用于回复引用。"""
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo
|
||||
from src.common.data_models.message_component_data_model import MessageSequence
|
||||
|
||||
user_info = UserInfo(
|
||||
user_id=message_obj.user_info.user_id or "",
|
||||
user_nickname=message_obj.user_info.user_nickname or "",
|
||||
user_cardname=message_obj.user_info.user_cardname,
|
||||
)
|
||||
|
||||
group_info = None
|
||||
if message_obj.chat_info.group_info:
|
||||
group_info = GroupInfo(
|
||||
group_id=message_obj.chat_info.group_info.group_id or "",
|
||||
group_name=message_obj.chat_info.group_info.group_name or "",
|
||||
)
|
||||
|
||||
msg = MaiMessage(
|
||||
message_id=message_obj.message_id,
|
||||
timestamp=datetime.fromtimestamp(message_obj.time) if message_obj.time else datetime.now(),
|
||||
)
|
||||
msg.message_info = MessageInfo(user_info=user_info, group_info=group_info)
|
||||
msg.platform = message_obj.chat_info.platform or ""
|
||||
msg.session_id = message_obj.chat_info.stream_id or ""
|
||||
msg.processed_plain_text = message_obj.processed_plain_text
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.initialized = True
|
||||
return msg
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 公共函数 - 预定义类型的发送函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def text_to_stream(
|
||||
text: str,
|
||||
stream_id: str,
|
||||
typing: bool = False,
|
||||
set_reply: bool = False,
|
||||
reply_message: Optional["DatabaseMessages"] = None,
|
||||
storage_message: bool = True,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
) -> bool:
|
||||
"""向指定流发送文本消息"""
|
||||
return await _send_to_target(
|
||||
message_segment=Seg(type="text", data=text),
|
||||
stream_id=stream_id,
|
||||
display_message="",
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
reply_message=reply_message,
|
||||
storage_message=storage_message,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
|
||||
|
||||
async def emoji_to_stream(
|
||||
emoji_base64: str,
|
||||
stream_id: str,
|
||||
storage_message: bool = True,
|
||||
set_reply: bool = False,
|
||||
reply_message: Optional["DatabaseMessages"] = None,
|
||||
) -> bool:
|
||||
"""向指定流发送表情包"""
|
||||
return await _send_to_target(
|
||||
message_segment=Seg(type="emoji", data=emoji_base64),
|
||||
stream_id=stream_id,
|
||||
display_message="",
|
||||
typing=False,
|
||||
storage_message=storage_message,
|
||||
set_reply=set_reply,
|
||||
reply_message=reply_message,
|
||||
)
|
||||
|
||||
|
||||
async def image_to_stream(
|
||||
image_base64: str,
|
||||
stream_id: str,
|
||||
storage_message: bool = True,
|
||||
set_reply: bool = False,
|
||||
reply_message: Optional["DatabaseMessages"] = None,
|
||||
) -> bool:
|
||||
"""向指定流发送图片"""
|
||||
return await _send_to_target(
|
||||
message_segment=Seg(type="image", data=image_base64),
|
||||
stream_id=stream_id,
|
||||
display_message="",
|
||||
typing=False,
|
||||
storage_message=storage_message,
|
||||
set_reply=set_reply,
|
||||
reply_message=reply_message,
|
||||
)
|
||||
|
||||
|
||||
async def command_to_stream(
|
||||
command: Union[str, dict],
|
||||
stream_id: str,
|
||||
storage_message: bool = True,
|
||||
display_message: str = "",
|
||||
) -> bool:
|
||||
"""向指定流发送命令"""
|
||||
return await _send_to_target(
|
||||
message_segment=Seg(type="command", data=command), # type: ignore
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
typing=False,
|
||||
storage_message=storage_message,
|
||||
set_reply=False,
|
||||
)
|
||||
|
||||
|
||||
async def custom_to_stream(
|
||||
message_type: str,
|
||||
content: str | Dict,
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_message: Optional["DatabaseMessages"] = None,
|
||||
set_reply: bool = False,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
) -> bool:
|
||||
"""向指定流发送自定义类型消息"""
|
||||
return await _send_to_target(
|
||||
message_segment=Seg(type=message_type, data=content), # type: ignore
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
typing=typing,
|
||||
reply_message=reply_message,
|
||||
set_reply=set_reply,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
|
||||
|
||||
async def custom_reply_set_to_stream(
|
||||
reply_set: "ReplySetModel",
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_message: Optional["DatabaseMessages"] = None,
|
||||
set_reply: bool = False,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
) -> bool:
|
||||
"""向指定流发送混合型消息集"""
|
||||
flag: bool = True
|
||||
for reply_content in reply_set.reply_data:
|
||||
status: bool = False
|
||||
message_seg, need_typing = _parse_content_to_seg(reply_content)
|
||||
status = await _send_to_target(
|
||||
message_segment=message_seg,
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
typing=bool(need_typing and typing),
|
||||
reply_message=reply_message,
|
||||
set_reply=set_reply,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
if not status:
|
||||
flag = False
|
||||
logger.error(
|
||||
f"[SendService] 发送{repr(reply_content.content_type)}消息失败,消息内容:{str(reply_content.content)[:100]}"
|
||||
)
|
||||
|
||||
return flag
|
||||
|
||||
|
||||
def _parse_content_to_seg(reply_content: "ReplyContent") -> Tuple[Seg, bool]:
|
||||
"""把 ReplyContent 转换为 Seg 结构"""
|
||||
content_type = reply_content.content_type
|
||||
if content_type == ReplyContentType.TEXT:
|
||||
text_data: str = reply_content.content # type: ignore
|
||||
return Seg(type="text", data=text_data), True
|
||||
elif content_type == ReplyContentType.IMAGE:
|
||||
return Seg(type="image", data=reply_content.content), False # type: ignore
|
||||
elif content_type == ReplyContentType.EMOJI:
|
||||
return Seg(type="emoji", data=reply_content.content), False # type: ignore
|
||||
elif content_type == ReplyContentType.COMMAND:
|
||||
return Seg(type="command", data=reply_content.content), False # type: ignore
|
||||
elif content_type == ReplyContentType.VOICE:
|
||||
return Seg(type="voice", data=reply_content.content), False # type: ignore
|
||||
elif content_type == ReplyContentType.HYBRID:
|
||||
hybrid_message_list_data: List[ReplyContent] = reply_content.content # type: ignore
|
||||
assert isinstance(hybrid_message_list_data, list), "混合类型内容必须是列表"
|
||||
sub_seg_list: List[Seg] = []
|
||||
for sub_content in hybrid_message_list_data:
|
||||
sub_content_type = sub_content.content_type
|
||||
sub_content_data = sub_content.content
|
||||
|
||||
if sub_content_type == ReplyContentType.TEXT:
|
||||
sub_seg_list.append(Seg(type="text", data=sub_content_data)) # type: ignore
|
||||
elif sub_content_type == ReplyContentType.IMAGE:
|
||||
sub_seg_list.append(Seg(type="image", data=sub_content_data)) # type: ignore
|
||||
elif sub_content_type == ReplyContentType.EMOJI:
|
||||
sub_seg_list.append(Seg(type="emoji", data=sub_content_data)) # type: ignore
|
||||
else:
|
||||
logger.warning(f"[SendService] 混合类型中不支持的子内容类型: {repr(sub_content_type)}")
|
||||
continue
|
||||
return Seg(type="seglist", data=sub_seg_list), True
|
||||
elif content_type == ReplyContentType.FORWARD:
|
||||
forward_message_list_data: List["ForwardNode"] = reply_content.content # type: ignore
|
||||
assert isinstance(forward_message_list_data, list), "转发类型内容必须是列表"
|
||||
forward_message_list: List[Dict] = []
|
||||
for forward_node in forward_message_list_data:
|
||||
message_segment = Seg(type="id", data=forward_node.content) # type: ignore
|
||||
user_info: Optional[UserInfo] = None
|
||||
if forward_node.user_id and forward_node.user_nickname:
|
||||
assert isinstance(forward_node.content, list), "转发节点内容必须是列表"
|
||||
user_info = UserInfo(user_id=forward_node.user_id, user_nickname=forward_node.user_nickname)
|
||||
single_node_content: List[Seg] = []
|
||||
for sub_content in forward_node.content:
|
||||
if sub_content.content_type != ReplyContentType.FORWARD:
|
||||
sub_seg, _ = _parse_content_to_seg(sub_content)
|
||||
single_node_content.append(sub_seg)
|
||||
message_segment = Seg(type="seglist", data=single_node_content)
|
||||
forward_message_list.append(
|
||||
MessageBase(
|
||||
message_segment=message_segment, message_info=BaseMessageInfo(user_info=user_info)
|
||||
).to_dict()
|
||||
)
|
||||
return Seg(type="forward", data=forward_message_list), False # type: ignore
|
||||
else:
|
||||
message_type_in_str = content_type.value if isinstance(content_type, ReplyContentType) else str(content_type)
|
||||
return Seg(type=message_type_in_str, data=reply_content.content), True # type: ignore
|
||||
Reference in New Issue
Block a user