部分模块的新数据结构适配

This commit is contained in:
DrSmoothl
2026-03-13 23:36:17 +08:00
parent 6201b862c9
commit 898fab6de9
7 changed files with 580 additions and 399 deletions

View File

@@ -8,7 +8,7 @@ from rich.traceback import install
from src.config.config import global_config from src.config.config import global_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.message_data_model import ReplyContentType from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
@@ -35,7 +35,6 @@ from src.chat.utils.chat_message_builder import (
if TYPE_CHECKING: if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_data_model import ReplySetModel
ERROR_LOOP_INFO = { ERROR_LOOP_INFO = {
@@ -513,7 +512,7 @@ class BrainChatting:
async def _send_response( async def _send_response(
self, self,
reply_set: "ReplySetModel", reply_set: MessageSequence,
message_data: "DatabaseMessages", message_data: "DatabaseMessages",
selected_expressions: Optional[List[int]] = None, selected_expressions: Optional[List[int]] = None,
) -> str: ) -> str:
@@ -528,10 +527,10 @@ class BrainChatting:
reply_text = "" reply_text = ""
first_replied = False first_replied = False
for reply_content in reply_set.reply_data: for component in reply_set.components:
if reply_content.content_type != ReplyContentType.TEXT: if not isinstance(component, TextComponent):
continue continue
data: str = reply_content.content # type: ignore data = component.text
if not first_replied: if not first_replied:
await send_api.text_to_stream( await send_api.text_to_stream(
text=data, text=data,

View File

@@ -7,29 +7,31 @@ import re
from typing import List, Optional, Dict, Any, Tuple from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime from datetime import datetime
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel from src.common.data_models.llm_data_model import LLMGenerationDataModel
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.message_old import UserInfo, Seg, MessageRecv, MessageSending from maim_message import BaseMessageInfo, MessageBase, Seg
from src.chat.message_receive.chat_stream import ChatStream
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer # <--- Import Timer from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager
from src.chat.utils.chat_message_builder import ( from src.services.message_service import (
build_readable_messages, build_readable_messages,
get_raw_msg_before_timestamp_with_chat, get_raw_msg_before_timestamp_with_chat,
replace_user_references, replace_user_references,
translate_pid_to_description,
) )
from src.bw_learner.expression_selector import expression_selector from src.bw_learner.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator # from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person from src.person_info.person_info import Person
from src.plugin_system.base.component_types import ActionInfo, EventType from src.core.types import ActionInfo, EventType
from src.plugin_system.apis import llm_api from src.services import llm_service as llm_api
from src.chat.logger.plan_reply_logger import PlanReplyLogger from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
@@ -45,17 +47,17 @@ logger = get_logger("replyer")
class DefaultReplyer: class DefaultReplyer:
def __init__( def __init__(
self, self,
chat_stream: ChatStream, chat_stream: BotChatSession,
request_type: str = "replyer", request_type: str = "replyer",
): ):
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.chat_stream = chat_stream self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id) self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
self.heart_fc_sender = UniversalMessageSender() self.heart_fc_sender = UniversalMessageSender()
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖 from src.chat.tool_executor import ToolExecutor
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3) self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
async def generate_reply_with_context( async def generate_reply_with_context(
self, self,
@@ -66,7 +68,7 @@ class DefaultReplyer:
enable_tool: bool = True, enable_tool: bool = True,
from_plugin: bool = True, from_plugin: bool = True,
stream_id: Optional[str] = None, stream_id: Optional[str] = None,
reply_message: Optional[DatabaseMessages] = None, reply_message: Optional[SessionMessage] = None,
reply_time_point: float = time.time(), reply_time_point: float = time.time(),
think_level: int = 1, think_level: int = 1,
unknown_words: Optional[List[str]] = None, unknown_words: Optional[List[str]] = None,
@@ -132,7 +134,7 @@ class DefaultReplyer:
if log_reply: if log_reply:
try: try:
PlanReplyLogger.log_reply( PlanReplyLogger.log_reply(
chat_id=self.chat_stream.stream_id, chat_id=self.chat_stream.session_id,
prompt="", prompt="",
output=None, output=None,
processed_output=None, processed_output=None,
@@ -146,12 +148,12 @@ class DefaultReplyer:
except Exception: except Exception:
logger.exception("记录reply日志失败") logger.exception("记录reply日志失败")
return False, llm_response return False, llm_response
from src.plugin_system.core.events_manager import events_manager from src.core.event_bus import event_bus
from src.chat.event_helpers import build_event_message
if not from_plugin: if not from_plugin:
continue_flag, modified_message = await events_manager.handle_mai_events( _event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id)
EventType.POST_LLM, None, prompt, None, stream_id=stream_id continue_flag, modified_message = await event_bus.emit(EventType.POST_LLM, _event_msg)
)
if not continue_flag: if not continue_flag:
raise UserWarning("插件于请求前中断了内容生成") raise UserWarning("插件于请求前中断了内容生成")
if modified_message and modified_message._modify_flags.modify_llm_prompt: if modified_message and modified_message._modify_flags.modify_llm_prompt:
@@ -202,7 +204,7 @@ class DefaultReplyer:
try: try:
if log_reply: if log_reply:
PlanReplyLogger.log_reply( PlanReplyLogger.log_reply(
chat_id=self.chat_stream.stream_id, chat_id=self.chat_stream.session_id,
prompt=prompt, prompt=prompt,
output=content, output=content,
processed_output=None, processed_output=None,
@@ -214,9 +216,10 @@ class DefaultReplyer:
) )
except Exception: except Exception:
logger.exception("记录reply日志失败") logger.exception("记录reply日志失败")
continue_flag, modified_message = await events_manager.handle_mai_events( _event_msg = build_event_message(
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id
) )
continue_flag, modified_message = await event_bus.emit(EventType.AFTER_LLM, _event_msg)
if not from_plugin and not continue_flag: if not from_plugin and not continue_flag:
raise UserWarning("插件于请求后取消了内容生成") raise UserWarning("插件于请求后取消了内容生成")
if modified_message: if modified_message:
@@ -259,7 +262,7 @@ class DefaultReplyer:
if log_reply: if log_reply:
try: try:
PlanReplyLogger.log_reply( PlanReplyLogger.log_reply(
chat_id=self.chat_stream.stream_id, chat_id=self.chat_stream.session_id,
prompt=prompt or "", prompt=prompt or "",
output=None, output=None,
processed_output=None, processed_output=None,
@@ -353,14 +356,14 @@ class DefaultReplyer:
str: 表达习惯信息字符串 str: 表达习惯信息字符串
""" """
# 检查是否允许在此聊天流中使用表达 # 检查是否允许在此聊天流中使用表达
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id) use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.session_id)
if not use_expression: if not use_expression:
return "", [] return "", []
style_habits = [] style_habits = []
# 使用从处理器传来的选中表达方式 # 使用从处理器传来的选中表达方式
# 使用模型预测选择表达方式 # 使用模型预测选择表达方式
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions( selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, self.chat_stream.session_id,
chat_history, chat_history,
max_num=8, max_num=8,
target_message=target, target_message=target,
@@ -594,7 +597,7 @@ class DefaultReplyer:
async def _build_jargon_explanation( async def _build_jargon_explanation(
self, self,
chat_id: str, chat_id: str,
messages_short: List[DatabaseMessages], messages_short: List[SessionMessage],
chat_talking_prompt_short: str, chat_talking_prompt_short: str,
unknown_words: Optional[List[str]], unknown_words: Optional[List[str]],
) -> str: ) -> str:
@@ -703,9 +706,13 @@ class DefaultReplyer:
is_group = stream_type == "group" is_group = stream_type == "group"
# 使用 ChatManager 提供的接口生成 chat_id避免在此重复实现逻辑 # 使用 ChatManager 提供的接口生成 chat_id避免在此重复实现逻辑
from src.chat.message_receive.chat_stream import get_chat_manager from src.common.utils.utils_session import SessionUtils
chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group) chat_id = SessionUtils.calculate_session_id(
platform,
group_id=str(id_str) if is_group else None,
user_id=str(id_str) if not is_group else None,
)
return chat_id, prompt_content return chat_id, prompt_content
except (ValueError, IndexError): except (ValueError, IndexError):
@@ -751,7 +758,7 @@ class DefaultReplyer:
async def build_prompt_reply_context( async def build_prompt_reply_context(
self, self,
reply_message: Optional[DatabaseMessages] = None, reply_message: Optional[SessionMessage] = None,
extra_info: str = "", extra_info: str = "",
reply_reason: str = "", reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None, available_actions: Optional[Dict[str, ActionInfo]] = None,
@@ -778,7 +785,7 @@ class DefaultReplyer:
if available_actions is None: if available_actions is None:
available_actions = {} available_actions = {}
chat_stream = self.chat_stream chat_stream = self.chat_stream
chat_id = chat_stream.stream_id chat_id = chat_stream.session_id
_is_group_chat = bool(chat_stream.group_info) _is_group_chat = bool(chat_stream.group_info)
platform = chat_stream.platform platform = chat_stream.platform
@@ -1005,7 +1012,7 @@ class DefaultReplyer:
reply_to: str, reply_to: str,
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
chat_stream = self.chat_stream chat_stream = self.chat_stream
chat_id = chat_stream.stream_id chat_id = chat_stream.session_id
sender, target = self._parse_reply_target(reply_to) sender, target = self._parse_reply_target(reply_to)
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
@@ -1105,31 +1112,29 @@ class DefaultReplyer:
is_emoji: bool, is_emoji: bool,
thinking_start_time: float, thinking_start_time: float,
display_message: str, display_message: str,
anchor_message: Optional[MessageRecv] = None, anchor_message: Optional[MaiMessage] = None,
) -> MessageSending: ) -> SessionMessage:
"""构建单个发送消息""" """构建单个发送消息"""
bot_user_info = UserInfo( maim_message = MessageBase(
user_id=str(global_config.bot.qq_account), message_info=BaseMessageInfo(
user_nickname=global_config.bot.nickname, platform=self.chat_stream.platform,
platform=self.chat_stream.platform, message_id=message_id,
) time=thinking_start_time,
user_info=UserInfo(
# await anchor_message.process() user_id=str(global_config.bot.qq_account),
sender_info = anchor_message.message_info.user_info if anchor_message else None user_nickname=global_config.bot.nickname,
),
return MessageSending( additional_config={},
message_id=message_id, # 使用片段的唯一ID ),
chat_stream=self.chat_stream,
bot_user_info=bot_user_info,
sender_info=sender_info,
message_segment=message_segment, message_segment=message_segment,
reply=anchor_message, # 回复原始锚点
is_head=reply_to,
is_emoji=is_emoji,
thinking_start_time=thinking_start_time, # 传递原始思考开始时间
display_message=display_message,
) )
message = SessionMessage.from_maim_message(maim_message)
message.session_id = self.chat_stream.session_id
message.display_message = display_message
message.reply_to = anchor_message.message_id if reply_to and anchor_message else None
message.is_emoji = is_emoji
return message
async def llm_generate_content(self, prompt: str): async def llm_generate_content(self, prompt: str):
with Timer("LLM生成", {}): # 内部计时器,可选保留 with Timer("LLM生成", {}): # 内部计时器,可选保留

View File

@@ -7,33 +7,31 @@ import re
from typing import List, Optional, Dict, Any, Tuple from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime from datetime import datetime
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel from src.common.data_models.llm_data_model import LLMGenerationDataModel
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from maim_message import Seg from maim_message import BaseMessageInfo, MessageBase, Seg
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.chat.message_receive.message_old import MessageSending from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager
from src.chat.utils.common_utils import TempMethodsExpression from src.chat.utils.common_utils import TempMethodsExpression
from src.chat.utils.chat_message_builder import ( from src.services.message_service import (
build_readable_messages, build_readable_messages,
get_raw_msg_before_timestamp_with_chat, get_raw_msg_before_timestamp_with_chat,
replace_user_references, replace_user_references,
translate_pid_to_description,
) )
from src.bw_learner.expression_selector import expression_selector from src.bw_learner.expression_selector import expression_selector
from src.services.message_service import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator # from src.memory_system.memory_activator import MemoryActivator
from src.person_info.person_info import Person, is_person_known from src.person_info.person_info import Person, is_person_known
from src.core.types import ActionInfo, EventType from src.core.types import ActionInfo, EventType
from src.services import llm_service as llm_api
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
from src.bw_learner.jargon_explainer_old import explain_jargon_in_context from src.bw_learner.jargon_explainer_old import explain_jargon_in_context
@@ -69,7 +67,7 @@ class PrivateReplyer:
from_plugin: bool = True, from_plugin: bool = True,
think_level: int = 1, think_level: int = 1,
stream_id: Optional[str] = None, stream_id: Optional[str] = None,
reply_message: Optional[DatabaseMessages] = None, reply_message: Optional[SessionMessage] = None,
reply_time_point: Optional[float] = time.time(), reply_time_point: Optional[float] = time.time(),
unknown_words: Optional[List[str]] = None, unknown_words: Optional[List[str]] = None,
log_reply: bool = True, log_reply: bool = True,
@@ -604,7 +602,7 @@ class PrivateReplyer:
async def build_prompt_reply_context( async def build_prompt_reply_context(
self, self,
reply_message: Optional[DatabaseMessages] = None, reply_message: Optional[SessionMessage] = None,
extra_info: str = "", extra_info: str = "",
reply_reason: str = "", reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None, available_actions: Optional[Dict[str, ActionInfo]] = None,
@@ -954,28 +952,29 @@ class PrivateReplyer:
thinking_start_time: float, thinking_start_time: float,
display_message: str, display_message: str,
anchor_message: Optional[MaiMessage] = None, anchor_message: Optional[MaiMessage] = None,
) -> MessageSending: ) -> SessionMessage:
"""构建单个发送消息""" """构建单个发送消息"""
bot_user_info = UserInfo( maim_message = MessageBase(
user_id=str(global_config.bot.qq_account), message_info=BaseMessageInfo(
user_nickname=global_config.bot.nickname, platform=self.chat_stream.platform,
) message_id=message_id,
time=thinking_start_time,
sender_info = anchor_message.message_info.user_info if anchor_message else None user_info=UserInfo(
user_id=str(global_config.bot.qq_account),
return MessageSending( user_nickname=global_config.bot.nickname,
message_id=message_id, ),
session=self.chat_stream, group_info=None,
bot_user_info=bot_user_info, additional_config={},
sender_info=sender_info, ),
message_segment=message_segment, message_segment=message_segment,
reply=anchor_message,
is_head=reply_to,
is_emoji=is_emoji,
thinking_start_time=thinking_start_time,
display_message=display_message,
) )
message = SessionMessage.from_maim_message(maim_message)
message.session_id = self.chat_stream.session_id
message.display_message = display_message
message.reply_to = anchor_message.message_id if reply_to and anchor_message else None
message.is_emoji = is_emoji
return message
async def llm_generate_content(self, prompt: str): async def llm_generate_content(self, prompt: str):
with Timer("LLM生成", {}): # 内部计时器,可选保留 with Timer("LLM生成", {}): # 内部计时器,可选保留
@@ -999,55 +998,9 @@ class PrivateReplyer:
return content, reasoning_content, model_name, tool_calls return content, reasoning_content, model_name, tool_calls
async def get_prompt_info(self, message: str, sender: str, target: str): async def get_prompt_info(self, message: str, sender: str, target: str):
related_info = "" logger.debug(f"已跳过知识库信息获取,元消息:{message[:30]}...,消息长度: {len(message)}")
start_time = time.time() del message, sender, target
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool return ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 从LPMM知识库获取知识
try:
# 检查LPMM知识库是否启用
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用跳过获取知识库内容")
return ""
if global_config.lpmm_knowledge.lpmm_mode == "agent":
return ""
prompt_template = prompt_manager.get_prompt("lpmm_get_knowledge")
prompt_template.add_context("bot_name", global_config.bot.nickname)
prompt_template.add_context("time_now", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
prompt_template.add_context("chat_history", message)
prompt_template.add_context("sender", sender)
prompt_template.add_context("target_message", target)
prompt = await prompt_manager.render_prompt(prompt_template)
_, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools(
prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
)
if tool_calls:
result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
end_time = time.time()
if not result or not result.get("content"):
logger.debug("从LPMM知识库获取知识失败返回空知识...")
return ""
found_knowledge_from_lpmm = result.get("content", "")
logger.debug(
f"从LPMM知识库获取知识相关信息{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
)
related_info += found_knowledge_from_lpmm
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}")
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return f"你有以下这些**知识**\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
else:
logger.debug("模型认为不需要使用LPMM知识库")
return ""
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
return ""
def weighted_sample_no_replacement(items, weights, k) -> list: def weighted_sample_no_replacement(items, weights, k) -> list:

View File

@@ -1,5 +1,6 @@
import traceback import traceback
from datetime import datetime from datetime import datetime
from types import SimpleNamespace
from typing import Any from typing import Any
import json import json
@@ -9,7 +10,7 @@ from sqlmodel import col, select
from src.common.database.database import get_db_session from src.common.database.database import get_db_session
from src.common.database.database_model import Messages from src.common.database.database_model import Messages
from src.common.data_models.database_data_model import DatabaseMessages from src.chat.message_receive.message import SessionMessage
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
@@ -58,53 +59,37 @@ def _normalize_optional_str(value: object) -> str | None:
return str(value) return str(value)
def _message_to_instance(message: Messages) -> DatabaseMessages: def _message_to_instance(message: Messages) -> SessionMessage:
config = _parse_additional_config(message) config = _parse_additional_config(message)
timestamp_value = message.timestamp instance = SessionMessage.from_db_instance(message)
if isinstance(timestamp_value, datetime): instance.interest_value = config.get("interest_value")
time_value = timestamp_value.timestamp() instance.key_words = _normalize_optional_str(config.get("key_words"))
else: instance.key_words_lite = _normalize_optional_str(config.get("key_words_lite"))
time_value = float(timestamp_value) instance.reply_probability_boost = config.get("reply_probability_boost")
selected_expressions = _normalize_optional_str(config.get("selected_expressions")) instance.priority_mode = _normalize_optional_str(config.get("priority_mode"))
priority_info = _normalize_optional_str(config.get("priority_info")) instance.priority_info = _normalize_optional_str(config.get("priority_info"))
return DatabaseMessages( instance.intercept_message_level = config.get("intercept_message_level", 0)
message_id=message.message_id, instance.selected_expressions = _normalize_optional_str(config.get("selected_expressions"))
time=time_value, group_info = instance.message_info.group_info
chat_id=message.session_id, legacy_group_info = None
reply_to=message.reply_to, if group_info:
interest_value=config.get("interest_value"), legacy_group_info = SimpleNamespace(
key_words=_normalize_optional_str(config.get("key_words")), group_id=group_info.group_id,
key_words_lite=_normalize_optional_str(config.get("key_words_lite")), group_name=group_info.group_name,
is_mentioned=message.is_mentioned, )
is_at=message.is_at, instance.user_info = SimpleNamespace(
reply_probability_boost=config.get("reply_probability_boost"), user_id=instance.message_info.user_info.user_id,
processed_plain_text=message.processed_plain_text, user_nickname=instance.message_info.user_info.user_nickname,
display_message=message.display_message, user_cardname=instance.message_info.user_info.user_cardname,
priority_mode=_normalize_optional_str(config.get("priority_mode")), platform=instance.platform,
priority_info=priority_info,
additional_config=message.additional_config,
is_emoji=message.is_emoji,
is_picid=message.is_picture,
is_command=message.is_command,
intercept_message_level=config.get("intercept_message_level", 0),
is_notify=message.is_notify,
selected_expressions=selected_expressions,
user_id=message.user_id,
user_nickname=message.user_nickname,
user_cardname=message.user_cardname,
user_platform=message.platform,
chat_info_group_id=message.group_id,
chat_info_group_name=message.group_name,
chat_info_group_platform=message.platform,
chat_info_user_id=message.user_id,
chat_info_user_nickname=message.user_nickname,
chat_info_user_cardname=message.user_cardname,
chat_info_user_platform=message.platform,
chat_info_stream_id=message.session_id,
chat_info_platform=message.platform,
chat_info_create_time=0.0,
chat_info_last_active_time=0.0,
) )
instance.chat_info = SimpleNamespace(
platform=instance.platform,
stream_id=instance.session_id,
group_info=legacy_group_info,
)
instance.time = instance.timestamp.timestamp()
return instance
def _coerce_datetime(value: Any) -> Any: def _coerce_datetime(value: Any) -> Any:
@@ -147,7 +132,7 @@ def find_messages(
filter_bot: bool = False, filter_bot: bool = False,
filter_command: bool = False, filter_command: bool = False,
filter_intercept_message_level: int | None = None, filter_intercept_message_level: int | None = None,
) -> list[DatabaseMessages]: ) -> list[SessionMessage]:
""" """
根据提供的过滤器、排序和限制条件查找消息。 根据提供的过滤器、排序和限制条件查找消息。

View File

@@ -16,14 +16,14 @@ from src.chat.replyer.group_generator import DefaultReplyer
from src.chat.replyer.private_generator import PrivateReplyer from src.chat.replyer.private_generator import PrivateReplyer
from src.chat.replyer.replyer_manager import replyer_manager from src.chat.replyer.replyer_manager import replyer_manager
from src.chat.utils.utils import process_llm_response from src.chat.utils.utils import process_llm_response
from src.common.data_models.message_data_model import ReplySetModel from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.common.logger import get_logger from src.common.logger import get_logger
from src.core.types import ActionInfo from src.core.types import ActionInfo
if TYPE_CHECKING: if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel from src.common.data_models.llm_data_model import LLMGenerationDataModel
from src.chat.message_receive.message import SessionMessage
install(extra_lines=3) install(extra_lines=3)
@@ -67,7 +67,7 @@ async def generate_reply(
chat_stream: Optional[BotChatSession] = None, chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None, chat_id: Optional[str] = None,
action_data: Optional[Dict[str, Any]] = None, action_data: Optional[Dict[str, Any]] = None,
reply_message: Optional["DatabaseMessages"] = None, reply_message: Optional["SessionMessage"] = None,
think_level: int = 1, think_level: int = 1,
extra_info: str = "", extra_info: str = "",
reply_reason: str = "", reply_reason: str = "",
@@ -126,15 +126,17 @@ async def generate_reply(
if not success: if not success:
logger.warning("[GeneratorService] 回复生成失败") logger.warning("[GeneratorService] 回复生成失败")
return False, None return False, None
reply_set: Optional[ReplySetModel] = None reply_set: Optional[MessageSequence] = None
if content := llm_response.content: if content := llm_response.content:
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo) processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
llm_response.processed_output = processed_response llm_response.processed_output = processed_response
reply_set = ReplySetModel() reply_set = MessageSequence(components=[])
for text in processed_response: for text in processed_response:
reply_set.add_text_content(text) reply_set.components.append(TextComponent(text))
llm_response.reply_set = reply_set llm_response.reply_set = reply_set
logger.debug(f"[GeneratorService] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项") logger.debug(
f"[GeneratorService] 回复生成成功,生成了 {len(reply_set.components) if reply_set else 0} 个回复项"
)
try: try:
PlanReplyLogger.log_reply( PlanReplyLogger.log_reply(
@@ -196,12 +198,14 @@ async def rewrite_reply(
reason=reason, reason=reason,
reply_to=reply_to, reply_to=reply_to,
) )
reply_set: Optional[ReplySetModel] = None reply_set: Optional[MessageSequence] = None
if success and llm_response and (content := llm_response.content): if success and llm_response and (content := llm_response.content):
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo) reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
llm_response.reply_set = reply_set llm_response.reply_set = reply_set
if success: if success:
logger.info(f"[GeneratorService] 重写回复成功,生成了 {len(reply_set) if reply_set else 0} 个回复项") logger.info(
f"[GeneratorService] 重写回复成功,生成了 {len(reply_set.components) if reply_set else 0} 个回复项"
)
else: else:
logger.warning("[GeneratorService] 重写回复失败") logger.warning("[GeneratorService] 重写回复失败")
@@ -215,16 +219,16 @@ async def rewrite_reply(
return False, None return False, None
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[ReplySetModel]: def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[MessageSequence]:
"""将文本处理为更拟人化的文本""" """将文本处理为更拟人化的文本"""
if not isinstance(content, str): if not isinstance(content, str):
raise ValueError("content 必须是字符串类型") raise ValueError("content 必须是字符串类型")
try: try:
reply_set = ReplySetModel() reply_set = MessageSequence(components=[])
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo) processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
for text in processed_response: for text in processed_response:
reply_set.add_text_content(text) reply_set.components.append(TextComponent(text))
return reply_set return reply_set

View File

@@ -1,34 +1,21 @@
""" """消息服务模块。"""
消息服务模块
提供消息查询和构建成字符串的核心功能。
"""
import re
import time import time
from typing import Any, Dict, List, Optional, Tuple from datetime import datetime
from typing import Any, List, Optional, Tuple
from sqlmodel import col, select from sqlmodel import col, select
from src.chat.utils.chat_message_builder import ( from src.chat.message_receive.message import SessionMessage
build_readable_messages, from src.common.data_models.action_record_data_model import MaiActionRecord
build_readable_messages_with_list,
get_person_id_list,
get_raw_msg_before_timestamp,
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_before_timestamp_with_users,
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_random,
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_by_timestamp_with_chat_users,
get_raw_msg_by_timestamp_with_users,
num_new_messages_since,
num_new_messages_since_with_users,
)
from src.chat.utils.utils import is_bot_self
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.database import get_db_session from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType from src.common.database.database_model import ActionRecord, Images, ImageType
from src.common.message_repository import count_messages, find_messages
from src.common.utils.math_utils import translate_timestamp_to_human_readable
from src.common.utils.utils_action import ActionUtils
from src.chat.utils.utils import is_bot_self
from src.config.config import global_config
# ============================================================================= # =============================================================================
@@ -36,16 +23,62 @@ from src.common.database.database_model import Images, ImageType
# ============================================================================= # =============================================================================
def _build_time_range_filter(start_time: float, end_time: float) -> dict[str, Any]:
return {
"time": {
"$gte": start_time,
"$lte": end_time,
}
}
def _build_readable_line(
message: SessionMessage,
*,
replace_bot_name: bool,
timestamp_mode: Optional[str],
show_message_id_prefix: bool,
) -> str:
plain_text = (message.processed_plain_text or "").strip()
if replace_bot_name and global_config.bot.nickname:
plain_text = plain_text.replace(global_config.bot.nickname, "")
user_name = (
message.message_info.user_info.user_cardname
or message.message_info.user_info.user_nickname
or message.message_info.user_info.user_id
)
prefix: List[str] = []
if timestamp_mode:
prefix.append(f"[{translate_timestamp_to_human_readable(message.timestamp.timestamp(), mode=timestamp_mode)}]")
if show_message_id_prefix:
prefix.append(f"[消息ID: {message.message_id}]")
prefix.append(f"{user_name}说:")
return " ".join(prefix) + plain_text
def _normalize_messages(messages: List[SessionMessage]) -> List[SessionMessage]:
normalized: List[SessionMessage] = []
for message in messages:
if not message.processed_plain_text:
message.processed_plain_text = message.display_message or ""
normalized.append(message)
return normalized
def get_messages_by_time( def get_messages_by_time(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]: ) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型") raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0: if limit < 0:
raise ValueError("limit 不能为负数") raise ValueError("limit 不能为负数")
if filter_mai: messages = find_messages(
return filter_mai_messages(get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)) message_filter=_build_time_range_filter(start_time, end_time),
return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode) limit=limit,
limit_mode=limit_mode,
filter_bot=filter_mai,
)
return _normalize_messages(messages)
def get_messages_by_time_in_chat( def get_messages_by_time_in_chat(
@@ -57,7 +90,7 @@ def get_messages_by_time_in_chat(
filter_mai: bool = False, filter_mai: bool = False,
filter_command: bool = False, filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None, filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]: ) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型") raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0: if limit < 0:
@@ -66,16 +99,18 @@ def get_messages_by_time_in_chat(
raise ValueError("chat_id 不能为空") raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str): if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型") raise ValueError("chat_id 必须是字符串类型")
return get_raw_msg_by_timestamp_with_chat( messages = find_messages(
chat_id=chat_id, message_filter={
timestamp_start=start_time, "chat_id": chat_id,
timestamp_end=end_time, **_build_time_range_filter(start_time, end_time),
},
limit=limit, limit=limit,
limit_mode=limit_mode, limit_mode=limit_mode,
filter_bot=filter_mai, filter_bot=filter_mai,
filter_command=filter_command, filter_command=filter_command,
filter_intercept_message_level=filter_intercept_message_level, filter_intercept_message_level=filter_intercept_message_level,
) )
return _normalize_messages(messages)
def get_messages_by_time_in_chat_inclusive( def get_messages_by_time_in_chat_inclusive(
@@ -87,7 +122,7 @@ def get_messages_by_time_in_chat_inclusive(
filter_mai: bool = False, filter_mai: bool = False,
filter_command: bool = False, filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None, filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]: ) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型") raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0: if limit < 0:
@@ -96,19 +131,21 @@ def get_messages_by_time_in_chat_inclusive(
raise ValueError("chat_id 不能为空") raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str): if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型") raise ValueError("chat_id 必须是字符串类型")
messages = get_raw_msg_by_timestamp_with_chat_inclusive( messages = find_messages(
chat_id=chat_id, message_filter={
timestamp_start=start_time, "chat_id": chat_id,
timestamp_end=end_time, "time": {
"$gte": start_time,
"$lte": end_time,
},
},
limit=limit, limit=limit,
limit_mode=limit_mode, limit_mode=limit_mode,
filter_bot=filter_mai, filter_bot=filter_mai,
filter_command=filter_command, filter_command=filter_command,
filter_intercept_message_level=filter_intercept_message_level, filter_intercept_message_level=filter_intercept_message_level,
) )
if filter_mai: return _normalize_messages(messages)
return filter_mai_messages(messages)
return messages
def get_messages_by_time_in_chat_for_users( def get_messages_by_time_in_chat_for_users(
@@ -118,7 +155,7 @@ def get_messages_by_time_in_chat_for_users(
person_ids: List[str], person_ids: List[str],
limit: int = 0, limit: int = 0,
limit_mode: str = "latest", limit_mode: str = "latest",
) -> List[DatabaseMessages]: ) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型") raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0: if limit < 0:
@@ -127,39 +164,64 @@ def get_messages_by_time_in_chat_for_users(
raise ValueError("chat_id 不能为空") raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str): if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型") raise ValueError("chat_id 必须是字符串类型")
return get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode) messages = find_messages(
message_filter={
"chat_id": chat_id,
"time": {
"$gte": start_time,
"$lte": end_time,
},
"user_id": {"$in": person_ids},
},
limit=limit,
limit_mode=limit_mode,
)
return _normalize_messages(messages)
def get_random_chat_messages( def get_random_chat_messages(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]: ) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型") raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0: if limit < 0:
raise ValueError("limit 不能为负数") raise ValueError("limit 不能为负数")
if filter_mai: return get_messages_by_time(start_time, end_time, limit, limit_mode, filter_mai)
return filter_mai_messages(get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode))
return get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)
def get_messages_by_time_for_users( def get_messages_by_time_for_users(
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest" start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[DatabaseMessages]: ) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型") raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0: if limit < 0:
raise ValueError("limit 不能为负数") raise ValueError("limit 不能为负数")
return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode) messages = find_messages(
message_filter={
"time": {
"$gte": start_time,
"$lte": end_time,
},
"user_id": {"$in": person_ids},
},
limit=limit,
limit_mode=limit_mode,
)
return _normalize_messages(messages)
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[DatabaseMessages]: def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[SessionMessage]:
if not isinstance(timestamp, (int, float)): if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型") raise ValueError("timestamp 必须是数字类型")
if limit < 0: if limit < 0:
raise ValueError("limit 不能为负数") raise ValueError("limit 不能为负数")
if filter_mai: messages = find_messages(
return filter_mai_messages(get_raw_msg_before_timestamp(timestamp, limit)) message_filter={"time": {"$lt": timestamp}},
return get_raw_msg_before_timestamp(timestamp, limit) limit=limit,
limit_mode="latest",
filter_bot=filter_mai,
)
return _normalize_messages(messages)
def get_messages_before_time_in_chat( def get_messages_before_time_in_chat(
@@ -168,7 +230,7 @@ def get_messages_before_time_in_chat(
limit: int = 0, limit: int = 0,
filter_mai: bool = False, filter_mai: bool = False,
filter_intercept_message_level: Optional[int] = None, filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]: ) -> List[SessionMessage]:
if not isinstance(timestamp, (int, float)): if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型") raise ValueError("timestamp 必须是数字类型")
if limit < 0: if limit < 0:
@@ -177,30 +239,40 @@ def get_messages_before_time_in_chat(
raise ValueError("chat_id 不能为空") raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str): if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型") raise ValueError("chat_id 必须是字符串类型")
messages = get_raw_msg_before_timestamp_with_chat( messages = find_messages(
chat_id=chat_id, message_filter={
timestamp=timestamp, "chat_id": chat_id,
"time": {"$lt": timestamp},
},
limit=limit, limit=limit,
limit_mode="latest",
filter_bot=filter_mai,
filter_intercept_message_level=filter_intercept_message_level, filter_intercept_message_level=filter_intercept_message_level,
) )
if filter_mai: return _normalize_messages(messages)
return filter_mai_messages(messages)
return messages
def get_messages_before_time_for_users( def get_messages_before_time_for_users(
timestamp: float, person_ids: List[str], limit: int = 0 timestamp: float, person_ids: List[str], limit: int = 0
) -> List[DatabaseMessages]: ) -> List[SessionMessage]:
if not isinstance(timestamp, (int, float)): if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型") raise ValueError("timestamp 必须是数字类型")
if limit < 0: if limit < 0:
raise ValueError("limit 不能为负数") raise ValueError("limit 不能为负数")
return get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit) messages = find_messages(
message_filter={
"time": {"$lt": timestamp},
"user_id": {"$in": person_ids},
},
limit=limit,
limit_mode="latest",
)
return _normalize_messages(messages)
def get_recent_messages( def get_recent_messages(
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]: ) -> List[SessionMessage]:
if not isinstance(hours, (int, float)) or hours < 0: if not isinstance(hours, (int, float)) or hours < 0:
raise ValueError("hours 不能是负数") raise ValueError("hours 不能是负数")
if not isinstance(limit, int) or limit < 0: if not isinstance(limit, int) or limit < 0:
@@ -211,9 +283,7 @@ def get_recent_messages(
raise ValueError("chat_id 必须是字符串类型") raise ValueError("chat_id 必须是字符串类型")
now = time.time() now = time.time()
start_time = now - hours * 3600 start_time = now - hours * 3600
if filter_mai: return get_messages_by_time_in_chat(chat_id, start_time, now, limit, limit_mode, filter_mai)
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode))
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode)
# ============================================================================= # =============================================================================
@@ -228,7 +298,13 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
raise ValueError("chat_id 不能为空") raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str): if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型") raise ValueError("chat_id 必须是字符串类型")
return num_new_messages_since(chat_id, start_time, end_time) message_filter: dict[str, Any] = {
"chat_id": chat_id,
"time": {"$gt": start_time},
}
if end_time is not None:
message_filter["time"]["$lte"] = end_time
return count_messages(message_filter)
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int: def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
@@ -238,7 +314,13 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
raise ValueError("chat_id 不能为空") raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str): if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型") raise ValueError("chat_id 必须是字符串类型")
return num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids) return count_messages(
{
"chat_id": chat_id,
"time": {"$gt": start_time, "$lte": end_time},
"user_id": {"$in": person_ids},
}
)
# ============================================================================= # =============================================================================
@@ -246,8 +328,45 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
# ============================================================================= # =============================================================================
def build_readable_messages(
messages: List[SessionMessage],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
) -> str:
normalized_messages = _normalize_messages(messages)
lines: List[str] = []
unread_mark_added = False
for message in normalized_messages:
if read_mark and not unread_mark_added and message.timestamp.timestamp() > read_mark:
lines.append("--- 以上消息是你已经看过,请关注以下未读的新消息 ---")
unread_mark_added = True
line = _build_readable_line(
message,
replace_bot_name=replace_bot_name,
timestamp_mode=timestamp_mode,
show_message_id_prefix=False,
)
if truncate and len(line) > 200:
line = f"{line[:200]}......(内容太长了)"
lines.append(line)
if show_actions and normalized_messages:
action_lines = build_readable_actions(
get_actions_by_timestamp_with_chat(
normalized_messages[0].session_id,
normalized_messages[0].timestamp.timestamp(),
normalized_messages[-1].timestamp.timestamp(),
)
)
if action_lines:
lines.append(action_lines)
return "\n".join(lines)
def build_readable_messages_to_str( def build_readable_messages_to_str(
messages: List[DatabaseMessages], messages: List[SessionMessage],
replace_bot_name: bool = True, replace_bot_name: bool = True,
timestamp_mode: str = "relative", timestamp_mode: str = "relative",
read_mark: float = 0.0, read_mark: float = 0.0,
@@ -257,17 +376,71 @@ def build_readable_messages_to_str(
return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions) return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions)
def build_readable_messages_with_id(
messages: List[SessionMessage],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
) -> Tuple[str, List[Tuple[str, SessionMessage]]]:
normalized_messages = _normalize_messages(messages)
lines: List[str] = []
message_id_list: List[Tuple[str, SessionMessage]] = []
unread_mark_added = False
for message in normalized_messages:
if read_mark and not unread_mark_added and message.timestamp.timestamp() > read_mark:
lines.append("--- 以上消息是你已经看过,请关注以下未读的新消息 ---")
unread_mark_added = True
line = _build_readable_line(
message,
replace_bot_name=replace_bot_name,
timestamp_mode=timestamp_mode,
show_message_id_prefix=True,
)
if truncate and len(line) > 200:
line = f"{line[:200]}......(内容太长了)"
lines.append(line)
message_id_list.append((message.message_id, message))
if show_actions and normalized_messages:
action_lines = build_readable_actions(
get_actions_by_timestamp_with_chat(
normalized_messages[0].session_id,
normalized_messages[0].timestamp.timestamp(),
normalized_messages[-1].timestamp.timestamp(),
)
)
if action_lines:
lines.append(action_lines)
return "\n".join(lines), message_id_list
async def build_readable_messages_with_details( async def build_readable_messages_with_details(
messages: List[DatabaseMessages], messages: List[SessionMessage],
replace_bot_name: bool = True, replace_bot_name: bool = True,
timestamp_mode: str = "relative", timestamp_mode: str = "relative",
truncate: bool = False, truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]: ) -> Tuple[str, List[Tuple[float, str, str]]]:
return await build_readable_messages_with_list(messages, replace_bot_name, timestamp_mode, truncate) normalized_messages = _normalize_messages(messages)
message_list = [
(
message.timestamp.timestamp(),
message.message_info.user_info.user_id,
message.processed_plain_text or "",
)
for message in normalized_messages
]
return build_readable_messages(normalized_messages, replace_bot_name, timestamp_mode, truncate=truncate), message_list
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]: async def get_person_ids_from_messages(messages: List[Any]) -> List[str]:
return await get_person_id_list(messages) person_ids: List[str] = []
for message in messages:
if isinstance(message, SessionMessage):
person_ids.append(message.message_info.user_info.user_id)
elif isinstance(message, dict) and (user_id := message.get("user_id")):
person_ids.append(str(user_id))
return person_ids
# ============================================================================= # =============================================================================
@@ -275,9 +448,145 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s
# ============================================================================= # =============================================================================
def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessages]: def filter_mai_messages(messages: List[SessionMessage]) -> List[SessionMessage]:
"""从消息列表中移除麦麦的消息""" """从消息列表中移除麦麦的消息"""
return [msg for msg in messages if not is_bot_self(msg.user_info.platform, msg.user_info.user_id)] return [
msg
for msg in messages
if not is_bot_self(msg.platform, msg.message_info.user_info.user_id)
]
def get_raw_msg_by_timestamp(
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
return get_messages_by_time(timestamp_start, timestamp_end, limit, limit_mode)
def get_raw_msg_by_timestamp_with_chat(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
filter_bot: bool = False,
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[SessionMessage]:
return get_messages_by_time_in_chat(
chat_id,
timestamp_start,
timestamp_end,
limit,
limit_mode,
filter_bot,
filter_command,
filter_intercept_message_level,
)
def get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
filter_bot: bool = False,
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[SessionMessage]:
return get_messages_by_time_in_chat_inclusive(
chat_id,
timestamp_start,
timestamp_end,
limit,
limit_mode,
filter_bot,
filter_command,
filter_intercept_message_level,
)
def get_raw_msg_by_timestamp_with_chat_users(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
return get_messages_by_time_in_chat_for_users(chat_id, timestamp_start, timestamp_end, person_ids, limit, limit_mode)
def get_raw_msg_by_timestamp_with_users(
timestamp_start: float,
timestamp_end: float,
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
return get_messages_by_time_for_users(timestamp_start, timestamp_end, person_ids, limit, limit_mode)
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[SessionMessage]:
return get_messages_before_time(timestamp, limit)
def get_raw_msg_before_timestamp_with_chat(
chat_id: str,
timestamp: float,
limit: int = 0,
filter_intercept_message_level: Optional[int] = None,
) -> List[SessionMessage]:
return get_messages_before_time_in_chat(chat_id, timestamp, limit, False, filter_intercept_message_level)
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[SessionMessage]:
return get_messages_before_time_for_users(timestamp, person_ids, limit)
def get_raw_msg_by_timestamp_random(
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
return get_random_chat_messages(timestamp_start, timestamp_end, limit, limit_mode)
def get_actions_by_timestamp_with_chat(chat_id: str, timestamp_start: float, timestamp_end: float) -> List[MaiActionRecord]:
with get_db_session() as session:
statement = (
select(ActionRecord)
.where(col(ActionRecord.session_id) == chat_id)
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(timestamp_start))
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(timestamp_end))
.order_by(col(ActionRecord.timestamp))
)
return [MaiActionRecord.from_db_instance(item) for item in session.exec(statement).all()]
def build_readable_actions(actions: List[MaiActionRecord], timestamp_mode: str = "relative") -> str:
return ActionUtils.build_readable_action_records(actions, timestamp_mode)
def replace_user_references(text: str, platform: str, replace_bot_name: bool = False) -> str:
del platform
if not text:
return text
def _replace(match: re.Match[str]) -> str:
prefix = match.group(1) or ""
user_name = match.group(2)
if replace_bot_name and user_name == global_config.bot.nickname:
user_name = ""
return f"{prefix}{user_name}"
text = re.sub(r"(回复|@)?<([^:<>]+):[^<>]+>", _replace, text)
return text
def translate_pid_to_description(pid: str) -> str: def translate_pid_to_description(pid: str) -> str:

View File

@@ -6,21 +6,20 @@
import traceback import traceback
import time import time
from typing import Optional, Union, Dict, List, TYPE_CHECKING, Tuple from typing import Optional, Union, Dict, List, TYPE_CHECKING
from maim_message import MessageBase, BaseMessageInfo, Seg from maim_message import BaseMessageInfo, GroupInfo as MaimGroupInfo, MessageBase, Seg, UserInfo as MaimUserInfo
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.message import MessageSending from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo from src.common.data_models.mai_message_data_model import MaiMessage
from src.common.data_models.message_data_model import ReplyContentType from src.common.data_models.message_component_data_model import DictComponent, MessageSequence
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
if TYPE_CHECKING: if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages from src.chat.message_receive.message import SessionMessage
from src.common.data_models.message_data_model import ForwardNode, ReplyContent, ReplySetModel
logger = get_logger("send_service") logger = get_logger("send_service")
@@ -36,7 +35,7 @@ async def _send_to_target(
display_message: str = "", display_message: str = "",
typing: bool = False, typing: bool = False,
set_reply: bool = False, set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None, reply_message: Optional["SessionMessage"] = None,
storage_message: bool = True, storage_message: bool = True,
show_log: bool = True, show_log: bool = True,
selected_expressions: Optional[List[int]] = None, selected_expressions: Optional[List[int]] = None,
@@ -60,12 +59,6 @@ async def _send_to_target(
current_time = time.time() current_time = time.time()
message_id = f"send_api_{int(current_time * 1000)}" message_id = f"send_api_{int(current_time * 1000)}"
bot_user_info = UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
)
reply_to_platform_id = ""
anchor_message: Optional[MaiMessage] = None anchor_message: Optional[MaiMessage] = None
if reply_message: if reply_message:
anchor_message = db_message_to_mai_message(reply_message) anchor_message = db_message_to_mai_message(reply_message)
@@ -73,31 +66,50 @@ async def _send_to_target(
logger.debug( logger.debug(
f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}" f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}"
) )
reply_to_platform_id = f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}"
sender_info = None group_info = None
if target_stream.context and target_stream.context.message: if target_stream.group_id:
sender_info = target_stream.context.message.message_info.user_info group_name = ""
if target_stream.context and target_stream.context.message and target_stream.context.message.message_info.group_info:
group_name = target_stream.context.message.message_info.group_info.group_name
group_info = MaimGroupInfo(
group_id=target_stream.group_id,
group_name=group_name,
platform=target_stream.platform,
)
bot_message = MessageSending( additional_config: dict[str, object] = {}
message_id=message_id, if selected_expressions is not None:
session=target_stream, additional_config["selected_expressions"] = selected_expressions
bot_user_info=bot_user_info,
sender_info=sender_info, maim_message = MessageBase(
message_info=BaseMessageInfo(
platform=target_stream.platform,
message_id=message_id,
time=current_time,
user_info=MaimUserInfo(
user_id=str(global_config.bot.qq_account),
user_nickname=global_config.bot.nickname,
platform=target_stream.platform,
),
group_info=group_info,
additional_config=additional_config,
),
message_segment=message_segment, message_segment=message_segment,
display_message=display_message,
reply=anchor_message,
is_head=True,
is_emoji=(message_segment.type == "emoji"),
thinking_start_time=current_time,
reply_to=reply_to_platform_id,
selected_expressions=selected_expressions,
) )
bot_message = SessionMessage.from_maim_message(maim_message)
bot_message.session_id = target_stream.session_id
bot_message.display_message = display_message
bot_message.reply_to = anchor_message.message_id if anchor_message else None
bot_message.is_emoji = message_segment.type == "emoji"
bot_message.is_picture = message_segment.type == "image"
bot_message.is_command = message_segment.type == "command"
sent_msg = await message_sender.send_message( sent_msg = await message_sender.send_message(
bot_message, bot_message,
typing=typing, typing=typing,
set_reply=set_reply, set_reply=set_reply,
reply_message_id=anchor_message.message_id if anchor_message else None,
storage_message=storage_message, storage_message=storage_message,
show_log=show_log, show_log=show_log,
) )
@@ -115,37 +127,9 @@ async def _send_to_target(
return False return False
def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMessage]: def db_message_to_mai_message(message_obj: "SessionMessage") -> Optional[MaiMessage]:
"""将数据库消息重建为 MaiMessage 对象,用于回复引用。""" """将数据库消息重建为 MaiMessage 对象,用于回复引用。"""
from datetime import datetime return message_obj.deepcopy()
from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo
from src.common.data_models.message_component_data_model import MessageSequence
user_info = UserInfo(
user_id=message_obj.user_info.user_id or "",
user_nickname=message_obj.user_info.user_nickname or "",
user_cardname=message_obj.user_info.user_cardname,
)
group_info = None
if message_obj.chat_info.group_info:
group_info = GroupInfo(
group_id=message_obj.chat_info.group_info.group_id or "",
group_name=message_obj.chat_info.group_info.group_name or "",
)
msg = MaiMessage(
message_id=message_obj.message_id,
timestamp=datetime.fromtimestamp(message_obj.time) if message_obj.time else datetime.now(),
)
msg.message_info = MessageInfo(user_info=user_info, group_info=group_info)
msg.platform = message_obj.chat_info.platform or ""
msg.session_id = message_obj.chat_info.stream_id or ""
msg.processed_plain_text = message_obj.processed_plain_text
msg.raw_message = MessageSequence(components=[])
msg.initialized = True
return msg
# ============================================================================= # =============================================================================
@@ -158,7 +142,7 @@ async def text_to_stream(
stream_id: str, stream_id: str,
typing: bool = False, typing: bool = False,
set_reply: bool = False, set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None, reply_message: Optional["SessionMessage"] = None,
storage_message: bool = True, storage_message: bool = True,
selected_expressions: Optional[List[int]] = None, selected_expressions: Optional[List[int]] = None,
) -> bool: ) -> bool:
@@ -180,7 +164,7 @@ async def emoji_to_stream(
stream_id: str, stream_id: str,
storage_message: bool = True, storage_message: bool = True,
set_reply: bool = False, set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None, reply_message: Optional["SessionMessage"] = None,
) -> bool: ) -> bool:
"""向指定流发送表情包""" """向指定流发送表情包"""
return await _send_to_target( return await _send_to_target(
@@ -199,7 +183,7 @@ async def image_to_stream(
stream_id: str, stream_id: str,
storage_message: bool = True, storage_message: bool = True,
set_reply: bool = False, set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None, reply_message: Optional["SessionMessage"] = None,
) -> bool: ) -> bool:
"""向指定流发送图片""" """向指定流发送图片"""
return await _send_to_target( return await _send_to_target(
@@ -236,7 +220,7 @@ async def custom_to_stream(
stream_id: str, stream_id: str,
display_message: str = "", display_message: str = "",
typing: bool = False, typing: bool = False,
reply_message: Optional["DatabaseMessages"] = None, reply_message: Optional["SessionMessage"] = None,
set_reply: bool = False, set_reply: bool = False,
storage_message: bool = True, storage_message: bool = True,
show_log: bool = True, show_log: bool = True,
@@ -255,25 +239,27 @@ async def custom_to_stream(
async def custom_reply_set_to_stream( async def custom_reply_set_to_stream(
reply_set: "ReplySetModel", reply_set: MessageSequence,
stream_id: str, stream_id: str,
display_message: str = "", display_message: str = "",
typing: bool = False, typing: bool = False,
reply_message: Optional["DatabaseMessages"] = None, reply_message: Optional["SessionMessage"] = None,
set_reply: bool = False, set_reply: bool = False,
storage_message: bool = True, storage_message: bool = True,
show_log: bool = True, show_log: bool = True,
) -> bool: ) -> bool:
"""向指定流发送混合型消息集""" """向指定流发送消息组件序列。"""
flag: bool = True flag: bool = True
for reply_content in reply_set.reply_data: for component in reply_set.components:
status: bool = False if isinstance(component, DictComponent):
message_seg, need_typing = _parse_content_to_seg(reply_content) message_seg = Seg(type="dict", data=component.data) # type: ignore
else:
message_seg = await component.to_seg()
status = await _send_to_target( status = await _send_to_target(
message_segment=message_seg, message_segment=message_seg,
stream_id=stream_id, stream_id=stream_id,
display_message=display_message, display_message=display_message,
typing=bool(need_typing and typing), typing=typing,
reply_message=reply_message, reply_message=reply_message,
set_reply=set_reply, set_reply=set_reply,
storage_message=storage_message, storage_message=storage_message,
@@ -281,67 +267,7 @@ async def custom_reply_set_to_stream(
) )
if not status: if not status:
flag = False flag = False
logger.error( logger.error(f"[SendService] 发送消息组件失败,组件类型:{type(component).__name__}")
f"[SendService] 发送{repr(reply_content.content_type)}消息失败,消息内容:{str(reply_content.content)[:100]}" set_reply = False
)
return flag return flag
def _parse_content_to_seg(reply_content: "ReplyContent") -> Tuple[Seg, bool]:
"""把 ReplyContent 转换为 Seg 结构"""
content_type = reply_content.content_type
if content_type == ReplyContentType.TEXT:
text_data: str = reply_content.content # type: ignore
return Seg(type="text", data=text_data), True
elif content_type == ReplyContentType.IMAGE:
return Seg(type="image", data=reply_content.content), False # type: ignore
elif content_type == ReplyContentType.EMOJI:
return Seg(type="emoji", data=reply_content.content), False # type: ignore
elif content_type == ReplyContentType.COMMAND:
return Seg(type="command", data=reply_content.content), False # type: ignore
elif content_type == ReplyContentType.VOICE:
return Seg(type="voice", data=reply_content.content), False # type: ignore
elif content_type == ReplyContentType.HYBRID:
hybrid_message_list_data: List[ReplyContent] = reply_content.content # type: ignore
assert isinstance(hybrid_message_list_data, list), "混合类型内容必须是列表"
sub_seg_list: List[Seg] = []
for sub_content in hybrid_message_list_data:
sub_content_type = sub_content.content_type
sub_content_data = sub_content.content
if sub_content_type == ReplyContentType.TEXT:
sub_seg_list.append(Seg(type="text", data=sub_content_data)) # type: ignore
elif sub_content_type == ReplyContentType.IMAGE:
sub_seg_list.append(Seg(type="image", data=sub_content_data)) # type: ignore
elif sub_content_type == ReplyContentType.EMOJI:
sub_seg_list.append(Seg(type="emoji", data=sub_content_data)) # type: ignore
else:
logger.warning(f"[SendService] 混合类型中不支持的子内容类型: {repr(sub_content_type)}")
continue
return Seg(type="seglist", data=sub_seg_list), True
elif content_type == ReplyContentType.FORWARD:
forward_message_list_data: List["ForwardNode"] = reply_content.content # type: ignore
assert isinstance(forward_message_list_data, list), "转发类型内容必须是列表"
forward_message_list: List[Dict] = []
for forward_node in forward_message_list_data:
message_segment = Seg(type="id", data=forward_node.content) # type: ignore
user_info: Optional[UserInfo] = None
if forward_node.user_id and forward_node.user_nickname:
assert isinstance(forward_node.content, list), "转发节点内容必须是列表"
user_info = UserInfo(user_id=forward_node.user_id, user_nickname=forward_node.user_nickname)
single_node_content: List[Seg] = []
for sub_content in forward_node.content:
if sub_content.content_type != ReplyContentType.FORWARD:
sub_seg, _ = _parse_content_to_seg(sub_content)
single_node_content.append(sub_seg)
message_segment = Seg(type="seglist", data=single_node_content)
forward_message_list.append(
MessageBase(
message_segment=message_segment, message_info=BaseMessageInfo(user_info=user_info)
).to_dict()
)
return Seg(type="forward", data=forward_message_list), False # type: ignore
else:
message_type_in_str = content_type.value if isinstance(content_type, ReplyContentType) else str(content_type)
return Seg(type=message_type_in_str, data=reply_content.content), True # type: ignore