数据库的信息重构为dataclass
This commit is contained in:
@@ -3,13 +3,15 @@ import re
|
||||
import string
|
||||
import time
|
||||
import jieba
|
||||
import json
|
||||
import ast
|
||||
import numpy as np
|
||||
|
||||
from collections import Counter
|
||||
from maim_message import UserInfo
|
||||
from typing import Optional, Tuple, Dict, List, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
@@ -130,22 +132,29 @@ def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> li
|
||||
return []
|
||||
|
||||
who_chat_in_group = []
|
||||
for msg_db_data in recent_messages:
|
||||
user_info = UserInfo.from_dict(
|
||||
{
|
||||
"platform": msg_db_data["user_platform"],
|
||||
"user_id": msg_db_data["user_id"],
|
||||
"user_nickname": msg_db_data["user_nickname"],
|
||||
"user_cardname": msg_db_data.get("user_cardname", ""),
|
||||
}
|
||||
)
|
||||
for db_msg in recent_messages:
|
||||
# user_info = UserInfo.from_dict(
|
||||
# {
|
||||
# "platform": msg_db_data["user_platform"],
|
||||
# "user_id": msg_db_data["user_id"],
|
||||
# "user_nickname": msg_db_data["user_nickname"],
|
||||
# "user_cardname": msg_db_data.get("user_cardname", ""),
|
||||
# }
|
||||
# )
|
||||
# if (
|
||||
# (user_info.platform, user_info.user_id) != sender
|
||||
# and user_info.user_id != global_config.bot.qq_account
|
||||
# and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
|
||||
# and len(who_chat_in_group) < 5
|
||||
# ): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
||||
# who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname))
|
||||
if (
|
||||
(user_info.platform, user_info.user_id) != sender
|
||||
and user_info.user_id != global_config.bot.qq_account
|
||||
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
|
||||
(db_msg.user_info.platform, db_msg.user_info.user_id) != sender
|
||||
and db_msg.user_info.user_id != global_config.bot.qq_account
|
||||
and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname) not in who_chat_in_group
|
||||
and len(who_chat_in_group) < 5
|
||||
): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
||||
who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname))
|
||||
who_chat_in_group.append((db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname))
|
||||
|
||||
return who_chat_in_group
|
||||
|
||||
@@ -555,7 +564,7 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
|
||||
|
||||
# 获取消息内容计算总长度
|
||||
messages = find_messages(message_filter=filter_query)
|
||||
total_length = sum(len(msg.get("processed_plain_text", "")) for msg in messages)
|
||||
total_length = sum(len(msg.processed_plain_text or "") for msg in messages)
|
||||
|
||||
return count, total_length
|
||||
|
||||
@@ -628,41 +637,34 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
user_id: str = user_info.user_id # type: ignore
|
||||
|
||||
# Initialize target_info with basic info
|
||||
target_info = {
|
||||
"platform": platform,
|
||||
"user_id": user_id,
|
||||
"user_nickname": user_info.user_nickname,
|
||||
"person_id": None,
|
||||
"person_name": None,
|
||||
}
|
||||
target_info = TargetPersonInfo(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
user_nickname=user_info.user_nickname, # type: ignore
|
||||
person_id=None,
|
||||
person_name=None
|
||||
)
|
||||
|
||||
# Try to fetch person info
|
||||
try:
|
||||
# Assume get_person_id is sync (as per original code), keep using to_thread
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
if not person.is_known:
|
||||
logger.warning(f"用户 {user_info.user_nickname} 尚未认识")
|
||||
# 如果用户尚未认识,则返回False和None
|
||||
return False, None
|
||||
person_id = person.person_id
|
||||
person_name = None
|
||||
if person_id:
|
||||
# get_value is async, so await it directly
|
||||
person_name = person.person_name
|
||||
|
||||
target_info["person_id"] = person_id
|
||||
target_info["person_name"] = person_name
|
||||
if person.person_id:
|
||||
target_info.person_id = person.person_id
|
||||
target_info.person_name = person.person_name
|
||||
except Exception as person_e:
|
||||
logger.warning(
|
||||
f"获取 person_id 或 person_name 时出错 for {platform}:{user_id} in utils: {person_e}"
|
||||
)
|
||||
|
||||
chat_target_info = target_info
|
||||
chat_target_info = target_info.__dict__
|
||||
else:
|
||||
logger.warning(f"无法获取 chat_stream for {chat_id} in utils")
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True)
|
||||
# Keep defaults on error
|
||||
|
||||
return is_group_chat, chat_target_info
|
||||
|
||||
@@ -771,6 +773,7 @@ def assign_message_ids_flexible(
|
||||
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
|
||||
|
||||
def parse_keywords_string(keywords_input) -> list[str]:
|
||||
# sourcery skip: use-contextlib-suppress
|
||||
"""
|
||||
统一的关键词解析函数,支持多种格式的关键词字符串解析
|
||||
|
||||
@@ -802,7 +805,6 @@ def parse_keywords_string(keywords_input) -> list[str]:
|
||||
|
||||
try:
|
||||
# 尝试作为JSON对象解析(支持 {"keywords": [...]} 格式)
|
||||
import json
|
||||
json_data = json.loads(keywords_str)
|
||||
if isinstance(json_data, dict) and "keywords" in json_data:
|
||||
keywords_list = json_data["keywords"]
|
||||
@@ -816,7 +818,6 @@ def parse_keywords_string(keywords_input) -> list[str]:
|
||||
|
||||
try:
|
||||
# 尝试使用 ast.literal_eval 解析(支持Python字面量格式)
|
||||
import ast
|
||||
parsed = ast.literal_eval(keywords_str)
|
||||
if isinstance(parsed, list):
|
||||
return [str(k).strip() for k in parsed if str(k).strip()]
|
||||
|
||||
Reference in New Issue
Block a user