Ruff format

This commit is contained in:
墨梓柒
2025-12-13 17:14:09 +08:00
parent ef377bb0cd
commit e680a4d1f5
60 changed files with 1546 additions and 1532 deletions

4
bot.py
View File

@@ -41,6 +41,7 @@ logger = get_logger("main")
# 定义重启退出码
RESTART_EXIT_CODE = 42
def run_runner_process():
"""
Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。
@@ -68,7 +69,7 @@ def run_runner_process():
if return_code == RESTART_EXIT_CODE:
logger.info("检测到重启请求 (退出码 42),正在重启...")
time.sleep(1) # 稍作等待
time.sleep(1) # 稍作等待
continue
else:
logger.info(f"程序已退出 (退出码 {return_code})")
@@ -87,6 +88,7 @@ def run_runner_process():
process.kill()
sys.exit(0)
# 检查是否是 Worker 进程
# 如果没有设置 MAIBOT_WORKER_PROCESS 环境变量,说明是直接运行的脚本,
# 此时应该作为 Runner 运行。

View File

@@ -19,6 +19,7 @@ from typing import Any, Dict, List, Optional, Tuple
@dataclass
class ConversionResult:
"""转换结果"""
success: bool
servers: List[Dict[str, Any]] = field(default_factory=list)
errors: List[str] = field(default_factory=list)
@@ -271,11 +272,7 @@ class ConfigConverter:
return name, result
@classmethod
def from_claude_format(
cls,
config: Dict[str, Any],
existing_names: Optional[set] = None
) -> ConversionResult:
def from_claude_format(cls, config: Dict[str, Any], existing_names: Optional[set] = None) -> ConversionResult:
"""从 Claude Desktop 格式转换为 MaiBot 格式
Args:
@@ -355,11 +352,7 @@ class ConfigConverter:
return {"mcpServers": mcp_servers}
@classmethod
def import_from_string(
cls,
json_str: str,
existing_names: Optional[set] = None
) -> ConversionResult:
def import_from_string(cls, json_str: str, existing_names: Optional[set] = None) -> ConversionResult:
"""从 JSON 字符串导入配置
自动检测格式并转换为 MaiBot 格式
@@ -422,12 +415,7 @@ class ConfigConverter:
return result
@classmethod
def export_to_string(
cls,
servers: List[Dict[str, Any]],
format_type: str = "claude",
pretty: bool = True
) -> str:
def export_to_string(cls, servers: List[Dict[str, Any]], format_type: str = "claude", pretty: bool = True) -> str:
"""导出配置为 JSON 字符串
Args:

View File

