移除所有 MongoDB 风格 filter

This commit is contained in:
DrSmoothl
2026-03-14 00:56:40 +08:00
parent c4a0cc19f8
commit a4303d9b81
2 changed files with 127 additions and 119 deletions

View File

@@ -98,24 +98,6 @@ def _coerce_datetime(value: Any) -> Any:
return value
def _cast_value_for_field(field: Any, value: Any) -> Any:
if field is Messages.timestamp:
return _coerce_datetime(value)
return value
def _ensure_list(value: Any) -> list[Any]:
if value is None:
return []
if isinstance(value, list):
return value
if isinstance(value, tuple):
return list(value)
if isinstance(value, set):
return list(value)
return [value]
def _resolve_field(field_name: str) -> Any | None:
if field_name in FIELD_MAP:
return FIELD_MAP[field_name]
@@ -124,8 +106,57 @@ def _resolve_field(field_name: str) -> Any | None:
return None
def _build_message_conditions(
*,
session_id: str | None = None,
user_id: str | None = None,
group_id: str | None = None,
platform: str | None = None,
message_id: str | None = None,
reply_to: str | None = None,
start_time: float | None = None,
end_time: float | None = None,
before_time: float | None = None,
after_time: float | None = None,
) -> list[Any]:
conditions: list[Any] = [Messages.message_id != "notice"]
if session_id is not None:
conditions.append(Messages.session_id == session_id)
if user_id is not None:
conditions.append(Messages.user_id == user_id)
if group_id is not None:
conditions.append(Messages.group_id == group_id)
if platform is not None:
conditions.append(Messages.platform == platform)
if message_id is not None:
conditions.append(Messages.message_id == message_id)
if reply_to is not None:
conditions.append(Messages.reply_to == reply_to)
if start_time is not None:
conditions.append(Messages.timestamp >= _coerce_datetime(start_time))
if end_time is not None:
conditions.append(Messages.timestamp <= _coerce_datetime(end_time))
if before_time is not None:
conditions.append(Messages.timestamp < _coerce_datetime(before_time))
if after_time is not None:
conditions.append(Messages.timestamp > _coerce_datetime(after_time))
return conditions
def find_messages(
message_filter: dict[str, Any],
*,
session_id: str | None = None,
user_id: str | None = None,
group_id: str | None = None,
platform: str | None = None,
message_id: str | None = None,
reply_to: str | None = None,
start_time: float | None = None,
end_time: float | None = None,
before_time: float | None = None,
after_time: float | None = None,
sort: list[tuple[str, int]] | None = None,
limit: int = 0,
limit_mode: str = "latest",
@@ -137,7 +168,16 @@ def find_messages(
根据提供的过滤器、排序和限制条件查找消息。
Args:
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
session_id: 会话 ID 过滤。
user_id: 用户 ID 过滤。
group_id: 群 ID 过滤。
platform: 平台过滤。
message_id: 消息 ID 过滤。
reply_to: 回复目标消息 ID 过滤。
start_time: 起始时间,闭区间下界。
end_time: 结束时间,闭区间上界。
before_time: 严格早于该时间。
after_time: 严格晚于该时间。
sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
limit: 返回的最大文档数0表示不限制。
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'
@@ -146,37 +186,18 @@ def find_messages(
消息字典列表,如果出错则返回空列表。
"""
try:
conditions: list[Any] = []
if message_filter:
for key, value in message_filter.items():
field = _resolve_field(key)
if field is None:
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
continue
if isinstance(value, dict):
for op, op_value in value.items():
coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value
if op == "$gt":
conditions.append(field > coerced_value)
elif op == "$lt":
conditions.append(field < coerced_value)
elif op == "$gte":
conditions.append(field >= coerced_value)
elif op == "$lte":
conditions.append(field <= coerced_value)
elif op == "$ne":
conditions.append(field != coerced_value)
elif op == "$in":
conditions.append(field.in_(_ensure_list(coerced_value)))
elif op == "$nin":
conditions.append(field.not_in(_ensure_list(coerced_value)))
else:
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
else:
coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value
conditions.append(field == coerced_value)
conditions.append(Messages.message_id != "notice")
conditions = _build_message_conditions(
session_id=session_id,
user_id=user_id,
group_id=group_id,
platform=platform,
message_id=message_id,
reply_to=reply_to,
start_time=start_time,
end_time=end_time,
before_time=before_time,
after_time=after_time,
)
if filter_bot:
conditions.append(Messages.user_id != global_config.bot.qq_account)
if filter_command:
@@ -218,60 +239,70 @@ def find_messages(
return [_message_to_instance(msg) for msg in results]
except Exception as e:
log_message = (
f"使用 SQLModel 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
"使用 SQLModel 查找消息失败 "
f"(session_id={session_id}, user_id={user_id}, group_id={group_id}, platform={platform}, "
f"message_id={message_id}, reply_to={reply_to}, start_time={start_time}, end_time={end_time}, "
f"before_time={before_time}, after_time={after_time}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
+ traceback.format_exc()
)
logger.error(log_message)
return []
def count_messages(message_filter: dict[str, Any]) -> int:
def count_messages(
*,
session_id: str | None = None,
user_id: str | None = None,
group_id: str | None = None,
platform: str | None = None,
message_id: str | None = None,
reply_to: str | None = None,
start_time: float | None = None,
end_time: float | None = None,
before_time: float | None = None,
after_time: float | None = None,
) -> int:
"""
根据提供的过滤器计算消息数量。
Args:
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
session_id: 会话 ID 过滤。
user_id: 用户 ID 过滤。
group_id: 群 ID 过滤。
platform: 平台过滤。
message_id: 消息 ID 过滤。
reply_to: 回复目标消息 ID 过滤。
start_time: 起始时间,闭区间下界。
end_time: 结束时间,闭区间上界。
before_time: 严格早于该时间。
after_time: 严格晚于该时间。
Returns:
符合条件的消息数量,如果出错则返回 0。
"""
try:
conditions: list[Any] = []
if message_filter:
for key, value in message_filter.items():
field = _resolve_field(key)
if field is None:
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
continue
if isinstance(value, dict):
for op, op_value in value.items():
coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value
if op == "$gt":
conditions.append(field > coerced_value)
elif op == "$lt":
conditions.append(field < coerced_value)
elif op == "$gte":
conditions.append(field >= coerced_value)
elif op == "$lte":
conditions.append(field <= coerced_value)
elif op == "$ne":
conditions.append(field != coerced_value)
elif op == "$in":
conditions.append(field.in_(_ensure_list(coerced_value)))
elif op == "$nin":
conditions.append(field.not_in(_ensure_list(coerced_value)))
else:
logger.warning(f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
else:
coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value
conditions.append(field == coerced_value)
conditions.append(Messages.message_id != "notice")
conditions = _build_message_conditions(
session_id=session_id,
user_id=user_id,
group_id=group_id,
platform=platform,
message_id=message_id,
reply_to=reply_to,
start_time=start_time,
end_time=end_time,
before_time=before_time,
after_time=after_time,
)
statement = select(func.count()).select_from(Messages).where(*conditions)
with get_db_session() as session:
result = session.exec(statement).one()
return int(result or 0)
except Exception as e:
log_message = f"使用 SQLModel 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
log_message = (
"使用 SQLModel 计数消息失败 "
f"(session_id={session_id}, user_id={user_id}, group_id={group_id}, platform={platform}, "
f"message_id={message_id}, reply_to={reply_to}, start_time={start_time}, end_time={end_time}, "
f"before_time={before_time}, after_time={after_time}): {e}\n{traceback.format_exc()}"
)
logger.error(log_message)
return 0

View File

@@ -1,9 +1,8 @@
"""消息服务模块。"""
import re
import time
from datetime import datetime
from typing import Any, List, Optional, Tuple
from typing import List, Optional, Tuple
from sqlmodel import col, select
@@ -17,20 +16,6 @@ from src.common.utils.utils_action import ActionUtils
from src.config.config import global_config
# =============================================================================
# 消息查询函数
# =============================================================================
def _build_time_range_filter(start_time: float, end_time: float) -> dict[str, Any]:
return {
"time": {
"$gte": start_time,
"$lte": end_time,
}
}
def _build_readable_line(
message: SessionMessage,
*,
@@ -72,7 +57,8 @@ def get_messages_by_time(
if limit < 0:
raise ValueError("limit 不能为负数")
messages = find_messages(
message_filter=_build_time_range_filter(start_time, end_time),
start_time=start_time,
end_time=end_time,
limit=limit,
limit_mode=limit_mode,
filter_bot=filter_mai,
@@ -99,10 +85,9 @@ def get_messages_by_time_in_chat(
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
messages = find_messages(
message_filter={
"chat_id": chat_id,
**_build_time_range_filter(start_time, end_time),
},
session_id=chat_id,
start_time=start_time,
end_time=end_time,
limit=limit,
limit_mode=limit_mode,
filter_bot=filter_mai,
@@ -118,7 +103,7 @@ def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool
if limit < 0:
raise ValueError("limit 不能为负数")
messages = find_messages(
message_filter={"time": {"$lt": timestamp}},
before_time=timestamp,
limit=limit,
limit_mode="latest",
filter_bot=filter_mai,
@@ -142,10 +127,8 @@ def get_messages_before_time_in_chat(
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
messages = find_messages(
message_filter={
"chat_id": chat_id,
"time": {"$lt": timestamp},
},
session_id=chat_id,
before_time=timestamp,
limit=limit,
limit_mode="latest",
filter_bot=filter_mai,
@@ -166,13 +149,7 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
message_filter: dict[str, Any] = {
"chat_id": chat_id,
"time": {"$gt": start_time},
}
if end_time is not None:
message_filter["time"]["$lte"] = end_time
return count_messages(message_filter)
return count_messages(session_id=chat_id, after_time=start_time, end_time=end_time)
# =============================================================================