Ruff format
This commit is contained in:
4
bot.py
4
bot.py
@@ -41,6 +41,7 @@ logger = get_logger("main")
|
|||||||
# 定义重启退出码
|
# 定义重启退出码
|
||||||
RESTART_EXIT_CODE = 42
|
RESTART_EXIT_CODE = 42
|
||||||
|
|
||||||
|
|
||||||
def run_runner_process():
|
def run_runner_process():
|
||||||
"""
|
"""
|
||||||
Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。
|
Runner 进程逻辑:作为守护进程运行,负责启动和监控 Worker 进程。
|
||||||
@@ -68,7 +69,7 @@ def run_runner_process():
|
|||||||
|
|
||||||
if return_code == RESTART_EXIT_CODE:
|
if return_code == RESTART_EXIT_CODE:
|
||||||
logger.info("检测到重启请求 (退出码 42),正在重启...")
|
logger.info("检测到重启请求 (退出码 42),正在重启...")
|
||||||
time.sleep(1) # 稍作等待
|
time.sleep(1) # 稍作等待
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
logger.info(f"程序已退出 (退出码 {return_code})")
|
logger.info(f"程序已退出 (退出码 {return_code})")
|
||||||
@@ -87,6 +88,7 @@ def run_runner_process():
|
|||||||
process.kill()
|
process.kill()
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
# 检查是否是 Worker 进程
|
# 检查是否是 Worker 进程
|
||||||
# 如果没有设置 MAIBOT_WORKER_PROCESS 环境变量,说明是直接运行的脚本,
|
# 如果没有设置 MAIBOT_WORKER_PROCESS 环境变量,说明是直接运行的脚本,
|
||||||
# 此时应该作为 Runner 运行。
|
# 此时应该作为 Runner 运行。
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ConversionResult:
|
class ConversionResult:
|
||||||
"""转换结果"""
|
"""转换结果"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
servers: List[Dict[str, Any]] = field(default_factory=list)
|
servers: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
errors: List[str] = field(default_factory=list)
|
errors: List[str] = field(default_factory=list)
|
||||||
@@ -271,11 +272,7 @@ class ConfigConverter:
|
|||||||
return name, result
|
return name, result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_claude_format(
|
def from_claude_format(cls, config: Dict[str, Any], existing_names: Optional[set] = None) -> ConversionResult:
|
||||||
cls,
|
|
||||||
config: Dict[str, Any],
|
|
||||||
existing_names: Optional[set] = None
|
|
||||||
) -> ConversionResult:
|
|
||||||
"""从 Claude Desktop 格式转换为 MaiBot 格式
|
"""从 Claude Desktop 格式转换为 MaiBot 格式
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -355,11 +352,7 @@ class ConfigConverter:
|
|||||||
return {"mcpServers": mcp_servers}
|
return {"mcpServers": mcp_servers}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def import_from_string(
|
def import_from_string(cls, json_str: str, existing_names: Optional[set] = None) -> ConversionResult:
|
||||||
cls,
|
|
||||||
json_str: str,
|
|
||||||
existing_names: Optional[set] = None
|
|
||||||
) -> ConversionResult:
|
|
||||||
"""从 JSON 字符串导入配置
|
"""从 JSON 字符串导入配置
|
||||||
|
|
||||||
自动检测格式并转换为 MaiBot 格式
|
自动检测格式并转换为 MaiBot 格式
|
||||||
@@ -422,12 +415,7 @@ class ConfigConverter:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def export_to_string(
|
def export_to_string(cls, servers: List[Dict[str, Any]], format_type: str = "claude", pretty: bool = True) -> str:
|
||||||
cls,
|
|
||||||
servers: List[Dict[str, Any]],
|
|
||||||
format_type: str = "claude",
|
|
||||||
pretty: bool = True
|
|
||||||
) -> str:
|
|
||||||
"""导出配置为 JSON 字符串
|
"""导出配置为 JSON 字符串
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -34,28 +34,31 @@ from enum import Enum
|
|||||||
# 尝试导入 MaiBot 的 logger,如果失败则使用标准 logging
|
# 尝试导入 MaiBot 的 logger,如果失败则使用标准 logging
|
||||||
try:
|
try:
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("mcp_client")
|
logger = get_logger("mcp_client")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Fallback: 使用标准 logging
|
# Fallback: 使用标准 logging
|
||||||
logger = logging.getLogger("mcp_client")
|
logger = logging.getLogger("mcp_client")
|
||||||
if not logger.handlers:
|
if not logger.handlers:
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
handler.setFormatter(logging.Formatter('[%(levelname)s] %(name)s: %(message)s'))
|
handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s"))
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
class TransportType(Enum):
|
class TransportType(Enum):
|
||||||
"""MCP 传输类型"""
|
"""MCP 传输类型"""
|
||||||
STDIO = "stdio" # 本地进程通信
|
|
||||||
SSE = "sse" # Server-Sent Events (旧版 HTTP)
|
STDIO = "stdio" # 本地进程通信
|
||||||
HTTP = "http" # HTTP Streamable (新版,推荐)
|
SSE = "sse" # Server-Sent Events (旧版 HTTP)
|
||||||
|
HTTP = "http" # HTTP Streamable (新版,推荐)
|
||||||
STREAMABLE_HTTP = "streamable_http" # HTTP Streamable 的别名
|
STREAMABLE_HTTP = "streamable_http" # HTTP Streamable 的别名
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MCPToolInfo:
|
class MCPToolInfo:
|
||||||
"""MCP 工具信息"""
|
"""MCP 工具信息"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
input_schema: Dict[str, Any]
|
input_schema: Dict[str, Any]
|
||||||
@@ -65,6 +68,7 @@ class MCPToolInfo:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MCPResourceInfo:
|
class MCPResourceInfo:
|
||||||
"""MCP 资源信息"""
|
"""MCP 资源信息"""
|
||||||
|
|
||||||
uri: str
|
uri: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
@@ -75,6 +79,7 @@ class MCPResourceInfo:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MCPPromptInfo:
|
class MCPPromptInfo:
|
||||||
"""MCP 提示模板信息"""
|
"""MCP 提示模板信息"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
arguments: List[Dict[str, Any]] # [{name, description, required}]
|
arguments: List[Dict[str, Any]] # [{name, description, required}]
|
||||||
@@ -84,6 +89,7 @@ class MCPPromptInfo:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MCPServerConfig:
|
class MCPServerConfig:
|
||||||
"""MCP 服务器配置"""
|
"""MCP 服务器配置"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
transport: TransportType = TransportType.STDIO
|
transport: TransportType = TransportType.STDIO
|
||||||
@@ -99,6 +105,7 @@ class MCPServerConfig:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MCPCallResult:
|
class MCPCallResult:
|
||||||
"""MCP 工具调用结果"""
|
"""MCP 工具调用结果"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
content: Any
|
content: Any
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
@@ -108,8 +115,9 @@ class MCPCallResult:
|
|||||||
|
|
||||||
class CircuitState(Enum):
|
class CircuitState(Enum):
|
||||||
"""断路器状态"""
|
"""断路器状态"""
|
||||||
CLOSED = "closed" # 正常状态,允许请求
|
|
||||||
OPEN = "open" # 熔断状态,拒绝请求
|
CLOSED = "closed" # 正常状态,允许请求
|
||||||
|
OPEN = "open" # 熔断状态,拒绝请求
|
||||||
HALF_OPEN = "half_open" # 半开状态,允许少量试探请求
|
HALF_OPEN = "half_open" # 半开状态,允许少量试探请求
|
||||||
|
|
||||||
|
|
||||||
@@ -125,9 +133,9 @@ class CircuitBreaker:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# 配置
|
# 配置
|
||||||
failure_threshold: int = 5 # 连续失败多少次后熔断
|
failure_threshold: int = 5 # 连续失败多少次后熔断
|
||||||
recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒)
|
recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒)
|
||||||
half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用
|
half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用
|
||||||
|
|
||||||
# 状态
|
# 状态
|
||||||
state: CircuitState = field(default=CircuitState.CLOSED)
|
state: CircuitState = field(default=CircuitState.CLOSED)
|
||||||
@@ -232,6 +240,7 @@ class CircuitBreaker:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ToolCallStats:
|
class ToolCallStats:
|
||||||
"""工具调用统计"""
|
"""工具调用统计"""
|
||||||
|
|
||||||
tool_key: str
|
tool_key: str
|
||||||
total_calls: int = 0
|
total_calls: int = 0
|
||||||
success_calls: int = 0
|
success_calls: int = 0
|
||||||
@@ -282,6 +291,7 @@ class ToolCallStats:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ServerStats:
|
class ServerStats:
|
||||||
"""服务器统计"""
|
"""服务器统计"""
|
||||||
|
|
||||||
server_name: str
|
server_name: str
|
||||||
connect_count: int = 0 # 连接次数
|
connect_count: int = 0 # 连接次数
|
||||||
disconnect_count: int = 0 # 断开次数
|
disconnect_count: int = 0 # 断开次数
|
||||||
@@ -442,9 +452,7 @@ class MCPClientSession:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
server_params = StdioServerParameters(
|
server_params = StdioServerParameters(
|
||||||
command=self.config.command,
|
command=self.config.command, args=self.config.args, env=self.config.env if self.config.env else None
|
||||||
args=self.config.args,
|
|
||||||
env=self.config.env if self.config.env else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._stdio_context = stdio_client(server_params)
|
self._stdio_context = stdio_client(server_params)
|
||||||
@@ -506,6 +514,7 @@ class MCPClientSession:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.server_name}] SSE 连接失败: {e}")
|
logger.error(f"[{self.server_name}] SSE 连接失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
|
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
|
||||||
await self._cleanup()
|
await self._cleanup()
|
||||||
return False
|
return False
|
||||||
@@ -551,6 +560,7 @@ class MCPClientSession:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.server_name}] HTTP 连接失败: {e}")
|
logger.error(f"[{self.server_name}] HTTP 连接失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
|
logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}")
|
||||||
await self._cleanup()
|
await self._cleanup()
|
||||||
return False
|
return False
|
||||||
@@ -568,8 +578,8 @@ class MCPClientSession:
|
|||||||
tool_info = MCPToolInfo(
|
tool_info = MCPToolInfo(
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
description=tool.description or f"MCP tool: {tool.name}",
|
description=tool.description or f"MCP tool: {tool.name}",
|
||||||
input_schema=tool.inputSchema if hasattr(tool, 'inputSchema') else {},
|
input_schema=tool.inputSchema if hasattr(tool, "inputSchema") else {},
|
||||||
server_name=self.server_name
|
server_name=self.server_name,
|
||||||
)
|
)
|
||||||
self._tools.append(tool_info)
|
self._tools.append(tool_info)
|
||||||
# 初始化工具统计
|
# 初始化工具统计
|
||||||
@@ -591,10 +601,7 @@ class MCPClientSession:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(self._session.list_resources(), timeout=self.call_timeout)
|
||||||
self._session.list_resources(),
|
|
||||||
timeout=self.call_timeout
|
|
||||||
)
|
|
||||||
self._resources = []
|
self._resources = []
|
||||||
|
|
||||||
for resource in result.resources:
|
for resource in result.resources:
|
||||||
@@ -602,8 +609,8 @@ class MCPClientSession:
|
|||||||
uri=str(resource.uri),
|
uri=str(resource.uri),
|
||||||
name=resource.name or str(resource.uri),
|
name=resource.name or str(resource.uri),
|
||||||
description=resource.description or "",
|
description=resource.description or "",
|
||||||
mime_type=resource.mimeType if hasattr(resource, 'mimeType') else None,
|
mime_type=resource.mimeType if hasattr(resource, "mimeType") else None,
|
||||||
server_name=self.server_name
|
server_name=self.server_name,
|
||||||
)
|
)
|
||||||
self._resources.append(resource_info)
|
self._resources.append(resource_info)
|
||||||
logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}")
|
logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}")
|
||||||
@@ -633,28 +640,27 @@ class MCPClientSession:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(self._session.list_prompts(), timeout=self.call_timeout)
|
||||||
self._session.list_prompts(),
|
|
||||||
timeout=self.call_timeout
|
|
||||||
)
|
|
||||||
self._prompts = []
|
self._prompts = []
|
||||||
|
|
||||||
for prompt in result.prompts:
|
for prompt in result.prompts:
|
||||||
# 解析参数
|
# 解析参数
|
||||||
arguments = []
|
arguments = []
|
||||||
if hasattr(prompt, 'arguments') and prompt.arguments:
|
if hasattr(prompt, "arguments") and prompt.arguments:
|
||||||
for arg in prompt.arguments:
|
for arg in prompt.arguments:
|
||||||
arguments.append({
|
arguments.append(
|
||||||
"name": arg.name,
|
{
|
||||||
"description": arg.description or "",
|
"name": arg.name,
|
||||||
"required": arg.required if hasattr(arg, 'required') else False,
|
"description": arg.description or "",
|
||||||
})
|
"required": arg.required if hasattr(arg, "required") else False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
prompt_info = MCPPromptInfo(
|
prompt_info = MCPPromptInfo(
|
||||||
name=prompt.name,
|
name=prompt.name,
|
||||||
description=prompt.description or f"MCP prompt: {prompt.name}",
|
description=prompt.description or f"MCP prompt: {prompt.name}",
|
||||||
arguments=arguments,
|
arguments=arguments,
|
||||||
server_name=self.server_name
|
server_name=self.server_name,
|
||||||
)
|
)
|
||||||
self._prompts.append(prompt_info)
|
self._prompts.append(prompt_info)
|
||||||
logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}")
|
logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}")
|
||||||
@@ -686,35 +692,25 @@ class MCPClientSession:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
if not self._connected or not self._session:
|
if not self._connected or not self._session:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {self.server_name} 未连接"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self._supports_resources:
|
if not self._supports_resources:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Resources 功能")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {self.server_name} 不支持 Resources 功能"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(self._session.read_resource(uri), timeout=self.call_timeout)
|
||||||
self._session.read_resource(uri),
|
|
||||||
timeout=self.call_timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
# 处理返回内容
|
# 处理返回内容
|
||||||
content_parts = []
|
content_parts = []
|
||||||
for content in result.contents:
|
for content in result.contents:
|
||||||
if hasattr(content, 'text'):
|
if hasattr(content, "text"):
|
||||||
content_parts.append(content.text)
|
content_parts.append(content.text)
|
||||||
elif hasattr(content, 'blob'):
|
elif hasattr(content, "blob"):
|
||||||
# 二进制数据,返回 base64 或提示
|
# 二进制数据,返回 base64 或提示
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
blob_data = content.blob
|
blob_data = content.blob
|
||||||
if len(blob_data) < 10000: # 小于 10KB 返回 base64
|
if len(blob_data) < 10000: # 小于 10KB 返回 base64
|
||||||
content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}")
|
content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}")
|
||||||
@@ -724,28 +720,18 @@ class MCPClientSession:
|
|||||||
content_parts.append(str(content))
|
content_parts.append(str(content))
|
||||||
|
|
||||||
return MCPCallResult(
|
return MCPCallResult(
|
||||||
success=True,
|
success=True, content="\n".join(content_parts) if content_parts else "", duration_ms=duration_ms
|
||||||
content="\n".join(content_parts) if content_parts else "",
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
return MCPCallResult(
|
return MCPCallResult(
|
||||||
success=False,
|
success=False, content=None, error=f"读取资源超时({self.call_timeout}秒)", duration_ms=duration_ms
|
||||||
content=None,
|
|
||||||
error=f"读取资源超时({self.call_timeout}秒)",
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}")
|
logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}")
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=str(e),
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult:
|
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult:
|
||||||
"""v1.2.0: 获取提示模板的内容
|
"""v1.2.0: 获取提示模板的内容
|
||||||
@@ -760,23 +746,14 @@ class MCPClientSession:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
if not self._connected or not self._session:
|
if not self._connected or not self._session:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {self.server_name} 未连接"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self._supports_prompts:
|
if not self._supports_prompts:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Prompts 功能")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {self.server_name} 不支持 Prompts 功能"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
self._session.get_prompt(name, arguments=arguments or {}),
|
self._session.get_prompt(name, arguments=arguments or {}), timeout=self.call_timeout
|
||||||
timeout=self.call_timeout
|
|
||||||
)
|
)
|
||||||
|
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
@@ -784,10 +761,10 @@ class MCPClientSession:
|
|||||||
# 处理返回的消息
|
# 处理返回的消息
|
||||||
messages = []
|
messages = []
|
||||||
for msg in result.messages:
|
for msg in result.messages:
|
||||||
role = msg.role if hasattr(msg, 'role') else "unknown"
|
role = msg.role if hasattr(msg, "role") else "unknown"
|
||||||
content_text = ""
|
content_text = ""
|
||||||
if hasattr(msg, 'content'):
|
if hasattr(msg, "content"):
|
||||||
if hasattr(msg.content, 'text'):
|
if hasattr(msg.content, "text"):
|
||||||
content_text = msg.content.text
|
content_text = msg.content.text
|
||||||
elif isinstance(msg.content, str):
|
elif isinstance(msg.content, str):
|
||||||
content_text = msg.content
|
content_text = msg.content
|
||||||
@@ -796,28 +773,18 @@ class MCPClientSession:
|
|||||||
messages.append(f"[{role}]: {content_text}")
|
messages.append(f"[{role}]: {content_text}")
|
||||||
|
|
||||||
return MCPCallResult(
|
return MCPCallResult(
|
||||||
success=True,
|
success=True, content="\n\n".join(messages) if messages else "", duration_ms=duration_ms
|
||||||
content="\n\n".join(messages) if messages else "",
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
return MCPCallResult(
|
return MCPCallResult(
|
||||||
success=False,
|
success=False, content=None, error=f"获取提示模板超时({self.call_timeout}秒)", duration_ms=duration_ms
|
||||||
content=None,
|
|
||||||
error=f"获取提示模板超时({self.call_timeout}秒)",
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}")
|
logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}")
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms)
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=str(e),
|
|
||||||
duration_ms=duration_ms
|
|
||||||
)
|
|
||||||
|
|
||||||
async def check_health(self) -> bool:
|
async def check_health(self) -> bool:
|
||||||
"""检查连接健康状态(心跳检测)
|
"""检查连接健康状态(心跳检测)
|
||||||
@@ -829,10 +796,7 @@ class MCPClientSession:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用 list_tools 作为心跳检测
|
# 使用 list_tools 作为心跳检测
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(self._session.list_tools(), timeout=10.0)
|
||||||
self._session.list_tools(),
|
|
||||||
timeout=10.0
|
|
||||||
)
|
|
||||||
self.stats.record_heartbeat()
|
self.stats.record_heartbeat()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -849,12 +813,7 @@ class MCPClientSession:
|
|||||||
# v1.7.0: 断路器检查
|
# v1.7.0: 断路器检查
|
||||||
can_execute, reject_reason = self._circuit_breaker.can_execute()
|
can_execute, reject_reason = self._circuit_breaker.can_execute()
|
||||||
if not can_execute:
|
if not can_execute:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"⚡ {reject_reason}", circuit_broken=True)
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"⚡ {reject_reason}",
|
|
||||||
circuit_broken=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 半开状态下增加试探计数
|
# 半开状态下增加试探计数
|
||||||
if self._circuit_breaker.state == CircuitState.HALF_OPEN:
|
if self._circuit_breaker.state == CircuitState.HALF_OPEN:
|
||||||
@@ -870,8 +829,7 @@ class MCPClientSession:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
self._session.call_tool(tool_name, arguments=arguments),
|
self._session.call_tool(tool_name, arguments=arguments), timeout=self.call_timeout
|
||||||
timeout=self.call_timeout
|
|
||||||
)
|
)
|
||||||
|
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
@@ -879,9 +837,9 @@ class MCPClientSession:
|
|||||||
# 处理返回内容
|
# 处理返回内容
|
||||||
content_parts = []
|
content_parts = []
|
||||||
for content in result.content:
|
for content in result.content:
|
||||||
if hasattr(content, 'text'):
|
if hasattr(content, "text"):
|
||||||
content_parts.append(content.text)
|
content_parts.append(content.text)
|
||||||
elif hasattr(content, 'data'):
|
elif hasattr(content, "data"):
|
||||||
content_parts.append(f"[二进制数据: {len(content.data)} bytes]")
|
content_parts.append(f"[二进制数据: {len(content.data)} bytes]")
|
||||||
else:
|
else:
|
||||||
content_parts.append(str(content))
|
content_parts.append(str(content))
|
||||||
@@ -896,7 +854,7 @@ class MCPClientSession:
|
|||||||
return MCPCallResult(
|
return MCPCallResult(
|
||||||
success=True,
|
success=True,
|
||||||
content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)",
|
content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)",
|
||||||
duration_ms=duration_ms
|
duration_ms=duration_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
@@ -939,25 +897,25 @@ class MCPClientSession:
|
|||||||
self._supports_prompts = False # v1.2.0
|
self._supports_prompts = False # v1.2.0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(self, '_session_context') and self._session_context:
|
if hasattr(self, "_session_context") and self._session_context:
|
||||||
await self._session_context.__aexit__(None, None, None)
|
await self._session_context.__aexit__(None, None, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}")
|
logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(self, '_stdio_context') and self._stdio_context:
|
if hasattr(self, "_stdio_context") and self._stdio_context:
|
||||||
await self._stdio_context.__aexit__(None, None, None)
|
await self._stdio_context.__aexit__(None, None, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}")
|
logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(self, '_http_context') and self._http_context:
|
if hasattr(self, "_http_context") and self._http_context:
|
||||||
await self._http_context.__aexit__(None, None, None)
|
await self._http_context.__aexit__(None, None, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}")
|
logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(self, '_sse_context') and self._sse_context:
|
if hasattr(self, "_sse_context") and self._sse_context:
|
||||||
await self._sse_context.__aexit__(None, None, None)
|
await self._sse_context.__aexit__(None, None, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}")
|
logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}")
|
||||||
@@ -1082,7 +1040,9 @@ class MCPClientManager:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
if attempt < retry_attempts:
|
if attempt < retry_attempts:
|
||||||
logger.warning(f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})")
|
logger.warning(
|
||||||
|
f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})"
|
||||||
|
)
|
||||||
await asyncio.sleep(retry_interval)
|
await asyncio.sleep(retry_interval)
|
||||||
|
|
||||||
logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})")
|
logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})")
|
||||||
@@ -1213,11 +1173,7 @@ class MCPClientManager:
|
|||||||
async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult:
|
async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult:
|
||||||
"""调用 MCP 工具"""
|
"""调用 MCP 工具"""
|
||||||
if tool_key not in self._all_tools:
|
if tool_key not in self._all_tools:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"工具 {tool_key} 不存在")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"工具 {tool_key} 不存在"
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_info, client = self._all_tools[tool_key]
|
tool_info, client = self._all_tools[tool_key]
|
||||||
|
|
||||||
@@ -1273,11 +1229,7 @@ class MCPClientManager:
|
|||||||
# 如果指定了服务器
|
# 如果指定了服务器
|
||||||
if server_name:
|
if server_name:
|
||||||
if server_name not in self._clients:
|
if server_name not in self._clients:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {server_name} 不存在"
|
|
||||||
)
|
|
||||||
client = self._clients[server_name]
|
client = self._clients[server_name]
|
||||||
return await client.read_resource(uri)
|
return await client.read_resource(uri)
|
||||||
|
|
||||||
@@ -1293,14 +1245,11 @@ class MCPClientManager:
|
|||||||
if result.success:
|
if result.success:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"未找到资源: {uri}")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"未找到资源: {uri}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None,
|
async def get_prompt(
|
||||||
server_name: Optional[str] = None) -> MCPCallResult:
|
self, name: str, arguments: Optional[Dict[str, str]] = None, server_name: Optional[str] = None
|
||||||
|
) -> MCPCallResult:
|
||||||
"""v1.2.0: 获取提示模板内容
|
"""v1.2.0: 获取提示模板内容
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1311,11 +1260,7 @@ class MCPClientManager:
|
|||||||
# 如果指定了服务器
|
# 如果指定了服务器
|
||||||
if server_name:
|
if server_name:
|
||||||
if server_name not in self._clients:
|
if server_name not in self._clients:
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"服务器 {server_name} 不存在"
|
|
||||||
)
|
|
||||||
client = self._clients[server_name]
|
client = self._clients[server_name]
|
||||||
return await client.get_prompt(name, arguments)
|
return await client.get_prompt(name, arguments)
|
||||||
|
|
||||||
@@ -1324,11 +1269,7 @@ class MCPClientManager:
|
|||||||
if prompt_info.name == name:
|
if prompt_info.name == name:
|
||||||
return await client.get_prompt(name, arguments)
|
return await client.get_prompt(name, arguments)
|
||||||
|
|
||||||
return MCPCallResult(
|
return MCPCallResult(success=False, content=None, error=f"未找到提示模板: {name}")
|
||||||
success=False,
|
|
||||||
content=None,
|
|
||||||
error=f"未找到提示模板: {name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ==================== 心跳检测 ====================
|
# ==================== 心跳检测 ====================
|
||||||
|
|
||||||
@@ -1489,7 +1430,9 @@ class MCPClientManager:
|
|||||||
"global": {
|
"global": {
|
||||||
**self._global_stats,
|
**self._global_stats,
|
||||||
"uptime_seconds": round(uptime, 2),
|
"uptime_seconds": round(uptime, 2),
|
||||||
"calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2) if uptime > 0 else 0,
|
"calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2)
|
||||||
|
if uptime > 0
|
||||||
|
else 0,
|
||||||
},
|
},
|
||||||
"servers": server_stats,
|
"servers": server_stats,
|
||||||
"tools": tool_stats,
|
"tools": tool_stats,
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ from .mcp_client import (
|
|||||||
TransportType,
|
TransportType,
|
||||||
mcp_manager,
|
mcp_manager,
|
||||||
)
|
)
|
||||||
from .config_converter import ConfigConverter, ConversionResult
|
from .config_converter import ConfigConverter
|
||||||
|
|
||||||
logger = get_logger("mcp_bridge_plugin")
|
logger = get_logger("mcp_bridge_plugin")
|
||||||
|
|
||||||
@@ -93,9 +93,11 @@ logger = get_logger("mcp_bridge_plugin")
|
|||||||
# v1.4.0: 调用链路追踪
|
# v1.4.0: 调用链路追踪
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolCallRecord:
|
class ToolCallRecord:
|
||||||
"""工具调用记录"""
|
"""工具调用记录"""
|
||||||
|
|
||||||
call_id: str
|
call_id: str
|
||||||
timestamp: float
|
timestamp: float
|
||||||
tool_name: str
|
tool_name: str
|
||||||
@@ -178,9 +180,11 @@ tool_call_tracer = ToolCallTracer()
|
|||||||
# v1.4.0: 工具调用缓存
|
# v1.4.0: 工具调用缓存
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CacheEntry:
|
class CacheEntry:
|
||||||
"""缓存条目"""
|
"""缓存条目"""
|
||||||
|
|
||||||
tool_name: str
|
tool_name: str
|
||||||
args_hash: str
|
args_hash: str
|
||||||
result: str
|
result: str
|
||||||
@@ -317,6 +321,7 @@ tool_call_cache = ToolCallCache()
|
|||||||
# v1.4.0: 工具权限控制
|
# v1.4.0: 工具权限控制
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class PermissionChecker:
|
class PermissionChecker:
|
||||||
"""工具权限检查器"""
|
"""工具权限检查器"""
|
||||||
|
|
||||||
@@ -449,6 +454,7 @@ permission_checker = PermissionChecker()
|
|||||||
# 工具类型转换
|
# 工具类型转换
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
|
def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
|
||||||
"""将 JSON Schema 类型转换为 MaiBot 的 ToolParamType"""
|
"""将 JSON Schema 类型转换为 MaiBot 的 ToolParamType"""
|
||||||
type_mapping = {
|
type_mapping = {
|
||||||
@@ -462,7 +468,9 @@ def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
|
|||||||
return type_mapping.get(json_type, ToolParamType.STRING)
|
return type_mapping.get(json_type, ToolParamType.STRING)
|
||||||
|
|
||||||
|
|
||||||
def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]:
|
def parse_mcp_parameters(
|
||||||
|
input_schema: Dict[str, Any],
|
||||||
|
) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]:
|
||||||
"""解析 MCP 工具的参数 schema,转换为 MaiBot 的参数格式"""
|
"""解析 MCP 工具的参数 schema,转换为 MaiBot 的参数格式"""
|
||||||
parameters = []
|
parameters = []
|
||||||
|
|
||||||
@@ -497,6 +505,7 @@ def parse_mcp_parameters(input_schema: Dict[str, Any]) -> List[Tuple[str, ToolPa
|
|||||||
# MCP 工具代理
|
# MCP 工具代理
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MCPToolProxy(BaseTool):
|
class MCPToolProxy(BaseTool):
|
||||||
"""MCP 工具代理基类"""
|
"""MCP 工具代理基类"""
|
||||||
|
|
||||||
@@ -539,10 +548,7 @@ class MCPToolProxy(BaseTool):
|
|||||||
# v1.4.0: 权限检查
|
# v1.4.0: 权限检查
|
||||||
if not permission_checker.check(self.name, chat_id, user_id, is_group):
|
if not permission_checker.check(self.name, chat_id, user_id, is_group):
|
||||||
logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}")
|
logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}")
|
||||||
return {
|
return {"name": self.name, "content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"}
|
||||||
"name": self.name,
|
|
||||||
"content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}")
|
logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}")
|
||||||
|
|
||||||
@@ -726,11 +732,7 @@ class MCPToolProxy(BaseTool):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def _call_post_process_llm(
|
async def _call_post_process_llm(
|
||||||
self,
|
self, prompt: str, max_tokens: int, settings: Dict[str, Any], server_config: Optional[Dict[str, Any]]
|
||||||
prompt: str,
|
|
||||||
max_tokens: int,
|
|
||||||
settings: Dict[str, Any],
|
|
||||||
server_config: Optional[Dict[str, Any]]
|
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""调用 LLM 进行后处理"""
|
"""调用 LLM 进行后处理"""
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
@@ -788,10 +790,7 @@ class MCPToolProxy(BaseTool):
|
|||||||
|
|
||||||
|
|
||||||
def create_mcp_tool_class(
|
def create_mcp_tool_class(
|
||||||
tool_key: str,
|
tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False
|
||||||
tool_info: MCPToolInfo,
|
|
||||||
tool_prefix: str,
|
|
||||||
disabled: bool = False
|
|
||||||
) -> Type[MCPToolProxy]:
|
) -> Type[MCPToolProxy]:
|
||||||
"""根据 MCP 工具信息动态创建 BaseTool 子类"""
|
"""根据 MCP 工具信息动态创建 BaseTool 子类"""
|
||||||
parameters = parse_mcp_parameters(tool_info.input_schema)
|
parameters = parse_mcp_parameters(tool_info.input_schema)
|
||||||
@@ -814,7 +813,7 @@ def create_mcp_tool_class(
|
|||||||
"_mcp_tool_key": tool_key,
|
"_mcp_tool_key": tool_key,
|
||||||
"_mcp_original_name": tool_info.name,
|
"_mcp_original_name": tool_info.name,
|
||||||
"_mcp_server_name": tool_info.server_name,
|
"_mcp_server_name": tool_info.server_name,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return tool_class
|
return tool_class
|
||||||
@@ -828,11 +827,7 @@ class MCPToolRegistry:
|
|||||||
self._tool_infos: Dict[str, ToolInfo] = {}
|
self._tool_infos: Dict[str, ToolInfo] = {}
|
||||||
|
|
||||||
def register_tool(
|
def register_tool(
|
||||||
self,
|
self, tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False
|
||||||
tool_key: str,
|
|
||||||
tool_info: MCPToolInfo,
|
|
||||||
tool_prefix: str,
|
|
||||||
disabled: bool = False
|
|
||||||
) -> Tuple[ToolInfo, Type[MCPToolProxy]]:
|
) -> Tuple[ToolInfo, Type[MCPToolProxy]]:
|
||||||
"""注册 MCP 工具"""
|
"""注册 MCP 工具"""
|
||||||
tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled)
|
tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled)
|
||||||
@@ -879,6 +874,7 @@ _plugin_instance: Optional["MCPBridgePlugin"] = None
|
|||||||
# 内置工具
|
# 内置工具
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MCPReadResourceTool(BaseTool):
|
class MCPReadResourceTool(BaseTool):
|
||||||
"""v1.2.0: MCP 资源读取工具"""
|
"""v1.2.0: MCP 资源读取工具"""
|
||||||
|
|
||||||
@@ -950,9 +946,17 @@ class MCPStatusTool(BaseTool):
|
|||||||
"""MCP 状态查询工具"""
|
"""MCP 状态查询工具"""
|
||||||
|
|
||||||
name = "mcp_status"
|
name = "mcp_status"
|
||||||
description = "查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、资源列表、提示模板列表、调用统计、追踪记录等信息"
|
description = (
|
||||||
|
"查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、资源列表、提示模板列表、调用统计、追踪记录等信息"
|
||||||
|
)
|
||||||
parameters = [
|
parameters = [
|
||||||
("query_type", ToolParamType.STRING, "查询类型", False, ["status", "tools", "resources", "prompts", "stats", "trace", "cache", "all"]),
|
(
|
||||||
|
"query_type",
|
||||||
|
ToolParamType.STRING,
|
||||||
|
"查询类型",
|
||||||
|
False,
|
||||||
|
["status", "tools", "resources", "prompts", "stats", "trace", "cache", "all"],
|
||||||
|
),
|
||||||
("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None),
|
("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None),
|
||||||
]
|
]
|
||||||
available_for_llm = True
|
available_for_llm = True
|
||||||
@@ -986,10 +990,7 @@ class MCPStatusTool(BaseTool):
|
|||||||
if query_type in ("cache",):
|
if query_type in ("cache",):
|
||||||
result_parts.append(self._format_cache())
|
result_parts.append(self._format_cache())
|
||||||
|
|
||||||
return {
|
return {"name": self.name, "content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"}
|
||||||
"name": self.name,
|
|
||||||
"content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"
|
|
||||||
}
|
|
||||||
|
|
||||||
def _format_status(self, server_name: Optional[str] = None) -> str:
|
def _format_status(self, server_name: Optional[str] = None) -> str:
|
||||||
status = mcp_manager.get_status()
|
status = mcp_manager.get_status()
|
||||||
@@ -1001,14 +1002,14 @@ class MCPStatusTool(BaseTool):
|
|||||||
lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}")
|
lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}")
|
||||||
|
|
||||||
lines.append("\n🔌 服务器详情:")
|
lines.append("\n🔌 服务器详情:")
|
||||||
for name, info in status['servers'].items():
|
for name, info in status["servers"].items():
|
||||||
if server_name and name != server_name:
|
if server_name and name != server_name:
|
||||||
continue
|
continue
|
||||||
status_icon = "✅" if info['connected'] else "❌"
|
status_icon = "✅" if info["connected"] else "❌"
|
||||||
enabled_text = "" if info['enabled'] else " (已禁用)"
|
enabled_text = "" if info["enabled"] else " (已禁用)"
|
||||||
lines.append(f" {status_icon} {name}{enabled_text}")
|
lines.append(f" {status_icon} {name}{enabled_text}")
|
||||||
lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}")
|
lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}")
|
||||||
if info['consecutive_failures'] > 0:
|
if info["consecutive_failures"] > 0:
|
||||||
lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']} 次")
|
lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']} 次")
|
||||||
|
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
@@ -1038,11 +1039,11 @@ class MCPStatusTool(BaseTool):
|
|||||||
stats = mcp_manager.get_all_stats()
|
stats = mcp_manager.get_all_stats()
|
||||||
lines = ["📈 调用统计"]
|
lines = ["📈 调用统计"]
|
||||||
|
|
||||||
g = stats['global']
|
g = stats["global"]
|
||||||
lines.append(f" 总调用次数: {g['total_tool_calls']}")
|
lines.append(f" 总调用次数: {g['total_tool_calls']}")
|
||||||
lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}")
|
lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}")
|
||||||
if g['total_tool_calls'] > 0:
|
if g["total_tool_calls"] > 0:
|
||||||
success_rate = (g['successful_calls'] / g['total_tool_calls']) * 100
|
success_rate = (g["successful_calls"] / g["total_tool_calls"]) * 100
|
||||||
lines.append(f" 成功率: {success_rate:.1f}%")
|
lines.append(f" 成功率: {success_rate:.1f}%")
|
||||||
lines.append(f" 运行时间: {g['uptime_seconds']:.0f} 秒")
|
lines.append(f" 运行时间: {g['uptime_seconds']:.0f} 秒")
|
||||||
|
|
||||||
@@ -1126,6 +1127,7 @@ class MCPStatusTool(BaseTool):
|
|||||||
# 命令处理
|
# 命令处理
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MCPStatusCommand(BaseCommand):
|
class MCPStatusCommand(BaseCommand):
|
||||||
"""MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态"""
|
"""MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态"""
|
||||||
|
|
||||||
@@ -1340,7 +1342,9 @@ class MCPStatusCommand(BaseCommand):
|
|||||||
try:
|
try:
|
||||||
exported = ConfigConverter.export_to_string(servers, format_type, pretty=True)
|
exported = ConfigConverter.export_to_string(servers, format_type, pretty=True)
|
||||||
|
|
||||||
format_name = {"claude": "Claude Desktop", "kiro": "Kiro MCP", "maibot": "MaiBot"}.get(format_type, format_type)
|
format_name = {"claude": "Claude Desktop", "kiro": "Kiro MCP", "maibot": "MaiBot"}.get(
|
||||||
|
format_type, format_type
|
||||||
|
)
|
||||||
lines = [f"📤 导出为 {format_name} 格式 ({len(servers)} 个服务器):"]
|
lines = [f"📤 导出为 {format_name} 格式 ({len(servers)} 个服务器):"]
|
||||||
lines.append("")
|
lines.append("")
|
||||||
lines.append(exported)
|
lines.append(exported)
|
||||||
@@ -1446,9 +1450,9 @@ class MCPStatusCommand(BaseCommand):
|
|||||||
cb = info.get("circuit_breaker", {})
|
cb = info.get("circuit_breaker", {})
|
||||||
cb_state = cb.get("state", "closed")
|
cb_state = cb.get("state", "closed")
|
||||||
if cb_state == "open":
|
if cb_state == "open":
|
||||||
lines.append(f" ⚡ 断路器熔断中")
|
lines.append(" ⚡ 断路器熔断中")
|
||||||
elif cb_state == "half_open":
|
elif cb_state == "half_open":
|
||||||
lines.append(f" ⚡ 断路器试探中")
|
lines.append(" ⚡ 断路器试探中")
|
||||||
if info["consecutive_failures"] > 0:
|
if info["consecutive_failures"] > 0:
|
||||||
lines.append(f" ⚠️ 连续失败 {info['consecutive_failures']} 次")
|
lines.append(f" ⚠️ 连续失败 {info['consecutive_failures']} 次")
|
||||||
|
|
||||||
@@ -1634,6 +1638,7 @@ class MCPImportCommand(BaseCommand):
|
|||||||
# 事件处理器
|
# 事件处理器
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MCPStartupHandler(BaseEventHandler):
|
class MCPStartupHandler(BaseEventHandler):
|
||||||
"""MCP 启动事件处理器"""
|
"""MCP 启动事件处理器"""
|
||||||
|
|
||||||
@@ -1692,6 +1697,7 @@ class MCPStopHandler(BaseEventHandler):
|
|||||||
# 主插件类
|
# 主插件类
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@register_plugin
|
@register_plugin
|
||||||
class MCPBridgePlugin(BasePlugin):
|
class MCPBridgePlugin(BasePlugin):
|
||||||
"""MCP 桥接插件 v1.4.0 - 将 MCP 服务器的工具桥接到 MaiBot"""
|
"""MCP 桥接插件 v1.4.0 - 将 MCP 服务器的工具桥接到 MaiBot"""
|
||||||
@@ -2116,9 +2122,9 @@ class MCPBridgePlugin(BasePlugin):
|
|||||||
label="📜 高级权限规则(可选)",
|
label="📜 高级权限规则(可选)",
|
||||||
input_type="textarea",
|
input_type="textarea",
|
||||||
rows=10,
|
rows=10,
|
||||||
placeholder='''[
|
placeholder="""[
|
||||||
{"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}
|
{"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}
|
||||||
]''',
|
]""",
|
||||||
hint="格式: qq:ID:group/private/user,工具名支持通配符 *",
|
hint="格式: qq:ID:group/private/user,工具名支持通配符 *",
|
||||||
order=10,
|
order=10,
|
||||||
),
|
),
|
||||||
@@ -2261,7 +2267,9 @@ class MCPBridgePlugin(BasePlugin):
|
|||||||
value = match1.group(2)
|
value = match1.group(2)
|
||||||
suffix = match1.group(3)
|
suffix = match1.group(3)
|
||||||
# 将转义的换行符还原为实际换行符
|
# 将转义的换行符还原为实际换行符
|
||||||
unescaped = value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\")
|
unescaped = (
|
||||||
|
value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\")
|
||||||
|
)
|
||||||
fixed_line = f'{prefix}"""{unescaped}"""{suffix}'
|
fixed_line = f'{prefix}"""{unescaped}"""{suffix}'
|
||||||
fixed_lines.append(fixed_line)
|
fixed_lines.append(fixed_line)
|
||||||
modified = True
|
modified = True
|
||||||
@@ -2560,7 +2568,7 @@ class MCPBridgePlugin(BasePlugin):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
import_config_str = str(import_config).strip()
|
import_config_str = str(import_config).strip()
|
||||||
logger.info(f"检测到 WebUI 导入请求,开始处理...")
|
logger.info("检测到 WebUI 导入请求,开始处理...")
|
||||||
|
|
||||||
# 执行导入
|
# 执行导入
|
||||||
await self._execute_webui_import(import_config_str, doc, config_path)
|
await self._execute_webui_import(import_config_str, doc, config_path)
|
||||||
@@ -2773,6 +2781,7 @@ class MCPBridgePlugin(BasePlugin):
|
|||||||
async def _async_connect_servers(self) -> None:
|
async def _async_connect_servers(self) -> None:
|
||||||
"""异步连接所有配置的 MCP 服务器(v1.5.0: 并行连接优化)"""
|
"""异步连接所有配置的 MCP 服务器(v1.5.0: 并行连接优化)"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
settings = self.config.get("settings", {})
|
settings = self.config.get("settings", {})
|
||||||
|
|
||||||
servers_section = self.config.get("servers", [])
|
servers_section = self.config.get("servers", [])
|
||||||
@@ -2854,10 +2863,7 @@ class MCPBridgePlugin(BasePlugin):
|
|||||||
|
|
||||||
# 并行执行所有连接
|
# 并行执行所有连接
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(*[connect_single_server(cfg) for cfg in enabled_configs], return_exceptions=True)
|
||||||
*[connect_single_server(cfg) for cfg in enabled_configs],
|
|
||||||
return_exceptions=True
|
|
||||||
)
|
|
||||||
connect_duration = time.time() - start_time
|
connect_duration = time.time() - start_time
|
||||||
|
|
||||||
# 统计连接结果
|
# 统计连接结果
|
||||||
@@ -2878,15 +2884,14 @@ class MCPBridgePlugin(BasePlugin):
|
|||||||
|
|
||||||
# 注册所有工具
|
# 注册所有工具
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
registered_count = 0
|
registered_count = 0
|
||||||
|
|
||||||
for tool_key, (tool_info, _) in mcp_manager.all_tools.items():
|
for tool_key, (tool_info, _) in mcp_manager.all_tools.items():
|
||||||
tool_name = tool_key.replace("-", "_").replace(".", "_")
|
tool_name = tool_key.replace("-", "_").replace(".", "_")
|
||||||
is_disabled = tool_name in disabled_tools
|
is_disabled = tool_name in disabled_tools
|
||||||
|
|
||||||
info, tool_class = mcp_tool_registry.register_tool(
|
info, tool_class = mcp_tool_registry.register_tool(tool_key, tool_info, tool_prefix, disabled=is_disabled)
|
||||||
tool_key, tool_info, tool_prefix, disabled=is_disabled
|
|
||||||
)
|
|
||||||
info.plugin_name = self.plugin_name
|
info.plugin_name = self.plugin_name
|
||||||
|
|
||||||
if component_registry.register_component(info, tool_class):
|
if component_registry.register_component(info, tool_class):
|
||||||
@@ -3004,6 +3009,7 @@ class MCPBridgePlugin(BasePlugin):
|
|||||||
doc["tools"] = tomlkit.table()
|
doc["tools"] = tomlkit.table()
|
||||||
# 使用 tomlkit 多行字符串避免控制字符问题
|
# 使用 tomlkit 多行字符串避免控制字符问题
|
||||||
from tomlkit.items import String, StringType, Trivia
|
from tomlkit.items import String, StringType, Trivia
|
||||||
|
|
||||||
ml_string = String(StringType.MLB, tool_list_text, tool_list_text, Trivia())
|
ml_string = String(StringType.MLB, tool_list_text, tool_list_text, Trivia())
|
||||||
doc["tools"]["tool_list"] = ml_string
|
doc["tools"]["tool_list"] = ml_string
|
||||||
|
|
||||||
@@ -3069,6 +3075,7 @@ class MCPBridgePlugin(BasePlugin):
|
|||||||
doc["status"] = tomlkit.table()
|
doc["status"] = tomlkit.table()
|
||||||
# 使用 tomlkit 多行字符串避免控制字符问题
|
# 使用 tomlkit 多行字符串避免控制字符问题
|
||||||
from tomlkit.items import String, StringType, Trivia
|
from tomlkit.items import String, StringType, Trivia
|
||||||
|
|
||||||
ml_string = String(StringType.MLB, status_text, status_text, Trivia())
|
ml_string = String(StringType.MLB, status_text, status_text, Trivia())
|
||||||
doc["status"]["connection_status"] = ml_string
|
doc["status"]["connection_status"] = ml_string
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ async def test_stats():
|
|||||||
assert stats.total_calls == 3
|
assert stats.total_calls == 3
|
||||||
assert stats.success_calls == 2
|
assert stats.success_calls == 2
|
||||||
assert stats.failed_calls == 1
|
assert stats.failed_calls == 1
|
||||||
assert stats.success_rate == (2/3) * 100
|
assert stats.success_rate == (2 / 3) * 100
|
||||||
assert stats.avg_duration_ms == 150.0
|
assert stats.avg_duration_ms == 150.0
|
||||||
assert stats.last_error == "timeout"
|
assert stats.last_error == "timeout"
|
||||||
|
|
||||||
@@ -66,13 +66,15 @@ async def test_manager_basic():
|
|||||||
manager.__init__()
|
manager.__init__()
|
||||||
|
|
||||||
# 配置
|
# 配置
|
||||||
manager.configure({
|
manager.configure(
|
||||||
"tool_prefix": "mcp",
|
{
|
||||||
"call_timeout": 30.0,
|
"tool_prefix": "mcp",
|
||||||
"retry_attempts": 1,
|
"call_timeout": 30.0,
|
||||||
"retry_interval": 1.0,
|
"retry_attempts": 1,
|
||||||
"heartbeat_enabled": False,
|
"retry_interval": 1.0,
|
||||||
})
|
"heartbeat_enabled": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 测试状态
|
# 测试状态
|
||||||
status = manager.get_status()
|
status = manager.get_status()
|
||||||
@@ -82,10 +84,7 @@ async def test_manager_basic():
|
|||||||
|
|
||||||
# 测试添加禁用的服务器
|
# 测试添加禁用的服务器
|
||||||
config = MCPServerConfig(
|
config = MCPServerConfig(
|
||||||
name="disabled_server",
|
name="disabled_server", enabled=False, transport=TransportType.HTTP, url="https://example.com/mcp"
|
||||||
enabled=False,
|
|
||||||
transport=TransportType.HTTP,
|
|
||||||
url="https://example.com/mcp"
|
|
||||||
)
|
)
|
||||||
result = await manager.add_server(config)
|
result = await manager.add_server(config)
|
||||||
assert result == True
|
assert result == True
|
||||||
@@ -120,27 +119,29 @@ async def test_http_connection():
|
|||||||
manager._initialized = False
|
manager._initialized = False
|
||||||
manager.__init__()
|
manager.__init__()
|
||||||
|
|
||||||
manager.configure({
|
manager.configure(
|
||||||
"tool_prefix": "mcp",
|
{
|
||||||
"call_timeout": 30.0,
|
"tool_prefix": "mcp",
|
||||||
"retry_attempts": 2,
|
"call_timeout": 30.0,
|
||||||
"retry_interval": 2.0,
|
"retry_attempts": 2,
|
||||||
"heartbeat_enabled": False,
|
"retry_interval": 2.0,
|
||||||
})
|
"heartbeat_enabled": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 使用 HowToCook MCP 服务器测试
|
# 使用 HowToCook MCP 服务器测试
|
||||||
config = MCPServerConfig(
|
config = MCPServerConfig(
|
||||||
name="howtocook",
|
name="howtocook",
|
||||||
enabled=True,
|
enabled=True,
|
||||||
transport=TransportType.HTTP,
|
transport=TransportType.HTTP,
|
||||||
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp"
|
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp",
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"正在连接 {config.url} ...")
|
print(f"正在连接 {config.url} ...")
|
||||||
result = await manager.add_server(config)
|
result = await manager.add_server(config)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
print(f"✅ 连接成功!")
|
print("✅ 连接成功!")
|
||||||
|
|
||||||
# 检查工具
|
# 检查工具
|
||||||
tools = manager.all_tools
|
tools = manager.all_tools
|
||||||
@@ -159,19 +160,23 @@ async def test_http_connection():
|
|||||||
call_result = await manager.call_tool("mcp_howtocook_whatToEat", {})
|
call_result = await manager.call_tool("mcp_howtocook_whatToEat", {})
|
||||||
if call_result.success:
|
if call_result.success:
|
||||||
print(f"✅ 工具调用成功 (耗时: {call_result.duration_ms:.0f}ms)")
|
print(f"✅ 工具调用成功 (耗时: {call_result.duration_ms:.0f}ms)")
|
||||||
print(f" 结果: {call_result.content[:200]}..." if len(str(call_result.content)) > 200 else f" 结果: {call_result.content}")
|
print(
|
||||||
|
f" 结果: {call_result.content[:200]}..."
|
||||||
|
if len(str(call_result.content)) > 200
|
||||||
|
else f" 结果: {call_result.content}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(f"❌ 工具调用失败: {call_result.error}")
|
print(f"❌ 工具调用失败: {call_result.error}")
|
||||||
|
|
||||||
# 查看统计
|
# 查看统计
|
||||||
stats = manager.get_all_stats()
|
stats = manager.get_all_stats()
|
||||||
print(f"\n📊 统计信息:")
|
print("\n📊 统计信息:")
|
||||||
print(f" 全局调用: {stats['global']['total_tool_calls']}")
|
print(f" 全局调用: {stats['global']['total_tool_calls']}")
|
||||||
print(f" 成功: {stats['global']['successful_calls']}")
|
print(f" 成功: {stats['global']['successful_calls']}")
|
||||||
print(f" 失败: {stats['global']['failed_calls']}")
|
print(f" 失败: {stats['global']['failed_calls']}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f"❌ 连接失败")
|
print("❌ 连接失败")
|
||||||
|
|
||||||
# 清理
|
# 清理
|
||||||
await manager.shutdown()
|
await manager.shutdown()
|
||||||
@@ -187,23 +192,25 @@ async def test_heartbeat():
|
|||||||
manager._initialized = False
|
manager._initialized = False
|
||||||
manager.__init__()
|
manager.__init__()
|
||||||
|
|
||||||
manager.configure({
|
manager.configure(
|
||||||
"tool_prefix": "mcp",
|
{
|
||||||
"call_timeout": 30.0,
|
"tool_prefix": "mcp",
|
||||||
"retry_attempts": 1,
|
"call_timeout": 30.0,
|
||||||
"retry_interval": 1.0,
|
"retry_attempts": 1,
|
||||||
"heartbeat_enabled": True,
|
"retry_interval": 1.0,
|
||||||
"heartbeat_interval": 5.0, # 5秒间隔用于测试
|
"heartbeat_enabled": True,
|
||||||
"auto_reconnect": True,
|
"heartbeat_interval": 5.0, # 5秒间隔用于测试
|
||||||
"max_reconnect_attempts": 2,
|
"auto_reconnect": True,
|
||||||
})
|
"max_reconnect_attempts": 2,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 添加一个测试服务器
|
# 添加一个测试服务器
|
||||||
config = MCPServerConfig(
|
config = MCPServerConfig(
|
||||||
name="heartbeat_test",
|
name="heartbeat_test",
|
||||||
enabled=True,
|
enabled=True,
|
||||||
transport=TransportType.HTTP,
|
transport=TransportType.HTTP,
|
||||||
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp"
|
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp",
|
||||||
)
|
)
|
||||||
|
|
||||||
print("正在连接服务器...")
|
print("正在连接服务器...")
|
||||||
@@ -260,6 +267,7 @@ async def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n❌ 测试失败: {e}")
|
print(f"\n❌ 测试失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -301,4 +301,3 @@ def main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,12 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
)
|
)
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.bw_learner.learner_utils import filter_message_content, is_bot_message, build_context_paragraph, contains_bot_self_name
|
from src.bw_learner.learner_utils import (
|
||||||
|
filter_message_content,
|
||||||
|
is_bot_message,
|
||||||
|
build_context_paragraph,
|
||||||
|
contains_bot_self_name,
|
||||||
|
)
|
||||||
from src.bw_learner.jargon_miner import miner_manager
|
from src.bw_learner.jargon_miner import miner_manager
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
@@ -77,8 +82,6 @@ def init_prompt() -> None:
|
|||||||
Prompt(learn_style_prompt, "learn_style_prompt")
|
Prompt(learn_style_prompt, "learn_style_prompt")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ExpressionLearner:
|
class ExpressionLearner:
|
||||||
def __init__(self, chat_id: str) -> None:
|
def __init__(self, chat_id: str) -> None:
|
||||||
self.express_learn_model: LLMRequest = LLMRequest(
|
self.express_learn_model: LLMRequest = LLMRequest(
|
||||||
@@ -186,7 +189,6 @@ class ExpressionLearner:
|
|||||||
|
|
||||||
filtered_expressions.append((situation, style, context))
|
filtered_expressions.append((situation, style, context))
|
||||||
|
|
||||||
|
|
||||||
learnt_expressions = filtered_expressions
|
learnt_expressions = filtered_expressions
|
||||||
|
|
||||||
if learnt_expressions is None:
|
if learnt_expressions is None:
|
||||||
@@ -270,6 +272,7 @@ class ExpressionLearner:
|
|||||||
# 如果解析失败,尝试修复中文引号问题
|
# 如果解析失败,尝试修复中文引号问题
|
||||||
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
|
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
|
||||||
try:
|
try:
|
||||||
|
|
||||||
def fix_chinese_quotes_in_json(text):
|
def fix_chinese_quotes_in_json(text):
|
||||||
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
||||||
result = []
|
result = []
|
||||||
@@ -287,7 +290,7 @@ class ExpressionLearner:
|
|||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if char == '\\':
|
if char == "\\":
|
||||||
# 转义字符
|
# 转义字符
|
||||||
result.append(char)
|
result.append(char)
|
||||||
escape_next = True
|
escape_next = True
|
||||||
@@ -315,7 +318,7 @@ class ExpressionLearner:
|
|||||||
|
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
return ''.join(result)
|
return "".join(result)
|
||||||
|
|
||||||
fixed_raw = fix_chinese_quotes_in_json(raw)
|
fixed_raw = fix_chinese_quotes_in_json(raw)
|
||||||
|
|
||||||
|
|||||||
@@ -82,9 +82,7 @@ class ExpressionReflector:
|
|||||||
# 获取未检查的表达
|
# 获取未检查的表达
|
||||||
try:
|
try:
|
||||||
logger.info("[Expression Reflection] 查询未检查且未拒绝的表达")
|
logger.info("[Expression Reflection] 查询未检查且未拒绝的表达")
|
||||||
expressions = (
|
expressions = Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
|
||||||
Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
|
|
||||||
)
|
|
||||||
|
|
||||||
expr_list = list(expressions)
|
expr_list = list(expressions)
|
||||||
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
|
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
|
||||||
|
|||||||
@@ -128,9 +128,7 @@ class ExpressionSelector:
|
|||||||
|
|
||||||
# 查询所有相关chat_id的表达方式,排除 rejected=1 的,且只选择 count > 1 的
|
# 查询所有相关chat_id的表达方式,排除 rejected=1 的,且只选择 count > 1 的
|
||||||
style_query = Expression.select().where(
|
style_query = Expression.select().where(
|
||||||
(Expression.chat_id.in_(related_chat_ids))
|
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
|
||||||
& (~Expression.rejected)
|
|
||||||
& (Expression.count > 1)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
style_exprs = [
|
style_exprs = [
|
||||||
@@ -150,12 +148,15 @@ class ExpressionSelector:
|
|||||||
# 要求至少有10个 count > 1 的表达方式才进行选择
|
# 要求至少有10个 count > 1 的表达方式才进行选择
|
||||||
min_required = 10
|
min_required = 10
|
||||||
if len(style_exprs) < min_required:
|
if len(style_exprs) < min_required:
|
||||||
logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择")
|
logger.info(
|
||||||
|
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),不进行选择"
|
||||||
|
)
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
# 固定选择5个
|
# 固定选择5个
|
||||||
select_count = 5
|
select_count = 5
|
||||||
import random
|
import random
|
||||||
|
|
||||||
selected_style = random.sample(style_exprs, select_count)
|
selected_style = random.sample(style_exprs, select_count)
|
||||||
|
|
||||||
# 更新last_active_time
|
# 更新last_active_time
|
||||||
@@ -163,7 +164,9 @@ class ExpressionSelector:
|
|||||||
self.update_expressions_last_active_time(selected_style)
|
self.update_expressions_last_active_time(selected_style)
|
||||||
|
|
||||||
selected_ids = [expr["id"] for expr in selected_style]
|
selected_ids = [expr["id"] for expr in selected_style]
|
||||||
logger.debug(f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)} 个")
|
logger.debug(
|
||||||
|
f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)} 个"
|
||||||
|
)
|
||||||
return selected_style, selected_ids
|
return selected_style, selected_ids
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -186,9 +189,7 @@ class ExpressionSelector:
|
|||||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||||
|
|
||||||
# 优化:一次性查询所有相关chat_id的表达方式,排除 rejected=1 的表达
|
# 优化:一次性查询所有相关chat_id的表达方式,排除 rejected=1 的表达
|
||||||
style_query = Expression.select().where(
|
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
|
||||||
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
|
||||||
)
|
|
||||||
|
|
||||||
style_exprs = [
|
style_exprs = [
|
||||||
{
|
{
|
||||||
@@ -246,7 +247,9 @@ class ExpressionSelector:
|
|||||||
|
|
||||||
# 使用classic模式(随机选择+LLM选择)
|
# 使用classic模式(随机选择+LLM选择)
|
||||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式,think_level={think_level}")
|
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式,think_level={think_level}")
|
||||||
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message, reply_reason, think_level)
|
return await self._select_expressions_classic(
|
||||||
|
chat_id, chat_info, max_num, target_message, reply_reason, think_level
|
||||||
|
)
|
||||||
|
|
||||||
async def _select_expressions_classic(
|
async def _select_expressions_classic(
|
||||||
self,
|
self,
|
||||||
@@ -279,9 +282,7 @@ class ExpressionSelector:
|
|||||||
# think_level == 1: 先选高count,再从所有表达方式中随机抽样
|
# think_level == 1: 先选高count,再从所有表达方式中随机抽样
|
||||||
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
|
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
|
||||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||||
style_query = Expression.select().where(
|
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
|
||||||
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
|
||||||
)
|
|
||||||
|
|
||||||
all_style_exprs = [
|
all_style_exprs = [
|
||||||
{
|
{
|
||||||
@@ -308,11 +309,15 @@ class ExpressionSelector:
|
|||||||
|
|
||||||
# 检查数量要求
|
# 检查数量要求
|
||||||
if len(high_count_exprs) < min_high_count:
|
if len(high_count_exprs) < min_high_count:
|
||||||
logger.info(f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择")
|
logger.info(
|
||||||
|
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),不进行选择"
|
||||||
|
)
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
if len(all_style_exprs) < min_total_count:
|
if len(all_style_exprs) < min_total_count:
|
||||||
logger.info(f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择")
|
logger.info(
|
||||||
|
f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择"
|
||||||
|
)
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
# 先选取高count的表达方式
|
# 先选取高count的表达方式
|
||||||
@@ -332,6 +337,7 @@ class ExpressionSelector:
|
|||||||
|
|
||||||
# 打乱顺序,避免高count的都在前面
|
# 打乱顺序,避免高count的都在前面
|
||||||
import random
|
import random
|
||||||
|
|
||||||
random.shuffle(candidate_exprs)
|
random.shuffle(candidate_exprs)
|
||||||
|
|
||||||
# 2. 构建所有表达方式的索引和情境列表
|
# 2. 构建所有表达方式的索引和情境列表
|
||||||
@@ -351,7 +357,7 @@ class ExpressionSelector:
|
|||||||
all_situations_str = "\n".join(all_situations)
|
all_situations_str = "\n".join(all_situations)
|
||||||
|
|
||||||
if target_message:
|
if target_message:
|
||||||
target_message_str = f",现在你想要对这条消息进行回复:\"{target_message}\""
|
target_message_str = f',现在你想要对这条消息进行回复:"{target_message}"'
|
||||||
target_message_extra_block = "4.考虑你要回复的目标消息"
|
target_message_extra_block = "4.考虑你要回复的目标消息"
|
||||||
else:
|
else:
|
||||||
target_message_str = ""
|
target_message_str = ""
|
||||||
|
|||||||
@@ -8,7 +8,12 @@ from src.llm_models.utils_model import LLMRequest
|
|||||||
from src.config.config import model_config, global_config
|
from src.config.config import model_config, global_config
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.bw_learner.jargon_miner import search_jargon
|
from src.bw_learner.jargon_miner import search_jargon
|
||||||
from src.bw_learner.learner_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains
|
from src.bw_learner.learner_utils import (
|
||||||
|
is_bot_message,
|
||||||
|
contains_bot_self_name,
|
||||||
|
parse_chat_id_list,
|
||||||
|
chat_id_list_contains,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger("jargon")
|
logger = get_logger("jargon")
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import time
|
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import random
|
import random
|
||||||
@@ -14,7 +13,6 @@ from src.config.config import model_config, global_config
|
|||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
build_readable_messages_with_id,
|
build_readable_messages_with_id,
|
||||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
|
||||||
)
|
)
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.bw_learner.learner_utils import (
|
from src.bw_learner.learner_utils import (
|
||||||
@@ -46,10 +44,10 @@ def _is_single_char_jargon(content: str) -> bool:
|
|||||||
char = content[0]
|
char = content[0]
|
||||||
# 判断是否是单个汉字、单个英文字母或单个数字
|
# 判断是否是单个汉字、单个英文字母或单个数字
|
||||||
return (
|
return (
|
||||||
'\u4e00' <= char <= '\u9fff' or # 汉字
|
"\u4e00" <= char <= "\u9fff" # 汉字
|
||||||
'a' <= char <= 'z' or # 小写字母
|
or "a" <= char <= "z" # 小写字母
|
||||||
'A' <= char <= 'Z' or # 大写字母
|
or "A" <= char <= "Z" # 大写字母
|
||||||
'0' <= char <= '9' # 数字
|
or "0" <= char <= "9" # 数字
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -305,7 +303,9 @@ class JargonMiner:
|
|||||||
# 计算要保留的数量(至少保留1个)
|
# 计算要保留的数量(至少保留1个)
|
||||||
keep_count = max(1, len(raw_content_list) // 2)
|
keep_count = max(1, len(raw_content_list) // 2)
|
||||||
raw_content_list = random.sample(raw_content_list, keep_count)
|
raw_content_list = random.sample(raw_content_list, keep_count)
|
||||||
logger.info(f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目")
|
logger.info(
|
||||||
|
f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目"
|
||||||
|
)
|
||||||
|
|
||||||
# 步骤1: 基于raw_content和content推断
|
# 步骤1: 基于raw_content和content推断
|
||||||
raw_content_text = "\n".join(raw_content_list)
|
raw_content_text = "\n".join(raw_content_list)
|
||||||
@@ -318,7 +318,9 @@ class JargonMiner:
|
|||||||
**上一次推断的含义(仅供参考)**
|
**上一次推断的含义(仅供参考)**
|
||||||
{previous_meaning}
|
{previous_meaning}
|
||||||
"""
|
"""
|
||||||
previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
|
previous_meaning_instruction = (
|
||||||
|
"- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
|
||||||
|
)
|
||||||
|
|
||||||
prompt1 = await global_prompt_manager.format_prompt(
|
prompt1 = await global_prompt_manager.format_prompt(
|
||||||
"jargon_inference_with_context_prompt",
|
"jargon_inference_with_context_prompt",
|
||||||
@@ -650,7 +652,9 @@ class JargonMiner:
|
|||||||
if obj.raw_content:
|
if obj.raw_content:
|
||||||
try:
|
try:
|
||||||
existing_raw_content = (
|
existing_raw_content = (
|
||||||
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
json.loads(obj.raw_content)
|
||||||
|
if isinstance(obj.raw_content, str)
|
||||||
|
else obj.raw_content
|
||||||
)
|
)
|
||||||
if not isinstance(existing_raw_content, list):
|
if not isinstance(existing_raw_content, list):
|
||||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||||
@@ -876,8 +880,6 @@ class JargonMinerManager:
|
|||||||
miner_manager = JargonMinerManager()
|
miner_manager = JargonMinerManager()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def search_jargon(
|
def search_jargon(
|
||||||
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
|
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
|
|||||||
@@ -116,7 +116,6 @@ class MessageRecorder:
|
|||||||
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
|
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 分别触发 expression_learner 和 jargon_miner 的处理
|
# 分别触发 expression_learner 和 jargon_miner 的处理
|
||||||
# 传递提取的消息,避免它们重复获取
|
# 传递提取的消息,避免它们重复获取
|
||||||
# 触发 expression 学习(如果启用)
|
# 触发 expression 学习(如果启用)
|
||||||
@@ -127,21 +126,19 @@ class MessageRecorder:
|
|||||||
|
|
||||||
# 触发 jargon 提取(如果启用),传递消息
|
# 触发 jargon 提取(如果启用),传递消息
|
||||||
# if self.enable_jargon_learning:
|
# if self.enable_jargon_learning:
|
||||||
# asyncio.create_task(
|
# asyncio.create_task(
|
||||||
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
|
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
|
||||||
# )
|
# )
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
|
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
# 即使失败也保持时间戳更新,避免频繁重试
|
# 即使失败也保持时间戳更新,避免频繁重试
|
||||||
|
|
||||||
async def _trigger_expression_learning(
|
async def _trigger_expression_learning(
|
||||||
self,
|
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
|
||||||
timestamp_start: float,
|
|
||||||
timestamp_end: float,
|
|
||||||
messages: List[Any]
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
触发 expression 学习,使用指定的消息列表
|
触发 expression 学习,使用指定的消息列表
|
||||||
@@ -162,13 +159,11 @@ class MessageRecorder:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}")
|
logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
async def _trigger_jargon_extraction(
|
async def _trigger_jargon_extraction(
|
||||||
self,
|
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
|
||||||
timestamp_start: float,
|
|
||||||
timestamp_end: float,
|
|
||||||
messages: List[Any]
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
触发 jargon 提取,使用指定的消息列表
|
触发 jargon 提取,使用指定的消息列表
|
||||||
@@ -185,6 +180,7 @@ class MessageRecorder:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
|
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
@@ -214,4 +210,3 @@ async def extract_and_distribute_messages(chat_id: str) -> None:
|
|||||||
"""
|
"""
|
||||||
recorder = recorder_manager.get_recorder(chat_id)
|
recorder = recorder_manager.get_recorder(chat_id)
|
||||||
await recorder.extract_and_distribute()
|
await recorder.extract_and_distribute()
|
||||||
|
|
||||||
|
|||||||
@@ -328,9 +328,7 @@ class BrainChatting:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否有 complete_talk 动作(会停止后续迭代)
|
# 检查是否有 complete_talk 动作(会停止后续迭代)
|
||||||
has_complete_talk = any(
|
has_complete_talk = any(action.action_type == "complete_talk" for action in action_to_use_info)
|
||||||
action.action_type == "complete_talk" for action in action_to_use_info
|
|
||||||
)
|
|
||||||
|
|
||||||
# 并行执行所有动作
|
# 并行执行所有动作
|
||||||
action_tasks = [
|
action_tasks = [
|
||||||
|
|||||||
@@ -204,7 +204,9 @@ class BrainPlanner:
|
|||||||
# 注意:listening 已合并到 wait 中,如果遇到 listening 则转换为 wait
|
# 注意:listening 已合并到 wait 中,如果遇到 listening 则转换为 wait
|
||||||
internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"]
|
internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"]
|
||||||
|
|
||||||
logger.debug(f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}")
|
logger.debug(
|
||||||
|
f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}"
|
||||||
|
)
|
||||||
|
|
||||||
# 将 listening 转换为 wait(向后兼容)
|
# 将 listening 转换为 wait(向后兼容)
|
||||||
if action == "listening":
|
if action == "listening":
|
||||||
@@ -521,7 +523,7 @@ class BrainPlanner:
|
|||||||
if json_objects:
|
if json_objects:
|
||||||
logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
|
||||||
for i, json_obj in enumerate(json_objects):
|
for i, json_obj in enumerate(json_objects):
|
||||||
logger.info(f"{self.log_prefix}解析第{i+1}个JSON对象: {json_obj}")
|
logger.info(f"{self.log_prefix}解析第{i + 1}个JSON对象: {json_obj}")
|
||||||
filtered_actions_list = list(filtered_actions.items())
|
filtered_actions_list = list(filtered_actions.items())
|
||||||
for json_obj in json_objects:
|
for json_obj in json_objects:
|
||||||
parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list)
|
parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list)
|
||||||
@@ -553,7 +555,9 @@ class BrainPlanner:
|
|||||||
|
|
||||||
return extracted_reasoning, actions
|
return extracted_reasoning, actions
|
||||||
|
|
||||||
def _create_complete_talk(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
|
def _create_complete_talk(
|
||||||
|
self, reasoning: str, available_actions: Dict[str, ActionInfo]
|
||||||
|
) -> List[ActionPlannerInfo]:
|
||||||
"""创建complete_talk"""
|
"""创建complete_talk"""
|
||||||
return [
|
return [
|
||||||
ActionPlannerInfo(
|
ActionPlannerInfo(
|
||||||
|
|||||||
@@ -271,7 +271,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
|
|||||||
|
|
||||||
emoji.description = emoji_data.description
|
emoji.description = emoji_data.description
|
||||||
# Deserialize emotion string from DB to list
|
# Deserialize emotion string from DB to list
|
||||||
emoji.emotion = emoji_data.emotion.replace(",",",").split(",") if emoji_data.emotion else []
|
emoji.emotion = emoji_data.emotion.replace(",", ",").split(",") if emoji_data.emotion else []
|
||||||
emoji.usage_count = emoji_data.usage_count
|
emoji.usage_count = emoji_data.usage_count
|
||||||
|
|
||||||
db_last_used_time = emoji_data.last_used_time
|
db_last_used_time = emoji_data.last_used_time
|
||||||
@@ -732,7 +732,7 @@ class EmojiManager:
|
|||||||
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||||
if emoji_record and emoji_record.emotion:
|
if emoji_record and emoji_record.emotion:
|
||||||
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
|
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
|
||||||
return emoji_record.emotion.replace(",",",").split(",")
|
return emoji_record.emotion.replace(",", ",").split(",")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
|
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
|
||||||
|
|
||||||
@@ -993,7 +993,7 @@ class EmojiManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 处理情感列表
|
# 处理情感列表
|
||||||
emotions = [e.strip() for e in emotions_text.replace(",",",").split(",") if e.strip()]
|
emotions = [e.strip() for e in emotions_text.replace(",", ",").split(",") if e.strip()]
|
||||||
|
|
||||||
# 根据情感标签数量随机选择 - 超过5个选3个,超过2个选2个
|
# 根据情感标签数量随机选择 - 超过5个选3个,超过2个选2个
|
||||||
if len(emotions) > 5:
|
if len(emotions) > 5:
|
||||||
|
|||||||
@@ -123,7 +123,11 @@ class ChatBot:
|
|||||||
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
|
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
|
||||||
|
|
||||||
# 根据命令的拦截设置决定是否继续处理消息
|
# 根据命令的拦截设置决定是否继续处理消息
|
||||||
return True, response, not bool(intercept_message_level) # 找到命令,根据intercept_message决定是否继续
|
return (
|
||||||
|
True,
|
||||||
|
response,
|
||||||
|
not bool(intercept_message_level),
|
||||||
|
) # 找到命令,根据intercept_message决定是否继续
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
|
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
|
||||||
|
|||||||
@@ -263,7 +263,7 @@ class MessageRecv(Message):
|
|||||||
desc = segment.data.get("desc", "") # 内容描述
|
desc = segment.data.get("desc", "") # 内容描述
|
||||||
source_url = segment.data.get("source_url", "") # 原始链接
|
source_url = segment.data.get("source_url", "") # 原始链接
|
||||||
url = segment.data.get("url", "") # 小程序链接
|
url = segment.data.get("url", "") # 小程序链接
|
||||||
text = f"[小程序分享"
|
text = "[小程序分享"
|
||||||
if title:
|
if title:
|
||||||
text += f" - {title}"
|
text += f" - {title}"
|
||||||
text += "]"
|
text += "]"
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ def parse_message_segments(segment) -> list:
|
|||||||
Returns:
|
Returns:
|
||||||
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
|
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
|
||||||
"""
|
"""
|
||||||
from maim_message import Seg
|
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
@@ -112,9 +111,13 @@ def parse_message_segments(segment) -> list:
|
|||||||
forward_items = []
|
forward_items = []
|
||||||
if segment.data:
|
if segment.data:
|
||||||
for item in segment.data:
|
for item in segment.data:
|
||||||
forward_items.append({
|
forward_items.append(
|
||||||
"content": parse_message_segments(item.get("message_segment", {})) if isinstance(item, dict) else []
|
{
|
||||||
})
|
"content": parse_message_segments(item.get("message_segment", {}))
|
||||||
|
if isinstance(item, dict)
|
||||||
|
else []
|
||||||
|
}
|
||||||
|
)
|
||||||
result.append({"type": "forward", "data": forward_items})
|
result.append({"type": "forward", "data": forward_items})
|
||||||
else:
|
else:
|
||||||
# 未知类型,尝试作为文本处理
|
# 未知类型,尝试作为文本处理
|
||||||
|
|||||||
@@ -78,7 +78,6 @@ target_message_id为必填,表示触发消息的id
|
|||||||
"planner_prompt",
|
"planner_prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
{action_name}
|
{action_name}
|
||||||
|
|||||||
@@ -250,7 +250,12 @@ class DefaultReplyer:
|
|||||||
# 使用从处理器传来的选中表达方式
|
# 使用从处理器传来的选中表达方式
|
||||||
# 使用模型预测选择表达方式
|
# 使用模型预测选择表达方式
|
||||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason, think_level=think_level
|
self.chat_stream.stream_id,
|
||||||
|
chat_history,
|
||||||
|
max_num=8,
|
||||||
|
target_message=target,
|
||||||
|
reply_reason=reply_reason,
|
||||||
|
think_level=think_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
if selected_expressions:
|
if selected_expressions:
|
||||||
@@ -273,7 +278,6 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||||
|
|
||||||
|
|
||||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||||
"""构建工具信息块
|
"""构建工具信息块
|
||||||
|
|
||||||
@@ -788,7 +792,8 @@ class DefaultReplyer:
|
|||||||
# 并行执行八个构建任务(包括黑话解释)
|
# 并行执行八个构建任务(包括黑话解释)
|
||||||
task_results = await asyncio.gather(
|
task_results = await asyncio.gather(
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level), "expression_habits"
|
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level),
|
||||||
|
"expression_habits",
|
||||||
),
|
),
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
|
||||||
@@ -980,7 +985,6 @@ class DefaultReplyer:
|
|||||||
else:
|
else:
|
||||||
reply_target_block = ""
|
reply_target_block = ""
|
||||||
|
|
||||||
|
|
||||||
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||||
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
|
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
|
||||||
|
|
||||||
|
|||||||
@@ -287,7 +287,6 @@ class PrivateReplyer:
|
|||||||
|
|
||||||
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||||
|
|
||||||
|
|
||||||
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
|
||||||
"""构建工具信息块
|
"""构建工具信息块
|
||||||
|
|
||||||
@@ -907,16 +906,11 @@ class PrivateReplyer:
|
|||||||
else:
|
else:
|
||||||
reply_target_block = ""
|
reply_target_block = ""
|
||||||
|
|
||||||
|
|
||||||
chat_target_name = "对方"
|
chat_target_name = "对方"
|
||||||
if self.chat_target_info:
|
if self.chat_target_info:
|
||||||
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
|
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
|
||||||
chat_target_1 = await global_prompt_manager.format_prompt(
|
chat_target_1 = await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name)
|
||||||
"chat_target_private1", sender_name=chat_target_name
|
chat_target_2 = await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name)
|
||||||
)
|
|
||||||
chat_target_2 = await global_prompt_manager.format_prompt(
|
|
||||||
"chat_target_private2", sender_name=chat_target_name
|
|
||||||
)
|
|
||||||
|
|
||||||
template_name = "default_expressor_prompt"
|
template_name = "default_expressor_prompt"
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from src.chat.utils.prompt_builder import Prompt
|
from src.chat.utils.prompt_builder import Prompt
|
||||||
|
|
||||||
|
|
||||||
def init_replyer_private_prompt():
|
def init_replyer_private_prompt():
|
||||||
Prompt(
|
Prompt(
|
||||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||||
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
{expression_habits_block}{memory_retrieval}{jargon_explanation}
|
||||||
|
|
||||||
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
你正在和{sender_name}聊天,这是你们之前聊的内容:
|
||||||
@@ -17,8 +18,8 @@ def init_replyer_private_prompt():
|
|||||||
{reply_style}
|
{reply_style}
|
||||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
|
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
|
||||||
"private_replyer_prompt",
|
"private_replyer_prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
|
||||||
|
|||||||
@@ -44,4 +44,3 @@ def init_replyer_prompt():
|
|||||||
现在,你说:""",
|
现在,你说:""",
|
||||||
"replyer_prompt",
|
"replyer_prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -311,7 +311,10 @@ def get_raw_msg_before_timestamp_with_chat(
|
|||||||
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
|
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
|
||||||
sort_order = [("time", 1)]
|
sort_order = [("time", 1)]
|
||||||
return find_messages(
|
return find_messages(
|
||||||
message_filter=filter_query, sort=sort_order, limit=limit, filter_intercept_message_level=filter_intercept_message_level
|
message_filter=filter_query,
|
||||||
|
sort=sort_order,
|
||||||
|
limit=limit,
|
||||||
|
filter_intercept_message_level=filter_intercept_message_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -132,9 +132,7 @@ class ImageManager:
|
|||||||
deleted_images = Images.delete().where(Images.type == "emoji").execute()
|
deleted_images = Images.delete().where(Images.type == "emoji").execute()
|
||||||
|
|
||||||
# 清理ImageDescriptions表中type为emoji的记录
|
# 清理ImageDescriptions表中type为emoji的记录
|
||||||
deleted_descriptions = (
|
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
||||||
ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
|
||||||
)
|
|
||||||
|
|
||||||
total_deleted = deleted_images + deleted_descriptions
|
total_deleted = deleted_images + deleted_descriptions
|
||||||
if total_deleted > 0:
|
if total_deleted > 0:
|
||||||
@@ -236,10 +234,14 @@ class ImageManager:
|
|||||||
# 优先使用情感标签,如果没有则使用详细描述
|
# 优先使用情感标签,如果没有则使用详细描述
|
||||||
result_text = ""
|
result_text = ""
|
||||||
if cache_record.emotion_tags:
|
if cache_record.emotion_tags:
|
||||||
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...")
|
logger.info(
|
||||||
|
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
|
||||||
|
)
|
||||||
result_text = f"[表情包:{cache_record.emotion_tags}]"
|
result_text = f"[表情包:{cache_record.emotion_tags}]"
|
||||||
elif cache_record.description:
|
elif cache_record.description:
|
||||||
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...")
|
logger.info(
|
||||||
|
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
|
||||||
|
)
|
||||||
result_text = f"[表情包:{cache_record.description}]"
|
result_text = f"[表情包:{cache_record.description}]"
|
||||||
|
|
||||||
# 即使缓存命中,如果启用了steal_emoji,也检查是否需要保存文件
|
# 即使缓存命中,如果启用了steal_emoji,也检查是否需要保存文件
|
||||||
|
|||||||
@@ -609,10 +609,10 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
|
|||||||
fields = list(model._meta.fields.keys())
|
fields = list(model._meta.fields.keys())
|
||||||
# Peewee 默认使用 'id' 作为主键字段名
|
# Peewee 默认使用 'id' 作为主键字段名
|
||||||
# 尝试获取主键字段名,如果获取失败则默认使用 'id'
|
# 尝试获取主键字段名,如果获取失败则默认使用 'id'
|
||||||
primary_key_name = 'id' # 默认值
|
primary_key_name = "id" # 默认值
|
||||||
try:
|
try:
|
||||||
if hasattr(model._meta, 'primary_key') and model._meta.primary_key:
|
if hasattr(model._meta, "primary_key") and model._meta.primary_key:
|
||||||
if hasattr(model._meta.primary_key, 'name'):
|
if hasattr(model._meta.primary_key, "name"):
|
||||||
primary_key_name = model._meta.primary_key.name
|
primary_key_name = model._meta.primary_key.name
|
||||||
elif isinstance(model._meta.primary_key, str):
|
elif isinstance(model._meta.primary_key, str):
|
||||||
primary_key_name = model._meta.primary_key
|
primary_key_name = model._meta.primary_key
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
# 决定是否多行:仅在顶层且长度超过阈值时
|
# 决定是否多行:仅在顶层且长度超过阈值时
|
||||||
should_multiline = (depth == 0 and len(obj) > threshold)
|
should_multiline = depth == 0 and len(obj) > threshold
|
||||||
|
|
||||||
# 如果已经是 tomlkit Array,原地修改以保留注释
|
# 如果已经是 tomlkit Array,原地修改以保留注释
|
||||||
if isinstance(obj, Array):
|
if isinstance(obj, Array):
|
||||||
@@ -112,7 +112,7 @@ def save_toml_with_format(
|
|||||||
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
||||||
output = tomlkit.dumps(formatted)
|
output = tomlkit.dumps(formatted)
|
||||||
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
|
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
|
||||||
output = re.sub(r'\n{3,}', '\n\n', output)
|
output = re.sub(r"\n{3,}", "\n\n", output)
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
f.write(output)
|
f.write(output)
|
||||||
|
|
||||||
@@ -122,4 +122,4 @@ def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
|
|||||||
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
||||||
output = tomlkit.dumps(formatted)
|
output = tomlkit.dumps(formatted)
|
||||||
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
|
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
|
||||||
return re.sub(r'\n{3,}', '\n\n', output)
|
return re.sub(r"\n{3,}", "\n\n", output)
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import json
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from peewee import fn
|
from peewee import fn
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.common.database.database_model import ChatHistory, Jargon
|
from src.common.database.database_model import ChatHistory
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
||||||
from src.plugin_system.apis import llm_api
|
from src.plugin_system.apis import llm_api
|
||||||
@@ -82,7 +81,6 @@ def init_dream_prompts() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DreamTool:
|
class DreamTool:
|
||||||
"""dream 模块内部使用的简易工具封装"""
|
"""dream 模块内部使用的简易工具封装"""
|
||||||
|
|
||||||
@@ -150,7 +148,13 @@ def init_dream_tools(chat_id: str) -> None:
|
|||||||
"search_chat_history",
|
"search_chat_history",
|
||||||
"根据关键词或参与人查询当前 chat_id 下的 ChatHistory 概览,便于快速定位相关记忆。",
|
"根据关键词或参与人查询当前 chat_id 下的 ChatHistory 概览,便于快速定位相关记忆。",
|
||||||
[
|
[
|
||||||
("keyword", ToolParamType.STRING, "关键词(可选,支持多个关键词,可用空格、逗号等分隔)。", False, None),
|
(
|
||||||
|
"keyword",
|
||||||
|
ToolParamType.STRING,
|
||||||
|
"关键词(可选,支持多个关键词,可用空格、逗号等分隔)。",
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
),
|
||||||
("participant", ToolParamType.STRING, "参与人昵称(可选)。", False, None),
|
("participant", ToolParamType.STRING, "参与人昵称(可选)。", False, None),
|
||||||
],
|
],
|
||||||
search_chat_history,
|
search_chat_history,
|
||||||
@@ -201,8 +205,20 @@ def init_dream_tools(chat_id: str) -> None:
|
|||||||
[
|
[
|
||||||
("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None),
|
("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None),
|
||||||
("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None),
|
("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None),
|
||||||
("keywords", ToolParamType.STRING, "新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。", True, None),
|
(
|
||||||
("key_point", ToolParamType.STRING, "新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。", True, None),
|
"keywords",
|
||||||
|
ToolParamType.STRING,
|
||||||
|
"新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。",
|
||||||
|
True,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"key_point",
|
||||||
|
ToolParamType.STRING,
|
||||||
|
"新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。",
|
||||||
|
True,
|
||||||
|
None,
|
||||||
|
),
|
||||||
("start_time", ToolParamType.STRING, "起始时间戳(秒,Unix 时间,必填)。", True, None),
|
("start_time", ToolParamType.STRING, "起始时间戳(秒,Unix 时间,必填)。", True, None),
|
||||||
("end_time", ToolParamType.STRING, "结束时间戳(秒,Unix 时间,必填)。", True, None),
|
("end_time", ToolParamType.STRING, "结束时间戳(秒,Unix 时间,必填)。", True, None),
|
||||||
],
|
],
|
||||||
@@ -215,7 +231,13 @@ def init_dream_tools(chat_id: str) -> None:
|
|||||||
"finish_maintenance",
|
"finish_maintenance",
|
||||||
"结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理、合并或修改的内容时,调用此工具来主动结束本次运行。",
|
"结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理、合并或修改的内容时,调用此工具来主动结束本次运行。",
|
||||||
[
|
[
|
||||||
("reason", ToolParamType.STRING, "结束维护的原因说明(可选),例如 '已完成所有记录的整理' 或 '当前记录质量良好,无需进一步维护'。", False, None),
|
(
|
||||||
|
"reason",
|
||||||
|
ToolParamType.STRING,
|
||||||
|
"结束维护的原因说明(可选),例如 '已完成所有记录的整理' 或 '当前记录质量良好,无需进一步维护'。",
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
finish_maintenance,
|
finish_maintenance,
|
||||||
)
|
)
|
||||||
@@ -282,9 +304,7 @@ async def run_dream_agent_once(
|
|||||||
else "未知"
|
else "未知"
|
||||||
)
|
)
|
||||||
end_time_str = (
|
end_time_str = (
|
||||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time))
|
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
|
||||||
if record.end_time
|
|
||||||
else "未知"
|
|
||||||
)
|
)
|
||||||
detail_text = (
|
detail_text = (
|
||||||
f"ID={record.id}\n"
|
f"ID={record.id}\n"
|
||||||
@@ -305,8 +325,7 @@ async def run_dream_agent_once(
|
|||||||
start_detail_builder = MessageBuilder()
|
start_detail_builder = MessageBuilder()
|
||||||
start_detail_builder.set_role(RoleType.User)
|
start_detail_builder.set_role(RoleType.User)
|
||||||
start_detail_builder.add_text_content(
|
start_detail_builder.add_text_content(
|
||||||
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n"
|
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n" + detail_text
|
||||||
+ detail_text
|
|
||||||
)
|
)
|
||||||
conversation_messages.append(start_detail_builder.build())
|
conversation_messages.append(start_detail_builder.build())
|
||||||
else:
|
else:
|
||||||
@@ -343,13 +362,17 @@ async def run_dream_agent_once(
|
|||||||
conversation_messages.append(round_info_builder.build())
|
conversation_messages.append(round_info_builder.build())
|
||||||
|
|
||||||
# 调用 LLM 让其决定是否要使用工具
|
# 调用 LLM 让其决定是否要使用工具
|
||||||
success, response, reasoning_content, model_name, tool_calls = (
|
(
|
||||||
await llm_api.generate_with_model_with_tools_by_message_factory(
|
success,
|
||||||
message_factory,
|
response,
|
||||||
model_config=model_config.model_task_config.tool_use,
|
reasoning_content,
|
||||||
tool_options=tool_defs,
|
model_name,
|
||||||
request_type="dream.react",
|
tool_calls,
|
||||||
)
|
) = await llm_api.generate_with_model_with_tools_by_message_factory(
|
||||||
|
message_factory,
|
||||||
|
model_config=model_config.model_task_config.tool_use,
|
||||||
|
tool_options=tool_defs,
|
||||||
|
request_type="dream.react",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
@@ -555,4 +578,3 @@ async def start_dream_scheduler(
|
|||||||
|
|
||||||
# 初始化提示词
|
# 初始化提示词
|
||||||
init_dream_prompts()
|
init_dream_prompts()
|
||||||
|
|
||||||
|
|||||||
@@ -110,13 +110,17 @@ async def generate_dream_summary(
|
|||||||
thought_content = ""
|
thought_content = ""
|
||||||
if msg.content:
|
if msg.content:
|
||||||
if isinstance(msg.content, list) and msg.content:
|
if isinstance(msg.content, list) and msg.content:
|
||||||
thought_content = msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
|
thought_content = (
|
||||||
|
msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
thought_content = str(msg.content)
|
thought_content = str(msg.content)
|
||||||
|
|
||||||
logger.info(f"[dream][工具调用详情] === 第 {tool_call_count} 组工具调用 ===")
|
logger.info(f"[dream][工具调用详情] === 第 {tool_call_count} 组工具调用 ===")
|
||||||
if thought_content:
|
if thought_content:
|
||||||
logger.info(f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}")
|
logger.info(
|
||||||
|
f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}"
|
||||||
|
)
|
||||||
|
|
||||||
# 记录每个工具调用的详细信息
|
# 记录每个工具调用的详细信息
|
||||||
for idx, tool_call in enumerate(msg.tool_calls, 1):
|
for idx, tool_call in enumerate(msg.tool_calls, 1):
|
||||||
@@ -167,7 +171,7 @@ async def generate_dream_summary(
|
|||||||
|
|
||||||
# 随机选择2个梦境风格
|
# 随机选择2个梦境风格
|
||||||
selected_styles = get_random_dream_styles(2)
|
selected_styles = get_random_dream_styles(2)
|
||||||
dream_styles_text = "\n".join([f"{i+1}. {style}" for i, style in enumerate(selected_styles)])
|
dream_styles_text = "\n".join([f"{i + 1}. {style}" for i, style in enumerate(selected_styles)])
|
||||||
|
|
||||||
# 使用 Prompt 管理器格式化梦境生成 prompt
|
# 使用 Prompt 管理器格式化梦境生成 prompt
|
||||||
dream_prompt = await global_prompt_manager.format_prompt(
|
dream_prompt = await global_prompt_manager.format_prompt(
|
||||||
@@ -195,4 +199,5 @@ async def generate_dream_summary(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[dream][梦境总结] 生成梦境总结失败: {e}", exc_info=True)
|
logger.error(f"[dream][梦境总结] 生成梦境总结失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
init_dream_summary_prompt()
|
init_dream_summary_prompt()
|
||||||
@@ -4,8 +4,3 @@ dream agent 工具实现模块。
|
|||||||
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
|
每个工具的具体实现放在独立文件中,通过 make_xxx(chat_id) 工厂函数
|
||||||
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
|
生成绑定到特定 chat_id 的协程函数,由 dream_agent.init_dream_tools 统一注册。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -60,8 +60,3 @@ def make_create_chat_history(chat_id: str):
|
|||||||
return f"create_chat_history 执行失败: {e}"
|
return f"create_chat_history 执行失败: {e}"
|
||||||
|
|
||||||
return create_chat_history
|
return create_chat_history
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,8 +23,3 @@ def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,
|
|||||||
return f"delete_chat_history 执行失败: {e}"
|
return f"delete_chat_history 执行失败: {e}"
|
||||||
|
|
||||||
return delete_chat_history
|
return delete_chat_history
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,8 +23,3 @@ def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留
|
|||||||
return f"delete_jargon 执行失败: {e}"
|
return f"delete_jargon 执行失败: {e}"
|
||||||
|
|
||||||
return delete_jargon
|
return delete_jargon
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,3 @@ def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
return finish_maintenance
|
return finish_maintenance
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import ChatHistory
|
from src.common.database.database_model import ChatHistory
|
||||||
@@ -20,14 +19,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
|
|||||||
|
|
||||||
# 将时间戳转换为可读时间格式
|
# 将时间戳转换为可读时间格式
|
||||||
start_time_str = (
|
start_time_str = (
|
||||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time))
|
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time)) if record.start_time else "未知"
|
||||||
if record.start_time
|
|
||||||
else "未知"
|
|
||||||
)
|
)
|
||||||
end_time_str = (
|
end_time_str = (
|
||||||
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time))
|
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
|
||||||
if record.end_time
|
|
||||||
else "未知"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = (
|
result = (
|
||||||
@@ -40,17 +35,10 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
|
|||||||
f"概括={record.summary or '无'}\n"
|
f"概括={record.summary or '无'}\n"
|
||||||
f"关键信息={record.key_point or '无'}"
|
f"关键信息={record.key_point or '无'}"
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}")
|
||||||
f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}"
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"get_chat_history_detail 失败: {e}")
|
logger.error(f"get_chat_history_detail 失败: {e}")
|
||||||
return f"get_chat_history_detail 执行失败: {e}"
|
return f"get_chat_history_detail 执行失败: {e}"
|
||||||
|
|
||||||
return get_chat_history_detail
|
return get_chat_history_detail
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -78,9 +78,7 @@ def make_search_chat_history(chat_id: str):
|
|||||||
if record.keywords:
|
if record.keywords:
|
||||||
try:
|
try:
|
||||||
keywords_data = (
|
keywords_data = (
|
||||||
json.loads(record.keywords)
|
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||||
if isinstance(record.keywords, str)
|
|
||||||
else record.keywords
|
|
||||||
)
|
)
|
||||||
if isinstance(keywords_data, list):
|
if isinstance(keywords_data, list):
|
||||||
record_keywords_list = [str(k).lower() for k in keywords_data]
|
record_keywords_list = [str(k).lower() for k in keywords_data]
|
||||||
@@ -125,9 +123,7 @@ def make_search_chat_history(chat_id: str):
|
|||||||
keywords_str = "、".join(keywords_list)
|
keywords_str = "、".join(keywords_list)
|
||||||
if len(keywords_list) > 2:
|
if len(keywords_list) > 2:
|
||||||
required_count = len(keywords_list) - 1
|
required_count = len(keywords_list) - 1
|
||||||
return (
|
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
||||||
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
|
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
|
||||||
elif participant:
|
elif participant:
|
||||||
@@ -142,9 +138,7 @@ def make_search_chat_history(chat_id: str):
|
|||||||
if record.keywords:
|
if record.keywords:
|
||||||
try:
|
try:
|
||||||
keywords_data = (
|
keywords_data = (
|
||||||
json.loads(record.keywords)
|
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||||
if isinstance(record.keywords, str)
|
|
||||||
else record.keywords
|
|
||||||
)
|
)
|
||||||
if isinstance(keywords_data, list):
|
if isinstance(keywords_data, list):
|
||||||
for k in keywords_data:
|
for k in keywords_data:
|
||||||
@@ -160,13 +154,13 @@ def make_search_chat_history(chat_id: str):
|
|||||||
keywords_str = "、".join(sorted(all_keywords_set))
|
keywords_str = "、".join(sorted(all_keywords_set))
|
||||||
response_text = (
|
response_text = (
|
||||||
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
||||||
f"有关\"{search_label}\"的关键词:\n"
|
f'有关"{search_label}"的关键词:\n'
|
||||||
f"{keywords_str}"
|
f"{keywords_str}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response_text = (
|
response_text = (
|
||||||
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
||||||
f"有关\"{search_label}\"的关键词信息为空"
|
f'有关"{search_label}"的关键词信息为空'
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -192,9 +186,7 @@ def make_search_chat_history(chat_id: str):
|
|||||||
if record.keywords:
|
if record.keywords:
|
||||||
try:
|
try:
|
||||||
keywords_data = (
|
keywords_data = (
|
||||||
json.loads(record.keywords)
|
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||||
if isinstance(record.keywords, str)
|
|
||||||
else record.keywords
|
|
||||||
)
|
)
|
||||||
if isinstance(keywords_data, list) and keywords_data:
|
if isinstance(keywords_data, list) and keywords_data:
|
||||||
keywords_str = "、".join([str(k) for k in keywords_data])
|
keywords_str = "、".join([str(k) for k in keywords_data])
|
||||||
@@ -220,8 +212,3 @@ def make_search_chat_history(chat_id: str):
|
|||||||
return f"search_chat_history 执行失败: {e}"
|
return f"search_chat_history 执行失败: {e}"
|
||||||
|
|
||||||
return search_chat_history
|
return search_chat_history
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,9 +16,7 @@ def make_search_jargon(chat_id: str):
|
|||||||
if not keyword or not keyword.strip():
|
if not keyword or not keyword.strip():
|
||||||
return "未指定查询关键词(参数 keyword 为必填,且不能为空)"
|
return "未指定查询关键词(参数 keyword 为必填,且不能为空)"
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})")
|
||||||
f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 基础条件:只查 is_jargon=True 的记录
|
# 基础条件:只查 is_jargon=True 的记录
|
||||||
query = Jargon.select().where(Jargon.is_jargon)
|
query = Jargon.select().where(Jargon.is_jargon)
|
||||||
@@ -102,5 +100,3 @@ def make_search_jargon(chat_id: str):
|
|||||||
return f"search_jargon 执行失败: {e}"
|
return f"search_jargon 执行失败: {e}"
|
||||||
|
|
||||||
return search_jargon
|
return search_jargon
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -49,8 +49,3 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
|
|||||||
return f"update_chat_history 执行失败: {e}"
|
return f"update_chat_history 执行失败: {e}"
|
||||||
|
|
||||||
return update_chat_history
|
return update_chat_history
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -49,8 +49,3 @@ def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留
|
|||||||
return f"update_jargon 执行失败: {e}"
|
return f"update_jargon 执行失败: {e}"
|
||||||
|
|
||||||
return update_jargon
|
return update_jargon
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -316,7 +316,9 @@ class ChatHistorySummarizer:
|
|||||||
before_count = len(self.current_batch.messages)
|
before_count = len(self.current_batch.messages)
|
||||||
self.current_batch.messages.extend(new_messages)
|
self.current_batch.messages.extend(new_messages)
|
||||||
self.current_batch.end_time = current_time
|
self.current_batch.end_time = current_time
|
||||||
logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息")
|
logger.info(
|
||||||
|
f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息"
|
||||||
|
)
|
||||||
# 更新批次后持久化
|
# 更新批次后持久化
|
||||||
self._persist_topic_cache()
|
self._persist_topic_cache()
|
||||||
else:
|
else:
|
||||||
@@ -362,9 +364,7 @@ class ChatHistorySummarizer:
|
|||||||
else:
|
else:
|
||||||
time_str = f"{time_since_last_check / 3600:.1f}小时"
|
time_str = f"{time_since_last_check / 3600:.1f}小时"
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}")
|
||||||
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查“话题检查”触发条件
|
# 检查“话题检查”触发条件
|
||||||
should_check = False
|
should_check = False
|
||||||
@@ -427,7 +427,9 @@ class ChatHistorySummarizer:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 2. 构造编号后的消息字符串和参与者信息
|
# 2. 构造编号后的消息字符串和参与者信息
|
||||||
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages)
|
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = (
|
||||||
|
self._build_numbered_messages_for_llm(messages)
|
||||||
|
)
|
||||||
|
|
||||||
# 3. 调用 LLM 识别话题,并得到 topic -> indices(失败时最多重试 3 次)
|
# 3. 调用 LLM 识别话题,并得到 topic -> indices(失败时最多重试 3 次)
|
||||||
existing_topics = list(self.topic_cache.keys())
|
existing_topics = list(self.topic_cache.keys())
|
||||||
@@ -456,9 +458,7 @@ class ChatHistorySummarizer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not success or not topic_to_indices:
|
if not success or not topic_to_indices:
|
||||||
logger.error(
|
logger.error(f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃")
|
||||||
f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃"
|
|
||||||
)
|
|
||||||
# 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks(保持原状)
|
# 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks(保持原状)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -610,9 +610,7 @@ class ChatHistorySummarizer:
|
|||||||
if not numbered_lines:
|
if not numbered_lines:
|
||||||
return False, {}
|
return False, {}
|
||||||
|
|
||||||
history_topics_block = (
|
history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
|
||||||
"\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
|
|
||||||
)
|
|
||||||
messages_block = "\n".join(numbered_lines)
|
messages_block = "\n".join(numbered_lines)
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
@@ -642,10 +640,10 @@ class ChatHistorySummarizer:
|
|||||||
else:
|
else:
|
||||||
# 如果没有找到代码块,尝试查找JSON数组的开始和结束位置
|
# 如果没有找到代码块,尝试查找JSON数组的开始和结束位置
|
||||||
# 查找第一个 [ 和最后一个 ]
|
# 查找第一个 [ 和最后一个 ]
|
||||||
start_idx = response.find('[')
|
start_idx = response.find("[")
|
||||||
end_idx = response.rfind(']')
|
end_idx = response.rfind("]")
|
||||||
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
|
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
|
||||||
json_str = response[start_idx:end_idx + 1].strip()
|
json_str = response[start_idx : end_idx + 1].strip()
|
||||||
else:
|
else:
|
||||||
# 如果还是找不到,尝试直接使用整个响应(移除可能的markdown标记)
|
# 如果还是找不到,尝试直接使用整个响应(移除可能的markdown标记)
|
||||||
json_str = response.strip()
|
json_str = response.strip()
|
||||||
@@ -942,4 +940,3 @@ class ChatHistorySummarizer:
|
|||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|
||||||
|
|||||||
@@ -366,7 +366,9 @@ class LLMRequest:
|
|||||||
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}")
|
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}")
|
||||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||||
|
|
||||||
logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}")
|
logger.warning(
|
||||||
|
f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}"
|
||||||
|
)
|
||||||
await asyncio.sleep(api_provider.retry_interval)
|
await asyncio.sleep(api_provider.retry_interval)
|
||||||
|
|
||||||
except NetworkConnectionError as e:
|
except NetworkConnectionError as e:
|
||||||
@@ -394,7 +396,9 @@ class LLMRequest:
|
|||||||
if e.status_code == 429 or e.status_code >= 500:
|
if e.status_code == 429 or e.status_code >= 500:
|
||||||
retry_remain -= 1
|
retry_remain -= 1
|
||||||
if retry_remain <= 0:
|
if retry_remain <= 0:
|
||||||
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}")
|
logger.error(
|
||||||
|
f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}"
|
||||||
|
)
|
||||||
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -540,7 +544,5 @@ class LLMRequest:
|
|||||||
if e.__cause__:
|
if e.__cause__:
|
||||||
original_error_type = type(e.__cause__).__name__
|
original_error_type = type(e.__cause__).__name__
|
||||||
original_error_msg = str(e.__cause__)
|
original_error_msg = str(e.__cause__)
|
||||||
return (
|
return f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
|
||||||
f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -113,7 +113,6 @@ class MainSystem:
|
|||||||
get_emoji_manager().initialize()
|
get_emoji_manager().initialize()
|
||||||
logger.info("表情包管理器初始化成功")
|
logger.info("表情包管理器初始化成功")
|
||||||
|
|
||||||
|
|
||||||
# 初始化聊天管理器
|
# 初始化聊天管理器
|
||||||
await get_chat_manager()._initialize()
|
await get_chat_manager()._initialize()
|
||||||
asyncio.create_task(get_chat_manager()._auto_save_task())
|
asyncio.create_task(get_chat_manager()._auto_save_task())
|
||||||
|
|||||||
@@ -136,8 +136,6 @@ def init_memory_retrieval_prompt():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _log_conversation_messages(
|
def _log_conversation_messages(
|
||||||
conversation_messages: List[Message],
|
conversation_messages: List[Message],
|
||||||
head_prompt: Optional[str] = None,
|
head_prompt: Optional[str] = None,
|
||||||
@@ -172,7 +170,9 @@ def _log_conversation_messages(
|
|||||||
|
|
||||||
# 构建单条消息的日志信息
|
# 构建单条消息的日志信息
|
||||||
# msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
|
# msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
|
||||||
msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
|
msg_info = (
|
||||||
|
f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
|
||||||
|
)
|
||||||
|
|
||||||
# if full_content:
|
# if full_content:
|
||||||
# msg_info += f"\n{full_content}"
|
# msg_info += f"\n{full_content}"
|
||||||
@@ -185,8 +185,7 @@ def _log_conversation_messages(
|
|||||||
msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}"
|
msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}"
|
||||||
|
|
||||||
# if msg.tool_call_id:
|
# if msg.tool_call_id:
|
||||||
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
|
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
|
||||||
|
|
||||||
|
|
||||||
log_lines.append(msg_info)
|
log_lines.append(msg_info)
|
||||||
|
|
||||||
@@ -365,7 +364,7 @@ async def _react_agent_solve_question(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# logger.info(
|
# logger.info(
|
||||||
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
|
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
|
||||||
# )
|
# )
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
@@ -467,10 +466,17 @@ async def _react_agent_solve_question(
|
|||||||
if parsed_found_answer:
|
if parsed_found_answer:
|
||||||
# 找到了答案
|
# 找到了答案
|
||||||
if parsed_answer:
|
if parsed_answer:
|
||||||
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": parsed_answer}})
|
step["actions"].append(
|
||||||
|
{
|
||||||
|
"action_type": "finish_search",
|
||||||
|
"action_params": {"found_answer": True, "answer": parsed_answer},
|
||||||
|
}
|
||||||
|
)
|
||||||
step["observations"] = ["检测到finish_search文本格式调用,找到答案"]
|
step["observations"] = ["检测到finish_search文本格式调用,找到答案"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}")
|
logger.info(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
|
||||||
|
)
|
||||||
|
|
||||||
_log_conversation_messages(
|
_log_conversation_messages(
|
||||||
conversation_messages,
|
conversation_messages,
|
||||||
@@ -481,10 +487,14 @@ async def _react_agent_solve_question(
|
|||||||
return True, parsed_answer, thinking_steps, False
|
return True, parsed_answer, thinking_steps, False
|
||||||
else:
|
else:
|
||||||
# found_answer为True但没有提供answer,视为错误,继续迭代
|
# found_answer为True但没有提供answer,视为错误,继续迭代
|
||||||
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer")
|
logger.warning(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 未找到答案,直接退出查询
|
# 未找到答案,直接退出查询
|
||||||
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
|
step["actions"].append(
|
||||||
|
{"action_type": "finish_search", "action_params": {"found_answer": False}}
|
||||||
|
)
|
||||||
step["observations"] = ["检测到finish_search文本格式调用,未找到答案"]
|
step["observations"] = ["检测到finish_search文本格式调用,未找到答案"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案")
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案")
|
||||||
@@ -522,10 +532,17 @@ async def _react_agent_solve_question(
|
|||||||
if finish_search_found:
|
if finish_search_found:
|
||||||
# 找到了答案
|
# 找到了答案
|
||||||
if finish_search_answer:
|
if finish_search_answer:
|
||||||
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": True, "answer": finish_search_answer}})
|
step["actions"].append(
|
||||||
|
{
|
||||||
|
"action_type": "finish_search",
|
||||||
|
"action_params": {"found_answer": True, "answer": finish_search_answer},
|
||||||
|
}
|
||||||
|
)
|
||||||
step["observations"] = ["检测到finish_search工具调用,找到答案"]
|
step["observations"] = ["检测到finish_search工具调用,找到答案"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}")
|
logger.info(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
|
||||||
|
)
|
||||||
|
|
||||||
_log_conversation_messages(
|
_log_conversation_messages(
|
||||||
conversation_messages,
|
conversation_messages,
|
||||||
@@ -536,7 +553,9 @@ async def _react_agent_solve_question(
|
|||||||
return True, finish_search_answer, thinking_steps, False
|
return True, finish_search_answer, thinking_steps, False
|
||||||
else:
|
else:
|
||||||
# found_answer为True但没有提供answer,视为错误
|
# found_answer为True但没有提供answer,视为错误
|
||||||
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer")
|
logger.warning(
|
||||||
|
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 未找到答案,直接退出查询
|
# 未找到答案,直接退出查询
|
||||||
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
|
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
|
||||||
@@ -724,7 +743,13 @@ async def _react_agent_solve_question(
|
|||||||
max_iterations=max_iterations,
|
max_iterations=max_iterations,
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_success, eval_response, eval_reasoning_content, eval_model_name, eval_tool_calls = await llm_api.generate_with_model_with_tools(
|
(
|
||||||
|
eval_success,
|
||||||
|
eval_response,
|
||||||
|
eval_reasoning_content,
|
||||||
|
eval_model_name,
|
||||||
|
eval_tool_calls,
|
||||||
|
) = await llm_api.generate_with_model_with_tools(
|
||||||
evaluation_prompt,
|
evaluation_prompt,
|
||||||
model_config=model_config.model_task_config.tool_use,
|
model_config=model_config.model_task_config.tool_use,
|
||||||
tool_options=[], # 最终评估阶段不提供工具
|
tool_options=[], # 最终评估阶段不提供工具
|
||||||
@@ -759,7 +784,7 @@ async def _react_agent_solve_question(
|
|||||||
"iteration": current_iteration,
|
"iteration": current_iteration,
|
||||||
"thought": f"[最终评估] {eval_response}",
|
"thought": f"[最终评估] {eval_response}",
|
||||||
"actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}],
|
"actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}],
|
||||||
"observations": ["最终评估阶段检测到found_answer"]
|
"observations": ["最终评估阶段检测到found_answer"],
|
||||||
}
|
}
|
||||||
thinking_steps.append(eval_step)
|
thinking_steps.append(eval_step)
|
||||||
logger.info(f"ReAct Agent 最终评估阶段找到关于问题{question}的答案: {found_answer_content}")
|
logger.info(f"ReAct Agent 最终评估阶段找到关于问题{question}的答案: {found_answer_content}")
|
||||||
@@ -778,7 +803,7 @@ async def _react_agent_solve_question(
|
|||||||
"iteration": current_iteration,
|
"iteration": current_iteration,
|
||||||
"thought": f"[最终评估] {eval_response}",
|
"thought": f"[最终评估] {eval_response}",
|
||||||
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}],
|
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}],
|
||||||
"observations": ["最终评估阶段检测到not_enough_info"]
|
"observations": ["最终评估阶段检测到not_enough_info"],
|
||||||
}
|
}
|
||||||
thinking_steps.append(eval_step)
|
thinking_steps.append(eval_step)
|
||||||
logger.info(f"ReAct Agent 最终评估阶段判断信息不足: {not_enough_info_reason}")
|
logger.info(f"ReAct Agent 最终评估阶段判断信息不足: {not_enough_info_reason}")
|
||||||
@@ -795,8 +820,10 @@ async def _react_agent_solve_question(
|
|||||||
eval_step = {
|
eval_step = {
|
||||||
"iteration": current_iteration,
|
"iteration": current_iteration,
|
||||||
"thought": f"[最终评估] {eval_response}",
|
"thought": f"[最终评估] {eval_response}",
|
||||||
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}],
|
"actions": [
|
||||||
"observations": ["已到达最大迭代次数,无法找到答案"]
|
{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}
|
||||||
|
],
|
||||||
|
"observations": ["已到达最大迭代次数,无法找到答案"],
|
||||||
}
|
}
|
||||||
thinking_steps.append(eval_step)
|
thinking_steps.append(eval_step)
|
||||||
logger.info("ReAct Agent 已到达最大迭代次数,无法找到答案")
|
logger.info("ReAct Agent 已到达最大迭代次数,无法找到答案")
|
||||||
@@ -1129,7 +1156,9 @@ async def build_memory_retrieval_prompt(
|
|||||||
else:
|
else:
|
||||||
max_iterations = base_max_iterations
|
max_iterations = base_max_iterations
|
||||||
timeout_seconds = global_config.memory.agent_timeout_seconds
|
timeout_seconds = global_config.memory.agent_timeout_seconds
|
||||||
logger.debug(f"问题数量: {len(questions)},think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}秒")
|
logger.debug(
|
||||||
|
f"问题数量: {len(questions)},think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}秒"
|
||||||
|
)
|
||||||
|
|
||||||
# 并行处理所有问题,将概念检索结果作为初始信息传递
|
# 并行处理所有问题,将概念检索结果作为初始信息传递
|
||||||
question_tasks = [
|
question_tasks = [
|
||||||
@@ -1198,4 +1227,3 @@ async def build_memory_retrieval_prompt(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"记忆检索时发生异常: {str(e)}")
|
logger.error(f"记忆检索时发生异常: {str(e)}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from src.common.logger import get_logger
|
|||||||
logger = get_logger("memory_utils")
|
logger = get_logger("memory_utils")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
|
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
|
||||||
"""解析问题JSON,返回概念列表和问题列表
|
"""解析问题JSON,返回概念列表和问题列表
|
||||||
|
|
||||||
@@ -68,6 +67,7 @@ def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
|
|||||||
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
|
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
|
|
||||||
def parse_datetime_to_timestamp(value: str) -> float:
|
def parse_datetime_to_timestamp(value: str) -> float:
|
||||||
"""
|
"""
|
||||||
接受多种常见格式并转换为时间戳(秒)
|
接受多种常见格式并转换为时间戳(秒)
|
||||||
|
|||||||
@@ -47,4 +47,3 @@ def register_tool():
|
|||||||
],
|
],
|
||||||
execute_func=finish_search,
|
execute_func=finish_search,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -16,9 +16,7 @@ from .tool_registry import register_memory_retrieval_tool
|
|||||||
logger = get_logger("memory_retrieval_tools")
|
logger = get_logger("memory_retrieval_tools")
|
||||||
|
|
||||||
|
|
||||||
async def search_chat_history(
|
async def search_chat_history(chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None) -> str:
|
||||||
chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None
|
|
||||||
) -> str:
|
|
||||||
"""根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords
|
"""根据关键词或参与人查询记忆,返回匹配的记忆id、记忆标题theme和关键词keywords
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -144,7 +142,9 @@ async def search_chat_history(
|
|||||||
keywords_list = parse_keywords_string(keyword)
|
keywords_list = parse_keywords_string(keyword)
|
||||||
if len(keywords_list) > 2:
|
if len(keywords_list) > 2:
|
||||||
required_count = len(keywords_list) - 1
|
required_count = len(keywords_list) - 1
|
||||||
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
return (
|
||||||
|
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
|
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
|
||||||
elif participant:
|
elif participant:
|
||||||
@@ -160,9 +160,7 @@ async def search_chat_history(
|
|||||||
if record.keywords:
|
if record.keywords:
|
||||||
try:
|
try:
|
||||||
keywords_data = (
|
keywords_data = (
|
||||||
json.loads(record.keywords)
|
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||||
if isinstance(record.keywords, str)
|
|
||||||
else record.keywords
|
|
||||||
)
|
)
|
||||||
if isinstance(keywords_data, list):
|
if isinstance(keywords_data, list):
|
||||||
for k in keywords_data:
|
for k in keywords_data:
|
||||||
@@ -179,13 +177,12 @@ async def search_chat_history(
|
|||||||
keywords_str = "、".join(sorted(all_keywords_set))
|
keywords_str = "、".join(sorted(all_keywords_set))
|
||||||
return (
|
return (
|
||||||
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
||||||
f"有关\"{search_label}\"的关键词:\n"
|
f'有关"{search_label}"的关键词:\n'
|
||||||
f"{keywords_str}"
|
f"{keywords_str}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
|
f'包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n有关"{search_label}"的关键词信息为空'
|
||||||
f"有关\"{search_label}\"的关键词信息为空"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建结果文本,返回id、theme和keywords(最多20条)
|
# 构建结果文本,返回id、theme和keywords(最多20条)
|
||||||
|
|||||||
@@ -414,7 +414,9 @@ async def websocket_chat(
|
|||||||
group_id=virtual_group_id,
|
group_id=virtual_group_id,
|
||||||
group_name=group_name or "WebUI虚拟群聊",
|
group_name=group_name or "WebUI虚拟群聊",
|
||||||
)
|
)
|
||||||
logger.info(f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}")
|
logger.info(
|
||||||
|
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
""" 表情包管理 API 路由"""
|
"""表情包管理 API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
|
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
|
||||||
from fastapi.responses import FileResponse, JSONResponse
|
from fastapi.responses import FileResponse, JSONResponse
|
||||||
@@ -100,23 +100,23 @@ def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
|
|||||||
try:
|
try:
|
||||||
with Image.open(source_path) as img:
|
with Image.open(source_path) as img:
|
||||||
# GIF 处理:提取第一帧
|
# GIF 处理:提取第一帧
|
||||||
if hasattr(img, 'n_frames') and img.n_frames > 1:
|
if hasattr(img, "n_frames") and img.n_frames > 1:
|
||||||
img.seek(0) # 确保在第一帧
|
img.seek(0) # 确保在第一帧
|
||||||
|
|
||||||
# 转换为 RGB/RGBA(WebP 支持透明度)
|
# 转换为 RGB/RGBA(WebP 支持透明度)
|
||||||
if img.mode in ('P', 'PA'):
|
if img.mode in ("P", "PA"):
|
||||||
# 调色板模式转换为 RGBA 以保留透明度
|
# 调色板模式转换为 RGBA 以保留透明度
|
||||||
img = img.convert('RGBA')
|
img = img.convert("RGBA")
|
||||||
elif img.mode == 'LA':
|
elif img.mode == "LA":
|
||||||
img = img.convert('RGBA')
|
img = img.convert("RGBA")
|
||||||
elif img.mode not in ('RGB', 'RGBA'):
|
elif img.mode not in ("RGB", "RGBA"):
|
||||||
img = img.convert('RGB')
|
img = img.convert("RGB")
|
||||||
|
|
||||||
# 创建缩略图(保持宽高比)
|
# 创建缩略图(保持宽高比)
|
||||||
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
|
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
# 保存为 WebP 格式
|
# 保存为 WebP 格式
|
||||||
img.save(cache_path, 'WEBP', quality=THUMBNAIL_QUALITY, method=6)
|
img.save(cache_path, "WEBP", quality=THUMBNAIL_QUALITY, method=6)
|
||||||
|
|
||||||
logger.debug(f"生成缩略图: {file_hash} -> {cache_path}")
|
logger.debug(f"生成缩略图: {file_hash} -> {cache_path}")
|
||||||
|
|
||||||
@@ -163,6 +163,7 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]:
|
|||||||
|
|
||||||
return cleaned, kept
|
return cleaned, kept
|
||||||
|
|
||||||
|
|
||||||
# 模块级别的类型别名(解决 B008 ruff 错误)
|
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||||
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
|
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
|
||||||
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
|
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
|
||||||
@@ -365,7 +366,9 @@ async def get_emoji_list(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
||||||
async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def get_emoji_detail(
|
||||||
|
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
获取表情包详细信息
|
获取表情包详细信息
|
||||||
|
|
||||||
@@ -394,7 +397,12 @@ async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
||||||
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def update_emoji(
|
||||||
|
emoji_id: int,
|
||||||
|
request: EmojiUpdateRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
增量更新表情包(只更新提供的字段)
|
增量更新表情包(只更新提供的字段)
|
||||||
|
|
||||||
@@ -446,7 +454,9 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_sessio
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
||||||
async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def delete_emoji(
|
||||||
|
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
删除表情包
|
删除表情包
|
||||||
|
|
||||||
@@ -538,7 +548,9 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
||||||
async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def register_emoji(
|
||||||
|
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
注册表情包(快捷操作)
|
注册表情包(快捷操作)
|
||||||
|
|
||||||
@@ -578,7 +590,9 @@ async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(N
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
||||||
async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def ban_emoji(
|
||||||
|
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
禁用表情包(快捷操作)
|
禁用表情包(快捷操作)
|
||||||
|
|
||||||
@@ -680,9 +694,7 @@ async def get_emoji_thumbnail(
|
|||||||
}
|
}
|
||||||
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
||||||
return FileResponse(
|
return FileResponse(
|
||||||
path=emoji.full_path,
|
path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}"
|
||||||
media_type=media_type,
|
|
||||||
filename=f"{emoji.emoji_hash}.{emoji.format}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 尝试获取或生成缩略图
|
# 尝试获取或生成缩略图
|
||||||
@@ -692,9 +704,7 @@ async def get_emoji_thumbnail(
|
|||||||
if cache_path.exists():
|
if cache_path.exists():
|
||||||
# 缓存命中,直接返回
|
# 缓存命中,直接返回
|
||||||
return FileResponse(
|
return FileResponse(
|
||||||
path=str(cache_path),
|
path=str(cache_path), media_type="image/webp", filename=f"{emoji.emoji_hash}_thumb.webp"
|
||||||
media_type="image/webp",
|
|
||||||
filename=f"{emoji.emoji_hash}_thumb.webp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 缓存未命中,触发后台生成并返回 202
|
# 缓存未命中,触发后台生成并返回 202
|
||||||
@@ -703,11 +713,7 @@ async def get_emoji_thumbnail(
|
|||||||
# 标记为正在生成
|
# 标记为正在生成
|
||||||
_generating_thumbnails.add(emoji.emoji_hash)
|
_generating_thumbnails.add(emoji.emoji_hash)
|
||||||
# 提交到线程池后台生成
|
# 提交到线程池后台生成
|
||||||
_thumbnail_executor.submit(
|
_thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.emoji_hash)
|
||||||
_background_generate_thumbnail,
|
|
||||||
emoji.full_path,
|
|
||||||
emoji.emoji_hash
|
|
||||||
)
|
|
||||||
|
|
||||||
# 返回 202 Accepted,告诉前端缩略图正在生成中
|
# 返回 202 Accepted,告诉前端缩略图正在生成中
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@@ -719,7 +725,7 @@ async def get_emoji_thumbnail(
|
|||||||
},
|
},
|
||||||
headers={
|
headers={
|
||||||
"Retry-After": "1", # 建议 1 秒后重试
|
"Retry-After": "1", # 建议 1 秒后重试
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -730,7 +736,11 @@ async def get_emoji_thumbnail(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||||
async def batch_delete_emojis(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def batch_delete_emojis(
|
||||||
|
request: BatchDeleteRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
批量删除表情包
|
批量删除表情包
|
||||||
|
|
||||||
@@ -1234,12 +1244,7 @@ async def preheat_thumbnail_cache(
|
|||||||
try:
|
try:
|
||||||
# 使用线程池异步生成缩略图,避免阻塞事件循环
|
# 使用线程池异步生成缩略图,避免阻塞事件循环
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.emoji_hash)
|
||||||
_thumbnail_executor,
|
|
||||||
_generate_thumbnail,
|
|
||||||
emoji.full_path,
|
|
||||||
emoji.emoji_hash
|
|
||||||
)
|
|
||||||
generated += 1
|
generated += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}")
|
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}")
|
||||||
|
|||||||
@@ -256,7 +256,9 @@ async def get_expression_list(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||||
async def get_expression_detail(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def get_expression_detail(
|
||||||
|
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
获取表达方式详细信息
|
获取表达方式详细信息
|
||||||
|
|
||||||
@@ -285,7 +287,11 @@ async def get_expression_detail(expression_id: int, maibot_session: Optional[str
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/", response_model=ExpressionCreateResponse)
|
@router.post("/", response_model=ExpressionCreateResponse)
|
||||||
async def create_expression(request: ExpressionCreateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def create_expression(
|
||||||
|
request: ExpressionCreateRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
创建新的表达方式
|
创建新的表达方式
|
||||||
|
|
||||||
@@ -326,7 +332,10 @@ async def create_expression(request: ExpressionCreateRequest, maibot_session: Op
|
|||||||
|
|
||||||
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
||||||
async def update_expression(
|
async def update_expression(
|
||||||
expression_id: int, request: ExpressionUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
expression_id: int,
|
||||||
|
request: ExpressionUpdateRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
增量更新表达方式(只更新提供的字段)
|
增量更新表达方式(只更新提供的字段)
|
||||||
@@ -376,7 +385,9 @@ async def update_expression(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||||
async def delete_expression(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def delete_expression(
|
||||||
|
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
删除表达方式
|
删除表达方式
|
||||||
|
|
||||||
@@ -419,7 +430,11 @@ class BatchDeleteRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
|
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
|
||||||
async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def batch_delete_expressions(
|
||||||
|
request: BatchDeleteRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
批量删除表达方式
|
批量删除表达方式
|
||||||
|
|
||||||
@@ -460,7 +475,9 @@ async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session:
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/stats/summary")
|
@router.get("/stats/summary")
|
||||||
async def get_expression_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def get_expression_stats(
|
||||||
|
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
获取表达方式统计数据
|
获取表达方式统计数据
|
||||||
|
|
||||||
|
|||||||
@@ -277,11 +277,7 @@ async def get_chat_list():
|
|||||||
"""获取所有有黑话记录的聊天列表"""
|
"""获取所有有黑话记录的聊天列表"""
|
||||||
try:
|
try:
|
||||||
# 获取所有不同的 chat_id
|
# 获取所有不同的 chat_id
|
||||||
chat_ids = (
|
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
|
||||||
Jargon.select(Jargon.chat_id)
|
|
||||||
.distinct()
|
|
||||||
.where(Jargon.chat_id.is_null(False))
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
|
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
|
||||||
|
|
||||||
@@ -346,12 +342,7 @@ async def get_jargon_stats():
|
|||||||
complete_count = Jargon.select().where(Jargon.is_complete).count()
|
complete_count = Jargon.select().where(Jargon.is_complete).count()
|
||||||
|
|
||||||
# 关联的聊天数量
|
# 关联的聊天数量
|
||||||
chat_count = (
|
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
|
||||||
Jargon.select(Jargon.chat_id)
|
|
||||||
.distinct()
|
|
||||||
.where(Jargon.chat_id.is_null(False))
|
|
||||||
.count()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 按聊天统计 TOP 5
|
# 按聊天统计 TOP 5
|
||||||
top_chats = (
|
top_chats = (
|
||||||
@@ -403,9 +394,7 @@ async def create_jargon(request: JargonCreateRequest):
|
|||||||
"""创建黑话"""
|
"""创建黑话"""
|
||||||
try:
|
try:
|
||||||
# 检查是否已存在相同内容的黑话
|
# 检查是否已存在相同内容的黑话
|
||||||
existing = Jargon.get_or_none(
|
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
|
||||||
(Jargon.content == request.content) & (Jargon.chat_id == request.chat_id)
|
|
||||||
)
|
|
||||||
if existing:
|
if existing:
|
||||||
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
|
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
|
||||||
|
|
||||||
@@ -527,11 +516,7 @@ async def batch_set_jargon_status(
|
|||||||
if not ids:
|
if not ids:
|
||||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||||
|
|
||||||
updated_count = (
|
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
|
||||||
Jargon.update(is_jargon=is_jargon)
|
|
||||||
.where(Jargon.id.in_(ids))
|
|
||||||
.execute()
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录,is_jargon={is_jargon}")
|
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录,is_jargon={is_jargon}")
|
||||||
|
|
||||||
|
|||||||
@@ -200,7 +200,9 @@ async def get_person_list(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
||||||
async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def get_person_detail(
|
||||||
|
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
获取人物详细信息
|
获取人物详细信息
|
||||||
|
|
||||||
@@ -229,7 +231,12 @@ async def get_person_detail(person_id: str, maibot_session: Optional[str] = Cook
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
||||||
async def update_person(person_id: str, request: PersonUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def update_person(
|
||||||
|
person_id: str,
|
||||||
|
request: PersonUpdateRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
增量更新人物信息(只更新提供的字段)
|
增量更新人物信息(只更新提供的字段)
|
||||||
|
|
||||||
@@ -278,7 +285,9 @@ async def update_person(person_id: str, request: PersonUpdateRequest, maibot_ses
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
||||||
async def delete_person(person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def delete_person(
|
||||||
|
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
删除人物信息
|
删除人物信息
|
||||||
|
|
||||||
@@ -348,7 +357,11 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||||
async def batch_delete_persons(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
async def batch_delete_persons(
|
||||||
|
request: BatchDeleteRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
批量删除人物信息
|
批量删除人物信息
|
||||||
|
|
||||||
|
|||||||
@@ -125,6 +125,7 @@ def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> No
|
|||||||
"""
|
"""
|
||||||
根据 schema 将配置中的类型纠正(目前只纠正 list-from-str)。
|
根据 schema 将配置中的类型纠正(目前只纠正 list-from-str)。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _is_list_type(tp: Any) -> bool:
|
def _is_list_type(tp: Any) -> bool:
|
||||||
origin = get_origin(tp)
|
origin = get_origin(tp)
|
||||||
return tp is list or origin is list
|
return tp is list or origin is list
|
||||||
@@ -313,7 +314,9 @@ async def check_git_status() -> GitStatusResponse:
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
|
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
|
||||||
async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse:
|
async def get_available_mirrors(
|
||||||
|
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
) -> AvailableMirrorsResponse:
|
||||||
"""
|
"""
|
||||||
获取所有可用的镜像源配置
|
获取所有可用的镜像源配置
|
||||||
"""
|
"""
|
||||||
@@ -343,7 +346,9 @@ async def get_available_mirrors(maibot_session: Optional[str] = Cookie(None), au
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/mirrors", response_model=MirrorConfigResponse)
|
@router.post("/mirrors", response_model=MirrorConfigResponse)
|
||||||
async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> MirrorConfigResponse:
|
async def add_mirror(
|
||||||
|
request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
) -> MirrorConfigResponse:
|
||||||
"""
|
"""
|
||||||
添加新的镜像源
|
添加新的镜像源
|
||||||
"""
|
"""
|
||||||
@@ -383,7 +388,10 @@ async def add_mirror(request: AddMirrorRequest, maibot_session: Optional[str] =
|
|||||||
|
|
||||||
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
|
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
|
||||||
async def update_mirror(
|
async def update_mirror(
|
||||||
mirror_id: str, request: UpdateMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
mirror_id: str,
|
||||||
|
request: UpdateMirrorRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
) -> MirrorConfigResponse:
|
) -> MirrorConfigResponse:
|
||||||
"""
|
"""
|
||||||
更新镜像源配置
|
更新镜像源配置
|
||||||
@@ -426,7 +434,9 @@ async def update_mirror(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/mirrors/{mirror_id}")
|
@router.delete("/mirrors/{mirror_id}")
|
||||||
async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
async def delete_mirror(
|
||||||
|
mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
删除镜像源
|
删除镜像源
|
||||||
"""
|
"""
|
||||||
@@ -449,7 +459,9 @@ async def delete_mirror(mirror_id: str, maibot_session: Optional[str] = Cookie(N
|
|||||||
|
|
||||||
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
|
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
|
||||||
async def fetch_raw_file(
|
async def fetch_raw_file(
|
||||||
request: FetchRawFileRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
request: FetchRawFileRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
) -> FetchRawFileResponse:
|
) -> FetchRawFileResponse:
|
||||||
"""
|
"""
|
||||||
获取 GitHub 仓库的 Raw 文件内容
|
获取 GitHub 仓库的 Raw 文件内容
|
||||||
@@ -534,7 +546,9 @@ async def fetch_raw_file(
|
|||||||
|
|
||||||
@router.post("/clone", response_model=CloneRepositoryResponse)
|
@router.post("/clone", response_model=CloneRepositoryResponse)
|
||||||
async def clone_repository(
|
async def clone_repository(
|
||||||
request: CloneRepositoryRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
request: CloneRepositoryRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
) -> CloneRepositoryResponse:
|
) -> CloneRepositoryResponse:
|
||||||
"""
|
"""
|
||||||
克隆 GitHub 仓库到本地
|
克隆 GitHub 仓库到本地
|
||||||
@@ -574,7 +588,11 @@ async def clone_repository(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/install")
|
@router.post("/install")
|
||||||
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
async def install_plugin(
|
||||||
|
request: InstallPluginRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
安装插件
|
安装插件
|
||||||
|
|
||||||
@@ -778,7 +796,9 @@ async def install_plugin(request: InstallPluginRequest, maibot_session: Optional
|
|||||||
|
|
||||||
@router.post("/uninstall")
|
@router.post("/uninstall")
|
||||||
async def uninstall_plugin(
|
async def uninstall_plugin(
|
||||||
request: UninstallPluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
request: UninstallPluginRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
卸载插件
|
卸载插件
|
||||||
@@ -913,7 +933,11 @@ async def uninstall_plugin(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/update")
|
@router.post("/update")
|
||||||
async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
async def update_plugin(
|
||||||
|
request: UpdatePluginRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
更新插件
|
更新插件
|
||||||
|
|
||||||
@@ -1132,7 +1156,9 @@ async def update_plugin(request: UpdatePluginRequest, maibot_session: Optional[s
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/installed")
|
@router.get("/installed")
|
||||||
async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
async def get_installed_plugins(
|
||||||
|
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取已安装的插件列表
|
获取已安装的插件列表
|
||||||
|
|
||||||
@@ -1272,7 +1298,9 @@ class UpdatePluginConfigRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/config/{plugin_id}/schema")
|
@router.get("/config/{plugin_id}/schema")
|
||||||
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
async def get_plugin_config_schema(
|
||||||
|
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取插件配置 Schema
|
获取插件配置 Schema
|
||||||
|
|
||||||
@@ -1405,7 +1433,9 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/config/{plugin_id}")
|
@router.get("/config/{plugin_id}")
|
||||||
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
async def get_plugin_config(
|
||||||
|
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取插件当前配置值
|
获取插件当前配置值
|
||||||
|
|
||||||
@@ -1461,7 +1491,10 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook
|
|||||||
|
|
||||||
@router.put("/config/{plugin_id}")
|
@router.put("/config/{plugin_id}")
|
||||||
async def update_plugin_config(
|
async def update_plugin_config(
|
||||||
plugin_id: str, request: UpdatePluginConfigRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
plugin_id: str,
|
||||||
|
request: UpdatePluginConfigRequest,
|
||||||
|
maibot_session: Optional[str] = Cookie(None),
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
更新插件配置
|
更新插件配置
|
||||||
@@ -1532,7 +1565,9 @@ async def update_plugin_config(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/config/{plugin_id}/reset")
|
@router.post("/config/{plugin_id}/reset")
|
||||||
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
async def reset_plugin_config(
|
||||||
|
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
重置插件配置为默认值
|
重置插件配置为默认值
|
||||||
|
|
||||||
@@ -1592,7 +1627,9 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/config/{plugin_id}/toggle")
|
@router.post("/config/{plugin_id}/toggle")
|
||||||
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
async def toggle_plugin(
|
||||||
|
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
切换插件启用状态
|
切换插件启用状态
|
||||||
|
|
||||||
|
|||||||
@@ -110,8 +110,10 @@ class WebUIServer:
|
|||||||
from src.webui.routes import router as webui_router
|
from src.webui.routes import router as webui_router
|
||||||
from src.webui.logs_ws import router as logs_router
|
from src.webui.logs_ws import router as logs_router
|
||||||
from src.webui.knowledge_routes import router as knowledge_router
|
from src.webui.knowledge_routes import router as knowledge_router
|
||||||
|
|
||||||
# 导入本地聊天室路由
|
# 导入本地聊天室路由
|
||||||
from src.webui.chat_routes import router as chat_router
|
from src.webui.chat_routes import router as chat_router
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
self.app.include_router(webui_router)
|
self.app.include_router(webui_router)
|
||||||
self.app.include_router(logs_router)
|
self.app.include_router(logs_router)
|
||||||
@@ -166,6 +168,7 @@ class WebUIServer:
|
|||||||
def _check_port_available(self) -> bool:
|
def _check_port_available(self) -> bool:
|
||||||
"""检查端口是否可用"""
|
"""检查端口是否可用"""
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
s.settimeout(1)
|
s.settimeout(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user