Ruff format

This commit is contained in:
墨梓柒
2025-12-13 17:14:09 +08:00
parent ef377bb0cd
commit e680a4d1f5
60 changed files with 1546 additions and 1532 deletions

4
bot.py
View File

@@ -41,6 +41,7 @@ logger = get_logger("main")
# 定义重启退出码 # 定义重启退出码
RESTART_EXIT_CODE = 42 RESTART_EXIT_CODE = 42
def run_runner_process(): def run_runner_process():
""" """
Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。 Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。
@@ -68,7 +69,7 @@ def run_runner_process():
if return_code == RESTART_EXIT_CODE: if return_code == RESTART_EXIT_CODE:
logger.info("检测到重启请求 (退出码 42),正在重启...") logger.info("检测到重启请求 (退出码 42),正在重启...")
time.sleep(1) # 稍作等待 time.sleep(1) # 稍作等待
continue continue
else: else:
logger.info(f"程序已退出 (退出码 {return_code})") logger.info(f"程序已退出 (退出码 {return_code})")
@@ -87,6 +88,7 @@ def run_runner_process():
process.kill() process.kill()
sys.exit(0) sys.exit(0)
# 检查是否是 Worker 进程 # 检查是否是 Worker 进程
# 如果没有设置 MAIBOT_WORKER_PROCESS 环境变量,说明是直接运行的脚本, # 如果没有设置 MAIBOT_WORKER_PROCESS 环境变量,说明是直接运行的脚本,
# 此时应该作为 Runner 运行。 # 此时应该作为 Runner 运行。

View File

@@ -19,6 +19,7 @@ from typing import Any, Dict, List, Optional, Tuple
@dataclass @dataclass
class ConversionResult: class ConversionResult:
"""转换结果""" """转换结果"""
success: bool success: bool
servers: List[Dict[str, Any]] = field(default_factory=list) servers: List[Dict[str, Any]] = field(default_factory=list)
errors: List[str] = field(default_factory=list) errors: List[str] = field(default_factory=list)
@@ -271,11 +272,7 @@ class ConfigConverter:
return name, result return name, result
@classmethod @classmethod
def from_claude_format( def from_claude_format(cls, config: Dict[str, Any], existing_names: Optional[set] = None) -> ConversionResult:
cls,
config: Dict[str, Any],
existing_names: Optional[set] = None
) -> ConversionResult:
"""从 Claude Desktop 格式转换为 MaiBot 格式 """从 Claude Desktop 格式转换为 MaiBot 格式
Args: Args:
@@ -355,11 +352,7 @@ class ConfigConverter:
return {"mcpServers": mcp_servers} return {"mcpServers": mcp_servers}
@classmethod @classmethod
def import_from_string( def import_from_string(cls, json_str: str, existing_names: Optional[set] = None) -> ConversionResult:
cls,
json_str: str,
existing_names: Optional[set] = None
) -> ConversionResult:
"""从 JSON 字符串导入配置 """从 JSON 字符串导入配置
自动检测格式并转换为 MaiBot 格式 自动检测格式并转换为 MaiBot 格式
@@ -422,12 +415,7 @@ class ConfigConverter:
return result return result
@classmethod @classmethod
def export_to_string( def export_to_string(cls, servers: List[Dict[str, Any]], format_type: str = "claude", pretty: bool = True) -> str:
cls,
servers: List[Dict[str, Any]],
format_type: str = "claude",
pretty: bool = True
) -> str:
"""导出配置为 JSON 字符串 """导出配置为 JSON 字符串
Args: Args:

View File

@@ -34,28 +34,31 @@ from enum import Enum
# 尝试导入 MaiBot 的 logger如果失败则使用标准 logging # 尝试导入 MaiBot 的 logger如果失败则使用标准 logging
try: try:
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("mcp_client") logger = get_logger("mcp_client")
except ImportError: except ImportError:
# Fallback: 使用标准 logging # Fallback: 使用标准 logging
logger = logging.getLogger("mcp_client") logger = logging.getLogger("mcp_client")
if not logger.handlers: if not logger.handlers:
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('[%(levelname)s] %(name)s: %(message)s')) handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s"))
logger.addHandler(handler) logger.addHandler(handler)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
class TransportType(Enum): class TransportType(Enum):
"""MCP 传输类型""" """MCP 传输类型"""
STDIO = "stdio" # 本地进程通信
SSE = "sse" # Server-Sent Events (旧版 HTTP) STDIO = "stdio" # 本地进程通信
HTTP = "http" # HTTP Streamable (新版,推荐) SSE = "sse" # Server-Sent Events (旧版 HTTP)
HTTP = "http" # HTTP Streamable (新版,推荐)
STREAMABLE_HTTP = "streamable_http" # HTTP Streamable 的别名 STREAMABLE_HTTP = "streamable_http" # HTTP Streamable 的别名
@dataclass @dataclass
class MCPToolInfo: class MCPToolInfo:
"""MCP 工具信息""" """MCP 工具信息"""
name: str name: str
description: str description: str
input_schema: Dict[str, Any] input_schema: Dict[str, Any]
@@ -65,6 +68,7 @@ class MCPToolInfo:
@dataclass @dataclass
class MCPResourceInfo: class MCPResourceInfo:
"""MCP 资源信息""" """MCP 资源信息"""
uri: str uri: str
name: str name: str
description: str description: str
@@ -75,6 +79,7 @@ class MCPResourceInfo:
@dataclass @dataclass
class MCPPromptInfo: class MCPPromptInfo:
"""MCP 提示模板信息""" """MCP 提示模板信息"""
name: str name: str
description: str description: str
arguments: List[Dict[str, Any]] # [{name, description, required}] arguments: List[Dict[str, Any]] # [{name, description, required}]
@@ -84,6 +89,7 @@ class MCPPromptInfo:
@dataclass @dataclass
class MCPServerConfig: class MCPServerConfig:
"""MCP 服务器配置""" """MCP 服务器配置"""
name: str name: str
enabled: bool = True enabled: bool = True
transport: TransportType = TransportType.STDIO transport: TransportType = TransportType.STDIO
@@ -99,6 +105,7 @@ class MCPServerConfig:
@dataclass @dataclass
class MCPCallResult: class MCPCallResult:
"""MCP 工具调用结果""" """MCP 工具调用结果"""
success: bool success: bool
content: Any content: Any
error: Optional[str] = None error: Optional[str] = None
@@ -108,8 +115,9 @@ class MCPCallResult:
class CircuitState(Enum): class CircuitState(Enum):
"""断路器状态""" """断路器状态"""
CLOSED = "closed" # 正常状态,允许请求
OPEN = "open" # 熔断状态,拒绝请求 CLOSED = "closed" # 正常状态,允许请求
OPEN = "open" # 熔断状态,拒绝请求
HALF_OPEN = "half_open" # 半开状态,允许少量试探请求 HALF_OPEN = "half_open" # 半开状态,允许少量试探请求
@@ -125,9 +133,9 @@ class CircuitBreaker:
""" """
# 配置 # 配置
failure_threshold: int = 5 # 连续失败多少次后熔断 failure_threshold: int = 5 # 连续失败多少次后熔断
recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒) recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒)
half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用 half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用
# 状态 # 状态
state: CircuitState = field(default=CircuitState.CLOSED) state: CircuitState = field(default=CircuitState.CLOSED)
@@ -232,6 +240,7 @@ class CircuitBreaker:
@dataclass @dataclass
class ToolCallStats: class ToolCallStats:
"""工具调用统计""" """工具调用统计"""
tool_key: str tool_key: str
total_calls: int = 0 total_calls: int = 0
success_calls: int = 0 success_calls: int = 0
@@ -282,6 +291,7 @@ class ToolCallStats:
@dataclass @dataclass
class ServerStats: class ServerStats:
"""服务器统计""" """服务器统计"""
server_name: str server_name: str
connect_count: int = 0 # 连接次数 connect_count: int = 0 # 连接次数
disconnect_count: int = 0 # 断开次数 disconnect_count: int = 0 # 断开次数
@@ -442,9 +452,7 @@ class MCPClientSession:
return False return False
server_params = StdioServerParameters( server_params = StdioServerParameters(
command=self.config.command, command=self.config.command, args=self.config.args, env=self.config.env if self.config.env else None
args=self.config.args,
env=self.config.env if self.config.env else None
) )
self._stdio_context = stdio_client(server_params) self._stdio_context = stdio_client(server_params)
@@ -506,6 +514,7 @@ class MCPClientSession:
except Exception as e: except Exception as e:
logger.error(f"[{self.server_name}] SSE 连接失败: {e}") logger.error(f"[{self.server_name}] SSE 连接失败: {e}")
import traceback import traceback
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}") logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
await self._cleanup() await self._cleanup()
return False return False
@@ -551,6 +560,7 @@ class MCPClientSession:
except Exception as e: except Exception as e:
logger.error(f"[{self.server_name}] HTTP 连接失败: {e}") logger.error(f"[{self.server_name}] HTTP 连接失败: {e}")
import traceback import traceback
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}") logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
await self._cleanup() await self._cleanup()
return False return False
@@ -568,8 +578,8 @@ class MCPClientSession:
tool_info = MCPToolInfo( tool_info = MCPToolInfo(
name=tool.name, name=tool.name,
description=tool.description or f"MCP tool: {tool.name}", description=tool.description or f"MCP tool: {tool.name}",
input_schema=tool.inputSchema if hasattr(tool, 'inputSchema') else {}, input_schema=tool.inputSchema if hasattr(tool, "inputSchema") else {},
server_name=self.server_name server_name=self.server_name,
) )
self._tools.append(tool_info) self._tools.append(tool_info)
# 初始化工具统计 # 初始化工具统计
@@ -591,10 +601,7 @@ class MCPClientSession:
return False return False
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(self._session.list_resources(), timeout=self.call_timeout)
self._session.list_resources(),
timeout=self.call_timeout
)
self._resources = [] self._resources = []
for resource in result.resources: for resource in result.resources:
@@ -602,8 +609,8 @@ class MCPClientSession:
uri=str(resource.uri), uri=str(resource.uri),
name=resource.name or str(resource.uri), name=resource.name or str(resource.uri),
description=resource.description or "", description=resource.description or "",
mime_type=resource.mimeType if hasattr(resource, 'mimeType') else None, mime_type=resource.mimeType if hasattr(resource, "mimeType") else None,
server_name=self.server_name server_name=self.server_name,
) )
self._resources.append(resource_info) self._resources.append(resource_info)
logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}") logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}")
@@ -633,28 +640,27 @@ class MCPClientSession:
return False return False
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(self._session.list_prompts(), timeout=self.call_timeout)
self._session.list_prompts(),
timeout=self.call_timeout
)
self._prompts = [] self._prompts = []
for prompt in result.prompts: for prompt in result.prompts:
# 解析参数 # 解析参数
arguments = [] arguments = []
if hasattr(prompt, 'arguments') and prompt.arguments: if hasattr(prompt, "arguments") and prompt.arguments:
for arg in prompt.arguments: for arg in prompt.arguments:
arguments.append({ arguments.append(
"name": arg.name, {
"description": arg.description or "", "name": arg.name,
"required": arg.required if hasattr(arg, 'required') else False, "description": arg.description or "",
}) "required": arg.required if hasattr(arg, "required") else False,
}
)
prompt_info = MCPPromptInfo( prompt_info = MCPPromptInfo(
name=prompt.name, name=prompt.name,
description=prompt.description or f"MCP prompt: {prompt.name}", description=prompt.description or f"MCP prompt: {prompt.name}",
arguments=arguments, arguments=arguments,
server_name=self.server_name server_name=self.server_name,
) )
self._prompts.append(prompt_info) self._prompts.append(prompt_info)
logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}") logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}")
@@ -686,35 +692,25 @@ class MCPClientSession:
start_time = time.time() start_time = time.time()
if not self._connected or not self._session: if not self._connected or not self._session:
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
success=False,
content=None,
error=f"服务器 {self.server_name} 未连接"
)
if not self._supports_resources: if not self._supports_resources:
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Resources 功能")
success=False,
content=None,
error=f"服务器 {self.server_name} 不支持 Resources 功能"
)
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(self._session.read_resource(uri), timeout=self.call_timeout)
self._session.read_resource(uri),
timeout=self.call_timeout
)
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
# 处理返回内容 # 处理返回内容
content_parts = [] content_parts = []
for content in result.contents: for content in result.contents:
if hasattr(content, 'text'): if hasattr(content, "text"):
content_parts.append(content.text) content_parts.append(content.text)
elif hasattr(content, 'blob'): elif hasattr(content, "blob"):
# 二进制数据,返回 base64 或提示 # 二进制数据,返回 base64 或提示
import base64 import base64
blob_data = content.blob blob_data = content.blob
if len(blob_data) < 10000: # 小于 10KB 返回 base64 if len(blob_data) < 10000: # 小于 10KB 返回 base64
content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}") content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}")
@@ -724,28 +720,18 @@ class MCPClientSession:
content_parts.append(str(content)) content_parts.append(str(content))
return MCPCallResult( return MCPCallResult(
success=True, success=True, content="\n".join(content_parts) if content_parts else "", duration_ms=duration_ms
content="\n".join(content_parts) if content_parts else "",
duration_ms=duration_ms
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
return MCPCallResult( return MCPCallResult(
success=False, success=False, content=None, error=f"读取资源超时({self.call_timeout}秒)", duration_ms=duration_ms
content=None,
error=f"读取资源超时({self.call_timeout}秒)",
duration_ms=duration_ms
) )
except Exception as e: except Exception as e:
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}") logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}")
return MCPCallResult( return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
success=False,
content=None,
error=str(e),
duration_ms=duration_ms
)
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult: async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult:
"""v1.2.0: 获取提示模板的内容 """v1.2.0: 获取提示模板的内容
@@ -760,23 +746,14 @@ class MCPClientSession:
start_time = time.time() start_time = time.time()
if not self._connected or not self._session: if not self._connected or not self._session:
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
success=False,
content=None,
error=f"服务器 {self.server_name} 未连接"
)
if not self._supports_prompts: if not self._supports_prompts:
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Prompts 功能")
success=False,
content=None,
error=f"服务器 {self.server_name} 不支持 Prompts 功能"
)
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(
self._session.get_prompt(name, arguments=arguments or {}), self._session.get_prompt(name, arguments=arguments or {}), timeout=self.call_timeout
timeout=self.call_timeout
) )
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
@@ -784,10 +761,10 @@ class MCPClientSession:
# 处理返回的消息 # 处理返回的消息
messages = [] messages = []
for msg in result.messages: for msg in result.messages:
role = msg.role if hasattr(msg, 'role') else "unknown" role = msg.role if hasattr(msg, "role") else "unknown"
content_text = "" content_text = ""
if hasattr(msg, 'content'): if hasattr(msg, "content"):
if hasattr(msg.content, 'text'): if hasattr(msg.content, "text"):
content_text = msg.content.text content_text = msg.content.text
elif isinstance(msg.content, str): elif isinstance(msg.content, str):
content_text = msg.content content_text = msg.content
@@ -796,28 +773,18 @@ class MCPClientSession:
messages.append(f"[{role}]: {content_text}") messages.append(f"[{role}]: {content_text}")
return MCPCallResult( return MCPCallResult(
success=True, success=True, content="\n\n".join(messages) if messages else "", duration_ms=duration_ms
content="\n\n".join(messages) if messages else "",
duration_ms=duration_ms
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
return MCPCallResult( return MCPCallResult(
success=False, success=False, content=None, error=f"获取提示模板超时({self.call_timeout}秒)", duration_ms=duration_ms
content=None,
error=f"获取提示模板超时({self.call_timeout}秒)",
duration_ms=duration_ms
) )
except Exception as e: except Exception as e:
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}") logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}")
return MCPCallResult( return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
success=False,
content=None,
error=str(e),
duration_ms=duration_ms
)
async def check_health(self) -> bool: async def check_health(self) -> bool:
"""检查连接健康状态(心跳检测) """检查连接健康状态(心跳检测)
@@ -829,10 +796,7 @@ class MCPClientSession:
try: try:
# 使用 list_tools 作为心跳检测 # 使用 list_tools 作为心跳检测
await asyncio.wait_for( await asyncio.wait_for(self._session.list_tools(), timeout=10.0)
self._session.list_tools(),
timeout=10.0
)
self.stats.record_heartbeat() self.stats.record_heartbeat()
return True return True
except Exception as e: except Exception as e:
@@ -849,12 +813,7 @@ class MCPClientSession:
# v1.7.0: 断路器检查 # v1.7.0: 断路器检查
can_execute, reject_reason = self._circuit_breaker.can_execute() can_execute, reject_reason = self._circuit_breaker.can_execute()
if not can_execute: if not can_execute:
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"{reject_reason}", circuit_broken=True)
success=False,
content=None,
error=f"{reject_reason}",
circuit_broken=True
)
# 半开状态下增加试探计数 # 半开状态下增加试探计数
if self._circuit_breaker.state == CircuitState.HALF_OPEN: if self._circuit_breaker.state == CircuitState.HALF_OPEN:
@@ -870,8 +829,7 @@ class MCPClientSession:
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(
self._session.call_tool(tool_name, arguments=arguments), self._session.call_tool(tool_name, arguments=arguments), timeout=self.call_timeout
timeout=self.call_timeout
) )
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
@@ -879,9 +837,9 @@ class MCPClientSession:
# 处理返回内容 # 处理返回内容
content_parts = [] content_parts = []
for content in result.content: for content in result.content:
if hasattr(content, 'text'): if hasattr(content, "text"):
content_parts.append(content.text) content_parts.append(content.text)
elif hasattr(content, 'data'): elif hasattr(content, "data"):
content_parts.append(f"[二进制数据: {len(content.data)} bytes]") content_parts.append(f"[二进制数据: {len(content.data)} bytes]")
else: else:
content_parts.append(str(content)) content_parts.append(str(content))
@@ -896,7 +854,7 @@ class MCPClientSession:
return MCPCallResult( return MCPCallResult(
success=True, success=True,
content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)", content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)",
duration_ms=duration_ms duration_ms=duration_ms,
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
@@ -939,25 +897,25 @@ class MCPClientSession:
self._supports_prompts = False # v1.2.0 self._supports_prompts = False # v1.2.0
try: try:
if hasattr(self, '_session_context') and self._session_context: if hasattr(self, "_session_context") and self._session_context:
await self._session_context.__aexit__(None, None, None) await self._session_context.__aexit__(None, None, None)
except Exception as e: except Exception as e:
logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}") logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}")
try: try:
if hasattr(self, '_stdio_context') and self._stdio_context: if hasattr(self, "_stdio_context") and self._stdio_context:
await self._stdio_context.__aexit__(None, None, None) await self._stdio_context.__aexit__(None, None, None)
except Exception as e: except Exception as e:
logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}") logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}")
try: try:
if hasattr(self, '_http_context') and self._http_context: if hasattr(self, "_http_context") and self._http_context:
await self._http_context.__aexit__(None, None, None) await self._http_context.__aexit__(None, None, None)
except Exception as e: except Exception as e:
logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}") logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}")
try: try:
if hasattr(self, '_sse_context') and self._sse_context: if hasattr(self, "_sse_context") and self._sse_context:
await self._sse_context.__aexit__(None, None, None) await self._sse_context.__aexit__(None, None, None)
except Exception as e: except Exception as e:
logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}") logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}")
@@ -1082,7 +1040,9 @@ class MCPClientManager:
return True return True
if attempt < retry_attempts: if attempt < retry_attempts:
logger.warning(f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})") logger.warning(
f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})"
)
await asyncio.sleep(retry_interval) await asyncio.sleep(retry_interval)
logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})") logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})")
@@ -1213,11 +1173,7 @@ class MCPClientManager:
async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult: async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult:
"""调用 MCP 工具""" """调用 MCP 工具"""
if tool_key not in self._all_tools: if tool_key not in self._all_tools:
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"工具 {tool_key} 不存在")
success=False,
content=None,
error=f"工具 {tool_key} 不存在"
)
tool_info, client = self._all_tools[tool_key] tool_info, client = self._all_tools[tool_key]
@@ -1273,11 +1229,7 @@ class MCPClientManager:
# 如果指定了服务器 # 如果指定了服务器
if server_name: if server_name:
if server_name not in self._clients: if server_name not in self._clients:
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
success=False,
content=None,
error=f"服务器 {server_name} 不存在"
)
client = self._clients[server_name] client = self._clients[server_name]
return await client.read_resource(uri) return await client.read_resource(uri)
@@ -1293,14 +1245,11 @@ class MCPClientManager:
if result.success: if result.success:
return result return result
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"未找到资源: {uri}")
success=False,
content=None,
error=f"未找到资源: {uri}"
)
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None, async def get_prompt(
server_name: Optional[str] = None) -> MCPCallResult: self, name: str, arguments: Optional[Dict[str, str]] = None, server_name: Optional[str] = None
) -> MCPCallResult:
"""v1.2.0: 获取提示模板内容 """v1.2.0: 获取提示模板内容
Args: Args:
@@ -1311,11 +1260,7 @@ class MCPClientManager:
# 如果指定了服务器 # 如果指定了服务器
if server_name: if server_name:
if server_name not in self._clients: if server_name not in self._clients:
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
success=False,
content=None,
error=f"服务器 {server_name} 不存在"
)
client = self._clients[server_name] client = self._clients[server_name]
return await client.get_prompt(name, arguments) return await client.get_prompt(name, arguments)
@@ -1324,11 +1269,7 @@ class MCPClientManager:
if prompt_info.name == name: if prompt_info.name == name:
return await client.get_prompt(name, arguments) return await client.get_prompt(name, arguments)
return MCPCallResult( return MCPCallResult(success=False, content=None, error=f"未找到提示模板: {name}")
success=False,
content=None,
error=f"未找到提示模板: {name}"
)
# ==================== 心跳检测 ==================== # ==================== 心跳检测 ====================
@@ -1489,7 +1430,9 @@ class MCPClientManager:
"global": { "global": {
**self._global_stats, **self._global_stats,
"uptime_seconds": round(uptime, 2), "uptime_seconds": round(uptime, 2),
"calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2) if uptime > 0 else 0, "calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2)
if uptime > 0
else 0,
}, },
"servers": server_stats, "servers": server_stats,
"tools": tool_stats, "tools": tool_stats,