@@ -34,28 +34,31 @@ from enum import Enum
# 尝试导入 MaiBot 的 logger如果失败则使用标准 logging
try:
from src.common.logger import get_logger
logger = get_logger("mcp_client")
except ImportError:
# Fallback: 使用标准 logging
logger = logging.getLogger("mcp_client")
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('[%(levelname)s] %(name)s: %(message)s'))
handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s"))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
class TransportType(Enum):
"""MCP 传输类型"""
STDIO = "stdio" # 本地进程通信
SSE = "sse" # Server-Sent Events (旧版 HTTP)
HTTP = "http" # HTTP Streamable (新版,推荐)
STDIO = "stdio" # 本地进程通信
SSE = "sse" # Server-Sent Events (旧版 HTTP)
HTTP = "http" # HTTP Streamable (新版,推荐)
STREAMABLE_HTTP = "streamable_http" # HTTP Streamable 的别名
@dataclass
class MCPToolInfo:
"""MCP 工具信息"""
name: str
description: str
input_schema: Dict[str, Any]
@@ -65,6 +68,7 @@ class MCPToolInfo:
@dataclass
class MCPResourceInfo:
"""MCP 资源信息"""
uri: str
name: str
description: str
@@ -75,6 +79,7 @@ class MCPResourceInfo:
@dataclass
class MCPPromptInfo:
"""MCP 提示模板信息"""
name: str
description: str
arguments: List[Dict[str, Any]] # [{name, description, required}]
@@ -84,6 +89,7 @@ class MCPPromptInfo:
@dataclass
class MCPServerConfig:
"""MCP 服务器配置"""
name: str
enabled: bool = True
transport: TransportType = TransportType.STDIO
@@ -99,6 +105,7 @@ class MCPServerConfig:
@dataclass
class MCPCallResult:
"""MCP 工具调用结果"""
success: bool
content: Any
error: Optional[str] = None
@@ -108,8 +115,9 @@ class MCPCallResult:
class CircuitState(Enum):
"""断路器状态"""
CLOSED = "closed" # 正常状态,允许请求
OPEN = "open" # 熔断状态,拒绝请求
CLOSED = "closed" # 正常状态,允许请求
OPEN = "open" # 熔断状态,拒绝请求
HALF_OPEN = "half_open" # 半开状态,允许少量试探请求
@@ -125,9 +133,9 @@ class CircuitBreaker:
"""
# 配置
failure_threshold: int = 5 # 连续失败多少次后熔断
recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒)
half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用
failure_threshold: int = 5 # 连续失败多少次后熔断
recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒)
half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用
# 状态
state: CircuitState = field(default=CircuitState.CLOSED)
@@ -232,6 +240,7 @@ class CircuitBreaker:
@dataclass
class ToolCallStats:
"""工具调用统计"""
tool_key: str
total_calls: int = 0
success_calls: int = 0
@@ -282,6 +291,7 @@ class ToolCallStats:
@dataclass
class ServerStats:
"""服务器统计"""
server_name: str
connect_count: int = 0 # 连接次数
disconnect_count: int = 0 # 断开次数
@@ -442,9 +452,7 @@ class MCPClientSession:
return False
server_params = StdioServerParameters(
command=self.config.command,
args=self.config.args,
env=self.config.env if self.config.env else None
command=self.config.command, args=self.config.args, env=self.config.env if self.config.env else None
)
self._stdio_context = stdio_client(server_params)
@@ -506,6 +514,7 @@ class MCPClientSession:
except Exception as e:
logger.error(f"[{self.server_name}] SSE 连接失败: {e}")
import traceback
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
await self._cleanup()
return False
@@ -551,6 +560,7 @@ class MCPClientSession:
except Exception as e:
logger.error(f"[{self.server_name}] HTTP 连接失败: {e}")
import traceback
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
await self._cleanup()
return False
@@ -568,8 +578,8 @@ class MCPClientSession:
tool_info = MCPToolInfo(
name=tool.name,
description=tool.description or f"MCP tool: {tool.name}",
input_schema=tool.inputSchema if hasattr(tool, 'inputSchema') else {},
server_name=self.server_name
input_schema=tool.inputSchema if hasattr(tool, "inputSchema") else {},
server_name=self.server_name,
)
self._tools.append(tool_info)
# 初始化工具统计
@@ -591,10 +601,7 @@ class MCPClientSession:
return False
try:
result = await asyncio.wait_for(
self._session.list_resources(),
timeout=self.call_timeout
)
result = await asyncio.wait_for(self._session.list_resources(), timeout=self.call_timeout)
self._resources = []
for resource in result.resources:
@@ -602,8 +609,8 @@ class MCPClientSession:
uri=str(resource.uri),
name=resource.name or str(resource.uri),
description=resource.description or "",
mime_type=resource.mimeType if hasattr(resource, 'mimeType') else None,
server_name=self.server_name
mime_type=resource.mimeType if hasattr(resource, "mimeType") else None,
server_name=self.server_name,
)
self._resources.append(resource_info)
logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}")
@@ -633,28 +640,27 @@ class MCPClientSession:
return False
try:
result = await asyncio.wait_for(
self._session.list_prompts(),
timeout=self.call_timeout
)
result = await asyncio.wait_for(self._session.list_prompts(), timeout=self.call_timeout)
self._prompts = []
for prompt in result.prompts:
# 解析参数
arguments = []
if hasattr(prompt, 'arguments') and prompt.arguments:
if hasattr(prompt, "arguments") and prompt.arguments:
for arg in prompt.arguments:
arguments.append({
"name": arg.name,
"description": arg.description or "",
"required": arg.required if hasattr(arg, 'required') else False,
})
arguments.append(
{
"name": arg.name,
"description": arg.description or "",
"required": arg.required if hasattr(arg, "required") else False,
}
)
prompt_info = MCPPromptInfo(
name=prompt.name,
description=prompt.description or f"MCP prompt: {prompt.name}",
arguments=arguments,
server_name=self.server_name
server_name=self.server_name,
)
self._prompts.append(prompt_info)
logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}")
@@ -686,35 +692,25 @@ class MCPClientSession:
start_time = time.time()
if not self._connected or not self._session:
return MCPCallResult(
success=False,
content=None,
error=f"服务器 {self.server_name} 未连接"
)
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
if not self._supports_resources:
return MCPCallResult(
success=False,
content=None,
error=f"服务器 {self.server_name} 不支持 Resources 功能"
)
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Resources 功能")
try:
result = await asyncio.wait_for(
self._session.read_resource(uri),
timeout=self.call_timeout
)
result = await asyncio.wait_for(self._session.read_resource(uri), timeout=self.call_timeout)
duration_ms = (time.time() - start_time) * 1000
# 处理返回内容
content_parts = []
for content in result.contents:
if hasattr(content, 'text'):
if hasattr(content, "text"):
content_parts.append(content.text)
elif hasattr(content, 'blob'):
elif hasattr(content, "blob"):
# 二进制数据,返回 base64 或提示
import base64
blob_data = content.blob
if len(blob_data) < 10000: # 小于 10KB 返回 base64
content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}")
@@ -724,28 +720,18 @@ class MCPClientSession:
content_parts.append(str(content))
return MCPCallResult(
success=True,
content="\n".join(content_parts) if content_parts else "",
duration_ms=duration_ms
success=True, content="\n".join(content_parts) if content_parts else "", duration_ms=duration_ms
)
except asyncio.TimeoutError:
duration_ms = (time.time() - start_time) * 1000
return MCPCallResult(
success=False,
content=None,
error=f"读取资源超时({self.call_timeout}秒)",
duration_ms=duration_ms
success=False, content=None, error=f"读取资源超时({self.call_timeout}秒)", duration_ms=duration_ms
)
except Exception as e:
duration_ms = (time.time() - start_time) * 1000
logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}")
return MCPCallResult(
success=False,
content=None,
error=str(e),
duration_ms=duration_ms
)
return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult:
"""v1.2.0: 获取提示模板的内容
@@ -760,23 +746,14 @@ class MCPClientSession:
start_time = time.time()
if not self._connected or not self._session:
return MCPCallResult(
success=False,
content=None,
error=f"服务器 {self.server_name} 未连接"
)
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
if not self._supports_prompts:
return MCPCallResult(
success=False,
content=None,
error=f"服务器 {self.server_name} 不支持 Prompts 功能"
)
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Prompts 功能")
try:
result = await asyncio.wait_for(
self._session.get_prompt(name, arguments=arguments or {}),
timeout=self.call_timeout
self._session.get_prompt(name, arguments=arguments or {}), timeout=self.call_timeout
)
duration_ms = (time.time() - start_time) * 1000
@@ -784,10 +761,10 @@ class MCPClientSession:
# 处理返回的消息
messages = []
for msg in result.messages:
role = msg.role if hasattr(msg, 'role') else "unknown"
role = msg.role if hasattr(msg, "role") else "unknown"
content_text = ""
if hasattr(msg, 'content'):
if hasattr(msg.content, 'text'):
if hasattr(msg, "content"):
if hasattr(msg.content, "text"):
content_text = msg.content.text
elif isinstance(msg.content, str):
content_text = msg.content
@@ -796,28 +773,18 @@ class MCPClientSession:
messages.append(f"[{role}]: {content_text}")
return MCPCallResult(
success=True,
content="\n\n".join(messages) if messages else "",
duration_ms=duration_ms
success=True, content="\n\n".join(messages) if messages else "", duration_ms=duration_ms
)
except asyncio.TimeoutError:
duration_ms = (time.time() - start_time) * 1000
return MCPCallResult(
success=False,
content=None,
error=f"获取提示模板超时({self.call_timeout}秒)",
duration_ms=duration_ms
success=False, content=None, error=f"获取提示模板超时({self.call_timeout}秒)", duration_ms=duration_ms
)
except Exception as e:
duration_ms = (time.time() - start_time) * 1000
logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}")
return MCPCallResult(
success=False,
content=None,
error=str(e),
duration_ms=duration_ms
)
return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
async def check_health(self) -> bool:
"""检查连接健康状态(心跳检测)
@@ -829,10 +796,7 @@ class MCPClientSession:
try:
# 使用 list_tools 作为心跳检测
await asyncio.wait_for(
self._session.list_tools(),
timeout=10.0
)
await asyncio.wait_for(self._session.list_tools(), timeout=10.0)
self.stats.record_heartbeat()
return True
except Exception as e:
@@ -849,12 +813,7 @@ class MCPClientSession:
# v1.7.0: 断路器检查
can_execute, reject_reason = self._circuit_breaker.can_execute()
if not can_execute:
return MCPCallResult(
success=False,
content=None,
error=f"{reject_reason}",
circuit_broken=True
)
return MCPCallResult(success=False, content=None, error=f"{reject_reason}", circuit_broken=True)
# 半开状态下增加试探计数
if self._circuit_breaker.state == CircuitState.HALF_OPEN:
@@ -870,8 +829,7 @@ class MCPClientSession:
try:
result = await asyncio.wait_for(
self._session.call_tool(tool_name, arguments=arguments),
timeout=self.call_timeout
self._session.call_tool(tool_name, arguments=arguments), timeout=self.call_timeout
)
duration_ms = (time.time() - start_time) * 1000
@@ -879,9 +837,9 @@ class MCPClientSession:
# 处理返回内容
content_parts = []
for content in result.content:
if hasattr(content, 'text'):
if hasattr(content, "text"):
content_parts.append(content.text)
elif hasattr(content, 'data'):
elif hasattr(content, "data"):
content_parts.append(f"[二进制数据: {len(content.data)} bytes]")
else:
content_parts.append(str(content))
@@ -896,7 +854,7 @@ class MCPClientSession:
return MCPCallResult(
success=True,
content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)",
duration_ms=duration_ms
duration_ms=duration_ms,
)
except asyncio.TimeoutError:
@@ -939,25 +897,25 @@ class MCPClientSession:
self._supports_prompts = False # v1.2.0
try:
if hasattr(self, '_session_context') and self._session_context:
if hasattr(self, "_session_context") and self._session_context:
await self._session_context.__aexit__(None, None, None)
except Exception as e:
logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}")
try:
if hasattr(self, '_stdio_context') and self._stdio_context:
if hasattr(self, "_stdio_context") and self._stdio_context:
await self._stdio_context.__aexit__(None, None, None)
except Exception as e:
logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}")
try:
if hasattr(self, '_http_context') and self._http_context:
if hasattr(self, "_http_context") and self._http_context:
await self._http_context.__aexit__(None, None, None)
except Exception as e:
logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}")
try:
if hasattr(self, '_sse_context') and self._sse_context:
if hasattr(self, "_sse_context") and self._sse_context:
await self._sse_context.__aexit__(None, None, None)
except Exception as e:
logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}")
@@ -1082,7 +1040,9 @@ class MCPClientManager:
return True
if attempt < retry_attempts:
logger.warning(f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})")
logger.warning(
f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})"
)
await asyncio.sleep(retry_interval)
logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})")
@@ -1213,11 +1173,7 @@ class MCPClientManager:
async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult:
"""调用 MCP 工具"""
if tool_key not in self._all_tools:
return MCPCallResult(
success=False,
content=None,
error=f"工具 {tool_key} 不存在"
)
return MCPCallResult(success=False, content=None, error=f"工具 {tool_key} 不存在")
tool_info, client = self._all_tools[tool_key]
@@ -1273,11 +1229,7 @@ class MCPClientManager:
# 如果指定了服务器
if server_name:
if server_name not in self._clients:
return MCPCallResult(
success=False,
content=None,
error=f"服务器 {server_name} 不存在"
)
return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
client = self._clients[server_name]
return await client.read_resource(uri)
@@ -1293,14 +1245,11 @@ class MCPClientManager:
if result.success:
return result
return MCPCallResult(
success=False,
content=None,
error=f"未找到资源: {uri}"
)
return MCPCallResult(success=False, content=None, error=f"未找到资源: {uri}")
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None,
server_name: Optional[str] = None) -> MCPCallResult:
async def get_prompt(
self, name: str, arguments: Optional[Dict[str, str]] = None, server_name: Optional[str] = None
) -> MCPCallResult:
"""v1.2.0: 获取提示模板内容
Args:
@@ -1311,11 +1260,7 @@ class MCPClientManager:
# 如果指定了服务器
if server_name:
if server_name not in self._clients:
return MCPCallResult(
success=False,
content=None,
error=f"服务器 {server_name} 不存在"
)
return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
client = self._clients[server_name]
return await client.get_prompt(name, arguments)
@@ -1324,11 +1269,7 @@ class MCPClientManager:
if prompt_info.name == name:
return await client.get_prompt(name, arguments)
return MCPCallResult(
success=False,
content=None,
error=f"未找到提示模板: {name}"
)
return MCPCallResult(success=False, content=None, error=f"未找到提示模板: {name}")
# ==================== 心跳检测 ====================
@@ -1489,7 +1430,9 @@ class MCPClientManager:
"global": {
**self._global_stats,
"uptime_seconds": round(uptime, 2),
"calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2) if uptime > 0 else 0,
"calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2)
if uptime > 0
else 0,
},
"servers": server_stats,
"tools": tool_stats,

View File

