import traceback from datetime import datetime from types import SimpleNamespace from typing import Any import json from sqlalchemy import func from sqlmodel import col, select from src.common.database.database import get_db_session from src.common.database.database_model import Messages from src.chat.message_receive.message import SessionMessage from src.common.logger import get_logger from src.config.config import global_config logger = get_logger(__name__) FIELD_MAP: dict[str, Any] = { "time": Messages.timestamp, "timestamp": Messages.timestamp, "chat_id": Messages.session_id, "session_id": Messages.session_id, "user_id": Messages.user_id, "message_id": Messages.message_id, "group_id": Messages.group_id, "platform": Messages.platform, "is_command": Messages.is_command, "is_mentioned": Messages.is_mentioned, "is_at": Messages.is_at, "is_emoji": Messages.is_emoji, "is_picid": Messages.is_picture, "is_picture": Messages.is_picture, "reply_to": Messages.reply_to, } def _parse_additional_config(message: Messages) -> dict[str, Any]: if not message.additional_config: return {} try: parsed = json.loads(message.additional_config) except (json.JSONDecodeError, TypeError): return {} if isinstance(parsed, dict): return parsed return {} def _normalize_optional_str(value: object) -> str | None: if value is None: return None if isinstance(value, str): return value try: return json.dumps(value, ensure_ascii=False) except (TypeError, ValueError): return str(value) def _message_to_instance(message: Messages) -> SessionMessage: config = _parse_additional_config(message) instance = SessionMessage.from_db_instance(message) instance.interest_value = config.get("interest_value") instance.key_words = _normalize_optional_str(config.get("key_words")) instance.key_words_lite = _normalize_optional_str(config.get("key_words_lite")) instance.reply_probability_boost = config.get("reply_probability_boost") instance.priority_mode = _normalize_optional_str(config.get("priority_mode")) instance.priority_info = _normalize_optional_str(config.get("priority_info")) instance.intercept_message_level = config.get("intercept_message_level", 0) instance.selected_expressions = _normalize_optional_str(config.get("selected_expressions")) group_info = instance.message_info.group_info legacy_group_info = None if group_info: legacy_group_info = SimpleNamespace( group_id=group_info.group_id, group_name=group_info.group_name, ) instance.user_info = SimpleNamespace( user_id=instance.message_info.user_info.user_id, user_nickname=instance.message_info.user_info.user_nickname, user_cardname=instance.message_info.user_info.user_cardname, platform=instance.platform, ) instance.chat_info = SimpleNamespace( platform=instance.platform, stream_id=instance.session_id, group_info=legacy_group_info, ) instance.time = instance.timestamp.timestamp() return instance def _coerce_datetime(value: Any) -> Any: if isinstance(value, (int, float)): return datetime.fromtimestamp(value) return value def _cast_value_for_field(field: Any, value: Any) -> Any: if field is Messages.timestamp: return _coerce_datetime(value) return value def _ensure_list(value: Any) -> list[Any]: if value is None: return [] if isinstance(value, list): return value if isinstance(value, tuple): return list(value) if isinstance(value, set): return list(value) return [value] def _resolve_field(field_name: str) -> Any | None: if field_name in FIELD_MAP: return FIELD_MAP[field_name] if hasattr(Messages, field_name): return getattr(Messages, field_name) return None def find_messages( message_filter: dict[str, Any], sort: list[tuple[str, int]] | None = None, limit: int = 0, limit_mode: str = "latest", filter_bot: bool = False, filter_command: bool = False, filter_intercept_message_level: int | None = None, ) -> list[SessionMessage]: """ 根据提供的过滤器、排序和限制条件查找消息。 Args: message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}). sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。 limit: 返回的最大文档数,0表示不限制。 limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。 Returns: 消息字典列表,如果出错则返回空列表。 """ try: conditions: list[Any] = [] if message_filter: for key, value in message_filter.items(): field = _resolve_field(key) if field is None: logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") continue if isinstance(value, dict): for op, op_value in value.items(): coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value if op == "$gt": conditions.append(field > coerced_value) elif op == "$lt": conditions.append(field < coerced_value) elif op == "$gte": conditions.append(field >= coerced_value) elif op == "$lte": conditions.append(field <= coerced_value) elif op == "$ne": conditions.append(field != coerced_value) elif op == "$in": conditions.append(field.in_(_ensure_list(coerced_value))) elif op == "$nin": conditions.append(field.not_in(_ensure_list(coerced_value))) else: logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。") else: coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value conditions.append(field == coerced_value) conditions.append(Messages.message_id != "notice") if filter_bot: conditions.append(Messages.user_id != global_config.bot.qq_account) if filter_command: conditions.append(Messages.is_command == False) # noqa: E712 statement = select(Messages).where(*conditions) if limit > 0: if limit_mode == "earliest": statement = statement.order_by(col(Messages.timestamp)).limit(limit) with get_db_session() as session: results = list(session.exec(statement).all()) else: statement = statement.order_by(col(Messages.timestamp).desc()).limit(limit) with get_db_session() as session: results = list(session.exec(statement).all()) results = list(reversed(results)) else: if sort: order_terms: list[Any] = [] for field_name, direction in sort: sort_field = _resolve_field(field_name) if sort_field is None: logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。") continue order_terms.append(sort_field.asc() if direction == 1 else sort_field.desc()) if order_terms: statement = statement.order_by(*order_terms) with get_db_session() as session: results = list(session.exec(statement).all()) if filter_intercept_message_level is not None: filtered_results = [] for msg in results: config = _parse_additional_config(msg) if config.get("intercept_message_level", 0) <= filter_intercept_message_level: filtered_results.append(msg) results = filtered_results return [_message_to_instance(msg) for msg in results] except Exception as e: log_message = ( f"使用 SQLModel 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" + traceback.format_exc() ) logger.error(log_message) return [] def count_messages(message_filter: dict[str, Any]) -> int: """ 根据提供的过滤器计算消息数量。 Args: message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}). Returns: 符合条件的消息数量,如果出错则返回 0。 """ try: conditions: list[Any] = [] if message_filter: for key, value in message_filter.items(): field = _resolve_field(key) if field is None: logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") continue if isinstance(value, dict): for op, op_value in value.items(): coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value if op == "$gt": conditions.append(field > coerced_value) elif op == "$lt": conditions.append(field < coerced_value) elif op == "$gte": conditions.append(field >= coerced_value) elif op == "$lte": conditions.append(field <= coerced_value) elif op == "$ne": conditions.append(field != coerced_value) elif op == "$in": conditions.append(field.in_(_ensure_list(coerced_value))) elif op == "$nin": conditions.append(field.not_in(_ensure_list(coerced_value))) else: logger.warning(f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。") else: coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value conditions.append(field == coerced_value) conditions.append(Messages.message_id != "notice") statement = select(func.count()).select_from(Messages).where(*conditions) with get_db_session() as session: result = session.exec(statement).one() return int(result or 0) except Exception as e: log_message = f"使用 SQLModel 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}" logger.error(log_message) return 0