- Implemented UnifiedWebSocketManager for managing WebSocket connections, including subscription handling and message sending. - Created unified WebSocket router to handle client messages, including authentication, subscription, and chat session management. - Added support for logging and plugin progress subscriptions. - Enhanced error handling and response structure for WebSocket operations.
549 lines
17 KiB
Python
549 lines
17 KiB
Python
"""统一 WebSocket 路由。"""
|
|
|
|
from typing import Any, Dict, Optional, Set, cast
|
|
import asyncio
|
|
import time
|
|
import uuid
|
|
|
|
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
|
|
|
from src.common.logger import get_logger
|
|
from src.webui.core import get_token_manager
|
|
from src.webui.logs_ws import load_recent_logs
|
|
from src.webui.routers.chat.service import (
|
|
chat_manager,
|
|
dispatch_chat_event,
|
|
normalize_webui_user_id,
|
|
resolve_initial_virtual_identity,
|
|
send_initial_chat_state,
|
|
)
|
|
from src.webui.routers.plugin.progress import get_current_progress
|
|
from src.webui.routers.websocket.auth import verify_ws_token
|
|
from src.webui.routers.websocket.manager import websocket_manager
|
|
|
|
logger = get_logger("webui.unified_ws")
|
|
router = APIRouter()
|
|
_background_tasks: Set["asyncio.Task[None]"] = set()
|
|
|
|
|
|
def _build_error(code: str, message: str) -> Dict[str, Any]:
|
|
"""构建统一错误响应体。
|
|
|
|
Args:
|
|
code: 错误码。
|
|
message: 错误描述。
|
|
|
|
Returns:
|
|
Dict[str, Any]: 统一错误对象。
|
|
"""
|
|
return {
|
|
"code": code,
|
|
"message": message,
|
|
}
|
|
|
|
|
|
def _get_request_data(message: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""从客户端消息中提取数据字段。
|
|
|
|
Args:
|
|
message: 客户端消息。
|
|
|
|
Returns:
|
|
Dict[str, Any]: 标准化后的数据字典。
|
|
"""
|
|
data = message.get("data", {})
|
|
if isinstance(data, dict):
|
|
return cast(Dict[str, Any], data)
|
|
return {}
|
|
|
|
|
|
def _track_background_task(task: "asyncio.Task[None]") -> None:
|
|
"""登记后台任务并在完成后自动清理。
|
|
|
|
Args:
|
|
task: 后台协程任务。
|
|
"""
|
|
_background_tasks.add(task)
|
|
task.add_done_callback(_background_tasks.discard)
|
|
|
|
|
|
async def authenticate_websocket_connection(websocket: WebSocket, token: Optional[str]) -> bool:
|
|
"""校验统一 WebSocket 连接的认证状态。
|
|
|
|
Args:
|
|
websocket: FastAPI WebSocket 对象。
|
|
token: 可选的一次性握手 Token。
|
|
|
|
Returns:
|
|
bool: 认证通过时返回 ``True``。
|
|
"""
|
|
if token and verify_ws_token(token):
|
|
logger.debug("统一 WebSocket 使用临时 token 认证成功")
|
|
return True
|
|
|
|
cookie_token = websocket.cookies.get("maibot_session")
|
|
if cookie_token:
|
|
token_manager = get_token_manager()
|
|
if token_manager.verify_token(cookie_token):
|
|
logger.debug("统一 WebSocket 使用 Cookie 认证成功")
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
async def _handle_logs_subscribe(connection_id: str, request_id: Optional[str], data: Dict[str, Any]) -> None:
|
|
"""处理日志域订阅请求。
|
|
|
|
Args:
|
|
connection_id: 连接 ID。
|
|
request_id: 请求 ID。
|
|
data: 订阅参数。
|
|
"""
|
|
replay_limit = int(data.get("replay", 100) or 100)
|
|
replay_limit = max(0, min(replay_limit, 500))
|
|
websocket_manager.subscribe(connection_id, domain="logs", topic="main")
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=True,
|
|
data={"domain": "logs", "topic": "main"},
|
|
)
|
|
await websocket_manager.send_event(
|
|
connection_id,
|
|
domain="logs",
|
|
event="snapshot",
|
|
topic="main",
|
|
data={"entries": load_recent_logs(limit=replay_limit)},
|
|
)
|
|
|
|
|
|
async def _handle_plugin_progress_subscribe(connection_id: str, request_id: Optional[str]) -> None:
|
|
"""处理插件进度域订阅请求。
|
|
|
|
Args:
|
|
connection_id: 连接 ID。
|
|
request_id: 请求 ID。
|
|
"""
|
|
websocket_manager.subscribe(connection_id, domain="plugin_progress", topic="main")
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=True,
|
|
data={"domain": "plugin_progress", "topic": "main"},
|
|
)
|
|
await websocket_manager.send_event(
|
|
connection_id,
|
|
domain="plugin_progress",
|
|
event="snapshot",
|
|
topic="main",
|
|
data={"progress": get_current_progress()},
|
|
)
|
|
|
|
|
|
async def _handle_subscribe(connection_id: str, message: Dict[str, Any]) -> None:
|
|
"""处理主题订阅请求。
|
|
|
|
Args:
|
|
connection_id: 连接 ID。
|
|
message: 客户端消息。
|
|
"""
|
|
request_id = cast(Optional[str], message.get("id"))
|
|
domain = str(message.get("domain") or "").strip()
|
|
topic = str(message.get("topic") or "").strip()
|
|
data = _get_request_data(message)
|
|
|
|
if domain == "logs" and topic == "main":
|
|
await _handle_logs_subscribe(connection_id, request_id, data)
|
|
return
|
|
|
|
if domain == "plugin_progress" and topic == "main":
|
|
await _handle_plugin_progress_subscribe(connection_id, request_id)
|
|
return
|
|
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=False,
|
|
error=_build_error("unsupported_subscription", f"不支持的订阅目标: {domain}:{topic}"),
|
|
)
|
|
|
|
|
|
async def _handle_unsubscribe(connection_id: str, message: Dict[str, Any]) -> None:
|
|
"""处理主题退订请求。
|
|
|
|
Args:
|
|
connection_id: 连接 ID。
|
|
message: 客户端消息。
|
|
"""
|
|
request_id = cast(Optional[str], message.get("id"))
|
|
domain = str(message.get("domain") or "").strip()
|
|
topic = str(message.get("topic") or "").strip()
|
|
|
|
if not domain or not topic:
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=False,
|
|
error=_build_error("invalid_unsubscribe", "退订请求缺少 domain 或 topic"),
|
|
)
|
|
return
|
|
|
|
websocket_manager.unsubscribe(connection_id, domain=domain, topic=topic)
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=True,
|
|
data={"domain": domain, "topic": topic},
|
|
)
|
|
|
|
|
|
async def _open_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
|
|
"""打开一个逻辑聊天会话。
|
|
|
|
Args:
|
|
connection_id: 连接 ID。
|
|
message: 客户端消息。
|
|
"""
|
|
request_id = cast(Optional[str], message.get("id"))
|
|
client_session_id = str(message.get("session") or "").strip()
|
|
if not client_session_id:
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=False,
|
|
error=_build_error("missing_session", "聊天会话打开请求缺少 session"),
|
|
)
|
|
return
|
|
|
|
data = _get_request_data(message)
|
|
normalized_user_id = normalize_webui_user_id(cast(Optional[str], data.get("user_id")))
|
|
current_user_name = str(data.get("user_name") or "WebUI用户")
|
|
current_virtual_config = resolve_initial_virtual_identity(
|
|
platform=cast(Optional[str], data.get("platform")),
|
|
person_id=cast(Optional[str], data.get("person_id")),
|
|
group_name=cast(Optional[str], data.get("group_name")),
|
|
group_id=cast(Optional[str], data.get("group_id")),
|
|
)
|
|
restore = bool(data.get("restore"))
|
|
session_id = f"{connection_id}:{client_session_id}"
|
|
|
|
async def send_chat_event(chat_message: Dict[str, Any]) -> None:
|
|
"""将聊天消息封装为统一事件并发送。
|
|
|
|
Args:
|
|
chat_message: 聊天消息体。
|
|
"""
|
|
event_name = str(chat_message.get("type") or "message")
|
|
await websocket_manager.send_event(
|
|
connection_id,
|
|
domain="chat",
|
|
event=event_name,
|
|
session=client_session_id,
|
|
data=chat_message,
|
|
)
|
|
|
|
await chat_manager.connect(
|
|
session_id=session_id,
|
|
connection_id=connection_id,
|
|
client_session_id=client_session_id,
|
|
user_id=normalized_user_id,
|
|
user_name=current_user_name,
|
|
virtual_config=current_virtual_config,
|
|
sender=send_chat_event,
|
|
)
|
|
websocket_manager.register_chat_session(connection_id, client_session_id, session_id)
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=True,
|
|
data={"session": client_session_id, "session_id": session_id},
|
|
)
|
|
await send_initial_chat_state(
|
|
session_id=session_id,
|
|
user_id=normalized_user_id,
|
|
user_name=current_user_name,
|
|
virtual_config=current_virtual_config,
|
|
include_welcome=not restore,
|
|
)
|
|
|
|
|
|
async def _close_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
|
|
"""关闭一个逻辑聊天会话。
|
|
|
|
Args:
|
|
connection_id: 连接 ID。
|
|
message: 客户端消息。
|
|
"""
|
|
request_id = cast(Optional[str], message.get("id"))
|
|
client_session_id = str(message.get("session") or "").strip()
|
|
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
|
|
if session_id is None:
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=False,
|
|
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
|
|
)
|
|
return
|
|
|
|
chat_manager.disconnect(session_id)
|
|
websocket_manager.unregister_chat_session(connection_id, client_session_id)
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=True,
|
|
data={"session": client_session_id},
|
|
)
|
|
|
|
|
|
async def _process_chat_message(connection_id: str, client_session_id: str, data: Dict[str, Any]) -> None:
|
|
"""在后台处理聊天消息事件。
|
|
|
|
Args:
|
|
connection_id: 连接 ID。
|
|
client_session_id: 前端会话 ID。
|
|
data: 客户端提交的消息数据。
|
|
"""
|
|
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
|
|
if session_id is None:
|
|
return
|
|
|
|
session_state = chat_manager.get_session(session_id)
|
|
if session_state is None:
|
|
return
|
|
|
|
next_user_name, next_virtual_config = await dispatch_chat_event(
|
|
session_id=session_id,
|
|
session_id_prefix=session_id[:8],
|
|
data=data,
|
|
current_user_name=session_state.user_name,
|
|
normalized_user_id=session_state.user_id,
|
|
current_virtual_config=session_state.virtual_config,
|
|
)
|
|
chat_manager.update_session_context(
|
|
session_id=session_id,
|
|
user_name=next_user_name,
|
|
virtual_config=next_virtual_config,
|
|
)
|
|
|
|
|
|
async def _handle_chat_message_send(connection_id: str, message: Dict[str, Any]) -> None:
|
|
"""处理聊天消息发送请求。
|
|
|
|
Args:
|
|
connection_id: 连接 ID。
|
|
message: 客户端消息。
|
|
"""
|
|
request_id = cast(Optional[str], message.get("id"))
|
|
client_session_id = str(message.get("session") or "").strip()
|
|
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
|
|
if session_id is None:
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=False,
|
|
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
|
|
)
|
|
return
|
|
|
|
data = _get_request_data(message)
|
|
payload = {
|
|
"type": "message",
|
|
"content": data.get("content", ""),
|
|
"user_name": data.get("user_name", ""),
|
|
}
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=True,
|
|
data={"accepted": True, "session": client_session_id},
|
|
)
|
|
_track_background_task(asyncio.create_task(_process_chat_message(connection_id, client_session_id, payload)))
|
|
|
|
|
|
async def _handle_chat_nickname_update(connection_id: str, message: Dict[str, Any]) -> None:
|
|
"""处理聊天昵称更新请求。
|
|
|
|
Args:
|
|
connection_id: 连接 ID。
|
|
message: 客户端消息。
|
|
"""
|
|
request_id = cast(Optional[str], message.get("id"))
|
|
client_session_id = str(message.get("session") or "").strip()
|
|
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
|
|
if session_id is None:
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=False,
|
|
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
|
|
)
|
|
return
|
|
|
|
data = _get_request_data(message)
|
|
session_state = chat_manager.get_session(session_id)
|
|
if session_state is None:
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=False,
|
|
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
|
|
)
|
|
return
|
|
|
|
next_user_name, next_virtual_config = await dispatch_chat_event(
|
|
session_id=session_id,
|
|
session_id_prefix=session_id[:8],
|
|
data={
|
|
"type": "update_nickname",
|
|
"user_name": data.get("user_name", ""),
|
|
},
|
|
current_user_name=session_state.user_name,
|
|
normalized_user_id=session_state.user_id,
|
|
current_virtual_config=session_state.virtual_config,
|
|
)
|
|
chat_manager.update_session_context(
|
|
session_id=session_id,
|
|
user_name=next_user_name,
|
|
virtual_config=next_virtual_config,
|
|
)
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=True,
|
|
data={"session": client_session_id, "user_name": next_user_name},
|
|
)
|
|
|
|
|
|
async def _handle_chat_call(connection_id: str, message: Dict[str, Any]) -> None:
|
|
"""处理聊天域调用请求。
|
|
|
|
Args:
|
|
connection_id: 连接 ID。
|
|
message: 客户端消息。
|
|
"""
|
|
request_id = cast(Optional[str], message.get("id"))
|
|
method = str(message.get("method") or "").strip()
|
|
|
|
if method == "session.open":
|
|
await _open_chat_session(connection_id, message)
|
|
return
|
|
|
|
if method == "session.close":
|
|
await _close_chat_session(connection_id, message)
|
|
return
|
|
|
|
if method == "message.send":
|
|
await _handle_chat_message_send(connection_id, message)
|
|
return
|
|
|
|
if method == "session.update_nickname":
|
|
await _handle_chat_nickname_update(connection_id, message)
|
|
return
|
|
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=False,
|
|
error=_build_error("unsupported_method", f"不支持的聊天方法: {method}"),
|
|
)
|
|
|
|
|
|
async def _handle_call(connection_id: str, message: Dict[str, Any]) -> None:
|
|
"""处理统一调用请求。
|
|
|
|
Args:
|
|
connection_id: 连接 ID。
|
|
message: 客户端消息。
|
|
"""
|
|
request_id = cast(Optional[str], message.get("id"))
|
|
domain = str(message.get("domain") or "").strip()
|
|
if domain == "chat":
|
|
await _handle_chat_call(connection_id, message)
|
|
return
|
|
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=False,
|
|
error=_build_error("unsupported_domain", f"不支持的调用域: {domain}"),
|
|
)
|
|
|
|
|
|
async def handle_client_message(connection_id: str, message: Dict[str, Any]) -> None:
|
|
"""处理统一 WebSocket 客户端消息。
|
|
|
|
Args:
|
|
connection_id: 连接 ID。
|
|
message: 客户端消息。
|
|
"""
|
|
operation = str(message.get("op") or "").strip()
|
|
request_id = cast(Optional[str], message.get("id"))
|
|
|
|
if operation == "ping":
|
|
await websocket_manager.send_pong(connection_id, time.time())
|
|
return
|
|
|
|
if operation == "subscribe":
|
|
await _handle_subscribe(connection_id, message)
|
|
return
|
|
|
|
if operation == "unsubscribe":
|
|
await _handle_unsubscribe(connection_id, message)
|
|
return
|
|
|
|
if operation == "call":
|
|
await _handle_call(connection_id, message)
|
|
return
|
|
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=request_id,
|
|
ok=False,
|
|
error=_build_error("unsupported_operation", f"不支持的操作: {operation}"),
|
|
)
|
|
|
|
|
|
@router.websocket("/ws")
|
|
async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
|
|
"""统一 WebSocket 入口。
|
|
|
|
Args:
|
|
websocket: FastAPI WebSocket 对象。
|
|
token: 可选的一次性握手 Token。
|
|
"""
|
|
if not await authenticate_websocket_connection(websocket, token):
|
|
logger.warning("统一 WebSocket 连接被拒绝:认证失败")
|
|
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
|
return
|
|
|
|
connection_id = uuid.uuid4().hex
|
|
await websocket_manager.connect(connection_id=connection_id, websocket=websocket)
|
|
logger.info("统一 WebSocket 客户端已连接: connection=%s", connection_id)
|
|
await websocket_manager.send_event(
|
|
connection_id,
|
|
domain="system",
|
|
event="ready",
|
|
data={"connection_id": connection_id, "timestamp": time.time()},
|
|
)
|
|
|
|
try:
|
|
while True:
|
|
raw_message = await websocket.receive_json()
|
|
if not isinstance(raw_message, dict):
|
|
await websocket_manager.send_response(
|
|
connection_id,
|
|
request_id=None,
|
|
ok=False,
|
|
error=_build_error("invalid_message", "消息必须是 JSON 对象"),
|
|
)
|
|
continue
|
|
await handle_client_message(connection_id, cast(Dict[str, Any], raw_message))
|
|
except WebSocketDisconnect:
|
|
logger.info("统一 WebSocket 客户端已断开: connection=%s", connection_id)
|
|
except Exception as exc:
|
|
logger.error(f"统一 WebSocket 处理失败: {exc}")
|
|
finally:
|
|
chat_manager.disconnect_connection(connection_id)
|
|
await websocket_manager.disconnect(connection_id)
|