@@ -84,7 +84,7 @@ from .mcp_client import (
TransportType,
mcp_manager,
)
from .config_converter import ConfigConverter, ConversionResult
from .config_converter import ConfigConverter
logger = get_logger("mcp_bridge_plugin")
@@ -93,9 +93,11 @@ logger = get_logger("mcp_bridge_plugin")
# v1.4.0: 调用链路追踪
# ============================================================================
@dataclass
class ToolCallRecord:
"""工具调用记录"""
call_id: str
timestamp: float
tool_name: str
@@ -178,9 +180,11 @@ tool_call_tracer = ToolCallTracer()
# v1.4.0: 工具调用缓存
# ============================================================================
@dataclass
class CacheEntry:
"""缓存条目"""
tool_name: str
args_hash: str
result: str
@@ -317,6 +321,7 @@ tool_call_cache = ToolCallCache()
# v1.4.0: 工具权限控制
# ============================================================================
class PermissionChecker:
"""工具权限检查器"""
@@ -449,6 +454,7 @@ permission_checker = PermissionChecker()
# 工具类型转换
# ============================================================================
def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
"""将 JSON Schema 类型转换为 MaiBot 的 ToolParamType"""
type_mapping = {
@@ -462,7 +468,9 @@ def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
return type_mapping.get(json_type, ToolParamType.STRING)
def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]:
def parse_mcp_parameters(
input_schema: Dict[str, Any],
) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]:
"""解析 MCP 工具的参数 schema转换为 MaiBot 的参数格式"""
parameters = []
@@ -497,6 +505,7 @@ def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolPa
# MCP 工具代理
# ============================================================================
class MCPToolProxy(BaseTool):
"""MCP 工具代理基类"""
@@ -539,10 +548,7 @@ class MCPToolProxy(BaseTool):
# v1.4.0: 权限检查
if not permission_checker.check(self.name, chat_id, user_id, is_group):
logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}")
return {
"name": self.name,
"content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"
}
return {"name": self.name, "content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"}
logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}")
@@ -726,11 +732,7 @@ class MCPToolProxy(BaseTool):
return None
async def _call_post_process_llm(
self,
prompt: str,
max_tokens: int,
settings: Dict[str, Any],
server_config: Optional[Dict[str, Any]]
self, prompt: str, max_tokens: int, settings: Dict[str, Any], server_config: Optional[Dict[str, Any]]
) -> Optional[str]:
"""调用 LLM 进行后处理"""
from src.config.config import model_config
@@ -788,10 +790,7 @@ class MCPToolProxy(BaseTool):
def create_mcp_tool_class(
tool_key: str,
tool_info: MCPToolInfo,
tool_prefix: str,
disabled: bool = False
tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False
) -> Type[MCPToolProxy]:
"""根据 MCP 工具信息动态创建 BaseTool 子类"""
parameters = parse_mcp_parameters(tool_info.input_schema)
@@ -814,7 +813,7 @@ def create_mcp_tool_class(
"_mcp_tool_key": tool_key,
"_mcp_original_name": tool_info.name,
"_mcp_server_name": tool_info.server_name,
}
},
)
return tool_class
@@ -828,11 +827,7 @@ class MCPToolRegistry:
self._tool_infos: Dict[str, ToolInfo] = {}
def register_tool(
self,
tool_key: str,
tool_info: MCPToolInfo,
tool_prefix: str,
disabled: bool = False
self, tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False
) -> Tuple[ToolInfo, Type[MCPToolProxy]]:
"""注册 MCP 工具"""
tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled)
@@ -879,6 +874,7 @@ _plugin_instance: Optional["MCPBridgePlugin"] = None
# 内置工具
# ============================================================================
class MCPReadResourceTool(BaseTool):
"""v1.2.0: MCP 资源读取工具"""
@@ -950,9 +946,17 @@ class MCPStatusTool(BaseTool):
"""MCP 状态查询工具"""
name = "mcp_status"
description = "查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、资源列表、提示模板列表、调用统计、追踪记录等信息"
description = (
"查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、资源列表、提示模板列表、调用统计、追踪记录等信息"
)
parameters = [
("query_type", ToolParamType.STRING, "查询类型", False, ["status", "tools", "resources", "prompts", "stats", "trace", "cache", "all"]),
(
"query_type",
ToolParamType.STRING,
"查询类型",
False,
["status", "tools", "resources", "prompts", "stats", "trace", "cache", "all"],
),
("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None),
]
available_for_llm = True
@@ -986,10 +990,7 @@ class MCPStatusTool(BaseTool):
if query_type in ("cache",):
result_parts.append(self._format_cache())
return {
"name": self.name,
"content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"
}
return {"name": self.name, "content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"}
def _format_status(self, server_name: Optional[str] = None) -> str:
status = mcp_manager.get_status()
@@ -1001,14 +1002,14 @@ class MCPStatusTool(BaseTool):
lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}")
lines.append("\n🔌 服务器详情:")
for name, info in status['servers'].items():
for name, info in status["servers"].items():
if server_name and name != server_name:
continue
status_icon = "" if info['connected'] else ""
enabled_text = "" if info['enabled'] else " (已禁用)"
status_icon = "" if info["connected"] else ""
enabled_text = "" if info["enabled"] else " (已禁用)"
lines.append(f" {status_icon} {name}{enabled_text}")
lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}")
if info['consecutive_failures'] > 0:
if info["consecutive_failures"] > 0:
lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']}")
return "\n".join(lines)
@@ -1038,11 +1039,11 @@ class MCPStatusTool(BaseTool):
stats = mcp_manager.get_all_stats()
lines = ["📈 调用统计"]
g = stats['global']
g = stats["global"]
lines.append(f" 总调用次数: {g['total_tool_calls']}")
lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}")
if g['total_tool_calls'] > 0:
success_rate = (g['successful_calls'] / g['total_tool_calls']) * 100
if g["total_tool_calls"] > 0:
success_rate = (g["successful_calls"] / g["total_tool_calls"]) * 100
lines.append(f" 成功率: {success_rate:.1f}%")
lines.append(f" 运行时间: {g['uptime_seconds']:.0f}")
@@ -1126,6 +1127,7 @@ class MCPStatusTool(BaseTool):
# 命令处理
# ============================================================================
class MCPStatusCommand(BaseCommand):
"""MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态"""
@@ -1340,7 +1342,9 @@ class MCPStatusCommand(BaseCommand):
try:
exported = ConfigConverter.export_to_string(servers, format_type, pretty=True)
format_name = {"claude": "Claude Desktop", "kiro": "Kiro MCP", "maibot": "MaiBot"}.get(format_type, format_type)
format_name = {"claude": "Claude Desktop", "kiro": "Kiro MCP", "maibot": "MaiBot"}.get(
format_type, format_type
)
lines = [f"📤 导出为 {format_name} 格式 ({len(servers)} 个服务器):"]
lines.append("")
lines.append(exported)
@@ -1446,9 +1450,9 @@ class MCPStatusCommand(BaseCommand):
cb = info.get("circuit_breaker", {})
cb_state = cb.get("state", "closed")
if cb_state == "open":
lines.append(f" ⚡ 断路器熔断中")
lines.append(" ⚡ 断路器熔断中")
elif cb_state == "half_open":
lines.append(f" ⚡ 断路器试探中")
lines.append(" ⚡ 断路器试探中")
if info["consecutive_failures"] > 0:
lines.append(f" ⚠️ 连续失败 {info['consecutive_failures']}")
@@ -1634,6 +1638,7 @@ class MCPImportCommand(BaseCommand):
# 事件处理器
# ============================================================================
class MCPStartupHandler(BaseEventHandler):
"""MCP 启动事件处理器"""
@@ -1692,6 +1697,7 @@ class MCPStopHandler(BaseEventHandler):
# 主插件类
# ============================================================================
@register_plugin
class MCPBridgePlugin(BasePlugin):
"""MCP 桥接插件 v1.4.0 - 将 MCP 服务器的工具桥接到 MaiBot"""
@@ -2116,9 +2122,9 @@ class MCPBridgePlugin(BasePlugin):
label="📜 高级权限规则(可选)",
input_type="textarea",
rows=10,
placeholder='''[
placeholder="""[
{"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}
]''',
]""",
hint="格式: qq:ID:group/private/user工具名支持通配符 *",
order=10,
),
@@ -2261,7 +2267,9 @@ class MCPBridgePlugin(BasePlugin):
value = match1.group(2)
suffix = match1.group(3)
# 将转义的换行符还原为实际换行符
unescaped = value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\")
unescaped = (
value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\")
)
fixed_line = f'{prefix}"""{unescaped}"""{suffix}'
fixed_lines.append(fixed_line)
modified = True
@@ -2560,7 +2568,7 @@ class MCPBridgePlugin(BasePlugin):
continue
import_config_str = str(import_config).strip()
logger.info(f"检测到 WebUI 导入请求,开始处理...")
logger.info("检测到 WebUI 导入请求,开始处理...")
# 执行导入
await self._execute_webui_import(import_config_str, doc, config_path)
@@ -2773,6 +2781,7 @@ class MCPBridgePlugin(BasePlugin):
async def _async_connect_servers(self) -> None:
"""异步连接所有配置的 MCP 服务器v1.5.0: 并行连接优化)"""
import asyncio
settings = self.config.get("settings", {})
servers_section = self.config.get("servers", [])
@@ -2854,10 +2863,7 @@ class MCPBridgePlugin(BasePlugin):
# 并行执行所有连接
start_time = time.time()
results = await asyncio.gather(
*[connect_single_server(cfg) for cfg in enabled_configs],
return_exceptions=True
)
results = await asyncio.gather(*[connect_single_server(cfg) for cfg in enabled_configs], return_exceptions=True)
connect_duration = time.time() - start_time
# 统计连接结果
@@ -2878,15 +2884,14 @@ class MCPBridgePlugin(BasePlugin):
# 注册所有工具
from src.plugin_system.core.component_registry import component_registry
registered_count = 0
for tool_key, (tool_info, _) in mcp_manager.all_tools.items():
tool_name = tool_key.replace("-", "_").replace(".", "_")
is_disabled = tool_name in disabled_tools
info, tool_class = mcp_tool_registry.register_tool(
tool_key, tool_info, tool_prefix, disabled=is_disabled
)
info, tool_class = mcp_tool_registry.register_tool(tool_key, tool_info, tool_prefix, disabled=is_disabled)
info.plugin_name = self.plugin_name
if component_registry.register_component(info, tool_class):
@@ -3004,6 +3009,7 @@ class MCPBridgePlugin(BasePlugin):
doc["tools"] = tomlkit.table()
# 使用 tomlkit 多行字符串避免控制字符问题
from tomlkit.items import String, StringType, Trivia
ml_string = String(StringType.MLB, tool_list_text, tool_list_text, Trivia())
doc["tools"]["tool_list"] = ml_string
@@ -3069,6 +3075,7 @@ class MCPBridgePlugin(BasePlugin):
doc["status"] = tomlkit.table()
# 使用 tomlkit 多行字符串避免控制字符问题
from tomlkit.items import String, StringType, Trivia
ml_string = String(StringType.MLB, status_text, status_text, Trivia())
doc["status"]["connection_status"] = ml_string

View File

@@ -33,7 +33,7 @@ async def test_stats():
assert stats.total_calls == 3
assert stats.success_calls == 2
assert stats.failed_calls == 1
assert stats.success_rate == (2/3) * 100
assert stats.success_rate == (2 / 3) * 100
assert stats.avg_duration_ms == 150.0
assert stats.last_error == "timeout"
@@ -66,13 +66,15 @@ async def test_manager_basic():
manager.__init__()
# 配置
manager.configure({
"tool_prefix": "mcp",
"call_timeout": 30.0,
"retry_attempts": 1,
"retry_interval": 1.0,
"heartbeat_enabled": False,
})
manager.configure(
{
"tool_prefix": "mcp",
"call_timeout": 30.0,
"retry_attempts": 1,
"retry_interval": 1.0,
"heartbeat_enabled": False,
}
)
# 测试状态
status = manager.get_status()
@@ -82,10 +84,7 @@ async def test_manager_basic():
# 测试添加禁用的服务器
config = MCPServerConfig(
name="disabled_server",
enabled=False,
transport=TransportType.HTTP,
url="https://example.com/mcp"
name="disabled_server", enabled=False, transport=TransportType.HTTP, url="https://example.com/mcp"
)
result = await manager.add_server(config)
assert result == True
@@ -120,27 +119,29 @@ async def test_http_connection():
manager._initialized = False
manager.__init__()
manager.configure({
"tool_prefix": "mcp",
"call_timeout": 30.0,
"retry_attempts": 2,
"retry_interval": 2.0,
"heartbeat_enabled": False,
})
manager.configure(
{
"tool_prefix": "mcp",
"call_timeout": 30.0,
"retry_attempts": 2,
"retry_interval": 2.0,
"heartbeat_enabled": False,
}
)
# 使用 HowToCook MCP 服务器测试
config = MCPServerConfig(
name="howtocook",
enabled=True,
transport=TransportType.HTTP,
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp"
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp",
)
print(f"正在连接 {config.url} ...")
result = await manager.add_server(config)
if result:
print(f"✅ 连接成功!")
print("✅ 连接成功!")
# 检查工具
tools = manager.all_tools
@@ -159,19 +160,23 @@ async def test_http_connection():
call_result = await manager.call_tool("mcp_howtocook_whatToEat", {})
if call_result.success:
print(f"✅ 工具调用成功 (耗时: {call_result.duration_ms:.0f}ms)")
print(f" 结果: {call_result.content[:200]}..." if len(str(call_result.content)) > 200 else f" 结果: {call_result.content}")
print(
f" 结果: {call_result.content[:200]}..."
if len(str(call_result.content)) > 200
else f" 结果: {call_result.content}"
)
else:
print(f"❌ 工具调用失败: {call_result.error}")
# 查看统计
stats = manager.get_all_stats()
print(f"\n📊 统计信息:")
print("\n📊 统计信息:")
print(f" 全局调用: {stats['global']['total_tool_calls']}")
print(f" 成功: {stats['global']['successful_calls']}")
print(f" 失败: {stats['global']['failed_calls']}")
else:
print(f"❌ 连接失败")
print("❌ 连接失败")
# 清理
await manager.shutdown()
@@ -187,23 +192,25 @@ async def test_heartbeat():
manager._initialized = False
manager.__init__()
manager.configure({
"tool_prefix": "mcp",
"call_timeout": 30.0,
"retry_attempts": 1,
"retry_interval": 1.0,
"heartbeat_enabled": True,
"heartbeat_interval": 5.0, # 5秒间隔用于测试
"auto_reconnect": True,
"max_reconnect_attempts": 2,
})
manager.configure(
{
"tool_prefix": "mcp",
"call_timeout": 30.0,
"retry_attempts": 1,
"retry_interval": 1.0,
"heartbeat_enabled": True,
"heartbeat_interval": 5.0, # 5秒间隔用于测试
"auto_reconnect": True,
"max_reconnect_attempts": 2,
}
)
# 添加一个测试服务器
config = MCPServerConfig(
name="heartbeat_test",
enabled=True,
transport=TransportType.HTTP,
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp"
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp",
)
print("正在连接服务器...")
@@ -260,6 +267,7 @@ async def main():
except Exception as e:
print(f"\n❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return False

View File

@@ -301,4 +301,3 @@ def main():
if __name__ == "__main__":
main()

View File

@@ -13,7 +13,12 @@ from src.chat.utils.chat_message_builder import (
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.bw_learner.learner_utils import filter_message_content, is_bot_message, build_context_paragraph, contains_bot_self_name
from src.bw_learner.learner_utils import (
filter_message_content,
is_bot_message,
build_context_paragraph,
contains_bot_self_name,
)
from src.bw_learner.jargon_miner import miner_manager
from json_repair import repair_json
@@ -77,8 +82,6 @@ def init_prompt() -> None:
Prompt(learn_style_prompt, "learn_style_prompt")
class ExpressionLearner:
def __init__(self, chat_id: str) -> None:
self.express_learn_model: LLMRequest = LLMRequest(
@@ -186,7 +189,6 @@ class ExpressionLearner:
filtered_expressions.append((situation, style, context))
learnt_expressions = filtered_expressions
if learnt_expressions is None:
@@ -270,6 +272,7 @@ class ExpressionLearner:
# 如果解析失败,尝试修复中文引号问题
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
try:
def fix_chinese_quotes_in_json(text):
"""使用状态机修复 JSON 字符串值中的中文引号"""
result = []
@@ -287,7 +290,7 @@ class ExpressionLearner:
i += 1
continue
if char == '\\':
if char == "\\":
# 转义字符
result.append(char)
escape_next = True
@@ -315,7 +318,7 @@ class ExpressionLearner:
i += 1
return ''.join(result)
return "".join(result)
fixed_raw = fix_chinese_quotes_in_json(raw)

View File

@@ -82,9 +82,7 @@ class ExpressionReflector:
# 获取未检查的表达
try:
logger.info("[Expression Reflection] 查询未检查且未拒绝的表达")
expressions = (
Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
)
expressions = Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
expr_list = list(expressions)
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")

View File

@@ -128,9 +128,7 @@ class ExpressionSelector:
# 查询所有相关chat_id的表达方式排除 rejected=1 的,且只选择 count > 1 的
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids))
& (~Expression.rejected)
& (Expression.count > 1)
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
)
style_exprs = [
@@ -150,12 +148,15 @@ class ExpressionSelector:
# 要求至少有10个 count > 1 的表达方式才进行选择
min_required = 10
if len(style_exprs) < min_required:
logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择")
logger.info(
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择"
)
return [], []
# 固定选择5个
select_count = 5
import random
selected_style = random.sample(style_exprs, select_count)
# 更新last_active_time
@@ -163,7 +164,9 @@ class ExpressionSelector:
self.update_expressions_last_active_time(selected_style)
selected_ids = [expr["id"] for expr in selected_style]
logger.debug(f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)}")
logger.debug(
f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)}"
)
return selected_style, selected_ids
except Exception as e:
@@ -186,9 +189,7 @@ class ExpressionSelector:
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式排除 rejected=1 的表达
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
)
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
style_exprs = [
{
@@ -246,7 +247,9 @@ class ExpressionSelector:
# 使用classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式think_level={think_level}")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message, reply_reason, think_level)
return await self._select_expressions_classic(
chat_id, chat_info, max_num, target_message, reply_reason, think_level
)
async def _select_expressions_classic(
self,
@@ -279,9 +282,7 @@ class ExpressionSelector:
# think_level == 1: 先选高count再从所有表达方式中随机抽样
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
related_chat_ids = self.get_related_chat_ids(chat_id)
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
)
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
all_style_exprs = [
{
@@ -308,11 +309,15 @@ class ExpressionSelector:
# 检查数量要求
if len(high_count_exprs) < min_high_count:
logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择")
logger.info(
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择"
)
return [], []
if len(all_style_exprs) < min_total_count:
logger.info(f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择")
logger.info(
f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择"
)
return [], []
# 先选取高count的表达方式
@@ -332,6 +337,7 @@ class ExpressionSelector:
# 打乱顺序避免高count的都在前面
import random
random.shuffle(candidate_exprs)
# 2. 构建所有表达方式的索引和情境列表
@@ -351,7 +357,7 @@ class ExpressionSelector:
all_situations_str = "\n".join(all_situations)
if target_message:
target_message_str = f",现在你想要对这条消息进行回复:\"{target_message}\""
target_message_str = f',现在你想要对这条消息进行回复:"{target_message}"'
target_message_extra_block = "4.考虑你要回复的目标消息"
else:
target_message_str = ""

View File

@@ -8,7 +8,12 @@ from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.bw_learner.jargon_miner import search_jargon
from src.bw_learner.learner_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains
from src.bw_learner.learner_utils import (
is_bot_message,
contains_bot_self_name,
parse_chat_id_list,
chat_id_list_contains,
)
logger = get_logger("jargon")

View File

@@ -1,4 +1,3 @@
import time
import json
import asyncio
import random
@@ -14,7 +13,6 @@ from src.config.config import model_config, global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_by_timestamp_with_chat_inclusive,
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.bw_learner.learner_utils import (
@@ -46,10 +44,10 @@ def _is_single_char_jargon(content: str) -> bool:
char = content[0]
# 判断是否是单个汉字、单个英文字母或单个数字
return (
'\u4e00' <= char <= '\u9fff' or # 汉字
'a' <= char <= 'z' or # 小写字母
'A' <= char <= 'Z' or # 大写字母
'0' <= char <= '9' # 数字
"\u4e00" <= char <= "\u9fff" # 汉字
or "a" <= char <= "z" # 小写字母
or "A" <= char <= "Z" # 大写字母
or "0" <= char <= "9" # 数字
)
@@ -305,7 +303,9 @@ class JargonMiner:
# 计算要保留的数量至少保留1个
keep_count = max(1, len(raw_content_list) // 2)
raw_content_list = random.sample(raw_content_list, keep_count)
logger.info(f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目")
logger.info(
f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目"
)
# 步骤1: 基于raw_content和content推断
raw_content_text = "\n".join(raw_content_list)
@@ -318,7 +318,9 @@ class JargonMiner:
**上一次推断的含义(仅供参考)**
{previous_meaning}
"""
previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
previous_meaning_instruction = (
"- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
)
prompt1 = await global_prompt_manager.format_prompt(
"jargon_inference_with_context_prompt",
@@ -650,7 +652,9 @@ class JargonMiner:
if obj.raw_content:
try:
existing_raw_content = (
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
json.loads(obj.raw_content)
if isinstance(obj.raw_content, str)
else obj.raw_content
)
if not isinstance(existing_raw_content, list):
existing_raw_content = [existing_raw_content] if existing_raw_content else []
@@ -876,8 +880,6 @@ class JargonMinerManager:
miner_manager = JargonMinerManager()
def search_jargon(
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
) -> List[Dict[str, str]]:

View File

@@ -116,7 +116,6 @@ class MessageRecorder:
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
)
# 分别触发 expression_learner 和 jargon_miner 的处理
# 传递提取的消息,避免它们重复获取
# 触发 expression 学习(如果启用)
@@ -127,21 +126,19 @@ class MessageRecorder:
# 触发 jargon 提取(如果启用),传递消息
# if self.enable_jargon_learning:
# asyncio.create_task(
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
# )
# asyncio.create_task(
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
# )
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
import traceback
traceback.print_exc()
# 即使失败也保持时间戳更新,避免频繁重试
async def _trigger_expression_learning(
self,
timestamp_start: float,
timestamp_end: float,
messages: List[Any]
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
) -> None:
"""
触发 expression 学习,使用指定的消息列表
@@ -162,13 +159,11 @@ class MessageRecorder:
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}")
import traceback
traceback.print_exc()
async def _trigger_jargon_extraction(
self,
timestamp_start: float,
timestamp_end: float,
messages: List[Any]
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
) -> None:
"""
触发 jargon 提取,使用指定的消息列表
@@ -185,6 +180,7 @@ class MessageRecorder:
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
import traceback
traceback.print_exc()
@@ -214,4 +210,3 @@ async def extract_and_distribute_messages(chat_id: str) -> None:
"""
recorder = recorder_manager.get_recorder(chat_id)
await recorder.extract_and_distribute()

View File

@@ -328,9 +328,7 @@ class BrainChatting:
)
# 检查是否有 complete_talk 动作(会停止后续迭代)
has_complete_talk = any(
action.action_type == "complete_talk" for action in action_to_use_info
)
has_complete_talk = any(action.action_type == "complete_talk" for action in action_to_use_info)
# 并行执行所有动作
action_tasks = [

View File

@@ -204,7 +204,9 @@ class BrainPlanner:
# 注意listening 已合并到 wait 中,如果遇到 listening 则转换为 wait
internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"]
logger.debug(f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}")
logger.debug(
f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}"
)
# 将 listening 转换为 wait向后兼容
if action == "listening":
@@ -521,7 +523,7 @@ class BrainPlanner:
if json_objects:
logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
for i, json_obj in enumerate(json_objects):
logger.info(f"{self.log_prefix}解析第{i+1}个JSON对象: {json_obj}")
logger.info(f"{self.log_prefix}解析第{i + 1}个JSON对象: {json_obj}")
filtered_actions_list = list(filtered_actions.items())
for json_obj in json_objects:
parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list)
@@ -553,7 +555,9 @@ class BrainPlanner:
return extracted_reasoning, actions
def _create_complete_talk(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
def _create_complete_talk(
self, reasoning: str, available_actions: Dict[str, ActionInfo]
) -> List[ActionPlannerInfo]:
"""创建complete_talk"""
return [
ActionPlannerInfo(

View File

@@ -271,7 +271,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
emoji.description = emoji_data.description
# Deserialize emotion string from DB to list
emoji.emotion = emoji_data.emotion.replace("",",").split(",") if emoji_data.emotion else []
emoji.emotion = emoji_data.emotion.replace("", ",").split(",") if emoji_data.emotion else []
emoji.usage_count = emoji_data.usage_count
db_last_used_time = emoji_data.last_used_time
@@ -732,7 +732,7 @@ class EmojiManager:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
if emoji_record and emoji_record.emotion:
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
return emoji_record.emotion.replace("",",").split(",")
return emoji_record.emotion.replace("", ",").split(",")
except Exception as e:
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
@@ -993,7 +993,7 @@ class EmojiManager:
)
# 处理情感列表
emotions = [e.strip() for e in emotions_text.replace("",",").split(",") if e.strip()]
emotions = [e.strip() for e in emotions_text.replace("", ",").split(",") if e.strip()]
# 根据情感标签数量随机选择 - 超过5个选3个超过2个选2个
if len(emotions) > 5:

View File

@@ -123,7 +123,11 @@ class ChatBot:
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
# 根据命令的拦截设置决定是否继续处理消息
return True, response, not bool(intercept_message_level) # 找到命令根据intercept_message决定是否继续
return (
True,
response,
not bool(intercept_message_level),
) # 找到命令根据intercept_message决定是否继续
except Exception as e:
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")

View File

@@ -263,7 +263,7 @@ class MessageRecv(Message):
desc = segment.data.get("desc", "") # 内容描述
source_url = segment.data.get("source_url", "") # 原始链接
url = segment.data.get("url", "") # 小程序链接
text = f"[小程序分享"
text = "[小程序分享"
if title:
text += f" - {title}"
text += "]"

View File

@@ -51,7 +51,6 @@ def parse_message_segments(segment) -> list:
Returns:
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
"""
from maim_message import Seg
result = []
@@ -112,9 +111,13 @@ def parse_message_segments(segment) -> list:
forward_items = []
if segment.data:
for item in segment.data:
forward_items.append({
"content": parse_message_segments(item.get("message_segment", {})) if isinstance(item, dict) else []
})
forward_items.append(
{
"content": parse_message_segments(item.get("message_segment", {}))
if isinstance(item, dict)
else []
}
)
result.append({"type": "forward", "data": forward_items})
else:
# 未知类型,尝试作为文本处理

