From a4303d9b81927daf0fcff7e25a15c73ab36525f7 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sat, 14 Mar 2026 00:56:40 +0800 Subject: [PATCH] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E6=89=80=E6=9C=89=20MongoDB?= =?UTF-8?q?=20=E9=A3=8E=E6=A0=BC=20filter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/message_repository.py | 203 ++++++++++++++++++------------- src/services/message_service.py | 43 ++----- 2 files changed, 127 insertions(+), 119 deletions(-) diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 7215ffa3..7b35ae07 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -98,24 +98,6 @@ def _coerce_datetime(value: Any) -> Any: return value -def _cast_value_for_field(field: Any, value: Any) -> Any: - if field is Messages.timestamp: - return _coerce_datetime(value) - return value - - -def _ensure_list(value: Any) -> list[Any]: - if value is None: - return [] - if isinstance(value, list): - return value - if isinstance(value, tuple): - return list(value) - if isinstance(value, set): - return list(value) - return [value] - - def _resolve_field(field_name: str) -> Any | None: if field_name in FIELD_MAP: return FIELD_MAP[field_name] @@ -124,8 +106,57 @@ def _resolve_field(field_name: str) -> Any | None: return None +def _build_message_conditions( + *, + session_id: str | None = None, + user_id: str | None = None, + group_id: str | None = None, + platform: str | None = None, + message_id: str | None = None, + reply_to: str | None = None, + start_time: float | None = None, + end_time: float | None = None, + before_time: float | None = None, + after_time: float | None = None, +) -> list[Any]: + conditions: list[Any] = [Messages.message_id != "notice"] + + if session_id is not None: + conditions.append(Messages.session_id == session_id) + if user_id is not None: + conditions.append(Messages.user_id == user_id) + if group_id is not None: + conditions.append(Messages.group_id == group_id) + if platform is not None: + conditions.append(Messages.platform == platform) + if message_id is not None: + conditions.append(Messages.message_id == message_id) + if reply_to is not None: + conditions.append(Messages.reply_to == reply_to) + if start_time is not None: + conditions.append(Messages.timestamp >= _coerce_datetime(start_time)) + if end_time is not None: + conditions.append(Messages.timestamp <= _coerce_datetime(end_time)) + if before_time is not None: + conditions.append(Messages.timestamp < _coerce_datetime(before_time)) + if after_time is not None: + conditions.append(Messages.timestamp > _coerce_datetime(after_time)) + + return conditions + + def find_messages( - message_filter: dict[str, Any], + *, + session_id: str | None = None, + user_id: str | None = None, + group_id: str | None = None, + platform: str | None = None, + message_id: str | None = None, + reply_to: str | None = None, + start_time: float | None = None, + end_time: float | None = None, + before_time: float | None = None, + after_time: float | None = None, sort: list[tuple[str, int]] | None = None, limit: int = 0, limit_mode: str = "latest", @@ -137,7 +168,16 @@ def find_messages( 根据提供的过滤器、排序和限制条件查找消息。 Args: - message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}). + session_id: 会话 ID 过滤。 + user_id: 用户 ID 过滤。 + group_id: 群 ID 过滤。 + platform: 平台过滤。 + message_id: 消息 ID 过滤。 + reply_to: 回复目标消息 ID 过滤。 + start_time: 起始时间,闭区间下界。 + end_time: 结束时间,闭区间上界。 + before_time: 严格早于该时间。 + after_time: 严格晚于该时间。 sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。 limit: 返回的最大文档数,0表示不限制。 limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。 @@ -146,37 +186,18 @@ def find_messages( 消息字典列表,如果出错则返回空列表。 """ try: - conditions: list[Any] = [] - if message_filter: - for key, value in message_filter.items(): - field = _resolve_field(key) - if field is None: - logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") - continue - if isinstance(value, dict): - for op, op_value in value.items(): - coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value - if op == "$gt": - conditions.append(field > coerced_value) - elif op == "$lt": - conditions.append(field < coerced_value) - elif op == "$gte": - conditions.append(field >= coerced_value) - elif op == "$lte": - conditions.append(field <= coerced_value) - elif op == "$ne": - conditions.append(field != coerced_value) - elif op == "$in": - conditions.append(field.in_(_ensure_list(coerced_value))) - elif op == "$nin": - conditions.append(field.not_in(_ensure_list(coerced_value))) - else: - logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。") - else: - coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value - conditions.append(field == coerced_value) - - conditions.append(Messages.message_id != "notice") + conditions = _build_message_conditions( + session_id=session_id, + user_id=user_id, + group_id=group_id, + platform=platform, + message_id=message_id, + reply_to=reply_to, + start_time=start_time, + end_time=end_time, + before_time=before_time, + after_time=after_time, + ) if filter_bot: conditions.append(Messages.user_id != global_config.bot.qq_account) if filter_command: @@ -218,60 +239,70 @@ def find_messages( return [_message_to_instance(msg) for msg in results] except Exception as e: log_message = ( - f"使用 SQLModel 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" + "使用 SQLModel 查找消息失败 " + f"(session_id={session_id}, user_id={user_id}, group_id={group_id}, platform={platform}, " + f"message_id={message_id}, reply_to={reply_to}, start_time={start_time}, end_time={end_time}, " + f"before_time={before_time}, after_time={after_time}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" + traceback.format_exc() ) logger.error(log_message) return [] -def count_messages(message_filter: dict[str, Any]) -> int: +def count_messages( + *, + session_id: str | None = None, + user_id: str | None = None, + group_id: str | None = None, + platform: str | None = None, + message_id: str | None = None, + reply_to: str | None = None, + start_time: float | None = None, + end_time: float | None = None, + before_time: float | None = None, + after_time: float | None = None, +) -> int: """ 根据提供的过滤器计算消息数量。 Args: - message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}). + session_id: 会话 ID 过滤。 + user_id: 用户 ID 过滤。 + group_id: 群 ID 过滤。 + platform: 平台过滤。 + message_id: 消息 ID 过滤。 + reply_to: 回复目标消息 ID 过滤。 + start_time: 起始时间,闭区间下界。 + end_time: 结束时间,闭区间上界。 + before_time: 严格早于该时间。 + after_time: 严格晚于该时间。 Returns: 符合条件的消息数量,如果出错则返回 0。 """ try: - conditions: list[Any] = [] - if message_filter: - for key, value in message_filter.items(): - field = _resolve_field(key) - if field is None: - logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") - continue - if isinstance(value, dict): - for op, op_value in value.items(): - coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value - if op == "$gt": - conditions.append(field > coerced_value) - elif op == "$lt": - conditions.append(field < coerced_value) - elif op == "$gte": - conditions.append(field >= coerced_value) - elif op == "$lte": - conditions.append(field <= coerced_value) - elif op == "$ne": - conditions.append(field != coerced_value) - elif op == "$in": - conditions.append(field.in_(_ensure_list(coerced_value))) - elif op == "$nin": - conditions.append(field.not_in(_ensure_list(coerced_value))) - else: - logger.warning(f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。") - else: - coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value - conditions.append(field == coerced_value) - - conditions.append(Messages.message_id != "notice") + conditions = _build_message_conditions( + session_id=session_id, + user_id=user_id, + group_id=group_id, + platform=platform, + message_id=message_id, + reply_to=reply_to, + start_time=start_time, + end_time=end_time, + before_time=before_time, + after_time=after_time, + ) statement = select(func.count()).select_from(Messages).where(*conditions) with get_db_session() as session: result = session.exec(statement).one() return int(result or 0) except Exception as e: - log_message = f"使用 SQLModel 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}" + log_message = ( + "使用 SQLModel 计数消息失败 " + f"(session_id={session_id}, user_id={user_id}, group_id={group_id}, platform={platform}, " + f"message_id={message_id}, reply_to={reply_to}, start_time={start_time}, end_time={end_time}, " + f"before_time={before_time}, after_time={after_time}): {e}\n{traceback.format_exc()}" + ) logger.error(log_message) return 0 diff --git a/src/services/message_service.py b/src/services/message_service.py index cdf9ff56..d918b177 100644 --- a/src/services/message_service.py +++ b/src/services/message_service.py @@ -1,9 +1,8 @@ """消息服务模块。""" import re -import time from datetime import datetime -from typing import Any, List, Optional, Tuple +from typing import List, Optional, Tuple from sqlmodel import col, select @@ -17,20 +16,6 @@ from src.common.utils.utils_action import ActionUtils from src.config.config import global_config -# ============================================================================= -# 消息查询函数 -# ============================================================================= - - -def _build_time_range_filter(start_time: float, end_time: float) -> dict[str, Any]: - return { - "time": { - "$gte": start_time, - "$lte": end_time, - } - } - - def _build_readable_line( message: SessionMessage, *, @@ -72,7 +57,8 @@ def get_messages_by_time( if limit < 0: raise ValueError("limit 不能为负数") messages = find_messages( - message_filter=_build_time_range_filter(start_time, end_time), + start_time=start_time, + end_time=end_time, limit=limit, limit_mode=limit_mode, filter_bot=filter_mai, @@ -99,10 +85,9 @@ def get_messages_by_time_in_chat( if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") messages = find_messages( - message_filter={ - "chat_id": chat_id, - **_build_time_range_filter(start_time, end_time), - }, + session_id=chat_id, + start_time=start_time, + end_time=end_time, limit=limit, limit_mode=limit_mode, filter_bot=filter_mai, @@ -118,7 +103,7 @@ def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool if limit < 0: raise ValueError("limit 不能为负数") messages = find_messages( - message_filter={"time": {"$lt": timestamp}}, + before_time=timestamp, limit=limit, limit_mode="latest", filter_bot=filter_mai, @@ -142,10 +127,8 @@ def get_messages_before_time_in_chat( if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") messages = find_messages( - message_filter={ - "chat_id": chat_id, - "time": {"$lt": timestamp}, - }, + session_id=chat_id, + before_time=timestamp, limit=limit, limit_mode="latest", filter_bot=filter_mai, @@ -166,13 +149,7 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional raise ValueError("chat_id 不能为空") if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") - message_filter: dict[str, Any] = { - "chat_id": chat_id, - "time": {"$gt": start_time}, - } - if end_time is not None: - message_filter["time"]["$lte"] = end_time - return count_messages(message_filter) + return count_messages(session_id=chat_id, after_time=start_time, end_time=end_time) # =============================================================================