Merge remote-tracking branch 'upstream/debug' into refractor

This commit is contained in:
tcmofashi
2025-03-11 16:45:41 +08:00
20 changed files with 356 additions and 228 deletions

View File

@@ -6,20 +6,44 @@ from pymongo import MongoClient
class Database:
_instance: Optional["Database"] = None
def __init__(self, host: str, port: int, db_name: str, username: Optional[str] = None, password: Optional[str] = None, auth_source: Optional[str] = None):
if username and password:
def __init__(
self,
host: str,
port: int,
db_name: str,
username: Optional[str] = None,
password: Optional[str] = None,
auth_source: Optional[str] = None,
uri: Optional[str] = None,
):
if uri and uri.startswith("mongodb://"):
# 优先使用URI连接
self.client = MongoClient(uri)
elif username and password:
# 如果有用户名和密码,使用认证连接
# TODO: 复杂情况直接支持URI吧
self.client = MongoClient(host, port, username=username, password=password, authSource=auth_source)
self.client = MongoClient(
host, port, username=username, password=password, authSource=auth_source
)
else:
# 否则使用无认证连接
self.client = MongoClient(host, port)
self.db = self.client[db_name]
@classmethod
def initialize(cls, host: str, port: int, db_name: str, username: Optional[str] = None, password: Optional[str] = None, auth_source: Optional[str] = None) -> "Database":
def initialize(
cls,
host: str,
port: int,
db_name: str,
username: Optional[str] = None,
password: Optional[str] = None,
auth_source: Optional[str] = None,
uri: Optional[str] = None,
) -> "Database":
if cls._instance is None:
cls._instance = cls(host, port, db_name, username, password, auth_source)
cls._instance = cls(
host, port, db_name, username, password, auth_source, uri
)
return cls._instance
@classmethod

View File

@@ -7,7 +7,7 @@ from datetime import datetime
from typing import Dict, List
from loguru import logger
from typing import Optional
from pymongo import MongoClient
from ..common.database import Database
import customtkinter as ctk
from dotenv import load_dotenv
@@ -28,38 +28,6 @@ else:
logger.error("未找到环境配置文件")
sys.exit(1)
class Database:
_instance: Optional["Database"] = None
def __init__(self, host: str, port: int, db_name: str, username: str = None, password: str = None,
auth_source: str = None):
if username and password:
self.client = MongoClient(
host=host,
port=port,
username=username,
password=password,
authSource=auth_source or 'admin'
)
else:
self.client = MongoClient(host, port)
self.db = self.client[db_name]
@classmethod
def initialize(cls, host: str, port: int, db_name: str, username: str = None, password: str = None,
auth_source: str = None) -> "Database":
if cls._instance is None:
cls._instance = cls(host, port, db_name, username, password, auth_source)
return cls._instance
@classmethod
def get_instance(cls) -> "Database":
if cls._instance is None:
raise RuntimeError("Database not initialized")
return cls._instance
class ReasoningGUI:
def __init__(self):
# 记录启动时间戳转换为Unix时间戳
@@ -83,7 +51,15 @@ class ReasoningGUI:
except RuntimeError:
logger.warning("数据库未初始化,正在尝试初始化...")
try:
Database.initialize("127.0.0.1", 27017, "maimai_bot")
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
self.db = Database.get_instance().db
logger.success("数据库初始化成功")
except Exception:
@@ -359,12 +335,13 @@ class ReasoningGUI:
def main():
"""主函数"""
Database.initialize(
host=os.getenv("MONGODB_HOST"),
port=int(os.getenv("MONGODB_PORT")),
db_name=os.getenv("DATABASE_NAME"),
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
app = ReasoningGUI()

View File

@@ -1,5 +1,6 @@
import asyncio
import time
import os
from loguru import logger
from nonebot import get_driver, on_message, require
@@ -32,12 +33,13 @@ driver = get_driver()
config = driver.config
Database.initialize(
host=config.MONGODB_HOST,
port=int(config.MONGODB_PORT),
db_name=config.DATABASE_NAME,
username=config.MONGODB_USERNAME,
password=config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
logger.success("初始化数据库成功")

View File

@@ -1,6 +1,5 @@
import base64
import html
import os
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

View File

@@ -36,7 +36,7 @@ class EmojiManager:
self.db = Database.get_instance()
self._scan_task = None
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
self.llm_emotion_judge = LLM_request(model=global_config.llm_normal_minor, max_tokens=60,
self.llm_emotion_judge = LLM_request(model=global_config.llm_emotion_judge, max_tokens=60,
temperature=0.8) # 更高的温度更少的token后续可以根据情绪来调整温度

View File

@@ -19,12 +19,13 @@ from src.common.database import Database
# 从环境变量获取配置
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "maimai"),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "admin")
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
class KnowledgeLibrary:

View File

@@ -9,7 +9,10 @@ import networkx as nx
from dotenv import load_dotenv
from loguru import logger
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.database import Database # 使用正确的导入语法
# 加载.env.dev文件
@@ -162,12 +165,13 @@ class Memory_graph:
def main():
# 初始化数据库
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME", ""),
password=os.getenv("MONGODB_PASSWORD", ""),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "")
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
memory_graph = Memory_graph()

