部分模块的新数据结构适配

This commit is contained in:
DrSmoothl
2026-03-13 23:36:17 +08:00
parent 6201b862c9
commit 898fab6de9
7 changed files with 580 additions and 399 deletions

View File

@@ -16,14 +16,14 @@ from src.chat.replyer.group_generator import DefaultReplyer
from src.chat.replyer.private_generator import PrivateReplyer
from src.chat.replyer.replyer_manager import replyer_manager
from src.chat.utils.utils import process_llm_response
from src.common.data_models.message_data_model import ReplySetModel
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.common.logger import get_logger
from src.core.types import ActionInfo
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel
from src.chat.message_receive.message import SessionMessage
install(extra_lines=3)
@@ -67,7 +67,7 @@ async def generate_reply(
chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None,
action_data: Optional[Dict[str, Any]] = None,
reply_message: Optional["DatabaseMessages"] = None,
reply_message: Optional["SessionMessage"] = None,
think_level: int = 1,
extra_info: str = "",
reply_reason: str = "",
@@ -126,15 +126,17 @@ async def generate_reply(
if not success:
logger.warning("[GeneratorService] 回复生成失败")
return False, None
reply_set: Optional[ReplySetModel] = None
reply_set: Optional[MessageSequence] = None
if content := llm_response.content:
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
llm_response.processed_output = processed_response
reply_set = ReplySetModel()
reply_set = MessageSequence(components=[])
for text in processed_response:
reply_set.add_text_content(text)
reply_set.components.append(TextComponent(text))
llm_response.reply_set = reply_set
logger.debug(f"[GeneratorService] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
logger.debug(
f"[GeneratorService] 回复生成成功,生成了 {len(reply_set.components) if reply_set else 0} 个回复项"
)
try:
PlanReplyLogger.log_reply(
@@ -196,12 +198,14 @@ async def rewrite_reply(
reason=reason,
reply_to=reply_to,
)
reply_set: Optional[ReplySetModel] = None
reply_set: Optional[MessageSequence] = None
if success and llm_response and (content := llm_response.content):
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
llm_response.reply_set = reply_set
if success:
logger.info(f"[GeneratorService] 重写回复成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
logger.info(
f"[GeneratorService] 重写回复成功,生成了 {len(reply_set.components) if reply_set else 0} 个回复项"
)
else:
logger.warning("[GeneratorService] 重写回复失败")
@@ -215,16 +219,16 @@ async def rewrite_reply(
return False, None
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[ReplySetModel]:
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[MessageSequence]:
"""将文本处理为更拟人化的文本"""
if not isinstance(content, str):
raise ValueError("content 必须是字符串类型")
try:
reply_set = ReplySetModel()
reply_set = MessageSequence(components=[])
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
for text in processed_response:
reply_set.add_text_content(text)
reply_set.components.append(TextComponent(text))
return reply_set

View File

@@ -1,34 +1,21 @@
"""
消息服务模块
提供消息查询和构建成字符串的核心功能。
"""
"""消息服务模块。"""
import re
import time
from typing import Any, Dict, List, Optional, Tuple
from datetime import datetime
from typing import Any, List, Optional, Tuple
from sqlmodel import col, select
from src.chat.utils.chat_message_builder import (
build_readable_messages,
build_readable_messages_with_list,
get_person_id_list,
get_raw_msg_before_timestamp,
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_before_timestamp_with_users,
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_random,
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_by_timestamp_with_chat_users,
get_raw_msg_by_timestamp_with_users,
num_new_messages_since,
num_new_messages_since_with_users,
)
from src.chat.utils.utils import is_bot_self
from src.common.data_models.database_data_model import DatabaseMessages
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.action_record_data_model import MaiActionRecord
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
from src.common.database.database_model import ActionRecord, Images, ImageType
from src.common.message_repository import count_messages, find_messages
from src.common.utils.math_utils import translate_timestamp_to_human_readable
from src.common.utils.utils_action import ActionUtils
from src.chat.utils.utils import is_bot_self
from src.config.config import global_config
# =============================================================================
@@ -36,16 +23,62 @@ from src.common.database.database_model import Images, ImageType
# =============================================================================
def _build_time_range_filter(start_time: float, end_time: float) -> dict[str, Any]:
return {
"time": {
"$gte": start_time,
"$lte": end_time,
}
}
def _build_readable_line(
message: SessionMessage,
*,
replace_bot_name: bool,
timestamp_mode: Optional[str],
show_message_id_prefix: bool,
) -> str:
plain_text = (message.processed_plain_text or "").strip()
if replace_bot_name and global_config.bot.nickname:
plain_text = plain_text.replace(global_config.bot.nickname, "")
user_name = (
message.message_info.user_info.user_cardname
or message.message_info.user_info.user_nickname
or message.message_info.user_info.user_id
)
prefix: List[str] = []
if timestamp_mode:
prefix.append(f"[{translate_timestamp_to_human_readable(message.timestamp.timestamp(), mode=timestamp_mode)}]")
if show_message_id_prefix:
prefix.append(f"[消息ID: {message.message_id}]")
prefix.append(f"{user_name}说:")
return " ".join(prefix) + plain_text
def _normalize_messages(messages: List[SessionMessage]) -> List[SessionMessage]:
normalized: List[SessionMessage] = []
for message in messages:
if not message.processed_plain_text:
message.processed_plain_text = message.display_message or ""
normalized.append(message)
return normalized
def get_messages_by_time(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode))
return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)
messages = find_messages(
message_filter=_build_time_range_filter(start_time, end_time),
limit=limit,
limit_mode=limit_mode,
filter_bot=filter_mai,
)
return _normalize_messages(messages)
def get_messages_by_time_in_chat(
@@ -57,7 +90,7 @@ def get_messages_by_time_in_chat(
filter_mai: bool = False,
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
@@ -66,16 +99,18 @@ def get_messages_by_time_in_chat(
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return get_raw_msg_by_timestamp_with_chat(
chat_id=chat_id,
timestamp_start=start_time,
timestamp_end=end_time,
messages = find_messages(
message_filter={
"chat_id": chat_id,
**_build_time_range_filter(start_time, end_time),
},
limit=limit,
limit_mode=limit_mode,
filter_bot=filter_mai,
filter_command=filter_command,
filter_intercept_message_level=filter_intercept_message_level,
)
return _normalize_messages(messages)
def get_messages_by_time_in_chat_inclusive(
@@ -87,7 +122,7 @@ def get_messages_by_time_in_chat_inclusive(
filter_mai: bool = False,
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
@@ -96,19 +131,21 @@ def get_messages_by_time_in_chat_inclusive(
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=chat_id,
timestamp_start=start_time,
timestamp_end=end_time,
messages = find_messages(
message_filter={
"chat_id": chat_id,
"time": {
"$gte": start_time,
"$lte": end_time,
},
},
limit=limit,
limit_mode=limit_mode,
filter_bot=filter_mai,
filter_command=filter_command,
filter_intercept_message_level=filter_intercept_message_level,
)
if filter_mai:
return filter_mai_messages(messages)
return messages
return _normalize_messages(messages)
def get_messages_by_time_in_chat_for_users(
@@ -118,7 +155,7 @@ def get_messages_by_time_in_chat_for_users(
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
@@ -127,39 +164,64 @@ def get_messages_by_time_in_chat_for_users(
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode)
messages = find_messages(
message_filter={
"chat_id": chat_id,
"time": {
"$gte": start_time,
"$lte": end_time,
},
"user_id": {"$in": person_ids},
},
limit=limit,
limit_mode=limit_mode,
)
return _normalize_messages(messages)
def get_random_chat_messages(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode))
return get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)
return get_messages_by_time(start_time, end_time, limit, limit_mode, filter_mai)
def get_messages_by_time_for_users(
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
messages = find_messages(
message_filter={
"time": {
"$gte": start_time,
"$lte": end_time,
},
"user_id": {"$in": person_ids},
},
limit=limit,
limit_mode=limit_mode,
)
return _normalize_messages(messages)
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[DatabaseMessages]:
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[SessionMessage]:
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai:
return filter_mai_messages(get_raw_msg_before_timestamp(timestamp, limit))
return get_raw_msg_before_timestamp(timestamp, limit)
messages = find_messages(
message_filter={"time": {"$lt": timestamp}},
limit=limit,
limit_mode="latest",
filter_bot=filter_mai,
)
return _normalize_messages(messages)
def get_messages_before_time_in_chat(
@@ -168,7 +230,7 @@ def get_messages_before_time_in_chat(
limit: int = 0,
filter_mai: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
@@ -177,30 +239,40 @@ def get_messages_before_time_in_chat(
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
messages = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=timestamp,
messages = find_messages(
message_filter={
"chat_id": chat_id,
"time": {"$lt": timestamp},
},
limit=limit,
limit_mode="latest",
filter_bot=filter_mai,
filter_intercept_message_level=filter_intercept_message_level,
)
if filter_mai:
return filter_mai_messages(messages)
return messages
return _normalize_messages(messages)
def get_messages_before_time_for_users(
timestamp: float, person_ids: List[str], limit: int = 0
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(timestamp, (int, float)):
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
return get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit)
messages = find_messages(
message_filter={
"time": {"$lt": timestamp},
"user_id": {"$in": person_ids},
},
limit=limit,
limit_mode="latest",
)
return _normalize_messages(messages)
def get_recent_messages(
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
) -> List[DatabaseMessages]:
) -> List[SessionMessage]:
if not isinstance(hours, (int, float)) or hours < 0:
raise ValueError("hours 不能是负数")
if not isinstance(limit, int) or limit < 0:
@@ -211,9 +283,7 @@ def get_recent_messages(
raise ValueError("chat_id 必须是字符串类型")
now = time.time()
start_time = now - hours * 3600
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode))
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode)
return get_messages_by_time_in_chat(chat_id, start_time, now, limit, limit_mode, filter_mai)
# =============================================================================
@@ -228,7 +298,13 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return num_new_messages_since(chat_id, start_time, end_time)
message_filter: dict[str, Any] = {
"chat_id": chat_id,
"time": {"$gt": start_time},
}
if end_time is not None:
message_filter["time"]["$lte"] = end_time
return count_messages(message_filter)
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
@@ -238,7 +314,13 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids)
return count_messages(
{
"chat_id": chat_id,
"time": {"$gt": start_time, "$lte": end_time},
"user_id": {"$in": person_ids},
}
)
# =============================================================================
@@ -246,8 +328,45 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
# =============================================================================
def build_readable_messages(
messages: List[SessionMessage],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
) -> str:
normalized_messages = _normalize_messages(messages)
lines: List[str] = []
unread_mark_added = False
for message in normalized_messages:
if read_mark and not unread_mark_added and message.timestamp.timestamp() > read_mark:
lines.append("--- 以上消息是你已经看过,请关注以下未读的新消息 ---")
unread_mark_added = True
line = _build_readable_line(
message,
replace_bot_name=replace_bot_name,
timestamp_mode=timestamp_mode,
show_message_id_prefix=False,
)
if truncate and len(line) > 200:
line = f"{line[:200]}......(内容太长了)"
lines.append(line)
if show_actions and normalized_messages:
action_lines = build_readable_actions(
get_actions_by_timestamp_with_chat(
normalized_messages[0].session_id,
normalized_messages[0].timestamp.timestamp(),
normalized_messages[-1].timestamp.timestamp(),
)
)
if action_lines:
lines.append(action_lines)
return "\n".join(lines)
def build_readable_messages_to_str(
messages: List[DatabaseMessages],
messages: List[SessionMessage],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
@@ -257,17 +376,71 @@ def build_readable_messages_to_str(
return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions)
def build_readable_messages_with_id(
messages: List[SessionMessage],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
) -> Tuple[str, List[Tuple[str, SessionMessage]]]:
normalized_messages = _normalize_messages(messages)
lines: List[str] = []
message_id_list: List[Tuple[str, SessionMessage]] = []
unread_mark_added = False
for message in normalized_messages:
if read_mark and not unread_mark_added and message.timestamp.timestamp() > read_mark:
lines.append("--- 以上消息是你已经看过,请关注以下未读的新消息 ---")
unread_mark_added = True
line = _build_readable_line(
message,
replace_bot_name=replace_bot_name,
timestamp_mode=timestamp_mode,
show_message_id_prefix=True,
)
if truncate and len(line) > 200:
line = f"{line[:200]}......(内容太长了)"
lines.append(line)
message_id_list.append((message.message_id, message))
if show_actions and normalized_messages:
action_lines = build_readable_actions(
get_actions_by_timestamp_with_chat(
normalized_messages[0].session_id,
normalized_messages[0].timestamp.timestamp(),
normalized_messages[-1].timestamp.timestamp(),
)
)
if action_lines:
lines.append(action_lines)
return "\n".join(lines), message_id_list
async def build_readable_messages_with_details(
messages: List[DatabaseMessages],
messages: List[SessionMessage],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
return await build_readable_messages_with_list(messages, replace_bot_name, timestamp_mode, truncate)
normalized_messages = _normalize_messages(messages)
message_list = [
(
message.timestamp.timestamp(),
message.message_info.user_info.user_id,
message.processed_plain_text or "",
)
for message in normalized_messages
]
return build_readable_messages(normalized_messages, replace_bot_name, timestamp_mode, truncate=truncate), message_list
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
return await get_person_id_list(messages)
async def get_person_ids_from_messages(messages: List[Any]) -> List[str]:
person_ids: List[str] = []
for message in messages:
if isinstance(message, SessionMessage):
person_ids.append(message.message_info.user_info.user_id)
elif isinstance(message, dict) and (user_id := message.get("user_id")):
person_ids.append(str(user_id))
return person_ids
# =============================================================================
@@ -275,9 +448,145 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s
# =============================================================================
def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
def filter_mai_messages(messages: List[SessionMessage]) -> List[SessionMessage]:
"""从消息列表中移除麦麦的消息"""
return [msg for msg in messages if not is_bot_self(msg.user_info.platform, msg.user_info.user_id)]
return [
msg
for msg in messages
if not is_bot_self(msg.platform, msg.message_info.user_info.user_id)
]
def get_raw_msg_by_timestamp(
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
return get_messages_by_time(timestamp_start, timestamp_end, limit, limit_mode)
def get_raw_msg_by_timestamp_with_chat(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
filter_bot: bool = False,
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[SessionMessage]:
return get_messages_by_time_in_chat(
chat_id,
timestamp_start,
timestamp_end,
limit,
limit_mode,
filter_bot,
filter_command,
filter_intercept_message_level,
)
def get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
filter_bot: bool = False,
filter_command: bool = False,
filter_intercept_message_level: Optional[int] = None,
) -> List[SessionMessage]:
return get_messages_by_time_in_chat_inclusive(
chat_id,
timestamp_start,
timestamp_end,
limit,
limit_mode,
filter_bot,
filter_command,
filter_intercept_message_level,
)
def get_raw_msg_by_timestamp_with_chat_users(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
return get_messages_by_time_in_chat_for_users(chat_id, timestamp_start, timestamp_end, person_ids, limit, limit_mode)
def get_raw_msg_by_timestamp_with_users(
timestamp_start: float,
timestamp_end: float,
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
return get_messages_by_time_for_users(timestamp_start, timestamp_end, person_ids, limit, limit_mode)
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[SessionMessage]:
return get_messages_before_time(timestamp, limit)
def get_raw_msg_before_timestamp_with_chat(
chat_id: str,
timestamp: float,
limit: int = 0,
filter_intercept_message_level: Optional[int] = None,
) -> List[SessionMessage]:
return get_messages_before_time_in_chat(chat_id, timestamp, limit, False, filter_intercept_message_level)
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[SessionMessage]:
return get_messages_before_time_for_users(timestamp, person_ids, limit)
def get_raw_msg_by_timestamp_random(
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
) -> List[SessionMessage]:
return get_random_chat_messages(timestamp_start, timestamp_end, limit, limit_mode)
def get_actions_by_timestamp_with_chat(chat_id: str, timestamp_start: float, timestamp_end: float) -> List[MaiActionRecord]:
with get_db_session() as session:
statement = (
select(ActionRecord)
.where(col(ActionRecord.session_id) == chat_id)
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(timestamp_start))
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(timestamp_end))
.order_by(col(ActionRecord.timestamp))
)
return [MaiActionRecord.from_db_instance(item) for item in session.exec(statement).all()]
def build_readable_actions(actions: List[MaiActionRecord], timestamp_mode: str = "relative") -> str:
return ActionUtils.build_readable_action_records(actions, timestamp_mode)
def replace_user_references(text: str, platform: str, replace_bot_name: bool = False) -> str:
del platform
if not text:
return text
def _replace(match: re.Match[str]) -> str:
prefix = match.group(1) or ""
user_name = match.group(2)
if replace_bot_name and user_name == global_config.bot.nickname:
user_name = ""
return f"{prefix}{user_name}"
text = re.sub(r"(回复|@)?<([^:<>]+):[^<>]+>", _replace, text)
return text
def translate_pid_to_description(pid: str) -> str:

View File

@@ -6,21 +6,20 @@
import traceback
import time
from typing import Optional, Union, Dict, List, TYPE_CHECKING, Tuple
from typing import Optional, Union, Dict, List, TYPE_CHECKING
from maim_message import MessageBase, BaseMessageInfo, Seg
from maim_message import BaseMessageInfo, GroupInfo as MaimGroupInfo, MessageBase, Seg, UserInfo as MaimUserInfo
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.common.data_models.message_data_model import ReplyContentType
from src.common.data_models.mai_message_data_model import MaiMessage
from src.common.data_models.message_component_data_model import DictComponent, MessageSequence
from src.common.logger import get_logger
from src.config.config import global_config
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_data_model import ForwardNode, ReplyContent, ReplySetModel
from src.chat.message_receive.message import SessionMessage
logger = get_logger("send_service")
@@ -36,7 +35,7 @@ async def _send_to_target(
display_message: str = "",
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
reply_message: Optional["SessionMessage"] = None,
storage_message: bool = True,
show_log: bool = True,
selected_expressions: Optional[List[int]] = None,
@@ -60,12 +59,6 @@ async def _send_to_target(
current_time = time.time()
message_id = f"send_api_{int(current_time * 1000)}"
bot_user_info = UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
)
reply_to_platform_id = ""
anchor_message: Optional[MaiMessage] = None
if reply_message:
anchor_message = db_message_to_mai_message(reply_message)
@@ -73,31 +66,50 @@ async def _send_to_target(
logger.debug(
f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}"
)
reply_to_platform_id = f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}"
sender_info = None
if target_stream.context and target_stream.context.message:
sender_info = target_stream.context.message.message_info.user_info
group_info = None
if target_stream.group_id:
group_name = ""
if target_stream.context and target_stream.context.message and target_stream.context.message.message_info.group_info:
group_name = target_stream.context.message.message_info.group_info.group_name
group_info = MaimGroupInfo(
group_id=target_stream.group_id,
group_name=group_name,
platform=target_stream.platform,
)
bot_message = MessageSending(
message_id=message_id,
session=target_stream,
bot_user_info=bot_user_info,
sender_info=sender_info,
additional_config: dict[str, object] = {}
if selected_expressions is not None:
additional_config["selected_expressions"] = selected_expressions
maim_message = MessageBase(
message_info=BaseMessageInfo(
platform=target_stream.platform,
message_id=message_id,
time=current_time,
user_info=MaimUserInfo(
user_id=str(global_config.bot.qq_account),
user_nickname=global_config.bot.nickname,
platform=target_stream.platform,
),
group_info=group_info,
additional_config=additional_config,
),
message_segment=message_segment,
display_message=display_message,
reply=anchor_message,
is_head=True,
is_emoji=(message_segment.type == "emoji"),
thinking_start_time=current_time,
reply_to=reply_to_platform_id,
selected_expressions=selected_expressions,
)
bot_message = SessionMessage.from_maim_message(maim_message)
bot_message.session_id = target_stream.session_id
bot_message.display_message = display_message
bot_message.reply_to = anchor_message.message_id if anchor_message else None
bot_message.is_emoji = message_segment.type == "emoji"
bot_message.is_picture = message_segment.type == "image"
bot_message.is_command = message_segment.type == "command"
sent_msg = await message_sender.send_message(
bot_message,
typing=typing,
set_reply=set_reply,
reply_message_id=anchor_message.message_id if anchor_message else None,
storage_message=storage_message,
show_log=show_log,
)
@@ -115,37 +127,9 @@ async def _send_to_target(
return False
def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMessage]:
def db_message_to_mai_message(message_obj: "SessionMessage") -> Optional[MaiMessage]:
"""将数据库消息重建为 MaiMessage 对象,用于回复引用。"""
from datetime import datetime
from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo
from src.common.data_models.message_component_data_model import MessageSequence
user_info = UserInfo(
user_id=message_obj.user_info.user_id or "",
user_nickname=message_obj.user_info.user_nickname or "",
user_cardname=message_obj.user_info.user_cardname,
)
group_info = None
if message_obj.chat_info.group_info:
group_info = GroupInfo(
group_id=message_obj.chat_info.group_info.group_id or "",
group_name=message_obj.chat_info.group_info.group_name or "",
)
msg = MaiMessage(
message_id=message_obj.message_id,
timestamp=datetime.fromtimestamp(message_obj.time) if message_obj.time else datetime.now(),
)
msg.message_info = MessageInfo(user_info=user_info, group_info=group_info)
msg.platform = message_obj.chat_info.platform or ""
msg.session_id = message_obj.chat_info.stream_id or ""
msg.processed_plain_text = message_obj.processed_plain_text
msg.raw_message = MessageSequence(components=[])
msg.initialized = True
return msg
return message_obj.deepcopy()
# =============================================================================
@@ -158,7 +142,7 @@ async def text_to_stream(
stream_id: str,
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
reply_message: Optional["SessionMessage"] = None,
storage_message: bool = True,
selected_expressions: Optional[List[int]] = None,
) -> bool:
@@ -180,7 +164,7 @@ async def emoji_to_stream(
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
reply_message: Optional["SessionMessage"] = None,
) -> bool:
"""向指定流发送表情包"""
return await _send_to_target(
@@ -199,7 +183,7 @@ async def image_to_stream(
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
reply_message: Optional["SessionMessage"] = None,
) -> bool:
"""向指定流发送图片"""
return await _send_to_target(
@@ -236,7 +220,7 @@ async def custom_to_stream(
stream_id: str,
display_message: str = "",
typing: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
reply_message: Optional["SessionMessage"] = None,
set_reply: bool = False,
storage_message: bool = True,
show_log: bool = True,
@@ -255,25 +239,27 @@ async def custom_to_stream(
async def custom_reply_set_to_stream(
reply_set: "ReplySetModel",
reply_set: MessageSequence,
stream_id: str,
display_message: str = "",
typing: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
reply_message: Optional["SessionMessage"] = None,
set_reply: bool = False,
storage_message: bool = True,
show_log: bool = True,
) -> bool:
"""向指定流发送混合型消息集"""
"""向指定流发送消息组件序列。"""
flag: bool = True
for reply_content in reply_set.reply_data:
status: bool = False
message_seg, need_typing = _parse_content_to_seg(reply_content)
for component in reply_set.components:
if isinstance(component, DictComponent):
message_seg = Seg(type="dict", data=component.data) # type: ignore
else:
message_seg = await component.to_seg()
status = await _send_to_target(
message_segment=message_seg,
stream_id=stream_id,
display_message=display_message,
typing=bool(need_typing and typing),
typing=typing,
reply_message=reply_message,
set_reply=set_reply,
storage_message=storage_message,
@@ -281,67 +267,7 @@ async def custom_reply_set_to_stream(
)
if not status:
flag = False
logger.error(
f"[SendService] 发送{repr(reply_content.content_type)}消息失败,消息内容:{str(reply_content.content)[:100]}"
)
logger.error(f"[SendService] 发送消息组件失败,组件类型:{type(component).__name__}")
set_reply = False
return flag
def _parse_content_to_seg(reply_content: "ReplyContent") -> Tuple[Seg, bool]:
"""把 ReplyContent 转换为 Seg 结构"""
content_type = reply_content.content_type
if content_type == ReplyContentType.TEXT:
text_data: str = reply_content.content # type: ignore
return Seg(type="text", data=text_data), True
elif content_type == ReplyContentType.IMAGE:
return Seg(type="image", data=reply_content.content), False # type: ignore
elif content_type == ReplyContentType.EMOJI:
return Seg(type="emoji", data=reply_content.content), False # type: ignore
elif content_type == ReplyContentType.COMMAND:
return Seg(type="command", data=reply_content.content), False # type: ignore
elif content_type == ReplyContentType.VOICE:
return Seg(type="voice", data=reply_content.content), False # type: ignore
elif content_type == ReplyContentType.HYBRID:
hybrid_message_list_data: List[ReplyContent] = reply_content.content # type: ignore
assert isinstance(hybrid_message_list_data, list), "混合类型内容必须是列表"
sub_seg_list: List[Seg] = []
for sub_content in hybrid_message_list_data:
sub_content_type = sub_content.content_type
sub_content_data = sub_content.content
if sub_content_type == ReplyContentType.TEXT:
sub_seg_list.append(Seg(type="text", data=sub_content_data)) # type: ignore
elif sub_content_type == ReplyContentType.IMAGE:
sub_seg_list.append(Seg(type="image", data=sub_content_data)) # type: ignore
elif sub_content_type == ReplyContentType.EMOJI:
sub_seg_list.append(Seg(type="emoji", data=sub_content_data)) # type: ignore
else:
logger.warning(f"[SendService] 混合类型中不支持的子内容类型: {repr(sub_content_type)}")
continue
return Seg(type="seglist", data=sub_seg_list), True
elif content_type == ReplyContentType.FORWARD:
forward_message_list_data: List["ForwardNode"] = reply_content.content # type: ignore
assert isinstance(forward_message_list_data, list), "转发类型内容必须是列表"
forward_message_list: List[Dict] = []
for forward_node in forward_message_list_data:
message_segment = Seg(type="id", data=forward_node.content) # type: ignore
user_info: Optional[UserInfo] = None
if forward_node.user_id and forward_node.user_nickname:
assert isinstance(forward_node.content, list), "转发节点内容必须是列表"
user_info = UserInfo(user_id=forward_node.user_id, user_nickname=forward_node.user_nickname)
single_node_content: List[Seg] = []
for sub_content in forward_node.content:
if sub_content.content_type != ReplyContentType.FORWARD:
sub_seg, _ = _parse_content_to_seg(sub_content)
single_node_content.append(sub_seg)
message_segment = Seg(type="seglist", data=single_node_content)
forward_message_list.append(
MessageBase(
message_segment=message_segment, message_info=BaseMessageInfo(user_info=user_info)
).to_dict()
)
return Seg(type="forward", data=forward_message_list), False # type: ignore
else:
message_type_in_str = content_type.value if isinstance(content_type, ReplyContentType) else str(content_type)
return Seg(type=message_type_in_str, data=reply_content.content), True # type: ignore