- Remove get_bot_account("qq") fallback from all 4 sender paths
(plan L108/L208/L449: unknown platform = no account, never substitute QQ)
- Sender paths now error immediately if platform bot account is not configured
- Add detailed comments on filter_bot legacy fallback explaining why
global user_id match is needed (plan contingency L528 insufficient for
platform-tagged legacy rows like telegram+qq_account)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
297 lines
11 KiB
Python
297 lines
11 KiB
Python
from datetime import datetime
|
||
from typing import Any
|
||
|
||
import json
|
||
import traceback
|
||
|
||
from sqlalchemy import and_, func, not_, or_
|
||
from sqlmodel import col, select
|
||
|
||
from src.chat.message_receive.message import SessionMessage
|
||
from src.common.database.database import get_db_session
|
||
from src.common.database.database_model import Messages
|
||
from src.common.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
FIELD_MAP: dict[str, Any] = {
|
||
"time": Messages.timestamp,
|
||
"timestamp": Messages.timestamp,
|
||
"chat_id": Messages.session_id,
|
||
"session_id": Messages.session_id,
|
||
"user_id": Messages.user_id,
|
||
"message_id": Messages.message_id,
|
||
"group_id": Messages.group_id,
|
||
"platform": Messages.platform,
|
||
"is_command": Messages.is_command,
|
||
"is_mentioned": Messages.is_mentioned,
|
||
"is_at": Messages.is_at,
|
||
"is_emoji": Messages.is_emoji,
|
||
"is_picid": Messages.is_picture,
|
||
"is_picture": Messages.is_picture,
|
||
"reply_to": Messages.reply_to,
|
||
}
|
||
|
||
|
||
def _parse_additional_config(message: Messages) -> dict[str, Any]:
|
||
if not message.additional_config:
|
||
return {}
|
||
try:
|
||
parsed = json.loads(message.additional_config)
|
||
except (json.JSONDecodeError, TypeError):
|
||
return {}
|
||
if isinstance(parsed, dict):
|
||
return parsed
|
||
return {}
|
||
|
||
|
||
def _message_to_instance(message: Messages) -> SessionMessage:
|
||
return SessionMessage.from_db_instance(message)
|
||
|
||
|
||
def _coerce_datetime(value: Any) -> Any:
|
||
if isinstance(value, (int, float)):
|
||
return datetime.fromtimestamp(value)
|
||
return value
|
||
|
||
|
||
def _resolve_field(field_name: str) -> Any | None:
|
||
if field_name in FIELD_MAP:
|
||
return FIELD_MAP[field_name]
|
||
if hasattr(Messages, field_name):
|
||
return getattr(Messages, field_name)
|
||
return None
|
||
|
||
|
||
def _build_message_conditions(
|
||
*,
|
||
session_id: str | None = None,
|
||
user_id: str | None = None,
|
||
group_id: str | None = None,
|
||
platform: str | None = None,
|
||
message_id: str | None = None,
|
||
reply_to: str | None = None,
|
||
start_time: float | None = None,
|
||
end_time: float | None = None,
|
||
before_time: float | None = None,
|
||
after_time: float | None = None,
|
||
has_reply_to: bool | None = None,
|
||
) -> list[Any]:
|
||
conditions: list[Any] = [Messages.message_id != "notice"]
|
||
|
||
if session_id is not None:
|
||
conditions.append(Messages.session_id == session_id)
|
||
if user_id is not None:
|
||
conditions.append(Messages.user_id == user_id)
|
||
if group_id is not None:
|
||
conditions.append(Messages.group_id == group_id)
|
||
if platform is not None:
|
||
conditions.append(Messages.platform == platform)
|
||
if message_id is not None:
|
||
conditions.append(Messages.message_id == message_id)
|
||
if reply_to is not None:
|
||
conditions.append(Messages.reply_to == reply_to)
|
||
if start_time is not None:
|
||
conditions.append(Messages.timestamp >= _coerce_datetime(start_time))
|
||
if end_time is not None:
|
||
conditions.append(Messages.timestamp <= _coerce_datetime(end_time))
|
||
if before_time is not None:
|
||
conditions.append(Messages.timestamp < _coerce_datetime(before_time))
|
||
if after_time is not None:
|
||
conditions.append(Messages.timestamp > _coerce_datetime(after_time))
|
||
if has_reply_to is True:
|
||
conditions.append(col(Messages.reply_to).is_not(None))
|
||
elif has_reply_to is False:
|
||
conditions.append(col(Messages.reply_to).is_(None))
|
||
|
||
return conditions
|
||
|
||
|
||
def find_messages(
|
||
*,
|
||
session_id: str | None = None,
|
||
user_id: str | None = None,
|
||
group_id: str | None = None,
|
||
platform: str | None = None,
|
||
message_id: str | None = None,
|
||
reply_to: str | None = None,
|
||
start_time: float | None = None,
|
||
end_time: float | None = None,
|
||
before_time: float | None = None,
|
||
after_time: float | None = None,
|
||
sort: list[tuple[str, int]] | None = None,
|
||
limit: int = 0,
|
||
limit_mode: str = "latest",
|
||
filter_bot: bool = False,
|
||
filter_command: bool = False,
|
||
filter_intercept_message_level: int | None = None,
|
||
) -> list[SessionMessage]:
|
||
"""
|
||
根据提供的过滤器、排序和限制条件查找消息。
|
||
|
||
Args:
|
||
session_id: 会话 ID 过滤。
|
||
user_id: 用户 ID 过滤。
|
||
group_id: 群 ID 过滤。
|
||
platform: 平台过滤。
|
||
message_id: 消息 ID 过滤。
|
||
reply_to: 回复目标消息 ID 过滤。
|
||
start_time: 起始时间,闭区间下界。
|
||
end_time: 结束时间,闭区间上界。
|
||
before_time: 严格早于该时间。
|
||
after_time: 严格晚于该时间。
|
||
sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
|
||
limit: 返回的最大文档数,0表示不限制。
|
||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。
|
||
|
||
Returns:
|
||
消息字典列表,如果出错则返回空列表。
|
||
"""
|
||
try:
|
||
conditions = _build_message_conditions(
|
||
session_id=session_id,
|
||
user_id=user_id,
|
||
group_id=group_id,
|
||
platform=platform,
|
||
message_id=message_id,
|
||
reply_to=reply_to,
|
||
start_time=start_time,
|
||
end_time=end_time,
|
||
before_time=before_time,
|
||
after_time=after_time,
|
||
)
|
||
if filter_bot:
|
||
from src.chat.utils.utils import get_all_bot_accounts, get_bot_account
|
||
|
||
bot_accounts = get_all_bot_accounts()
|
||
exclusion_conditions: list[Any] = []
|
||
if bot_accounts:
|
||
exclusion_conditions.append(
|
||
or_(
|
||
*[
|
||
and_(Messages.platform == platform_name, Messages.user_id == account)
|
||
for platform_name, account in bot_accounts.items()
|
||
]
|
||
)
|
||
)
|
||
|
||
# 兼容旧数据:历史机器人消息在所有平台上都使用 QQ 账号作为 user_id 存储,
|
||
# 例如旧 Telegram bot 消息的 (platform="telegram", user_id=qq_account)。
|
||
# plan 建议的 ("", qq_account) pair 只能覆盖空 platform 行,无法覆盖这种情况。
|
||
# 因此这里使用全局 user_id 匹配作为临时方案,待 DB 迁移后应移除此兜底。
|
||
if qq_fallback := get_bot_account("qq"):
|
||
exclusion_conditions.append(Messages.user_id == qq_fallback)
|
||
|
||
if exclusion_conditions:
|
||
conditions.append(not_(or_(*exclusion_conditions)))
|
||
if filter_command:
|
||
conditions.append(Messages.is_command == False) # noqa: E712
|
||
|
||
statement = select(Messages).where(*conditions)
|
||
if limit > 0:
|
||
if limit_mode == "earliest":
|
||
statement = statement.order_by(col(Messages.timestamp)).limit(limit)
|
||
with get_db_session() as session:
|
||
results = list(session.exec(statement).all())
|
||
else:
|
||
statement = statement.order_by(col(Messages.timestamp).desc()).limit(limit)
|
||
with get_db_session() as session:
|
||
results = list(session.exec(statement).all())
|
||
results = list(reversed(results))
|
||
else:
|
||
if sort:
|
||
order_terms: list[Any] = []
|
||
for field_name, direction in sort:
|
||
sort_field = _resolve_field(field_name)
|
||
if sort_field is None:
|
||
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
||
continue
|
||
order_terms.append(sort_field.asc() if direction == 1 else sort_field.desc())
|
||
if order_terms:
|
||
statement = statement.order_by(*order_terms)
|
||
with get_db_session() as session:
|
||
results = list(session.exec(statement).all())
|
||
|
||
if filter_intercept_message_level is not None:
|
||
filtered_results = []
|
||
for msg in results:
|
||
config = _parse_additional_config(msg)
|
||
if config.get("intercept_message_level", 0) <= filter_intercept_message_level:
|
||
filtered_results.append(msg)
|
||
results = filtered_results
|
||
|
||
return [_message_to_instance(msg) for msg in results]
|
||
except Exception as e:
|
||
log_message = (
|
||
"使用 SQLModel 查找消息失败 "
|
||
f"(session_id={session_id}, user_id={user_id}, group_id={group_id}, platform={platform}, "
|
||
f"message_id={message_id}, reply_to={reply_to}, start_time={start_time}, end_time={end_time}, "
|
||
f"before_time={before_time}, after_time={after_time}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||
+ traceback.format_exc()
|
||
)
|
||
logger.error(log_message)
|
||
return []
|
||
|
||
|
||
def count_messages(
|
||
*,
|
||
session_id: str | None = None,
|
||
user_id: str | None = None,
|
||
group_id: str | None = None,
|
||
platform: str | None = None,
|
||
message_id: str | None = None,
|
||
reply_to: str | None = None,
|
||
start_time: float | None = None,
|
||
end_time: float | None = None,
|
||
before_time: float | None = None,
|
||
after_time: float | None = None,
|
||
has_reply_to: bool | None = None,
|
||
) -> int:
|
||
"""
|
||
根据提供的过滤器计算消息数量。
|
||
|
||
Args:
|
||
session_id: 会话 ID 过滤。
|
||
user_id: 用户 ID 过滤。
|
||
group_id: 群 ID 过滤。
|
||
platform: 平台过滤。
|
||
message_id: 消息 ID 过滤。
|
||
reply_to: 回复目标消息 ID 过滤。
|
||
start_time: 起始时间,闭区间下界。
|
||
end_time: 结束时间,闭区间上界。
|
||
before_time: 严格早于该时间。
|
||
after_time: 严格晚于该时间。
|
||
has_reply_to: 是否要求存在 reply_to 字段。
|
||
|
||
Returns:
|
||
符合条件的消息数量,如果出错则返回 0。
|
||
"""
|
||
try:
|
||
conditions = _build_message_conditions(
|
||
session_id=session_id,
|
||
user_id=user_id,
|
||
group_id=group_id,
|
||
platform=platform,
|
||
message_id=message_id,
|
||
reply_to=reply_to,
|
||
start_time=start_time,
|
||
end_time=end_time,
|
||
before_time=before_time,
|
||
after_time=after_time,
|
||
has_reply_to=has_reply_to,
|
||
)
|
||
statement = select(func.count()).select_from(Messages).where(*conditions)
|
||
with get_db_session() as session:
|
||
result = session.exec(statement).one()
|
||
return int(result or 0)
|
||
except Exception as e:
|
||
log_message = (
|
||
"使用 SQLModel 计数消息失败 "
|
||
f"(session_id={session_id}, user_id={user_id}, group_id={group_id}, platform={platform}, "
|
||
f"message_id={message_id}, reply_to={reply_to}, start_time={start_time}, end_time={end_time}, "
|
||
f"before_time={before_time}, after_time={after_time}, has_reply_to={has_reply_to}): {e}\n{traceback.format_exc()}"
|
||
)
|
||
logger.error(log_message)
|
||
return 0
|