WebUI 后端类型注解补全,使用全 typing 库类型注解
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
from fastapi import APIRouter
|
||||
from typing import Tuple
|
||||
|
||||
from .routes import router
|
||||
from .support import ChatConnectionManager, WEBUI_CHAT_PLATFORM, chat_manager
|
||||
from .support import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
|
||||
|
||||
|
||||
def get_webui_chat_broadcaster() -> tuple[ChatConnectionManager, str]:
|
||||
def get_webui_chat_broadcaster() -> Tuple[ChatConnectionManager, str]:
|
||||
"""获取 WebUI 聊天广播器,供外部模块使用。"""
|
||||
return chat_manager, WEBUI_CHAT_PLATFORM
|
||||
|
||||
@@ -15,4 +15,4 @@ __all__ = [
|
||||
"chat_manager",
|
||||
"get_webui_chat_broadcaster",
|
||||
"router",
|
||||
]
|
||||
]
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"""本地聊天室路由 - WebUI 与麦麦直接对话。"""
|
||||
|
||||
import uuid
|
||||
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
|
||||
from sqlalchemy import case, func
|
||||
@@ -36,7 +35,7 @@ async def get_chat_history(
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
user_id: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None),
|
||||
) -> dict[str, object]:
|
||||
) -> Dict[str, object]:
|
||||
"""获取聊天历史记录。"""
|
||||
del user_id
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
@@ -45,7 +44,7 @@ async def get_chat_history(
|
||||
|
||||
|
||||
@router.get("/platforms")
|
||||
async def get_available_platforms() -> dict[str, object]:
|
||||
async def get_available_platforms() -> Dict[str, object]:
|
||||
"""获取可用平台列表。"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
@@ -68,7 +67,7 @@ async def get_persons_by_platform(
|
||||
platform: str = Query(..., description="平台名称"),
|
||||
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
) -> dict[str, object]:
|
||||
) -> Dict[str, object]:
|
||||
"""获取指定平台的用户列表。"""
|
||||
try:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.platform) == platform)
|
||||
@@ -108,7 +107,7 @@ async def get_persons_by_platform(
|
||||
@router.delete("/history")
|
||||
async def clear_chat_history(
|
||||
group_id: Optional[str] = Query(default=None),
|
||||
) -> dict[str, object]:
|
||||
) -> Dict[str, object]:
|
||||
"""清空聊天历史记录。"""
|
||||
deleted = chat_history.clear_history(group_id)
|
||||
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
|
||||
@@ -164,11 +163,11 @@ async def websocket_chat(
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_chat_info() -> dict[str, object]:
|
||||
async def get_chat_info() -> Dict[str, object]:
|
||||
"""获取聊天室信息。"""
|
||||
return {
|
||||
"bot_name": global_config.bot.nickname,
|
||||
"platform": WEBUI_CHAT_PLATFORM,
|
||||
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||
"active_sessions": len(chat_manager.active_connections),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""WebUI 聊天路由支持逻辑。"""
|
||||
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
from fastapi import WebSocket
|
||||
from pydantic import BaseModel
|
||||
@@ -59,7 +58,7 @@ class ChatHistoryManager:
|
||||
def __init__(self, max_messages: int = 200) -> None:
|
||||
self.max_messages = max_messages
|
||||
|
||||
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> dict[str, Any]:
|
||||
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
user_info = msg.message_info.user_info
|
||||
user_id = user_info.user_id or ""
|
||||
is_bot = is_bot_self(msg.platform, user_id)
|
||||
@@ -78,7 +77,7 @@ class ChatHistoryManager:
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
|
||||
|
||||
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> list[dict[str, Any]]:
|
||||
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
session_id = self._resolve_session_id(target_group_id)
|
||||
try:
|
||||
@@ -114,8 +113,8 @@ class ChatConnectionManager:
|
||||
"""聊天连接管理器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.active_connections: dict[str, WebSocket] = {}
|
||||
self.user_sessions: dict[str, str] = {}
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.user_sessions: Dict[str, str] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None:
|
||||
await websocket.accept()
|
||||
@@ -130,14 +129,14 @@ class ChatConnectionManager:
|
||||
del self.user_sessions[user_id]
|
||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||
|
||||
async def send_message(self, session_id: str, message: dict[str, Any]) -> None:
|
||||
async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
|
||||
if session_id in self.active_connections:
|
||||
try:
|
||||
await self.active_connections[session_id].send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
async def broadcast(self, message: dict[str, Any]) -> None:
|
||||
async def broadcast(self, message: Dict[str, Any]) -> None:
|
||||
for session_id in list(self.active_connections.keys()):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
@@ -224,8 +223,8 @@ def build_session_info_message(
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> dict[str, Any]:
|
||||
session_info_data: dict[str, Any] = {
|
||||
) -> Dict[str, Any]:
|
||||
session_info_data: Dict[str, Any] = {
|
||||
"type": "session_info",
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
@@ -314,7 +313,7 @@ def resolve_sender_identity(
|
||||
current_user_name: str,
|
||||
normalized_user_id: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> tuple[str, str]:
|
||||
) -> Tuple[str, str]:
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id
|
||||
@@ -328,7 +327,7 @@ def create_message_data(
|
||||
message_id: Optional[str] = None,
|
||||
is_at_bot: bool = True,
|
||||
virtual_config: Optional[VirtualIdentityConfig] = None,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
@@ -385,7 +384,7 @@ def create_message_data(
|
||||
|
||||
async def handle_chat_message(
|
||||
session_id: str,
|
||||
data: dict[str, Any],
|
||||
data: Dict[str, Any],
|
||||
current_user_name: str,
|
||||
normalized_user_id: str,
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
@@ -443,7 +442,7 @@ async def handle_chat_ping(session_id: str) -> None:
|
||||
await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()})
|
||||
|
||||
|
||||
async def handle_nickname_update(session_id: str, data: dict[str, Any], current_user_name: str) -> str:
|
||||
async def handle_nickname_update(session_id: str, data: Dict[str, Any], current_user_name: str) -> str:
|
||||
new_name = str(data.get("user_name", "")).strip()
|
||||
if not new_name:
|
||||
return current_user_name
|
||||
@@ -462,7 +461,7 @@ async def handle_nickname_update(session_id: str, data: dict[str, Any], current_
|
||||
async def enable_virtual_identity(
|
||||
session_id: str,
|
||||
session_prefix: str,
|
||||
virtual_data: dict[str, Any],
|
||||
virtual_data: Dict[str, Any],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
|
||||
await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id")
|
||||
@@ -558,7 +557,7 @@ async def disable_virtual_identity(session_id: str) -> None:
|
||||
async def handle_virtual_identity_update(
|
||||
session_id: str,
|
||||
session_id_prefix: str,
|
||||
data: dict[str, Any],
|
||||
data: Dict[str, Any],
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
virtual_data = cast(dict[str, Any], data.get("config", {}))
|
||||
@@ -573,11 +572,11 @@ async def handle_virtual_identity_update(
|
||||
async def dispatch_chat_event(
|
||||
session_id: str,
|
||||
session_id_prefix: str,
|
||||
data: dict[str, Any],
|
||||
data: Dict[str, Any],
|
||||
current_user_name: str,
|
||||
normalized_user_id: str,
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> tuple[str, Optional[VirtualIdentityConfig]]:
|
||||
) -> Tuple[str, Optional[VirtualIdentityConfig]]:
|
||||
event_type = data.get("type")
|
||||
if event_type == "message":
|
||||
next_user_name = await handle_chat_message(
|
||||
|
||||
Reference in New Issue
Block a user