Files
mai-bot/src/person_info/person_info.py
2026-04-23 16:02:32 +08:00

590 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from datetime import datetime
from typing import Optional, Union
import hashlib
import json
import time
from sqlmodel import col, select
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
from src.common.logger import get_logger
from src.config.config import global_config
from src.services.memory_service import memory_service
logger = get_logger("person_info")
def _to_group_cardname_records(group_cardname_json: Optional[str]) -> list[dict[str, str]]:
"""将数据库中的群名片 JSON 转换为 `Person` 内部使用的结构。
Args:
group_cardname_json: 数据库存储的群名片 JSON 字符串。
Returns:
list[dict[str, str]]: 统一使用 `group_cardname` 键名的群名片列表。
Raises:
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
"""
group_cardname_list = parse_group_cardname_json(group_cardname_json)
if not group_cardname_list:
return []
return [
{
"group_id": group_cardname.group_id,
"group_cardname": group_cardname.group_cardname,
}
for group_cardname in group_cardname_list
]
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
"""获取唯一id"""
if "-" in platform:
platform = platform.split("-")[1]
components = [platform, str(user_id)]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
def get_person_id_by_person_name(person_name: str) -> str:
"""根据用户名获取用户ID"""
try:
with get_db_session(auto_commit=False) as session:
statement = select(PersonInfo.person_id).where(col(PersonInfo.person_name) == person_name).limit(1)
person_id = session.exec(statement).first()
return str(person_id) if person_id else ""
except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
return ""
def resolve_person_id_for_memory(
*,
person_name: str = "",
platform: str = "",
user_id: Union[int, str, None] = None,
strict_known: bool = False,
) -> str:
"""解析长期记忆检索/写入使用的人物 ID。
解析顺序:
1. 优先按 `person_name` 映射数据库中的 `person_id`
2. 回退到 `platform + user_id` 生成稳定 `person_id`
3. 若 `strict_known=True`,则要求该 `person_id` 已被认识
"""
clean_name = str(person_name or "").strip()
if clean_name:
if by_name := get_person_id_by_person_name(clean_name):
return by_name
clean_platform = str(platform or "").strip()
clean_user_id = str(user_id or "").strip()
if clean_platform and clean_user_id:
candidate = get_person_id(clean_platform, clean_user_id)
if strict_known and not is_person_known(person_id=candidate):
return ""
return candidate
return ""
def is_person_known(
person_id: Optional[str] = None,
user_id: Optional[str] = None,
platform: Optional[str] = None,
person_name: Optional[str] = None,
) -> bool: # sourcery skip: extract-duplicate-method
if person_id:
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
return person.is_known if person else False
elif user_id and platform:
person_id = get_person_id(platform, user_id)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
return person.is_known if person else False
elif person_name:
person_id = get_person_id_by_person_name(person_name)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
return person.is_known if person else False
else:
return False
def calculate_string_similarity(s1: str, s2: str) -> float:
"""
计算两个字符串的相似度
Args:
s1: 第一个字符串
s2: 第二个字符串
Returns:
float: 相似度范围0-11表示完全相同
"""
if s1 == s2:
return 1.0
if not s1 or not s2:
return 0.0
# 计算Levenshtein距离
distance = levenshtein_distance(s1, s2)
max_len = max(len(s1), len(s2))
# 计算相似度1 - (编辑距离 / 最大长度)
similarity = 1 - (distance / max_len if max_len > 0 else 0)
return similarity
def levenshtein_distance(s1: str, s2: str) -> int:
"""
计算两个字符串的编辑距离
Args:
s1: 第一个字符串
s2: 第二个字符串
Returns:
int: 编辑距离
"""
if len(s1) < len(s2):
return levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
class Person:
@classmethod
def register_person(
cls,
platform: str,
user_id: str,
nickname: str,
group_id: Optional[str] = None,
group_nick_name: Optional[str] = None,
):
"""
注册新用户的类方法
必须输入 platform、user_id 和 nickname 参数
Args:
platform: 平台名称
user_id: 用户ID
nickname: 用户昵称
group_id: 群号(可选,仅在群聊时提供)
group_nick_name: 群昵称(可选,仅在群聊时提供)
Returns:
Person: 新注册的Person实例
"""
if not platform or not user_id or not nickname:
logger.error("注册用户失败platform、user_id 和 nickname 都是必需参数")
return None
# 生成唯一的person_id
person_id = get_person_id(platform, user_id)
if is_person_known(person_id=person_id):
logger.debug(f"用户 {nickname} 已存在")
person = Person(person_id=person_id)
# 如果是群聊,更新群昵称
if group_id and group_nick_name:
person.add_group_nick_name(group_id, group_nick_name)
return person
# 创建Person实例
person = cls.__new__(cls)
# 设置基本属性
person.person_id = person_id
person.platform = platform
person.user_id = user_id
person.nickname = nickname
# 初始化默认值
person.is_known = True # 注册后立即标记为已认识
person.person_name = nickname # 使用nickname作为初始person_name
person.name_reason = "用户注册时设置的昵称"
person.know_times = 1
person.know_since = time.time()
person.last_know = time.time()
person.memory_points = []
person.group_cardname_list = [] # 初始化群名片列表
# 如果是群聊,添加群昵称
if group_id and group_nick_name:
person.add_group_nick_name(group_id, group_nick_name)
# 同步到数据库
person.sync_to_database()
logger.info(f"成功注册新用户:{person_id},平台:{platform},昵称:{nickname}")
return person
def _is_bot_self(self, platform: str, user_id: str) -> bool:
"""判断给定的平台和用户ID是否是机器人自己
这个函数统一处理所有平台(包括 QQ、Telegram、WebUI 等)的机器人识别逻辑。
Args:
platform: 消息平台(如 "qq", "telegram", "webui" 等)
user_id: 用户ID
Returns:
bool: 如果是机器人自己则返回 True否则返回 False
"""
from src.chat.utils.utils import is_bot_self
return is_bot_self(platform, user_id)
def __init__(self, platform: str = "", user_id: str = "", person_id: str = "", person_name: str = ""):
# 使用统一的机器人识别函数(支持多平台,包括 WebUI
if self._is_bot_self(platform, user_id):
self.is_known = True
self.person_id = get_person_id(platform, user_id)
self.user_id = user_id
self.platform = platform
self.nickname = global_config.bot.nickname
self.person_name = global_config.bot.nickname
self.group_cardname_list: list[dict[str, str]] = []
return
self.user_id = ""
self.platform = ""
if person_id:
self.person_id = person_id
elif person_name:
self.person_id = get_person_id_by_person_name(person_name)
if not self.person_id:
self.is_known = False
logger.warning(f"根据用户名 {person_name} 获取用户ID时不存在用户{person_name}")
return
elif platform and user_id:
self.person_id = get_person_id(platform, user_id)
self.user_id = user_id
self.platform = platform
else:
logger.error("Person 初始化失败,缺少必要参数")
raise ValueError("Person 初始化失败,缺少必要参数")
if not is_person_known(person_id=self.person_id):
self.is_known = False
logger.debug(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
self.person_name = f"未知用户{self.person_id[:4]}"
return
# raise ValueError(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
self.is_known = False
# 初始化默认值
self.nickname = ""
self.person_name: Optional[str] = None
self.name_reason: Optional[str] = None
self.know_times = 0
self.know_since = None
self.last_know: Optional[float] = None
self.memory_points = []
self.group_cardname_list: list[dict[str, str]] = [] # 群名片列表,存储 {"group_id": str, "group_cardname": str}
# 从数据库加载数据
self.load_from_database()
def del_memory(self, category: str, memory_content: str, similarity_threshold: float = 0.95):
"""
删除指定分类和记忆内容的记忆点
Args:
category: 记忆分类
memory_content: 要删除的记忆内容
similarity_threshold: 相似度阈值默认0.9595%
Returns:
int: 删除的记忆点数量
"""
if not self.memory_points:
return 0
deleted_count = 0
memory_points_to_keep = []
for memory_point in self.memory_points:
# 跳过None值
if memory_point is None:
continue
# 解析记忆点
parts = memory_point.split(":", 2) # 最多分割2次保留记忆内容中的冒号
if len(parts) < 3:
# 格式不正确,保留原样
memory_points_to_keep.append(memory_point)
continue
memory_category = parts[0].strip()
memory_text = parts[1].strip()
_memory_weight = parts[2].strip()
# 检查分类是否匹配
if memory_category != category:
memory_points_to_keep.append(memory_point)
continue
# 计算记忆内容的相似度
similarity = calculate_string_similarity(memory_content, memory_text)
# 如果相似度达到阈值,则删除(不添加到保留列表)
if similarity >= similarity_threshold:
deleted_count += 1
logger.debug(f"删除记忆点: {memory_point} (相似度: {similarity:.4f})")
else:
memory_points_to_keep.append(memory_point)
# 更新memory_points
self.memory_points = memory_points_to_keep
# 同步到数据库
if deleted_count > 0:
self.sync_to_database()
logger.info(f"成功删除 {deleted_count} 个记忆点,分类: {category}")
return deleted_count
def add_group_nick_name(self, group_id: str, group_nick_name: str):
"""
添加或更新群昵称
Args:
group_id: 群号
group_nick_name: 群昵称
"""
if not group_id or not group_nick_name:
return
# 检查是否已存在该群号的记录
for item in self.group_cardname_list:
if item.get("group_id") == group_id:
# 更新现有记录
item["group_cardname"] = group_nick_name
self.sync_to_database()
logger.debug(f"更新用户 {self.person_id} 在群 {group_id} 的群昵称为 {group_nick_name}")
return
# 添加新记录
self.group_cardname_list.append({"group_id": group_id, "group_cardname": group_nick_name})
self.sync_to_database()
logger.debug(f"添加用户 {self.person_id} 在群 {group_id} 的群昵称 {group_nick_name}")
def load_from_database(self):
"""从数据库加载个人信息数据"""
try:
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == self.person_id).limit(1)
record = session.exec(statement).first()
if record:
self.user_id = record.user_id or ""
self.platform = record.platform or ""
self.is_known = record.is_known or False
self.nickname = record.user_nickname or ""
self.person_name = record.person_name or self.nickname
self.name_reason = record.name_reason or None
self.know_times = record.know_counts or 0
# 处理points字段JSON格式的列表
if record.memory_points:
try:
loaded_points = json.loads(record.memory_points)
# 过滤掉None值确保数据质量
if isinstance(loaded_points, list):
self.memory_points = [point for point in loaded_points if point is not None]
else:
self.memory_points = []
except (json.JSONDecodeError, TypeError):
logger.warning(f"解析用户 {self.person_id} 的points字段失败使用默认值")
self.memory_points = []
else:
self.memory_points = []
# 处理 group_cardname 字段JSON 格式的列表)
if record.group_cardname:
try:
self.group_cardname_list = _to_group_cardname_records(record.group_cardname)
except (json.JSONDecodeError, TypeError):
logger.warning(f"解析用户 {self.person_id} 的group_cardname字段失败使用默认值")
self.group_cardname_list = []
else:
self.group_cardname_list = []
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
else:
self.sync_to_database()
logger.info(f"用户 {self.person_id} 在数据库中不存在,使用默认值并创建")
except Exception as e:
logger.error(f"从数据库加载用户 {self.person_id} 信息时出错: {e}")
# 出错时保持默认值
def sync_to_database(self):
"""将所有属性同步回数据库"""
if not self.is_known:
return
try:
memory_points_value = (
json.dumps([point for point in self.memory_points if point is not None], ensure_ascii=False)
if self.memory_points
else json.dumps([], ensure_ascii=False)
)
group_cardname_value = dump_group_cardname_records(self.group_cardname_list)
first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None
last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == self.person_id).limit(1)
record = session.exec(statement).first()
if record:
record.person_id = self.person_id
record.is_known = self.is_known
record.platform = self.platform
record.user_id = self.user_id
record.user_nickname = self.nickname
record.person_name = self.person_name
record.name_reason = self.name_reason
record.know_counts = self.know_times
record.first_known_time = first_known_time
record.last_known_time = last_known_time
record.memory_points = memory_points_value
record.group_cardname = group_cardname_value
session.add(record)
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
else:
record = PersonInfo(
person_id=self.person_id,
is_known=self.is_known,
platform=self.platform,
user_id=self.user_id,
user_nickname=self.nickname,
person_name=self.person_name,
name_reason=self.name_reason,
know_counts=self.know_times,
first_known_time=first_known_time,
last_known_time=last_known_time,
memory_points=memory_points_value,
group_cardname=group_cardname_value,
)
session.add(record)
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
except Exception as e:
logger.error(f"同步用户 {self.person_id} 信息到数据库时出错: {e}")
async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None:
"""将人物事实写入长期记忆系统。
Args:
person_name: 人物名称
memory_content: 记忆内容
chat_id: 聊天ID
"""
clean_content = str(memory_content or "").strip()
if not clean_content:
logger.debug("人物事实写回跳过memory_content 为空")
return
clean_chat_id = str(chat_id or "").strip()
if not clean_chat_id:
logger.warning("人物事实写回失败chat_id 为空")
return
clean_person_name = str(person_name or "").strip()
try:
# 从 chat_id 获取 session
session = _chat_manager.get_session_by_session_id(clean_chat_id)
if not session:
logger.warning(f"无法获取session for chat_id: {clean_chat_id}")
return
session_platform = str(getattr(session, "platform", "") or "").strip()
session_user_id = str(getattr(session, "user_id", "") or "").strip()
session_group_id = str(getattr(session, "group_id", "") or "").strip()
person_id = resolve_person_id_for_memory(
person_name=clean_person_name,
platform=session_platform,
user_id=session_user_id,
)
if not person_id:
logger.warning(f"无法确定person_id for person_name: {clean_person_name}, chat_id: {clean_chat_id}")
return
person = Person(person_id=person_id)
if not person.is_known:
logger.warning(f"用户 {clean_person_name or person_id} (person_id: {person_id}) 尚未认识,跳过写回")
return
participant_name = str(getattr(person, "person_name", "") or getattr(person, "nickname", "") or "").strip()
if not participant_name:
participant_name = clean_person_name or person_id
payload_fingerprint = hashlib.md5(f"{person_id}|{clean_chat_id}|{clean_content}".encode()).hexdigest()
external_id = f"person_fact:{person_id}:{payload_fingerprint}"
result = await memory_service.ingest_text(
external_id=external_id,
source_type="person_fact",
text=clean_content,
chat_id=clean_chat_id,
person_ids=[person_id],
participants=[participant_name],
tags=["person_fact"],
metadata={
"person_id": person_id,
"person_name": participant_name,
"writeback_source": "memory_flow_service",
},
respect_filter=True,
user_id=session_user_id,
group_id=session_group_id,
)
if getattr(result, "success", False):
logger.info(
f"成功写回人物事实到长期记忆: person={participant_name} person_id={person_id} chat_id={clean_chat_id}"
)
else:
logger.warning(
f"人物事实写回长期记忆失败: person={participant_name} person_id={person_id} "
f"chat_id={clean_chat_id} detail={getattr(result, 'detail', '')}"
)
except Exception as e:
logger.error(f"存储人物记忆失败: {e}")