View File

@@ -78,7 +78,6 @@ target_message_id为必填表示触发消息的id
"planner_prompt",
)
Prompt(
"""
{action_name}

View File

@@ -250,7 +250,12 @@ class DefaultReplyer:
# 使用从处理器传来的选中表达方式
# 使用模型预测选择表达方式
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason, think_level=think_level
self.chat_stream.stream_id,
chat_history,
max_num=8,
target_message=target,
reply_reason=reply_reason,
think_level=think_level,
)
if selected_expressions:
@@ -273,7 +278,6 @@ class DefaultReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块
@@ -788,7 +792,8 @@ class DefaultReplyer:
# 并行执行八个构建任务(包括黑话解释)
task_results = await asyncio.gather(
self._time_and_run_task(
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level), "expression_habits"
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level),
"expression_habits",
),
self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
@@ -980,7 +985,6 @@ class DefaultReplyer:
else:
reply_target_block = ""
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")

View File

@@ -287,7 +287,6 @@ class PrivateReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块
@@ -907,16 +906,11 @@ class PrivateReplyer:
else:
reply_target_block = ""
chat_target_name = "对方"
if self.chat_target_info:
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
chat_target_1 = await global_prompt_manager.format_prompt(
"chat_target_private1", sender_name=chat_target_name
)
chat_target_2 = await global_prompt_manager.format_prompt(
"chat_target_private2", sender_name=chat_target_name
)
chat_target_1 = await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name)
chat_target_2 = await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name)
template_name = "default_expressor_prompt"

View File

@@ -1,8 +1,9 @@
from src.chat.utils.prompt_builder import Prompt
def init_replyer_private_prompt():
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation}
你正在和{sender_name}聊天,这是你们之前聊的内容:
@@ -17,8 +18,8 @@ def init_replyer_private_prompt():
{reply_style}
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""",
"private_replyer_prompt",
)
"private_replyer_prompt",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}

