重构绝大部分模块以适配新版本的数据库和数据模型,修复缺少依赖问题,更新 pyproject
This commit is contained in:
@@ -6,10 +6,13 @@ import random
|
||||
import math
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import Union, Optional
|
||||
from typing import Union, Optional, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
@@ -35,24 +38,37 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||
def get_person_id_by_person_name(person_name: str) -> str:
|
||||
"""根据用户名获取用户ID"""
|
||||
try:
|
||||
record = PersonInfo.get_or_none(PersonInfo.person_name == person_name)
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_name) == person_name).limit(1)
|
||||
record = session.exec(statement).first()
|
||||
return record.person_id if record else ""
|
||||
except Exception as e:
|
||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
|
||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def is_person_known(person_id: str = None, user_id: str = None, platform: str = None, person_name: str = None) -> bool: # type: ignore
|
||||
def is_person_known(
|
||||
person_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
platform: Optional[str] = None,
|
||||
person_name: Optional[str] = None,
|
||||
) -> bool:
|
||||
if person_id:
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||
person = session.exec(statement).first()
|
||||
return person.is_known if person else False
|
||||
elif user_id and platform:
|
||||
person_id = get_person_id(platform, user_id)
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||
person = session.exec(statement).first()
|
||||
return person.is_known if person else False
|
||||
elif person_name:
|
||||
person_id = get_person_id_by_person_name(person_name)
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||
person = session.exec(statement).first()
|
||||
return person.is_known if person else False
|
||||
else:
|
||||
return False
|
||||
@@ -442,17 +458,18 @@ class Person:
|
||||
def load_from_database(self):
|
||||
"""从数据库加载个人信息数据"""
|
||||
try:
|
||||
# 查询数据库中的记录
|
||||
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == self.person_id).limit(1)
|
||||
record = session.exec(statement).first()
|
||||
|
||||
if record:
|
||||
self.user_id = record.user_id or ""
|
||||
self.platform = record.platform or ""
|
||||
self.is_known = record.is_known or False
|
||||
self.nickname = record.nickname or ""
|
||||
self.nickname = record.user_nickname or ""
|
||||
self.person_name = record.person_name or self.nickname
|
||||
self.name_reason = record.name_reason or None
|
||||
self.know_times = record.know_times or 0
|
||||
self.know_times = record.know_counts or 0
|
||||
|
||||
# 处理points字段(JSON格式的列表)
|
||||
if record.memory_points:
|
||||
@@ -470,16 +487,16 @@ class Person:
|
||||
self.memory_points = []
|
||||
|
||||
# 处理group_nick_name字段(JSON格式的列表)
|
||||
if record.group_nick_name:
|
||||
if record.group_nickname:
|
||||
try:
|
||||
loaded_group_nick_names = json.loads(record.group_nick_name)
|
||||
loaded_group_nick_names = json.loads(record.group_nickname)
|
||||
# 确保是列表格式
|
||||
if isinstance(loaded_group_nick_names, list):
|
||||
self.group_nick_name = loaded_group_nick_names
|
||||
else:
|
||||
self.group_nick_name = []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(f"解析用户 {self.person_id} 的group_nick_name字段失败,使用默认值")
|
||||
logger.warning(f"解析用户 {self.person_id} 的group_nickname字段失败,使用默认值")
|
||||
self.group_nick_name = []
|
||||
else:
|
||||
self.group_nick_name = []
|
||||
@@ -498,42 +515,55 @@ class Person:
|
||||
if not self.is_known:
|
||||
return
|
||||
try:
|
||||
# 准备数据
|
||||
data = {
|
||||
"person_id": self.person_id,
|
||||
"is_known": self.is_known,
|
||||
"platform": self.platform,
|
||||
"user_id": self.user_id,
|
||||
"nickname": self.nickname,
|
||||
"person_name": self.person_name,
|
||||
"name_reason": self.name_reason,
|
||||
"know_times": self.know_times,
|
||||
"know_since": self.know_since,
|
||||
"last_know": self.last_know,
|
||||
"memory_points": json.dumps(
|
||||
[point for point in self.memory_points if point is not None], ensure_ascii=False
|
||||
)
|
||||
memory_points_value = (
|
||||
json.dumps([point for point in self.memory_points if point is not None], ensure_ascii=False)
|
||||
if self.memory_points
|
||||
else json.dumps([], ensure_ascii=False),
|
||||
"group_nick_name": json.dumps(self.group_nick_name, ensure_ascii=False)
|
||||
else json.dumps([], ensure_ascii=False)
|
||||
)
|
||||
group_nickname_value = (
|
||||
json.dumps(self.group_nick_name, ensure_ascii=False)
|
||||
if self.group_nick_name
|
||||
else json.dumps([], ensure_ascii=False),
|
||||
}
|
||||
else json.dumps([], ensure_ascii=False)
|
||||
)
|
||||
first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None
|
||||
last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None
|
||||
|
||||
# 检查记录是否存在
|
||||
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == self.person_id).limit(1)
|
||||
record = session.exec(statement).first()
|
||||
|
||||
if record:
|
||||
# 更新现有记录
|
||||
for field, value in data.items():
|
||||
if hasattr(record, field):
|
||||
setattr(record, field, value)
|
||||
record.save()
|
||||
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
|
||||
else:
|
||||
# 创建新记录
|
||||
PersonInfo.create(**data)
|
||||
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
|
||||
if record:
|
||||
record.person_id = self.person_id
|
||||
record.is_known = self.is_known
|
||||
record.platform = self.platform
|
||||
record.user_id = self.user_id
|
||||
record.user_nickname = self.nickname
|
||||
record.person_name = self.person_name
|
||||
record.name_reason = self.name_reason
|
||||
record.know_counts = self.know_times
|
||||
record.first_known_time = first_known_time
|
||||
record.last_known_time = last_known_time
|
||||
record.memory_points = memory_points_value
|
||||
record.group_nickname = group_nickname_value
|
||||
session.add(record)
|
||||
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
|
||||
else:
|
||||
record = PersonInfo(
|
||||
person_id=self.person_id,
|
||||
is_known=self.is_known,
|
||||
platform=self.platform,
|
||||
user_id=self.user_id,
|
||||
user_nickname=self.nickname,
|
||||
person_name=self.person_name,
|
||||
name_reason=self.name_reason,
|
||||
know_counts=self.know_times,
|
||||
first_known_time=first_known_time,
|
||||
last_known_time=last_known_time,
|
||||
memory_points=memory_points_value,
|
||||
group_nickname=group_nickname_value,
|
||||
)
|
||||
session.add(record)
|
||||
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"同步用户 {self.person_id} 信息到数据库时出错: {e}")
|
||||
@@ -621,30 +651,26 @@ class PersonInfoManager:
|
||||
self.person_name_list = {}
|
||||
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
# 设置连接池参数
|
||||
if hasattr(db, "execute_sql"):
|
||||
# 设置SQLite优化参数
|
||||
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
|
||||
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
|
||||
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
|
||||
db.create_tables([PersonInfo], safe=True)
|
||||
with get_db_session() as _:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
|
||||
|
||||
# 初始化时读取所有person_name
|
||||
try:
|
||||
for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
|
||||
PersonInfo.person_name.is_null(False)
|
||||
):
|
||||
if record.person_name:
|
||||
self.person_name_list[record.person_id] = record.person_name
|
||||
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo.person_id, PersonInfo.person_name).where(
|
||||
col(PersonInfo.person_name).is_not(None)
|
||||
)
|
||||
for person_id, person_name in session.exec(statement).all():
|
||||
if person_name:
|
||||
self.person_name_list[person_id] = person_name
|
||||
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称")
|
||||
except Exception as e:
|
||||
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
|
||||
logger.error(f"加载 person_name_list 失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_from_text(text: str) -> dict:
|
||||
def _extract_json_from_text(text: str) -> Dict[str, str]:
|
||||
"""从文本中提取JSON数据的高容错方法"""
|
||||
try:
|
||||
fixed_json = repair_json(text)
|
||||
@@ -744,7 +770,9 @@ class PersonInfoManager:
|
||||
else:
|
||||
|
||||
def _db_check_name_exists_sync(name_to_check):
|
||||
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo.person_id).where(col(PersonInfo.person_name) == name_to_check)
|
||||
return session.exec(statement).first() is not None
|
||||
|
||||
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
|
||||
is_duplicate = True
|
||||
@@ -804,7 +832,7 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
||||
|
||||
if not person_id:
|
||||
# 如果通过person_name找不到,尝试从chat_stream获取user_info
|
||||
if chat_stream.user_info:
|
||||
if platform and chat_stream.user_info and chat_stream.user_info.user_id:
|
||||
user_id = chat_stream.user_info.user_id
|
||||
person_id = get_person_id(platform, user_id)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user