修复typing问题,保证类型正确

This commit is contained in:
UnCLAS-Prommer
2025-08-21 23:52:44 +08:00
parent ec500f1f5b
commit a55979164e
6 changed files with 230 additions and 267 deletions

View File

@@ -4,44 +4,43 @@ from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
class FocusValueControl: class FocusValueControl:
def __init__(self,chat_id:str): def __init__(self, chat_id: str):
self.chat_id = chat_id self.chat_id = chat_id
self.focus_value_adjust = 1 self.focus_value_adjust: float = 1
def get_current_focus_value(self) -> float: def get_current_focus_value(self) -> float:
return get_current_focus_value(self.chat_id) * self.focus_value_adjust return get_current_focus_value(self.chat_id) * self.focus_value_adjust
class FocusValueControlManager: class FocusValueControlManager:
def __init__(self): def __init__(self):
self.focus_value_controls = {} self.focus_value_controls: dict[str, FocusValueControl] = {}
def get_focus_value_control(self,chat_id:str) -> FocusValueControl: def get_focus_value_control(self, chat_id: str) -> FocusValueControl:
if chat_id not in self.focus_value_controls: if chat_id not in self.focus_value_controls:
self.focus_value_controls[chat_id] = FocusValueControl(chat_id) self.focus_value_controls[chat_id] = FocusValueControl(chat_id)
return self.focus_value_controls[chat_id] return self.focus_value_controls[chat_id]
def get_current_focus_value(chat_id: Optional[str] = None) -> float: def get_current_focus_value(chat_id: Optional[str] = None) -> float:
""" """
根据当前时间和聊天流获取对应的 focus_value 根据当前时间和聊天流获取对应的 focus_value
""" """
if not global_config.chat.focus_value_adjust: if not global_config.chat.focus_value_adjust:
return global_config.chat.focus_value return global_config.chat.focus_value
if chat_id: if chat_id:
stream_focus_value = get_stream_specific_focus_value(chat_id) stream_focus_value = get_stream_specific_focus_value(chat_id)
if stream_focus_value is not None: if stream_focus_value is not None:
return stream_focus_value return stream_focus_value
global_focus_value = get_global_focus_value() global_focus_value = get_global_focus_value()
if global_focus_value is not None: if global_focus_value is not None:
return global_focus_value return global_focus_value
return global_config.chat.focus_value return global_config.chat.focus_value
def get_stream_specific_focus_value(chat_id: str) -> Optional[float]: def get_stream_specific_focus_value(chat_id: str) -> Optional[float]:
""" """
获取特定聊天流在当前时间的专注度 获取特定聊天流在当前时间的专注度
@@ -140,4 +139,5 @@ def get_global_focus_value() -> Optional[float]:
return None return None
focus_value_control = FocusValueControlManager()
focus_value_control = FocusValueControlManager()

View File

@@ -2,20 +2,21 @@ from typing import Optional
from src.config.config import global_config from src.config.config import global_config
from src.chat.frequency_control.utils import parse_stream_config_to_chat_id from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
class TalkFrequencyControl: class TalkFrequencyControl:
def __init__(self,chat_id:str): def __init__(self, chat_id: str):
self.chat_id = chat_id self.chat_id = chat_id
self.talk_frequency_adjust = 1 self.talk_frequency_adjust: float = 1
def get_current_talk_frequency(self) -> float: def get_current_talk_frequency(self) -> float:
return get_current_talk_frequency(self.chat_id) * self.talk_frequency_adjust return get_current_talk_frequency(self.chat_id) * self.talk_frequency_adjust
class TalkFrequencyControlManager: class TalkFrequencyControlManager:
def __init__(self): def __init__(self):
self.talk_frequency_controls = {} self.talk_frequency_controls = {}
def get_talk_frequency_control(self,chat_id:str) -> TalkFrequencyControl: def get_talk_frequency_control(self, chat_id: str) -> TalkFrequencyControl:
if chat_id not in self.talk_frequency_controls: if chat_id not in self.talk_frequency_controls:
self.talk_frequency_controls[chat_id] = TalkFrequencyControl(chat_id) self.talk_frequency_controls[chat_id] = TalkFrequencyControl(chat_id)
return self.talk_frequency_controls[chat_id] return self.talk_frequency_controls[chat_id]
@@ -44,6 +45,7 @@ def get_current_talk_frequency(chat_id: Optional[str] = None) -> float:
global_frequency = get_global_frequency() global_frequency = get_global_frequency()
return global_config.chat.talk_frequency if global_frequency is None else global_frequency return global_config.chat.talk_frequency if global_frequency is None else global_frequency
def get_time_based_frequency(time_freq_list: list[str]) -> Optional[float]: def get_time_based_frequency(time_freq_list: list[str]) -> Optional[float]:
""" """
根据时间配置列表获取当前时段的频率 根据时间配置列表获取当前时段的频率
@@ -124,6 +126,7 @@ def get_stream_specific_frequency(chat_stream_id: str):
return None return None
def get_global_frequency() -> Optional[float]: def get_global_frequency() -> Optional[float]:
""" """
获取全局默认频率配置 获取全局默认频率配置
@@ -141,4 +144,5 @@ def get_global_frequency() -> Optional[float]:
return None return None
talk_frequency_control = TalkFrequencyControlManager()
talk_frequency_control = TalkFrequencyControlManager()

View File

