优化代码格式和异常处理

- 修复异常处理链,使用from语法保留原始异常
- 格式化代码以符合项目规范
- 优化导入模块的顺序

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
春河晴
2025-03-19 20:27:34 +09:00
parent a829dfdb77
commit fdc098d0db
52 changed files with 3156 additions and 2778 deletions

View File

@@ -15,10 +15,10 @@ logger = get_module_logger("draw_memory")
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.database import db # 使用正确的导入语法
from src.common.database import db # noqa: E402
# 加载.env.dev文件
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), ".env.dev")
load_dotenv(env_path)
@@ -32,13 +32,13 @@ class Memory_graph:
def add_dot(self, concept, memory):
if concept in self.G:
# 如果节点已存在,将新记忆添加到现有列表中
if 'memory_items' in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]['memory_items'], list):
if "memory_items" in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]["memory_items"], list):
# 如果当前不是列表,将其转换为列表
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
self.G.nodes[concept]['memory_items'].append(memory)
self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]]
self.G.nodes[concept]["memory_items"].append(memory)
else:
self.G.nodes[concept]['memory_items'] = [memory]
self.G.nodes[concept]["memory_items"] = [memory]
else:
# 如果是新节点,创建新的记忆列表
self.G.add_node(concept, memory_items=[memory])
@@ -68,8 +68,8 @@ class Memory_graph:
node_data = self.get_dot(topic)
if node_data:
concept, data = node_data
if 'memory_items' in data:
memory_items = data['memory_items']
if "memory_items" in data:
memory_items = data["memory_items"]
if isinstance(memory_items, list):
first_layer_items.extend(memory_items)
else:
@@ -83,8 +83,8 @@ class Memory_graph:
node_data = self.get_dot(neighbor)
if node_data:
concept, data = node_data
if 'memory_items' in data:
memory_items = data['memory_items']
if "memory_items" in data:
memory_items = data["memory_items"]
if isinstance(memory_items, list):
second_layer_items.extend(memory_items)
else:
@@ -94,9 +94,7 @@ class Memory_graph:
def store_memory(self):
for node in self.G.nodes():
dot_data = {
"concept": node
}
dot_data = {"concept": node}
db.store_memory_dots.insert_one(dot_data)
@property
@@ -106,25 +104,27 @@ class Memory_graph:
def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ''
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
chat_text = ""
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[("time", -1)]) # 调试输出
logger.info(
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}"
)
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
closest_time = closest_record["time"]
group_id = closest_record["group_id"] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(
db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
length))
db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort("time", 1).limit(length)
)
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(record["time"])))
try:
displayname = "[(%s)%s]%s" % (record["user_id"], record["user_nickname"], record["user_cardname"])
except:
displayname = record["user_nickname"] or "用户" + str(record["user_id"])
chat_text += f'[{time_str}] {displayname}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
except (KeyError, TypeError):
# 处理缺少键或类型错误的情况
displayname = record.get("user_nickname", "") or "用户" + str(record.get("user_id", "未知"))
chat_text += f"[{time_str}] {displayname}: {record['processed_plain_text']}\n" # 添加发送者和时间信息
return chat_text
return [] # 如果没有找到记录,返回空列表
@@ -135,16 +135,13 @@ class Memory_graph:
# 保存节点
for node in self.G.nodes(data=True):
node_data = {
'concept': node[0],
'memory_items': node[1].get('memory_items', []) # 默认为空列表
"concept": node[0],
"memory_items": node[1].get("memory_items", []), # 默认为空列表
}
db.graph_data.nodes.insert_one(node_data)
# 保存边
for edge in self.G.edges():
edge_data = {
'source': edge[0],
'target': edge[1]
}
edge_data = {"source": edge[0], "target": edge[1]}
db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self):
@@ -153,14 +150,14 @@ class Memory_graph:
# 加载节点
nodes = db.graph_data.nodes.find()
for node in nodes:
memory_items = node.get('memory_items', [])
memory_items = node.get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
self.G.add_node(node['concept'], memory_items=memory_items)
self.G.add_node(node["concept"], memory_items=memory_items)
# 加载边
edges = db.graph_data.edges.find()
for edge in edges:
self.G.add_edge(edge['source'], edge['target'])
self.G.add_edge(edge["source"], edge["target"])
def main():
@@ -172,7 +169,7 @@ def main():
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
if query.lower() == "退出":
break
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
if first_layer_items or second_layer_items:
@@ -192,19 +189,25 @@ def segment_text(text):
def find_topic(text, topic_num):
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
prompt = (
f"这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。"
f"只需要列举{topic_num}个话题就好,不要告诉我其他内容。"
)
return prompt
def topic_what(text, topic):
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
prompt = (
f"这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。"
f"只输出这句话就好"
)
return prompt
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams["font.sans-serif"] = ["SimHei"] # 用来正常显示中文标签
plt.rcParams["axes.unicode_minus"] = False # 用来正常显示负号
G = memory_graph.G
@@ -214,7 +217,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
# 移除只有一条记忆的节点和连接数少于3的节点
nodes_to_remove = []
for node in H.nodes():
memory_items = H.nodes[node].get('memory_items', [])
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
if memory_count < 3 or degree < 2: # 改为小于2而不是小于等于2
@@ -239,7 +242,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
max_memories = 1
max_degree = 1
for node in nodes:
memory_items = H.nodes[node].get('memory_items', [])
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
max_memories = max(max_memories, memory_count)
@@ -248,7 +251,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
# 计算每个节点的大小和颜色
for node in nodes:
# 计算节点大小(基于记忆数量)
memory_items = H.nodes[node].get('memory_items', [])
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
# 使用指数函数使变化更明显
ratio = memory_count / max_memories
@@ -269,19 +272,22 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(H, k=1, iterations=50) # 增加k值使节点分布更开
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=node_sizes,
font_size=10,
font_family='SimHei',
font_weight='bold',
edge_color='gray',
width=0.5,
alpha=0.9)
nx.draw(
H,
pos,
with_labels=True,
node_color=node_colors,
node_size=node_sizes,
font_size=10,
font_family="SimHei",
font_weight="bold",
edge_color="gray",
width=0.5,
alpha=0.9,
)
title = '记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数'
plt.title(title, fontsize=16, fontfamily='SimHei')
title = "记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数"
plt.title(title, fontsize=16, fontfamily="SimHei")
plt.show()