v0.5.2 记忆系统更新
This commit is contained in:
@@ -22,63 +22,6 @@ from src.common.database import Database # 使用正确的导入语法
|
||||
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
|
||||
load_dotenv(env_path)
|
||||
|
||||
class LLMModel:
|
||||
def __init__(self, model_name=os.getenv("SILICONFLOW_MODEL_V3"), **kwargs):
|
||||
self.model_name = model_name
|
||||
self.params = kwargs
|
||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
||||
|
||||
async def generate_response(self, prompt: str) -> Tuple[str, str]:
|
||||
"""根据输入的提示生成模型的响应"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 构建请求体
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.5,
|
||||
**self.params
|
||||
}
|
||||
|
||||
# 发送请求到完整的chat/completions端点
|
||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
|
||||
max_retries = 3
|
||||
base_wait_time = 15
|
||||
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(api_url, headers=headers, json=data) as response:
|
||||
if response.status == 429:
|
||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
response.raise_for_status() # 检查其他响应状态
|
||||
|
||||
result = await response.json()
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||
return content, reasoning_content
|
||||
return "没有返回结果", ""
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
return f"请求失败: {str(e)}", ""
|
||||
|
||||
return "达到最大重试次数,请求仍然失败", ""
|
||||
|
||||
|
||||
class Memory_graph:
|
||||
def __init__(self):
|
||||
@@ -232,19 +175,10 @@ def main():
|
||||
)
|
||||
|
||||
memory_graph = Memory_graph()
|
||||
# 创建LLM模型实例
|
||||
|
||||
memory_graph.load_graph_from_db()
|
||||
# 展示两种不同的可视化方式
|
||||
print("\n按连接数量着色的图谱:")
|
||||
# visualize_graph(memory_graph, color_by_memory=False)
|
||||
visualize_graph_lite(memory_graph, color_by_memory=False)
|
||||
|
||||
print("\n按记忆数量着色的图谱:")
|
||||
# visualize_graph(memory_graph, color_by_memory=True)
|
||||
visualize_graph_lite(memory_graph, color_by_memory=True)
|
||||
|
||||
# memory_graph.save_graph_to_db()
|
||||
# 只显示一次优化后的图形
|
||||
visualize_graph_lite(memory_graph)
|
||||
|
||||
while True:
|
||||
query = input("请输入新的查询概念(输入'退出'以结束):")
|
||||
@@ -327,7 +261,7 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||
nx.draw(G, pos,
|
||||
with_labels=True,
|
||||
node_color=node_colors,
|
||||
node_size=2000,
|
||||
node_size=200,
|
||||
font_size=10,
|
||||
font_family='SimHei',
|
||||
font_weight='bold')
|
||||
@@ -353,7 +287,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||
degree = H.degree(node)
|
||||
if memory_count <= 2 or degree <= 2:
|
||||
if memory_count < 5 or degree < 2: # 改为小于2而不是小于等于2
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
H.remove_nodes_from(nodes_to_remove)
|
||||
@@ -366,55 +300,55 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
# 保存图到本地
|
||||
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
|
||||
|
||||
# 根据连接条数或记忆数量设置节点颜色
|
||||
# 计算节点大小和颜色
|
||||
node_colors = []
|
||||
nodes = list(H.nodes()) # 获取图中实际的节点列表
|
||||
node_sizes = []
|
||||
nodes = list(H.nodes())
|
||||
|
||||
if color_by_memory:
|
||||
# 计算每个节点的记忆数量
|
||||
memory_counts = []
|
||||
for node in nodes:
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
if isinstance(memory_items, list):
|
||||
count = len(memory_items)
|
||||
else:
|
||||
count = 1 if memory_items else 0
|
||||
memory_counts.append(count)
|
||||
max_memories = max(memory_counts) if memory_counts else 1
|
||||
# 获取最大记忆数和最大度数用于归一化
|
||||
max_memories = 1
|
||||
max_degree = 1
|
||||
for node in nodes:
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||
degree = H.degree(node)
|
||||
max_memories = max(max_memories, memory_count)
|
||||
max_degree = max(max_degree, degree)
|
||||
|
||||
# 计算每个节点的大小和颜色
|
||||
for node in nodes:
|
||||
# 计算节点大小(基于记忆数量)
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||
# 使用指数函数使变化更明显
|
||||
ratio = memory_count / max_memories
|
||||
size = 500 + 5000 * (ratio ** 2) # 使用平方函数使差异更明显
|
||||
node_sizes.append(size)
|
||||
|
||||
for count in memory_counts:
|
||||
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
|
||||
if max_memories > 0:
|
||||
intensity = min(1.0, count / max_memories)
|
||||
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
|
||||
else:
|
||||
color = (0, 0, 1) # 如果没有记忆,则为蓝色
|
||||
node_colors.append(color)
|
||||
else:
|
||||
# 使用原来的连接数量着色方案
|
||||
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
|
||||
for node in nodes:
|
||||
degree = H.degree(node)
|
||||
if max_degree > 0:
|
||||
red = min(1.0, degree / max_degree)
|
||||
blue = 1.0 - red
|
||||
color = (red, 0, blue)
|
||||
else:
|
||||
color = (0, 0, 1)
|
||||
node_colors.append(color)
|
||||
# 计算节点颜色(基于连接数)
|
||||
degree = H.degree(node)
|
||||
# 红色分量随着度数增加而增加
|
||||
red = min(1.0, degree / max_degree)
|
||||
# 蓝色分量随着度数减少而增加
|
||||
blue = 1.0 - red
|
||||
color = (red, 0, blue)
|
||||
node_colors.append(color)
|
||||
|
||||
# 绘制图形
|
||||
plt.figure(figsize=(12, 8))
|
||||
pos = nx.spring_layout(H, k=1, iterations=50)
|
||||
pos = nx.spring_layout(H, k=1.5, iterations=50) # 增加k值使节点分布更开
|
||||
nx.draw(H, pos,
|
||||
with_labels=True,
|
||||
node_color=node_colors,
|
||||
node_size=2000,
|
||||
node_size=node_sizes,
|
||||
font_size=10,
|
||||
font_family='SimHei',
|
||||
font_weight='bold')
|
||||
font_weight='bold',
|
||||
edge_color='gray',
|
||||
width=0.5,
|
||||
alpha=0.7)
|
||||
|
||||
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
|
||||
title = '记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数'
|
||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||
plt.show()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user