feat: add unified WebSocket connection manager and routing
- Implemented UnifiedWebSocketManager for managing WebSocket connections, including subscription handling and message sending. - Created unified WebSocket router to handle client messages, including authentication, subscription, and chat session management. - Added support for logging and plugin progress subscriptions. - Enhanced error handling and response structure for WebSocket operations.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import Tuple
|
||||
|
||||
from .routes import router
|
||||
from .support import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
|
||||
from .service import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
|
||||
|
||||
|
||||
def get_webui_chat_broadcaster() -> Tuple[ChatConnectionManager, str]:
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""本地聊天室路由 - WebUI 与麦麦直接对话。"""
|
||||
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import case, func
|
||||
from sqlmodel import col, select
|
||||
|
||||
@@ -13,16 +12,11 @@ from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
from .support import (
|
||||
from .service import (
|
||||
WEBUI_CHAT_GROUP_ID,
|
||||
WEBUI_CHAT_PLATFORM,
|
||||
authenticate_chat_websocket,
|
||||
chat_history,
|
||||
chat_manager,
|
||||
dispatch_chat_event,
|
||||
normalize_webui_user_id,
|
||||
resolve_initial_virtual_identity,
|
||||
send_initial_chat_state,
|
||||
)
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
@@ -113,55 +107,6 @@ async def clear_chat_history(
|
||||
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_chat(
|
||||
websocket: WebSocket,
|
||||
user_id: Optional[str] = Query(default=None),
|
||||
user_name: Optional[str] = Query(default="WebUI用户"),
|
||||
platform: Optional[str] = Query(default=None),
|
||||
person_id: Optional[str] = Query(default=None),
|
||||
group_name: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None),
|
||||
token: Optional[str] = Query(default=None),
|
||||
) -> None:
|
||||
"""WebSocket 聊天端点。"""
|
||||
if not await authenticate_chat_websocket(websocket, token):
|
||||
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
normalized_user_id = normalize_webui_user_id(user_id)
|
||||
current_user_name = user_name or "WebUI用户"
|
||||
current_virtual_config = resolve_initial_virtual_identity(platform, person_id, group_name, group_id)
|
||||
|
||||
await chat_manager.connect(websocket, session_id, normalized_user_id)
|
||||
try:
|
||||
await send_initial_chat_state(
|
||||
session_id=session_id,
|
||||
user_id=normalized_user_id,
|
||||
user_name=current_user_name,
|
||||
virtual_config=current_virtual_config,
|
||||
)
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
current_user_name, current_virtual_config = await dispatch_chat_event(
|
||||
session_id=session_id,
|
||||
session_id_prefix=session_id[:8],
|
||||
data=data,
|
||||
current_user_name=current_user_name,
|
||||
normalized_user_id=normalized_user_id,
|
||||
current_virtual_config=current_virtual_config,
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开: session={session_id}, user={normalized_user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 错误: {e}")
|
||||
finally:
|
||||
chat_manager.disconnect(session_id, normalized_user_id)
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_chat_info() -> Dict[str, object]:
|
||||
"""获取聊天室信息。"""
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""WebUI 聊天路由支持逻辑。"""
|
||||
"""WebUI 聊天运行时服务。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, cast
|
||||
|
||||
from fastapi import WebSocket
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, delete, select
|
||||
|
||||
@@ -17,8 +17,6 @@ from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
|
||||
@@ -27,6 +25,8 @@ WEBUI_CHAT_PLATFORM = "webui"
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
WEBUI_USER_ID_PREFIX = "webui_user_"
|
||||
|
||||
AsyncMessageSender = Callable[[Dict[str, Any]], Awaitable[None]]
|
||||
|
||||
|
||||
class VirtualIdentityConfig(BaseModel):
|
||||
"""虚拟身份配置。"""
|
||||
@@ -52,13 +52,42 @@ class ChatHistoryMessage(BaseModel):
|
||||
is_bot: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatSessionConnection:
|
||||
"""逻辑聊天会话连接信息。"""
|
||||
|
||||
session_id: str
|
||||
connection_id: str
|
||||
client_session_id: str
|
||||
user_id: str
|
||||
user_name: str
|
||||
active_group_id: str
|
||||
virtual_config: Optional[VirtualIdentityConfig]
|
||||
sender: AsyncMessageSender
|
||||
|
||||
|
||||
class ChatHistoryManager:
|
||||
"""聊天历史管理器。"""
|
||||
|
||||
def __init__(self, max_messages: int = 200) -> None:
|
||||
"""初始化聊天历史管理器。
|
||||
|
||||
Args:
|
||||
max_messages: 内存中允许处理的最大消息数。
|
||||
"""
|
||||
self.max_messages = max_messages
|
||||
|
||||
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""将内部消息对象转换为前端可消费的字典。
|
||||
|
||||
Args:
|
||||
msg: 内部统一消息对象。
|
||||
group_id: 当前会话所属的群组标识。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 面向 WebUI 的消息字典。
|
||||
"""
|
||||
del group_id
|
||||
user_info = msg.message_info.user_info
|
||||
user_id = user_info.user_id or ""
|
||||
is_bot = is_bot_self(msg.platform, user_id)
|
||||
@@ -74,10 +103,27 @@ class ChatHistoryManager:
|
||||
}
|
||||
|
||||
def _resolve_session_id(self, group_id: Optional[str]) -> str:
|
||||
"""根据群组标识解析聊天会话 ID。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识。
|
||||
|
||||
Returns:
|
||||
str: 内部聊天会话 ID。
|
||||
"""
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
|
||||
|
||||
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""获取指定会话的历史消息。
|
||||
|
||||
Args:
|
||||
limit: 最大返回条数。
|
||||
group_id: 群组标识。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 历史消息列表。
|
||||
"""
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
session_id = self._resolve_session_id(target_group_id)
|
||||
try:
|
||||
@@ -90,11 +136,19 @@ class ChatHistoryManager:
|
||||
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
|
||||
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载聊天记录失败: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"从数据库加载聊天记录失败: {exc}")
|
||||
return []
|
||||
|
||||
def clear_history(self, group_id: Optional[str] = None) -> int:
|
||||
"""清空指定会话的历史消息。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识。
|
||||
|
||||
Returns:
|
||||
int: 被删除的消息数量。
|
||||
"""
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
session_id = self._resolve_session_id(target_group_id)
|
||||
try:
|
||||
@@ -104,66 +158,245 @@ class ChatHistoryManager:
|
||||
deleted = result.rowcount or 0
|
||||
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
|
||||
return deleted
|
||||
except Exception as e:
|
||||
logger.error(f"清空聊天记录失败: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"清空聊天记录失败: {exc}")
|
||||
return 0
|
||||
|
||||
|
||||
class ChatConnectionManager:
|
||||
"""聊天连接管理器。"""
|
||||
"""统一聊天逻辑会话管理器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.user_sessions: Dict[str, str] = {}
|
||||
"""初始化聊天逻辑会话管理器。"""
|
||||
self.active_connections: Dict[str, ChatSessionConnection] = {}
|
||||
self.client_sessions: Dict[Tuple[str, str], str] = {}
|
||||
self.connection_sessions: Dict[str, Set[str]] = {}
|
||||
self.group_sessions: Dict[str, Set[str]] = {}
|
||||
self.user_sessions: Dict[str, Set[str]] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None:
|
||||
await websocket.accept()
|
||||
self.active_connections[session_id] = websocket
|
||||
self.user_sessions[user_id] = session_id
|
||||
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
|
||||
def _bind_group(self, session_id: str, group_id: str) -> None:
|
||||
"""为会话绑定群组索引。
|
||||
|
||||
def disconnect(self, session_id: str, user_id: str) -> None:
|
||||
if session_id in self.active_connections:
|
||||
del self.active_connections[session_id]
|
||||
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
|
||||
del self.user_sessions[user_id]
|
||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||
Args:
|
||||
session_id: 内部会话 ID。
|
||||
group_id: 群组标识。
|
||||
"""
|
||||
group_session_ids = self.group_sessions.setdefault(group_id, set())
|
||||
group_session_ids.add(session_id)
|
||||
|
||||
def _unbind_group(self, session_id: str, group_id: str) -> None:
|
||||
"""移除会话与群组的索引关系。
|
||||
|
||||
Args:
|
||||
session_id: 内部会话 ID。
|
||||
group_id: 群组标识。
|
||||
"""
|
||||
group_session_ids = self.group_sessions.get(group_id)
|
||||
if group_session_ids is None:
|
||||
return
|
||||
|
||||
group_session_ids.discard(session_id)
|
||||
if not group_session_ids:
|
||||
del self.group_sessions[group_id]
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
session_id: str,
|
||||
connection_id: str,
|
||||
client_session_id: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
sender: AsyncMessageSender,
|
||||
) -> None:
|
||||
"""注册一个新的逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
connection_id: 物理 WebSocket 连接 ID。
|
||||
client_session_id: 前端标签页使用的会话 ID。
|
||||
user_id: 规范化后的用户 ID。
|
||||
user_name: 当前展示昵称。
|
||||
virtual_config: 当前虚拟身份配置。
|
||||
sender: 发送消息到前端的异步回调。
|
||||
"""
|
||||
existing_session_id = self.client_sessions.get((connection_id, client_session_id))
|
||||
if existing_session_id is not None:
|
||||
self.disconnect(existing_session_id)
|
||||
|
||||
active_group_id = get_current_group_id(virtual_config)
|
||||
session_connection = ChatSessionConnection(
|
||||
session_id=session_id,
|
||||
connection_id=connection_id,
|
||||
client_session_id=client_session_id,
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
active_group_id=active_group_id,
|
||||
virtual_config=virtual_config,
|
||||
sender=sender,
|
||||
)
|
||||
|
||||
self.active_connections[session_id] = session_connection
|
||||
self.client_sessions[(connection_id, client_session_id)] = session_id
|
||||
self.connection_sessions.setdefault(connection_id, set()).add(session_id)
|
||||
self.user_sessions.setdefault(user_id, set()).add(session_id)
|
||||
self._bind_group(session_id, active_group_id)
|
||||
logger.info(
|
||||
"WebUI 聊天会话已连接: session=%s, connection=%s, client_session=%s, user=%s, group=%s",
|
||||
session_id,
|
||||
connection_id,
|
||||
client_session_id,
|
||||
user_id,
|
||||
active_group_id,
|
||||
)
|
||||
|
||||
def disconnect(self, session_id: str) -> None:
|
||||
"""断开一个逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
"""
|
||||
session_connection = self.active_connections.pop(session_id, None)
|
||||
if session_connection is None:
|
||||
return
|
||||
|
||||
self.client_sessions.pop((session_connection.connection_id, session_connection.client_session_id), None)
|
||||
self._unbind_group(session_id, session_connection.active_group_id)
|
||||
|
||||
connection_session_ids = self.connection_sessions.get(session_connection.connection_id)
|
||||
if connection_session_ids is not None:
|
||||
connection_session_ids.discard(session_id)
|
||||
if not connection_session_ids:
|
||||
del self.connection_sessions[session_connection.connection_id]
|
||||
|
||||
user_session_ids = self.user_sessions.get(session_connection.user_id)
|
||||
if user_session_ids is not None:
|
||||
user_session_ids.discard(session_id)
|
||||
if not user_session_ids:
|
||||
del self.user_sessions[session_connection.user_id]
|
||||
|
||||
logger.info("WebUI 聊天会话已断开: session=%s", session_id)
|
||||
|
||||
def disconnect_connection(self, connection_id: str) -> None:
|
||||
"""断开物理连接下的全部逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
connection_id: 物理 WebSocket 连接 ID。
|
||||
"""
|
||||
session_ids = list(self.connection_sessions.get(connection_id, set()))
|
||||
for session_id in session_ids:
|
||||
self.disconnect(session_id)
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[ChatSessionConnection]:
|
||||
"""获取逻辑聊天会话信息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
|
||||
Returns:
|
||||
Optional[ChatSessionConnection]: 会话存在时返回对应信息。
|
||||
"""
|
||||
return self.active_connections.get(session_id)
|
||||
|
||||
def get_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
|
||||
"""根据连接 ID 和前端会话 ID 查询内部会话 ID。
|
||||
|
||||
Args:
|
||||
connection_id: 物理 WebSocket 连接 ID。
|
||||
client_session_id: 前端标签页使用的会话 ID。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 找到时返回内部会话 ID。
|
||||
"""
|
||||
return self.client_sessions.get((connection_id, client_session_id))
|
||||
|
||||
def update_session_context(
|
||||
self,
|
||||
session_id: str,
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> None:
|
||||
"""更新会话上下文信息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
user_name: 最新昵称。
|
||||
virtual_config: 最新虚拟身份配置。
|
||||
"""
|
||||
session_connection = self.active_connections.get(session_id)
|
||||
if session_connection is None:
|
||||
return
|
||||
|
||||
next_group_id = get_current_group_id(virtual_config)
|
||||
if next_group_id != session_connection.active_group_id:
|
||||
self._unbind_group(session_id, session_connection.active_group_id)
|
||||
self._bind_group(session_id, next_group_id)
|
||||
session_connection.active_group_id = next_group_id
|
||||
|
||||
session_connection.user_name = user_name
|
||||
session_connection.virtual_config = virtual_config
|
||||
|
||||
async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
|
||||
if session_id in self.active_connections:
|
||||
try:
|
||||
await self.active_connections[session_id].send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
"""向指定逻辑会话发送消息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
message: 待发送的消息内容。
|
||||
"""
|
||||
session_connection = self.active_connections.get(session_id)
|
||||
if session_connection is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await session_connection.sender(message)
|
||||
except Exception as exc:
|
||||
logger.error("发送聊天消息失败: session=%s, error=%s", session_id, exc)
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]) -> None:
|
||||
"""向全部逻辑聊天会话广播消息。
|
||||
|
||||
Args:
|
||||
message: 待广播的消息内容。
|
||||
"""
|
||||
for session_id in list(self.active_connections.keys()):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
async def broadcast_to_group(self, group_id: str, message: Dict[str, Any]) -> None:
|
||||
"""向指定群组下的全部逻辑会话广播消息。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识。
|
||||
message: 待广播的消息内容。
|
||||
"""
|
||||
for session_id in list(self.group_sessions.get(group_id, set())):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
|
||||
chat_history = ChatHistoryManager()
|
||||
chat_manager = ChatConnectionManager()
|
||||
|
||||
|
||||
def is_virtual_mode_enabled(virtual_config: Optional[VirtualIdentityConfig]) -> bool:
|
||||
"""判断当前是否启用了虚拟身份模式。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
bool: 已启用时返回 ``True``。
|
||||
"""
|
||||
return bool(virtual_config and virtual_config.enabled)
|
||||
|
||||
|
||||
async def authenticate_chat_websocket(websocket: WebSocket, token: Optional[str]) -> bool:
|
||||
if token and verify_ws_token(token):
|
||||
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
|
||||
return True
|
||||
|
||||
if cookie_token := websocket.cookies.get("maibot_session"):
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def normalize_webui_user_id(user_id: Optional[str]) -> str:
|
||||
"""标准化 WebUI 用户 ID。
|
||||
|
||||
Args:
|
||||
user_id: 原始用户 ID。
|
||||
|
||||
Returns:
|
||||
str: 带统一前缀的用户 ID。
|
||||
"""
|
||||
if not user_id:
|
||||
return f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
|
||||
if user_id.startswith(WEBUI_USER_ID_PREFIX):
|
||||
@@ -172,12 +405,30 @@ def normalize_webui_user_id(user_id: Optional[str]) -> str:
|
||||
|
||||
|
||||
def get_person_by_person_id(person_id: str) -> Optional[PersonInfo]:
|
||||
"""根据人物 ID 查询人物信息。
|
||||
|
||||
Args:
|
||||
person_id: 人物 ID。
|
||||
|
||||
Returns:
|
||||
Optional[PersonInfo]: 查到时返回人物信息。
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||
return session.exec(statement).first()
|
||||
|
||||
|
||||
def build_virtual_identity_config(person: PersonInfo, group_id: str, group_name: str) -> VirtualIdentityConfig:
|
||||
"""根据人物信息构建虚拟身份配置。
|
||||
|
||||
Args:
|
||||
person: 人物信息对象。
|
||||
group_id: 逻辑群组 ID。
|
||||
group_name: 逻辑群组名称。
|
||||
|
||||
Returns:
|
||||
VirtualIdentityConfig: 虚拟身份配置对象。
|
||||
"""
|
||||
return VirtualIdentityConfig(
|
||||
enabled=True,
|
||||
platform=person.platform,
|
||||
@@ -195,6 +446,17 @@ def resolve_initial_virtual_identity(
|
||||
group_name: Optional[str],
|
||||
group_id: Optional[str],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
"""根据初始参数解析虚拟身份配置。
|
||||
|
||||
Args:
|
||||
platform: 平台名称。
|
||||
person_id: 人物 ID。
|
||||
group_name: 群组名称。
|
||||
group_id: 群组 ID。
|
||||
|
||||
Returns:
|
||||
Optional[VirtualIdentityConfig]: 解析成功时返回虚拟身份配置。
|
||||
"""
|
||||
if not (platform and person_id):
|
||||
return None
|
||||
|
||||
@@ -210,11 +472,14 @@ def resolve_initial_virtual_identity(
|
||||
group_name=group_name or "WebUI虚拟群聊",
|
||||
)
|
||||
logger.info(
|
||||
f"虚拟身份模式已通过 URL 参数激活: {virtual_config.user_nickname} @ {virtual_config.platform}, group_id={virtual_group_id}"
|
||||
"虚拟身份模式已通过参数激活: %s @ %s, group_id=%s",
|
||||
virtual_config.user_nickname,
|
||||
virtual_config.platform,
|
||||
virtual_group_id,
|
||||
)
|
||||
return virtual_config
|
||||
except Exception as e:
|
||||
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"通过参数配置虚拟身份失败: {exc}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -224,6 +489,17 @@ def build_session_info_message(
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Dict[str, Any]:
|
||||
"""构建会话信息消息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
user_id: 规范化后的用户 ID。
|
||||
user_name: 当前昵称。
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 会话信息消息。
|
||||
"""
|
||||
session_info_data: Dict[str, Any] = {
|
||||
"type": "session_info",
|
||||
"session_id": session_id,
|
||||
@@ -247,13 +523,41 @@ def build_session_info_message(
|
||||
|
||||
|
||||
def get_active_history_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> Optional[str]:
|
||||
"""获取当前虚拟身份对应的历史群组 ID。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 虚拟身份启用时返回对应群组 ID。
|
||||
"""
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return virtual_config.group_id
|
||||
return None
|
||||
|
||||
|
||||
def get_current_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> str:
|
||||
"""获取当前会话的有效群组 ID。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
str: 当前会话应使用的群组 ID。
|
||||
"""
|
||||
return get_active_history_group_id(virtual_config) or WEBUI_CHAT_GROUP_ID
|
||||
|
||||
|
||||
def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> str:
|
||||
"""构建欢迎消息。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
str: 欢迎消息文本。
|
||||
"""
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return (
|
||||
@@ -264,6 +568,12 @@ def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> st
|
||||
|
||||
|
||||
async def send_chat_error(session_id: str, content: str) -> None:
|
||||
"""向指定会话发送错误消息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
content: 错误消息内容。
|
||||
"""
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
@@ -279,7 +589,17 @@ async def send_initial_chat_state(
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
include_welcome: bool = True,
|
||||
) -> None:
|
||||
"""向新会话发送初始化状态。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
user_id: 规范化后的用户 ID。
|
||||
user_name: 当前昵称。
|
||||
virtual_config: 虚拟身份配置。
|
||||
include_welcome: 是否发送欢迎消息。
|
||||
"""
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
build_session_info_message(
|
||||
@@ -290,30 +610,43 @@ async def send_initial_chat_state(
|
||||
),
|
||||
)
|
||||
|
||||
if history := chat_history.get_history(50, get_active_history_group_id(virtual_config)):
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
},
|
||||
)
|
||||
|
||||
history_group_id = get_active_history_group_id(virtual_config)
|
||||
history = chat_history.get_history(50, history_group_id)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": build_welcome_message(virtual_config),
|
||||
"timestamp": time.time(),
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
"group_id": get_current_group_id(virtual_config),
|
||||
},
|
||||
)
|
||||
|
||||
if include_welcome:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": build_welcome_message(virtual_config),
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def resolve_sender_identity(
|
||||
current_user_name: str,
|
||||
normalized_user_id: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Tuple[str, str]:
|
||||
"""解析当前发送者身份。
|
||||
|
||||
Args:
|
||||
current_user_name: 当前昵称。
|
||||
normalized_user_id: 规范化后的用户 ID。
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: ``(发送者昵称, 发送者用户 ID)``。
|
||||
"""
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id
|
||||
@@ -328,6 +661,19 @@ def create_message_data(
|
||||
is_at_bot: bool = True,
|
||||
virtual_config: Optional[VirtualIdentityConfig] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建发送给聊天核心的消息数据。
|
||||
|
||||
Args:
|
||||
content: 文本内容。
|
||||
user_id: 用户 ID。
|
||||
user_name: 用户昵称。
|
||||
message_id: 消息 ID。
|
||||
is_at_bot: 是否默认艾特机器人。
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 聊天核心可处理的消息数据。
|
||||
"""
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
@@ -389,6 +735,18 @@ async def handle_chat_message(
|
||||
normalized_user_id: str,
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> str:
|
||||
"""处理用户发送的聊天消息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
data: 前端提交的消息数据。
|
||||
current_user_name: 当前昵称。
|
||||
normalized_user_id: 规范化后的用户 ID。
|
||||
current_virtual_config: 当前虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
str: 处理后的最新昵称。
|
||||
"""
|
||||
content = str(data.get("content", "")).strip()
|
||||
if not content:
|
||||
return current_user_name
|
||||
@@ -401,11 +759,14 @@ async def handle_chat_message(
|
||||
normalized_user_id=normalized_user_id,
|
||||
virtual_config=current_virtual_config,
|
||||
)
|
||||
target_group_id = get_current_group_id(current_virtual_config)
|
||||
|
||||
await chat_manager.broadcast(
|
||||
await chat_manager.broadcast_to_group(
|
||||
target_group_id,
|
||||
{
|
||||
"type": "user_message",
|
||||
"content": content,
|
||||
"group_id": target_group_id,
|
||||
"message_id": message_id,
|
||||
"timestamp": timestamp,
|
||||
"sender": {
|
||||
@@ -414,7 +775,7 @@ async def handle_chat_message(
|
||||
"is_bot": False,
|
||||
},
|
||||
"virtual_mode": is_virtual_mode_enabled(current_virtual_config),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
message_data = create_message_data(
|
||||
@@ -427,22 +788,37 @@ async def handle_chat_message(
|
||||
)
|
||||
|
||||
try:
|
||||
await chat_manager.broadcast({"type": "typing", "is_typing": True})
|
||||
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": True})
|
||||
await chat_bot.message_process(message_data)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时出错: {e}")
|
||||
await send_chat_error(session_id, f"处理消息时出错: {str(e)}")
|
||||
except Exception as exc:
|
||||
logger.error(f"处理消息时出错: {exc}")
|
||||
await send_chat_error(session_id, f"处理消息时出错: {str(exc)}")
|
||||
finally:
|
||||
await chat_manager.broadcast({"type": "typing", "is_typing": False})
|
||||
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": False})
|
||||
|
||||
return next_user_name
|
||||
|
||||
|
||||
async def handle_chat_ping(session_id: str) -> None:
|
||||
"""处理聊天心跳。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
"""
|
||||
await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()})
|
||||
|
||||
|
||||
async def handle_nickname_update(session_id: str, data: Dict[str, Any], current_user_name: str) -> str:
|
||||
"""处理昵称更新请求。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
data: 前端提交的数据。
|
||||
current_user_name: 当前昵称。
|
||||
|
||||
Returns:
|
||||
str: 更新后的昵称。
|
||||
"""
|
||||
new_name = str(data.get("user_name", "")).strip()
|
||||
if not new_name:
|
||||
return current_user_name
|
||||
@@ -463,6 +839,16 @@ async def enable_virtual_identity(
|
||||
session_prefix: str,
|
||||
virtual_data: Dict[str, Any],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
"""启用虚拟身份模式。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
session_prefix: 会话前缀,用于生成默认群组 ID。
|
||||
virtual_data: 前端提交的虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Optional[VirtualIdentityConfig]: 启用成功时返回新的虚拟身份配置。
|
||||
"""
|
||||
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
|
||||
await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id")
|
||||
return None
|
||||
@@ -470,16 +856,18 @@ async def enable_virtual_identity(
|
||||
person_id_value = str(virtual_data.get("person_id"))
|
||||
try:
|
||||
person = get_person_by_person_id(person_id_value)
|
||||
if not person:
|
||||
if person is None:
|
||||
await send_chat_error(session_id, f"找不到用户: {person_id_value}")
|
||||
return None
|
||||
|
||||
custom_group_id = virtual_data.get("group_id")
|
||||
current_group_id = (
|
||||
f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
|
||||
if custom_group_id
|
||||
else f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
|
||||
)
|
||||
custom_group_id = str(virtual_data.get("group_id") or "").strip()
|
||||
if custom_group_id:
|
||||
current_group_id = custom_group_id
|
||||
if not current_group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
|
||||
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{current_group_id}"
|
||||
else:
|
||||
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
|
||||
|
||||
current_virtual_config = build_virtual_identity_config(
|
||||
person=person,
|
||||
group_id=current_group_id,
|
||||
@@ -521,13 +909,18 @@ async def enable_virtual_identity(
|
||||
},
|
||||
)
|
||||
return current_virtual_config
|
||||
except Exception as e:
|
||||
logger.error(f"设置虚拟身份失败: {e}")
|
||||
await send_chat_error(session_id, f"设置虚拟身份失败: {str(e)}")
|
||||
except Exception as exc:
|
||||
logger.error(f"设置虚拟身份失败: {exc}")
|
||||
await send_chat_error(session_id, f"设置虚拟身份失败: {str(exc)}")
|
||||
return None
|
||||
|
||||
|
||||
async def disable_virtual_identity(session_id: str) -> None:
|
||||
"""关闭虚拟身份模式。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
"""
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
@@ -560,7 +953,18 @@ async def handle_virtual_identity_update(
|
||||
data: Dict[str, Any],
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
virtual_data = cast(dict[str, Any], data.get("config", {}))
|
||||
"""处理虚拟身份切换请求。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
session_id_prefix: 会话前缀。
|
||||
data: 前端提交的数据。
|
||||
current_virtual_config: 当前虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Optional[VirtualIdentityConfig]: 更新后的虚拟身份配置。
|
||||
"""
|
||||
virtual_data = cast(Dict[str, Any], data.get("config", {}))
|
||||
if virtual_data.get("enabled"):
|
||||
next_config = await enable_virtual_identity(session_id, session_id_prefix, virtual_data)
|
||||
return next_config if next_config is not None else current_virtual_config
|
||||
@@ -577,6 +981,19 @@ async def dispatch_chat_event(
|
||||
normalized_user_id: str,
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Tuple[str, Optional[VirtualIdentityConfig]]:
|
||||
"""分发聊天事件到对应的处理器。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
session_id_prefix: 会话前缀。
|
||||
data: 前端提交的数据。
|
||||
current_user_name: 当前昵称。
|
||||
normalized_user_id: 规范化后的用户 ID。
|
||||
current_virtual_config: 当前虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Tuple[str, Optional[VirtualIdentityConfig]]: ``(最新昵称, 最新虚拟身份配置)``。
|
||||
"""
|
||||
event_type = data.get("type")
|
||||
if event_type == "message":
|
||||
next_user_name = await handle_chat_message(
|
||||
Reference in New Issue
Block a user