feat: add unified WebSocket connection manager and routing

- Implemented UnifiedWebSocketManager for managing WebSocket connections, including subscription handling and message sending.
- Created unified WebSocket router to handle client messages, including authentication, subscription, and chat session management.
- Added support for logging and plugin progress subscriptions.
- Enhanced error handling and response structure for WebSocket operations.
This commit is contained in:
DrSmoothl
2026-04-02 22:08:52 +08:00
parent 7d0d429640
commit 1906890b67
28 changed files with 3845 additions and 1137 deletions

View File

@@ -9,6 +9,7 @@ from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.routers.websocket.manager import websocket_manager
logger = get_logger("webui.logs_ws")
router = APIRouter()
@@ -148,24 +149,9 @@ async def broadcast_log(log_data: Dict):
Args:
log_data: 日志数据字典
"""
if not active_connections:
return
# 格式化为 JSON
message = json.dumps(log_data, ensure_ascii=False)
# 记录需要断开的连接
disconnected = set()
# 广播到所有客户端
for connection in active_connections:
try:
await connection.send_text(message)
except Exception:
# 发送失败,标记为断开
disconnected.add(connection)
# 清理断开的连接
if disconnected:
active_connections.difference_update(disconnected)
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")
await websocket_manager.broadcast_to_topic(
domain="logs",
topic="main",
event="entry",
data={"entry": log_data},
)

View File

@@ -18,12 +18,10 @@ def get_all_routers() -> List[APIRouter]:
from src.webui.api.replier import router as replier_router
from src.webui.routers.chat import router as chat_router
from src.webui.routers.knowledge import router as knowledge_router
from src.webui.routers.websocket.logs import router as logs_router
from src.webui.routes import router as main_router
return [
main_router,
logs_router,
knowledge_router,
chat_router,
planner_router,

View File

@@ -1,7 +1,7 @@
from typing import Tuple
from .routes import router
from .support import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
from .service import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
def get_webui_chat_broadcaster() -> Tuple[ChatConnectionManager, str]:

View File

@@ -1,9 +1,8 @@
"""本地聊天室路由 - WebUI 与麦麦直接对话。"""
import uuid
from typing import Dict, Optional
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Depends, Query
from sqlalchemy import case, func
from sqlmodel import col, select
@@ -13,16 +12,11 @@ from src.common.logger import get_logger
from src.config.config import global_config
from src.webui.dependencies import require_auth
from .support import (
from .service import (
WEBUI_CHAT_GROUP_ID,
WEBUI_CHAT_PLATFORM,
authenticate_chat_websocket,
chat_history,
chat_manager,
dispatch_chat_event,
normalize_webui_user_id,
resolve_initial_virtual_identity,
send_initial_chat_state,
)
logger = get_logger("webui.chat")
@@ -113,55 +107,6 @@ async def clear_chat_history(
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
@router.websocket("/ws")
async def websocket_chat(
websocket: WebSocket,
user_id: Optional[str] = Query(default=None),
user_name: Optional[str] = Query(default="WebUI用户"),
platform: Optional[str] = Query(default=None),
person_id: Optional[str] = Query(default=None),
group_name: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None),
token: Optional[str] = Query(default=None),
) -> None:
"""WebSocket 聊天端点。"""
if not await authenticate_chat_websocket(websocket, token):
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
session_id = str(uuid.uuid4())
normalized_user_id = normalize_webui_user_id(user_id)
current_user_name = user_name or "WebUI用户"
current_virtual_config = resolve_initial_virtual_identity(platform, person_id, group_name, group_id)
await chat_manager.connect(websocket, session_id, normalized_user_id)
try:
await send_initial_chat_state(
session_id=session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
)
while True:
data = await websocket.receive_json()
current_user_name, current_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data=data,
current_user_name=current_user_name,
normalized_user_id=normalized_user_id,
current_virtual_config=current_virtual_config,
)
except WebSocketDisconnect:
logger.info(f"WebSocket 断开: session={session_id}, user={normalized_user_id}")
except Exception as e:
logger.error(f"WebSocket 错误: {e}")
finally:
chat_manager.disconnect(session_id, normalized_user_id)
@router.get("/info")
async def get_chat_info() -> Dict[str, object]:
"""获取聊天室信息。"""

View File

@@ -1,10 +1,10 @@
"""WebUI 聊天路由支持逻辑"""
"""WebUI 聊天运行时服务"""
from dataclasses import dataclass
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, cast
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, cast
from fastapi import WebSocket
from pydantic import BaseModel
from sqlmodel import col, delete, select
@@ -17,8 +17,6 @@ from src.common.logger import get_logger
from src.common.message_repository import find_messages
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.chat")
@@ -27,6 +25,8 @@ WEBUI_CHAT_PLATFORM = "webui"
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
WEBUI_USER_ID_PREFIX = "webui_user_"
AsyncMessageSender = Callable[[Dict[str, Any]], Awaitable[None]]
class VirtualIdentityConfig(BaseModel):
"""虚拟身份配置。"""
@@ -52,13 +52,42 @@ class ChatHistoryMessage(BaseModel):
is_bot: bool = False
@dataclass
class ChatSessionConnection:
"""逻辑聊天会话连接信息。"""
session_id: str
connection_id: str
client_session_id: str
user_id: str
user_name: str
active_group_id: str
virtual_config: Optional[VirtualIdentityConfig]
sender: AsyncMessageSender
class ChatHistoryManager:
"""聊天历史管理器。"""
def __init__(self, max_messages: int = 200) -> None:
"""初始化聊天历史管理器。
Args:
max_messages: 内存中允许处理的最大消息数
"""
self.max_messages = max_messages
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> Dict[str, Any]:
"""将内部消息对象转换为前端可消费的字典。
Args:
msg: 内部统一消息对象
group_id: 当前会话所属的群组标识
Returns:
Dict[str, Any]: 面向 WebUI 的消息字典
"""
del group_id
user_info = msg.message_info.user_info
user_id = user_info.user_id or ""
is_bot = is_bot_self(msg.platform, user_id)
@@ -74,10 +103,27 @@ class ChatHistoryManager:
}
def _resolve_session_id(self, group_id: Optional[str]) -> str:
"""根据群组标识解析聊天会话 ID。
Args:
group_id: 群组标识
Returns:
str: 内部聊天会话 ID
"""
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""获取指定会话的历史消息。
Args:
limit: 最大返回条数
group_id: 群组标识
Returns:
List[Dict[str, Any]]: 历史消息列表
"""
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
session_id = self._resolve_session_id(target_group_id)
try:
@@ -90,11 +136,19 @@ class ChatHistoryManager:
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
return result
except Exception as e:
logger.error(f"从数据库加载聊天记录失败: {e}")
except Exception as exc:
logger.error(f"从数据库加载聊天记录失败: {exc}")
return []
def clear_history(self, group_id: Optional[str] = None) -> int:
"""清空指定会话的历史消息。
Args:
group_id: 群组标识
Returns:
int: 被删除的消息数量
"""
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
session_id = self._resolve_session_id(target_group_id)
try:
@@ -104,66 +158,245 @@ class ChatHistoryManager:
deleted = result.rowcount or 0
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
return deleted
except Exception as e:
logger.error(f"清空聊天记录失败: {e}")
except Exception as exc:
logger.error(f"清空聊天记录失败: {exc}")
return 0
class ChatConnectionManager:
"""聊天连接管理器。"""
"""统一聊天逻辑会话管理器。"""
def __init__(self) -> None:
self.active_connections: Dict[str, WebSocket] = {}
self.user_sessions: Dict[str, str] = {}
"""初始化聊天逻辑会话管理器。"""
self.active_connections: Dict[str, ChatSessionConnection] = {}
self.client_sessions: Dict[Tuple[str, str], str] = {}
self.connection_sessions: Dict[str, Set[str]] = {}
self.group_sessions: Dict[str, Set[str]] = {}
self.user_sessions: Dict[str, Set[str]] = {}
async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None:
await websocket.accept()
self.active_connections[session_id] = websocket
self.user_sessions[user_id] = session_id
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
def _bind_group(self, session_id: str, group_id: str) -> None:
"""为会话绑定群组索引。
def disconnect(self, session_id: str, user_id: str) -> None:
if session_id in self.active_connections:
del self.active_connections[session_id]
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
del self.user_sessions[user_id]
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
Args:
session_id: 内部会话 ID
group_id: 群组标识
"""
group_session_ids = self.group_sessions.setdefault(group_id, set())
group_session_ids.add(session_id)
def _unbind_group(self, session_id: str, group_id: str) -> None:
"""移除会话与群组的索引关系。
Args:
session_id: 内部会话 ID
group_id: 群组标识
"""
group_session_ids = self.group_sessions.get(group_id)
if group_session_ids is None:
return
group_session_ids.discard(session_id)
if not group_session_ids:
del self.group_sessions[group_id]
async def connect(
self,
session_id: str,
connection_id: str,
client_session_id: str,
user_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
sender: AsyncMessageSender,
) -> None:
"""注册一个新的逻辑聊天会话。
Args:
session_id: 内部逻辑会话 ID
connection_id: 物理 WebSocket 连接 ID
client_session_id: 前端标签页使用的会话 ID
user_id: 规范化后的用户 ID
user_name: 当前展示昵称
virtual_config: 当前虚拟身份配置
sender: 发送消息到前端的异步回调
"""
existing_session_id = self.client_sessions.get((connection_id, client_session_id))
if existing_session_id is not None:
self.disconnect(existing_session_id)
active_group_id = get_current_group_id(virtual_config)
session_connection = ChatSessionConnection(
session_id=session_id,
connection_id=connection_id,
client_session_id=client_session_id,
user_id=user_id,
user_name=user_name,
active_group_id=active_group_id,
virtual_config=virtual_config,
sender=sender,
)
self.active_connections[session_id] = session_connection
self.client_sessions[(connection_id, client_session_id)] = session_id
self.connection_sessions.setdefault(connection_id, set()).add(session_id)
self.user_sessions.setdefault(user_id, set()).add(session_id)
self._bind_group(session_id, active_group_id)
logger.info(
"WebUI 聊天会话已连接: session=%s, connection=%s, client_session=%s, user=%s, group=%s",
session_id,
connection_id,
client_session_id,
user_id,
active_group_id,
)
def disconnect(self, session_id: str) -> None:
"""断开一个逻辑聊天会话。
Args:
session_id: 内部逻辑会话 ID
"""
session_connection = self.active_connections.pop(session_id, None)
if session_connection is None:
return
self.client_sessions.pop((session_connection.connection_id, session_connection.client_session_id), None)
self._unbind_group(session_id, session_connection.active_group_id)
connection_session_ids = self.connection_sessions.get(session_connection.connection_id)
if connection_session_ids is not None:
connection_session_ids.discard(session_id)
if not connection_session_ids:
del self.connection_sessions[session_connection.connection_id]
user_session_ids = self.user_sessions.get(session_connection.user_id)
if user_session_ids is not None:
user_session_ids.discard(session_id)
if not user_session_ids:
del self.user_sessions[session_connection.user_id]
logger.info("WebUI 聊天会话已断开: session=%s", session_id)
def disconnect_connection(self, connection_id: str) -> None:
"""断开物理连接下的全部逻辑聊天会话。
Args:
connection_id: 物理 WebSocket 连接 ID
"""
session_ids = list(self.connection_sessions.get(connection_id, set()))
for session_id in session_ids:
self.disconnect(session_id)
def get_session(self, session_id: str) -> Optional[ChatSessionConnection]:
"""获取逻辑聊天会话信息。
Args:
session_id: 内部逻辑会话 ID
Returns:
Optional[ChatSessionConnection]: 会话存在时返回对应信息
"""
return self.active_connections.get(session_id)
def get_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
"""根据连接 ID 和前端会话 ID 查询内部会话 ID。
Args:
connection_id: 物理 WebSocket 连接 ID
client_session_id: 前端标签页使用的会话 ID
Returns:
Optional[str]: 找到时返回内部会话 ID
"""
return self.client_sessions.get((connection_id, client_session_id))
def update_session_context(
self,
session_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> None:
"""更新会话上下文信息。
Args:
session_id: 内部逻辑会话 ID
user_name: 最新昵称
virtual_config: 最新虚拟身份配置
"""
session_connection = self.active_connections.get(session_id)
if session_connection is None:
return
next_group_id = get_current_group_id(virtual_config)
if next_group_id != session_connection.active_group_id:
self._unbind_group(session_id, session_connection.active_group_id)
self._bind_group(session_id, next_group_id)
session_connection.active_group_id = next_group_id
session_connection.user_name = user_name
session_connection.virtual_config = virtual_config
async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
if session_id in self.active_connections:
try:
await self.active_connections[session_id].send_json(message)
except Exception as e:
logger.error(f"发送消息失败: {e}")
"""向指定逻辑会话发送消息。
Args:
session_id: 内部逻辑会话 ID
message: 发送消息内容
"""
session_connection = self.active_connections.get(session_id)
if session_connection is None:
return
try:
await session_connection.sender(message)
except Exception as exc:
logger.error("发送聊天消息失败: session=%s, error=%s", session_id, exc)
async def broadcast(self, message: Dict[str, Any]) -> None:
"""向全部逻辑聊天会话广播消息。
Args:
message: 待广播的消息内容
"""
for session_id in list(self.active_connections.keys()):
await self.send_message(session_id, message)
async def broadcast_to_group(self, group_id: str, message: Dict[str, Any]) -> None:
"""向指定群组下的全部逻辑会话广播消息。
Args:
group_id: 群组标识
message: 待广播的消息内容
"""
for session_id in list(self.group_sessions.get(group_id, set())):
await self.send_message(session_id, message)
chat_history = ChatHistoryManager()
chat_manager = ChatConnectionManager()
def is_virtual_mode_enabled(virtual_config: Optional[VirtualIdentityConfig]) -> bool:
"""判断当前是否启用了虚拟身份模式。
Args:
virtual_config: 虚拟身份配置
Returns:
bool: 已启用时返回 ``True``
"""
return bool(virtual_config and virtual_config.enabled)
async def authenticate_chat_websocket(websocket: WebSocket, token: Optional[str]) -> bool:
if token and verify_ws_token(token):
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
return True
if cookie_token := websocket.cookies.get("maibot_session"):
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
return True
return False
def normalize_webui_user_id(user_id: Optional[str]) -> str:
"""标准化 WebUI 用户 ID。
Args:
user_id: 原始用户 ID
Returns:
str: 带统一前缀的用户 ID
"""
if not user_id:
return f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
if user_id.startswith(WEBUI_USER_ID_PREFIX):
@@ -172,12 +405,30 @@ def normalize_webui_user_id(user_id: Optional[str]) -> str:
def get_person_by_person_id(person_id: str) -> Optional[PersonInfo]:
"""根据人物 ID 查询人物信息。
Args:
person_id: 人物 ID
Returns:
Optional[PersonInfo]: 查到时返回人物信息
"""
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
return session.exec(statement).first()
def build_virtual_identity_config(person: PersonInfo, group_id: str, group_name: str) -> VirtualIdentityConfig:
"""根据人物信息构建虚拟身份配置。
Args:
person: 人物信息对象
group_id: 逻辑群组 ID
group_name: 逻辑群组名称
Returns:
VirtualIdentityConfig: 虚拟身份配置对象
"""
return VirtualIdentityConfig(
enabled=True,
platform=person.platform,
@@ -195,6 +446,17 @@ def resolve_initial_virtual_identity(
group_name: Optional[str],
group_id: Optional[str],
) -> Optional[VirtualIdentityConfig]:
"""根据初始参数解析虚拟身份配置。
Args:
platform: 平台名称
person_id: 人物 ID
group_name: 群组名称
group_id: 群组 ID
Returns:
Optional[VirtualIdentityConfig]: 解析成功时返回虚拟身份配置
"""
if not (platform and person_id):
return None
@@ -210,11 +472,14 @@ def resolve_initial_virtual_identity(
group_name=group_name or "WebUI虚拟群聊",
)
logger.info(
f"虚拟身份模式已通过 URL 参数激活: {virtual_config.user_nickname} @ {virtual_config.platform}, group_id={virtual_group_id}"
"虚拟身份模式已通过参数激活: %s @ %s, group_id=%s",
virtual_config.user_nickname,
virtual_config.platform,
virtual_group_id,
)
return virtual_config
except Exception as e:
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
except Exception as exc:
logger.warning(f"通过参数配置虚拟身份失败: {exc}")
return None
@@ -224,6 +489,17 @@ def build_session_info_message(
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> Dict[str, Any]:
"""构建会话信息消息。
Args:
session_id: 内部逻辑会话 ID
user_id: 规范化后的用户 ID
user_name: 当前昵称
virtual_config: 虚拟身份配置
Returns:
Dict[str, Any]: 会话信息消息
"""
session_info_data: Dict[str, Any] = {
"type": "session_info",
"session_id": session_id,
@@ -247,13 +523,41 @@ def build_session_info_message(
def get_active_history_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> Optional[str]:
"""获取当前虚拟身份对应的历史群组 ID。
Args:
virtual_config: 虚拟身份配置
Returns:
Optional[str]: 虚拟身份启用时返回对应群组 ID
"""
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return virtual_config.group_id
return None
def get_current_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> str:
"""获取当前会话的有效群组 ID。
Args:
virtual_config: 虚拟身份配置
Returns:
str: 当前会话应使用的群组 ID
"""
return get_active_history_group_id(virtual_config) or WEBUI_CHAT_GROUP_ID
def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> str:
"""构建欢迎消息。
Args:
virtual_config: 虚拟身份配置
Returns:
str: 欢迎消息文本
"""
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return (
@@ -264,6 +568,12 @@ def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> st
async def send_chat_error(session_id: str, content: str) -> None:
"""向指定会话发送错误消息。
Args:
session_id: 内部逻辑会话 ID
content: 错误消息内容
"""
await chat_manager.send_message(
session_id,
{
@@ -279,7 +589,17 @@ async def send_initial_chat_state(
user_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
include_welcome: bool = True,
) -> None:
"""向新会话发送初始化状态。
Args:
session_id: 内部逻辑会话 ID
user_id: 规范化后的用户 ID
user_name: 当前昵称
virtual_config: 虚拟身份配置
include_welcome: 是否发送欢迎消息
"""
await chat_manager.send_message(
session_id,
build_session_info_message(
@@ -290,30 +610,43 @@ async def send_initial_chat_state(
),
)
if history := chat_history.get_history(50, get_active_history_group_id(virtual_config)):
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": history,
},
)
history_group_id = get_active_history_group_id(virtual_config)
history = chat_history.get_history(50, history_group_id)
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": build_welcome_message(virtual_config),
"timestamp": time.time(),
"type": "history",
"messages": history,
"group_id": get_current_group_id(virtual_config),
},
)
if include_welcome:
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": build_welcome_message(virtual_config),
"timestamp": time.time(),
},
)
def resolve_sender_identity(
current_user_name: str,
normalized_user_id: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> Tuple[str, str]:
"""解析当前发送者身份。
Args:
current_user_name: 当前昵称
normalized_user_id: 规范化后的用户 ID
virtual_config: 虚拟身份配置
Returns:
Tuple[str, str]: ``(发送者昵称, 发送者用户 ID)``
"""
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id
@@ -328,6 +661,19 @@ def create_message_data(
is_at_bot: bool = True,
virtual_config: Optional[VirtualIdentityConfig] = None,
) -> Dict[str, Any]:
"""构建发送给聊天核心的消息数据。
Args:
content: 文本内容
user_id: 用户 ID
user_name: 用户昵称
message_id: 消息 ID
is_at_bot: 是否默认艾特机器人
virtual_config: 虚拟身份配置
Returns:
Dict[str, Any]: 聊天核心可处理的消息数据
"""
if message_id is None:
message_id = str(uuid.uuid4())
@@ -389,6 +735,18 @@ async def handle_chat_message(
normalized_user_id: str,
current_virtual_config: Optional[VirtualIdentityConfig],
) -> str:
"""处理用户发送的聊天消息。
Args:
session_id: 内部逻辑会话 ID
data: 前端提交的消息数据
current_user_name: 当前昵称
normalized_user_id: 规范化后的用户 ID
current_virtual_config: 当前虚拟身份配置
Returns:
str: 处理后的最新昵称
"""
content = str(data.get("content", "")).strip()
if not content:
return current_user_name
@@ -401,11 +759,14 @@ async def handle_chat_message(
normalized_user_id=normalized_user_id,
virtual_config=current_virtual_config,
)
target_group_id = get_current_group_id(current_virtual_config)
await chat_manager.broadcast(
await chat_manager.broadcast_to_group(
target_group_id,
{
"type": "user_message",
"content": content,
"group_id": target_group_id,
"message_id": message_id,
"timestamp": timestamp,
"sender": {
@@ -414,7 +775,7 @@ async def handle_chat_message(
"is_bot": False,
},
"virtual_mode": is_virtual_mode_enabled(current_virtual_config),
}
},
)
message_data = create_message_data(
@@ -427,22 +788,37 @@ async def handle_chat_message(
)
try:
await chat_manager.broadcast({"type": "typing", "is_typing": True})
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": True})
await chat_bot.message_process(message_data)
except Exception as e:
logger.error(f"处理消息时出错: {e}")
await send_chat_error(session_id, f"处理消息时出错: {str(e)}")
except Exception as exc:
logger.error(f"处理消息时出错: {exc}")
await send_chat_error(session_id, f"处理消息时出错: {str(exc)}")
finally:
await chat_manager.broadcast({"type": "typing", "is_typing": False})
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": False})
return next_user_name
async def handle_chat_ping(session_id: str) -> None:
"""处理聊天心跳。
Args:
session_id: 内部逻辑会话 ID
"""
await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()})
async def handle_nickname_update(session_id: str, data: Dict[str, Any], current_user_name: str) -> str:
"""处理昵称更新请求。
Args:
session_id: 内部逻辑会话 ID
data: 前端提交的数据
current_user_name: 当前昵称
Returns:
str: 更新后的昵称
"""
new_name = str(data.get("user_name", "")).strip()
if not new_name:
return current_user_name
@@ -463,6 +839,16 @@ async def enable_virtual_identity(
session_prefix: str,
virtual_data: Dict[str, Any],
) -> Optional[VirtualIdentityConfig]:
"""启用虚拟身份模式。
Args:
session_id: 内部逻辑会话 ID
session_prefix: 会话前缀用于生成默认群组 ID
virtual_data: 前端提交的虚拟身份配置
Returns:
Optional[VirtualIdentityConfig]: 启用成功时返回新的虚拟身份配置
"""
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id")
return None
@@ -470,16 +856,18 @@ async def enable_virtual_identity(
person_id_value = str(virtual_data.get("person_id"))
try:
person = get_person_by_person_id(person_id_value)
if not person:
if person is None:
await send_chat_error(session_id, f"找不到用户: {person_id_value}")
return None
custom_group_id = virtual_data.get("group_id")
current_group_id = (
f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
if custom_group_id
else f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
)
custom_group_id = str(virtual_data.get("group_id") or "").strip()
if custom_group_id:
current_group_id = custom_group_id
if not current_group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{current_group_id}"
else:
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
current_virtual_config = build_virtual_identity_config(
person=person,
group_id=current_group_id,
@@ -521,13 +909,18 @@ async def enable_virtual_identity(
},
)
return current_virtual_config
except Exception as e:
logger.error(f"设置虚拟身份失败: {e}")
await send_chat_error(session_id, f"设置虚拟身份失败: {str(e)}")
except Exception as exc:
logger.error(f"设置虚拟身份失败: {exc}")
await send_chat_error(session_id, f"设置虚拟身份失败: {str(exc)}")
return None
async def disable_virtual_identity(session_id: str) -> None:
"""关闭虚拟身份模式。
Args:
session_id: 内部逻辑会话 ID
"""
await chat_manager.send_message(
session_id,
{
@@ -560,7 +953,18 @@ async def handle_virtual_identity_update(
data: Dict[str, Any],
current_virtual_config: Optional[VirtualIdentityConfig],
) -> Optional[VirtualIdentityConfig]:
virtual_data = cast(dict[str, Any], data.get("config", {}))
"""处理虚拟身份切换请求。
Args:
session_id: 内部逻辑会话 ID
session_id_prefix: 会话前缀
data: 前端提交的数据
current_virtual_config: 当前虚拟身份配置
Returns:
Optional[VirtualIdentityConfig]: 更新后的虚拟身份配置
"""
virtual_data = cast(Dict[str, Any], data.get("config", {}))
if virtual_data.get("enabled"):
next_config = await enable_virtual_identity(session_id, session_id_prefix, virtual_data)
return next_config if next_config is not None else current_virtual_config
@@ -577,6 +981,19 @@ async def dispatch_chat_event(
normalized_user_id: str,
current_virtual_config: Optional[VirtualIdentityConfig],
) -> Tuple[str, Optional[VirtualIdentityConfig]]:
"""分发聊天事件到对应的处理器。
Args:
session_id: 内部逻辑会话 ID
session_id_prefix: 会话前缀
data: 前端提交的数据
current_user_name: 当前昵称
normalized_user_id: 规范化后的用户 ID
current_virtual_config: 当前虚拟身份配置
Returns:
Tuple[str, Optional[VirtualIdentityConfig]]: ``(最新昵称, 最新虚拟身份配置)``
"""
event_type = data.get("type")
if event_type == "message":
next_user_name = await handle_chat_message(

View File

@@ -1,12 +1,15 @@
"""插件进度实时推送支持。"""
from typing import Any, Dict, Optional, Set
import asyncio
import json
from typing import Any, Dict, Optional, Set
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.routers.websocket.manager import websocket_manager
logger = get_logger("webui.plugin_progress")
@@ -25,25 +28,29 @@ current_progress: Dict[str, Any] = {
}
def get_current_progress() -> Dict[str, Any]:
"""获取当前插件进度快照。
Returns:
Dict[str, Any]: 当前插件进度数据副本。
"""
return current_progress.copy()
async def broadcast_progress(progress_data: Dict[str, Any]) -> None:
"""向统一连接层广播插件进度更新。
Args:
progress_data: 插件进度数据。
"""
global current_progress
current_progress = progress_data.copy()
if not active_connections:
return
message = json.dumps(progress_data, ensure_ascii=False)
disconnected: Set[WebSocket] = set()
for websocket in active_connections:
try:
await websocket.send_text(message)
except Exception as e:
logger.error(f"发送进度更新失败: {e}")
disconnected.add(websocket)
for websocket in disconnected:
active_connections.discard(websocket)
await websocket_manager.broadcast_to_topic(
domain="plugin_progress",
topic="main",
event="update",
data={"progress": progress_data},
)
async def update_progress(
@@ -56,6 +63,18 @@ async def update_progress(
total_plugins: int = 0,
loaded_plugins: int = 0,
) -> None:
"""更新当前插件进度并广播。
Args:
stage: 当前阶段。
progress: 当前进度百分比。
message: 进度说明消息。
operation: 当前操作类型。
error: 可选的错误信息。
plugin_id: 当前处理的插件 ID。
total_plugins: 总插件数量。
loaded_plugins: 已处理插件数量。
"""
progress_data = {
"operation": operation,
"stage": stage,
@@ -74,6 +93,12 @@ async def update_progress(
@router.websocket("/ws/plugin-progress")
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
"""旧版插件进度 WebSocket 入口。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
"""
is_authenticated = False
if token and verify_ws_token(token):
@@ -105,17 +130,22 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
data = await websocket.receive_text()
if data == "ping":
await websocket.send_text("pong")
except Exception as e:
logger.error(f"处理客户端消息时出错: {e}")
except Exception as exc:
logger.error(f"处理客户端消息时出错: {exc}")
break
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
except Exception as e:
logger.error(f"❌ WebSocket 错误: {e}")
except Exception as exc:
logger.error(f"❌ WebSocket 错误: {exc}")
active_connections.discard(websocket)
def get_progress_router() -> APIRouter:
"""获取旧版插件进度路由对象。
Returns:
APIRouter: 插件进度路由对象。
"""
return router

View File

@@ -1,7 +1,9 @@
"""WebSocket 路由聚合导出。"""
from .auth import router as ws_auth_router
from .logs import router as logs_router
from .unified import router as unified_ws_router
__all__ = [
"logs_router",
"unified_ws_router",
"ws_auth_router",
]

View File

@@ -1,11 +0,0 @@
"""WebSocket 日志推送路由兼容导出。"""
from src.webui.logs_ws import active_connections, broadcast_log, load_recent_logs, router, websocket_logs
__all__ = [
"active_connections",
"broadcast_log",
"load_recent_logs",
"router",
"websocket_logs",
]

View File

@@ -0,0 +1,297 @@
"""统一 WebSocket 连接管理器。"""
import asyncio
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Set
from fastapi import WebSocket
from src.common.logger import get_logger
logger = get_logger("webui.websocket")
@dataclass
class WebSocketConnection:
"""统一 WebSocket 连接上下文。"""
connection_id: str
websocket: WebSocket
subscriptions: Set[str] = field(default_factory=set)
chat_sessions: Dict[str, str] = field(default_factory=dict)
send_queue: "asyncio.Queue[Optional[Dict[str, Any]]]" = field(default_factory=asyncio.Queue)
sender_task: Optional["asyncio.Task[None]"] = None
class UnifiedWebSocketManager:
"""统一 WebSocket 连接管理器。"""
def __init__(self) -> None:
"""初始化统一 WebSocket 连接管理器。"""
self.connections: Dict[str, WebSocketConnection] = {}
def _build_subscription_key(self, domain: str, topic: str) -> str:
"""构建订阅索引键。
Args:
domain: 业务域名称。
topic: 主题名称。
Returns:
str: 订阅索引键。
"""
return f"{domain}:{topic}"
async def _sender_loop(self, connection: WebSocketConnection) -> None:
"""串行发送指定连接的出站消息。
Args:
connection: 目标连接上下文。
"""
try:
while True:
message = await connection.send_queue.get()
if message is None:
return
await connection.websocket.send_json(message)
except asyncio.CancelledError:
raise
except Exception as exc:
logger.error("统一 WebSocket 发送失败: connection=%s, error=%s", connection.connection_id, exc)
async def connect(self, connection_id: str, websocket: WebSocket) -> WebSocketConnection:
"""注册一个新的物理 WebSocket 连接。
Args:
connection_id: 连接 ID。
websocket: FastAPI WebSocket 对象。
Returns:
WebSocketConnection: 新建的连接上下文。
"""
await websocket.accept()
connection = WebSocketConnection(connection_id=connection_id, websocket=websocket)
connection.sender_task = asyncio.create_task(self._sender_loop(connection))
self.connections[connection_id] = connection
return connection
async def disconnect(self, connection_id: str) -> None:
"""断开并清理指定连接。
Args:
connection_id: 连接 ID。
"""
connection = self.connections.pop(connection_id, None)
if connection is None:
return
await connection.send_queue.put(None)
if connection.sender_task is not None:
try:
await connection.sender_task
except asyncio.CancelledError:
pass
except Exception as exc:
logger.debug("等待发送协程退出时出现异常: connection=%s, error=%s", connection_id, exc)
def get_connection(self, connection_id: str) -> Optional[WebSocketConnection]:
"""获取指定连接上下文。
Args:
connection_id: 连接 ID。
Returns:
Optional[WebSocketConnection]: 找到时返回连接上下文。
"""
return self.connections.get(connection_id)
def register_chat_session(self, connection_id: str, client_session_id: str, session_id: str) -> None:
"""登记连接下的逻辑聊天会话。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
session_id: 内部会话 ID。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.chat_sessions[client_session_id] = session_id
def unregister_chat_session(self, connection_id: str, client_session_id: str) -> None:
"""移除连接下的逻辑聊天会话登记。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.chat_sessions.pop(client_session_id, None)
def get_chat_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
"""查询连接下的内部聊天会话 ID。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
Returns:
Optional[str]: 找到时返回内部会话 ID。
"""
connection = self.connections.get(connection_id)
if connection is None:
return None
return connection.chat_sessions.get(client_session_id)
def subscribe(self, connection_id: str, domain: str, topic: str) -> None:
"""登记连接的主题订阅。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
topic: 主题名称。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.subscriptions.add(self._build_subscription_key(domain, topic))
def unsubscribe(self, connection_id: str, domain: str, topic: str) -> None:
"""移除连接的主题订阅。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
topic: 主题名称。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.subscriptions.discard(self._build_subscription_key(domain, topic))
def is_subscribed(self, connection_id: str, domain: str, topic: str) -> bool:
"""判断连接是否订阅了指定主题。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
topic: 主题名称。
Returns:
bool: 已订阅时返回 ``True``。
"""
connection = self.connections.get(connection_id)
if connection is None:
return False
return self._build_subscription_key(domain, topic) in connection.subscriptions
async def enqueue(self, connection_id: str, message: Dict[str, Any]) -> None:
"""向指定连接的发送队列压入消息。
Args:
connection_id: 连接 ID。
message: 待发送的消息。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
await connection.send_queue.put(message)
async def send_response(
self,
connection_id: str,
request_id: Optional[str],
ok: bool,
data: Optional[Dict[str, Any]] = None,
error: Optional[Dict[str, Any]] = None,
) -> None:
"""发送统一响应消息。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
ok: 请求是否成功。
data: 成功响应数据。
error: 失败响应数据。
"""
response_message: Dict[str, Any] = {
"op": "response",
"id": request_id,
"ok": ok,
}
if data is not None:
response_message["data"] = data
if error is not None:
response_message["error"] = error
await self.enqueue(connection_id, response_message)
async def send_event(
self,
connection_id: str,
domain: str,
event: str,
data: Dict[str, Any],
session: Optional[str] = None,
topic: Optional[str] = None,
) -> None:
"""发送统一事件消息。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
event: 事件名称。
data: 事件数据。
session: 可选的逻辑会话 ID。
topic: 可选的主题名称。
"""
event_message: Dict[str, Any] = {
"op": "event",
"domain": domain,
"event": event,
"data": data,
}
if session is not None:
event_message["session"] = session
if topic is not None:
event_message["topic"] = topic
await self.enqueue(connection_id, event_message)
async def send_pong(self, connection_id: str, timestamp: float) -> None:
"""发送心跳响应。
Args:
connection_id: 连接 ID。
timestamp: 当前时间戳。
"""
await self.enqueue(
connection_id,
{
"op": "pong",
"ts": timestamp,
},
)
async def broadcast_to_topic(self, domain: str, topic: str, event: str, data: Dict[str, Any]) -> None:
"""向订阅指定主题的全部连接广播事件。
Args:
domain: 业务域名称。
topic: 主题名称。
event: 事件名称。
data: 事件数据。
"""
subscription_key = self._build_subscription_key(domain, topic)
for connection in list(self.connections.values()):
if subscription_key in connection.subscriptions:
await self.send_event(
connection.connection_id,
domain=domain,
event=event,
data=data,
topic=topic,
)
websocket_manager = UnifiedWebSocketManager()

View File

@@ -0,0 +1,548 @@
"""统一 WebSocket 路由。"""
from typing import Any, Dict, Optional, Set, cast
import asyncio
import time
import uuid
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.logs_ws import load_recent_logs
from src.webui.routers.chat.service import (
chat_manager,
dispatch_chat_event,
normalize_webui_user_id,
resolve_initial_virtual_identity,
send_initial_chat_state,
)
from src.webui.routers.plugin.progress import get_current_progress
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.routers.websocket.manager import websocket_manager
logger = get_logger("webui.unified_ws")
router = APIRouter()
_background_tasks: Set["asyncio.Task[None]"] = set()
def _build_error(code: str, message: str) -> Dict[str, Any]:
"""构建统一错误响应体。
Args:
code: 错误码。
message: 错误描述。
Returns:
Dict[str, Any]: 统一错误对象。
"""
return {
"code": code,
"message": message,
}
def _get_request_data(message: Dict[str, Any]) -> Dict[str, Any]:
"""从客户端消息中提取数据字段。
Args:
message: 客户端消息。
Returns:
Dict[str, Any]: 标准化后的数据字典。
"""
data = message.get("data", {})
if isinstance(data, dict):
return cast(Dict[str, Any], data)
return {}
def _track_background_task(task: "asyncio.Task[None]") -> None:
"""登记后台任务并在完成后自动清理。
Args:
task: 后台协程任务。
"""
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
async def authenticate_websocket_connection(websocket: WebSocket, token: Optional[str]) -> bool:
"""校验统一 WebSocket 连接的认证状态。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
Returns:
bool: 认证通过时返回 ``True``。
"""
if token and verify_ws_token(token):
logger.debug("统一 WebSocket 使用临时 token 认证成功")
return True
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
logger.debug("统一 WebSocket 使用 Cookie 认证成功")
return True
return False
async def _handle_logs_subscribe(connection_id: str, request_id: Optional[str], data: Dict[str, Any]) -> None:
"""处理日志域订阅请求。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
data: 订阅参数。
"""
replay_limit = int(data.get("replay", 100) or 100)
replay_limit = max(0, min(replay_limit, 500))
websocket_manager.subscribe(connection_id, domain="logs", topic="main")
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": "logs", "topic": "main"},
)
await websocket_manager.send_event(
connection_id,
domain="logs",
event="snapshot",
topic="main",
data={"entries": load_recent_logs(limit=replay_limit)},
)
async def _handle_plugin_progress_subscribe(connection_id: str, request_id: Optional[str]) -> None:
"""处理插件进度域订阅请求。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
"""
websocket_manager.subscribe(connection_id, domain="plugin_progress", topic="main")
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": "plugin_progress", "topic": "main"},
)
await websocket_manager.send_event(
connection_id,
domain="plugin_progress",
event="snapshot",
topic="main",
data={"progress": get_current_progress()},
)
async def _handle_subscribe(connection_id: str, message: Dict[str, Any]) -> None:
"""处理主题订阅请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
topic = str(message.get("topic") or "").strip()
data = _get_request_data(message)
if domain == "logs" and topic == "main":
await _handle_logs_subscribe(connection_id, request_id, data)
return
if domain == "plugin_progress" and topic == "main":
await _handle_plugin_progress_subscribe(connection_id, request_id)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_subscription", f"不支持的订阅目标: {domain}:{topic}"),
)
async def _handle_unsubscribe(connection_id: str, message: Dict[str, Any]) -> None:
"""处理主题退订请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
topic = str(message.get("topic") or "").strip()
if not domain or not topic:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("invalid_unsubscribe", "退订请求缺少 domain 或 topic"),
)
return
websocket_manager.unsubscribe(connection_id, domain=domain, topic=topic)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": domain, "topic": topic},
)
async def _open_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
"""打开一个逻辑聊天会话。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
if not client_session_id:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("missing_session", "聊天会话打开请求缺少 session"),
)
return
data = _get_request_data(message)
normalized_user_id = normalize_webui_user_id(cast(Optional[str], data.get("user_id")))
current_user_name = str(data.get("user_name") or "WebUI用户")
current_virtual_config = resolve_initial_virtual_identity(
platform=cast(Optional[str], data.get("platform")),
person_id=cast(Optional[str], data.get("person_id")),
group_name=cast(Optional[str], data.get("group_name")),
group_id=cast(Optional[str], data.get("group_id")),
)
restore = bool(data.get("restore"))
session_id = f"{connection_id}:{client_session_id}"
async def send_chat_event(chat_message: Dict[str, Any]) -> None:
"""将聊天消息封装为统一事件并发送。
Args:
chat_message: 聊天消息体。
"""
event_name = str(chat_message.get("type") or "message")
await websocket_manager.send_event(
connection_id,
domain="chat",
event=event_name,
session=client_session_id,
data=chat_message,
)
await chat_manager.connect(
session_id=session_id,
connection_id=connection_id,
client_session_id=client_session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
sender=send_chat_event,
)
websocket_manager.register_chat_session(connection_id, client_session_id, session_id)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id, "session_id": session_id},
)
await send_initial_chat_state(
session_id=session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
include_welcome=not restore,
)
async def _close_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
"""关闭一个逻辑聊天会话。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
chat_manager.disconnect(session_id)
websocket_manager.unregister_chat_session(connection_id, client_session_id)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id},
)
async def _process_chat_message(connection_id: str, client_session_id: str, data: Dict[str, Any]) -> None:
"""在后台处理聊天消息事件。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
data: 客户端提交的消息数据。
"""
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
return
session_state = chat_manager.get_session(session_id)
if session_state is None:
return
next_user_name, next_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data=data,
current_user_name=session_state.user_name,
normalized_user_id=session_state.user_id,
current_virtual_config=session_state.virtual_config,
)
chat_manager.update_session_context(
session_id=session_id,
user_name=next_user_name,
virtual_config=next_virtual_config,
)
async def _handle_chat_message_send(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天消息发送请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
data = _get_request_data(message)
payload = {
"type": "message",
"content": data.get("content", ""),
"user_name": data.get("user_name", ""),
}
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"accepted": True, "session": client_session_id},
)
_track_background_task(asyncio.create_task(_process_chat_message(connection_id, client_session_id, payload)))
async def _handle_chat_nickname_update(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天昵称更新请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
data = _get_request_data(message)
session_state = chat_manager.get_session(session_id)
if session_state is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
next_user_name, next_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data={
"type": "update_nickname",
"user_name": data.get("user_name", ""),
},
current_user_name=session_state.user_name,
normalized_user_id=session_state.user_id,
current_virtual_config=session_state.virtual_config,
)
chat_manager.update_session_context(
session_id=session_id,
user_name=next_user_name,
virtual_config=next_virtual_config,
)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id, "user_name": next_user_name},
)
async def _handle_chat_call(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天域调用请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
method = str(message.get("method") or "").strip()
if method == "session.open":
await _open_chat_session(connection_id, message)
return
if method == "session.close":
await _close_chat_session(connection_id, message)
return
if method == "message.send":
await _handle_chat_message_send(connection_id, message)
return
if method == "session.update_nickname":
await _handle_chat_nickname_update(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_method", f"不支持的聊天方法: {method}"),
)
async def _handle_call(connection_id: str, message: Dict[str, Any]) -> None:
"""处理统一调用请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
if domain == "chat":
await _handle_chat_call(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_domain", f"不支持的调用域: {domain}"),
)
async def handle_client_message(connection_id: str, message: Dict[str, Any]) -> None:
"""处理统一 WebSocket 客户端消息。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
operation = str(message.get("op") or "").strip()
request_id = cast(Optional[str], message.get("id"))
if operation == "ping":
await websocket_manager.send_pong(connection_id, time.time())
return
if operation == "subscribe":
await _handle_subscribe(connection_id, message)
return
if operation == "unsubscribe":
await _handle_unsubscribe(connection_id, message)
return
if operation == "call":
await _handle_call(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_operation", f"不支持的操作: {operation}"),
)
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
"""统一 WebSocket 入口。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
"""
if not await authenticate_websocket_connection(websocket, token):
logger.warning("统一 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
connection_id = uuid.uuid4().hex
await websocket_manager.connect(connection_id=connection_id, websocket=websocket)
logger.info("统一 WebSocket 客户端已连接: connection=%s", connection_id)
await websocket_manager.send_event(
connection_id,
domain="system",
event="ready",
data={"connection_id": connection_id, "timestamp": time.time()},
)
try:
while True:
raw_message = await websocket.receive_json()
if not isinstance(raw_message, dict):
await websocket_manager.send_response(
connection_id,
request_id=None,
ok=False,
error=_build_error("invalid_message", "消息必须是 JSON 对象"),
)
continue
await handle_client_message(connection_id, cast(Dict[str, Any], raw_message))
except WebSocketDisconnect:
logger.info("统一 WebSocket 客户端已断开: connection=%s", connection_id)
except Exception as exc:
logger.error(f"统一 WebSocket 处理失败: {exc}")
finally:
chat_manager.disconnect_connection(connection_id)
await websocket_manager.disconnect(connection_id)

View File

@@ -18,11 +18,11 @@ from src.webui.routers.expression import router as expression_router
from src.webui.routers.jargon import router as jargon_router
from src.webui.routers.model import router as model_router
from src.webui.routers.person import router as person_router
from src.webui.routers.plugin import get_progress_router
from src.webui.routers.plugin import router as plugin_router
from src.webui.routers.statistics import router as statistics_router
from src.webui.routers.system import router as system_router
from src.webui.routers.websocket.auth import router as ws_auth_router
from src.webui.routers.websocket.unified import router as unified_ws_router
logger = get_logger("webui.api")
@@ -43,14 +43,14 @@ router.include_router(jargon_router)
router.include_router(emoji_router)
# 注册插件管理路由
router.include_router(plugin_router)
# 注册插件进度 WebSocket 路由
router.include_router(get_progress_router())
# 注册系统控制路由
router.include_router(system_router)
# 注册模型列表获取路由
router.include_router(model_router)
# 注册 WebSocket 认证路由
router.include_router(ws_auth_router)
# 注册统一 WebSocket 路由
router.include_router(unified_ws_router)
class TokenVerifyRequest(BaseModel):