View File

@@ -84,7 +84,7 @@ from .mcp_client import (
TransportType, TransportType,
mcp_manager, mcp_manager,
) )
from .config_converter import ConfigConverter, ConversionResult from .config_converter import ConfigConverter
logger = get_logger("mcp_bridge_plugin") logger = get_logger("mcp_bridge_plugin")
@@ -93,9 +93,11 @@ logger = get_logger("mcp_bridge_plugin")
# v1.4.0: 调用链路追踪 # v1.4.0: 调用链路追踪
# ============================================================================ # ============================================================================
@dataclass @dataclass
class ToolCallRecord: class ToolCallRecord:
"""工具调用记录""" """工具调用记录"""
call_id: str call_id: str
timestamp: float timestamp: float
tool_name: str tool_name: str
@@ -178,9 +180,11 @@ tool_call_tracer = ToolCallTracer()
# v1.4.0: 工具调用缓存 # v1.4.0: 工具调用缓存
# ============================================================================ # ============================================================================
@dataclass @dataclass
class CacheEntry: class CacheEntry:
"""缓存条目""" """缓存条目"""
tool_name: str tool_name: str
args_hash: str args_hash: str
result: str result: str
@@ -317,6 +321,7 @@ tool_call_cache = ToolCallCache()
# v1.4.0: 工具权限控制 # v1.4.0: 工具权限控制
# ============================================================================ # ============================================================================
class PermissionChecker: class PermissionChecker:
"""工具权限检查器""" """工具权限检查器"""
@@ -449,6 +454,7 @@ permission_checker = PermissionChecker()
# 工具类型转换 # 工具类型转换
# ============================================================================ # ============================================================================
def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType: def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
"""将 JSON Schema 类型转换为 MaiBot 的 ToolParamType""" """将 JSON Schema 类型转换为 MaiBot 的 ToolParamType"""
type_mapping = { type_mapping = {
@@ -462,7 +468,9 @@ def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
return type_mapping.get(json_type, ToolParamType.STRING) return type_mapping.get(json_type, ToolParamType.STRING)
def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]: def parse_mcp_parameters(
input_schema: Dict[str, Any],
) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]:
"""解析 MCP 工具的参数 schema转换为 MaiBot 的参数格式""" """解析 MCP 工具的参数 schema转换为 MaiBot 的参数格式"""
parameters = [] parameters = []
@@ -497,6 +505,7 @@ def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolPa
# MCP 工具代理 # MCP 工具代理
# ============================================================================ # ============================================================================
class MCPToolProxy(BaseTool): class MCPToolProxy(BaseTool):
"""MCP 工具代理基类""" """MCP 工具代理基类"""
@@ -539,10 +548,7 @@ class MCPToolProxy(BaseTool):
# v1.4.0: 权限检查 # v1.4.0: 权限检查
if not permission_checker.check(self.name, chat_id, user_id, is_group): if not permission_checker.check(self.name, chat_id, user_id, is_group):
logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}") logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}")
return { return {"name": self.name, "content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"}
"name": self.name,
"content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"
}
logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}") logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}")
@@ -726,11 +732,7 @@ class MCPToolProxy(BaseTool):
return None return None
async def _call_post_process_llm( async def _call_post_process_llm(
self, self, prompt: str, max_tokens: int, settings: Dict[str, Any], server_config: Optional[Dict[str, Any]]
prompt: str,
max_tokens: int,
settings: Dict[str, Any],
server_config: Optional[Dict[str, Any]]
) -> Optional[str]: ) -> Optional[str]:
"""调用 LLM 进行后处理""" """调用 LLM 进行后处理"""
from src.config.config import model_config from src.config.config import model_config
@@ -788,10 +790,7 @@ class MCPToolProxy(BaseTool):
def create_mcp_tool_class( def create_mcp_tool_class(
tool_key: str, tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False
tool_info: MCPToolInfo,
tool_prefix: str,
disabled: bool = False
) -> Type[MCPToolProxy]: ) -> Type[MCPToolProxy]:
"""根据 MCP 工具信息动态创建 BaseTool 子类""" """根据 MCP 工具信息动态创建 BaseTool 子类"""
parameters = parse_mcp_parameters(tool_info.input_schema) parameters = parse_mcp_parameters(tool_info.input_schema)
@@ -814,7 +813,7 @@ def create_mcp_tool_class(
"_mcp_tool_key": tool_key, "_mcp_tool_key": tool_key,
"_mcp_original_name": tool_info.name, "_mcp_original_name": tool_info.name,
"_mcp_server_name": tool_info.server_name, "_mcp_server_name": tool_info.server_name,
} },
) )
return tool_class return tool_class
@@ -828,11 +827,7 @@ class MCPToolRegistry:
self._tool_infos: Dict[str, ToolInfo] = {} self._tool_infos: Dict[str, ToolInfo] = {}
def register_tool( def register_tool(
self, self, tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False
tool_key: str,
tool_info: MCPToolInfo,
tool_prefix: str,
disabled: bool = False
) -> Tuple[ToolInfo, Type[MCPToolProxy]]: ) -> Tuple[ToolInfo, Type[MCPToolProxy]]:
"""注册 MCP 工具""" """注册 MCP 工具"""
tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled) tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled)
@@ -879,6 +874,7 @@ _plugin_instance: Optional["MCPBridgePlugin"] = None
# 内置工具 # 内置工具
# ============================================================================ # ============================================================================
class MCPReadResourceTool(BaseTool): class MCPReadResourceTool(BaseTool):
"""v1.2.0: MCP 资源读取工具""" """v1.2.0: MCP 资源读取工具"""
@@ -950,9 +946,17 @@ class MCPStatusTool(BaseTool):
"""MCP 状态查询工具""" """MCP 状态查询工具"""
name = "mcp_status" name = "mcp_status"
description = "查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、资源列表、提示模板列表、调用统计、追踪记录等信息" description = (
"查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、资源列表、提示模板列表、调用统计、追踪记录等信息"
)
parameters = [ parameters = [
("query_type", ToolParamType.STRING, "查询类型", False, ["status", "tools", "resources", "prompts", "stats", "trace", "cache", "all"]), (
"query_type",
ToolParamType.STRING,
"查询类型",
False,
["status", "tools", "resources", "prompts", "stats", "trace", "cache", "all"],
),
("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None), ("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None),
] ]
available_for_llm = True available_for_llm = True
@@ -986,10 +990,7 @@ class MCPStatusTool(BaseTool):
if query_type in ("cache",): if query_type in ("cache",):
result_parts.append(self._format_cache()) result_parts.append(self._format_cache())
return { return {"name": self.name, "content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"}
"name": self.name,
"content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"
}
def _format_status(self, server_name: Optional[str] = None) -> str: def _format_status(self, server_name: Optional[str] = None) -> str:
status = mcp_manager.get_status() status = mcp_manager.get_status()
@@ -1001,14 +1002,14 @@ class MCPStatusTool(BaseTool):
lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}") lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}")
lines.append("\n🔌 服务器详情:") lines.append("\n🔌 服务器详情:")
for name, info in status['servers'].items(): for name, info in status["servers"].items():
if server_name and name != server_name: if server_name and name != server_name:
continue continue
status_icon = "" if info['connected'] else "" status_icon = "" if info["connected"] else ""
enabled_text = "" if info['enabled'] else " (已禁用)" enabled_text = "" if info["enabled"] else " (已禁用)"
lines.append(f" {status_icon} {name}{enabled_text}") lines.append(f" {status_icon} {name}{enabled_text}")
lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}") lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}")
if info['consecutive_failures'] > 0: if info["consecutive_failures"] > 0:
lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']}") lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']}")
return "\n".join(lines) return "\n".join(lines)
@@ -1038,11 +1039,11 @@ class MCPStatusTool(BaseTool):
stats = mcp_manager.get_all_stats() stats = mcp_manager.get_all_stats()
lines = ["📈 调用统计"] lines = ["📈 调用统计"]
g = stats['global'] g = stats["global"]
lines.append(f" 总调用次数: {g['total_tool_calls']}") lines.append(f" 总调用次数: {g['total_tool_calls']}")
lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}") lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}")
if g['total_tool_calls'] > 0: if g["total_tool_calls"] > 0:
success_rate = (g['successful_calls'] / g['total_tool_calls']) * 100 success_rate = (g["successful_calls"] / g["total_tool_calls"]) * 100
lines.append(f" 成功率: {success_rate:.1f}%") lines.append(f" 成功率: {success_rate:.1f}%")
lines.append(f" 运行时间: {g['uptime_seconds']:.0f}") lines.append(f" 运行时间: {g['uptime_seconds']:.0f}")
@@ -1126,6 +1127,7 @@ class MCPStatusTool(BaseTool):
# 命令处理 # 命令处理
# ============================================================================ # ============================================================================
class MCPStatusCommand(BaseCommand): class MCPStatusCommand(BaseCommand):
"""MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态""" """MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态"""
@@ -1340,7 +1342,9 @@ class MCPStatusCommand(BaseCommand):
try: try:
exported = ConfigConverter.export_to_string(servers, format_type, pretty=True) exported = ConfigConverter.export_to_string(servers, format_type, pretty=True)
format_name = {"claude": "Claude Desktop", "kiro": "Kiro MCP", "maibot": "MaiBot"}.get(format_type, format_type) format_name = {"claude": "Claude Desktop", "kiro": "Kiro MCP", "maibot": "MaiBot"}.get(
format_type, format_type
)
lines = [f"📤 导出为 {format_name} 格式 ({len(servers)} 个服务器):"] lines = [f"📤 导出为 {format_name} 格式 ({len(servers)} 个服务器):"]
lines.append("") lines.append("")
lines.append(exported) lines.append(exported)
@@ -1446,9 +1450,9 @@ class MCPStatusCommand(BaseCommand):
cb = info.get("circuit_breaker", {}) cb = info.get("circuit_breaker", {})
cb_state = cb.get("state", "closed") cb_state = cb.get("state", "closed")
if cb_state == "open": if cb_state == "open":
lines.append(f" ⚡ 断路器熔断中") lines.append(" ⚡ 断路器熔断中")
elif cb_state == "half_open": elif cb_state == "half_open":
lines.append(f" ⚡ 断路器试探中") lines.append(" ⚡ 断路器试探中")
if info["consecutive_failures"] > 0: if info["consecutive_failures"] > 0:
lines.append(f" ⚠️ 连续失败 {info['consecutive_failures']}") lines.append(f" ⚠️ 连续失败 {info['consecutive_failures']}")
@@ -1634,6 +1638,7 @@ class MCPImportCommand(BaseCommand):
# 事件处理器 # 事件处理器
# ============================================================================ # ============================================================================
class MCPStartupHandler(BaseEventHandler): class MCPStartupHandler(BaseEventHandler):
"""MCP 启动事件处理器""" """MCP 启动事件处理器"""
@@ -1692,6 +1697,7 @@ class MCPStopHandler(BaseEventHandler):
# 主插件类 # 主插件类
# ============================================================================ # ============================================================================
@register_plugin @register_plugin
class MCPBridgePlugin(BasePlugin): class MCPBridgePlugin(BasePlugin):
"""MCP 桥接插件 v1.4.0 - 将 MCP 服务器的工具桥接到 MaiBot""" """MCP 桥接插件 v1.4.0 - 将 MCP 服务器的工具桥接到 MaiBot"""
@@ -2116,9 +2122,9 @@ class MCPBridgePlugin(BasePlugin):
label="📜 高级权限规则(可选)", label="📜 高级权限规则(可选)",
input_type="textarea", input_type="textarea",
rows=10, rows=10,
placeholder='''[ placeholder="""[
{"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]} {"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}
]''', ]""",
hint="格式: qq:ID:group/private/user工具名支持通配符 *", hint="格式: qq:ID:group/private/user工具名支持通配符 *",
order=10, order=10,
), ),
@@ -2261,7 +2267,9 @@ class MCPBridgePlugin(BasePlugin):
value = match1.group(2) value = match1.group(2)
suffix = match1.group(3) suffix = match1.group(3)
# 将转义的换行符还原为实际换行符 # 将转义的换行符还原为实际换行符
unescaped = value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\") unescaped = (
value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\")
)
fixed_line = f'{prefix}"""{unescaped}"""{suffix}' fixed_line = f'{prefix}"""{unescaped}"""{suffix}'
fixed_lines.append(fixed_line) fixed_lines.append(fixed_line)
modified = True modified = True
@@ -2560,7 +2568,7 @@ class MCPBridgePlugin(BasePlugin):
continue continue
import_config_str = str(import_config).strip() import_config_str = str(import_config).strip()
logger.info(f"检测到 WebUI 导入请求,开始处理...") logger.info("检测到 WebUI 导入请求,开始处理...")
# 执行导入 # 执行导入
await self._execute_webui_import(import_config_str, doc, config_path) await self._execute_webui_import(import_config_str, doc, config_path)
@@ -2773,6 +2781,7 @@ class MCPBridgePlugin(BasePlugin):
async def _async_connect_servers(self) -> None: async def _async_connect_servers(self) -> None:
"""异步连接所有配置的 MCP 服务器v1.5.0: 并行连接优化)""" """异步连接所有配置的 MCP 服务器v1.5.0: 并行连接优化)"""
import asyncio import asyncio
settings = self.config.get("settings", {}) settings = self.config.get("settings", {})
servers_section = self.config.get("servers", []) servers_section = self.config.get("servers", [])
@@ -2854,10 +2863,7 @@ class MCPBridgePlugin(BasePlugin):
# 并行执行所有连接 # 并行执行所有连接
start_time = time.time() start_time = time.time()
results = await asyncio.gather( results = await asyncio.gather(*[connect_single_server(cfg) for cfg in enabled_configs], return_exceptions=True)
*[connect_single_server(cfg) for cfg in enabled_configs],
return_exceptions=True
)
connect_duration = time.time() - start_time connect_duration = time.time() - start_time
# 统计连接结果 # 统计连接结果
@@ -2878,15 +2884,14 @@ class MCPBridgePlugin(BasePlugin):
# 注册所有工具 # 注册所有工具
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
registered_count = 0 registered_count = 0
for tool_key, (tool_info, _) in mcp_manager.all_tools.items(): for tool_key, (tool_info, _) in mcp_manager.all_tools.items():
tool_name = tool_key.replace("-", "_").replace(".", "_") tool_name = tool_key.replace("-", "_").replace(".", "_")
is_disabled = tool_name in disabled_tools is_disabled = tool_name in disabled_tools
info, tool_class = mcp_tool_registry.register_tool( info, tool_class = mcp_tool_registry.register_tool(tool_key, tool_info, tool_prefix, disabled=is_disabled)
tool_key, tool_info, tool_prefix, disabled=is_disabled
)
info.plugin_name = self.plugin_name info.plugin_name = self.plugin_name
if component_registry.register_component(info, tool_class): if component_registry.register_component(info, tool_class):
@@ -3004,6 +3009,7 @@ class MCPBridgePlugin(BasePlugin):
doc["tools"] = tomlkit.table() doc["tools"] = tomlkit.table()
# 使用 tomlkit 多行字符串避免控制字符问题 # 使用 tomlkit 多行字符串避免控制字符问题
from tomlkit.items import String, StringType, Trivia from tomlkit.items import String, StringType, Trivia
ml_string = String(StringType.MLB, tool_list_text, tool_list_text, Trivia()) ml_string = String(StringType.MLB, tool_list_text, tool_list_text, Trivia())
doc["tools"]["tool_list"] = ml_string doc["tools"]["tool_list"] = ml_string
@@ -3069,6 +3075,7 @@ class MCPBridgePlugin(BasePlugin):
doc["status"] = tomlkit.table() doc["status"] = tomlkit.table()
# 使用 tomlkit 多行字符串避免控制字符问题 # 使用 tomlkit 多行字符串避免控制字符问题
from tomlkit.items import String, StringType, Trivia from tomlkit.items import String, StringType, Trivia
ml_string = String(StringType.MLB, status_text, status_text, Trivia()) ml_string = String(StringType.MLB, status_text, status_text, Trivia())
doc["status"]["connection_status"] = ml_string doc["status"]["connection_status"] = ml_string

View File

@@ -33,7 +33,7 @@ async def test_stats():
assert stats.total_calls == 3 assert stats.total_calls == 3
assert stats.success_calls == 2 assert stats.success_calls == 2
assert stats.failed_calls == 1 assert stats.failed_calls == 1
assert stats.success_rate == (2/3) * 100 assert stats.success_rate == (2 / 3) * 100
assert stats.avg_duration_ms == 150.0 assert stats.avg_duration_ms == 150.0
assert stats.last_error == "timeout" assert stats.last_error == "timeout"
@@ -66,13 +66,15 @@ async def test_manager_basic():
manager.__init__() manager.__init__()
# 配置 # 配置
manager.configure({ manager.configure(
"tool_prefix": "mcp", {
"call_timeout": 30.0, "tool_prefix": "mcp",
"retry_attempts": 1, "call_timeout": 30.0,
"retry_interval": 1.0, "retry_attempts": 1,
"heartbeat_enabled": False, "retry_interval": 1.0,
}) "heartbeat_enabled": False,
}
)
# 测试状态 # 测试状态
status = manager.get_status() status = manager.get_status()
@@ -82,10 +84,7 @@ async def test_manager_basic():
# 测试添加禁用的服务器 # 测试添加禁用的服务器
config = MCPServerConfig( config = MCPServerConfig(
name="disabled_server", name="disabled_server", enabled=False, transport=TransportType.HTTP, url="https://example.com/mcp"
enabled=False,
transport=TransportType.HTTP,
url="https://example.com/mcp"
) )
result = await manager.add_server(config) result = await manager.add_server(config)
assert result == True assert result == True
@@ -120,27 +119,29 @@ async def test_http_connection():
manager._initialized = False manager._initialized = False
manager.__init__() manager.__init__()
manager.configure({ manager.configure(
"tool_prefix": "mcp", {
"call_timeout": 30.0, "tool_prefix": "mcp",
"retry_attempts": 2, "call_timeout": 30.0,
"retry_interval": 2.0, "retry_attempts": 2,
"heartbeat_enabled": False, "retry_interval": 2.0,
}) "heartbeat_enabled": False,
}
)
# 使用 HowToCook MCP 服务器测试 # 使用 HowToCook MCP 服务器测试
config = MCPServerConfig( config = MCPServerConfig(
name="howtocook", name="howtocook",
enabled=True, enabled=True,
transport=TransportType.HTTP, transport=TransportType.HTTP,
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp" url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp",
) )
print(f"正在连接 {config.url} ...") print(f"正在连接 {config.url} ...")
result = await manager.add_server(config) result = await manager.add_server(config)
if result: if result:
print(f"✅ 连接成功!") print("✅ 连接成功!")
# 检查工具 # 检查工具
tools = manager.all_tools tools = manager.all_tools
@@ -159,19 +160,23 @@ async def test_http_connection():
call_result = await manager.call_tool("mcp_howtocook_whatToEat", {}) call_result = await manager.call_tool("mcp_howtocook_whatToEat", {})
if call_result.success: if call_result.success:
print(f"✅ 工具调用成功 (耗时: {call_result.duration_ms:.0f}ms)") print(f"✅ 工具调用成功 (耗时: {call_result.duration_ms:.0f}ms)")
print(f" 结果: {call_result.content[:200]}..." if len(str(call_result.content)) > 200 else f" 结果: {call_result.content}") print(
f" 结果: {call_result.content[:200]}..."
if len(str(call_result.content)) > 200
else f" 结果: {call_result.content}"
)
else: else:
print(f"❌ 工具调用失败: {call_result.error}") print(f"❌ 工具调用失败: {call_result.error}")
# 查看统计 # 查看统计
stats = manager.get_all_stats() stats = manager.get_all_stats()
print(f"\n📊 统计信息:") print("\n📊 统计信息:")
print(f" 全局调用: {stats['global']['total_tool_calls']}") print(f" 全局调用: {stats['global']['total_tool_calls']}")
print(f" 成功: {stats['global']['successful_calls']}") print(f" 成功: {stats['global']['successful_calls']}")
print(f" 失败: {stats['global']['failed_calls']}") print(f" 失败: {stats['global']['failed_calls']}")
else: else:
print(f"❌ 连接失败") print("❌ 连接失败")
# 清理 # 清理
await manager.shutdown() await manager.shutdown()
@@ -187,23 +192,25 @@ async def test_heartbeat():
manager._initialized = False manager._initialized = False
manager.__init__() manager.__init__()
manager.configure({ manager.configure(
"tool_prefix": "mcp", {
"call_timeout": 30.0, "tool_prefix": "mcp",
"retry_attempts": 1, "call_timeout": 30.0,
"retry_interval": 1.0, "retry_attempts": 1,
"heartbeat_enabled": True, "retry_interval": 1.0,
"heartbeat_interval": 5.0, # 5秒间隔用于测试 "heartbeat_enabled": True,
"auto_reconnect": True, "heartbeat_interval": 5.0, # 5秒间隔用于测试
"max_reconnect_attempts": 2, "auto_reconnect": True,
}) "max_reconnect_attempts": 2,
}
)
# 添加一个测试服务器 # 添加一个测试服务器
config = MCPServerConfig( config = MCPServerConfig(
name="heartbeat_test", name="heartbeat_test",
enabled=True, enabled=True,
transport=TransportType.HTTP, transport=TransportType.HTTP,
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp" url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp",
) )
print("正在连接服务器...") print("正在连接服务器...")
@@ -260,6 +267,7 @@ async def main():
except Exception as e: except Exception as e:
print(f"\n❌ 测试失败: {e}") print(f"\n❌ 测试失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return False return False

View File

@@ -301,4 +301,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -13,7 +13,12 @@ from src.chat.utils.chat_message_builder import (
) )
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.bw_learner.learner_utils import filter_message_content, is_bot_message, build_context_paragraph, contains_bot_self_name from src.bw_learner.learner_utils import (
filter_message_content,
is_bot_message,
build_context_paragraph,
contains_bot_self_name,
)
from src.bw_learner.jargon_miner import miner_manager from src.bw_learner.jargon_miner import miner_manager
from json_repair import repair_json from json_repair import repair_json
@@ -77,8 +82,6 @@ def init_prompt() -> None:
Prompt(learn_style_prompt, "learn_style_prompt") Prompt(learn_style_prompt, "learn_style_prompt")
class ExpressionLearner: class ExpressionLearner:
def __init__(self, chat_id: str) -> None: def __init__(self, chat_id: str) -> None:
self.express_learn_model: LLMRequest = LLMRequest( self.express_learn_model: LLMRequest = LLMRequest(
@@ -186,7 +189,6 @@ class ExpressionLearner:
filtered_expressions.append((situation, style, context)) filtered_expressions.append((situation, style, context))
learnt_expressions = filtered_expressions learnt_expressions = filtered_expressions
if learnt_expressions is None: if learnt_expressions is None:
@@ -270,6 +272,7 @@ class ExpressionLearner:
# 如果解析失败,尝试修复中文引号问题 # 如果解析失败,尝试修复中文引号问题
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号 # 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
try: try:
def fix_chinese_quotes_in_json(text): def fix_chinese_quotes_in_json(text):
"""使用状态机修复 JSON 字符串值中的中文引号""" """使用状态机修复 JSON 字符串值中的中文引号"""
result = [] result = []
@@ -287,7 +290,7 @@ class ExpressionLearner:
i += 1 i += 1
continue continue
if char == '\\': if char == "\\":
# 转义字符 # 转义字符
result.append(char) result.append(char)
escape_next = True escape_next = True
@@ -315,7 +318,7 @@ class ExpressionLearner:
i += 1 i += 1
return ''.join(result) return "".join(result)
fixed_raw = fix_chinese_quotes_in_json(raw) fixed_raw = fix_chinese_quotes_in_json(raw)

View File

@@ -82,9 +82,7 @@ class ExpressionReflector:
# 获取未检查的表达 # 获取未检查的表达
try: try:
logger.info("[Expression Reflection] 查询未检查且未拒绝的表达") logger.info("[Expression Reflection] 查询未检查且未拒绝的表达")
expressions = ( expressions = Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
)
expr_list = list(expressions) expr_list = list(expressions)
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达") logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")

View File

@@ -128,9 +128,7 @@ class ExpressionSelector:
# 查询所有相关chat_id的表达方式排除 rejected=1 的,且只选择 count > 1 的 # 查询所有相关chat_id的表达方式排除 rejected=1 的,且只选择 count > 1 的
style_query = Expression.select().where( style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
& (~Expression.rejected)
& (Expression.count > 1)
) )
style_exprs = [ style_exprs = [
@@ -150,12 +148,15 @@ class ExpressionSelector:
# 要求至少有10个 count > 1 的表达方式才进行选择 # 要求至少有10个 count > 1 的表达方式才进行选择
min_required = 10 min_required = 10
if len(style_exprs) < min_required: if len(style_exprs) < min_required:
logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择") logger.info(
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择"
)
return [], [] return [], []
# 固定选择5个 # 固定选择5个
select_count = 5 select_count = 5
import random import random
selected_style = random.sample(style_exprs, select_count) selected_style = random.sample(style_exprs, select_count)
# 更新last_active_time # 更新last_active_time
@@ -163,7 +164,9 @@ class ExpressionSelector:
self.update_expressions_last_active_time(selected_style) self.update_expressions_last_active_time(selected_style)
selected_ids = [expr["id"] for expr in selected_style] selected_ids = [expr["id"] for expr in selected_style]
logger.debug(f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)}") logger.debug(
f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)}"
)
return selected_style, selected_ids return selected_style, selected_ids
except Exception as e: except Exception as e:
@@ -186,9 +189,7 @@ class ExpressionSelector:
related_chat_ids = self.get_related_chat_ids(chat_id) related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式排除 rejected=1 的表达 # 优化一次性查询所有相关chat_id的表达方式排除 rejected=1 的表达
style_query = Expression.select().where( style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
)
style_exprs = [ style_exprs = [
{ {
@@ -246,7 +247,9 @@ class ExpressionSelector:
# 使用classic模式随机选择+LLM选择 # 使用classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式think_level={think_level}") logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式think_level={think_level}")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message, reply_reason, think_level) return await self._select_expressions_classic(
chat_id, chat_info, max_num, target_message, reply_reason, think_level
)
async def _select_expressions_classic( async def _select_expressions_classic(
self, self,
@@ -279,9 +282,7 @@ class ExpressionSelector:
# think_level == 1: 先选高count再从所有表达方式中随机抽样 # think_level == 1: 先选高count再从所有表达方式中随机抽样
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的 # 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
related_chat_ids = self.get_related_chat_ids(chat_id) related_chat_ids = self.get_related_chat_ids(chat_id)
style_query = Expression.select().where( style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
)
all_style_exprs = [ all_style_exprs = [
{ {
@@ -308,11 +309,15 @@ class ExpressionSelector:
# 检查数量要求 # 检查数量要求
if len(high_count_exprs) < min_high_count: if len(high_count_exprs) < min_high_count:
logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择") logger.info(
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择"
)
return [], [] return [], []
if len(all_style_exprs) < min_total_count: if len(all_style_exprs) < min_total_count:
logger.info(f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择") logger.info(
f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择"
)
return [], [] return [], []
# 先选取高count的表达方式 # 先选取高count的表达方式
@@ -332,6 +337,7 @@ class ExpressionSelector:
# 打乱顺序避免高count的都在前面 # 打乱顺序避免高count的都在前面
import random import random
random.shuffle(candidate_exprs) random.shuffle(candidate_exprs)
# 2. 构建所有表达方式的索引和情境列表 # 2. 构建所有表达方式的索引和情境列表
@@ -351,7 +357,7 @@ class ExpressionSelector:
all_situations_str = "\n".join(all_situations) all_situations_str = "\n".join(all_situations)
if target_message: if target_message:
target_message_str = f",现在你想要对这条消息进行回复:\"{target_message}\"" target_message_str = f',现在你想要对这条消息进行回复:"{target_message}"'
target_message_extra_block = "4.考虑你要回复的目标消息" target_message_extra_block = "4.考虑你要回复的目标消息"
else: else:
target_message_str = "" target_message_str = ""

View File

@@ -8,7 +8,12 @@ from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config from src.config.config import model_config, global_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.bw_learner.jargon_miner import search_jargon from src.bw_learner.jargon_miner import search_jargon
from src.bw_learner.learner_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains from src.bw_learner.learner_utils import (
is_bot_message,
contains_bot_self_name,
parse_chat_id_list,
chat_id_list_contains,
)
logger = get_logger("jargon") logger = get_logger("jargon")

View File

@@ -1,4 +1,3 @@
import time
import json import json
import asyncio import asyncio
import random import random
@@ -14,7 +13,6 @@ from src.config.config import model_config, global_config
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id, build_readable_messages_with_id,
get_raw_msg_by_timestamp_with_chat_inclusive,
) )
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.bw_learner.learner_utils import ( from src.bw_learner.learner_utils import (
@@ -46,10 +44,10 @@ def _is_single_char_jargon(content: str) -> bool:
char = content[0] char = content[0]
# 判断是否是单个汉字、单个英文字母或单个数字 # 判断是否是单个汉字、单个英文字母或单个数字
return ( return (
'\u4e00' <= char <= '\u9fff' or # 汉字 "\u4e00" <= char <= "\u9fff" # 汉字
'a' <= char <= 'z' or # 小写字母 or "a" <= char <= "z" # 小写字母
'A' <= char <= 'Z' or # 大写字母 or "A" <= char <= "Z" # 大写字母
'0' <= char <= '9' # 数字 or "0" <= char <= "9" # 数字
) )
@@ -305,7 +303,9 @@ class JargonMiner:
# 计算要保留的数量至少保留1个 # 计算要保留的数量至少保留1个
keep_count = max(1, len(raw_content_list) // 2) keep_count = max(1, len(raw_content_list) // 2)
raw_content_list = random.sample(raw_content_list, keep_count) raw_content_list = random.sample(raw_content_list, keep_count)
logger.info(f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目") logger.info(
f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目"
)
# 步骤1: 基于raw_content和content推断 # 步骤1: 基于raw_content和content推断
raw_content_text = "\n".join(raw_content_list) raw_content_text = "\n".join(raw_content_list)
@@ -318,7 +318,9 @@ class JargonMiner:
**上一次推断的含义(仅供参考)** **上一次推断的含义(仅供参考)**
{previous_meaning} {previous_meaning}
""" """
previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果" previous_meaning_instruction = (
"- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
)
prompt1 = await global_prompt_manager.format_prompt( prompt1 = await global_prompt_manager.format_prompt(
"jargon_inference_with_context_prompt", "jargon_inference_with_context_prompt",
@@ -650,7 +652,9 @@ class JargonMiner:
if obj.raw_content: if obj.raw_content:
try: try:
existing_raw_content = ( existing_raw_content = (
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content json.loads(obj.raw_content)
if isinstance(obj.raw_content, str)
else obj.raw_content
) )
if not isinstance(existing_raw_content, list): if not isinstance(existing_raw_content, list):
existing_raw_content = [existing_raw_content] if existing_raw_content else [] existing_raw_content = [existing_raw_content] if existing_raw_content else []
@@ -876,8 +880,6 @@ class JargonMinerManager:
miner_manager = JargonMinerManager() miner_manager = JargonMinerManager()
def search_jargon( def search_jargon(
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:

View File

@@ -116,7 +116,6 @@ class MessageRecorder:
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}" f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
) )
# 分别触发 expression_learner 和 jargon_miner 的处理 # 分别触发 expression_learner 和 jargon_miner 的处理
# 传递提取的消息,避免它们重复获取 # 传递提取的消息,避免它们重复获取
# 触发 expression 学习(如果启用) # 触发 expression 学习(如果启用)
@@ -127,21 +126,19 @@ class MessageRecorder:
# 触发 jargon 提取(如果启用),传递消息 # 触发 jargon 提取(如果启用),传递消息
# if self.enable_jargon_learning: # if self.enable_jargon_learning:
# asyncio.create_task( # asyncio.create_task(
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages) # self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
# ) # )
except Exception as e: except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}") logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
# 即使失败也保持时间戳更新,避免频繁重试 # 即使失败也保持时间戳更新,避免频繁重试
async def _trigger_expression_learning( async def _trigger_expression_learning(
self, self, timestamp_start: float, timestamp_end: float, messages: List[Any]
timestamp_start: float,
timestamp_end: float,
messages: List[Any]
) -> None: ) -> None:
""" """
触发 expression 学习,使用指定的消息列表 触发 expression 学习,使用指定的消息列表
@@ -162,13 +159,11 @@ class MessageRecorder:
except Exception as e: except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}") logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
async def _trigger_jargon_extraction( async def _trigger_jargon_extraction(
self, self, timestamp_start: float, timestamp_end: float, messages: List[Any]
timestamp_start: float,
timestamp_end: float,
messages: List[Any]
) -> None: ) -> None:
""" """
触发 jargon 提取,使用指定的消息列表 触发 jargon 提取,使用指定的消息列表
@@ -185,6 +180,7 @@ class MessageRecorder:
except Exception as e: except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}") logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
@@ -214,4 +210,3 @@ async def extract_and_distribute_messages(chat_id: str) -> None:
""" """
recorder = recorder_manager.get_recorder(chat_id) recorder = recorder_manager.get_recorder(chat_id)
await recorder.extract_and_distribute() await recorder.extract_and_distribute()

View File

@@ -328,9 +328,7 @@ class BrainChatting:
) )
# 检查是否有 complete_talk 动作(会停止后续迭代) # 检查是否有 complete_talk 动作(会停止后续迭代)
has_complete_talk = any( has_complete_talk = any(action.action_type == "complete_talk" for action in action_to_use_info)
action.action_type == "complete_talk" for action in action_to_use_info
)
# 并行执行所有动作 # 并行执行所有动作
action_tasks = [ action_tasks = [

View File

@@ -204,7 +204,9 @@ class BrainPlanner:
# 注意listening 已合并到 wait 中,如果遇到 listening 则转换为 wait # 注意listening 已合并到 wait 中,如果遇到 listening 则转换为 wait
internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"] internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"]
logger.debug(f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}") logger.debug(
f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}"
)
# 将 listening 转换为 wait向后兼容 # 将 listening 转换为 wait向后兼容
if action == "listening": if action == "listening":
@@ -521,7 +523,7 @@ class BrainPlanner:
if json_objects: if json_objects:
logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象") logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
for i, json_obj in enumerate(json_objects): for i, json_obj in enumerate(json_objects):
logger.info(f"{self.log_prefix}解析第{i+1}个JSON对象: {json_obj}") logger.info(f"{self.log_prefix}解析第{i + 1}个JSON对象: {json_obj}")
filtered_actions_list = list(filtered_actions.items()) filtered_actions_list = list(filtered_actions.items())
for json_obj in json_objects: for json_obj in json_objects:
parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list) parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list)
@@ -553,7 +555,9 @@ class BrainPlanner:
return extracted_reasoning, actions return extracted_reasoning, actions
def _create_complete_talk(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]: def _create_complete_talk(
self, reasoning: str, available_actions: Dict[str, ActionInfo]
) -> List[ActionPlannerInfo]:
"""创建complete_talk""" """创建complete_talk"""
return [ return [
ActionPlannerInfo( ActionPlannerInfo(

View File

@@ -271,7 +271,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
emoji.description = emoji_data.description emoji.description = emoji_data.description
# Deserialize emotion string from DB to list # Deserialize emotion string from DB to list
emoji.emotion = emoji_data.emotion.replace("",",").split(",") if emoji_data.emotion else [] emoji.emotion = emoji_data.emotion.replace("", ",").split(",") if emoji_data.emotion else []
emoji.usage_count = emoji_data.usage_count emoji.usage_count = emoji_data.usage_count
db_last_used_time = emoji_data.last_used_time db_last_used_time = emoji_data.last_used_time
@@ -732,7 +732,7 @@ class EmojiManager:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash) emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
if emoji_record and emoji_record.emotion: if emoji_record and emoji_record.emotion:
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...") logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
return emoji_record.emotion.replace("",",").split(",") return emoji_record.emotion.replace("", ",").split(",")
except Exception as e: except Exception as e:
logger.error(f"从数据库查询表情包情感标签时出错: {e}") logger.error(f"从数据库查询表情包情感标签时出错: {e}")
@@ -993,7 +993,7 @@ class EmojiManager:
) )
# 处理情感列表 # 处理情感列表
emotions = [e.strip() for e in emotions_text.replace("",",").split(",") if e.strip()] emotions = [e.strip() for e in emotions_text.replace("", ",").split(",") if e.strip()]
# 根据情感标签数量随机选择 - 超过5个选3个超过2个选2个 # 根据情感标签数量随机选择 - 超过5个选3个超过2个选2个
if len(emotions) > 5: if len(emotions) > 5:

View File

@@ -123,7 +123,11 @@ class ChatBot:
logger.warning(f"命令执行失败: {command_class.__name__} - {response}") logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
# 根据命令的拦截设置决定是否继续处理消息 # 根据命令的拦截设置决定是否继续处理消息
return True, response, not bool(intercept_message_level) # 找到命令根据intercept_message决定是否继续 return (
True,
response,
not bool(intercept_message_level),
) # 找到命令根据intercept_message决定是否继续
except Exception as e: except Exception as e:
logger.error(f"执行命令时出错: {command_class.__name__} - {e}") logger.error(f"执行命令时出错: {command_class.__name__} - {e}")

View File

@@ -263,7 +263,7 @@ class MessageRecv(Message):
desc = segment.data.get("desc", "") # 内容描述 desc = segment.data.get("desc", "") # 内容描述
source_url = segment.data.get("source_url", "") # 原始链接 source_url = segment.data.get("source_url", "") # 原始链接
url = segment.data.get("url", "") # 小程序链接 url = segment.data.get("url", "") # 小程序链接
text = f"[小程序分享" text = "[小程序分享"
if title: if title:
text += f" - {title}" text += f" - {title}"
text += "]" text += "]"

View File

@@ -51,7 +51,6 @@ def parse_message_segments(segment) -> list:
Returns: Returns:
list: 消息段列表,每个元素为 {"type": "...", "data": ...} list: 消息段列表,每个元素为 {"type": "...", "data": ...}
""" """
from maim_message import Seg
result = [] result = []
@@ -112,9 +111,13 @@ def parse_message_segments(segment) -> list:
forward_items = [] forward_items = []
if segment.data: if segment.data:
for item in segment.data: for item in segment.data:
forward_items.append({ forward_items.append(
"content": parse_message_segments(item.get("message_segment", {})) if isinstance(item, dict) else [] {
}) "content": parse_message_segments(item.get("message_segment", {}))
if isinstance(item, dict)
else []
}
)
result.append({"type": "forward", "data": forward_items}) result.append({"type": "forward", "data": forward_items})
else: else:
# 未知类型,尝试作为文本处理 # 未知类型,尝试作为文本处理

View File

@@ -78,7 +78,6 @@ target_message_id为必填表示触发消息的id
"planner_prompt", "planner_prompt",
) )
Prompt( Prompt(
""" """
{action_name} {action_name}

View File

@@ -250,7 +250,12 @@ class DefaultReplyer:
# 使用从处理器传来的选中表达方式 # 使用从处理器传来的选中表达方式
# 使用模型预测选择表达方式 # 使用模型预测选择表达方式
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions( selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason, think_level=think_level self.chat_stream.stream_id,
chat_history,
max_num=8,
target_message=target,
reply_reason=reply_reason,
think_level=think_level,
) )
if selected_expressions: if selected_expressions:
@@ -273,7 +278,6 @@ class DefaultReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str: async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块 """构建工具信息块
@@ -788,7 +792,8 @@ class DefaultReplyer:
# 并行执行八个构建任务(包括黑话解释) # 并行执行八个构建任务(包括黑话解释)
task_results = await asyncio.gather( task_results = await asyncio.gather(
self._time_and_run_task( self._time_and_run_task(
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level), "expression_habits" self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level),
"expression_habits",
), ),
self._time_and_run_task( self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info" self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
@@ -980,7 +985,6 @@ class DefaultReplyer:
else: else:
reply_target_block = "" reply_target_block = ""
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1") chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2") chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")