@@ -30,9 +30,7 @@ def cosine_similarity(v1, v2):
dot_product = np.dot(v1, v2) dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1) norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2) norm2 = np.linalg.norm(v2)
if norm1 == 0 or norm2 == 0: return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
return 0
return dot_product / (norm1 * norm2)
install(extra_lines=3) install(extra_lines=3)
@@ -142,11 +140,10 @@ class MemoryGraph:
# 获取当前节点的记忆项 # 获取当前节点的记忆项
node_data = self.get_dot(topic) node_data = self.get_dot(topic)
if node_data: if node_data:
concept, data = node_data _, data = node_data
if "memory_items" in data: if "memory_items" in data:
memory_items = data["memory_items"]
# 直接使用完整的记忆内容 # 直接使用完整的记忆内容
if memory_items: if memory_items := data["memory_items"]:
first_layer_items.append(memory_items) first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆 # 只在depth=2时获取第二层记忆
@@ -154,11 +151,10 @@ class MemoryGraph:
# 获取相邻节点的记忆项 # 获取相邻节点的记忆项
for neighbor in neighbors: for neighbor in neighbors:
if node_data := self.get_dot(neighbor): if node_data := self.get_dot(neighbor):
concept, data = node_data _, data = node_data
if "memory_items" in data: if "memory_items" in data:
memory_items = data["memory_items"]
# 直接使用完整的记忆内容 # 直接使用完整的记忆内容
if memory_items: if memory_items := data["memory_items"]:
second_layer_items.append(memory_items) second_layer_items.append(memory_items)
return first_layer_items, second_layer_items return first_layer_items, second_layer_items
@@ -224,27 +220,17 @@ class MemoryGraph:
# 获取话题节点数据 # 获取话题节点数据
node_data = self.G.nodes[topic] node_data = self.G.nodes[topic]
# 删除整个节点
self.G.remove_node(topic)
# 如果节点存在memory_items # 如果节点存在memory_items
if "memory_items" in node_data: if "memory_items" in node_data:
memory_items = node_data["memory_items"] if memory_items := node_data["memory_items"]:
# 既然每个节点现在是一个完整的记忆内容,直接删除整个节点
if memory_items:
# 删除整个节点
self.G.remove_node(topic)
return ( return (
f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..." f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..."
if len(memory_items) > 50 if len(memory_items) > 50
else f"删除了节点 {topic} 的完整记忆: {memory_items}" else f"删除了节点 {topic} 的完整记忆: {memory_items}"
) )
else: return None
# 如果没有记忆项,删除该节点
self.G.remove_node(topic)
return None
else:
# 如果没有memory_items字段删除该节点
self.G.remove_node(topic)
return None
# 海马体 # 海马体
@@ -392,9 +378,8 @@ class Hippocampus:
# 如果相似度超过阈值,获取该节点的记忆 # 如果相似度超过阈值,获取该节点的记忆
if similarity >= 0.3: # 可以调整这个阈值 if similarity >= 0.3: # 可以调整这个阈值
node_data = self.memory_graph.G.nodes[node] node_data = self.memory_graph.G.nodes[node]
memory_items = node_data.get("memory_items", "")
# 直接使用完整的记忆内容 # 直接使用完整的记忆内容
if memory_items: if memory_items := node_data.get("memory_items", ""):
memories.append((node, memory_items, similarity)) memories.append((node, memory_items, similarity))
# 按相似度降序排序 # 按相似度降序排序
@@ -587,7 +572,7 @@ class Hippocampus:
unique_memories = [] unique_memories = []
for topic, memory_items, activation_value in all_memories: for topic, memory_items, activation_value in all_memories:
# memory_items现在是完整的字符串格式 # memory_items现在是完整的字符串格式
memory = memory_items if memory_items else "" memory = memory_items or ""
if memory not in seen_memories: if memory not in seen_memories:
seen_memories.add(memory) seen_memories.add(memory)
unique_memories.append((topic, memory_items, activation_value)) unique_memories.append((topic, memory_items, activation_value))
@@ -599,7 +584,7 @@ class Hippocampus:
result = [] result = []
for topic, memory_items, _ in unique_memories: for topic, memory_items, _ in unique_memories:
# memory_items现在是完整的字符串格式 # memory_items现在是完整的字符串格式
memory = memory_items if memory_items else "" memory = memory_items or ""
result.append((topic, memory)) result.append((topic, memory))
logger.debug(f"选中记忆: {memory} (来自节点: {topic})") logger.debug(f"选中记忆: {memory} (来自节点: {topic})")
@@ -1471,6 +1456,7 @@ class MemoryBuilder:
self.last_processed_time: float = 0.0 self.last_processed_time: float = 0.0
def should_trigger_memory_build(self) -> bool: def should_trigger_memory_build(self) -> bool:
# sourcery skip: assign-if-exp, boolean-if-exp-identity, reintroduce-else
"""检查是否应该触发记忆构建""" """检查是否应该触发记忆构建"""
current_time = time.time() current_time = time.time()

View File