View File

@@ -44,4 +44,3 @@ def init_replyer_prompt():
现在,你说:""",
"replyer_prompt",
)

View File

@@ -311,7 +311,10 @@ def get_raw_msg_before_timestamp_with_chat(
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
sort_order = [("time", 1)]
return find_messages(
message_filter=filter_query, sort=sort_order, limit=limit, filter_intercept_message_level=filter_intercept_message_level
message_filter=filter_query,
sort=sort_order,
limit=limit,
filter_intercept_message_level=filter_intercept_message_level,
)

View File

@@ -132,9 +132,7 @@ class ImageManager:
deleted_images = Images.delete().where(Images.type == "emoji").execute()
# 清理ImageDescriptions表中type为emoji的记录
deleted_descriptions = (
ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
)
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
total_deleted = deleted_images + deleted_descriptions
if total_deleted > 0:
@@ -236,10 +234,14 @@ class ImageManager:
# 优先使用情感标签,如果没有则使用详细描述
result_text = ""
if cache_record.emotion_tags:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...")
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
)
result_text = f"[表情包:{cache_record.emotion_tags}]"
elif cache_record.description:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...")
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
)
result_text = f"[表情包:{cache_record.description}]"
# 即使缓存命中如果启用了steal_emoji也检查是否需要保存文件

View File

@@ -609,10 +609,10 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
fields = list(model._meta.fields.keys())
# Peewee 默认使用 'id' 作为主键字段名
# 尝试获取主键字段名,如果获取失败则默认使用 'id'
primary_key_name = 'id' # 默认值
primary_key_name = "id" # 默认值
try:
if hasattr(model._meta, 'primary_key') and model._meta.primary_key:
if hasattr(model._meta.primary_key, 'name'):
if hasattr(model._meta, "primary_key") and model._meta.primary_key:
if hasattr(model._meta.primary_key, "name"):
primary_key_name = model._meta.primary_key.name
elif isinstance(model._meta.primary_key, str):
primary_key_name = model._meta.primary_key

View File

@@ -34,7 +34,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
return obj
# 决定是否多行:仅在顶层且长度超过阈值时
should_multiline = (depth == 0 and len(obj) > threshold)
should_multiline = depth == 0 and len(obj) > threshold
# 如果已经是 tomlkit Array原地修改以保留注释
if isinstance(obj, Array):
@@ -112,7 +112,7 @@ def save_toml_with_format(
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
output = tomlkit.dumps(formatted)
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
output = re.sub(r'\n{3,}', '\n\n', output)
output = re.sub(r"\n{3,}", "\n\n", output)
with open(file_path, "w", encoding="utf-8") as f:
f.write(output)
@@ -122,4 +122,4 @@ def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
output = tomlkit.dumps(formatted)
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
return re.sub(r'\n{3,}', '\n\n', output)
return re.sub(r"\n{3,}", "\n\n", output)

View File

@@ -1,14 +1,13 @@
import asyncio
import random
import time
import json
from typing import Any, Dict, List, Optional, Tuple
from peewee import fn
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.common.database.database_model import ChatHistory, Jargon
from src.common.database.database_model import ChatHistory
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.plugin_system.apis import llm_api
@@ -82,7 +81,6 @@ def init_dream_prompts() -> None:
)
class DreamTool:
"""dream 模块内部使用的简易工具封装"""
@@ -150,7 +148,13 @@ def init_dream_tools(chat_id: str) -> None:
"search_chat_history",
"根据关键词或参与人查询当前 chat_id 下的 ChatHistory 概览,便于快速定位相关记忆。",
[
("keyword", ToolParamType.STRING, "关键词(可选,支持多个关键词,可用空格、逗号等分隔)。", False, None),
(
"keyword",
ToolParamType.STRING,
"关键词(可选,支持多个关键词,可用空格、逗号等分隔)。",
False,
None,
),
("participant", ToolParamType.STRING, "参与人昵称(可选)。", False, None),
],
search_chat_history,
@@ -201,8 +205,20 @@ def init_dream_tools(chat_id: str) -> None:
[
("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None),
("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None),
("keywords", ToolParamType.STRING, "新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。", True, None),
("key_point", ToolParamType.STRING, "新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。", True, None),
(
"keywords",
ToolParamType.STRING,
"新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。",
True,
None,
),
(
"key_point",
ToolParamType.STRING,
"新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。",
True,
None,
),
("start_time", ToolParamType.STRING, "起始时间戳Unix 时间,必填)。", True, None),
("end_time", ToolParamType.STRING, "结束时间戳Unix 时间,必填)。", True, None),
],
@@ -215,7 +231,13 @@ def init_dream_tools(chat_id: str) -> None:
"finish_maintenance",
"结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理、合并或修改的内容时,调用此工具来主动结束本次运行。",
[
("reason", ToolParamType.STRING, "结束维护的原因说明(可选),例如 '已完成所有记录的整理''当前记录质量良好,无需进一步维护'", False, None),
(
"reason",
ToolParamType.STRING,
"结束维护的原因说明(可选),例如 '已完成所有记录的整理''当前记录质量良好,无需进一步维护'",
False,
None,
),
],
finish_maintenance,
)
@@ -282,9 +304,7 @@ async def run_dream_agent_once(
else "未知"
)
end_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time))
if record.end_time
else "未知"
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
)
detail_text = (
f"ID={record.id}\n"
@@ -305,8 +325,7 @@ async def run_dream_agent_once(
start_detail_builder = MessageBuilder()
start_detail_builder.set_role(RoleType.User)
start_detail_builder.add_text_content(
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n"
+ detail_text
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n" + detail_text
)
conversation_messages.append(start_detail_builder.build())
else:
@@ -343,13 +362,17 @@ async def run_dream_agent_once(
conversation_messages.append(round_info_builder.build())
# 调用 LLM 让其决定是否要使用工具
success, response, reasoning_content, model_name, tool_calls = (
await llm_api.generate_with_model_with_tools_by_message_factory(
message_factory,
model_config=model_config.model_task_config.tool_use,
tool_options=tool_defs,
request_type="dream.react",
)
(
success,
response,
reasoning_content,
model_name,
tool_calls,
) = await llm_api.generate_with_model_with_tools_by_message_factory(
message_factory,
model_config=model_config.model_task_config.tool_use,
tool_options=tool_defs,
request_type="dream.react",
)
if not success:
@@ -555,4 +578,3 @@ async def start_dream_scheduler(
# 初始化提示词
init_dream_prompts()

View File

@@ -110,13 +110,17 @@ async def generate_dream_summary(
thought_content = ""
if msg.content:
if isinstance(msg.content, list) and msg.content:
thought_content = msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
thought_content = (
msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
)
else:
thought_content = str(msg.content)
logger.info(f"[dream][工具调用详情] === 第 {tool_call_count} 组工具调用 ===")
if thought_content:
logger.info(f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}")
logger.info(
f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}"
)
# 记录每个工具调用的详细信息
for idx, tool_call in enumerate(msg.tool_calls, 1):
@@ -167,7 +171,7 @@ async def generate_dream_summary(
# 随机选择2个梦境风格
selected_styles = get_random_dream_styles(2)
dream_styles_text = "\n".join([f"{i+1}. {style}" for i, style in enumerate(selected_styles)])
dream_styles_text = "\n".join([f"{i + 1}. {style}" for i, style in enumerate(selected_styles)])
# 使用 Prompt 管理器格式化梦境生成 prompt
dream_prompt = await global_prompt_manager.format_prompt(
@@ -195,4 +199,5 @@ async def generate_dream_summary(
except Exception as e:
logger.error(f"[dream][梦境总结] 生成梦境总结失败: {e}", exc_info=True)
init_dream_summary_prompt()

View File

@@ -4,8 +4,3 @@ dream agent 工具实现模块。
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
"""