View File

@@ -3,11 +3,13 @@ import datetime
import math
import random
import time
import os
import jieba
import networkx as nx
from loguru import logger
from nonebot import get_driver
from ...common.database import Database # 使用正确的导入语法
from ..chat.config import global_config
from ..chat.utils import (
@@ -18,7 +20,6 @@ from ..chat.utils import (
)
from ..models.utils_model import LLM_request
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
@@ -150,7 +151,7 @@ class Memory_graph:
return None
# 海马体
# 海马体
class Hippocampus:
def __init__(self, memory_graph: Memory_graph):
self.memory_graph = memory_graph
@@ -318,6 +319,8 @@ class Hippocampus:
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}")
current_time = datetime.datetime.now().timestamp()
for topic, memory in compressed_memory:
logger.info(f"添加节点: {topic}")
self.memory_graph.add_dot(topic, memory)
@@ -330,7 +333,10 @@ class Hippocampus:
if topic != similar_topic:
strength = int(similarity * 10)
logger.info(f"连接相似节点: {topic}{similar_topic} (强度: {strength})")
self.memory_graph.G.add_edge(topic, similar_topic, strength=strength)
self.memory_graph.G.add_edge(topic, similar_topic,
strength=strength,
created_time=current_time,
last_modified=current_time)
# 连接同批次的相关话题
for i in range(len(all_topics)):
@@ -438,21 +444,39 @@ class Hippocampus:
def sync_memory_from_db(self):
"""从数据库同步数据到内存中的图结构"""
current_time = datetime.datetime.now().timestamp()
need_update = False
# 清空当前图
self.memory_graph.G.clear()
# 从数据库加载所有节点
nodes = self.memory_graph.db.db.graph_data.nodes.find()
nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
for node in nodes:
concept = node['concept']
memory_items = node.get('memory_items', [])
# 确保memory_items是列表
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 获取时间信息
created_time = node.get('created_time', datetime.datetime.now().timestamp())
last_modified = node.get('last_modified', datetime.datetime.now().timestamp())
# 检查时间字段是否存在
if 'created_time' not in node or 'last_modified' not in node:
need_update = True
# 更新数据库中的节点
update_data = {}
if 'created_time' not in node:
update_data['created_time'] = current_time
if 'last_modified' not in node:
update_data['last_modified'] = current_time
self.memory_graph.db.db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': update_data}
)
logger.info(f"为节点 {concept} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间)
created_time = node.get('created_time', current_time)
last_modified = node.get('last_modified', current_time)
# 添加节点到图中
self.memory_graph.G.add_node(concept,
@@ -461,15 +485,31 @@ class Hippocampus:
last_modified=last_modified)
# 从数据库加载所有边
edges = self.memory_graph.db.db.graph_data.edges.find()
edges = list(self.memory_graph.db.db.graph_data.edges.find())
for edge in edges:
source = edge['source']
target = edge['target']
strength = edge.get('strength', 1) # 获取 strength,默认为 1
strength = edge.get('strength', 1)
# 获取时间信息
created_time = edge.get('created_time', datetime.datetime.now().timestamp())
last_modified = edge.get('last_modified', datetime.datetime.now().timestamp())
# 检查时间字段是否存在
if 'created_time' not in edge or 'last_modified' not in edge:
need_update = True
# 更新数据库中的边
update_data = {}
if 'created_time' not in edge:
update_data['created_time'] = current_time
if 'last_modified' not in edge:
update_data['last_modified'] = current_time
self.memory_graph.db.db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': update_data}
)
logger.info(f"为边 {source} - {target} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间)
created_time = edge.get('created_time', current_time)
last_modified = edge.get('last_modified', current_time)
# 只有当源节点和目标节点都存在时才添加边
if source in self.memory_graph.G and target in self.memory_graph.G:
@@ -477,6 +517,9 @@ class Hippocampus:
strength=strength,
created_time=created_time,
last_modified=last_modified)
if need_update:
logger.success("已为缺失的时间字段进行补充")
async def operation_forget_topic(self, percentage=0.1):
"""随机选择图中一定比例的节点和边进行检查,根据时间条件决定是否遗忘"""
@@ -839,21 +882,19 @@ def segment_text(text):
seg_text = list(jieba.cut(text))
return seg_text
from nonebot import get_driver
driver = get_driver()
config = driver.config
start_time = time.time()
Database.initialize(
host=config.MONGODB_HOST,
port=config.MONGODB_PORT,
db_name=config.DATABASE_NAME,
username=config.MONGODB_USERNAME,
password=config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
# 创建记忆图
memory_graph = Memory_graph()

View File

@@ -16,7 +16,10 @@ from loguru import logger
import jieba
# from chat.config import global_config
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.database import Database
from src.plugins.memory_system.offline_llm import LLMModel
@@ -35,45 +38,6 @@ else:
logger.warning(f"未找到环境变量文件: {env_path}")
logger.info("将使用默认配置")
class Database:
_instance = None
db = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
if not Database.db:
Database.initialize(
host=os.getenv("MONGODB_HOST"),
port=int(os.getenv("MONGODB_PORT")),
db_name=os.getenv("DATABASE_NAME"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
@classmethod
def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"):
try:
if username and password:
uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}"
else:
uri = f"mongodb://{host}:{port}"
client = pymongo.MongoClient(uri)
cls.db = client[db_name]
# 测试连接
client.server_info()
logger.success("MongoDB连接成功!")
except Exception as e:
logger.error(f"初始化MongoDB失败: {str(e)}")
raise
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
char_count = Counter(text)
@@ -202,7 +166,7 @@ class Memory_graph:
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
# 海马体
# 海马体
class Hippocampus:
def __init__(self, memory_graph: Memory_graph):
self.memory_graph = memory_graph
@@ -941,59 +905,67 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
async def main():
# 初始化数据库
logger.info("正在初始化数据库连接...")
db = Database.get_instance()
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
start_time = time.time()
test_pare = {'do_build_memory':False,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}
# 创建记忆图
memory_graph = Memory_graph()
# 创建海马体
hippocampus = Hippocampus(memory_graph)
# 从数据库同步数据
hippocampus.sync_memory_from_db()
end_time = time.time()
logger.info(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
# 构建记忆
if test_pare['do_build_memory']:
logger.info("开始构建记忆...")
chat_size = 20
await hippocampus.operation_build_memory(chat_size=chat_size)
end_time = time.time()
logger.info(f"\033[32m[构建记忆耗时: {end_time - start_time:.2f} 秒,chat_size={chat_size},chat_count = 16]\033[0m")
if test_pare['do_forget_topic']:
logger.info("开始遗忘记忆...")
await hippocampus.operation_forget_topic(percentage=0.1)
end_time = time.time()
logger.info(f"\033[32m[遗忘记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
if test_pare['do_merge_memory']:
logger.info("开始合并记忆...")
await hippocampus.operation_merge_memory(percentage=0.1)
end_time = time.time()
logger.info(f"\033[32m[合并记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
if test_pare['do_visualize_graph']:
# 展示优化后的图形
logger.info("生成记忆图谱可视化...")
print("\n生成优化后的记忆图谱:")
visualize_graph_lite(memory_graph)
if test_pare['do_query']:
# 交互式查询
while True:
query = input("\n请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
first_layer, second_layer = items_list
@@ -1008,9 +980,6 @@ async def main():
else:
print("未找到相关记忆。")
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -69,13 +69,14 @@ class Database:
def __init__(self):
if not Database.db:
Database.initialize(
host=os.getenv("MONGODB_HOST"),
port=int(os.getenv("MONGODB_PORT")),
db_name=os.getenv("DATABASE_NAME"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
@classmethod
def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"):

View File

@@ -1,3 +1,4 @@
import os
import datetime
import json
from typing import Dict, Union
@@ -14,15 +15,15 @@ driver = get_driver()
config = driver.config
Database.initialize(
host=config.MONGODB_HOST,
port=int(config.MONGODB_PORT),
db_name=config.DATABASE_NAME,
username=config.MONGODB_USERNAME,
password=config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
class ScheduleGenerator:
def __init__(self):
# 根据global_config.llm_normal这一字典配置指定模型
@@ -176,6 +177,6 @@ class ScheduleGenerator:
# print(scheduler.tomorrow_schedule)
# if __name__ == "__main__":
# main()
# main()
bot_schedule = ScheduleGenerator()