重构数据库访问,替换为统一的数据库实例引用
This commit is contained in:
@@ -13,7 +13,7 @@ from loguru import logger
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
sys.path.append(root_path)
|
||||
|
||||
from src.common.database import Database # 使用正确的导入语法
|
||||
from src.common.database import db # 使用正确的导入语法
|
||||
|
||||
# 加载.env.dev文件
|
||||
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
|
||||
@@ -23,7 +23,6 @@ load_dotenv(env_path)
|
||||
class Memory_graph:
|
||||
def __init__(self):
|
||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||
self.db = Database.get_instance()
|
||||
|
||||
def connect_dot(self, concept1, concept2):
|
||||
self.G.add_edge(concept1, concept2)
|
||||
@@ -96,7 +95,7 @@ class Memory_graph:
|
||||
dot_data = {
|
||||
"concept": node
|
||||
}
|
||||
self.db.store_memory_dots.insert_one(dot_data)
|
||||
db.store_memory_dots.insert_one(dot_data)
|
||||
|
||||
@property
|
||||
def dots(self):
|
||||
@@ -106,7 +105,7 @@ class Memory_graph:
|
||||
def get_random_chat_from_db(self, length: int, timestamp: str):
|
||||
# 从数据库中根据时间戳获取离其最近的聊天记录
|
||||
chat_text = ''
|
||||
closest_record = self.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
||||
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
||||
logger.info(
|
||||
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||
|
||||
@@ -115,7 +114,7 @@ class Memory_graph:
|
||||
group_id = closest_record['group_id'] # 获取groupid
|
||||
# 获取该时间戳之后的length条消息,且groupid相同
|
||||
chat_record = list(
|
||||
self.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
|
||||
db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
|
||||
length))
|
||||
for record in chat_record:
|
||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||
@@ -130,50 +129,39 @@ class Memory_graph:
|
||||
|
||||
def save_graph_to_db(self):
|
||||
# 清空现有的图数据
|
||||
self.db.graph_data.delete_many({})
|
||||
db.graph_data.delete_many({})
|
||||
# 保存节点
|
||||
for node in self.G.nodes(data=True):
|
||||
node_data = {
|
||||
'concept': node[0],
|
||||
'memory_items': node[1].get('memory_items', []) # 默认为空列表
|
||||
}
|
||||
self.db.graph_data.nodes.insert_one(node_data)
|
||||
db.graph_data.nodes.insert_one(node_data)
|
||||
# 保存边
|
||||
for edge in self.G.edges():
|
||||
edge_data = {
|
||||
'source': edge[0],
|
||||
'target': edge[1]
|
||||
}
|
||||
self.db.graph_data.edges.insert_one(edge_data)
|
||||
db.graph_data.edges.insert_one(edge_data)
|
||||
|
||||
def load_graph_from_db(self):
|
||||
# 清空当前图
|
||||
self.G.clear()
|
||||
# 加载节点
|
||||
nodes = self.db.graph_data.nodes.find()
|
||||
nodes = db.graph_data.nodes.find()
|
||||
for node in nodes:
|
||||
memory_items = node.get('memory_items', [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
self.G.add_node(node['concept'], memory_items=memory_items)
|
||||
# 加载边
|
||||
edges = self.db.graph_data.edges.find()
|
||||
edges = db.graph_data.edges.find()
|
||||
for edge in edges:
|
||||
self.G.add_edge(edge['source'], edge['target'])
|
||||
|
||||
|
||||
def main():
|
||||
# 初始化数据库
|
||||
Database.initialize(
|
||||
uri=os.getenv("MONGODB_URI"),
|
||||
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
||||
port=int(os.getenv("MONGODB_PORT", "27017")),
|
||||
db_name=os.getenv("DATABASE_NAME", "MegBot"),
|
||||
username=os.getenv("MONGODB_USERNAME"),
|
||||
password=os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
|
||||
)
|
||||
|
||||
memory_graph = Memory_graph()
|
||||
memory_graph.load_graph_from_db()
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import networkx as nx
|
||||
|
||||
from loguru import logger
|
||||
from nonebot import get_driver
|
||||
from ...common.database import Database # 使用正确的导入语法
|
||||
from ...common.database import db # 使用正确的导入语法
|
||||
from ..chat.config import global_config
|
||||
from ..chat.utils import (
|
||||
calculate_information_content,
|
||||
@@ -23,7 +23,6 @@ from ..models.utils_model import LLM_request
|
||||
class Memory_graph:
|
||||
def __init__(self):
|
||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||
self.db = Database.get_instance()
|
||||
|
||||
def connect_dot(self, concept1, concept2):
|
||||
# 避免自连接
|
||||
@@ -191,19 +190,19 @@ class Hippocampus:
|
||||
# 短期:1h 中期:4h 长期:24h
|
||||
for _ in range(time_frequency.get('near')):
|
||||
random_time = current_timestamp - random.randint(1, 3600)
|
||||
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
|
||||
if messages:
|
||||
chat_samples.append(messages)
|
||||
|
||||
for _ in range(time_frequency.get('mid')):
|
||||
random_time = current_timestamp - random.randint(3600, 3600 * 4)
|
||||
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
|
||||
if messages:
|
||||
chat_samples.append(messages)
|
||||
|
||||
for _ in range(time_frequency.get('far')):
|
||||
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
|
||||
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
|
||||
if messages:
|
||||
chat_samples.append(messages)
|
||||
|
||||
@@ -349,7 +348,7 @@ class Hippocampus:
|
||||
def sync_memory_to_db(self):
|
||||
"""检查并同步内存中的图结构与数据库"""
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
db_nodes = list(self.memory_graph.db.graph_data.nodes.find())
|
||||
db_nodes = list(db.graph_data.nodes.find())
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 转换数据库节点为字典格式,方便查找
|
||||
@@ -377,7 +376,7 @@ class Hippocampus:
|
||||
'created_time': created_time,
|
||||
'last_modified': last_modified
|
||||
}
|
||||
self.memory_graph.db.graph_data.nodes.insert_one(node_data)
|
||||
db.graph_data.nodes.insert_one(node_data)
|
||||
else:
|
||||
# 获取数据库中节点的特征值
|
||||
db_node = db_nodes_dict[concept]
|
||||
@@ -385,7 +384,7 @@ class Hippocampus:
|
||||
|
||||
# 如果特征值不同,则更新节点
|
||||
if db_hash != memory_hash:
|
||||
self.memory_graph.db.graph_data.nodes.update_one(
|
||||
db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': {
|
||||
'memory_items': memory_items,
|
||||
@@ -396,7 +395,7 @@ class Hippocampus:
|
||||
)
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(self.memory_graph.db.graph_data.edges.find())
|
||||
db_edges = list(db.graph_data.edges.find())
|
||||
memory_edges = list(self.memory_graph.G.edges(data=True))
|
||||
|
||||
# 创建边的哈希值字典
|
||||
@@ -428,11 +427,11 @@ class Hippocampus:
|
||||
'created_time': created_time,
|
||||
'last_modified': last_modified
|
||||
}
|
||||
self.memory_graph.db.graph_data.edges.insert_one(edge_data)
|
||||
db.graph_data.edges.insert_one(edge_data)
|
||||
else:
|
||||
# 检查边的特征值是否变化
|
||||
if db_edge_dict[edge_key]['hash'] != edge_hash:
|
||||
self.memory_graph.db.graph_data.edges.update_one(
|
||||
db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': {
|
||||
'hash': edge_hash,
|
||||
@@ -451,7 +450,7 @@ class Hippocampus:
|
||||
self.memory_graph.G.clear()
|
||||
|
||||
# 从数据库加载所有节点
|
||||
nodes = list(self.memory_graph.db.graph_data.nodes.find())
|
||||
nodes = list(db.graph_data.nodes.find())
|
||||
for node in nodes:
|
||||
concept = node['concept']
|
||||
memory_items = node.get('memory_items', [])
|
||||
@@ -468,7 +467,7 @@ class Hippocampus:
|
||||
if 'last_modified' not in node:
|
||||
update_data['last_modified'] = current_time
|
||||
|
||||
self.memory_graph.db.graph_data.nodes.update_one(
|
||||
db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': update_data}
|
||||
)
|
||||
@@ -485,7 +484,7 @@ class Hippocampus:
|
||||
last_modified=last_modified)
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = list(self.memory_graph.db.graph_data.edges.find())
|
||||
edges = list(db.graph_data.edges.find())
|
||||
for edge in edges:
|
||||
source = edge['source']
|
||||
target = edge['target']
|
||||
@@ -501,7 +500,7 @@ class Hippocampus:
|
||||
if 'last_modified' not in edge:
|
||||
update_data['last_modified'] = current_time
|
||||
|
||||
self.memory_graph.db.graph_data.edges.update_one(
|
||||
db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': update_data}
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ import jieba
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
sys.path.append(root_path)
|
||||
|
||||
from src.common.database import Database
|
||||
from src.common.database import db
|
||||
from src.plugins.memory_system.offline_llm import LLMModel
|
||||
|
||||
# 获取当前文件的目录
|
||||
@@ -49,7 +49,7 @@ def calculate_information_content(text):
|
||||
|
||||
return entropy
|
||||
|
||||
def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||
def get_cloest_chat_from_db(length: int, timestamp: str):
|
||||
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数
|
||||
|
||||
Returns:
|
||||
@@ -91,7 +91,6 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||
class Memory_graph:
|
||||
def __init__(self):
|
||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||
self.db = Database.get_instance()
|
||||
|
||||
def connect_dot(self, concept1, concept2):
|
||||
# 如果边已存在,增加 strength
|
||||
@@ -186,19 +185,19 @@ class Hippocampus:
|
||||
# 短期:1h 中期:4h 长期:24h
|
||||
for _ in range(time_frequency.get('near')):
|
||||
random_time = current_timestamp - random.randint(1, 3600*4)
|
||||
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
|
||||
if messages:
|
||||
chat_samples.append(messages)
|
||||
|
||||
for _ in range(time_frequency.get('mid')):
|
||||
random_time = current_timestamp - random.randint(3600*4, 3600*24)
|
||||
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
|
||||
if messages:
|
||||
chat_samples.append(messages)
|
||||
|
||||
for _ in range(time_frequency.get('far')):
|
||||
random_time = current_timestamp - random.randint(3600*24, 3600*24*7)
|
||||
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
|
||||
if messages:
|
||||
chat_samples.append(messages)
|
||||
|
||||
@@ -323,7 +322,7 @@ class Hippocampus:
|
||||
self.memory_graph.G.clear()
|
||||
|
||||
# 从数据库加载所有节点
|
||||
nodes = self.memory_graph.db.graph_data.nodes.find()
|
||||
nodes = db.graph_data.nodes.find()
|
||||
for node in nodes:
|
||||
concept = node['concept']
|
||||
memory_items = node.get('memory_items', [])
|
||||
@@ -334,7 +333,7 @@ class Hippocampus:
|
||||
self.memory_graph.G.add_node(concept, memory_items=memory_items)
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = self.memory_graph.db.graph_data.edges.find()
|
||||
edges = db.graph_data.edges.find()
|
||||
for edge in edges:
|
||||
source = edge['source']
|
||||
target = edge['target']
|
||||
@@ -371,7 +370,7 @@ class Hippocampus:
|
||||
使用特征值(哈希值)快速判断是否需要更新
|
||||
"""
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
db_nodes = list(self.memory_graph.db.graph_data.nodes.find())
|
||||
db_nodes = list(db.graph_data.nodes.find())
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 转换数据库节点为字典格式,方便查找
|
||||
@@ -394,7 +393,7 @@ class Hippocampus:
|
||||
'memory_items': memory_items,
|
||||
'hash': memory_hash
|
||||
}
|
||||
self.memory_graph.db.graph_data.nodes.insert_one(node_data)
|
||||
db.graph_data.nodes.insert_one(node_data)
|
||||
else:
|
||||
# 获取数据库中节点的特征值
|
||||
db_node = db_nodes_dict[concept]
|
||||
@@ -403,7 +402,7 @@ class Hippocampus:
|
||||
# 如果特征值不同,则更新节点
|
||||
if db_hash != memory_hash:
|
||||
# logger.info(f"更新节点内容: {concept}")
|
||||
self.memory_graph.db.graph_data.nodes.update_one(
|
||||
db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': {
|
||||
'memory_items': memory_items,
|
||||
@@ -416,10 +415,10 @@ class Hippocampus:
|
||||
for db_node in db_nodes:
|
||||
if db_node['concept'] not in memory_concepts:
|
||||
# logger.info(f"删除多余节点: {db_node['concept']}")
|
||||
self.memory_graph.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
|
||||
db.graph_data.nodes.delete_one({'concept': db_node['concept']})
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(self.memory_graph.db.graph_data.edges.find())
|
||||
db_edges = list(db.graph_data.edges.find())
|
||||
memory_edges = list(self.memory_graph.G.edges())
|
||||
|
||||
# 创建边的哈希值字典
|
||||
@@ -445,12 +444,12 @@ class Hippocampus:
|
||||
'num': 1,
|
||||
'hash': edge_hash
|
||||
}
|
||||
self.memory_graph.db.graph_data.edges.insert_one(edge_data)
|
||||
db.graph_data.edges.insert_one(edge_data)
|
||||
else:
|
||||
# 检查边的特征值是否变化
|
||||
if db_edge_dict[edge_key]['hash'] != edge_hash:
|
||||
logger.info(f"更新边: {source} - {target}")
|
||||
self.memory_graph.db.graph_data.edges.update_one(
|
||||
db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': {'hash': edge_hash}}
|
||||
)
|
||||
@@ -461,7 +460,7 @@ class Hippocampus:
|
||||
if edge_key not in memory_edge_set:
|
||||
source, target = edge_key
|
||||
logger.info(f"删除多余边: {source} - {target}")
|
||||
self.memory_graph.db.graph_data.edges.delete_one({
|
||||
db.graph_data.edges.delete_one({
|
||||
'source': source,
|
||||
'target': target
|
||||
})
|
||||
@@ -487,9 +486,9 @@ class Hippocampus:
|
||||
topic: 要删除的节点概念
|
||||
"""
|
||||
# 删除节点
|
||||
self.memory_graph.db.graph_data.nodes.delete_one({'concept': topic})
|
||||
db.graph_data.nodes.delete_one({'concept': topic})
|
||||
# 删除所有涉及该节点的边
|
||||
self.memory_graph.db.graph_data.edges.delete_many({
|
||||
db.graph_data.edges.delete_many({
|
||||
'$or': [
|
||||
{'source': topic},
|
||||
{'target': topic}
|
||||
@@ -902,17 +901,6 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
plt.show()
|
||||
|
||||
async def main():
|
||||
# 初始化数据库
|
||||
logger.info("正在初始化数据库连接...")
|
||||
Database.initialize(
|
||||
uri=os.getenv("MONGODB_URI"),
|
||||
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
||||
port=int(os.getenv("MONGODB_PORT", "27017")),
|
||||
db_name=os.getenv("DATABASE_NAME", "MegBot"),
|
||||
username=os.getenv("MONGODB_USERNAME"),
|
||||
password=os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
|
||||
)
|
||||
start_time = time.time()
|
||||
|
||||
test_pare = {'do_build_memory':False,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}
|
||||
|
||||
@@ -38,7 +38,7 @@ import jieba
|
||||
|
||||
# from chat.config import global_config
|
||||
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||
from src.common.database import Database
|
||||
from src.common.database import db
|
||||
from src.plugins.memory_system.offline_llm import LLMModel
|
||||
|
||||
# 获取当前文件的目录
|
||||
@@ -56,45 +56,6 @@ else:
|
||||
logger.warning(f"未找到环境变量文件: {env_path}")
|
||||
logger.info("将使用默认配置")
|
||||
|
||||
class Database:
|
||||
_instance = None
|
||||
db = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not Database.db:
|
||||
Database.initialize(
|
||||
uri=os.getenv("MONGODB_URI"),
|
||||
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
||||
port=int(os.getenv("MONGODB_PORT", "27017")),
|
||||
db_name=os.getenv("DATABASE_NAME", "MegBot"),
|
||||
username=os.getenv("MONGODB_USERNAME"),
|
||||
password=os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"):
|
||||
try:
|
||||
if username and password:
|
||||
uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}"
|
||||
else:
|
||||
uri = f"mongodb://{host}:{port}"
|
||||
|
||||
client = pymongo.MongoClient(uri)
|
||||
cls.db = client[db_name]
|
||||
# 测试连接
|
||||
client.server_info()
|
||||
logger.success("MongoDB连接成功!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化MongoDB失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def calculate_information_content(text):
|
||||
"""计算文本的信息量(熵)"""
|
||||
@@ -108,7 +69,7 @@ def calculate_information_content(text):
|
||||
|
||||
return entropy
|
||||
|
||||
def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||
def get_cloest_chat_from_db(length: int, timestamp: str):
|
||||
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数
|
||||
|
||||
Returns:
|
||||
@@ -163,7 +124,7 @@ class Memory_cortex:
|
||||
default_time = datetime.datetime.now().timestamp()
|
||||
|
||||
# 从数据库加载所有节点
|
||||
nodes = self.memory_graph.db.graph_data.nodes.find()
|
||||
nodes = db.graph_data.nodes.find()
|
||||
for node in nodes:
|
||||
concept = node['concept']
|
||||
memory_items = node.get('memory_items', [])
|
||||
@@ -180,7 +141,7 @@ class Memory_cortex:
|
||||
created_time = default_time
|
||||
last_modified = default_time
|
||||
# 更新数据库中的节点
|
||||
self.memory_graph.db.graph_data.nodes.update_one(
|
||||
db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': {
|
||||
'created_time': created_time,
|
||||
@@ -196,7 +157,7 @@ class Memory_cortex:
|
||||
last_modified=last_modified)
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = self.memory_graph.db.graph_data.edges.find()
|
||||
edges = db.graph_data.edges.find()
|
||||
for edge in edges:
|
||||
source = edge['source']
|
||||
target = edge['target']
|
||||
@@ -212,7 +173,7 @@ class Memory_cortex:
|
||||
created_time = default_time
|
||||
last_modified = default_time
|
||||
# 更新数据库中的边
|
||||
self.memory_graph.db.graph_data.edges.update_one(
|
||||
db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': {
|
||||
'created_time': created_time,
|
||||
@@ -256,7 +217,7 @@ class Memory_cortex:
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
db_nodes = list(self.memory_graph.db.graph_data.nodes.find())
|
||||
db_nodes = list(db.graph_data.nodes.find())
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 转换数据库节点为字典格式,方便查找
|
||||
@@ -280,7 +241,7 @@ class Memory_cortex:
|
||||
'created_time': data.get('created_time', current_time),
|
||||
'last_modified': data.get('last_modified', current_time)
|
||||
}
|
||||
self.memory_graph.db.graph_data.nodes.insert_one(node_data)
|
||||
db.graph_data.nodes.insert_one(node_data)
|
||||
else:
|
||||
# 获取数据库中节点的特征值
|
||||
db_node = db_nodes_dict[concept]
|
||||
@@ -288,7 +249,7 @@ class Memory_cortex:
|
||||
|
||||
# 如果特征值不同,则更新节点
|
||||
if db_hash != memory_hash:
|
||||
self.memory_graph.db.graph_data.nodes.update_one(
|
||||
db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': {
|
||||
'memory_items': memory_items,
|
||||
@@ -301,10 +262,10 @@ class Memory_cortex:
|
||||
memory_concepts = set(node[0] for node in memory_nodes)
|
||||
for db_node in db_nodes:
|
||||
if db_node['concept'] not in memory_concepts:
|
||||
self.memory_graph.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
|
||||
db.graph_data.nodes.delete_one({'concept': db_node['concept']})
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(self.memory_graph.db.graph_data.edges.find())
|
||||
db_edges = list(db.graph_data.edges.find())
|
||||
memory_edges = list(self.memory_graph.G.edges(data=True))
|
||||
|
||||
# 创建边的哈希值字典
|
||||
@@ -332,11 +293,11 @@ class Memory_cortex:
|
||||
'created_time': data.get('created_time', current_time),
|
||||
'last_modified': data.get('last_modified', current_time)
|
||||
}
|
||||
self.memory_graph.db.graph_data.edges.insert_one(edge_data)
|
||||
db.graph_data.edges.insert_one(edge_data)
|
||||
else:
|
||||
# 检查边的特征值是否变化
|
||||
if db_edge_dict[edge_key]['hash'] != edge_hash:
|
||||
self.memory_graph.db.graph_data.edges.update_one(
|
||||
db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': {
|
||||
'hash': edge_hash,
|
||||
@@ -350,7 +311,7 @@ class Memory_cortex:
|
||||
for edge_key in db_edge_dict:
|
||||
if edge_key not in memory_edge_set:
|
||||
source, target = edge_key
|
||||
self.memory_graph.db.graph_data.edges.delete_one({
|
||||
db.graph_data.edges.delete_one({
|
||||
'source': source,
|
||||
'target': target
|
||||
})
|
||||
@@ -365,9 +326,9 @@ class Memory_cortex:
|
||||
topic: 要删除的节点概念
|
||||
"""
|
||||
# 删除节点
|
||||
self.memory_graph.db.graph_data.nodes.delete_one({'concept': topic})
|
||||
db.graph_data.nodes.delete_one({'concept': topic})
|
||||
# 删除所有涉及该节点的边
|
||||
self.memory_graph.db.graph_data.edges.delete_many({
|
||||
db.graph_data.edges.delete_many({
|
||||
'$or': [
|
||||
{'source': topic},
|
||||
{'target': topic}
|
||||
@@ -377,7 +338,6 @@ class Memory_cortex:
|
||||
class Memory_graph:
|
||||
def __init__(self):
|
||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||
self.db = Database.get_instance()
|
||||
|
||||
def connect_dot(self, concept1, concept2):
|
||||
# 避免自连接
|
||||
@@ -492,19 +452,19 @@ class Hippocampus:
|
||||
# 短期:1h 中期:4h 长期:24h
|
||||
for _ in range(time_frequency.get('near')):
|
||||
random_time = current_timestamp - random.randint(1, 3600*4)
|
||||
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
|
||||
if messages:
|
||||
chat_samples.append(messages)
|
||||
|
||||
for _ in range(time_frequency.get('mid')):
|
||||
random_time = current_timestamp - random.randint(3600*4, 3600*24)
|
||||
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
|
||||
if messages:
|
||||
chat_samples.append(messages)
|
||||
|
||||
for _ in range(time_frequency.get('far')):
|
||||
random_time = current_timestamp - random.randint(3600*24, 3600*24*7)
|
||||
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
messages = get_cloest_chat_from_db(length=chat_size, timestamp=random_time)
|
||||
if messages:
|
||||
chat_samples.append(messages)
|
||||
|
||||
@@ -1134,7 +1094,6 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
async def main():
|
||||
# 初始化数据库
|
||||
logger.info("正在初始化数据库连接...")
|
||||
db = Database.get_instance()
|
||||
start_time = time.time()
|
||||
|
||||
test_pare = {'do_build_memory':True,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}
|
||||
|
||||
Reference in New Issue
Block a user