280 lines
10 KiB
Python
280 lines
10 KiB
Python
"""消息服务模块。"""
|
|
|
|
import re
|
|
from datetime import datetime
|
|
from typing import List, Optional, Tuple
|
|
|
|
from sqlmodel import col, select
|
|
|
|
from src.chat.message_receive.message import SessionMessage
|
|
from src.common.data_models.tool_record_data_model import MaiToolRecord
|
|
from src.common.database.database import get_db_session
|
|
from src.common.database.database_model import Images, ImageType, ToolRecord
|
|
from src.common.message_repository import count_messages, find_messages
|
|
from src.common.utils.math_utils import translate_timestamp_to_human_readable
|
|
from src.common.utils.utils_action import ActionUtils
|
|
from src.config.config import global_config
|
|
|
|
|
|
def _build_readable_line(
|
|
message: SessionMessage,
|
|
*,
|
|
replace_bot_name: bool,
|
|
timestamp_mode: Optional[str],
|
|
show_message_id_prefix: bool,
|
|
) -> str:
|
|
plain_text = (message.processed_plain_text or "").strip()
|
|
if replace_bot_name and global_config.bot.nickname:
|
|
plain_text = plain_text.replace(global_config.bot.nickname, "你")
|
|
user_name = (
|
|
message.message_info.user_info.user_cardname
|
|
or message.message_info.user_info.user_nickname
|
|
or message.message_info.user_info.user_id
|
|
)
|
|
prefix: List[str] = []
|
|
if timestamp_mode:
|
|
prefix.append(f"[{translate_timestamp_to_human_readable(message.timestamp.timestamp(), mode=timestamp_mode)}]")
|
|
if show_message_id_prefix:
|
|
prefix.append(f"[消息ID: {message.message_id}]")
|
|
prefix.append(f"{user_name}说:")
|
|
return " ".join(prefix) + plain_text
|
|
|
|
|
|
def _normalize_messages(messages: List[SessionMessage]) -> List[SessionMessage]:
|
|
normalized: List[SessionMessage] = []
|
|
for message in messages:
|
|
if not message.processed_plain_text:
|
|
message.processed_plain_text = message.display_message or ""
|
|
normalized.append(message)
|
|
return normalized
|
|
|
|
|
|
def get_messages_by_time(
|
|
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
|
) -> List[SessionMessage]:
|
|
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
|
raise ValueError("start_time 和 end_time 必须是数字类型")
|
|
if limit < 0:
|
|
raise ValueError("limit 不能为负数")
|
|
messages = find_messages(
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
limit=limit,
|
|
limit_mode=limit_mode,
|
|
filter_bot=filter_mai,
|
|
)
|
|
return _normalize_messages(messages)
|
|
|
|
|
|
def get_messages_by_time_in_chat(
|
|
chat_id: str,
|
|
start_time: float,
|
|
end_time: float,
|
|
limit: int = 0,
|
|
limit_mode: str = "latest",
|
|
filter_mai: bool = False,
|
|
filter_command: bool = False,
|
|
filter_intercept_message_level: Optional[int] = None,
|
|
) -> List[SessionMessage]:
|
|
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
|
raise ValueError("start_time 和 end_time 必须是数字类型")
|
|
if limit < 0:
|
|
raise ValueError("limit 不能为负数")
|
|
if not chat_id:
|
|
raise ValueError("chat_id 不能为空")
|
|
if not isinstance(chat_id, str):
|
|
raise ValueError("chat_id 必须是字符串类型")
|
|
messages = find_messages(
|
|
session_id=chat_id,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
limit=limit,
|
|
limit_mode=limit_mode,
|
|
filter_bot=filter_mai,
|
|
filter_command=filter_command,
|
|
filter_intercept_message_level=filter_intercept_message_level,
|
|
)
|
|
return _normalize_messages(messages)
|
|
|
|
|
|
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[SessionMessage]:
|
|
if not isinstance(timestamp, (int, float)):
|
|
raise ValueError("timestamp 必须是数字类型")
|
|
if limit < 0:
|
|
raise ValueError("limit 不能为负数")
|
|
messages = find_messages(
|
|
before_time=timestamp,
|
|
limit=limit,
|
|
limit_mode="latest",
|
|
filter_bot=filter_mai,
|
|
)
|
|
return _normalize_messages(messages)
|
|
|
|
|
|
def get_messages_before_time_in_chat(
|
|
chat_id: str,
|
|
timestamp: float,
|
|
limit: int = 0,
|
|
filter_mai: bool = False,
|
|
filter_intercept_message_level: Optional[int] = None,
|
|
) -> List[SessionMessage]:
|
|
if not isinstance(timestamp, (int, float)):
|
|
raise ValueError("timestamp 必须是数字类型")
|
|
if limit < 0:
|
|
raise ValueError("limit 不能为负数")
|
|
if not chat_id:
|
|
raise ValueError("chat_id 不能为空")
|
|
if not isinstance(chat_id, str):
|
|
raise ValueError("chat_id 必须是字符串类型")
|
|
messages = find_messages(
|
|
session_id=chat_id,
|
|
before_time=timestamp,
|
|
limit=limit,
|
|
limit_mode="latest",
|
|
filter_bot=filter_mai,
|
|
filter_intercept_message_level=filter_intercept_message_level,
|
|
)
|
|
return _normalize_messages(messages)
|
|
|
|
|
|
# =============================================================================
|
|
# 消息计数函数
|
|
# =============================================================================
|
|
|
|
|
|
def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int:
|
|
if not isinstance(start_time, (int, float)):
|
|
raise ValueError("start_time 必须是数字类型")
|
|
if not chat_id:
|
|
raise ValueError("chat_id 不能为空")
|
|
if not isinstance(chat_id, str):
|
|
raise ValueError("chat_id 必须是字符串类型")
|
|
return count_messages(session_id=chat_id, after_time=start_time, end_time=end_time)
|
|
|
|
|
|
# =============================================================================
|
|
# 消息格式化函数
|
|
# =============================================================================
|
|
|
|
|
|
def build_readable_messages(
|
|
messages: List[SessionMessage],
|
|
replace_bot_name: bool = True,
|
|
timestamp_mode: str = "relative",
|
|
read_mark: float = 0.0,
|
|
truncate: bool = False,
|
|
show_actions: bool = False,
|
|
) -> str:
|
|
normalized_messages = _normalize_messages(messages)
|
|
lines: List[str] = []
|
|
unread_mark_added = False
|
|
for message in normalized_messages:
|
|
if read_mark and not unread_mark_added and message.timestamp.timestamp() > read_mark:
|
|
lines.append("--- 以上消息是你已经看过,请关注以下未读的新消息 ---")
|
|
unread_mark_added = True
|
|
line = _build_readable_line(
|
|
message,
|
|
replace_bot_name=replace_bot_name,
|
|
timestamp_mode=timestamp_mode,
|
|
show_message_id_prefix=False,
|
|
)
|
|
if truncate and len(line) > 200:
|
|
line = f"{line[:200]}......(内容太长了)"
|
|
lines.append(line)
|
|
if show_actions and normalized_messages:
|
|
if action_lines := ActionUtils.build_readable_action_records(
|
|
get_actions_by_timestamp_with_chat(
|
|
normalized_messages[0].session_id,
|
|
normalized_messages[0].timestamp.timestamp(),
|
|
normalized_messages[-1].timestamp.timestamp(),
|
|
),
|
|
"relative",
|
|
):
|
|
lines.append(action_lines)
|
|
return "\n".join(lines)
|
|
|
|
|
|
def build_readable_messages_with_id(
|
|
messages: List[SessionMessage],
|
|
replace_bot_name: bool = True,
|
|
timestamp_mode: str = "relative",
|
|
read_mark: float = 0.0,
|
|
truncate: bool = False,
|
|
show_actions: bool = False,
|
|
) -> Tuple[str, List[Tuple[str, SessionMessage]]]:
|
|
normalized_messages = _normalize_messages(messages)
|
|
lines: List[str] = []
|
|
message_id_list: List[Tuple[str, SessionMessage]] = []
|
|
unread_mark_added = False
|
|
for message in normalized_messages:
|
|
if read_mark and not unread_mark_added and message.timestamp.timestamp() > read_mark:
|
|
lines.append("--- 以上消息是你已经看过,请关注以下未读的新消息 ---")
|
|
unread_mark_added = True
|
|
line = _build_readable_line(
|
|
message,
|
|
replace_bot_name=replace_bot_name,
|
|
timestamp_mode=timestamp_mode,
|
|
show_message_id_prefix=True,
|
|
)
|
|
if truncate and len(line) > 200:
|
|
line = f"{line[:200]}......(内容太长了)"
|
|
lines.append(line)
|
|
message_id_list.append((message.message_id, message))
|
|
if show_actions and normalized_messages:
|
|
if action_lines := ActionUtils.build_readable_action_records(
|
|
get_actions_by_timestamp_with_chat(
|
|
normalized_messages[0].session_id,
|
|
normalized_messages[0].timestamp.timestamp(),
|
|
normalized_messages[-1].timestamp.timestamp(),
|
|
),
|
|
"relative",
|
|
):
|
|
lines.append(action_lines)
|
|
return "\n".join(lines), message_id_list
|
|
|
|
|
|
def get_actions_by_timestamp_with_chat(
|
|
chat_id: str,
|
|
timestamp_start: float,
|
|
timestamp_end: float,
|
|
limit: Optional[int] = None,
|
|
) -> List[MaiToolRecord]:
|
|
with get_db_session() as session:
|
|
statement = (
|
|
select(ToolRecord)
|
|
.where(col(ToolRecord.session_id) == chat_id)
|
|
.where(col(ToolRecord.timestamp) >= datetime.fromtimestamp(timestamp_start))
|
|
.where(col(ToolRecord.timestamp) <= datetime.fromtimestamp(timestamp_end))
|
|
.order_by(col(ToolRecord.timestamp))
|
|
)
|
|
if limit is not None:
|
|
statement = statement.limit(limit)
|
|
return [MaiToolRecord.from_db_instance(item) for item in session.exec(statement).all()]
|
|
|
|
|
|
def replace_user_references(text: str, platform: str, replace_bot_name: bool = False) -> str:
|
|
del platform
|
|
if not text:
|
|
return text
|
|
|
|
def _replace(match: re.Match[str]) -> str:
|
|
prefix = match.group(1) or ""
|
|
user_name = match.group(2)
|
|
if replace_bot_name and user_name == global_config.bot.nickname:
|
|
user_name = "你"
|
|
return f"{prefix}{user_name}"
|
|
|
|
text = re.sub(r"(回复|@)?<([^:<>]+):[^<>]+>", _replace, text)
|
|
return text
|
|
|
|
|
|
def translate_pid_to_description(pid: str) -> str:
|
|
with get_db_session() as session:
|
|
statement = (
|
|
select(Images).where((col(Images.id) == int(pid)) & (col(Images.image_type) == ImageType.IMAGE))
|
|
if pid.isdigit()
|
|
else None
|
|
)
|
|
image = session.exec(statement).first() if statement is not None else None
|
|
return image.description.strip() if image and image.description and image.description.strip() else "[图片]"
|