View File

@@ -60,8 +60,3 @@ def make_create_chat_history(chat_id: str):
return f"create_chat_history 执行失败: {e}"
return create_chat_history

View File

@@ -23,8 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"delete_chat_history 执行失败: {e}"
return delete_chat_history

View File

@@ -23,8 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"delete_jargon 执行失败: {e}"
return delete_jargon

View File

@@ -14,8 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,
return msg
return finish_maintenance

View File

@@ -1,5 +1,4 @@
import time
from typing import Optional
from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory
@@ -20,14 +19,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
# 将时间戳转换为可读时间格式
start_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time))
if record.start_time
else "未知"
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time)) if record.start_time else "未知"
)
end_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time))
if record.end_time
else "未知"
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
)
result = (
@@ -40,17 +35,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
f"概括={record.summary or ''}\n"
f"关键信息={record.key_point or ''}"
)
logger.debug(
f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}"
)
logger.debug(f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}")
return result
except Exception as e:
logger.error(f"get_chat_history_detail 失败: {e}")
return f"get_chat_history_detail 执行失败: {e}"
return get_chat_history_detail

View File

@@ -78,9 +78,7 @@ def make_search_chat_history(chat_id: str):
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords)
if isinstance(record.keywords, str)
else record.keywords
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
record_keywords_list = [str(k).lower() for k in keywords_data]
@@ -125,9 +123,7 @@ def make_search_chat_history(chat_id: str):
keywords_str = "".join(keywords_list)
if len(keywords_list) > 2:
required_count = len(keywords_list) - 1
return (
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
)
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
else:
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
elif participant:
@@ -142,9 +138,7 @@ def make_search_chat_history(chat_id: str):
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords)
if isinstance(record.keywords, str)
else record.keywords
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
for k in keywords_data:
@@ -160,13 +154,13 @@ def make_search_chat_history(chat_id: str):
keywords_str = "".join(sorted(all_keywords_set))
response_text = (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词:\n"
f'有关"{search_label}"的关键词:\n'
f"{keywords_str}"
)
else:
response_text = (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词信息为空"
f'有关"{search_label}"的关键词信息为空'
)
logger.info(
@@ -192,9 +186,7 @@ def make_search_chat_history(chat_id: str):
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords)
if isinstance(record.keywords, str)
else record.keywords
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list) and keywords_data:
keywords_str = "".join([str(k) for k in keywords_data])
@@ -220,8 +212,3 @@ def make_search_chat_history(chat_id: str):
return f"search_chat_history 执行失败: {e}"
return search_chat_history

View File

@@ -16,9 +16,7 @@ def make_search_jargon(chat_id: str):
if not keyword or not keyword.strip():
return "未指定查询关键词(参数 keyword 为必填,且不能为空)"
logger.info(
f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})"
)
logger.info(f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})")
# 基础条件:只查 is_jargon=True 的记录
query = Jargon.select().where(Jargon.is_jargon)
@@ -102,5 +100,3 @@ def make_search_jargon(chat_id: str):
return f"search_jargon 执行失败: {e}"
return search_jargon

View File

@@ -49,8 +49,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"update_chat_history 执行失败: {e}"
return update_chat_history

View File

@@ -49,8 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"update_jargon 执行失败: {e}"
return update_jargon

View File

