v0.3.0 记忆和知识库
beta
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
import jieba
|
||||
from llm_module import LLMModel
|
||||
from .llm_module import LLMModel
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
import math
|
||||
@@ -9,9 +9,9 @@ from collections import Counter
|
||||
import datetime
|
||||
import random
|
||||
import time
|
||||
|
||||
from ..chat.config import global_config
|
||||
import sys
|
||||
sys.path.append("C:/GitHub/MegMeg-bot") # 添加项目根目录到 Python 路径
|
||||
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||
from src.common.database import Database # 使用正确的导入语法
|
||||
|
||||
class Memory_graph:
|
||||
@@ -23,44 +23,67 @@ class Memory_graph:
|
||||
self.G.add_edge(concept1, concept2)
|
||||
|
||||
def add_dot(self, concept, memory):
|
||||
self.G.add_node(concept, memory_items=memory)
|
||||
if concept in self.G:
|
||||
# 如果节点已存在,将新记忆添加到现有列表中
|
||||
if 'memory_items' in self.G.nodes[concept]:
|
||||
if not isinstance(self.G.nodes[concept]['memory_items'], list):
|
||||
# 如果当前不是列表,将其转换为列表
|
||||
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
|
||||
self.G.nodes[concept]['memory_items'].append(memory)
|
||||
else:
|
||||
self.G.nodes[concept]['memory_items'] = [memory]
|
||||
else:
|
||||
# 如果是新节点,创建新的记忆列表
|
||||
self.G.add_node(concept, memory_items=[memory])
|
||||
|
||||
def get_dot(self, concept):
|
||||
# 检查节点是否存在于图中
|
||||
if concept in self.G:
|
||||
# 从图中获取节点数据
|
||||
node_data = self.G.nodes[concept]
|
||||
print(node_data)
|
||||
# print(node_data)
|
||||
# 创建新的Memory_dot对象
|
||||
return concept,node_data
|
||||
return None
|
||||
|
||||
def get_related_item(self, topic, depth=1):
|
||||
if topic not in self.G:
|
||||
return set()
|
||||
return [], []
|
||||
|
||||
items_set = set()
|
||||
first_layer_items = []
|
||||
second_layer_items = []
|
||||
|
||||
# 获取相邻节点
|
||||
neighbors = list(self.G.neighbors(topic))
|
||||
print(f"第一层: {topic}")
|
||||
# print(f"第一层: {topic}")
|
||||
|
||||
# 获取当前节点的记忆项
|
||||
node_data = self.get_dot(topic)
|
||||
if node_data:
|
||||
concept, data = node_data
|
||||
if 'memory_items' in data:
|
||||
items_set.add(data['memory_items'])
|
||||
memory_items = data['memory_items']
|
||||
if isinstance(memory_items, list):
|
||||
first_layer_items.extend(memory_items)
|
||||
else:
|
||||
first_layer_items.append(memory_items)
|
||||
|
||||
# 获取相邻节点的记忆项
|
||||
for neighbor in neighbors:
|
||||
print(f"第二层: {neighbor}")
|
||||
node_data = self.get_dot(neighbor)
|
||||
if node_data:
|
||||
concept, data = node_data
|
||||
if 'memory_items' in data:
|
||||
items_set.add(data['memory_items'])
|
||||
# 只在depth=2时获取第二层记忆
|
||||
if depth >= 2:
|
||||
# 获取相邻节点的记忆项
|
||||
for neighbor in neighbors:
|
||||
# print(f"第二层: {neighbor}")
|
||||
node_data = self.get_dot(neighbor)
|
||||
if node_data:
|
||||
concept, data = node_data
|
||||
if 'memory_items' in data:
|
||||
memory_items = data['memory_items']
|
||||
if isinstance(memory_items, list):
|
||||
second_layer_items.extend(memory_items)
|
||||
else:
|
||||
second_layer_items.append(memory_items)
|
||||
|
||||
return items_set
|
||||
return first_layer_items, second_layer_items
|
||||
|
||||
def store_memory(self):
|
||||
for node in self.G.nodes():
|
||||
@@ -100,7 +123,7 @@ class Memory_graph:
|
||||
for node in self.G.nodes(data=True):
|
||||
node_data = {
|
||||
'concept': node[0],
|
||||
'memory_items': node[1].get('memory_items', None)
|
||||
'memory_items': node[1].get('memory_items', []) # 默认为空列表
|
||||
}
|
||||
self.db.db.graph_data.nodes.insert_one(node_data)
|
||||
# 保存边
|
||||
@@ -117,7 +140,10 @@ class Memory_graph:
|
||||
# 加载节点
|
||||
nodes = self.db.db.graph_data.nodes.find()
|
||||
for node in nodes:
|
||||
self.G.add_node(node['concept'], memory_items=node['memory_items'])
|
||||
memory_items = node.get('memory_items', [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
self.G.add_node(node['concept'], memory_items=memory_items)
|
||||
# 加载边
|
||||
edges = self.db.db.graph_data.edges.find()
|
||||
for edge in edges:
|
||||
@@ -138,6 +164,26 @@ def calculate_information_content(text):
|
||||
|
||||
return entropy
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
Database.initialize(
|
||||
global_config.MONGODB_HOST,
|
||||
global_config.MONGODB_PORT,
|
||||
global_config.DATABASE_NAME
|
||||
)
|
||||
memory_graph = Memory_graph()
|
||||
|
||||
llm_model = LLMModel()
|
||||
llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
||||
|
||||
memory_graph.load_graph_from_db()
|
||||
|
||||
end_time = time.time()
|
||||
print(f"加载海马体耗时: {end_time - start_time:.2f} 秒")
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
# 初始化数据库
|
||||
Database.initialize(
|
||||
@@ -155,13 +201,14 @@ def main():
|
||||
current_timestamp = datetime.datetime.now().timestamp()
|
||||
chat_text = []
|
||||
|
||||
chat_size =30
|
||||
chat_size =40
|
||||
|
||||
for _ in range(60): # 循环10次
|
||||
random_time = current_timestamp - random.randint(1, 3600*3) # 随机时间
|
||||
for _ in range(100): # 循环10次
|
||||
random_time = current_timestamp - random.randint(1, 3600*39) # 随机时间
|
||||
print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
||||
chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time)
|
||||
chat_text.append(chat_) # 拼接所有text
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
|
||||
@@ -173,7 +220,7 @@ def main():
|
||||
#将记忆加入到图谱中
|
||||
for topic, memory in first_memory:
|
||||
topics = segment_text(topic)
|
||||
print(f"话题: {topic},节点: {topics}, 记忆: {memory}")
|
||||
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
||||
for split_topic in topics:
|
||||
memory_graph.add_dot(split_topic,memory)
|
||||
for split_topic in topics:
|
||||
@@ -182,7 +229,13 @@ def main():
|
||||
memory_graph.connect_dot(split_topic, other_split_topic)
|
||||
|
||||
# memory_graph.store_memory()
|
||||
visualize_graph(memory_graph)
|
||||
|
||||
# 展示两种不同的可视化方式
|
||||
print("\n按连接数量着色的图谱:")
|
||||
visualize_graph(memory_graph, color_by_memory=False)
|
||||
|
||||
print("\n按记忆数量着色的图谱:")
|
||||
visualize_graph(memory_graph, color_by_memory=True)
|
||||
|
||||
memory_graph.save_graph_to_db()
|
||||
# memory_graph.load_graph_from_db()
|
||||
@@ -252,45 +305,66 @@ def topic_what(text, topic):
|
||||
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
|
||||
return prompt
|
||||
|
||||
def visualize_graph(memory_graph: Memory_graph):
|
||||
def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||||
|
||||
G = memory_graph.G
|
||||
|
||||
|
||||
# 保存图到本地
|
||||
nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式
|
||||
|
||||
# 根据连接条数设置节点颜色
|
||||
# 根据连接条数或记忆数量设置节点颜色
|
||||
node_colors = []
|
||||
nodes = list(G.nodes()) # 获取图中实际的节点列表
|
||||
max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1 # 获取最大连接数
|
||||
|
||||
for node in nodes:
|
||||
degree = G.degree(node) # 获取节点的度
|
||||
# 计算颜色,使用渐变效果
|
||||
if max_degree > 0:
|
||||
red = min(1.0, degree / max_degree) # 红色分量随连接数增加而增加
|
||||
blue = 1.0 - red # 蓝色分量随连接数增加而减少
|
||||
color = (red, 0, blue)
|
||||
else:
|
||||
color = (0, 0, 1) # 如果没有连接,则为蓝色
|
||||
node_colors.append(color)
|
||||
if color_by_memory:
|
||||
# 计算每个节点的记忆数量
|
||||
memory_counts = []
|
||||
for node in nodes:
|
||||
memory_items = G.nodes[node].get('memory_items', [])
|
||||
if isinstance(memory_items, list):
|
||||
count = len(memory_items)
|
||||
else:
|
||||
count = 1 if memory_items else 0
|
||||
memory_counts.append(count)
|
||||
max_memories = max(memory_counts) if memory_counts else 1
|
||||
|
||||
for count in memory_counts:
|
||||
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
|
||||
if max_memories > 0:
|
||||
intensity = min(1.0, count / max_memories)
|
||||
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
|
||||
else:
|
||||
color = (0, 0, 1) # 如果没有记忆,则为蓝色
|
||||
node_colors.append(color)
|
||||
else:
|
||||
# 使用原来的连接数量着色方案
|
||||
max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1
|
||||
for node in nodes:
|
||||
degree = G.degree(node)
|
||||
if max_degree > 0:
|
||||
red = min(1.0, degree / max_degree)
|
||||
blue = 1.0 - red
|
||||
color = (red, 0, blue)
|
||||
else:
|
||||
color = (0, 0, 1)
|
||||
node_colors.append(color)
|
||||
|
||||
# 绘制图形
|
||||
plt.figure(figsize=(12, 8))
|
||||
pos = nx.spring_layout(G, k=1, iterations=50) # 使用弹簧布局,调整参数使布局更合理
|
||||
pos = nx.spring_layout(G, k=1, iterations=50)
|
||||
nx.draw(G, pos,
|
||||
with_labels=True,
|
||||
node_color=node_colors,
|
||||
node_size=2000,
|
||||
font_size=10,
|
||||
font_family='SimHei', # 设置节点标签的字体
|
||||
font_family='SimHei',
|
||||
font_weight='bold')
|
||||
|
||||
plt.title('记忆图谱可视化', fontsize=16, fontfamily='SimHei')
|
||||
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
|
||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user