feat:记忆系统再优化,现在及时构建,并且不会重复构建
This commit is contained in:
@@ -7,24 +7,21 @@ import re
|
||||
import jieba
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Set, Coroutine, Any
|
||||
from typing import List, Tuple, Set, Coroutine, Any, Dict
|
||||
from collections import Counter
|
||||
from itertools import combinations
|
||||
|
||||
import traceback
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
|
||||
from src.common.database.database_model import GraphNodes, GraphEdges # Peewee Models导入
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp,
|
||||
build_readable_messages,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
) # 导入 build_readable_messages
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||
# 添加cosine_similarity函数
|
||||
def cosine_similarity(v1, v2):
|
||||
"""计算余弦相似度"""
|
||||
@@ -334,6 +331,9 @@ class Hippocampus:
|
||||
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
|
||||
f"如果确定找不出主题或者没有明显主题,返回<none>。"
|
||||
)
|
||||
|
||||
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
@@ -417,14 +417,17 @@ class Hippocampus:
|
||||
# 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算
|
||||
text_length = len(text)
|
||||
topic_num: int | list[int] = 0
|
||||
if text_length <= 6:
|
||||
words = jieba.cut(text)
|
||||
keywords = [word for word in words if len(word) > 1]
|
||||
keywords = list(set(keywords))[:3] # 限制最多3个关键词
|
||||
if keywords:
|
||||
logger.debug(f"提取关键词: {keywords}")
|
||||
return keywords
|
||||
elif text_length <= 12:
|
||||
|
||||
|
||||
words = jieba.cut(text)
|
||||
keywords_lite = [word for word in words if len(word) > 1]
|
||||
keywords_lite = list(set(keywords_lite))
|
||||
if keywords_lite:
|
||||
logger.debug(f"提取关键词极简版: {keywords_lite}")
|
||||
|
||||
|
||||
|
||||
if text_length <= 12:
|
||||
topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本)
|
||||
elif text_length <= 20:
|
||||
topic_num = [2, 4] # 11-20字符: 2个关键词 (22.76%的文本)
|
||||
@@ -451,169 +454,7 @@ class Hippocampus:
|
||||
if keywords:
|
||||
logger.debug(f"提取关键词: {keywords}")
|
||||
|
||||
return keywords
|
||||
|
||||
async def get_memory_from_text(
|
||||
self,
|
||||
text: str,
|
||||
max_memory_num: int = 3,
|
||||
max_memory_length: int = 2,
|
||||
max_depth: int = 3,
|
||||
fast_retrieval: bool = False,
|
||||
) -> list:
|
||||
"""从文本中提取关键词并获取相关记忆。
|
||||
|
||||
Args:
|
||||
text (str): 输入文本
|
||||
max_memory_num (int, optional): 返回的记忆条目数量上限。默认为3,表示最多返回3条与输入文本相关度最高的记忆。
|
||||
max_memory_length (int, optional): 每个主题最多返回的记忆条目数量。默认为2,表示每个主题最多返回2条相似度最高的记忆。
|
||||
max_depth (int, optional): 记忆检索深度。默认为3。值越大,检索范围越广,可以获取更多间接相关的记忆,但速度会变慢。
|
||||
fast_retrieval (bool, optional): 是否使用快速检索。默认为False。
|
||||
如果为True,使用jieba分词和TF-IDF提取关键词,速度更快但可能不够准确。
|
||||
如果为False,使用LLM提取关键词,速度较慢但更准确。
|
||||
|
||||
Returns:
|
||||
list: 记忆列表,每个元素是一个元组 (topic, memory_content)
|
||||
- topic: str, 记忆主题
|
||||
- memory_content: str, 该主题下的完整记忆内容
|
||||
"""
|
||||
keywords = await self.get_keywords_from_text(text)
|
||||
|
||||
# 过滤掉不存在于记忆图中的关键词
|
||||
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
||||
if not valid_keywords:
|
||||
logger.debug("没有找到有效的关键词节点")
|
||||
return []
|
||||
|
||||
logger.info(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||
|
||||
# 从每个关键词获取记忆
|
||||
activate_map = {} # 存储每个词的累计激活值
|
||||
|
||||
# 对每个关键词进行扩散式检索
|
||||
for keyword in valid_keywords:
|
||||
logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):")
|
||||
# 初始化激活值
|
||||
activation_values = {keyword: 1.0}
|
||||
# 记录已访问的节点
|
||||
visited_nodes = {keyword}
|
||||
# 待处理的节点队列,每个元素是(节点, 激活值, 当前深度)
|
||||
nodes_to_process = [(keyword, 1.0, 0)]
|
||||
|
||||
while nodes_to_process:
|
||||
current_node, current_activation, current_depth = nodes_to_process.pop(0)
|
||||
|
||||
# 如果激活值小于0或超过最大深度,停止扩散
|
||||
if current_activation <= 0 or current_depth >= max_depth:
|
||||
continue
|
||||
|
||||
# 获取当前节点的所有邻居
|
||||
neighbors = list(self.memory_graph.G.neighbors(current_node))
|
||||
|
||||
for neighbor in neighbors:
|
||||
if neighbor in visited_nodes:
|
||||
continue
|
||||
|
||||
# 获取连接强度
|
||||
edge_data = self.memory_graph.G[current_node][neighbor]
|
||||
strength = edge_data.get("strength", 1)
|
||||
|
||||
# 计算新的激活值
|
||||
new_activation = current_activation - (1 / strength)
|
||||
|
||||
if new_activation > 0:
|
||||
activation_values[neighbor] = new_activation
|
||||
visited_nodes.add(neighbor)
|
||||
nodes_to_process.append((neighbor, new_activation, current_depth + 1))
|
||||
# logger.debug(
|
||||
# f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})"
|
||||
# ) # noqa: E501
|
||||
|
||||
# 更新激活映射
|
||||
for node, activation_value in activation_values.items():
|
||||
if activation_value > 0:
|
||||
if node in activate_map:
|
||||
activate_map[node] += activation_value
|
||||
else:
|
||||
activate_map[node] = activation_value
|
||||
|
||||
# 输出激活映射
|
||||
# logger.info("激活映射统计:")
|
||||
# for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True):
|
||||
# logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}")
|
||||
|
||||
# 基于激活值平方的独立概率选择
|
||||
remember_map = {}
|
||||
# logger.info("基于激活值平方的归一化选择:")
|
||||
|
||||
# 计算所有激活值的平方和
|
||||
total_squared_activation = sum(activation**2 for activation in activate_map.values())
|
||||
if total_squared_activation > 0:
|
||||
# 计算归一化的激活值
|
||||
normalized_activations = {
|
||||
node: (activation**2) / total_squared_activation for node, activation in activate_map.items()
|
||||
}
|
||||
|
||||
# 按归一化激活值排序并选择前max_memory_num个
|
||||
sorted_nodes = sorted(normalized_activations.items(), key=lambda x: x[1], reverse=True)[:max_memory_num]
|
||||
|
||||
# 将选中的节点添加到remember_map
|
||||
for node, normalized_activation in sorted_nodes:
|
||||
remember_map[node] = activate_map[node] # 使用原始激活值
|
||||
logger.debug(
|
||||
f"节点 '{node}' (归一化激活值: {normalized_activation:.2f}, 激活值: {activate_map[node]:.2f})"
|
||||
)
|
||||
else:
|
||||
logger.info("没有有效的激活值")
|
||||
|
||||
# 从选中的节点中提取记忆
|
||||
all_memories = []
|
||||
# logger.info("开始从选中的节点中提取记忆:")
|
||||
for node, activation in remember_map.items():
|
||||
logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):")
|
||||
node_data = self.memory_graph.G.nodes[node]
|
||||
memory_items = node_data.get("memory_items", "")
|
||||
# 直接使用完整的记忆内容
|
||||
if memory_items:
|
||||
logger.debug("节点包含完整记忆")
|
||||
# 计算记忆与输入文本的相似度
|
||||
memory_words = set(jieba.cut(memory_items))
|
||||
text_words = set(jieba.cut(text))
|
||||
all_words = memory_words | text_words
|
||||
if all_words:
|
||||
# 计算相似度(虽然这里没有使用,但保持逻辑一致性)
|
||||
v1 = [1 if word in memory_words else 0 for word in all_words]
|
||||
v2 = [1 if word in text_words else 0 for word in all_words]
|
||||
_ = cosine_similarity(v1, v2) # 计算但不使用,用_表示
|
||||
|
||||
# 添加完整记忆到结果中
|
||||
all_memories.append((node, memory_items, activation))
|
||||
else:
|
||||
logger.info("节点没有记忆")
|
||||
|
||||
# 去重(基于记忆内容)
|
||||
logger.debug("开始记忆去重:")
|
||||
seen_memories = set()
|
||||
unique_memories = []
|
||||
for topic, memory_items, activation_value in all_memories:
|
||||
# memory_items现在是完整的字符串格式
|
||||
memory = memory_items if memory_items else ""
|
||||
if memory not in seen_memories:
|
||||
seen_memories.add(memory)
|
||||
unique_memories.append((topic, memory_items, activation_value))
|
||||
logger.debug(f"保留记忆: {memory} (来自节点: {topic}, 激活值: {activation_value:.2f})")
|
||||
else:
|
||||
logger.debug(f"跳过重复记忆: {memory} (来自节点: {topic})")
|
||||
|
||||
# 转换为(关键词, 记忆)格式
|
||||
result = []
|
||||
for topic, memory_items, _ in unique_memories:
|
||||
# memory_items现在是完整的字符串格式
|
||||
memory = memory_items if memory_items else ""
|
||||
result.append((topic, memory))
|
||||
logger.debug(f"选中记忆: {memory} (来自节点: {topic})")
|
||||
|
||||
return result
|
||||
return keywords,keywords_lite
|
||||
|
||||
async def get_memory_from_topic(
|
||||
self,
|
||||
@@ -771,7 +612,7 @@ class Hippocampus:
|
||||
|
||||
return result
|
||||
|
||||
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]:
|
||||
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str],list[str]]:
|
||||
"""从文本中提取关键词并获取相关记忆。
|
||||
|
||||
Args:
|
||||
@@ -785,13 +626,13 @@ class Hippocampus:
|
||||
float: 激活节点数与总节点数的比值
|
||||
list[str]: 有效的关键词
|
||||
"""
|
||||
keywords = await self.get_keywords_from_text(text)
|
||||
keywords,keywords_lite = await self.get_keywords_from_text(text)
|
||||
|
||||
# 过滤掉不存在于记忆图中的关键词
|
||||
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
||||
if not valid_keywords:
|
||||
# logger.info("没有找到有效的关键词节点")
|
||||
return 0, []
|
||||
return 0, keywords,keywords_lite
|
||||
|
||||
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||
|
||||
@@ -858,7 +699,7 @@ class Hippocampus:
|
||||
activation_ratio = activation_ratio * 50
|
||||
logger.debug(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}")
|
||||
|
||||
return activation_ratio, keywords
|
||||
return activation_ratio, keywords,keywords_lite
|
||||
|
||||
|
||||
# 负责海马体与其他部分的交互
|
||||
@@ -867,92 +708,6 @@ class EntorhinalCortex:
|
||||
self.hippocampus = hippocampus
|
||||
self.memory_graph = hippocampus.memory_graph
|
||||
|
||||
def get_memory_sample(self):
|
||||
"""从数据库获取记忆样本"""
|
||||
# 硬编码:每条消息最大记忆次数
|
||||
max_memorized_time_per_msg = 2
|
||||
|
||||
# 创建双峰分布的记忆调度器
|
||||
sample_scheduler = MemoryBuildScheduler(
|
||||
n_hours1=global_config.memory.memory_build_distribution[0],
|
||||
std_hours1=global_config.memory.memory_build_distribution[1],
|
||||
weight1=global_config.memory.memory_build_distribution[2],
|
||||
n_hours2=global_config.memory.memory_build_distribution[3],
|
||||
std_hours2=global_config.memory.memory_build_distribution[4],
|
||||
weight2=global_config.memory.memory_build_distribution[5],
|
||||
total_samples=global_config.memory.memory_build_sample_num,
|
||||
)
|
||||
|
||||
timestamps = sample_scheduler.get_timestamp_array()
|
||||
# 使用 translate_timestamp_to_human_readable 并指定 mode="normal"
|
||||
readable_timestamps = [translate_timestamp_to_human_readable(ts, mode="normal") for ts in timestamps]
|
||||
for _, readable_timestamp in zip(timestamps, readable_timestamps, strict=False):
|
||||
logger.debug(f"回忆往事: {readable_timestamp}")
|
||||
chat_samples = []
|
||||
for timestamp in timestamps:
|
||||
if messages := self.random_get_msg_snippet(
|
||||
timestamp,
|
||||
global_config.memory.memory_build_sample_length,
|
||||
max_memorized_time_per_msg,
|
||||
):
|
||||
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
||||
logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条")
|
||||
chat_samples.append(messages)
|
||||
else:
|
||||
logger.debug(f"时间戳 {timestamp} 的消息无需记忆")
|
||||
|
||||
return chat_samples
|
||||
|
||||
@staticmethod
|
||||
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
|
||||
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
|
||||
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
|
||||
time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟
|
||||
|
||||
for _ in range(3):
|
||||
# 定义时间范围:从目标时间戳开始,向后推移 time_window_seconds
|
||||
timestamp_start = target_timestamp
|
||||
timestamp_end = target_timestamp + time_window_seconds
|
||||
|
||||
if chosen_message := get_raw_msg_by_timestamp(
|
||||
timestamp_start=timestamp_start,
|
||||
timestamp_end=timestamp_end,
|
||||
limit=1,
|
||||
limit_mode="earliest",
|
||||
):
|
||||
chat_id: str = chosen_message[0].get("chat_id") # type: ignore
|
||||
|
||||
if messages := get_raw_msg_by_timestamp_with_chat(
|
||||
timestamp_start=timestamp_start,
|
||||
timestamp_end=timestamp_end,
|
||||
limit=chat_size,
|
||||
limit_mode="earliest",
|
||||
chat_id=chat_id,
|
||||
):
|
||||
# 检查获取到的所有消息是否都未达到最大记忆次数
|
||||
all_valid = True
|
||||
for message in messages:
|
||||
if message.get("memorized_times", 0) >= max_memorized_time_per_msg:
|
||||
all_valid = False
|
||||
break
|
||||
|
||||
# 如果所有消息都有效
|
||||
if all_valid:
|
||||
# 更新数据库中的记忆次数
|
||||
for message in messages:
|
||||
# 确保在更新前获取最新的 memorized_times
|
||||
current_memorized_times = message.get("memorized_times", 0)
|
||||
# 使用 Peewee 更新记录
|
||||
Messages.update(memorized_times=current_memorized_times + 1).where(
|
||||
Messages.message_id == message["message_id"]
|
||||
).execute()
|
||||
return messages # 直接返回原始的消息列表
|
||||
|
||||
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
|
||||
|
||||
# 三次尝试都失败,返回 None
|
||||
return None
|
||||
|
||||
async def sync_memory_to_db(self):
|
||||
"""将记忆图同步到数据库"""
|
||||
start_time = time.time()
|
||||
@@ -1407,81 +1162,14 @@ class ParahippocampalGyrus:
|
||||
similar_topics.sort(key=lambda x: x[1], reverse=True)
|
||||
similar_topics = similar_topics[:3]
|
||||
similar_topics_dict[topic] = similar_topics
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"prompt: {topic_what_prompt}")
|
||||
logger.info(f"压缩后的记忆: {compressed_memory}")
|
||||
logger.info(f"相似主题: {similar_topics_dict}")
|
||||
|
||||
return compressed_memory, similar_topics_dict
|
||||
|
||||
async def operation_build_memory(self):
|
||||
# sourcery skip: merge-list-appends-into-extend
|
||||
logger.info("------------------------------------开始构建记忆--------------------------------------")
|
||||
start_time = time.time()
|
||||
memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample()
|
||||
all_added_nodes = []
|
||||
all_connected_nodes = []
|
||||
all_added_edges = []
|
||||
for i, messages in enumerate(memory_samples, 1):
|
||||
all_topics = []
|
||||
compress_rate = global_config.memory.memory_compress_rate
|
||||
try:
|
||||
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
|
||||
except Exception as e:
|
||||
logger.error(f"压缩记忆时发生错误: {e}")
|
||||
continue
|
||||
for topic, memory in compressed_memory:
|
||||
logger.info(f"取得记忆: {topic} - {memory}")
|
||||
for topic, similar_topics in similar_topics_dict.items():
|
||||
logger.debug(f"相似话题: {topic} - {similar_topics}")
|
||||
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
logger.debug(f"添加节点: {', '.join(topic for topic, _ in compressed_memory)}")
|
||||
all_added_nodes.extend(topic for topic, _ in compressed_memory)
|
||||
|
||||
for topic, memory in compressed_memory:
|
||||
await self.memory_graph.add_dot(topic, memory, self.hippocampus)
|
||||
all_topics.append(topic)
|
||||
|
||||
if topic in similar_topics_dict:
|
||||
similar_topics = similar_topics_dict[topic]
|
||||
for similar_topic, similarity in similar_topics:
|
||||
if topic != similar_topic:
|
||||
strength = int(similarity * 10)
|
||||
|
||||
logger.debug(f"连接相似节点: {topic} 和 {similar_topic} (强度: {strength})")
|
||||
all_added_edges.append(f"{topic}-{similar_topic}")
|
||||
|
||||
all_connected_nodes.append(topic)
|
||||
all_connected_nodes.append(similar_topic)
|
||||
|
||||
self.memory_graph.G.add_edge(
|
||||
topic,
|
||||
similar_topic,
|
||||
strength=strength,
|
||||
created_time=current_time,
|
||||
last_modified=current_time,
|
||||
)
|
||||
|
||||
for topic1, topic2 in combinations(all_topics, 2):
|
||||
logger.debug(f"连接同批次节点: {topic1} 和 {topic2}")
|
||||
all_added_edges.append(f"{topic1}-{topic2}")
|
||||
self.memory_graph.connect_dot(topic1, topic2)
|
||||
|
||||
progress = (i / len(memory_samples)) * 100
|
||||
bar_length = 30
|
||||
filled_length = int(bar_length * i // len(memory_samples))
|
||||
bar = "█" * filled_length + "-" * (bar_length - filled_length)
|
||||
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
|
||||
|
||||
if all_added_nodes:
|
||||
logger.info(f"更新记忆: {', '.join(all_added_nodes)}")
|
||||
if all_added_edges:
|
||||
logger.debug(f"强化连接: {', '.join(all_added_edges)}")
|
||||
if all_connected_nodes:
|
||||
logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}")
|
||||
|
||||
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"---------------------记忆构建耗时: {end_time - start_time:.2f} 秒---------------------")
|
||||
|
||||
async def operation_forget_topic(self, percentage=0.005):
|
||||
start_time = time.time()
|
||||
logger.info("[遗忘] 开始检查数据库...")
|
||||
@@ -1650,8 +1338,7 @@ class HippocampusManager:
|
||||
logger.info(f"""
|
||||
--------------------------------
|
||||
记忆系统参数配置:
|
||||
构建间隔: {global_config.memory.memory_build_interval}秒|样本数: {global_config.memory.memory_build_sample_num},长度: {global_config.memory.memory_build_sample_length}|压缩率: {global_config.memory.memory_compress_rate}
|
||||
记忆构建分布: {global_config.memory.memory_build_distribution}
|
||||
构建频率: {global_config.memory.memory_build_frequency}秒|压缩率: {global_config.memory.memory_compress_rate}
|
||||
遗忘间隔: {global_config.memory.forget_memory_interval}秒|遗忘比例: {global_config.memory.memory_forget_percentage}|遗忘: {global_config.memory.memory_forget_time}小时之后
|
||||
记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count}
|
||||
--------------------------------""") # noqa: E501
|
||||
@@ -1663,39 +1350,60 @@ class HippocampusManager:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return self._hippocampus
|
||||
|
||||
async def build_memory(self):
|
||||
"""构建记忆的公共接口"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return await self._hippocampus.parahippocampal_gyrus.operation_build_memory()
|
||||
|
||||
async def forget_memory(self, percentage: float = 0.005):
|
||||
"""遗忘记忆的公共接口"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage)
|
||||
|
||||
|
||||
|
||||
async def get_memory_from_text(
|
||||
self,
|
||||
text: str,
|
||||
max_memory_num: int = 3,
|
||||
max_memory_length: int = 2,
|
||||
max_depth: int = 3,
|
||||
fast_retrieval: bool = False,
|
||||
) -> list:
|
||||
"""从文本中获取相关记忆的公共接口"""
|
||||
async def build_memory_for_chat(self, chat_id: str):
|
||||
"""为指定chat_id构建记忆(在heartFC_chat.py中调用)"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
|
||||
try:
|
||||
response = await self._hippocampus.get_memory_from_text(
|
||||
text, max_memory_num, max_memory_length, max_depth, fast_retrieval
|
||||
)
|
||||
# 检查是否需要构建记忆
|
||||
logger.info(f"为 {chat_id} 构建记忆")
|
||||
if memory_segment_manager.check_and_build_memory_for_chat(chat_id):
|
||||
logger.info(f"为 {chat_id} 构建记忆,需要构建记忆")
|
||||
messages = memory_segment_manager.get_messages_for_memory_build(chat_id, 30 / global_config.memory.memory_build_frequency)
|
||||
if messages:
|
||||
logger.info(f"为 {chat_id} 构建记忆,消息数量: {len(messages)}")
|
||||
|
||||
# 调用记忆压缩和构建
|
||||
compressed_memory, similar_topics_dict = await self._hippocampus.parahippocampal_gyrus.memory_compress(
|
||||
messages, global_config.memory.memory_compress_rate
|
||||
)
|
||||
|
||||
# 添加记忆节点
|
||||
current_time = time.time()
|
||||
for topic, memory in compressed_memory:
|
||||
await self._hippocampus.memory_graph.add_dot(topic, memory, self._hippocampus)
|
||||
|
||||
# 连接相似主题
|
||||
if topic in similar_topics_dict:
|
||||
similar_topics = similar_topics_dict[topic]
|
||||
for similar_topic, similarity in similar_topics:
|
||||
if topic != similar_topic:
|
||||
strength = int(similarity * 10)
|
||||
self._hippocampus.memory_graph.G.add_edge(
|
||||
topic, similar_topic,
|
||||
strength=strength,
|
||||
created_time=current_time,
|
||||
last_modified=current_time
|
||||
)
|
||||
|
||||
# 同步到数据库
|
||||
await self._hippocampus.entorhinal_cortex.sync_memory_to_db()
|
||||
logger.info(f"为 {chat_id} 构建记忆完成")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文本激活记忆失败: {e}")
|
||||
response = []
|
||||
return response
|
||||
logger.error(f"为 {chat_id} 构建记忆失败: {e}")
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def get_memory_from_topic(
|
||||
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||
@@ -1717,12 +1425,11 @@ class HippocampusManager:
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
try:
|
||||
response, keywords = await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval)
|
||||
response, keywords,keywords_lite = await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval)
|
||||
except Exception as e:
|
||||
logger.error(f"文本产生激活值失败: {e}")
|
||||
response = 0.0
|
||||
keywords = [] # 在异常情况下初始化 keywords 为空列表
|
||||
return response, keywords
|
||||
logger.error(traceback.format_exc())
|
||||
return 0.0, [],[]
|
||||
|
||||
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
||||
"""从关键词获取相关记忆的公共接口"""
|
||||
@@ -1741,3 +1448,90 @@ class HippocampusManager:
|
||||
hippocampus_manager = HippocampusManager()
|
||||
|
||||
|
||||
# 在Hippocampus类中添加新的记忆构建管理器
|
||||
class MemoryBuilder:
|
||||
"""记忆构建器
|
||||
|
||||
为每个chat_id维护消息缓存和触发机制,类似ExpressionLearner
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.last_update_time: float = time.time()
|
||||
self.last_processed_time: float = 0.0
|
||||
|
||||
def should_trigger_memory_build(self) -> bool:
|
||||
"""检查是否应该触发记忆构建"""
|
||||
current_time = time.time()
|
||||
|
||||
# 检查时间间隔
|
||||
time_diff = current_time - self.last_update_time
|
||||
if time_diff < 600 /global_config.memory.memory_build_frequency:
|
||||
return False
|
||||
|
||||
# 检查消息数量
|
||||
|
||||
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_update_time,
|
||||
timestamp_end=current_time,
|
||||
)
|
||||
|
||||
logger.info(f"最近消息数量: {len(recent_messages)},间隔时间: {time_diff}")
|
||||
|
||||
if not recent_messages or len(recent_messages) < 30/global_config.memory.memory_build_frequency :
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_messages_for_memory_build(self, threshold: int = 25) -> List[Dict[str, Any]]:
|
||||
"""获取用于记忆构建的消息"""
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_update_time,
|
||||
timestamp_end=current_time,
|
||||
limit=threshold,
|
||||
)
|
||||
|
||||
if messages:
|
||||
# 更新最后处理时间
|
||||
self.last_processed_time = current_time
|
||||
self.last_update_time = current_time
|
||||
|
||||
return messages or []
|
||||
|
||||
|
||||
|
||||
class MemorySegmentManager:
|
||||
"""记忆段管理器
|
||||
|
||||
管理所有chat_id的MemoryBuilder实例,自动检查和触发记忆构建
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.builders: Dict[str, MemoryBuilder] = {}
|
||||
|
||||
def get_or_create_builder(self, chat_id: str) -> MemoryBuilder:
|
||||
"""获取或创建指定chat_id的MemoryBuilder"""
|
||||
if chat_id not in self.builders:
|
||||
self.builders[chat_id] = MemoryBuilder(chat_id)
|
||||
return self.builders[chat_id]
|
||||
|
||||
def check_and_build_memory_for_chat(self, chat_id: str) -> bool:
|
||||
"""检查指定chat_id是否需要构建记忆,如果需要则返回True"""
|
||||
builder = self.get_or_create_builder(chat_id)
|
||||
return builder.should_trigger_memory_build()
|
||||
|
||||
def get_messages_for_memory_build(self, chat_id: str, threshold: int = 25) -> List[Dict[str, Any]]:
|
||||
"""获取指定chat_id用于记忆构建的消息"""
|
||||
if chat_id not in self.builders:
|
||||
return []
|
||||
return self.builders[chat_id].get_messages_for_memory_build(threshold)
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
memory_segment_manager = MemorySegmentManager()
|
||||
|
||||
|
||||
@@ -105,8 +105,8 @@ class MemoryActivator:
|
||||
valid_keywords=list(keywords_list), max_memory_num=5, max_memory_length=3, max_depth=3
|
||||
)
|
||||
|
||||
logger.info(f"当前记忆关键词: {keywords_list}")
|
||||
logger.info(f"获取到的记忆: {related_memory}")
|
||||
# logger.info(f"当前记忆关键词: {keywords_list}")
|
||||
logger.debug(f"获取到的记忆: {related_memory}")
|
||||
|
||||
if not related_memory:
|
||||
logger.debug("海马体没有返回相关记忆")
|
||||
@@ -141,7 +141,7 @@ class MemoryActivator:
|
||||
|
||||
# 如果只有少量记忆,直接返回
|
||||
if len(candidate_memories) <= 2:
|
||||
logger.info(f"候选记忆较少({len(candidate_memories)}个),直接返回")
|
||||
logger.debug(f"候选记忆较少({len(candidate_memories)}个),直接返回")
|
||||
# 转换为 (keyword, content) 格式
|
||||
return [(mem["keyword"], mem["content"]) for mem in candidate_memories]
|
||||
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
class MemoryBuildScheduler:
|
||||
def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
|
||||
"""
|
||||
初始化记忆构建调度器
|
||||
|
||||
参数:
|
||||
n_hours1 (float): 第一个分布的均值(距离现在的小时数)
|
||||
std_hours1 (float): 第一个分布的标准差(小时)
|
||||
weight1 (float): 第一个分布的权重
|
||||
n_hours2 (float): 第二个分布的均值(距离现在的小时数)
|
||||
std_hours2 (float): 第二个分布的标准差(小时)
|
||||
weight2 (float): 第二个分布的权重
|
||||
total_samples (int): 要生成的总时间点数量
|
||||
"""
|
||||
# 验证参数
|
||||
if total_samples <= 0:
|
||||
raise ValueError("total_samples 必须大于0")
|
||||
if weight1 < 0 or weight2 < 0:
|
||||
raise ValueError("权重必须为非负数")
|
||||
if std_hours1 < 0 or std_hours2 < 0:
|
||||
raise ValueError("标准差必须为非负数")
|
||||
|
||||
# 归一化权重
|
||||
total_weight = weight1 + weight2
|
||||
if total_weight == 0:
|
||||
raise ValueError("权重总和不能为0")
|
||||
self.weight1 = weight1 / total_weight
|
||||
self.weight2 = weight2 / total_weight
|
||||
|
||||
self.n_hours1 = n_hours1
|
||||
self.std_hours1 = std_hours1
|
||||
self.n_hours2 = n_hours2
|
||||
self.std_hours2 = std_hours2
|
||||
self.total_samples = total_samples
|
||||
self.base_time = datetime.now()
|
||||
|
||||
def generate_time_samples(self):
|
||||
"""生成混合分布的时间采样点"""
|
||||
# 根据权重计算每个分布的样本数
|
||||
samples1 = max(1, int(self.total_samples * self.weight1))
|
||||
samples2 = max(1, self.total_samples - samples1) # 确保 samples2 至少为1
|
||||
|
||||
# 生成两个正态分布的小时偏移
|
||||
hours_offset1 = np.random.normal(loc=self.n_hours1, scale=self.std_hours1, size=samples1)
|
||||
hours_offset2 = np.random.normal(loc=self.n_hours2, scale=self.std_hours2, size=samples2)
|
||||
|
||||
# 合并两个分布的偏移
|
||||
hours_offset = np.concatenate([hours_offset1, hours_offset2])
|
||||
|
||||
# 将偏移转换为实际时间戳(使用绝对值确保时间点在过去)
|
||||
timestamps = [self.base_time - timedelta(hours=abs(offset)) for offset in hours_offset]
|
||||
|
||||
# 按时间排序(从最早到最近)
|
||||
return sorted(timestamps)
|
||||
|
||||
def get_timestamp_array(self):
|
||||
"""返回时间戳数组"""
|
||||
timestamps = self.generate_time_samples()
|
||||
return [int(t.timestamp()) for t in timestamps]
|
||||
|
||||
|
||||
# def print_time_samples(timestamps, show_distribution=True):
|
||||
# """打印时间样本和分布信息"""
|
||||
# print(f"\n生成的{len(timestamps)}个时间点分布:")
|
||||
# print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)")
|
||||
# print("-" * 50)
|
||||
|
||||
# now = datetime.now()
|
||||
# time_diffs = []
|
||||
|
||||
# for i, timestamp in enumerate(timestamps, 1):
|
||||
# hours_diff = (now - timestamp).total_seconds() / 3600
|
||||
# time_diffs.append(hours_diff)
|
||||
# print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}")
|
||||
|
||||
# # 打印统计信息
|
||||
# print("\n统计信息:")
|
||||
# print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时")
|
||||
# print(f"标准差:{np.std(time_diffs):.2f}小时")
|
||||
# print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)")
|
||||
# print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)")
|
||||
|
||||
# if show_distribution:
|
||||
# # 计算时间分布的直方图
|
||||
# hist, bins = np.histogram(time_diffs, bins=40)
|
||||
# print("\n时间分布(每个*代表一个时间点):")
|
||||
# for i in range(len(hist)):
|
||||
# if hist[i] > 0:
|
||||
# print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}")
|
||||
|
||||
|
||||
# # 使用示例
|
||||
# if __name__ == "__main__":
|
||||
# # 创建一个双峰分布的记忆调度器
|
||||
# scheduler = MemoryBuildScheduler(
|
||||
# n_hours1=12, # 第一个分布均值(12小时前)
|
||||
# std_hours1=8, # 第一个分布标准差
|
||||
# weight1=0.7, # 第一个分布权重 70%
|
||||
# n_hours2=36, # 第二个分布均值(36小时前)
|
||||
# std_hours2=24, # 第二个分布标准差
|
||||
# weight2=0.3, # 第二个分布权重 30%
|
||||
# total_samples=50, # 总共生成50个时间点
|
||||
# )
|
||||
|
||||
# # 生成时间分布
|
||||
# timestamps = scheduler.generate_time_samples()
|
||||
|
||||
# # 打印结果,包含分布可视化
|
||||
# print_time_samples(timestamps, show_distribution=True)
|
||||
|
||||
# # 打印时间戳数组
|
||||
# timestamp_array = scheduler.get_timestamp_array()
|
||||
# print("\n时间戳数组(Unix时间戳):")
|
||||
# print("[", end="")
|
||||
# for i, ts in enumerate(timestamp_array):
|
||||
# if i > 0:
|
||||
# print(", ", end="")
|
||||
# print(ts, end="")
|
||||
# print("]")
|
||||
Reference in New Issue
Block a user