重构数据库访问,替换为统一的数据库实例引用
This commit is contained in:
@@ -10,7 +10,7 @@ import networkx as nx
|
||||
|
||||
from loguru import logger
|
||||
from nonebot import get_driver
|
||||
from ...common.database import Database # 使用正确的导入语法
|
||||
from ...common.database import db # 使用正确的导入语法
|
||||
from ..chat.config import global_config
|
||||
from ..chat.utils import (
|
||||
calculate_information_content,
|
||||
@@ -23,7 +23,6 @@ from ..models.utils_model import LLM_request
|
||||
class Memory_graph:
|
||||
def __init__(self):
|
||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||
self.db = Database.get_instance()
|
||||
|
||||
def connect_dot(self, concept1, concept2):
|
||||
# 避免自连接
|
||||
@@ -191,19 +190,19 @@ class Hippocampus:
|
||||
# 短期:1h 中期:4h 长期:24h
|
||||
for _ in range(time_frequency.get('near')):
|
||||
random_time = current_timestamp - random.randint(1, 3600)
|
||||
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
|
||||
if messages:
|
||||
chat_samples.append(messages)
|
||||
|
||||
for _ in range(time_frequency.get('mid')):
|
||||
random_time = current_timestamp - random.randint(3600, 3600 * 4)
|
||||
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
|
||||
if messages:
|
||||
chat_samples.append(messages)
|
||||
|
||||
for _ in range(time_frequency.get('far')):
|
||||
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
|
||||
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
|
||||
if messages:
|
||||
chat_samples.append(messages)
|
||||
|
||||
@@ -349,7 +348,7 @@ class Hippocampus:
|
||||
def sync_memory_to_db(self):
|
||||
"""检查并同步内存中的图结构与数据库"""
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
db_nodes = list(self.memory_graph.db.graph_data.nodes.find())
|
||||
db_nodes = list(db.graph_data.nodes.find())
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 转换数据库节点为字典格式,方便查找
|
||||
@@ -377,7 +376,7 @@ class Hippocampus:
|
||||
'created_time': created_time,
|
||||
'last_modified': last_modified
|
||||
}
|
||||
self.memory_graph.db.graph_data.nodes.insert_one(node_data)
|
||||
db.graph_data.nodes.insert_one(node_data)
|
||||
else:
|
||||
# 获取数据库中节点的特征值
|
||||
db_node = db_nodes_dict[concept]
|
||||
@@ -385,7 +384,7 @@ class Hippocampus:
|
||||
|
||||
# 如果特征值不同,则更新节点
|
||||
if db_hash != memory_hash:
|
||||
self.memory_graph.db.graph_data.nodes.update_one(
|
||||
db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': {
|
||||
'memory_items': memory_items,
|
||||
@@ -396,7 +395,7 @@ class Hippocampus:
|
||||
)
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(self.memory_graph.db.graph_data.edges.find())
|
||||
db_edges = list(db.graph_data.edges.find())
|
||||
memory_edges = list(self.memory_graph.G.edges(data=True))
|
||||
|
||||
# 创建边的哈希值字典
|
||||
@@ -428,11 +427,11 @@ class Hippocampus:
|
||||
'created_time': created_time,
|
||||
'last_modified': last_modified
|
||||
}
|
||||
self.memory_graph.db.graph_data.edges.insert_one(edge_data)
|
||||
db.graph_data.edges.insert_one(edge_data)
|
||||
else:
|
||||
# 检查边的特征值是否变化
|
||||
if db_edge_dict[edge_key]['hash'] != edge_hash:
|
||||
self.memory_graph.db.graph_data.edges.update_one(
|
||||
db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': {
|
||||
'hash': edge_hash,
|
||||
@@ -451,7 +450,7 @@ class Hippocampus:
|
||||
self.memory_graph.G.clear()
|
||||
|
||||
# 从数据库加载所有节点
|
||||
nodes = list(self.memory_graph.db.graph_data.nodes.find())
|
||||
nodes = list(db.graph_data.nodes.find())
|
||||
for node in nodes:
|
||||
concept = node['concept']
|
||||
memory_items = node.get('memory_items', [])
|
||||
@@ -468,7 +467,7 @@ class Hippocampus:
|
||||
if 'last_modified' not in node:
|
||||
update_data['last_modified'] = current_time
|
||||
|
||||
self.memory_graph.db.graph_data.nodes.update_one(
|
||||
db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': update_data}
|
||||
)
|
||||
@@ -485,7 +484,7 @@ class Hippocampus:
|
||||
last_modified=last_modified)
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = list(self.memory_graph.db.graph_data.edges.find())
|
||||
edges = list(db.graph_data.edges.find())
|
||||
for edge in edges:
|
||||
source = edge['source']
|
||||
target = edge['target']
|
||||
@@ -501,7 +500,7 @@ class Hippocampus:
|
||||
if 'last_modified' not in edge:
|
||||
update_data['last_modified'] = current_time
|
||||
|
||||
self.memory_graph.db.graph_data.edges.update_one(
|
||||
db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': update_data}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user