重构数据库访问,替换为统一的数据库实例引用

This commit is contained in:
晴猫
2025-03-12 22:27:59 +09:00
parent 49082267bb
commit 8be087dcad
19 changed files with 138 additions and 284 deletions

View File

@@ -10,7 +10,7 @@ import networkx as nx
from loguru import logger
from nonebot import get_driver
from ...common.database import Database # 使用正确的导入语法
from ...common.database import db # 使用正确的导入语法
from ..chat.config import global_config
from ..chat.utils import (
calculate_information_content,
@@ -23,7 +23,6 @@ from ..models.utils_model import LLM_request
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
# 避免自连接
@@ -191,19 +190,19 @@ class Hippocampus:
# 短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')):
random_time = current_timestamp - random.randint(1, 3600)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
for _ in range(time_frequency.get('mid')):
random_time = current_timestamp - random.randint(3600, 3600 * 4)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
for _ in range(time_frequency.get('far')):
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
@@ -349,7 +348,7 @@ class Hippocampus:
def sync_memory_to_db(self):
"""检查并同步内存中的图结构与数据库"""
# 获取数据库中所有节点和内存中所有节点
db_nodes = list(self.memory_graph.db.graph_data.nodes.find())
db_nodes = list(db.graph_data.nodes.find())
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找
@@ -377,7 +376,7 @@ class Hippocampus:
'created_time': created_time,
'last_modified': last_modified
}
self.memory_graph.db.graph_data.nodes.insert_one(node_data)
db.graph_data.nodes.insert_one(node_data)
else:
# 获取数据库中节点的特征值
db_node = db_nodes_dict[concept]
@@ -385,7 +384,7 @@ class Hippocampus:
# 如果特征值不同,则更新节点
if db_hash != memory_hash:
self.memory_graph.db.graph_data.nodes.update_one(
db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {
'memory_items': memory_items,
@@ -396,7 +395,7 @@ class Hippocampus:
)
# 处理边的信息
db_edges = list(self.memory_graph.db.graph_data.edges.find())
db_edges = list(db.graph_data.edges.find())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
@@ -428,11 +427,11 @@ class Hippocampus:
'created_time': created_time,
'last_modified': last_modified
}
self.memory_graph.db.graph_data.edges.insert_one(edge_data)
db.graph_data.edges.insert_one(edge_data)
else:
# 检查边的特征值是否变化
if db_edge_dict[edge_key]['hash'] != edge_hash:
self.memory_graph.db.graph_data.edges.update_one(
db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {
'hash': edge_hash,
@@ -451,7 +450,7 @@ class Hippocampus:
self.memory_graph.G.clear()
# 从数据库加载所有节点
nodes = list(self.memory_graph.db.graph_data.nodes.find())
nodes = list(db.graph_data.nodes.find())
for node in nodes:
concept = node['concept']
memory_items = node.get('memory_items', [])
@@ -468,7 +467,7 @@ class Hippocampus:
if 'last_modified' not in node:
update_data['last_modified'] = current_time
self.memory_graph.db.graph_data.nodes.update_one(
db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': update_data}
)
@@ -485,7 +484,7 @@ class Hippocampus:
last_modified=last_modified)
# 从数据库加载所有边
edges = list(self.memory_graph.db.graph_data.edges.find())
edges = list(db.graph_data.edges.find())
for edge in edges:
source = edge['source']
target = edge['target']
@@ -501,7 +500,7 @@ class Hippocampus:
if 'last_modified' not in edge:
update_data['last_modified'] = current_time
self.memory_graph.db.graph_data.edges.update_one(
db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': update_data}
)