重构绝大部分模块以适配新版本的数据库和数据模型,修复缺少依赖问题,更新 pyproject

This commit is contained in:
DrSmoothl
2026-02-13 20:39:11 +08:00
parent c14736ffca
commit 16b16d2ca6
29 changed files with 2459 additions and 1737 deletions

View File

@@ -6,10 +6,13 @@ import random
import math
from json_repair import repair_json
from typing import Union, Optional
from typing import Union, Optional, Dict
from datetime import datetime
from sqlmodel import col, select
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
@@ -35,24 +38,37 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str:
def get_person_id_by_person_name(person_name: str) -> str:
"""根据用户名获取用户ID"""
try:
record = PersonInfo.get_or_none(PersonInfo.person_name == person_name)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_name) == person_name).limit(1)
record = session.exec(statement).first()
return record.person_id if record else ""
except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
return ""
def is_person_known(person_id: str = None, user_id: str = None, platform: str = None, person_name: str = None) -> bool: # type: ignore
def is_person_known(
person_id: Optional[str] = None,
user_id: Optional[str] = None,
platform: Optional[str] = None,
person_name: Optional[str] = None,
) -> bool:
if person_id:
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
return person.is_known if person else False
elif user_id and platform:
person_id = get_person_id(platform, user_id)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
return person.is_known if person else False
elif person_name:
person_id = get_person_id_by_person_name(person_name)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
return person.is_known if person else False
else:
return False
@@ -442,17 +458,18 @@ class Person:
def load_from_database(self):
"""从数据库加载个人信息数据"""
try:
# 查询数据库中的记录
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == self.person_id).limit(1)
record = session.exec(statement).first()
if record:
self.user_id = record.user_id or ""
self.platform = record.platform or ""
self.is_known = record.is_known or False
self.nickname = record.nickname or ""
self.nickname = record.user_nickname or ""
self.person_name = record.person_name or self.nickname
self.name_reason = record.name_reason or None
self.know_times = record.know_times or 0
self.know_times = record.know_counts or 0
# 处理points字段JSON格式的列表
if record.memory_points:
@@ -470,16 +487,16 @@ class Person:
self.memory_points = []
# 处理group_nick_name字段JSON格式的列表
if record.group_nick_name:
if record.group_nickname:
try:
loaded_group_nick_names = json.loads(record.group_nick_name)
loaded_group_nick_names = json.loads(record.group_nickname)
# 确保是列表格式
if isinstance(loaded_group_nick_names, list):
self.group_nick_name = loaded_group_nick_names
else:
self.group_nick_name = []
except (json.JSONDecodeError, TypeError):
logger.warning(f"解析用户 {self.person_id} 的group_nick_name字段失败使用默认值")
logger.warning(f"解析用户 {self.person_id} 的group_nickname字段失败使用默认值")
self.group_nick_name = []
else:
self.group_nick_name = []
@@ -498,42 +515,55 @@ class Person:
if not self.is_known:
return
try:
# 准备数据
data = {
"person_id": self.person_id,
"is_known": self.is_known,
"platform": self.platform,
"user_id": self.user_id,
"nickname": self.nickname,
"person_name": self.person_name,
"name_reason": self.name_reason,
"know_times": self.know_times,
"know_since": self.know_since,
"last_know": self.last_know,
"memory_points": json.dumps(
[point for point in self.memory_points if point is not None], ensure_ascii=False
)
memory_points_value = (
json.dumps([point for point in self.memory_points if point is not None], ensure_ascii=False)
if self.memory_points
else json.dumps([], ensure_ascii=False),
"group_nick_name": json.dumps(self.group_nick_name, ensure_ascii=False)
else json.dumps([], ensure_ascii=False)
)
group_nickname_value = (
json.dumps(self.group_nick_name, ensure_ascii=False)
if self.group_nick_name
else json.dumps([], ensure_ascii=False),
}
else json.dumps([], ensure_ascii=False)
)
first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None
last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None
# 检查记录是否存在
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == self.person_id).limit(1)
record = session.exec(statement).first()
if record:
# 更新现有记录
for field, value in data.items():
if hasattr(record, field):
setattr(record, field, value)
record.save()
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
else:
# 创建新记录
PersonInfo.create(**data)
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
if record:
record.person_id = self.person_id
record.is_known = self.is_known
record.platform = self.platform
record.user_id = self.user_id
record.user_nickname = self.nickname
record.person_name = self.person_name
record.name_reason = self.name_reason
record.know_counts = self.know_times
record.first_known_time = first_known_time
record.last_known_time = last_known_time
record.memory_points = memory_points_value
record.group_nickname = group_nickname_value
session.add(record)
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
else:
record = PersonInfo(
person_id=self.person_id,
is_known=self.is_known,
platform=self.platform,
user_id=self.user_id,
user_nickname=self.nickname,
person_name=self.person_name,
name_reason=self.name_reason,
know_counts=self.know_times,
first_known_time=first_known_time,
last_known_time=last_known_time,
memory_points=memory_points_value,
group_nickname=group_nickname_value,
)
session.add(record)
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
except Exception as e:
logger.error(f"同步用户 {self.person_id} 信息到数据库时出错: {e}")
@@ -621,30 +651,26 @@ class PersonInfoManager:
self.person_name_list = {}
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
try:
db.connect(reuse_if_open=True)
# 设置连接池参数
if hasattr(db, "execute_sql"):
# 设置SQLite优化参数
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
db.create_tables([PersonInfo], safe=True)
with get_db_session() as _:
pass
except Exception as e:
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
# 初始化时读取所有person_name
try:
for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
PersonInfo.person_name.is_null(False)
):
if record.person_name:
self.person_name_list[record.person_id] = record.person_name
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
with get_db_session() as session:
statement = select(PersonInfo.person_id, PersonInfo.person_name).where(
col(PersonInfo.person_name).is_not(None)
)
for person_id, person_name in session.exec(statement).all():
if person_name:
self.person_name_list[person_id] = person_name
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称")
except Exception as e:
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
logger.error(f"加载 person_name_list 失败: {e}")
@staticmethod
def _extract_json_from_text(text: str) -> dict:
def _extract_json_from_text(text: str) -> Dict[str, str]:
"""从文本中提取JSON数据的高容错方法"""
try:
fixed_json = repair_json(text)
@@ -744,7 +770,9 @@ class PersonInfoManager:
else:
def _db_check_name_exists_sync(name_to_check):
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
with get_db_session() as session:
statement = select(PersonInfo.person_id).where(col(PersonInfo.person_name) == name_to_check)
return session.exec(statement).first() is not None
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
is_duplicate = True
@@ -804,7 +832,7 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
if not person_id:
# 如果通过person_name找不到尝试从chat_stream获取user_info
if chat_stream.user_info:
if platform and chat_stream.user_info and chat_stream.user_info.user_id:
user_id = chat_stream.user_info.user_id
person_id = get_person_id(platform, user_id)
else: