移除所有 MongoDB 风格 filter

This commit is contained in:
DrSmoothl
2026-03-14 00:56:40 +08:00
parent c4a0cc19f8
commit a4303d9b81
2 changed files with 127 additions and 119 deletions

View File

@@ -98,24 +98,6 @@ def _coerce_datetime(value: Any) -> Any:
return value return value
def _cast_value_for_field(field: Any, value: Any) -> Any:
if field is Messages.timestamp:
return _coerce_datetime(value)
return value
def _ensure_list(value: Any) -> list[Any]:
if value is None:
return []
if isinstance(value, list):
return value
if isinstance(value, tuple):
return list(value)
if isinstance(value, set):
return list(value)
return [value]
def _resolve_field(field_name: str) -> Any | None: def _resolve_field(field_name: str) -> Any | None:
if field_name in FIELD_MAP: if field_name in FIELD_MAP:
return FIELD_MAP[field_name] return FIELD_MAP[field_name]
@@ -124,8 +106,57 @@ def _resolve_field(field_name: str) -> Any | None:
return None return None
def _build_message_conditions(
*,
session_id: str | None = None,
user_id: str | None = None,
group_id: str | None = None,
platform: str | None = None,
message_id: str | None = None,
reply_to: str | None = None,
start_time: float | None = None,
end_time: float | None = None,
before_time: float | None = None,
after_time: float | None = None,
) -> list[Any]:
conditions: list[Any] = [Messages.message_id != "notice"]
if session_id is not None:
conditions.append(Messages.session_id == session_id)
if user_id is not None:
conditions.append(Messages.user_id == user_id)
if group_id is not None:
conditions.append(Messages.group_id == group_id)
if platform is not None:
conditions.append(Messages.platform == platform)
if message_id is not None:
conditions.append(Messages.message_id == message_id)
if reply_to is not None:
conditions.append(Messages.reply_to == reply_to)
if start_time is not None:
conditions.append(Messages.timestamp >= _coerce_datetime(start_time))
if end_time is not None:
conditions.append(Messages.timestamp <= _coerce_datetime(end_time))
if before_time is not None:
conditions.append(Messages.timestamp < _coerce_datetime(before_time))
if after_time is not None:
conditions.append(Messages.timestamp > _coerce_datetime(after_time))
return conditions
def find_messages( def find_messages(
message_filter: dict[str, Any], *,
session_id: str | None = None,
user_id: str | None = None,
group_id: str | None = None,
platform: str | None = None,
message_id: str | None = None,
reply_to: str | None = None,
start_time: float | None = None,
end_time: float | None = None,
before_time: float | None = None,
after_time: float | None = None,
sort: list[tuple[str, int]] | None = None, sort: list[tuple[str, int]] | None = None,
limit: int = 0, limit: int = 0,
limit_mode: str = "latest", limit_mode: str = "latest",
@@ -137,7 +168,16 @@ def find_messages(
根据提供的过滤器、排序和限制条件查找消息。 根据提供的过滤器、排序和限制条件查找消息。
Args: Args:
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}). session_id: 会话 ID 过滤。
user_id: 用户 ID 过滤。
group_id: 群 ID 过滤。
platform: 平台过滤。
message_id: 消息 ID 过滤。
reply_to: 回复目标消息 ID 过滤。
start_time: 起始时间,闭区间下界。
end_time: 结束时间,闭区间上界。
before_time: 严格早于该时间。
after_time: 严格晚于该时间。
sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。 sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
limit: 返回的最大文档数0表示不限制。 limit: 返回的最大文档数0表示不限制。
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest' limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'
@@ -146,37 +186,18 @@ def find_messages(
消息字典列表,如果出错则返回空列表。 消息字典列表,如果出错则返回空列表。
""" """
try: try:
conditions: list[Any] = [] conditions = _build_message_conditions(
if message_filter: session_id=session_id,
for key, value in message_filter.items(): user_id=user_id,
field = _resolve_field(key) group_id=group_id,
if field is None: platform=platform,
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") message_id=message_id,
continue reply_to=reply_to,
if isinstance(value, dict): start_time=start_time,
for op, op_value in value.items(): end_time=end_time,
coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value before_time=before_time,
if op == "$gt": after_time=after_time,
conditions.append(field > coerced_value) )
elif op == "$lt":
conditions.append(field < coerced_value)
elif op == "$gte":
conditions.append(field >= coerced_value)
elif op == "$lte":
conditions.append(field <= coerced_value)
elif op == "$ne":
conditions.append(field != coerced_value)
elif op == "$in":
conditions.append(field.in_(_ensure_list(coerced_value)))
elif op == "$nin":
conditions.append(field.not_in(_ensure_list(coerced_value)))
else:
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
else:
coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value
conditions.append(field == coerced_value)
conditions.append(Messages.message_id != "notice")
if filter_bot: if filter_bot:
conditions.append(Messages.user_id != global_config.bot.qq_account) conditions.append(Messages.user_id != global_config.bot.qq_account)
if filter_command: if filter_command:
@@ -218,60 +239,70 @@ def find_messages(
return [_message_to_instance(msg) for msg in results] return [_message_to_instance(msg) for msg in results]
except Exception as e: except Exception as e:
log_message = ( log_message = (
f"使用 SQLModel 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" "使用 SQLModel 查找消息失败 "
f"(session_id={session_id}, user_id={user_id}, group_id={group_id}, platform={platform}, "
f"message_id={message_id}, reply_to={reply_to}, start_time={start_time}, end_time={end_time}, "
f"before_time={before_time}, after_time={after_time}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
+ traceback.format_exc() + traceback.format_exc()
) )
logger.error(log_message) logger.error(log_message)
return [] return []
def count_messages(message_filter: dict[str, Any]) -> int: def count_messages(
*,
session_id: str | None = None,
user_id: str | None = None,
group_id: str | None = None,
platform: str | None = None,
message_id: str | None = None,
reply_to: str | None = None,
start_time: float | None = None,
end_time: float | None = None,
before_time: float | None = None,
after_time: float | None = None,
) -> int:
""" """
根据提供的过滤器计算消息数量。 根据提供的过滤器计算消息数量。
Args: Args:
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}). session_id: 会话 ID 过滤。
user_id: 用户 ID 过滤。
group_id: 群 ID 过滤。
platform: 平台过滤。
message_id: 消息 ID 过滤。
reply_to: 回复目标消息 ID 过滤。
start_time: 起始时间,闭区间下界。
end_time: 结束时间,闭区间上界。
before_time: 严格早于该时间。
after_time: 严格晚于该时间。
Returns: Returns:
符合条件的消息数量,如果出错则返回 0。 符合条件的消息数量,如果出错则返回 0。
""" """
try: try:
conditions: list[Any] = [] conditions = _build_message_conditions(
if message_filter: session_id=session_id,
for key, value in message_filter.items(): user_id=user_id,
field = _resolve_field(key) group_id=group_id,
if field is None: platform=platform,
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") message_id=message_id,
continue reply_to=reply_to,
if isinstance(value, dict): start_time=start_time,
for op, op_value in value.items(): end_time=end_time,
coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value before_time=before_time,
if op == "$gt": after_time=after_time,
conditions.append(field > coerced_value) )
elif op == "$lt":
conditions.append(field < coerced_value)
elif op == "$gte":
conditions.append(field >= coerced_value)
elif op == "$lte":
conditions.append(field <= coerced_value)
elif op == "$ne":
conditions.append(field != coerced_value)
elif op == "$in":
conditions.append(field.in_(_ensure_list(coerced_value)))
elif op == "$nin":
conditions.append(field.not_in(_ensure_list(coerced_value)))
else:
logger.warning(f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
else:
coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value
conditions.append(field == coerced_value)
conditions.append(Messages.message_id != "notice")
statement = select(func.count()).select_from(Messages).where(*conditions) statement = select(func.count()).select_from(Messages).where(*conditions)
with get_db_session() as session: with get_db_session() as session:
result = session.exec(statement).one() result = session.exec(statement).one()
return int(result or 0) return int(result or 0)
except Exception as e: except Exception as e:
log_message = f"使用 SQLModel 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}" log_message = (
"使用 SQLModel 计数消息失败 "
f"(session_id={session_id}, user_id={user_id}, group_id={group_id}, platform={platform}, "
f"message_id={message_id}, reply_to={reply_to}, start_time={start_time}, end_time={end_time}, "
f"before_time={before_time}, after_time={after_time}): {e}\n{traceback.format_exc()}"
)
logger.error(log_message) logger.error(log_message)
return 0 return 0

View File

@@ -1,9 +1,8 @@
"""消息服务模块。""" """消息服务模块。"""
import re import re
import time
from datetime import datetime from datetime import datetime
from typing import Any, List, Optional, Tuple from typing import List, Optional, Tuple
from sqlmodel import col, select from sqlmodel import col, select
@@ -17,20 +16,6 @@ from src.common.utils.utils_action import ActionUtils
from src.config.config import global_config from src.config.config import global_config
# =============================================================================
# 消息查询函数
# =============================================================================
def _build_time_range_filter(start_time: float, end_time: float) -> dict[str, Any]:
return {
"time": {
"$gte": start_time,
"$lte": end_time,
}
}
def _build_readable_line( def _build_readable_line(
message: SessionMessage, message: SessionMessage,
*, *,
@@ -72,7 +57,8 @@ def get_messages_by_time(
if limit < 0: if limit < 0:
raise ValueError("limit 不能为负数") raise ValueError("limit 不能为负数")
messages = find_messages( messages = find_messages(
message_filter=_build_time_range_filter(start_time, end_time), start_time=start_time,
end_time=end_time,
limit=limit, limit=limit,
limit_mode=limit_mode, limit_mode=limit_mode,
filter_bot=filter_mai, filter_bot=filter_mai,
@@ -99,10 +85,9 @@ def get_messages_by_time_in_chat(
if not isinstance(chat_id, str): if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型") raise ValueError("chat_id 必须是字符串类型")
messages = find_messages( messages = find_messages(
message_filter={ session_id=chat_id,
"chat_id": chat_id, start_time=start_time,
**_build_time_range_filter(start_time, end_time), end_time=end_time,
},
limit=limit, limit=limit,
limit_mode=limit_mode, limit_mode=limit_mode,
filter_bot=filter_mai, filter_bot=filter_mai,
@@ -118,7 +103,7 @@ def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool
if limit < 0: if limit < 0:
raise ValueError("limit 不能为负数") raise ValueError("limit 不能为负数")
messages = find_messages( messages = find_messages(
message_filter={"time": {"$lt": timestamp}}, before_time=timestamp,
limit=limit, limit=limit,
limit_mode="latest", limit_mode="latest",
filter_bot=filter_mai, filter_bot=filter_mai,
@@ -142,10 +127,8 @@ def get_messages_before_time_in_chat(
if not isinstance(chat_id, str): if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型") raise ValueError("chat_id 必须是字符串类型")
messages = find_messages( messages = find_messages(
message_filter={ session_id=chat_id,
"chat_id": chat_id, before_time=timestamp,
"time": {"$lt": timestamp},
},
limit=limit, limit=limit,
limit_mode="latest", limit_mode="latest",
filter_bot=filter_mai, filter_bot=filter_mai,
@@ -166,13 +149,7 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
raise ValueError("chat_id 不能为空") raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str): if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型") raise ValueError("chat_id 必须是字符串类型")
message_filter: dict[str, Any] = { return count_messages(session_id=chat_id, after_time=start_time, end_time=end_time)
"chat_id": chat_id,
"time": {"$gt": start_time},
}
if end_time is not None:
message_filter["time"]["$lte"] = end_time
return count_messages(message_filter)
# ============================================================================= # =============================================================================