移除所有 MongoDB 风格 filter
This commit is contained in:
@@ -98,24 +98,6 @@ def _coerce_datetime(value: Any) -> Any:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _cast_value_for_field(field: Any, value: Any) -> Any:
|
|
||||||
if field is Messages.timestamp:
|
|
||||||
return _coerce_datetime(value)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure_list(value: Any) -> list[Any]:
|
|
||||||
if value is None:
|
|
||||||
return []
|
|
||||||
if isinstance(value, list):
|
|
||||||
return value
|
|
||||||
if isinstance(value, tuple):
|
|
||||||
return list(value)
|
|
||||||
if isinstance(value, set):
|
|
||||||
return list(value)
|
|
||||||
return [value]
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_field(field_name: str) -> Any | None:
|
def _resolve_field(field_name: str) -> Any | None:
|
||||||
if field_name in FIELD_MAP:
|
if field_name in FIELD_MAP:
|
||||||
return FIELD_MAP[field_name]
|
return FIELD_MAP[field_name]
|
||||||
@@ -124,8 +106,57 @@ def _resolve_field(field_name: str) -> Any | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _build_message_conditions(
|
||||||
|
*,
|
||||||
|
session_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
group_id: str | None = None,
|
||||||
|
platform: str | None = None,
|
||||||
|
message_id: str | None = None,
|
||||||
|
reply_to: str | None = None,
|
||||||
|
start_time: float | None = None,
|
||||||
|
end_time: float | None = None,
|
||||||
|
before_time: float | None = None,
|
||||||
|
after_time: float | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
conditions: list[Any] = [Messages.message_id != "notice"]
|
||||||
|
|
||||||
|
if session_id is not None:
|
||||||
|
conditions.append(Messages.session_id == session_id)
|
||||||
|
if user_id is not None:
|
||||||
|
conditions.append(Messages.user_id == user_id)
|
||||||
|
if group_id is not None:
|
||||||
|
conditions.append(Messages.group_id == group_id)
|
||||||
|
if platform is not None:
|
||||||
|
conditions.append(Messages.platform == platform)
|
||||||
|
if message_id is not None:
|
||||||
|
conditions.append(Messages.message_id == message_id)
|
||||||
|
if reply_to is not None:
|
||||||
|
conditions.append(Messages.reply_to == reply_to)
|
||||||
|
if start_time is not None:
|
||||||
|
conditions.append(Messages.timestamp >= _coerce_datetime(start_time))
|
||||||
|
if end_time is not None:
|
||||||
|
conditions.append(Messages.timestamp <= _coerce_datetime(end_time))
|
||||||
|
if before_time is not None:
|
||||||
|
conditions.append(Messages.timestamp < _coerce_datetime(before_time))
|
||||||
|
if after_time is not None:
|
||||||
|
conditions.append(Messages.timestamp > _coerce_datetime(after_time))
|
||||||
|
|
||||||
|
return conditions
|
||||||
|
|
||||||
|
|
||||||
def find_messages(
|
def find_messages(
|
||||||
message_filter: dict[str, Any],
|
*,
|
||||||
|
session_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
group_id: str | None = None,
|
||||||
|
platform: str | None = None,
|
||||||
|
message_id: str | None = None,
|
||||||
|
reply_to: str | None = None,
|
||||||
|
start_time: float | None = None,
|
||||||
|
end_time: float | None = None,
|
||||||
|
before_time: float | None = None,
|
||||||
|
after_time: float | None = None,
|
||||||
sort: list[tuple[str, int]] | None = None,
|
sort: list[tuple[str, int]] | None = None,
|
||||||
limit: int = 0,
|
limit: int = 0,
|
||||||
limit_mode: str = "latest",
|
limit_mode: str = "latest",
|
||||||
@@ -137,7 +168,16 @@ def find_messages(
|
|||||||
根据提供的过滤器、排序和限制条件查找消息。
|
根据提供的过滤器、排序和限制条件查找消息。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
|
session_id: 会话 ID 过滤。
|
||||||
|
user_id: 用户 ID 过滤。
|
||||||
|
group_id: 群 ID 过滤。
|
||||||
|
platform: 平台过滤。
|
||||||
|
message_id: 消息 ID 过滤。
|
||||||
|
reply_to: 回复目标消息 ID 过滤。
|
||||||
|
start_time: 起始时间,闭区间下界。
|
||||||
|
end_time: 结束时间,闭区间上界。
|
||||||
|
before_time: 严格早于该时间。
|
||||||
|
after_time: 严格晚于该时间。
|
||||||
sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
|
sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
|
||||||
limit: 返回的最大文档数,0表示不限制。
|
limit: 返回的最大文档数,0表示不限制。
|
||||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。
|
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。
|
||||||
@@ -146,37 +186,18 @@ def find_messages(
|
|||||||
消息字典列表,如果出错则返回空列表。
|
消息字典列表,如果出错则返回空列表。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
conditions: list[Any] = []
|
conditions = _build_message_conditions(
|
||||||
if message_filter:
|
session_id=session_id,
|
||||||
for key, value in message_filter.items():
|
user_id=user_id,
|
||||||
field = _resolve_field(key)
|
group_id=group_id,
|
||||||
if field is None:
|
platform=platform,
|
||||||
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
message_id=message_id,
|
||||||
continue
|
reply_to=reply_to,
|
||||||
if isinstance(value, dict):
|
start_time=start_time,
|
||||||
for op, op_value in value.items():
|
end_time=end_time,
|
||||||
coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value
|
before_time=before_time,
|
||||||
if op == "$gt":
|
after_time=after_time,
|
||||||
conditions.append(field > coerced_value)
|
)
|
||||||
elif op == "$lt":
|
|
||||||
conditions.append(field < coerced_value)
|
|
||||||
elif op == "$gte":
|
|
||||||
conditions.append(field >= coerced_value)
|
|
||||||
elif op == "$lte":
|
|
||||||
conditions.append(field <= coerced_value)
|
|
||||||
elif op == "$ne":
|
|
||||||
conditions.append(field != coerced_value)
|
|
||||||
elif op == "$in":
|
|
||||||
conditions.append(field.in_(_ensure_list(coerced_value)))
|
|
||||||
elif op == "$nin":
|
|
||||||
conditions.append(field.not_in(_ensure_list(coerced_value)))
|
|
||||||
else:
|
|
||||||
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
|
|
||||||
else:
|
|
||||||
coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value
|
|
||||||
conditions.append(field == coerced_value)
|
|
||||||
|
|
||||||
conditions.append(Messages.message_id != "notice")
|
|
||||||
if filter_bot:
|
if filter_bot:
|
||||||
conditions.append(Messages.user_id != global_config.bot.qq_account)
|
conditions.append(Messages.user_id != global_config.bot.qq_account)
|
||||||
if filter_command:
|
if filter_command:
|
||||||
@@ -218,60 +239,70 @@ def find_messages(
|
|||||||
return [_message_to_instance(msg) for msg in results]
|
return [_message_to_instance(msg) for msg in results]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message = (
|
log_message = (
|
||||||
f"使用 SQLModel 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
"使用 SQLModel 查找消息失败 "
|
||||||
|
f"(session_id={session_id}, user_id={user_id}, group_id={group_id}, platform={platform}, "
|
||||||
|
f"message_id={message_id}, reply_to={reply_to}, start_time={start_time}, end_time={end_time}, "
|
||||||
|
f"before_time={before_time}, after_time={after_time}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||||
+ traceback.format_exc()
|
+ traceback.format_exc()
|
||||||
)
|
)
|
||||||
logger.error(log_message)
|
logger.error(log_message)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def count_messages(message_filter: dict[str, Any]) -> int:
|
def count_messages(
|
||||||
|
*,
|
||||||
|
session_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
group_id: str | None = None,
|
||||||
|
platform: str | None = None,
|
||||||
|
message_id: str | None = None,
|
||||||
|
reply_to: str | None = None,
|
||||||
|
start_time: float | None = None,
|
||||||
|
end_time: float | None = None,
|
||||||
|
before_time: float | None = None,
|
||||||
|
after_time: float | None = None,
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
根据提供的过滤器计算消息数量。
|
根据提供的过滤器计算消息数量。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
|
session_id: 会话 ID 过滤。
|
||||||
|
user_id: 用户 ID 过滤。
|
||||||
|
group_id: 群 ID 过滤。
|
||||||
|
platform: 平台过滤。
|
||||||
|
message_id: 消息 ID 过滤。
|
||||||
|
reply_to: 回复目标消息 ID 过滤。
|
||||||
|
start_time: 起始时间,闭区间下界。
|
||||||
|
end_time: 结束时间,闭区间上界。
|
||||||
|
before_time: 严格早于该时间。
|
||||||
|
after_time: 严格晚于该时间。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
符合条件的消息数量,如果出错则返回 0。
|
符合条件的消息数量,如果出错则返回 0。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
conditions: list[Any] = []
|
conditions = _build_message_conditions(
|
||||||
if message_filter:
|
session_id=session_id,
|
||||||
for key, value in message_filter.items():
|
user_id=user_id,
|
||||||
field = _resolve_field(key)
|
group_id=group_id,
|
||||||
if field is None:
|
platform=platform,
|
||||||
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
message_id=message_id,
|
||||||
continue
|
reply_to=reply_to,
|
||||||
if isinstance(value, dict):
|
start_time=start_time,
|
||||||
for op, op_value in value.items():
|
end_time=end_time,
|
||||||
coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value
|
before_time=before_time,
|
||||||
if op == "$gt":
|
after_time=after_time,
|
||||||
conditions.append(field > coerced_value)
|
)
|
||||||
elif op == "$lt":
|
|
||||||
conditions.append(field < coerced_value)
|
|
||||||
elif op == "$gte":
|
|
||||||
conditions.append(field >= coerced_value)
|
|
||||||
elif op == "$lte":
|
|
||||||
conditions.append(field <= coerced_value)
|
|
||||||
elif op == "$ne":
|
|
||||||
conditions.append(field != coerced_value)
|
|
||||||
elif op == "$in":
|
|
||||||
conditions.append(field.in_(_ensure_list(coerced_value)))
|
|
||||||
elif op == "$nin":
|
|
||||||
conditions.append(field.not_in(_ensure_list(coerced_value)))
|
|
||||||
else:
|
|
||||||
logger.warning(f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
|
|
||||||
else:
|
|
||||||
coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value
|
|
||||||
conditions.append(field == coerced_value)
|
|
||||||
|
|
||||||
conditions.append(Messages.message_id != "notice")
|
|
||||||
statement = select(func.count()).select_from(Messages).where(*conditions)
|
statement = select(func.count()).select_from(Messages).where(*conditions)
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
result = session.exec(statement).one()
|
result = session.exec(statement).one()
|
||||||
return int(result or 0)
|
return int(result or 0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message = f"使用 SQLModel 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
log_message = (
|
||||||
|
"使用 SQLModel 计数消息失败 "
|
||||||
|
f"(session_id={session_id}, user_id={user_id}, group_id={group_id}, platform={platform}, "
|
||||||
|
f"message_id={message_id}, reply_to={reply_to}, start_time={start_time}, end_time={end_time}, "
|
||||||
|
f"before_time={before_time}, after_time={after_time}): {e}\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
logger.error(log_message)
|
logger.error(log_message)
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
"""消息服务模块。"""
|
"""消息服务模块。"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import time
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
|
|
||||||
@@ -17,20 +16,6 @@ from src.common.utils.utils_action import ActionUtils
|
|||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 消息查询函数
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def _build_time_range_filter(start_time: float, end_time: float) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"time": {
|
|
||||||
"$gte": start_time,
|
|
||||||
"$lte": end_time,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _build_readable_line(
|
def _build_readable_line(
|
||||||
message: SessionMessage,
|
message: SessionMessage,
|
||||||
*,
|
*,
|
||||||
@@ -72,7 +57,8 @@ def get_messages_by_time(
|
|||||||
if limit < 0:
|
if limit < 0:
|
||||||
raise ValueError("limit 不能为负数")
|
raise ValueError("limit 不能为负数")
|
||||||
messages = find_messages(
|
messages = find_messages(
|
||||||
message_filter=_build_time_range_filter(start_time, end_time),
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
limit_mode=limit_mode,
|
limit_mode=limit_mode,
|
||||||
filter_bot=filter_mai,
|
filter_bot=filter_mai,
|
||||||
@@ -99,10 +85,9 @@ def get_messages_by_time_in_chat(
|
|||||||
if not isinstance(chat_id, str):
|
if not isinstance(chat_id, str):
|
||||||
raise ValueError("chat_id 必须是字符串类型")
|
raise ValueError("chat_id 必须是字符串类型")
|
||||||
messages = find_messages(
|
messages = find_messages(
|
||||||
message_filter={
|
session_id=chat_id,
|
||||||
"chat_id": chat_id,
|
start_time=start_time,
|
||||||
**_build_time_range_filter(start_time, end_time),
|
end_time=end_time,
|
||||||
},
|
|
||||||
limit=limit,
|
limit=limit,
|
||||||
limit_mode=limit_mode,
|
limit_mode=limit_mode,
|
||||||
filter_bot=filter_mai,
|
filter_bot=filter_mai,
|
||||||
@@ -118,7 +103,7 @@ def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool
|
|||||||
if limit < 0:
|
if limit < 0:
|
||||||
raise ValueError("limit 不能为负数")
|
raise ValueError("limit 不能为负数")
|
||||||
messages = find_messages(
|
messages = find_messages(
|
||||||
message_filter={"time": {"$lt": timestamp}},
|
before_time=timestamp,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
limit_mode="latest",
|
limit_mode="latest",
|
||||||
filter_bot=filter_mai,
|
filter_bot=filter_mai,
|
||||||
@@ -142,10 +127,8 @@ def get_messages_before_time_in_chat(
|
|||||||
if not isinstance(chat_id, str):
|
if not isinstance(chat_id, str):
|
||||||
raise ValueError("chat_id 必须是字符串类型")
|
raise ValueError("chat_id 必须是字符串类型")
|
||||||
messages = find_messages(
|
messages = find_messages(
|
||||||
message_filter={
|
session_id=chat_id,
|
||||||
"chat_id": chat_id,
|
before_time=timestamp,
|
||||||
"time": {"$lt": timestamp},
|
|
||||||
},
|
|
||||||
limit=limit,
|
limit=limit,
|
||||||
limit_mode="latest",
|
limit_mode="latest",
|
||||||
filter_bot=filter_mai,
|
filter_bot=filter_mai,
|
||||||
@@ -166,13 +149,7 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
|
|||||||
raise ValueError("chat_id 不能为空")
|
raise ValueError("chat_id 不能为空")
|
||||||
if not isinstance(chat_id, str):
|
if not isinstance(chat_id, str):
|
||||||
raise ValueError("chat_id 必须是字符串类型")
|
raise ValueError("chat_id 必须是字符串类型")
|
||||||
message_filter: dict[str, Any] = {
|
return count_messages(session_id=chat_id, after_time=start_time, end_time=end_time)
|
||||||
"chat_id": chat_id,
|
|
||||||
"time": {"$gt": start_time},
|
|
||||||
}
|
|
||||||
if end_time is not None:
|
|
||||||
message_filter["time"]["$lte"] = end_time
|
|
||||||
return count_messages(message_filter)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
Reference in New Issue
Block a user