feat: add unified WebSocket connection manager and routing
- Implemented UnifiedWebSocketManager for managing WebSocket connections, including subscription handling and message sending. - Created unified WebSocket router to handle client messages, including authentication, subscription, and chat session management. - Added support for logging and plugin progress subscriptions. - Enhanced error handling and response structure for WebSocket operations.
This commit is contained in:
@@ -1,25 +1,28 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import heapq
|
||||
import Levenshtein
|
||||
import random
|
||||
import re
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
|
||||
import Levenshtein
|
||||
|
||||
from src.common.data_models.image_data_model import MaiEmoji
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.database.database import get_db_session, get_db_session_manual
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMImageOptions
|
||||
from src.common.database.database import get_db_session, get_db_session_manual
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("emoji")
|
||||
@@ -33,6 +36,171 @@ EMOJI_REGISTERED_DIR = DATA_DIR / "emoji_registered" # 已注册的表情包注
|
||||
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
|
||||
|
||||
|
||||
def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册表情包系统内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
emoji_schema = {
|
||||
"type": "object",
|
||||
"description": "当前表情包的序列化信息,主要包含 file_hash、description、emotions 等字段。",
|
||||
}
|
||||
string_array_schema = {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="emoji.maisaka.before_select",
|
||||
description="Maisaka 表情发送工具选择表情前触发,可改写情绪、上下文和采样参数,或中止本次选择。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"stream_id": {"type": "string", "description": "目标会话 ID。"},
|
||||
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
|
||||
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
|
||||
"context_texts": {
|
||||
**string_array_schema,
|
||||
"description": "最近聊天上下文文本列表。",
|
||||
},
|
||||
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
|
||||
"abort_message": {
|
||||
"type": "string",
|
||||
"description": "当 Hook 主动中止时可附带的失败提示。",
|
||||
},
|
||||
},
|
||||
required=["stream_id", "requested_emotion", "reasoning", "context_texts", "sample_size"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="emoji.maisaka.after_select",
|
||||
description="Maisaka 已选出表情后触发,可替换选中的表情哈希、补充匹配情绪,或中止发送。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"stream_id": {"type": "string", "description": "目标会话 ID。"},
|
||||
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
|
||||
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
|
||||
"context_texts": {
|
||||
**string_array_schema,
|
||||
"description": "最近聊天上下文文本列表。",
|
||||
},
|
||||
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
|
||||
"selected_emoji": emoji_schema,
|
||||
"selected_emoji_hash": {"type": "string", "description": "选中的表情哈希。"},
|
||||
"matched_emotion": {"type": "string", "description": "最终命中的情绪标签。"},
|
||||
"abort_message": {
|
||||
"type": "string",
|
||||
"description": "当 Hook 主动中止时可附带的失败提示。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"stream_id",
|
||||
"requested_emotion",
|
||||
"reasoning",
|
||||
"context_texts",
|
||||
"sample_size",
|
||||
"matched_emotion",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="emoji.register.after_build_description",
|
||||
description="表情包描述生成并通过内容审查后触发,可改写描述文本或拒绝本次注册。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"emoji": emoji_schema,
|
||||
"description": {"type": "string", "description": "当前生成出的表情包描述。"},
|
||||
"image_format": {"type": "string", "description": "表情图片格式。"},
|
||||
},
|
||||
required=["emoji", "description", "image_format"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="emoji.register.after_build_emotion",
|
||||
description="表情包情绪标签生成完成后触发,可改写标签列表或拒绝本次注册。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"emoji": emoji_schema,
|
||||
"description": {"type": "string", "description": "当前表情包描述。"},
|
||||
"emotions": {
|
||||
**string_array_schema,
|
||||
"description": "当前生成出的情绪标签列表。",
|
||||
},
|
||||
},
|
||||
required=["emoji", "description", "emotions"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
|
||||
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, Any]]:
|
||||
"""将表情包对象序列化为 Hook 可传输载荷。
|
||||
|
||||
Args:
|
||||
emoji: 待序列化的表情包对象。
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 序列化后的字典;当表情为空时返回 ``None``。
|
||||
"""
|
||||
|
||||
if emoji is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"file_hash": str(emoji.file_hash or "").strip(),
|
||||
"file_name": emoji.file_name,
|
||||
"full_path": str(emoji.full_path),
|
||||
"description": emoji.description,
|
||||
"emotions": [str(item).strip() for item in emoji.emotion if str(item).strip()],
|
||||
"query_count": int(emoji.query_count),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_string_list(raw_values: Any) -> List[str]:
|
||||
"""将任意列表值规范化为字符串列表。
|
||||
|
||||
Args:
|
||||
raw_values: 待规范化的原始值。
|
||||
|
||||
Returns:
|
||||
List[str]: 去空白后的字符串列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_values, list):
|
||||
return []
|
||||
return [str(item).strip() for item in raw_values if str(item).strip()]
|
||||
|
||||
|
||||
def _ensure_directories() -> None:
|
||||
"""确保表情包相关目录存在"""
|
||||
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
|
||||
@@ -642,6 +810,22 @@ class EmojiManager:
|
||||
if "否" in llm_response:
|
||||
logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
hook_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.register.after_build_description",
|
||||
emoji=_serialize_emoji_for_hook(target_emoji),
|
||||
description=description,
|
||||
image_format=image_format,
|
||||
)
|
||||
if hook_result.aborted:
|
||||
logger.info(f"[构建描述] 表情包描述被 Hook 中止注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
normalized_description = str(hook_result.kwargs.get("description", description) or "").strip()
|
||||
if not normalized_description:
|
||||
logger.warning(f"[构建描述] Hook 返回空描述,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
description = normalized_description
|
||||
target_emoji.description = description
|
||||
logger.info(f"[构建描述] 成功为表情包构建描述: {target_emoji.description}")
|
||||
return True, target_emoji
|
||||
@@ -687,6 +871,23 @@ class EmojiManager:
|
||||
elif len(emotions) > 2:
|
||||
emotions = random.sample(emotions, 2)
|
||||
|
||||
hook_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.register.after_build_emotion",
|
||||
emoji=_serialize_emoji_for_hook(target_emoji),
|
||||
description=target_emoji.description,
|
||||
emotions=list(emotions),
|
||||
)
|
||||
if hook_result.aborted:
|
||||
logger.info(f"[构建情感标签] 表情包情感标签被 Hook 中止注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
raw_emotions = hook_result.kwargs.get("emotions")
|
||||
if raw_emotions is not None:
|
||||
emotions = _normalize_string_list(raw_emotions)
|
||||
if not emotions:
|
||||
logger.warning(f"[构建情感标签] Hook 返回空情绪标签,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
logger.info(f"[构建情感标签] 成功为表情包构建情感标签: {','.join(emotions)}")
|
||||
target_emoji.emotion = emotions
|
||||
return True, target_emoji
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Maisaka 表情工具内置能力。"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Sequence
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
import random
|
||||
|
||||
@@ -11,7 +11,7 @@ from src.common.logger import get_logger
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.services import send_service
|
||||
|
||||
from .emoji_manager import emoji_manager, emoji_manager_emotion_judge_llm
|
||||
from .emoji_manager import _serialize_emoji_for_hook, emoji_manager, emoji_manager_emotion_judge_llm
|
||||
|
||||
logger = get_logger("emoji_maisaka_tool")
|
||||
|
||||
@@ -29,6 +29,76 @@ class MaisakaEmojiSendResult:
|
||||
matched_emotion: str = ""
|
||||
|
||||
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
|
||||
def _coerce_positive_int(value: Any, default: int) -> int:
|
||||
"""将任意值安全转换为正整数。
|
||||
|
||||
Args:
|
||||
value: 待转换的值。
|
||||
default: 转换失败时使用的默认值。
|
||||
|
||||
Returns:
|
||||
int: 规范化后的正整数。
|
||||
"""
|
||||
|
||||
try:
|
||||
normalized_value = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return normalized_value if normalized_value > 0 else default
|
||||
|
||||
|
||||
def _normalize_context_texts(context_texts: Sequence[str] | None) -> list[str]:
|
||||
"""清洗 Hook 和调用链传入的上下文文本列表。
|
||||
|
||||
Args:
|
||||
context_texts: 原始上下文文本序列。
|
||||
|
||||
Returns:
|
||||
list[str]: 过滤空白后的上下文文本列表。
|
||||
"""
|
||||
|
||||
if not context_texts:
|
||||
return []
|
||||
return [str(item).strip() for item in context_texts if str(item).strip()]
|
||||
|
||||
|
||||
def _resolve_selected_emoji(raw_value: Any) -> Optional[MaiEmoji]:
|
||||
"""根据 Hook 返回值解析目标表情包对象。
|
||||
|
||||
Args:
|
||||
raw_value: Hook 返回的 ``selected_emoji`` 或 ``selected_emoji_hash``。
|
||||
|
||||
Returns:
|
||||
Optional[MaiEmoji]: 命中的表情包对象;未命中时返回 ``None``。
|
||||
"""
|
||||
|
||||
raw_hash: str = ""
|
||||
if isinstance(raw_value, dict):
|
||||
raw_hash = str(raw_value.get("file_hash") or raw_value.get("hash") or "").strip()
|
||||
elif isinstance(raw_value, str):
|
||||
raw_hash = raw_value.strip()
|
||||
|
||||
if not raw_hash:
|
||||
return None
|
||||
|
||||
for emoji in emoji_manager.emojis:
|
||||
if emoji.file_hash == raw_hash:
|
||||
return emoji
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
|
||||
"""提取并清洗单个表情的情绪标签。"""
|
||||
|
||||
@@ -129,16 +199,81 @@ async def send_emoji_for_maisaka(
|
||||
) -> MaisakaEmojiSendResult:
|
||||
"""为 Maisaka 选择并发送一个表情。"""
|
||||
|
||||
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
|
||||
requested_emotion=requested_emotion,
|
||||
reasoning=reasoning,
|
||||
context_texts=context_texts,
|
||||
normalized_requested_emotion = requested_emotion.strip()
|
||||
normalized_reasoning = reasoning.strip()
|
||||
normalized_context_texts = _normalize_context_texts(context_texts)
|
||||
sample_size = 30
|
||||
|
||||
before_select_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.maisaka.before_select",
|
||||
stream_id=stream_id,
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
reasoning=normalized_reasoning,
|
||||
context_texts=list(normalized_context_texts),
|
||||
sample_size=sample_size,
|
||||
abort_message="表情选择已被 Hook 中止。",
|
||||
)
|
||||
if before_select_result.aborted:
|
||||
abort_message = str(before_select_result.kwargs.get("abort_message") or "表情选择已被 Hook 中止。").strip()
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message=abort_message or "表情选择已被 Hook 中止。",
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
)
|
||||
|
||||
before_select_kwargs = before_select_result.kwargs
|
||||
normalized_requested_emotion = str(
|
||||
before_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
|
||||
).strip()
|
||||
normalized_reasoning = str(before_select_kwargs.get("reasoning", normalized_reasoning) or "").strip()
|
||||
if isinstance(before_select_kwargs.get("context_texts"), list):
|
||||
normalized_context_texts = _normalize_context_texts(before_select_kwargs.get("context_texts"))
|
||||
sample_size = _coerce_positive_int(before_select_kwargs.get("sample_size"), sample_size)
|
||||
|
||||
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
reasoning=normalized_reasoning,
|
||||
context_texts=normalized_context_texts,
|
||||
sample_size=sample_size,
|
||||
)
|
||||
after_select_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.maisaka.after_select",
|
||||
stream_id=stream_id,
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
reasoning=normalized_reasoning,
|
||||
context_texts=list(normalized_context_texts),
|
||||
sample_size=sample_size,
|
||||
selected_emoji=_serialize_emoji_for_hook(selected_emoji),
|
||||
selected_emoji_hash=str(selected_emoji.file_hash or "").strip() if selected_emoji is not None else "",
|
||||
matched_emotion=matched_emotion,
|
||||
abort_message="表情发送已被 Hook 中止。",
|
||||
)
|
||||
if after_select_result.aborted:
|
||||
abort_message = str(after_select_result.kwargs.get("abort_message") or "表情发送已被 Hook 中止。").strip()
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message=abort_message or "表情发送已被 Hook 中止。",
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
after_select_kwargs = after_select_result.kwargs
|
||||
normalized_requested_emotion = str(
|
||||
after_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
|
||||
).strip()
|
||||
matched_emotion = str(after_select_kwargs.get("matched_emotion", matched_emotion) or "").strip()
|
||||
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji_hash"))
|
||||
if override_emoji is None:
|
||||
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji"))
|
||||
if override_emoji is not None:
|
||||
selected_emoji = override_emoji
|
||||
|
||||
if selected_emoji is None:
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message="当前表情包库中没有可用表情。",
|
||||
requested_emotion=requested_emotion.strip(),
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -151,7 +286,7 @@ async def send_emoji_for_maisaka(
|
||||
message=f"发送表情包失败:{exc}",
|
||||
description=selected_emoji.description.strip(),
|
||||
emotions=_normalize_emotions(selected_emoji),
|
||||
requested_emotion=requested_emotion.strip(),
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
@@ -169,7 +304,7 @@ async def send_emoji_for_maisaka(
|
||||
message=f"发送表情包时发生异常:{exc}",
|
||||
description=selected_emoji.description.strip(),
|
||||
emotions=_normalize_emotions(selected_emoji),
|
||||
requested_emotion=requested_emotion.strip(),
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
@@ -181,7 +316,7 @@ async def send_emoji_for_maisaka(
|
||||
message="发送表情包失败。",
|
||||
description=description,
|
||||
emotions=emotions,
|
||||
requested_emotion=requested_emotion.strip(),
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
@@ -197,6 +332,6 @@ async def send_emoji_for_maisaka(
|
||||
emoji_base64=emoji_base64,
|
||||
description=description,
|
||||
emotions=emotions,
|
||||
requested_emotion=requested_emotion.strip(),
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
@@ -18,28 +18,29 @@ install(extra_lines=3)
|
||||
logger = get_logger("sender")
|
||||
|
||||
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
|
||||
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None
|
||||
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str], Optional[str]]] = None
|
||||
|
||||
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
|
||||
|
||||
# TODO: 重构完成后完成webui相关
|
||||
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]:
|
||||
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str], Optional[str]]:
|
||||
"""获取 WebUI 聊天室广播器。
|
||||
|
||||
Returns:
|
||||
Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 二元组;
|
||||
Tuple[Any, Optional[str], Optional[str]]: ``(chat_manager, platform_name, default_group_id)`` 三元组;
|
||||
若 WebUI 相关模块不可用,则元素会退化为 ``None``。
|
||||
"""
|
||||
global _webui_chat_broadcaster
|
||||
if _webui_chat_broadcaster is None:
|
||||
try:
|
||||
from src.webui.routers.chat import WEBUI_CHAT_PLATFORM, chat_manager
|
||||
from src.webui.routers.chat.service import WEBUI_CHAT_GROUP_ID
|
||||
|
||||
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
|
||||
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM, WEBUI_CHAT_GROUP_ID)
|
||||
except ImportError:
|
||||
_webui_chat_broadcaster = (None, None)
|
||||
_webui_chat_broadcaster = (None, None, None)
|
||||
return _webui_chat_broadcaster
|
||||
|
||||
|
||||
@@ -76,7 +77,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
|
||||
try:
|
||||
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
|
||||
chat_manager, webui_platform = get_webui_chat_broadcaster()
|
||||
chat_manager, webui_platform, default_group_id = get_webui_chat_broadcaster()
|
||||
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
|
||||
|
||||
if is_webui_message and chat_manager is not None:
|
||||
@@ -97,8 +98,9 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
message_type = "rich"
|
||||
segments = message_segments
|
||||
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
await chat_manager.broadcast_to_group(
|
||||
group_id=group_id or default_group_id or "",
|
||||
message={
|
||||
"type": "bot_message",
|
||||
"content": message.processed_plain_text,
|
||||
"message_type": message_type,
|
||||
@@ -110,7 +112,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
"avatar": None,
|
||||
"is_bot": True,
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库
|
||||
|
||||
@@ -1,22 +1,25 @@
|
||||
from datetime import datetime
|
||||
from sqlmodel import select
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import difflib
|
||||
import json
|
||||
import re
|
||||
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.data_models.expression_data_model import MaiExpression
|
||||
from sqlmodel import select
|
||||
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.common.data_models.expression_data_model import MaiExpression
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
from src.config.config import global_config
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from .expression_utils import check_expression_suitability, parse_expression_response
|
||||
|
||||
@@ -34,8 +37,122 @@ summary_model = LLMServiceClient(task_name="utils", request_type="expression.sum
|
||||
check_model = LLMServiceClient(task_name="utils", request_type="expression.check")
|
||||
|
||||
|
||||
def register_expression_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册表达方式系统内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="expression.select.before_select",
|
||||
description="表达方式选择流程开始前触发,可改写会话上下文、选择参数或中止本次选择。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
|
||||
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
|
||||
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
|
||||
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
|
||||
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
|
||||
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
|
||||
},
|
||||
required=["chat_id", "chat_info", "max_num", "think_level"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="expression.select.after_selection",
|
||||
description="表达方式选择完成后触发,可改写最终选中的表达方式列表与 ID。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
|
||||
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
|
||||
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
|
||||
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
|
||||
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
|
||||
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
|
||||
"selected_expressions": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "当前已选中的表达方式列表。",
|
||||
},
|
||||
"selected_expression_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
"description": "当前已选中的表达方式 ID 列表。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"chat_id",
|
||||
"chat_info",
|
||||
"max_num",
|
||||
"think_level",
|
||||
"selected_expressions",
|
||||
"selected_expression_ids",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="expression.learn.after_extract",
|
||||
description="表达方式学习解析出表达/黑话候选后触发,可改写候选集或直接终止本轮学习。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"message_count": {"type": "integer", "description": "本轮参与学习的消息数量。"},
|
||||
"expressions": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "解析出的表达方式候选列表。",
|
||||
},
|
||||
"jargon_entries": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "解析出的黑话候选列表。",
|
||||
},
|
||||
},
|
||||
required=["session_id", "message_count", "expressions", "jargon_entries"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="expression.learn.before_upsert",
|
||||
description="表达方式写入数据库前触发,可改写情景/风格文本或跳过本条写入。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"situation": {"type": "string", "description": "即将写入的情景文本。"},
|
||||
"style": {"type": "string", "description": "即将写入的风格文本。"},
|
||||
},
|
||||
required=["session_id", "situation", "style"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ExpressionLearner:
|
||||
def __init__(self, session_id: str) -> None:
|
||||
"""初始化表达方式学习器。
|
||||
|
||||
Args:
|
||||
session_id: 当前会话 ID。
|
||||
"""
|
||||
|
||||
self.session_id = session_id
|
||||
|
||||
# 学习锁,防止并发执行学习任务
|
||||
@@ -44,6 +161,110 @@ class ExpressionLearner:
|
||||
# 消息缓存
|
||||
self._messages_cache: List["SessionMessage"] = []
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_expressions(expressions: List[Tuple[str, str, str]]) -> List[dict[str, str]]:
|
||||
"""将表达方式候选序列化为 Hook 载荷。
|
||||
|
||||
Args:
|
||||
expressions: 原始表达方式候选列表。
|
||||
|
||||
Returns:
|
||||
List[dict[str, str]]: 序列化后的表达方式候选。
|
||||
"""
|
||||
|
||||
return [
|
||||
{
|
||||
"situation": str(situation).strip(),
|
||||
"style": str(style).strip(),
|
||||
"source_id": str(source_id).strip(),
|
||||
}
|
||||
for situation, style, source_id in expressions
|
||||
if str(situation).strip() and str(style).strip()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_expressions(raw_expressions: Any) -> List[Tuple[str, str, str]]:
|
||||
"""从 Hook 载荷恢复表达方式候选列表。
|
||||
|
||||
Args:
|
||||
raw_expressions: Hook 返回的表达方式候选。
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 恢复后的表达方式候选列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_expressions, list):
|
||||
return []
|
||||
|
||||
normalized_expressions: List[Tuple[str, str, str]] = []
|
||||
for raw_expression in raw_expressions:
|
||||
if not isinstance(raw_expression, dict):
|
||||
continue
|
||||
situation = str(raw_expression.get("situation") or "").strip()
|
||||
style = str(raw_expression.get("style") or "").strip()
|
||||
source_id = str(raw_expression.get("source_id") or "").strip()
|
||||
if not situation or not style:
|
||||
continue
|
||||
normalized_expressions.append((situation, style, source_id))
|
||||
return normalized_expressions
|
||||
|
||||
@staticmethod
|
||||
def _serialize_jargon_entries(jargon_entries: List[Tuple[str, str]]) -> List[dict[str, str]]:
|
||||
"""将黑话候选序列化为 Hook 载荷。
|
||||
|
||||
Args:
|
||||
jargon_entries: 原始黑话候选列表。
|
||||
|
||||
Returns:
|
||||
List[dict[str, str]]: 序列化后的黑话候选列表。
|
||||
"""
|
||||
|
||||
return [
|
||||
{
|
||||
"content": str(content).strip(),
|
||||
"source_id": str(source_id).strip(),
|
||||
}
|
||||
for content, source_id in jargon_entries
|
||||
if str(content).strip()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_jargon_entries(raw_jargon_entries: Any) -> List[Tuple[str, str]]:
|
||||
"""从 Hook 载荷恢复黑话候选列表。
|
||||
|
||||
Args:
|
||||
raw_jargon_entries: Hook 返回的黑话候选列表。
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 恢复后的黑话候选列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_jargon_entries, list):
|
||||
return []
|
||||
|
||||
normalized_entries: List[Tuple[str, str]] = []
|
||||
for raw_entry in raw_jargon_entries:
|
||||
if not isinstance(raw_entry, dict):
|
||||
continue
|
||||
content = str(raw_entry.get("content") or "").strip()
|
||||
source_id = str(raw_entry.get("source_id") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
normalized_entries.append((content, source_id))
|
||||
return normalized_entries
|
||||
|
||||
def add_messages(self, messages: List["SessionMessage"]) -> None:
|
||||
"""添加消息到缓存"""
|
||||
self._messages_cache.extend(messages)
|
||||
@@ -52,8 +273,12 @@ class ExpressionLearner:
|
||||
"""获取当前消息缓存的大小"""
|
||||
return len(self._messages_cache)
|
||||
|
||||
async def learn(self, jargon_miner: Optional["JargonMiner"] = None):
|
||||
"""学习主流程"""
|
||||
async def learn(self, jargon_miner: Optional["JargonMiner"] = None) -> None:
|
||||
"""执行表达方式学习主流程。
|
||||
|
||||
Args:
|
||||
jargon_miner: 可选的黑话学习器实例,用于同步处理黑话候选。
|
||||
"""
|
||||
if not self._messages_cache:
|
||||
logger.debug("没有消息可供学习,跳过学习过程")
|
||||
return
|
||||
@@ -109,6 +334,25 @@ class ExpressionLearner:
|
||||
logger.info(f"黑话提取数量超过 30 个(实际{len(jargon_entries)}个),放弃本次黑话学习")
|
||||
jargon_entries = []
|
||||
|
||||
after_extract_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.learn.after_extract",
|
||||
session_id=self.session_id,
|
||||
message_count=len(self._messages_cache),
|
||||
expressions=self._serialize_expressions(expressions),
|
||||
jargon_entries=self._serialize_jargon_entries(jargon_entries),
|
||||
)
|
||||
if after_extract_result.aborted:
|
||||
logger.info(f"{self.session_id} 的表达方式学习结果被 Hook 中止")
|
||||
return
|
||||
|
||||
after_extract_kwargs = after_extract_result.kwargs
|
||||
raw_expressions = after_extract_kwargs.get("expressions")
|
||||
if raw_expressions is not None:
|
||||
expressions = self._deserialize_expressions(raw_expressions)
|
||||
raw_jargon_entries = after_extract_kwargs.get("jargon_entries")
|
||||
if raw_jargon_entries is not None:
|
||||
jargon_entries = self._deserialize_jargon_entries(raw_jargon_entries)
|
||||
|
||||
# 处理黑话条目,路由到 jargon_miner(即使没有表达方式也要处理黑话)
|
||||
# TODO: 检测是否开启了
|
||||
if jargon_entries:
|
||||
@@ -135,6 +379,22 @@ class ExpressionLearner:
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for situation, style in learnt_expressions:
|
||||
before_upsert_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.learn.before_upsert",
|
||||
session_id=self.session_id,
|
||||
situation=situation,
|
||||
style=style,
|
||||
)
|
||||
if before_upsert_result.aborted:
|
||||
logger.info(f"{self.session_id} 的表达方式写入被 Hook 跳过: situation={situation!r}")
|
||||
continue
|
||||
|
||||
upsert_kwargs = before_upsert_result.kwargs
|
||||
situation = str(upsert_kwargs.get("situation", situation) or "").strip()
|
||||
style = str(upsert_kwargs.get("style", style) or "").strip()
|
||||
if not situation or not style:
|
||||
logger.info(f"{self.session_id} 的表达方式写入被 Hook 清空,已跳过")
|
||||
continue
|
||||
await self._upsert_expression_to_db(situation, style)
|
||||
|
||||
# ====== 黑话相关 ======
|
||||
|
||||
@@ -1,27 +1,109 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.learners.learner_utils_old import weighted_sample
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.learners.learner_utils_old import weighted_sample
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""初始化表达方式选择器。"""
|
||||
|
||||
self.llm_model = LLMServiceClient(
|
||||
task_name="utils", request_type="expression.selector"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
"""将任意值安全转换为整数。
|
||||
|
||||
Args:
|
||||
value: 待转换的值。
|
||||
default: 转换失败时的默认值。
|
||||
|
||||
Returns:
|
||||
int: 转换后的整数结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _normalize_selected_expressions(raw_expressions: Any) -> List[Dict[str, Any]]:
|
||||
"""从 Hook 载荷恢复表达方式选择结果。
|
||||
|
||||
Args:
|
||||
raw_expressions: Hook 返回的表达方式列表。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 恢复后的表达方式列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_expressions, list):
|
||||
return []
|
||||
|
||||
normalized_expressions: List[Dict[str, Any]] = []
|
||||
for raw_expression in raw_expressions:
|
||||
if not isinstance(raw_expression, dict):
|
||||
continue
|
||||
expression_id = raw_expression.get("id")
|
||||
situation = str(raw_expression.get("situation") or "").strip()
|
||||
style = str(raw_expression.get("style") or "").strip()
|
||||
source_id = str(raw_expression.get("source_id") or "").strip()
|
||||
if not isinstance(expression_id, int) or not situation or not style or not source_id:
|
||||
continue
|
||||
normalized_expression = dict(raw_expression)
|
||||
normalized_expression["id"] = expression_id
|
||||
normalized_expression["situation"] = situation
|
||||
normalized_expression["style"] = style
|
||||
normalized_expression["source_id"] = source_id
|
||||
normalized_expressions.append(normalized_expression)
|
||||
return normalized_expressions
|
||||
|
||||
@staticmethod
|
||||
def _normalize_selected_expression_ids(raw_ids: Any, expressions: List[Dict[str, Any]]) -> List[int]:
|
||||
"""规范化最终选中的表达方式 ID 列表。
|
||||
|
||||
Args:
|
||||
raw_ids: Hook 返回的 ID 列表。
|
||||
expressions: 当前最终表达方式列表。
|
||||
|
||||
Returns:
|
||||
List[int]: 规范化后的 ID 列表。
|
||||
"""
|
||||
|
||||
if isinstance(raw_ids, list):
|
||||
normalized_ids = [item for item in raw_ids if isinstance(item, int)]
|
||||
if normalized_ids:
|
||||
return normalized_ids
|
||||
return [expression["id"] for expression in expressions if isinstance(expression.get("id"), int)]
|
||||
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许使用表达
|
||||
@@ -214,8 +296,7 @@ class ExpressionSelector:
|
||||
reply_reason: Optional[str] = None,
|
||||
think_level: int = 1,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
选择适合的表达方式(使用classic模式:随机选择+LLM选择)
|
||||
"""选择适合的表达方式。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
@@ -233,11 +314,60 @@ class ExpressionSelector:
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return [], []
|
||||
|
||||
before_select_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.select.before_select",
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
max_num=max_num,
|
||||
target_message=target_message or "",
|
||||
reply_reason=reply_reason or "",
|
||||
think_level=think_level,
|
||||
)
|
||||
if before_select_result.aborted:
|
||||
logger.info(f"聊天流 {chat_id} 的表达方式选择被 Hook 中止")
|
||||
return [], []
|
||||
|
||||
before_select_kwargs = before_select_result.kwargs
|
||||
chat_id = str(before_select_kwargs.get("chat_id", chat_id) or "").strip() or chat_id
|
||||
chat_info = str(before_select_kwargs.get("chat_info", chat_info) or "")
|
||||
max_num = max(self._coerce_int(before_select_kwargs.get("max_num"), max_num), 1)
|
||||
raw_target_message = before_select_kwargs.get("target_message", target_message or "")
|
||||
target_message = str(raw_target_message or "").strip() or None
|
||||
raw_reply_reason = before_select_kwargs.get("reply_reason", reply_reason or "")
|
||||
reply_reason = str(raw_reply_reason or "").strip() or None
|
||||
think_level = self._coerce_int(before_select_kwargs.get("think_level"), think_level)
|
||||
|
||||
# 使用classic模式(随机选择+LLM选择)
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式,think_level={think_level}")
|
||||
return await self._select_expressions_classic(
|
||||
selected_expressions, selected_ids = await self._select_expressions_classic(
|
||||
chat_id, chat_info, max_num, target_message, reply_reason, think_level
|
||||
)
|
||||
after_selection_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.select.after_selection",
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
max_num=max_num,
|
||||
target_message=target_message or "",
|
||||
reply_reason=reply_reason or "",
|
||||
think_level=think_level,
|
||||
selected_expressions=[dict(item) for item in selected_expressions],
|
||||
selected_expression_ids=list(selected_ids),
|
||||
)
|
||||
if after_selection_result.aborted:
|
||||
logger.info(f"聊天流 {chat_id} 的表达方式选择结果被 Hook 中止")
|
||||
return [], []
|
||||
|
||||
after_selection_kwargs = after_selection_result.kwargs
|
||||
raw_selected_expressions = after_selection_kwargs.get("selected_expressions")
|
||||
if raw_selected_expressions is not None:
|
||||
selected_expressions = self._normalize_selected_expressions(raw_selected_expressions)
|
||||
selected_ids = self._normalize_selected_expression_ids(
|
||||
after_selection_kwargs.get("selected_expression_ids"),
|
||||
selected_expressions,
|
||||
)
|
||||
if selected_expressions:
|
||||
self.update_expressions_last_active_time(selected_expressions)
|
||||
return selected_expressions, selected_ids
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Dict, List, Optional, Set, TypedDict
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, TypedDict
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
@@ -9,13 +9,15 @@ from json_repair import repair_json
|
||||
from sqlmodel import select
|
||||
|
||||
from src.common.data_models.jargon_data_model import MaiJargon
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from .expression_utils import is_single_char_jargon
|
||||
|
||||
@@ -35,8 +37,140 @@ class JargonMeaningEntry(TypedDict):
|
||||
meaning: str
|
||||
|
||||
|
||||
def register_jargon_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册 jargon 系统内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="jargon.query.before_search",
|
||||
description="Maisaka 黑话查询工具执行检索前触发,可改写词条列表、检索参数或直接中止。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"words": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "准备查询的黑话词条列表。",
|
||||
},
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"limit": {"type": "integer", "description": "单个词条的最大返回条数。"},
|
||||
"case_sensitive": {"type": "boolean", "description": "是否大小写敏感。"},
|
||||
"enable_fuzzy_fallback": {"type": "boolean", "description": "是否允许精确命中失败后回退模糊检索。"},
|
||||
"abort_message": {"type": "string", "description": "Hook 主动中止时的失败提示。"},
|
||||
},
|
||||
required=["words", "session_id", "limit", "case_sensitive", "enable_fuzzy_fallback"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="jargon.query.after_search",
|
||||
description="Maisaka 黑话查询工具完成检索后触发,可改写结果列表或中止返回。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"words": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "实际查询的黑话词条列表。",
|
||||
},
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"limit": {"type": "integer", "description": "单个词条的最大返回条数。"},
|
||||
"case_sensitive": {"type": "boolean", "description": "是否大小写敏感。"},
|
||||
"enable_fuzzy_fallback": {"type": "boolean", "description": "是否启用了模糊检索回退。"},
|
||||
"results": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "查询结果列表。",
|
||||
},
|
||||
"abort_message": {"type": "string", "description": "Hook 主动中止时的失败提示。"},
|
||||
},
|
||||
required=["words", "session_id", "limit", "case_sensitive", "enable_fuzzy_fallback", "results"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="jargon.extract.before_persist",
|
||||
description="黑话条目准备写入数据库前触发,可改写去重后的条目列表或跳过本次持久化。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"session_name": {"type": "string", "description": "当前会话展示名称。"},
|
||||
"entries": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "即将持久化的黑话条目列表。",
|
||||
},
|
||||
},
|
||||
required=["session_id", "session_name", "entries"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="jargon.inference.before_finalize",
|
||||
description="黑话含义推断完成、写回数据库前触发,可改写最终判定与含义结果。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"session_name": {"type": "string", "description": "当前会话展示名称。"},
|
||||
"content": {"type": "string", "description": "当前黑话词条。"},
|
||||
"count": {"type": "integer", "description": "当前词条累计命中次数。"},
|
||||
"raw_content_list": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "用于推断的原始上下文片段列表。",
|
||||
},
|
||||
"inference_with_context": {"type": "object", "description": "基于上下文的推断结果。"},
|
||||
"inference_with_content_only": {"type": "object", "description": "仅基于词条内容的推断结果。"},
|
||||
"comparison_result": {"type": "object", "description": "比较阶段输出结果。"},
|
||||
"is_jargon": {"type": "boolean", "description": "当前推断是否判定为黑话。"},
|
||||
"meaning": {"type": "string", "description": "当前推断出的黑话含义。"},
|
||||
"is_complete": {"type": "boolean", "description": "当前是否已完成全部推断流程。"},
|
||||
"last_inference_count": {"type": "integer", "description": "本次推断完成后应写回的 last_inference_count。"},
|
||||
},
|
||||
required=[
|
||||
"session_id",
|
||||
"session_name",
|
||||
"content",
|
||||
"count",
|
||||
"raw_content_list",
|
||||
"inference_with_context",
|
||||
"inference_with_content_only",
|
||||
"comparison_result",
|
||||
"is_jargon",
|
||||
"meaning",
|
||||
"is_complete",
|
||||
"last_inference_count",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class JargonMiner:
|
||||
def __init__(self, session_id: str, session_name: str) -> None:
|
||||
"""初始化黑话学习器。
|
||||
|
||||
Args:
|
||||
session_id: 当前会话 ID。
|
||||
session_name: 当前会话展示名称。
|
||||
"""
|
||||
|
||||
self.session_id = session_id
|
||||
self.session_name = session_name
|
||||
|
||||
@@ -46,13 +180,92 @@ class JargonMiner:
|
||||
# 黑话提取锁,防止并发执行
|
||||
self._extraction_lock = asyncio.Lock()
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
"""将任意值安全转换为整数。
|
||||
|
||||
Args:
|
||||
value: 待转换的值。
|
||||
default: 转换失败时使用的默认值。
|
||||
|
||||
Returns:
|
||||
int: 转换后的整数结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _serialize_jargon_entries(entries: List[JargonEntry]) -> List[Dict[str, object]]:
|
||||
"""将黑话条目列表序列化为 Hook 可传输结构。
|
||||
|
||||
Args:
|
||||
entries: 原始黑话条目列表。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, object]]: 序列化后的条目列表。
|
||||
"""
|
||||
|
||||
return [
|
||||
{
|
||||
"content": str(entry["content"]).strip(),
|
||||
"raw_content": sorted(str(item).strip() for item in entry["raw_content"] if str(item).strip()),
|
||||
}
|
||||
for entry in entries
|
||||
if str(entry["content"]).strip()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_jargon_entries(raw_entries: Any) -> List[JargonEntry]:
|
||||
"""从 Hook 载荷恢复黑话条目列表。
|
||||
|
||||
Args:
|
||||
raw_entries: Hook 返回的条目数据。
|
||||
|
||||
Returns:
|
||||
List[JargonEntry]: 恢复后的黑话条目列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_entries, list):
|
||||
return []
|
||||
|
||||
normalized_entries: List[JargonEntry] = []
|
||||
for raw_entry in raw_entries:
|
||||
if not isinstance(raw_entry, dict):
|
||||
continue
|
||||
content = str(raw_entry.get("content") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
raw_content_values = raw_entry.get("raw_content")
|
||||
raw_content: Set[str] = set()
|
||||
if isinstance(raw_content_values, list):
|
||||
raw_content = {str(item).strip() for item in raw_content_values if str(item).strip()}
|
||||
normalized_entries.append({"content": content, "raw_content": raw_content})
|
||||
return normalized_entries
|
||||
|
||||
def get_cached_jargons(self) -> List[str]:
|
||||
"""获取缓存中的所有黑话列表"""
|
||||
return list(self.cache.keys())
|
||||
|
||||
async def infer_meaning(self, jargon_obj: MaiJargon) -> None:
|
||||
"""
|
||||
对jargon进行含义推断
|
||||
"""对黑话条目执行含义推断。
|
||||
|
||||
Args:
|
||||
jargon_obj: 待推断的黑话数据对象。
|
||||
"""
|
||||
content = jargon_obj.content
|
||||
# 解析raw_content列表
|
||||
@@ -175,15 +388,45 @@ class JargonMiner:
|
||||
is_similar = comparison_result.get("is_similar", False)
|
||||
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
|
||||
|
||||
finalized_meaning = inference1.get("meaning", "") if is_jargon else ""
|
||||
is_complete = (jargon_obj.count or 0) >= 100
|
||||
last_inference_count = jargon_obj.count or 0
|
||||
finalize_result = await self._get_runtime_manager().invoke_hook(
|
||||
"jargon.inference.before_finalize",
|
||||
session_id=self.session_id,
|
||||
session_name=self.session_name,
|
||||
content=content,
|
||||
count=current_count,
|
||||
raw_content_list=list(raw_content_list),
|
||||
inference_with_context=dict(inference1),
|
||||
inference_with_content_only=dict(inference2),
|
||||
comparison_result=dict(comparison_result),
|
||||
is_jargon=is_jargon,
|
||||
meaning=finalized_meaning,
|
||||
is_complete=is_complete,
|
||||
last_inference_count=last_inference_count,
|
||||
)
|
||||
if finalize_result.aborted:
|
||||
logger.info(f"jargon {content} 的推断结果被 Hook 中止写回")
|
||||
return
|
||||
|
||||
finalize_kwargs = finalize_result.kwargs
|
||||
is_jargon = bool(finalize_kwargs.get("is_jargon", is_jargon))
|
||||
finalized_meaning = str(finalize_kwargs.get("meaning", finalized_meaning) or "").strip() if is_jargon else ""
|
||||
is_complete = bool(finalize_kwargs.get("is_complete", is_complete))
|
||||
last_inference_count = self._coerce_int(
|
||||
finalize_kwargs.get("last_inference_count"),
|
||||
last_inference_count,
|
||||
)
|
||||
|
||||
# 更新数据库记录
|
||||
jargon_obj.is_jargon = is_jargon
|
||||
jargon_obj.meaning = inference1.get("meaning", "") if is_jargon else ""
|
||||
jargon_obj.meaning = finalized_meaning
|
||||
# 更新最后一次判定的count值,避免重启后重复判定
|
||||
jargon_obj.last_inference_count = jargon_obj.count or 0
|
||||
jargon_obj.last_inference_count = last_inference_count
|
||||
|
||||
# 如果count>=100,标记为完成,不再进行推断
|
||||
if (jargon_obj.count or 0) >= 100:
|
||||
jargon_obj.is_complete = True
|
||||
jargon_obj.is_complete = is_complete
|
||||
|
||||
try:
|
||||
self._modify_jargon_entry(jargon_obj)
|
||||
@@ -232,6 +475,22 @@ class JargonMiner:
|
||||
merged_entries[content] = {"content": content, "raw_content": set(raw_list)}
|
||||
|
||||
uniq_entries: List[JargonEntry] = list(merged_entries.values())
|
||||
before_persist_result = await self._get_runtime_manager().invoke_hook(
|
||||
"jargon.extract.before_persist",
|
||||
session_id=self.session_id,
|
||||
session_name=self.session_name,
|
||||
entries=self._serialize_jargon_entries(uniq_entries),
|
||||
)
|
||||
if before_persist_result.aborted:
|
||||
logger.info(f"[{self.session_name}] 黑话提取结果被 Hook 中止,不写入数据库")
|
||||
return
|
||||
|
||||
raw_hook_entries = before_persist_result.kwargs.get("entries")
|
||||
if raw_hook_entries is not None:
|
||||
uniq_entries = self._deserialize_jargon_entries(raw_hook_entries)
|
||||
if not uniq_entries:
|
||||
logger.info(f"[{self.session_name}] Hook 过滤后没有可写入的黑话条目")
|
||||
return
|
||||
|
||||
saved = 0
|
||||
updated = 0
|
||||
|
||||
@@ -54,6 +54,84 @@ class MaisakaReasoningEngine:
|
||||
self._runtime = runtime
|
||||
self._last_reasoning_content: str = ""
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_words(raw_words: Any) -> list[str]:
|
||||
"""清洗黑话查询词条列表。
|
||||
|
||||
Args:
|
||||
raw_words: 原始词条列表。
|
||||
|
||||
Returns:
|
||||
list[str]: 去重去空白后的词条列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_words, list):
|
||||
return []
|
||||
|
||||
normalized_words: list[str] = []
|
||||
seen_words: set[str] = set()
|
||||
for item in raw_words:
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
word = item.strip()
|
||||
if not word or word in seen_words:
|
||||
continue
|
||||
seen_words.add(word)
|
||||
normalized_words.append(word)
|
||||
return normalized_words
|
||||
|
||||
@staticmethod
|
||||
def _normalize_jargon_query_results(raw_results: Any) -> list[dict[str, object]]:
|
||||
"""规范化黑话查询结果列表。
|
||||
|
||||
Args:
|
||||
raw_results: Hook 返回的结果列表。
|
||||
|
||||
Returns:
|
||||
list[dict[str, object]]: 清洗后的结果列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_results, list):
|
||||
return []
|
||||
|
||||
normalized_results: list[dict[str, object]] = []
|
||||
for raw_item in raw_results:
|
||||
if not isinstance(raw_item, dict):
|
||||
continue
|
||||
word = str(raw_item.get("word") or "").strip()
|
||||
matches = raw_item.get("matches")
|
||||
normalized_matches: list[dict[str, str]] = []
|
||||
if isinstance(matches, list):
|
||||
for match in matches:
|
||||
if not isinstance(match, dict):
|
||||
continue
|
||||
content = str(match.get("content") or "").strip()
|
||||
meaning = str(match.get("meaning") or "").strip()
|
||||
if not content or not meaning:
|
||||
continue
|
||||
normalized_matches.append({"content": content, "meaning": meaning})
|
||||
|
||||
normalized_results.append(
|
||||
{
|
||||
"word": word,
|
||||
"found": bool(raw_item.get("found", bool(normalized_matches))),
|
||||
"matches": normalized_matches,
|
||||
}
|
||||
)
|
||||
return normalized_results
|
||||
|
||||
def build_builtin_tool_handlers(self) -> dict[str, "BuiltinToolHandler"]:
|
||||
"""构造 Maisaka 内置工具处理器映射。
|
||||
|
||||
@@ -1012,6 +1090,35 @@ class MaisakaReasoningEngine:
|
||||
"查询黑话工具至少需要一个非空词条。",
|
||||
)
|
||||
|
||||
limit = 5
|
||||
case_sensitive = False
|
||||
enable_fuzzy_fallback = True
|
||||
before_search_result = await self._get_runtime_manager().invoke_hook(
|
||||
"jargon.query.before_search",
|
||||
words=list(words),
|
||||
session_id=self._runtime.session_id,
|
||||
limit=limit,
|
||||
case_sensitive=case_sensitive,
|
||||
enable_fuzzy_fallback=enable_fuzzy_fallback,
|
||||
abort_message="黑话查询已被 Hook 中止。",
|
||||
)
|
||||
if before_search_result.aborted:
|
||||
abort_message = str(before_search_result.kwargs.get("abort_message") or "黑话查询已被 Hook 中止。").strip()
|
||||
return self._build_tool_failure_result(tool_call.func_name, abort_message or "黑话查询已被 Hook 中止。")
|
||||
|
||||
before_search_kwargs = before_search_result.kwargs
|
||||
if before_search_kwargs.get("words") is not None:
|
||||
words = self._normalize_words(before_search_kwargs.get("words"))
|
||||
if not words:
|
||||
return self._build_tool_failure_result(tool_call.func_name, "Hook 过滤后没有可查询的黑话词条。")
|
||||
try:
|
||||
limit = int(before_search_kwargs.get("limit", limit))
|
||||
except (TypeError, ValueError):
|
||||
limit = 5
|
||||
limit = max(limit, 1)
|
||||
case_sensitive = bool(before_search_kwargs.get("case_sensitive", case_sensitive))
|
||||
enable_fuzzy_fallback = bool(before_search_kwargs.get("enable_fuzzy_fallback", enable_fuzzy_fallback))
|
||||
|
||||
logger.info(f"{self._runtime.log_prefix} 已触发黑话查询: 词条={words!r}")
|
||||
|
||||
results: list[dict[str, object]] = []
|
||||
@@ -1019,17 +1126,19 @@ class MaisakaReasoningEngine:
|
||||
exact_matches = search_jargon(
|
||||
keyword=word,
|
||||
chat_id=self._runtime.session_id,
|
||||
limit=5,
|
||||
case_sensitive=False,
|
||||
limit=limit,
|
||||
case_sensitive=case_sensitive,
|
||||
fuzzy=False,
|
||||
)
|
||||
matched_entries = exact_matches or search_jargon(
|
||||
keyword=word,
|
||||
chat_id=self._runtime.session_id,
|
||||
limit=5,
|
||||
case_sensitive=False,
|
||||
fuzzy=True,
|
||||
)
|
||||
matched_entries = exact_matches
|
||||
if not matched_entries and enable_fuzzy_fallback:
|
||||
matched_entries = search_jargon(
|
||||
keyword=word,
|
||||
chat_id=self._runtime.session_id,
|
||||
limit=limit,
|
||||
case_sensitive=case_sensitive,
|
||||
fuzzy=True,
|
||||
)
|
||||
|
||||
results.append(
|
||||
{
|
||||
@@ -1039,6 +1148,27 @@ class MaisakaReasoningEngine:
|
||||
}
|
||||
)
|
||||
|
||||
after_search_result = await self._get_runtime_manager().invoke_hook(
|
||||
"jargon.query.after_search",
|
||||
words=list(words),
|
||||
session_id=self._runtime.session_id,
|
||||
limit=limit,
|
||||
case_sensitive=case_sensitive,
|
||||
enable_fuzzy_fallback=enable_fuzzy_fallback,
|
||||
results=list(results),
|
||||
abort_message="黑话查询结果已被 Hook 中止。",
|
||||
)
|
||||
if after_search_result.aborted:
|
||||
abort_message = str(after_search_result.kwargs.get("abort_message") or "黑话查询结果已被 Hook 中止。").strip()
|
||||
return self._build_tool_failure_result(
|
||||
tool_call.func_name,
|
||||
abort_message or "黑话查询结果已被 Hook 中止。",
|
||||
)
|
||||
|
||||
raw_results = after_search_result.kwargs.get("results")
|
||||
if raw_results is not None:
|
||||
results = self._normalize_jargon_query_results(raw_results)
|
||||
|
||||
logger.info(f"{self._runtime.log_prefix} 黑话查询完成: 结果={results!r}")
|
||||
return self._build_tool_success_result(
|
||||
tool_call.func_name,
|
||||
|
||||
@@ -20,11 +20,17 @@ def _get_builtin_hook_spec_registrars() -> List[HookSpecRegistrar]:
|
||||
"""
|
||||
|
||||
from src.chat.message_receive.bot import register_chat_hook_specs
|
||||
from src.chat.emoji_system.emoji_manager import register_emoji_hook_specs
|
||||
from src.learners.expression_learner import register_expression_hook_specs
|
||||
from src.learners.jargon_miner import register_jargon_hook_specs
|
||||
from src.maisaka.chat_loop_service import register_maisaka_hook_specs
|
||||
from src.services.send_service import register_send_service_hook_specs
|
||||
|
||||
return [
|
||||
register_chat_hook_specs,
|
||||
register_emoji_hook_specs,
|
||||
register_jargon_hook_specs,
|
||||
register_expression_hook_specs,
|
||||
register_send_service_hook_specs,
|
||||
register_maisaka_hook_specs,
|
||||
]
|
||||
|
||||
@@ -9,6 +9,7 @@ from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
from src.webui.routers.websocket.manager import websocket_manager
|
||||
|
||||
logger = get_logger("webui.logs_ws")
|
||||
router = APIRouter()
|
||||
@@ -148,24 +149,9 @@ async def broadcast_log(log_data: Dict):
|
||||
Args:
|
||||
log_data: 日志数据字典
|
||||
"""
|
||||
if not active_connections:
|
||||
return
|
||||
|
||||
# 格式化为 JSON
|
||||
message = json.dumps(log_data, ensure_ascii=False)
|
||||
|
||||
# 记录需要断开的连接
|
||||
disconnected = set()
|
||||
|
||||
# 广播到所有客户端
|
||||
for connection in active_connections:
|
||||
try:
|
||||
await connection.send_text(message)
|
||||
except Exception:
|
||||
# 发送失败,标记为断开
|
||||
disconnected.add(connection)
|
||||
|
||||
# 清理断开的连接
|
||||
if disconnected:
|
||||
active_connections.difference_update(disconnected)
|
||||
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")
|
||||
await websocket_manager.broadcast_to_topic(
|
||||
domain="logs",
|
||||
topic="main",
|
||||
event="entry",
|
||||
data={"entry": log_data},
|
||||
)
|
||||
|
||||
@@ -18,12 +18,10 @@ def get_all_routers() -> List[APIRouter]:
|
||||
from src.webui.api.replier import router as replier_router
|
||||
from src.webui.routers.chat import router as chat_router
|
||||
from src.webui.routers.knowledge import router as knowledge_router
|
||||
from src.webui.routers.websocket.logs import router as logs_router
|
||||
from src.webui.routes import router as main_router
|
||||
|
||||
return [
|
||||
main_router,
|
||||
logs_router,
|
||||
knowledge_router,
|
||||
chat_router,
|
||||
planner_router,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Tuple
|
||||
|
||||
from .routes import router
|
||||
from .support import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
|
||||
from .service import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
|
||||
|
||||
|
||||
def get_webui_chat_broadcaster() -> Tuple[ChatConnectionManager, str]:
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""本地聊天室路由 - WebUI 与麦麦直接对话。"""
|
||||
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import case, func
|
||||
from sqlmodel import col, select
|
||||
|
||||
@@ -13,16 +12,11 @@ from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
from .support import (
|
||||
from .service import (
|
||||
WEBUI_CHAT_GROUP_ID,
|
||||
WEBUI_CHAT_PLATFORM,
|
||||
authenticate_chat_websocket,
|
||||
chat_history,
|
||||
chat_manager,
|
||||
dispatch_chat_event,
|
||||
normalize_webui_user_id,
|
||||
resolve_initial_virtual_identity,
|
||||
send_initial_chat_state,
|
||||
)
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
@@ -113,55 +107,6 @@ async def clear_chat_history(
|
||||
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_chat(
|
||||
websocket: WebSocket,
|
||||
user_id: Optional[str] = Query(default=None),
|
||||
user_name: Optional[str] = Query(default="WebUI用户"),
|
||||
platform: Optional[str] = Query(default=None),
|
||||
person_id: Optional[str] = Query(default=None),
|
||||
group_name: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None),
|
||||
token: Optional[str] = Query(default=None),
|
||||
) -> None:
|
||||
"""WebSocket 聊天端点。"""
|
||||
if not await authenticate_chat_websocket(websocket, token):
|
||||
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
normalized_user_id = normalize_webui_user_id(user_id)
|
||||
current_user_name = user_name or "WebUI用户"
|
||||
current_virtual_config = resolve_initial_virtual_identity(platform, person_id, group_name, group_id)
|
||||
|
||||
await chat_manager.connect(websocket, session_id, normalized_user_id)
|
||||
try:
|
||||
await send_initial_chat_state(
|
||||
session_id=session_id,
|
||||
user_id=normalized_user_id,
|
||||
user_name=current_user_name,
|
||||
virtual_config=current_virtual_config,
|
||||
)
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
current_user_name, current_virtual_config = await dispatch_chat_event(
|
||||
session_id=session_id,
|
||||
session_id_prefix=session_id[:8],
|
||||
data=data,
|
||||
current_user_name=current_user_name,
|
||||
normalized_user_id=normalized_user_id,
|
||||
current_virtual_config=current_virtual_config,
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开: session={session_id}, user={normalized_user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 错误: {e}")
|
||||
finally:
|
||||
chat_manager.disconnect(session_id, normalized_user_id)
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_chat_info() -> Dict[str, object]:
|
||||
"""获取聊天室信息。"""
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""WebUI 聊天路由支持逻辑。"""
|
||||
"""WebUI 聊天运行时服务。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, cast
|
||||
|
||||
from fastapi import WebSocket
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, delete, select
|
||||
|
||||
@@ -17,8 +17,6 @@ from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
|
||||
@@ -27,6 +25,8 @@ WEBUI_CHAT_PLATFORM = "webui"
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
WEBUI_USER_ID_PREFIX = "webui_user_"
|
||||
|
||||
AsyncMessageSender = Callable[[Dict[str, Any]], Awaitable[None]]
|
||||
|
||||
|
||||
class VirtualIdentityConfig(BaseModel):
|
||||
"""虚拟身份配置。"""
|
||||
@@ -52,13 +52,42 @@ class ChatHistoryMessage(BaseModel):
|
||||
is_bot: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatSessionConnection:
|
||||
"""逻辑聊天会话连接信息。"""
|
||||
|
||||
session_id: str
|
||||
connection_id: str
|
||||
client_session_id: str
|
||||
user_id: str
|
||||
user_name: str
|
||||
active_group_id: str
|
||||
virtual_config: Optional[VirtualIdentityConfig]
|
||||
sender: AsyncMessageSender
|
||||
|
||||
|
||||
class ChatHistoryManager:
|
||||
"""聊天历史管理器。"""
|
||||
|
||||
def __init__(self, max_messages: int = 200) -> None:
|
||||
"""初始化聊天历史管理器。
|
||||
|
||||
Args:
|
||||
max_messages: 内存中允许处理的最大消息数。
|
||||
"""
|
||||
self.max_messages = max_messages
|
||||
|
||||
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""将内部消息对象转换为前端可消费的字典。
|
||||
|
||||
Args:
|
||||
msg: 内部统一消息对象。
|
||||
group_id: 当前会话所属的群组标识。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 面向 WebUI 的消息字典。
|
||||
"""
|
||||
del group_id
|
||||
user_info = msg.message_info.user_info
|
||||
user_id = user_info.user_id or ""
|
||||
is_bot = is_bot_self(msg.platform, user_id)
|
||||
@@ -74,10 +103,27 @@ class ChatHistoryManager:
|
||||
}
|
||||
|
||||
def _resolve_session_id(self, group_id: Optional[str]) -> str:
|
||||
"""根据群组标识解析聊天会话 ID。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识。
|
||||
|
||||
Returns:
|
||||
str: 内部聊天会话 ID。
|
||||
"""
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
|
||||
|
||||
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""获取指定会话的历史消息。
|
||||
|
||||
Args:
|
||||
limit: 最大返回条数。
|
||||
group_id: 群组标识。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 历史消息列表。
|
||||
"""
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
session_id = self._resolve_session_id(target_group_id)
|
||||
try:
|
||||
@@ -90,11 +136,19 @@ class ChatHistoryManager:
|
||||
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
|
||||
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载聊天记录失败: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"从数据库加载聊天记录失败: {exc}")
|
||||
return []
|
||||
|
||||
def clear_history(self, group_id: Optional[str] = None) -> int:
|
||||
"""清空指定会话的历史消息。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识。
|
||||
|
||||
Returns:
|
||||
int: 被删除的消息数量。
|
||||
"""
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
session_id = self._resolve_session_id(target_group_id)
|
||||
try:
|
||||
@@ -104,66 +158,245 @@ class ChatHistoryManager:
|
||||
deleted = result.rowcount or 0
|
||||
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
|
||||
return deleted
|
||||
except Exception as e:
|
||||
logger.error(f"清空聊天记录失败: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"清空聊天记录失败: {exc}")
|
||||
return 0
|
||||
|
||||
|
||||
class ChatConnectionManager:
|
||||
"""聊天连接管理器。"""
|
||||
"""统一聊天逻辑会话管理器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.user_sessions: Dict[str, str] = {}
|
||||
"""初始化聊天逻辑会话管理器。"""
|
||||
self.active_connections: Dict[str, ChatSessionConnection] = {}
|
||||
self.client_sessions: Dict[Tuple[str, str], str] = {}
|
||||
self.connection_sessions: Dict[str, Set[str]] = {}
|
||||
self.group_sessions: Dict[str, Set[str]] = {}
|
||||
self.user_sessions: Dict[str, Set[str]] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None:
|
||||
await websocket.accept()
|
||||
self.active_connections[session_id] = websocket
|
||||
self.user_sessions[user_id] = session_id
|
||||
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
|
||||
def _bind_group(self, session_id: str, group_id: str) -> None:
|
||||
"""为会话绑定群组索引。
|
||||
|
||||
def disconnect(self, session_id: str, user_id: str) -> None:
|
||||
if session_id in self.active_connections:
|
||||
del self.active_connections[session_id]
|
||||
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
|
||||
del self.user_sessions[user_id]
|
||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||
Args:
|
||||
session_id: 内部会话 ID。
|
||||
group_id: 群组标识。
|
||||
"""
|
||||
group_session_ids = self.group_sessions.setdefault(group_id, set())
|
||||
group_session_ids.add(session_id)
|
||||
|
||||
def _unbind_group(self, session_id: str, group_id: str) -> None:
|
||||
"""移除会话与群组的索引关系。
|
||||
|
||||
Args:
|
||||
session_id: 内部会话 ID。
|
||||
group_id: 群组标识。
|
||||
"""
|
||||
group_session_ids = self.group_sessions.get(group_id)
|
||||
if group_session_ids is None:
|
||||
return
|
||||
|
||||
group_session_ids.discard(session_id)
|
||||
if not group_session_ids:
|
||||
del self.group_sessions[group_id]
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
session_id: str,
|
||||
connection_id: str,
|
||||
client_session_id: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
sender: AsyncMessageSender,
|
||||
) -> None:
|
||||
"""注册一个新的逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
connection_id: 物理 WebSocket 连接 ID。
|
||||
client_session_id: 前端标签页使用的会话 ID。
|
||||
user_id: 规范化后的用户 ID。
|
||||
user_name: 当前展示昵称。
|
||||
virtual_config: 当前虚拟身份配置。
|
||||
sender: 发送消息到前端的异步回调。
|
||||
"""
|
||||
existing_session_id = self.client_sessions.get((connection_id, client_session_id))
|
||||
if existing_session_id is not None:
|
||||
self.disconnect(existing_session_id)
|
||||
|
||||
active_group_id = get_current_group_id(virtual_config)
|
||||
session_connection = ChatSessionConnection(
|
||||
session_id=session_id,
|
||||
connection_id=connection_id,
|
||||
client_session_id=client_session_id,
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
active_group_id=active_group_id,
|
||||
virtual_config=virtual_config,
|
||||
sender=sender,
|
||||
)
|
||||
|
||||
self.active_connections[session_id] = session_connection
|
||||
self.client_sessions[(connection_id, client_session_id)] = session_id
|
||||
self.connection_sessions.setdefault(connection_id, set()).add(session_id)
|
||||
self.user_sessions.setdefault(user_id, set()).add(session_id)
|
||||
self._bind_group(session_id, active_group_id)
|
||||
logger.info(
|
||||
"WebUI 聊天会话已连接: session=%s, connection=%s, client_session=%s, user=%s, group=%s",
|
||||
session_id,
|
||||
connection_id,
|
||||
client_session_id,
|
||||
user_id,
|
||||
active_group_id,
|
||||
)
|
||||
|
||||
def disconnect(self, session_id: str) -> None:
|
||||
"""断开一个逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
"""
|
||||
session_connection = self.active_connections.pop(session_id, None)
|
||||
if session_connection is None:
|
||||
return
|
||||
|
||||
self.client_sessions.pop((session_connection.connection_id, session_connection.client_session_id), None)
|
||||
self._unbind_group(session_id, session_connection.active_group_id)
|
||||
|
||||
connection_session_ids = self.connection_sessions.get(session_connection.connection_id)
|
||||
if connection_session_ids is not None:
|
||||
connection_session_ids.discard(session_id)
|
||||
if not connection_session_ids:
|
||||
del self.connection_sessions[session_connection.connection_id]
|
||||
|
||||
user_session_ids = self.user_sessions.get(session_connection.user_id)
|
||||
if user_session_ids is not None:
|
||||
user_session_ids.discard(session_id)
|
||||
if not user_session_ids:
|
||||
del self.user_sessions[session_connection.user_id]
|
||||
|
||||
logger.info("WebUI 聊天会话已断开: session=%s", session_id)
|
||||
|
||||
def disconnect_connection(self, connection_id: str) -> None:
|
||||
"""断开物理连接下的全部逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
connection_id: 物理 WebSocket 连接 ID。
|
||||
"""
|
||||
session_ids = list(self.connection_sessions.get(connection_id, set()))
|
||||
for session_id in session_ids:
|
||||
self.disconnect(session_id)
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[ChatSessionConnection]:
|
||||
"""获取逻辑聊天会话信息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
|
||||
Returns:
|
||||
Optional[ChatSessionConnection]: 会话存在时返回对应信息。
|
||||
"""
|
||||
return self.active_connections.get(session_id)
|
||||
|
||||
def get_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
|
||||
"""根据连接 ID 和前端会话 ID 查询内部会话 ID。
|
||||
|
||||
Args:
|
||||
connection_id: 物理 WebSocket 连接 ID。
|
||||
client_session_id: 前端标签页使用的会话 ID。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 找到时返回内部会话 ID。
|
||||
"""
|
||||
return self.client_sessions.get((connection_id, client_session_id))
|
||||
|
||||
def update_session_context(
|
||||
self,
|
||||
session_id: str,
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> None:
|
||||
"""更新会话上下文信息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
user_name: 最新昵称。
|
||||
virtual_config: 最新虚拟身份配置。
|
||||
"""
|
||||
session_connection = self.active_connections.get(session_id)
|
||||
if session_connection is None:
|
||||
return
|
||||
|
||||
next_group_id = get_current_group_id(virtual_config)
|
||||
if next_group_id != session_connection.active_group_id:
|
||||
self._unbind_group(session_id, session_connection.active_group_id)
|
||||
self._bind_group(session_id, next_group_id)
|
||||
session_connection.active_group_id = next_group_id
|
||||
|
||||
session_connection.user_name = user_name
|
||||
session_connection.virtual_config = virtual_config
|
||||
|
||||
async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
|
||||
if session_id in self.active_connections:
|
||||
try:
|
||||
await self.active_connections[session_id].send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
"""向指定逻辑会话发送消息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
message: 待发送的消息内容。
|
||||
"""
|
||||
session_connection = self.active_connections.get(session_id)
|
||||
if session_connection is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await session_connection.sender(message)
|
||||
except Exception as exc:
|
||||
logger.error("发送聊天消息失败: session=%s, error=%s", session_id, exc)
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]) -> None:
|
||||
"""向全部逻辑聊天会话广播消息。
|
||||
|
||||
Args:
|
||||
message: 待广播的消息内容。
|
||||
"""
|
||||
for session_id in list(self.active_connections.keys()):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
async def broadcast_to_group(self, group_id: str, message: Dict[str, Any]) -> None:
|
||||
"""向指定群组下的全部逻辑会话广播消息。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识。
|
||||
message: 待广播的消息内容。
|
||||
"""
|
||||
for session_id in list(self.group_sessions.get(group_id, set())):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
|
||||
chat_history = ChatHistoryManager()
|
||||
chat_manager = ChatConnectionManager()
|
||||
|
||||
|
||||
def is_virtual_mode_enabled(virtual_config: Optional[VirtualIdentityConfig]) -> bool:
|
||||
"""判断当前是否启用了虚拟身份模式。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
bool: 已启用时返回 ``True``。
|
||||
"""
|
||||
return bool(virtual_config and virtual_config.enabled)
|
||||
|
||||
|
||||
async def authenticate_chat_websocket(websocket: WebSocket, token: Optional[str]) -> bool:
|
||||
if token and verify_ws_token(token):
|
||||
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
|
||||
return True
|
||||
|
||||
if cookie_token := websocket.cookies.get("maibot_session"):
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def normalize_webui_user_id(user_id: Optional[str]) -> str:
|
||||
"""标准化 WebUI 用户 ID。
|
||||
|
||||
Args:
|
||||
user_id: 原始用户 ID。
|
||||
|
||||
Returns:
|
||||
str: 带统一前缀的用户 ID。
|
||||
"""
|
||||
if not user_id:
|
||||
return f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
|
||||
if user_id.startswith(WEBUI_USER_ID_PREFIX):
|
||||
@@ -172,12 +405,30 @@ def normalize_webui_user_id(user_id: Optional[str]) -> str:
|
||||
|
||||
|
||||
def get_person_by_person_id(person_id: str) -> Optional[PersonInfo]:
|
||||
"""根据人物 ID 查询人物信息。
|
||||
|
||||
Args:
|
||||
person_id: 人物 ID。
|
||||
|
||||
Returns:
|
||||
Optional[PersonInfo]: 查到时返回人物信息。
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||
return session.exec(statement).first()
|
||||
|
||||
|
||||
def build_virtual_identity_config(person: PersonInfo, group_id: str, group_name: str) -> VirtualIdentityConfig:
|
||||
"""根据人物信息构建虚拟身份配置。
|
||||
|
||||
Args:
|
||||
person: 人物信息对象。
|
||||
group_id: 逻辑群组 ID。
|
||||
group_name: 逻辑群组名称。
|
||||
|
||||
Returns:
|
||||
VirtualIdentityConfig: 虚拟身份配置对象。
|
||||
"""
|
||||
return VirtualIdentityConfig(
|
||||
enabled=True,
|
||||
platform=person.platform,
|
||||
@@ -195,6 +446,17 @@ def resolve_initial_virtual_identity(
|
||||
group_name: Optional[str],
|
||||
group_id: Optional[str],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
"""根据初始参数解析虚拟身份配置。
|
||||
|
||||
Args:
|
||||
platform: 平台名称。
|
||||
person_id: 人物 ID。
|
||||
group_name: 群组名称。
|
||||
group_id: 群组 ID。
|
||||
|
||||
Returns:
|
||||
Optional[VirtualIdentityConfig]: 解析成功时返回虚拟身份配置。
|
||||
"""
|
||||
if not (platform and person_id):
|
||||
return None
|
||||
|
||||
@@ -210,11 +472,14 @@ def resolve_initial_virtual_identity(
|
||||
group_name=group_name or "WebUI虚拟群聊",
|
||||
)
|
||||
logger.info(
|
||||
f"虚拟身份模式已通过 URL 参数激活: {virtual_config.user_nickname} @ {virtual_config.platform}, group_id={virtual_group_id}"
|
||||
"虚拟身份模式已通过参数激活: %s @ %s, group_id=%s",
|
||||
virtual_config.user_nickname,
|
||||
virtual_config.platform,
|
||||
virtual_group_id,
|
||||
)
|
||||
return virtual_config
|
||||
except Exception as e:
|
||||
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"通过参数配置虚拟身份失败: {exc}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -224,6 +489,17 @@ def build_session_info_message(
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Dict[str, Any]:
|
||||
"""构建会话信息消息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
user_id: 规范化后的用户 ID。
|
||||
user_name: 当前昵称。
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 会话信息消息。
|
||||
"""
|
||||
session_info_data: Dict[str, Any] = {
|
||||
"type": "session_info",
|
||||
"session_id": session_id,
|
||||
@@ -247,13 +523,41 @@ def build_session_info_message(
|
||||
|
||||
|
||||
def get_active_history_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> Optional[str]:
|
||||
"""获取当前虚拟身份对应的历史群组 ID。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 虚拟身份启用时返回对应群组 ID。
|
||||
"""
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return virtual_config.group_id
|
||||
return None
|
||||
|
||||
|
||||
def get_current_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> str:
|
||||
"""获取当前会话的有效群组 ID。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
str: 当前会话应使用的群组 ID。
|
||||
"""
|
||||
return get_active_history_group_id(virtual_config) or WEBUI_CHAT_GROUP_ID
|
||||
|
||||
|
||||
def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> str:
|
||||
"""构建欢迎消息。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
str: 欢迎消息文本。
|
||||
"""
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return (
|
||||
@@ -264,6 +568,12 @@ def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> st
|
||||
|
||||
|
||||
async def send_chat_error(session_id: str, content: str) -> None:
|
||||
"""向指定会话发送错误消息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
content: 错误消息内容。
|
||||
"""
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
@@ -279,7 +589,17 @@ async def send_initial_chat_state(
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
include_welcome: bool = True,
|
||||
) -> None:
|
||||
"""向新会话发送初始化状态。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
user_id: 规范化后的用户 ID。
|
||||
user_name: 当前昵称。
|
||||
virtual_config: 虚拟身份配置。
|
||||
include_welcome: 是否发送欢迎消息。
|
||||
"""
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
build_session_info_message(
|
||||
@@ -290,30 +610,43 @@ async def send_initial_chat_state(
|
||||
),
|
||||
)
|
||||
|
||||
if history := chat_history.get_history(50, get_active_history_group_id(virtual_config)):
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
},
|
||||
)
|
||||
|
||||
history_group_id = get_active_history_group_id(virtual_config)
|
||||
history = chat_history.get_history(50, history_group_id)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": build_welcome_message(virtual_config),
|
||||
"timestamp": time.time(),
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
"group_id": get_current_group_id(virtual_config),
|
||||
},
|
||||
)
|
||||
|
||||
if include_welcome:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": build_welcome_message(virtual_config),
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def resolve_sender_identity(
|
||||
current_user_name: str,
|
||||
normalized_user_id: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Tuple[str, str]:
|
||||
"""解析当前发送者身份。
|
||||
|
||||
Args:
|
||||
current_user_name: 当前昵称。
|
||||
normalized_user_id: 规范化后的用户 ID。
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: ``(发送者昵称, 发送者用户 ID)``。
|
||||
"""
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id
|
||||
@@ -328,6 +661,19 @@ def create_message_data(
|
||||
is_at_bot: bool = True,
|
||||
virtual_config: Optional[VirtualIdentityConfig] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建发送给聊天核心的消息数据。
|
||||
|
||||
Args:
|
||||
content: 文本内容。
|
||||
user_id: 用户 ID。
|
||||
user_name: 用户昵称。
|
||||
message_id: 消息 ID。
|
||||
is_at_bot: 是否默认艾特机器人。
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 聊天核心可处理的消息数据。
|
||||
"""
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
@@ -389,6 +735,18 @@ async def handle_chat_message(
|
||||
normalized_user_id: str,
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> str:
|
||||
"""处理用户发送的聊天消息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
data: 前端提交的消息数据。
|
||||
current_user_name: 当前昵称。
|
||||
normalized_user_id: 规范化后的用户 ID。
|
||||
current_virtual_config: 当前虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
str: 处理后的最新昵称。
|
||||
"""
|
||||
content = str(data.get("content", "")).strip()
|
||||
if not content:
|
||||
return current_user_name
|
||||
@@ -401,11 +759,14 @@ async def handle_chat_message(
|
||||
normalized_user_id=normalized_user_id,
|
||||
virtual_config=current_virtual_config,
|
||||
)
|
||||
target_group_id = get_current_group_id(current_virtual_config)
|
||||
|
||||
await chat_manager.broadcast(
|
||||
await chat_manager.broadcast_to_group(
|
||||
target_group_id,
|
||||
{
|
||||
"type": "user_message",
|
||||
"content": content,
|
||||
"group_id": target_group_id,
|
||||
"message_id": message_id,
|
||||
"timestamp": timestamp,
|
||||
"sender": {
|
||||
@@ -414,7 +775,7 @@ async def handle_chat_message(
|
||||
"is_bot": False,
|
||||
},
|
||||
"virtual_mode": is_virtual_mode_enabled(current_virtual_config),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
message_data = create_message_data(
|
||||
@@ -427,22 +788,37 @@ async def handle_chat_message(
|
||||
)
|
||||
|
||||
try:
|
||||
await chat_manager.broadcast({"type": "typing", "is_typing": True})
|
||||
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": True})
|
||||
await chat_bot.message_process(message_data)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时出错: {e}")
|
||||
await send_chat_error(session_id, f"处理消息时出错: {str(e)}")
|
||||
except Exception as exc:
|
||||
logger.error(f"处理消息时出错: {exc}")
|
||||
await send_chat_error(session_id, f"处理消息时出错: {str(exc)}")
|
||||
finally:
|
||||
await chat_manager.broadcast({"type": "typing", "is_typing": False})
|
||||
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": False})
|
||||
|
||||
return next_user_name
|
||||
|
||||
|
||||
async def handle_chat_ping(session_id: str) -> None:
|
||||
"""处理聊天心跳。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
"""
|
||||
await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()})
|
||||
|
||||
|
||||
async def handle_nickname_update(session_id: str, data: Dict[str, Any], current_user_name: str) -> str:
|
||||
"""处理昵称更新请求。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
data: 前端提交的数据。
|
||||
current_user_name: 当前昵称。
|
||||
|
||||
Returns:
|
||||
str: 更新后的昵称。
|
||||
"""
|
||||
new_name = str(data.get("user_name", "")).strip()
|
||||
if not new_name:
|
||||
return current_user_name
|
||||
@@ -463,6 +839,16 @@ async def enable_virtual_identity(
|
||||
session_prefix: str,
|
||||
virtual_data: Dict[str, Any],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
"""启用虚拟身份模式。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
session_prefix: 会话前缀,用于生成默认群组 ID。
|
||||
virtual_data: 前端提交的虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Optional[VirtualIdentityConfig]: 启用成功时返回新的虚拟身份配置。
|
||||
"""
|
||||
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
|
||||
await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id")
|
||||
return None
|
||||
@@ -470,16 +856,18 @@ async def enable_virtual_identity(
|
||||
person_id_value = str(virtual_data.get("person_id"))
|
||||
try:
|
||||
person = get_person_by_person_id(person_id_value)
|
||||
if not person:
|
||||
if person is None:
|
||||
await send_chat_error(session_id, f"找不到用户: {person_id_value}")
|
||||
return None
|
||||
|
||||
custom_group_id = virtual_data.get("group_id")
|
||||
current_group_id = (
|
||||
f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
|
||||
if custom_group_id
|
||||
else f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
|
||||
)
|
||||
custom_group_id = str(virtual_data.get("group_id") or "").strip()
|
||||
if custom_group_id:
|
||||
current_group_id = custom_group_id
|
||||
if not current_group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
|
||||
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{current_group_id}"
|
||||
else:
|
||||
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
|
||||
|
||||
current_virtual_config = build_virtual_identity_config(
|
||||
person=person,
|
||||
group_id=current_group_id,
|
||||
@@ -521,13 +909,18 @@ async def enable_virtual_identity(
|
||||
},
|
||||
)
|
||||
return current_virtual_config
|
||||
except Exception as e:
|
||||
logger.error(f"设置虚拟身份失败: {e}")
|
||||
await send_chat_error(session_id, f"设置虚拟身份失败: {str(e)}")
|
||||
except Exception as exc:
|
||||
logger.error(f"设置虚拟身份失败: {exc}")
|
||||
await send_chat_error(session_id, f"设置虚拟身份失败: {str(exc)}")
|
||||
return None
|
||||
|
||||
|
||||
async def disable_virtual_identity(session_id: str) -> None:
|
||||
"""关闭虚拟身份模式。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
"""
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
@@ -560,7 +953,18 @@ async def handle_virtual_identity_update(
|
||||
data: Dict[str, Any],
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
virtual_data = cast(dict[str, Any], data.get("config", {}))
|
||||
"""处理虚拟身份切换请求。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
session_id_prefix: 会话前缀。
|
||||
data: 前端提交的数据。
|
||||
current_virtual_config: 当前虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Optional[VirtualIdentityConfig]: 更新后的虚拟身份配置。
|
||||
"""
|
||||
virtual_data = cast(Dict[str, Any], data.get("config", {}))
|
||||
if virtual_data.get("enabled"):
|
||||
next_config = await enable_virtual_identity(session_id, session_id_prefix, virtual_data)
|
||||
return next_config if next_config is not None else current_virtual_config
|
||||
@@ -577,6 +981,19 @@ async def dispatch_chat_event(
|
||||
normalized_user_id: str,
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Tuple[str, Optional[VirtualIdentityConfig]]:
|
||||
"""分发聊天事件到对应的处理器。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
session_id_prefix: 会话前缀。
|
||||
data: 前端提交的数据。
|
||||
current_user_name: 当前昵称。
|
||||
normalized_user_id: 规范化后的用户 ID。
|
||||
current_virtual_config: 当前虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Tuple[str, Optional[VirtualIdentityConfig]]: ``(最新昵称, 最新虚拟身份配置)``。
|
||||
"""
|
||||
event_type = data.get("type")
|
||||
if event_type == "message":
|
||||
next_user_name = await handle_chat_message(
|
||||
@@ -1,12 +1,15 @@
|
||||
"""插件进度实时推送支持。"""
|
||||
|
||||
from typing import Any, Dict, Optional, Set
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
from src.webui.routers.websocket.manager import websocket_manager
|
||||
|
||||
logger = get_logger("webui.plugin_progress")
|
||||
|
||||
@@ -25,25 +28,29 @@ current_progress: Dict[str, Any] = {
|
||||
}
|
||||
|
||||
|
||||
def get_current_progress() -> Dict[str, Any]:
|
||||
"""获取当前插件进度快照。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 当前插件进度数据副本。
|
||||
"""
|
||||
return current_progress.copy()
|
||||
|
||||
|
||||
async def broadcast_progress(progress_data: Dict[str, Any]) -> None:
|
||||
"""向统一连接层广播插件进度更新。
|
||||
|
||||
Args:
|
||||
progress_data: 插件进度数据。
|
||||
"""
|
||||
global current_progress
|
||||
current_progress = progress_data.copy()
|
||||
|
||||
if not active_connections:
|
||||
return
|
||||
|
||||
message = json.dumps(progress_data, ensure_ascii=False)
|
||||
disconnected: Set[WebSocket] = set()
|
||||
|
||||
for websocket in active_connections:
|
||||
try:
|
||||
await websocket.send_text(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送进度更新失败: {e}")
|
||||
disconnected.add(websocket)
|
||||
|
||||
for websocket in disconnected:
|
||||
active_connections.discard(websocket)
|
||||
await websocket_manager.broadcast_to_topic(
|
||||
domain="plugin_progress",
|
||||
topic="main",
|
||||
event="update",
|
||||
data={"progress": progress_data},
|
||||
)
|
||||
|
||||
|
||||
async def update_progress(
|
||||
@@ -56,6 +63,18 @@ async def update_progress(
|
||||
total_plugins: int = 0,
|
||||
loaded_plugins: int = 0,
|
||||
) -> None:
|
||||
"""更新当前插件进度并广播。
|
||||
|
||||
Args:
|
||||
stage: 当前阶段。
|
||||
progress: 当前进度百分比。
|
||||
message: 进度说明消息。
|
||||
operation: 当前操作类型。
|
||||
error: 可选的错误信息。
|
||||
plugin_id: 当前处理的插件 ID。
|
||||
total_plugins: 总插件数量。
|
||||
loaded_plugins: 已处理插件数量。
|
||||
"""
|
||||
progress_data = {
|
||||
"operation": operation,
|
||||
"stage": stage,
|
||||
@@ -74,6 +93,12 @@ async def update_progress(
|
||||
|
||||
@router.websocket("/ws/plugin-progress")
|
||||
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
|
||||
"""旧版插件进度 WebSocket 入口。
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket 对象。
|
||||
token: 可选的一次性握手 Token。
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
if token and verify_ws_token(token):
|
||||
@@ -105,17 +130,22 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
|
||||
data = await websocket.receive_text()
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
except Exception as e:
|
||||
logger.error(f"处理客户端消息时出错: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"处理客户端消息时出错: {exc}")
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
active_connections.discard(websocket)
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebSocket 错误: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"❌ WebSocket 错误: {exc}")
|
||||
active_connections.discard(websocket)
|
||||
|
||||
|
||||
def get_progress_router() -> APIRouter:
|
||||
"""获取旧版插件进度路由对象。
|
||||
|
||||
Returns:
|
||||
APIRouter: 插件进度路由对象。
|
||||
"""
|
||||
return router
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""WebSocket 路由聚合导出。"""
|
||||
|
||||
from .auth import router as ws_auth_router
|
||||
from .logs import router as logs_router
|
||||
from .unified import router as unified_ws_router
|
||||
|
||||
__all__ = [
|
||||
"logs_router",
|
||||
"unified_ws_router",
|
||||
"ws_auth_router",
|
||||
]
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
"""WebSocket 日志推送路由兼容导出。"""
|
||||
|
||||
from src.webui.logs_ws import active_connections, broadcast_log, load_recent_logs, router, websocket_logs
|
||||
|
||||
__all__ = [
|
||||
"active_connections",
|
||||
"broadcast_log",
|
||||
"load_recent_logs",
|
||||
"router",
|
||||
"websocket_logs",
|
||||
]
|
||||
297
src/webui/routers/websocket/manager.py
Normal file
297
src/webui/routers/websocket/manager.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""统一 WebSocket 连接管理器。"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.websocket")
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocketConnection:
|
||||
"""统一 WebSocket 连接上下文。"""
|
||||
|
||||
connection_id: str
|
||||
websocket: WebSocket
|
||||
subscriptions: Set[str] = field(default_factory=set)
|
||||
chat_sessions: Dict[str, str] = field(default_factory=dict)
|
||||
send_queue: "asyncio.Queue[Optional[Dict[str, Any]]]" = field(default_factory=asyncio.Queue)
|
||||
sender_task: Optional["asyncio.Task[None]"] = None
|
||||
|
||||
|
||||
class UnifiedWebSocketManager:
|
||||
"""统一 WebSocket 连接管理器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化统一 WebSocket 连接管理器。"""
|
||||
self.connections: Dict[str, WebSocketConnection] = {}
|
||||
|
||||
def _build_subscription_key(self, domain: str, topic: str) -> str:
|
||||
"""构建订阅索引键。
|
||||
|
||||
Args:
|
||||
domain: 业务域名称。
|
||||
topic: 主题名称。
|
||||
|
||||
Returns:
|
||||
str: 订阅索引键。
|
||||
"""
|
||||
return f"{domain}:{topic}"
|
||||
|
||||
async def _sender_loop(self, connection: WebSocketConnection) -> None:
|
||||
"""串行发送指定连接的出站消息。
|
||||
|
||||
Args:
|
||||
connection: 目标连接上下文。
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
message = await connection.send_queue.get()
|
||||
if message is None:
|
||||
return
|
||||
await connection.websocket.send_json(message)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("统一 WebSocket 发送失败: connection=%s, error=%s", connection.connection_id, exc)
|
||||
|
||||
async def connect(self, connection_id: str, websocket: WebSocket) -> WebSocketConnection:
|
||||
"""注册一个新的物理 WebSocket 连接。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
websocket: FastAPI WebSocket 对象。
|
||||
|
||||
Returns:
|
||||
WebSocketConnection: 新建的连接上下文。
|
||||
"""
|
||||
await websocket.accept()
|
||||
connection = WebSocketConnection(connection_id=connection_id, websocket=websocket)
|
||||
connection.sender_task = asyncio.create_task(self._sender_loop(connection))
|
||||
self.connections[connection_id] = connection
|
||||
return connection
|
||||
|
||||
async def disconnect(self, connection_id: str) -> None:
|
||||
"""断开并清理指定连接。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
"""
|
||||
connection = self.connections.pop(connection_id, None)
|
||||
if connection is None:
|
||||
return
|
||||
|
||||
await connection.send_queue.put(None)
|
||||
if connection.sender_task is not None:
|
||||
try:
|
||||
await connection.sender_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.debug("等待发送协程退出时出现异常: connection=%s, error=%s", connection_id, exc)
|
||||
|
||||
def get_connection(self, connection_id: str) -> Optional[WebSocketConnection]:
|
||||
"""获取指定连接上下文。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
|
||||
Returns:
|
||||
Optional[WebSocketConnection]: 找到时返回连接上下文。
|
||||
"""
|
||||
return self.connections.get(connection_id)
|
||||
|
||||
def register_chat_session(self, connection_id: str, client_session_id: str, session_id: str) -> None:
|
||||
"""登记连接下的逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
client_session_id: 前端会话 ID。
|
||||
session_id: 内部会话 ID。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return
|
||||
connection.chat_sessions[client_session_id] = session_id
|
||||
|
||||
def unregister_chat_session(self, connection_id: str, client_session_id: str) -> None:
|
||||
"""移除连接下的逻辑聊天会话登记。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
client_session_id: 前端会话 ID。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return
|
||||
connection.chat_sessions.pop(client_session_id, None)
|
||||
|
||||
def get_chat_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
|
||||
"""查询连接下的内部聊天会话 ID。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
client_session_id: 前端会话 ID。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 找到时返回内部会话 ID。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return None
|
||||
return connection.chat_sessions.get(client_session_id)
|
||||
|
||||
def subscribe(self, connection_id: str, domain: str, topic: str) -> None:
|
||||
"""登记连接的主题订阅。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
domain: 业务域名称。
|
||||
topic: 主题名称。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return
|
||||
connection.subscriptions.add(self._build_subscription_key(domain, topic))
|
||||
|
||||
def unsubscribe(self, connection_id: str, domain: str, topic: str) -> None:
|
||||
"""移除连接的主题订阅。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
domain: 业务域名称。
|
||||
topic: 主题名称。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return
|
||||
connection.subscriptions.discard(self._build_subscription_key(domain, topic))
|
||||
|
||||
def is_subscribed(self, connection_id: str, domain: str, topic: str) -> bool:
|
||||
"""判断连接是否订阅了指定主题。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
domain: 业务域名称。
|
||||
topic: 主题名称。
|
||||
|
||||
Returns:
|
||||
bool: 已订阅时返回 ``True``。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return False
|
||||
return self._build_subscription_key(domain, topic) in connection.subscriptions
|
||||
|
||||
async def enqueue(self, connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""向指定连接的发送队列压入消息。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 待发送的消息。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return
|
||||
await connection.send_queue.put(message)
|
||||
|
||||
async def send_response(
|
||||
self,
|
||||
connection_id: str,
|
||||
request_id: Optional[str],
|
||||
ok: bool,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
error: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""发送统一响应消息。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
request_id: 请求 ID。
|
||||
ok: 请求是否成功。
|
||||
data: 成功响应数据。
|
||||
error: 失败响应数据。
|
||||
"""
|
||||
response_message: Dict[str, Any] = {
|
||||
"op": "response",
|
||||
"id": request_id,
|
||||
"ok": ok,
|
||||
}
|
||||
if data is not None:
|
||||
response_message["data"] = data
|
||||
if error is not None:
|
||||
response_message["error"] = error
|
||||
await self.enqueue(connection_id, response_message)
|
||||
|
||||
async def send_event(
|
||||
self,
|
||||
connection_id: str,
|
||||
domain: str,
|
||||
event: str,
|
||||
data: Dict[str, Any],
|
||||
session: Optional[str] = None,
|
||||
topic: Optional[str] = None,
|
||||
) -> None:
|
||||
"""发送统一事件消息。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
domain: 业务域名称。
|
||||
event: 事件名称。
|
||||
data: 事件数据。
|
||||
session: 可选的逻辑会话 ID。
|
||||
topic: 可选的主题名称。
|
||||
"""
|
||||
event_message: Dict[str, Any] = {
|
||||
"op": "event",
|
||||
"domain": domain,
|
||||
"event": event,
|
||||
"data": data,
|
||||
}
|
||||
if session is not None:
|
||||
event_message["session"] = session
|
||||
if topic is not None:
|
||||
event_message["topic"] = topic
|
||||
await self.enqueue(connection_id, event_message)
|
||||
|
||||
async def send_pong(self, connection_id: str, timestamp: float) -> None:
|
||||
"""发送心跳响应。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
timestamp: 当前时间戳。
|
||||
"""
|
||||
await self.enqueue(
|
||||
connection_id,
|
||||
{
|
||||
"op": "pong",
|
||||
"ts": timestamp,
|
||||
},
|
||||
)
|
||||
|
||||
async def broadcast_to_topic(self, domain: str, topic: str, event: str, data: Dict[str, Any]) -> None:
|
||||
"""向订阅指定主题的全部连接广播事件。
|
||||
|
||||
Args:
|
||||
domain: 业务域名称。
|
||||
topic: 主题名称。
|
||||
event: 事件名称。
|
||||
data: 事件数据。
|
||||
"""
|
||||
subscription_key = self._build_subscription_key(domain, topic)
|
||||
for connection in list(self.connections.values()):
|
||||
if subscription_key in connection.subscriptions:
|
||||
await self.send_event(
|
||||
connection.connection_id,
|
||||
domain=domain,
|
||||
event=event,
|
||||
data=data,
|
||||
topic=topic,
|
||||
)
|
||||
|
||||
|
||||
websocket_manager = UnifiedWebSocketManager()
|
||||
548
src/webui/routers/websocket/unified.py
Normal file
548
src/webui/routers/websocket/unified.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""统一 WebSocket 路由。"""
|
||||
|
||||
from typing import Any, Dict, Optional, Set, cast
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.logs_ws import load_recent_logs
|
||||
from src.webui.routers.chat.service import (
|
||||
chat_manager,
|
||||
dispatch_chat_event,
|
||||
normalize_webui_user_id,
|
||||
resolve_initial_virtual_identity,
|
||||
send_initial_chat_state,
|
||||
)
|
||||
from src.webui.routers.plugin.progress import get_current_progress
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
from src.webui.routers.websocket.manager import websocket_manager
|
||||
|
||||
logger = get_logger("webui.unified_ws")
|
||||
router = APIRouter()
|
||||
_background_tasks: Set["asyncio.Task[None]"] = set()
|
||||
|
||||
|
||||
def _build_error(code: str, message: str) -> Dict[str, Any]:
|
||||
"""构建统一错误响应体。
|
||||
|
||||
Args:
|
||||
code: 错误码。
|
||||
message: 错误描述。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 统一错误对象。
|
||||
"""
|
||||
return {
|
||||
"code": code,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
|
||||
def _get_request_data(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""从客户端消息中提取数据字段。
|
||||
|
||||
Args:
|
||||
message: 客户端消息。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 标准化后的数据字典。
|
||||
"""
|
||||
data = message.get("data", {})
|
||||
if isinstance(data, dict):
|
||||
return cast(Dict[str, Any], data)
|
||||
return {}
|
||||
|
||||
|
||||
def _track_background_task(task: "asyncio.Task[None]") -> None:
|
||||
"""登记后台任务并在完成后自动清理。
|
||||
|
||||
Args:
|
||||
task: 后台协程任务。
|
||||
"""
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
|
||||
async def authenticate_websocket_connection(websocket: WebSocket, token: Optional[str]) -> bool:
|
||||
"""校验统一 WebSocket 连接的认证状态。
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket 对象。
|
||||
token: 可选的一次性握手 Token。
|
||||
|
||||
Returns:
|
||||
bool: 认证通过时返回 ``True``。
|
||||
"""
|
||||
if token and verify_ws_token(token):
|
||||
logger.debug("统一 WebSocket 使用临时 token 认证成功")
|
||||
return True
|
||||
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
logger.debug("统一 WebSocket 使用 Cookie 认证成功")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def _handle_logs_subscribe(connection_id: str, request_id: Optional[str], data: Dict[str, Any]) -> None:
|
||||
"""处理日志域订阅请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
request_id: 请求 ID。
|
||||
data: 订阅参数。
|
||||
"""
|
||||
replay_limit = int(data.get("replay", 100) or 100)
|
||||
replay_limit = max(0, min(replay_limit, 500))
|
||||
websocket_manager.subscribe(connection_id, domain="logs", topic="main")
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"domain": "logs", "topic": "main"},
|
||||
)
|
||||
await websocket_manager.send_event(
|
||||
connection_id,
|
||||
domain="logs",
|
||||
event="snapshot",
|
||||
topic="main",
|
||||
data={"entries": load_recent_logs(limit=replay_limit)},
|
||||
)
|
||||
|
||||
|
||||
async def _handle_plugin_progress_subscribe(connection_id: str, request_id: Optional[str]) -> None:
|
||||
"""处理插件进度域订阅请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
request_id: 请求 ID。
|
||||
"""
|
||||
websocket_manager.subscribe(connection_id, domain="plugin_progress", topic="main")
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"domain": "plugin_progress", "topic": "main"},
|
||||
)
|
||||
await websocket_manager.send_event(
|
||||
connection_id,
|
||||
domain="plugin_progress",
|
||||
event="snapshot",
|
||||
topic="main",
|
||||
data={"progress": get_current_progress()},
|
||||
)
|
||||
|
||||
|
||||
async def _handle_subscribe(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理主题订阅请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
domain = str(message.get("domain") or "").strip()
|
||||
topic = str(message.get("topic") or "").strip()
|
||||
data = _get_request_data(message)
|
||||
|
||||
if domain == "logs" and topic == "main":
|
||||
await _handle_logs_subscribe(connection_id, request_id, data)
|
||||
return
|
||||
|
||||
if domain == "plugin_progress" and topic == "main":
|
||||
await _handle_plugin_progress_subscribe(connection_id, request_id)
|
||||
return
|
||||
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("unsupported_subscription", f"不支持的订阅目标: {domain}:{topic}"),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_unsubscribe(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理主题退订请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
domain = str(message.get("domain") or "").strip()
|
||||
topic = str(message.get("topic") or "").strip()
|
||||
|
||||
if not domain or not topic:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("invalid_unsubscribe", "退订请求缺少 domain 或 topic"),
|
||||
)
|
||||
return
|
||||
|
||||
websocket_manager.unsubscribe(connection_id, domain=domain, topic=topic)
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"domain": domain, "topic": topic},
|
||||
)
|
||||
|
||||
|
||||
async def _open_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""打开一个逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
client_session_id = str(message.get("session") or "").strip()
|
||||
if not client_session_id:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("missing_session", "聊天会话打开请求缺少 session"),
|
||||
)
|
||||
return
|
||||
|
||||
data = _get_request_data(message)
|
||||
normalized_user_id = normalize_webui_user_id(cast(Optional[str], data.get("user_id")))
|
||||
current_user_name = str(data.get("user_name") or "WebUI用户")
|
||||
current_virtual_config = resolve_initial_virtual_identity(
|
||||
platform=cast(Optional[str], data.get("platform")),
|
||||
person_id=cast(Optional[str], data.get("person_id")),
|
||||
group_name=cast(Optional[str], data.get("group_name")),
|
||||
group_id=cast(Optional[str], data.get("group_id")),
|
||||
)
|
||||
restore = bool(data.get("restore"))
|
||||
session_id = f"{connection_id}:{client_session_id}"
|
||||
|
||||
async def send_chat_event(chat_message: Dict[str, Any]) -> None:
|
||||
"""将聊天消息封装为统一事件并发送。
|
||||
|
||||
Args:
|
||||
chat_message: 聊天消息体。
|
||||
"""
|
||||
event_name = str(chat_message.get("type") or "message")
|
||||
await websocket_manager.send_event(
|
||||
connection_id,
|
||||
domain="chat",
|
||||
event=event_name,
|
||||
session=client_session_id,
|
||||
data=chat_message,
|
||||
)
|
||||
|
||||
await chat_manager.connect(
|
||||
session_id=session_id,
|
||||
connection_id=connection_id,
|
||||
client_session_id=client_session_id,
|
||||
user_id=normalized_user_id,
|
||||
user_name=current_user_name,
|
||||
virtual_config=current_virtual_config,
|
||||
sender=send_chat_event,
|
||||
)
|
||||
websocket_manager.register_chat_session(connection_id, client_session_id, session_id)
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"session": client_session_id, "session_id": session_id},
|
||||
)
|
||||
await send_initial_chat_state(
|
||||
session_id=session_id,
|
||||
user_id=normalized_user_id,
|
||||
user_name=current_user_name,
|
||||
virtual_config=current_virtual_config,
|
||||
include_welcome=not restore,
|
||||
)
|
||||
|
||||
|
||||
async def _close_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""关闭一个逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
client_session_id = str(message.get("session") or "").strip()
|
||||
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
|
||||
if session_id is None:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
|
||||
)
|
||||
return
|
||||
|
||||
chat_manager.disconnect(session_id)
|
||||
websocket_manager.unregister_chat_session(connection_id, client_session_id)
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"session": client_session_id},
|
||||
)
|
||||
|
||||
|
||||
async def _process_chat_message(connection_id: str, client_session_id: str, data: Dict[str, Any]) -> None:
|
||||
"""在后台处理聊天消息事件。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
client_session_id: 前端会话 ID。
|
||||
data: 客户端提交的消息数据。
|
||||
"""
|
||||
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
|
||||
if session_id is None:
|
||||
return
|
||||
|
||||
session_state = chat_manager.get_session(session_id)
|
||||
if session_state is None:
|
||||
return
|
||||
|
||||
next_user_name, next_virtual_config = await dispatch_chat_event(
|
||||
session_id=session_id,
|
||||
session_id_prefix=session_id[:8],
|
||||
data=data,
|
||||
current_user_name=session_state.user_name,
|
||||
normalized_user_id=session_state.user_id,
|
||||
current_virtual_config=session_state.virtual_config,
|
||||
)
|
||||
chat_manager.update_session_context(
|
||||
session_id=session_id,
|
||||
user_name=next_user_name,
|
||||
virtual_config=next_virtual_config,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_chat_message_send(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理聊天消息发送请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
client_session_id = str(message.get("session") or "").strip()
|
||||
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
|
||||
if session_id is None:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
|
||||
)
|
||||
return
|
||||
|
||||
data = _get_request_data(message)
|
||||
payload = {
|
||||
"type": "message",
|
||||
"content": data.get("content", ""),
|
||||
"user_name": data.get("user_name", ""),
|
||||
}
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"accepted": True, "session": client_session_id},
|
||||
)
|
||||
_track_background_task(asyncio.create_task(_process_chat_message(connection_id, client_session_id, payload)))
|
||||
|
||||
|
||||
async def _handle_chat_nickname_update(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理聊天昵称更新请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
client_session_id = str(message.get("session") or "").strip()
|
||||
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
|
||||
if session_id is None:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
|
||||
)
|
||||
return
|
||||
|
||||
data = _get_request_data(message)
|
||||
session_state = chat_manager.get_session(session_id)
|
||||
if session_state is None:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
|
||||
)
|
||||
return
|
||||
|
||||
next_user_name, next_virtual_config = await dispatch_chat_event(
|
||||
session_id=session_id,
|
||||
session_id_prefix=session_id[:8],
|
||||
data={
|
||||
"type": "update_nickname",
|
||||
"user_name": data.get("user_name", ""),
|
||||
},
|
||||
current_user_name=session_state.user_name,
|
||||
normalized_user_id=session_state.user_id,
|
||||
current_virtual_config=session_state.virtual_config,
|
||||
)
|
||||
chat_manager.update_session_context(
|
||||
session_id=session_id,
|
||||
user_name=next_user_name,
|
||||
virtual_config=next_virtual_config,
|
||||
)
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"session": client_session_id, "user_name": next_user_name},
|
||||
)
|
||||
|
||||
|
||||
async def _handle_chat_call(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理聊天域调用请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
method = str(message.get("method") or "").strip()
|
||||
|
||||
if method == "session.open":
|
||||
await _open_chat_session(connection_id, message)
|
||||
return
|
||||
|
||||
if method == "session.close":
|
||||
await _close_chat_session(connection_id, message)
|
||||
return
|
||||
|
||||
if method == "message.send":
|
||||
await _handle_chat_message_send(connection_id, message)
|
||||
return
|
||||
|
||||
if method == "session.update_nickname":
|
||||
await _handle_chat_nickname_update(connection_id, message)
|
||||
return
|
||||
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("unsupported_method", f"不支持的聊天方法: {method}"),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_call(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理统一调用请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
domain = str(message.get("domain") or "").strip()
|
||||
if domain == "chat":
|
||||
await _handle_chat_call(connection_id, message)
|
||||
return
|
||||
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("unsupported_domain", f"不支持的调用域: {domain}"),
|
||||
)
|
||||
|
||||
|
||||
async def handle_client_message(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理统一 WebSocket 客户端消息。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
operation = str(message.get("op") or "").strip()
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
|
||||
if operation == "ping":
|
||||
await websocket_manager.send_pong(connection_id, time.time())
|
||||
return
|
||||
|
||||
if operation == "subscribe":
|
||||
await _handle_subscribe(connection_id, message)
|
||||
return
|
||||
|
||||
if operation == "unsubscribe":
|
||||
await _handle_unsubscribe(connection_id, message)
|
||||
return
|
||||
|
||||
if operation == "call":
|
||||
await _handle_call(connection_id, message)
|
||||
return
|
||||
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("unsupported_operation", f"不支持的操作: {operation}"),
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
|
||||
"""统一 WebSocket 入口。
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket 对象。
|
||||
token: 可选的一次性握手 Token。
|
||||
"""
|
||||
if not await authenticate_websocket_connection(websocket, token):
|
||||
logger.warning("统一 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
connection_id = uuid.uuid4().hex
|
||||
await websocket_manager.connect(connection_id=connection_id, websocket=websocket)
|
||||
logger.info("统一 WebSocket 客户端已连接: connection=%s", connection_id)
|
||||
await websocket_manager.send_event(
|
||||
connection_id,
|
||||
domain="system",
|
||||
event="ready",
|
||||
data={"connection_id": connection_id, "timestamp": time.time()},
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw_message = await websocket.receive_json()
|
||||
if not isinstance(raw_message, dict):
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=None,
|
||||
ok=False,
|
||||
error=_build_error("invalid_message", "消息必须是 JSON 对象"),
|
||||
)
|
||||
continue
|
||||
await handle_client_message(connection_id, cast(Dict[str, Any], raw_message))
|
||||
except WebSocketDisconnect:
|
||||
logger.info("统一 WebSocket 客户端已断开: connection=%s", connection_id)
|
||||
except Exception as exc:
|
||||
logger.error(f"统一 WebSocket 处理失败: {exc}")
|
||||
finally:
|
||||
chat_manager.disconnect_connection(connection_id)
|
||||
await websocket_manager.disconnect(connection_id)
|
||||
@@ -18,11 +18,11 @@ from src.webui.routers.expression import router as expression_router
|
||||
from src.webui.routers.jargon import router as jargon_router
|
||||
from src.webui.routers.model import router as model_router
|
||||
from src.webui.routers.person import router as person_router
|
||||
from src.webui.routers.plugin import get_progress_router
|
||||
from src.webui.routers.plugin import router as plugin_router
|
||||
from src.webui.routers.statistics import router as statistics_router
|
||||
from src.webui.routers.system import router as system_router
|
||||
from src.webui.routers.websocket.auth import router as ws_auth_router
|
||||
from src.webui.routers.websocket.unified import router as unified_ws_router
|
||||
|
||||
logger = get_logger("webui.api")
|
||||
|
||||
@@ -43,14 +43,14 @@ router.include_router(jargon_router)
|
||||
router.include_router(emoji_router)
|
||||
# 注册插件管理路由
|
||||
router.include_router(plugin_router)
|
||||
# 注册插件进度 WebSocket 路由
|
||||
router.include_router(get_progress_router())
|
||||
# 注册系统控制路由
|
||||
router.include_router(system_router)
|
||||
# 注册模型列表获取路由
|
||||
router.include_router(model_router)
|
||||
# 注册 WebSocket 认证路由
|
||||
router.include_router(ws_auth_router)
|
||||
# 注册统一 WebSocket 路由
|
||||
router.include_router(unified_ws_router)
|
||||
|
||||
|
||||
class TokenVerifyRequest(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user