@@ -3,6 +3,7 @@ import asyncio
import json import json
import time import time
import random import random
import math
from json_repair import repair_json from json_repair import repair_json
from typing import Union, Optional from typing import Union, Optional
@@ -16,6 +17,7 @@ from src.config.config import global_config, model_config
logger = get_logger("person_info") logger = get_logger("person_info")
def get_person_id(platform: str, user_id: Union[int, str]) -> str: def get_person_id(platform: str, user_id: Union[int, str]) -> str:
"""获取唯一id""" """获取唯一id"""
if "-" in platform: if "-" in platform:
@@ -24,6 +26,7 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str:
key = "_".join(components) key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest() return hashlib.md5(key.encode()).hexdigest()
def get_person_id_by_person_name(person_name: str) -> str: def get_person_id_by_person_name(person_name: str) -> str:
"""根据用户名获取用户ID""" """根据用户名获取用户ID"""
try: try:
@@ -33,7 +36,8 @@ def get_person_id_by_person_name(person_name: str) -> str:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}") logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
return "" return ""
def is_person_known(person_id: str = None,user_id: str = None,platform: str = None,person_name: str = None) -> bool:
def is_person_known(person_id: str = None, user_id: str = None, platform: str = None, person_name: str = None) -> bool: # type: ignore
if person_id: if person_id:
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
return person.is_known if person else False return person.is_known if person else False
@@ -47,89 +51,84 @@ def is_person_known(person_id: str = None,user_id: str = None,platform: str = No
return person.is_known if person else False return person.is_known if person else False
else: else:
return False return False
def get_catagory_from_memory(memory_point:str) -> str: def get_category_from_memory(memory_point: str) -> Optional[str]:
"""从记忆点中获取分类""" """从记忆点中获取分类"""
# 按照最左边的:符号进行分割,返回分割后的第一个部分作为分类 # 按照最左边的:符号进行分割,返回分割后的第一个部分作为分类
if not isinstance(memory_point, str): if not isinstance(memory_point, str):
return None return None
parts = memory_point.split(":", 1) parts = memory_point.split(":", 1)
if len(parts) > 1: return parts[0].strip() if len(parts) > 1 else None
return parts[0].strip()
else:
return None def get_weight_from_memory(memory_point: str) -> float:
def get_weight_from_memory(memory_point:str) -> float:
"""从记忆点中获取权重""" """从记忆点中获取权重"""
# 按照最右边的:符号进行分割,返回分割后的最后一个部分作为权重 # 按照最右边的:符号进行分割,返回分割后的最后一个部分作为权重
if not isinstance(memory_point, str): if not isinstance(memory_point, str):
return None return -math.inf
parts = memory_point.rsplit(":", 1) parts = memory_point.rsplit(":", 1)
if len(parts) > 1: if len(parts) <= 1:
try: return -math.inf
return float(parts[-1].strip()) try:
except Exception: return float(parts[-1].strip())
return None except Exception:
else: return -math.inf
return None
def get_memory_content_from_memory(memory_point:str) -> str: def get_memory_content_from_memory(memory_point: str) -> str:
"""从记忆点中获取记忆内容""" """从记忆点中获取记忆内容"""
# 按:进行分割,去掉第一段和最后一段,返回中间部分作为记忆内容 # 按:进行分割,去掉第一段和最后一段,返回中间部分作为记忆内容
if not isinstance(memory_point, str): if not isinstance(memory_point, str):
return None return ""
parts = memory_point.split(":") parts = memory_point.split(":")
if len(parts) > 2: return ":".join(parts[1:-1]).strip() if len(parts) > 2 else ""
return ":".join(parts[1:-1]).strip()
else:
return None
def calculate_string_similarity(s1: str, s2: str) -> float: def calculate_string_similarity(s1: str, s2: str) -> float:
""" """
计算两个字符串的相似度 计算两个字符串的相似度
Args: Args:
s1: 第一个字符串 s1: 第一个字符串
s2: 第二个字符串 s2: 第二个字符串
Returns: Returns:
float: 相似度范围0-11表示完全相同 float: 相似度范围0-11表示完全相同
""" """
if s1 == s2: if s1 == s2:
return 1.0 return 1.0
if not s1 or not s2: if not s1 or not s2:
return 0.0 return 0.0
# 计算Levenshtein距离 # 计算Levenshtein距离
distance = levenshtein_distance(s1, s2) distance = levenshtein_distance(s1, s2)
max_len = max(len(s1), len(s2)) max_len = max(len(s1), len(s2))
# 计算相似度1 - (编辑距离 / 最大长度) # 计算相似度1 - (编辑距离 / 最大长度)
similarity = 1 - (distance / max_len if max_len > 0 else 0) similarity = 1 - (distance / max_len if max_len > 0 else 0)
return similarity return similarity
def levenshtein_distance(s1: str, s2: str) -> int: def levenshtein_distance(s1: str, s2: str) -> int:
""" """
计算两个字符串的编辑距离 计算两个字符串的编辑距离
Args: Args:
s1: 第一个字符串 s1: 第一个字符串
s2: 第二个字符串 s2: 第二个字符串
Returns: Returns:
int: 编辑距离 int: 编辑距离
""" """
if len(s1) < len(s2): if len(s1) < len(s2):
return levenshtein_distance(s2, s1) return levenshtein_distance(s2, s1)
if len(s2) == 0: if len(s2) == 0:
return len(s1) return len(s1)
previous_row = range(len(s2) + 1) previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1): for i, c1 in enumerate(s1):
current_row = [i + 1] current_row = [i + 1]
@@ -139,44 +138,45 @@ def levenshtein_distance(s1: str, s2: str) -> int:
substitutions = previous_row[j] + (c1 != c2) substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions)) current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row previous_row = current_row
return previous_row[-1] return previous_row[-1]
class Person: class Person:
@classmethod @classmethod
def register_person(cls, platform: str, user_id: str, nickname: str): def register_person(cls, platform: str, user_id: str, nickname: str):
""" """
注册新用户的类方法 注册新用户的类方法
必须输入 platform、user_id 和 nickname 参数 必须输入 platform、user_id 和 nickname 参数
Args: Args:
platform: 平台名称 platform: 平台名称
user_id: 用户ID user_id: 用户ID
nickname: 用户昵称 nickname: 用户昵称
Returns: Returns:
Person: 新注册的Person实例 Person: 新注册的Person实例
""" """
if not platform or not user_id or not nickname: if not platform or not user_id or not nickname:
logger.error("注册用户失败platform、user_id 和 nickname 都是必需参数") logger.error("注册用户失败platform、user_id 和 nickname 都是必需参数")
return None return None
# 生成唯一的person_id # 生成唯一的person_id
person_id = get_person_id(platform, user_id) person_id = get_person_id(platform, user_id)
if is_person_known(person_id=person_id): if is_person_known(person_id=person_id):
logger.debug(f"用户 {nickname} 已存在") logger.debug(f"用户 {nickname} 已存在")
return Person(person_id=person_id) return Person(person_id=person_id)
# 创建Person实例 # 创建Person实例
person = cls.__new__(cls) person = cls.__new__(cls)
# 设置基本属性 # 设置基本属性
person.person_id = person_id person.person_id = person_id
person.platform = platform person.platform = platform
person.user_id = user_id person.user_id = user_id
person.nickname = nickname person.nickname = nickname
# 初始化默认值 # 初始化默认值
person.is_known = True # 注册后立即标记为已认识 person.is_known = True # 注册后立即标记为已认识
person.person_name = nickname # 使用nickname作为初始person_name person.person_name = nickname # 使用nickname作为初始person_name
@@ -185,34 +185,34 @@ class Person:
person.know_since = time.time() person.know_since = time.time()
person.last_know = time.time() person.last_know = time.time()
person.memory_points = [] person.memory_points = []
# 初始化性格特征相关字段 # 初始化性格特征相关字段
person.attitude_to_me = 0 person.attitude_to_me = 0
person.attitude_to_me_confidence = 1 person.attitude_to_me_confidence = 1
person.neuroticism = 5 person.neuroticism = 5
person.neuroticism_confidence = 1 person.neuroticism_confidence = 1
person.friendly_value = 50 person.friendly_value = 50
person.friendly_value_confidence = 1 person.friendly_value_confidence = 1
person.rudeness = 50 person.rudeness = 50
person.rudeness_confidence = 1 person.rudeness_confidence = 1
person.conscientiousness = 50 person.conscientiousness = 50
person.conscientiousness_confidence = 1 person.conscientiousness_confidence = 1
person.likeness = 50 person.likeness = 50
person.likeness_confidence = 1 person.likeness_confidence = 1
# 同步到数据库 # 同步到数据库
person.sync_to_database() person.sync_to_database()
logger.info(f"成功注册新用户:{person_id},平台:{platform},昵称:{nickname}") logger.info(f"成功注册新用户:{person_id},平台:{platform},昵称:{nickname}")
return person return person
def __init__(self, platform: str = "", user_id: str = "",person_id: str = "",person_name: str = ""): def __init__(self, platform: str = "", user_id: str = "", person_id: str = "", person_name: str = ""):
if platform == global_config.bot.platform and user_id == global_config.bot.qq_account: if platform == global_config.bot.platform and user_id == global_config.bot.qq_account:
self.is_known = True self.is_known = True
self.person_id = get_person_id(platform, user_id) self.person_id = get_person_id(platform, user_id)
@@ -221,10 +221,10 @@ class Person:
self.nickname = global_config.bot.nickname self.nickname = global_config.bot.nickname
self.person_name = global_config.bot.nickname self.person_name = global_config.bot.nickname
return return
self.user_id = "" self.user_id = ""
self.platform = "" self.platform = ""
if person_id: if person_id:
self.person_id = person_id self.person_id = person_id
elif person_name: elif person_name:
@@ -232,7 +232,7 @@ class Person:
if not self.person_id: if not self.person_id:
self.is_known = False self.is_known = False
logger.warning(f"根据用户名 {person_name} 获取用户ID时不存在用户{person_name}") logger.warning(f"根据用户名 {person_name} 获取用户ID时不存在用户{person_name}")
return return
elif platform and user_id: elif platform and user_id:
self.person_id = get_person_id(platform, user_id) self.person_id = get_person_id(platform, user_id)
self.user_id = user_id self.user_id = user_id
@@ -240,17 +240,16 @@ class Person:
else: else:
logger.error("Person 初始化失败,缺少必要参数") logger.error("Person 初始化失败,缺少必要参数")
raise ValueError("Person 初始化失败,缺少必要参数") raise ValueError("Person 初始化失败,缺少必要参数")
if not is_person_known(person_id=self.person_id): if not is_person_known(person_id=self.person_id):
self.is_known = False self.is_known = False
logger.debug(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识") logger.debug(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
self.person_name = f"未知用户{self.person_id[:4]}" self.person_name = f"未知用户{self.person_id[:4]}"
return return
# raise ValueError(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识") # raise ValueError(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
self.is_known = False self.is_known = False
# 初始化默认值 # 初始化默认值
self.nickname = "" self.nickname = ""
self.person_name: Optional[str] = None self.person_name: Optional[str] = None
@@ -259,47 +258,47 @@ class Person:
self.know_since = None self.know_since = None
self.last_know = None self.last_know = None
self.memory_points = [] self.memory_points = []
# 初始化性格特征相关字段 # 初始化性格特征相关字段
self.attitude_to_me:float = 0 self.attitude_to_me: float = 0
self.attitude_to_me_confidence:float = 1 self.attitude_to_me_confidence: float = 1
self.neuroticism:float = 5 self.neuroticism: float = 5
self.neuroticism_confidence:float = 1 self.neuroticism_confidence: float = 1
self.friendly_value:float = 50 self.friendly_value: float = 50
self.friendly_value_confidence:float = 1 self.friendly_value_confidence: float = 1
self.rudeness:float = 50 self.rudeness: float = 50
self.rudeness_confidence:float = 1 self.rudeness_confidence: float = 1
self.conscientiousness:float = 50 self.conscientiousness: float = 50
self.conscientiousness_confidence:float = 1 self.conscientiousness_confidence: float = 1
self.likeness:float = 50 self.likeness: float = 50
self.likeness_confidence:float = 1 self.likeness_confidence: float = 1
# 从数据库加载数据 # 从数据库加载数据
self.load_from_database() self.load_from_database()
def del_memory(self, category: str, memory_content: str, similarity_threshold: float = 0.95): def del_memory(self, category: str, memory_content: str, similarity_threshold: float = 0.95):
""" """
删除指定分类和记忆内容的记忆点 删除指定分类和记忆内容的记忆点
Args: Args:
category: 记忆分类 category: 记忆分类
memory_content: 要删除的记忆内容 memory_content: 要删除的记忆内容
similarity_threshold: 相似度阈值默认0.9595% similarity_threshold: 相似度阈值默认0.9595%
Returns: Returns:
int: 删除的记忆点数量 int: 删除的记忆点数量
""" """
if not self.memory_points: if not self.memory_points:
return 0 return 0
deleted_count = 0 deleted_count = 0
memory_points_to_keep = [] memory_points_to_keep = []
for memory_point in self.memory_points: for memory_point in self.memory_points:
# 跳过None值 # 跳过None值
if memory_point is None: if memory_point is None:
@@ -310,80 +309,76 @@ class Person:
# 格式不正确,保留原样 # 格式不正确,保留原样
memory_points_to_keep.append(memory_point) memory_points_to_keep.append(memory_point)
continue continue
memory_category = parts[0].strip() memory_category = parts[0].strip()
memory_text = parts[1].strip() memory_text = parts[1].strip()
memory_weight = parts[2].strip() memory_weight = parts[2].strip()
# 检查分类是否匹配 # 检查分类是否匹配
if memory_category != category: if memory_category != category:
memory_points_to_keep.append(memory_point) memory_points_to_keep.append(memory_point)
continue continue
# 计算记忆内容的相似度 # 计算记忆内容的相似度
similarity = calculate_string_similarity(memory_content, memory_text) similarity = calculate_string_similarity(memory_content, memory_text)
# 如果相似度达到阈值,则删除(不添加到保留列表) # 如果相似度达到阈值,则删除(不添加到保留列表)
if similarity >= similarity_threshold: if similarity >= similarity_threshold:
deleted_count += 1 deleted_count += 1
logger.debug(f"删除记忆点: {memory_point} (相似度: {similarity:.4f})") logger.debug(f"删除记忆点: {memory_point} (相似度: {similarity:.4f})")
else: else:
memory_points_to_keep.append(memory_point) memory_points_to_keep.append(memory_point)
# 更新memory_points # 更新memory_points
self.memory_points = memory_points_to_keep self.memory_points = memory_points_to_keep
# 同步到数据库 # 同步到数据库
if deleted_count > 0: if deleted_count > 0:
self.sync_to_database() self.sync_to_database()
logger.info(f"成功删除 {deleted_count} 个记忆点,分类: {category}") logger.info(f"成功删除 {deleted_count} 个记忆点,分类: {category}")
return deleted_count return deleted_count
def get_all_category(self): def get_all_category(self):
category_list = [] category_list = []
for memory in self.memory_points: for memory in self.memory_points:
if memory is None: if memory is None:
continue continue
category = get_catagory_from_memory(memory) category = get_category_from_memory(memory)
if category and category not in category_list: if category and category not in category_list:
category_list.append(category) category_list.append(category)
return category_list return category_list
def get_memory_list_by_category(self, category: str):
def get_memory_list_by_category(self,category:str):
memory_list = [] memory_list = []
for memory in self.memory_points: for memory in self.memory_points:
if memory is None: if memory is None:
continue continue
if get_catagory_from_memory(memory) == category: if get_category_from_memory(memory) == category:
memory_list.append(memory) memory_list.append(memory)
return memory_list return memory_list
def get_random_memory_by_category(self,category:str,num:int=1): def get_random_memory_by_category(self, category: str, num: int = 1):
memory_list = self.get_memory_list_by_category(category) memory_list = self.get_memory_list_by_category(category)
if len(memory_list) < num: if len(memory_list) < num:
return memory_list return memory_list
return random.sample(memory_list, num) return random.sample(memory_list, num)
def load_from_database(self): def load_from_database(self):
"""从数据库加载个人信息数据""" """从数据库加载个人信息数据"""
try: try:
# 查询数据库中的记录 # 查询数据库中的记录
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id) record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
if record: if record:
self.user_id = record.user_id if record.user_id else "" self.user_id = record.user_id or ""
self.platform = record.platform if record.platform else "" self.platform = record.platform or ""
self.is_known = record.is_known if record.is_known else False self.is_known = record.is_known or False
self.nickname = record.nickname if record.nickname else "" self.nickname = record.nickname or ""
self.person_name = record.person_name if record.person_name else self.nickname self.person_name = record.person_name or self.nickname
self.name_reason = record.name_reason if record.name_reason else None self.name_reason = record.name_reason or None
self.know_times = record.know_times if record.know_times else 0 self.know_times = record.know_times or 0
# 处理points字段JSON格式的列表 # 处理points字段JSON格式的列表
if record.memory_points: if record.memory_points:
try: try:
@@ -398,53 +393,53 @@ class Person:
self.memory_points = [] self.memory_points = []
else: else:
self.memory_points = [] self.memory_points = []
# 加载性格特征相关字段 # 加载性格特征相关字段
if record.attitude_to_me and not isinstance(record.attitude_to_me, str): if record.attitude_to_me and not isinstance(record.attitude_to_me, str):
self.attitude_to_me = record.attitude_to_me self.attitude_to_me = record.attitude_to_me
if record.attitude_to_me_confidence is not None: if record.attitude_to_me_confidence is not None:
self.attitude_to_me_confidence = float(record.attitude_to_me_confidence) self.attitude_to_me_confidence = float(record.attitude_to_me_confidence)
if record.friendly_value is not None: if record.friendly_value is not None:
self.friendly_value = float(record.friendly_value) self.friendly_value = float(record.friendly_value)
if record.friendly_value_confidence is not None: if record.friendly_value_confidence is not None:
self.friendly_value_confidence = float(record.friendly_value_confidence) self.friendly_value_confidence = float(record.friendly_value_confidence)
if record.rudeness is not None: if record.rudeness is not None:
self.rudeness = float(record.rudeness) self.rudeness = float(record.rudeness)
if record.rudeness_confidence is not None: if record.rudeness_confidence is not None:
self.rudeness_confidence = float(record.rudeness_confidence) self.rudeness_confidence = float(record.rudeness_confidence)
if record.neuroticism and not isinstance(record.neuroticism, str): if record.neuroticism and not isinstance(record.neuroticism, str):
self.neuroticism = float(record.neuroticism) self.neuroticism = float(record.neuroticism)
if record.neuroticism_confidence is not None: if record.neuroticism_confidence is not None:
self.neuroticism_confidence = float(record.neuroticism_confidence) self.neuroticism_confidence = float(record.neuroticism_confidence)
if record.conscientiousness is not None: if record.conscientiousness is not None:
self.conscientiousness = float(record.conscientiousness) self.conscientiousness = float(record.conscientiousness)
if record.conscientiousness_confidence is not None: if record.conscientiousness_confidence is not None:
self.conscientiousness_confidence = float(record.conscientiousness_confidence) self.conscientiousness_confidence = float(record.conscientiousness_confidence)
if record.likeness is not None: if record.likeness is not None:
self.likeness = float(record.likeness) self.likeness = float(record.likeness)
if record.likeness_confidence is not None: if record.likeness_confidence is not None:
self.likeness_confidence = float(record.likeness_confidence) self.likeness_confidence = float(record.likeness_confidence)
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息") logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
else: else:
self.sync_to_database() self.sync_to_database()
logger.info(f"用户 {self.person_id} 在数据库中不存在,使用默认值并创建") logger.info(f"用户 {self.person_id} 在数据库中不存在,使用默认值并创建")
except Exception as e: except Exception as e:
logger.error(f"从数据库加载用户 {self.person_id} 信息时出错: {e}") logger.error(f"从数据库加载用户 {self.person_id} 信息时出错: {e}")
# 出错时保持默认值 # 出错时保持默认值
def sync_to_database(self): def sync_to_database(self):
"""将所有属性同步回数据库""" """将所有属性同步回数据库"""
if not self.is_known: if not self.is_known:
@@ -452,34 +447,38 @@ class Person:
try: try:
# 准备数据 # 准备数据
data = { data = {
'person_id': self.person_id, "person_id": self.person_id,
'is_known': self.is_known, "is_known": self.is_known,
'platform': self.platform, "platform": self.platform,
'user_id': self.user_id, "user_id": self.user_id,
'nickname': self.nickname, "nickname": self.nickname,
'person_name': self.person_name, "person_name": self.person_name,
'name_reason': self.name_reason, "name_reason": self.name_reason,
'know_times': self.know_times, "know_times": self.know_times,
'know_since': self.know_since, "know_since": self.know_since,
'last_know': self.last_know, "last_know": self.last_know,
'memory_points': json.dumps([point for point in self.memory_points if point is not None], ensure_ascii=False) if self.memory_points else json.dumps([], ensure_ascii=False), "memory_points": json.dumps(
'attitude_to_me': self.attitude_to_me, [point for point in self.memory_points if point is not None], ensure_ascii=False
'attitude_to_me_confidence': self.attitude_to_me_confidence, )
'friendly_value': self.friendly_value, if self.memory_points
'friendly_value_confidence': self.friendly_value_confidence, else json.dumps([], ensure_ascii=False),
'rudeness': self.rudeness, "attitude_to_me": self.attitude_to_me,
'rudeness_confidence': self.rudeness_confidence, "attitude_to_me_confidence": self.attitude_to_me_confidence,
'neuroticism': self.neuroticism, "friendly_value": self.friendly_value,
'neuroticism_confidence': self.neuroticism_confidence, "friendly_value_confidence": self.friendly_value_confidence,
'conscientiousness': self.conscientiousness, "rudeness": self.rudeness,
'conscientiousness_confidence': self.conscientiousness_confidence, "rudeness_confidence": self.rudeness_confidence,
'likeness': self.likeness, "neuroticism": self.neuroticism,
'likeness_confidence': self.likeness_confidence, "neuroticism_confidence": self.neuroticism_confidence,
"conscientiousness": self.conscientiousness,
"conscientiousness_confidence": self.conscientiousness_confidence,
"likeness": self.likeness,
"likeness_confidence": self.likeness_confidence,
} }
# 检查记录是否存在 # 检查记录是否存在
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id) record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
if record: if record:
# 更新现有记录 # 更新现有记录
for field, value in data.items(): for field, value in data.items():
@@ -491,10 +490,10 @@ class Person:
# 创建新记录 # 创建新记录
PersonInfo.create(**data) PersonInfo.create(**data)
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库") logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
except Exception as e: except Exception as e:
logger.error(f"同步用户 {self.person_id} 信息到数据库时出错: {e}") logger.error(f"同步用户 {self.person_id} 信息到数据库时出错: {e}")
def build_relationship(self): def build_relationship(self):
if not self.is_known: if not self.is_known:
return "" return ""
@@ -505,22 +504,21 @@ class Person:
nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})" nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})"
relation_info = "" relation_info = ""
attitude_info = "" attitude_info = ""
if self.attitude_to_me: if self.attitude_to_me:
if self.attitude_to_me > 8: if self.attitude_to_me > 8:
attitude_info = f"{self.person_name}对你的态度十分好," attitude_info = f"{self.person_name}对你的态度十分好,"
elif self.attitude_to_me > 5: elif self.attitude_to_me > 5:
attitude_info = f"{self.person_name}对你的态度较好," attitude_info = f"{self.person_name}对你的态度较好,"
if self.attitude_to_me < -8: if self.attitude_to_me < -8:
attitude_info = f"{self.person_name}对你的态度十分恶劣," attitude_info = f"{self.person_name}对你的态度十分恶劣,"
elif self.attitude_to_me < -4: elif self.attitude_to_me < -4:
attitude_info = f"{self.person_name}对你的态度不好," attitude_info = f"{self.person_name}对你的态度不好,"
elif self.attitude_to_me < 0: elif self.attitude_to_me < 0:
attitude_info = f"{self.person_name}对你的态度一般," attitude_info = f"{self.person_name}对你的态度一般,"
neuroticism_info = "" neuroticism_info = ""
if self.neuroticism: if self.neuroticism:
if self.neuroticism > 8: if self.neuroticism > 8:
@@ -533,29 +531,28 @@ class Person:
neuroticism_info = f"{self.person_name}的情绪比较稳定," neuroticism_info = f"{self.person_name}的情绪比较稳定,"
else: else:
neuroticism_info = f"{self.person_name}的情绪非常稳定,毫无波动" neuroticism_info = f"{self.person_name}的情绪非常稳定,毫无波动"
points_text = "" points_text = ""
category_list = self.get_all_category() category_list = self.get_all_category()
for category in category_list: for category in category_list:
random_memory = self.get_random_memory_by_category(category,1)[0] random_memory = self.get_random_memory_by_category(category, 1)[0]
if random_memory: if random_memory:
points_text = f"有关 {category} 的记忆:{get_memory_content_from_memory(random_memory)}" points_text = f"有关 {category} 的记忆:{get_memory_content_from_memory(random_memory)}"
break break
points_info = "" points_info = ""
if points_text: if points_text:
points_info = f"你还记得有关{self.person_name}的最近记忆:{points_text}" points_info = f"你还记得有关{self.person_name}的最近记忆:{points_text}"
if not (nickname_str or attitude_info or neuroticism_info or points_info): if not (nickname_str or attitude_info or neuroticism_info or points_info):
return "" return ""
relation_info = f"{self.person_name}:{nickname_str}{attitude_info}{neuroticism_info}{points_info}" relation_info = f"{self.person_name}:{nickname_str}{attitude_info}{neuroticism_info}{points_info}"
return relation_info return relation_info
class PersonInfoManager: class PersonInfoManager:
def __init__(self): def __init__(self):
self.person_name_list = {} self.person_name_list = {}
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name") self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
try: try:
@@ -580,8 +577,6 @@ class PersonInfoManager:
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)") logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
except Exception as e: except Exception as e:
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}") logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
@staticmethod @staticmethod
def _extract_json_from_text(text: str) -> dict: def _extract_json_from_text(text: str) -> dict:
@@ -717,6 +712,6 @@ class PersonInfoManager:
person.sync_to_database() person.sync_to_database()
self.person_name_list[person_id] = unique_nickname self.person_name_list[person_id] = unique_nickname
return {"nickname": unique_nickname, "reason": "使用用户原始昵称作为默认值"} return {"nickname": unique_nickname, "reason": "使用用户原始昵称作为默认值"}
person_info_manager = PersonInfoManager() person_info_manager = PersonInfoManager()

