feat: add unified WebSocket connection manager and routing

- Implemented UnifiedWebSocketManager for managing WebSocket connections, including subscription handling and message sending.
- Created unified WebSocket router to handle client messages, including authentication, subscription, and chat session management.
- Added support for logging and plugin progress subscriptions.
- Enhanced error handling and response structure for WebSocket operations.
This commit is contained in:
DrSmoothl
2026-04-02 22:08:52 +08:00
parent 7d0d429640
commit 1906890b67
28 changed files with 3845 additions and 1137 deletions

View File

@@ -0,0 +1,548 @@
"""统一 WebSocket 路由。"""
from typing import Any, Dict, Optional, Set, cast
import asyncio
import time
import uuid
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.logs_ws import load_recent_logs
from src.webui.routers.chat.service import (
chat_manager,
dispatch_chat_event,
normalize_webui_user_id,
resolve_initial_virtual_identity,
send_initial_chat_state,
)
from src.webui.routers.plugin.progress import get_current_progress
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.routers.websocket.manager import websocket_manager
logger = get_logger("webui.unified_ws")
router = APIRouter()
_background_tasks: Set["asyncio.Task[None]"] = set()
def _build_error(code: str, message: str) -> Dict[str, Any]:
"""构建统一错误响应体。
Args:
code: 错误码。
message: 错误描述。
Returns:
Dict[str, Any]: 统一错误对象。
"""
return {
"code": code,
"message": message,
}
def _get_request_data(message: Dict[str, Any]) -> Dict[str, Any]:
"""从客户端消息中提取数据字段。
Args:
message: 客户端消息。
Returns:
Dict[str, Any]: 标准化后的数据字典。
"""
data = message.get("data", {})
if isinstance(data, dict):
return cast(Dict[str, Any], data)
return {}
def _track_background_task(task: "asyncio.Task[None]") -> None:
"""登记后台任务并在完成后自动清理。
Args:
task: 后台协程任务。
"""
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
async def authenticate_websocket_connection(websocket: WebSocket, token: Optional[str]) -> bool:
"""校验统一 WebSocket 连接的认证状态。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
Returns:
bool: 认证通过时返回 ``True``。
"""
if token and verify_ws_token(token):
logger.debug("统一 WebSocket 使用临时 token 认证成功")
return True
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
logger.debug("统一 WebSocket 使用 Cookie 认证成功")
return True
return False
async def _handle_logs_subscribe(connection_id: str, request_id: Optional[str], data: Dict[str, Any]) -> None:
"""处理日志域订阅请求。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
data: 订阅参数。
"""
replay_limit = int(data.get("replay", 100) or 100)
replay_limit = max(0, min(replay_limit, 500))
websocket_manager.subscribe(connection_id, domain="logs", topic="main")
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": "logs", "topic": "main"},
)
await websocket_manager.send_event(
connection_id,
domain="logs",
event="snapshot",
topic="main",
data={"entries": load_recent_logs(limit=replay_limit)},
)
async def _handle_plugin_progress_subscribe(connection_id: str, request_id: Optional[str]) -> None:
"""处理插件进度域订阅请求。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
"""
websocket_manager.subscribe(connection_id, domain="plugin_progress", topic="main")
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": "plugin_progress", "topic": "main"},
)
await websocket_manager.send_event(
connection_id,
domain="plugin_progress",
event="snapshot",
topic="main",
data={"progress": get_current_progress()},
)
async def _handle_subscribe(connection_id: str, message: Dict[str, Any]) -> None:
"""处理主题订阅请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
topic = str(message.get("topic") or "").strip()
data = _get_request_data(message)
if domain == "logs" and topic == "main":
await _handle_logs_subscribe(connection_id, request_id, data)
return
if domain == "plugin_progress" and topic == "main":
await _handle_plugin_progress_subscribe(connection_id, request_id)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_subscription", f"不支持的订阅目标: {domain}:{topic}"),
)
async def _handle_unsubscribe(connection_id: str, message: Dict[str, Any]) -> None:
"""处理主题退订请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
topic = str(message.get("topic") or "").strip()
if not domain or not topic:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("invalid_unsubscribe", "退订请求缺少 domain 或 topic"),
)
return
websocket_manager.unsubscribe(connection_id, domain=domain, topic=topic)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": domain, "topic": topic},
)
async def _open_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
"""打开一个逻辑聊天会话。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
if not client_session_id:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("missing_session", "聊天会话打开请求缺少 session"),
)
return
data = _get_request_data(message)
normalized_user_id = normalize_webui_user_id(cast(Optional[str], data.get("user_id")))
current_user_name = str(data.get("user_name") or "WebUI用户")
current_virtual_config = resolve_initial_virtual_identity(
platform=cast(Optional[str], data.get("platform")),
person_id=cast(Optional[str], data.get("person_id")),
group_name=cast(Optional[str], data.get("group_name")),
group_id=cast(Optional[str], data.get("group_id")),
)
restore = bool(data.get("restore"))
session_id = f"{connection_id}:{client_session_id}"
async def send_chat_event(chat_message: Dict[str, Any]) -> None:
"""将聊天消息封装为统一事件并发送。
Args:
chat_message: 聊天消息体。
"""
event_name = str(chat_message.get("type") or "message")
await websocket_manager.send_event(
connection_id,
domain="chat",
event=event_name,
session=client_session_id,
data=chat_message,
)
await chat_manager.connect(
session_id=session_id,
connection_id=connection_id,
client_session_id=client_session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
sender=send_chat_event,
)
websocket_manager.register_chat_session(connection_id, client_session_id, session_id)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id, "session_id": session_id},
)
await send_initial_chat_state(
session_id=session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
include_welcome=not restore,
)
async def _close_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
"""关闭一个逻辑聊天会话。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
chat_manager.disconnect(session_id)
websocket_manager.unregister_chat_session(connection_id, client_session_id)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id},
)
async def _process_chat_message(connection_id: str, client_session_id: str, data: Dict[str, Any]) -> None:
"""在后台处理聊天消息事件。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
data: 客户端提交的消息数据。
"""
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
return
session_state = chat_manager.get_session(session_id)
if session_state is None:
return
next_user_name, next_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data=data,
current_user_name=session_state.user_name,
normalized_user_id=session_state.user_id,
current_virtual_config=session_state.virtual_config,
)
chat_manager.update_session_context(
session_id=session_id,
user_name=next_user_name,
virtual_config=next_virtual_config,
)
async def _handle_chat_message_send(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天消息发送请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
data = _get_request_data(message)
payload = {
"type": "message",
"content": data.get("content", ""),
"user_name": data.get("user_name", ""),
}
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"accepted": True, "session": client_session_id},
)
_track_background_task(asyncio.create_task(_process_chat_message(connection_id, client_session_id, payload)))
async def _handle_chat_nickname_update(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天昵称更新请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
data = _get_request_data(message)
session_state = chat_manager.get_session(session_id)
if session_state is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
next_user_name, next_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data={
"type": "update_nickname",
"user_name": data.get("user_name", ""),
},
current_user_name=session_state.user_name,
normalized_user_id=session_state.user_id,
current_virtual_config=session_state.virtual_config,
)
chat_manager.update_session_context(
session_id=session_id,
user_name=next_user_name,
virtual_config=next_virtual_config,
)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id, "user_name": next_user_name},
)
async def _handle_chat_call(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天域调用请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
method = str(message.get("method") or "").strip()
if method == "session.open":
await _open_chat_session(connection_id, message)
return
if method == "session.close":
await _close_chat_session(connection_id, message)
return
if method == "message.send":
await _handle_chat_message_send(connection_id, message)
return
if method == "session.update_nickname":
await _handle_chat_nickname_update(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_method", f"不支持的聊天方法: {method}"),
)
async def _handle_call(connection_id: str, message: Dict[str, Any]) -> None:
"""处理统一调用请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
if domain == "chat":
await _handle_chat_call(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_domain", f"不支持的调用域: {domain}"),
)
async def handle_client_message(connection_id: str, message: Dict[str, Any]) -> None:
"""处理统一 WebSocket 客户端消息。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
operation = str(message.get("op") or "").strip()
request_id = cast(Optional[str], message.get("id"))
if operation == "ping":
await websocket_manager.send_pong(connection_id, time.time())
return
if operation == "subscribe":
await _handle_subscribe(connection_id, message)
return
if operation == "unsubscribe":
await _handle_unsubscribe(connection_id, message)
return
if operation == "call":
await _handle_call(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_operation", f"不支持的操作: {operation}"),
)
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
"""统一 WebSocket 入口。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
"""
if not await authenticate_websocket_connection(websocket, token):
logger.warning("统一 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
connection_id = uuid.uuid4().hex
await websocket_manager.connect(connection_id=connection_id, websocket=websocket)
logger.info("统一 WebSocket 客户端已连接: connection=%s", connection_id)
await websocket_manager.send_event(
connection_id,
domain="system",
event="ready",
data={"connection_id": connection_id, "timestamp": time.time()},
)
try:
while True:
raw_message = await websocket.receive_json()
if not isinstance(raw_message, dict):
await websocket_manager.send_response(
connection_id,
request_id=None,
ok=False,
error=_build_error("invalid_message", "消息必须是 JSON 对象"),
)
continue
await handle_client_message(connection_id, cast(Dict[str, Any], raw_message))
except WebSocketDisconnect:
logger.info("统一 WebSocket 客户端已断开: connection=%s", connection_id)
except Exception as exc:
logger.error(f"统一 WebSocket 处理失败: {exc}")
finally:
chat_manager.disconnect_connection(connection_id)
await websocket_manager.disconnect(connection_id)