feat: add unified WebSocket connection manager and routing

- Implemented UnifiedWebSocketManager for managing WebSocket connections, including subscription handling and message sending.
- Created unified WebSocket router to handle client messages, including authentication, subscription, and chat session management.
- Added support for logging and plugin progress subscriptions.
- Enhanced error handling and response structure for WebSocket operations.
This commit is contained in:
DrSmoothl
2026-04-02 22:08:52 +08:00
parent 7d0d429640
commit 1906890b67
28 changed files with 3845 additions and 1137 deletions

View File

@@ -1,22 +1,25 @@
from datetime import datetime
from sqlmodel import select
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
import asyncio
import difflib
import json
import re
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
from src.prompt.prompt_manager import prompt_manager
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.database.database import get_db_session
from src.common.data_models.expression_data_model import MaiExpression
from sqlmodel import select
from src.chat.utils.utils import is_bot_self
from src.common.data_models.expression_data_model import MaiExpression
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.common.utils.utils_message import MessageUtils
from src.config.config import global_config
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
from .expression_utils import check_expression_suitability, parse_expression_response
@@ -34,8 +37,122 @@ summary_model = LLMServiceClient(task_name="utils", request_type="expression.sum
check_model = LLMServiceClient(task_name="utils", request_type="expression.check")
def register_expression_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册表达方式系统内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="expression.select.before_select",
description="表达方式选择流程开始前触发,可改写会话上下文、选择参数或中止本次选择。",
parameters_schema=build_object_schema(
{
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
},
required=["chat_id", "chat_info", "max_num", "think_level"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.select.after_selection",
description="表达方式选择完成后触发,可改写最终选中的表达方式列表与 ID。",
parameters_schema=build_object_schema(
{
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
"selected_expressions": {
"type": "array",
"items": {"type": "object"},
"description": "当前已选中的表达方式列表。",
},
"selected_expression_ids": {
"type": "array",
"items": {"type": "integer"},
"description": "当前已选中的表达方式 ID 列表。",
},
},
required=[
"chat_id",
"chat_info",
"max_num",
"think_level",
"selected_expressions",
"selected_expression_ids",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.learn.after_extract",
description="表达方式学习解析出表达/黑话候选后触发,可改写候选集或直接终止本轮学习。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"message_count": {"type": "integer", "description": "本轮参与学习的消息数量。"},
"expressions": {
"type": "array",
"items": {"type": "object"},
"description": "解析出的表达方式候选列表。",
},
"jargon_entries": {
"type": "array",
"items": {"type": "object"},
"description": "解析出的黑话候选列表。",
},
},
required=["session_id", "message_count", "expressions", "jargon_entries"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.learn.before_upsert",
description="表达方式写入数据库前触发,可改写情景/风格文本或跳过本条写入。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"situation": {"type": "string", "description": "即将写入的情景文本。"},
"style": {"type": "string", "description": "即将写入的风格文本。"},
},
required=["session_id", "situation", "style"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
]
)
class ExpressionLearner:
def __init__(self, session_id: str) -> None:
"""初始化表达方式学习器。
Args:
session_id: 当前会话 ID。
"""
self.session_id = session_id
# 学习锁,防止并发执行学习任务
@@ -44,6 +161,110 @@ class ExpressionLearner:
# 消息缓存
self._messages_cache: List["SessionMessage"] = []
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _serialize_expressions(expressions: List[Tuple[str, str, str]]) -> List[dict[str, str]]:
"""将表达方式候选序列化为 Hook 载荷。
Args:
expressions: 原始表达方式候选列表。
Returns:
List[dict[str, str]]: 序列化后的表达方式候选。
"""
return [
{
"situation": str(situation).strip(),
"style": str(style).strip(),
"source_id": str(source_id).strip(),
}
for situation, style, source_id in expressions
if str(situation).strip() and str(style).strip()
]
@staticmethod
def _deserialize_expressions(raw_expressions: Any) -> List[Tuple[str, str, str]]:
"""从 Hook 载荷恢复表达方式候选列表。
Args:
raw_expressions: Hook 返回的表达方式候选。
Returns:
List[Tuple[str, str, str]]: 恢复后的表达方式候选列表。
"""
if not isinstance(raw_expressions, list):
return []
normalized_expressions: List[Tuple[str, str, str]] = []
for raw_expression in raw_expressions:
if not isinstance(raw_expression, dict):
continue
situation = str(raw_expression.get("situation") or "").strip()
style = str(raw_expression.get("style") or "").strip()
source_id = str(raw_expression.get("source_id") or "").strip()
if not situation or not style:
continue
normalized_expressions.append((situation, style, source_id))
return normalized_expressions
@staticmethod
def _serialize_jargon_entries(jargon_entries: List[Tuple[str, str]]) -> List[dict[str, str]]:
"""将黑话候选序列化为 Hook 载荷。
Args:
jargon_entries: 原始黑话候选列表。
Returns:
List[dict[str, str]]: 序列化后的黑话候选列表。
"""
return [
{
"content": str(content).strip(),
"source_id": str(source_id).strip(),
}
for content, source_id in jargon_entries
if str(content).strip()
]
@staticmethod
def _deserialize_jargon_entries(raw_jargon_entries: Any) -> List[Tuple[str, str]]:
"""从 Hook 载荷恢复黑话候选列表。
Args:
raw_jargon_entries: Hook 返回的黑话候选列表。
Returns:
List[Tuple[str, str]]: 恢复后的黑话候选列表。
"""
if not isinstance(raw_jargon_entries, list):
return []
normalized_entries: List[Tuple[str, str]] = []
for raw_entry in raw_jargon_entries:
if not isinstance(raw_entry, dict):
continue
content = str(raw_entry.get("content") or "").strip()
source_id = str(raw_entry.get("source_id") or "").strip()
if not content:
continue
normalized_entries.append((content, source_id))
return normalized_entries
def add_messages(self, messages: List["SessionMessage"]) -> None:
"""添加消息到缓存"""
self._messages_cache.extend(messages)
@@ -52,8 +273,12 @@ class ExpressionLearner:
"""获取当前消息缓存的大小"""
return len(self._messages_cache)
async def learn(self, jargon_miner: Optional["JargonMiner"] = None):
"""学习主流程"""
async def learn(self, jargon_miner: Optional["JargonMiner"] = None) -> None:
"""执行表达方式学习主流程
Args:
jargon_miner: 可选的黑话学习器实例,用于同步处理黑话候选。
"""
if not self._messages_cache:
logger.debug("没有消息可供学习,跳过学习过程")
return
@@ -109,6 +334,25 @@ class ExpressionLearner:
logger.info(f"黑话提取数量超过 30 个(实际{len(jargon_entries)}个),放弃本次黑话学习")
jargon_entries = []
after_extract_result = await self._get_runtime_manager().invoke_hook(
"expression.learn.after_extract",
session_id=self.session_id,
message_count=len(self._messages_cache),
expressions=self._serialize_expressions(expressions),
jargon_entries=self._serialize_jargon_entries(jargon_entries),
)
if after_extract_result.aborted:
logger.info(f"{self.session_id} 的表达方式学习结果被 Hook 中止")
return
after_extract_kwargs = after_extract_result.kwargs
raw_expressions = after_extract_kwargs.get("expressions")
if raw_expressions is not None:
expressions = self._deserialize_expressions(raw_expressions)
raw_jargon_entries = after_extract_kwargs.get("jargon_entries")
if raw_jargon_entries is not None:
jargon_entries = self._deserialize_jargon_entries(raw_jargon_entries)
# 处理黑话条目,路由到 jargon_miner即使没有表达方式也要处理黑话
# TODO: 检测是否开启了
if jargon_entries:
@@ -135,6 +379,22 @@ class ExpressionLearner:
# 存储到数据库 Expression 表
for situation, style in learnt_expressions:
before_upsert_result = await self._get_runtime_manager().invoke_hook(
"expression.learn.before_upsert",
session_id=self.session_id,
situation=situation,
style=style,
)
if before_upsert_result.aborted:
logger.info(f"{self.session_id} 的表达方式写入被 Hook 跳过: situation={situation!r}")
continue
upsert_kwargs = before_upsert_result.kwargs
situation = str(upsert_kwargs.get("situation", situation) or "").strip()
style = str(upsert_kwargs.get("style", style) or "").strip()
if not situation or not style:
logger.info(f"{self.session_id} 的表达方式写入被 Hook 清空,已跳过")
continue
await self._upsert_expression_to_db(situation, style)
# ====== 黑话相关 ======

View File

@@ -1,27 +1,109 @@
from typing import Any, Dict, List, Optional, Tuple
import json
import time
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.utils.utils_session import SessionUtils
from src.prompt.prompt_manager import prompt_manager
from src.learners.learner_utils_old import weighted_sample
from src.chat.utils.common_utils import TempMethodsExpression
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.learners.learner_utils_old import weighted_sample
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
logger = get_logger("expression_selector")
class ExpressionSelector:
def __init__(self):
def __init__(self) -> None:
"""初始化表达方式选择器。"""
self.llm_model = LLMServiceClient(
task_name="utils", request_type="expression.selector"
)
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
"""将任意值安全转换为整数。
Args:
value: 待转换的值。
default: 转换失败时的默认值。
Returns:
int: 转换后的整数结果。
"""
try:
return int(value)
except (TypeError, ValueError):
return default
@staticmethod
def _normalize_selected_expressions(raw_expressions: Any) -> List[Dict[str, Any]]:
"""从 Hook 载荷恢复表达方式选择结果。
Args:
raw_expressions: Hook 返回的表达方式列表。
Returns:
List[Dict[str, Any]]: 恢复后的表达方式列表。
"""
if not isinstance(raw_expressions, list):
return []
normalized_expressions: List[Dict[str, Any]] = []
for raw_expression in raw_expressions:
if not isinstance(raw_expression, dict):
continue
expression_id = raw_expression.get("id")
situation = str(raw_expression.get("situation") or "").strip()
style = str(raw_expression.get("style") or "").strip()
source_id = str(raw_expression.get("source_id") or "").strip()
if not isinstance(expression_id, int) or not situation or not style or not source_id:
continue
normalized_expression = dict(raw_expression)
normalized_expression["id"] = expression_id
normalized_expression["situation"] = situation
normalized_expression["style"] = style
normalized_expression["source_id"] = source_id
normalized_expressions.append(normalized_expression)
return normalized_expressions
@staticmethod
def _normalize_selected_expression_ids(raw_ids: Any, expressions: List[Dict[str, Any]]) -> List[int]:
"""规范化最终选中的表达方式 ID 列表。
Args:
raw_ids: Hook 返回的 ID 列表。
expressions: 当前最终表达方式列表。
Returns:
List[int]: 规范化后的 ID 列表。
"""
if isinstance(raw_ids, list):
normalized_ids = [item for item in raw_ids if isinstance(item, int)]
if normalized_ids:
return normalized_ids
return [expression["id"] for expression in expressions if isinstance(expression.get("id"), int)]
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
@@ -214,8 +296,7 @@ class ExpressionSelector:
reply_reason: Optional[str] = None,
think_level: int = 1,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
选择适合的表达方式使用classic模式随机选择+LLM选择
"""选择适合的表达方式。
Args:
chat_id: 聊天流ID
@@ -233,11 +314,60 @@ class ExpressionSelector:
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
before_select_result = await self._get_runtime_manager().invoke_hook(
"expression.select.before_select",
chat_id=chat_id,
chat_info=chat_info,
max_num=max_num,
target_message=target_message or "",
reply_reason=reply_reason or "",
think_level=think_level,
)
if before_select_result.aborted:
logger.info(f"聊天流 {chat_id} 的表达方式选择被 Hook 中止")
return [], []
before_select_kwargs = before_select_result.kwargs
chat_id = str(before_select_kwargs.get("chat_id", chat_id) or "").strip() or chat_id
chat_info = str(before_select_kwargs.get("chat_info", chat_info) or "")
max_num = max(self._coerce_int(before_select_kwargs.get("max_num"), max_num), 1)
raw_target_message = before_select_kwargs.get("target_message", target_message or "")
target_message = str(raw_target_message or "").strip() or None
raw_reply_reason = before_select_kwargs.get("reply_reason", reply_reason or "")
reply_reason = str(raw_reply_reason or "").strip() or None
think_level = self._coerce_int(before_select_kwargs.get("think_level"), think_level)
# 使用classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式think_level={think_level}")
return await self._select_expressions_classic(
selected_expressions, selected_ids = await self._select_expressions_classic(
chat_id, chat_info, max_num, target_message, reply_reason, think_level
)
after_selection_result = await self._get_runtime_manager().invoke_hook(
"expression.select.after_selection",
chat_id=chat_id,
chat_info=chat_info,
max_num=max_num,
target_message=target_message or "",
reply_reason=reply_reason or "",
think_level=think_level,
selected_expressions=[dict(item) for item in selected_expressions],
selected_expression_ids=list(selected_ids),
)
if after_selection_result.aborted:
logger.info(f"聊天流 {chat_id} 的表达方式选择结果被 Hook 中止")
return [], []
after_selection_kwargs = after_selection_result.kwargs
raw_selected_expressions = after_selection_kwargs.get("selected_expressions")
if raw_selected_expressions is not None:
selected_expressions = self._normalize_selected_expressions(raw_selected_expressions)
selected_ids = self._normalize_selected_expression_ids(
after_selection_kwargs.get("selected_expression_ids"),
selected_expressions,
)
if selected_expressions:
self.update_expressions_last_active_time(selected_expressions)
return selected_expressions, selected_ids
async def _select_expressions_classic(
self,

View File

@@ -1,5 +1,5 @@
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Set, TypedDict
from typing import Any, Callable, Dict, List, Optional, Set, TypedDict
import asyncio
import json
@@ -9,13 +9,15 @@ from json_repair import repair_json
from sqlmodel import select
from src.common.data_models.jargon_data_model import MaiJargon
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.database.database import get_db_session
from src.common.database.database_model import Jargon
from src.common.logger import get_logger
from src.config.config import global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
from .expression_utils import is_single_char_jargon
@@ -35,8 +37,140 @@ class JargonMeaningEntry(TypedDict):
meaning: str
def register_jargon_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册 jargon 系统内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="jargon.query.before_search",
description="Maisaka 黑话查询工具执行检索前触发,可改写词条列表、检索参数或直接中止。",
parameters_schema=build_object_schema(
{
"words": {
"type": "array",
"items": {"type": "string"},
"description": "准备查询的黑话词条列表。",
},
"session_id": {"type": "string", "description": "当前会话 ID。"},
"limit": {"type": "integer", "description": "单个词条的最大返回条数。"},
"case_sensitive": {"type": "boolean", "description": "是否大小写敏感。"},
"enable_fuzzy_fallback": {"type": "boolean", "description": "是否允许精确命中失败后回退模糊检索。"},
"abort_message": {"type": "string", "description": "Hook 主动中止时的失败提示。"},
},
required=["words", "session_id", "limit", "case_sensitive", "enable_fuzzy_fallback"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="jargon.query.after_search",
description="Maisaka 黑话查询工具完成检索后触发,可改写结果列表或中止返回。",
parameters_schema=build_object_schema(
{
"words": {
"type": "array",
"items": {"type": "string"},
"description": "实际查询的黑话词条列表。",
},
"session_id": {"type": "string", "description": "当前会话 ID。"},
"limit": {"type": "integer", "description": "单个词条的最大返回条数。"},
"case_sensitive": {"type": "boolean", "description": "是否大小写敏感。"},
"enable_fuzzy_fallback": {"type": "boolean", "description": "是否启用了模糊检索回退。"},
"results": {
"type": "array",
"items": {"type": "object"},
"description": "查询结果列表。",
},
"abort_message": {"type": "string", "description": "Hook 主动中止时的失败提示。"},
},
required=["words", "session_id", "limit", "case_sensitive", "enable_fuzzy_fallback", "results"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="jargon.extract.before_persist",
description="黑话条目准备写入数据库前触发,可改写去重后的条目列表或跳过本次持久化。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"session_name": {"type": "string", "description": "当前会话展示名称。"},
"entries": {
"type": "array",
"items": {"type": "object"},
"description": "即将持久化的黑话条目列表。",
},
},
required=["session_id", "session_name", "entries"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="jargon.inference.before_finalize",
description="黑话含义推断完成、写回数据库前触发,可改写最终判定与含义结果。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"session_name": {"type": "string", "description": "当前会话展示名称。"},
"content": {"type": "string", "description": "当前黑话词条。"},
"count": {"type": "integer", "description": "当前词条累计命中次数。"},
"raw_content_list": {
"type": "array",
"items": {"type": "string"},
"description": "用于推断的原始上下文片段列表。",
},
"inference_with_context": {"type": "object", "description": "基于上下文的推断结果。"},
"inference_with_content_only": {"type": "object", "description": "仅基于词条内容的推断结果。"},
"comparison_result": {"type": "object", "description": "比较阶段输出结果。"},
"is_jargon": {"type": "boolean", "description": "当前推断是否判定为黑话。"},
"meaning": {"type": "string", "description": "当前推断出的黑话含义。"},
"is_complete": {"type": "boolean", "description": "当前是否已完成全部推断流程。"},
"last_inference_count": {"type": "integer", "description": "本次推断完成后应写回的 last_inference_count。"},
},
required=[
"session_id",
"session_name",
"content",
"count",
"raw_content_list",
"inference_with_context",
"inference_with_content_only",
"comparison_result",
"is_jargon",
"meaning",
"is_complete",
"last_inference_count",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
]
)
class JargonMiner:
def __init__(self, session_id: str, session_name: str) -> None:
"""初始化黑话学习器。
Args:
session_id: 当前会话 ID。
session_name: 当前会话展示名称。
"""
self.session_id = session_id
self.session_name = session_name
@@ -46,13 +180,92 @@ class JargonMiner:
# 黑话提取锁,防止并发执行
self._extraction_lock = asyncio.Lock()
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
"""将任意值安全转换为整数。
Args:
value: 待转换的值。
default: 转换失败时使用的默认值。
Returns:
int: 转换后的整数结果。
"""
try:
return int(value)
except (TypeError, ValueError):
return default
@staticmethod
def _serialize_jargon_entries(entries: List[JargonEntry]) -> List[Dict[str, object]]:
"""将黑话条目列表序列化为 Hook 可传输结构。
Args:
entries: 原始黑话条目列表。
Returns:
List[Dict[str, object]]: 序列化后的条目列表。
"""
return [
{
"content": str(entry["content"]).strip(),
"raw_content": sorted(str(item).strip() for item in entry["raw_content"] if str(item).strip()),
}
for entry in entries
if str(entry["content"]).strip()
]
@staticmethod
def _deserialize_jargon_entries(raw_entries: Any) -> List[JargonEntry]:
"""从 Hook 载荷恢复黑话条目列表。
Args:
raw_entries: Hook 返回的条目数据。
Returns:
List[JargonEntry]: 恢复后的黑话条目列表。
"""
if not isinstance(raw_entries, list):
return []
normalized_entries: List[JargonEntry] = []
for raw_entry in raw_entries:
if not isinstance(raw_entry, dict):
continue
content = str(raw_entry.get("content") or "").strip()
if not content:
continue
raw_content_values = raw_entry.get("raw_content")
raw_content: Set[str] = set()
if isinstance(raw_content_values, list):
raw_content = {str(item).strip() for item in raw_content_values if str(item).strip()}
normalized_entries.append({"content": content, "raw_content": raw_content})
return normalized_entries
def get_cached_jargons(self) -> List[str]:
"""获取缓存中的所有黑话列表"""
return list(self.cache.keys())
async def infer_meaning(self, jargon_obj: MaiJargon) -> None:
"""
对jargon进行含义推断
"""对黑话条目执行含义推断。
Args:
jargon_obj: 待推断的黑话数据对象。
"""
content = jargon_obj.content
# 解析raw_content列表
@@ -175,15 +388,45 @@ class JargonMiner:
is_similar = comparison_result.get("is_similar", False)
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
finalized_meaning = inference1.get("meaning", "") if is_jargon else ""
is_complete = (jargon_obj.count or 0) >= 100
last_inference_count = jargon_obj.count or 0
finalize_result = await self._get_runtime_manager().invoke_hook(
"jargon.inference.before_finalize",
session_id=self.session_id,
session_name=self.session_name,
content=content,
count=current_count,
raw_content_list=list(raw_content_list),
inference_with_context=dict(inference1),
inference_with_content_only=dict(inference2),
comparison_result=dict(comparison_result),
is_jargon=is_jargon,
meaning=finalized_meaning,
is_complete=is_complete,
last_inference_count=last_inference_count,
)
if finalize_result.aborted:
logger.info(f"jargon {content} 的推断结果被 Hook 中止写回")
return
finalize_kwargs = finalize_result.kwargs
is_jargon = bool(finalize_kwargs.get("is_jargon", is_jargon))
finalized_meaning = str(finalize_kwargs.get("meaning", finalized_meaning) or "").strip() if is_jargon else ""
is_complete = bool(finalize_kwargs.get("is_complete", is_complete))
last_inference_count = self._coerce_int(
finalize_kwargs.get("last_inference_count"),
last_inference_count,
)
# 更新数据库记录
jargon_obj.is_jargon = is_jargon
jargon_obj.meaning = inference1.get("meaning", "") if is_jargon else ""
jargon_obj.meaning = finalized_meaning
# 更新最后一次判定的count值避免重启后重复判定
jargon_obj.last_inference_count = jargon_obj.count or 0
jargon_obj.last_inference_count = last_inference_count
# 如果count>=100标记为完成不再进行推断
if (jargon_obj.count or 0) >= 100:
jargon_obj.is_complete = True
jargon_obj.is_complete = is_complete
try:
self._modify_jargon_entry(jargon_obj)
@@ -232,6 +475,22 @@ class JargonMiner:
merged_entries[content] = {"content": content, "raw_content": set(raw_list)}
uniq_entries: List[JargonEntry] = list(merged_entries.values())
before_persist_result = await self._get_runtime_manager().invoke_hook(
"jargon.extract.before_persist",
session_id=self.session_id,
session_name=self.session_name,
entries=self._serialize_jargon_entries(uniq_entries),
)
if before_persist_result.aborted:
logger.info(f"[{self.session_name}] 黑话提取结果被 Hook 中止,不写入数据库")
return
raw_hook_entries = before_persist_result.kwargs.get("entries")
if raw_hook_entries is not None:
uniq_entries = self._deserialize_jargon_entries(raw_hook_entries)
if not uniq_entries:
logger.info(f"[{self.session_name}] Hook 过滤后没有可写入的黑话条目")
return
saved = 0
updated = 0