重构数据库访问,替换为统一的数据库实例引用
This commit is contained in:
@@ -13,7 +13,7 @@ from loguru import logger
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
sys.path.append(root_path)
|
||||
|
||||
from src.common.database import Database # 使用正确的导入语法
|
||||
from src.common.database import db # 使用正确的导入语法
|
||||
|
||||
# 加载.env.dev文件
|
||||
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
|
||||
@@ -23,7 +23,6 @@ load_dotenv(env_path)
|
||||
class Memory_graph:
|
||||
def __init__(self):
|
||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||
self.db = Database.get_instance()
|
||||
|
||||
def connect_dot(self, concept1, concept2):
|
||||
self.G.add_edge(concept1, concept2)
|
||||
@@ -96,7 +95,7 @@ class Memory_graph:
|
||||
dot_data = {
|
||||
"concept": node
|
||||
}
|
||||
self.db.store_memory_dots.insert_one(dot_data)
|
||||
db.store_memory_dots.insert_one(dot_data)
|
||||
|
||||
@property
|
||||
def dots(self):
|
||||
@@ -106,7 +105,7 @@ class Memory_graph:
|
||||
def get_random_chat_from_db(self, length: int, timestamp: str):
|
||||
# 从数据库中根据时间戳获取离其最近的聊天记录
|
||||
chat_text = ''
|
||||
closest_record = self.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
||||
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
||||
logger.info(
|
||||
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||
|
||||
@@ -115,7 +114,7 @@ class Memory_graph:
|
||||
group_id = closest_record['group_id'] # 获取groupid
|
||||
# 获取该时间戳之后的length条消息,且groupid相同
|
||||
chat_record = list(
|
||||
self.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
|
||||
db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
|
||||
length))
|
||||
for record in chat_record:
|
||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||
@@ -130,50 +129,39 @@ class Memory_graph:
|
||||
|
||||
def save_graph_to_db(self):
|
||||
# 清空现有的图数据
|
||||
self.db.graph_data.delete_many({})
|
||||
db.graph_data.delete_many({})
|
||||
# 保存节点
|
||||
for node in self.G.nodes(data=True):
|
||||
node_data = {
|
||||
'concept': node[0],
|
||||
'memory_items': node[1].get('memory_items', []) # 默认为空列表
|
||||
}
|
||||
self.db.graph_data.nodes.insert_one(node_data)
|
||||
db.graph_data.nodes.insert_one(node_data)
|
||||
# 保存边
|
||||
for edge in self.G.edges():
|
||||
edge_data = {
|
||||
'source': edge[0],
|
||||
'target': edge[1]
|
||||
}
|
||||
self.db.graph_data.edges.insert_one(edge_data)
|
||||
db.graph_data.edges.insert_one(edge_data)
|
||||
|
||||
def load_graph_from_db(self):
|
||||
# 清空当前图
|
||||
self.G.clear()
|
||||
# 加载节点
|
||||
nodes = self.db.graph_data.nodes.find()
|
||||
nodes = db.graph_data.nodes.find()
|
||||
for node in nodes:
|
||||
memory_items = node.get('memory_items', [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
self.G.add_node(node['concept'], memory_items=memory_items)
|
||||
# 加载边
|
||||
edges = self.db.graph_data.edges.find()
|
||||
edges = db.graph_data.edges.find()
|
||||
for edge in edges:
|
||||
self.G.add_edge(edge['source'], edge['target'])
|
||||
|
||||
|
||||
def main():
|
||||
# 初始化数据库
|
||||
Database.initialize(
|
||||
uri=os.getenv("MONGODB_URI"),
|
||||
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
||||
port=int(os.getenv("MONGODB_PORT", "27017")),
|
||||
db_name=os.getenv("DATABASE_NAME", "MegBot"),
|
||||
username=os.getenv("MONGODB_USERNAME"),
|
||||
password=os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
|
||||
)
|
||||
|
||||
memory_graph = Memory_graph()
|
||||
memory_graph.load_graph_from_db()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user