perf:优化webui交互体验,优化统计逻辑,优化log展示
This commit is contained in:
@@ -360,11 +360,8 @@ class ChatConnectionManager:
|
||||
existing.virtual_config = virtual_config
|
||||
existing.sender = sender
|
||||
logger.debug(
|
||||
"WebUI 聊天会话复用: session=%s, connection=%s, client_session=%s, channel=%s",
|
||||
session_id,
|
||||
connection_id,
|
||||
client_session_id,
|
||||
channel_key,
|
||||
f"WebUI 聊天会话复用: session={session_id}, connection={connection_id}, "
|
||||
f"client_session={client_session_id}, channel={channel_key}",
|
||||
)
|
||||
return
|
||||
if existing_session_id is not None:
|
||||
@@ -387,12 +384,8 @@ class ChatConnectionManager:
|
||||
self.user_sessions.setdefault(user_id, set()).add(session_id)
|
||||
self._bind_channel(session_id, channel_key)
|
||||
logger.info(
|
||||
"WebUI 聊天会话已连接: session=%s, connection=%s, client_session=%s, user=%s, channel=%s",
|
||||
session_id,
|
||||
connection_id,
|
||||
client_session_id,
|
||||
user_id,
|
||||
channel_key,
|
||||
f"WebUI 聊天会话已连接: session={session_id}, connection={connection_id}, "
|
||||
f"client_session={client_session_id}, user={user_id}, channel={channel_key}",
|
||||
)
|
||||
|
||||
def disconnect(self, session_id: str) -> None:
|
||||
@@ -420,7 +413,7 @@ class ChatConnectionManager:
|
||||
if not user_session_ids:
|
||||
del self.user_sessions[session_connection.user_id]
|
||||
|
||||
logger.info("WebUI 聊天会话已断开: session=%s", session_id)
|
||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||
|
||||
def disconnect_connection(self, connection_id: str) -> None:
|
||||
"""断开物理连接下的全部逻辑聊天会话。
|
||||
@@ -495,7 +488,7 @@ class ChatConnectionManager:
|
||||
try:
|
||||
await session_connection.sender(message)
|
||||
except Exception as exc:
|
||||
logger.error("发送聊天消息失败: session=%s, error=%s", session_id, exc)
|
||||
logger.error(f"发送聊天消息失败: session={session_id}, error={exc}")
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]) -> None:
|
||||
"""向全部逻辑聊天会话广播消息。
|
||||
@@ -659,10 +652,8 @@ def resolve_initial_virtual_identity(
|
||||
group_name=group_name or "WebUI虚拟群聊",
|
||||
)
|
||||
logger.info(
|
||||
"虚拟身份模式已通过参数激活: %s @ %s, group_id=%s",
|
||||
virtual_config.user_nickname,
|
||||
virtual_config.platform,
|
||||
virtual_group_id,
|
||||
f"虚拟身份模式已通过参数激活: {virtual_config.user_nickname} @ "
|
||||
f"{virtual_config.platform}, group_id={virtual_group_id}",
|
||||
)
|
||||
return virtual_config
|
||||
except Exception as exc:
|
||||
|
||||
@@ -1,349 +1,34 @@
|
||||
"""统计数据 API 路由"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import desc, func, or_
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ModelUsage, OnlineTime
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import count_messages
|
||||
from src.services.statistics_service import get_dashboard_statistics, get_model_statistics, get_summary_statistics
|
||||
from src.webui.dependencies import require_auth
|
||||
from src.webui.schemas.statistics import DashboardData
|
||||
|
||||
logger = get_logger("webui.statistics")
|
||||
|
||||
router = APIRouter(prefix="/statistics", tags=["statistics"], dependencies=[Depends(require_auth)])
|
||||
|
||||
|
||||
class StatisticsSummary(BaseModel):
|
||||
"""统计数据摘要"""
|
||||
|
||||
total_requests: int = Field(0, description="总请求数")
|
||||
total_cost: float = Field(0.0, description="总花费")
|
||||
total_tokens: int = Field(0, description="总token数")
|
||||
online_time: float = Field(0.0, description="在线时间(秒)")
|
||||
total_messages: int = Field(0, description="总消息数")
|
||||
total_replies: int = Field(0, description="总回复数")
|
||||
avg_response_time: float = Field(0.0, description="平均响应时间")
|
||||
cost_per_hour: float = Field(0.0, description="每小时花费")
|
||||
tokens_per_hour: float = Field(0.0, description="每小时token数")
|
||||
|
||||
|
||||
class ModelStatistics(BaseModel):
|
||||
"""模型统计"""
|
||||
|
||||
model_name: str
|
||||
request_count: int
|
||||
total_cost: float
|
||||
total_tokens: int
|
||||
avg_response_time: float
|
||||
|
||||
|
||||
class TimeSeriesData(BaseModel):
|
||||
"""时间序列数据"""
|
||||
|
||||
timestamp: str
|
||||
requests: int = 0
|
||||
cost: float = 0.0
|
||||
tokens: int = 0
|
||||
|
||||
|
||||
class DashboardData(BaseModel):
|
||||
"""仪表盘数据"""
|
||||
|
||||
summary: StatisticsSummary
|
||||
model_stats: List[ModelStatistics]
|
||||
hourly_data: List[TimeSeriesData]
|
||||
daily_data: List[TimeSeriesData]
|
||||
recent_activity: List[Dict[str, Any]]
|
||||
|
||||
|
||||
@router.get("/dashboard", response_model=DashboardData)
|
||||
async def get_dashboard_data(hours: int = 24):
|
||||
"""
|
||||
获取仪表盘统计数据
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时),默认24小时
|
||||
|
||||
Returns:
|
||||
仪表盘数据
|
||||
"""
|
||||
async def get_dashboard_data(hours: int = 24) -> DashboardData:
|
||||
"""获取仪表盘统计数据。"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(hours=hours)
|
||||
|
||||
# 获取摘要数据
|
||||
summary = await _get_summary_statistics(start_time, now)
|
||||
|
||||
# 获取模型统计
|
||||
model_stats = await _get_model_statistics(start_time)
|
||||
|
||||
# 获取小时级时间序列数据
|
||||
hourly_data = await _get_hourly_statistics(start_time, now)
|
||||
|
||||
# 获取日级时间序列数据(最近7天)
|
||||
daily_start = now - timedelta(days=7)
|
||||
daily_data = await _get_daily_statistics(daily_start, now)
|
||||
|
||||
# 获取最近活动
|
||||
recent_activity = await _get_recent_activity(limit=10)
|
||||
|
||||
return DashboardData(
|
||||
summary=summary,
|
||||
model_stats=model_stats,
|
||||
hourly_data=hourly_data,
|
||||
daily_data=daily_data,
|
||||
recent_activity=recent_activity,
|
||||
)
|
||||
return await get_dashboard_statistics(hours=hours)
|
||||
except Exception as e:
|
||||
logger.error(f"获取仪表盘数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
|
||||
|
||||
|
||||
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
|
||||
"""获取摘要统计数据(优化:使用数据库聚合)"""
|
||||
summary = StatisticsSummary(
|
||||
total_requests=0,
|
||||
total_cost=0.0,
|
||||
total_tokens=0,
|
||||
online_time=0.0,
|
||||
total_messages=0,
|
||||
total_replies=0,
|
||||
avg_response_time=0.0,
|
||||
cost_per_hour=0.0,
|
||||
tokens_per_hour=0.0,
|
||||
)
|
||||
|
||||
# 使用聚合查询替代全量加载
|
||||
with get_db_session() as session:
|
||||
statement = select(
|
||||
func.count().label("total_requests"),
|
||||
func.sum(col(ModelUsage.cost)).label("total_cost"),
|
||||
func.sum(col(ModelUsage.total_tokens)).label("total_tokens"),
|
||||
func.avg(col(ModelUsage.time_cost)).label("avg_response_time"),
|
||||
).where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time)
|
||||
result = session.exec(statement).first()
|
||||
|
||||
if result:
|
||||
total_requests, total_cost, total_tokens, avg_response_time = result
|
||||
summary.total_requests = total_requests or 0
|
||||
summary.total_cost = float(total_cost or 0.0)
|
||||
summary.total_tokens = total_tokens or 0
|
||||
summary.avg_response_time = float(avg_response_time or 0.0)
|
||||
|
||||
# 查询在线时间 - 这个数据量通常不大,保留原逻辑
|
||||
with get_db_session() as session:
|
||||
statement = select(OnlineTime).where(
|
||||
or_(
|
||||
col(OnlineTime.start_timestamp) >= start_time,
|
||||
col(OnlineTime.end_timestamp) >= start_time,
|
||||
)
|
||||
)
|
||||
online_records = session.exec(statement).all()
|
||||
|
||||
for record in online_records:
|
||||
start = max(record.start_timestamp, start_time)
|
||||
end = min(record.end_timestamp, end_time)
|
||||
if end > start:
|
||||
summary.online_time += (end - start).total_seconds()
|
||||
|
||||
summary.total_messages = count_messages(start_time=start_time.timestamp(), end_time=end_time.timestamp())
|
||||
summary.total_replies = count_messages(
|
||||
start_time=start_time.timestamp(),
|
||||
end_time=end_time.timestamp(),
|
||||
has_reply_to=True,
|
||||
)
|
||||
|
||||
# 计算派生指标
|
||||
if summary.online_time > 0:
|
||||
online_hours = summary.online_time / 3600.0
|
||||
summary.cost_per_hour = summary.total_cost / online_hours
|
||||
summary.tokens_per_hour = summary.total_tokens / online_hours
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
||||
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
|
||||
# 使用GROUP BY聚合,避免全量加载
|
||||
statement = (
|
||||
select(ModelUsage)
|
||||
.where(col(ModelUsage.timestamp) >= start_time)
|
||||
.order_by(desc(col(ModelUsage.timestamp)))
|
||||
.limit(200)
|
||||
)
|
||||
|
||||
with get_db_session() as session:
|
||||
rows = session.exec(statement).all()
|
||||
|
||||
aggregates: Dict[str, Dict[str, float | int]] = {}
|
||||
for record in rows:
|
||||
model_name = record.model_assign_name or record.model_name or "unknown"
|
||||
if model_name not in aggregates:
|
||||
aggregates[model_name] = {
|
||||
"request_count": 0,
|
||||
"total_cost": 0.0,
|
||||
"total_tokens": 0,
|
||||
"total_time_cost": 0.0,
|
||||
"time_cost_count": 0,
|
||||
}
|
||||
bucket = aggregates[model_name]
|
||||
bucket["request_count"] = int(bucket["request_count"]) + 1
|
||||
bucket["total_cost"] = float(bucket["total_cost"]) + float(record.cost or 0.0)
|
||||
bucket["total_tokens"] = int(bucket["total_tokens"]) + int(record.total_tokens or 0)
|
||||
if record.time_cost:
|
||||
bucket["total_time_cost"] = float(bucket["total_time_cost"]) + float(record.time_cost)
|
||||
bucket["time_cost_count"] = int(bucket["time_cost_count"]) + 1
|
||||
|
||||
result: List[ModelStatistics] = []
|
||||
for model_name, bucket in sorted(
|
||||
aggregates.items(),
|
||||
key=lambda item: float(item[1]["request_count"]),
|
||||
reverse=True,
|
||||
)[:10]:
|
||||
time_cost_count = int(bucket["time_cost_count"])
|
||||
avg_time_cost = float(bucket["total_time_cost"]) / time_cost_count if time_cost_count > 0 else 0.0
|
||||
result.append(
|
||||
ModelStatistics(
|
||||
model_name=model_name,
|
||||
request_count=int(bucket["request_count"]),
|
||||
total_cost=float(bucket["total_cost"]),
|
||||
total_tokens=int(bucket["total_tokens"]),
|
||||
avg_response_time=avg_time_cost,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""获取小时级统计数据(优化:使用数据库聚合)"""
|
||||
# SQLite的日期时间函数进行小时分组
|
||||
# 使用strftime将timestamp格式化为小时级别
|
||||
hour_expr = func.strftime("%Y-%m-%dT%H:00:00", col(ModelUsage.timestamp))
|
||||
statement = (
|
||||
select(
|
||||
hour_expr.label("hour"),
|
||||
func.count().label("requests"),
|
||||
func.sum(col(ModelUsage.cost)).label("cost"),
|
||||
func.sum(col(ModelUsage.total_tokens)).label("tokens"),
|
||||
)
|
||||
.where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time)
|
||||
.group_by(hour_expr)
|
||||
)
|
||||
|
||||
with get_db_session() as session:
|
||||
rows = session.exec(statement).all()
|
||||
|
||||
# 转换为字典以快速查找
|
||||
data_dict = {row[0]: row for row in rows}
|
||||
|
||||
# 填充所有小时(包括没有数据的)
|
||||
result = []
|
||||
current = start_time.replace(minute=0, second=0, microsecond=0)
|
||||
while current <= end_time:
|
||||
hour_str = current.strftime("%Y-%m-%dT%H:00:00")
|
||||
if hour_str in data_dict:
|
||||
row = data_dict[hour_str]
|
||||
result.append(
|
||||
TimeSeriesData(
|
||||
timestamp=hour_str,
|
||||
requests=row[1] or 0,
|
||||
cost=float(row[2] or 0.0),
|
||||
tokens=row[3] or 0,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(TimeSeriesData(timestamp=hour_str, requests=0, cost=0.0, tokens=0))
|
||||
current += timedelta(hours=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""获取日级统计数据(优化:使用数据库聚合)"""
|
||||
# 使用strftime按日期分组
|
||||
day_expr = func.strftime("%Y-%m-%dT00:00:00", col(ModelUsage.timestamp))
|
||||
statement = (
|
||||
select(
|
||||
day_expr.label("day"),
|
||||
func.count().label("requests"),
|
||||
func.sum(col(ModelUsage.cost)).label("cost"),
|
||||
func.sum(col(ModelUsage.total_tokens)).label("tokens"),
|
||||
)
|
||||
.where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time)
|
||||
.group_by(day_expr)
|
||||
)
|
||||
|
||||
with get_db_session() as session:
|
||||
rows = session.exec(statement).all()
|
||||
|
||||
# 转换为字典
|
||||
data_dict = {row[0]: row for row in rows}
|
||||
|
||||
# 填充所有天
|
||||
result = []
|
||||
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
while current <= end_time:
|
||||
day_str = current.strftime("%Y-%m-%dT00:00:00")
|
||||
if day_str in data_dict:
|
||||
row = data_dict[day_str]
|
||||
result.append(
|
||||
TimeSeriesData(
|
||||
timestamp=day_str,
|
||||
requests=row[1] or 0,
|
||||
cost=float(row[2] or 0.0),
|
||||
tokens=row[3] or 0,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(TimeSeriesData(timestamp=day_str, requests=0, cost=0.0, tokens=0))
|
||||
current += timedelta(days=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取最近活动"""
|
||||
with get_db_session() as session:
|
||||
statement = select(ModelUsage).order_by(desc(col(ModelUsage.timestamp))).limit(limit)
|
||||
records = session.exec(statement).all()
|
||||
|
||||
activities = []
|
||||
for record in records:
|
||||
activities.append(
|
||||
{
|
||||
"timestamp": record.timestamp.isoformat(),
|
||||
"model": record.model_assign_name or record.model_name,
|
||||
"request_type": record.request_type,
|
||||
"tokens": record.total_tokens or 0,
|
||||
"cost": record.cost or 0.0,
|
||||
"time_cost": record.time_cost or 0.0,
|
||||
"status": None,
|
||||
}
|
||||
)
|
||||
|
||||
return activities
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_summary(hours: int = 24):
|
||||
"""
|
||||
获取统计摘要
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时)
|
||||
"""
|
||||
"""获取统计摘要。"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(hours=hours)
|
||||
summary = await _get_summary_statistics(start_time, now)
|
||||
return summary
|
||||
return await get_summary_statistics(start_time, now)
|
||||
except Exception as e:
|
||||
logger.error(f"获取统计摘要失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
@@ -351,17 +36,10 @@ async def get_summary(hours: int = 24):
|
||||
|
||||
@router.get("/models")
|
||||
async def get_model_stats(hours: int = 24):
|
||||
"""
|
||||
获取模型统计
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时)
|
||||
"""
|
||||
"""获取模型统计。"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(hours=hours)
|
||||
stats = await _get_model_statistics(start_time)
|
||||
return stats
|
||||
start_time = datetime.now() - timedelta(hours=hours)
|
||||
return await get_model_statistics(start_time)
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
@@ -4,13 +4,17 @@
|
||||
提供系统重启、状态查询等功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from importlib.metadata import PackageNotFoundError, version as get_package_version
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
import httpx
|
||||
import os
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import MMC_VERSION
|
||||
from src.webui.dependencies import require_auth
|
||||
@@ -20,6 +24,10 @@ logger = get_logger("webui_system")
|
||||
|
||||
# 记录启动时间
|
||||
_start_time = time.time()
|
||||
_DASHBOARD_PACKAGE_NAME = "maibot-dashboard"
|
||||
_PYPI_JSON_URL = f"https://pypi.org/pypi/{_DASHBOARD_PACKAGE_NAME}/json"
|
||||
_PYPI_CACHE_TTL_SECONDS = 60 * 60 * 6
|
||||
_pypi_version_cache: Dict[str, Any] = {"checked_at": 0.0, "latest_version": None}
|
||||
|
||||
|
||||
class RestartResponse(BaseModel):
|
||||
@@ -38,6 +46,72 @@ class StatusResponse(BaseModel):
|
||||
start_time: str
|
||||
|
||||
|
||||
class DashboardVersionResponse(BaseModel):
|
||||
"""WebUI 版本检查响应"""
|
||||
|
||||
current_version: str
|
||||
latest_version: Optional[str] = None
|
||||
has_update: bool = False
|
||||
package_name: str = _DASHBOARD_PACKAGE_NAME
|
||||
pypi_url: str = f"https://pypi.org/project/{_DASHBOARD_PACKAGE_NAME}/"
|
||||
|
||||
|
||||
def _get_installed_dashboard_version() -> str:
|
||||
try:
|
||||
return get_package_version(_DASHBOARD_PACKAGE_NAME)
|
||||
except PackageNotFoundError:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _normalize_version(version: str) -> tuple[int, ...]:
|
||||
clean_version = version.strip().lower().removeprefix("v")
|
||||
numeric_part = clean_version.split("-", 1)[0].split("+", 1)[0]
|
||||
parts = []
|
||||
for item in numeric_part.split("."):
|
||||
number = ""
|
||||
for char in item:
|
||||
if not char.isdigit():
|
||||
break
|
||||
number += char
|
||||
parts.append(int(number) if number else 0)
|
||||
return tuple(parts)
|
||||
|
||||
|
||||
def _is_newer_version(latest_version: Optional[str], current_version: str) -> bool:
|
||||
if not latest_version or not current_version or current_version == "unknown":
|
||||
return False
|
||||
|
||||
latest_parts = _normalize_version(latest_version)
|
||||
current_parts = _normalize_version(current_version)
|
||||
width = max(len(latest_parts), len(current_parts))
|
||||
return latest_parts + (0,) * (width - len(latest_parts)) > current_parts + (0,) * (width - len(current_parts))
|
||||
|
||||
|
||||
async def _get_latest_dashboard_version_from_pypi() -> Optional[str]:
|
||||
now = time.time()
|
||||
cached_version = _pypi_version_cache.get("latest_version")
|
||||
checked_at = float(_pypi_version_cache.get("checked_at", 0.0))
|
||||
if cached_version and now - checked_at < _PYPI_CACHE_TTL_SECONDS:
|
||||
return str(cached_version)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(_PYPI_JSON_URL)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
except Exception as e:
|
||||
logger.debug(f"检查 WebUI PyPI 版本失败: {e}")
|
||||
return str(cached_version) if cached_version else None
|
||||
|
||||
latest_version = payload.get("info", {}).get("version")
|
||||
if isinstance(latest_version, str) and latest_version.strip():
|
||||
_pypi_version_cache["checked_at"] = now
|
||||
_pypi_version_cache["latest_version"] = latest_version.strip()
|
||||
return latest_version.strip()
|
||||
|
||||
return str(cached_version) if cached_version else None
|
||||
|
||||
|
||||
@router.post("/restart", response_model=RestartResponse)
|
||||
async def restart_maibot():
|
||||
"""
|
||||
@@ -89,6 +163,19 @@ async def get_maibot_status():
|
||||
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/dashboard-version", response_model=DashboardVersionResponse)
|
||||
async def get_dashboard_version(current_version: Optional[str] = None):
|
||||
"""获取 WebUI 当前版本和 PyPI 最新版本。"""
|
||||
resolved_current_version = current_version or _get_installed_dashboard_version()
|
||||
latest_version = await _get_latest_dashboard_version_from_pypi()
|
||||
|
||||
return DashboardVersionResponse(
|
||||
current_version=resolved_current_version,
|
||||
latest_version=latest_version,
|
||||
has_update=_is_newer_version(latest_version, resolved_current_version),
|
||||
)
|
||||
|
||||
|
||||
# 可选:添加更多系统控制功能
|
||||
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ class UnifiedWebSocketManager:
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("统一 WebSocket 发送失败: connection=%s, error=%s", connection.connection_id, exc)
|
||||
logger.error(f"统一 WebSocket 发送失败: connection={connection.connection_id}, error={exc}")
|
||||
|
||||
async def connect(self, connection_id: str, websocket: WebSocket) -> WebSocketConnection:
|
||||
"""注册一个新的物理 WebSocket 连接。
|
||||
@@ -108,7 +108,7 @@ class UnifiedWebSocketManager:
|
||||
try:
|
||||
await self._close_websocket(connection)
|
||||
except Exception as exc:
|
||||
logger.debug("关闭统一 WebSocket 底层连接时出现异常: connection=%s, error=%s", connection_id, exc)
|
||||
logger.debug(f"关闭统一 WebSocket 底层连接时出现异常: connection={connection_id}, error={exc}")
|
||||
|
||||
await connection.send_queue.put(None)
|
||||
if connection.sender_task is not None:
|
||||
@@ -117,7 +117,7 @@ class UnifiedWebSocketManager:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.debug("等待发送协程退出时出现异常: connection=%s, error=%s", connection_id, exc)
|
||||
logger.debug(f"等待发送协程退出时出现异常: connection={connection_id}, error={exc}")
|
||||
|
||||
def get_connection(self, connection_id: str) -> Optional[WebSocketConnection]:
|
||||
"""获取指定连接上下文。
|
||||
|
||||
@@ -544,7 +544,7 @@ async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = Query(
|
||||
|
||||
connection_id = uuid.uuid4().hex
|
||||
await websocket_manager.connect(connection_id=connection_id, websocket=websocket)
|
||||
logger.info("统一 WebSocket 客户端已连接: connection=%s", connection_id)
|
||||
logger.info(f"统一 WebSocket 客户端已连接: connection={connection_id}")
|
||||
await websocket_manager.send_event(
|
||||
connection_id,
|
||||
domain="system",
|
||||
@@ -565,17 +565,15 @@ async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = Query(
|
||||
continue
|
||||
await handle_client_message(connection_id, cast(Dict[str, Any], raw_message))
|
||||
except WebSocketDisconnect:
|
||||
logger.info("统一 WebSocket 客户端已断开: connection=%s", connection_id)
|
||||
logger.info(f"统一 WebSocket 客户端已断开: connection={connection_id}")
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("统一 WebSocket 连接处理被取消: connection=%s", connection_id)
|
||||
logger.warning(f"统一 WebSocket 连接处理被取消: connection={connection_id}")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("统一 WebSocket 处理失败: connection=%s, error=%s", connection_id, exc, exc_info=True)
|
||||
logger.error(f"统一 WebSocket 处理失败: connection={connection_id}, error={exc}", exc_info=True)
|
||||
finally:
|
||||
chat_manager.disconnect_connection(connection_id)
|
||||
await websocket_manager.disconnect(connection_id)
|
||||
logger.info(
|
||||
"统一 WebSocket 连接清理完成: connection=%s, 剩余连接=%s",
|
||||
connection_id,
|
||||
len(websocket_manager.connections),
|
||||
f"统一 WebSocket 连接清理完成: connection={connection_id}, 剩余连接={len(websocket_manager.connections)}",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user