View File

@@ -287,7 +287,6 @@ class PrivateReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str: async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块 """构建工具信息块
@@ -907,16 +906,11 @@ class PrivateReplyer:
else: else:
reply_target_block = "" reply_target_block = ""
chat_target_name = "对方" chat_target_name = "对方"
if self.chat_target_info: if self.chat_target_info:
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方" chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
chat_target_1 = await global_prompt_manager.format_prompt( chat_target_1 = await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name)
"chat_target_private1", sender_name=chat_target_name chat_target_2 = await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name)
)
chat_target_2 = await global_prompt_manager.format_prompt(
"chat_target_private2", sender_name=chat_target_name
)
template_name = "default_expressor_prompt" template_name = "default_expressor_prompt"

View File

@@ -1,8 +1,9 @@
from src.chat.utils.prompt_builder import Prompt from src.chat.utils.prompt_builder import Prompt
def init_replyer_private_prompt(): def init_replyer_private_prompt():
Prompt( Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block} """{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation} {expression_habits_block}{memory_retrieval}{jargon_explanation}
你正在和{sender_name}聊天,这是你们之前聊的内容: 你正在和{sender_name}聊天,这是你们之前聊的内容:
@@ -17,8 +18,8 @@ def init_replyer_private_prompt():
{reply_style} {reply_style}
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。 请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""", {moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""",
"private_replyer_prompt", "private_replyer_prompt",
) )
Prompt( Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block} """{knowledge_prompt}{tool_info_block}{extra_info_block}

View File

@@ -44,4 +44,3 @@ def init_replyer_prompt():
现在,你说:""", 现在,你说:""",
"replyer_prompt", "replyer_prompt",
) )

View File

@@ -311,7 +311,10 @@ def get_raw_msg_before_timestamp_with_chat(
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}} filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
sort_order = [("time", 1)] sort_order = [("time", 1)]
return find_messages( return find_messages(
message_filter=filter_query, sort=sort_order, limit=limit, filter_intercept_message_level=filter_intercept_message_level message_filter=filter_query,
sort=sort_order,
limit=limit,
filter_intercept_message_level=filter_intercept_message_level,
) )

View File

@@ -132,9 +132,7 @@ class ImageManager:
deleted_images = Images.delete().where(Images.type == "emoji").execute() deleted_images = Images.delete().where(Images.type == "emoji").execute()
# 清理ImageDescriptions表中type为emoji的记录 # 清理ImageDescriptions表中type为emoji的记录
deleted_descriptions = ( deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
)
total_deleted = deleted_images + deleted_descriptions total_deleted = deleted_images + deleted_descriptions
if total_deleted > 0: if total_deleted > 0:
@@ -236,10 +234,14 @@ class ImageManager:
# 优先使用情感标签,如果没有则使用详细描述 # 优先使用情感标签,如果没有则使用详细描述
result_text = "" result_text = ""
if cache_record.emotion_tags: if cache_record.emotion_tags:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...") logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
)
result_text = f"[表情包:{cache_record.emotion_tags}]" result_text = f"[表情包:{cache_record.emotion_tags}]"
elif cache_record.description: elif cache_record.description:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...") logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
)
result_text = f"[表情包:{cache_record.description}]" result_text = f"[表情包:{cache_record.description}]"
# 即使缓存命中如果启用了steal_emoji也检查是否需要保存文件 # 即使缓存命中如果启用了steal_emoji也检查是否需要保存文件

View File

@@ -609,10 +609,10 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
fields = list(model._meta.fields.keys()) fields = list(model._meta.fields.keys())
# Peewee 默认使用 'id' 作为主键字段名 # Peewee 默认使用 'id' 作为主键字段名
# 尝试获取主键字段名,如果获取失败则默认使用 'id' # 尝试获取主键字段名,如果获取失败则默认使用 'id'
primary_key_name = 'id' # 默认值 primary_key_name = "id" # 默认值
try: try:
if hasattr(model._meta, 'primary_key') and model._meta.primary_key: if hasattr(model._meta, "primary_key") and model._meta.primary_key:
if hasattr(model._meta.primary_key, 'name'): if hasattr(model._meta.primary_key, "name"):
primary_key_name = model._meta.primary_key.name primary_key_name = model._meta.primary_key.name
elif isinstance(model._meta.primary_key, str): elif isinstance(model._meta.primary_key, str):
primary_key_name = model._meta.primary_key primary_key_name = model._meta.primary_key

View File

@@ -34,7 +34,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
return obj return obj
# 决定是否多行:仅在顶层且长度超过阈值时 # 决定是否多行:仅在顶层且长度超过阈值时
should_multiline = (depth == 0 and len(obj) > threshold) should_multiline = depth == 0 and len(obj) > threshold
# 如果已经是 tomlkit Array原地修改以保留注释 # 如果已经是 tomlkit Array原地修改以保留注释
if isinstance(obj, Array): if isinstance(obj, Array):
@@ -112,7 +112,7 @@ def save_toml_with_format(
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
output = tomlkit.dumps(formatted) output = tomlkit.dumps(formatted)
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积 # 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
output = re.sub(r'\n{3,}', '\n\n', output) output = re.sub(r"\n{3,}", "\n\n", output)
with open(file_path, "w", encoding="utf-8") as f: with open(file_path, "w", encoding="utf-8") as f:
f.write(output) f.write(output)
@@ -122,4 +122,4 @@ def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
output = tomlkit.dumps(formatted) output = tomlkit.dumps(formatted)
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积 # 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
return re.sub(r'\n{3,}', '\n\n', output) return re.sub(r"\n{3,}", "\n\n", output)

View File

@@ -1,14 +1,13 @@
import asyncio import asyncio
import random import random
import time import time
import json
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from peewee import fn from peewee import fn
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.common.database.database_model import ChatHistory, Jargon from src.common.database.database_model import ChatHistory
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.plugin_system.apis import llm_api from src.plugin_system.apis import llm_api
@@ -82,7 +81,6 @@ def init_dream_prompts() -> None:
) )
class DreamTool: class DreamTool:
"""dream 模块内部使用的简易工具封装""" """dream 模块内部使用的简易工具封装"""
@@ -150,7 +148,13 @@ def init_dream_tools(chat_id: str) -> None:
"search_chat_history", "search_chat_history",
"根据关键词或参与人查询当前 chat_id 下的 ChatHistory 概览,便于快速定位相关记忆。", "根据关键词或参与人查询当前 chat_id 下的 ChatHistory 概览,便于快速定位相关记忆。",
[ [
("keyword", ToolParamType.STRING, "关键词(可选,支持多个关键词,可用空格、逗号等分隔)。", False, None), (
"keyword",
ToolParamType.STRING,
"关键词(可选,支持多个关键词,可用空格、逗号等分隔)。",
False,
None,
),
("participant", ToolParamType.STRING, "参与人昵称(可选)。", False, None), ("participant", ToolParamType.STRING, "参与人昵称(可选)。", False, None),
], ],
search_chat_history, search_chat_history,
@@ -201,8 +205,20 @@ def init_dream_tools(chat_id: str) -> None:
[ [
("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None), ("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None),
("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None), ("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None),
("keywords", ToolParamType.STRING, "新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。", True, None), (
("key_point", ToolParamType.STRING, "新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。", True, None), "keywords",
ToolParamType.STRING,
"新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。",
True,
None,
),
(
"key_point",
ToolParamType.STRING,
"新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。",
True,
None,
),
("start_time", ToolParamType.STRING, "起始时间戳Unix 时间,必填)。", True, None), ("start_time", ToolParamType.STRING, "起始时间戳Unix 时间,必填)。", True, None),
("end_time", ToolParamType.STRING, "结束时间戳Unix 时间,必填)。", True, None), ("end_time", ToolParamType.STRING, "结束时间戳Unix 时间,必填)。", True, None),
], ],
@@ -215,7 +231,13 @@ def init_dream_tools(chat_id: str) -> None:
"finish_maintenance", "finish_maintenance",
"结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理、合并或修改的内容时,调用此工具来主动结束本次运行。", "结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理、合并或修改的内容时,调用此工具来主动结束本次运行。",
[ [
("reason", ToolParamType.STRING, "结束维护的原因说明(可选),例如 '已完成所有记录的整理''当前记录质量良好,无需进一步维护'", False, None), (
"reason",
ToolParamType.STRING,
"结束维护的原因说明(可选),例如 '已完成所有记录的整理''当前记录质量良好,无需进一步维护'",
False,
None,
),
], ],
finish_maintenance, finish_maintenance,
) )
@@ -282,9 +304,7 @@ async def run_dream_agent_once(
else "未知" else "未知"
) )
end_time_str = ( end_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
if record.end_time
else "未知"
) )
detail_text = ( detail_text = (
f"ID={record.id}\n" f"ID={record.id}\n"
@@ -305,8 +325,7 @@ async def run_dream_agent_once(
start_detail_builder = MessageBuilder() start_detail_builder = MessageBuilder()
start_detail_builder.set_role(RoleType.User) start_detail_builder.set_role(RoleType.User)
start_detail_builder.add_text_content( start_detail_builder.add_text_content(
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n" "【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n" + detail_text
+ detail_text
) )
conversation_messages.append(start_detail_builder.build()) conversation_messages.append(start_detail_builder.build())
else: else:
@@ -343,13 +362,17 @@ async def run_dream_agent_once(
conversation_messages.append(round_info_builder.build()) conversation_messages.append(round_info_builder.build())
# 调用 LLM 让其决定是否要使用工具 # 调用 LLM 让其决定是否要使用工具
success, response, reasoning_content, model_name, tool_calls = ( (
await llm_api.generate_with_model_with_tools_by_message_factory( success,
message_factory, response,
model_config=model_config.model_task_config.tool_use, reasoning_content,
tool_options=tool_defs, model_name,
request_type="dream.react", tool_calls,
) ) = await llm_api.generate_with_model_with_tools_by_message_factory(
message_factory,
model_config=model_config.model_task_config.tool_use,
tool_options=tool_defs,
request_type="dream.react",
) )
if not success: if not success:
@@ -555,4 +578,3 @@ async def start_dream_scheduler(
# 初始化提示词 # 初始化提示词
init_dream_prompts() init_dream_prompts()

View File

@@ -110,13 +110,17 @@ async def generate_dream_summary(
thought_content = "" thought_content = ""
if msg.content: if msg.content:
if isinstance(msg.content, list) and msg.content: if isinstance(msg.content, list) and msg.content:
thought_content = msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0]) thought_content = (
msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
)
else: else:
thought_content = str(msg.content) thought_content = str(msg.content)
logger.info(f"[dream][工具调用详情] === 第 {tool_call_count} 组工具调用 ===") logger.info(f"[dream][工具调用详情] === 第 {tool_call_count} 组工具调用 ===")
if thought_content: if thought_content:
logger.info(f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}") logger.info(
f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}"
)
# 记录每个工具调用的详细信息 # 记录每个工具调用的详细信息
for idx, tool_call in enumerate(msg.tool_calls, 1): for idx, tool_call in enumerate(msg.tool_calls, 1):
@@ -167,7 +171,7 @@ async def generate_dream_summary(
# 随机选择2个梦境风格 # 随机选择2个梦境风格
selected_styles = get_random_dream_styles(2) selected_styles = get_random_dream_styles(2)
dream_styles_text = "\n".join([f"{i+1}. {style}" for i, style in enumerate(selected_styles)]) dream_styles_text = "\n".join([f"{i + 1}. {style}" for i, style in enumerate(selected_styles)])
# 使用 Prompt 管理器格式化梦境生成 prompt # 使用 Prompt 管理器格式化梦境生成 prompt
dream_prompt = await global_prompt_manager.format_prompt( dream_prompt = await global_prompt_manager.format_prompt(
@@ -195,4 +199,5 @@ async def generate_dream_summary(
except Exception as e: except Exception as e:
logger.error(f"[dream][梦境总结] 生成梦境总结失败: {e}", exc_info=True) logger.error(f"[dream][梦境总结] 生成梦境总结失败: {e}", exc_info=True)
init_dream_summary_prompt() init_dream_summary_prompt()

View File

@@ -4,8 +4,3 @@ dream agent 工具实现模块。
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数 每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。 生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
""" """

View File

@@ -60,8 +60,3 @@ def make_create_chat_history(chat_id: str):
return f"create_chat_history 执行失败: {e}" return f"create_chat_history 执行失败: {e}"
return create_chat_history return create_chat_history

View File

@@ -23,8 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"delete_chat_history 执行失败: {e}" return f"delete_chat_history 执行失败: {e}"
return delete_chat_history return delete_chat_history

View File

@@ -23,8 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"delete_jargon 执行失败: {e}" return f"delete_jargon 执行失败: {e}"
return delete_jargon return delete_jargon

View File

@@ -14,8 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,
return msg return msg
return finish_maintenance return finish_maintenance

View File

@@ -1,5 +1,4 @@
import time import time
from typing import Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory from src.common.database.database_model import ChatHistory
@@ -20,14 +19,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
# 将时间戳转换为可读时间格式 # 将时间戳转换为可读时间格式
start_time_str = ( start_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time)) time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time)) if record.start_time else "未知"
if record.start_time
else "未知"
) )
end_time_str = ( end_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
if record.end_time
else "未知"
) )
result = ( result = (
@@ -40,17 +35,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
f"概括={record.summary or ''}\n" f"概括={record.summary or ''}\n"
f"关键信息={record.key_point or ''}" f"关键信息={record.key_point or ''}"
) )
logger.debug( logger.debug(f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}")
f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}"
)
return result return result
except Exception as e: except Exception as e:
logger.error(f"get_chat_history_detail 失败: {e}") logger.error(f"get_chat_history_detail 失败: {e}")
return f"get_chat_history_detail 执行失败: {e}" return f"get_chat_history_detail 执行失败: {e}"
return get_chat_history_detail return get_chat_history_detail

View File

@@ -78,9 +78,7 @@ def make_search_chat_history(chat_id: str):
if record.keywords: if record.keywords:
try: try:
keywords_data = ( keywords_data = (
json.loads(record.keywords) json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
if isinstance(record.keywords, str)
else record.keywords
) )
if isinstance(keywords_data, list): if isinstance(keywords_data, list):
record_keywords_list = [str(k).lower() for k in keywords_data] record_keywords_list = [str(k).lower() for k in keywords_data]
@@ -125,9 +123,7 @@ def make_search_chat_history(chat_id: str):
keywords_str = "".join(keywords_list) keywords_str = "".join(keywords_list)
if len(keywords_list) > 2: if len(keywords_list) > 2:
required_count = len(keywords_list) - 1 required_count = len(keywords_list) - 1
return ( return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
)
else: else:
return f"未找到包含所有关键词'{keywords_str}'的聊天记录" return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
elif participant: elif participant:
@@ -142,9 +138,7 @@ def make_search_chat_history(chat_id: str):
if record.keywords: if record.keywords:
try: try:
keywords_data = ( keywords_data = (
json.loads(record.keywords) json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
if isinstance(record.keywords, str)
else record.keywords
) )
if isinstance(keywords_data, list): if isinstance(keywords_data, list):
for k in keywords_data: for k in keywords_data:
@@ -160,13 +154,13 @@ def make_search_chat_history(chat_id: str):
keywords_str = "".join(sorted(all_keywords_set)) keywords_str = "".join(sorted(all_keywords_set))
response_text = ( response_text = (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n" f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词:\n" f'有关"{search_label}"的关键词:\n'
f"{keywords_str}" f"{keywords_str}"
) )
else: else:
response_text = ( response_text = (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n" f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词信息为空" f'有关"{search_label}"的关键词信息为空'
) )
logger.info( logger.info(
@@ -192,9 +186,7 @@ def make_search_chat_history(chat_id: str):
if record.keywords: if record.keywords:
try: try:
keywords_data = ( keywords_data = (
json.loads(record.keywords) json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
if isinstance(record.keywords, str)
else record.keywords
) )
if isinstance(keywords_data, list) and keywords_data: if isinstance(keywords_data, list) and keywords_data:
keywords_str = "".join([str(k) for k in keywords_data]) keywords_str = "".join([str(k) for k in keywords_data])
@@ -220,8 +212,3 @@ def make_search_chat_history(chat_id: str):
return f"search_chat_history 执行失败: {e}" return f"search_chat_history 执行失败: {e}"
return search_chat_history return search_chat_history

View File

@@ -16,9 +16,7 @@ def make_search_jargon(chat_id: str):
if not keyword or not keyword.strip(): if not keyword or not keyword.strip():
return "未指定查询关键词(参数 keyword 为必填,且不能为空)" return "未指定查询关键词(参数 keyword 为必填,且不能为空)"
logger.info( logger.info(f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})")
f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})"
)
# 基础条件:只查 is_jargon=True 的记录 # 基础条件:只查 is_jargon=True 的记录
query = Jargon.select().where(Jargon.is_jargon) query = Jargon.select().where(Jargon.is_jargon)
@@ -102,5 +100,3 @@ def make_search_jargon(chat_id: str):
return f"search_jargon 执行失败: {e}" return f"search_jargon 执行失败: {e}"
return search_jargon return search_jargon

View File

@@ -49,8 +49,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
return f"update_chat_history 执行失败: {e}" return f"update_chat_history 执行失败: {e}"
return update_chat_history return update_chat_history

View File

@@ -49,8 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留
return f"update_jargon 执行失败: {e}" return f"update_jargon 执行失败: {e}"
return update_jargon return update_jargon

View File

@@ -316,7 +316,9 @@ class ChatHistorySummarizer:
before_count = len(self.current_batch.messages) before_count = len(self.current_batch.messages)
self.current_batch.messages.extend(new_messages) self.current_batch.messages.extend(new_messages)
self.current_batch.end_time = current_time self.current_batch.end_time = current_time
logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息") logger.info(
f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息"
)
# 更新批次后持久化 # 更新批次后持久化
self._persist_topic_cache() self._persist_topic_cache()
else: else:
@@ -362,9 +364,7 @@ class ChatHistorySummarizer:
else: else:
time_str = f"{time_since_last_check / 3600:.1f}小时" time_str = f"{time_since_last_check / 3600:.1f}小时"
logger.debug( logger.debug(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}")
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}"
)
# 检查“话题检查”触发条件 # 检查“话题检查”触发条件
should_check = False should_check = False
@@ -427,7 +427,9 @@ class ChatHistorySummarizer:
return return
# 2. 构造编号后的消息字符串和参与者信息 # 2. 构造编号后的消息字符串和参与者信息
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages) numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = (
self._build_numbered_messages_for_llm(messages)
)
# 3. 调用 LLM 识别话题,并得到 topic -> indices失败时最多重试 3 次) # 3. 调用 LLM 识别话题,并得到 topic -> indices失败时最多重试 3 次)
existing_topics = list(self.topic_cache.keys()) existing_topics = list(self.topic_cache.keys())
@@ -456,9 +458,7 @@ class ChatHistorySummarizer:
) )
if not success or not topic_to_indices: if not success or not topic_to_indices:
logger.error( logger.error(f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃")
f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃"
)
# 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks保持原状 # 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks保持原状
return return
@@ -610,9 +610,7 @@ class ChatHistorySummarizer:
if not numbered_lines: if not numbered_lines:
return False, {} return False, {}
history_topics_block = ( history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
"\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
)
messages_block = "\n".join(numbered_lines) messages_block = "\n".join(numbered_lines)
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
@@ -642,10 +640,10 @@ class ChatHistorySummarizer:
else: else:
# 如果没有找到代码块尝试查找JSON数组的开始和结束位置 # 如果没有找到代码块尝试查找JSON数组的开始和结束位置
# 查找第一个 [ 和最后一个 ] # 查找第一个 [ 和最后一个 ]
start_idx = response.find('[') start_idx = response.find("[")
end_idx = response.rfind(']') end_idx = response.rfind("]")
if start_idx != -1 and end_idx != -1 and end_idx > start_idx: if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
json_str = response[start_idx:end_idx + 1].strip() json_str = response[start_idx : end_idx + 1].strip()
else: else:
# 如果还是找不到尝试直接使用整个响应移除可能的markdown标记 # 如果还是找不到尝试直接使用整个响应移除可能的markdown标记
json_str = response.strip() json_str = response.strip()
@@ -942,4 +940,3 @@ class ChatHistorySummarizer:
init_prompt() init_prompt()

View File

@@ -366,7 +366,9 @@ class LLMRequest:
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}") logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}") logger.warning(
f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval) await asyncio.sleep(api_provider.retry_interval)
except NetworkConnectionError as e: except NetworkConnectionError as e:
@@ -394,7 +396,9 @@ class LLMRequest:
if e.status_code == 429 or e.status_code >= 500: if e.status_code == 429 or e.status_code >= 500:
retry_remain -= 1 retry_remain -= 1
if retry_remain <= 0: if retry_remain <= 0:
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}") logger.error(
f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}"
)
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning( logger.warning(
@@ -540,7 +544,5 @@ class LLMRequest:
if e.__cause__: if e.__cause__:
original_error_type = type(e.__cause__).__name__ original_error_type = type(e.__cause__).__name__
original_error_msg = str(e.__cause__) original_error_msg = str(e.__cause__)
return ( return f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
)
return "" return ""

View File

@@ -113,7 +113,6 @@ class MainSystem:
get_emoji_manager().initialize() get_emoji_manager().initialize()
logger.info("表情包管理器初始化成功") logger.info("表情包管理器初始化成功")
# 初始化聊天管理器 # 初始化聊天管理器
await get_chat_manager()._initialize() await get_chat_manager()._initialize()
asyncio.create_task(get_chat_manager()._auto_save_task()) asyncio.create_task(get_chat_manager()._auto_save_task())

View File

@@ -136,8 +136,6 @@ def init_memory_retrieval_prompt():
) )
def _log_conversation_messages( def _log_conversation_messages(
conversation_messages: List[Message], conversation_messages: List[Message],
head_prompt: Optional[str] = None, head_prompt: Optional[str] = None,
@@ -172,7 +170,9 @@ def _log_conversation_messages(
# 构建单条消息的日志信息 # 构建单条消息的日志信息
# msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------" # msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------" msg_info = (
f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
)
# if full_content: # if full_content:
# msg_info += f"\n{full_content}" # msg_info += f"\n{full_content}"
@@ -185,8 +185,7 @@ def _log_conversation_messages(
msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}" msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}"
# if msg.tool_call_id: # if msg.tool_call_id:
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}" # msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
log_lines.append(msg_info) log_lines.append(msg_info)
@@ -365,7 +364,7 @@ async def _react_agent_solve_question(
) )
# logger.info( # logger.info(
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}" # f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
# ) # )
if not success: if not success:
@@ -467,10 +466,17 @@ async def _react_agent_solve_question(
if parsed_found_answer: if parsed_found_answer:
# 找到了答案 # 找到了答案
if parsed_answer: if parsed_answer:
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": parsed_answer}}) step["actions"].append(
{
"action_type": "finish_search",
"action_params": {"found_answer": True, "answer": parsed_answer},
}
)
step["observations"] = ["检测到finish_search文本格式调用找到答案"] step["observations"] = ["检测到finish_search文本格式调用找到答案"]
thinking_steps.append(step) thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}") logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
)
_log_conversation_messages( _log_conversation_messages(
conversation_messages, conversation_messages,
@@ -481,10 +487,14 @@ async def _react_agent_solve_question(
return True, parsed_answer, thinking_steps, False return True, parsed_answer, thinking_steps, False
else: else:
# found_answer为True但没有提供answer视为错误继续迭代 # found_answer为True但没有提供answer视为错误继续迭代
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer") logger.warning(
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
)
else: else:
# 未找到答案,直接退出查询 # 未找到答案,直接退出查询
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}}) step["actions"].append(
{"action_type": "finish_search", "action_params": {"found_answer": False}}
)
step["observations"] = ["检测到finish_search文本格式调用未找到答案"] step["observations"] = ["检测到finish_search文本格式调用未找到答案"]
thinking_steps.append(step) thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案") logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案")
@@ -522,10 +532,17 @@ async def _react_agent_solve_question(
if finish_search_found: if finish_search_found:
# 找到了答案 # 找到了答案
if finish_search_answer: if finish_search_answer:
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": finish_search_answer}}) step["actions"].append(
{
"action_type": "finish_search",
"action_params": {"found_answer": True, "answer": finish_search_answer},
}
)
step["observations"] = ["检测到finish_search工具调用找到答案"] step["observations"] = ["检测到finish_search工具调用找到答案"]
thinking_steps.append(step) thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}") logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
)
_log_conversation_messages( _log_conversation_messages(
conversation_messages, conversation_messages,
@@ -536,7 +553,9 @@ async def _react_agent_solve_question(
return True, finish_search_answer, thinking_steps, False return True, finish_search_answer, thinking_steps, False
else: else:
# found_answer为True但没有提供answer视为错误 # found_answer为True但没有提供answer视为错误
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer") logger.warning(
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
)
else: else:
# 未找到答案,直接退出查询 # 未找到答案,直接退出查询
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}}) step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
@@ -724,7 +743,13 @@ async def _react_agent_solve_question(
max_iterations=max_iterations, max_iterations=max_iterations,
) )
eval_success, eval_response, eval_reasoning_content, eval_model_name, eval_tool_calls = await llm_api.generate_with_model_with_tools( (
eval_success,
eval_response,
eval_reasoning_content,
eval_model_name,
eval_tool_calls,
) = await llm_api.generate_with_model_with_tools(
evaluation_prompt, evaluation_prompt,
model_config=model_config.model_task_config.tool_use, model_config=model_config.model_task_config.tool_use,
tool_options=[], # 最终评估阶段不提供工具 tool_options=[], # 最终评估阶段不提供工具
@@ -759,7 +784,7 @@ async def _react_agent_solve_question(
"iteration": current_iteration, "iteration": current_iteration,
"thought": f"[最终评估] {eval_response}", "thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}], "actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}],
"observations": ["最终评估阶段检测到found_answer"] "observations": ["最终评估阶段检测到found_answer"],
} }
thinking_steps.append(eval_step) thinking_steps.append(eval_step)
logger.info(f"ReAct Agent 最终评估阶段找到关于问题{question}的答案: {found_answer_content}") logger.info(f"ReAct Agent 最终评估阶段找到关于问题{question}的答案: {found_answer_content}")
@@ -778,7 +803,7 @@ async def _react_agent_solve_question(
"iteration": current_iteration, "iteration": current_iteration,
"thought": f"[最终评估] {eval_response}", "thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}], "actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}],
"observations": ["最终评估阶段检测到not_enough_info"] "observations": ["最终评估阶段检测到not_enough_info"],
} }
thinking_steps.append(eval_step) thinking_steps.append(eval_step)
logger.info(f"ReAct Agent 最终评估阶段判断信息不足: {not_enough_info_reason}") logger.info(f"ReAct Agent 最终评估阶段判断信息不足: {not_enough_info_reason}")
@@ -795,8 +820,10 @@ async def _react_agent_solve_question(
eval_step = { eval_step = {
"iteration": current_iteration, "iteration": current_iteration,
"thought": f"[最终评估] {eval_response}", "thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}], "actions": [
"observations": ["已到达最大迭代次数,无法找到答案"] {"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}
],
"observations": ["已到达最大迭代次数,无法找到答案"],
} }
thinking_steps.append(eval_step) thinking_steps.append(eval_step)
logger.info("ReAct Agent 已到达最大迭代次数,无法找到答案") logger.info("ReAct Agent 已到达最大迭代次数,无法找到答案")
@@ -1129,7 +1156,9 @@ async def build_memory_retrieval_prompt(
else: else:
max_iterations = base_max_iterations max_iterations = base_max_iterations
timeout_seconds = global_config.memory.agent_timeout_seconds timeout_seconds = global_config.memory.agent_timeout_seconds
logger.debug(f"问题数量: {len(questions)}think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}") logger.debug(
f"问题数量: {len(questions)}think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}"
)
# 并行处理所有问题,将概念检索结果作为初始信息传递 # 并行处理所有问题,将概念检索结果作为初始信息传递
question_tasks = [ question_tasks = [
@@ -1198,4 +1227,3 @@ async def build_memory_retrieval_prompt(
except Exception as e: except Exception as e:
logger.error(f"记忆检索时发生异常: {str(e)}") logger.error(f"记忆检索时发生异常: {str(e)}")
return "" return ""

View File

@@ -17,7 +17,6 @@ from src.common.logger import get_logger
logger = get_logger("memory_utils") logger = get_logger("memory_utils")
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]: def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
"""解析问题JSON返回概念列表和问题列表 """解析问题JSON返回概念列表和问题列表
@@ -68,6 +67,7 @@ def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...") logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
return [], [] return [], []
def parse_datetime_to_timestamp(value: str) -> float: def parse_datetime_to_timestamp(value: str) -> float:
""" """
接受多种常见格式并转换为时间戳(秒) 接受多种常见格式并转换为时间戳(秒)

View File

@@ -47,4 +47,3 @@ def register_tool():
], ],
execute_func=finish_search, execute_func=finish_search,
) )

View File

@@ -16,9 +16,7 @@ from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools") logger = get_logger("memory_retrieval_tools")
async def search_chat_history( async def search_chat_history(chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None) -> str:
chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None
) -> str:
"""根据关键词或参与人查询记忆返回匹配的记忆id、记忆标题theme和关键词keywords """根据关键词或参与人查询记忆返回匹配的记忆id、记忆标题theme和关键词keywords
Args: Args:
@@ -144,7 +142,9 @@ async def search_chat_history(
keywords_list = parse_keywords_string(keyword) keywords_list = parse_keywords_string(keyword)
if len(keywords_list) > 2: if len(keywords_list) > 2:
required_count = len(keywords_list) - 1 required_count = len(keywords_list) - 1
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录" return (
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
)
else: else:
return f"未找到包含所有关键词'{keywords_str}'的聊天记录" return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
elif participant: elif participant:
@@ -160,9 +160,7 @@ async def search_chat_history(
if record.keywords: if record.keywords:
try: try:
keywords_data = ( keywords_data = (
json.loads(record.keywords) json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
if isinstance(record.keywords, str)
else record.keywords
) )
if isinstance(keywords_data, list): if isinstance(keywords_data, list):
for k in keywords_data: for k in keywords_data:
@@ -179,13 +177,12 @@ async def search_chat_history(
keywords_str = "".join(sorted(all_keywords_set)) keywords_str = "".join(sorted(all_keywords_set))
return ( return (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n" f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f"有关\"{search_label}\"的关键词:\n" f'有关"{search_label}"的关键词:\n'
f"{keywords_str}" f"{keywords_str}"
) )
else: else:
return ( return (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n" f'包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n有关"{search_label}"的关键词信息为空'
f"有关\"{search_label}\"的关键词信息为空"
) )
# 构建结果文本返回id、theme和keywords最多20条 # 构建结果文本返回id、theme和keywords最多20条

View File

@@ -414,7 +414,9 @@ async def websocket_chat(
group_id=virtual_group_id, group_id=virtual_group_id,
group_name=group_name or "WebUI虚拟群聊", group_name=group_name or "WebUI虚拟群聊",
) )
logger.info(f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}") logger.info(
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
)
except Exception as e: except Exception as e:
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}") logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")

View File

@@ -1,4 +1,4 @@
""" 表情包管理 API 路由""" """表情包管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
from fastapi.responses import FileResponse, JSONResponse from fastapi.responses import FileResponse, JSONResponse
@@ -100,23 +100,23 @@ def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
try: try:
with Image.open(source_path) as img: with Image.open(source_path) as img:
# GIF 处理:提取第一帧 # GIF 处理:提取第一帧
if hasattr(img, 'n_frames') and img.n_frames > 1: if hasattr(img, "n_frames") and img.n_frames > 1:
img.seek(0) # 确保在第一帧 img.seek(0) # 确保在第一帧
# 转换为 RGB/RGBAWebP 支持透明度) # 转换为 RGB/RGBAWebP 支持透明度)
if img.mode in ('P', 'PA'): if img.mode in ("P", "PA"):
# 调色板模式转换为 RGBA 以保留透明度 # 调色板模式转换为 RGBA 以保留透明度
img = img.convert('RGBA') img = img.convert("RGBA")
elif img.mode == 'LA': elif img.mode == "LA":
img = img.convert('RGBA') img = img.convert("RGBA")
elif img.mode not in ('RGB', 'RGBA'): elif img.mode not in ("RGB", "RGBA"):
img = img.convert('RGB') img = img.convert("RGB")
# 创建缩略图(保持宽高比) # 创建缩略图(保持宽高比)
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS) img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
# 保存为 WebP 格式 # 保存为 WebP 格式
img.save(cache_path, 'WEBP', quality=THUMBNAIL_QUALITY, method=6) img.save(cache_path, "WEBP", quality=THUMBNAIL_QUALITY, method=6)
logger.debug(f"生成缩略图: {file_hash} -> {cache_path}") logger.debug(f"生成缩略图: {file_hash} -> {cache_path}")
@@ -163,6 +163,7 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]:
return cleaned, kept return cleaned, kept
# 模块级别的类型别名(解决 B008 ruff 错误) # 模块级别的类型别名(解决 B008 ruff 错误)
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")] EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")] EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
@@ -365,7 +366,9 @@ async def get_emoji_list(
@router.get("/{emoji_id}", response_model=EmojiDetailResponse) @router.get("/{emoji_id}", response_model=EmojiDetailResponse)
async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def get_emoji_detail(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
获取表情包详细信息 获取表情包详细信息
@@ -394,7 +397,12 @@ async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse) @router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def update_emoji(
emoji_id: int,
request: EmojiUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
""" """
增量更新表情包(只更新提供的字段) 增量更新表情包(只更新提供的字段)
@@ -446,7 +454,9 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_sessio
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse) @router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def delete_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
删除表情包 删除表情包
@@ -538,7 +548,9 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse) @router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def register_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
注册表情包(快捷操作) 注册表情包(快捷操作)
@@ -578,7 +590,9 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse) @router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def ban_emoji(
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
禁用表情包(快捷操作) 禁用表情包(快捷操作)
@@ -680,9 +694,7 @@ async def get_emoji_thumbnail(
} }
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream") media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
return FileResponse( return FileResponse(
path=emoji.full_path, path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}"
media_type=media_type,
filename=f"{emoji.emoji_hash}.{emoji.format}"
) )
# 尝试获取或生成缩略图 # 尝试获取或生成缩略图
@@ -692,9 +704,7 @@ async def get_emoji_thumbnail(
if cache_path.exists(): if cache_path.exists():
# 缓存命中,直接返回 # 缓存命中,直接返回
return FileResponse( return FileResponse(
path=str(cache_path), path=str(cache_path), media_type="image/webp", filename=f"{emoji.emoji_hash}_thumb.webp"
media_type="image/webp",
filename=f"{emoji.emoji_hash}_thumb.webp"
) )
# 缓存未命中,触发后台生成并返回 202 # 缓存未命中,触发后台生成并返回 202
@@ -703,11 +713,7 @@ async def get_emoji_thumbnail(
# 标记为正在生成 # 标记为正在生成
_generating_thumbnails.add(emoji.emoji_hash) _generating_thumbnails.add(emoji.emoji_hash)
# 提交到线程池后台生成 # 提交到线程池后台生成
_thumbnail_executor.submit( _thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.emoji_hash)
_background_generate_thumbnail,
emoji.full_path,
emoji.emoji_hash
)
# 返回 202 Accepted告诉前端缩略图正在生成中 # 返回 202 Accepted告诉前端缩略图正在生成中
return JSONResponse( return JSONResponse(
@@ -719,7 +725,7 @@ async def get_emoji_thumbnail(
}, },
headers={ headers={
"Retry-After": "1", # 建议 1 秒后重试 "Retry-After": "1", # 建议 1 秒后重试
} },
) )
except HTTPException: except HTTPException:
@@ -730,7 +736,11 @@ async def get_emoji_thumbnail(
@router.post("/batch/delete", response_model=BatchDeleteResponse) @router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_emojis(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def batch_delete_emojis(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
""" """
批量删除表情包 批量删除表情包
@@ -1234,12 +1244,7 @@ async def preheat_thumbnail_cache(
try: try:
# 使用线程池异步生成缩略图,避免阻塞事件循环 # 使用线程池异步生成缩略图,避免阻塞事件循环
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
await loop.run_in_executor( await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.emoji_hash)
_thumbnail_executor,
_generate_thumbnail,
emoji.full_path,
emoji.emoji_hash
)
generated += 1 generated += 1
except Exception as e: except Exception as e:
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}") logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}")

View File

@@ -256,7 +256,9 @@ async def get_expression_list(
@router.get("/{expression_id}", response_model=ExpressionDetailResponse) @router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def get_expression_detail(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
获取表达方式详细信息 获取表达方式详细信息
@@ -285,7 +287,11 @@ async def get_expression_detail(expression_id: int, maibot_session: Optional[str
@router.post("/", response_model=ExpressionCreateResponse) @router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(request: ExpressionCreateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def create_expression(
request: ExpressionCreateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
""" """
创建新的表达方式 创建新的表达方式
@@ -326,7 +332,10 @@ async def create_expression(request: ExpressionCreateRequest, maibot_session: Op
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse) @router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression( async def update_expression(
expression_id: int, request: ExpressionUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) expression_id: int,
request: ExpressionUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
): ):
""" """
增量更新表达方式(只更新提供的字段) 增量更新表达方式(只更新提供的字段)
@@ -376,7 +385,9 @@ async def update_expression(
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse) @router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def delete_expression(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
删除表达方式 删除表达方式
@@ -419,7 +430,11 @@ class BatchDeleteRequest(BaseModel):
@router.post("/batch/delete", response_model=ExpressionDeleteResponse) @router.post("/batch/delete", response_model=ExpressionDeleteResponse)
async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def batch_delete_expressions(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
""" """
批量删除表达方式 批量删除表达方式
@@ -460,7 +475,9 @@ async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session:
@router.get("/stats/summary") @router.get("/stats/summary")
async def get_expression_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def get_expression_stats(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
获取表达方式统计数据 获取表达方式统计数据

View File

@@ -277,11 +277,7 @@ async def get_chat_list():
"""获取所有有黑话记录的聊天列表""" """获取所有有黑话记录的聊天列表"""
try: try:
# 获取所有不同的 chat_id # 获取所有不同的 chat_id
chat_ids = ( chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
Jargon.select(Jargon.chat_id)
.distinct()
.where(Jargon.chat_id.is_null(False))
)
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id] chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
@@ -346,12 +342,7 @@ async def get_jargon_stats():
complete_count = Jargon.select().where(Jargon.is_complete).count() complete_count = Jargon.select().where(Jargon.is_complete).count()
# 关联的聊天数量 # 关联的聊天数量
chat_count = ( chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
Jargon.select(Jargon.chat_id)
.distinct()
.where(Jargon.chat_id.is_null(False))
.count()
)
# 按聊天统计 TOP 5 # 按聊天统计 TOP 5
top_chats = ( top_chats = (
@@ -403,9 +394,7 @@ async def create_jargon(request: JargonCreateRequest):
"""创建黑话""" """创建黑话"""
try: try:
# 检查是否已存在相同内容的黑话 # 检查是否已存在相同内容的黑话
existing = Jargon.get_or_none( existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
(Jargon.content == request.content) & (Jargon.chat_id == request.chat_id)
)
if existing: if existing:
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话") raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
@@ -527,11 +516,7 @@ async def batch_set_jargon_status(
if not ids: if not ids:
raise HTTPException(status_code=400, detail="ID列表不能为空") raise HTTPException(status_code=400, detail="ID列表不能为空")
updated_count = ( updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
Jargon.update(is_jargon=is_jargon)
.where(Jargon.id.in_(ids))
.execute()
)
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}") logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}")

View File

@@ -200,7 +200,9 @@ async def get_person_list(
@router.get("/{person_id}", response_model=PersonDetailResponse) @router.get("/{person_id}", response_model=PersonDetailResponse)
async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def get_person_detail(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
获取人物详细信息 获取人物详细信息
@@ -229,7 +231,12 @@ async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cook
@router.patch("/{person_id}", response_model=PersonUpdateResponse) @router.patch("/{person_id}", response_model=PersonUpdateResponse)
async def update_person(person_id: str, request: PersonUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def update_person(
person_id: str,
request: PersonUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
""" """
增量更新人物信息(只更新提供的字段) 增量更新人物信息(只更新提供的字段)
@@ -278,7 +285,9 @@ async def update_person(person_id: str, request: PersonUpdateRequest, maibot_ses
@router.delete("/{person_id}", response_model=PersonDeleteResponse) @router.delete("/{person_id}", response_model=PersonDeleteResponse)
async def delete_person(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def delete_person(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
""" """
删除人物信息 删除人物信息
@@ -348,7 +357,11 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
@router.post("/batch/delete", response_model=BatchDeleteResponse) @router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_persons(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)): async def batch_delete_persons(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
""" """
批量删除人物信息 批量删除人物信息

View File

@@ -125,6 +125,7 @@ def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> No
""" """
根据 schema 将配置中的类型纠正(目前只纠正 list-from-str 根据 schema 将配置中的类型纠正(目前只纠正 list-from-str
""" """
def _is_list_type(tp: Any) -> bool: def _is_list_type(tp: Any) -> bool:
origin = get_origin(tp) origin = get_origin(tp)
return tp is list or origin is list return tp is list or origin is list
@@ -313,7 +314,9 @@ async def check_git_status() -> GitStatusResponse:
@router.get("/mirrors", response_model=AvailableMirrorsResponse) @router.get("/mirrors", response_model=AvailableMirrorsResponse)
async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse: async def get_available_mirrors(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> AvailableMirrorsResponse:
""" """
获取所有可用的镜像源配置 获取所有可用的镜像源配置
""" """
@@ -343,7 +346,9 @@ async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), au
@router.post("/mirrors", response_model=MirrorConfigResponse) @router.post("/mirrors", response_model=MirrorConfigResponse)
async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> MirrorConfigResponse: async def add_mirror(
request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> MirrorConfigResponse:
""" """
添加新的镜像源 添加新的镜像源
""" """
@@ -383,7 +388,10 @@ async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] =
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse) @router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
async def update_mirror( async def update_mirror(
mirror_id: str, request: UpdateMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) mirror_id: str,
request: UpdateMirrorRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> MirrorConfigResponse: ) -> MirrorConfigResponse:
""" """
更新镜像源配置 更新镜像源配置
@@ -426,7 +434,9 @@ async def update_mirror(
@router.delete("/mirrors/{mirror_id}") @router.delete("/mirrors/{mirror_id}")
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: async def delete_mirror(
mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
删除镜像源 删除镜像源
""" """
@@ -449,7 +459,9 @@ async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(N
@router.post("/fetch-raw", response_model=FetchRawFileResponse) @router.post("/fetch-raw", response_model=FetchRawFileResponse)
async def fetch_raw_file( async def fetch_raw_file(
request: FetchRawFileRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) request: FetchRawFileRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> FetchRawFileResponse: ) -> FetchRawFileResponse:
""" """
获取 GitHub 仓库的 Raw 文件内容 获取 GitHub 仓库的 Raw 文件内容
@@ -534,7 +546,9 @@ async def fetch_raw_file(
@router.post("/clone", response_model=CloneRepositoryResponse) @router.post("/clone", response_model=CloneRepositoryResponse)
async def clone_repository( async def clone_repository(
request: CloneRepositoryRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) request: CloneRepositoryRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> CloneRepositoryResponse: ) -> CloneRepositoryResponse:
""" """
克隆 GitHub 仓库到本地 克隆 GitHub 仓库到本地
@@ -574,7 +588,11 @@ async def clone_repository(
@router.post("/install") @router.post("/install")
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: async def install_plugin(
request: InstallPluginRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]:
""" """
安装插件 安装插件
@@ -778,7 +796,9 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
@router.post("/uninstall") @router.post("/uninstall")
async def uninstall_plugin( async def uninstall_plugin(
request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) request: UninstallPluginRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
卸载插件 卸载插件
@@ -913,7 +933,11 @@ async def uninstall_plugin(
@router.post("/update") @router.post("/update")
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: async def update_plugin(
request: UpdatePluginRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]:
""" """
更新插件 更新插件
@@ -1132,7 +1156,9 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
@router.get("/installed") @router.get("/installed")
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: async def get_installed_plugins(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
获取已安装的插件列表 获取已安装的插件列表
@@ -1272,7 +1298,9 @@ class UpdatePluginConfigRequest(BaseModel):
@router.get("/config/{plugin_id}/schema") @router.get("/config/{plugin_id}/schema")
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: async def get_plugin_config_schema(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
获取插件配置 Schema 获取插件配置 Schema
@@ -1405,7 +1433,9 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
@router.get("/config/{plugin_id}") @router.get("/config/{plugin_id}")
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: async def get_plugin_config(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
获取插件当前配置值 获取插件当前配置值
@@ -1461,7 +1491,10 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook
@router.put("/config/{plugin_id}") @router.put("/config/{plugin_id}")
async def update_plugin_config( async def update_plugin_config(
plugin_id: str, request: UpdatePluginConfigRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None) plugin_id: str,
request: UpdatePluginConfigRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
更新插件配置 更新插件配置
@@ -1532,7 +1565,9 @@ async def update_plugin_config(
@router.post("/config/{plugin_id}/reset") @router.post("/config/{plugin_id}/reset")
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: async def reset_plugin_config(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
重置插件配置为默认值 重置插件配置为默认值
@@ -1592,7 +1627,9 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co
@router.post("/config/{plugin_id}/toggle") @router.post("/config/{plugin_id}/toggle")
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]: async def toggle_plugin(
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
切换插件启用状态 切换插件启用状态

View File

@@ -110,8 +110,10 @@ class WebUIServer:
from src.webui.routes import router as webui_router from src.webui.routes import router as webui_router
from src.webui.logs_ws import router as logs_router from src.webui.logs_ws import router as logs_router
from src.webui.knowledge_routes import router as knowledge_router from src.webui.knowledge_routes import router as knowledge_router
# 导入本地聊天室路由 # 导入本地聊天室路由
from src.webui.chat_routes import router as chat_router from src.webui.chat_routes import router as chat_router
# 注册路由 # 注册路由
self.app.include_router(webui_router) self.app.include_router(webui_router)
self.app.include_router(logs_router) self.app.include_router(logs_router)
@@ -166,6 +168,7 @@ class WebUIServer:
def _check_port_available(self) -> bool: def _check_port_available(self) -> bool:
"""检查端口是否可用""" """检查端口是否可用"""
import socket import socket
try: try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1) s.settimeout(1)