View File

@@ -1,20 +1,13 @@
import random import json
from json_repair import repair_json
from typing import Tuple from typing import Tuple
# 导入新插件系统
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
# 导入依赖的系统组件
from src.common.logger import get_logger from src.common.logger import get_logger
# 导入API模块 - 标准Python包方式
from src.plugin_system.apis import emoji_api, llm_api, message_api
# NoReplyAction已集成到heartFC_chat.py中不再需要导入
from src.config.config import global_config from src.config.config import global_config
from src.person_info.person_info import Person, get_memory_content_from_memory, get_weight_from_memory from src.person_info.person_info import Person, get_memory_content_from_memory, get_weight_from_memory
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
import json from src.plugin_system import BaseAction, ActionActivationType
from json_repair import repair_json from src.plugin_system.apis import llm_api
logger = get_logger("relation") logger = get_logger("relation")
@@ -39,10 +32,9 @@ def init_prompt():
{{ {{
"category": "分类名称" "category": "分类名称"
}} """, }} """,
"relation_category" "relation_category",
) )
Prompt( Prompt(
""" """
以下是有关{category}的现有记忆: 以下是有关{category}的现有记忆:
@@ -73,7 +65,7 @@ def init_prompt():
现在请你根据情况选出合适的修改方式并输出json不要输出其他内容 现在请你根据情况选出合适的修改方式并输出json不要输出其他内容
""", """,
"relation_category_update" "relation_category_update",
) )
@@ -98,17 +90,14 @@ class BuildRelationAction(BaseAction):
""" """
# 动作参数定义 # 动作参数定义
action_parameters = { action_parameters = {"person_name": "需要了解或记忆的人的名称", "impression": "需要了解的对某人的记忆或印象"}
"person_name":"需要了解或记忆的人的名称",
"impression":"需要了解的对某人的记忆或印象"
}
# 动作使用场景 # 动作使用场景
action_require = [ action_require = [
"了解对于某人的记忆,并添加到你对对方的印象中", "了解对于某人的记忆,并添加到你对对方的印象中",
"对方与有明确提到有关其自身的事件", "对方与有明确提到有关其自身的事件",
"对方有提到其个人信息,包括喜好,身份,等等", "对方有提到其个人信息,包括喜好,身份,等等",
"对方希望你记住对方的信息" "对方希望你记住对方的信息",
] ]
# 关联类型 # 关联类型
@@ -129,9 +118,7 @@ class BuildRelationAction(BaseAction):
if not person.is_known: if not person.is_known:
logger.warning(f"{self.log_prefix} 用户 {person_name} 不存在,跳过添加记忆") logger.warning(f"{self.log_prefix} 用户 {person_name} 不存在,跳过添加记忆")
return False, f"用户 {person_name} 不存在,跳过添加记忆" return False, f"用户 {person_name} 不存在,跳过添加记忆"
category_list = person.get_all_category() category_list = person.get_all_category()
if not category_list: if not category_list:
category_list_str = "无分类" category_list_str = "无分类"
@@ -142,9 +129,8 @@ class BuildRelationAction(BaseAction):
"relation_category", "relation_category",
category_list=category_list_str, category_list=category_list_str,
memory_point=impression, memory_point=impression,
person_name=person.person_name person_name=person.person_name,
) )
if global_config.debug.show_prompt: if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}") logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
@@ -161,84 +147,76 @@ class BuildRelationAction(BaseAction):
success, category, _, _ = await llm_api.generate_with_model( success, category, _, _ = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="relation.category" prompt, model_config=chat_model_config, request_type="relation.category"
) )
category_data = json.loads(repair_json(category)) category_data = json.loads(repair_json(category))
category = category_data.get("category", "") category = category_data.get("category", "")
if not category: if not category:
logger.warning(f"{self.log_prefix} LLM未给出分类跳过添加记忆") logger.warning(f"{self.log_prefix} LLM未给出分类跳过添加记忆")
return False, "LLM未给出分类跳过添加记忆" return False, "LLM未给出分类跳过添加记忆"
# 第二部分:更新记忆 # 第二部分:更新记忆
memory_list = person.get_memory_list_by_category(category) memory_list = person.get_memory_list_by_category(category)
if not memory_list: if not memory_list:
logger.info(f"{self.log_prefix} {person.person_name}{category} 的记忆为空,进行创建") logger.info(f"{self.log_prefix} {person.person_name}{category} 的记忆为空,进行创建")
person.memory_points.append(f"{category}:{impression}:1.0") person.memory_points.append(f"{category}:{impression}:1.0")
person.sync_to_database() person.sync_to_database()
return True, f"未找到分类为{category}的记忆点,进行添加" return True, f"未找到分类为{category}的记忆点,进行添加"
memory_list_str = "" memory_list_str = ""
memory_list_id = {} memory_list_id = {}
id = 1 for id, memory in enumerate(memory_list, start=1):
for memory in memory_list:
memory_content = get_memory_content_from_memory(memory) memory_content = get_memory_content_from_memory(memory)
memory_list_str += f"{id}. {memory_content}\n" memory_list_str += f"{id}. {memory_content}\n"
memory_list_id[id] = memory memory_list_id[id] = memory
id += 1
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
"relation_category_update", "relation_category_update",
category=category, category=category,
memory_list=memory_list_str, memory_list=memory_list_str,
memory_point=impression, memory_point=impression,
person_name=person.person_name person_name=person.person_name,
) )
if global_config.debug.show_prompt: if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}") logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
else: else:
logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}") logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
chat_model_config = models.get("utils") chat_model_config = models.get("utils")
success, update_memory, _, _ = await llm_api.generate_with_model( success, update_memory, _, _ = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="relation.category.update" prompt, model_config=chat_model_config, request_type="relation.category.update" # type: ignore
) )
update_memory_data = json.loads(repair_json(update_memory)) update_memory_data = json.loads(repair_json(update_memory))
new_memory = update_memory_data.get("new_memory", "") new_memory = update_memory_data.get("new_memory", "")
memory_id = update_memory_data.get("memory_id", "") memory_id = update_memory_data.get("memory_id", "")
integrate_memory = update_memory_data.get("integrate_memory", "") integrate_memory = update_memory_data.get("integrate_memory", "")
if new_memory: if new_memory:
# 新记忆 # 新记忆
person.memory_points.append(f"{category}:{new_memory}:1.0") person.memory_points.append(f"{category}:{new_memory}:1.0")
person.sync_to_database() person.sync_to_database()
return True, f"{person.person_name}新增记忆点: {new_memory}" return True, f"{person.person_name}新增记忆点: {new_memory}"
elif memory_id and integrate_memory: elif memory_id and integrate_memory:
# 现存或冲突记忆 # 现存或冲突记忆
memory = memory_list_id[memory_id] memory = memory_list_id[memory_id]
memory_content = get_memory_content_from_memory(memory) memory_content = get_memory_content_from_memory(memory)
del_count = person.del_memory(category,memory_content) del_count = person.del_memory(category, memory_content)
if del_count > 0: if del_count > 0:
logger.info(f"{self.log_prefix} 删除记忆点: {memory_content}") logger.info(f"{self.log_prefix} 删除记忆点: {memory_content}")
memory_weight = get_weight_from_memory(memory) memory_weight = get_weight_from_memory(memory)
person.memory_points.append(f"{category}:{integrate_memory}:{memory_weight + 1.0}") person.memory_points.append(f"{category}:{integrate_memory}:{memory_weight + 1.0}")
person.sync_to_database() person.sync_to_database()
return True, f"更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}" return True, f"更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}"
else: else:
logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}") logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}")
return False, f"删除{person.person_name}的记忆点失败: {memory_content}" return False, f"删除{person.person_name}的记忆点失败: {memory_content}"
return True, "关系动作执行成功" return True, "关系动作执行成功"
@@ -248,4 +226,4 @@ class BuildRelationAction(BaseAction):
# 还缺一个关系的太多遗忘和对应的提取 # 还缺一个关系的太多遗忘和对应的提取
init_prompt() init_prompt()

View File

@@ -2,7 +2,7 @@ from src.plugin_system.apis.plugin_register_api import register_plugin
from src.plugin_system.base.base_plugin import BasePlugin from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system.base.component_types import ComponentInfo from src.plugin_system.base.component_types import ComponentInfo
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode from src.plugin_system.base.base_action import BaseAction, ActionActivationType
from src.plugin_system.base.config_types import ConfigField from src.plugin_system.base.config_types import ConfigField
from typing import Tuple, List, Type from typing import Tuple, List, Type