Ruff format

This commit is contained in:
墨梓柒
2025-12-13 17:14:09 +08:00
parent ef377bb0cd
commit e680a4d1f5
60 changed files with 1546 additions and 1532 deletions

14
bot.py
View File

@@ -41,6 +41,7 @@ logger = get_logger("main")
# 定义重启退出码 # 定义重启退出码
RESTART_EXIT_CODE = 42 RESTART_EXIT_CODE = 42
def run_runner_process(): def run_runner_process():
""" """
Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。 Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。
@@ -55,25 +56,25 @@ def run_runner_process():
while True: while True:
logger.info(f"正在启动 {script_file}...") logger.info(f"正在启动 {script_file}...")
# 启动子进程 (Worker) # 启动子进程 (Worker)
# 使用 sys.executable 确保使用相同的 Python 解释器 # 使用 sys.executable 确保使用相同的 Python 解释器
cmd = [python_executable, script_file] + sys.argv[1:] cmd = [python_executable, script_file] + sys.argv[1:]
process = subprocess.Popen(cmd, env=env) process = subprocess.Popen(cmd, env=env)
try: try:
# 等待子进程结束 # 等待子进程结束
return_code = process.wait() return_code = process.wait()
if return_code == RESTART_EXIT_CODE: if return_code == RESTART_EXIT_CODE:
logger.info("检测到重启请求 (退出码 42),正在重启...") logger.info("检测到重启请求 (退出码 42),正在重启...")
time.sleep(1) # 稍作等待 time.sleep(1) # 稍作等待
continue continue
else: else:
logger.info(f"程序已退出 (退出码 {return_code})") logger.info(f"程序已退出 (退出码 {return_code})")
sys.exit(return_code) sys.exit(return_code)
except KeyboardInterrupt: except KeyboardInterrupt:
# 向子进程发送终止信号 # 向子进程发送终止信号
if process.poll() is None: if process.poll() is None:
@@ -87,6 +88,7 @@ def run_runner_process():
process.kill() process.kill()
sys.exit(0) sys.exit(0)
# 检查是否是 Worker 进程 # 检查是否是 Worker 进程
# 如果没有设置 MAIBOT_WORKER_PROCESS 环境变量,说明是直接运行的脚本, # 如果没有设置 MAIBOT_WORKER_PROCESS 环境变量,说明是直接运行的脚本,
# 此时应该作为 Runner 运行。 # 此时应该作为 Runner 运行。

View File

