数据库的信息重构为dataclass

This commit is contained in:
UnCLAS-Prommer
2025-08-17 17:11:32 +08:00
parent d74beef4b4
commit 3481234d2b
18 changed files with 243 additions and 206 deletions

View File

@@ -3,13 +3,15 @@ import re
import string
import time
import jieba
import json
import ast
import numpy as np
from collections import Counter
from maim_message import UserInfo
from typing import Optional, Tuple, Dict, List, Any
from src.common.logger import get_logger
from src.common.data_models.info_data_model import TargetPersonInfo
from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
@@ -130,22 +132,29 @@ def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> li
return []
who_chat_in_group = []
for msg_db_data in recent_messages:
user_info = UserInfo.from_dict(
{
"platform": msg_db_data["user_platform"],
"user_id": msg_db_data["user_id"],
"user_nickname": msg_db_data["user_nickname"],
"user_cardname": msg_db_data.get("user_cardname", ""),
}
)
for db_msg in recent_messages:
# user_info = UserInfo.from_dict(
# {
# "platform": msg_db_data["user_platform"],
# "user_id": msg_db_data["user_id"],
# "user_nickname": msg_db_data["user_nickname"],
# "user_cardname": msg_db_data.get("user_cardname", ""),
# }
# )
# if (
# (user_info.platform, user_info.user_id) != sender
# and user_info.user_id != global_config.bot.qq_account
# and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
# and len(who_chat_in_group) < 5
# ): # 排除重复排除消息发送者排除bot限制加载的关系数目
# who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname))
if (
(user_info.platform, user_info.user_id) != sender
and user_info.user_id != global_config.bot.qq_account
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
(db_msg.user_info.platform, db_msg.user_info.user_id) != sender
and db_msg.user_info.user_id != global_config.bot.qq_account
and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname) not in who_chat_in_group
and len(who_chat_in_group) < 5
): # 排除重复排除消息发送者排除bot限制加载的关系数目
who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname))
who_chat_in_group.append((db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname))
return who_chat_in_group
@@ -555,7 +564,7 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
# 获取消息内容计算总长度
messages = find_messages(message_filter=filter_query)
total_length = sum(len(msg.get("processed_plain_text", "")) for msg in messages)
total_length = sum(len(msg.processed_plain_text or "") for msg in messages)
return count, total_length
@@ -628,41 +637,34 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
user_id: str = user_info.user_id # type: ignore
# Initialize target_info with basic info
target_info = {
"platform": platform,
"user_id": user_id,
"user_nickname": user_info.user_nickname,
"person_id": None,
"person_name": None,
}
target_info = TargetPersonInfo(
platform=platform,
user_id=user_id,
user_nickname=user_info.user_nickname, # type: ignore
person_id=None,
person_name=None
)
# Try to fetch person info
try:
# Assume get_person_id is sync (as per original code), keep using to_thread
person = Person(platform=platform, user_id=user_id)
if not person.is_known:
logger.warning(f"用户 {user_info.user_nickname} 尚未认识")
# 如果用户尚未认识则返回False和None
return False, None
person_id = person.person_id
person_name = None
if person_id:
# get_value is async, so await it directly
person_name = person.person_name
target_info["person_id"] = person_id
target_info["person_name"] = person_name
if person.person_id:
target_info.person_id = person.person_id
target_info.person_name = person.person_name
except Exception as person_e:
logger.warning(
f"获取 person_id 或 person_name 时出错 for {platform}:{user_id} in utils: {person_e}"
)
chat_target_info = target_info
chat_target_info = target_info.__dict__
else:
logger.warning(f"无法获取 chat_stream for {chat_id} in utils")
except Exception as e:
logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True)
# Keep defaults on error
return is_group_chat, chat_target_info
@@ -771,6 +773,7 @@ def assign_message_ids_flexible(
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
def parse_keywords_string(keywords_input) -> list[str]:
# sourcery skip: use-contextlib-suppress
"""
统一的关键词解析函数,支持多种格式的关键词字符串解析
@@ -802,7 +805,6 @@ def parse_keywords_string(keywords_input) -> list[str]:
try:
# 尝试作为JSON对象解析支持 {"keywords": [...]} 格式)
import json
json_data = json.loads(keywords_str)
if isinstance(json_data, dict) and "keywords" in json_data:
keywords_list = json_data["keywords"]
@@ -816,7 +818,6 @@ def parse_keywords_string(keywords_input) -> list[str]:
try:
# 尝试使用 ast.literal_eval 解析支持Python字面量格式
import ast
parsed = ast.literal_eval(keywords_str)
if isinstance(parsed, list):
return [str(k).strip() for k in parsed if str(k).strip()]