feat: add unified WebSocket connection manager and routing

- Implemented UnifiedWebSocketManager for managing WebSocket connections, including subscription handling and message sending.
- Created unified WebSocket router to handle client messages, including authentication, subscription, and chat session management.
- Added support for logging and plugin progress subscriptions.
- Enhanced error handling and response structure for WebSocket operations.
This commit is contained in:
DrSmoothl
2026-04-02 22:08:52 +08:00
parent 7d0d429640
commit 1906890b67
28 changed files with 3845 additions and 1137 deletions

View File

@@ -1,22 +1,25 @@
from datetime import datetime
from sqlmodel import select
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
import asyncio
import difflib
import json
import re
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
from src.prompt.prompt_manager import prompt_manager
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.database.database import get_db_session
from src.common.data_models.expression_data_model import MaiExpression
from sqlmodel import select
from src.chat.utils.utils import is_bot_self
from src.common.data_models.expression_data_model import MaiExpression
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.common.utils.utils_message import MessageUtils
from src.config.config import global_config
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
from .expression_utils import check_expression_suitability, parse_expression_response
@@ -34,8 +37,122 @@ summary_model = LLMServiceClient(task_name="utils", request_type="expression.sum
check_model = LLMServiceClient(task_name="utils", request_type="expression.check")
def register_expression_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册表达方式系统内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="expression.select.before_select",
description="表达方式选择流程开始前触发,可改写会话上下文、选择参数或中止本次选择。",
parameters_schema=build_object_schema(
{
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
},
required=["chat_id", "chat_info", "max_num", "think_level"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.select.after_selection",
description="表达方式选择完成后触发,可改写最终选中的表达方式列表与 ID。",
parameters_schema=build_object_schema(
{
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
"selected_expressions": {
"type": "array",
"items": {"type": "object"},
"description": "当前已选中的表达方式列表。",
},
"selected_expression_ids": {
"type": "array",
"items": {"type": "integer"},
"description": "当前已选中的表达方式 ID 列表。",
},
},
required=[
"chat_id",
"chat_info",
"max_num",
"think_level",
"selected_expressions",
"selected_expression_ids",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.learn.after_extract",
description="表达方式学习解析出表达/黑话候选后触发,可改写候选集或直接终止本轮学习。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"message_count": {"type": "integer", "description": "本轮参与学习的消息数量。"},
"expressions": {
"type": "array",
"items": {"type": "object"},
"description": "解析出的表达方式候选列表。",
},
"jargon_entries": {
"type": "array",
"items": {"type": "object"},
"description": "解析出的黑话候选列表。",
},
},
required=["session_id", "message_count", "expressions", "jargon_entries"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.learn.before_upsert",
description="表达方式写入数据库前触发,可改写情景/风格文本或跳过本条写入。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"situation": {"type": "string", "description": "即将写入的情景文本。"},
"style": {"type": "string", "description": "即将写入的风格文本。"},
},
required=["session_id", "situation", "style"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
]
)
class ExpressionLearner:
def __init__(self, session_id: str) -> None:
"""初始化表达方式学习器。
Args:
session_id: 当前会话 ID。
"""
self.session_id = session_id
# 学习锁,防止并发执行学习任务
@@ -44,6 +161,110 @@ class ExpressionLearner:
# 消息缓存
self._messages_cache: List["SessionMessage"] = []
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _serialize_expressions(expressions: List[Tuple[str, str, str]]) -> List[dict[str, str]]:
"""将表达方式候选序列化为 Hook 载荷。
Args:
expressions: 原始表达方式候选列表。
Returns:
List[dict[str, str]]: 序列化后的表达方式候选。
"""
return [
{
"situation": str(situation).strip(),
"style": str(style).strip(),
"source_id": str(source_id).strip(),
}
for situation, style, source_id in expressions
if str(situation).strip() and str(style).strip()
]
@staticmethod
def _deserialize_expressions(raw_expressions: Any) -> List[Tuple[str, str, str]]:
"""从 Hook 载荷恢复表达方式候选列表。
Args:
raw_expressions: Hook 返回的表达方式候选。
Returns:
List[Tuple[str, str, str]]: 恢复后的表达方式候选列表。
"""
if not isinstance(raw_expressions, list):
return []
normalized_expressions: List[Tuple[str, str, str]] = []
for raw_expression in raw_expressions:
if not isinstance(raw_expression, dict):
continue
situation = str(raw_expression.get("situation") or "").strip()
style = str(raw_expression.get("style") or "").strip()
source_id = str(raw_expression.get("source_id") or "").strip()
if not situation or not style:
continue
normalized_expressions.append((situation, style, source_id))
return normalized_expressions
@staticmethod
def _serialize_jargon_entries(jargon_entries: List[Tuple[str, str]]) -> List[dict[str, str]]:
"""将黑话候选序列化为 Hook 载荷。
Args:
jargon_entries: 原始黑话候选列表。
Returns:
List[dict[str, str]]: 序列化后的黑话候选列表。
"""
return [
{
"content": str(content).strip(),
"source_id": str(source_id).strip(),
}
for content, source_id in jargon_entries
if str(content).strip()
]
@staticmethod
def _deserialize_jargon_entries(raw_jargon_entries: Any) -> List[Tuple[str, str]]:
"""从 Hook 载荷恢复黑话候选列表。
Args:
raw_jargon_entries: Hook 返回的黑话候选列表。
Returns:
List[Tuple[str, str]]: 恢复后的黑话候选列表。
"""
if not isinstance(raw_jargon_entries, list):
return []
normalized_entries: List[Tuple[str, str]] = []
for raw_entry in raw_jargon_entries:
if not isinstance(raw_entry, dict):
continue
content = str(raw_entry.get("content") or "").strip()
source_id = str(raw_entry.get("source_id") or "").strip()
if not content:
continue
normalized_entries.append((content, source_id))
return normalized_entries
def add_messages(self, messages: List["SessionMessage"]) -> None:
"""添加消息到缓存"""
self._messages_cache.extend(messages)
@@ -52,8 +273,12 @@ class ExpressionLearner:
"""获取当前消息缓存的大小"""
return len(self._messages_cache)
async def learn(self, jargon_miner: Optional["JargonMiner"] = None):
"""学习主流程"""
async def learn(self, jargon_miner: Optional["JargonMiner"] = None) -> None:
"""执行表达方式学习主流程
Args:
jargon_miner: 可选的黑话学习器实例,用于同步处理黑话候选。
"""
if not self._messages_cache:
logger.debug("没有消息可供学习,跳过学习过程")
return
@@ -109,6 +334,25 @@ class ExpressionLearner:
logger.info(f"黑话提取数量超过 30 个(实际{len(jargon_entries)}个),放弃本次黑话学习")
jargon_entries = []
after_extract_result = await self._get_runtime_manager().invoke_hook(
"expression.learn.after_extract",
session_id=self.session_id,
message_count=len(self._messages_cache),
expressions=self._serialize_expressions(expressions),
jargon_entries=self._serialize_jargon_entries(jargon_entries),
)
if after_extract_result.aborted:
logger.info(f"{self.session_id} 的表达方式学习结果被 Hook 中止")
return
after_extract_kwargs = after_extract_result.kwargs
raw_expressions = after_extract_kwargs.get("expressions")
if raw_expressions is not None:
expressions = self._deserialize_expressions(raw_expressions)
raw_jargon_entries = after_extract_kwargs.get("jargon_entries")
if raw_jargon_entries is not None:
jargon_entries = self._deserialize_jargon_entries(raw_jargon_entries)
# 处理黑话条目,路由到 jargon_miner即使没有表达方式也要处理黑话
# TODO: 检测是否开启了
if jargon_entries:
@@ -135,6 +379,22 @@ class ExpressionLearner:
# 存储到数据库 Expression 表
for situation, style in learnt_expressions:
before_upsert_result = await self._get_runtime_manager().invoke_hook(
"expression.learn.before_upsert",
session_id=self.session_id,
situation=situation,
style=style,
)
if before_upsert_result.aborted:
logger.info(f"{self.session_id} 的表达方式写入被 Hook 跳过: situation={situation!r}")
continue
upsert_kwargs = before_upsert_result.kwargs
situation = str(upsert_kwargs.get("situation", situation) or "").strip()
style = str(upsert_kwargs.get("style", style) or "").strip()
if not situation or not style:
logger.info(f"{self.session_id} 的表达方式写入被 Hook 清空,已跳过")
continue
await self._upsert_expression_to_db(situation, style)
# ====== 黑话相关 ======