@@ -19,6 +19,7 @@ from typing import Any, Dict, List, Optional, Tuple
@dataclass @dataclass
class ConversionResult: class ConversionResult:
"""转换结果""" """转换结果"""
success: bool success: bool
servers: List[Dict[str, Any]] = field(default_factory=list) servers: List[Dict[str, Any]] = field(default_factory=list)
errors: List[str] = field(default_factory=list) errors: List[str] = field(default_factory=list)
@@ -53,7 +54,7 @@ class ConfigConverter:
@classmethod @classmethod
def detect_format(cls, config: Dict[str, Any]) -> Optional[str]: def detect_format(cls, config: Dict[str, Any]) -> Optional[str]:
"""检测配置格式类型 """检测配置格式类型
Returns: Returns:
"claude": Claude Desktop 格式 (mcpServers 对象) "claude": Claude Desktop 格式 (mcpServers 对象)
"kiro": Kiro MCP 格式 (mcpServers 对象,与 Claude 相同) "kiro": Kiro MCP 格式 (mcpServers 对象,与 Claude 相同)
@@ -82,7 +83,7 @@ class ConfigConverter:
@classmethod @classmethod
def parse_json_safe(cls, json_str: str) -> Tuple[Optional[Any], Optional[str]]: def parse_json_safe(cls, json_str: str) -> Tuple[Optional[Any], Optional[str]]:
"""安全解析 JSON 字符串 """安全解析 JSON 字符串
Returns: Returns:
(解析结果, 错误信息) (解析结果, 错误信息)
""" """
@@ -102,11 +103,11 @@ class ConfigConverter:
@classmethod @classmethod
def validate_server_config(cls, name: str, config: Dict[str, Any]) -> Tuple[bool, Optional[str], List[str]]: def validate_server_config(cls, name: str, config: Dict[str, Any]) -> Tuple[bool, Optional[str], List[str]]:
"""验证单个服务器配置 """验证单个服务器配置
Args: Args:
name: 服务器名称 name: 服务器名称
config: 服务器配置字典 config: 服务器配置字典
Returns: Returns:
(是否有效, 错误信息, 警告列表) (是否有效, 错误信息, 警告列表)
""" """
@@ -177,11 +178,11 @@ class ConfigConverter:
@classmethod @classmethod
def convert_claude_server(cls, name: str, config: Dict[str, Any]) -> Dict[str, Any]: def convert_claude_server(cls, name: str, config: Dict[str, Any]) -> Dict[str, Any]:
"""将单个 Claude 格式服务器配置转换为 MaiBot 格式 """将单个 Claude 格式服务器配置转换为 MaiBot 格式
Args: Args:
name: 服务器名称 name: 服务器名称
config: Claude 格式的服务器配置 config: Claude 格式的服务器配置
Returns: Returns:
MaiBot 格式的服务器配置 MaiBot 格式的服务器配置
""" """
@@ -231,10 +232,10 @@ class ConfigConverter:
@classmethod @classmethod
def convert_maibot_server(cls, config: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: def convert_maibot_server(cls, config: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""将单个 MaiBot 格式服务器配置转换为 Claude 格式 """将单个 MaiBot 格式服务器配置转换为 Claude 格式
Args: Args:
config: MaiBot 格式的服务器配置 config: MaiBot 格式的服务器配置
Returns: Returns:
(服务器名称, Claude 格式的服务器配置) (服务器名称, Claude 格式的服务器配置)
""" """
@@ -271,17 +272,13 @@ class ConfigConverter:
return name, result return name, result
@classmethod @classmethod
def from_claude_format( def from_claude_format(cls, config: Dict[str, Any], existing_names: Optional[set] = None) -> ConversionResult:
cls,
config: Dict[str, Any],
existing_names: Optional[set] = None
) -> ConversionResult:
"""从 Claude Desktop 格式转换为 MaiBot 格式 """从 Claude Desktop 格式转换为 MaiBot 格式
Args: Args:
config: Claude Desktop 配置 (包含 mcpServers 字段) config: Claude Desktop 配置 (包含 mcpServers 字段)
existing_names: 已存在的服务器名称集合,用于跳过重复 existing_names: 已存在的服务器名称集合,用于跳过重复
Returns: Returns:
ConversionResult ConversionResult
""" """
@@ -336,10 +333,10 @@ class ConfigConverter:
@classmethod @classmethod
def to_claude_format(cls, servers: List[Dict[str, Any]]) -> Dict[str, Any]: def to_claude_format(cls, servers: List[Dict[str, Any]]) -> Dict[str, Any]:
"""将 MaiBot 格式转换为 Claude Desktop 格式 """将 MaiBot 格式转换为 Claude Desktop 格式
Args: Args:
servers: MaiBot 格式的服务器列表 servers: MaiBot 格式的服务器列表
Returns: Returns:
Claude Desktop 格式的配置 Claude Desktop 格式的配置
""" """
@@ -355,19 +352,15 @@ class ConfigConverter:
return {"mcpServers": mcp_servers} return {"mcpServers": mcp_servers}
@classmethod @classmethod
def import_from_string( def import_from_string(cls, json_str: str, existing_names: Optional[set] = None) -> ConversionResult:
cls,
json_str: str,
existing_names: Optional[set] = None
) -> ConversionResult:
"""从 JSON 字符串导入配置 """从 JSON 字符串导入配置
自动检测格式并转换为 MaiBot 格式 自动检测格式并转换为 MaiBot 格式
Args: Args:
json_str: JSON 字符串 json_str: JSON 字符串
existing_names: 已存在的服务器名称集合 existing_names: 已存在的服务器名称集合
Returns: Returns:
ConversionResult ConversionResult
""" """
@@ -422,19 +415,14 @@ class ConfigConverter:
return result return result
@classmethod @classmethod
def export_to_string( def export_to_string(cls, servers: List[Dict[str, Any]], format_type: str = "claude", pretty: bool = True) -> str:
cls,
servers: List[Dict[str, Any]],
format_type: str = "claude",
pretty: bool = True
) -> str:
"""导出配置为 JSON 字符串 """导出配置为 JSON 字符串
Args: Args:
servers: MaiBot 格式的服务器列表 servers: MaiBot 格式的服务器列表
format_type: 导出格式 ("claude", "kiro", "maibot") format_type: 导出格式 ("claude", "kiro", "maibot")
pretty: 是否格式化输出 pretty: 是否格式化输出
Returns: Returns:
JSON 字符串 JSON 字符串
""" """

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -23,22 +23,22 @@ from mcp_client import (
async def test_stats(): async def test_stats():
"""测试统计类""" """测试统计类"""
print("\n=== 测试统计类 ===") print("\n=== 测试统计类 ===")
# 测试 ToolCallStats # 测试 ToolCallStats
stats = ToolCallStats(tool_key="test_tool") stats = ToolCallStats(tool_key="test_tool")
stats.record_call(True, 100.0) stats.record_call(True, 100.0)
stats.record_call(True, 200.0) stats.record_call(True, 200.0)
stats.record_call(False, 50.0, "timeout") stats.record_call(False, 50.0, "timeout")
assert stats.total_calls == 3 assert stats.total_calls == 3
assert stats.success_calls == 2 assert stats.success_calls == 2
assert stats.failed_calls == 1 assert stats.failed_calls == 1
assert stats.success_rate == (2/3) * 100 assert stats.success_rate == (2 / 3) * 100
assert stats.avg_duration_ms == 150.0 assert stats.avg_duration_ms == 150.0
assert stats.last_error == "timeout" assert stats.last_error == "timeout"
print(f"✅ ToolCallStats: {stats.to_dict()}") print(f"✅ ToolCallStats: {stats.to_dict()}")
# 测试 ServerStats # 测试 ServerStats
server_stats = ServerStats(server_name="test_server") server_stats = ServerStats(server_name="test_server")
server_stats.record_connect() server_stats.record_connect()
@@ -46,133 +46,138 @@ async def test_stats():
server_stats.record_disconnect() server_stats.record_disconnect()
server_stats.record_failure() server_stats.record_failure()
server_stats.record_failure() server_stats.record_failure()
assert server_stats.connect_count == 1 assert server_stats.connect_count == 1
assert server_stats.disconnect_count == 1 assert server_stats.disconnect_count == 1
assert server_stats.consecutive_failures == 2 assert server_stats.consecutive_failures == 2
print(f"✅ ServerStats: {server_stats.to_dict()}") print(f"✅ ServerStats: {server_stats.to_dict()}")
return True return True
async def test_manager_basic(): async def test_manager_basic():
"""测试管理器基本功能""" """测试管理器基本功能"""
print("\n=== 测试管理器基本功能 ===") print("\n=== 测试管理器基本功能 ===")
# 创建新的管理器实例(绕过单例) # 创建新的管理器实例(绕过单例)
manager = MCPClientManager.__new__(MCPClientManager) manager = MCPClientManager.__new__(MCPClientManager)
manager._initialized = False manager._initialized = False
manager.__init__() manager.__init__()
# 配置 # 配置
manager.configure({ manager.configure(
"tool_prefix": "mcp", {
"call_timeout": 30.0, "tool_prefix": "mcp",
"retry_attempts": 1, "call_timeout": 30.0,
"retry_interval": 1.0, "retry_attempts": 1,
"heartbeat_enabled": False, "retry_interval": 1.0,
}) "heartbeat_enabled": False,
}
)
# 测试状态 # 测试状态
status = manager.get_status() status = manager.get_status()
assert status["total_servers"] == 0 assert status["total_servers"] == 0
assert status["connected_servers"] == 0 assert status["connected_servers"] == 0
print(f"✅ 初始状态: {status}") print(f"✅ 初始状态: {status}")
# 测试添加禁用的服务器 # 测试添加禁用的服务器
config = MCPServerConfig( config = MCPServerConfig(
name="disabled_server", name="disabled_server", enabled=False, transport=TransportType.HTTP, url="https://example.com/mcp"
enabled=False,
transport=TransportType.HTTP,
url="https://example.com/mcp"
) )
result = await manager.add_server(config) result = await manager.add_server(config)
assert result == True assert result == True
assert "disabled_server" in manager._clients assert "disabled_server" in manager._clients
assert manager._clients["disabled_server"].is_connected == False assert manager._clients["disabled_server"].is_connected == False
print("✅ 添加禁用服务器成功") print("✅ 添加禁用服务器成功")
# 测试重复添加 # 测试重复添加
result = await manager.add_server(config) result = await manager.add_server(config)
assert result == False assert result == False
print("✅ 重复添加被拒绝") print("✅ 重复添加被拒绝")
# 测试移除 # 测试移除
result = await manager.remove_server("disabled_server") result = await manager.remove_server("disabled_server")
assert result == True assert result == True
assert "disabled_server" not in manager._clients assert "disabled_server" not in manager._clients
print("✅ 移除服务器成功") print("✅ 移除服务器成功")
# 清理 # 清理
await manager.shutdown() await manager.shutdown()
print("✅ 管理器关闭成功") print("✅ 管理器关闭成功")
return True return True
async def test_http_connection(): async def test_http_connection():
"""测试 HTTP 连接(使用真实的 MCP 服务器)""" """测试 HTTP 连接(使用真实的 MCP 服务器)"""
print("\n=== 测试 HTTP 连接 ===") print("\n=== 测试 HTTP 连接 ===")
# 创建新的管理器实例 # 创建新的管理器实例
manager = MCPClientManager.__new__(MCPClientManager) manager = MCPClientManager.__new__(MCPClientManager)
manager._initialized = False manager._initialized = False
manager.__init__() manager.__init__()
manager.configure({ manager.configure(
"tool_prefix": "mcp", {
"call_timeout": 30.0, "tool_prefix": "mcp",
"retry_attempts": 2, "call_timeout": 30.0,
"retry_interval": 2.0, "retry_attempts": 2,
"heartbeat_enabled": False, "retry_interval": 2.0,
}) "heartbeat_enabled": False,
}
)
# 使用 HowToCook MCP 服务器测试 # 使用 HowToCook MCP 服务器测试
config = MCPServerConfig( config = MCPServerConfig(
name="howtocook", name="howtocook",
enabled=True, enabled=True,
transport=TransportType.HTTP, transport=TransportType.HTTP,
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp" url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp",
) )
print(f"正在连接 {config.url} ...") print(f"正在连接 {config.url} ...")
result = await manager.add_server(config) result = await manager.add_server(config)
if result: if result:
print(f"✅ 连接成功!") print("✅ 连接成功!")
# 检查工具 # 检查工具
tools = manager.all_tools tools = manager.all_tools
print(f"✅ 发现 {len(tools)} 个工具:") print(f"✅ 发现 {len(tools)} 个工具:")
for tool_key in tools: for tool_key in tools:
print(f" - {tool_key}") print(f" - {tool_key}")
# 测试心跳 # 测试心跳
client = manager._clients["howtocook"] client = manager._clients["howtocook"]
healthy = await client.check_health() healthy = await client.check_health()
print(f"✅ 心跳检测: {'健康' if healthy else '异常'}") print(f"✅ 心跳检测: {'健康' if healthy else '异常'}")
# 测试工具调用 # 测试工具调用
if "mcp_howtocook_whatToEat" in tools: if "mcp_howtocook_whatToEat" in tools:
print("\n正在调用 whatToEat 工具...") print("\n正在调用 whatToEat 工具...")
call_result = await manager.call_tool("mcp_howtocook_whatToEat", {}) call_result = await manager.call_tool("mcp_howtocook_whatToEat", {})
if call_result.success: if call_result.success:
print(f"✅ 工具调用成功 (耗时: {call_result.duration_ms:.0f}ms)") print(f"✅ 工具调用成功 (耗时: {call_result.duration_ms:.0f}ms)")
print(f" 结果: {call_result.content[:200]}..." if len(str(call_result.content)) > 200 else f" 结果: {call_result.content}") print(
f" 结果: {call_result.content[:200]}..."
if len(str(call_result.content)) > 200
else f" 结果: {call_result.content}"
)
else: else:
print(f"❌ 工具调用失败: {call_result.error}") print(f"❌ 工具调用失败: {call_result.error}")
# 查看统计 # 查看统计
stats = manager.get_all_stats() stats = manager.get_all_stats()
print(f"\n📊 统计信息:") print("\n📊 统计信息:")
print(f" 全局调用: {stats['global']['total_tool_calls']}") print(f" 全局调用: {stats['global']['total_tool_calls']}")
print(f" 成功: {stats['global']['successful_calls']}") print(f" 成功: {stats['global']['successful_calls']}")
print(f" 失败: {stats['global']['failed_calls']}") print(f" 失败: {stats['global']['failed_calls']}")
else: else:
print(f"❌ 连接失败") print("❌ 连接失败")
# 清理 # 清理
await manager.shutdown() await manager.shutdown()
return result return result
@@ -181,55 +186,57 @@ async def test_http_connection():
async def test_heartbeat(): async def test_heartbeat():
"""测试心跳检测功能""" """测试心跳检测功能"""
print("\n=== 测试心跳检测 ===") print("\n=== 测试心跳检测 ===")
# 创建新的管理器实例 # 创建新的管理器实例
manager = MCPClientManager.__new__(MCPClientManager) manager = MCPClientManager.__new__(MCPClientManager)
manager._initialized = False manager._initialized = False
manager.__init__() manager.__init__()
manager.configure({ manager.configure(
"tool_prefix": "mcp", {
"call_timeout": 30.0, "tool_prefix": "mcp",
"retry_attempts": 1, "call_timeout": 30.0,
"retry_interval": 1.0, "retry_attempts": 1,
"heartbeat_enabled": True, "retry_interval": 1.0,
"heartbeat_interval": 5.0, # 5秒间隔用于测试 "heartbeat_enabled": True,
"auto_reconnect": True, "heartbeat_interval": 5.0, # 5秒间隔用于测试
"max_reconnect_attempts": 2, "auto_reconnect": True,
}) "max_reconnect_attempts": 2,
}
)
# 添加一个测试服务器 # 添加一个测试服务器
config = MCPServerConfig( config = MCPServerConfig(
name="heartbeat_test", name="heartbeat_test",
enabled=True, enabled=True,
transport=TransportType.HTTP, transport=TransportType.HTTP,
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp" url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp",
) )
print("正在连接服务器...") print("正在连接服务器...")
result = await manager.add_server(config) result = await manager.add_server(config)
if result: if result:
print("✅ 服务器连接成功") print("✅ 服务器连接成功")
# 启动心跳检测 # 启动心跳检测
await manager.start_heartbeat() await manager.start_heartbeat()
print("✅ 心跳检测已启动") print("✅ 心跳检测已启动")
# 等待一个心跳周期 # 等待一个心跳周期
print("等待心跳检测...") print("等待心跳检测...")
await asyncio.sleep(2) await asyncio.sleep(2)
# 检查状态 # 检查状态
status = manager.get_status() status = manager.get_status()
print(f"✅ 心跳运行状态: {status['heartbeat_running']}") print(f"✅ 心跳运行状态: {status['heartbeat_running']}")
# 停止心跳 # 停止心跳
await manager.stop_heartbeat() await manager.stop_heartbeat()
print("✅ 心跳检测已停止") print("✅ 心跳检测已停止")
else: else:
print("❌ 服务器连接失败,跳过心跳测试") print("❌ 服务器连接失败,跳过心跳测试")
await manager.shutdown() await manager.shutdown()
return True return True
@@ -239,30 +246,31 @@ async def main():
print("=" * 50) print("=" * 50)
print("MCP 客户端测试") print("MCP 客户端测试")
print("=" * 50) print("=" * 50)
try: try:
# 基础测试 # 基础测试
await test_stats() await test_stats()
await test_manager_basic() await test_manager_basic()
# 网络测试 # 网络测试
print("\n是否进行网络连接测试? (需要网络) [y/N]: ", end="") print("\n是否进行网络连接测试? (需要网络) [y/N]: ", end="")
# 自动进行网络测试 # 自动进行网络测试
await test_http_connection() await test_http_connection()
# 心跳测试 # 心跳测试
await test_heartbeat() await test_heartbeat()
print("\n" + "=" * 50) print("\n" + "=" * 50)
print("✅ 所有测试通过!") print("✅ 所有测试通过!")
print("=" * 50) print("=" * 50)
except Exception as e: except Exception as e:
print(f"\n❌ 测试失败: {e}") print(f"\n❌ 测试失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return False return False
return True return True

View File

@@ -35,13 +35,13 @@ def get_chat_name(chat_id: str) -> str:
return f"{chat_stream.group_name}" return f"{chat_stream.group_name}"
elif chat_stream.user_nickname: elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊" return f"{chat_stream.user_nickname}的私聊"
if get_chat_manager: if get_chat_manager:
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
stream_name = chat_manager.get_stream_name(chat_id) stream_name = chat_manager.get_stream_name(chat_id)
if stream_name: if stream_name:
return stream_name return stream_name
return f"未知聊天 ({chat_id[:8]}...)" return f"未知聊天 ({chat_id[:8]}...)"
except Exception: except Exception:
return f"查询失败 ({chat_id[:8]}...)" return f"查询失败 ({chat_id[:8]}...)"
@@ -51,11 +51,11 @@ def load_records(temp_dir: str = "data/temp") -> List[Dict[str, Any]]:
"""加载所有 replyer 动作记录""" """加载所有 replyer 动作记录"""
records = [] records = []
temp_path = Path(temp_dir) temp_path = Path(temp_dir)
if not temp_path.exists(): if not temp_path.exists():
print(f"目录不存在: {temp_dir}") print(f"目录不存在: {temp_dir}")
return records return records
# 查找所有 replyer_action_*.json 文件 # 查找所有 replyer_action_*.json 文件
pattern = "replyer_action_*.json" pattern = "replyer_action_*.json"
for file_path in temp_path.glob(pattern): for file_path in temp_path.glob(pattern):
@@ -65,7 +65,7 @@ def load_records(temp_dir: str = "data/temp") -> List[Dict[str, Any]]:
records.append(data) records.append(data)
except Exception as e: except Exception as e:
print(f"读取文件失败 {file_path}: {e}") print(f"读取文件失败 {file_path}: {e}")
# 按时间戳排序 # 按时间戳排序
records.sort(key=lambda x: x.get("timestamp", "")) records.sort(key=lambda x: x.get("timestamp", ""))
return records return records
@@ -91,7 +91,7 @@ def calculate_time_distribution(records: List[Dict[str, Any]]) -> Dict[str, int]
"30天内": 0, "30天内": 0,
"更早": 0, "更早": 0,
} }
for record in records: for record in records:
try: try:
ts = record.get("timestamp", "") ts = record.get("timestamp", "")
@@ -99,7 +99,7 @@ def calculate_time_distribution(records: List[Dict[str, Any]]) -> Dict[str, int]
continue continue
dt = datetime.fromisoformat(ts) dt = datetime.fromisoformat(ts)
diff = (now - dt).days diff = (now - dt).days
if diff == 0: if diff == 0:
distribution["今天"] += 1 distribution["今天"] += 1
elif diff == 1: elif diff == 1:
@@ -114,7 +114,7 @@ def calculate_time_distribution(records: List[Dict[str, Any]]) -> Dict[str, int]
distribution["更早"] += 1 distribution["更早"] += 1
except Exception: except Exception:
pass pass
return distribution return distribution
@@ -123,17 +123,17 @@ def print_statistics(records: List[Dict[str, Any]]):
if not records: if not records:
print("没有找到任何记录") print("没有找到任何记录")
return return
print("=" * 80) print("=" * 80)
print("Replyer 动作选择记录统计") print("Replyer 动作选择记录统计")
print("=" * 80) print("=" * 80)
print() print()
# 总记录数 # 总记录数
total_count = len(records) total_count = len(records)
print(f"📊 总记录数: {total_count}") print(f"📊 总记录数: {total_count}")
print() print()
# 时间范围 # 时间范围
timestamps = [r.get("timestamp", "") for r in records if r.get("timestamp")] timestamps = [r.get("timestamp", "") for r in records if r.get("timestamp")]
if timestamps: if timestamps:
@@ -141,7 +141,7 @@ def print_statistics(records: List[Dict[str, Any]]):
last_time = format_timestamp(max(timestamps)) last_time = format_timestamp(max(timestamps))
print(f"📅 时间范围: {first_time} ~ {last_time}") print(f"📅 时间范围: {first_time} ~ {last_time}")
print() print()
# 按 think_level 统计 # 按 think_level 统计
think_levels = [r.get("think_level", 0) for r in records] think_levels = [r.get("think_level", 0) for r in records]
think_level_counter = Counter(think_levels) think_level_counter = Counter(think_levels)
@@ -152,7 +152,7 @@ def print_statistics(records: List[Dict[str, Any]]):
level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})") level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})")
print(f" Level {level} ({level_name}): {count} 次 ({percentage:.1f}%)") print(f" Level {level} ({level_name}): {count} 次 ({percentage:.1f}%)")
print() print()
# 按 chat_id 统计(总体) # 按 chat_id 统计(总体)
chat_counter = Counter([r.get("chat_id", "未知") for r in records]) chat_counter = Counter([r.get("chat_id", "未知") for r in records])
print(f"💬 聊天分布 (共 {len(chat_counter)} 个聊天):") print(f"💬 聊天分布 (共 {len(chat_counter)} 个聊天):")
@@ -164,30 +164,30 @@ def print_statistics(records: List[Dict[str, Any]]):
if len(chat_counter) > 10: if len(chat_counter) > 10:
print(f" ... 还有 {len(chat_counter) - 10} 个聊天") print(f" ... 还有 {len(chat_counter) - 10} 个聊天")
print() print()
# 每个 chat_id 的详细统计 # 每个 chat_id 的详细统计
print("=" * 80) print("=" * 80)
print("每个聊天的详细统计") print("每个聊天的详细统计")
print("=" * 80) print("=" * 80)
print() print()
# 按 chat_id 分组记录 # 按 chat_id 分组记录
records_by_chat = defaultdict(list) records_by_chat = defaultdict(list)
for record in records: for record in records:
chat_id = record.get("chat_id", "未知") chat_id = record.get("chat_id", "未知")
records_by_chat[chat_id].append(record) records_by_chat[chat_id].append(record)
# 按记录数排序 # 按记录数排序
sorted_chats = sorted(records_by_chat.items(), key=lambda x: len(x[1]), reverse=True) sorted_chats = sorted(records_by_chat.items(), key=lambda x: len(x[1]), reverse=True)
for chat_id, chat_records in sorted_chats: for chat_id, chat_records in sorted_chats:
chat_name = get_chat_name(chat_id) chat_name = get_chat_name(chat_id)
chat_count = len(chat_records) chat_count = len(chat_records)
chat_percentage = (chat_count / total_count) * 100 chat_percentage = (chat_count / total_count) * 100
print(f"📱 {chat_name} ({chat_id[:8]}...)") print(f"📱 {chat_name} ({chat_id[:8]}...)")
print(f" 总记录数: {chat_count} ({chat_percentage:.1f}%)") print(f" 总记录数: {chat_count} ({chat_percentage:.1f}%)")
# 该聊天的 think_level 分布 # 该聊天的 think_level 分布
chat_think_levels = [r.get("think_level", 0) for r in chat_records] chat_think_levels = [r.get("think_level", 0) for r in chat_records]
chat_think_counter = Counter(chat_think_levels) chat_think_counter = Counter(chat_think_levels)
@@ -197,14 +197,14 @@ def print_statistics(records: List[Dict[str, Any]]):
level_percentage = (level_count / chat_count) * 100 level_percentage = (level_count / chat_count) * 100
level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})") level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})")
print(f" Level {level} ({level_name}): {level_count} 次 ({level_percentage:.1f}%)") print(f" Level {level} ({level_name}): {level_count} 次 ({level_percentage:.1f}%)")
# 该聊天的时间范围 # 该聊天的时间范围
chat_timestamps = [r.get("timestamp", "") for r in chat_records if r.get("timestamp")] chat_timestamps = [r.get("timestamp", "") for r in chat_records if r.get("timestamp")]
if chat_timestamps: if chat_timestamps:
first_time = format_timestamp(min(chat_timestamps)) first_time = format_timestamp(min(chat_timestamps))
last_time = format_timestamp(max(chat_timestamps)) last_time = format_timestamp(max(chat_timestamps))
print(f" 时间范围: {first_time} ~ {last_time}") print(f" 时间范围: {first_time} ~ {last_time}")
# 该聊天的时间分布 # 该聊天的时间分布
chat_time_dist = calculate_time_distribution(chat_records) chat_time_dist = calculate_time_distribution(chat_records)
print(" 时间分布:") print(" 时间分布:")
@@ -212,7 +212,7 @@ def print_statistics(records: List[Dict[str, Any]]):
if count > 0: if count > 0:
period_percentage = (count / chat_count) * 100 period_percentage = (count / chat_count) * 100
print(f" {period}: {count} 次 ({period_percentage:.1f}%)") print(f" {period}: {count} 次 ({period_percentage:.1f}%)")
# 显示该聊天最近的一条理由示例 # 显示该聊天最近的一条理由示例
if chat_records: if chat_records:
latest_record = chat_records[-1] latest_record = chat_records[-1]
@@ -222,9 +222,9 @@ def print_statistics(records: List[Dict[str, Any]]):
timestamp = format_timestamp(latest_record.get("timestamp", "")) timestamp = format_timestamp(latest_record.get("timestamp", ""))
think_level = latest_record.get("think_level", 0) think_level = latest_record.get("think_level", 0)
print(f" 最新记录 [{timestamp}] (Level {think_level}): {reason}") print(f" 最新记录 [{timestamp}] (Level {think_level}): {reason}")
print() print()
# 时间分布 # 时间分布
time_dist = calculate_time_distribution(records) time_dist = calculate_time_distribution(records)
print("⏰ 时间分布:") print("⏰ 时间分布:")
@@ -233,7 +233,7 @@ def print_statistics(records: List[Dict[str, Any]]):
percentage = (count / total_count) * 100 percentage = (count / total_count) * 100
print(f" {period}: {count} 次 ({percentage:.1f}%)") print(f" {period}: {count} 次 ({percentage:.1f}%)")
print() print()
# 显示一些示例理由 # 显示一些示例理由
print("📝 示例理由 (最近5条):") print("📝 示例理由 (最近5条):")
recent_records = records[-5:] recent_records = records[-5:]
@@ -243,29 +243,29 @@ def print_statistics(records: List[Dict[str, Any]]):
timestamp = format_timestamp(record.get("timestamp", "")) timestamp = format_timestamp(record.get("timestamp", ""))
chat_id = record.get("chat_id", "未知") chat_id = record.get("chat_id", "未知")
chat_name = get_chat_name(chat_id) chat_name = get_chat_name(chat_id)
# 截断过长的理由 # 截断过长的理由
if len(reason) > 100: if len(reason) > 100:
reason = reason[:100] + "..." reason = reason[:100] + "..."
print(f" {i}. [{timestamp}] {chat_name} (Level {think_level})") print(f" {i}. [{timestamp}] {chat_name} (Level {think_level})")
print(f" {reason}") print(f" {reason}")
print() print()
# 按 think_level 分组显示理由示例 # 按 think_level 分组显示理由示例
print("=" * 80) print("=" * 80)
print("按思考深度分类的示例理由") print("按思考深度分类的示例理由")
print("=" * 80) print("=" * 80)
print() print()
for level in [0, 1, 2]: for level in [0, 1, 2]:
level_records = [r for r in records if r.get("think_level") == level] level_records = [r for r in records if r.get("think_level") == level]
if not level_records: if not level_records:
continue continue
level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})") level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})")
print(f"Level {level} ({level_name}) - 共 {len(level_records)} 条:") print(f"Level {level} ({level_name}) - 共 {len(level_records)} 条:")
# 显示3个示例选择最近的 # 显示3个示例选择最近的
examples = level_records[-3:] if len(level_records) >= 3 else level_records examples = level_records[-3:] if len(level_records) >= 3 else level_records
for i, record in enumerate(examples, 1): for i, record in enumerate(examples, 1):
@@ -278,7 +278,7 @@ def print_statistics(records: List[Dict[str, Any]]):
print(f" {i}. [{timestamp}] {chat_name}") print(f" {i}. [{timestamp}] {chat_name}")
print(f" {reason}") print(f" {reason}")
print() print()
# 统计信息汇总 # 统计信息汇总
print("=" * 80) print("=" * 80)
print("统计汇总") print("统计汇总")
@@ -301,4 +301,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -13,7 +13,12 @@ from src.chat.utils.chat_message_builder import (
) )
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.bw_learner.learner_utils import filter_message_content, is_bot_message, build_context_paragraph, contains_bot_self_name from src.bw_learner.learner_utils import (
filter_message_content,
is_bot_message,
build_context_paragraph,
contains_bot_self_name,
)
from src.bw_learner.jargon_miner import miner_manager from src.bw_learner.jargon_miner import miner_manager
from json_repair import repair_json from json_repair import repair_json
@@ -77,8 +82,6 @@ def init_prompt() -> None:
Prompt(learn_style_prompt, "learn_style_prompt") Prompt(learn_style_prompt, "learn_style_prompt")
class ExpressionLearner: class ExpressionLearner:
def __init__(self, chat_id: str) -> None: def __init__(self, chat_id: str) -> None:
self.express_learn_model: LLMRequest = LLMRequest( self.express_learn_model: LLMRequest = LLMRequest(
@@ -95,12 +98,12 @@ class ExpressionLearner:
self._learning_lock = asyncio.Lock() self._learning_lock = asyncio.Lock()
async def learn_and_store( async def learn_and_store(
self, self,
messages: List[Any], messages: List[Any],
) -> List[Tuple[str, str, str]]: ) -> List[Tuple[str, str, str]]:
""" """
学习并存储表达方式 学习并存储表达方式
Args: Args:
messages: 外部传入的消息列表(必需) messages: 外部传入的消息列表(必需)
num: 学习数量 num: 学习数量
@@ -108,7 +111,7 @@ class ExpressionLearner:
""" """
if not messages: if not messages:
return None return None
random_msg = messages random_msg = messages
# 学习用(开启行编号,便于溯源) # 学习用(开启行编号,便于溯源)
@@ -134,26 +137,26 @@ class ExpressionLearner:
jargon_entries: List[Tuple[str, str]] # (content, source_id) jargon_entries: List[Tuple[str, str]] # (content, source_id)
expressions, jargon_entries = self.parse_expression_response(response) expressions, jargon_entries = self.parse_expression_response(response)
expressions = self._filter_self_reference_styles(expressions) expressions = self._filter_self_reference_styles(expressions)
# 检查表达方式数量如果超过10个则放弃本次表达学习 # 检查表达方式数量如果超过10个则放弃本次表达学习
if len(expressions) > 10: if len(expressions) > 10:
logger.info(f"表达方式提取数量超过10个实际{len(expressions)}个),放弃本次表达学习") logger.info(f"表达方式提取数量超过10个实际{len(expressions)}个),放弃本次表达学习")
expressions = [] expressions = []
# 检查黑话数量如果超过30个则放弃本次黑话学习 # 检查黑话数量如果超过30个则放弃本次黑话学习
if len(jargon_entries) > 30: if len(jargon_entries) > 30:
logger.info(f"黑话提取数量超过30个实际{len(jargon_entries)}个),放弃本次黑话学习") logger.info(f"黑话提取数量超过30个实际{len(jargon_entries)}个),放弃本次黑话学习")
jargon_entries = [] jargon_entries = []
# 处理黑话条目,路由到 jargon_miner即使没有表达方式也要处理黑话 # 处理黑话条目,路由到 jargon_miner即使没有表达方式也要处理黑话
if jargon_entries: if jargon_entries:
await self._process_jargon_entries(jargon_entries, random_msg) await self._process_jargon_entries(jargon_entries, random_msg)
# 如果没有表达方式,直接返回 # 如果没有表达方式,直接返回
if not expressions: if not expressions:
logger.info("过滤后没有可用的表达方式style 与机器人名称重复)") logger.info("过滤后没有可用的表达方式style 与机器人名称重复)")
return [] return []
logger.info(f"学习的prompt: {prompt}") logger.info(f"学习的prompt: {prompt}")
logger.info(f"学习的expressions: {expressions}") logger.info(f"学习的expressions: {expressions}")
logger.info(f"学习的jargon_entries: {jargon_entries}") logger.info(f"学习的jargon_entries: {jargon_entries}")
@@ -175,18 +178,17 @@ class ExpressionLearner:
# 当前行的原始内容 # 当前行的原始内容
current_msg = random_msg[line_index] current_msg = random_msg[line_index]
# 过滤掉从bot自己发言中提取到的表达方式 # 过滤掉从bot自己发言中提取到的表达方式
if is_bot_message(current_msg): if is_bot_message(current_msg):
continue continue
context = filter_message_content(current_msg.processed_plain_text or "") context = filter_message_content(current_msg.processed_plain_text or "")
if not context: if not context:
continue continue
filtered_expressions.append((situation, style, context)) filtered_expressions.append((situation, style, context))
learnt_expressions = filtered_expressions learnt_expressions = filtered_expressions
if learnt_expressions is None: if learnt_expressions is None:
@@ -270,37 +272,38 @@ class ExpressionLearner:
# 如果解析失败,尝试修复中文引号问题 # 如果解析失败,尝试修复中文引号问题
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号 # 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
try: try:
def fix_chinese_quotes_in_json(text): def fix_chinese_quotes_in_json(text):
"""使用状态机修复 JSON 字符串值中的中文引号""" """使用状态机修复 JSON 字符串值中的中文引号"""
result = [] result = []
i = 0 i = 0
in_string = False in_string = False
escape_next = False escape_next = False
while i < len(text): while i < len(text):
char = text[i] char = text[i]
if escape_next: if escape_next:
# 当前字符是转义字符后的字符,直接添加 # 当前字符是转义字符后的字符,直接添加
result.append(char) result.append(char)
escape_next = False escape_next = False
i += 1 i += 1
continue continue
if char == '\\': if char == "\\":
# 转义字符 # 转义字符
result.append(char) result.append(char)
escape_next = True escape_next = True
i += 1 i += 1
continue continue
if char == '"' and not escape_next: if char == '"' and not escape_next:
# 遇到英文引号,切换字符串状态 # 遇到英文引号,切换字符串状态
in_string = not in_string in_string = not in_string
result.append(char) result.append(char)
i += 1 i += 1
continue continue
if in_string: if in_string:
# 在字符串值内部,将中文引号替换为转义的英文引号 # 在字符串值内部,将中文引号替换为转义的英文引号
if char == '"': # 中文左引号 U+201C if char == '"': # 中文左引号 U+201C
@@ -312,13 +315,13 @@ class ExpressionLearner:
else: else:
# 不在字符串内,直接添加 # 不在字符串内,直接添加
result.append(char) result.append(char)
i += 1 i += 1
return ''.join(result) return "".join(result)
fixed_raw = fix_chinese_quotes_in_json(raw) fixed_raw = fix_chinese_quotes_in_json(raw)
# 再次尝试解析 # 再次尝试解析
if fixed_raw.startswith("[") and fixed_raw.endswith("]"): if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
parsed = json.loads(fixed_raw) parsed = json.loads(fixed_raw)
@@ -346,12 +349,12 @@ class ExpressionLearner:
for item in parsed_list: for item in parsed_list:
if not isinstance(item, dict): if not isinstance(item, dict):
continue continue
# 检查是否是表达方式条目(有 situation 和 style # 检查是否是表达方式条目(有 situation 和 style
situation = str(item.get("situation", "")).strip() situation = str(item.get("situation", "")).strip()
style = str(item.get("style", "")).strip() style = str(item.get("style", "")).strip()
source_id = str(item.get("source_id", "")).strip() source_id = str(item.get("source_id", "")).strip()
if situation and style and source_id: if situation and style and source_id:
# 表达方式条目 # 表达方式条目
expressions.append((situation, style, source_id)) expressions.append((situation, style, source_id))
@@ -503,59 +506,59 @@ class ExpressionLearner:
async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None: async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None:
""" """
处理从 expression learner 提取的黑话条目,路由到 jargon_miner 处理从 expression learner 提取的黑话条目,路由到 jargon_miner
Args: Args:
jargon_entries: 黑话条目列表,每个元素是 (content, source_id) jargon_entries: 黑话条目列表,每个元素是 (content, source_id)
messages: 消息列表,用于构建上下文 messages: 消息列表,用于构建上下文
""" """
if not jargon_entries or not messages: if not jargon_entries or not messages:
return return
# 获取 jargon_miner 实例 # 获取 jargon_miner 实例
jargon_miner = miner_manager.get_miner(self.chat_id) jargon_miner = miner_manager.get_miner(self.chat_id)
# 构建黑话条目格式,与 jargon_miner.run_once 中的格式一致 # 构建黑话条目格式,与 jargon_miner.run_once 中的格式一致
entries: List[Dict[str, List[str]]] = [] entries: List[Dict[str, List[str]]] = []
for content, source_id in jargon_entries: for content, source_id in jargon_entries:
content = content.strip() content = content.strip()
if not content: if not content:
continue continue
# 检查是否包含机器人名称 # 检查是否包含机器人名称
if contains_bot_self_name(content): if contains_bot_self_name(content):
logger.info(f"跳过包含机器人昵称/别名的黑话: {content}") logger.info(f"跳过包含机器人昵称/别名的黑话: {content}")
continue continue
# 解析 source_id # 解析 source_id
source_id_str = (source_id or "").strip() source_id_str = (source_id or "").strip()
if not source_id_str.isdigit(): if not source_id_str.isdigit():
logger.warning(f"黑话条目 source_id 无效: content={content}, source_id={source_id_str}") logger.warning(f"黑话条目 source_id 无效: content={content}, source_id={source_id_str}")
continue continue
# build_anonymous_messages 的编号从 1 开始 # build_anonymous_messages 的编号从 1 开始
line_index = int(source_id_str) - 1 line_index = int(source_id_str) - 1
if line_index < 0 or line_index >= len(messages): if line_index < 0 or line_index >= len(messages):
logger.warning(f"黑话条目 source_id 超出范围: content={content}, source_id={source_id_str}") logger.warning(f"黑话条目 source_id 超出范围: content={content}, source_id={source_id_str}")
continue continue
# 检查是否是机器人自己的消息 # 检查是否是机器人自己的消息
target_msg = messages[line_index] target_msg = messages[line_index]
if is_bot_message(target_msg): if is_bot_message(target_msg):
logger.info(f"跳过引用机器人自身消息的黑话: content={content}, source_id={source_id_str}") logger.info(f"跳过引用机器人自身消息的黑话: content={content}, source_id={source_id_str}")
continue continue
# 构建上下文段落 # 构建上下文段落
context_paragraph = build_context_paragraph(messages, line_index) context_paragraph = build_context_paragraph(messages, line_index)
if not context_paragraph: if not context_paragraph:
logger.warning(f"黑话条目上下文为空: content={content}, source_id={source_id_str}") logger.warning(f"黑话条目上下文为空: content={content}, source_id={source_id_str}")
continue continue
entries.append({"content": content, "raw_content": [context_paragraph]}) entries.append({"content": content, "raw_content": [context_paragraph]})
if not entries: if not entries:
return return
# 调用 jargon_miner 处理这些条目 # 调用 jargon_miner 处理这些条目
await jargon_miner.process_extracted_entries(entries) await jargon_miner.process_extracted_entries(entries)

View File

@@ -82,9 +82,7 @@ class ExpressionReflector:
# 获取未检查的表达 # 获取未检查的表达
try: try:
logger.info("[Expression Reflection] 查询未检查且未拒绝的表达") logger.info("[Expression Reflection] 查询未检查且未拒绝的表达")
expressions = ( expressions = Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
)
expr_list = list(expressions) expr_list = list(expressions)
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达") logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")

View File

@@ -128,9 +128,7 @@ class ExpressionSelector:
# 查询所有相关chat_id的表达方式排除 rejected=1 的,且只选择 count > 1 的 # 查询所有相关chat_id的表达方式排除 rejected=1 的,且只选择 count > 1 的
style_query = Expression.select().where( style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
& (~Expression.rejected)
& (Expression.count > 1)
) )
style_exprs = [ style_exprs = [
@@ -150,12 +148,15 @@ class ExpressionSelector:
# 要求至少有10个 count > 1 的表达方式才进行选择 # 要求至少有10个 count > 1 的表达方式才进行选择
min_required = 10 min_required = 10
if len(style_exprs) < min_required: if len(style_exprs) < min_required:
logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择") logger.info(
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择"
)
return [], [] return [], []
# 固定选择5个 # 固定选择5个
select_count = 5 select_count = 5
import random import random
selected_style = random.sample(style_exprs, select_count) selected_style = random.sample(style_exprs, select_count)
# 更新last_active_time # 更新last_active_time
@@ -163,7 +164,9 @@ class ExpressionSelector:
self.update_expressions_last_active_time(selected_style) self.update_expressions_last_active_time(selected_style)
selected_ids = [expr["id"] for expr in selected_style] selected_ids = [expr["id"] for expr in selected_style]
logger.debug(f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)}") logger.debug(
f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)}"
)
return selected_style, selected_ids return selected_style, selected_ids
except Exception as e: except Exception as e:
@@ -186,9 +189,7 @@ class ExpressionSelector:
related_chat_ids = self.get_related_chat_ids(chat_id) related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式排除 rejected=1 的表达 # 优化一次性查询所有相关chat_id的表达方式排除 rejected=1 的表达
style_query = Expression.select().where( style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
)
style_exprs = [ style_exprs = [
{ {
@@ -246,7 +247,9 @@ class ExpressionSelector:
# 使用classic模式随机选择+LLM选择 # 使用classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式think_level={think_level}") logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式think_level={think_level}")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message, reply_reason, think_level) return await self._select_expressions_classic(
chat_id, chat_info, max_num, target_message, reply_reason, think_level
)
async def _select_expressions_classic( async def _select_expressions_classic(
self, self,
@@ -275,14 +278,12 @@ class ExpressionSelector:
# think_level == 0: 只选择 count > 1 的项目随机选10个不进行LLM选择 # think_level == 0: 只选择 count > 1 的项目随机选10个不进行LLM选择
if think_level == 0: if think_level == 0:
return self._select_expressions_simple(chat_id, max_num) return self._select_expressions_simple(chat_id, max_num)
# think_level == 1: 先选高count再从所有表达方式中随机抽样 # think_level == 1: 先选高count再从所有表达方式中随机抽样
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的 # 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
related_chat_ids = self.get_related_chat_ids(chat_id) related_chat_ids = self.get_related_chat_ids(chat_id)
style_query = Expression.select().where( style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
)
all_style_exprs = [ all_style_exprs = [
{ {
"id": expr.id, "id": expr.id,
@@ -299,29 +300,33 @@ class ExpressionSelector:
# 分离 count > 1 和 count <= 1 的表达方式 # 分离 count > 1 和 count <= 1 的表达方式
high_count_exprs = [expr for expr in all_style_exprs if (expr.get("count", 1) or 1) > 1] high_count_exprs = [expr for expr in all_style_exprs if (expr.get("count", 1) or 1) > 1]
# 根据 think_level 设置要求(仅支持 0/10 已在上方返回) # 根据 think_level 设置要求(仅支持 0/10 已在上方返回)
min_high_count = 10 min_high_count = 10
min_total_count = 10 min_total_count = 10
select_high_count = 5 select_high_count = 5
select_random_count = 5 select_random_count = 5
# 检查数量要求 # 检查数量要求
if len(high_count_exprs) < min_high_count: if len(high_count_exprs) < min_high_count:
logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择") logger.info(
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择"
)
return [], [] return [], []
if len(all_style_exprs) < min_total_count: if len(all_style_exprs) < min_total_count:
logger.info(f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择") logger.info(
f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择"
)
return [], [] return [], []
# 先选取高count的表达方式 # 先选取高count的表达方式
selected_high = weighted_sample(high_count_exprs, min(len(high_count_exprs), select_high_count)) selected_high = weighted_sample(high_count_exprs, min(len(high_count_exprs), select_high_count))
# 然后从所有表达方式中随机抽样(使用加权抽样) # 然后从所有表达方式中随机抽样(使用加权抽样)
remaining_num = select_random_count remaining_num = select_random_count
selected_random = weighted_sample(all_style_exprs, min(len(all_style_exprs), remaining_num)) selected_random = weighted_sample(all_style_exprs, min(len(all_style_exprs), remaining_num))
# 合并候选池(去重,避免重复) # 合并候选池(去重,避免重复)
candidate_exprs = selected_high.copy() candidate_exprs = selected_high.copy()
candidate_ids = {expr["id"] for expr in candidate_exprs} candidate_ids = {expr["id"] for expr in candidate_exprs}
@@ -329,9 +334,10 @@ class ExpressionSelector:
if expr["id"] not in candidate_ids: if expr["id"] not in candidate_ids:
candidate_exprs.append(expr) candidate_exprs.append(expr)
candidate_ids.add(expr["id"]) candidate_ids.add(expr["id"])
# 打乱顺序避免高count的都在前面 # 打乱顺序避免高count的都在前面
import random import random
random.shuffle(candidate_exprs) random.shuffle(candidate_exprs)
# 2. 构建所有表达方式的索引和情境列表 # 2. 构建所有表达方式的索引和情境列表
@@ -351,7 +357,7 @@ class ExpressionSelector:
all_situations_str = "\n".join(all_situations) all_situations_str = "\n".join(all_situations)
if target_message: if target_message:
target_message_str = f",现在你想要对这条消息进行回复:\"{target_message}\"" target_message_str = f',现在你想要对这条消息进行回复:"{target_message}"'
target_message_extra_block = "4.考虑你要回复的目标消息" target_message_extra_block = "4.考虑你要回复的目标消息"
else: else:
target_message_str = "" target_message_str = ""

View File

@@ -8,7 +8,12 @@ from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config from src.config.config import model_config, global_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.bw_learner.jargon_miner import search_jargon from src.bw_learner.jargon_miner import search_jargon
from src.bw_learner.learner_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains from src.bw_learner.learner_utils import (
is_bot_message,
contains_bot_self_name,
parse_chat_id_list,
chat_id_list_contains,
)
logger = get_logger("jargon") logger = get_logger("jargon")
@@ -357,4 +362,4 @@ async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> st
if results: if results:
return "【概念检索结果】\n" + "\n".join(results) + "\n" return "【概念检索结果】\n" + "\n".join(results) + "\n"
return "" return ""

View File

@@ -1,4 +1,3 @@
import time
import json import json
import asyncio import asyncio
import random import random
@@ -14,7 +13,6 @@ from src.config.config import model_config, global_config
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id, build_readable_messages_with_id,
get_raw_msg_by_timestamp_with_chat_inclusive,
) )
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.bw_learner.learner_utils import ( from src.bw_learner.learner_utils import (
@@ -33,23 +31,23 @@ logger = get_logger("jargon")
def _is_single_char_jargon(content: str) -> bool: def _is_single_char_jargon(content: str) -> bool:
""" """
判断是否是单字黑话(单个汉字、英文或数字) 判断是否是单字黑话(单个汉字、英文或数字)
Args: Args:
content: 词条内容 content: 词条内容
Returns: Returns:
bool: 如果是单字黑话返回True否则返回False bool: 如果是单字黑话返回True否则返回False
""" """
if not content or len(content) != 1: if not content or len(content) != 1:
return False return False
char = content[0] char = content[0]
# 判断是否是单个汉字、单个英文字母或单个数字 # 判断是否是单个汉字、单个英文字母或单个数字
return ( return (
'\u4e00' <= char <= '\u9fff' or # 汉字 "\u4e00" <= char <= "\u9fff" # 汉字
'a' <= char <= 'z' or # 小写字母 or "a" <= char <= "z" # 小写字母
'A' <= char <= 'Z' or # 大写字母 or "A" <= char <= "Z" # 大写字母
'0' <= char <= '9' # 数字 or "0" <= char <= "9" # 数字
) )
@@ -195,7 +193,7 @@ class JargonMiner:
model_set=model_config.model_task_config.utils, model_set=model_config.model_task_config.utils,
request_type="jargon.extract", request_type="jargon.extract",
) )
self.llm_inference = LLMRequest( self.llm_inference = LLMRequest(
model_set=model_config.model_task_config.utils, model_set=model_config.model_task_config.utils,
request_type="jargon.inference", request_type="jargon.inference",
@@ -207,7 +205,7 @@ class JargonMiner:
self.stream_name = stream_name if stream_name else self.chat_id self.stream_name = stream_name if stream_name else self.chat_id
self.cache_limit = 50 self.cache_limit = 50
self.cache: OrderedDict[str, None] = OrderedDict() self.cache: OrderedDict[str, None] = OrderedDict()
# 黑话提取锁,防止并发执行 # 黑话提取锁,防止并发执行
self._extraction_lock = asyncio.Lock() self._extraction_lock = asyncio.Lock()
@@ -299,17 +297,19 @@ class JargonMiner:
# 获取当前count和上一次的meaning # 获取当前count和上一次的meaning
current_count = jargon_obj.count or 0 current_count = jargon_obj.count or 0
previous_meaning = jargon_obj.meaning or "" previous_meaning = jargon_obj.meaning or ""
# 当count为24, 60时随机移除一半的raw_content项目 # 当count为24, 60时随机移除一半的raw_content项目
if current_count in [24, 60] and len(raw_content_list) > 1: if current_count in [24, 60] and len(raw_content_list) > 1:
# 计算要保留的数量至少保留1个 # 计算要保留的数量至少保留1个
keep_count = max(1, len(raw_content_list) // 2) keep_count = max(1, len(raw_content_list) // 2)
raw_content_list = random.sample(raw_content_list, keep_count) raw_content_list = random.sample(raw_content_list, keep_count)
logger.info(f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目") logger.info(
f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目"
)
# 步骤1: 基于raw_content和content推断 # 步骤1: 基于raw_content和content推断
raw_content_text = "\n".join(raw_content_list) raw_content_text = "\n".join(raw_content_list)
# 当count为24, 60, 100时在prompt中放入上一次推断出的meaning作为参考 # 当count为24, 60, 100时在prompt中放入上一次推断出的meaning作为参考
previous_meaning_section = "" previous_meaning_section = ""
previous_meaning_instruction = "" previous_meaning_instruction = ""
@@ -318,8 +318,10 @@ class JargonMiner:
**上一次推断的含义(仅供参考)** **上一次推断的含义(仅供参考)**
{previous_meaning} {previous_meaning}
""" """
previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果" previous_meaning_instruction = (
"- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
)
prompt1 = await global_prompt_manager.format_prompt( prompt1 = await global_prompt_manager.format_prompt(
"jargon_inference_with_context_prompt", "jargon_inference_with_context_prompt",
content=content, content=content,
@@ -481,7 +483,7 @@ class JargonMiner:
async def run_once(self, messages: List[Any]) -> None: async def run_once(self, messages: List[Any]) -> None:
""" """
运行一次黑话提取 运行一次黑话提取
Args: Args:
messages: 外部传入的消息列表(必需) messages: 外部传入的消息列表(必需)
""" """
@@ -650,7 +652,9 @@ class JargonMiner:
if obj.raw_content: if obj.raw_content:
try: try:
existing_raw_content = ( existing_raw_content = (
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content json.loads(obj.raw_content)
if isinstance(obj.raw_content, str)
else obj.raw_content
) )
if not isinstance(existing_raw_content, list): if not isinstance(existing_raw_content, list):
existing_raw_content = [existing_raw_content] if existing_raw_content else [] existing_raw_content = [existing_raw_content] if existing_raw_content else []
@@ -726,13 +730,13 @@ class JargonMiner:
async def process_extracted_entries(self, entries: List[Dict[str, List[str]]]) -> None: async def process_extracted_entries(self, entries: List[Dict[str, List[str]]]) -> None:
""" """
处理已提取的黑话条目(从 expression_learner 路由过来的) 处理已提取的黑话条目(从 expression_learner 路由过来的)
Args: Args:
entries: 黑话条目列表,每个元素格式为 {"content": "...", "raw_content": [...]} entries: 黑话条目列表,每个元素格式为 {"content": "...", "raw_content": [...]}
""" """
if not entries: if not entries:
return return
try: try:
# 去重并合并raw_content按 content 聚合) # 去重并合并raw_content按 content 聚合)
merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict() merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict()
@@ -876,8 +880,6 @@ class JargonMinerManager:
miner_manager = JargonMinerManager() miner_manager = JargonMinerManager()
def search_jargon( def search_jargon(
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:

View File

@@ -15,25 +15,25 @@ class MessageRecorder:
""" """
统一的消息记录器,负责管理时间窗口和消息提取,并将消息分发给 expression_learner 和 jargon_miner 统一的消息记录器,负责管理时间窗口和消息提取,并将消息分发给 expression_learner 和 jargon_miner
""" """
def __init__(self, chat_id: str) -> None: def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id) self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
# 维护每个chat的上次提取时间 # 维护每个chat的上次提取时间
self.last_extraction_time: float = time.time() self.last_extraction_time: float = time.time()
# 提取锁,防止并发执行 # 提取锁,防止并发执行
self._extraction_lock = asyncio.Lock() self._extraction_lock = asyncio.Lock()
# 获取 expression 和 jargon 的配置参数 # 获取 expression 和 jargon 的配置参数
self._init_parameters() self._init_parameters()
# 获取 expression_learner 和 jargon_miner 实例 # 获取 expression_learner 和 jargon_miner 实例
self.expression_learner = expression_learner_manager.get_expression_learner(chat_id) self.expression_learner = expression_learner_manager.get_expression_learner(chat_id)
self.jargon_miner = miner_manager.get_miner(chat_id) self.jargon_miner = miner_manager.get_miner(chat_id)
def _init_parameters(self) -> None: def _init_parameters(self) -> None:
"""初始化提取参数""" """初始化提取参数"""
# 获取 expression 配置 # 获取 expression 配置
@@ -42,17 +42,17 @@ class MessageRecorder:
) )
self.min_messages_for_extraction = 30 self.min_messages_for_extraction = 30
self.min_extraction_interval = 60 self.min_extraction_interval = 60
logger.debug( logger.debug(
f"MessageRecorder 初始化: chat_id={self.chat_id}, " f"MessageRecorder 初始化: chat_id={self.chat_id}, "
f"min_messages={self.min_messages_for_extraction}, " f"min_messages={self.min_messages_for_extraction}, "
f"min_interval={self.min_extraction_interval}" f"min_interval={self.min_extraction_interval}"
) )
def should_trigger_extraction(self) -> bool: def should_trigger_extraction(self) -> bool:
""" """
检查是否应该触发消息提取 检查是否应该触发消息提取
Returns: Returns:
bool: 是否应该触发提取 bool: 是否应该触发提取
""" """
@@ -60,19 +60,19 @@ class MessageRecorder:
time_diff = time.time() - self.last_extraction_time time_diff = time.time() - self.last_extraction_time
if time_diff < self.min_extraction_interval: if time_diff < self.min_extraction_interval:
return False return False
# 检查消息数量 # 检查消息数量
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive( recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id, chat_id=self.chat_id,
timestamp_start=self.last_extraction_time, timestamp_start=self.last_extraction_time,
timestamp_end=time.time(), timestamp_end=time.time(),
) )
if not recent_messages or len(recent_messages) < self.min_messages_for_extraction: if not recent_messages or len(recent_messages) < self.min_messages_for_extraction:
return False return False
return True return True
async def extract_and_distribute(self) -> None: async def extract_and_distribute(self) -> None:
""" """
提取消息并分发给 expression_learner 和 jargon_miner 提取消息并分发给 expression_learner 和 jargon_miner
@@ -82,41 +82,40 @@ class MessageRecorder:
# 在锁内检查,避免并发触发 # 在锁内检查,避免并发触发
if not self.should_trigger_extraction(): if not self.should_trigger_extraction():
return return
# 检查 chat_stream 是否存在 # 检查 chat_stream 是否存在
if not self.chat_stream: if not self.chat_stream:
return return
# 记录本次提取的时间窗口,避免重复提取 # 记录本次提取的时间窗口,避免重复提取
extraction_start_time = self.last_extraction_time extraction_start_time = self.last_extraction_time
extraction_end_time = time.time() extraction_end_time = time.time()
# 立即更新提取时间,防止并发触发 # 立即更新提取时间,防止并发触发
self.last_extraction_time = extraction_end_time self.last_extraction_time = extraction_end_time
try: try:
logger.info(f"在聊天流 {self.chat_name} 开始统一消息提取和分发") logger.info(f"在聊天流 {self.chat_name} 开始统一消息提取和分发")
# 拉取提取窗口内的消息 # 拉取提取窗口内的消息
messages = get_raw_msg_by_timestamp_with_chat_inclusive( messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id, chat_id=self.chat_id,
timestamp_start=extraction_start_time, timestamp_start=extraction_start_time,
timestamp_end=extraction_end_time, timestamp_end=extraction_end_time,
) )
if not messages: if not messages:
logger.debug(f"聊天流 {self.chat_name} 没有新消息,跳过提取") logger.debug(f"聊天流 {self.chat_name} 没有新消息,跳过提取")
return return
# 按时间排序,确保顺序一致 # 按时间排序,确保顺序一致
messages = sorted(messages, key=lambda msg: msg.time or 0) messages = sorted(messages, key=lambda msg: msg.time or 0)
logger.info( logger.info(
f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息," f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息,"
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}" f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
) )
# 分别触发 expression_learner 和 jargon_miner 的处理 # 分别触发 expression_learner 和 jargon_miner 的处理
# 传递提取的消息,避免它们重复获取 # 传递提取的消息,避免它们重复获取
# 触发 expression 学习(如果启用) # 触发 expression 学习(如果启用)
@@ -124,28 +123,26 @@ class MessageRecorder:
asyncio.create_task( asyncio.create_task(
self._trigger_expression_learning(extraction_start_time, extraction_end_time, messages) self._trigger_expression_learning(extraction_start_time, extraction_end_time, messages)
) )
# 触发 jargon 提取(如果启用),传递消息 # 触发 jargon 提取(如果启用),传递消息
# if self.enable_jargon_learning: # if self.enable_jargon_learning:
# asyncio.create_task( # asyncio.create_task(
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages) # self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
# ) # )
except Exception as e: except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}") logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
# 即使失败也保持时间戳更新,避免频繁重试 # 即使失败也保持时间戳更新,避免频繁重试
async def _trigger_expression_learning( async def _trigger_expression_learning(
self, self, timestamp_start: float, timestamp_end: float, messages: List[Any]
timestamp_start: float,
timestamp_end: float,
messages: List[Any]
) -> None: ) -> None:
""" """
触发 expression 学习,使用指定的消息列表 触发 expression 学习,使用指定的消息列表
Args: Args:
timestamp_start: 开始时间戳 timestamp_start: 开始时间戳
timestamp_end: 结束时间戳 timestamp_end: 结束时间戳
@@ -154,7 +151,7 @@ class MessageRecorder:
try: try:
# 传递消息给 ExpressionLearner必需参数 # 传递消息给 ExpressionLearner必需参数
learnt_style = await self.expression_learner.learn_and_store(messages=messages) learnt_style = await self.expression_learner.learn_and_store(messages=messages)
if learnt_style: if learnt_style:
logger.info(f"聊天流 {self.chat_name} 表达学习完成") logger.info(f"聊天流 {self.chat_name} 表达学习完成")
else: else:
@@ -162,17 +159,15 @@ class MessageRecorder:
except Exception as e: except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}") logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
async def _trigger_jargon_extraction( async def _trigger_jargon_extraction(
self, self, timestamp_start: float, timestamp_end: float, messages: List[Any]
timestamp_start: float,
timestamp_end: float,
messages: List[Any]
) -> None: ) -> None:
""" """
触发 jargon 提取,使用指定的消息列表 触发 jargon 提取,使用指定的消息列表
Args: Args:
timestamp_start: 开始时间戳 timestamp_start: 开始时间戳
timestamp_end: 结束时间戳 timestamp_end: 结束时间戳
@@ -181,19 +176,20 @@ class MessageRecorder:
try: try:
# 传递消息给 JargonMiner避免它重复获取 # 传递消息给 JargonMiner避免它重复获取
await self.jargon_miner.run_once(messages=messages) await self.jargon_miner.run_once(messages=messages)
except Exception as e: except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}") logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
class MessageRecorderManager: class MessageRecorderManager:
"""MessageRecorder 管理器""" """MessageRecorder 管理器"""
def __init__(self) -> None: def __init__(self) -> None:
self._recorders: dict[str, MessageRecorder] = {} self._recorders: dict[str, MessageRecorder] = {}
def get_recorder(self, chat_id: str) -> MessageRecorder: def get_recorder(self, chat_id: str) -> MessageRecorder:
"""获取或创建指定 chat_id 的 MessageRecorder""" """获取或创建指定 chat_id 的 MessageRecorder"""
if chat_id not in self._recorders: if chat_id not in self._recorders:
@@ -208,10 +204,9 @@ recorder_manager = MessageRecorderManager()
async def extract_and_distribute_messages(chat_id: str) -> None: async def extract_and_distribute_messages(chat_id: str) -> None:
""" """
统一的消息提取和分发入口函数 统一的消息提取和分发入口函数
Args: Args:
chat_id: 聊天流ID chat_id: 聊天流ID
""" """
recorder = recorder_manager.get_recorder(chat_id) recorder = recorder_manager.get_recorder(chat_id)
await recorder.extract_and_distribute() await recorder.extract_and_distribute()

View File

@@ -176,19 +176,19 @@ class BrainChatting:
# 如果有新消息,更新 last_read_time # 如果有新消息,更新 last_read_time
if len(recent_messages_list) >= 1: if len(recent_messages_list) >= 1:
self.last_read_time = time.time() self.last_read_time = time.time()
# 总是执行一次思考迭代(不管有没有新消息) # 总是执行一次思考迭代(不管有没有新消息)
# wait 动作会在其内部等待,不需要在这里处理 # wait 动作会在其内部等待,不需要在这里处理
should_continue = await self._observe(recent_messages_list=recent_messages_list) should_continue = await self._observe(recent_messages_list=recent_messages_list)
if not should_continue: if not should_continue:
# 选择了 complete_talk返回 False 表示需要等待新消息 # 选择了 complete_talk返回 False 表示需要等待新消息
return False return False
# 继续下一次迭代(除非选择了 complete_talk # 继续下一次迭代(除非选择了 complete_talk
# 短暂等待后再继续,避免过于频繁的循环 # 短暂等待后再继续,避免过于频繁的循环
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
return True return True
async def _send_and_store_reply( async def _send_and_store_reply(
@@ -328,9 +328,7 @@ class BrainChatting:
) )
# 检查是否有 complete_talk 动作(会停止后续迭代) # 检查是否有 complete_talk 动作(会停止后续迭代)
has_complete_talk = any( has_complete_talk = any(action.action_type == "complete_talk" for action in action_to_use_info)
action.action_type == "complete_talk" for action in action_to_use_info
)
# 并行执行所有动作 # 并行执行所有动作
action_tasks = [ action_tasks = [
@@ -430,12 +428,12 @@ class BrainChatting:
await asyncio.sleep(3) await asyncio.sleep(3)
self._loop_task = asyncio.create_task(self._main_chat_loop()) self._loop_task = asyncio.create_task(self._main_chat_loop())
logger.error(f"{self.log_prefix} 结束了当前聊天循环") logger.error(f"{self.log_prefix} 结束了当前聊天循环")
async def _wait_for_new_message(self): async def _wait_for_new_message(self):
"""等待新消息到达""" """等待新消息到达"""
last_check_time = self.last_read_time last_check_time = self.last_read_time
check_interval = 1.0 # 每秒检查一次 check_interval = 1.0 # 每秒检查一次
while self.running: while self.running:
# 检查是否有新消息 # 检查是否有新消息
recent_messages_list = message_api.get_messages_by_time_in_chat( recent_messages_list = message_api.get_messages_by_time_in_chat(
@@ -448,13 +446,13 @@ class BrainChatting:
filter_command=False, filter_command=False,
filter_intercept_message_level=1, filter_intercept_message_level=1,
) )
# 如果有新消息,更新 last_read_time 并返回 # 如果有新消息,更新 last_read_time 并返回
if len(recent_messages_list) >= 1: if len(recent_messages_list) >= 1:
self.last_read_time = time.time() self.last_read_time = time.time()
logger.info(f"{self.log_prefix} 检测到新消息,恢复循环") logger.info(f"{self.log_prefix} 检测到新消息,恢复循环")
return return
# 等待一段时间后再次检查 # 等待一段时间后再次检查
await asyncio.sleep(check_interval) await asyncio.sleep(check_interval)
@@ -660,9 +658,9 @@ class BrainChatting:
except (ValueError, TypeError): except (ValueError, TypeError):
logger.warning(f"{self.log_prefix} wait_seconds 参数格式错误,使用默认值 5 秒") logger.warning(f"{self.log_prefix} wait_seconds 参数格式错误,使用默认值 5 秒")
wait_seconds = 5 wait_seconds = 5
logger.info(f"{self.log_prefix} 执行 wait 动作,等待 {wait_seconds}") logger.info(f"{self.log_prefix} 执行 wait 动作,等待 {wait_seconds}")
# 记录动作信息 # 记录动作信息
await database_api.store_action_info( await database_api.store_action_info(
chat_stream=self.chat_stream, chat_stream=self.chat_stream,
@@ -673,12 +671,12 @@ class BrainChatting:
action_data={"reason": reason, "wait_seconds": wait_seconds}, action_data={"reason": reason, "wait_seconds": wait_seconds},
action_name="wait", action_name="wait",
) )
# 等待指定时间 # 等待指定时间
await asyncio.sleep(wait_seconds) await asyncio.sleep(wait_seconds)
logger.info(f"{self.log_prefix} wait 动作完成,继续下一次思考") logger.info(f"{self.log_prefix} wait 动作完成,继续下一次思考")
# 这些动作本身不产生文本回复 # 这些动作本身不产生文本回复
self._last_successful_reply = False self._last_successful_reply = False
return { return {
@@ -693,9 +691,9 @@ class BrainChatting:
logger.debug(f"{self.log_prefix} 检测到 listening 动作,已合并到 wait自动转换") logger.debug(f"{self.log_prefix} 检测到 listening 动作,已合并到 wait自动转换")
# 使用默认等待时间 # 使用默认等待时间
wait_seconds = 3 wait_seconds = 3
logger.info(f"{self.log_prefix} 执行 listening转换为 wait动作等待 {wait_seconds}") logger.info(f"{self.log_prefix} 执行 listening转换为 wait动作等待 {wait_seconds}")
# 记录动作信息 # 记录动作信息
await database_api.store_action_info( await database_api.store_action_info(
chat_stream=self.chat_stream, chat_stream=self.chat_stream,
@@ -706,12 +704,12 @@ class BrainChatting:
action_data={"reason": reason, "wait_seconds": wait_seconds}, action_data={"reason": reason, "wait_seconds": wait_seconds},
action_name="listening", action_name="listening",
) )
# 等待指定时间 # 等待指定时间
await asyncio.sleep(wait_seconds) await asyncio.sleep(wait_seconds)
logger.info(f"{self.log_prefix} listening 动作完成,继续下一次思考") logger.info(f"{self.log_prefix} listening 动作完成,继续下一次思考")
# 这些动作本身不产生文本回复 # 这些动作本身不产生文本回复
self._last_successful_reply = False self._last_successful_reply = False
return { return {

View File

@@ -147,7 +147,7 @@ class BrainPlanner:
) # 用于动作规划 ) # 用于动作规划
self.last_obs_time_mark = 0.0 self.last_obs_time_mark = 0.0
# 计划日志记录 # 计划日志记录
self.plan_log: List[Tuple[str, float, List[ActionPlannerInfo]]] = [] self.plan_log: List[Tuple[str, float, List[ActionPlannerInfo]]] = []
@@ -203,9 +203,11 @@ class BrainPlanner:
# 内部保留动作(不依赖插件系统) # 内部保留动作(不依赖插件系统)
# 注意listening 已合并到 wait 中,如果遇到 listening 则转换为 wait # 注意listening 已合并到 wait 中,如果遇到 listening 则转换为 wait
internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"] internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"]
logger.debug(f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}") logger.debug(
f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}"
)
# 将 listening 转换为 wait向后兼容 # 将 listening 转换为 wait向后兼容
if action == "listening": if action == "listening":
logger.debug(f"{self.log_prefix}检测到 listening 动作,已合并到 wait自动转换") logger.debug(f"{self.log_prefix}检测到 listening 动作,已合并到 wait自动转换")
@@ -521,7 +523,7 @@ class BrainPlanner:
if json_objects: if json_objects:
logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象") logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
for i, json_obj in enumerate(json_objects): for i, json_obj in enumerate(json_objects):
logger.info(f"{self.log_prefix}解析第{i+1}个JSON对象: {json_obj}") logger.info(f"{self.log_prefix}解析第{i + 1}个JSON对象: {json_obj}")
filtered_actions_list = list(filtered_actions.items()) filtered_actions_list = list(filtered_actions.items())
for json_obj in json_objects: for json_obj in json_objects:
parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list) parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list)
@@ -553,7 +555,9 @@ class BrainPlanner:
return extracted_reasoning, actions return extracted_reasoning, actions
def _create_complete_talk(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]: def _create_complete_talk(
self, reasoning: str, available_actions: Dict[str, ActionInfo]
) -> List[ActionPlannerInfo]:
"""创建complete_talk""" """创建complete_talk"""
return [ return [
ActionPlannerInfo( ActionPlannerInfo(
@@ -564,7 +568,7 @@ class BrainPlanner:
available_actions=available_actions, available_actions=available_actions,
) )
] ]
def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]): def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]):
"""添加计划日志""" """添加计划日志"""
self.plan_log.append((reasoning, time.time(), actions)) self.plan_log.append((reasoning, time.time(), actions))

View File

@@ -271,7 +271,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
emoji.description = emoji_data.description emoji.description = emoji_data.description
# Deserialize emotion string from DB to list # Deserialize emotion string from DB to list
emoji.emotion = emoji_data.emotion.replace("",",").split(",") if emoji_data.emotion else [] emoji.emotion = emoji_data.emotion.replace("", ",").split(",") if emoji_data.emotion else []
emoji.usage_count = emoji_data.usage_count emoji.usage_count = emoji_data.usage_count
db_last_used_time = emoji_data.last_used_time db_last_used_time = emoji_data.last_used_time
@@ -732,7 +732,7 @@ class EmojiManager:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash) emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
if emoji_record and emoji_record.emotion: if emoji_record and emoji_record.emotion:
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...") logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
return emoji_record.emotion.replace("",",").split(",") return emoji_record.emotion.replace("", ",").split(",")
except Exception as e: except Exception as e:
logger.error(f"从数据库查询表情包情感标签时出错: {e}") logger.error(f"从数据库查询表情包情感标签时出错: {e}")
@@ -993,7 +993,7 @@ class EmojiManager:
) )
# 处理情感列表 # 处理情感列表
emotions = [e.strip() for e in emotions_text.replace("",",").split(",") if e.strip()] emotions = [e.strip() for e in emotions_text.replace("", ",").split(",") if e.strip()]
# 根据情感标签数量随机选择 - 超过5个选3个超过2个选2个 # 根据情感标签数量随机选择 - 超过5个选3个超过2个选2个
if len(emotions) > 5: if len(emotions) > 5:

View File

@@ -619,13 +619,13 @@ class HeartFChatting:
think_level = 0 think_level = 0
# 使用 action_reasoningplanner 的整体思考理由)作为 reply_reason # 使用 action_reasoningplanner 的整体思考理由)作为 reply_reason
planner_reasoning = action_planner_info.action_reasoning or reason planner_reasoning = action_planner_info.action_reasoning or reason
record_replyer_action_temp( record_replyer_action_temp(
chat_id=self.stream_id, chat_id=self.stream_id,
reason=reason, reason=reason,
think_level=think_level, think_level=think_level,
) )
await database_api.store_action_info( await database_api.store_action_info(
chat_stream=self.chat_stream, chat_stream=self.chat_stream,
action_build_into_prompt=False, action_build_into_prompt=False,

View File

@@ -123,7 +123,11 @@ class ChatBot:
logger.warning(f"命令执行失败: {command_class.__name__} - {response}") logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
# 根据命令的拦截设置决定是否继续处理消息 # 根据命令的拦截设置决定是否继续处理消息
return True, response, not bool(intercept_message_level) # 找到命令根据intercept_message决定是否继续 return (
True,
response,
not bool(intercept_message_level),
) # 找到命令根据intercept_message决定是否继续
except Exception as e: except Exception as e:
logger.error(f"执行命令时出错: {command_class.__name__} - {e}") logger.error(f"执行命令时出错: {command_class.__name__} - {e}")

View File

@@ -263,7 +263,7 @@ class MessageRecv(Message):
desc = segment.data.get("desc", "") # 内容描述 desc = segment.data.get("desc", "") # 内容描述
source_url = segment.data.get("source_url", "") # 原始链接 source_url = segment.data.get("source_url", "") # 原始链接
url = segment.data.get("url", "") # 小程序链接 url = segment.data.get("url", "") # 小程序链接
text = f"[小程序分享" text = "[小程序分享"
if title: if title:
text += f" - {title}" text += f" - {title}"
text += "]" text += "]"

View File

@@ -42,22 +42,21 @@ def is_webui_virtual_group(group_id: str) -> bool:
def parse_message_segments(segment) -> list: def parse_message_segments(segment) -> list:
"""解析消息段,转换为 WebUI 可用的格式 """解析消息段,转换为 WebUI 可用的格式
参考 NapCat 适配器的消息解析逻辑 参考 NapCat 适配器的消息解析逻辑
Args: Args:
segment: Seg 消息段对象 segment: Seg 消息段对象
Returns: Returns:
list: 消息段列表,每个元素为 {"type": "...", "data": ...} list: 消息段列表,每个元素为 {"type": "...", "data": ...}
""" """
from maim_message import Seg
result = [] result = []
if segment is None: if segment is None:
return result return result
if segment.type == "seglist": if segment.type == "seglist":
# 处理消息段列表 # 处理消息段列表
if segment.data: if segment.data:
@@ -112,15 +111,19 @@ def parse_message_segments(segment) -> list:
forward_items = [] forward_items = []
if segment.data: if segment.data:
for item in segment.data: for item in segment.data:
forward_items.append({ forward_items.append(
"content": parse_message_segments(item.get("message_segment", {})) if isinstance(item, dict) else [] {
}) "content": parse_message_segments(item.get("message_segment", {}))
if isinstance(item, dict)
else []
}
)
result.append({"type": "forward", "data": forward_items}) result.append({"type": "forward", "data": forward_items})
else: else:
# 未知类型,尝试作为文本处理 # 未知类型,尝试作为文本处理
if segment.data: if segment.data:
result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)}) result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
return result return result
@@ -134,7 +137,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息 # 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
chat_manager, webui_platform = get_webui_chat_broadcaster() chat_manager, webui_platform = get_webui_chat_broadcaster()
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id) is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
if is_webui_message and chat_manager is not None: if is_webui_message and chat_manager is not None:
# WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播 # WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播
import time import time
@@ -142,7 +145,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
# 解析消息段,获取富文本内容 # 解析消息段,获取富文本内容
message_segments = parse_message_segments(message.message_segment) message_segments = parse_message_segments(message.message_segment)
# 判断消息类型 # 判断消息类型
# 如果只有一个文本段,使用简单的 text 类型 # 如果只有一个文本段,使用简单的 text 类型
# 否则使用 rich 类型,包含完整的消息段 # 否则使用 rich 类型,包含完整的消息段

View File

@@ -77,8 +77,7 @@ target_message_id为必填表示触发消息的id
```""", ```""",
"planner_prompt", "planner_prompt",
) )
Prompt( Prompt(
""" """
{action_name} {action_name}

View File

@@ -250,7 +250,12 @@ class DefaultReplyer:
# 使用从处理器传来的选中表达方式 # 使用从处理器传来的选中表达方式
# 使用模型预测选择表达方式 # 使用模型预测选择表达方式
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions( selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason, think_level=think_level self.chat_stream.stream_id,
chat_history,
max_num=8,
target_message=target,
reply_reason=reply_reason,
think_level=think_level,
) )
if selected_expressions: if selected_expressions:
@@ -273,7 +278,6 @@ class DefaultReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str: async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块 """构建工具信息块
@@ -788,7 +792,8 @@ class DefaultReplyer:
# 并行执行八个构建任务(包括黑话解释) # 并行执行八个构建任务(包括黑话解释)
task_results = await asyncio.gather( task_results = await asyncio.gather(
self._time_and_run_task( self._time_and_run_task(
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level), "expression_habits" self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level),
"expression_habits",
), ),
self._time_and_run_task( self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info" self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
@@ -980,7 +985,6 @@ class DefaultReplyer:
else: else:
reply_target_block = "" reply_target_block = ""
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1") chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2") chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")

View File

@@ -287,7 +287,6 @@ class PrivateReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str: async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块 """构建工具信息块
@@ -907,16 +906,11 @@ class PrivateReplyer:
else: else:
reply_target_block = "" reply_target_block = ""
chat_target_name = "对方" chat_target_name = "对方"
if self.chat_target_info: if self.chat_target_info:
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方" chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
chat_target_1 = await global_prompt_manager.format_prompt( chat_target_1 = await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name)
"chat_target_private1", sender_name=chat_target_name chat_target_2 = await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name)
)
chat_target_2 = await global_prompt_manager.format_prompt(
"chat_target_private2", sender_name=chat_target_name
)
template_name = "default_expressor_prompt" template_name = "default_expressor_prompt"

View File

@@ -1,8 +1,9 @@
from src.chat.utils.prompt_builder import Prompt from src.chat.utils.prompt_builder import Prompt
def init_replyer_private_prompt(): def init_replyer_private_prompt():
Prompt( Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block} """{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation} {expression_habits_block}{memory_retrieval}{jargon_explanation}
你正在和{sender_name}聊天,这是你们之前聊的内容: 你正在和{sender_name}聊天,这是你们之前聊的内容:
@@ -17,9 +18,9 @@ def init_replyer_private_prompt():
{reply_style} {reply_style}
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。 请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""", {moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""",
"private_replyer_prompt", "private_replyer_prompt",
) )
Prompt( Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block} """{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation} {expression_habits_block}{memory_retrieval}{jargon_explanation}
@@ -37,4 +38,4 @@ def init_replyer_private_prompt():
{moderation_prompt}不要输出多余内容(包括冒号和引号括号表情包at或 @等 )。 {moderation_prompt}不要输出多余内容(包括冒号和引号括号表情包at或 @等 )。
""", """,
"private_replyer_self_prompt", "private_replyer_self_prompt",
) )

View File

@@ -23,7 +23,7 @@ def init_replyer_prompt():
现在,你说:""", 现在,你说:""",
"replyer_prompt_0", "replyer_prompt_0",
) )
Prompt( Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block} """{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation} {expression_habits_block}{memory_retrieval}{jargon_explanation}
@@ -44,4 +44,3 @@ def init_replyer_prompt():
现在,你说:""", 现在,你说:""",
"replyer_prompt", "replyer_prompt",
) )

View File

@@ -311,7 +311,10 @@ def get_raw_msg_before_timestamp_with_chat(
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}} filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
sort_order = [("time", 1)] sort_order = [("time", 1)]
return find_messages( return find_messages(
message_filter=filter_query, sort=sort_order, limit=limit, filter_intercept_message_level=filter_intercept_message_level message_filter=filter_query,
sort=sort_order,
limit=limit,
filter_intercept_message_level=filter_intercept_message_level,
) )

View File

@@ -746,7 +746,7 @@ class StatisticOutputTask(AsyncTask):
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}" data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
total_replies = stats.get(TOTAL_REPLY_CNT, 0) total_replies = stats.get(TOTAL_REPLY_CNT, 0)
output = [ output = [
"按模型分类统计:", "按模型分类统计:",
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数", " 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
@@ -759,11 +759,11 @@ class StatisticOutputTask(AsyncTask):
cost = stats[COST_BY_MODEL][model_name] cost = stats[COST_BY_MODEL][model_name]
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name] avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name] std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
# 计算每次回复平均值 # 计算每次回复平均值
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0 avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0 avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
# 格式化大数字 # 格式化大数字
formatted_count = _format_large_number(count) formatted_count = _format_large_number(count)
formatted_in_tokens = _format_large_number(in_tokens) formatted_in_tokens = _format_large_number(in_tokens)
@@ -771,7 +771,7 @@ class StatisticOutputTask(AsyncTask):
formatted_tokens = _format_large_number(tokens) formatted_tokens = _format_large_number(tokens)
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A" formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A" formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
output.append( output.append(
data_fmt.format( data_fmt.format(
name, name,
@@ -800,7 +800,7 @@ class StatisticOutputTask(AsyncTask):
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}" data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
total_replies = stats.get(TOTAL_REPLY_CNT, 0) total_replies = stats.get(TOTAL_REPLY_CNT, 0)
output = [ output = [
"按模块分类统计:", "按模块分类统计:",
" 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数", " 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
@@ -813,11 +813,11 @@ class StatisticOutputTask(AsyncTask):
cost = stats[COST_BY_MODULE][module_name] cost = stats[COST_BY_MODULE][module_name]
avg_time_cost = stats[AVG_TIME_COST_BY_MODULE][module_name] avg_time_cost = stats[AVG_TIME_COST_BY_MODULE][module_name]
std_time_cost = stats[STD_TIME_COST_BY_MODULE][module_name] std_time_cost = stats[STD_TIME_COST_BY_MODULE][module_name]
# 计算每次回复平均值 # 计算每次回复平均值
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0 avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0 avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
# 格式化大数字 # 格式化大数字
formatted_count = _format_large_number(count) formatted_count = _format_large_number(count)
formatted_in_tokens = _format_large_number(in_tokens) formatted_in_tokens = _format_large_number(in_tokens)
@@ -825,7 +825,7 @@ class StatisticOutputTask(AsyncTask):
formatted_tokens = _format_large_number(tokens) formatted_tokens = _format_large_number(tokens)
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A" formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A" formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
output.append( output.append(
data_fmt.format( data_fmt.format(
name, name,

View File

@@ -646,7 +646,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> None: def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> None:
""" """
临时记录replyer动作被选择的信息仅群聊 临时记录replyer动作被选择的信息仅群聊
Args: Args:
chat_id: 聊天ID chat_id: 聊天ID
reason: 选择理由 reason: 选择理由
@@ -656,7 +656,7 @@ def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> N
# 确保data/temp目录存在 # 确保data/temp目录存在
temp_dir = "data/temp" temp_dir = "data/temp"
os.makedirs(temp_dir, exist_ok=True) os.makedirs(temp_dir, exist_ok=True)
# 创建记录数据 # 创建记录数据
record_data = { record_data = {
"chat_id": chat_id, "chat_id": chat_id,
@@ -664,16 +664,16 @@ def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> N
"think_level": think_level, "think_level": think_level,
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
} }
# 生成文件名(使用时间戳避免冲突) # 生成文件名(使用时间戳避免冲突)
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S_%f") timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"replyer_action_{timestamp_str}.json" filename = f"replyer_action_{timestamp_str}.json"
filepath = os.path.join(temp_dir, filename) filepath = os.path.join(temp_dir, filename)
# 写入文件 # 写入文件
with open(filepath, "w", encoding="utf-8") as f: with open(filepath, "w", encoding="utf-8") as f:
json.dump(record_data, f, ensure_ascii=False, indent=2) json.dump(record_data, f, ensure_ascii=False, indent=2)
logger.debug(f"已记录replyer动作选择: chat_id={chat_id}, think_level={think_level}") logger.debug(f"已记录replyer动作选择: chat_id={chat_id}, think_level={think_level}")
except Exception as e: except Exception as e:
logger.warning(f"记录replyer动作选择失败: {e}") logger.warning(f"记录replyer动作选择失败: {e}")

View File

@@ -130,12 +130,10 @@ class ImageManager:
try: try:
# 清理Images表中type为emoji的记录 # 清理Images表中type为emoji的记录
deleted_images = Images.delete().where(Images.type == "emoji").execute() deleted_images = Images.delete().where(Images.type == "emoji").execute()
# 清理ImageDescriptions表中type为emoji的记录 # 清理ImageDescriptions表中type为emoji的记录
deleted_descriptions = ( deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
)
total_deleted = deleted_images + deleted_descriptions total_deleted = deleted_images + deleted_descriptions
if total_deleted > 0: if total_deleted > 0:
logger.info( logger.info(
@@ -166,7 +164,7 @@ class ImageManager:
async def _save_emoji_file_if_needed(self, image_base64: str, image_hash: str, image_format: str) -> None: async def _save_emoji_file_if_needed(self, image_base64: str, image_hash: str, image_format: str) -> None:
"""如果启用了steal_emoji且表情包未注册保存文件到data/emoji目录 """如果启用了steal_emoji且表情包未注册保存文件到data/emoji目录
Args: Args:
image_base64: 图片的base64编码 image_base64: 图片的base64编码
image_hash: 图片的MD5哈希值 image_hash: 图片的MD5哈希值
@@ -174,7 +172,7 @@ class ImageManager:
""" """
if not global_config.emoji.steal_emoji: if not global_config.emoji.steal_emoji:
return return
try: try:
from src.chat.emoji_system.emoji_manager import EMOJI_DIR from src.chat.emoji_system.emoji_manager import EMOJI_DIR
from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.emoji_system.emoji_manager import get_emoji_manager
@@ -236,12 +234,16 @@ class ImageManager:
# 优先使用情感标签,如果没有则使用详细描述 # 优先使用情感标签,如果没有则使用详细描述
result_text = "" result_text = ""
if cache_record.emotion_tags: if cache_record.emotion_tags:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...") logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
)
result_text = f"[表情包:{cache_record.emotion_tags}]" result_text = f"[表情包:{cache_record.emotion_tags}]"
elif cache_record.description: elif cache_record.description:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...") logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
)
result_text = f"[表情包:{cache_record.description}]" result_text = f"[表情包:{cache_record.description}]"
# 即使缓存命中如果启用了steal_emoji也检查是否需要保存文件 # 即使缓存命中如果启用了steal_emoji也检查是否需要保存文件
if result_text: if result_text:
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format) await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)

View File

@@ -609,23 +609,23 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
fields = list(model._meta.fields.keys()) fields = list(model._meta.fields.keys())
# Peewee 默认使用 'id' 作为主键字段名 # Peewee 默认使用 'id' 作为主键字段名
# 尝试获取主键字段名,如果获取失败则默认使用 'id' # 尝试获取主键字段名,如果获取失败则默认使用 'id'
primary_key_name = 'id' # 默认值 primary_key_name = "id" # 默认值
try: try:
if hasattr(model._meta, 'primary_key') and model._meta.primary_key: if hasattr(model._meta, "primary_key") and model._meta.primary_key:
if hasattr(model._meta.primary_key, 'name'): if hasattr(model._meta.primary_key, "name"):
primary_key_name = model._meta.primary_key.name primary_key_name = model._meta.primary_key.name
elif isinstance(model._meta.primary_key, str): elif isinstance(model._meta.primary_key, str):
primary_key_name = model._meta.primary_key primary_key_name = model._meta.primary_key
except Exception: except Exception:
pass # 如果获取失败,使用默认值 'id' pass # 如果获取失败,使用默认值 'id'
# 如果字段列表包含主键,则排除它 # 如果字段列表包含主键,则排除它
if primary_key_name in fields: if primary_key_name in fields:
fields_without_pk = [f for f in fields if f != primary_key_name] fields_without_pk = [f for f in fields if f != primary_key_name]
logger.info(f"排除主键字段 '{primary_key_name}',让数据库自动生成新的主键") logger.info(f"排除主键字段 '{primary_key_name}',让数据库自动生成新的主键")
else: else:
fields_without_pk = fields fields_without_pk = fields
fields_str = ", ".join(fields_without_pk) fields_str = ", ".join(fields_without_pk)
# 检查是否有字段需要从 NULL 改为 NOT NULL # 检查是否有字段需要从 NULL 改为 NOT NULL

View File

@@ -34,7 +34,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
return obj return obj
# 决定是否多行:仅在顶层且长度超过阈值时 # 决定是否多行:仅在顶层且长度超过阈值时
should_multiline = (depth == 0 and len(obj) > threshold) should_multiline = depth == 0 and len(obj) > threshold
# 如果已经是 tomlkit Array原地修改以保留注释 # 如果已经是 tomlkit Array原地修改以保留注释
if isinstance(obj, Array): if isinstance(obj, Array):
@@ -46,7 +46,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
# 普通 list转换为 tomlkit 数组 # 普通 list转换为 tomlkit 数组
arr = tomlkit.array() arr = tomlkit.array()
arr.multiline(should_multiline) arr.multiline(should_multiline)
for item in obj: for item in obj:
arr.append(_format_toml_value(item, threshold, depth + 1)) arr.append(_format_toml_value(item, threshold, depth + 1))
return arr return arr
@@ -112,7 +112,7 @@ def save_toml_with_format(
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
output = tomlkit.dumps(formatted) output = tomlkit.dumps(formatted)
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积 # 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
output = re.sub(r'\n{3,}', '\n\n', output) output = re.sub(r"\n{3,}", "\n\n", output)
with open(file_path, "w", encoding="utf-8") as f: with open(file_path, "w", encoding="utf-8") as f:
f.write(output) f.write(output)
@@ -122,4 +122,4 @@ def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
output = tomlkit.dumps(formatted) output = tomlkit.dumps(formatted)
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积 # 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
return re.sub(r'\n{3,}', '\n\n', output) return re.sub(r"\n{3,}", "\n\n", output)

View File

@@ -778,9 +778,9 @@ class DreamConfig(ConfigBase):
""" """
if not self.dream_time_ranges: if not self.dream_time_ranges:
return True return True
now_min = self._now_minutes() now_min = self._now_minutes()
for time_range in self.dream_time_ranges: for time_range in self.dream_time_ranges:
if not isinstance(time_range, str): if not isinstance(time_range, str):
continue continue
@@ -790,7 +790,7 @@ class DreamConfig(ConfigBase):
start_min, end_min = parsed start_min, end_min = parsed
if self._in_range(now_min, start_min, end_min): if self._in_range(now_min, start_min, end_min):
return True return True
return False return False
def __post_init__(self): def __post_init__(self):
@@ -800,4 +800,4 @@ class DreamConfig(ConfigBase):
if self.max_iterations < 1: if self.max_iterations < 1:
raise ValueError(f"max_iterations 必须至少为1当前值: {self.max_iterations}") raise ValueError(f"max_iterations 必须至少为1当前值: {self.max_iterations}")
if self.first_delay_seconds < 0: if self.first_delay_seconds < 0:
raise ValueError(f"first_delay_seconds 不能为负数,当前值: {self.first_delay_seconds}") raise ValueError(f"first_delay_seconds 不能为负数,当前值: {self.first_delay_seconds}")

View File

@@ -1,14 +1,13 @@
import asyncio import asyncio
import random import random
import time import time
import json
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from peewee import fn from peewee import fn
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.common.database.database_model import ChatHistory, Jargon from src.common.database.database_model import ChatHistory
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.plugin_system.apis import llm_api from src.plugin_system.apis import llm_api
@@ -82,7 +81,6 @@ def init_dream_prompts() -> None:
) )
class DreamTool: class DreamTool:
"""dream 模块内部使用的简易工具封装""" """dream 模块内部使用的简易工具封装"""
@@ -150,7 +148,13 @@ def init_dream_tools(chat_id: str) -> None:
"search_chat_history", "search_chat_history",
"根据关键词或参与人查询当前 chat_id 下的 ChatHistory 概览,便于快速定位相关记忆。", "根据关键词或参与人查询当前 chat_id 下的 ChatHistory 概览,便于快速定位相关记忆。",
[ [
("keyword", ToolParamType.STRING, "关键词(可选,支持多个关键词,可用空格、逗号等分隔)。", False, None), (
"keyword",
ToolParamType.STRING,
"关键词(可选,支持多个关键词,可用空格、逗号等分隔)。",
False,
None,
),
("participant", ToolParamType.STRING, "参与人昵称(可选)。", False, None), ("participant", ToolParamType.STRING, "参与人昵称(可选)。", False, None),
], ],
search_chat_history, search_chat_history,
@@ -201,8 +205,20 @@ def init_dream_tools(chat_id: str) -> None:
[ [
("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None), ("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None),
("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None), ("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None),
("keywords", ToolParamType.STRING, "新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。", True, None), (
("key_point", ToolParamType.STRING, "新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。", True, None), "keywords",
ToolParamType.STRING,
"新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。",
True,
None,
),
(
"key_point",
ToolParamType.STRING,
"新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。",
True,
None,
),
("start_time", ToolParamType.STRING, "起始时间戳Unix 时间,必填)。", True, None), ("start_time", ToolParamType.STRING, "起始时间戳Unix 时间,必填)。", True, None),
("end_time", ToolParamType.STRING, "结束时间戳Unix 时间,必填)。", True, None), ("end_time", ToolParamType.STRING, "结束时间戳Unix 时间,必填)。", True, None),
], ],
@@ -215,7 +231,13 @@ def init_dream_tools(chat_id: str) -> None:
"finish_maintenance", "finish_maintenance",
"结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理、合并或修改的内容时,调用此工具来主动结束本次运行。", "结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理、合并或修改的内容时,调用此工具来主动结束本次运行。",
[ [
("reason", ToolParamType.STRING, "结束维护的原因说明(可选),例如 '已完成所有记录的整理''当前记录质量良好,无需进一步维护'", False, None), (
"reason",
ToolParamType.STRING,
"结束维护的原因说明(可选),例如 '已完成所有记录的整理''当前记录质量良好,无需进一步维护'",
False,
None,
),
], ],
finish_maintenance, finish_maintenance,
) )
@@ -246,7 +268,7 @@ async def run_dream_agent_once(
""" """
if max_iterations is None: if max_iterations is None:
max_iterations = global_config.dream.max_iterations max_iterations = global_config.dream.max_iterations
start_ts = time.time() start_ts = time.time()
logger.info(f"[dream] 开始对 chat_id={chat_id} 进行 dream 维护,最多迭代 {max_iterations}") logger.info(f"[dream] 开始对 chat_id={chat_id} 进行 dream 维护,最多迭代 {max_iterations}")
@@ -282,9 +304,7 @@ async def run_dream_agent_once(
else "未知" else "未知"
) )
end_time_str = ( end_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
if record.end_time
else "未知"
) )
detail_text = ( detail_text = (
f"ID={record.id}\n" f"ID={record.id}\n"
@@ -305,8 +325,7 @@ async def run_dream_agent_once(
start_detail_builder = MessageBuilder() start_detail_builder = MessageBuilder()
start_detail_builder.set_role(RoleType.User) start_detail_builder.set_role(RoleType.User)
start_detail_builder.add_text_content( start_detail_builder.add_text_content(
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n" "【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n" + detail_text
+ detail_text
) )
conversation_messages.append(start_detail_builder.build()) conversation_messages.append(start_detail_builder.build())
else: else:
@@ -343,13 +362,17 @@ async def run_dream_agent_once(
conversation_messages.append(round_info_builder.build()) conversation_messages.append(round_info_builder.build())
# 调用 LLM 让其决定是否要使用工具 # 调用 LLM 让其决定是否要使用工具
success, response, reasoning_content, model_name, tool_calls = ( (
await llm_api.generate_with_model_with_tools_by_message_factory( success,
message_factory, response,
model_config=model_config.model_task_config.tool_use, reasoning_content,
tool_options=tool_defs, model_name,
request_type="dream.react", tool_calls,
) ) = await llm_api.generate_with_model_with_tools_by_message_factory(
message_factory,
model_config=model_config.model_task_config.tool_use,
tool_options=tool_defs,
request_type="dream.react",
) )
if not success: if not success:
@@ -522,7 +545,7 @@ async def start_dream_scheduler(
if interval_seconds is None: if interval_seconds is None:
interval_seconds = global_config.dream.interval_minutes * 60 interval_seconds = global_config.dream.interval_minutes * 60
logger.info( logger.info(
f"[dream] dream 调度器启动:首次延迟 {first_delay_seconds}s之后每隔 {interval_seconds}s ({interval_seconds // 60} 分钟) 运行一次 dream agent" f"[dream] dream 调度器启动:首次延迟 {first_delay_seconds}s之后每隔 {interval_seconds}s ({interval_seconds // 60} 分钟) 运行一次 dream agent"
) )
@@ -555,4 +578,3 @@ async def start_dream_scheduler(
# 初始化提示词 # 初始化提示词
init_dream_prompts() init_dream_prompts()

View File

@@ -86,7 +86,7 @@ async def generate_dream_summary(
try: try:
import json import json
from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.prompt_builder import global_prompt_manager
# 第一步:建立工具调用结果映射 (call_id -> result) # 第一步:建立工具调用结果映射 (call_id -> result)
tool_results_map: dict[str, str] = {} tool_results_map: dict[str, str] = {}
for msg in conversation_messages: for msg in conversation_messages:
@@ -98,11 +98,11 @@ async def generate_dream_summary(
else: else:
content = str(msg.content) content = str(msg.content)
tool_results_map[msg.tool_call_id] = content tool_results_map[msg.tool_call_id] = content
# 第二步:详细记录所有工具调用操作和结果到日志 # 第二步:详细记录所有工具调用操作和结果到日志
tool_call_count = 0 tool_call_count = 0
logger.info(f"[dream][工具调用详情] 开始记录 chat_id={chat_id} 的所有工具调用操作:") logger.info(f"[dream][工具调用详情] 开始记录 chat_id={chat_id} 的所有工具调用操作:")
for msg in conversation_messages: for msg in conversation_messages:
if msg.role == RoleType.Assistant and msg.tool_calls: if msg.role == RoleType.Assistant and msg.tool_calls:
tool_call_count += 1 tool_call_count += 1
@@ -110,34 +110,38 @@ async def generate_dream_summary(
thought_content = "" thought_content = ""
if msg.content: if msg.content:
if isinstance(msg.content, list) and msg.content: if isinstance(msg.content, list) and msg.content:
thought_content = msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0]) thought_content = (
msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
)
else: else:
thought_content = str(msg.content) thought_content = str(msg.content)
logger.info(f"[dream][工具调用详情] === 第 {tool_call_count} 组工具调用 ===") logger.info(f"[dream][工具调用详情] === 第 {tool_call_count} 组工具调用 ===")
if thought_content: if thought_content:
logger.info(f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}") logger.info(
f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}"
)
# 记录每个工具调用的详细信息 # 记录每个工具调用的详细信息
for idx, tool_call in enumerate(msg.tool_calls, 1): for idx, tool_call in enumerate(msg.tool_calls, 1):
tool_name = tool_call.func_name tool_name = tool_call.func_name
tool_args = tool_call.args or {} tool_args = tool_call.args or {}
tool_call_id = tool_call.call_id tool_call_id = tool_call.call_id
tool_result = tool_results_map.get(tool_call_id, "未找到执行结果") tool_result = tool_results_map.get(tool_call_id, "未找到执行结果")
# 格式化参数 # 格式化参数
try: try:
args_str = json.dumps(tool_args, ensure_ascii=False, indent=2) if tool_args else "无参数" args_str = json.dumps(tool_args, ensure_ascii=False, indent=2) if tool_args else "无参数"
except Exception: except Exception:
args_str = str(tool_args) args_str = str(tool_args)
logger.info(f"[dream][工具调用详情] --- 工具 {idx}: {tool_name} ---") logger.info(f"[dream][工具调用详情] --- 工具 {idx}: {tool_name} ---")
logger.info(f"[dream][工具调用详情] 调用参数:\n{args_str}") logger.info(f"[dream][工具调用详情] 调用参数:\n{args_str}")
logger.info(f"[dream][工具调用详情] 执行结果:\n{tool_result}") logger.info(f"[dream][工具调用详情] 执行结果:\n{tool_result}")
logger.info(f"[dream][工具调用详情] {'-' * 60}") logger.info(f"[dream][工具调用详情] {'-' * 60}")
logger.info(f"[dream][工具调用详情] 共记录了 {tool_call_count} 组工具调用操作") logger.info(f"[dream][工具调用详情] 共记录了 {tool_call_count} 组工具调用操作")
# 第三步:构建对话历史摘要(用于生成梦境) # 第三步:构建对话历史摘要(用于生成梦境)
conversation_summary = [] conversation_summary = []
for msg in conversation_messages: for msg in conversation_messages:
@@ -145,11 +149,11 @@ async def generate_dream_summary(
content = "" content = ""
if msg.content: if msg.content:
content = msg.content[0].text if isinstance(msg.content, list) and msg.content else str(msg.content) content = msg.content[0].text if isinstance(msg.content, list) and msg.content else str(msg.content)
if role == "user" and "轮次信息" in content: if role == "user" and "轮次信息" in content:
# 跳过轮次信息消息 # 跳过轮次信息消息
continue continue
if role == "assistant": if role == "assistant":
# 只保留思考内容,简化工具调用信息 # 只保留思考内容,简化工具调用信息
if content: if content:
@@ -162,13 +166,13 @@ async def generate_dream_summary(
# 截取前300字符 # 截取前300字符
content_preview = content[:300] + ("..." if len(content) > 300 else "") content_preview = content[:300] + ("..." if len(content) > 300 else "")
conversation_summary.append(f"[工具执行] {content_preview}") conversation_summary.append(f"[工具执行] {content_preview}")
conversation_text = "\n".join(conversation_summary[-20:]) # 只保留最后20条消息 conversation_text = "\n".join(conversation_summary[-20:]) # 只保留最后20条消息
# 随机选择2个梦境风格 # 随机选择2个梦境风格
selected_styles = get_random_dream_styles(2) selected_styles = get_random_dream_styles(2)
dream_styles_text = "\n".join([f"{i+1}. {style}" for i, style in enumerate(selected_styles)]) dream_styles_text = "\n".join([f"{i + 1}. {style}" for i, style in enumerate(selected_styles)])
# 使用 Prompt 管理器格式化梦境生成 prompt # 使用 Prompt 管理器格式化梦境生成 prompt
dream_prompt = await global_prompt_manager.format_prompt( dream_prompt = await global_prompt_manager.format_prompt(
"dream_summary_prompt", "dream_summary_prompt",
@@ -186,13 +190,14 @@ async def generate_dream_summary(
max_tokens=512, max_tokens=512,
temperature=0.8, temperature=0.8,
) )
if dream_content: if dream_content:
logger.info(f"[dream][梦境总结] 对 chat_id={chat_id} 的整理过程梦境:\n{dream_content}") logger.info(f"[dream][梦境总结] 对 chat_id={chat_id} 的整理过程梦境:\n{dream_content}")
else: else:
logger.warning("[dream][梦境总结] 未能生成梦境总结") logger.warning("[dream][梦境总结] 未能生成梦境总结")
except Exception as e: except Exception as e:
logger.error(f"[dream][梦境总结] 生成梦境总结失败: {e}", exc_info=True) logger.error(f"[dream][梦境总结] 生成梦境总结失败: {e}", exc_info=True)
init_dream_summary_prompt()
init_dream_summary_prompt()

View File

@@ -4,8 +4,3 @@ dream agent 工具实现模块。
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数 每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。 生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
""" """

View File

@@ -60,8 +60,3 @@ def make_create_chat_history(chat_id: str):
return f"create_chat_history 执行失败: {e}" return f"create_chat_history 执行失败: {e}"
return create_chat_history return create_chat_history

View File

@@ -23,8 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"delete_chat_history 执行失败: {e}" return f"delete_chat_history 执行失败: {e}"
return delete_chat_history return delete_chat_history

View File

@@ -23,8 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"delete_jargon 执行失败: {e}" return f"delete_jargon 执行失败: {e}"
return delete_jargon return delete_jargon

View File

@@ -14,8 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,
return msg return msg
return finish_maintenance return finish_maintenance

View File

@@ -1,5 +1,4 @@
import time import time
from typing import Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory from src.common.database.database_model import ChatHistory
@@ -20,14 +19,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
# 将时间戳转换为可读时间格式 # 将时间戳转换为可读时间格式
start_time_str = ( start_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time)) time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time)) if record.start_time else "未知"
if record.start_time
else "未知"
) )
end_time_str = ( end_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
if record.end_time
else "未知"
) )
result = ( result = (
@@ -40,17 +35,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
f"概括={record.summary or ''}\n" f"概括={record.summary or ''}\n"
f"关键信息={record.key_point or ''}" f"关键信息={record.key_point or ''}"
) )
logger.debug( logger.debug(f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}")
f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}"
)
return result return result
except Exception as e: except Exception as e:
logger.error(f"get_chat_history_detail 失败: {e}") logger.error(f"get_chat_history_detail 失败: {e}")
return f"get_chat_history_detail 执行失败: {e}" return f"get_chat_history_detail 执行失败: {e}"
return get_chat_history_detail return get_chat_history_detail

View File

@@ -78,9 +78,7 @@ def make_search_chat_history(chat_id: str):
if record.keywords: if record.keywords:
try: try:
keywords_data = ( keywords_data = (
json.loads(record.keywords) json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
if isinstance(record.keywords, str)
else record.keywords
) )
if isinstance(keywords_data, list): if isinstance(keywords_data, list):
record_keywords_list = [str(k).lower() for k in keywords_data] record_keywords_list = [str(k).lower() for k in keywords_data]
@@ -125,9 +123,7 @@ def make_search_chat_history(chat_id: str):
keywords_str = "".join(keywords_list) keywords_str = "".join(keywords_list)
if len(keywords_list) > 2: if len(keywords_list) > 2:
required_count = len(keywords_list) - 1 required_count = len(keywords_list) - 1
return ( return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
)
else: else:
return f"未找到包含所有关键词'{keywords_str}'的聊天记录" return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
elif participant: elif participant:
@@ -142,9 +138,7 @@ def make_search_chat_history(chat_id: str):
if record.keywords: if record.keywords:
try: try:
keywords_data = ( keywords_data = (
json.loads(record.keywords) json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
if isinstance(record.keywords, str)
else record.keywords
) )
if isinstance(keywords_data, list): if isinstance(keywords_data, list):
for k in keywords_data: for k in keywords_data:
@@ -160,13 +154,13 @@ def make_search_chat_history(chat_id: str):
keywords_str = "".join(sorted(all_keywords_set)) keywords_str = "".join(sorted(all_keywords_set))
response_text = ( response_text = (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n" f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词:\n" f'有关"{search_label}"的关键词:\n'
f"{keywords_str}" f"{keywords_str}"
) )
else: else:
response_text = ( response_text = (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n" f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词信息为空" f'有关"{search_label}"的关键词信息为空'
) )
logger.info( logger.info(
@@ -192,9 +186,7 @@ def make_search_chat_history(chat_id: str):
if record.keywords: if record.keywords:
try: try:
keywords_data = ( keywords_data = (
json.loads(record.keywords) json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
if isinstance(record.keywords, str)
else record.keywords
) )
if isinstance(keywords_data, list) and keywords_data: if isinstance(keywords_data, list) and keywords_data:
keywords_str = "".join([str(k) for k in keywords_data]) keywords_str = "".join([str(k) for k in keywords_data])
@@ -220,8 +212,3 @@ def make_search_chat_history(chat_id: str):
return f"search_chat_history 执行失败: {e}" return f"search_chat_history 执行失败: {e}"
return search_chat_history return search_chat_history

View File

@@ -16,9 +16,7 @@ def make_search_jargon(chat_id: str):
if not keyword or not keyword.strip(): if not keyword or not keyword.strip():
return "未指定查询关键词(参数 keyword 为必填,且不能为空)" return "未指定查询关键词(参数 keyword 为必填,且不能为空)"
logger.info( logger.info(f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})")
f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})"
)
# 基础条件:只查 is_jargon=True 的记录 # 基础条件:只查 is_jargon=True 的记录
query = Jargon.select().where(Jargon.is_jargon) query = Jargon.select().where(Jargon.is_jargon)
@@ -102,5 +100,3 @@ def make_search_jargon(chat_id: str):
return f"search_jargon 执行失败: {e}" return f"search_jargon 执行失败: {e}"
return search_jargon return search_jargon

View File

@@ -49,8 +49,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"update_chat_history 执行失败: {e}" return f"update_chat_history 执行失败: {e}"
return update_chat_history return update_chat_history

View File

@@ -49,8 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"update_jargon 执行失败: {e}" return f"update_jargon 执行失败: {e}"
return update_jargon return update_jargon

View File

@@ -316,7 +316,9 @@ class ChatHistorySummarizer:
before_count = len(self.current_batch.messages) before_count = len(self.current_batch.messages)
self.current_batch.messages.extend(new_messages) self.current_batch.messages.extend(new_messages)
self.current_batch.end_time = current_time self.current_batch.end_time = current_time
logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息") logger.info(
f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息"
)
# 更新批次后持久化 # 更新批次后持久化
self._persist_topic_cache() self._persist_topic_cache()
else: else:
@@ -362,9 +364,7 @@ class ChatHistorySummarizer:
else: else:
time_str = f"{time_since_last_check / 3600:.1f}小时" time_str = f"{time_since_last_check / 3600:.1f}小时"
logger.debug( logger.debug(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}")
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}"
)
# 检查“话题检查”触发条件 # 检查“话题检查”触发条件
should_check = False should_check = False
@@ -414,7 +414,7 @@ class ChatHistorySummarizer:
# 说明 bot 没有参与这段对话,不应该记录 # 说明 bot 没有参与这段对话,不应该记录
bot_user_id = str(global_config.bot.qq_account) bot_user_id = str(global_config.bot.qq_account)
has_bot_message = False has_bot_message = False
for msg in messages: for msg in messages:
if msg.user_info.user_id == bot_user_id: if msg.user_info.user_id == bot_user_id:
has_bot_message = True has_bot_message = True
@@ -427,7 +427,9 @@ class ChatHistorySummarizer:
return return
# 2. 构造编号后的消息字符串和参与者信息 # 2. 构造编号后的消息字符串和参与者信息
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages) numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = (
self._build_numbered_messages_for_llm(messages)
)
# 3. 调用 LLM 识别话题,并得到 topic -> indices失败时最多重试 3 次) # 3. 调用 LLM 识别话题,并得到 topic -> indices失败时最多重试 3 次)
existing_topics = list(self.topic_cache.keys()) existing_topics = list(self.topic_cache.keys())
@@ -456,9 +458,7 @@ class ChatHistorySummarizer:
) )
if not success or not topic_to_indices: if not success or not topic_to_indices:
logger.error( logger.error(f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃")
f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃"
)
# 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks保持原状 # 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks保持原状
return return
@@ -610,9 +610,7 @@ class ChatHistorySummarizer:
if not numbered_lines: if not numbered_lines:
return False, {} return False, {}
history_topics_block = ( history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
"\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
)
messages_block = "\n".join(numbered_lines) messages_block = "\n".join(numbered_lines)
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
@@ -635,17 +633,17 @@ class ChatHistorySummarizer:
json_str = None json_str = None
json_pattern = r"```json\s*(.*?)\s*```" json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL) matches = re.findall(json_pattern, response, re.DOTALL)
if matches: if matches:
# 找到JSON代码块使用第一个匹配 # 找到JSON代码块使用第一个匹配
json_str = matches[0].strip() json_str = matches[0].strip()
else: else:
# 如果没有找到代码块尝试查找JSON数组的开始和结束位置 # 如果没有找到代码块尝试查找JSON数组的开始和结束位置
# 查找第一个 [ 和最后一个 ] # 查找第一个 [ 和最后一个 ]
start_idx = response.find('[') start_idx = response.find("[")
end_idx = response.rfind(']') end_idx = response.rfind("]")
if start_idx != -1 and end_idx != -1 and end_idx > start_idx: if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
json_str = response[start_idx:end_idx + 1].strip() json_str = response[start_idx : end_idx + 1].strip()
else: else:
# 如果还是找不到尝试直接使用整个响应移除可能的markdown标记 # 如果还是找不到尝试直接使用整个响应移除可能的markdown标记
json_str = response.strip() json_str = response.strip()
@@ -942,4 +940,3 @@ class ChatHistorySummarizer:
init_prompt() init_prompt()

View File

@@ -49,7 +49,7 @@ class LLMRequest:
def _check_slow_request(self, time_cost: float, model_name: str) -> None: def _check_slow_request(self, time_cost: float, model_name: str) -> None:
"""检查请求是否过慢并输出警告日志 """检查请求是否过慢并输出警告日志
Args: Args:
time_cost: 请求耗时(秒) time_cost: 请求耗时(秒)
model_name: 使用的模型名称 model_name: 使用的模型名称
@@ -323,7 +323,7 @@ class LLMRequest:
effective_temperature = (model_info.extra_params or {}).get("temperature") effective_temperature = (model_info.extra_params or {}).get("temperature")
if effective_temperature is None: if effective_temperature is None:
effective_temperature = self.model_for_task.temperature effective_temperature = self.model_for_task.temperature
# max_tokens 优先级:参数传入 > 模型级别配置 > extra_params > 任务配置 # max_tokens 优先级:参数传入 > 模型级别配置 > extra_params > 任务配置
effective_max_tokens = max_tokens effective_max_tokens = max_tokens
if effective_max_tokens is None: if effective_max_tokens is None:
@@ -332,7 +332,7 @@ class LLMRequest:
effective_max_tokens = (model_info.extra_params or {}).get("max_tokens") effective_max_tokens = (model_info.extra_params or {}).get("max_tokens")
if effective_max_tokens is None: if effective_max_tokens is None:
effective_max_tokens = self.model_for_task.max_tokens effective_max_tokens = self.model_for_task.max_tokens
return await client.get_response( return await client.get_response(
model_info=model_info, model_info=model_info,
message_list=(compressed_messages or message_list), message_list=(compressed_messages or message_list),
@@ -366,7 +366,9 @@ class LLMRequest:
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}") logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}") logger.warning(
f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval) await asyncio.sleep(api_provider.retry_interval)
except NetworkConnectionError as e: except NetworkConnectionError as e:
@@ -394,7 +396,9 @@ class LLMRequest:
if e.status_code == 429 or e.status_code >= 500: if e.status_code == 429 or e.status_code >= 500:
retry_remain -= 1 retry_remain -= 1
if retry_remain <= 0: if retry_remain <= 0:
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}") logger.error(
f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}"
)
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning( logger.warning(
@@ -540,7 +544,5 @@ class LLMRequest:
if e.__cause__: if e.__cause__:
original_error_type = type(e.__cause__).__name__ original_error_type = type(e.__cause__).__name__
original_error_msg = str(e.__cause__) original_error_msg = str(e.__cause__)
return ( return f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
)
return "" return ""

View File

@@ -113,7 +113,6 @@ class MainSystem:
get_emoji_manager().initialize() get_emoji_manager().initialize()
logger.info("表情包管理器初始化成功") logger.info("表情包管理器初始化成功")
# 初始化聊天管理器 # 初始化聊天管理器
await get_chat_manager()._initialize() await get_chat_manager()._initialize()
asyncio.create_task(get_chat_manager()._auto_save_task()) asyncio.create_task(get_chat_manager()._auto_save_task())

View File

@@ -136,8 +136,6 @@ def init_memory_retrieval_prompt():
) )
def _log_conversation_messages( def _log_conversation_messages(
conversation_messages: List[Message], conversation_messages: List[Message],
head_prompt: Optional[str] = None, head_prompt: Optional[str] = None,
@@ -172,7 +170,9 @@ def _log_conversation_messages(
# 构建单条消息的日志信息 # 构建单条消息的日志信息
# msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------" # msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------" msg_info = (
f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
)
# if full_content: # if full_content:
# msg_info += f"\n{full_content}" # msg_info += f"\n{full_content}"
@@ -185,8 +185,7 @@ def _log_conversation_messages(
msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}" msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}"
# if msg.tool_call_id: # if msg.tool_call_id:
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}" # msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
log_lines.append(msg_info) log_lines.append(msg_info)
@@ -330,7 +329,7 @@ async def _react_agent_solve_question(
remaining_iterations=remaining_iterations, remaining_iterations=remaining_iterations,
max_iterations=max_iterations, max_iterations=max_iterations,
) )
# 后续迭代都复用第一次构建的head_prompt # 后续迭代都复用第一次构建的head_prompt
head_prompt = first_head_prompt head_prompt = first_head_prompt
@@ -365,7 +364,7 @@ async def _react_agent_solve_question(
) )
# logger.info( # logger.info(
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}" # f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
# ) # )
if not success: if not success:
@@ -409,20 +408,20 @@ async def _react_agent_solve_question(
"""从文本中解析finish_search函数调用返回(found_answer, answer)元组,如果未找到则返回(None, None)""" """从文本中解析finish_search函数调用返回(found_answer, answer)元组,如果未找到则返回(None, None)"""
if not text: if not text:
return None, None return None, None
# 查找finish_search函数调用位置不区分大小写 # 查找finish_search函数调用位置不区分大小写
func_pattern = "finish_search" func_pattern = "finish_search"
text_lower = text.lower() text_lower = text.lower()
func_pos = text_lower.find(func_pattern) func_pos = text_lower.find(func_pattern)
if func_pos == -1: if func_pos == -1:
return None, None return None, None
# 查找函数调用的开始和结束位置 # 查找函数调用的开始和结束位置
# 从func_pos开始向后查找左括号 # 从func_pos开始向后查找左括号
start_pos = text.find("(", func_pos) start_pos = text.find("(", func_pos)
if start_pos == -1: if start_pos == -1:
return None, None return None, None
# 查找匹配的右括号(考虑嵌套) # 查找匹配的右括号(考虑嵌套)
paren_count = 0 paren_count = 0
end_pos = start_pos end_pos = start_pos
@@ -437,10 +436,10 @@ async def _react_agent_solve_question(
else: else:
# 没有找到匹配的右括号 # 没有找到匹配的右括号
return None, None return None, None
# 提取函数参数部分 # 提取函数参数部分
params_text = text[start_pos + 1 : end_pos] params_text = text[start_pos + 1 : end_pos]
# 解析found_answer参数布尔值可能是true/false/True/False # 解析found_answer参数布尔值可能是true/false/True/False
found_answer = None found_answer = None
found_answer_patterns = [ found_answer_patterns = [
@@ -454,49 +453,60 @@ async def _react_agent_solve_question(
if match: if match:
found_answer = "true" in match.group(0).lower() found_answer = "true" in match.group(0).lower()
break break
# 解析answer参数字符串使用extract_quoted_content # 解析answer参数字符串使用extract_quoted_content
answer = extract_quoted_content(text, "finish_search", "answer") answer = extract_quoted_content(text, "finish_search", "answer")
return found_answer, answer return found_answer, answer
parsed_found_answer, parsed_answer = parse_finish_search_from_text(response) parsed_found_answer, parsed_answer = parse_finish_search_from_text(response)
if parsed_found_answer is not None: if parsed_found_answer is not None:
# 检测到finish_search函数调用格式 # 检测到finish_search函数调用格式
if parsed_found_answer: if parsed_found_answer:
# 找到了答案 # 找到了答案
if parsed_answer: if parsed_answer:
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": parsed_answer}}) step["actions"].append(
{
"action_type": "finish_search",
"action_params": {"found_answer": True, "answer": parsed_answer},
}
)
step["observations"] = ["检测到finish_search文本格式调用找到答案"] step["observations"] = ["检测到finish_search文本格式调用找到答案"]
thinking_steps.append(step) thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}") logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
)
_log_conversation_messages( _log_conversation_messages(
conversation_messages, conversation_messages,
head_prompt=first_head_prompt, head_prompt=first_head_prompt,
final_status=f"找到答案:{parsed_answer}", final_status=f"找到答案:{parsed_answer}",
) )
return True, parsed_answer, thinking_steps, False return True, parsed_answer, thinking_steps, False
else: else:
# found_answer为True但没有提供answer视为错误继续迭代 # found_answer为True但没有提供answer视为错误继续迭代
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer") logger.warning(
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
)
else: else:
# 未找到答案,直接退出查询 # 未找到答案,直接退出查询
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}}) step["actions"].append(
{"action_type": "finish_search", "action_params": {"found_answer": False}}
)
step["observations"] = ["检测到finish_search文本格式调用未找到答案"] step["observations"] = ["检测到finish_search文本格式调用未找到答案"]
thinking_steps.append(step) thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案") logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案")
_log_conversation_messages( _log_conversation_messages(
conversation_messages, conversation_messages,
head_prompt=first_head_prompt, head_prompt=first_head_prompt,
final_status="未找到答案通过finish_search文本格式判断未找到答案", final_status="未找到答案通过finish_search文本格式判断未找到答案",
) )
return False, "", thinking_steps, False return False, "", thinking_steps, False
# 如果没有检测到finish_search格式记录思考过程继续下一轮迭代 # 如果没有检测到finish_search格式记录思考过程继续下一轮迭代
step["observations"] = [f"思考完成,但未调用工具。响应: {response}"] step["observations"] = [f"思考完成,但未调用工具。响应: {response}"]
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考完成但未调用工具: {response}") logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考完成但未调用工具: {response}")
@@ -514,44 +524,53 @@ async def _react_agent_solve_question(
for tool_call in tool_calls: for tool_call in tool_calls:
tool_name = tool_call.func_name tool_name = tool_call.func_name
tool_args = tool_call.args or {} tool_args = tool_call.args or {}
if tool_name == "finish_search": if tool_name == "finish_search":
finish_search_found = tool_args.get("found_answer", False) finish_search_found = tool_args.get("found_answer", False)
finish_search_answer = tool_args.get("answer", "") finish_search_answer = tool_args.get("answer", "")
if finish_search_found: if finish_search_found:
# 找到了答案 # 找到了答案
if finish_search_answer: if finish_search_answer:
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": finish_search_answer}}) step["actions"].append(
{
"action_type": "finish_search",
"action_params": {"found_answer": True, "answer": finish_search_answer},
}
)
step["observations"] = ["检测到finish_search工具调用找到答案"] step["observations"] = ["检测到finish_search工具调用找到答案"]
thinking_steps.append(step) thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}") logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
)
_log_conversation_messages( _log_conversation_messages(
conversation_messages, conversation_messages,
head_prompt=first_head_prompt, head_prompt=first_head_prompt,
final_status=f"找到答案:{finish_search_answer}", final_status=f"找到答案:{finish_search_answer}",
) )
return True, finish_search_answer, thinking_steps, False return True, finish_search_answer, thinking_steps, False
else: else:
# found_answer为True但没有提供answer视为错误 # found_answer为True但没有提供answer视为错误
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer") logger.warning(
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
)
else: else:
# 未找到答案,直接退出查询 # 未找到答案,直接退出查询
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}}) step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
step["observations"] = ["检测到finish_search工具调用未找到答案"] step["observations"] = ["检测到finish_search工具调用未找到答案"]
thinking_steps.append(step) thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具判断未找到答案") logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具判断未找到答案")
_log_conversation_messages( _log_conversation_messages(
conversation_messages, conversation_messages,
head_prompt=first_head_prompt, head_prompt=first_head_prompt,
final_status="未找到答案通过finish_search工具判断未找到答案", final_status="未找到答案通过finish_search工具判断未找到答案",
) )
return False, "", thinking_steps, False return False, "", thinking_steps, False
# 如果没有finish_search工具调用继续处理其他工具 # 如果没有finish_search工具调用继续处理其他工具
tool_tasks = [] tool_tasks = []
for i, tool_call in enumerate(tool_calls): for i, tool_call in enumerate(tool_calls):
@@ -627,7 +646,7 @@ async def _react_agent_solve_question(
observation_text += f"\n\n{jargon_info}" observation_text += f"\n\n{jargon_info}"
collected_info += f"\n{jargon_info}\n" collected_info += f"\n{jargon_info}\n"
logger.info(f"工具输出触发黑话解析: {new_concepts}") logger.info(f"工具输出触发黑话解析: {new_concepts}")
tool_builder = MessageBuilder() tool_builder = MessageBuilder()
tool_builder.set_role(RoleType.Tool) tool_builder.set_role(RoleType.Tool)
tool_builder.add_text_content(observation_text) tool_builder.add_text_content(observation_text)
@@ -645,7 +664,7 @@ async def _react_agent_solve_question(
elif iteration + 1 >= max_iterations: elif iteration + 1 >= max_iterations:
should_do_final_evaluation = True should_do_final_evaluation = True
logger.info(f"ReAct Agent达到最大迭代次数已迭代{iteration + 1}次),进入最终评估") logger.info(f"ReAct Agent达到最大迭代次数已迭代{iteration + 1}次),进入最终评估")
if should_do_final_evaluation: if should_do_final_evaluation:
# 获取必要变量用于最终评估 # 获取必要变量用于最终评估
tool_registry = get_tool_registry() tool_registry = get_tool_registry()
@@ -653,7 +672,7 @@ async def _react_agent_solve_question(
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
current_iteration = iteration + 1 current_iteration = iteration + 1
remaining_iterations = 0 remaining_iterations = 0
# 提取函数调用中参数的值,支持单引号和双引号 # 提取函数调用中参数的值,支持单引号和双引号
def extract_quoted_content(text, func_name, param_name): def extract_quoted_content(text, func_name, param_name):
"""从文本中提取函数调用中参数的值,支持单引号和双引号 """从文本中提取函数调用中参数的值,支持单引号和双引号
@@ -724,7 +743,13 @@ async def _react_agent_solve_question(
max_iterations=max_iterations, max_iterations=max_iterations,
) )
eval_success, eval_response, eval_reasoning_content, eval_model_name, eval_tool_calls = await llm_api.generate_with_model_with_tools( (
eval_success,
eval_response,
eval_reasoning_content,
eval_model_name,
eval_tool_calls,
) = await llm_api.generate_with_model_with_tools(
evaluation_prompt, evaluation_prompt,
model_config=model_config.model_task_config.tool_use, model_config=model_config.model_task_config.tool_use,
tool_options=[], # 最终评估阶段不提供工具 tool_options=[], # 最终评估阶段不提供工具
@@ -739,7 +764,7 @@ async def _react_agent_solve_question(
final_status="未找到答案最终评估阶段LLM调用失败", final_status="未找到答案最终评估阶段LLM调用失败",
) )
return False, "最终评估阶段LLM调用失败", thinking_steps, is_timeout return False, "最终评估阶段LLM调用失败", thinking_steps, is_timeout
if global_config.debug.show_memory_prompt: if global_config.debug.show_memory_prompt:
logger.info(f"ReAct Agent 最终评估Prompt: {evaluation_prompt}") logger.info(f"ReAct Agent 最终评估Prompt: {evaluation_prompt}")
logger.info(f"ReAct Agent 最终评估响应: {eval_response}") logger.info(f"ReAct Agent 最终评估响应: {eval_response}")
@@ -759,17 +784,17 @@ async def _react_agent_solve_question(
"iteration": current_iteration, "iteration": current_iteration,
"thought": f"[最终评估] {eval_response}", "thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}], "actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}],
"observations": ["最终评估阶段检测到found_answer"] "observations": ["最终评估阶段检测到found_answer"],
} }
thinking_steps.append(eval_step) thinking_steps.append(eval_step)
logger.info(f"ReAct Agent 最终评估阶段找到关于问题{question}的答案: {found_answer_content}") logger.info(f"ReAct Agent 最终评估阶段找到关于问题{question}的答案: {found_answer_content}")
_log_conversation_messages( _log_conversation_messages(
conversation_messages, conversation_messages,
head_prompt=first_head_prompt, head_prompt=first_head_prompt,
final_status=f"找到答案:{found_answer_content}", final_status=f"找到答案:{found_answer_content}",
) )
return True, found_answer_content, thinking_steps, False return True, found_answer_content, thinking_steps, False
# 如果评估为not_enough_info返回空字符串不返回任何信息 # 如果评估为not_enough_info返回空字符串不返回任何信息
@@ -778,35 +803,37 @@ async def _react_agent_solve_question(
"iteration": current_iteration, "iteration": current_iteration,
"thought": f"[最终评估] {eval_response}", "thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}], "actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}],
"observations": ["最终评估阶段检测到not_enough_info"] "observations": ["最终评估阶段检测到not_enough_info"],
} }
thinking_steps.append(eval_step) thinking_steps.append(eval_step)
logger.info(f"ReAct Agent 最终评估阶段判断信息不足: {not_enough_info_reason}") logger.info(f"ReAct Agent 最终评估阶段判断信息不足: {not_enough_info_reason}")
_log_conversation_messages( _log_conversation_messages(
conversation_messages, conversation_messages,
head_prompt=first_head_prompt, head_prompt=first_head_prompt,
final_status=f"未找到答案:{not_enough_info_reason}", final_status=f"未找到答案:{not_enough_info_reason}",
) )
return False, "", thinking_steps, is_timeout return False, "", thinking_steps, is_timeout
# 如果没有明确判断视为not_enough_info返回空字符串不返回任何信息 # 如果没有明确判断视为not_enough_info返回空字符串不返回任何信息
eval_step = { eval_step = {
"iteration": current_iteration, "iteration": current_iteration,
"thought": f"[最终评估] {eval_response}", "thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}], "actions": [
"observations": ["已到达最大迭代次数,无法找到答案"] {"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}
],
"observations": ["已到达最大迭代次数,无法找到答案"],
} }
thinking_steps.append(eval_step) thinking_steps.append(eval_step)
logger.info("ReAct Agent 已到达最大迭代次数,无法找到答案") logger.info("ReAct Agent 已到达最大迭代次数,无法找到答案")
_log_conversation_messages( _log_conversation_messages(
conversation_messages, conversation_messages,
head_prompt=first_head_prompt, head_prompt=first_head_prompt,
final_status="未找到答案:已到达最大迭代次数,无法找到答案", final_status="未找到答案:已到达最大迭代次数,无法找到答案",
) )
return False, "", thinking_steps, is_timeout return False, "", thinking_steps, is_timeout
# 如果正常迭代过程中提前找到答案返回,不会到达这里 # 如果正常迭代过程中提前找到答案返回,不会到达这里
@@ -817,7 +844,7 @@ async def _react_agent_solve_question(
head_prompt=first_head_prompt, head_prompt=first_head_prompt,
final_status="未找到答案:正常迭代结束", final_status="未找到答案:正常迭代结束",
) )
return False, "", thinking_steps, is_timeout return False, "", thinking_steps, is_timeout
@@ -1129,7 +1156,9 @@ async def build_memory_retrieval_prompt(
else: else:
max_iterations = base_max_iterations max_iterations = base_max_iterations
timeout_seconds = global_config.memory.agent_timeout_seconds timeout_seconds = global_config.memory.agent_timeout_seconds
logger.debug(f"问题数量: {len(questions)}think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}") logger.debug(
f"问题数量: {len(questions)}think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}"
)
# 并行处理所有问题,将概念检索结果作为初始信息传递 # 并行处理所有问题,将概念检索结果作为初始信息传递
question_tasks = [ question_tasks = [
@@ -1157,10 +1186,10 @@ async def build_memory_retrieval_prompt(
# 获取最近10分钟内已找到答案的缓存记录 # 获取最近10分钟内已找到答案的缓存记录
cached_answers = _get_recent_found_answers(chat_id, time_window_seconds=600.0) cached_answers = _get_recent_found_answers(chat_id, time_window_seconds=600.0)
# 合并当前查询结果和缓存答案(去重:如果当前查询的问题在缓存中已存在,优先使用当前结果) # 合并当前查询结果和缓存答案(去重:如果当前查询的问题在缓存中已存在,优先使用当前结果)
all_results = [] all_results = []
# 先添加当前查询的结果 # 先添加当前查询的结果
current_questions = set() current_questions = set()
for result in question_results: for result in question_results:
@@ -1170,7 +1199,7 @@ async def build_memory_retrieval_prompt(
if question_end != -1: if question_end != -1:
current_questions.add(result[4:question_end]) current_questions.add(result[4:question_end])
all_results.append(result) all_results.append(result)
# 添加缓存答案(排除当前查询中已存在的问题) # 添加缓存答案(排除当前查询中已存在的问题)
for cached_answer in cached_answers: for cached_answer in cached_answers:
if cached_answer.startswith("问题:"): if cached_answer.startswith("问题:"):
@@ -1198,4 +1227,3 @@ async def build_memory_retrieval_prompt(
except Exception as e: except Exception as e:
logger.error(f"记忆检索时发生异常: {str(e)}") logger.error(f"记忆检索时发生异常: {str(e)}")
return "" return ""

View File

@@ -17,7 +17,6 @@ from src.common.logger import get_logger
logger = get_logger("memory_utils") logger = get_logger("memory_utils")
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]: def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
"""解析问题JSON返回概念列表和问题列表 """解析问题JSON返回概念列表和问题列表
@@ -68,6 +67,7 @@ def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...") logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
return [], [] return [], []
def parse_datetime_to_timestamp(value: str) -> float: def parse_datetime_to_timestamp(value: str) -> float:
""" """
接受多种常见格式并转换为时间戳(秒) 接受多种常见格式并转换为时间戳(秒)

View File

@@ -47,4 +47,3 @@ def register_tool():
], ],
execute_func=finish_search, execute_func=finish_search,
) )

View File

@@ -16,9 +16,7 @@ from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools") logger = get_logger("memory_retrieval_tools")
async def search_chat_history( async def search_chat_history(chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None) -> str:
chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None
) -> str:
"""根据关键词或参与人查询记忆返回匹配的记忆id、记忆标题theme和关键词keywords """根据关键词或参与人查询记忆返回匹配的记忆id、记忆标题theme和关键词keywords
Args: Args:
@@ -117,7 +115,7 @@ async def search_chat_history(
) )
if kw_matched: if kw_matched:
matched_count += 1 matched_count += 1
# 计算需要匹配的关键词数量 # 计算需要匹配的关键词数量
total_keywords = len(keywords_lower) total_keywords = len(keywords_lower)
if total_keywords > 2: if total_keywords > 2:
@@ -126,7 +124,7 @@ async def search_chat_history(
else: else:
# 关键词数量<=2必须全部匹配 # 关键词数量<=2必须全部匹配
required_matches = total_keywords required_matches = total_keywords
keyword_matched = matched_count >= required_matches keyword_matched = matched_count >= required_matches
# 两者都匹配如果同时有participant和keyword需要两者都匹配如果只有一个条件只需要该条件匹配 # 两者都匹配如果同时有participant和keyword需要两者都匹配如果只有一个条件只需要该条件匹配
@@ -144,7 +142,9 @@ async def search_chat_history(
keywords_list = parse_keywords_string(keyword) keywords_list = parse_keywords_string(keyword)
if len(keywords_list) > 2: if len(keywords_list) > 2:
required_count = len(keywords_list) - 1 required_count = len(keywords_list) - 1
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录" return (
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
)
else: else:
return f"未找到包含所有关键词'{keywords_str}'的聊天记录" return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
elif participant: elif participant:
@@ -160,9 +160,7 @@ async def search_chat_history(
if record.keywords: if record.keywords:
try: try:
keywords_data = ( keywords_data = (
json.loads(record.keywords) json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
if isinstance(record.keywords, str)
else record.keywords
) )
if isinstance(keywords_data, list): if isinstance(keywords_data, list):
for k in keywords_data: for k in keywords_data:
@@ -179,13 +177,12 @@ async def search_chat_history(
keywords_str = "".join(sorted(all_keywords_set)) keywords_str = "".join(sorted(all_keywords_set))
return ( return (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n" f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词:\n" f'有关"{search_label}"的关键词:\n'
f"{keywords_str}" f"{keywords_str}"
) )
else: else:
return ( return (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n" f'包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n有关"{search_label}"的关键词信息为空'
f"有关\"{search_label}\"的关键词信息为空"
) )
# 构建结果文本返回id、theme和keywords最多20条 # 构建结果文本返回id、theme和keywords最多20条

View File

@@ -22,42 +22,42 @@ def get_current_token(
) -> str: ) -> str:
""" """
获取当前请求的 token优先从 Cookie 获取,其次从 Header 获取 获取当前请求的 token优先从 Cookie 获取,其次从 Header 获取
Args: Args:
request: FastAPI Request 对象 request: FastAPI Request 对象
maibot_session: Cookie 中的 token maibot_session: Cookie 中的 token
authorization: Authorization Header (Bearer token) authorization: Authorization Header (Bearer token)
Returns: Returns:
验证通过的 token 验证通过的 token
Raises: Raises:
HTTPException: 认证失败时抛出 401 错误 HTTPException: 认证失败时抛出 401 错误
""" """
token = None token = None
# 优先从 Cookie 获取 # 优先从 Cookie 获取
if maibot_session: if maibot_session:
token = maibot_session token = maibot_session
# 其次从 Header 获取(兼容旧版本) # 其次从 Header 获取(兼容旧版本)
elif authorization and authorization.startswith("Bearer "): elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "") token = authorization.replace("Bearer ", "")
if not token: if not token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息") raise HTTPException(status_code=401, detail="未提供有效的认证信息")
# 验证 token # 验证 token
token_manager = get_token_manager() token_manager = get_token_manager()
if not token_manager.verify_token(token): if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期") raise HTTPException(status_code=401, detail="Token 无效或已过期")
return token return token
def set_auth_cookie(response: Response, token: str) -> None: def set_auth_cookie(response: Response, token: str) -> None:
""" """
设置认证 Cookie 设置认证 Cookie
Args: Args:
response: FastAPI Response 对象 response: FastAPI Response 对象
token: 要设置的 token token: 要设置的 token
@@ -77,7 +77,7 @@ def set_auth_cookie(response: Response, token: str) -> None:
def clear_auth_cookie(response: Response) -> None: def clear_auth_cookie(response: Response) -> None:
""" """
清除认证 Cookie 清除认证 Cookie
Args: Args:
response: FastAPI Response 对象 response: FastAPI Response 对象
""" """
@@ -96,32 +96,32 @@ def verify_auth_token_from_cookie_or_header(
) -> bool: ) -> bool:
""" """
验证认证 Token支持从 Cookie 或 Header 获取 验证认证 Token支持从 Cookie 或 Header 获取
Args: Args:
maibot_session: Cookie 中的 token maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token) authorization: Authorization header (Bearer token)
Returns: Returns:
验证成功返回 True 验证成功返回 True
Raises: Raises:
HTTPException: 认证失败时抛出 401 错误 HTTPException: 认证失败时抛出 401 错误
""" """
token = None token = None
# 优先从 Cookie 获取 # 优先从 Cookie 获取
if maibot_session: if maibot_session:
token = maibot_session token = maibot_session
# 其次从 Header 获取(兼容旧版本) # 其次从 Header 获取(兼容旧版本)
elif authorization and authorization.startswith("Bearer "): elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "") token = authorization.replace("Bearer ", "")
if not token: if not token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息") raise HTTPException(status_code=401, detail="未提供有效的认证信息")
# 验证 token # 验证 token
token_manager = get_token_manager() token_manager = get_token_manager()
if not token_manager.verify_token(token): if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期") raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True return True

View File

@@ -63,14 +63,14 @@ class ChatHistoryManager:
def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]: def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]:
"""将数据库消息转换为前端格式 """将数据库消息转换为前端格式
Args: Args:
msg: 数据库消息对象 msg: 数据库消息对象
group_id: 群 ID用于判断是否是虚拟群 group_id: 群 ID用于判断是否是虚拟群
""" """
# 判断是否是机器人消息 # 判断是否是机器人消息
user_id = msg.user_id or "" user_id = msg.user_id or ""
# 对于虚拟群,通过比较机器人 QQ 账号来判断 # 对于虚拟群,通过比较机器人 QQ 账号来判断
# 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头 # 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头
if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX): if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
@@ -414,7 +414,9 @@ async def websocket_chat(
group_id=virtual_group_id, group_id=virtual_group_id,
group_name=group_name or "WebUI虚拟群聊", group_name=group_name or "WebUI虚拟群聊",
) )
logger.info(f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}") logger.info(
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
)
except Exception as e: except Exception as e:
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}") logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")

View File

@@ -1,4 +1,4 @@
""" 表情包管理 API 路由""" """表情包管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
from fastapi.responses import FileResponse, JSONResponse from fastapi.responses import FileResponse, JSONResponse
@@ -48,7 +48,7 @@ def _get_thumbnail_lock(file_hash: str) -> threading.Lock:
def _background_generate_thumbnail(source_path: str, file_hash: str) -> None: def _background_generate_thumbnail(source_path: str, file_hash: str) -> None:
""" """
后台生成缩略图(在线程池中执行) 后台生成缩略图(在线程池中执行)
生成完成后自动从 generating 集合中移除 生成完成后自动从 generating 集合中移除
""" """
try: try:
@@ -74,14 +74,14 @@ def _get_thumbnail_cache_path(file_hash: str) -> Path:
def _generate_thumbnail(source_path: str, file_hash: str) -> Path: def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
""" """
生成缩略图并保存到缓存目录 生成缩略图并保存到缓存目录
Args: Args:
source_path: 原图路径 source_path: 原图路径
file_hash: 文件哈希值,用作缓存文件名 file_hash: 文件哈希值,用作缓存文件名
Returns: Returns:
缩略图路径 缩略图路径
Features: Features:
- GIF: 提取第一帧作为缩略图 - GIF: 提取第一帧作为缩略图
- 所有格式统一转为 WebP - 所有格式统一转为 WebP
@@ -89,63 +89,63 @@ def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
""" """
_ensure_thumbnail_cache_dir() _ensure_thumbnail_cache_dir()
cache_path = _get_thumbnail_cache_path(file_hash) cache_path = _get_thumbnail_cache_path(file_hash)
# 使用锁防止并发生成同一缩略图 # 使用锁防止并发生成同一缩略图
lock = _get_thumbnail_lock(file_hash) lock = _get_thumbnail_lock(file_hash)
with lock: with lock:
# 双重检查,可能在等待锁时已被其他线程生成 # 双重检查,可能在等待锁时已被其他线程生成
if cache_path.exists(): if cache_path.exists():
return cache_path return cache_path
try: try:
with Image.open(source_path) as img: with Image.open(source_path) as img:
# GIF 处理:提取第一帧 # GIF 处理:提取第一帧
if hasattr(img, 'n_frames') and img.n_frames > 1: if hasattr(img, "n_frames") and img.n_frames > 1:
img.seek(0) # 确保在第一帧 img.seek(0) # 确保在第一帧
# 转换为 RGB/RGBAWebP 支持透明度) # 转换为 RGB/RGBAWebP 支持透明度)
if img.mode in ('P', 'PA'): if img.mode in ("P", "PA"):
# 调色板模式转换为 RGBA 以保留透明度 # 调色板模式转换为 RGBA 以保留透明度
img = img.convert('RGBA') img = img.convert("RGBA")
elif img.mode == 'LA': elif img.mode == "LA":
img = img.convert('RGBA') img = img.convert("RGBA")
elif img.mode not in ('RGB', 'RGBA'): elif img.mode not in ("RGB", "RGBA"):
img = img.convert('RGB') img = img.convert("RGB")
# 创建缩略图(保持宽高比) # 创建缩略图(保持宽高比)
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS) img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
# 保存为 WebP 格式 # 保存为 WebP 格式
img.save(cache_path, 'WEBP', quality=THUMBNAIL_QUALITY, method=6) img.save(cache_path, "WEBP", quality=THUMBNAIL_QUALITY, method=6)
logger.debug(f"生成缩略图: {file_hash} -> {cache_path}") logger.debug(f"生成缩略图: {file_hash} -> {cache_path}")
except Exception as e: except Exception as e:
logger.warning(f"生成缩略图失败 {file_hash}: {e},将返回原图") logger.warning(f"生成缩略图失败 {file_hash}: {e},将返回原图")
# 生成失败时不创建缓存文件,下次会重试 # 生成失败时不创建缓存文件,下次会重试
raise raise
return cache_path return cache_path
def cleanup_orphaned_thumbnails() -> tuple[int, int]: def cleanup_orphaned_thumbnails() -> tuple[int, int]:
""" """
清理孤立的缩略图缓存(原图已不存在的缩略图) 清理孤立的缩略图缓存(原图已不存在的缩略图)
Returns: Returns:
(清理数量, 保留数量) (清理数量, 保留数量)
""" """
if not THUMBNAIL_CACHE_DIR.exists(): if not THUMBNAIL_CACHE_DIR.exists():
return 0, 0 return 0, 0
# 获取所有表情包的哈希值 # 获取所有表情包的哈希值
valid_hashes = set() valid_hashes = set()
for emoji in Emoji.select(Emoji.emoji_hash): for emoji in Emoji.select(Emoji.emoji_hash):
valid_hashes.add(emoji.emoji_hash) valid_hashes.add(emoji.emoji_hash)
cleaned = 0 cleaned = 0
kept = 0 kept = 0
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"): for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
file_hash = cache_file.stem file_hash = cache_file.stem
if file_hash not in valid_hashes: if file_hash not in valid_hashes:
@@ -157,12 +157,13 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]:
logger.warning(f"清理缩略图失败 {cache_file.name}: {e}") logger.warning(f"清理缩略图失败 {cache_file.name}: {e}")
else: else:
kept += 1 kept += 1
if cleaned > 0: if cleaned > 0:
logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept}") logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept}")
return cleaned, kept return cleaned, kept
# 模块级别的类型别名(解决 B008 ruff 错误) # 模块级别的类型别名(解决 B008 ruff 错误)
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")] EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")] EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
@@ -365,7 +366,9 @@ async def get_emoji_list(
@router.get("/{emoji_id}", response_model=EmojiDetailResponse) @router.get("/{emoji_id}", response_model=EmojiDetailResponse)
async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def get_emoji_detail(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
获取表情包详细信息 获取表情包详细信息
@@ -394,7 +397,12 @@ async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse) @router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def update_emoji(
emoji_id: int,
request: EmojiUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
""" """
增量更新表情包(只更新提供的字段) 增量更新表情包(只更新提供的字段)
@@ -446,7 +454,9 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_sessio
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse) @router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def delete_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
删除表情包 删除表情包
@@ -538,7 +548,9 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse) @router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def register_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
注册表情包(快捷操作) 注册表情包(快捷操作)
@@ -578,7 +590,9 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse) @router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def ban_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
禁用表情包(快捷操作) 禁用表情包(快捷操作)
@@ -633,7 +647,7 @@ async def get_emoji_thumbnail(
Returns: Returns:
表情包缩略图WebP 格式)或原图 表情包缩略图WebP 格式)或原图
Features: Features:
- 懒加载:首次请求时生成缩略图 - 懒加载:首次请求时生成缩略图
- 缓存:后续请求直接返回缓存 - 缓存:后续请求直接返回缓存
@@ -643,7 +657,7 @@ async def get_emoji_thumbnail(
try: try:
token_manager = get_token_manager() token_manager = get_token_manager()
is_valid = False is_valid = False
# 1. 优先使用 Cookie # 1. 优先使用 Cookie
if maibot_session and token_manager.verify_token(maibot_session): if maibot_session and token_manager.verify_token(maibot_session):
is_valid = True is_valid = True
@@ -655,7 +669,7 @@ async def get_emoji_thumbnail(
auth_token = authorization.replace("Bearer ", "") auth_token = authorization.replace("Bearer ", "")
if token_manager.verify_token(auth_token): if token_manager.verify_token(auth_token):
is_valid = True is_valid = True
if not is_valid: if not is_valid:
raise HTTPException(status_code=401, detail="Token 无效或已过期") raise HTTPException(status_code=401, detail="Token 无效或已过期")
@@ -680,35 +694,27 @@ async def get_emoji_thumbnail(
} }
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream") media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
return FileResponse( return FileResponse(
path=emoji.full_path, path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}"
media_type=media_type,
filename=f"{emoji.emoji_hash}.{emoji.format}"
) )
# 尝试获取或生成缩略图 # 尝试获取或生成缩略图
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash) cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
# 检查缓存是否存在 # 检查缓存是否存在
if cache_path.exists(): if cache_path.exists():
# 缓存命中,直接返回 # 缓存命中,直接返回
return FileResponse( return FileResponse(
path=str(cache_path), path=str(cache_path), media_type="image/webp", filename=f"{emoji.emoji_hash}_thumb.webp"
media_type="image/webp",
filename=f"{emoji.emoji_hash}_thumb.webp"
) )
# 缓存未命中,触发后台生成并返回 202 # 缓存未命中,触发后台生成并返回 202
with _generating_lock: with _generating_lock:
if emoji.emoji_hash not in _generating_thumbnails: if emoji.emoji_hash not in _generating_thumbnails:
# 标记为正在生成 # 标记为正在生成
_generating_thumbnails.add(emoji.emoji_hash) _generating_thumbnails.add(emoji.emoji_hash)
# 提交到线程池后台生成 # 提交到线程池后台生成
_thumbnail_executor.submit( _thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.emoji_hash)
_background_generate_thumbnail,
emoji.full_path,
emoji.emoji_hash
)
# 返回 202 Accepted告诉前端缩略图正在生成中 # 返回 202 Accepted告诉前端缩略图正在生成中
return JSONResponse( return JSONResponse(
status_code=202, status_code=202,
@@ -719,7 +725,7 @@ async def get_emoji_thumbnail(
}, },
headers={ headers={
"Retry-After": "1", # 建议 1 秒后重试 "Retry-After": "1", # 建议 1 秒后重试
} },
) )
except HTTPException: except HTTPException:
@@ -730,7 +736,11 @@ async def get_emoji_thumbnail(
@router.post("/batch/delete", response_model=BatchDeleteResponse) @router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_emojis(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def batch_delete_emojis(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
""" """
批量删除表情包 批量删除表情包
@@ -1079,7 +1089,7 @@ async def batch_upload_emoji(
class ThumbnailCacheStatsResponse(BaseModel): class ThumbnailCacheStatsResponse(BaseModel):
"""缩略图缓存统计响应""" """缩略图缓存统计响应"""
success: bool success: bool
cache_dir: str cache_dir: str
total_count: int total_count: int
@@ -1090,7 +1100,7 @@ class ThumbnailCacheStatsResponse(BaseModel):
class ThumbnailCleanupResponse(BaseModel): class ThumbnailCleanupResponse(BaseModel):
"""缩略图清理响应""" """缩略图清理响应"""
success: bool success: bool
message: str message: str
cleaned_count: int cleaned_count: int
@@ -1099,7 +1109,7 @@ class ThumbnailCleanupResponse(BaseModel):
class ThumbnailPreheatResponse(BaseModel): class ThumbnailPreheatResponse(BaseModel):
"""缩略图预热响应""" """缩略图预热响应"""
success: bool success: bool
message: str message: str
generated_count: int generated_count: int
@@ -1114,27 +1124,27 @@ async def get_thumbnail_cache_stats(
): ):
""" """
获取缩略图缓存统计信息 获取缩略图缓存统计信息
Returns: Returns:
缓存目录、缓存数量、总大小、覆盖率等统计信息 缓存目录、缓存数量、总大小、覆盖率等统计信息
""" """
try: try:
verify_auth_token(maibot_session, authorization) verify_auth_token(maibot_session, authorization)
_ensure_thumbnail_cache_dir() _ensure_thumbnail_cache_dir()
# 统计缓存文件 # 统计缓存文件
cache_files = list(THUMBNAIL_CACHE_DIR.glob("*.webp")) cache_files = list(THUMBNAIL_CACHE_DIR.glob("*.webp"))
total_count = len(cache_files) total_count = len(cache_files)
total_size = sum(f.stat().st_size for f in cache_files) total_size = sum(f.stat().st_size for f in cache_files)
total_size_mb = round(total_size / (1024 * 1024), 2) total_size_mb = round(total_size / (1024 * 1024), 2)
# 统计表情包总数 # 统计表情包总数
emoji_count = Emoji.select().count() emoji_count = Emoji.select().count()
# 计算覆盖率 # 计算覆盖率
coverage_percent = round((total_count / emoji_count * 100) if emoji_count > 0 else 0, 1) coverage_percent = round((total_count / emoji_count * 100) if emoji_count > 0 else 0, 1)
return ThumbnailCacheStatsResponse( return ThumbnailCacheStatsResponse(
success=True, success=True,
cache_dir=str(THUMBNAIL_CACHE_DIR.absolute()), cache_dir=str(THUMBNAIL_CACHE_DIR.absolute()),
@@ -1143,7 +1153,7 @@ async def get_thumbnail_cache_stats(
emoji_count=emoji_count, emoji_count=emoji_count,
coverage_percent=coverage_percent, coverage_percent=coverage_percent,
) )
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@@ -1158,22 +1168,22 @@ async def cleanup_thumbnail_cache(
): ):
""" """
清理孤立的缩略图缓存(原图已删除的表情包对应的缩略图) 清理孤立的缩略图缓存(原图已删除的表情包对应的缩略图)
Returns: Returns:
清理结果 清理结果
""" """
try: try:
verify_auth_token(maibot_session, authorization) verify_auth_token(maibot_session, authorization)
cleaned, kept = cleanup_orphaned_thumbnails() cleaned, kept = cleanup_orphaned_thumbnails()
return ThumbnailCleanupResponse( return ThumbnailCleanupResponse(
success=True, success=True,
message=f"清理完成:删除 {cleaned} 个孤立缓存,保留 {kept} 个有效缓存", message=f"清理完成:删除 {cleaned} 个孤立缓存,保留 {kept} 个有效缓存",
cleaned_count=cleaned, cleaned_count=cleaned,
kept_count=kept, kept_count=kept,
) )
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@@ -1189,20 +1199,20 @@ async def preheat_thumbnail_cache(
): ):
""" """
预热缩略图缓存(提前生成未缓存的缩略图) 预热缩略图缓存(提前生成未缓存的缩略图)
优先处理使用次数高的表情包 优先处理使用次数高的表情包
Args: Args:
limit: 最多预热数量 (1-1000) limit: 最多预热数量 (1-1000)
Returns: Returns:
预热结果 预热结果
""" """
try: try:
verify_auth_token(maibot_session, authorization) verify_auth_token(maibot_session, authorization)
_ensure_thumbnail_cache_dir() _ensure_thumbnail_cache_dir()
# 获取使用次数最高的表情包(未缓存的优先) # 获取使用次数最高的表情包(未缓存的优先)
emojis = ( emojis = (
Emoji.select() Emoji.select()
@@ -1210,41 +1220,36 @@ async def preheat_thumbnail_cache(
.order_by(Emoji.usage_count.desc()) .order_by(Emoji.usage_count.desc())
.limit(limit * 2) # 多查一些,因为有些可能已缓存 .limit(limit * 2) # 多查一些,因为有些可能已缓存
) )
generated = 0 generated = 0
skipped = 0 skipped = 0
failed = 0 failed = 0
for emoji in emojis: for emoji in emojis:
if generated >= limit: if generated >= limit:
break break
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash) cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
# 已缓存,跳过 # 已缓存,跳过
if cache_path.exists(): if cache_path.exists():
skipped += 1 skipped += 1
continue continue
# 原文件不存在,跳过 # 原文件不存在,跳过
if not os.path.exists(emoji.full_path): if not os.path.exists(emoji.full_path):
failed += 1 failed += 1
continue continue
try: try:
# 使用线程池异步生成缩略图,避免阻塞事件循环 # 使用线程池异步生成缩略图,避免阻塞事件循环
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
await loop.run_in_executor( await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.emoji_hash)
_thumbnail_executor,
_generate_thumbnail,
emoji.full_path,
emoji.emoji_hash
)
generated += 1 generated += 1
except Exception as e: except Exception as e:
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}") logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}")
failed += 1 failed += 1
return ThumbnailPreheatResponse( return ThumbnailPreheatResponse(
success=True, success=True,
message=f"预热完成:生成 {generated} 个,跳过 {skipped} 个已缓存,失败 {failed}", message=f"预热完成:生成 {generated} 个,跳过 {skipped} 个已缓存,失败 {failed}",
@@ -1252,7 +1257,7 @@ async def preheat_thumbnail_cache(
skipped_count=skipped, skipped_count=skipped,
failed_count=failed, failed_count=failed,
) )
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@@ -1267,13 +1272,13 @@ async def clear_all_thumbnail_cache(
): ):
""" """
清空所有缩略图缓存(下次访问时会重新生成) 清空所有缩略图缓存(下次访问时会重新生成)
Returns: Returns:
清理结果 清理结果
""" """
try: try:
verify_auth_token(maibot_session, authorization) verify_auth_token(maibot_session, authorization)
if not THUMBNAIL_CACHE_DIR.exists(): if not THUMBNAIL_CACHE_DIR.exists():
return ThumbnailCleanupResponse( return ThumbnailCleanupResponse(
success=True, success=True,
@@ -1281,7 +1286,7 @@ async def clear_all_thumbnail_cache(
cleaned_count=0, cleaned_count=0,
kept_count=0, kept_count=0,
) )
cleaned = 0 cleaned = 0
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"): for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
try: try:
@@ -1289,16 +1294,16 @@ async def clear_all_thumbnail_cache(
cleaned += 1 cleaned += 1
except Exception as e: except Exception as e:
logger.warning(f"删除缓存文件失败 {cache_file.name}: {e}") logger.warning(f"删除缓存文件失败 {cache_file.name}: {e}")
logger.info(f"已清空缩略图缓存: 删除 {cleaned} 个文件") logger.info(f"已清空缩略图缓存: 删除 {cleaned} 个文件")
return ThumbnailCleanupResponse( return ThumbnailCleanupResponse(
success=True, success=True,
message=f"已清空所有缩略图缓存:删除 {cleaned} 个文件", message=f"已清空所有缩略图缓存:删除 {cleaned} 个文件",
cleaned_count=cleaned, cleaned_count=cleaned,
kept_count=0, kept_count=0,
) )
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:

View File

@@ -256,7 +256,9 @@ async def get_expression_list(
@router.get("/{expression_id}", response_model=ExpressionDetailResponse) @router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def get_expression_detail(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
获取表达方式详细信息 获取表达方式详细信息
@@ -285,7 +287,11 @@ async def get_expression_detail(expression_id: int, maibot_session: Optional[str
@router.post("/", response_model=ExpressionCreateResponse) @router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(request: ExpressionCreateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def create_expression(
request: ExpressionCreateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
""" """
创建新的表达方式 创建新的表达方式
@@ -326,7 +332,10 @@ async def create_expression(request: ExpressionCreateRequest, maibot_session: Op
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse) @router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression( async def update_expression(
expression_id: int, request: ExpressionUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) expression_id: int,
request: ExpressionUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
): ):
""" """
增量更新表达方式(只更新提供的字段) 增量更新表达方式(只更新提供的字段)
@@ -376,7 +385,9 @@ async def update_expression(
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse) @router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def delete_expression(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
删除表达方式 删除表达方式
@@ -419,7 +430,11 @@ class BatchDeleteRequest(BaseModel):
@router.post("/batch/delete", response_model=ExpressionDeleteResponse) @router.post("/batch/delete", response_model=ExpressionDeleteResponse)
async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def batch_delete_expressions(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
""" """
批量删除表达方式 批量删除表达方式
@@ -460,7 +475,9 @@ async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session:
@router.get("/stats/summary") @router.get("/stats/summary")
async def get_expression_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def get_expression_stats(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
获取表达方式统计数据 获取表达方式统计数据

View File

@@ -24,7 +24,7 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
""" """
if not chat_id_str: if not chat_id_str:
return [] return []
try: try:
# 尝试解析为 JSON # 尝试解析为 JSON
parsed = json.loads(chat_id_str) parsed = json.loads(chat_id_str)
@@ -49,10 +49,10 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str:
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称 尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
""" """
stream_ids = parse_chat_id_to_stream_ids(chat_id_str) stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
if not stream_ids: if not stream_ids:
return chat_id_str return chat_id_str
# 查询所有 stream_id 对应的名称 # 查询所有 stream_id 对应的名称
names = [] names = []
for stream_id in stream_ids: for stream_id in stream_ids:
@@ -62,7 +62,7 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str:
else: else:
# 如果没找到,显示截断的 stream_id # 如果没找到,显示截断的 stream_id
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id) names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
return ", ".join(names) if names else chat_id_str return ", ".join(names) if names else chat_id_str
@@ -187,7 +187,7 @@ def jargon_to_dict(jargon: Jargon) -> dict:
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else [] stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
stream_id = stream_ids[0] if stream_ids else None stream_id = stream_ids[0] if stream_ids else None
return { return {
"id": jargon.id, "id": jargon.id,
"content": jargon.content, "content": jargon.content,
@@ -277,17 +277,13 @@ async def get_chat_list():
"""获取所有有黑话记录的聊天列表""" """获取所有有黑话记录的聊天列表"""
try: try:
# 获取所有不同的 chat_id # 获取所有不同的 chat_id
chat_ids = ( chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
Jargon.select(Jargon.chat_id)
.distinct()
.where(Jargon.chat_id.is_null(False))
)
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id] chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
# 用于按 stream_id 去重 # 用于按 stream_id 去重
seen_stream_ids: set[str] = set() seen_stream_ids: set[str] = set()
for chat_id in chat_id_list: for chat_id in chat_id_list:
stream_ids = parse_chat_id_to_stream_ids(chat_id) stream_ids = parse_chat_id_to_stream_ids(chat_id)
if stream_ids: if stream_ids:
@@ -346,12 +342,7 @@ async def get_jargon_stats():
complete_count = Jargon.select().where(Jargon.is_complete).count() complete_count = Jargon.select().where(Jargon.is_complete).count()
# 关联的聊天数量 # 关联的聊天数量
chat_count = ( chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
Jargon.select(Jargon.chat_id)
.distinct()
.where(Jargon.chat_id.is_null(False))
.count()
)
# 按聊天统计 TOP 5 # 按聊天统计 TOP 5
top_chats = ( top_chats = (
@@ -403,9 +394,7 @@ async def create_jargon(request: JargonCreateRequest):
"""创建黑话""" """创建黑话"""
try: try:
# 检查是否已存在相同内容的黑话 # 检查是否已存在相同内容的黑话
existing = Jargon.get_or_none( existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
(Jargon.content == request.content) & (Jargon.chat_id == request.chat_id)
)
if existing: if existing:
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话") raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
@@ -527,11 +516,7 @@ async def batch_set_jargon_status(
if not ids: if not ids:
raise HTTPException(status_code=400, detail="ID列表不能为空") raise HTTPException(status_code=400, detail="ID列表不能为空")
updated_count = ( updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
Jargon.update(is_jargon=is_jargon)
.where(Jargon.id.in_(ids))
.execute()
)
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}") logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}")

View File

@@ -200,7 +200,9 @@ async def get_person_list(
@router.get("/{person_id}", response_model=PersonDetailResponse) @router.get("/{person_id}", response_model=PersonDetailResponse)
async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def get_person_detail(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
获取人物详细信息 获取人物详细信息
@@ -229,7 +231,12 @@ async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cook
@router.patch("/{person_id}", response_model=PersonUpdateResponse) @router.patch("/{person_id}", response_model=PersonUpdateResponse)
async def update_person(person_id: str, request: PersonUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def update_person(
person_id: str,
request: PersonUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
""" """
增量更新人物信息(只更新提供的字段) 增量更新人物信息(只更新提供的字段)
@@ -278,7 +285,9 @@ async def update_person(person_id: str, request: PersonUpdateRequest, maibot_ses
@router.delete("/{person_id}", response_model=PersonDeleteResponse) @router.delete("/{person_id}", response_model=PersonDeleteResponse)
async def delete_person(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def delete_person(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
删除人物信息 删除人物信息
@@ -348,7 +357,11 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
@router.post("/batch/delete", response_model=BatchDeleteResponse) @router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_persons(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def batch_delete_persons(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
""" """
批量删除人物信息 批量删除人物信息

View File

@@ -125,6 +125,7 @@ def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> No
""" """
根据 schema 将配置中的类型纠正(目前只纠正 list-from-str 根据 schema 将配置中的类型纠正(目前只纠正 list-from-str
""" """
def _is_list_type(tp: Any) -> bool: def _is_list_type(tp: Any) -> bool:
origin = get_origin(tp) origin = get_origin(tp)
return tp is list or origin is list return tp is list or origin is list
@@ -313,7 +314,9 @@ async def check_git_status() -> GitStatusResponse:
@router.get("/mirrors", response_model=AvailableMirrorsResponse) @router.get("/mirrors", response_model=AvailableMirrorsResponse)
async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse: async def get_available_mirrors(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> AvailableMirrorsResponse:
""" """
获取所有可用的镜像源配置 获取所有可用的镜像源配置
""" """
@@ -343,7 +346,9 @@ async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), au
@router.post("/mirrors", response_model=MirrorConfigResponse) @router.post("/mirrors", response_model=MirrorConfigResponse)
async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> MirrorConfigResponse: async def add_mirror(
request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> MirrorConfigResponse:
""" """
添加新的镜像源 添加新的镜像源
""" """
@@ -383,7 +388,10 @@ async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] =
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse) @router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
async def update_mirror( async def update_mirror(
mirror_id: str, request: UpdateMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) mirror_id: str,
request: UpdateMirrorRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> MirrorConfigResponse: ) -> MirrorConfigResponse:
""" """
更新镜像源配置 更新镜像源配置
@@ -426,7 +434,9 @@ async def update_mirror(
@router.delete("/mirrors/{mirror_id}") @router.delete("/mirrors/{mirror_id}")
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: async def delete_mirror(
mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
删除镜像源 删除镜像源
""" """
@@ -449,7 +459,9 @@ async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(N
@router.post("/fetch-raw", response_model=FetchRawFileResponse) @router.post("/fetch-raw", response_model=FetchRawFileResponse)
async def fetch_raw_file( async def fetch_raw_file(
request: FetchRawFileRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) request: FetchRawFileRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> FetchRawFileResponse: ) -> FetchRawFileResponse:
""" """
获取 GitHub 仓库的 Raw 文件内容 获取 GitHub 仓库的 Raw 文件内容
@@ -534,7 +546,9 @@ async def fetch_raw_file(
@router.post("/clone", response_model=CloneRepositoryResponse) @router.post("/clone", response_model=CloneRepositoryResponse)
async def clone_repository( async def clone_repository(
request: CloneRepositoryRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) request: CloneRepositoryRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> CloneRepositoryResponse: ) -> CloneRepositoryResponse:
""" """
克隆 GitHub 仓库到本地 克隆 GitHub 仓库到本地
@@ -574,7 +588,11 @@ async def clone_repository(
@router.post("/install") @router.post("/install")
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: async def install_plugin(
request: InstallPluginRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]:
""" """
安装插件 安装插件
@@ -778,7 +796,9 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
@router.post("/uninstall") @router.post("/uninstall")
async def uninstall_plugin( async def uninstall_plugin(
request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) request: UninstallPluginRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
卸载插件 卸载插件
@@ -913,7 +933,11 @@ async def uninstall_plugin(
@router.post("/update") @router.post("/update")
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: async def update_plugin(
request: UpdatePluginRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]:
""" """
更新插件 更新插件
@@ -1132,7 +1156,9 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
@router.get("/installed") @router.get("/installed")
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: async def get_installed_plugins(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
获取已安装的插件列表 获取已安装的插件列表
@@ -1272,7 +1298,9 @@ class UpdatePluginConfigRequest(BaseModel):
@router.get("/config/{plugin_id}/schema") @router.get("/config/{plugin_id}/schema")
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: async def get_plugin_config_schema(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
获取插件配置 Schema 获取插件配置 Schema
@@ -1405,7 +1433,9 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
@router.get("/config/{plugin_id}") @router.get("/config/{plugin_id}")
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: async def get_plugin_config(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
获取插件当前配置值 获取插件当前配置值
@@ -1461,7 +1491,10 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook
@router.put("/config/{plugin_id}") @router.put("/config/{plugin_id}")
async def update_plugin_config( async def update_plugin_config(
plugin_id: str, request: UpdatePluginConfigRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) plugin_id: str,
request: UpdatePluginConfigRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
更新插件配置 更新插件配置
@@ -1532,7 +1565,9 @@ async def update_plugin_config(
@router.post("/config/{plugin_id}/reset") @router.post("/config/{plugin_id}/reset")
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: async def reset_plugin_config(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
重置插件配置为默认值 重置插件配置为默认值
@@ -1592,7 +1627,9 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co
@router.post("/config/{plugin_id}/toggle") @router.post("/config/{plugin_id}/toggle")
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: async def toggle_plugin(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
切换插件启用状态 切换插件启用状态

View File

@@ -139,10 +139,10 @@ async def verify_token(request: TokenVerifyRequest, response: Response):
async def logout(response: Response): async def logout(response: Response):
""" """
登出并清除认证 Cookie 登出并清除认证 Cookie
Args: Args:
response: FastAPI Response 对象 response: FastAPI Response 对象
Returns: Returns:
登出结果 登出结果
""" """
@@ -158,23 +158,23 @@ async def check_auth_status(
): ):
""" """
检查当前认证状态(用于前端判断是否已登录) 检查当前认证状态(用于前端判断是否已登录)
Returns: Returns:
认证状态 认证状态
""" """
try: try:
token = None token = None
# 优先从 Cookie 获取 # 优先从 Cookie 获取
if maibot_session: if maibot_session:
token = maibot_session token = maibot_session
# 其次从 Header 获取 # 其次从 Header 获取
elif authorization and authorization.startswith("Bearer "): elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "") token = authorization.replace("Bearer ", "")
if not token: if not token:
return {"authenticated": False} return {"authenticated": False}
token_manager = get_token_manager() token_manager = get_token_manager()
if token_manager.verify_token(token): if token_manager.verify_token(token):
return {"authenticated": True} return {"authenticated": True}
@@ -211,7 +211,7 @@ async def update_token(
current_token = maibot_session current_token = maibot_session
elif authorization and authorization.startswith("Bearer "): elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "") current_token = authorization.replace("Bearer ", "")
if not current_token: if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息") raise HTTPException(status_code=401, detail="未提供有效的认证信息")
@@ -222,7 +222,7 @@ async def update_token(
# 更新 token # 更新 token
success, message = token_manager.update_token(request.new_token) success, message = token_manager.update_token(request.new_token)
# 如果更新成功,清除 Cookie要求用户重新登录 # 如果更新成功,清除 Cookie要求用户重新登录
if success: if success:
clear_auth_cookie(response) clear_auth_cookie(response)
@@ -263,7 +263,7 @@ async def regenerate_token(
if not current_token: if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息") raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager() token_manager = get_token_manager()
if not token_manager.verify_token(current_token): if not token_manager.verify_token(current_token):
@@ -271,7 +271,7 @@ async def regenerate_token(
# 重新生成 token # 重新生成 token
new_token = token_manager.regenerate_token() new_token = token_manager.regenerate_token()
# 清除 Cookie要求用户重新登录 # 清除 Cookie要求用户重新登录
clear_auth_cookie(response) clear_auth_cookie(response)
@@ -306,7 +306,7 @@ async def get_setup_status(
current_token = maibot_session current_token = maibot_session
elif authorization and authorization.startswith("Bearer "): elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "") current_token = authorization.replace("Bearer ", "")
if not current_token: if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息") raise HTTPException(status_code=401, detail="未提供有效的认证信息")
@@ -349,7 +349,7 @@ async def complete_setup(
current_token = maibot_session current_token = maibot_session
elif authorization and authorization.startswith("Bearer "): elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "") current_token = authorization.replace("Bearer ", "")
if not current_token: if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息") raise HTTPException(status_code=401, detail="未提供有效的认证信息")
@@ -392,7 +392,7 @@ async def reset_setup(
current_token = maibot_session current_token = maibot_session
elif authorization and authorization.startswith("Bearer "): elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "") current_token = authorization.replace("Bearer ", "")
if not current_token: if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息") raise HTTPException(status_code=401, detail="未提供有效的认证信息")

View File

@@ -166,22 +166,22 @@ class TokenManager:
str: 新生成的 token str: 新生成的 token
""" """
logger.info("正在重新生成 WebUI Token...") logger.info("正在重新生成 WebUI Token...")
# 生成新的 64 位十六进制字符串 # 生成新的 64 位十六进制字符串
new_token = secrets.token_hex(32) new_token = secrets.token_hex(32)
# 加载现有配置,保留 first_setup_completed 状态 # 加载现有配置,保留 first_setup_completed 状态
config = self._load_config() config = self._load_config()
old_token = config.get("access_token", "")[:8] if config.get("access_token") else "" old_token = config.get("access_token", "")[:8] if config.get("access_token") else ""
first_setup_completed = config.get("first_setup_completed", True) # 默认为 True表示已完成配置 first_setup_completed = config.get("first_setup_completed", True) # 默认为 True表示已完成配置
config["access_token"] = new_token config["access_token"] = new_token
config["updated_at"] = self._get_current_timestamp() config["updated_at"] = self._get_current_timestamp()
config["first_setup_completed"] = first_setup_completed # 保留原来的状态 config["first_setup_completed"] = first_setup_completed # 保留原来的状态
self._save_config(config) self._save_config(config)
logger.info(f"WebUI Token 已重新生成: {old_token}... -> {new_token[:8]}...") logger.info(f"WebUI Token 已重新生成: {old_token}... -> {new_token[:8]}...")
return new_token return new_token
def _validate_token_format(self, token: str) -> bool: def _validate_token_format(self, token: str) -> bool:

View File

@@ -110,8 +110,10 @@ class WebUIServer:
from src.webui.routes import router as webui_router from src.webui.routes import router as webui_router
from src.webui.logs_ws import router as logs_router from src.webui.logs_ws import router as logs_router
from src.webui.knowledge_routes import router as knowledge_router from src.webui.knowledge_routes import router as knowledge_router
# 导入本地聊天室路由 # 导入本地聊天室路由
from src.webui.chat_routes import router as chat_router from src.webui.chat_routes import router as chat_router
# 注册路由 # 注册路由
self.app.include_router(webui_router) self.app.include_router(webui_router)
self.app.include_router(logs_router) self.app.include_router(logs_router)
@@ -166,6 +168,7 @@ class WebUIServer:
def _check_port_available(self) -> bool: def _check_port_available(self) -> bool:
"""检查端口是否可用""" """检查端口是否可用"""
import socket import socket
try: try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1) s.settimeout(1)