部分模块的新数据结构适配

This commit is contained in:
DrSmoothl
2026-03-13 23:36:17 +08:00
parent 6201b862c9
commit 898fab6de9
7 changed files with 580 additions and 399 deletions

View File

@@ -8,7 +8,7 @@ from rich.traceback import install
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.message_data_model import ReplyContentType
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.timer_calculator import Timer
@@ -35,7 +35,6 @@ from src.chat.utils.chat_message_builder import (
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_data_model import ReplySetModel
ERROR_LOOP_INFO = {
@@ -513,7 +512,7 @@ class BrainChatting:
async def _send_response(
self,
reply_set: "ReplySetModel",
reply_set: MessageSequence,
message_data: "DatabaseMessages",
selected_expressions: Optional[List[int]] = None,
) -> str:
@@ -528,10 +527,10 @@ class BrainChatting:
reply_text = ""
first_replied = False
for reply_content in reply_set.reply_data:
if reply_content.content_type != ReplyContentType.TEXT:
for component in reply_set.components:
if not isinstance(component, TextComponent):
continue
data: str = reply_content.content # type: ignore
data = component.text
if not first_replied:
await send_api.text_to_stream(
text=data,

View File

@@ -7,29 +7,31 @@ import re
from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.message_old import UserInfo, Seg, MessageRecv, MessageSending
from src.chat.message_receive.chat_stream import ChatStream
from maim_message import BaseMessageInfo, MessageBase, Seg
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.prompt.prompt_manager import prompt_manager
from src.chat.utils.chat_message_builder import (
from src.services.message_service import (
build_readable_messages,
get_raw_msg_before_timestamp_with_chat,
replace_user_references,
translate_pid_to_description,
)
from src.bw_learner.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import ActionInfo, EventType
from src.plugin_system.apis import llm_api
from src.core.types import ActionInfo, EventType
from src.services import llm_service as llm_api
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
@@ -45,17 +47,17 @@ logger = get_logger("replyer")
class DefaultReplyer:
def __init__(
self,
chat_stream: ChatStream,
chat_stream: BotChatSession,
request_type: str = "replyer",
):
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
self.heart_fc_sender = UniversalMessageSender()
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖
from src.chat.tool_executor import ToolExecutor
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3)
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
async def generate_reply_with_context(
self,
@@ -66,7 +68,7 @@ class DefaultReplyer:
enable_tool: bool = True,
from_plugin: bool = True,
stream_id: Optional[str] = None,
reply_message: Optional[DatabaseMessages] = None,
reply_message: Optional[SessionMessage] = None,
reply_time_point: float = time.time(),
think_level: int = 1,
unknown_words: Optional[List[str]] = None,
@@ -132,7 +134,7 @@ class DefaultReplyer:
if log_reply:
try:
PlanReplyLogger.log_reply(
chat_id=self.chat_stream.stream_id,
chat_id=self.chat_stream.session_id,
prompt="",
output=None,
processed_output=None,
@@ -146,12 +148,12 @@ class DefaultReplyer:
except Exception:
logger.exception("记录reply日志失败")
return False, llm_response
from src.plugin_system.core.events_manager import events_manager
from src.core.event_bus import event_bus
from src.chat.event_helpers import build_event_message
if not from_plugin:
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.POST_LLM, None, prompt, None, stream_id=stream_id
)
_event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id)
continue_flag, modified_message = await event_bus.emit(EventType.POST_LLM, _event_msg)
if not continue_flag:
raise UserWarning("插件于请求前中断了内容生成")
if modified_message and modified_message._modify_flags.modify_llm_prompt:
@@ -202,7 +204,7 @@ class DefaultReplyer:
try:
if log_reply:
PlanReplyLogger.log_reply(
chat_id=self.chat_stream.stream_id,
chat_id=self.chat_stream.session_id,
prompt=prompt,
output=content,
processed_output=None,
@@ -214,9 +216,10 @@ class DefaultReplyer:
)
except Exception:
logger.exception("记录reply日志失败")
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
_event_msg = build_event_message(
EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id
)
continue_flag, modified_message = await event_bus.emit(EventType.AFTER_LLM, _event_msg)
if not from_plugin and not continue_flag:
raise UserWarning("插件于请求后取消了内容生成")
if modified_message:
@@ -259,7 +262,7 @@ class DefaultReplyer:
if log_reply:
try:
PlanReplyLogger.log_reply(
chat_id=self.chat_stream.stream_id,
chat_id=self.chat_stream.session_id,
prompt=prompt or "",
output=None,
processed_output=None,
@@ -353,14 +356,14 @@ class DefaultReplyer:
str: 表达习惯信息字符串
"""
# 检查是否允许在此聊天流中使用表达
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id)
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.session_id)
if not use_expression:
return "", []
style_habits = []
# 使用从处理器传来的选中表达方式
# 使用模型预测选择表达方式
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id,
self.chat_stream.session_id,
chat_history,
max_num=8,
target_message=target,
@@ -594,7 +597,7 @@ class DefaultReplyer:
async def _build_jargon_explanation(
self,
chat_id: str,
messages_short: List[DatabaseMessages],
messages_short: List[SessionMessage],
chat_talking_prompt_short: str,
unknown_words: Optional[List[str]],
) -> str:
@@ -703,9 +706,13 @@ class DefaultReplyer:
is_group = stream_type == "group"
# 使用 ChatManager 提供的接口生成 chat_id避免在此重复实现逻辑
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.utils.utils_session import SessionUtils
chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
chat_id = SessionUtils.calculate_session_id(
platform,
group_id=str(id_str) if is_group else None,
user_id=str(id_str) if not is_group else None,
)
return chat_id, prompt_content
except (ValueError, IndexError):
@@ -751,7 +758,7 @@ class DefaultReplyer:
async def build_prompt_reply_context(
self,
reply_message: Optional[DatabaseMessages] = None,
reply_message: Optional[SessionMessage] = None,
extra_info: str = "",
reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
@@ -778,7 +785,7 @@ class DefaultReplyer:
if available_actions is None:
available_actions = {}
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
chat_id = chat_stream.session_id
_is_group_chat = bool(chat_stream.group_info)
platform = chat_stream.platform
@@ -1005,7 +1012,7 @@ class DefaultReplyer:
reply_to: str,
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
chat_id = chat_stream.session_id
sender, target = self._parse_reply_target(reply_to)
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
@@ -1105,31 +1112,29 @@ class DefaultReplyer:
is_emoji: bool,
thinking_start_time: float,
display_message: str,
anchor_message: Optional[MessageRecv] = None,
) -> MessageSending:
anchor_message: Optional[MaiMessage] = None,
) -> SessionMessage:
"""构建单个发送消息"""
bot_user_info = UserInfo(
user_id=str(global_config.bot.qq_account),
user_nickname=global_config.bot.nickname,
platform=self.chat_stream.platform,
)
# await anchor_message.process()
sender_info = anchor_message.message_info.user_info if anchor_message else None
return MessageSending(
message_id=message_id, # 使用片段的唯一ID
chat_stream=self.chat_stream,
bot_user_info=bot_user_info,
sender_info=sender_info,
maim_message = MessageBase(
message_info=BaseMessageInfo(
platform=self.chat_stream.platform,
message_id=message_id,
time=thinking_start_time,
user_info=UserInfo(
user_id=str(global_config.bot.qq_account),
user_nickname=global_config.bot.nickname,
),
additional_config={},
),
message_segment=message_segment,
reply=anchor_message, # 回复原始锚点
is_head=reply_to,
is_emoji=is_emoji,
thinking_start_time=thinking_start_time, # 传递原始思考开始时间
display_message=display_message,
)
message = SessionMessage.from_maim_message(maim_message)
message.session_id = self.chat_stream.session_id
message.display_message = display_message
message.reply_to = anchor_message.message_id if reply_to and anchor_message else None
message.is_emoji = is_emoji
return message
async def llm_generate_content(self, prompt: str):
with Timer("LLM生成", {}): # 内部计时器,可选保留

View File

@@ -7,33 +7,31 @@ import re
from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from maim_message import Seg
from maim_message import BaseMessageInfo, MessageBase, Seg
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.chat.message_receive.message_old import MessageSending
from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.prompt.prompt_manager import prompt_manager
from src.chat.utils.common_utils import TempMethodsExpression
from src.chat.utils.chat_message_builder import (
from src.services.message_service import (
build_readable_messages,
get_raw_msg_before_timestamp_with_chat,
replace_user_references,
translate_pid_to_description,
)
from src.bw_learner.expression_selector import expression_selector
from src.services.message_service import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person, is_person_known
from src.core.types import ActionInfo, EventType
from src.services import llm_service as llm_api
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
from src.bw_learner.jargon_explainer_old import explain_jargon_in_context
@@ -69,7 +67,7 @@ class PrivateReplyer:
from_plugin: bool = True,
think_level: int = 1,
stream_id: Optional[str] = None,
reply_message: Optional[DatabaseMessages] = None,
reply_message: Optional[SessionMessage] = None,
reply_time_point: Optional[float] = time.time(),
unknown_words: Optional[List[str]] = None,
log_reply: bool = True,
@@ -604,7 +602,7 @@ class PrivateReplyer:
async def build_prompt_reply_context(
self,
reply_message: Optional[DatabaseMessages] = None,
reply_message: Optional[SessionMessage] = None,
extra_info: str = "",
reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
@@ -954,28 +952,29 @@ class PrivateReplyer:
thinking_start_time: float,
display_message: str,
anchor_message: Optional[MaiMessage] = None,
) -> MessageSending:
) -> SessionMessage:
"""构建单个发送消息"""
bot_user_info = UserInfo(
user_id=str(global_config.bot.qq_account),
user_nickname=global_config.bot.nickname,
)
sender_info = anchor_message.message_info.user_info if anchor_message else None
return MessageSending(
message_id=message_id,
session=self.chat_stream,
bot_user_info=bot_user_info,
sender_info=sender_info,
maim_message = MessageBase(
message_info=BaseMessageInfo(
platform=self.chat_stream.platform,
message_id=message_id,
time=thinking_start_time,
user_info=UserInfo(
user_id=str(global_config.bot.qq_account),
user_nickname=global_config.bot.nickname,
),
group_info=None,
additional_config={},
),
message_segment=message_segment,
reply=anchor_message,
is_head=reply_to,
is_emoji=is_emoji,
thinking_start_time=thinking_start_time,
display_message=display_message,
)
message = SessionMessage.from_maim_message(maim_message)
message.session_id = self.chat_stream.session_id
message.display_message = display_message
message.reply_to = anchor_message.message_id if reply_to and anchor_message else None
message.is_emoji = is_emoji
return message
async def llm_generate_content(self, prompt: str):
with Timer("LLM生成", {}): # 内部计时器,可选保留
@@ -999,55 +998,9 @@ class PrivateReplyer:
return content, reasoning_content, model_name, tool_calls
async def get_prompt_info(self, message: str, sender: str, target: str):
related_info = ""
start_time = time.time()
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 从LPMM知识库获取知识
try:
# 检查LPMM知识库是否启用
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用跳过获取知识库内容")
return ""
if global_config.lpmm_knowledge.lpmm_mode == "agent":
return ""
prompt_template = prompt_manager.get_prompt("lpmm_get_knowledge")
prompt_template.add_context("bot_name", global_config.bot.nickname)
prompt_template.add_context("time_now", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
prompt_template.add_context("chat_history", message)
prompt_template.add_context("sender", sender)
prompt_template.add_context("target_message", target)
prompt = await prompt_manager.render_prompt(prompt_template)
_, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools(
prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
)
if tool_calls:
result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
end_time = time.time()
if not result or not result.get("content"):
logger.debug("从LPMM知识库获取知识失败返回空知识...")
return ""
found_knowledge_from_lpmm = result.get("content", "")
logger.debug(
f"从LPMM知识库获取知识相关信息{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
)
related_info += found_knowledge_from_lpmm
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}")
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return f"你有以下这些**知识**\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
else:
logger.debug("模型认为不需要使用LPMM知识库")
return ""
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
return ""
logger.debug(f"已跳过知识库信息获取,元消息:{message[:30]}...,消息长度: {len(message)}")
del message, sender, target
return ""
def weighted_sample_no_replacement(items, weights, k) -> list:

View File

@@ -1,5 +1,6 @@
import traceback
from datetime import datetime
from types import SimpleNamespace
from typing import Any
import json
@@ -9,7 +10,7 @@ from sqlmodel import col, select
from src.common.database.database import get_db_session
from src.common.database.database_model import Messages
from src.common.data_models.database_data_model import DatabaseMessages
from src.chat.message_receive.message import SessionMessage
from src.common.logger import get_logger
from src.config.config import global_config
@@ -58,53 +59,37 @@ def _normalize_optional_str(value: object) -> str | None:
return str(value)
def _message_to_instance(message: Messages) -> DatabaseMessages:
def _message_to_instance(message: Messages) -> SessionMessage:
config = _parse_additional_config(message)
timestamp_value = message.timestamp
if isinstance(timestamp_value, datetime):
time_value = timestamp_value.timestamp()
else:
time_value = float(timestamp_value)
selected_expressions = _normalize_optional_str(config.get("selected_expressions"))
priority_info = _normalize_optional_str(config.get("priority_info"))
return DatabaseMessages(
message_id=message.message_id,
time=time_value,
chat_id=message.session_id,
reply_to=message.reply_to,
interest_value=config.get("interest_value"),
key_words=_normalize_optional_str(config.get("key_words")),
key_words_lite=_normalize_optional_str(config.get("key_words_lite")),
is_mentioned=message.is_mentioned,
is_at=message.is_at,
reply_probability_boost=config.get("reply_probability_boost"),
processed_plain_text=message.processed_plain_text,
display_message=message.display_message,
priority_mode=_normalize_optional_str(config.get("priority_mode")),
priority_info=priority_info,
additional_config=message.additional_config,
is_emoji=message.is_emoji,
is_picid=message.is_picture,
is_command=message.is_command,
intercept_message_level=config.get("intercept_message_level", 0),
is_notify=message.is_notify,
selected_expressions=selected_expressions,
user_id=message.user_id,
user_nickname=message.user_nickname,
user_cardname=message.user_cardname,
user_platform=message.platform,
chat_info_group_id=message.group_id,
chat_info_group_name=message.group_name,
chat_info_group_platform=message.platform,
chat_info_user_id=message.user_id,
chat_info_user_nickname=message.user_nickname,
chat_info_user_cardname=message.user_cardname,
chat_info_user_platform=message.platform,
chat_info_stream_id=message.session_id,
chat_info_platform=message.platform,
chat_info_create_time=0.0,
chat_info_last_active_time=0.0,
instance = SessionMessage.from_db_instance(message)
instance.interest_value = config.get("interest_value")
instance.key_words = _normalize_optional_str(config.get("key_words"))
instance.key_words_lite = _normalize_optional_str(config.get("key_words_lite"))
instance.reply_probability_boost = config.get("reply_probability_boost")
instance.priority_mode = _normalize_optional_str(config.get("priority_mode"))
instance.priority_info = _normalize_optional_str(config.get("priority_info"))
instance.intercept_message_level = config.get("intercept_message_level", 0)
instance.selected_expressions = _normalize_optional_str(config.get("selected_expressions"))
group_info = instance.message_info.group_info
legacy_group_info = None
if group_info:
legacy_group_info = SimpleNamespace(
group_id=group_info.group_id,
group_name=group_info.group_name,
)
instance.user_info = SimpleNamespace(
user_id=instance.message_info.user_info.user_id,
user_nickname=instance.message_info.user_info.user_nickname,
user_cardname=instance.message_info.user_info.user_cardname,
platform=instance.platform,
)
instance.chat_info = SimpleNamespace(
platform=instance.platform,
stream_id=instance.session_id,
group_info=legacy_group_info,
)
instance.time = instance.timestamp.timestamp()
return instance
def _coerce_datetime(value: Any) -> Any:
@@ -147,7 +132,7 @@ def find_messages(
filter_bot: bool = False,
filter_command: bool = False,
filter_intercept_message_level: int | None = None,
) -> list[DatabaseMessages]:
) -> list[SessionMessage]:
"""
根据提供的过滤器、排序和限制条件查找消息。

View File

@@ -16,14 +16,14 @@ from src.chat.replyer.group_generator import DefaultReplyer
from src.chat.replyer.private_generator import PrivateReplyer
from src.chat.replyer.replyer_manager import replyer_manager
from src.chat.utils.utils import process_llm_response
from src.common.data_models.message_data_model import ReplySetModel
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.common.logger import get_logger
from src.core.types import ActionInfo
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel
from src.chat.message_receive.message import SessionMessage
install(extra_lines=3)
@@ -67,7 +67,7 @@ async def generate_reply(
chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None,
action_data: Optional[Dict[str, Any]] = None,
reply_message: Optional["DatabaseMessages"] = None,
reply_message: Optional["SessionMessage"] = None,
think_level: int = 1,
extra_info: str = "",
reply_reason: str = "",
@@ -126,15 +126,17 @@ async def generate_reply(
if not success:
logger.warning("[GeneratorService] 回复生成失败")
return False, None
reply_set: Optional[ReplySetModel] = None
reply_set: Optional[MessageSequence] = None
if content := llm_response.content:
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
llm_response.processed_output = processed_response
reply_set = ReplySetModel()
reply_set = MessageSequence(components=[])
for text in processed_response:
reply_set.add_text_content(text)
reply_set.components.append(TextComponent(text))
llm_response.reply_set = reply_set
logger.debug(f"[GeneratorService] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
logger.debug(
f"[GeneratorService] 回复生成成功,生成了 {len(reply_set.components) if reply_set else 0} 个回复项"
)
try:
PlanReplyLogger.log_reply(
@@ -196,12 +198,14 @@ async def rewrite_reply(
reason=reason,
reply_to=reply_to,
)
reply_set: Optional[ReplySetModel] = None
reply_set: Optional[MessageSequence] = None
if success and llm_response and (content := llm_response.content):
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
llm_response.reply_set = reply_set
if success:
logger.info(f"[GeneratorService] 重写回复成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
logger.info(
f"[GeneratorService] 重写回复成功,生成了 {len(reply_set.components) if reply_set else 0} 个回复项"
)
else:
logger.warning("[GeneratorService] 重写回复失败")
@@ -215,16 +219,16 @@ async def rewrite_reply(
return False, None
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[ReplySetModel]:
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[MessageSequence]:
"""将文本处理为更拟人化的文本"""
if not isinstance(content, str):
raise ValueError("content 必须是字符串类型")
try:
reply_set = ReplySetModel()
reply_set = MessageSequence(components=[])
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
for text in processed_response:
reply_set.add_text_content(text)
reply_set.components.append(TextComponent(text))
return reply_set

View File

@@ -1,34 +1,21 @@
"""
消息服务模块
提供消息查询和构建成字符串的核心功能。
"""
"""消息服务模块。"""
import re
import time
from typing import Any, Dict, List, Optional, Tuple
from datetime import datetime
from typing import Any, List, Optional, Tuple
from sqlmodel import col, select
from src.chat.utils.chat_message_builder import (
build_readable_messages,
build_readable_messages_with_list,
get_person_id_list,
get_raw_msg_before_timestamp,
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_before_timestamp_with_users,
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_random,
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_by_timestamp_with_chat_users,
get_raw_msg_by_timestamp_with_users,
num_new_messages_since,
num_new_messages_since_with_users,
)
from src.chat.utils.utils import is_bot_self
from src.common.data_models.database_data_model import DatabaseMessages
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.action_record_data_model import MaiActionRecord
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
from src.common.database.database_model import ActionRecord, Images, ImageType
from src.common.message_repository import count_messages, find_messages
from src.common.utils.math_utils import translate_timestamp_to_human_readable
from src.common.utils.utils_action import ActionUtils
from src.chat.utils.utils import is_bot_self
from src.config.config import global_config
# =============================================================================
@@ -36,16 +23,62 @@ from src.common.database.database_model import Images, ImageType
# =============================================================================
def _build_time_range_filter(start_time: float, end_time: float) -> dict[str, Any]:
return {
"time": {
"$gte": start_time,
"$lte": end_time,
}
}
def _build_readable_line(
message: SessionMessage,
*,
replace_bot_name: bool,
timestamp_mode: Optional[str],
show_message_id_prefix: bool,
) -> str:
plain_text = (message.processed_plain_text or "").strip()
if replace_bot_name and global_config.bot.nickname:
plain_text = plain_text.replace(global_config.bot.nickname, "")
user_name = (
message.message_info.user_info.user_cardname
or message.message_info.user_info.user_nickname
or message.message_info.user_info.user_id
)
prefix: List[str] = []
if timestamp_mode:
prefix.append(f"[{translate_timestamp_to_human_readable(message.timestamp.timestamp(), mode=timestamp_mode)}]")
if show_message_id_prefix:
prefix.append(f"[消息ID: {message.message_id}]")
prefix.append(f"{user_name}说:")
return " ".join(prefix) + plain_text
def _normalize_messages(messages: List[SessionMessage]) -> List[SessionMessage]:
normalized: List[SessionMessage] = []
for message in messages:
if not message.processed_plain_text:
message.processed_plain_text = message.display_message or ""
normalized.append(message)
return normalized
def get_messages_by_time(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode))
return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)
messages = find_messages(
message_filter=_build_time_range_filter(start_time, end_time),
limit=limit,
limit_mode=limit_mode,
filter_bot=filter_mai,
)
return _normalize_messages(messages)
def get_messages_by_time_in_chat(
@@ -57,7 +90,7 @@ def get_messages_by_time_in_chat(
filter_mai: bool = False,
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
@@ -66,16 +99,18 @@ def get_messages_by_time_in_chat(
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return get_raw_msg_by_timestamp_with_chat(
chat_id=chat_id,
timestamp_start=start_time,
timestamp_end=end_time,
messages = find_messages(
message_filter={
"chat_id": chat_id,
**_build_time_range_filter(start_time, end_time),
},
limit=limit,
limit_mode=limit_mode,
filter_bot=filter_mai,
filter_command=filter_command,
filter_intercept_message_level=filter_intercept_message_level,
)
return _normalize_messages(messages)
def get_messages_by_time_in_chat_inclusive(
@@ -87,7 +122,7 @@ def get_messages_by_time_in_chat_inclusive(
filter_mai: bool = False,
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
@@ -96,19 +131,21 @@ def get_messages_by_time_in_chat_inclusive(
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=chat_id,
timestamp_start=start_time,
timestamp_end=end_time,
messages = find_messages(
message_filter={
"chat_id": chat_id,
"time": {
"$gte": start_time,
"$lte": end_time,
},
},
limit=limit,
limit_mode=limit_mode,
filter_bot=filter_mai,
filter_command=filter_command,
filter_intercept_message_level=filter_intercept_message_level,
)
if filter_mai:
return filter_mai_messages(messages)
return messages
return _normalize_messages(messages)
def get_messages_by_time_in_chat_for_users(
@@ -118,7 +155,7 @@ def get_messages_by_time_in_chat_for_users(
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
@@ -127,39 +164,64 @@ def get_messages_by_time_in_chat_for_users(
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode)
messages = find_messages(
message_filter={
"chat_id": chat_id,
"time": {
"$gte": start_time,
"$lte": end_time,
},
"user_id": {"$in": person_ids},
},
limit=limit,
limit_mode=limit_mode,
)
return _normalize_messages(messages)
def get_random_chat_messages(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode))
return get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)
return get_messages_by_time(start_time, end_time, limit, limit_mode, filter_mai)
def get_messages_by_time_for_users(
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
messages = find_messages(
message_filter={
"time": {
"$gte": start_time,
"$lte": end_time,
},
"user_id": {"$in": person_ids},
},
limit=limit,
limit_mode=limit_mode,
)
return _normalize_messages(messages)
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[DatabaseMessages]:
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[SessionMessage]:
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai:
return filter_mai_messages(get_raw_msg_before_timestamp(timestamp, limit))
return get_raw_msg_before_timestamp(timestamp, limit)
messages = find_messages(
message_filter={"time": {"$lt": timestamp}},
limit=limit,
limit_mode="latest",
filter_bot=filter_mai,
)
return _normalize_messages(messages)
def get_messages_before_time_in_chat(
@@ -168,7 +230,7 @@ def get_messages_before_time_in_chat(
limit: int = 0,
filter_mai: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
@@ -177,30 +239,40 @@ def get_messages_before_time_in_chat(
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
messages = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=timestamp,
messages = find_messages(
message_filter={
"chat_id": chat_id,
"time": {"$lt": timestamp},
},
limit=limit,
limit_mode="latest",
filter_bot=filter_mai,
filter_intercept_message_level=filter_intercept_message_level,
)
if filter_mai:
return filter_mai_messages(messages)
return messages
return _normalize_messages(messages)
def get_messages_before_time_for_users(
timestamp: float, person_ids: List[str], limit: int = 0
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
return get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit)
messages = find_messages(
message_filter={
"time": {"$lt": timestamp},
"user_id": {"$in": person_ids},
},
limit=limit,
limit_mode="latest",
)
return _normalize_messages(messages)
def get_recent_messages(
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(hours, (int, float)) or hours < 0:
raise ValueError("hours 不能是负数")
if not isinstance(limit, int) or limit < 0:
@@ -211,9 +283,7 @@ def get_recent_messages(
raise ValueError("chat_id 必须是字符串类型")
now = time.time()
start_time = now - hours * 3600
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode))
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode)
return get_messages_by_time_in_chat(chat_id, start_time, now, limit, limit_mode, filter_mai)
# =============================================================================
@@ -228,7 +298,13 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return num_new_messages_since(chat_id, start_time, end_time)
message_filter: dict[str, Any] = {
"chat_id": chat_id,
"time": {"$gt": start_time},
}
if end_time is not None:
message_filter["time"]["$lte"] = end_time
return count_messages(message_filter)
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
@@ -238,7 +314,13 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids)
return count_messages(
{
"chat_id": chat_id,
"time": {"$gt": start_time, "$lte": end_time},
"user_id": {"$in": person_ids},
}
)
# =============================================================================
@@ -246,8 +328,45 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
# =============================================================================
def build_readable_messages(
messages: List[SessionMessage],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
) -> str:
normalized_messages = _normalize_messages(messages)
lines: List[str] = []
unread_mark_added = False
for message in normalized_messages:
if read_mark and not unread_mark_added and message.timestamp.timestamp() > read_mark:
lines.append("--- 以上消息是你已经看过,请关注以下未读的新消息 ---")
unread_mark_added = True
line = _build_readable_line(
message,
replace_bot_name=replace_bot_name,
timestamp_mode=timestamp_mode,
show_message_id_prefix=False,
)
if truncate and len(line) > 200:
line = f"{line[:200]}......(内容太长了)"
lines.append(line)
if show_actions and normalized_messages:
action_lines = build_readable_actions(
get_actions_by_timestamp_with_chat(
normalized_messages[0].session_id,
normalized_messages[0].timestamp.timestamp(),
normalized_messages[-1].timestamp.timestamp(),
)
)
if action_lines:
lines.append(action_lines)
return "\n".join(lines)
def build_readable_messages_to_str(
messages: List[DatabaseMessages],
messages: List[SessionMessage],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
@@ -257,17 +376,71 @@ def build_readable_messages_to_str(
return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions)
def build_readable_messages_with_id(
messages: List[SessionMessage],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
) -> Tuple[str, List[Tuple[str, SessionMessage]]]:
normalized_messages = _normalize_messages(messages)
lines: List[str] = []
message_id_list: List[Tuple[str, SessionMessage]] = []
unread_mark_added = False
for message in normalized_messages:
if read_mark and not unread_mark_added and message.timestamp.timestamp() > read_mark:
lines.append("--- 以上消息是你已经看过,请关注以下未读的新消息 ---")
unread_mark_added = True
line = _build_readable_line(
message,
replace_bot_name=replace_bot_name,
timestamp_mode=timestamp_mode,
show_message_id_prefix=True,
)
if truncate and len(line) > 200:
line = f"{line[:200]}......(内容太长了)"
lines.append(line)
message_id_list.append((message.message_id, message))
if show_actions and normalized_messages:
action_lines = build_readable_actions(
get_actions_by_timestamp_with_chat(
normalized_messages[0].session_id,
normalized_messages[0].timestamp.timestamp(),
normalized_messages[-1].timestamp.timestamp(),
)
)
if action_lines:
lines.append(action_lines)
return "\n".join(lines), message_id_list
async def build_readable_messages_with_details(
messages: List[DatabaseMessages],
messages: List[SessionMessage],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
return await build_readable_messages_with_list(messages, replace_bot_name, timestamp_mode, truncate)
normalized_messages = _normalize_messages(messages)
message_list = [
(
message.timestamp.timestamp(),
message.message_info.user_info.user_id,
message.processed_plain_text or "",
)
for message in normalized_messages
]
return build_readable_messages(normalized_messages, replace_bot_name, timestamp_mode, truncate=truncate), message_list
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
return await get_person_id_list(messages)
async def get_person_ids_from_messages(messages: List[Any]) -> List[str]:
person_ids: List[str] = []
for message in messages:
if isinstance(message, SessionMessage):
person_ids.append(message.message_info.user_info.user_id)
elif isinstance(message, dict) and (user_id := message.get("user_id")):
person_ids.append(str(user_id))
return person_ids
# =============================================================================
@@ -275,9 +448,145 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s
# =============================================================================
def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
def filter_mai_messages(messages: List[SessionMessage]) -> List[SessionMessage]:
"""从消息列表中移除麦麦的消息"""
return [msg for msg in messages if not is_bot_self(msg.user_info.platform, msg.user_info.user_id)]
return [
msg
for msg in messages
if not is_bot_self(msg.platform, msg.message_info.user_info.user_id)
]
def get_raw_msg_by_timestamp(
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
return get_messages_by_time(timestamp_start, timestamp_end, limit, limit_mode)
def get_raw_msg_by_timestamp_with_chat(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
filter_bot: bool = False,
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[SessionMessage]:
return get_messages_by_time_in_chat(
chat_id,
timestamp_start,
timestamp_end,
limit,
limit_mode,
filter_bot,
filter_command,
filter_intercept_message_level,
)
def get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
filter_bot: bool = False,
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[SessionMessage]:
return get_messages_by_time_in_chat_inclusive(
chat_id,
timestamp_start,
timestamp_end,
limit,
limit_mode,
filter_bot,
filter_command,
filter_intercept_message_level,
)
def get_raw_msg_by_timestamp_with_chat_users(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
return get_messages_by_time_in_chat_for_users(chat_id, timestamp_start, timestamp_end, person_ids, limit, limit_mode)
def get_raw_msg_by_timestamp_with_users(
timestamp_start: float,
timestamp_end: float,
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
return get_messages_by_time_for_users(timestamp_start, timestamp_end, person_ids, limit, limit_mode)
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[SessionMessage]:
return get_messages_before_time(timestamp, limit)
def get_raw_msg_before_timestamp_with_chat(
chat_id: str,
timestamp: float,
limit: int = 0,
filter_intercept_message_level: Optional[int] = None,
) -> List[SessionMessage]:
return get_messages_before_time_in_chat(chat_id, timestamp, limit, False, filter_intercept_message_level)
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[SessionMessage]:
return get_messages_before_time_for_users(timestamp, person_ids, limit)
def get_raw_msg_by_timestamp_random(
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
return get_random_chat_messages(timestamp_start, timestamp_end, limit, limit_mode)
def get_actions_by_timestamp_with_chat(chat_id: str, timestamp_start: float, timestamp_end: float) -> List[MaiActionRecord]:
with get_db_session() as session:
statement = (
select(ActionRecord)
.where(col(ActionRecord.session_id) == chat_id)
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(timestamp_start))
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(timestamp_end))
.order_by(col(ActionRecord.timestamp))
)
return [MaiActionRecord.from_db_instance(item) for item in session.exec(statement).all()]
def build_readable_actions(actions: List[MaiActionRecord], timestamp_mode: str = "relative") -> str:
return ActionUtils.build_readable_action_records(actions, timestamp_mode)
def replace_user_references(text: str, platform: str, replace_bot_name: bool = False) -> str:
del platform
if not text:
return text
def _replace(match: re.Match[str]) -> str:
prefix = match.group(1) or ""
user_name = match.group(2)
if replace_bot_name and user_name == global_config.bot.nickname:
user_name = ""
return f"{prefix}{user_name}"
text = re.sub(r"(回复|@)?<([^:<>]+):[^<>]+>", _replace, text)
return text
def translate_pid_to_description(pid: str) -> str:

View File

@@ -6,21 +6,20 @@
import traceback
import time
from typing import Optional, Union, Dict, List, TYPE_CHECKING, Tuple
from typing import Optional, Union, Dict, List, TYPE_CHECKING
from maim_message import MessageBase, BaseMessageInfo, Seg
from maim_message import BaseMessageInfo, GroupInfo as MaimGroupInfo, MessageBase, Seg, UserInfo as MaimUserInfo
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.common.data_models.message_data_model import ReplyContentType
from src.common.data_models.mai_message_data_model import MaiMessage
from src.common.data_models.message_component_data_model import DictComponent, MessageSequence
from src.common.logger import get_logger
from src.config.config import global_config
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_data_model import ForwardNode, ReplyContent, ReplySetModel
from src.chat.message_receive.message import SessionMessage
logger = get_logger("send_service")
@@ -36,7 +35,7 @@ async def _send_to_target(
display_message: str = "",
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
reply_message: Optional["SessionMessage"] = None,
storage_message: bool = True,
show_log: bool = True,
selected_expressions: Optional[List[int]] = None,
@@ -60,12 +59,6 @@ async def _send_to_target(
current_time = time.time()
message_id = f"send_api_{int(current_time * 1000)}"
bot_user_info = UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
)
reply_to_platform_id = ""
anchor_message: Optional[MaiMessage] = None
if reply_message:
anchor_message = db_message_to_mai_message(reply_message)
@@ -73,31 +66,50 @@ async def _send_to_target(
logger.debug(
f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}"
)
reply_to_platform_id = f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}"
sender_info = None
if target_stream.context and target_stream.context.message:
sender_info = target_stream.context.message.message_info.user_info
group_info = None
if target_stream.group_id:
group_name = ""
if target_stream.context and target_stream.context.message and target_stream.context.message.message_info.group_info:
group_name = target_stream.context.message.message_info.group_info.group_name
group_info = MaimGroupInfo(
group_id=target_stream.group_id,
group_name=group_name,
platform=target_stream.platform,
)
bot_message = MessageSending(
message_id=message_id,
session=target_stream,
bot_user_info=bot_user_info,
sender_info=sender_info,
additional_config: dict[str, object] = {}
if selected_expressions is not None:
additional_config["selected_expressions"] = selected_expressions
maim_message = MessageBase(
message_info=BaseMessageInfo(
platform=target_stream.platform,
message_id=message_id,
time=current_time,
user_info=MaimUserInfo(
user_id=str(global_config.bot.qq_account),
user_nickname=global_config.bot.nickname,
platform=target_stream.platform,
),
group_info=group_info,
additional_config=additional_config,
),
message_segment=message_segment,
display_message=display_message,
reply=anchor_message,
is_head=True,
is_emoji=(message_segment.type == "emoji"),
thinking_start_time=current_time,
reply_to=reply_to_platform_id,
selected_expressions=selected_expressions,
)
bot_message = SessionMessage.from_maim_message(maim_message)
bot_message.session_id = target_stream.session_id
bot_message.display_message = display_message
bot_message.reply_to = anchor_message.message_id if anchor_message else None
bot_message.is_emoji = message_segment.type == "emoji"
bot_message.is_picture = message_segment.type == "image"
bot_message.is_command = message_segment.type == "command"
sent_msg = await message_sender.send_message(
bot_message,
typing=typing,
set_reply=set_reply,
reply_message_id=anchor_message.message_id if anchor_message else None,
storage_message=storage_message,
show_log=show_log,
)
@@ -115,37 +127,9 @@ async def _send_to_target(
return False
def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMessage]:
def db_message_to_mai_message(message_obj: "SessionMessage") -> Optional[MaiMessage]:
"""将数据库消息重建为 MaiMessage 对象,用于回复引用。"""
from datetime import datetime
from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo
from src.common.data_models.message_component_data_model import MessageSequence
user_info = UserInfo(
user_id=message_obj.user_info.user_id or "",
user_nickname=message_obj.user_info.user_nickname or "",
user_cardname=message_obj.user_info.user_cardname,
)
group_info = None
if message_obj.chat_info.group_info:
group_info = GroupInfo(
group_id=message_obj.chat_info.group_info.group_id or "",
group_name=message_obj.chat_info.group_info.group_name or "",
)
msg = MaiMessage(
message_id=message_obj.message_id,
timestamp=datetime.fromtimestamp(message_obj.time) if message_obj.time else datetime.now(),
)
msg.message_info = MessageInfo(user_info=user_info, group_info=group_info)
msg.platform = message_obj.chat_info.platform or ""
msg.session_id = message_obj.chat_info.stream_id or ""
msg.processed_plain_text = message_obj.processed_plain_text
msg.raw_message = MessageSequence(components=[])
msg.initialized = True
return msg
return message_obj.deepcopy()
# =============================================================================
@@ -158,7 +142,7 @@ async def text_to_stream(
stream_id: str,
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
reply_message: Optional["SessionMessage"] = None,
storage_message: bool = True,
selected_expressions: Optional[List[int]] = None,
) -> bool:
@@ -180,7 +164,7 @@ async def emoji_to_stream(
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
reply_message: Optional["SessionMessage"] = None,
) -> bool:
"""向指定流发送表情包"""
return await _send_to_target(
@@ -199,7 +183,7 @@ async def image_to_stream(
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
reply_message: Optional["SessionMessage"] = None,
) -> bool:
"""向指定流发送图片"""
return await _send_to_target(
@@ -236,7 +220,7 @@ async def custom_to_stream(
stream_id: str,
display_message: str = "",
typing: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
reply_message: Optional["SessionMessage"] = None,
set_reply: bool = False,
storage_message: bool = True,
show_log: bool = True,
@@ -255,25 +239,27 @@ async def custom_to_stream(
async def custom_reply_set_to_stream(
reply_set: "ReplySetModel",
reply_set: MessageSequence,
stream_id: str,
display_message: str = "",
typing: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
reply_message: Optional["SessionMessage"] = None,
set_reply: bool = False,
storage_message: bool = True,
show_log: bool = True,
) -> bool:
"""向指定流发送混合型消息集"""
"""向指定流发送消息组件序列。"""
flag: bool = True
for reply_content in reply_set.reply_data:
status: bool = False
message_seg, need_typing = _parse_content_to_seg(reply_content)
for component in reply_set.components:
if isinstance(component, DictComponent):
message_seg = Seg(type="dict", data=component.data) # type: ignore
else:
message_seg = await component.to_seg()
status = await _send_to_target(
message_segment=message_seg,
stream_id=stream_id,
display_message=display_message,
typing=bool(need_typing and typing),
typing=typing,
reply_message=reply_message,
set_reply=set_reply,
storage_message=storage_message,
@@ -281,67 +267,7 @@ async def custom_reply_set_to_stream(
)
if not status:
flag = False
logger.error(
f"[SendService] 发送{repr(reply_content.content_type)}消息失败,消息内容:{str(reply_content.content)[:100]}"
)
logger.error(f"[SendService] 发送消息组件失败,组件类型:{type(component).__name__}")
set_reply = False
return flag
def _parse_content_to_seg(reply_content: "ReplyContent") -> Tuple[Seg, bool]:
"""把 ReplyContent 转换为 Seg 结构"""
content_type = reply_content.content_type
if content_type == ReplyContentType.TEXT:
text_data: str = reply_content.content # type: ignore
return Seg(type="text", data=text_data), True
elif content_type == ReplyContentType.IMAGE:
return Seg(type="image", data=reply_content.content), False # type: ignore
elif content_type == ReplyContentType.EMOJI:
return Seg(type="emoji", data=reply_content.content), False # type: ignore
elif content_type == ReplyContentType.COMMAND:
return Seg(type="command", data=reply_content.content), False # type: ignore
elif content_type == ReplyContentType.VOICE:
return Seg(type="voice", data=reply_content.content), False # type: ignore
elif content_type == ReplyContentType.HYBRID:
hybrid_message_list_data: List[ReplyContent] = reply_content.content # type: ignore
assert isinstance(hybrid_message_list_data, list), "混合类型内容必须是列表"
sub_seg_list: List[Seg] = []
for sub_content in hybrid_message_list_data:
sub_content_type = sub_content.content_type
sub_content_data = sub_content.content
if sub_content_type == ReplyContentType.TEXT:
sub_seg_list.append(Seg(type="text", data=sub_content_data)) # type: ignore
elif sub_content_type == ReplyContentType.IMAGE:
sub_seg_list.append(Seg(type="image", data=sub_content_data)) # type: ignore
elif sub_content_type == ReplyContentType.EMOJI:
sub_seg_list.append(Seg(type="emoji", data=sub_content_data)) # type: ignore
else:
logger.warning(f"[SendService] 混合类型中不支持的子内容类型: {repr(sub_content_type)}")
continue
return Seg(type="seglist", data=sub_seg_list), True
elif content_type == ReplyContentType.FORWARD:
forward_message_list_data: List["ForwardNode"] = reply_content.content # type: ignore
assert isinstance(forward_message_list_data, list), "转发类型内容必须是列表"
forward_message_list: List[Dict] = []
for forward_node in forward_message_list_data:
message_segment = Seg(type="id", data=forward_node.content) # type: ignore
user_info: Optional[UserInfo] = None
if forward_node.user_id and forward_node.user_nickname:
assert isinstance(forward_node.content, list), "转发节点内容必须是列表"
user_info = UserInfo(user_id=forward_node.user_id, user_nickname=forward_node.user_nickname)
single_node_content: List[Seg] = []
for sub_content in forward_node.content:
if sub_content.content_type != ReplyContentType.FORWARD:
sub_seg, _ = _parse_content_to_seg(sub_content)
single_node_content.append(sub_seg)
message_segment = Seg(type="seglist", data=single_node_content)
forward_message_list.append(
MessageBase(
message_segment=message_segment, message_info=BaseMessageInfo(user_info=user_info)
).to_dict()
)
return Seg(type="forward", data=forward_message_list), False # type: ignore
else:
message_type_in_str = content_type.value if isinstance(content_type, ReplyContentType) else str(content_type)
return Seg(type=message_type_in_str, data=reply_content.content), True # type: ignore