重构数据库访问,替换为统一的数据库实例引用

This commit is contained in:
晴猫
2025-03-12 22:27:59 +09:00
parent 49082267bb
commit 8be087dcad
19 changed files with 138 additions and 284 deletions

View File

@@ -38,7 +38,7 @@ import jieba
# from chat.config import global_config
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database
from src.common.database import db
from src.plugins.memory_system.offline_llm import LLMModel
# 获取当前文件的目录
@@ -56,45 +56,6 @@ else:
logger.warning(f"未找到环境变量文件: {env_path}")
logger.info("将使用默认配置")
class Database:
_instance = None
db = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
if not Database.db:
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
@classmethod
def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"):
try:
if username and password:
uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}"
else:
uri = f"mongodb://{host}:{port}"
client = pymongo.MongoClient(uri)
cls.db = client[db_name]
# 测试连接
client.server_info()
logger.success("MongoDB连接成功!")
except Exception as e:
logger.error(f"初始化MongoDB失败: {str(e)}")
raise
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
@@ -108,7 +69,7 @@ def calculate_information_content(text):
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
def get_cloest_chat_from_db(length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数
Returns:
@@ -163,7 +124,7 @@ class Memory_cortex:
default_time = datetime.datetime.now().timestamp()
# 从数据库加载所有节点
nodes = self.memory_graph.db.graph_data.nodes.find()
nodes = db.graph_data.nodes.find()
for node in nodes:
concept = node['concept']
memory_items = node.get('memory_items', [])
@@ -180,7 +141,7 @@ class Memory_cortex:
created_time = default_time
last_modified = default_time
# 更新数据库中的节点
self.memory_graph.db.graph_data.nodes.update_one(
db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {
'created_time': created_time,
@@ -196,7 +157,7 @@ class Memory_cortex:
last_modified=last_modified)
# 从数据库加载所有边
edges = self.memory_graph.db.graph_data.edges.find()
edges = db.graph_data.edges.find()
for edge in edges:
source = edge['source']
target = edge['target']
@@ -212,7 +173,7 @@ class Memory_cortex:
created_time = default_time
last_modified = default_time
# 更新数据库中的边
self.memory_graph.db.graph_data.edges.update_one(
db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {
'created_time': created_time,
@@ -256,7 +217,7 @@ class Memory_cortex:
current_time = datetime.datetime.now().timestamp()
# 获取数据库中所有节点和内存中所有节点
db_nodes = list(self.memory_graph.db.graph_data.nodes.find())
db_nodes = list(db.graph_data.nodes.find())
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找
@@ -280,7 +241,7 @@ class Memory_cortex:
'created_time': data.get('created_time', current_time),
'last_modified': data.get('last_modified', current_time)
}
self.memory_graph.db.graph_data.nodes.insert_one(node_data)
db.graph_data.nodes.insert_one(node_data)
else:
# 获取数据库中节点的特征值
db_node = db_nodes_dict[concept]
@@ -288,7 +249,7 @@ class Memory_cortex:
# 如果特征值不同,则更新节点
if db_hash != memory_hash:
self.memory_graph.db.graph_data.nodes.update_one(
db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {
'memory_items': memory_items,
@@ -301,10 +262,10 @@ class Memory_cortex:
memory_concepts = set(node[0] for node in memory_nodes)
for db_node in db_nodes:
if db_node['concept'] not in memory_concepts:
self.memory_graph.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
db.graph_data.nodes.delete_one({'concept': db_node['concept']})
# 处理边的信息
db_edges = list(self.memory_graph.db.graph_data.edges.find())
db_edges = list(db.graph_data.edges.find())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
@@ -332,11 +293,11 @@ class Memory_cortex:
'created_time': data.get('created_time', current_time),
'last_modified': data.get('last_modified', current_time)
}
self.memory_graph.db.graph_data.edges.insert_one(edge_data)
db.graph_data.edges.insert_one(edge_data)
else:
# 检查边的特征值是否变化
if db_edge_dict[edge_key]['hash'] != edge_hash:
self.memory_graph.db.graph_data.edges.update_one(
db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {
'hash': edge_hash,
@@ -350,7 +311,7 @@ class Memory_cortex:
for edge_key in db_edge_dict:
if edge_key not in memory_edge_set:
source, target = edge_key
self.memory_graph.db.graph_data.edges.delete_one({
db.graph_data.edges.delete_one({
'source': source,
'target': target
})
@@ -365,9 +326,9 @@ class Memory_cortex:
topic: 要删除的节点概念
"""
# 删除节点
self.memory_graph.db.graph_data.nodes.delete_one({'concept': topic})
db.graph_data.nodes.delete_one({'concept': topic})
# 删除所有涉及该节点的边
self.memory_graph.db.graph_data.edges.delete_many({
db.graph_data.edges.delete_many({
'$or': [
{'source': topic},
{'target': topic}
@@ -377,7 +338,6 @@ class Memory_cortex:
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
# 避免自连接
@@ -492,19 +452,19 @@ class Hippocampus:
# 短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')):
random_time = current_timestamp - random.randint(1, 3600*4)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
for _ in range(time_frequency.get('mid')):
random_time = current_timestamp - random.randint(3600*4, 3600*24)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
for _ in range(time_frequency.get('far')):
random_time = current_timestamp - random.randint(3600*24, 3600*24*7)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
@@ -1134,7 +1094,6 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
async def main():
# 初始化数据库
logger.info("正在初始化数据库连接...")
db = Database.get_instance()
start_time = time.time()
test_pare = {'do_build_memory':True,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}