278 lines
11 KiB
Python
278 lines
11 KiB
Python
import traceback
|
||
from datetime import datetime
|
||
from types import SimpleNamespace
|
||
from typing import Any
|
||
|
||
import json
|
||
|
||
from sqlalchemy import func
|
||
from sqlmodel import col, select
|
||
|
||
from src.common.database.database import get_db_session
|
||
from src.common.database.database_model import Messages
|
||
from src.chat.message_receive.message import SessionMessage
|
||
from src.common.logger import get_logger
|
||
from src.config.config import global_config
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
FIELD_MAP: dict[str, Any] = {
|
||
"time": Messages.timestamp,
|
||
"timestamp": Messages.timestamp,
|
||
"chat_id": Messages.session_id,
|
||
"session_id": Messages.session_id,
|
||
"user_id": Messages.user_id,
|
||
"message_id": Messages.message_id,
|
||
"group_id": Messages.group_id,
|
||
"platform": Messages.platform,
|
||
"is_command": Messages.is_command,
|
||
"is_mentioned": Messages.is_mentioned,
|
||
"is_at": Messages.is_at,
|
||
"is_emoji": Messages.is_emoji,
|
||
"is_picid": Messages.is_picture,
|
||
"is_picture": Messages.is_picture,
|
||
"reply_to": Messages.reply_to,
|
||
}
|
||
|
||
|
||
def _parse_additional_config(message: Messages) -> dict[str, Any]:
|
||
if not message.additional_config:
|
||
return {}
|
||
try:
|
||
parsed = json.loads(message.additional_config)
|
||
except (json.JSONDecodeError, TypeError):
|
||
return {}
|
||
if isinstance(parsed, dict):
|
||
return parsed
|
||
return {}
|
||
|
||
|
||
def _normalize_optional_str(value: object) -> str | None:
|
||
if value is None:
|
||
return None
|
||
if isinstance(value, str):
|
||
return value
|
||
try:
|
||
return json.dumps(value, ensure_ascii=False)
|
||
except (TypeError, ValueError):
|
||
return str(value)
|
||
|
||
|
||
def _message_to_instance(message: Messages) -> SessionMessage:
|
||
config = _parse_additional_config(message)
|
||
instance = SessionMessage.from_db_instance(message)
|
||
instance.interest_value = config.get("interest_value")
|
||
instance.key_words = _normalize_optional_str(config.get("key_words"))
|
||
instance.key_words_lite = _normalize_optional_str(config.get("key_words_lite"))
|
||
instance.reply_probability_boost = config.get("reply_probability_boost")
|
||
instance.priority_mode = _normalize_optional_str(config.get("priority_mode"))
|
||
instance.priority_info = _normalize_optional_str(config.get("priority_info"))
|
||
instance.intercept_message_level = config.get("intercept_message_level", 0)
|
||
instance.selected_expressions = _normalize_optional_str(config.get("selected_expressions"))
|
||
group_info = instance.message_info.group_info
|
||
legacy_group_info = None
|
||
if group_info:
|
||
legacy_group_info = SimpleNamespace(
|
||
group_id=group_info.group_id,
|
||
group_name=group_info.group_name,
|
||
)
|
||
instance.user_info = SimpleNamespace(
|
||
user_id=instance.message_info.user_info.user_id,
|
||
user_nickname=instance.message_info.user_info.user_nickname,
|
||
user_cardname=instance.message_info.user_info.user_cardname,
|
||
platform=instance.platform,
|
||
)
|
||
instance.chat_info = SimpleNamespace(
|
||
platform=instance.platform,
|
||
stream_id=instance.session_id,
|
||
group_info=legacy_group_info,
|
||
)
|
||
instance.time = instance.timestamp.timestamp()
|
||
return instance
|
||
|
||
|
||
def _coerce_datetime(value: Any) -> Any:
|
||
if isinstance(value, (int, float)):
|
||
return datetime.fromtimestamp(value)
|
||
return value
|
||
|
||
|
||
def _cast_value_for_field(field: Any, value: Any) -> Any:
|
||
if field is Messages.timestamp:
|
||
return _coerce_datetime(value)
|
||
return value
|
||
|
||
|
||
def _ensure_list(value: Any) -> list[Any]:
|
||
if value is None:
|
||
return []
|
||
if isinstance(value, list):
|
||
return value
|
||
if isinstance(value, tuple):
|
||
return list(value)
|
||
if isinstance(value, set):
|
||
return list(value)
|
||
return [value]
|
||
|
||
|
||
def _resolve_field(field_name: str) -> Any | None:
|
||
if field_name in FIELD_MAP:
|
||
return FIELD_MAP[field_name]
|
||
if hasattr(Messages, field_name):
|
||
return getattr(Messages, field_name)
|
||
return None
|
||
|
||
|
||
def find_messages(
|
||
message_filter: dict[str, Any],
|
||
sort: list[tuple[str, int]] | None = None,
|
||
limit: int = 0,
|
||
limit_mode: str = "latest",
|
||
filter_bot: bool = False,
|
||
filter_command: bool = False,
|
||
filter_intercept_message_level: int | None = None,
|
||
) -> list[SessionMessage]:
|
||
"""
|
||
根据提供的过滤器、排序和限制条件查找消息。
|
||
|
||
Args:
|
||
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
|
||
sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
|
||
limit: 返回的最大文档数,0表示不限制。
|
||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。
|
||
|
||
Returns:
|
||
消息字典列表,如果出错则返回空列表。
|
||
"""
|
||
try:
|
||
conditions: list[Any] = []
|
||
if message_filter:
|
||
for key, value in message_filter.items():
|
||
field = _resolve_field(key)
|
||
if field is None:
|
||
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
||
continue
|
||
if isinstance(value, dict):
|
||
for op, op_value in value.items():
|
||
coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value
|
||
if op == "$gt":
|
||
conditions.append(field > coerced_value)
|
||
elif op == "$lt":
|
||
conditions.append(field < coerced_value)
|
||
elif op == "$gte":
|
||
conditions.append(field >= coerced_value)
|
||
elif op == "$lte":
|
||
conditions.append(field <= coerced_value)
|
||
elif op == "$ne":
|
||
conditions.append(field != coerced_value)
|
||
elif op == "$in":
|
||
conditions.append(field.in_(_ensure_list(coerced_value)))
|
||
elif op == "$nin":
|
||
conditions.append(field.not_in(_ensure_list(coerced_value)))
|
||
else:
|
||
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
|
||
else:
|
||
coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value
|
||
conditions.append(field == coerced_value)
|
||
|
||
conditions.append(Messages.message_id != "notice")
|
||
if filter_bot:
|
||
conditions.append(Messages.user_id != global_config.bot.qq_account)
|
||
if filter_command:
|
||
conditions.append(Messages.is_command == False) # noqa: E712
|
||
|
||
statement = select(Messages).where(*conditions)
|
||
if limit > 0:
|
||
if limit_mode == "earliest":
|
||
statement = statement.order_by(col(Messages.timestamp)).limit(limit)
|
||
with get_db_session() as session:
|
||
results = list(session.exec(statement).all())
|
||
else:
|
||
statement = statement.order_by(col(Messages.timestamp).desc()).limit(limit)
|
||
with get_db_session() as session:
|
||
results = list(session.exec(statement).all())
|
||
results = list(reversed(results))
|
||
else:
|
||
if sort:
|
||
order_terms: list[Any] = []
|
||
for field_name, direction in sort:
|
||
sort_field = _resolve_field(field_name)
|
||
if sort_field is None:
|
||
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
||
continue
|
||
order_terms.append(sort_field.asc() if direction == 1 else sort_field.desc())
|
||
if order_terms:
|
||
statement = statement.order_by(*order_terms)
|
||
with get_db_session() as session:
|
||
results = list(session.exec(statement).all())
|
||
|
||
if filter_intercept_message_level is not None:
|
||
filtered_results = []
|
||
for msg in results:
|
||
config = _parse_additional_config(msg)
|
||
if config.get("intercept_message_level", 0) <= filter_intercept_message_level:
|
||
filtered_results.append(msg)
|
||
results = filtered_results
|
||
|
||
return [_message_to_instance(msg) for msg in results]
|
||
except Exception as e:
|
||
log_message = (
|
||
f"使用 SQLModel 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||
+ traceback.format_exc()
|
||
)
|
||
logger.error(log_message)
|
||
return []
|
||
|
||
|
||
def count_messages(message_filter: dict[str, Any]) -> int:
|
||
"""
|
||
根据提供的过滤器计算消息数量。
|
||
|
||
Args:
|
||
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
|
||
|
||
Returns:
|
||
符合条件的消息数量,如果出错则返回 0。
|
||
"""
|
||
try:
|
||
conditions: list[Any] = []
|
||
if message_filter:
|
||
for key, value in message_filter.items():
|
||
field = _resolve_field(key)
|
||
if field is None:
|
||
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
||
continue
|
||
if isinstance(value, dict):
|
||
for op, op_value in value.items():
|
||
coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value
|
||
if op == "$gt":
|
||
conditions.append(field > coerced_value)
|
||
elif op == "$lt":
|
||
conditions.append(field < coerced_value)
|
||
elif op == "$gte":
|
||
conditions.append(field >= coerced_value)
|
||
elif op == "$lte":
|
||
conditions.append(field <= coerced_value)
|
||
elif op == "$ne":
|
||
conditions.append(field != coerced_value)
|
||
elif op == "$in":
|
||
conditions.append(field.in_(_ensure_list(coerced_value)))
|
||
elif op == "$nin":
|
||
conditions.append(field.not_in(_ensure_list(coerced_value)))
|
||
else:
|
||
logger.warning(f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
|
||
else:
|
||
coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value
|
||
conditions.append(field == coerced_value)
|
||
|
||
conditions.append(Messages.message_id != "notice")
|
||
statement = select(func.count()).select_from(Messages).where(*conditions)
|
||
with get_db_session() as session:
|
||
result = session.exec(statement).one()
|
||
return int(result or 0)
|
||
except Exception as e:
|
||
log_message = f"使用 SQLModel 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||
logger.error(log_message)
|
||
return 0
|