feat: add unified WebSocket connection manager and routing

- Implemented UnifiedWebSocketManager for managing WebSocket connections, including subscription handling and message sending.
- Created unified WebSocket router to handle client messages, including authentication, subscription, and chat session management.
- Added support for logging and plugin progress subscriptions.
- Enhanced error handling and response structure for WebSocket operations.
This commit is contained in:
DrSmoothl
2026-04-02 22:08:52 +08:00
parent 7d0d429640
commit 1906890b67
28 changed files with 3845 additions and 1137 deletions

View File

@@ -1,25 +1,28 @@
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from rich.traceback import install
from sqlmodel import select
from typing import Any, Dict, List, Optional, Tuple
import asyncio
import hashlib
import heapq
import Levenshtein
import random
import re
from src.common.logger import get_logger
from rich.traceback import install
from sqlmodel import select
import Levenshtein
from src.common.data_models.image_data_model import MaiEmoji
from src.common.database.database_model import Images, ImageType
from src.common.database.database import get_db_session, get_db_session_manual
from src.common.utils.utils_image import ImageUtils
from src.prompt.prompt_manager import prompt_manager
from src.config.config import config_manager, global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMImageOptions
from src.common.database.database import get_db_session, get_db_session_manual
from src.common.database.database_model import Images, ImageType
from src.common.logger import get_logger
from src.common.utils.utils_image import ImageUtils
from src.config.config import config_manager, global_config
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
logger = get_logger("emoji")
@@ -33,6 +36,171 @@ EMOJI_REGISTERED_DIR = DATA_DIR / "emoji_registered" # 已注册的表情包注
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册表情包系统内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
emoji_schema = {
"type": "object",
"description": "当前表情包的序列化信息,主要包含 file_hash、description、emotions 等字段。",
}
string_array_schema = {
"type": "array",
"items": {"type": "string"},
}
return registry.register_hook_specs(
[
HookSpec(
name="emoji.maisaka.before_select",
description="Maisaka 表情发送工具选择表情前触发,可改写情绪、上下文和采样参数,或中止本次选择。",
parameters_schema=build_object_schema(
{
"stream_id": {"type": "string", "description": "目标会话 ID。"},
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
"context_texts": {
**string_array_schema,
"description": "最近聊天上下文文本列表。",
},
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
"abort_message": {
"type": "string",
"description": "当 Hook 主动中止时可附带的失败提示。",
},
},
required=["stream_id", "requested_emotion", "reasoning", "context_texts", "sample_size"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="emoji.maisaka.after_select",
description="Maisaka 已选出表情后触发,可替换选中的表情哈希、补充匹配情绪,或中止发送。",
parameters_schema=build_object_schema(
{
"stream_id": {"type": "string", "description": "目标会话 ID。"},
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
"context_texts": {
**string_array_schema,
"description": "最近聊天上下文文本列表。",
},
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
"selected_emoji": emoji_schema,
"selected_emoji_hash": {"type": "string", "description": "选中的表情哈希。"},
"matched_emotion": {"type": "string", "description": "最终命中的情绪标签。"},
"abort_message": {
"type": "string",
"description": "当 Hook 主动中止时可附带的失败提示。",
},
},
required=[
"stream_id",
"requested_emotion",
"reasoning",
"context_texts",
"sample_size",
"matched_emotion",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="emoji.register.after_build_description",
description="表情包描述生成并通过内容审查后触发,可改写描述文本或拒绝本次注册。",
parameters_schema=build_object_schema(
{
"emoji": emoji_schema,
"description": {"type": "string", "description": "当前生成出的表情包描述。"},
"image_format": {"type": "string", "description": "表情图片格式。"},
},
required=["emoji", "description", "image_format"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="emoji.register.after_build_emotion",
description="表情包情绪标签生成完成后触发,可改写标签列表或拒绝本次注册。",
parameters_schema=build_object_schema(
{
"emoji": emoji_schema,
"description": {"type": "string", "description": "当前表情包描述。"},
"emotions": {
**string_array_schema,
"description": "当前生成出的情绪标签列表。",
},
},
required=["emoji", "description", "emotions"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
]
)
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, Any]]:
"""将表情包对象序列化为 Hook 可传输载荷。
Args:
emoji: 待序列化的表情包对象。
Returns:
Optional[Dict[str, Any]]: 序列化后的字典;当表情为空时返回 ``None``。
"""
if emoji is None:
return None
return {
"file_hash": str(emoji.file_hash or "").strip(),
"file_name": emoji.file_name,
"full_path": str(emoji.full_path),
"description": emoji.description,
"emotions": [str(item).strip() for item in emoji.emotion if str(item).strip()],
"query_count": int(emoji.query_count),
}
def _normalize_string_list(raw_values: Any) -> List[str]:
"""将任意列表值规范化为字符串列表。
Args:
raw_values: 待规范化的原始值。
Returns:
List[str]: 去空白后的字符串列表。
"""
if not isinstance(raw_values, list):
return []
return [str(item).strip() for item in raw_values if str(item).strip()]
def _ensure_directories() -> None:
"""确保表情包相关目录存在"""
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
@@ -642,6 +810,22 @@ class EmojiManager:
if "" in llm_response:
logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
hook_result = await _get_runtime_manager().invoke_hook(
"emoji.register.after_build_description",
emoji=_serialize_emoji_for_hook(target_emoji),
description=description,
image_format=image_format,
)
if hook_result.aborted:
logger.info(f"[构建描述] 表情包描述被 Hook 中止注册: {target_emoji.file_name}")
return False, target_emoji
normalized_description = str(hook_result.kwargs.get("description", description) or "").strip()
if not normalized_description:
logger.warning(f"[构建描述] Hook 返回空描述,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
description = normalized_description
target_emoji.description = description
logger.info(f"[构建描述] 成功为表情包构建描述: {target_emoji.description}")
return True, target_emoji
@@ -687,6 +871,23 @@ class EmojiManager:
elif len(emotions) > 2:
emotions = random.sample(emotions, 2)
hook_result = await _get_runtime_manager().invoke_hook(
"emoji.register.after_build_emotion",
emoji=_serialize_emoji_for_hook(target_emoji),
description=target_emoji.description,
emotions=list(emotions),
)
if hook_result.aborted:
logger.info(f"[构建情感标签] 表情包情感标签被 Hook 中止注册: {target_emoji.file_name}")
return False, target_emoji
raw_emotions = hook_result.kwargs.get("emotions")
if raw_emotions is not None:
emotions = _normalize_string_list(raw_emotions)
if not emotions:
logger.warning(f"[构建情感标签] Hook 返回空情绪标签,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
logger.info(f"[构建情感标签] 成功为表情包构建情感标签: {','.join(emotions)}")
target_emoji.emotion = emotions
return True, target_emoji

View File

@@ -1,7 +1,7 @@
"""Maisaka 表情工具内置能力。"""
from dataclasses import dataclass, field
from typing import Sequence
from typing import Any, Optional, Sequence
import random
@@ -11,7 +11,7 @@ from src.common.logger import get_logger
from src.common.utils.utils_image import ImageUtils
from src.services import send_service
from .emoji_manager import emoji_manager, emoji_manager_emotion_judge_llm
from .emoji_manager import _serialize_emoji_for_hook, emoji_manager, emoji_manager_emotion_judge_llm
logger = get_logger("emoji_maisaka_tool")
@@ -29,6 +29,76 @@ class MaisakaEmojiSendResult:
matched_emotion: str = ""
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
def _coerce_positive_int(value: Any, default: int) -> int:
"""将任意值安全转换为正整数。
Args:
value: 待转换的值。
default: 转换失败时使用的默认值。
Returns:
int: 规范化后的正整数。
"""
try:
normalized_value = int(value)
except (TypeError, ValueError):
return default
return normalized_value if normalized_value > 0 else default
def _normalize_context_texts(context_texts: Sequence[str] | None) -> list[str]:
"""清洗 Hook 和调用链传入的上下文文本列表。
Args:
context_texts: 原始上下文文本序列。
Returns:
list[str]: 过滤空白后的上下文文本列表。
"""
if not context_texts:
return []
return [str(item).strip() for item in context_texts if str(item).strip()]
def _resolve_selected_emoji(raw_value: Any) -> Optional[MaiEmoji]:
"""根据 Hook 返回值解析目标表情包对象。
Args:
raw_value: Hook 返回的 ``selected_emoji`` 或 ``selected_emoji_hash``。
Returns:
Optional[MaiEmoji]: 命中的表情包对象;未命中时返回 ``None``。
"""
raw_hash: str = ""
if isinstance(raw_value, dict):
raw_hash = str(raw_value.get("file_hash") or raw_value.get("hash") or "").strip()
elif isinstance(raw_value, str):
raw_hash = raw_value.strip()
if not raw_hash:
return None
for emoji in emoji_manager.emojis:
if emoji.file_hash == raw_hash:
return emoji
return None
def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
"""提取并清洗单个表情的情绪标签。"""
@@ -129,16 +199,81 @@ async def send_emoji_for_maisaka(
) -> MaisakaEmojiSendResult:
"""为 Maisaka 选择并发送一个表情。"""
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
requested_emotion=requested_emotion,
reasoning=reasoning,
context_texts=context_texts,
normalized_requested_emotion = requested_emotion.strip()
normalized_reasoning = reasoning.strip()
normalized_context_texts = _normalize_context_texts(context_texts)
sample_size = 30
before_select_result = await _get_runtime_manager().invoke_hook(
"emoji.maisaka.before_select",
stream_id=stream_id,
requested_emotion=normalized_requested_emotion,
reasoning=normalized_reasoning,
context_texts=list(normalized_context_texts),
sample_size=sample_size,
abort_message="表情选择已被 Hook 中止。",
)
if before_select_result.aborted:
abort_message = str(before_select_result.kwargs.get("abort_message") or "表情选择已被 Hook 中止。").strip()
return MaisakaEmojiSendResult(
success=False,
message=abort_message or "表情选择已被 Hook 中止。",
requested_emotion=normalized_requested_emotion,
)
before_select_kwargs = before_select_result.kwargs
normalized_requested_emotion = str(
before_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
).strip()
normalized_reasoning = str(before_select_kwargs.get("reasoning", normalized_reasoning) or "").strip()
if isinstance(before_select_kwargs.get("context_texts"), list):
normalized_context_texts = _normalize_context_texts(before_select_kwargs.get("context_texts"))
sample_size = _coerce_positive_int(before_select_kwargs.get("sample_size"), sample_size)
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
requested_emotion=normalized_requested_emotion,
reasoning=normalized_reasoning,
context_texts=normalized_context_texts,
sample_size=sample_size,
)
after_select_result = await _get_runtime_manager().invoke_hook(
"emoji.maisaka.after_select",
stream_id=stream_id,
requested_emotion=normalized_requested_emotion,
reasoning=normalized_reasoning,
context_texts=list(normalized_context_texts),
sample_size=sample_size,
selected_emoji=_serialize_emoji_for_hook(selected_emoji),
selected_emoji_hash=str(selected_emoji.file_hash or "").strip() if selected_emoji is not None else "",
matched_emotion=matched_emotion,
abort_message="表情发送已被 Hook 中止。",
)
if after_select_result.aborted:
abort_message = str(after_select_result.kwargs.get("abort_message") or "表情发送已被 Hook 中止。").strip()
return MaisakaEmojiSendResult(
success=False,
message=abort_message or "表情发送已被 Hook 中止。",
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
after_select_kwargs = after_select_result.kwargs
normalized_requested_emotion = str(
after_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
).strip()
matched_emotion = str(after_select_kwargs.get("matched_emotion", matched_emotion) or "").strip()
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji_hash"))
if override_emoji is None:
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji"))
if override_emoji is not None:
selected_emoji = override_emoji
if selected_emoji is None:
return MaisakaEmojiSendResult(
success=False,
message="当前表情包库中没有可用表情。",
requested_emotion=requested_emotion.strip(),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
try:
@@ -151,7 +286,7 @@ async def send_emoji_for_maisaka(
message=f"发送表情包失败:{exc}",
description=selected_emoji.description.strip(),
emotions=_normalize_emotions(selected_emoji),
requested_emotion=requested_emotion.strip(),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
@@ -169,7 +304,7 @@ async def send_emoji_for_maisaka(
message=f"发送表情包时发生异常:{exc}",
description=selected_emoji.description.strip(),
emotions=_normalize_emotions(selected_emoji),
requested_emotion=requested_emotion.strip(),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
@@ -181,7 +316,7 @@ async def send_emoji_for_maisaka(
message="发送表情包失败。",
description=description,
emotions=emotions,
requested_emotion=requested_emotion.strip(),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
@@ -197,6 +332,6 @@ async def send_emoji_for_maisaka(
emoji_base64=emoji_base64,
description=description,
emotions=emotions,
requested_emotion=requested_emotion.strip(),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)

View File

@@ -18,28 +18,29 @@ install(extra_lines=3)
logger = get_logger("sender")
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str], Optional[str]]] = None
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
# TODO: 重构完成后完成webui相关
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]:
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str], Optional[str]]:
"""获取 WebUI 聊天室广播器。
Returns:
Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 元组;
Tuple[Any, Optional[str], Optional[str]]: ``(chat_manager, platform_name, default_group_id)`` 元组;
若 WebUI 相关模块不可用,则元素会退化为 ``None``。
"""
global _webui_chat_broadcaster
if _webui_chat_broadcaster is None:
try:
from src.webui.routers.chat import WEBUI_CHAT_PLATFORM, chat_manager
from src.webui.routers.chat.service import WEBUI_CHAT_GROUP_ID
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM, WEBUI_CHAT_GROUP_ID)
except ImportError:
_webui_chat_broadcaster = (None, None)
_webui_chat_broadcaster = (None, None, None)
return _webui_chat_broadcaster
@@ -76,7 +77,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
try:
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
chat_manager, webui_platform = get_webui_chat_broadcaster()
chat_manager, webui_platform, default_group_id = get_webui_chat_broadcaster()
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
if is_webui_message and chat_manager is not None:
@@ -97,8 +98,9 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
message_type = "rich"
segments = message_segments
await chat_manager.broadcast(
{
await chat_manager.broadcast_to_group(
group_id=group_id or default_group_id or "",
message={
"type": "bot_message",
"content": message.processed_plain_text,
"message_type": message_type,
@@ -110,7 +112,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
"avatar": None,
"is_bot": True,
},
}
},
)
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库

View File

@@ -1,22 +1,25 @@
from datetime import datetime
from sqlmodel import select
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
import asyncio
import difflib
import json
import re
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
from src.prompt.prompt_manager import prompt_manager
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.database.database import get_db_session
from src.common.data_models.expression_data_model import MaiExpression
from sqlmodel import select
from src.chat.utils.utils import is_bot_self
from src.common.data_models.expression_data_model import MaiExpression
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.common.utils.utils_message import MessageUtils
from src.config.config import global_config
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
from .expression_utils import check_expression_suitability, parse_expression_response
@@ -34,8 +37,122 @@ summary_model = LLMServiceClient(task_name="utils", request_type="expression.sum
check_model = LLMServiceClient(task_name="utils", request_type="expression.check")
def register_expression_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册表达方式系统内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="expression.select.before_select",
description="表达方式选择流程开始前触发,可改写会话上下文、选择参数或中止本次选择。",
parameters_schema=build_object_schema(
{
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
},
required=["chat_id", "chat_info", "max_num", "think_level"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.select.after_selection",
description="表达方式选择完成后触发,可改写最终选中的表达方式列表与 ID。",
parameters_schema=build_object_schema(
{
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
"selected_expressions": {
"type": "array",
"items": {"type": "object"},
"description": "当前已选中的表达方式列表。",
},
"selected_expression_ids": {
"type": "array",
"items": {"type": "integer"},
"description": "当前已选中的表达方式 ID 列表。",
},
},
required=[
"chat_id",
"chat_info",
"max_num",
"think_level",
"selected_expressions",
"selected_expression_ids",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.learn.after_extract",
description="表达方式学习解析出表达/黑话候选后触发,可改写候选集或直接终止本轮学习。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"message_count": {"type": "integer", "description": "本轮参与学习的消息数量。"},
"expressions": {
"type": "array",
"items": {"type": "object"},
"description": "解析出的表达方式候选列表。",
},
"jargon_entries": {
"type": "array",
"items": {"type": "object"},
"description": "解析出的黑话候选列表。",
},
},
required=["session_id", "message_count", "expressions", "jargon_entries"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.learn.before_upsert",
description="表达方式写入数据库前触发,可改写情景/风格文本或跳过本条写入。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"situation": {"type": "string", "description": "即将写入的情景文本。"},
"style": {"type": "string", "description": "即将写入的风格文本。"},
},
required=["session_id", "situation", "style"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
]
)
class ExpressionLearner:
def __init__(self, session_id: str) -> None:
"""初始化表达方式学习器。
Args:
session_id: 当前会话 ID。
"""
self.session_id = session_id
# 学习锁,防止并发执行学习任务
@@ -44,6 +161,110 @@ class ExpressionLearner:
# 消息缓存
self._messages_cache: List["SessionMessage"] = []
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _serialize_expressions(expressions: List[Tuple[str, str, str]]) -> List[dict[str, str]]:
"""将表达方式候选序列化为 Hook 载荷。
Args:
expressions: 原始表达方式候选列表。
Returns:
List[dict[str, str]]: 序列化后的表达方式候选。
"""
return [
{
"situation": str(situation).strip(),
"style": str(style).strip(),
"source_id": str(source_id).strip(),
}
for situation, style, source_id in expressions
if str(situation).strip() and str(style).strip()
]
@staticmethod
def _deserialize_expressions(raw_expressions: Any) -> List[Tuple[str, str, str]]:
"""从 Hook 载荷恢复表达方式候选列表。
Args:
raw_expressions: Hook 返回的表达方式候选。
Returns:
List[Tuple[str, str, str]]: 恢复后的表达方式候选列表。
"""
if not isinstance(raw_expressions, list):
return []
normalized_expressions: List[Tuple[str, str, str]] = []
for raw_expression in raw_expressions:
if not isinstance(raw_expression, dict):
continue
situation = str(raw_expression.get("situation") or "").strip()
style = str(raw_expression.get("style") or "").strip()
source_id = str(raw_expression.get("source_id") or "").strip()
if not situation or not style:
continue
normalized_expressions.append((situation, style, source_id))
return normalized_expressions
@staticmethod
def _serialize_jargon_entries(jargon_entries: List[Tuple[str, str]]) -> List[dict[str, str]]:
"""将黑话候选序列化为 Hook 载荷。
Args:
jargon_entries: 原始黑话候选列表。
Returns:
List[dict[str, str]]: 序列化后的黑话候选列表。
"""
return [
{
"content": str(content).strip(),
"source_id": str(source_id).strip(),
}
for content, source_id in jargon_entries
if str(content).strip()
]
@staticmethod
def _deserialize_jargon_entries(raw_jargon_entries: Any) -> List[Tuple[str, str]]:
"""从 Hook 载荷恢复黑话候选列表。
Args:
raw_jargon_entries: Hook 返回的黑话候选列表。
Returns:
List[Tuple[str, str]]: 恢复后的黑话候选列表。
"""
if not isinstance(raw_jargon_entries, list):
return []
normalized_entries: List[Tuple[str, str]] = []
for raw_entry in raw_jargon_entries:
if not isinstance(raw_entry, dict):
continue
content = str(raw_entry.get("content") or "").strip()
source_id = str(raw_entry.get("source_id") or "").strip()
if not content:
continue
normalized_entries.append((content, source_id))
return normalized_entries
def add_messages(self, messages: List["SessionMessage"]) -> None:
"""添加消息到缓存"""
self._messages_cache.extend(messages)
@@ -52,8 +273,12 @@ class ExpressionLearner:
"""获取当前消息缓存的大小"""
return len(self._messages_cache)
async def learn(self, jargon_miner: Optional["JargonMiner"] = None):
"""学习主流程"""
async def learn(self, jargon_miner: Optional["JargonMiner"] = None) -> None:
"""执行表达方式学习主流程
Args:
jargon_miner: 可选的黑话学习器实例,用于同步处理黑话候选。
"""
if not self._messages_cache:
logger.debug("没有消息可供学习,跳过学习过程")
return
@@ -109,6 +334,25 @@ class ExpressionLearner:
logger.info(f"黑话提取数量超过 30 个(实际{len(jargon_entries)}个),放弃本次黑话学习")
jargon_entries = []
after_extract_result = await self._get_runtime_manager().invoke_hook(
"expression.learn.after_extract",
session_id=self.session_id,
message_count=len(self._messages_cache),
expressions=self._serialize_expressions(expressions),
jargon_entries=self._serialize_jargon_entries(jargon_entries),
)
if after_extract_result.aborted:
logger.info(f"{self.session_id} 的表达方式学习结果被 Hook 中止")
return
after_extract_kwargs = after_extract_result.kwargs
raw_expressions = after_extract_kwargs.get("expressions")
if raw_expressions is not None:
expressions = self._deserialize_expressions(raw_expressions)
raw_jargon_entries = after_extract_kwargs.get("jargon_entries")
if raw_jargon_entries is not None:
jargon_entries = self._deserialize_jargon_entries(raw_jargon_entries)
# 处理黑话条目,路由到 jargon_miner即使没有表达方式也要处理黑话
# TODO: 检测是否开启了
if jargon_entries:
@@ -135,6 +379,22 @@ class ExpressionLearner:
# 存储到数据库 Expression 表
for situation, style in learnt_expressions:
before_upsert_result = await self._get_runtime_manager().invoke_hook(
"expression.learn.before_upsert",
session_id=self.session_id,
situation=situation,
style=style,
)
if before_upsert_result.aborted:
logger.info(f"{self.session_id} 的表达方式写入被 Hook 跳过: situation={situation!r}")
continue
upsert_kwargs = before_upsert_result.kwargs
situation = str(upsert_kwargs.get("situation", situation) or "").strip()
style = str(upsert_kwargs.get("style", style) or "").strip()
if not situation or not style:
logger.info(f"{self.session_id} 的表达方式写入被 Hook 清空,已跳过")
continue
await self._upsert_expression_to_db(situation, style)
# ====== 黑话相关 ======

View File

@@ -1,27 +1,109 @@
from typing import Any, Dict, List, Optional, Tuple
import json
import time
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.utils.utils_session import SessionUtils
from src.prompt.prompt_manager import prompt_manager
from src.learners.learner_utils_old import weighted_sample
from src.chat.utils.common_utils import TempMethodsExpression
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.learners.learner_utils_old import weighted_sample
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
logger = get_logger("expression_selector")
class ExpressionSelector:
def __init__(self):
def __init__(self) -> None:
"""初始化表达方式选择器。"""
self.llm_model = LLMServiceClient(
task_name="utils", request_type="expression.selector"
)
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
"""将任意值安全转换为整数。
Args:
value: 待转换的值。
default: 转换失败时的默认值。
Returns:
int: 转换后的整数结果。
"""
try:
return int(value)
except (TypeError, ValueError):
return default
@staticmethod
def _normalize_selected_expressions(raw_expressions: Any) -> List[Dict[str, Any]]:
"""从 Hook 载荷恢复表达方式选择结果。
Args:
raw_expressions: Hook 返回的表达方式列表。
Returns:
List[Dict[str, Any]]: 恢复后的表达方式列表。
"""
if not isinstance(raw_expressions, list):
return []
normalized_expressions: List[Dict[str, Any]] = []
for raw_expression in raw_expressions:
if not isinstance(raw_expression, dict):
continue
expression_id = raw_expression.get("id")
situation = str(raw_expression.get("situation") or "").strip()
style = str(raw_expression.get("style") or "").strip()
source_id = str(raw_expression.get("source_id") or "").strip()
if not isinstance(expression_id, int) or not situation or not style or not source_id:
continue
normalized_expression = dict(raw_expression)
normalized_expression["id"] = expression_id
normalized_expression["situation"] = situation
normalized_expression["style"] = style
normalized_expression["source_id"] = source_id
normalized_expressions.append(normalized_expression)
return normalized_expressions
@staticmethod
def _normalize_selected_expression_ids(raw_ids: Any, expressions: List[Dict[str, Any]]) -> List[int]:
"""规范化最终选中的表达方式 ID 列表。
Args:
raw_ids: Hook 返回的 ID 列表。
expressions: 当前最终表达方式列表。
Returns:
List[int]: 规范化后的 ID 列表。
"""
if isinstance(raw_ids, list):
normalized_ids = [item for item in raw_ids if isinstance(item, int)]
if normalized_ids:
return normalized_ids
return [expression["id"] for expression in expressions if isinstance(expression.get("id"), int)]
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
@@ -214,8 +296,7 @@ class ExpressionSelector:
reply_reason: Optional[str] = None,
think_level: int = 1,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
选择适合的表达方式使用classic模式随机选择+LLM选择
"""选择适合的表达方式。
Args:
chat_id: 聊天流ID
@@ -233,11 +314,60 @@ class ExpressionSelector:
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
before_select_result = await self._get_runtime_manager().invoke_hook(
"expression.select.before_select",
chat_id=chat_id,
chat_info=chat_info,
max_num=max_num,
target_message=target_message or "",
reply_reason=reply_reason or "",
think_level=think_level,
)
if before_select_result.aborted:
logger.info(f"聊天流 {chat_id} 的表达方式选择被 Hook 中止")
return [], []
before_select_kwargs = before_select_result.kwargs
chat_id = str(before_select_kwargs.get("chat_id", chat_id) or "").strip() or chat_id
chat_info = str(before_select_kwargs.get("chat_info", chat_info) or "")
max_num = max(self._coerce_int(before_select_kwargs.get("max_num"), max_num), 1)
raw_target_message = before_select_kwargs.get("target_message", target_message or "")
target_message = str(raw_target_message or "").strip() or None
raw_reply_reason = before_select_kwargs.get("reply_reason", reply_reason or "")
reply_reason = str(raw_reply_reason or "").strip() or None
think_level = self._coerce_int(before_select_kwargs.get("think_level"), think_level)
# 使用classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式think_level={think_level}")
return await self._select_expressions_classic(
selected_expressions, selected_ids = await self._select_expressions_classic(
chat_id, chat_info, max_num, target_message, reply_reason, think_level
)
after_selection_result = await self._get_runtime_manager().invoke_hook(
"expression.select.after_selection",
chat_id=chat_id,
chat_info=chat_info,
max_num=max_num,
target_message=target_message or "",
reply_reason=reply_reason or "",
think_level=think_level,
selected_expressions=[dict(item) for item in selected_expressions],
selected_expression_ids=list(selected_ids),
)
if after_selection_result.aborted:
logger.info(f"聊天流 {chat_id} 的表达方式选择结果被 Hook 中止")
return [], []
after_selection_kwargs = after_selection_result.kwargs
raw_selected_expressions = after_selection_kwargs.get("selected_expressions")
if raw_selected_expressions is not None:
selected_expressions = self._normalize_selected_expressions(raw_selected_expressions)
selected_ids = self._normalize_selected_expression_ids(
after_selection_kwargs.get("selected_expression_ids"),
selected_expressions,
)
if selected_expressions:
self.update_expressions_last_active_time(selected_expressions)
return selected_expressions, selected_ids
async def _select_expressions_classic(
self,

View File

@@ -1,5 +1,5 @@
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Set, TypedDict
from typing import Any, Callable, Dict, List, Optional, Set, TypedDict
import asyncio
import json
@@ -9,13 +9,15 @@ from json_repair import repair_json
from sqlmodel import select
from src.common.data_models.jargon_data_model import MaiJargon
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.database.database import get_db_session
from src.common.database.database_model import Jargon
from src.common.logger import get_logger
from src.config.config import global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
from .expression_utils import is_single_char_jargon
@@ -35,8 +37,140 @@ class JargonMeaningEntry(TypedDict):
meaning: str
def register_jargon_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册 jargon 系统内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="jargon.query.before_search",
description="Maisaka 黑话查询工具执行检索前触发,可改写词条列表、检索参数或直接中止。",
parameters_schema=build_object_schema(
{
"words": {
"type": "array",
"items": {"type": "string"},
"description": "准备查询的黑话词条列表。",
},
"session_id": {"type": "string", "description": "当前会话 ID。"},
"limit": {"type": "integer", "description": "单个词条的最大返回条数。"},
"case_sensitive": {"type": "boolean", "description": "是否大小写敏感。"},
"enable_fuzzy_fallback": {"type": "boolean", "description": "是否允许精确命中失败后回退模糊检索。"},
"abort_message": {"type": "string", "description": "Hook 主动中止时的失败提示。"},
},
required=["words", "session_id", "limit", "case_sensitive", "enable_fuzzy_fallback"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="jargon.query.after_search",
description="Maisaka 黑话查询工具完成检索后触发,可改写结果列表或中止返回。",
parameters_schema=build_object_schema(
{
"words": {
"type": "array",
"items": {"type": "string"},
"description": "实际查询的黑话词条列表。",
},
"session_id": {"type": "string", "description": "当前会话 ID。"},
"limit": {"type": "integer", "description": "单个词条的最大返回条数。"},
"case_sensitive": {"type": "boolean", "description": "是否大小写敏感。"},
"enable_fuzzy_fallback": {"type": "boolean", "description": "是否启用了模糊检索回退。"},
"results": {
"type": "array",
"items": {"type": "object"},
"description": "查询结果列表。",
},
"abort_message": {"type": "string", "description": "Hook 主动中止时的失败提示。"},
},
required=["words", "session_id", "limit", "case_sensitive", "enable_fuzzy_fallback", "results"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="jargon.extract.before_persist",
description="黑话条目准备写入数据库前触发,可改写去重后的条目列表或跳过本次持久化。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"session_name": {"type": "string", "description": "当前会话展示名称。"},
"entries": {
"type": "array",
"items": {"type": "object"},
"description": "即将持久化的黑话条目列表。",
},
},
required=["session_id", "session_name", "entries"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="jargon.inference.before_finalize",
description="黑话含义推断完成、写回数据库前触发,可改写最终判定与含义结果。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"session_name": {"type": "string", "description": "当前会话展示名称。"},
"content": {"type": "string", "description": "当前黑话词条。"},
"count": {"type": "integer", "description": "当前词条累计命中次数。"},
"raw_content_list": {
"type": "array",
"items": {"type": "string"},
"description": "用于推断的原始上下文片段列表。",
},
"inference_with_context": {"type": "object", "description": "基于上下文的推断结果。"},
"inference_with_content_only": {"type": "object", "description": "仅基于词条内容的推断结果。"},
"comparison_result": {"type": "object", "description": "比较阶段输出结果。"},
"is_jargon": {"type": "boolean", "description": "当前推断是否判定为黑话。"},
"meaning": {"type": "string", "description": "当前推断出的黑话含义。"},
"is_complete": {"type": "boolean", "description": "当前是否已完成全部推断流程。"},
"last_inference_count": {"type": "integer", "description": "本次推断完成后应写回的 last_inference_count。"},
},
required=[
"session_id",
"session_name",
"content",
"count",
"raw_content_list",
"inference_with_context",
"inference_with_content_only",
"comparison_result",
"is_jargon",
"meaning",
"is_complete",
"last_inference_count",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
]
)
class JargonMiner:
def __init__(self, session_id: str, session_name: str) -> None:
"""初始化黑话学习器。
Args:
session_id: 当前会话 ID。
session_name: 当前会话展示名称。
"""
self.session_id = session_id
self.session_name = session_name
@@ -46,13 +180,92 @@ class JargonMiner:
# 黑话提取锁,防止并发执行
self._extraction_lock = asyncio.Lock()
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
"""将任意值安全转换为整数。
Args:
value: 待转换的值。
default: 转换失败时使用的默认值。
Returns:
int: 转换后的整数结果。
"""
try:
return int(value)
except (TypeError, ValueError):
return default
@staticmethod
def _serialize_jargon_entries(entries: List[JargonEntry]) -> List[Dict[str, object]]:
"""将黑话条目列表序列化为 Hook 可传输结构。
Args:
entries: 原始黑话条目列表。
Returns:
List[Dict[str, object]]: 序列化后的条目列表。
"""
return [
{
"content": str(entry["content"]).strip(),
"raw_content": sorted(str(item).strip() for item in entry["raw_content"] if str(item).strip()),
}
for entry in entries
if str(entry["content"]).strip()
]
@staticmethod
def _deserialize_jargon_entries(raw_entries: Any) -> List[JargonEntry]:
"""从 Hook 载荷恢复黑话条目列表。
Args:
raw_entries: Hook 返回的条目数据。
Returns:
List[JargonEntry]: 恢复后的黑话条目列表。
"""
if not isinstance(raw_entries, list):
return []
normalized_entries: List[JargonEntry] = []
for raw_entry in raw_entries:
if not isinstance(raw_entry, dict):
continue
content = str(raw_entry.get("content") or "").strip()
if not content:
continue
raw_content_values = raw_entry.get("raw_content")
raw_content: Set[str] = set()
if isinstance(raw_content_values, list):
raw_content = {str(item).strip() for item in raw_content_values if str(item).strip()}
normalized_entries.append({"content": content, "raw_content": raw_content})
return normalized_entries
def get_cached_jargons(self) -> List[str]:
"""获取缓存中的所有黑话列表"""
return list(self.cache.keys())
async def infer_meaning(self, jargon_obj: MaiJargon) -> None:
"""
对jargon进行含义推断
"""对黑话条目执行含义推断。
Args:
jargon_obj: 待推断的黑话数据对象。
"""
content = jargon_obj.content
# 解析raw_content列表
@@ -175,15 +388,45 @@ class JargonMiner:
is_similar = comparison_result.get("is_similar", False)
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
finalized_meaning = inference1.get("meaning", "") if is_jargon else ""
is_complete = (jargon_obj.count or 0) >= 100
last_inference_count = jargon_obj.count or 0
finalize_result = await self._get_runtime_manager().invoke_hook(
"jargon.inference.before_finalize",
session_id=self.session_id,
session_name=self.session_name,
content=content,
count=current_count,
raw_content_list=list(raw_content_list),
inference_with_context=dict(inference1),
inference_with_content_only=dict(inference2),
comparison_result=dict(comparison_result),
is_jargon=is_jargon,
meaning=finalized_meaning,
is_complete=is_complete,
last_inference_count=last_inference_count,
)
if finalize_result.aborted:
logger.info(f"jargon {content} 的推断结果被 Hook 中止写回")
return
finalize_kwargs = finalize_result.kwargs
is_jargon = bool(finalize_kwargs.get("is_jargon", is_jargon))
finalized_meaning = str(finalize_kwargs.get("meaning", finalized_meaning) or "").strip() if is_jargon else ""
is_complete = bool(finalize_kwargs.get("is_complete", is_complete))
last_inference_count = self._coerce_int(
finalize_kwargs.get("last_inference_count"),
last_inference_count,
)
# 更新数据库记录
jargon_obj.is_jargon = is_jargon
jargon_obj.meaning = inference1.get("meaning", "") if is_jargon else ""
jargon_obj.meaning = finalized_meaning
# 更新最后一次判定的count值避免重启后重复判定
jargon_obj.last_inference_count = jargon_obj.count or 0
jargon_obj.last_inference_count = last_inference_count
# 如果count>=100标记为完成不再进行推断
if (jargon_obj.count or 0) >= 100:
jargon_obj.is_complete = True
jargon_obj.is_complete = is_complete
try:
self._modify_jargon_entry(jargon_obj)
@@ -232,6 +475,22 @@ class JargonMiner:
merged_entries[content] = {"content": content, "raw_content": set(raw_list)}
uniq_entries: List[JargonEntry] = list(merged_entries.values())
before_persist_result = await self._get_runtime_manager().invoke_hook(
"jargon.extract.before_persist",
session_id=self.session_id,
session_name=self.session_name,
entries=self._serialize_jargon_entries(uniq_entries),
)
if before_persist_result.aborted:
logger.info(f"[{self.session_name}] 黑话提取结果被 Hook 中止,不写入数据库")
return
raw_hook_entries = before_persist_result.kwargs.get("entries")
if raw_hook_entries is not None:
uniq_entries = self._deserialize_jargon_entries(raw_hook_entries)
if not uniq_entries:
logger.info(f"[{self.session_name}] Hook 过滤后没有可写入的黑话条目")
return
saved = 0
updated = 0

View File

@@ -54,6 +54,84 @@ class MaisakaReasoningEngine:
self._runtime = runtime
self._last_reasoning_content: str = ""
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _normalize_words(raw_words: Any) -> list[str]:
"""清洗黑话查询词条列表。
Args:
raw_words: 原始词条列表。
Returns:
list[str]: 去重去空白后的词条列表。
"""
if not isinstance(raw_words, list):
return []
normalized_words: list[str] = []
seen_words: set[str] = set()
for item in raw_words:
if not isinstance(item, str):
continue
word = item.strip()
if not word or word in seen_words:
continue
seen_words.add(word)
normalized_words.append(word)
return normalized_words
@staticmethod
def _normalize_jargon_query_results(raw_results: Any) -> list[dict[str, object]]:
"""规范化黑话查询结果列表。
Args:
raw_results: Hook 返回的结果列表。
Returns:
list[dict[str, object]]: 清洗后的结果列表。
"""
if not isinstance(raw_results, list):
return []
normalized_results: list[dict[str, object]] = []
for raw_item in raw_results:
if not isinstance(raw_item, dict):
continue
word = str(raw_item.get("word") or "").strip()
matches = raw_item.get("matches")
normalized_matches: list[dict[str, str]] = []
if isinstance(matches, list):
for match in matches:
if not isinstance(match, dict):
continue
content = str(match.get("content") or "").strip()
meaning = str(match.get("meaning") or "").strip()
if not content or not meaning:
continue
normalized_matches.append({"content": content, "meaning": meaning})
normalized_results.append(
{
"word": word,
"found": bool(raw_item.get("found", bool(normalized_matches))),
"matches": normalized_matches,
}
)
return normalized_results
def build_builtin_tool_handlers(self) -> dict[str, "BuiltinToolHandler"]:
"""构造 Maisaka 内置工具处理器映射。
@@ -1012,6 +1090,35 @@ class MaisakaReasoningEngine:
"查询黑话工具至少需要一个非空词条。",
)
limit = 5
case_sensitive = False
enable_fuzzy_fallback = True
before_search_result = await self._get_runtime_manager().invoke_hook(
"jargon.query.before_search",
words=list(words),
session_id=self._runtime.session_id,
limit=limit,
case_sensitive=case_sensitive,
enable_fuzzy_fallback=enable_fuzzy_fallback,
abort_message="黑话查询已被 Hook 中止。",
)
if before_search_result.aborted:
abort_message = str(before_search_result.kwargs.get("abort_message") or "黑话查询已被 Hook 中止。").strip()
return self._build_tool_failure_result(tool_call.func_name, abort_message or "黑话查询已被 Hook 中止。")
before_search_kwargs = before_search_result.kwargs
if before_search_kwargs.get("words") is not None:
words = self._normalize_words(before_search_kwargs.get("words"))
if not words:
return self._build_tool_failure_result(tool_call.func_name, "Hook 过滤后没有可查询的黑话词条。")
try:
limit = int(before_search_kwargs.get("limit", limit))
except (TypeError, ValueError):
limit = 5
limit = max(limit, 1)
case_sensitive = bool(before_search_kwargs.get("case_sensitive", case_sensitive))
enable_fuzzy_fallback = bool(before_search_kwargs.get("enable_fuzzy_fallback", enable_fuzzy_fallback))
logger.info(f"{self._runtime.log_prefix} 已触发黑话查询: 词条={words!r}")
results: list[dict[str, object]] = []
@@ -1019,17 +1126,19 @@ class MaisakaReasoningEngine:
exact_matches = search_jargon(
keyword=word,
chat_id=self._runtime.session_id,
limit=5,
case_sensitive=False,
limit=limit,
case_sensitive=case_sensitive,
fuzzy=False,
)
matched_entries = exact_matches or search_jargon(
keyword=word,
chat_id=self._runtime.session_id,
limit=5,
case_sensitive=False,
fuzzy=True,
)
matched_entries = exact_matches
if not matched_entries and enable_fuzzy_fallback:
matched_entries = search_jargon(
keyword=word,
chat_id=self._runtime.session_id,
limit=limit,
case_sensitive=case_sensitive,
fuzzy=True,
)
results.append(
{
@@ -1039,6 +1148,27 @@ class MaisakaReasoningEngine:
}
)
after_search_result = await self._get_runtime_manager().invoke_hook(
"jargon.query.after_search",
words=list(words),
session_id=self._runtime.session_id,
limit=limit,
case_sensitive=case_sensitive,
enable_fuzzy_fallback=enable_fuzzy_fallback,
results=list(results),
abort_message="黑话查询结果已被 Hook 中止。",
)
if after_search_result.aborted:
abort_message = str(after_search_result.kwargs.get("abort_message") or "黑话查询结果已被 Hook 中止。").strip()
return self._build_tool_failure_result(
tool_call.func_name,
abort_message or "黑话查询结果已被 Hook 中止。",
)
raw_results = after_search_result.kwargs.get("results")
if raw_results is not None:
results = self._normalize_jargon_query_results(raw_results)
logger.info(f"{self._runtime.log_prefix} 黑话查询完成: 结果={results!r}")
return self._build_tool_success_result(
tool_call.func_name,

View File

@@ -20,11 +20,17 @@ def _get_builtin_hook_spec_registrars() -> List[HookSpecRegistrar]:
"""
from src.chat.message_receive.bot import register_chat_hook_specs
from src.chat.emoji_system.emoji_manager import register_emoji_hook_specs
from src.learners.expression_learner import register_expression_hook_specs
from src.learners.jargon_miner import register_jargon_hook_specs
from src.maisaka.chat_loop_service import register_maisaka_hook_specs
from src.services.send_service import register_send_service_hook_specs
return [
register_chat_hook_specs,
register_emoji_hook_specs,
register_jargon_hook_specs,
register_expression_hook_specs,
register_send_service_hook_specs,
register_maisaka_hook_specs,
]

View File

@@ -9,6 +9,7 @@ from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.routers.websocket.manager import websocket_manager
logger = get_logger("webui.logs_ws")
router = APIRouter()
@@ -148,24 +149,9 @@ async def broadcast_log(log_data: Dict):
Args:
log_data: 日志数据字典
"""
if not active_connections:
return
# 格式化为 JSON
message = json.dumps(log_data, ensure_ascii=False)
# 记录需要断开的连接
disconnected = set()
# 广播到所有客户端
for connection in active_connections:
try:
await connection.send_text(message)
except Exception:
# 发送失败,标记为断开
disconnected.add(connection)
# 清理断开的连接
if disconnected:
active_connections.difference_update(disconnected)
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")
await websocket_manager.broadcast_to_topic(
domain="logs",
topic="main",
event="entry",
data={"entry": log_data},
)

View File

@@ -18,12 +18,10 @@ def get_all_routers() -> List[APIRouter]:
from src.webui.api.replier import router as replier_router
from src.webui.routers.chat import router as chat_router
from src.webui.routers.knowledge import router as knowledge_router
from src.webui.routers.websocket.logs import router as logs_router
from src.webui.routes import router as main_router
return [
main_router,
logs_router,
knowledge_router,
chat_router,
planner_router,

View File

@@ -1,7 +1,7 @@
from typing import Tuple
from .routes import router
from .support import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
from .service import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
def get_webui_chat_broadcaster() -> Tuple[ChatConnectionManager, str]:

View File

@@ -1,9 +1,8 @@
"""本地聊天室路由 - WebUI 与麦麦直接对话。"""
import uuid
from typing import Dict, Optional
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Depends, Query
from sqlalchemy import case, func
from sqlmodel import col, select
@@ -13,16 +12,11 @@ from src.common.logger import get_logger
from src.config.config import global_config
from src.webui.dependencies import require_auth
from .support import (
from .service import (
WEBUI_CHAT_GROUP_ID,
WEBUI_CHAT_PLATFORM,
authenticate_chat_websocket,
chat_history,
chat_manager,
dispatch_chat_event,
normalize_webui_user_id,
resolve_initial_virtual_identity,
send_initial_chat_state,
)
logger = get_logger("webui.chat")
@@ -113,55 +107,6 @@ async def clear_chat_history(
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
@router.websocket("/ws")
async def websocket_chat(
websocket: WebSocket,
user_id: Optional[str] = Query(default=None),
user_name: Optional[str] = Query(default="WebUI用户"),
platform: Optional[str] = Query(default=None),
person_id: Optional[str] = Query(default=None),
group_name: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None),
token: Optional[str] = Query(default=None),
) -> None:
"""WebSocket 聊天端点。"""
if not await authenticate_chat_websocket(websocket, token):
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
session_id = str(uuid.uuid4())
normalized_user_id = normalize_webui_user_id(user_id)
current_user_name = user_name or "WebUI用户"
current_virtual_config = resolve_initial_virtual_identity(platform, person_id, group_name, group_id)
await chat_manager.connect(websocket, session_id, normalized_user_id)
try:
await send_initial_chat_state(
session_id=session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
)
while True:
data = await websocket.receive_json()
current_user_name, current_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data=data,
current_user_name=current_user_name,
normalized_user_id=normalized_user_id,
current_virtual_config=current_virtual_config,
)
except WebSocketDisconnect:
logger.info(f"WebSocket 断开: session={session_id}, user={normalized_user_id}")
except Exception as e:
logger.error(f"WebSocket 错误: {e}")
finally:
chat_manager.disconnect(session_id, normalized_user_id)
@router.get("/info")
async def get_chat_info() -> Dict[str, object]:
"""获取聊天室信息。"""

View File

@@ -1,10 +1,10 @@
"""WebUI 聊天路由支持逻辑"""
"""WebUI 聊天运行时服务"""
from dataclasses import dataclass
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, cast
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, cast
from fastapi import WebSocket
from pydantic import BaseModel
from sqlmodel import col, delete, select
@@ -17,8 +17,6 @@ from src.common.logger import get_logger
from src.common.message_repository import find_messages
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.chat")
@@ -27,6 +25,8 @@ WEBUI_CHAT_PLATFORM = "webui"
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
WEBUI_USER_ID_PREFIX = "webui_user_"
AsyncMessageSender = Callable[[Dict[str, Any]], Awaitable[None]]
class VirtualIdentityConfig(BaseModel):
"""虚拟身份配置。"""
@@ -52,13 +52,42 @@ class ChatHistoryMessage(BaseModel):
is_bot: bool = False
@dataclass
class ChatSessionConnection:
"""逻辑聊天会话连接信息。"""
session_id: str
connection_id: str
client_session_id: str
user_id: str
user_name: str
active_group_id: str
virtual_config: Optional[VirtualIdentityConfig]
sender: AsyncMessageSender
class ChatHistoryManager:
"""聊天历史管理器。"""
def __init__(self, max_messages: int = 200) -> None:
"""初始化聊天历史管理器。
Args:
max_messages: 内存中允许处理的最大消息数
"""
self.max_messages = max_messages
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> Dict[str, Any]:
"""将内部消息对象转换为前端可消费的字典。
Args:
msg: 内部统一消息对象
group_id: 当前会话所属的群组标识
Returns:
Dict[str, Any]: 面向 WebUI 的消息字典
"""
del group_id
user_info = msg.message_info.user_info
user_id = user_info.user_id or ""
is_bot = is_bot_self(msg.platform, user_id)
@@ -74,10 +103,27 @@ class ChatHistoryManager:
}
def _resolve_session_id(self, group_id: Optional[str]) -> str:
"""根据群组标识解析聊天会话 ID。
Args:
group_id: 群组标识
Returns:
str: 内部聊天会话 ID
"""
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""获取指定会话的历史消息。
Args:
limit: 最大返回条数
group_id: 群组标识
Returns:
List[Dict[str, Any]]: 历史消息列表
"""
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
session_id = self._resolve_session_id(target_group_id)
try:
@@ -90,11 +136,19 @@ class ChatHistoryManager:
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
return result
except Exception as e:
logger.error(f"从数据库加载聊天记录失败: {e}")
except Exception as exc:
logger.error(f"从数据库加载聊天记录失败: {exc}")
return []
def clear_history(self, group_id: Optional[str] = None) -> int:
"""清空指定会话的历史消息。
Args:
group_id: 群组标识
Returns:
int: 被删除的消息数量
"""
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
session_id = self._resolve_session_id(target_group_id)
try:
@@ -104,66 +158,245 @@ class ChatHistoryManager:
deleted = result.rowcount or 0
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
return deleted
except Exception as e:
logger.error(f"清空聊天记录失败: {e}")
except Exception as exc:
logger.error(f"清空聊天记录失败: {exc}")
return 0
class ChatConnectionManager:
"""聊天连接管理器。"""
"""统一聊天逻辑会话管理器。"""
def __init__(self) -> None:
self.active_connections: Dict[str, WebSocket] = {}
self.user_sessions: Dict[str, str] = {}
"""初始化聊天逻辑会话管理器。"""
self.active_connections: Dict[str, ChatSessionConnection] = {}
self.client_sessions: Dict[Tuple[str, str], str] = {}
self.connection_sessions: Dict[str, Set[str]] = {}
self.group_sessions: Dict[str, Set[str]] = {}
self.user_sessions: Dict[str, Set[str]] = {}
async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None:
await websocket.accept()
self.active_connections[session_id] = websocket
self.user_sessions[user_id] = session_id
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
def _bind_group(self, session_id: str, group_id: str) -> None:
"""为会话绑定群组索引。
def disconnect(self, session_id: str, user_id: str) -> None:
if session_id in self.active_connections:
del self.active_connections[session_id]
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
del self.user_sessions[user_id]
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
Args:
session_id: 内部会话 ID
group_id: 群组标识
"""
group_session_ids = self.group_sessions.setdefault(group_id, set())
group_session_ids.add(session_id)
def _unbind_group(self, session_id: str, group_id: str) -> None:
"""移除会话与群组的索引关系。
Args:
session_id: 内部会话 ID
group_id: 群组标识
"""
group_session_ids = self.group_sessions.get(group_id)
if group_session_ids is None:
return
group_session_ids.discard(session_id)
if not group_session_ids:
del self.group_sessions[group_id]
async def connect(
self,
session_id: str,
connection_id: str,
client_session_id: str,
user_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
sender: AsyncMessageSender,
) -> None:
"""注册一个新的逻辑聊天会话。
Args:
session_id: 内部逻辑会话 ID
connection_id: 物理 WebSocket 连接 ID
client_session_id: 前端标签页使用的会话 ID
user_id: 规范化后的用户 ID
user_name: 当前展示昵称
virtual_config: 当前虚拟身份配置
sender: 发送消息到前端的异步回调
"""
existing_session_id = self.client_sessions.get((connection_id, client_session_id))
if existing_session_id is not None:
self.disconnect(existing_session_id)
active_group_id = get_current_group_id(virtual_config)
session_connection = ChatSessionConnection(
session_id=session_id,
connection_id=connection_id,
client_session_id=client_session_id,
user_id=user_id,
user_name=user_name,
active_group_id=active_group_id,
virtual_config=virtual_config,
sender=sender,
)
self.active_connections[session_id] = session_connection
self.client_sessions[(connection_id, client_session_id)] = session_id
self.connection_sessions.setdefault(connection_id, set()).add(session_id)
self.user_sessions.setdefault(user_id, set()).add(session_id)
self._bind_group(session_id, active_group_id)
logger.info(
"WebUI 聊天会话已连接: session=%s, connection=%s, client_session=%s, user=%s, group=%s",
session_id,
connection_id,
client_session_id,
user_id,
active_group_id,
)
def disconnect(self, session_id: str) -> None:
"""断开一个逻辑聊天会话。
Args:
session_id: 内部逻辑会话 ID
"""
session_connection = self.active_connections.pop(session_id, None)
if session_connection is None:
return
self.client_sessions.pop((session_connection.connection_id, session_connection.client_session_id), None)
self._unbind_group(session_id, session_connection.active_group_id)
connection_session_ids = self.connection_sessions.get(session_connection.connection_id)
if connection_session_ids is not None:
connection_session_ids.discard(session_id)
if not connection_session_ids:
del self.connection_sessions[session_connection.connection_id]
user_session_ids = self.user_sessions.get(session_connection.user_id)
if user_session_ids is not None:
user_session_ids.discard(session_id)
if not user_session_ids:
del self.user_sessions[session_connection.user_id]
logger.info("WebUI 聊天会话已断开: session=%s", session_id)
def disconnect_connection(self, connection_id: str) -> None:
"""断开物理连接下的全部逻辑聊天会话。
Args:
connection_id: 物理 WebSocket 连接 ID
"""
session_ids = list(self.connection_sessions.get(connection_id, set()))
for session_id in session_ids:
self.disconnect(session_id)
def get_session(self, session_id: str) -> Optional[ChatSessionConnection]:
"""获取逻辑聊天会话信息。
Args:
session_id: 内部逻辑会话 ID
Returns:
Optional[ChatSessionConnection]: 会话存在时返回对应信息
"""
return self.active_connections.get(session_id)
def get_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
"""根据连接 ID 和前端会话 ID 查询内部会话 ID。
Args:
connection_id: 物理 WebSocket 连接 ID
client_session_id: 前端标签页使用的会话 ID
Returns:
Optional[str]: 找到时返回内部会话 ID
"""
return self.client_sessions.get((connection_id, client_session_id))
def update_session_context(
self,
session_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> None:
"""更新会话上下文信息。
Args:
session_id: 内部逻辑会话 ID
user_name: 最新昵称
virtual_config: 最新虚拟身份配置
"""
session_connection = self.active_connections.get(session_id)
if session_connection is None:
return
next_group_id = get_current_group_id(virtual_config)
if next_group_id != session_connection.active_group_id:
self._unbind_group(session_id, session_connection.active_group_id)
self._bind_group(session_id, next_group_id)
session_connection.active_group_id = next_group_id
session_connection.user_name = user_name
session_connection.virtual_config = virtual_config
async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
if session_id in self.active_connections:
try:
await self.active_connections[session_id].send_json(message)
except Exception as e:
logger.error(f"发送消息失败: {e}")
"""向指定逻辑会话发送消息。
Args:
session_id: 内部逻辑会话 ID
message: 发送消息内容
"""
session_connection = self.active_connections.get(session_id)
if session_connection is None:
return
try:
await session_connection.sender(message)
except Exception as exc:
logger.error("发送聊天消息失败: session=%s, error=%s", session_id, exc)
async def broadcast(self, message: Dict[str, Any]) -> None:
"""向全部逻辑聊天会话广播消息。
Args:
message: 待广播的消息内容
"""
for session_id in list(self.active_connections.keys()):
await self.send_message(session_id, message)
async def broadcast_to_group(self, group_id: str, message: Dict[str, Any]) -> None:
"""向指定群组下的全部逻辑会话广播消息。
Args:
group_id: 群组标识
message: 待广播的消息内容
"""
for session_id in list(self.group_sessions.get(group_id, set())):
await self.send_message(session_id, message)
chat_history = ChatHistoryManager()
chat_manager = ChatConnectionManager()
def is_virtual_mode_enabled(virtual_config: Optional[VirtualIdentityConfig]) -> bool:
"""判断当前是否启用了虚拟身份模式。
Args:
virtual_config: 虚拟身份配置
Returns:
bool: 已启用时返回 ``True``
"""
return bool(virtual_config and virtual_config.enabled)
async def authenticate_chat_websocket(websocket: WebSocket, token: Optional[str]) -> bool:
if token and verify_ws_token(token):
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
return True
if cookie_token := websocket.cookies.get("maibot_session"):
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
return True
return False
def normalize_webui_user_id(user_id: Optional[str]) -> str:
"""标准化 WebUI 用户 ID。
Args:
user_id: 原始用户 ID
Returns:
str: 带统一前缀的用户 ID
"""
if not user_id:
return f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
if user_id.startswith(WEBUI_USER_ID_PREFIX):
@@ -172,12 +405,30 @@ def normalize_webui_user_id(user_id: Optional[str]) -> str:
def get_person_by_person_id(person_id: str) -> Optional[PersonInfo]:
"""根据人物 ID 查询人物信息。
Args:
person_id: 人物 ID
Returns:
Optional[PersonInfo]: 查到时返回人物信息
"""
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
return session.exec(statement).first()
def build_virtual_identity_config(person: PersonInfo, group_id: str, group_name: str) -> VirtualIdentityConfig:
"""根据人物信息构建虚拟身份配置。
Args:
person: 人物信息对象
group_id: 逻辑群组 ID
group_name: 逻辑群组名称
Returns:
VirtualIdentityConfig: 虚拟身份配置对象
"""
return VirtualIdentityConfig(
enabled=True,
platform=person.platform,
@@ -195,6 +446,17 @@ def resolve_initial_virtual_identity(
group_name: Optional[str],
group_id: Optional[str],
) -> Optional[VirtualIdentityConfig]:
"""根据初始参数解析虚拟身份配置。
Args:
platform: 平台名称
person_id: 人物 ID
group_name: 群组名称
group_id: 群组 ID
Returns:
Optional[VirtualIdentityConfig]: 解析成功时返回虚拟身份配置
"""
if not (platform and person_id):
return None
@@ -210,11 +472,14 @@ def resolve_initial_virtual_identity(
group_name=group_name or "WebUI虚拟群聊",
)
logger.info(
f"虚拟身份模式已通过 URL 参数激活: {virtual_config.user_nickname} @ {virtual_config.platform}, group_id={virtual_group_id}"
"虚拟身份模式已通过参数激活: %s @ %s, group_id=%s",
virtual_config.user_nickname,
virtual_config.platform,
virtual_group_id,
)
return virtual_config
except Exception as e:
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
except Exception as exc:
logger.warning(f"通过参数配置虚拟身份失败: {exc}")
return None
@@ -224,6 +489,17 @@ def build_session_info_message(
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> Dict[str, Any]:
"""构建会话信息消息。
Args:
session_id: 内部逻辑会话 ID
user_id: 规范化后的用户 ID
user_name: 当前昵称
virtual_config: 虚拟身份配置
Returns:
Dict[str, Any]: 会话信息消息
"""
session_info_data: Dict[str, Any] = {
"type": "session_info",
"session_id": session_id,
@@ -247,13 +523,41 @@ def build_session_info_message(
def get_active_history_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> Optional[str]:
"""获取当前虚拟身份对应的历史群组 ID。
Args:
virtual_config: 虚拟身份配置
Returns:
Optional[str]: 虚拟身份启用时返回对应群组 ID
"""
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return virtual_config.group_id
return None
def get_current_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> str:
"""获取当前会话的有效群组 ID。
Args:
virtual_config: 虚拟身份配置
Returns:
str: 当前会话应使用的群组 ID
"""
return get_active_history_group_id(virtual_config) or WEBUI_CHAT_GROUP_ID
def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> str:
"""构建欢迎消息。
Args:
virtual_config: 虚拟身份配置
Returns:
str: 欢迎消息文本
"""
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return (
@@ -264,6 +568,12 @@ def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> st
async def send_chat_error(session_id: str, content: str) -> None:
"""向指定会话发送错误消息。
Args:
session_id: 内部逻辑会话 ID
content: 错误消息内容
"""
await chat_manager.send_message(
session_id,
{
@@ -279,7 +589,17 @@ async def send_initial_chat_state(
user_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
include_welcome: bool = True,
) -> None:
"""向新会话发送初始化状态。
Args:
session_id: 内部逻辑会话 ID
user_id: 规范化后的用户 ID
user_name: 当前昵称
virtual_config: 虚拟身份配置
include_welcome: 是否发送欢迎消息
"""
await chat_manager.send_message(
session_id,
build_session_info_message(
@@ -290,30 +610,43 @@ async def send_initial_chat_state(
),
)
if history := chat_history.get_history(50, get_active_history_group_id(virtual_config)):
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": history,
},
)
history_group_id = get_active_history_group_id(virtual_config)
history = chat_history.get_history(50, history_group_id)
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": build_welcome_message(virtual_config),
"timestamp": time.time(),
"type": "history",
"messages": history,
"group_id": get_current_group_id(virtual_config),
},
)
if include_welcome:
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": build_welcome_message(virtual_config),
"timestamp": time.time(),
},
)
def resolve_sender_identity(
current_user_name: str,
normalized_user_id: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> Tuple[str, str]:
"""解析当前发送者身份。
Args:
current_user_name: 当前昵称
normalized_user_id: 规范化后的用户 ID
virtual_config: 虚拟身份配置
Returns:
Tuple[str, str]: ``(发送者昵称, 发送者用户 ID)``
"""
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id
@@ -328,6 +661,19 @@ def create_message_data(
is_at_bot: bool = True,
virtual_config: Optional[VirtualIdentityConfig] = None,
) -> Dict[str, Any]:
"""构建发送给聊天核心的消息数据。
Args:
content: 文本内容
user_id: 用户 ID
user_name: 用户昵称
message_id: 消息 ID
is_at_bot: 是否默认艾特机器人
virtual_config: 虚拟身份配置
Returns:
Dict[str, Any]: 聊天核心可处理的消息数据
"""
if message_id is None:
message_id = str(uuid.uuid4())
@@ -389,6 +735,18 @@ async def handle_chat_message(
normalized_user_id: str,
current_virtual_config: Optional[VirtualIdentityConfig],
) -> str:
"""处理用户发送的聊天消息。
Args:
session_id: 内部逻辑会话 ID
data: 前端提交的消息数据
current_user_name: 当前昵称
normalized_user_id: 规范化后的用户 ID
current_virtual_config: 当前虚拟身份配置
Returns:
str: 处理后的最新昵称
"""
content = str(data.get("content", "")).strip()
if not content:
return current_user_name
@@ -401,11 +759,14 @@ async def handle_chat_message(
normalized_user_id=normalized_user_id,
virtual_config=current_virtual_config,
)
target_group_id = get_current_group_id(current_virtual_config)
await chat_manager.broadcast(
await chat_manager.broadcast_to_group(
target_group_id,
{
"type": "user_message",
"content": content,
"group_id": target_group_id,
"message_id": message_id,
"timestamp": timestamp,
"sender": {
@@ -414,7 +775,7 @@ async def handle_chat_message(
"is_bot": False,
},
"virtual_mode": is_virtual_mode_enabled(current_virtual_config),
}
},
)
message_data = create_message_data(
@@ -427,22 +788,37 @@ async def handle_chat_message(
)
try:
await chat_manager.broadcast({"type": "typing", "is_typing": True})
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": True})
await chat_bot.message_process(message_data)
except Exception as e:
logger.error(f"处理消息时出错: {e}")
await send_chat_error(session_id, f"处理消息时出错: {str(e)}")
except Exception as exc:
logger.error(f"处理消息时出错: {exc}")
await send_chat_error(session_id, f"处理消息时出错: {str(exc)}")
finally:
await chat_manager.broadcast({"type": "typing", "is_typing": False})
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": False})
return next_user_name
async def handle_chat_ping(session_id: str) -> None:
"""处理聊天心跳。
Args:
session_id: 内部逻辑会话 ID
"""
await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()})
async def handle_nickname_update(session_id: str, data: Dict[str, Any], current_user_name: str) -> str:
"""处理昵称更新请求。
Args:
session_id: 内部逻辑会话 ID
data: 前端提交的数据
current_user_name: 当前昵称
Returns:
str: 更新后的昵称
"""
new_name = str(data.get("user_name", "")).strip()
if not new_name:
return current_user_name
@@ -463,6 +839,16 @@ async def enable_virtual_identity(
session_prefix: str,
virtual_data: Dict[str, Any],
) -> Optional[VirtualIdentityConfig]:
"""启用虚拟身份模式。
Args:
session_id: 内部逻辑会话 ID
session_prefix: 会话前缀用于生成默认群组 ID
virtual_data: 前端提交的虚拟身份配置
Returns:
Optional[VirtualIdentityConfig]: 启用成功时返回新的虚拟身份配置
"""
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id")
return None
@@ -470,16 +856,18 @@ async def enable_virtual_identity(
person_id_value = str(virtual_data.get("person_id"))
try:
person = get_person_by_person_id(person_id_value)
if not person:
if person is None:
await send_chat_error(session_id, f"找不到用户: {person_id_value}")
return None
custom_group_id = virtual_data.get("group_id")
current_group_id = (
f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
if custom_group_id
else f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
)
custom_group_id = str(virtual_data.get("group_id") or "").strip()
if custom_group_id:
current_group_id = custom_group_id
if not current_group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{current_group_id}"
else:
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
current_virtual_config = build_virtual_identity_config(
person=person,
group_id=current_group_id,
@@ -521,13 +909,18 @@ async def enable_virtual_identity(
},
)
return current_virtual_config
except Exception as e:
logger.error(f"设置虚拟身份失败: {e}")
await send_chat_error(session_id, f"设置虚拟身份失败: {str(e)}")
except Exception as exc:
logger.error(f"设置虚拟身份失败: {exc}")
await send_chat_error(session_id, f"设置虚拟身份失败: {str(exc)}")
return None
async def disable_virtual_identity(session_id: str) -> None:
"""关闭虚拟身份模式。
Args:
session_id: 内部逻辑会话 ID
"""
await chat_manager.send_message(
session_id,
{
@@ -560,7 +953,18 @@ async def handle_virtual_identity_update(
data: Dict[str, Any],
current_virtual_config: Optional[VirtualIdentityConfig],
) -> Optional[VirtualIdentityConfig]:
virtual_data = cast(dict[str, Any], data.get("config", {}))
"""处理虚拟身份切换请求。
Args:
session_id: 内部逻辑会话 ID
session_id_prefix: 会话前缀
data: 前端提交的数据
current_virtual_config: 当前虚拟身份配置
Returns:
Optional[VirtualIdentityConfig]: 更新后的虚拟身份配置
"""
virtual_data = cast(Dict[str, Any], data.get("config", {}))
if virtual_data.get("enabled"):
next_config = await enable_virtual_identity(session_id, session_id_prefix, virtual_data)
return next_config if next_config is not None else current_virtual_config
@@ -577,6 +981,19 @@ async def dispatch_chat_event(
normalized_user_id: str,
current_virtual_config: Optional[VirtualIdentityConfig],
) -> Tuple[str, Optional[VirtualIdentityConfig]]:
"""分发聊天事件到对应的处理器。
Args:
session_id: 内部逻辑会话 ID
session_id_prefix: 会话前缀
data: 前端提交的数据
current_user_name: 当前昵称
normalized_user_id: 规范化后的用户 ID
current_virtual_config: 当前虚拟身份配置
Returns:
Tuple[str, Optional[VirtualIdentityConfig]]: ``(最新昵称, 最新虚拟身份配置)``
"""
event_type = data.get("type")
if event_type == "message":
next_user_name = await handle_chat_message(

View File

@@ -1,12 +1,15 @@
"""插件进度实时推送支持。"""
from typing import Any, Dict, Optional, Set
import asyncio
import json
from typing import Any, Dict, Optional, Set
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.routers.websocket.manager import websocket_manager
logger = get_logger("webui.plugin_progress")
@@ -25,25 +28,29 @@ current_progress: Dict[str, Any] = {
}
def get_current_progress() -> Dict[str, Any]:
"""获取当前插件进度快照。
Returns:
Dict[str, Any]: 当前插件进度数据副本。
"""
return current_progress.copy()
async def broadcast_progress(progress_data: Dict[str, Any]) -> None:
"""向统一连接层广播插件进度更新。
Args:
progress_data: 插件进度数据。
"""
global current_progress
current_progress = progress_data.copy()
if not active_connections:
return
message = json.dumps(progress_data, ensure_ascii=False)
disconnected: Set[WebSocket] = set()
for websocket in active_connections:
try:
await websocket.send_text(message)
except Exception as e:
logger.error(f"发送进度更新失败: {e}")
disconnected.add(websocket)
for websocket in disconnected:
active_connections.discard(websocket)
await websocket_manager.broadcast_to_topic(
domain="plugin_progress",
topic="main",
event="update",
data={"progress": progress_data},
)
async def update_progress(
@@ -56,6 +63,18 @@ async def update_progress(
total_plugins: int = 0,
loaded_plugins: int = 0,
) -> None:
"""更新当前插件进度并广播。
Args:
stage: 当前阶段。
progress: 当前进度百分比。
message: 进度说明消息。
operation: 当前操作类型。
error: 可选的错误信息。
plugin_id: 当前处理的插件 ID。
total_plugins: 总插件数量。
loaded_plugins: 已处理插件数量。
"""
progress_data = {
"operation": operation,
"stage": stage,
@@ -74,6 +93,12 @@ async def update_progress(
@router.websocket("/ws/plugin-progress")
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
"""旧版插件进度 WebSocket 入口。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
"""
is_authenticated = False
if token and verify_ws_token(token):
@@ -105,17 +130,22 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
data = await websocket.receive_text()
if data == "ping":
await websocket.send_text("pong")
except Exception as e:
logger.error(f"处理客户端消息时出错: {e}")
except Exception as exc:
logger.error(f"处理客户端消息时出错: {exc}")
break
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
except Exception as e:
logger.error(f"❌ WebSocket 错误: {e}")
except Exception as exc:
logger.error(f"❌ WebSocket 错误: {exc}")
active_connections.discard(websocket)
def get_progress_router() -> APIRouter:
"""获取旧版插件进度路由对象。
Returns:
APIRouter: 插件进度路由对象。
"""
return router

View File

@@ -1,7 +1,9 @@
"""WebSocket 路由聚合导出。"""
from .auth import router as ws_auth_router
from .logs import router as logs_router
from .unified import router as unified_ws_router
__all__ = [
"logs_router",
"unified_ws_router",
"ws_auth_router",
]

View File

@@ -1,11 +0,0 @@
"""WebSocket 日志推送路由兼容导出。"""
from src.webui.logs_ws import active_connections, broadcast_log, load_recent_logs, router, websocket_logs
__all__ = [
"active_connections",
"broadcast_log",
"load_recent_logs",
"router",
"websocket_logs",
]

View File

@@ -0,0 +1,297 @@
"""统一 WebSocket 连接管理器。"""
import asyncio
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Set
from fastapi import WebSocket
from src.common.logger import get_logger
logger = get_logger("webui.websocket")
@dataclass
class WebSocketConnection:
"""统一 WebSocket 连接上下文。"""
connection_id: str
websocket: WebSocket
subscriptions: Set[str] = field(default_factory=set)
chat_sessions: Dict[str, str] = field(default_factory=dict)
send_queue: "asyncio.Queue[Optional[Dict[str, Any]]]" = field(default_factory=asyncio.Queue)
sender_task: Optional["asyncio.Task[None]"] = None
class UnifiedWebSocketManager:
"""统一 WebSocket 连接管理器。"""
def __init__(self) -> None:
"""初始化统一 WebSocket 连接管理器。"""
self.connections: Dict[str, WebSocketConnection] = {}
def _build_subscription_key(self, domain: str, topic: str) -> str:
"""构建订阅索引键。
Args:
domain: 业务域名称。
topic: 主题名称。
Returns:
str: 订阅索引键。
"""
return f"{domain}:{topic}"
async def _sender_loop(self, connection: WebSocketConnection) -> None:
"""串行发送指定连接的出站消息。
Args:
connection: 目标连接上下文。
"""
try:
while True:
message = await connection.send_queue.get()
if message is None:
return
await connection.websocket.send_json(message)
except asyncio.CancelledError:
raise
except Exception as exc:
logger.error("统一 WebSocket 发送失败: connection=%s, error=%s", connection.connection_id, exc)
async def connect(self, connection_id: str, websocket: WebSocket) -> WebSocketConnection:
"""注册一个新的物理 WebSocket 连接。
Args:
connection_id: 连接 ID。
websocket: FastAPI WebSocket 对象。
Returns:
WebSocketConnection: 新建的连接上下文。
"""
await websocket.accept()
connection = WebSocketConnection(connection_id=connection_id, websocket=websocket)
connection.sender_task = asyncio.create_task(self._sender_loop(connection))
self.connections[connection_id] = connection
return connection
async def disconnect(self, connection_id: str) -> None:
"""断开并清理指定连接。
Args:
connection_id: 连接 ID。
"""
connection = self.connections.pop(connection_id, None)
if connection is None:
return
await connection.send_queue.put(None)
if connection.sender_task is not None:
try:
await connection.sender_task
except asyncio.CancelledError:
pass
except Exception as exc:
logger.debug("等待发送协程退出时出现异常: connection=%s, error=%s", connection_id, exc)
def get_connection(self, connection_id: str) -> Optional[WebSocketConnection]:
"""获取指定连接上下文。
Args:
connection_id: 连接 ID。
Returns:
Optional[WebSocketConnection]: 找到时返回连接上下文。
"""
return self.connections.get(connection_id)
def register_chat_session(self, connection_id: str, client_session_id: str, session_id: str) -> None:
"""登记连接下的逻辑聊天会话。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
session_id: 内部会话 ID。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.chat_sessions[client_session_id] = session_id
def unregister_chat_session(self, connection_id: str, client_session_id: str) -> None:
"""移除连接下的逻辑聊天会话登记。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.chat_sessions.pop(client_session_id, None)
def get_chat_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
"""查询连接下的内部聊天会话 ID。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
Returns:
Optional[str]: 找到时返回内部会话 ID。
"""
connection = self.connections.get(connection_id)
if connection is None:
return None
return connection.chat_sessions.get(client_session_id)
def subscribe(self, connection_id: str, domain: str, topic: str) -> None:
"""登记连接的主题订阅。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
topic: 主题名称。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.subscriptions.add(self._build_subscription_key(domain, topic))
def unsubscribe(self, connection_id: str, domain: str, topic: str) -> None:
"""移除连接的主题订阅。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
topic: 主题名称。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.subscriptions.discard(self._build_subscription_key(domain, topic))
def is_subscribed(self, connection_id: str, domain: str, topic: str) -> bool:
"""判断连接是否订阅了指定主题。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
topic: 主题名称。
Returns:
bool: 已订阅时返回 ``True``。
"""
connection = self.connections.get(connection_id)
if connection is None:
return False
return self._build_subscription_key(domain, topic) in connection.subscriptions
async def enqueue(self, connection_id: str, message: Dict[str, Any]) -> None:
"""向指定连接的发送队列压入消息。
Args:
connection_id: 连接 ID。
message: 待发送的消息。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
await connection.send_queue.put(message)
async def send_response(
self,
connection_id: str,
request_id: Optional[str],
ok: bool,
data: Optional[Dict[str, Any]] = None,
error: Optional[Dict[str, Any]] = None,
) -> None:
"""发送统一响应消息。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
ok: 请求是否成功。
data: 成功响应数据。
error: 失败响应数据。
"""
response_message: Dict[str, Any] = {
"op": "response",
"id": request_id,
"ok": ok,
}
if data is not None:
response_message["data"] = data
if error is not None:
response_message["error"] = error
await self.enqueue(connection_id, response_message)
async def send_event(
self,
connection_id: str,
domain: str,
event: str,
data: Dict[str, Any],
session: Optional[str] = None,
topic: Optional[str] = None,
) -> None:
"""发送统一事件消息。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
event: 事件名称。
data: 事件数据。
session: 可选的逻辑会话 ID。
topic: 可选的主题名称。
"""
event_message: Dict[str, Any] = {
"op": "event",
"domain": domain,
"event": event,
"data": data,
}
if session is not None:
event_message["session"] = session
if topic is not None:
event_message["topic"] = topic
await self.enqueue(connection_id, event_message)
async def send_pong(self, connection_id: str, timestamp: float) -> None:
"""发送心跳响应。
Args:
connection_id: 连接 ID。
timestamp: 当前时间戳。
"""
await self.enqueue(
connection_id,
{
"op": "pong",
"ts": timestamp,
},
)
async def broadcast_to_topic(self, domain: str, topic: str, event: str, data: Dict[str, Any]) -> None:
"""向订阅指定主题的全部连接广播事件。
Args:
domain: 业务域名称。
topic: 主题名称。
event: 事件名称。
data: 事件数据。
"""
subscription_key = self._build_subscription_key(domain, topic)
for connection in list(self.connections.values()):
if subscription_key in connection.subscriptions:
await self.send_event(
connection.connection_id,
domain=domain,
event=event,
data=data,
topic=topic,
)
websocket_manager = UnifiedWebSocketManager()

View File

@@ -0,0 +1,548 @@
"""统一 WebSocket 路由。"""
from typing import Any, Dict, Optional, Set, cast
import asyncio
import time
import uuid
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.logs_ws import load_recent_logs
from src.webui.routers.chat.service import (
chat_manager,
dispatch_chat_event,
normalize_webui_user_id,
resolve_initial_virtual_identity,
send_initial_chat_state,
)
from src.webui.routers.plugin.progress import get_current_progress
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.routers.websocket.manager import websocket_manager
logger = get_logger("webui.unified_ws")
router = APIRouter()
_background_tasks: Set["asyncio.Task[None]"] = set()
def _build_error(code: str, message: str) -> Dict[str, Any]:
"""构建统一错误响应体。
Args:
code: 错误码。
message: 错误描述。
Returns:
Dict[str, Any]: 统一错误对象。
"""
return {
"code": code,
"message": message,
}
def _get_request_data(message: Dict[str, Any]) -> Dict[str, Any]:
"""从客户端消息中提取数据字段。
Args:
message: 客户端消息。
Returns:
Dict[str, Any]: 标准化后的数据字典。
"""
data = message.get("data", {})
if isinstance(data, dict):
return cast(Dict[str, Any], data)
return {}
def _track_background_task(task: "asyncio.Task[None]") -> None:
"""登记后台任务并在完成后自动清理。
Args:
task: 后台协程任务。
"""
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
async def authenticate_websocket_connection(websocket: WebSocket, token: Optional[str]) -> bool:
"""校验统一 WebSocket 连接的认证状态。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
Returns:
bool: 认证通过时返回 ``True``。
"""
if token and verify_ws_token(token):
logger.debug("统一 WebSocket 使用临时 token 认证成功")
return True
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
logger.debug("统一 WebSocket 使用 Cookie 认证成功")
return True
return False
async def _handle_logs_subscribe(connection_id: str, request_id: Optional[str], data: Dict[str, Any]) -> None:
"""处理日志域订阅请求。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
data: 订阅参数。
"""
replay_limit = int(data.get("replay", 100) or 100)
replay_limit = max(0, min(replay_limit, 500))
websocket_manager.subscribe(connection_id, domain="logs", topic="main")
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": "logs", "topic": "main"},
)
await websocket_manager.send_event(
connection_id,
domain="logs",
event="snapshot",
topic="main",
data={"entries": load_recent_logs(limit=replay_limit)},
)
async def _handle_plugin_progress_subscribe(connection_id: str, request_id: Optional[str]) -> None:
"""处理插件进度域订阅请求。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
"""
websocket_manager.subscribe(connection_id, domain="plugin_progress", topic="main")
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": "plugin_progress", "topic": "main"},
)
await websocket_manager.send_event(
connection_id,
domain="plugin_progress",
event="snapshot",
topic="main",
data={"progress": get_current_progress()},
)
async def _handle_subscribe(connection_id: str, message: Dict[str, Any]) -> None:
"""处理主题订阅请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
topic = str(message.get("topic") or "").strip()
data = _get_request_data(message)
if domain == "logs" and topic == "main":
await _handle_logs_subscribe(connection_id, request_id, data)
return
if domain == "plugin_progress" and topic == "main":
await _handle_plugin_progress_subscribe(connection_id, request_id)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_subscription", f"不支持的订阅目标: {domain}:{topic}"),
)
async def _handle_unsubscribe(connection_id: str, message: Dict[str, Any]) -> None:
"""处理主题退订请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
topic = str(message.get("topic") or "").strip()
if not domain or not topic:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("invalid_unsubscribe", "退订请求缺少 domain 或 topic"),
)
return
websocket_manager.unsubscribe(connection_id, domain=domain, topic=topic)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": domain, "topic": topic},
)
async def _open_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
"""打开一个逻辑聊天会话。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
if not client_session_id:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("missing_session", "聊天会话打开请求缺少 session"),
)
return
data = _get_request_data(message)
normalized_user_id = normalize_webui_user_id(cast(Optional[str], data.get("user_id")))
current_user_name = str(data.get("user_name") or "WebUI用户")
current_virtual_config = resolve_initial_virtual_identity(
platform=cast(Optional[str], data.get("platform")),
person_id=cast(Optional[str], data.get("person_id")),
group_name=cast(Optional[str], data.get("group_name")),
group_id=cast(Optional[str], data.get("group_id")),
)
restore = bool(data.get("restore"))
session_id = f"{connection_id}:{client_session_id}"
async def send_chat_event(chat_message: Dict[str, Any]) -> None:
"""将聊天消息封装为统一事件并发送。
Args:
chat_message: 聊天消息体。
"""
event_name = str(chat_message.get("type") or "message")
await websocket_manager.send_event(
connection_id,
domain="chat",
event=event_name,
session=client_session_id,
data=chat_message,
)
await chat_manager.connect(
session_id=session_id,
connection_id=connection_id,
client_session_id=client_session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
sender=send_chat_event,
)
websocket_manager.register_chat_session(connection_id, client_session_id, session_id)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id, "session_id": session_id},
)
await send_initial_chat_state(
session_id=session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
include_welcome=not restore,
)
async def _close_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
"""关闭一个逻辑聊天会话。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
chat_manager.disconnect(session_id)
websocket_manager.unregister_chat_session(connection_id, client_session_id)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id},
)
async def _process_chat_message(connection_id: str, client_session_id: str, data: Dict[str, Any]) -> None:
"""在后台处理聊天消息事件。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
data: 客户端提交的消息数据。
"""
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
return
session_state = chat_manager.get_session(session_id)
if session_state is None:
return
next_user_name, next_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data=data,
current_user_name=session_state.user_name,
normalized_user_id=session_state.user_id,
current_virtual_config=session_state.virtual_config,
)
chat_manager.update_session_context(
session_id=session_id,
user_name=next_user_name,
virtual_config=next_virtual_config,
)
async def _handle_chat_message_send(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天消息发送请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
data = _get_request_data(message)
payload = {
"type": "message",
"content": data.get("content", ""),
"user_name": data.get("user_name", ""),
}
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"accepted": True, "session": client_session_id},
)
_track_background_task(asyncio.create_task(_process_chat_message(connection_id, client_session_id, payload)))
async def _handle_chat_nickname_update(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天昵称更新请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
data = _get_request_data(message)
session_state = chat_manager.get_session(session_id)
if session_state is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
next_user_name, next_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data={
"type": "update_nickname",
"user_name": data.get("user_name", ""),
},
current_user_name=session_state.user_name,
normalized_user_id=session_state.user_id,
current_virtual_config=session_state.virtual_config,
)
chat_manager.update_session_context(
session_id=session_id,
user_name=next_user_name,
virtual_config=next_virtual_config,
)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id, "user_name": next_user_name},
)
async def _handle_chat_call(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天域调用请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
method = str(message.get("method") or "").strip()
if method == "session.open":
await _open_chat_session(connection_id, message)
return
if method == "session.close":
await _close_chat_session(connection_id, message)
return
if method == "message.send":
await _handle_chat_message_send(connection_id, message)
return
if method == "session.update_nickname":
await _handle_chat_nickname_update(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_method", f"不支持的聊天方法: {method}"),
)
async def _handle_call(connection_id: str, message: Dict[str, Any]) -> None:
"""处理统一调用请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
if domain == "chat":
await _handle_chat_call(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_domain", f"不支持的调用域: {domain}"),
)
async def handle_client_message(connection_id: str, message: Dict[str, Any]) -> None:
"""处理统一 WebSocket 客户端消息。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
operation = str(message.get("op") or "").strip()
request_id = cast(Optional[str], message.get("id"))
if operation == "ping":
await websocket_manager.send_pong(connection_id, time.time())
return
if operation == "subscribe":
await _handle_subscribe(connection_id, message)
return
if operation == "unsubscribe":
await _handle_unsubscribe(connection_id, message)
return
if operation == "call":
await _handle_call(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_operation", f"不支持的操作: {operation}"),
)
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
"""统一 WebSocket 入口。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
"""
if not await authenticate_websocket_connection(websocket, token):
logger.warning("统一 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
connection_id = uuid.uuid4().hex
await websocket_manager.connect(connection_id=connection_id, websocket=websocket)
logger.info("统一 WebSocket 客户端已连接: connection=%s", connection_id)
await websocket_manager.send_event(
connection_id,
domain="system",
event="ready",
data={"connection_id": connection_id, "timestamp": time.time()},
)
try:
while True:
raw_message = await websocket.receive_json()
if not isinstance(raw_message, dict):
await websocket_manager.send_response(
connection_id,
request_id=None,
ok=False,
error=_build_error("invalid_message", "消息必须是 JSON 对象"),
)
continue
await handle_client_message(connection_id, cast(Dict[str, Any], raw_message))
except WebSocketDisconnect:
logger.info("统一 WebSocket 客户端已断开: connection=%s", connection_id)
except Exception as exc:
logger.error(f"统一 WebSocket 处理失败: {exc}")
finally:
chat_manager.disconnect_connection(connection_id)
await websocket_manager.disconnect(connection_id)

View File

@@ -18,11 +18,11 @@ from src.webui.routers.expression import router as expression_router
from src.webui.routers.jargon import router as jargon_router
from src.webui.routers.model import router as model_router
from src.webui.routers.person import router as person_router
from src.webui.routers.plugin import get_progress_router
from src.webui.routers.plugin import router as plugin_router
from src.webui.routers.statistics import router as statistics_router
from src.webui.routers.system import router as system_router
from src.webui.routers.websocket.auth import router as ws_auth_router
from src.webui.routers.websocket.unified import router as unified_ws_router
logger = get_logger("webui.api")
@@ -43,14 +43,14 @@ router.include_router(jargon_router)
router.include_router(emoji_router)
# 注册插件管理路由
router.include_router(plugin_router)
# 注册插件进度 WebSocket 路由
router.include_router(get_progress_router())
# 注册系统控制路由
router.include_router(system_router)
# 注册模型列表获取路由
router.include_router(model_router)
# 注册 WebSocket 认证路由
router.include_router(ws_auth_router)
# 注册统一 WebSocket 路由
router.include_router(unified_ws_router)
class TokenVerifyRequest(BaseModel):