移除所有 MongoDB 风格 filter
This commit is contained in:
@@ -98,24 +98,6 @@ def _coerce_datetime(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _cast_value_for_field(field: Any, value: Any) -> Any:
|
||||
if field is Messages.timestamp:
|
||||
return _coerce_datetime(value)
|
||||
return value
|
||||
|
||||
|
||||
def _ensure_list(value: Any) -> list[Any]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
if isinstance(value, tuple):
|
||||
return list(value)
|
||||
if isinstance(value, set):
|
||||
return list(value)
|
||||
return [value]
|
||||
|
||||
|
||||
def _resolve_field(field_name: str) -> Any | None:
|
||||
if field_name in FIELD_MAP:
|
||||
return FIELD_MAP[field_name]
|
||||
@@ -124,8 +106,57 @@ def _resolve_field(field_name: str) -> Any | None:
|
||||
return None
|
||||
|
||||
|
||||
def _build_message_conditions(
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
platform: str | None = None,
|
||||
message_id: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
start_time: float | None = None,
|
||||
end_time: float | None = None,
|
||||
before_time: float | None = None,
|
||||
after_time: float | None = None,
|
||||
) -> list[Any]:
|
||||
conditions: list[Any] = [Messages.message_id != "notice"]
|
||||
|
||||
if session_id is not None:
|
||||
conditions.append(Messages.session_id == session_id)
|
||||
if user_id is not None:
|
||||
conditions.append(Messages.user_id == user_id)
|
||||
if group_id is not None:
|
||||
conditions.append(Messages.group_id == group_id)
|
||||
if platform is not None:
|
||||
conditions.append(Messages.platform == platform)
|
||||
if message_id is not None:
|
||||
conditions.append(Messages.message_id == message_id)
|
||||
if reply_to is not None:
|
||||
conditions.append(Messages.reply_to == reply_to)
|
||||
if start_time is not None:
|
||||
conditions.append(Messages.timestamp >= _coerce_datetime(start_time))
|
||||
if end_time is not None:
|
||||
conditions.append(Messages.timestamp <= _coerce_datetime(end_time))
|
||||
if before_time is not None:
|
||||
conditions.append(Messages.timestamp < _coerce_datetime(before_time))
|
||||
if after_time is not None:
|
||||
conditions.append(Messages.timestamp > _coerce_datetime(after_time))
|
||||
|
||||
return conditions
|
||||
|
||||
|
||||
def find_messages(
|
||||
message_filter: dict[str, Any],
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
platform: str | None = None,
|
||||
message_id: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
start_time: float | None = None,
|
||||
end_time: float | None = None,
|
||||
before_time: float | None = None,
|
||||
after_time: float | None = None,
|
||||
sort: list[tuple[str, int]] | None = None,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
@@ -137,7 +168,16 @@ def find_messages(
|
||||
根据提供的过滤器、排序和限制条件查找消息。
|
||||
|
||||
Args:
|
||||
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
|
||||
session_id: 会话 ID 过滤。
|
||||
user_id: 用户 ID 过滤。
|
||||
group_id: 群 ID 过滤。
|
||||
platform: 平台过滤。
|
||||
message_id: 消息 ID 过滤。
|
||||
reply_to: 回复目标消息 ID 过滤。
|
||||
start_time: 起始时间,闭区间下界。
|
||||
end_time: 结束时间,闭区间上界。
|
||||
before_time: 严格早于该时间。
|
||||
after_time: 严格晚于该时间。
|
||||
sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
|
||||
limit: 返回的最大文档数,0表示不限制。
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。
|
||||
@@ -146,37 +186,18 @@ def find_messages(
|
||||
消息字典列表,如果出错则返回空列表。
|
||||
"""
|
||||
try:
|
||||
conditions: list[Any] = []
|
||||
if message_filter:
|
||||
for key, value in message_filter.items():
|
||||
field = _resolve_field(key)
|
||||
if field is None:
|
||||
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
for op, op_value in value.items():
|
||||
coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value
|
||||
if op == "$gt":
|
||||
conditions.append(field > coerced_value)
|
||||
elif op == "$lt":
|
||||
conditions.append(field < coerced_value)
|
||||
elif op == "$gte":
|
||||
conditions.append(field >= coerced_value)
|
||||
elif op == "$lte":
|
||||
conditions.append(field <= coerced_value)
|
||||
elif op == "$ne":
|
||||
conditions.append(field != coerced_value)
|
||||
elif op == "$in":
|
||||
conditions.append(field.in_(_ensure_list(coerced_value)))
|
||||
elif op == "$nin":
|
||||
conditions.append(field.not_in(_ensure_list(coerced_value)))
|
||||
else:
|
||||
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
|
||||
else:
|
||||
coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value
|
||||
conditions.append(field == coerced_value)
|
||||
|
||||
conditions.append(Messages.message_id != "notice")
|
||||
conditions = _build_message_conditions(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
platform=platform,
|
||||
message_id=message_id,
|
||||
reply_to=reply_to,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
before_time=before_time,
|
||||
after_time=after_time,
|
||||
)
|
||||
if filter_bot:
|
||||
conditions.append(Messages.user_id != global_config.bot.qq_account)
|
||||
if filter_command:
|
||||
@@ -218,60 +239,70 @@ def find_messages(
|
||||
return [_message_to_instance(msg) for msg in results]
|
||||
except Exception as e:
|
||||
log_message = (
|
||||
f"使用 SQLModel 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||
"使用 SQLModel 查找消息失败 "
|
||||
f"(session_id={session_id}, user_id={user_id}, group_id={group_id}, platform={platform}, "
|
||||
f"message_id={message_id}, reply_to={reply_to}, start_time={start_time}, end_time={end_time}, "
|
||||
f"before_time={before_time}, after_time={after_time}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||
+ traceback.format_exc()
|
||||
)
|
||||
logger.error(log_message)
|
||||
return []
|
||||
|
||||
|
||||
def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
def count_messages(
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
platform: str | None = None,
|
||||
message_id: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
start_time: float | None = None,
|
||||
end_time: float | None = None,
|
||||
before_time: float | None = None,
|
||||
after_time: float | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
根据提供的过滤器计算消息数量。
|
||||
|
||||
Args:
|
||||
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
|
||||
session_id: 会话 ID 过滤。
|
||||
user_id: 用户 ID 过滤。
|
||||
group_id: 群 ID 过滤。
|
||||
platform: 平台过滤。
|
||||
message_id: 消息 ID 过滤。
|
||||
reply_to: 回复目标消息 ID 过滤。
|
||||
start_time: 起始时间,闭区间下界。
|
||||
end_time: 结束时间,闭区间上界。
|
||||
before_time: 严格早于该时间。
|
||||
after_time: 严格晚于该时间。
|
||||
|
||||
Returns:
|
||||
符合条件的消息数量,如果出错则返回 0。
|
||||
"""
|
||||
try:
|
||||
conditions: list[Any] = []
|
||||
if message_filter:
|
||||
for key, value in message_filter.items():
|
||||
field = _resolve_field(key)
|
||||
if field is None:
|
||||
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
for op, op_value in value.items():
|
||||
coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value
|
||||
if op == "$gt":
|
||||
conditions.append(field > coerced_value)
|
||||
elif op == "$lt":
|
||||
conditions.append(field < coerced_value)
|
||||
elif op == "$gte":
|
||||
conditions.append(field >= coerced_value)
|
||||
elif op == "$lte":
|
||||
conditions.append(field <= coerced_value)
|
||||
elif op == "$ne":
|
||||
conditions.append(field != coerced_value)
|
||||
elif op == "$in":
|
||||
conditions.append(field.in_(_ensure_list(coerced_value)))
|
||||
elif op == "$nin":
|
||||
conditions.append(field.not_in(_ensure_list(coerced_value)))
|
||||
else:
|
||||
logger.warning(f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
|
||||
else:
|
||||
coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value
|
||||
conditions.append(field == coerced_value)
|
||||
|
||||
conditions.append(Messages.message_id != "notice")
|
||||
conditions = _build_message_conditions(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
platform=platform,
|
||||
message_id=message_id,
|
||||
reply_to=reply_to,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
before_time=before_time,
|
||||
after_time=after_time,
|
||||
)
|
||||
statement = select(func.count()).select_from(Messages).where(*conditions)
|
||||
with get_db_session() as session:
|
||||
result = session.exec(statement).one()
|
||||
return int(result or 0)
|
||||
except Exception as e:
|
||||
log_message = f"使用 SQLModel 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||
log_message = (
|
||||
"使用 SQLModel 计数消息失败 "
|
||||
f"(session_id={session_id}, user_id={user_id}, group_id={group_id}, platform={platform}, "
|
||||
f"message_id={message_id}, reply_to={reply_to}, start_time={start_time}, end_time={end_time}, "
|
||||
f"before_time={before_time}, after_time={after_time}): {e}\n{traceback.format_exc()}"
|
||||
)
|
||||
logger.error(log_message)
|
||||
return 0
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""消息服务模块。"""
|
||||
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from sqlmodel import col, select
|
||||
|
||||
@@ -17,20 +16,6 @@ from src.common.utils.utils_action import ActionUtils
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 消息查询函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _build_time_range_filter(start_time: float, end_time: float) -> dict[str, Any]:
|
||||
return {
|
||||
"time": {
|
||||
"$gte": start_time,
|
||||
"$lte": end_time,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _build_readable_line(
|
||||
message: SessionMessage,
|
||||
*,
|
||||
@@ -72,7 +57,8 @@ def get_messages_by_time(
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
messages = find_messages(
|
||||
message_filter=_build_time_range_filter(start_time, end_time),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
limit_mode=limit_mode,
|
||||
filter_bot=filter_mai,
|
||||
@@ -99,10 +85,9 @@ def get_messages_by_time_in_chat(
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
messages = find_messages(
|
||||
message_filter={
|
||||
"chat_id": chat_id,
|
||||
**_build_time_range_filter(start_time, end_time),
|
||||
},
|
||||
session_id=chat_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
limit_mode=limit_mode,
|
||||
filter_bot=filter_mai,
|
||||
@@ -118,7 +103,7 @@ def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
messages = find_messages(
|
||||
message_filter={"time": {"$lt": timestamp}},
|
||||
before_time=timestamp,
|
||||
limit=limit,
|
||||
limit_mode="latest",
|
||||
filter_bot=filter_mai,
|
||||
@@ -142,10 +127,8 @@ def get_messages_before_time_in_chat(
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
messages = find_messages(
|
||||
message_filter={
|
||||
"chat_id": chat_id,
|
||||
"time": {"$lt": timestamp},
|
||||
},
|
||||
session_id=chat_id,
|
||||
before_time=timestamp,
|
||||
limit=limit,
|
||||
limit_mode="latest",
|
||||
filter_bot=filter_mai,
|
||||
@@ -166,13 +149,7 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
message_filter: dict[str, Any] = {
|
||||
"chat_id": chat_id,
|
||||
"time": {"$gt": start_time},
|
||||
}
|
||||
if end_time is not None:
|
||||
message_filter["time"]["$lte"] = end_time
|
||||
return count_messages(message_filter)
|
||||
return count_messages(session_id=chat_id, after_time=start_time, end_time=end_time)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
Reference in New Issue
Block a user