feat: add unified WebSocket connection manager and routing
- Implemented UnifiedWebSocketManager for managing WebSocket connections, including subscription handling and message sending. - Created unified WebSocket router to handle client messages, including authentication, subscription, and chat session management. - Added support for logging and plugin progress subscriptions. - Enhanced error handling and response structure for WebSocket operations.
This commit is contained in:
@@ -1,22 +1,25 @@
|
||||
from datetime import datetime
|
||||
from sqlmodel import select
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import difflib
|
||||
import json
|
||||
import re
|
||||
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.data_models.expression_data_model import MaiExpression
|
||||
from sqlmodel import select
|
||||
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.common.data_models.expression_data_model import MaiExpression
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
from src.config.config import global_config
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from .expression_utils import check_expression_suitability, parse_expression_response
|
||||
|
||||
@@ -34,8 +37,122 @@ summary_model = LLMServiceClient(task_name="utils", request_type="expression.sum
|
||||
check_model = LLMServiceClient(task_name="utils", request_type="expression.check")
|
||||
|
||||
|
||||
def register_expression_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册表达方式系统内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="expression.select.before_select",
|
||||
description="表达方式选择流程开始前触发,可改写会话上下文、选择参数或中止本次选择。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
|
||||
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
|
||||
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
|
||||
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
|
||||
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
|
||||
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
|
||||
},
|
||||
required=["chat_id", "chat_info", "max_num", "think_level"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="expression.select.after_selection",
|
||||
description="表达方式选择完成后触发,可改写最终选中的表达方式列表与 ID。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
|
||||
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
|
||||
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
|
||||
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
|
||||
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
|
||||
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
|
||||
"selected_expressions": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "当前已选中的表达方式列表。",
|
||||
},
|
||||
"selected_expression_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
"description": "当前已选中的表达方式 ID 列表。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"chat_id",
|
||||
"chat_info",
|
||||
"max_num",
|
||||
"think_level",
|
||||
"selected_expressions",
|
||||
"selected_expression_ids",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="expression.learn.after_extract",
|
||||
description="表达方式学习解析出表达/黑话候选后触发,可改写候选集或直接终止本轮学习。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"message_count": {"type": "integer", "description": "本轮参与学习的消息数量。"},
|
||||
"expressions": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "解析出的表达方式候选列表。",
|
||||
},
|
||||
"jargon_entries": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "解析出的黑话候选列表。",
|
||||
},
|
||||
},
|
||||
required=["session_id", "message_count", "expressions", "jargon_entries"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="expression.learn.before_upsert",
|
||||
description="表达方式写入数据库前触发,可改写情景/风格文本或跳过本条写入。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"situation": {"type": "string", "description": "即将写入的情景文本。"},
|
||||
"style": {"type": "string", "description": "即将写入的风格文本。"},
|
||||
},
|
||||
required=["session_id", "situation", "style"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ExpressionLearner:
|
||||
def __init__(self, session_id: str) -> None:
|
||||
"""初始化表达方式学习器。
|
||||
|
||||
Args:
|
||||
session_id: 当前会话 ID。
|
||||
"""
|
||||
|
||||
self.session_id = session_id
|
||||
|
||||
# 学习锁,防止并发执行学习任务
|
||||
@@ -44,6 +161,110 @@ class ExpressionLearner:
|
||||
# 消息缓存
|
||||
self._messages_cache: List["SessionMessage"] = []
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_expressions(expressions: List[Tuple[str, str, str]]) -> List[dict[str, str]]:
|
||||
"""将表达方式候选序列化为 Hook 载荷。
|
||||
|
||||
Args:
|
||||
expressions: 原始表达方式候选列表。
|
||||
|
||||
Returns:
|
||||
List[dict[str, str]]: 序列化后的表达方式候选。
|
||||
"""
|
||||
|
||||
return [
|
||||
{
|
||||
"situation": str(situation).strip(),
|
||||
"style": str(style).strip(),
|
||||
"source_id": str(source_id).strip(),
|
||||
}
|
||||
for situation, style, source_id in expressions
|
||||
if str(situation).strip() and str(style).strip()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_expressions(raw_expressions: Any) -> List[Tuple[str, str, str]]:
|
||||
"""从 Hook 载荷恢复表达方式候选列表。
|
||||
|
||||
Args:
|
||||
raw_expressions: Hook 返回的表达方式候选。
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 恢复后的表达方式候选列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_expressions, list):
|
||||
return []
|
||||
|
||||
normalized_expressions: List[Tuple[str, str, str]] = []
|
||||
for raw_expression in raw_expressions:
|
||||
if not isinstance(raw_expression, dict):
|
||||
continue
|
||||
situation = str(raw_expression.get("situation") or "").strip()
|
||||
style = str(raw_expression.get("style") or "").strip()
|
||||
source_id = str(raw_expression.get("source_id") or "").strip()
|
||||
if not situation or not style:
|
||||
continue
|
||||
normalized_expressions.append((situation, style, source_id))
|
||||
return normalized_expressions
|
||||
|
||||
@staticmethod
|
||||
def _serialize_jargon_entries(jargon_entries: List[Tuple[str, str]]) -> List[dict[str, str]]:
|
||||
"""将黑话候选序列化为 Hook 载荷。
|
||||
|
||||
Args:
|
||||
jargon_entries: 原始黑话候选列表。
|
||||
|
||||
Returns:
|
||||
List[dict[str, str]]: 序列化后的黑话候选列表。
|
||||
"""
|
||||
|
||||
return [
|
||||
{
|
||||
"content": str(content).strip(),
|
||||
"source_id": str(source_id).strip(),
|
||||
}
|
||||
for content, source_id in jargon_entries
|
||||
if str(content).strip()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_jargon_entries(raw_jargon_entries: Any) -> List[Tuple[str, str]]:
|
||||
"""从 Hook 载荷恢复黑话候选列表。
|
||||
|
||||
Args:
|
||||
raw_jargon_entries: Hook 返回的黑话候选列表。
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 恢复后的黑话候选列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_jargon_entries, list):
|
||||
return []
|
||||
|
||||
normalized_entries: List[Tuple[str, str]] = []
|
||||
for raw_entry in raw_jargon_entries:
|
||||
if not isinstance(raw_entry, dict):
|
||||
continue
|
||||
content = str(raw_entry.get("content") or "").strip()
|
||||
source_id = str(raw_entry.get("source_id") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
normalized_entries.append((content, source_id))
|
||||
return normalized_entries
|
||||
|
||||
def add_messages(self, messages: List["SessionMessage"]) -> None:
|
||||
"""添加消息到缓存"""
|
||||
self._messages_cache.extend(messages)
|
||||
@@ -52,8 +273,12 @@ class ExpressionLearner:
|
||||
"""获取当前消息缓存的大小"""
|
||||
return len(self._messages_cache)
|
||||
|
||||
async def learn(self, jargon_miner: Optional["JargonMiner"] = None):
|
||||
"""学习主流程"""
|
||||
async def learn(self, jargon_miner: Optional["JargonMiner"] = None) -> None:
|
||||
"""执行表达方式学习主流程。
|
||||
|
||||
Args:
|
||||
jargon_miner: 可选的黑话学习器实例,用于同步处理黑话候选。
|
||||
"""
|
||||
if not self._messages_cache:
|
||||
logger.debug("没有消息可供学习,跳过学习过程")
|
||||
return
|
||||
@@ -109,6 +334,25 @@ class ExpressionLearner:
|
||||
logger.info(f"黑话提取数量超过 30 个(实际{len(jargon_entries)}个),放弃本次黑话学习")
|
||||
jargon_entries = []
|
||||
|
||||
after_extract_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.learn.after_extract",
|
||||
session_id=self.session_id,
|
||||
message_count=len(self._messages_cache),
|
||||
expressions=self._serialize_expressions(expressions),
|
||||
jargon_entries=self._serialize_jargon_entries(jargon_entries),
|
||||
)
|
||||
if after_extract_result.aborted:
|
||||
logger.info(f"{self.session_id} 的表达方式学习结果被 Hook 中止")
|
||||
return
|
||||
|
||||
after_extract_kwargs = after_extract_result.kwargs
|
||||
raw_expressions = after_extract_kwargs.get("expressions")
|
||||
if raw_expressions is not None:
|
||||
expressions = self._deserialize_expressions(raw_expressions)
|
||||
raw_jargon_entries = after_extract_kwargs.get("jargon_entries")
|
||||
if raw_jargon_entries is not None:
|
||||
jargon_entries = self._deserialize_jargon_entries(raw_jargon_entries)
|
||||
|
||||
# 处理黑话条目,路由到 jargon_miner(即使没有表达方式也要处理黑话)
|
||||
# TODO: 检测是否开启了
|
||||
if jargon_entries:
|
||||
@@ -135,6 +379,22 @@ class ExpressionLearner:
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for situation, style in learnt_expressions:
|
||||
before_upsert_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.learn.before_upsert",
|
||||
session_id=self.session_id,
|
||||
situation=situation,
|
||||
style=style,
|
||||
)
|
||||
if before_upsert_result.aborted:
|
||||
logger.info(f"{self.session_id} 的表达方式写入被 Hook 跳过: situation={situation!r}")
|
||||
continue
|
||||
|
||||
upsert_kwargs = before_upsert_result.kwargs
|
||||
situation = str(upsert_kwargs.get("situation", situation) or "").strip()
|
||||
style = str(upsert_kwargs.get("style", style) or "").strip()
|
||||
if not situation or not style:
|
||||
logger.info(f"{self.session_id} 的表达方式写入被 Hook 清空,已跳过")
|
||||
continue
|
||||
await self._upsert_expression_to_db(situation, style)
|
||||
|
||||
# ====== 黑话相关 ======
|
||||
|
||||
Reference in New Issue
Block a user