feat:新增麦麦好奇功能,优化记忆构建
This commit is contained in:
988
src/memory_system/Hippocampus.py
Normal file
988
src/memory_system/Hippocampus.py
Normal file
@@ -0,0 +1,988 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
import re
|
||||
import jieba
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Set
|
||||
from collections import Counter
|
||||
import traceback
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.common.database.database_model import GraphNodes, GraphEdges # Peewee Models导入
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.utils import cut_key_words
|
||||
|
||||
|
||||
# 添加cosine_similarity函数
|
||||
def cosine_similarity(v1, v2):
|
||||
"""计算余弦相似度"""
|
||||
dot_product = np.dot(v1, v2)
|
||||
norm1 = np.linalg.norm(v1)
|
||||
norm2 = np.linalg.norm(v2)
|
||||
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
def calculate_information_content(text):
|
||||
"""计算文本的信息量(熵)"""
|
||||
char_count = Counter(text)
|
||||
total_chars = len(text)
|
||||
if total_chars == 0:
|
||||
return 0
|
||||
entropy = 0
|
||||
for count in char_count.values():
|
||||
probability = count / total_chars
|
||||
entropy -= probability * math.log2(probability)
|
||||
|
||||
return entropy
|
||||
|
||||
|
||||
logger = get_logger("memory")
|
||||
|
||||
|
||||
class MemoryGraph:
|
||||
def __init__(self):
|
||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||
|
||||
def connect_dot(self, concept1, concept2):
|
||||
# 避免自连接
|
||||
if concept1 == concept2:
|
||||
return
|
||||
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
# 如果边已存在,增加 strength
|
||||
if self.G.has_edge(concept1, concept2):
|
||||
self.G[concept1][concept2]["strength"] = self.G[concept1][concept2].get("strength", 1) + 1
|
||||
# 更新最后修改时间
|
||||
self.G[concept1][concept2]["last_modified"] = current_time
|
||||
else:
|
||||
# 如果是新边,初始化 strength 为 1
|
||||
self.G.add_edge(
|
||||
concept1,
|
||||
concept2,
|
||||
strength=1,
|
||||
created_time=current_time, # 添加创建时间
|
||||
last_modified=current_time,
|
||||
) # 添加最后修改时间
|
||||
|
||||
async def add_dot(self, concept, memory, hippocampus_instance=None):
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
if concept in self.G:
|
||||
if "memory_items" in self.G.nodes[concept]:
|
||||
# 获取现有的记忆项(已经是str格式)
|
||||
existing_memory = self.G.nodes[concept]["memory_items"]
|
||||
# 简单连接新旧记忆
|
||||
new_memory_str = f"{existing_memory} | {memory}"
|
||||
self.G.nodes[concept]["memory_items"] = new_memory_str
|
||||
logger.info(f"节点 {concept} 记忆内容已简单拼接并更新:{new_memory_str}")
|
||||
else:
|
||||
self.G.nodes[concept]["memory_items"] = str(memory)
|
||||
# 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time
|
||||
if "created_time" not in self.G.nodes[concept]:
|
||||
self.G.nodes[concept]["created_time"] = current_time
|
||||
logger.info(f"节点 {concept} 创建新记忆:{str(memory)}")
|
||||
# 更新最后修改时间
|
||||
self.G.nodes[concept]["last_modified"] = current_time
|
||||
else:
|
||||
# 如果是新节点,创建新的记忆字符串
|
||||
self.G.add_node(
|
||||
concept,
|
||||
memory_items=str(memory),
|
||||
weight=1.0, # 新节点初始权重为1.0
|
||||
created_time=current_time, # 添加创建时间
|
||||
last_modified=current_time,
|
||||
) # 添加最后修改时间
|
||||
logger.info(f"新节点 {concept} 已添加,记忆内容已写入:{str(memory)}")
|
||||
|
||||
def get_dot(self, concept):
|
||||
# 检查节点是否存在于图中
|
||||
return (concept, self.G.nodes[concept]) if concept in self.G else None
|
||||
|
||||
def get_related_item(self, topic, depth=1):
|
||||
if topic not in self.G:
|
||||
return [], []
|
||||
|
||||
first_layer_items = []
|
||||
second_layer_items = []
|
||||
|
||||
# 获取相邻节点
|
||||
neighbors = list(self.G.neighbors(topic))
|
||||
|
||||
# 获取当前节点的记忆项
|
||||
node_data = self.get_dot(topic)
|
||||
if node_data:
|
||||
_, data = node_data
|
||||
if "memory_items" in data:
|
||||
# 直接使用完整的记忆内容
|
||||
if memory_items := data["memory_items"]:
|
||||
first_layer_items.append(memory_items)
|
||||
|
||||
# 只在depth=2时获取第二层记忆
|
||||
if depth >= 2:
|
||||
# 获取相邻节点的记忆项
|
||||
for neighbor in neighbors:
|
||||
if node_data := self.get_dot(neighbor):
|
||||
_, data = node_data
|
||||
if "memory_items" in data:
|
||||
# 直接使用完整的记忆内容
|
||||
if memory_items := data["memory_items"]:
|
||||
second_layer_items.append(memory_items)
|
||||
|
||||
return first_layer_items, second_layer_items
|
||||
|
||||
|
||||
@property
|
||||
def dots(self):
|
||||
# 返回所有节点对应的 Memory_dot 对象
|
||||
return [self.get_dot(node) for node in self.G.nodes()]
|
||||
|
||||
def forget_topic(self, topic):
|
||||
"""随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点"""
|
||||
if topic not in self.G:
|
||||
return None
|
||||
|
||||
# 获取话题节点数据
|
||||
node_data = self.G.nodes[topic]
|
||||
|
||||
# 删除整个节点
|
||||
self.G.remove_node(topic)
|
||||
# 如果节点存在memory_items
|
||||
if "memory_items" in node_data:
|
||||
if memory_items := node_data["memory_items"]:
|
||||
return (
|
||||
f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..."
|
||||
if len(memory_items) > 50
|
||||
else f"删除了节点 {topic} 的完整记忆: {memory_items}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# 海马体
|
||||
class Hippocampus:
|
||||
def __init__(self):
|
||||
self.memory_graph = MemoryGraph()
|
||||
self.entorhinal_cortex: EntorhinalCortex = None # type: ignore
|
||||
self.parahippocampal_gyrus: ParahippocampalGyrus = None # type: ignore
|
||||
|
||||
def initialize(self):
|
||||
# 初始化子组件
|
||||
self.entorhinal_cortex = EntorhinalCortex(self)
|
||||
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
|
||||
# 从数据库加载记忆图
|
||||
self.entorhinal_cortex.sync_memory_from_db()
|
||||
|
||||
def get_all_node_names(self) -> list:
|
||||
"""获取记忆图中所有节点的名字列表"""
|
||||
return list(self.memory_graph.G.nodes())
|
||||
|
||||
@staticmethod
|
||||
def calculate_node_hash(concept, memory_items) -> int:
|
||||
"""计算节点的特征值"""
|
||||
# memory_items已经是str格式,直接按分隔符分割
|
||||
if memory_items:
|
||||
unique_items = {item.strip() for item in memory_items.split(" | ") if item.strip()}
|
||||
else:
|
||||
unique_items = set()
|
||||
|
||||
# 使用frozenset来保证顺序一致性
|
||||
content = f"{concept}:{frozenset(unique_items)}"
|
||||
return hash(content)
|
||||
|
||||
@staticmethod
|
||||
def calculate_edge_hash(source, target) -> int:
|
||||
"""计算边的特征值"""
|
||||
# 直接使用元组,保证顺序一致性
|
||||
return hash((source, target))
|
||||
|
||||
|
||||
|
||||
|
||||
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
||||
"""从关键词获取相关记忆。
|
||||
|
||||
Args:
|
||||
keyword (str): 关键词
|
||||
max_depth (int, optional): 记忆检索深度,默认为2。1表示只获取直接相关的记忆,2表示获取间接相关的记忆。
|
||||
|
||||
Returns:
|
||||
list: 记忆列表,每个元素是一个元组 (topic, memory_content, similarity)
|
||||
- topic: str, 记忆主题
|
||||
- memory_content: str, 该主题下的完整记忆内容
|
||||
- similarity: float, 与关键词的相似度
|
||||
"""
|
||||
if not keyword:
|
||||
return []
|
||||
|
||||
# 获取所有节点
|
||||
all_nodes = list(self.memory_graph.G.nodes())
|
||||
memories = []
|
||||
|
||||
# 计算关键词的词集合
|
||||
keyword_words = set(jieba.cut(keyword))
|
||||
|
||||
# 遍历所有节点,计算相似度
|
||||
for node in all_nodes:
|
||||
node_words = set(jieba.cut(node))
|
||||
all_words = keyword_words | node_words
|
||||
v1 = [1 if word in keyword_words else 0 for word in all_words]
|
||||
v2 = [1 if word in node_words else 0 for word in all_words]
|
||||
similarity = cosine_similarity(v1, v2)
|
||||
|
||||
# 如果相似度超过阈值,获取该节点的记忆
|
||||
if similarity >= 0.3: # 可以调整这个阈值
|
||||
node_data = self.memory_graph.G.nodes[node]
|
||||
# 直接使用完整的记忆内容
|
||||
if memory_items := node_data.get("memory_items", ""):
|
||||
memories.append((node, memory_items, similarity))
|
||||
|
||||
# 按相似度降序排序
|
||||
memories.sort(key=lambda x: x[2], reverse=True)
|
||||
return memories
|
||||
|
||||
async def get_keywords_from_text(self, text: str) -> Tuple[List[str], List]:
|
||||
"""从文本中提取关键词。
|
||||
|
||||
Args:
|
||||
text (str): 输入文本
|
||||
fast_retrieval (bool, optional): 是否使用快速检索。默认为False。
|
||||
如果为True,使用jieba分词提取关键词,速度更快但可能不够准确。
|
||||
如果为False,使用LLM提取关键词,速度较慢但更准确。
|
||||
"""
|
||||
if not text:
|
||||
return [], []
|
||||
|
||||
# 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算
|
||||
text_length = len(text)
|
||||
topic_num: int | list[int] = 0
|
||||
|
||||
keywords_lite = cut_key_words(text)
|
||||
if keywords_lite:
|
||||
logger.debug(f"提取关键词极简版: {keywords_lite}")
|
||||
|
||||
if text_length <= 12:
|
||||
topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本)
|
||||
elif text_length <= 20:
|
||||
topic_num = [2, 4] # 11-20字符: 2个关键词 (22.76%的文本)
|
||||
elif text_length <= 30:
|
||||
topic_num = [3, 5] # 21-30字符: 3个关键词 (10.33%的文本)
|
||||
elif text_length <= 50:
|
||||
topic_num = [4, 5] # 31-50字符: 4个关键词 (9.79%的文本)
|
||||
else:
|
||||
topic_num = 5 # 51+字符: 5个关键词 (其余长文本)
|
||||
|
||||
topics_response, _ = await self.model_small.generate_response_async(self.find_topic_llm(text, topic_num))
|
||||
|
||||
# 提取关键词
|
||||
keywords = re.findall(r"<([^>]+)>", topics_response)
|
||||
if not keywords:
|
||||
keywords = []
|
||||
else:
|
||||
keywords = [
|
||||
keyword.strip()
|
||||
for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||
if keyword.strip()
|
||||
]
|
||||
|
||||
if keywords:
|
||||
logger.debug(f"提取关键词: {keywords}")
|
||||
|
||||
return keywords, keywords_lite
|
||||
|
||||
async def get_memory_from_topic(
|
||||
self,
|
||||
keywords: list[str],
|
||||
max_memory_num: int = 3,
|
||||
max_memory_length: int = 2,
|
||||
max_depth: int = 3,
|
||||
) -> list:
|
||||
"""从文本中提取关键词并获取相关记忆。
|
||||
|
||||
Args:
|
||||
keywords (list): 输入文本
|
||||
max_memory_num (int, optional): 返回的记忆条目数量上限。默认为3,表示最多返回3条与输入文本相关度最高的记忆。
|
||||
max_memory_length (int, optional): 每个主题最多返回的记忆条目数量。默认为2,表示每个主题最多返回2条相似度最高的记忆。
|
||||
max_depth (int, optional): 记忆检索深度。默认为3。值越大,检索范围越广,可以获取更多间接相关的记忆,但速度会变慢。
|
||||
|
||||
Returns:
|
||||
list: 记忆列表,每个元素是一个元组 (topic, memory_content)
|
||||
- topic: str, 记忆主题
|
||||
- memory_content: str, 该主题下的完整记忆内容
|
||||
"""
|
||||
if not keywords:
|
||||
return []
|
||||
|
||||
logger.info(f"提取的关键词: {', '.join(keywords)}")
|
||||
|
||||
# 过滤掉不存在于记忆图中的关键词
|
||||
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
||||
if not valid_keywords:
|
||||
logger.debug("没有找到有效的关键词节点")
|
||||
return []
|
||||
|
||||
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||
|
||||
# 从每个关键词获取记忆
|
||||
activate_map = {} # 存储每个词的累计激活值
|
||||
|
||||
# 对每个关键词进行扩散式检索
|
||||
for keyword in valid_keywords:
|
||||
logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):")
|
||||
# 初始化激活值
|
||||
activation_values = {keyword: 1.0}
|
||||
# 记录已访问的节点
|
||||
visited_nodes = {keyword}
|
||||
# 待处理的节点队列,每个元素是(节点, 激活值, 当前深度)
|
||||
nodes_to_process = [(keyword, 1.0, 0)]
|
||||
|
||||
while nodes_to_process:
|
||||
current_node, current_activation, current_depth = nodes_to_process.pop(0)
|
||||
|
||||
# 如果激活值小于0或超过最大深度,停止扩散
|
||||
if current_activation <= 0 or current_depth >= max_depth:
|
||||
continue
|
||||
|
||||
# 获取当前节点的所有邻居
|
||||
neighbors = list(self.memory_graph.G.neighbors(current_node))
|
||||
|
||||
for neighbor in neighbors:
|
||||
if neighbor in visited_nodes:
|
||||
continue
|
||||
|
||||
# 获取连接强度
|
||||
edge_data = self.memory_graph.G[current_node][neighbor]
|
||||
strength = edge_data.get("strength", 1)
|
||||
|
||||
# 计算新的激活值
|
||||
new_activation = current_activation - (1 / strength)
|
||||
|
||||
if new_activation > 0:
|
||||
activation_values[neighbor] = new_activation
|
||||
visited_nodes.add(neighbor)
|
||||
nodes_to_process.append((neighbor, new_activation, current_depth + 1))
|
||||
# logger.debug(
|
||||
# f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})"
|
||||
# ) # noqa: E501
|
||||
|
||||
# 更新激活映射
|
||||
for node, activation_value in activation_values.items():
|
||||
if activation_value > 0:
|
||||
if node in activate_map:
|
||||
activate_map[node] += activation_value
|
||||
else:
|
||||
activate_map[node] = activation_value
|
||||
|
||||
# 基于激活值平方的独立概率选择
|
||||
remember_map = {}
|
||||
# logger.info("基于激活值平方的归一化选择:")
|
||||
|
||||
# 计算所有激活值的平方和
|
||||
total_squared_activation = sum(activation**2 for activation in activate_map.values())
|
||||
if total_squared_activation > 0:
|
||||
# 计算归一化的激活值
|
||||
normalized_activations = {
|
||||
node: (activation**2) / total_squared_activation for node, activation in activate_map.items()
|
||||
}
|
||||
|
||||
# 按归一化激活值排序并选择前max_memory_num个
|
||||
sorted_nodes = sorted(normalized_activations.items(), key=lambda x: x[1], reverse=True)[:max_memory_num]
|
||||
|
||||
# 将选中的节点添加到remember_map
|
||||
for node, normalized_activation in sorted_nodes:
|
||||
remember_map[node] = activate_map[node] # 使用原始激活值
|
||||
logger.debug(
|
||||
f"节点 '{node}' (归一化激活值: {normalized_activation:.2f}, 激活值: {activate_map[node]:.2f})"
|
||||
)
|
||||
else:
|
||||
logger.info("没有有效的激活值")
|
||||
|
||||
# 从选中的节点中提取记忆
|
||||
all_memories = []
|
||||
# logger.info("开始从选中的节点中提取记忆:")
|
||||
for node, activation in remember_map.items():
|
||||
logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):")
|
||||
node_data = self.memory_graph.G.nodes[node]
|
||||
if memory_items := node_data.get("memory_items", ""):
|
||||
logger.debug("节点包含完整记忆")
|
||||
# 计算记忆与关键词的相似度
|
||||
memory_words = set(jieba.cut(memory_items))
|
||||
text_words = set(keywords)
|
||||
if all_words := memory_words | text_words:
|
||||
# 计算相似度(虽然这里没有使用,但保持逻辑一致性)
|
||||
v1 = [1 if word in memory_words else 0 for word in all_words]
|
||||
v2 = [1 if word in text_words else 0 for word in all_words]
|
||||
_ = cosine_similarity(v1, v2) # 计算但不使用,用_表示
|
||||
|
||||
# 添加完整记忆到结果中
|
||||
all_memories.append((node, memory_items, activation))
|
||||
else:
|
||||
logger.info("节点没有记忆")
|
||||
|
||||
# 去重(基于记忆内容)
|
||||
logger.debug("开始记忆去重:")
|
||||
seen_memories = set()
|
||||
unique_memories = []
|
||||
for topic, memory_items, activation_value in all_memories:
|
||||
# memory_items现在是完整的字符串格式
|
||||
memory = memory_items or ""
|
||||
if memory not in seen_memories:
|
||||
seen_memories.add(memory)
|
||||
unique_memories.append((topic, memory_items, activation_value))
|
||||
logger.debug(f"保留记忆: {memory} (来自节点: {topic}, 激活值: {activation_value:.2f})")
|
||||
else:
|
||||
logger.debug(f"跳过重复记忆: {memory} (来自节点: {topic})")
|
||||
|
||||
# 转换为(关键词, 记忆)格式
|
||||
result = []
|
||||
for topic, memory_items, _ in unique_memories:
|
||||
# memory_items现在是完整的字符串格式
|
||||
memory = memory_items or ""
|
||||
result.append((topic, memory))
|
||||
logger.debug(f"选中记忆: {memory} (来自节点: {topic})")
|
||||
|
||||
return result
|
||||
|
||||
async def get_activate_from_text(
|
||||
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
|
||||
) -> tuple[float, list[str], list[str]]:
|
||||
"""从文本中提取关键词并获取相关记忆。
|
||||
|
||||
Args:
|
||||
text (str): 输入文本
|
||||
max_depth (int, optional): 记忆检索深度。默认为2。
|
||||
fast_retrieval (bool, optional): 是否使用快速检索。默认为False。
|
||||
如果为True,使用jieba分词和TF-IDF提取关键词,速度更快但可能不够准确。
|
||||
如果为False,使用LLM提取关键词,速度较慢但更准确。
|
||||
|
||||
Returns:
|
||||
float: 激活节点数与总节点数的比值
|
||||
list[str]: 有效的关键词
|
||||
"""
|
||||
keywords, keywords_lite = await self.get_keywords_from_text(text)
|
||||
|
||||
# 过滤掉不存在于记忆图中的关键词
|
||||
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
||||
if not valid_keywords:
|
||||
# logger.info("没有找到有效的关键词节点")
|
||||
return 0, keywords, keywords_lite
|
||||
|
||||
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||
|
||||
# 从每个关键词获取记忆
|
||||
activate_map = {} # 存储每个词的累计激活值
|
||||
|
||||
# 对每个关键词进行扩散式检索
|
||||
for keyword in valid_keywords:
|
||||
logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):")
|
||||
# 初始化激活值
|
||||
activation_values = {keyword: 1.5}
|
||||
# 记录已访问的节点
|
||||
visited_nodes = {keyword}
|
||||
# 待处理的节点队列,每个元素是(节点, 激活值, 当前深度)
|
||||
nodes_to_process = [(keyword, 1.0, 0)]
|
||||
|
||||
while nodes_to_process:
|
||||
current_node, current_activation, current_depth = nodes_to_process.pop(0)
|
||||
|
||||
# 如果激活值小于0或超过最大深度,停止扩散
|
||||
if current_activation <= 0 or current_depth >= max_depth:
|
||||
continue
|
||||
|
||||
# 获取当前节点的所有邻居
|
||||
neighbors = list(self.memory_graph.G.neighbors(current_node))
|
||||
|
||||
for neighbor in neighbors:
|
||||
if neighbor in visited_nodes:
|
||||
continue
|
||||
|
||||
# 获取连接强度
|
||||
edge_data = self.memory_graph.G[current_node][neighbor]
|
||||
strength = edge_data.get("strength", 1)
|
||||
|
||||
# 计算新的激活值
|
||||
new_activation = current_activation - (1 / strength)
|
||||
|
||||
if new_activation > 0:
|
||||
activation_values[neighbor] = new_activation
|
||||
visited_nodes.add(neighbor)
|
||||
nodes_to_process.append((neighbor, new_activation, current_depth + 1))
|
||||
# logger.debug(
|
||||
# f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})") # noqa: E501
|
||||
|
||||
# 更新激活映射
|
||||
for node, activation_value in activation_values.items():
|
||||
if activation_value > 0:
|
||||
if node in activate_map:
|
||||
activate_map[node] += activation_value
|
||||
else:
|
||||
activate_map[node] = activation_value
|
||||
|
||||
# 输出激活映射
|
||||
# logger.info("激活映射统计:")
|
||||
# for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True):
|
||||
# logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}")
|
||||
|
||||
# 计算激活节点数与总节点数的比值
|
||||
total_activation = sum(activate_map.values())
|
||||
# logger.debug(f"总激活值: {total_activation:.2f}")
|
||||
total_nodes = len(self.memory_graph.G.nodes())
|
||||
# activated_nodes = len(activate_map)
|
||||
activation_ratio = total_activation / total_nodes if total_nodes > 0 else 0
|
||||
activation_ratio = activation_ratio * 50
|
||||
logger.debug(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}")
|
||||
|
||||
return activation_ratio, keywords, keywords_lite
|
||||
|
||||
|
||||
# 负责海马体与其他部分的交互
|
||||
class EntorhinalCortex:
|
||||
def __init__(self, hippocampus: Hippocampus):
|
||||
self.hippocampus = hippocampus
|
||||
self.memory_graph = hippocampus.memory_graph
|
||||
|
||||
async def sync_memory_to_db(self):
|
||||
"""将记忆图同步到数据库"""
|
||||
start_time = time.time()
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
db_nodes = {node.concept: node for node in GraphNodes.select()}
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 批量准备节点数据
|
||||
nodes_to_create = []
|
||||
nodes_to_update = []
|
||||
nodes_to_delete = set()
|
||||
|
||||
# 处理节点
|
||||
for concept, data in memory_nodes:
|
||||
if not concept or not isinstance(concept, str):
|
||||
self.memory_graph.G.remove_node(concept)
|
||||
continue
|
||||
|
||||
memory_items = data.get("memory_items", "")
|
||||
|
||||
# 直接检查字符串是否为空,不需要分割成列表
|
||||
if not memory_items or memory_items.strip() == "":
|
||||
self.memory_graph.G.remove_node(concept)
|
||||
continue
|
||||
|
||||
# 计算内存中节点的特征值
|
||||
memory_hash = self.hippocampus.calculate_node_hash(concept, memory_items)
|
||||
created_time = data.get("created_time", current_time)
|
||||
last_modified = data.get("last_modified", current_time)
|
||||
|
||||
# memory_items直接作为字符串存储,不需要JSON序列化
|
||||
if not memory_items:
|
||||
continue
|
||||
|
||||
# 获取权重属性
|
||||
weight = data.get("weight", 1.0)
|
||||
|
||||
if concept not in db_nodes:
|
||||
nodes_to_create.append(
|
||||
{
|
||||
"concept": concept,
|
||||
"memory_items": memory_items,
|
||||
"weight": weight,
|
||||
"hash": memory_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
)
|
||||
else:
|
||||
db_node = db_nodes[concept]
|
||||
if db_node.hash != memory_hash:
|
||||
nodes_to_update.append(
|
||||
{
|
||||
"concept": concept,
|
||||
"memory_items": memory_items,
|
||||
"weight": weight,
|
||||
"hash": memory_hash,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
)
|
||||
|
||||
# 计算需要删除的节点
|
||||
memory_concepts = {concept for concept, _ in memory_nodes}
|
||||
nodes_to_delete = set(db_nodes.keys()) - memory_concepts
|
||||
|
||||
# 批量处理节点
|
||||
if nodes_to_create:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_to_create), batch_size):
|
||||
batch = nodes_to_create[i : i + batch_size]
|
||||
GraphNodes.insert_many(batch).execute()
|
||||
|
||||
if nodes_to_update:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_to_update), batch_size):
|
||||
batch = nodes_to_update[i : i + batch_size]
|
||||
for node_data in batch:
|
||||
GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where(
|
||||
GraphNodes.concept == node_data["concept"]
|
||||
).execute()
|
||||
|
||||
if nodes_to_delete:
|
||||
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() # type: ignore
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(GraphEdges.select())
|
||||
memory_edges = list(self.memory_graph.G.edges(data=True))
|
||||
|
||||
# 创建边的哈希值字典
|
||||
db_edge_dict = {}
|
||||
for edge in db_edges:
|
||||
edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target)
|
||||
db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength}
|
||||
|
||||
# 批量准备边数据
|
||||
edges_to_create = []
|
||||
edges_to_update = []
|
||||
|
||||
# 处理边
|
||||
for source, target, data in memory_edges:
|
||||
edge_hash = self.hippocampus.calculate_edge_hash(source, target)
|
||||
edge_key = (source, target)
|
||||
strength = data.get("strength", 1)
|
||||
created_time = data.get("created_time", current_time)
|
||||
last_modified = data.get("last_modified", current_time)
|
||||
|
||||
if edge_key not in db_edge_dict:
|
||||
edges_to_create.append(
|
||||
{
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": strength,
|
||||
"hash": edge_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
)
|
||||
elif db_edge_dict[edge_key]["hash"] != edge_hash:
|
||||
edges_to_update.append(
|
||||
{
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": strength,
|
||||
"hash": edge_hash,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
)
|
||||
|
||||
# 计算需要删除的边
|
||||
memory_edge_keys = {(source, target) for source, target, _ in memory_edges}
|
||||
edges_to_delete = set(db_edge_dict.keys()) - memory_edge_keys
|
||||
|
||||
# 批量处理边
|
||||
if edges_to_create:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_to_create), batch_size):
|
||||
batch = edges_to_create[i : i + batch_size]
|
||||
GraphEdges.insert_many(batch).execute()
|
||||
|
||||
if edges_to_update:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_to_update), batch_size):
|
||||
batch = edges_to_update[i : i + batch_size]
|
||||
for edge_data in batch:
|
||||
GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}).where(
|
||||
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
|
||||
).execute()
|
||||
|
||||
if edges_to_delete:
|
||||
for source, target in edges_to_delete:
|
||||
GraphEdges.delete().where((GraphEdges.source == source) & (GraphEdges.target == target)).execute()
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"[数据库] 同步完成,总耗时: {end_time - start_time:.2f}秒")
|
||||
logger.info(
|
||||
f"[数据库] 同步了 {len(nodes_to_create) + len(nodes_to_update)} 个节点和 {len(edges_to_create) + len(edges_to_update)} 条边"
|
||||
)
|
||||
|
||||
async def resync_memory_to_db(self):
|
||||
"""清空数据库并重新同步所有记忆数据"""
|
||||
start_time = time.time()
|
||||
logger.info("[数据库] 开始重新同步所有记忆数据...")
|
||||
|
||||
# 清空数据库
|
||||
clear_start = time.time()
|
||||
GraphNodes.delete().execute()
|
||||
GraphEdges.delete().execute()
|
||||
clear_end = time.time()
|
||||
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒")
|
||||
|
||||
# 获取所有节点和边
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
memory_edges = list(self.memory_graph.G.edges(data=True))
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
# 批量准备节点数据
|
||||
nodes_data = []
|
||||
for concept, data in memory_nodes:
|
||||
memory_items = data.get("memory_items", "")
|
||||
|
||||
# 直接检查字符串是否为空,不需要分割成列表
|
||||
if not memory_items or memory_items.strip() == "":
|
||||
self.memory_graph.G.remove_node(concept)
|
||||
continue
|
||||
|
||||
# 计算内存中节点的特征值
|
||||
memory_hash = self.hippocampus.calculate_node_hash(concept, memory_items)
|
||||
created_time = data.get("created_time", current_time)
|
||||
last_modified = data.get("last_modified", current_time)
|
||||
|
||||
# memory_items直接作为字符串存储,不需要JSON序列化
|
||||
if not memory_items:
|
||||
continue
|
||||
|
||||
# 获取权重属性
|
||||
weight = data.get("weight", 1.0)
|
||||
|
||||
nodes_data.append(
|
||||
{
|
||||
"concept": concept,
|
||||
"memory_items": memory_items,
|
||||
"weight": weight,
|
||||
"hash": memory_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
)
|
||||
|
||||
# 批量插入节点
|
||||
if nodes_data:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_data), batch_size):
|
||||
batch = nodes_data[i : i + batch_size]
|
||||
GraphNodes.insert_many(batch).execute()
|
||||
|
||||
# 批量准备边数据
|
||||
edges_data = []
|
||||
for source, target, data in memory_edges:
|
||||
try:
|
||||
edges_data.append(
|
||||
{
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": data.get("strength", 1),
|
||||
"hash": self.hippocampus.calculate_edge_hash(source, target),
|
||||
"created_time": data.get("created_time", current_time),
|
||||
"last_modified": data.get("last_modified", current_time),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}")
|
||||
continue
|
||||
|
||||
# 批量插入边
|
||||
if edges_data:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_data), batch_size):
|
||||
batch = edges_data[i : i + batch_size]
|
||||
GraphEdges.insert_many(batch).execute()
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}秒")
|
||||
logger.info(f"[数据库] 同步了 {len(nodes_data)} 个节点和 {len(edges_data)} 条边")
|
||||
|
||||
def sync_memory_from_db(self):
|
||||
"""从数据库同步数据到内存中的图结构"""
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
need_update = False
|
||||
|
||||
# 清空当前图
|
||||
self.memory_graph.G.clear()
|
||||
|
||||
# 统计加载情况
|
||||
total_nodes = 0
|
||||
loaded_nodes = 0
|
||||
skipped_nodes = 0
|
||||
|
||||
# 从数据库加载所有节点
|
||||
nodes = list(GraphNodes.select())
|
||||
total_nodes = len(nodes)
|
||||
|
||||
for node in nodes:
|
||||
concept = node.concept
|
||||
try:
|
||||
# 处理空字符串或None的情况
|
||||
if not node.memory_items or node.memory_items.strip() == "":
|
||||
logger.warning(f"节点 {concept} 的memory_items为空,跳过")
|
||||
skipped_nodes += 1
|
||||
continue
|
||||
|
||||
# 直接使用memory_items
|
||||
memory_items = node.memory_items.strip()
|
||||
|
||||
# 检查时间字段是否存在
|
||||
if not node.created_time or not node.last_modified:
|
||||
# 更新数据库中的节点
|
||||
update_data = {}
|
||||
if not node.created_time:
|
||||
update_data["created_time"] = current_time
|
||||
if not node.last_modified:
|
||||
update_data["last_modified"] = current_time
|
||||
|
||||
if update_data:
|
||||
GraphNodes.update(**update_data).where(GraphNodes.concept == concept).execute()
|
||||
|
||||
# 获取时间信息(如果不存在则使用当前时间)
|
||||
created_time = node.created_time or current_time
|
||||
last_modified = node.last_modified or current_time
|
||||
|
||||
# 获取权重属性
|
||||
weight = node.weight if hasattr(node, "weight") and node.weight is not None else 1.0
|
||||
|
||||
# 添加节点到图中
|
||||
self.memory_graph.G.add_node(
|
||||
concept,
|
||||
memory_items=memory_items,
|
||||
weight=weight,
|
||||
created_time=created_time,
|
||||
last_modified=last_modified,
|
||||
)
|
||||
loaded_nodes += 1
|
||||
except Exception as e:
|
||||
logger.error(f"加载节点 {concept} 时发生错误: {e}")
|
||||
skipped_nodes += 1
|
||||
continue
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = list(GraphEdges.select())
|
||||
for edge in edges:
|
||||
source = edge.source
|
||||
target = edge.target
|
||||
strength = edge.strength
|
||||
|
||||
# 检查时间字段是否存在
|
||||
if not edge.created_time or not edge.last_modified:
|
||||
need_update = True
|
||||
# 更新数据库中的边
|
||||
update_data = {}
|
||||
if not edge.created_time:
|
||||
update_data["created_time"] = current_time
|
||||
if not edge.last_modified:
|
||||
update_data["last_modified"] = current_time
|
||||
|
||||
GraphEdges.update(**update_data).where(
|
||||
(GraphEdges.source == source) & (GraphEdges.target == target)
|
||||
).execute()
|
||||
|
||||
# 获取时间信息(如果不存在则使用当前时间)
|
||||
created_time = edge.created_time or current_time
|
||||
last_modified = edge.last_modified or current_time
|
||||
|
||||
# 只有当源节点和目标节点都存在时才添加边
|
||||
if source in self.memory_graph.G and target in self.memory_graph.G:
|
||||
self.memory_graph.G.add_edge(
|
||||
source, target, strength=strength, created_time=created_time, last_modified=last_modified
|
||||
)
|
||||
|
||||
if need_update:
|
||||
logger.info("[数据库] 已为缺失的时间字段进行补充")
|
||||
|
||||
# 输出加载统计信息
|
||||
logger.info(
|
||||
f"[数据库] 记忆加载完成: 总计 {total_nodes} 个节点, 成功加载 {loaded_nodes} 个, 跳过 {skipped_nodes} 个"
|
||||
)
|
||||
|
||||
|
||||
# 负责记忆管理
|
||||
class ParahippocampalGyrus:
|
||||
def __init__(self, hippocampus: Hippocampus):
|
||||
self.hippocampus = hippocampus
|
||||
self.memory_graph = hippocampus.memory_graph
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class HippocampusManager:
|
||||
def __init__(self):
|
||||
self._hippocampus: Hippocampus = None # type: ignore
|
||||
self._initialized = False
|
||||
|
||||
def initialize(self):
|
||||
"""初始化海马体实例"""
|
||||
if self._initialized:
|
||||
return self._hippocampus
|
||||
|
||||
self._hippocampus = Hippocampus()
|
||||
self._hippocampus.initialize()
|
||||
self._initialized = True
|
||||
|
||||
# 输出记忆图统计信息
|
||||
memory_graph = self._hippocampus.memory_graph.G
|
||||
node_count = len(memory_graph.nodes())
|
||||
edge_count = len(memory_graph.edges())
|
||||
|
||||
logger.info(f"""
|
||||
--------------------------------
|
||||
记忆系统参数配置:
|
||||
记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count}
|
||||
--------------------------------""") # noqa: E501
|
||||
|
||||
return self._hippocampus
|
||||
|
||||
def get_hippocampus(self):
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return self._hippocampus
|
||||
|
||||
|
||||
async def get_memory_from_topic(
|
||||
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||
) -> list:
|
||||
"""从文本中获取相关记忆的公共接口"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
try:
|
||||
response = await self._hippocampus.get_memory_from_topic(
|
||||
valid_keywords, max_memory_num, max_memory_length, max_depth
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"文本激活记忆失败: {e}")
|
||||
response = []
|
||||
return response
|
||||
|
||||
async def get_activate_from_text(
|
||||
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
|
||||
) -> tuple[float, list[str], list[str]]:
|
||||
"""从文本中获取激活值的公共接口"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
try:
|
||||
return await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval)
|
||||
except Exception as e:
|
||||
logger.error(f"文本产生激活值失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return 0.0, [], []
|
||||
|
||||
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
||||
"""从关键词获取相关记忆的公共接口"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return self._hippocampus.get_memory_from_keyword(keyword, max_depth)
|
||||
|
||||
def get_all_node_names(self) -> list:
|
||||
"""获取所有节点名称的公共接口"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return self._hippocampus.get_all_node_names()
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
hippocampus_manager = HippocampusManager()
|
||||
722
src/memory_system/Memory_chest.py
Normal file
722
src/memory_system/Memory_chest.py
Normal file
@@ -0,0 +1,722 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import random
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.common.database.database_model import MemoryChest as MemoryChestModel
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis.message_api import build_readable_messages
|
||||
from src.plugin_system.apis.message_api import get_raw_msg_by_timestamp_with_chat
|
||||
from json_repair import repair_json
|
||||
from .memory_utils import (
|
||||
find_best_matching_memory,
|
||||
check_title_exists_fuzzy
|
||||
)
|
||||
|
||||
logger = get_logger("memory_chest")
|
||||
|
||||
class MemoryChest:
|
||||
def __init__(self):
|
||||
|
||||
self.LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="memory_chest",
|
||||
)
|
||||
|
||||
self.LLMRequest_build = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="memory_chest_build",
|
||||
)
|
||||
|
||||
self.memory_build_threshold = 20
|
||||
self.memory_size_limit = global_config.memory.max_memory_size
|
||||
|
||||
self.running_content_list = {} # {chat_id: {"content": running_content, "last_update_time": timestamp, "create_time": timestamp}}
|
||||
self.fetched_memory_list = [] # [(chat_id, (question, answer, timestamp)), ...]
|
||||
|
||||
async def build_running_content(self, chat_id: str = None) -> str:
|
||||
"""
|
||||
构建记忆仓库的运行内容
|
||||
|
||||
Args:
|
||||
message_str: 消息内容
|
||||
chat_id: 聊天ID,用于提取对应的运行内容
|
||||
|
||||
Returns:
|
||||
str: 构建后的运行内容
|
||||
"""
|
||||
# 检查是否需要更新:上次更新时间和现在时间的消息数量大于30
|
||||
if chat_id not in self.running_content_list:
|
||||
self.running_content_list[chat_id] = {
|
||||
"content": "",
|
||||
"last_update_time": time.time(),
|
||||
"create_time": time.time()
|
||||
}
|
||||
|
||||
should_update = True
|
||||
if chat_id and chat_id in self.running_content_list:
|
||||
last_update_time = self.running_content_list[chat_id]["last_update_time"]
|
||||
current_time = time.time()
|
||||
# 使用message_api获取消息数量
|
||||
message_list = get_raw_msg_by_timestamp_with_chat(
|
||||
timestamp_start=last_update_time,
|
||||
timestamp_end=current_time,
|
||||
chat_id=chat_id,
|
||||
limit=global_config.chat.max_context_size * 2,
|
||||
)
|
||||
|
||||
new_messages_count = len(message_list)
|
||||
time_diff_minutes = (current_time - last_update_time) / 60
|
||||
|
||||
# 检查是否满足强制构建条件:超过15分钟且至少有5条新消息
|
||||
forced_update = time_diff_minutes > 15 and new_messages_count >= 5
|
||||
should_update = new_messages_count > self.memory_build_threshold or forced_update
|
||||
|
||||
if forced_update:
|
||||
logger.info(f"chat_id {chat_id} 距离上次更新已 {time_diff_minutes:.1f} 分钟,有 {new_messages_count} 条新消息,强制构建")
|
||||
else:
|
||||
logger.info(f"chat_id {chat_id} 自上次更新后有 {new_messages_count} 条新消息,{'需要' if should_update else '不需要'}更新")
|
||||
|
||||
|
||||
if should_update:
|
||||
# 如果有chat_id,先提取对应的running_content
|
||||
message_str = build_readable_messages(
|
||||
message_list,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
remove_emoji_stickers=True,
|
||||
)
|
||||
|
||||
|
||||
current_running_content = ""
|
||||
if chat_id and chat_id in self.running_content_list:
|
||||
current_running_content = self.running_content_list[chat_id]["content"]
|
||||
|
||||
prompt = f"""
|
||||
以下是你的记忆内容和新的聊天记录,请你将他们整合和修改:
|
||||
记忆内容:
|
||||
<memory_content>
|
||||
{current_running_content}
|
||||
</memory_content>
|
||||
|
||||
<聊天记录>
|
||||
{message_str}
|
||||
</聊天记录>
|
||||
聊天记录中可能包含有效信息,也可能信息密度很低,请你根据聊天记录中的信息,修改<part1>中的内容与<part2>中的内容
|
||||
--------------------------------
|
||||
请将上面的新聊天记录内的有用的信息进行整合到现有的记忆中
|
||||
请主要关注概念和知识或者时效性较强的信息!!,而不是聊天的琐事
|
||||
1.不要关注诸如某个用户做了什么,说了什么,不要关注某个用户的行为,而是关注其中的概念性信息
|
||||
2.概念要求精确,不啰嗦,像科普读物或教育课本那样
|
||||
3.如果有图片,请只关注图片和文本结合的知识和概念性内容
|
||||
4.记忆为一段纯文本,逻辑清晰,指出概念的含义,并说明关系
|
||||
|
||||
记忆内容的格式,你必须仿照下面的格式,但不一定全部使用:
|
||||
[概念] 是 [概念的含义(简短描述,不超过十个字)]
|
||||
[概念] 不是 [对概念的负面含义(简短描述,不超过十个字)]
|
||||
[概念1] 与 [概念2] 是 [概念1和概念2的关联(简短描述,不超过二十个字)]
|
||||
[概念1] 包含 [概念2] 和 [概念3]
|
||||
[概念1] 属于 [概念2]
|
||||
......(不要包含中括号)
|
||||
|
||||
请仿照上述格式输出,每个知识点一句话。输出成一段平文本
|
||||
现在请你输出,不要输出其他内容,注意一定要直白,白话,口语化不要浮夸,修辞。:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"记忆仓库构建运行内容 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"记忆仓库构建运行内容 prompt: {prompt}")
|
||||
|
||||
running_content, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(prompt)
|
||||
|
||||
print(f"记忆仓库构建运行内容: {running_content}")
|
||||
|
||||
# 如果有chat_id,更新对应的running_content
|
||||
if chat_id and running_content:
|
||||
current_time = time.time()
|
||||
|
||||
# 保留原有的create_time,如果没有则使用当前时间
|
||||
create_time = self.running_content_list[chat_id].get("create_time", current_time)
|
||||
|
||||
self.running_content_list[chat_id] = {
|
||||
"content": running_content,
|
||||
"last_update_time": current_time,
|
||||
"create_time": create_time
|
||||
}
|
||||
|
||||
# 检查running_content长度是否大于限制
|
||||
if len(running_content) > self.memory_size_limit:
|
||||
await self._save_to_database_and_clear(chat_id, running_content)
|
||||
|
||||
# 检查是否需要强制保存:create_time超过1800秒且内容大小达到max_memory_size的30%
|
||||
elif (current_time - create_time > 1800 and
|
||||
len(running_content) >= self.memory_size_limit * 0.3):
|
||||
logger.info(f"chat_id {chat_id} 内容创建时间已超过 {(current_time - create_time)/60:.1f} 分钟,"
|
||||
f"内容大小 {len(running_content)} 达到限制的 {int(self.memory_size_limit * 0.3)} 字符,强制保存")
|
||||
await self._save_to_database_and_clear(chat_id, running_content)
|
||||
|
||||
|
||||
return running_content
|
||||
|
||||
|
||||
|
||||
|
||||
def get_all_titles(self, exclude_locked: bool = False) -> list[str]:
|
||||
"""
|
||||
获取记忆仓库中的所有标题
|
||||
|
||||
Args:
|
||||
exclude_locked: 是否排除锁定的记忆,默认为 False
|
||||
|
||||
Returns:
|
||||
list: 包含所有标题的列表
|
||||
"""
|
||||
try:
|
||||
# 查询所有记忆记录的标题
|
||||
titles = []
|
||||
for memory in MemoryChestModel.select():
|
||||
if memory.title:
|
||||
# 如果 exclude_locked 为 True 且记忆已锁定,则跳过
|
||||
if exclude_locked and memory.locked:
|
||||
continue
|
||||
titles.append(memory.title)
|
||||
return titles
|
||||
except Exception as e:
|
||||
print(f"获取记忆标题时出错: {e}")
|
||||
return []
|
||||
|
||||
async def get_answer_by_question(self, chat_id: str = "", question: str = "") -> str:
|
||||
"""
|
||||
根据问题获取答案
|
||||
"""
|
||||
logger.info(f"正在回忆问题答案: {question}")
|
||||
|
||||
title = await self.select_title_by_question(question)
|
||||
|
||||
if not title:
|
||||
return ""
|
||||
|
||||
for memory in MemoryChestModel.select():
|
||||
if memory.title == title:
|
||||
content = memory.content
|
||||
|
||||
if random.random() < 0.5:
|
||||
type = "要求原文能够较为全面的回答问题"
|
||||
else:
|
||||
type = "要求提取简短的内容"
|
||||
|
||||
prompt = f"""
|
||||
{content}
|
||||
|
||||
请根据问题:{question}
|
||||
在上方内容中,提取相关信息的原文并输出,{type}
|
||||
请务必提取上面原文,不要输出其他内容:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"记忆仓库获取答案 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"记忆仓库获取答案 prompt: {prompt}")
|
||||
|
||||
answer, (reasoning_content, model_name, tool_calls) = await self.LLMRequest.generate_response_async(prompt)
|
||||
|
||||
|
||||
logger.info(f"记忆仓库对问题 “{question}” 获取答案: {answer}")
|
||||
|
||||
# 将问题和答案存到fetched_memory_list
|
||||
if chat_id and answer:
|
||||
self.fetched_memory_list.append((chat_id, (question, answer, time.time())))
|
||||
|
||||
# 清理fetched_memory_list
|
||||
self._cleanup_fetched_memory_list()
|
||||
|
||||
return answer
|
||||
|
||||
def get_chat_memories_as_string(self, chat_id: str) -> str:
|
||||
"""
|
||||
获取某个chat_id的所有记忆,并构建成字符串
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
str: 格式化的记忆字符串,格式:问题:xxx,答案:xxxxx\n问题:xxx,答案:xxxxx\n...
|
||||
"""
|
||||
try:
|
||||
memories = []
|
||||
|
||||
# 从fetched_memory_list中获取该chat_id的所有记忆
|
||||
for cid, (question, answer, timestamp) in self.fetched_memory_list:
|
||||
if cid == chat_id:
|
||||
memories.append(f"问题:{question},答案:{answer}")
|
||||
|
||||
# 按时间戳排序(最新的在后面)
|
||||
memories.sort()
|
||||
|
||||
# 用换行符连接所有记忆
|
||||
result = "\n".join(memories)
|
||||
|
||||
logger.info(f"chat_id {chat_id} 共有 {len(memories)} 条记忆")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取chat_id {chat_id} 的记忆时出错: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
async def select_title_by_question(self, question: str) -> str:
|
||||
"""
|
||||
根据消息内容选择最匹配的标题
|
||||
|
||||
Args:
|
||||
question: 问题
|
||||
|
||||
Returns:
|
||||
str: 选择的标题
|
||||
"""
|
||||
# 获取所有标题并构建格式化字符串(排除锁定的记忆)
|
||||
titles = self.get_all_titles(exclude_locked=True)
|
||||
formatted_titles = ""
|
||||
for title in titles:
|
||||
formatted_titles += f"{title}\n"
|
||||
|
||||
prompt = f"""
|
||||
所有主题:
|
||||
{formatted_titles}
|
||||
|
||||
请根据以下问题,选择一个能够回答问题的主题:
|
||||
问题:{question}
|
||||
请你输出主题,不要输出其他内容,完整输出主题名:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"记忆仓库选择标题 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"记忆仓库选择标题 prompt: {prompt}")
|
||||
|
||||
|
||||
title, (reasoning_content, model_name, tool_calls) = await self.LLMRequest.generate_response_async(prompt)
|
||||
|
||||
# 根据 title 获取 titles 里的对应项
|
||||
titles = self.get_all_titles()
|
||||
selected_title = None
|
||||
|
||||
# 使用模糊查找匹配标题
|
||||
best_match = find_best_matching_memory(title, similarity_threshold=0.8)
|
||||
if best_match:
|
||||
selected_title = best_match[0] # 获取匹配的标题
|
||||
logger.info(f"记忆仓库选择标题: {selected_title} (相似度: {best_match[2]:.3f})")
|
||||
else:
|
||||
logger.warning(f"未找到相似度 >= 0.7 的标题匹配: {title}")
|
||||
selected_title = None
|
||||
|
||||
return selected_title
|
||||
|
||||
def _cleanup_fetched_memory_list(self):
|
||||
"""
|
||||
清理fetched_memory_list,移除超过10分钟的记忆和超过10条的最旧记忆
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
ten_minutes_ago = current_time - 600 # 10分钟 = 600秒
|
||||
|
||||
# 移除超过10分钟的记忆
|
||||
self.fetched_memory_list = [
|
||||
(chat_id, (question, answer, timestamp))
|
||||
for chat_id, (question, answer, timestamp) in self.fetched_memory_list
|
||||
if timestamp > ten_minutes_ago
|
||||
]
|
||||
|
||||
# 如果记忆条数超过10条,移除最旧的5条
|
||||
if len(self.fetched_memory_list) > 10:
|
||||
# 按时间戳排序,移除最旧的5条
|
||||
self.fetched_memory_list.sort(key=lambda x: x[1][2]) # 按timestamp排序
|
||||
self.fetched_memory_list = self.fetched_memory_list[5:] # 保留最新的5条
|
||||
|
||||
logger.debug(f"fetched_memory_list清理后,当前有 {len(self.fetched_memory_list)} 条记忆")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理fetched_memory_list时出错: {e}")
|
||||
|
||||
async def _save_to_database_and_clear(self, chat_id: str, content: str):
|
||||
"""
|
||||
生成标题,保存到数据库,并清空对应chat_id的running_content
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
content: 要保存的内容
|
||||
"""
|
||||
try:
|
||||
# 生成标题
|
||||
title = ""
|
||||
title_prompt = f"""
|
||||
请为以下内容生成一个描述全面的标题,要求描述内容的主要概念和事件:
|
||||
{content}
|
||||
|
||||
标题不要分点,不要换行,不要输出其他内容
|
||||
请只输出标题,不要输出其他内容:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"记忆仓库生成标题 prompt: {title_prompt}")
|
||||
else:
|
||||
logger.debug(f"记忆仓库生成标题 prompt: {title_prompt}")
|
||||
|
||||
title, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(title_prompt)
|
||||
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if title:
|
||||
# 保存到数据库
|
||||
MemoryChestModel.create(
|
||||
title=title.strip(),
|
||||
content=content
|
||||
)
|
||||
logger.info(f"已保存记忆仓库内容,标题: {title.strip()}, chat_id: {chat_id}")
|
||||
|
||||
# 清空对应chat_id的running_content
|
||||
if chat_id in self.running_content_list:
|
||||
del self.running_content_list[chat_id]
|
||||
logger.info(f"已清空chat_id {chat_id} 的running_content")
|
||||
else:
|
||||
logger.warning(f"生成标题失败,chat_id: {chat_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存记忆仓库内容时出错: {e}")
|
||||
|
||||
async def choose_merge_target(self, memory_title: str) -> list[str]:
|
||||
"""
|
||||
选择与给定记忆标题相关的记忆目标
|
||||
|
||||
Args:
|
||||
memory_title: 要匹配的记忆标题
|
||||
|
||||
Returns:
|
||||
list[str]: 选中的记忆内容列表
|
||||
"""
|
||||
try:
|
||||
all_titles = self.get_all_titles(exclude_locked=True)
|
||||
content = ""
|
||||
for title in all_titles:
|
||||
content += f"{title}\n"
|
||||
|
||||
prompt = f"""
|
||||
所有记忆列表
|
||||
{content}
|
||||
|
||||
请根据以上记忆列表,选择一个与"{memory_title}"相关的记忆,用json输出:
|
||||
可以选择多个相关的记忆,但最多不超过5个
|
||||
例如:
|
||||
{{
|
||||
"selected_title": "选择的相关记忆标题"
|
||||
}},
|
||||
{{
|
||||
"selected_title": "选择的相关记忆标题"
|
||||
}},
|
||||
{{
|
||||
"selected_title": "选择的相关记忆标题"
|
||||
}}
|
||||
...
|
||||
请输出JSON格式,不要输出其他内容:
|
||||
"""
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"选择合并目标 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"选择合并目标 prompt: {prompt}")
|
||||
|
||||
merge_target_response, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(prompt)
|
||||
|
||||
# 解析JSON响应
|
||||
selected_titles = self._parse_merge_target_json(merge_target_response)
|
||||
|
||||
# 根据标题查找对应的内容
|
||||
selected_contents = self._get_memories_by_titles(selected_titles)
|
||||
|
||||
logger.info(f"选择合并目标结果: {len(selected_contents)} 条记忆:{selected_titles}")
|
||||
return selected_contents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"选择合并目标时出错: {e}")
|
||||
return []
|
||||
|
||||
def _get_memories_by_titles(self, titles: list[str]) -> list[str]:
|
||||
"""
|
||||
根据标题列表查找对应的记忆内容
|
||||
|
||||
Args:
|
||||
titles: 记忆标题列表
|
||||
|
||||
Returns:
|
||||
list[str]: 记忆内容列表
|
||||
"""
|
||||
try:
|
||||
contents = []
|
||||
for title in titles:
|
||||
if not title or not title.strip():
|
||||
continue
|
||||
|
||||
# 使用模糊查找匹配记忆
|
||||
try:
|
||||
best_match = find_best_matching_memory(title.strip(), similarity_threshold=0.8)
|
||||
if best_match:
|
||||
# 检查记忆是否被锁定
|
||||
memory_title = best_match[0]
|
||||
memory_content = best_match[1]
|
||||
|
||||
# 查询数据库中的锁定状态
|
||||
for memory in MemoryChestModel.select():
|
||||
if memory.title == memory_title and memory.locked:
|
||||
logger.warning(f"记忆 '{memory_title}' 已锁定,跳过合并")
|
||||
continue
|
||||
|
||||
contents.append(memory_content)
|
||||
logger.debug(f"找到记忆: {memory_title} (相似度: {best_match[2]:.3f})")
|
||||
else:
|
||||
logger.warning(f"未找到相似度 >= 0.8 的标题匹配: '{title}'")
|
||||
except Exception as e:
|
||||
logger.error(f"查找标题 '{title}' 的记忆时出错: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"成功找到 {len(contents)} 条记忆内容")
|
||||
return contents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"根据标题查找记忆时出错: {e}")
|
||||
return []
|
||||
|
||||
def _parse_merged_parts(self, merged_response: str) -> tuple[str, str]:
|
||||
"""
|
||||
解析合并记忆的part1和part2内容
|
||||
|
||||
Args:
|
||||
merged_response: LLM返回的合并记忆响应
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: (part1_content, part2_content)
|
||||
"""
|
||||
try:
|
||||
# 使用正则表达式提取part1和part2内容
|
||||
import re
|
||||
|
||||
# 提取part1内容
|
||||
part1_pattern = r'<part1>(.*?)</part1>'
|
||||
part1_match = re.search(part1_pattern, merged_response, re.DOTALL)
|
||||
part1_content = part1_match.group(1).strip() if part1_match else ""
|
||||
|
||||
# 提取part2内容
|
||||
part2_pattern = r'<part2>(.*?)</part2>'
|
||||
part2_match = re.search(part2_pattern, merged_response, re.DOTALL)
|
||||
part2_content = part2_match.group(1).strip() if part2_match else ""
|
||||
|
||||
# 检查是否包含none或None(不区分大小写)
|
||||
def is_none_content(content: str) -> bool:
|
||||
if not content:
|
||||
return True
|
||||
# 检查是否只包含"none"或"None"(不区分大小写)
|
||||
return re.match(r'^\s*none\s*$', content, re.IGNORECASE) is not None
|
||||
|
||||
# 如果包含none,则设置为空字符串
|
||||
if is_none_content(part1_content):
|
||||
part1_content = ""
|
||||
logger.info("part1内容为none,设置为空")
|
||||
|
||||
if is_none_content(part2_content):
|
||||
part2_content = ""
|
||||
logger.info("part2内容为none,设置为空")
|
||||
|
||||
logger.info(f"解析合并记忆结果: part1={len(part1_content)}字符, part2={len(part2_content)}字符")
|
||||
|
||||
return part1_content, part2_content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析合并记忆part1/part2时出错: {e}")
|
||||
return "", ""
|
||||
|
||||
def _parse_merge_target_json(self, json_text: str) -> list[str]:
|
||||
"""
|
||||
解析choose_merge_target生成的JSON响应
|
||||
|
||||
Args:
|
||||
json_text: LLM返回的JSON文本
|
||||
|
||||
Returns:
|
||||
list[str]: 解析出的记忆标题列表
|
||||
"""
|
||||
try:
|
||||
# 清理JSON文本,移除可能的额外内容
|
||||
repaired_content = repair_json(json_text)
|
||||
|
||||
# 尝试直接解析JSON
|
||||
try:
|
||||
parsed_data = json.loads(repaired_content)
|
||||
if isinstance(parsed_data, list):
|
||||
# 如果是列表,提取selected_title字段
|
||||
titles = []
|
||||
for item in parsed_data:
|
||||
if isinstance(item, dict) and "selected_title" in item:
|
||||
titles.append(item["selected_title"])
|
||||
return titles
|
||||
elif isinstance(parsed_data, dict) and "selected_title" in parsed_data:
|
||||
# 如果是单个对象
|
||||
return [parsed_data["selected_title"]]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 如果直接解析失败,尝试提取JSON对象
|
||||
# 查找所有包含selected_title的JSON对象
|
||||
pattern = r'\{[^}]*"selected_title"[^}]*\}'
|
||||
matches = re.findall(pattern, repaired_content)
|
||||
|
||||
titles = []
|
||||
for match in matches:
|
||||
try:
|
||||
obj = json.loads(match)
|
||||
if "selected_title" in obj:
|
||||
titles.append(obj["selected_title"])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if titles:
|
||||
return titles
|
||||
|
||||
logger.warning(f"无法解析JSON响应: {json_text[:200]}...")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析合并目标JSON时出错: {e}")
|
||||
return []
|
||||
|
||||
async def merge_memory(self,memory_list: list[str]) -> tuple[str, str]:
|
||||
"""
|
||||
合并记忆
|
||||
"""
|
||||
try:
|
||||
content = ""
|
||||
for memory in memory_list:
|
||||
content += f"{memory}\n"
|
||||
|
||||
prompt = f"""
|
||||
以下是多段记忆内容,请将它们进行整合和修改:
|
||||
{content}
|
||||
--------------------------------
|
||||
请将上面的多段记忆内容,合并成两部分内容,第一部分是可以整合,不冲突的概念和知识,第二部分是相互有冲突的概念和知识
|
||||
请主要关注概念和知识,而不是聊天的琐事
|
||||
重要!!你要关注的概念和知识必须是较为不常见的信息,或者时效性较强的信息!!
|
||||
不要!!关注常见的只是,或者已经过时的信息!!
|
||||
1.不要关注诸如某个用户做了什么,说了什么,不要关注某个用户的行为,而是关注其中的概念性信息
|
||||
2.概念要求精确,不啰嗦,像科普读物或教育课本那样
|
||||
3.如果有图片,请只关注图片和文本结合的知识和概念性内容
|
||||
4.记忆为一段纯文本,逻辑清晰,指出概念的含义,并说明关系
|
||||
**第一部分**
|
||||
1.如果两个概念在描述同一件事情,且相互之间逻辑不冲突(请你严格判断),且相互之间没有矛盾,请将它们整合成一个概念,并输出到第一部分
|
||||
2.如果某个概念在时间上更新了另一个概念,请用新概念更新就概念来整合,并输出到第一部分
|
||||
3.如果没有可整合的概念,请你输出none
|
||||
**第二部分**
|
||||
1.如果记忆中有无法整合的地方,例如概念不一致,有逻辑上的冲突,请你输出到第二部分
|
||||
2.如果两个概念在描述同一件事情,但相互之间逻辑冲突,请将它们输出到第二部分
|
||||
3.如果没有无法整合的概念,请你输出none
|
||||
|
||||
**输出格式要求**
|
||||
请你按以下格式输出:
|
||||
<part1>
|
||||
第一部分内容,整合后的概念,如果第一部分为none,请输出none
|
||||
</part1>
|
||||
<part2>
|
||||
第二部分内容,无法整合,冲突的概念,如果第二部分为none,请输出none
|
||||
</part2>
|
||||
不要输出其他内容,现在请你输出,不要输出其他内容,注意一定要直白,白话,口语化不要浮夸,修辞。:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"合并记忆 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"合并记忆 prompt: {prompt}")
|
||||
|
||||
merged_memory, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(prompt)
|
||||
|
||||
# 解析part1和part2
|
||||
part1_content, part2_content = self._parse_merged_parts(merged_memory)
|
||||
|
||||
# 处理part2:独立记录冲突内容(无论part1是否为空)
|
||||
if part2_content and part2_content.strip() != "none":
|
||||
logger.info(f"合并记忆part2记录冲突内容: {len(part2_content)} 字符")
|
||||
# 导入冲突追踪器
|
||||
from src.curiousity.questions import global_conflict_tracker
|
||||
# 记录冲突到数据库
|
||||
await global_conflict_tracker.record_memory_merge_conflict(part2_content)
|
||||
|
||||
# 处理part1:生成标题并保存
|
||||
if part1_content and part1_content.strip() != "none":
|
||||
merged_title = await self._generate_title_for_merged_memory(part1_content)
|
||||
|
||||
# 保存part1到数据库
|
||||
MemoryChestModel.create(
|
||||
title=merged_title,
|
||||
content=part1_content
|
||||
)
|
||||
|
||||
logger.info(f"合并记忆part1已保存: {merged_title}")
|
||||
|
||||
return merged_title, part1_content
|
||||
else:
|
||||
logger.warning("合并记忆part1为空,跳过保存")
|
||||
return "", ""
|
||||
except Exception as e:
|
||||
logger.error(f"合并记忆时出错: {e}")
|
||||
return "", ""
|
||||
|
||||
async def _generate_title_for_merged_memory(self, merged_content: str) -> str:
|
||||
"""
|
||||
为合并后的记忆生成标题
|
||||
|
||||
Args:
|
||||
merged_content: 合并后的记忆内容
|
||||
|
||||
Returns:
|
||||
str: 生成的标题
|
||||
"""
|
||||
try:
|
||||
prompt = f"""
|
||||
请为以下内容生成一个描述全面的标题,要求描述内容的主要概念和事件:
|
||||
{merged_content}
|
||||
|
||||
标题不要分点,不要换行,不要输出其他内容,不要浮夸,以白话简洁的风格输出标题
|
||||
请只输出标题,不要输出其他内容:
|
||||
"""
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"生成合并记忆标题 prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"生成合并记忆标题 prompt: {prompt}")
|
||||
|
||||
title_response, (reasoning_content, model_name, tool_calls) = await self.LLMRequest.generate_response_async(prompt)
|
||||
|
||||
# 清理标题,移除可能的引号或多余字符
|
||||
title = title_response.strip().strip('"').strip("'").strip()
|
||||
|
||||
if title:
|
||||
# 检查是否存在相似标题
|
||||
if check_title_exists_fuzzy(title, similarity_threshold=0.9):
|
||||
logger.warning(f"生成的标题 '{title}' 与现有标题相似,使用时间戳后缀")
|
||||
title = f"{title}_{int(time.time())}"
|
||||
|
||||
logger.info(f"生成合并记忆标题: {title}")
|
||||
return title
|
||||
else:
|
||||
logger.warning("生成合并记忆标题失败,使用默认标题")
|
||||
return f"合并记忆_{int(time.time())}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成合并记忆标题时出错: {e}")
|
||||
return f"合并记忆_{int(time.time())}"
|
||||
|
||||
|
||||
global_memory_chest = MemoryChest()
|
||||
224
src/memory_system/hippocampus_to_memory_chest_task.py
Normal file
224
src/memory_system/hippocampus_to_memory_chest_task.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import random
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("hippocampus_to_memory_chest")
|
||||
|
||||
|
||||
class HippocampusToMemoryChestTask(AsyncTask):
|
||||
"""海马体到记忆仓库的转换任务
|
||||
|
||||
每10秒执行一次转换,每次最多处理50批,每批15个节点,
|
||||
当没有新节点时停止任务运行
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
task_name="Hippocampus to Memory Chest Task",
|
||||
wait_before_start=5, # 启动后等待5秒再开始
|
||||
run_interval=10 # 每10秒运行一次
|
||||
)
|
||||
self.task_stopped = False # 标记任务是否已停止
|
||||
|
||||
async def start_task(self, abort_flag: asyncio.Event):
|
||||
"""重写start_task方法,支持任务停止"""
|
||||
if self.wait_before_start > 0:
|
||||
# 等待指定时间后开始任务
|
||||
await asyncio.sleep(self.wait_before_start)
|
||||
|
||||
while not abort_flag.is_set() and not self.task_stopped:
|
||||
await self.run()
|
||||
if self.run_interval > 0:
|
||||
await asyncio.sleep(self.run_interval)
|
||||
else:
|
||||
break
|
||||
|
||||
if self.task_stopped:
|
||||
logger.info("[海马体转换] 任务已完全停止,不再执行")
|
||||
|
||||
async def run(self):
|
||||
"""执行转换任务"""
|
||||
try:
|
||||
# 检查任务是否已停止
|
||||
if self.task_stopped:
|
||||
logger.info("[海马体转换] 任务已停止,跳过执行")
|
||||
return
|
||||
|
||||
logger.info("[海马体转换] 开始执行海马体到记忆仓库的转换任务")
|
||||
|
||||
# 检查海马体管理器是否已初始化
|
||||
if not hippocampus_manager._initialized:
|
||||
logger.warning("[海马体转换] 海马体管理器尚未初始化,跳过本次转换")
|
||||
return
|
||||
|
||||
# 获取海马体实例
|
||||
hippocampus = hippocampus_manager.get_hippocampus()
|
||||
memory_graph = hippocampus.memory_graph.G
|
||||
|
||||
# 执行10批转换
|
||||
total_processed = 0
|
||||
total_success = 0
|
||||
|
||||
for batch_num in range(1, 51): # 执行10批
|
||||
logger.info(f"[海马体转换] 开始执行第 {batch_num} 批转换")
|
||||
|
||||
# 检查剩余节点
|
||||
remaining_nodes = list(memory_graph.nodes())
|
||||
if len(remaining_nodes) == 0:
|
||||
logger.info(f"[海马体转换] 第 {batch_num} 批:没有剩余节点,停止任务运行")
|
||||
self.task_stopped = True
|
||||
break
|
||||
|
||||
# 如果剩余节点不足10个,使用所有剩余节点
|
||||
if len(remaining_nodes) < 5:
|
||||
selected_nodes = remaining_nodes
|
||||
logger.info(f"[海马体转换] 第 {batch_num} 批:剩余节点不足10个({len(remaining_nodes)}个),使用所有剩余节点")
|
||||
else:
|
||||
# 随机选择10个节点
|
||||
selected_nodes = random.sample(remaining_nodes, 5)
|
||||
logger.info(f"[海马体转换] 第 {batch_num} 批:选择了 {len(selected_nodes)} 个节点")
|
||||
|
||||
# 拼接节点内容
|
||||
content_parts = []
|
||||
valid_nodes = []
|
||||
|
||||
for node in selected_nodes:
|
||||
node_data = memory_graph.nodes[node]
|
||||
memory_items = node_data.get("memory_items", "")
|
||||
|
||||
if memory_items and memory_items.strip():
|
||||
# 添加节点名称和内容
|
||||
content_parts.append(f"【{node}】{memory_items}")
|
||||
valid_nodes.append(node)
|
||||
else:
|
||||
logger.debug(f"[海马体转换] 第 {batch_num} 批:节点 {node} 没有记忆内容,跳过")
|
||||
|
||||
if not content_parts:
|
||||
logger.info(f"[海马体转换] 第 {batch_num} 批:没有找到有效的记忆内容,跳过")
|
||||
continue
|
||||
|
||||
# 拼接所有内容
|
||||
combined_content = "\n\n".join(content_parts)
|
||||
logger.info(f"[海马体转换] 第 {batch_num} 批:拼接完成,内容长度: {len(combined_content)} 字符")
|
||||
|
||||
# 生成标题并存储到记忆仓库
|
||||
success = await self._save_to_memory_chest(combined_content, batch_num)
|
||||
|
||||
# 如果保存成功,删除已转换的节点
|
||||
if success:
|
||||
await self._remove_converted_nodes(valid_nodes)
|
||||
total_success += 1
|
||||
logger.info(f"[海马体转换] 第 {batch_num} 批:转换成功")
|
||||
else:
|
||||
logger.warning(f"[海马体转换] 第 {batch_num} 批:转换失败")
|
||||
|
||||
total_processed += 1
|
||||
|
||||
# 批次间短暂休息,避免过于频繁的数据库操作
|
||||
if batch_num < 10:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
logger.info(f"[海马体转换] 本次执行完成:共处理 {total_processed} 批,成功 {total_success} 批")
|
||||
|
||||
logger.info("[海马体转换] 转换任务完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[海马体转换] 执行转换任务时发生错误: {e}", exc_info=True)
|
||||
|
||||
async def _save_to_memory_chest(self, content: str, batch_num: int = 1) -> bool:
|
||||
"""将内容保存到记忆仓库
|
||||
|
||||
Args:
|
||||
content: 要保存的内容
|
||||
batch_num: 批次号
|
||||
|
||||
Returns:
|
||||
bool: 保存是否成功
|
||||
"""
|
||||
try:
|
||||
# 从内容中提取节点名称作为标题
|
||||
title = self._generate_title_from_content(content, batch_num)
|
||||
|
||||
if title:
|
||||
# 保存到数据库
|
||||
from src.common.database.database_model import MemoryChest as MemoryChestModel
|
||||
|
||||
MemoryChestModel.create(
|
||||
title=title,
|
||||
content=content
|
||||
)
|
||||
|
||||
logger.info(f"[海马体转换] 第 {batch_num} 批:已保存到记忆仓库,标题: {title}")
|
||||
return True
|
||||
else:
|
||||
logger.warning("[海马体转换] 生成标题失败,跳过保存")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[海马体转换] 保存到记忆仓库时发生错误: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def _generate_title_from_content(self, content: str, batch_num: int = 1) -> str:
|
||||
"""从内容中提取节点名称生成标题
|
||||
|
||||
Args:
|
||||
content: 拼接的内容
|
||||
batch_num: 批次号
|
||||
|
||||
Returns:
|
||||
str: 生成的标题
|
||||
"""
|
||||
try:
|
||||
# 提取所有【节点名称】中的节点名称
|
||||
node_pattern = r'【([^】]+)】'
|
||||
nodes = re.findall(node_pattern, content)
|
||||
|
||||
if nodes:
|
||||
# 去重并限制数量(最多显示前5个)
|
||||
unique_nodes = list(dict.fromkeys(nodes))[:5]
|
||||
title = f"关于{','.join(unique_nodes)}的记忆"
|
||||
return title
|
||||
else:
|
||||
logger.warning("[海马体转换] 无法从内容中提取节点名称")
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[海马体转换] 生成标题时发生错误: {e}", exc_info=True)
|
||||
return ""
|
||||
|
||||
async def _remove_converted_nodes(self, nodes_to_remove: List[str]):
|
||||
"""删除已转换的海马体节点
|
||||
|
||||
Args:
|
||||
nodes_to_remove: 要删除的节点列表
|
||||
"""
|
||||
try:
|
||||
# 获取海马体实例
|
||||
hippocampus = hippocampus_manager.get_hippocampus()
|
||||
memory_graph = hippocampus.memory_graph.G
|
||||
|
||||
removed_count = 0
|
||||
for node in nodes_to_remove:
|
||||
if node in memory_graph:
|
||||
# 删除节点(这会自动删除相关的边)
|
||||
memory_graph.remove_node(node)
|
||||
removed_count += 1
|
||||
logger.info(f"[海马体转换] 已删除节点: {node}")
|
||||
else:
|
||||
logger.debug(f"[海马体转换] 节点 {node} 不存在,跳过删除")
|
||||
|
||||
# 同步到数据库
|
||||
if removed_count > 0:
|
||||
await hippocampus.entorhinal_cortex.sync_memory_to_db()
|
||||
logger.info(f"[海马体转换] 已删除 {removed_count} 个节点并同步到数据库")
|
||||
else:
|
||||
logger.info("[海马体转换] 没有节点需要删除")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[海马体转换] 删除节点时发生错误: {e}", exc_info=True)
|
||||
171
src/memory_system/memory_management_task.py
Normal file
171
src/memory_system/memory_management_task.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.memory_system.Memory_chest import global_memory_chest
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import MemoryChest as MemoryChestModel
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("memory_management")
|
||||
|
||||
|
||||
class MemoryManagementTask(AsyncTask):
|
||||
"""记忆管理定时任务
|
||||
|
||||
根据Memory_chest中的记忆数量与MAX_MEMORY_NUMBER的比例来决定执行频率:
|
||||
- 小于50%:每600秒执行一次
|
||||
- 大于等于50%:每300秒执行一次
|
||||
|
||||
每次执行时随机选择一个title,执行choose_merge_target和merge_memory,
|
||||
然后删除原始记忆
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
task_name="Memory Management Task",
|
||||
wait_before_start=10, # 启动后等待10秒再开始
|
||||
run_interval=300 # 默认300秒间隔,会根据记忆数量动态调整
|
||||
)
|
||||
self.max_memory_number = global_config.memory.max_memory_number
|
||||
|
||||
async def start_task(self, abort_flag: asyncio.Event):
|
||||
"""重写start_task方法,支持动态调整执行间隔"""
|
||||
if self.wait_before_start > 0:
|
||||
# 等待指定时间后开始任务
|
||||
await asyncio.sleep(self.wait_before_start)
|
||||
|
||||
while not abort_flag.is_set():
|
||||
await self.run()
|
||||
|
||||
# 动态调整执行间隔
|
||||
current_interval = self._calculate_interval()
|
||||
logger.info(f"[记忆管理] 下次执行间隔: {current_interval}秒")
|
||||
|
||||
if current_interval > 0:
|
||||
await asyncio.sleep(current_interval)
|
||||
else:
|
||||
break
|
||||
|
||||
def _calculate_interval(self) -> int:
|
||||
"""根据当前记忆数量计算执行间隔"""
|
||||
try:
|
||||
current_count = self._get_memory_count()
|
||||
percentage = current_count / self.max_memory_number
|
||||
|
||||
if percentage < 0.5:
|
||||
# 小于50%,每600秒执行一次
|
||||
return 3600
|
||||
elif percentage < 0.7:
|
||||
# 大于等于50%,每300秒执行一次
|
||||
return 1800
|
||||
elif percentage < 0.9:
|
||||
# 大于等于70%,每120秒执行一次
|
||||
return 300
|
||||
elif percentage < 1.2:
|
||||
return 30
|
||||
else:
|
||||
return 10
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 计算执行间隔时出错: {e}")
|
||||
return 300 # 默认300秒
|
||||
|
||||
def _get_memory_count(self) -> int:
|
||||
"""获取当前记忆数量"""
|
||||
try:
|
||||
count = MemoryChestModel.select().count()
|
||||
logger.debug(f"[记忆管理] 当前记忆数量: {count}")
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 获取记忆数量时出错: {e}")
|
||||
return 0
|
||||
|
||||
async def run(self):
|
||||
"""执行记忆管理任务"""
|
||||
try:
|
||||
logger.info("[记忆管理] 开始执行记忆管理任务")
|
||||
|
||||
# 获取当前记忆数量
|
||||
current_count = self._get_memory_count()
|
||||
percentage = current_count / self.max_memory_number
|
||||
logger.info(f"[记忆管理] 当前记忆数量: {current_count}/{self.max_memory_number} ({percentage:.1%})")
|
||||
|
||||
# 如果记忆数量为0,跳过执行
|
||||
if current_count < 10:
|
||||
logger.info("[记忆管理] 没有太多记忆,跳过执行")
|
||||
return
|
||||
|
||||
# 随机选择一个记忆标题
|
||||
selected_title = self._get_random_memory_title()
|
||||
if not selected_title:
|
||||
logger.warning("[记忆管理] 无法获取随机记忆标题,跳过执行")
|
||||
return
|
||||
|
||||
logger.info(f"[记忆管理] 随机选择的记忆标题: {selected_title}")
|
||||
|
||||
# 执行choose_merge_target获取相关记忆内容
|
||||
related_contents_titles = await global_memory_chest.choose_merge_target(selected_title)
|
||||
if not related_contents_titles:
|
||||
logger.warning("[记忆管理] 未找到相关记忆内容,跳过合并")
|
||||
return
|
||||
|
||||
logger.info(f"[记忆管理] 找到 {len(related_contents_titles)} 条相关记忆")
|
||||
|
||||
# 执行merge_memory合并记忆
|
||||
merged_title, merged_content = await global_memory_chest.merge_memory(related_contents_titles)
|
||||
if not merged_title or not merged_content:
|
||||
logger.warning("[记忆管理] 记忆合并失败,跳过删除")
|
||||
return
|
||||
|
||||
logger.info(f"[记忆管理] 记忆合并成功,新标题: {merged_title}")
|
||||
|
||||
# 删除原始记忆(包括选中的标题和相关的记忆)
|
||||
deleted_count = self._delete_original_memories(related_contents_titles)
|
||||
logger.info(f"[记忆管理] 已删除 {deleted_count} 条原始记忆")
|
||||
|
||||
logger.info("[记忆管理] 记忆管理任务完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 执行记忆管理任务时发生错误: {e}", exc_info=True)
|
||||
|
||||
def _get_random_memory_title(self) -> str:
|
||||
"""随机获取一个记忆标题"""
|
||||
try:
|
||||
# 获取所有记忆标题
|
||||
all_titles = global_memory_chest.get_all_titles()
|
||||
if not all_titles:
|
||||
return ""
|
||||
|
||||
# 随机选择一个标题
|
||||
selected_title = random.choice(all_titles)
|
||||
return selected_title
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 获取随机记忆标题时发生错误: {e}")
|
||||
return ""
|
||||
|
||||
def _delete_original_memories(self, related_contents: List[str]) -> int:
|
||||
"""删除原始记忆"""
|
||||
try:
|
||||
deleted_count = 0
|
||||
# 删除相关记忆(通过内容匹配)
|
||||
for content in related_contents:
|
||||
try:
|
||||
# 通过内容查找并删除对应的记忆
|
||||
memories_to_delete = MemoryChestModel.select().where(MemoryChestModel.content == content)
|
||||
for memory in memories_to_delete:
|
||||
MemoryChestModel.delete().where(MemoryChestModel.id == memory.id).execute()
|
||||
deleted_count += 1
|
||||
logger.debug(f"[记忆管理] 删除相关记忆: {memory.title}")
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 删除相关记忆时出错: {e}")
|
||||
continue
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[记忆管理] 删除原始记忆时发生错误: {e}")
|
||||
return 0
|
||||
156
src/memory_system/memory_utils.py
Normal file
156
src/memory_system/memory_utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆系统工具函数
|
||||
包含模糊查找、相似度计算等工具函数
|
||||
"""
|
||||
import re
|
||||
from difflib import SequenceMatcher
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from src.common.database.database_model import MemoryChest as MemoryChestModel
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("memory_utils")
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
|
||||
Returns:
|
||||
float: 相似度分数 (0-1)
|
||||
"""
|
||||
try:
|
||||
# 预处理文本
|
||||
text1 = preprocess_text(text1)
|
||||
text2 = preprocess_text(text2)
|
||||
|
||||
# 使用SequenceMatcher计算相似度
|
||||
similarity = SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
# 如果其中一个文本包含另一个,提高相似度
|
||||
if text1 in text2 or text2 in text1:
|
||||
similarity = max(similarity, 0.8)
|
||||
|
||||
return similarity
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算相似度时出错: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def preprocess_text(text: str) -> str:
|
||||
"""
|
||||
预处理文本,提高匹配准确性
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
str: 预处理后的文本
|
||||
"""
|
||||
try:
|
||||
# 转换为小写
|
||||
text = text.lower()
|
||||
|
||||
# 移除标点符号和特殊字符
|
||||
text = re.sub(r'[^\w\s]', '', text)
|
||||
|
||||
# 移除多余空格
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预处理文本时出错: {e}")
|
||||
return text
|
||||
|
||||
|
||||
def fuzzy_find_memory_by_title(target_title: str, similarity_threshold: float = 0.9) -> List[Tuple[str, str, float]]:
|
||||
"""
|
||||
根据标题模糊查找记忆
|
||||
|
||||
Args:
|
||||
target_title: 目标标题
|
||||
similarity_threshold: 相似度阈值,默认0.9
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, float]]: 匹配的记忆列表,每个元素为(title, content, similarity_score)
|
||||
"""
|
||||
try:
|
||||
# 获取所有记忆
|
||||
all_memories = MemoryChestModel.select()
|
||||
|
||||
matches = []
|
||||
for memory in all_memories:
|
||||
similarity = calculate_similarity(target_title, memory.title)
|
||||
if similarity >= similarity_threshold:
|
||||
matches.append((memory.title, memory.content, similarity))
|
||||
|
||||
# 按相似度降序排序
|
||||
matches.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
logger.info(f"模糊查找标题 '{target_title}' 找到 {len(matches)} 个匹配项")
|
||||
return matches
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模糊查找记忆时出错: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def find_best_matching_memory(target_title: str, similarity_threshold: float = 0.9) -> Optional[Tuple[str, str, float]]:
|
||||
"""
|
||||
查找最佳匹配的记忆
|
||||
|
||||
Args:
|
||||
target_title: 目标标题
|
||||
similarity_threshold: 相似度阈值
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, str, float]]: 最佳匹配的记忆(title, content, similarity)或None
|
||||
"""
|
||||
try:
|
||||
matches = fuzzy_find_memory_by_title(target_title, similarity_threshold)
|
||||
|
||||
if matches:
|
||||
best_match = matches[0] # 已经按相似度排序,第一个是最佳匹配
|
||||
logger.info(f"找到最佳匹配: '{best_match[0]}' (相似度: {best_match[2]:.3f})")
|
||||
return best_match
|
||||
else:
|
||||
logger.info(f"未找到相似度 >= {similarity_threshold} 的记忆")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查找最佳匹配记忆时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def check_title_exists_fuzzy(target_title: str, similarity_threshold: float = 0.9) -> bool:
|
||||
"""
|
||||
检查标题是否已存在(模糊匹配)
|
||||
|
||||
Args:
|
||||
target_title: 目标标题
|
||||
similarity_threshold: 相似度阈值,默认0.9(较高阈值避免误判)
|
||||
|
||||
Returns:
|
||||
bool: 是否存在相似标题
|
||||
"""
|
||||
try:
|
||||
matches = fuzzy_find_memory_by_title(target_title, similarity_threshold)
|
||||
exists = len(matches) > 0
|
||||
|
||||
if exists:
|
||||
logger.info(f"发现相似标题: '{matches[0][0]}' (相似度: {matches[0][2]:.3f})")
|
||||
else:
|
||||
logger.debug("未发现相似标题")
|
||||
|
||||
return exists
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查标题是否存在时出错: {e}")
|
||||
return False
|
||||
Reference in New Issue
Block a user