重构整个插件系统,尝试恢复可启动性,新增插件系统maibot-plugin-sdk依赖

This commit is contained in:
DrSmoothl
2026-03-07 19:40:51 +08:00
parent 2e3dd44ee9
commit ce8d8dfd0a
90 changed files with 3785 additions and 10061 deletions

7
src/services/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""
核心服务层
提供与具体插件系统无关的核心业务服务。
内部模块chat、dream、memory 等)应直接使用此层,
而 plugin_system.apis 仅作为面向插件的薄包装。
"""

View File

@@ -0,0 +1,159 @@
"""
聊天服务模块
提供聊天信息查询和管理的核心功能。
"""
from typing import Any, Dict, List, Optional
from enum import Enum
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from src.common.logger import get_logger
logger = get_logger("chat_service")
class SpecialTypes(Enum):
"""特殊枚举类型"""
ALL_PLATFORMS = "all_platforms"
class ChatManager:
"""聊天管理器 - 负责聊天信息的查询和管理"""
@staticmethod
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
streams.append(stream)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的聊天流")
except Exception as e:
logger.error(f"[ChatService] 获取聊天流失败: {e}")
return streams
@staticmethod
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session:
streams.append(stream)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的群聊流")
except Exception as e:
logger.error(f"[ChatService] 获取群聊流失败: {e}")
return streams
@staticmethod
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if (
platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
) and not stream.is_group_session:
streams.append(stream)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的私聊流")
except Exception as e:
logger.error(f"[ChatService] 获取私聊流失败: {e}")
return streams
@staticmethod
def get_group_stream_by_group_id(
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
if not isinstance(group_id, str):
raise TypeError("group_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
if not group_id:
raise ValueError("group_id 不能为空")
try:
for _, stream in _chat_manager.sessions.items():
if (
stream.is_group_session
and str(stream.group_id) == str(group_id)
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
):
logger.debug(f"[ChatService] 找到群ID {group_id} 的聊天流")
return stream
logger.warning(f"[ChatService] 未找到群ID {group_id} 的聊天流")
except Exception as e:
logger.error(f"[ChatService] 查找群聊流失败: {e}")
return None
@staticmethod
def get_private_stream_by_user_id(
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
if not isinstance(user_id, str):
raise TypeError("user_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
if not user_id:
raise ValueError("user_id 不能为空")
try:
for _, stream in _chat_manager.sessions.items():
if (
not stream.is_group_session
and str(stream.user_id) == str(user_id)
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
):
logger.debug(f"[ChatService] 找到用户ID {user_id} 的私聊流")
return stream
logger.warning(f"[ChatService] 未找到用户ID {user_id} 的私聊流")
except Exception as e:
logger.error(f"[ChatService] 查找私聊流失败: {e}")
return None
@staticmethod
def get_stream_type(chat_stream: BotChatSession) -> str:
if not isinstance(chat_stream, BotChatSession):
raise TypeError("chat_stream 必须是 BotChatSession 类型")
if not chat_stream:
raise ValueError("chat_stream 不能为 None")
return "group" if chat_stream.is_group_session else "private"
@staticmethod
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
if not chat_stream:
raise ValueError("chat_stream 不能为 None")
if not isinstance(chat_stream, BotChatSession):
raise TypeError("chat_stream 必须是 BotChatSession 类型")
try:
info: Dict[str, Any] = {
"session_id": chat_stream.session_id,
"platform": chat_stream.platform,
"type": ChatManager.get_stream_type(chat_stream),
}
if chat_stream.is_group_session:
info["group_id"] = chat_stream.group_id
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info:
info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
else:
info["group_name"] = "未知群聊"
else:
info["user_id"] = chat_stream.user_id
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info:
info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
else:
info["user_name"] = "未知用户"
return info
except Exception as e:
logger.error(f"[ChatService] 获取聊天流信息失败: {e}")
return {}

View File

@@ -0,0 +1,66 @@
"""配置服务模块
提供配置读取的核心功能。
"""
from typing import Any
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("config_service")
def get_global_config(key: str, default: Any = None) -> Any:
"""
安全地从全局配置中获取一个值。
Args:
key: 命名空间式配置键名,使用嵌套访问,如 "section.subsection.key",大小写敏感
default: 如果配置不存在时返回的默认值
Returns:
Any: 配置值或默认值
"""
keys = key.split(".")
current = global_config
try:
for k in keys:
if hasattr(current, k):
current = getattr(current, k)
else:
raise KeyError(f"配置中不存在子空间或键 '{k}'")
return current
except Exception as e:
logger.warning(f"[ConfigService] 获取全局配置 {key} 失败: {e}")
return default
def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any:
"""
从插件配置中获取值,支持嵌套键访问
Args:
plugin_config: 插件配置字典
key: 配置键名,支持嵌套访问如 "section.subsection.key",大小写敏感
default: 如果配置不存在时返回的默认值
Returns:
Any: 配置值或默认值
"""
keys = key.split(".")
current = plugin_config
try:
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
elif hasattr(current, k):
current = getattr(current, k)
else:
raise KeyError(f"配置中不存在子空间或键 '{k}'")
return current
except Exception as e:
logger.warning(f"[ConfigService] 获取插件配置 {key} 失败: {e}")
return default

View File

@@ -0,0 +1,173 @@
"""数据库服务模块
提供数据库操作相关的核心功能。
"""
import json
import time
import traceback
from typing import Any, Optional
from src.common.logger import get_logger
logger = get_logger("database_service")
def _to_dict(record: Any) -> dict[str, Any]:
if record is None:
return {}
if isinstance(record, dict):
return record
if hasattr(record, "model_dump"):
return record.model_dump()
if hasattr(record, "__dict__"):
return dict(record.__dict__)
return {}
async def db_query(
model_class,
data: Optional[dict[str, Any]] = None,
query_type: str = "get",
filters: Optional[dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[list[str]] = None,
single_result: bool = False,
):
try:
if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
if query_type == "get":
query = model_class.select()
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
if order_by:
query = query.order_by(*order_by)
if limit:
query = query.limit(limit)
results = list(query.dicts())
if single_result:
return results[0] if results else None
return results
if query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
record = model_class.create(**data)
return _to_dict(record)
query = model_class.select()
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
if query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
return query.model_class.update(**data).where(*query.stmt._where_criteria).execute()
if query_type == "delete":
return model_class.delete().where(*query.stmt._where_criteria).execute()
return query.count()
except Exception as e:
logger.error(f"[DatabaseService] 数据库操作出错: {e}")
traceback.print_exc()
if query_type == "get":
return None if single_result else []
return None
async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None):
try:
if key_field and key_value is not None:
record = model_class.get_or_none(getattr(model_class, key_field) == key_value)
if record is not None:
for field, value in data.items():
setattr(record, field, value)
record.save()
return _to_dict(record)
new_record = model_class.create(**data)
return _to_dict(new_record)
except Exception as e:
logger.error(f"[DatabaseService] 保存数据库记录出错: {e}")
traceback.print_exc()
return None
async def db_get(
model_class,
filters: Optional[dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
single_result: bool = False,
):
try:
query = model_class.select()
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
if order_by:
query = query.order_by(order_by)
if limit:
query = query.limit(limit)
results = list(query.dicts())
if single_result:
return results[0] if results else None
return results
except Exception as e:
logger.error(f"[DatabaseService] 获取数据库记录出错: {e}")
traceback.print_exc()
return None if single_result else []
async def store_action_info(
chat_stream=None,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: Optional[dict] = None,
action_name: str = "",
action_reasoning: str = "",
):
try:
from src.common.database.database_model import ActionRecords
record_data = {
"action_id": thinking_id or str(int(time.time() * 1000000)),
"time": time.time(),
"action_name": action_name,
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
"action_done": action_done,
"action_reasoning": action_reasoning,
"action_build_into_prompt": action_build_into_prompt,
"action_prompt_display": action_prompt_display,
}
if chat_stream:
record_data.update(
{
"chat_id": getattr(chat_stream, "stream_id", ""),
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
"chat_info_platform": getattr(chat_stream, "platform", ""),
}
)
else:
record_data.update({"chat_id": "", "chat_info_stream_id": "", "chat_info_platform": ""})
saved_record = await db_save(
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
)
if saved_record:
logger.debug(f"[DatabaseService] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
else:
logger.error(f"[DatabaseService] 存储动作信息失败: {action_name}")
return saved_record
except Exception as e:
logger.error(f"[DatabaseService] 存储动作信息时发生错误: {e}")
traceback.print_exc()
return None

View File

@@ -0,0 +1,406 @@
"""
表情服务模块
提供表情包相关的核心功能。
"""
import base64
import os
import random
import uuid
from typing import Any, Dict, List, Optional, Tuple
from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR
from src.common.logger import get_logger
from src.common.utils.utils_image import ImageUtils
from src.config.config import global_config
logger = get_logger("emoji_service")
# =============================================================================
# 表情包获取函数
# =============================================================================
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
"""根据描述选择表情包"""
if not description:
raise ValueError("描述不能为空")
if not isinstance(description, str):
raise TypeError("描述必须是字符串类型")
try:
logger.debug(f"[EmojiService] 根据描述获取表情包: {description}")
emoji_obj = await emoji_manager.get_emoji_for_emotion(description)
if not emoji_obj:
logger.warning(f"[EmojiService] 未找到匹配描述 '{description}' 的表情包")
return None
emoji_path = str(emoji_obj.full_path)
emoji_description = emoji_obj.description
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else ""
emoji_base64 = ImageUtils.image_path_to_base64(emoji_path)
if not emoji_base64:
logger.error(f"[EmojiService] 无法将表情包文件转换为base64: {emoji_path}")
return None
logger.debug(f"[EmojiService] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}")
return emoji_base64, emoji_description, matched_emotion
except Exception as e:
logger.error(f"[EmojiService] 获取表情包失败: {e}")
return None
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
"""随机获取指定数量的表情包"""
if not isinstance(count, int):
raise TypeError("count 必须是整数类型")
if count < 0:
raise ValueError("count 不能为负数")
if count == 0:
logger.warning("[EmojiService] count 为0返回空列表")
return []
try:
all_emojis = emoji_manager.emojis
if not all_emojis:
logger.warning("[EmojiService] 没有可用的表情包")
return []
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
if not valid_emojis:
logger.warning("[EmojiService] 没有有效的表情包")
return []
if len(valid_emojis) < count:
logger.debug(
f"[EmojiService] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包"
)
count = len(valid_emojis)
selected_emojis = random.sample(valid_emojis, count)
results = []
for selected_emoji in selected_emojis:
emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
if not emoji_base64:
logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}")
continue
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
emoji_manager.update_emoji_usage(selected_emoji)
results.append((emoji_base64, selected_emoji.description, matched_emotion))
if not results and count > 0:
logger.warning("[EmojiService] 随机获取表情包失败,没有一个可以成功处理")
return []
logger.debug(f"[EmojiService] 成功获取 {len(results)} 个随机表情包")
return results
except Exception as e:
logger.error(f"[EmojiService] 获取随机表情包失败: {e}")
return []
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
"""根据情感标签获取表情包"""
if not emotion:
raise ValueError("情感标签不能为空")
if not isinstance(emotion, str):
raise TypeError("情感标签必须是字符串类型")
try:
logger.info(f"[EmojiService] 根据情感获取表情包: {emotion}")
all_emojis = emoji_manager.emojis
matching_emojis = []
matching_emojis.extend(
emoji_obj
for emoji_obj in all_emojis
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
)
if not matching_emojis:
logger.warning(f"[EmojiService] 未找到匹配情感 '{emotion}' 的表情包")
return None
selected_emoji = random.choice(matching_emojis)
emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path)
if not emoji_base64:
logger.error(f"[EmojiService] 无法转换表情包为base64: {selected_emoji.full_path}")
return None
emoji_manager.update_emoji_usage(selected_emoji)
logger.info(f"[EmojiService] 成功获取情感表情包: {selected_emoji.description}")
return emoji_base64, selected_emoji.description, emotion
except Exception as e:
logger.error(f"[EmojiService] 根据情感获取表情包失败: {e}")
return None
# =============================================================================
# 表情包信息查询函数
# =============================================================================
def get_count() -> int:
try:
return len(emoji_manager.emojis)
except Exception as e:
logger.error(f"[EmojiService] 获取表情包数量失败: {e}")
return 0
def get_info():
try:
return {
"current_count": len(emoji_manager.emojis),
"max_count": global_config.emoji.max_reg_num,
"available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]),
}
except Exception as e:
logger.error(f"[EmojiService] 获取表情包信息失败: {e}")
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
def get_emotions() -> List[str]:
try:
emotions = set()
for emoji_obj in emoji_manager.emojis:
if not emoji_obj.is_deleted and emoji_obj.emotion:
emotions.update(emoji_obj.emotion)
return sorted(list(emotions))
except Exception as e:
logger.error(f"[EmojiService] 获取情感标签失败: {e}")
return []
async def get_all() -> List[Tuple[str, str, str]]:
try:
all_emojis = emoji_manager.emojis
if not all_emojis:
logger.warning("[EmojiService] 没有可用的表情包")
return []
results = []
for emoji_obj in all_emojis:
if emoji_obj.is_deleted:
continue
emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path))
if not emoji_base64:
logger.error(f"[EmojiService] 无法转换表情包为base64: {emoji_obj.full_path}")
continue
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "随机表情"
results.append((emoji_base64, emoji_obj.description, matched_emotion))
logger.debug(f"[EmojiService] 成功获取 {len(results)} 个表情包")
return results
except Exception as e:
logger.error(f"[EmojiService] 获取所有表情包失败: {e}")
return []
def get_descriptions() -> List[str]:
try:
descriptions = []
descriptions.extend(
emoji_obj.description
for emoji_obj in emoji_manager.emojis
if not emoji_obj.is_deleted and emoji_obj.description
)
return descriptions
except Exception as e:
logger.error(f"[EmojiService] 获取表情包描述失败: {e}")
return []
# =============================================================================
# 表情包注册函数
# =============================================================================
async def register_emoji(image_base64: str, filename: Optional[str] = None) -> Dict[str, Any]:
"""注册新的表情包"""
if not image_base64:
raise ValueError("图片base64编码不能为空")
if not isinstance(image_base64, str):
raise TypeError("image_base64必须是字符串类型")
if filename is not None and not isinstance(filename, str):
raise TypeError("filename必须是字符串类型或None")
try:
logger.info(f"[EmojiService] 开始注册表情包,文件名: {filename or '自动生成'}")
count_before = len(emoji_manager.emojis)
max_count = global_config.emoji.max_reg_num
can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace)
if not can_register:
return {
"success": False,
"message": f"表情包数量已达上限({count_before}/{max_count})且未启用替换功能",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
os.makedirs(EMOJI_DIR, exist_ok=True)
if not filename:
import time as _time
timestamp = int(_time.time())
microseconds = int(_time.time() * 1000000) % 1000000
random_bytes = random.getrandbits(72).to_bytes(9, "big")
short_id = base64.b64encode(random_bytes).decode("ascii")[:12].rstrip("=")
short_id = short_id.replace("/", "_").replace("+", "-")
filename = f"emoji_{timestamp}_{microseconds}_{short_id}"
if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
filename = f"{filename}.png"
temp_file_path = os.path.join(EMOJI_DIR, filename)
attempts = 0
max_attempts = 10
while os.path.exists(temp_file_path) and attempts < max_attempts:
random_bytes = random.getrandbits(48).to_bytes(6, "big")
short_id = base64.b64encode(random_bytes).decode("ascii")[:8].rstrip("=")
short_id = short_id.replace("/", "_").replace("+", "-")
name_part, ext = os.path.splitext(filename)
base_name = name_part.rsplit("_", 1)[0]
filename = f"{base_name}_{short_id}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
attempts += 1
if os.path.exists(temp_file_path):
uuid_short = str(uuid.uuid4())[:8]
name_part, ext = os.path.splitext(filename)
base_name = name_part.rsplit("_", 1)[0]
filename = f"{base_name}_{uuid_short}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
counter = 1
original_filename = filename
while os.path.exists(temp_file_path):
name_part, ext = os.path.splitext(original_filename)
filename = f"{name_part}_{counter}{ext}"
temp_file_path = os.path.join(EMOJI_DIR, filename)
counter += 1
if counter > 100:
logger.error(f"[EmojiService] 无法生成唯一文件名,尝试次数过多: {original_filename}")
return {
"success": False,
"message": "无法生成唯一文件名,请稍后重试",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
try:
if not ImageUtils.base64_to_image(image_base64, temp_file_path):
logger.error(f"[EmojiService] 无法保存base64图片到文件: {temp_file_path}")
return {
"success": False,
"message": "无法保存图片文件",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
logger.debug(f"[EmojiService] 图片已保存到临时文件: {temp_file_path}")
except Exception as save_error:
logger.error(f"[EmojiService] 保存图片文件失败: {save_error}")
return {
"success": False,
"message": f"保存图片文件失败: {str(save_error)}",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
register_success = await emoji_manager.register_emoji_by_filename(filename)
if not register_success and os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
logger.debug(f"[EmojiService] 已清理临时文件: {temp_file_path}")
except Exception as cleanup_error:
logger.warning(f"[EmojiService] 清理临时文件失败: {cleanup_error}")
if register_success:
count_after = len(emoji_manager.emojis)
replaced = count_after <= count_before
new_emoji_info = None
if count_after > count_before or replaced:
try:
for emoji_obj in reversed(emoji_manager.emojis):
if not emoji_obj.is_deleted and (
emoji_obj.file_name == filename
or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path))
):
new_emoji_info = emoji_obj
break
except Exception as find_error:
logger.warning(f"[EmojiService] 查找新注册表情包信息失败: {find_error}")
description = new_emoji_info.description if new_emoji_info else None
emotions = new_emoji_info.emotion if new_emoji_info else None
emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None
return {
"success": True,
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
"description": description,
"emotions": emotions,
"replaced": replaced,
"hash": emoji_hash,
}
else:
return {
"success": False,
"message": "表情包注册失败,可能因为重复、格式不支持或审核未通过",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}
except Exception as e:
logger.error(f"[EmojiService] 注册表情包时发生异常: {e}")
return {
"success": False,
"message": f"注册过程中发生错误: {str(e)}",
"description": None,
"emotions": None,
"replaced": None,
"hash": None,
}

View File

@@ -0,0 +1,21 @@
"""频率控制服务模块
提供聊天频率控制的核心功能。
"""
from src.chat.heart_flow.frequency_control import frequency_control_manager
from src.config.config import global_config
def get_current_talk_value(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(
chat_id
).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id)
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:
frequency_control_manager.get_or_create_frequency_control(chat_id).set_talk_frequency_adjust(talk_frequency_adjust)
def get_talk_frequency_adjust(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust()

View File

@@ -0,0 +1,256 @@
"""
回复器服务模块
提供回复器相关的核心功能。
"""
import traceback
import time
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from rich.traceback import install
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.replyer.group_generator import DefaultReplyer
from src.chat.replyer.private_generator import PrivateReplyer
from src.chat.replyer.replyer_manager import replyer_manager
from src.chat.utils.utils import process_llm_response
from src.common.data_models.message_data_model import ReplySetModel
from src.common.logger import get_logger
from src.core.types import ActionInfo
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel
install(extra_lines=3)
logger = get_logger("generator_service")
# =============================================================================
# 回复器获取函数
# =============================================================================
def get_replyer(
chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer | PrivateReplyer]:
"""获取回复器对象"""
if not chat_id and not chat_stream:
raise ValueError("chat_stream 和 chat_id 不可均为空")
try:
logger.debug(f"[GeneratorService] 正在获取回复器chat_id: {chat_id}, chat_stream: {'' if chat_stream else ''}")
return replyer_manager.get_replyer(
chat_stream=chat_stream,
chat_id=chat_id,
request_type=request_type,
)
except Exception as e:
logger.error(f"[GeneratorService] 获取回复器时发生意外错误: {e}", exc_info=True)
traceback.print_exc()
return None
# =============================================================================
# 回复生成函数
# =============================================================================
async def generate_reply(
chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None,
action_data: Optional[Dict[str, Any]] = None,
reply_message: Optional["DatabaseMessages"] = None,
think_level: int = 1,
extra_info: str = "",
reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
chosen_actions: Optional[List["ActionPlannerInfo"]] = None,
unknown_words: Optional[List[str]] = None,
enable_tool: bool = False,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
request_type: str = "generator_api",
from_plugin: bool = True,
reply_time_point: Optional[float] = None,
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
"""生成回复"""
try:
if reply_time_point is None:
reply_time_point = time.time()
logger.debug("[GeneratorService] 开始生成回复")
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorService] 无法获取回复器")
return False, None
if action_data:
if not extra_info:
extra_info = action_data.get("extra_info", "")
if not reply_reason:
reply_reason = action_data.get("reason", "")
if unknown_words is None:
uw = action_data.get("unknown_words")
if isinstance(uw, list):
cleaned: List[str] = []
for item in uw:
if isinstance(item, str):
s = item.strip()
if s:
cleaned.append(s)
if cleaned:
unknown_words = cleaned
success, llm_response = await replyer.generate_reply_with_context(
extra_info=extra_info,
available_actions=available_actions,
chosen_actions=chosen_actions,
enable_tool=enable_tool,
reply_message=reply_message,
reply_reason=reply_reason,
unknown_words=unknown_words,
think_level=think_level,
from_plugin=from_plugin,
stream_id=chat_stream.session_id if chat_stream else chat_id,
reply_time_point=reply_time_point,
log_reply=False,
)
if not success:
logger.warning("[GeneratorService] 回复生成失败")
return False, None
reply_set: Optional[ReplySetModel] = None
if content := llm_response.content:
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
llm_response.processed_output = processed_response
reply_set = ReplySetModel()
for text in processed_response:
reply_set.add_text_content(text)
llm_response.reply_set = reply_set
logger.debug(f"[GeneratorService] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
try:
PlanReplyLogger.log_reply(
chat_id=chat_stream.session_id if chat_stream else (chat_id or ""),
prompt=llm_response.prompt or "",
output=llm_response.content,
processed_output=llm_response.processed_output,
model=llm_response.model,
timing=llm_response.timing,
reasoning=llm_response.reasoning,
think_level=think_level,
success=True,
)
except Exception:
logger.exception("[GeneratorService] 记录reply日志失败")
return success, llm_response
except ValueError as ve:
raise ve
except UserWarning as uw:
logger.warning(f"[GeneratorService] 中断了生成: {uw}")
return False, None
except Exception as e:
logger.error(f"[GeneratorService] 生成回复时出错: {e}")
logger.error(traceback.format_exc())
return False, None
async def rewrite_reply(
chat_stream: Optional[BotChatSession] = None,
reply_data: Optional[Dict[str, Any]] = None,
chat_id: Optional[str] = None,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
raw_reply: str = "",
reason: str = "",
reply_to: str = "",
request_type: str = "generator_api",
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
"""重写回复"""
try:
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorService] 无法获取回复器")
return False, None
logger.info("[GeneratorService] 开始重写回复")
if reply_data:
raw_reply = raw_reply or reply_data.get("raw_reply", "")
reason = reason or reply_data.get("reason", "")
reply_to = reply_to or reply_data.get("reply_to", "")
success, llm_response = await replyer.rewrite_reply_with_context(
raw_reply=raw_reply,
reason=reason,
reply_to=reply_to,
)
reply_set: Optional[ReplySetModel] = None
if success and llm_response and (content := llm_response.content):
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
llm_response.reply_set = reply_set
if success:
logger.info(f"[GeneratorService] 重写回复成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
else:
logger.warning("[GeneratorService] 重写回复失败")
return success, llm_response
except ValueError as ve:
raise ve
except Exception as e:
logger.error(f"[GeneratorService] 重写回复时出错: {e}")
return False, None
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[ReplySetModel]:
"""将文本处理为更拟人化的文本"""
if not isinstance(content, str):
raise ValueError("content 必须是字符串类型")
try:
reply_set = ReplySetModel()
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
for text in processed_response:
reply_set.add_text_content(text)
return reply_set
except Exception as e:
logger.error(f"[GeneratorService] 处理人形文本时出错: {e}")
return None
async def generate_response_custom(
chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None,
request_type: str = "generator_api",
prompt: str = "",
) -> Optional[str]:
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorService] 无法获取回复器")
return None
try:
logger.debug("[GeneratorService] 开始生成自定义回复")
response, _, _, _ = await replyer.llm_generate_content(prompt)
if response:
logger.debug("[GeneratorService] 自定义回复生成成功")
return response
else:
logger.warning("[GeneratorService] 自定义回复生成失败")
return None
except Exception as e:
logger.error(f"[GeneratorService] 生成自定义回复时出错: {e}")
return None

155
src/services/llm_service.py Normal file
View File

@@ -0,0 +1,155 @@
"""LLM 服务模块
提供与 LLM 模型交互的核心功能。
"""
from typing import Any, Callable, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from src.config.config import config_manager
from src.config.model_configs import TaskConfig
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.payload_content.message import Message
from src.llm_models.payload_content.tool_option import ToolCall
from src.llm_models.utils_model import LLMRequest
logger = get_logger("llm_service")
def get_available_models() -> Dict[str, TaskConfig]:
"""获取所有可用的模型配置
Returns:
Dict[str, Any]: 模型配置字典key为模型名称value为模型配置
"""
try:
models = config_manager.get_model_config().model_task_config
attrs = dir(models)
rets: Dict[str, TaskConfig] = {}
for attr in attrs:
if not attr.startswith("__"):
try:
value = getattr(models, attr)
if not callable(value) and isinstance(value, TaskConfig):
rets[attr] = value
except Exception as e:
logger.debug(f"[LLMService] 获取属性 {attr} 失败: {e}")
continue
return rets
except Exception as e:
logger.error(f"[LLMService] 获取可用模型失败: {e}")
return {}
async def generate_with_model(
prompt: str,
model_config: TaskConfig,
request_type: str = "plugin.generate",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str]:
"""使用指定模型生成内容
Args:
prompt: 提示词
model_config: 模型配置(从 get_available_models 获取的模型配置)
request_type: 请求类型标识
Returns:
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
"""
try:
logger.debug(f"[LLMService] 完整提示词: {prompt}")
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(
prompt, temperature=temperature, max_tokens=max_tokens
)
return True, response, reasoning_content, model_name
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMService] {error_msg}")
return False, error_msg, "", ""
async def generate_with_model_with_tools(
prompt: str,
model_config: TaskConfig,
tool_options: List[Dict[str, Any]] | None = None,
request_type: str = "plugin.generate",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
"""使用指定模型和工具生成内容
Args:
prompt: 提示词
model_config: 模型配置(从 get_available_models 获取的模型配置)
tool_options: 工具选项列表
request_type: 请求类型标识
temperature: 温度参数
max_tokens: 最大token数
Returns:
Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表)
"""
try:
model_name_list = model_config.model_list
logger.info(f"使用模型{model_name_list}生成内容")
logger.debug(f"完整提示词: {prompt}")
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
prompt, tools=tool_options, temperature=temperature, max_tokens=max_tokens
)
return True, response, reasoning_content, model_name, tool_call
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMService] {error_msg}")
return False, error_msg, "", "", None
async def generate_with_model_with_tools_by_message_factory(
message_factory: Callable[[BaseClient], List[Message]],
model_config: TaskConfig,
tool_options: List[Dict[str, Any]] | None = None,
request_type: str = "plugin.generate",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
"""使用指定模型和工具生成内容(通过消息工厂构建消息列表)
Args:
message_factory: 消息工厂函数
model_config: 模型配置
tool_options: 工具选项列表
request_type: 请求类型标识
temperature: 温度参数
max_tokens: 最大token数
Returns:
Tuple[bool, str, str, str, List[ToolCall] | None]: (是否成功, 生成的内容, 推理过程, 模型名称, 工具调用列表)
"""
try:
model_name_list = model_config.model_list
logger.info(f"使用模型 {model_name_list} 生成内容")
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_with_message_async(
message_factory=message_factory,
tools=tool_options,
temperature=temperature,
max_tokens=max_tokens,
)
return True, response, reasoning_content, model_name, tool_call
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMService] {error_msg}")
return False, error_msg, "", "", None

View File

@@ -0,0 +1,296 @@
"""
消息服务模块
提供消息查询和构建成字符串的核心功能。
"""
import time
from typing import Any, Dict, List, Optional, Tuple
from sqlmodel import col, select
from src.chat.utils.chat_message_builder import (
build_readable_messages,
build_readable_messages_with_list,
get_person_id_list,
get_raw_msg_before_timestamp,
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_before_timestamp_with_users,
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_random,
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_by_timestamp_with_chat_users,
get_raw_msg_by_timestamp_with_users,
num_new_messages_since,
num_new_messages_since_with_users,
)
from src.chat.utils.utils import is_bot_self
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
# =============================================================================
# 消息查询函数
# =============================================================================
def get_messages_by_time(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode))
return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)
def get_messages_by_time_in_chat(
chat_id: str,
start_time: float,
end_time: float,
limit: int = 0,
limit_mode: str = "latest",
filter_mai: bool = False,
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return get_raw_msg_by_timestamp_with_chat(
chat_id=chat_id,
timestamp_start=start_time,
timestamp_end=end_time,
limit=limit,
limit_mode=limit_mode,
filter_bot=filter_mai,
filter_command=filter_command,
filter_intercept_message_level=filter_intercept_message_level,
)
def get_messages_by_time_in_chat_inclusive(
chat_id: str,
start_time: float,
end_time: float,
limit: int = 0,
limit_mode: str = "latest",
filter_mai: bool = False,
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=chat_id,
timestamp_start=start_time,
timestamp_end=end_time,
limit=limit,
limit_mode=limit_mode,
filter_bot=filter_mai,
filter_command=filter_command,
filter_intercept_message_level=filter_intercept_message_level,
)
if filter_mai:
return filter_mai_messages(messages)
return messages
def get_messages_by_time_in_chat_for_users(
chat_id: str,
start_time: float,
end_time: float,
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[DatabaseMessages]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode)
def get_random_chat_messages(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode))
return get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)
def get_messages_by_time_for_users(
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[DatabaseMessages]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[DatabaseMessages]:
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai:
return filter_mai_messages(get_raw_msg_before_timestamp(timestamp, limit))
return get_raw_msg_before_timestamp(timestamp, limit)
def get_messages_before_time_in_chat(
chat_id: str,
timestamp: float,
limit: int = 0,
filter_mai: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
messages = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=timestamp,
limit=limit,
filter_intercept_message_level=filter_intercept_message_level,
)
if filter_mai:
return filter_mai_messages(messages)
return messages
def get_messages_before_time_for_users(
timestamp: float, person_ids: List[str], limit: int = 0
) -> List[DatabaseMessages]:
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
return get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit)
def get_recent_messages(
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]:
if not isinstance(hours, (int, float)) or hours < 0:
raise ValueError("hours 不能是负数")
if not isinstance(limit, int) or limit < 0:
raise ValueError("limit 必须是非负整数")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
now = time.time()
start_time = now - hours * 3600
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode))
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode)
# =============================================================================
# 消息计数函数
# =============================================================================
def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int:
if not isinstance(start_time, (int, float)):
raise ValueError("start_time 必须是数字类型")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return num_new_messages_since(chat_id, start_time, end_time)
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if not chat_id:
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids)
# =============================================================================
# 消息格式化函数
# =============================================================================
def build_readable_messages_to_str(
messages: List[DatabaseMessages],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
) -> str:
return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions)
async def build_readable_messages_with_details(
messages: List[DatabaseMessages],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
return await build_readable_messages_with_list(messages, replace_bot_name, timestamp_mode, truncate)
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
return await get_person_id_list(messages)
# =============================================================================
# 消息过滤函数
# =============================================================================
def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
"""从消息列表中移除麦麦的消息"""
return [msg for msg in messages if not is_bot_self(msg.user_info.platform, msg.user_info.user_id)]
def translate_pid_to_description(pid: str) -> str:
with get_db_session() as session:
statement = (
select(Images).where((col(Images.id) == int(pid)) & (col(Images.image_type) == ImageType.IMAGE))
if pid.isdigit()
else None
)
image = session.exec(statement).first() if statement is not None else None
description = ""
if image and image.description and image.description.strip():
description = image.description.strip()
else:
description = "[图片]"
return description

View File

@@ -0,0 +1,65 @@
"""个人信息服务模块
提供个人信息查询的核心功能。
"""
from typing import Any
from src.common.logger import get_logger
from src.person_info.person_info import Person
logger = get_logger("person_service")
def get_person_id(platform: str, user_id: int | str) -> str:
"""根据平台和用户ID获取person_id
Args:
platform: 平台名称,如 "qq", "telegram"
user_id: 用户ID
Returns:
str: 唯一的person_idMD5哈希值
"""
try:
return Person(platform=platform, user_id=str(user_id)).person_id
except Exception as e:
logger.error(f"[PersonService] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}")
return ""
async def get_person_value(person_id: str, field_name: str, default: Any = None) -> Any:
"""根据person_id和字段名获取某个值
Args:
person_id: 用户的唯一标识ID
field_name: 要获取的字段名,如 "nickname", "impression"
default: 当字段不存在或获取失败时返回的默认值
Returns:
Any: 字段值或默认值
"""
try:
person = Person(person_id=person_id)
value = getattr(person, field_name)
return value if value is not None else default
except Exception as e:
logger.error(f"[PersonService] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}")
return default
def get_person_id_by_name(person_name: str) -> str:
"""根据用户名获取person_id
Args:
person_name: 用户名
Returns:
str: person_id如果未找到返回空字符串
"""
try:
person = Person(person_name=person_name)
return person.person_id
except Exception as e:
logger.error(f"[PersonService] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
return ""

View File

@@ -0,0 +1,347 @@
"""
发送服务模块
提供发送各种类型消息的核心功能。
"""
import traceback
import time
from typing import Optional, Union, Dict, List, TYPE_CHECKING, Tuple
from maim_message import MessageBase, BaseMessageInfo, Seg
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.common.data_models.message_data_model import ReplyContentType
from src.common.logger import get_logger
from src.config.config import global_config
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_data_model import ForwardNode, ReplyContent, ReplySetModel
logger = get_logger("send_service")
# =============================================================================
# 内部实现函数
# =============================================================================
async def _send_to_target(
message_segment: Seg,
stream_id: str,
display_message: str = "",
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
show_log: bool = True,
selected_expressions: Optional[List[int]] = None,
) -> bool:
"""向指定目标发送消息的内部实现"""
try:
if set_reply and not reply_message:
logger.warning("[SendService] 使用引用回复,但未提供回复消息")
return False
if show_log:
logger.debug(f"[SendService] 发送{message_segment.type}消息到 {stream_id}")
target_stream = _chat_manager.get_session_by_session_id(stream_id)
if not target_stream:
logger.error(f"[SendService] 未找到聊天流: {stream_id}")
return False
message_sender = UniversalMessageSender()
current_time = time.time()
message_id = f"send_api_{int(current_time * 1000)}"
bot_user_info = UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
)
reply_to_platform_id = ""
anchor_message: Optional[MaiMessage] = None
if reply_message:
anchor_message = db_message_to_mai_message(reply_message)
if anchor_message:
logger.debug(f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}")
reply_to_platform_id = (
f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}"
)
sender_info = None
if target_stream.context and target_stream.context.message:
sender_info = target_stream.context.message.message_info.user_info
bot_message = MessageSending(
message_id=message_id,
session=target_stream,
bot_user_info=bot_user_info,
sender_info=sender_info,
message_segment=message_segment,
display_message=display_message,
reply=anchor_message,
is_head=True,
is_emoji=(message_segment.type == "emoji"),
thinking_start_time=current_time,
reply_to=reply_to_platform_id,
selected_expressions=selected_expressions,
)
sent_msg = await message_sender.send_message(
bot_message,
typing=typing,
set_reply=set_reply,
storage_message=storage_message,
show_log=show_log,
)
if sent_msg:
logger.debug(f"[SendService] 成功发送消息到 {stream_id}")
return True
else:
logger.error("[SendService] 发送消息失败")
return False
except Exception as e:
logger.error(f"[SendService] 发送消息时出错: {e}")
traceback.print_exc()
return False
def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMessage]:
"""将数据库消息重建为 MaiMessage 对象,用于回复引用。"""
from datetime import datetime
from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo
from src.common.data_models.message_component_data_model import MessageSequence
user_info = UserInfo(
user_id=message_obj.user_info.user_id or "",
user_nickname=message_obj.user_info.user_nickname or "",
user_cardname=message_obj.user_info.user_cardname,
)
group_info = None
if message_obj.chat_info.group_info:
group_info = GroupInfo(
group_id=message_obj.chat_info.group_info.group_id or "",
group_name=message_obj.chat_info.group_info.group_name or "",
)
msg = MaiMessage(
message_id=message_obj.message_id,
timestamp=datetime.fromtimestamp(message_obj.time) if message_obj.time else datetime.now(),
)
msg.message_info = MessageInfo(user_info=user_info, group_info=group_info)
msg.platform = message_obj.chat_info.platform or ""
msg.session_id = message_obj.chat_info.stream_id or ""
msg.processed_plain_text = message_obj.processed_plain_text
msg.raw_message = MessageSequence(components=[])
msg.initialized = True
return msg
# =============================================================================
# 公共函数 - 预定义类型的发送函数
# =============================================================================
async def text_to_stream(
text: str,
stream_id: str,
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
selected_expressions: Optional[List[int]] = None,
) -> bool:
"""向指定流发送文本消息"""
return await _send_to_target(
message_segment=Seg(type="text", data=text),
stream_id=stream_id,
display_message="",
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
selected_expressions=selected_expressions,
)
async def emoji_to_stream(
emoji_base64: str,
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""向指定流发送表情包"""
return await _send_to_target(
message_segment=Seg(type="emoji", data=emoji_base64),
stream_id=stream_id,
display_message="",
typing=False,
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
)
async def image_to_stream(
image_base64: str,
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""向指定流发送图片"""
return await _send_to_target(
message_segment=Seg(type="image", data=image_base64),
stream_id=stream_id,
display_message="",
typing=False,
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
)
async def command_to_stream(
command: Union[str, dict],
stream_id: str,
storage_message: bool = True,
display_message: str = "",
) -> bool:
"""向指定流发送命令"""
return await _send_to_target(
message_segment=Seg(type="command", data=command), # type: ignore
stream_id=stream_id,
display_message=display_message,
typing=False,
storage_message=storage_message,
set_reply=False,
)
async def custom_to_stream(
message_type: str,
content: str | Dict,
stream_id: str,
display_message: str = "",
typing: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
set_reply: bool = False,
storage_message: bool = True,
show_log: bool = True,
) -> bool:
"""向指定流发送自定义类型消息"""
return await _send_to_target(
message_segment=Seg(type=message_type, data=content), # type: ignore
stream_id=stream_id,
display_message=display_message,
typing=typing,
reply_message=reply_message,
set_reply=set_reply,
storage_message=storage_message,
show_log=show_log,
)
async def custom_reply_set_to_stream(
reply_set: "ReplySetModel",
stream_id: str,
display_message: str = "",
typing: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
set_reply: bool = False,
storage_message: bool = True,
show_log: bool = True,
) -> bool:
"""向指定流发送混合型消息集"""
flag: bool = True
for reply_content in reply_set.reply_data:
status: bool = False
message_seg, need_typing = _parse_content_to_seg(reply_content)
status = await _send_to_target(
message_segment=message_seg,
stream_id=stream_id,
display_message=display_message,
typing=bool(need_typing and typing),
reply_message=reply_message,
set_reply=set_reply,
storage_message=storage_message,
show_log=show_log,
)
if not status:
flag = False
logger.error(
f"[SendService] 发送{repr(reply_content.content_type)}消息失败,消息内容:{str(reply_content.content)[:100]}"
)
return flag
def _parse_content_to_seg(reply_content: "ReplyContent") -> Tuple[Seg, bool]:
"""把 ReplyContent 转换为 Seg 结构"""
content_type = reply_content.content_type
if content_type == ReplyContentType.TEXT:
text_data: str = reply_content.content # type: ignore
return Seg(type="text", data=text_data), True
elif content_type == ReplyContentType.IMAGE:
return Seg(type="image", data=reply_content.content), False # type: ignore
elif content_type == ReplyContentType.EMOJI:
return Seg(type="emoji", data=reply_content.content), False # type: ignore
elif content_type == ReplyContentType.COMMAND:
return Seg(type="command", data=reply_content.content), False # type: ignore
elif content_type == ReplyContentType.VOICE:
return Seg(type="voice", data=reply_content.content), False # type: ignore
elif content_type == ReplyContentType.HYBRID:
hybrid_message_list_data: List[ReplyContent] = reply_content.content # type: ignore
assert isinstance(hybrid_message_list_data, list), "混合类型内容必须是列表"
sub_seg_list: List[Seg] = []
for sub_content in hybrid_message_list_data:
sub_content_type = sub_content.content_type
sub_content_data = sub_content.content
if sub_content_type == ReplyContentType.TEXT:
sub_seg_list.append(Seg(type="text", data=sub_content_data)) # type: ignore
elif sub_content_type == ReplyContentType.IMAGE:
sub_seg_list.append(Seg(type="image", data=sub_content_data)) # type: ignore
elif sub_content_type == ReplyContentType.EMOJI:
sub_seg_list.append(Seg(type="emoji", data=sub_content_data)) # type: ignore
else:
logger.warning(f"[SendService] 混合类型中不支持的子内容类型: {repr(sub_content_type)}")
continue
return Seg(type="seglist", data=sub_seg_list), True
elif content_type == ReplyContentType.FORWARD:
forward_message_list_data: List["ForwardNode"] = reply_content.content # type: ignore
assert isinstance(forward_message_list_data, list), "转发类型内容必须是列表"
forward_message_list: List[Dict] = []
for forward_node in forward_message_list_data:
message_segment = Seg(type="id", data=forward_node.content) # type: ignore
user_info: Optional[UserInfo] = None
if forward_node.user_id and forward_node.user_nickname:
assert isinstance(forward_node.content, list), "转发节点内容必须是列表"
user_info = UserInfo(user_id=forward_node.user_id, user_nickname=forward_node.user_nickname)
single_node_content: List[Seg] = []
for sub_content in forward_node.content:
if sub_content.content_type != ReplyContentType.FORWARD:
sub_seg, _ = _parse_content_to_seg(sub_content)
single_node_content.append(sub_seg)
message_segment = Seg(type="seglist", data=single_node_content)
forward_message_list.append(
MessageBase(
message_segment=message_segment, message_info=BaseMessageInfo(user_info=user_info)
).to_dict()
)
return Seg(type="forward", data=forward_message_list), False # type: ignore
else:
message_type_in_str = content_type.value if isinstance(content_type, ReplyContentType) else str(content_type)
return Seg(type=message_type_in_str, data=reply_content.content), True # type: ignore