@@ -316,7 +316,9 @@ class ChatHistorySummarizer:
before_count = len(self.current_batch.messages)
self.current_batch.messages.extend(new_messages)
self.current_batch.end_time = current_time
logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息")
logger.info(
f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息"
)
# 更新批次后持久化
self._persist_topic_cache()
else:
@@ -362,9 +364,7 @@ class ChatHistorySummarizer:
else:
time_str = f"{time_since_last_check / 3600:.1f}小时"
logger.debug(
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}"
)
logger.debug(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}")
# 检查“话题检查”触发条件
should_check = False
@@ -427,7 +427,9 @@ class ChatHistorySummarizer:
return
# 2. 构造编号后的消息字符串和参与者信息
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages)
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = (
self._build_numbered_messages_for_llm(messages)
)
# 3. 调用 LLM 识别话题,并得到 topic -> indices失败时最多重试 3 次)
existing_topics = list(self.topic_cache.keys())
@@ -456,9 +458,7 @@ class ChatHistorySummarizer:
)
if not success or not topic_to_indices:
logger.error(
f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃"
)
logger.error(f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃")
# 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks保持原状
return
@@ -610,9 +610,7 @@ class ChatHistorySummarizer:
if not numbered_lines:
return False, {}
history_topics_block = (
"\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
)
history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
messages_block = "\n".join(numbered_lines)
prompt = await global_prompt_manager.format_prompt(
@@ -642,10 +640,10 @@ class ChatHistorySummarizer:
else:
# 如果没有找到代码块尝试查找JSON数组的开始和结束位置
# 查找第一个 [ 和最后一个 ]
start_idx = response.find('[')
end_idx = response.rfind(']')
start_idx = response.find("[")
end_idx = response.rfind("]")
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
json_str = response[start_idx:end_idx + 1].strip()
json_str = response[start_idx : end_idx + 1].strip()
else:
# 如果还是找不到尝试直接使用整个响应移除可能的markdown标记
json_str = response.strip()
@@ -942,4 +940,3 @@ class ChatHistorySummarizer:
init_prompt()

View File

@@ -366,7 +366,9 @@ class LLMRequest:
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}")
logger.warning(
f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval)
except NetworkConnectionError as e:
@@ -394,7 +396,9 @@ class LLMRequest:
if e.status_code == 429 or e.status_code >= 500:
retry_remain -= 1
if retry_remain <= 0:
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}")
logger.error(
f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}"
)
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(
@@ -540,7 +544,5 @@ class LLMRequest:
if e.__cause__:
original_error_type = type(e.__cause__).__name__
original_error_msg = str(e.__cause__)
return (
f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
)
return f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
return ""

View File

@@ -113,7 +113,6 @@ class MainSystem:
get_emoji_manager().initialize()
logger.info("表情包管理器初始化成功")
# 初始化聊天管理器
await get_chat_manager()._initialize()
asyncio.create_task(get_chat_manager()._auto_save_task())

View File

@@ -136,8 +136,6 @@ def init_memory_retrieval_prompt():
)
def _log_conversation_messages(
conversation_messages: List[Message],
head_prompt: Optional[str] = None,
@@ -172,7 +170,9 @@ def _log_conversation_messages(
# 构建单条消息的日志信息
# msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
msg_info = (
f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
)
# if full_content:
# msg_info += f"\n{full_content}"
@@ -185,8 +185,7 @@ def _log_conversation_messages(
msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}"
# if msg.tool_call_id:
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
log_lines.append(msg_info)
@@ -365,7 +364,7 @@ async def _react_agent_solve_question(
)
# logger.info(
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
# )
if not success:
@@ -467,10 +466,17 @@ async def _react_agent_solve_question(
if parsed_found_answer:
# 找到了答案
if parsed_answer:
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": parsed_answer}})
step["actions"].append(
{
"action_type": "finish_search",
"action_params": {"found_answer": True, "answer": parsed_answer},
}
)
step["observations"] = ["检测到finish_search文本格式调用找到答案"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}")
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
)
_log_conversation_messages(
conversation_messages,
@@ -481,10 +487,14 @@ async def _react_agent_solve_question(
return True, parsed_answer, thinking_steps, False
else:
# found_answer为True但没有提供answer视为错误继续迭代
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer")
logger.warning(
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
)
else:
# 未找到答案,直接退出查询
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
step["actions"].append(
{"action_type": "finish_search", "action_params": {"found_answer": False}}
)
step["observations"] = ["检测到finish_search文本格式调用未找到答案"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案")
@@ -522,10 +532,17 @@ async def _react_agent_solve_question(
if finish_search_found:
# 找到了答案
if finish_search_answer:
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": finish_search_answer}})
step["actions"].append(
{
"action_type": "finish_search",
"action_params": {"found_answer": True, "answer": finish_search_answer},
}
)
step["observations"] = ["检测到finish_search工具调用找到答案"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}")
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
)
_log_conversation_messages(
conversation_messages,
@@ -536,7 +553,9 @@ async def _react_agent_solve_question(
return True, finish_search_answer, thinking_steps, False
else:
# found_answer为True但没有提供answer视为错误
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer")
logger.warning(
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
)
else:
# 未找到答案,直接退出查询
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
@@ -724,7 +743,13 @@ async def _react_agent_solve_question(
max_iterations=max_iterations,
)
eval_success, eval_response, eval_reasoning_content, eval_model_name, eval_tool_calls = await llm_api.generate_with_model_with_tools(
(
eval_success,
eval_response,
eval_reasoning_content,
eval_model_name,
eval_tool_calls,
) = await llm_api.generate_with_model_with_tools(
evaluation_prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=[], # 最终评估阶段不提供工具
@@ -759,7 +784,7 @@ async def _react_agent_solve_question(
"iteration": current_iteration,
"thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}],
"observations": ["最终评估阶段检测到found_answer"]
"observations": ["最终评估阶段检测到found_answer"],
}
thinking_steps.append(eval_step)
logger.info(f"ReAct Agent 最终评估阶段找到关于问题{question}的答案: {found_answer_content}")
@@ -778,7 +803,7 @@ async def _react_agent_solve_question(
"iteration": current_iteration,
"thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}],
"observations": ["最终评估阶段检测到not_enough_info"]
"observations": ["最终评估阶段检测到not_enough_info"],
}
thinking_steps.append(eval_step)
logger.info(f"ReAct Agent 最终评估阶段判断信息不足: {not_enough_info_reason}")
@@ -795,8 +820,10 @@ async def _react_agent_solve_question(
eval_step = {
"iteration": current_iteration,
"thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}],
"observations": ["已到达最大迭代次数,无法找到答案"]
"actions": [
{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}
],
"observations": ["已到达最大迭代次数,无法找到答案"],
}
thinking_steps.append(eval_step)
logger.info("ReAct Agent 已到达最大迭代次数,无法找到答案")
@@ -1129,7 +1156,9 @@ async def build_memory_retrieval_prompt(
else:
max_iterations = base_max_iterations
timeout_seconds = global_config.memory.agent_timeout_seconds
logger.debug(f"问题数量: {len(questions)}think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}")
logger.debug(
f"问题数量: {len(questions)}think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}"
)
# 并行处理所有问题,将概念检索结果作为初始信息传递
question_tasks = [
@@ -1198,4 +1227,3 @@ async def build_memory_retrieval_prompt(
except Exception as e:
logger.error(f"记忆检索时发生异常: {str(e)}")
return ""

View File

@@ -17,7 +17,6 @@ from src.common.logger import get_logger
logger = get_logger("memory_utils")
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
"""解析问题JSON返回概念列表和问题列表
@@ -68,6 +67,7 @@ def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
return [], []
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳(秒)

View File

@@ -47,4 +47,3 @@ def register_tool():
],
execute_func=finish_search,
)

View File

@@ -16,9 +16,7 @@ from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def search_chat_history(
chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None
) -> str:
async def search_chat_history(chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None) -> str:
"""根据关键词或参与人查询记忆返回匹配的记忆id、记忆标题theme和关键词keywords
Args:
@@ -144,7 +142,9 @@ async def search_chat_history(
keywords_list = parse_keywords_string(keyword)
if len(keywords_list) > 2:
required_count = len(keywords_list) - 1
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
return (
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
)
else:
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
elif participant:
@@ -160,9 +160,7 @@ async def search_chat_history(
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords)
if isinstance(record.keywords, str)
else record.keywords
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
for k in keywords_data:
@@ -179,13 +177,12 @@ async def search_chat_history(
keywords_str = "".join(sorted(all_keywords_set))
return (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词:\n"
f'有关"{search_label}"的关键词:\n'
f"{keywords_str}"
)
else:
return (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词信息为空"
f'包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n有关"{search_label}"的关键词信息为空'
)
# 构建结果文本返回id、theme和keywords最多20条

View File

@@ -414,7 +414,9 @@ async def websocket_chat(
group_id=virtual_group_id,
group_name=group_name or "WebUI虚拟群聊",
)
logger.info(f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}")
logger.info(
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
)
except Exception as e:
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")

View File

@@ -1,4 +1,4 @@
""" 表情包管理 API 路由"""
"""表情包管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
from fastapi.responses import FileResponse, JSONResponse
@@ -100,23 +100,23 @@ def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
try:
with Image.open(source_path) as img:
# GIF 处理:提取第一帧
if hasattr(img, 'n_frames') and img.n_frames > 1:
if hasattr(img, "n_frames") and img.n_frames > 1:
img.seek(0) # 确保在第一帧
# 转换为 RGB/RGBAWebP 支持透明度)
if img.mode in ('P', 'PA'):
if img.mode in ("P", "PA"):
# 调色板模式转换为 RGBA 以保留透明度
img = img.convert('RGBA')
elif img.mode == 'LA':
img = img.convert('RGBA')
elif img.mode not in ('RGB', 'RGBA'):
img = img.convert('RGB')
img = img.convert("RGBA")
elif img.mode == "LA":
img = img.convert("RGBA")
elif img.mode not in ("RGB", "RGBA"):
img = img.convert("RGB")
# 创建缩略图(保持宽高比)
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
# 保存为 WebP 格式
img.save(cache_path, 'WEBP', quality=THUMBNAIL_QUALITY, method=6)
img.save(cache_path, "WEBP", quality=THUMBNAIL_QUALITY, method=6)
logger.debug(f"生成缩略图: {file_hash} -> {cache_path}")
@@ -163,6 +163,7 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]:
return cleaned, kept
# 模块级别的类型别名(解决 B008 ruff 错误)
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
@@ -365,7 +366,9 @@ async def get_emoji_list(
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_emoji_detail(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表情包详细信息
@@ -394,7 +397,12 @@ async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def update_emoji(
emoji_id: int,
request: EmojiUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新表情包(只更新提供的字段)
@@ -446,7 +454,9 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_sessio
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def delete_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除表情包
@@ -538,7 +548,9 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def register_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
注册表情包(快捷操作)
@@ -578,7 +590,9 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def ban_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
禁用表情包(快捷操作)
@@ -680,9 +694,7 @@ async def get_emoji_thumbnail(
}
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
return FileResponse(
path=emoji.full_path,
media_type=media_type,
filename=f"{emoji.emoji_hash}.{emoji.format}"
path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}"
)
# 尝试获取或生成缩略图
@@ -692,9 +704,7 @@ async def get_emoji_thumbnail(
if cache_path.exists():
# 缓存命中,直接返回
return FileResponse(
path=str(cache_path),
media_type="image/webp",
filename=f"{emoji.emoji_hash}_thumb.webp"
path=str(cache_path), media_type="image/webp", filename=f"{emoji.emoji_hash}_thumb.webp"
)
# 缓存未命中,触发后台生成并返回 202
@@ -703,11 +713,7 @@ async def get_emoji_thumbnail(
# 标记为正在生成
_generating_thumbnails.add(emoji.emoji_hash)
# 提交到线程池后台生成
_thumbnail_executor.submit(
_background_generate_thumbnail,
emoji.full_path,
emoji.emoji_hash
)
_thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.emoji_hash)
# 返回 202 Accepted告诉前端缩略图正在生成中
return JSONResponse(
@@ -719,7 +725,7 @@ async def get_emoji_thumbnail(
},
headers={
"Retry-After": "1", # 建议 1 秒后重试
}
},
)
except HTTPException:
@@ -730,7 +736,11 @@ async def get_emoji_thumbnail(
@router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_emojis(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def batch_delete_emojis(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除表情包
@@ -1234,12 +1244,7 @@ async def preheat_thumbnail_cache(
try:
# 使用线程池异步生成缩略图,避免阻塞事件循环
loop = asyncio.get_event_loop()
await loop.run_in_executor(
_thumbnail_executor,
_generate_thumbnail,
emoji.full_path,
emoji.emoji_hash
)
await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.emoji_hash)
generated += 1
except Exception as e:
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}")

View File

@@ -256,7 +256,9 @@ async def get_expression_list(
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_expression_detail(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表达方式详细信息
@@ -285,7 +287,11 @@ async def get_expression_detail(expression_id: int, maibot_session: Optional[str
@router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(request: ExpressionCreateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def create_expression(
request: ExpressionCreateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
创建新的表达方式
@@ -326,7 +332,10 @@ async def create_expression(request: ExpressionCreateRequest, maibot_session: Op
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression(
expression_id: int, request: ExpressionUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
expression_id: int,
request: ExpressionUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新表达方式(只更新提供的字段)
@@ -376,7 +385,9 @@ async def update_expression(
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def delete_expression(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除表达方式
@@ -419,7 +430,11 @@ class BatchDeleteRequest(BaseModel):
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def batch_delete_expressions(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除表达方式
@@ -460,7 +475,9 @@ async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session:
@router.get("/stats/summary")
async def get_expression_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_expression_stats(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表达方式统计数据

View File

@@ -277,11 +277,7 @@ async def get_chat_list():
"""获取所有有黑话记录的聊天列表"""
try:
# 获取所有不同的 chat_id
chat_ids = (
Jargon.select(Jargon.chat_id)
.distinct()
.where(Jargon.chat_id.is_null(False))
)
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
@@ -346,12 +342,7 @@ async def get_jargon_stats():
complete_count = Jargon.select().where(Jargon.is_complete).count()
# 关联的聊天数量
chat_count = (
Jargon.select(Jargon.chat_id)
.distinct()
.where(Jargon.chat_id.is_null(False))
.count()
)
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
# 按聊天统计 TOP 5
top_chats = (
@@ -403,9 +394,7 @@ async def create_jargon(request: JargonCreateRequest):
"""创建黑话"""
try:
# 检查是否已存在相同内容的黑话
existing = Jargon.get_or_none(
(Jargon.content == request.content) & (Jargon.chat_id == request.chat_id)
)
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
if existing:
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
@@ -527,11 +516,7 @@ async def batch_set_jargon_status(
if not ids:
raise HTTPException(status_code=400, detail="ID列表不能为空")
updated_count = (
Jargon.update(is_jargon=is_jargon)
.where(Jargon.id.in_(ids))
.execute()
)
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}")

View File

@@ -200,7 +200,9 @@ async def get_person_list(
@router.get("/{person_id}", response_model=PersonDetailResponse)
async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def get_person_detail(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取人物详细信息
@@ -229,7 +231,12 @@ async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cook
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
async def update_person(person_id: str, request: PersonUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def update_person(
person_id: str,
request: PersonUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新人物信息(只更新提供的字段)
@@ -278,7 +285,9 @@ async def update_person(person_id: str, request: PersonUpdateRequest, maibot_ses
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
async def delete_person(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def delete_person(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除人物信息
@@ -348,7 +357,11 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
@router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_persons(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
async def batch_delete_persons(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除人物信息

View File

@@ -125,6 +125,7 @@ def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> No
"""
根据 schema 将配置中的类型纠正(目前只纠正 list-from-str
"""
def _is_list_type(tp: Any) -> bool:
origin = get_origin(tp)
return tp is list or origin is list
@@ -313,7 +314,9 @@ async def check_git_status() -> GitStatusResponse:
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse:
async def get_available_mirrors(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> AvailableMirrorsResponse:
"""
获取所有可用的镜像源配置
"""
@@ -343,7 +346,9 @@ async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), au
@router.post("/mirrors", response_model=MirrorConfigResponse)
async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> MirrorConfigResponse:
async def add_mirror(
request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> MirrorConfigResponse:
"""
添加新的镜像源
"""
@@ -383,7 +388,10 @@ async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] =
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
async def update_mirror(
mirror_id: str, request: UpdateMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
mirror_id: str,
request: UpdateMirrorRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> MirrorConfigResponse:
"""
更新镜像源配置
@@ -426,7 +434,9 @@ async def update_mirror(
@router.delete("/mirrors/{mirror_id}")
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def delete_mirror(
mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
删除镜像源
"""
@@ -449,7 +459,9 @@ async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(N
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
async def fetch_raw_file(
request: FetchRawFileRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
request: FetchRawFileRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> FetchRawFileResponse:
"""
获取 GitHub 仓库的 Raw 文件内容
@@ -534,7 +546,9 @@ async def fetch_raw_file(
@router.post("/clone", response_model=CloneRepositoryResponse)
async def clone_repository(
request: CloneRepositoryRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
request: CloneRepositoryRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> CloneRepositoryResponse:
"""
克隆 GitHub 仓库到本地
@@ -574,7 +588,11 @@ async def clone_repository(
@router.post("/install")
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def install_plugin(
request: InstallPluginRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]:
"""
安装插件
@@ -778,7 +796,9 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
@router.post("/uninstall")
async def uninstall_plugin(
request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
request: UninstallPluginRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]:
"""
卸载插件
@@ -913,7 +933,11 @@ async def uninstall_plugin(
@router.post("/update")
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def update_plugin(
request: UpdatePluginRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]:
"""
更新插件
@@ -1132,7 +1156,9 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
@router.get("/installed")
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def get_installed_plugins(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
获取已安装的插件列表
@@ -1272,7 +1298,9 @@ class UpdatePluginConfigRequest(BaseModel):
@router.get("/config/{plugin_id}/schema")
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def get_plugin_config_schema(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
获取插件配置 Schema
@@ -1405,7 +1433,9 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
@router.get("/config/{plugin_id}")
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def get_plugin_config(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
获取插件当前配置值
@@ -1461,7 +1491,10 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook
@router.put("/config/{plugin_id}")
async def update_plugin_config(
plugin_id: str, request: UpdatePluginConfigRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
plugin_id: str,
request: UpdatePluginConfigRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]:
"""
更新插件配置
@@ -1532,7 +1565,9 @@ async def update_plugin_config(
@router.post("/config/{plugin_id}/reset")
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def reset_plugin_config(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
重置插件配置为默认值
@@ -1592,7 +1627,9 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co
@router.post("/config/{plugin_id}/toggle")
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
async def toggle_plugin(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
切换插件启用状态

View File

@@ -110,8 +110,10 @@ class WebUIServer:
from src.webui.routes import router as webui_router
from src.webui.logs_ws import router as logs_router
from src.webui.knowledge_routes import router as knowledge_router
# 导入本地聊天室路由
from src.webui.chat_routes import router as chat_router
# 注册路由
self.app.include_router(webui_router)
self.app.include_router(logs_router)
@@ -166,6 +168,7 @@ class WebUIServer:
def _check_port_available(self) -> bool:
"""检查端口是否可用"""
import socket
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)