merge: 同步 upstream/r-dev 并解决冲突
This commit is contained in:
@@ -1,25 +1,28 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import heapq
|
||||
import Levenshtein
|
||||
import random
|
||||
import re
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
|
||||
import Levenshtein
|
||||
|
||||
from src.common.data_models.image_data_model import MaiEmoji
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.database.database import get_db_session, get_db_session_manual
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMImageOptions
|
||||
from src.common.database.database import get_db_session, get_db_session_manual
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("emoji")
|
||||
@@ -33,6 +36,171 @@ EMOJI_REGISTERED_DIR = DATA_DIR / "emoji_registered" # 已注册的表情包注
|
||||
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
|
||||
|
||||
|
||||
def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册表情包系统内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
emoji_schema = {
|
||||
"type": "object",
|
||||
"description": "当前表情包的序列化信息,主要包含 file_hash、description、emotions 等字段。",
|
||||
}
|
||||
string_array_schema = {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="emoji.maisaka.before_select",
|
||||
description="Maisaka 表情发送工具选择表情前触发,可改写情绪、上下文和采样参数,或中止本次选择。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"stream_id": {"type": "string", "description": "目标会话 ID。"},
|
||||
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
|
||||
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
|
||||
"context_texts": {
|
||||
**string_array_schema,
|
||||
"description": "最近聊天上下文文本列表。",
|
||||
},
|
||||
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
|
||||
"abort_message": {
|
||||
"type": "string",
|
||||
"description": "当 Hook 主动中止时可附带的失败提示。",
|
||||
},
|
||||
},
|
||||
required=["stream_id", "requested_emotion", "reasoning", "context_texts", "sample_size"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="emoji.maisaka.after_select",
|
||||
description="Maisaka 已选出表情后触发,可替换选中的表情哈希、补充匹配情绪,或中止发送。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"stream_id": {"type": "string", "description": "目标会话 ID。"},
|
||||
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
|
||||
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
|
||||
"context_texts": {
|
||||
**string_array_schema,
|
||||
"description": "最近聊天上下文文本列表。",
|
||||
},
|
||||
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
|
||||
"selected_emoji": emoji_schema,
|
||||
"selected_emoji_hash": {"type": "string", "description": "选中的表情哈希。"},
|
||||
"matched_emotion": {"type": "string", "description": "最终命中的情绪标签。"},
|
||||
"abort_message": {
|
||||
"type": "string",
|
||||
"description": "当 Hook 主动中止时可附带的失败提示。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"stream_id",
|
||||
"requested_emotion",
|
||||
"reasoning",
|
||||
"context_texts",
|
||||
"sample_size",
|
||||
"matched_emotion",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="emoji.register.after_build_description",
|
||||
description="表情包描述生成并通过内容审查后触发,可改写描述文本或拒绝本次注册。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"emoji": emoji_schema,
|
||||
"description": {"type": "string", "description": "当前生成出的表情包描述。"},
|
||||
"image_format": {"type": "string", "description": "表情图片格式。"},
|
||||
},
|
||||
required=["emoji", "description", "image_format"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="emoji.register.after_build_emotion",
|
||||
description="表情包情绪标签生成完成后触发,可改写标签列表或拒绝本次注册。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"emoji": emoji_schema,
|
||||
"description": {"type": "string", "description": "当前表情包描述。"},
|
||||
"emotions": {
|
||||
**string_array_schema,
|
||||
"description": "当前生成出的情绪标签列表。",
|
||||
},
|
||||
},
|
||||
required=["emoji", "description", "emotions"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
|
||||
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, Any]]:
|
||||
"""将表情包对象序列化为 Hook 可传输载荷。
|
||||
|
||||
Args:
|
||||
emoji: 待序列化的表情包对象。
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 序列化后的字典;当表情为空时返回 ``None``。
|
||||
"""
|
||||
|
||||
if emoji is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"file_hash": str(emoji.file_hash or "").strip(),
|
||||
"file_name": emoji.file_name,
|
||||
"full_path": str(emoji.full_path),
|
||||
"description": emoji.description,
|
||||
"emotions": [str(item).strip() for item in emoji.emotion if str(item).strip()],
|
||||
"query_count": int(emoji.query_count),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_string_list(raw_values: Any) -> List[str]:
|
||||
"""将任意列表值规范化为字符串列表。
|
||||
|
||||
Args:
|
||||
raw_values: 待规范化的原始值。
|
||||
|
||||
Returns:
|
||||
List[str]: 去空白后的字符串列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_values, list):
|
||||
return []
|
||||
return [str(item).strip() for item in raw_values if str(item).strip()]
|
||||
|
||||
|
||||
def _ensure_directories() -> None:
|
||||
"""确保表情包相关目录存在"""
|
||||
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
|
||||
@@ -642,6 +810,22 @@ class EmojiManager:
|
||||
if "否" in llm_response:
|
||||
logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
hook_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.register.after_build_description",
|
||||
emoji=_serialize_emoji_for_hook(target_emoji),
|
||||
description=description,
|
||||
image_format=image_format,
|
||||
)
|
||||
if hook_result.aborted:
|
||||
logger.info(f"[构建描述] 表情包描述被 Hook 中止注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
normalized_description = str(hook_result.kwargs.get("description", description) or "").strip()
|
||||
if not normalized_description:
|
||||
logger.warning(f"[构建描述] Hook 返回空描述,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
description = normalized_description
|
||||
target_emoji.description = description
|
||||
logger.info(f"[构建描述] 成功为表情包构建描述: {target_emoji.description}")
|
||||
return True, target_emoji
|
||||
@@ -687,6 +871,23 @@ class EmojiManager:
|
||||
elif len(emotions) > 2:
|
||||
emotions = random.sample(emotions, 2)
|
||||
|
||||
hook_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.register.after_build_emotion",
|
||||
emoji=_serialize_emoji_for_hook(target_emoji),
|
||||
description=target_emoji.description,
|
||||
emotions=list(emotions),
|
||||
)
|
||||
if hook_result.aborted:
|
||||
logger.info(f"[构建情感标签] 表情包情感标签被 Hook 中止注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
raw_emotions = hook_result.kwargs.get("emotions")
|
||||
if raw_emotions is not None:
|
||||
emotions = _normalize_string_list(raw_emotions)
|
||||
if not emotions:
|
||||
logger.warning(f"[构建情感标签] Hook 返回空情绪标签,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
logger.info(f"[构建情感标签] 成功为表情包构建情感标签: {','.join(emotions)}")
|
||||
target_emoji.emotion = emotions
|
||||
return True, target_emoji
|
||||
|
||||
349
src/chat/emoji_system/maisaka_tool.py
Normal file
349
src/chat/emoji_system/maisaka_tool.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""Maisaka 表情工具内置能力。"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
import random
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.cli.maisaka_cli_sender import CLI_PLATFORM_NAME, render_cli_message
|
||||
from src.common.data_models.image_data_model import MaiEmoji
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.services import send_service
|
||||
|
||||
from .emoji_manager import _serialize_emoji_for_hook, emoji_manager, emoji_manager_emotion_judge_llm
|
||||
|
||||
logger = get_logger("emoji_maisaka_tool")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MaisakaEmojiSendResult:
|
||||
"""Maisaka 表情发送结果。"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
emoji_base64: str = ""
|
||||
description: str = ""
|
||||
emotions: list[str] = field(default_factory=list)
|
||||
requested_emotion: str = ""
|
||||
matched_emotion: str = ""
|
||||
|
||||
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
|
||||
def _coerce_positive_int(value: Any, default: int) -> int:
|
||||
"""将任意值安全转换为正整数。
|
||||
|
||||
Args:
|
||||
value: 待转换的值。
|
||||
default: 转换失败时使用的默认值。
|
||||
|
||||
Returns:
|
||||
int: 规范化后的正整数。
|
||||
"""
|
||||
|
||||
try:
|
||||
normalized_value = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return normalized_value if normalized_value > 0 else default
|
||||
|
||||
|
||||
def _normalize_context_texts(context_texts: Sequence[str] | None) -> list[str]:
|
||||
"""清洗 Hook 和调用链传入的上下文文本列表。
|
||||
|
||||
Args:
|
||||
context_texts: 原始上下文文本序列。
|
||||
|
||||
Returns:
|
||||
list[str]: 过滤空白后的上下文文本列表。
|
||||
"""
|
||||
|
||||
if not context_texts:
|
||||
return []
|
||||
return [str(item).strip() for item in context_texts if str(item).strip()]
|
||||
|
||||
|
||||
def _resolve_selected_emoji(raw_value: Any) -> Optional[MaiEmoji]:
|
||||
"""根据 Hook 返回值解析目标表情包对象。
|
||||
|
||||
Args:
|
||||
raw_value: Hook 返回的 ``selected_emoji`` 或 ``selected_emoji_hash``。
|
||||
|
||||
Returns:
|
||||
Optional[MaiEmoji]: 命中的表情包对象;未命中时返回 ``None``。
|
||||
"""
|
||||
|
||||
raw_hash: str = ""
|
||||
if isinstance(raw_value, dict):
|
||||
raw_hash = str(raw_value.get("file_hash") or raw_value.get("hash") or "").strip()
|
||||
elif isinstance(raw_value, str):
|
||||
raw_hash = raw_value.strip()
|
||||
|
||||
if not raw_hash:
|
||||
return None
|
||||
|
||||
for emoji in emoji_manager.emojis:
|
||||
if emoji.file_hash == raw_hash:
|
||||
return emoji
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
|
||||
"""提取并清洗单个表情的情绪标签。"""
|
||||
|
||||
return [str(item).strip() for item in emoji.emotion if str(item).strip()]
|
||||
|
||||
|
||||
def _build_recent_context_text(context_texts: Sequence[str], max_items: int = 5) -> str:
|
||||
"""构建供情绪判断使用的最近上下文文本。"""
|
||||
|
||||
normalized_items = [str(item).strip() for item in context_texts if str(item).strip()]
|
||||
if not normalized_items:
|
||||
return ""
|
||||
return "\n".join(normalized_items[-max_items:])
|
||||
|
||||
|
||||
async def _select_emoji_with_llm(
|
||||
*,
|
||||
sampled_emojis: Sequence[MaiEmoji],
|
||||
reasoning: str,
|
||||
context_text: str,
|
||||
) -> tuple[MaiEmoji, str]:
|
||||
"""让模型在采样表情中选择更合适的情绪标签。"""
|
||||
|
||||
emotion_map: dict[str, list[MaiEmoji]] = {}
|
||||
for emoji in sampled_emojis:
|
||||
for emotion in _normalize_emotions(emoji):
|
||||
emotion_map.setdefault(emotion, []).append(emoji)
|
||||
|
||||
available_emotions = list(emotion_map.keys())
|
||||
if not available_emotions:
|
||||
return random.choice(list(sampled_emojis)), ""
|
||||
|
||||
prompt = (
|
||||
"你正在为聊天场景选择一个最合适的表情包情绪标签。\n"
|
||||
f"发送原因:{reasoning or '辅助表达当前语气和情绪'}\n"
|
||||
f"最近聊天记录:\n{context_text or '(暂无额外上下文)'}\n\n"
|
||||
"可选情绪标签如下:\n"
|
||||
f"{chr(10).join(available_emotions)}\n\n"
|
||||
"请只返回一个最匹配的情绪标签,不要解释。"
|
||||
)
|
||||
|
||||
try:
|
||||
llm_result = await emoji_manager_emotion_judge_llm.generate_response(
|
||||
prompt,
|
||||
options=LLMGenerationOptions(temperature=0.3, max_tokens=60),
|
||||
)
|
||||
chosen_emotion = (llm_result.response or "").strip().strip("\"'")
|
||||
except Exception as exc:
|
||||
logger.warning(f"使用 LLM 选择表情情绪失败,将回退为随机选择: {exc}")
|
||||
chosen_emotion = ""
|
||||
|
||||
if chosen_emotion and chosen_emotion in emotion_map:
|
||||
return random.choice(emotion_map[chosen_emotion]), chosen_emotion
|
||||
return random.choice(list(sampled_emojis)), ""
|
||||
|
||||
|
||||
async def select_emoji_for_maisaka(
|
||||
*,
|
||||
requested_emotion: str = "",
|
||||
reasoning: str = "",
|
||||
context_texts: Sequence[str] | None = None,
|
||||
sample_size: int = 30,
|
||||
) -> tuple[MaiEmoji | None, str]:
|
||||
"""为 Maisaka 选择一个合适的表情。"""
|
||||
|
||||
available_emojis = list(emoji_manager.emojis)
|
||||
if not available_emojis:
|
||||
return None, ""
|
||||
|
||||
normalized_requested_emotion = requested_emotion.strip()
|
||||
if normalized_requested_emotion:
|
||||
matched_emojis = [
|
||||
emoji
|
||||
for emoji in available_emojis
|
||||
if normalized_requested_emotion.lower() in (emotion.lower() for emotion in _normalize_emotions(emoji))
|
||||
]
|
||||
if matched_emojis:
|
||||
return random.choice(matched_emojis), normalized_requested_emotion
|
||||
|
||||
sampled_emojis = random.sample(
|
||||
available_emojis,
|
||||
min(max(sample_size, 1), len(available_emojis)),
|
||||
)
|
||||
context_text = _build_recent_context_text(context_texts or [])
|
||||
return await _select_emoji_with_llm(
|
||||
sampled_emojis=sampled_emojis,
|
||||
reasoning=reasoning,
|
||||
context_text=context_text,
|
||||
)
|
||||
|
||||
|
||||
async def send_emoji_for_maisaka(
|
||||
*,
|
||||
stream_id: str,
|
||||
requested_emotion: str = "",
|
||||
reasoning: str = "",
|
||||
context_texts: Sequence[str] | None = None,
|
||||
) -> MaisakaEmojiSendResult:
|
||||
"""为 Maisaka 选择并发送一个表情。"""
|
||||
|
||||
normalized_requested_emotion = requested_emotion.strip()
|
||||
normalized_reasoning = reasoning.strip()
|
||||
normalized_context_texts = _normalize_context_texts(context_texts)
|
||||
sample_size = 30
|
||||
|
||||
before_select_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.maisaka.before_select",
|
||||
stream_id=stream_id,
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
reasoning=normalized_reasoning,
|
||||
context_texts=list(normalized_context_texts),
|
||||
sample_size=sample_size,
|
||||
abort_message="表情选择已被 Hook 中止。",
|
||||
)
|
||||
if before_select_result.aborted:
|
||||
abort_message = str(before_select_result.kwargs.get("abort_message") or "表情选择已被 Hook 中止。").strip()
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message=abort_message or "表情选择已被 Hook 中止。",
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
)
|
||||
|
||||
before_select_kwargs = before_select_result.kwargs
|
||||
normalized_requested_emotion = str(
|
||||
before_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
|
||||
).strip()
|
||||
normalized_reasoning = str(before_select_kwargs.get("reasoning", normalized_reasoning) or "").strip()
|
||||
if isinstance(before_select_kwargs.get("context_texts"), list):
|
||||
normalized_context_texts = _normalize_context_texts(before_select_kwargs.get("context_texts"))
|
||||
sample_size = _coerce_positive_int(before_select_kwargs.get("sample_size"), sample_size)
|
||||
|
||||
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
reasoning=normalized_reasoning,
|
||||
context_texts=normalized_context_texts,
|
||||
sample_size=sample_size,
|
||||
)
|
||||
after_select_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.maisaka.after_select",
|
||||
stream_id=stream_id,
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
reasoning=normalized_reasoning,
|
||||
context_texts=list(normalized_context_texts),
|
||||
sample_size=sample_size,
|
||||
selected_emoji=_serialize_emoji_for_hook(selected_emoji),
|
||||
selected_emoji_hash=str(selected_emoji.file_hash or "").strip() if selected_emoji is not None else "",
|
||||
matched_emotion=matched_emotion,
|
||||
abort_message="表情发送已被 Hook 中止。",
|
||||
)
|
||||
if after_select_result.aborted:
|
||||
abort_message = str(after_select_result.kwargs.get("abort_message") or "表情发送已被 Hook 中止。").strip()
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message=abort_message or "表情发送已被 Hook 中止。",
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
after_select_kwargs = after_select_result.kwargs
|
||||
normalized_requested_emotion = str(
|
||||
after_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
|
||||
).strip()
|
||||
matched_emotion = str(after_select_kwargs.get("matched_emotion", matched_emotion) or "").strip()
|
||||
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji_hash"))
|
||||
if override_emoji is None:
|
||||
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji"))
|
||||
if override_emoji is not None:
|
||||
selected_emoji = override_emoji
|
||||
|
||||
if selected_emoji is None:
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message="当前表情包库中没有可用表情。",
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
try:
|
||||
emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
|
||||
if not emoji_base64:
|
||||
raise ValueError("表情图片转换为 base64 失败")
|
||||
except Exception as exc:
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message=f"发送表情包失败:{exc}",
|
||||
description=selected_emoji.description.strip(),
|
||||
emotions=_normalize_emotions(selected_emoji),
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
try:
|
||||
target_session = chat_manager.get_session_by_session_id(stream_id)
|
||||
if target_session is not None and target_session.platform == CLI_PLATFORM_NAME:
|
||||
preview_message = (
|
||||
f"已发送表情包:{selected_emoji.description.strip()}"
|
||||
if selected_emoji.description.strip()
|
||||
else "[表情包]"
|
||||
)
|
||||
render_cli_message(preview_message)
|
||||
sent = True
|
||||
else:
|
||||
sent = await send_service.emoji_to_stream(
|
||||
emoji_base64=emoji_base64,
|
||||
stream_id=stream_id,
|
||||
storage_message=True,
|
||||
set_reply=False,
|
||||
reply_message=None,
|
||||
)
|
||||
except Exception as exc:
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message=f"发送表情包时发生异常:{exc}",
|
||||
description=selected_emoji.description.strip(),
|
||||
emotions=_normalize_emotions(selected_emoji),
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
description = selected_emoji.description.strip()
|
||||
emotions = _normalize_emotions(selected_emoji)
|
||||
if not sent:
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message="发送表情包失败。",
|
||||
description=description,
|
||||
emotions=emotions,
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
emoji_manager.update_emoji_usage(selected_emoji)
|
||||
success_message = (
|
||||
f"已发送表情包:{description}(情绪:{', '.join(emotions)})"
|
||||
if emotions
|
||||
else f"已发送表情包:{description}"
|
||||
)
|
||||
return MaisakaEmojiSendResult(
|
||||
success=True,
|
||||
message=success_message,
|
||||
emoji_base64=emoji_base64,
|
||||
description=description,
|
||||
emotions=emotions,
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
@@ -44,6 +44,12 @@ class ImageManager:
|
||||
|
||||
logger.info("图片管理器初始化完成")
|
||||
|
||||
def _get_image_record(self, image_hash: str) -> Optional[Images]:
|
||||
"""根据哈希获取图片记录。"""
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1)
|
||||
return session.exec(statement).first()
|
||||
|
||||
async def get_image_description(
|
||||
self,
|
||||
*,
|
||||
@@ -76,9 +82,8 @@ class ImageManager:
|
||||
hash_str = hashlib.sha256(image_bytes).hexdigest()
|
||||
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1)
|
||||
if record := session.exec(statement).first():
|
||||
if record := self._get_image_record(hash_str):
|
||||
if record.vlm_processed and record.description:
|
||||
return record.description
|
||||
except Exception as e:
|
||||
logger.error(f"查询图片描述时发生错误: {e}")
|
||||
@@ -86,12 +91,17 @@ class ImageManager:
|
||||
if not image_bytes:
|
||||
logger.warning("图片哈希值未找到,且未提供图片字节数据,返回无描述")
|
||||
return ""
|
||||
try:
|
||||
await self.ensure_image_saved(image_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件时发生错误: {e}")
|
||||
return ""
|
||||
if not wait_for_build:
|
||||
self._schedule_description_build(hash_str, image_bytes)
|
||||
return ""
|
||||
logger.info(f"图片描述未找到,哈希值: {hash_str},准备生成新描述")
|
||||
try:
|
||||
image = await self.save_image_and_process(image_bytes)
|
||||
image = await self.build_image_description(image_bytes)
|
||||
return image.description
|
||||
except Exception as e:
|
||||
logger.error(f"生成图片描述时发生错误: {e}")
|
||||
@@ -120,7 +130,7 @@ class ImageManager:
|
||||
"""
|
||||
try:
|
||||
logger.info(f"图片描述后台构建已开始,哈希值: {image_hash}")
|
||||
await self.save_image_and_process(image_bytes)
|
||||
await self.build_image_description(image_bytes)
|
||||
logger.info(f"图片描述后台构建完成,哈希值: {image_hash}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"图片描述后台构建失败,哈希值: {image_hash},错误: {exc}")
|
||||
@@ -201,6 +211,7 @@ class ImageManager:
|
||||
return False
|
||||
record.description = image.description
|
||||
record.last_used_time = datetime.now()
|
||||
record.vlm_processed = image.vlm_processed
|
||||
session.add(record)
|
||||
logger.info(f"成功更新图片描述: {image.file_hash},新描述: {image.description}")
|
||||
except Exception as e:
|
||||
@@ -239,22 +250,13 @@ class ImageManager:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def save_image_and_process(self, image_bytes: bytes) -> MaiImage:
|
||||
"""
|
||||
保存图片并生成描述
|
||||
|
||||
Args:
|
||||
image_bytes (bytes): 图片的字节数据
|
||||
Returns:
|
||||
return (MaiImage): 包含图片信息的 MaiImage 对象
|
||||
Raises:
|
||||
Exception: 如果在保存或处理过程中发生错误
|
||||
"""
|
||||
async def ensure_image_saved(self, image_bytes: bytes) -> MaiImage:
|
||||
"""先保存图片记录,确保后续可以按哈希回填图片内容。"""
|
||||
hash_str = hashlib.sha256(image_bytes).hexdigest()
|
||||
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=hash_str).limit(1)
|
||||
statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1)
|
||||
if record := session.exec(statement).first():
|
||||
logger.info(f"图片已存在于数据库中,哈希值: {hash_str}")
|
||||
record.last_used_time = datetime.now()
|
||||
@@ -270,18 +272,38 @@ class ImageManager:
|
||||
tmp_file_path = IMAGE_DIR / f"{hash_str}.tmp"
|
||||
with tmp_file_path.open("wb") as f:
|
||||
f.write(image_bytes)
|
||||
mai_image = MaiImage(full_path=(IMAGE_DIR / f"{hash_str}.tmp"), image_bytes=image_bytes)
|
||||
mai_image = MaiImage(full_path=tmp_file_path, image_bytes=image_bytes)
|
||||
await mai_image.calculate_hash_format()
|
||||
if not self.register_image_to_db(mai_image):
|
||||
raise RuntimeError(f"保存图片记录到数据库失败: {hash_str}")
|
||||
return mai_image
|
||||
|
||||
async def build_image_description(self, image_bytes: bytes) -> MaiImage:
|
||||
"""在图片已保存的前提下生成或补齐图片描述。"""
|
||||
mai_image = await self.ensure_image_saved(image_bytes)
|
||||
if mai_image.vlm_processed and mai_image.description:
|
||||
return mai_image
|
||||
|
||||
desc = await self._generate_image_description(image_bytes, mai_image.image_format)
|
||||
mai_image.description = desc
|
||||
mai_image.vlm_processed = True
|
||||
try:
|
||||
self.register_image_to_db(mai_image)
|
||||
except Exception as e:
|
||||
logger.error(f"保存新图片记录到数据库时发生错误: {e}")
|
||||
raise e
|
||||
if not self.update_image_description(mai_image):
|
||||
raise RuntimeError(f"更新图片描述失败: {mai_image.file_hash}")
|
||||
return mai_image
|
||||
|
||||
async def save_image_and_process(self, image_bytes: bytes) -> MaiImage:
|
||||
"""
|
||||
保存图片并生成描述
|
||||
|
||||
Args:
|
||||
image_bytes (bytes): 图片的字节数据
|
||||
Returns:
|
||||
return (MaiImage): 包含图片信息的 MaiImage 对象
|
||||
Raises:
|
||||
Exception: 如果在保存或处理过程中发生错误
|
||||
"""
|
||||
return await self.build_image_description(image_bytes)
|
||||
|
||||
def cleanup_invalid_descriptions_in_db(self):
|
||||
"""
|
||||
清理数据库中无效的图片记录
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""聊天消息入口与主链路调度。"""
|
||||
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import os
|
||||
import traceback
|
||||
@@ -13,12 +14,15 @@ from src.common.utils.utils_message import MessageUtils
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.plugin_runtime.hook_payloads import deserialize_session_message, serialize_session_message
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
|
||||
from .message import SessionMessage
|
||||
from .chat_manager import chat_manager
|
||||
from .message import SessionMessage
|
||||
|
||||
# 定义日志配置
|
||||
|
||||
@@ -29,7 +33,137 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
def register_chat_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册聊天消息主链内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="chat.receive.before_process",
|
||||
description="在入站消息执行 `SessionMessage.process()` 之前触发,可拦截或改写消息。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "当前入站消息的序列化 SessionMessage。",
|
||||
},
|
||||
},
|
||||
required=["message"],
|
||||
),
|
||||
default_timeout_ms=8000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="chat.receive.after_process",
|
||||
description="在入站消息完成轻量预处理后触发,可改写文本、消息体或中止后续链路。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "已完成 `process()` 的序列化 SessionMessage。",
|
||||
},
|
||||
},
|
||||
required=["message"],
|
||||
),
|
||||
default_timeout_ms=8000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="chat.command.before_execute",
|
||||
description="在命令匹配成功、实际执行前触发,可拦截命令或改写命令上下文。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "当前命令消息的序列化 SessionMessage。",
|
||||
},
|
||||
"command_name": {
|
||||
"type": "string",
|
||||
"description": "命中的命令名称。",
|
||||
},
|
||||
"plugin_id": {
|
||||
"type": "string",
|
||||
"description": "命令所属插件 ID。",
|
||||
},
|
||||
"matched_groups": {
|
||||
"type": "object",
|
||||
"description": "命令正则命名捕获结果。",
|
||||
},
|
||||
},
|
||||
required=["message", "command_name", "plugin_id", "matched_groups"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="chat.command.after_execute",
|
||||
description="在命令执行结束后触发,可调整返回文本和是否继续主链处理。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "当前命令消息的序列化 SessionMessage。",
|
||||
},
|
||||
"command_name": {
|
||||
"type": "string",
|
||||
"description": "命令名称。",
|
||||
},
|
||||
"plugin_id": {
|
||||
"type": "string",
|
||||
"description": "命令所属插件 ID。",
|
||||
},
|
||||
"matched_groups": {
|
||||
"type": "object",
|
||||
"description": "命令正则命名捕获结果。",
|
||||
},
|
||||
"success": {
|
||||
"type": "boolean",
|
||||
"description": "命令执行是否成功。",
|
||||
},
|
||||
"response": {
|
||||
"type": "string",
|
||||
"description": "命令返回文本。",
|
||||
},
|
||||
"intercept_message_level": {
|
||||
"type": "integer",
|
||||
"description": "命令拦截等级。",
|
||||
},
|
||||
"continue_process": {
|
||||
"type": "boolean",
|
||||
"description": "命令执行后是否继续后续消息处理。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"message",
|
||||
"command_name",
|
||||
"plugin_id",
|
||||
"matched_groups",
|
||||
"success",
|
||||
"intercept_message_level",
|
||||
"continue_process",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=False,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ChatBot:
|
||||
"""聊天机器人入口协调器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化聊天机器人入口。"""
|
||||
|
||||
@@ -44,6 +178,66 @@ class ChatBot:
|
||||
|
||||
self._started = True
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
"""将任意值安全转换为整数。
|
||||
|
||||
Args:
|
||||
value: 待转换的值。
|
||||
default: 转换失败时的默认值。
|
||||
|
||||
Returns:
|
||||
int: 转换后的整数结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
async def _invoke_message_hook(
|
||||
self,
|
||||
hook_name: str,
|
||||
message: SessionMessage,
|
||||
**kwargs: Any,
|
||||
) -> tuple[HookDispatchResult, SessionMessage]:
|
||||
"""触发携带会话消息的命名 Hook。
|
||||
|
||||
Args:
|
||||
hook_name: 目标 Hook 名称。
|
||||
message: 当前会话消息。
|
||||
**kwargs: 需要附带传递的额外参数。
|
||||
|
||||
Returns:
|
||||
tuple[HookDispatchResult, SessionMessage]: Hook 聚合结果以及可能被改写后的消息对象。
|
||||
"""
|
||||
|
||||
hook_result = await self._get_runtime_manager().invoke_hook(
|
||||
hook_name,
|
||||
message=serialize_session_message(message),
|
||||
**kwargs,
|
||||
)
|
||||
mutated_message = message
|
||||
raw_message = hook_result.kwargs.get("message")
|
||||
if raw_message is not None:
|
||||
try:
|
||||
mutated_message = deserialize_session_message(raw_message)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Hook {hook_name} 返回的 message 无法反序列化,已忽略: {exc}")
|
||||
return hook_result, mutated_message
|
||||
|
||||
async def _process_commands(self, message: SessionMessage) -> tuple[bool, Optional[str], bool]:
|
||||
"""使用统一组件注册表处理命令。
|
||||
|
||||
@@ -71,6 +265,25 @@ class ChatBot:
|
||||
return False, None, True
|
||||
|
||||
message.is_command = True
|
||||
before_result, message = await self._invoke_message_hook(
|
||||
"chat.command.before_execute",
|
||||
message,
|
||||
command_name=command_name,
|
||||
plugin_id=plugin_name,
|
||||
matched_groups=dict(matched_groups),
|
||||
)
|
||||
if before_result.aborted:
|
||||
logger.info(f"命令 {command_name} 被 Hook 中止,跳过命令执行")
|
||||
return True, None, False
|
||||
|
||||
hook_kwargs = before_result.kwargs
|
||||
command_name = str(hook_kwargs.get("command_name", command_name) or command_name)
|
||||
plugin_name = str(hook_kwargs.get("plugin_id", plugin_name) or plugin_name)
|
||||
matched_groups = (
|
||||
dict(hook_kwargs["matched_groups"])
|
||||
if isinstance(hook_kwargs.get("matched_groups"), dict)
|
||||
else dict(matched_groups)
|
||||
)
|
||||
|
||||
# 获取插件配置
|
||||
plugin_config = component_query_service.get_plugin_config(plugin_name)
|
||||
@@ -82,27 +295,43 @@ class ChatBot:
|
||||
plugin_config=plugin_config,
|
||||
matched_groups=matched_groups,
|
||||
)
|
||||
self._mark_command_message(message, intercept_message_level)
|
||||
|
||||
# 记录命令执行结果
|
||||
if success:
|
||||
logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})")
|
||||
else:
|
||||
logger.warning(f"命令执行失败: {command_name} - {response}")
|
||||
|
||||
# 根据命令的拦截设置决定是否继续处理消息
|
||||
return (
|
||||
True,
|
||||
response,
|
||||
not bool(intercept_message_level),
|
||||
) # 找到命令,根据intercept_message决定是否继续
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令时出错: {command_name} - {e}")
|
||||
continue_process = not bool(intercept_message_level)
|
||||
except Exception as exc:
|
||||
logger.error(f"执行命令时出错: {command_name} - {exc}")
|
||||
logger.error(traceback.format_exc())
|
||||
success = False
|
||||
response = str(exc)
|
||||
intercept_message_level = 1
|
||||
continue_process = False
|
||||
|
||||
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
|
||||
return True, str(e), False # 出错时继续处理消息
|
||||
after_result, message = await self._invoke_message_hook(
|
||||
"chat.command.after_execute",
|
||||
message,
|
||||
command_name=command_name,
|
||||
plugin_id=plugin_name,
|
||||
matched_groups=dict(matched_groups),
|
||||
success=success,
|
||||
response=response,
|
||||
intercept_message_level=intercept_message_level,
|
||||
continue_process=continue_process,
|
||||
)
|
||||
after_kwargs = after_result.kwargs
|
||||
success = bool(after_kwargs.get("success", success))
|
||||
raw_response = after_kwargs.get("response", response)
|
||||
response = None if raw_response is None else str(raw_response)
|
||||
intercept_message_level = self._coerce_int(
|
||||
after_kwargs.get("intercept_message_level", intercept_message_level),
|
||||
intercept_message_level,
|
||||
)
|
||||
continue_process = bool(after_kwargs.get("continue_process", continue_process))
|
||||
self._mark_command_message(message, intercept_message_level)
|
||||
|
||||
if success:
|
||||
logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})")
|
||||
else:
|
||||
logger.warning(f"命令执行失败: {command_name} - {response}")
|
||||
|
||||
return True, response, continue_process
|
||||
|
||||
return False, None, True
|
||||
|
||||
@@ -138,6 +367,17 @@ class ChatBot:
|
||||
cmd_result: Optional[str],
|
||||
continue_process: bool,
|
||||
) -> bool:
|
||||
"""处理命令链结果并决定是否终止主消息链。
|
||||
|
||||
Args:
|
||||
message: 当前命令消息。
|
||||
cmd_result: 命令响应文本。
|
||||
continue_process: 是否继续后续主链处理。
|
||||
|
||||
Returns:
|
||||
bool: ``True`` 表示已经终止后续主链。
|
||||
"""
|
||||
|
||||
if continue_process:
|
||||
return False
|
||||
|
||||
@@ -145,9 +385,18 @@ class ChatBot:
|
||||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||
return True
|
||||
|
||||
async def handle_notice_message(self, message: SessionMessage):
|
||||
async def handle_notice_message(self, message: SessionMessage) -> bool:
|
||||
"""处理通知类消息。
|
||||
|
||||
Args:
|
||||
message: 当前通知消息。
|
||||
|
||||
Returns:
|
||||
bool: 当前消息是否为通知消息。
|
||||
"""
|
||||
|
||||
if message.message_id != "notice":
|
||||
return
|
||||
return False
|
||||
|
||||
message.is_notify = True
|
||||
logger.debug("notice消息")
|
||||
@@ -203,9 +452,12 @@ class ChatBot:
|
||||
return True
|
||||
|
||||
async def echo_message_process(self, raw_data: Dict[str, Any]) -> None:
|
||||
"""处理消息回送 ID 对应关系。
|
||||
|
||||
Args:
|
||||
raw_data: 平台适配器上报的原始回送载荷。
|
||||
"""
|
||||
用于专门处理回送消息ID的函数
|
||||
"""
|
||||
|
||||
message_data: Dict[str, Any] = raw_data.get("content", {})
|
||||
if not message_data:
|
||||
return
|
||||
@@ -218,18 +470,10 @@ class ChatBot:
|
||||
logger.debug(f"收到回送消息ID: {mmc_message_id} -> {actual_message_id}")
|
||||
|
||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息
|
||||
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
|
||||
heart_flow模式:使用思维流系统进行回复
|
||||
- 包含思维流状态管理
|
||||
- 在回复前进行观察和状态更新
|
||||
- 回复后更新思维流状态
|
||||
- 消息过滤
|
||||
- 记忆激活
|
||||
- 意愿计算
|
||||
- 消息生成和发送
|
||||
- 表情包处理
|
||||
- 性能计时
|
||||
"""处理统一格式的入站消息字典。
|
||||
|
||||
Args:
|
||||
message_data: 适配器整理后的统一消息字典。
|
||||
"""
|
||||
try:
|
||||
# 确保所有任务已启动
|
||||
@@ -253,7 +497,13 @@ class ChatBot:
|
||||
logger.error(f"预处理消息失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def receive_message(self, message: SessionMessage):
|
||||
async def receive_message(self, message: SessionMessage) -> None:
|
||||
"""处理单条入站会话消息。
|
||||
|
||||
Args:
|
||||
message: 待处理的会话消息。
|
||||
"""
|
||||
|
||||
try:
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
@@ -272,6 +522,19 @@ class ChatBot:
|
||||
)
|
||||
|
||||
message.session_id = session_id # 正确初始化session_id
|
||||
before_process_result, message = await self._invoke_message_hook(
|
||||
"chat.receive.before_process",
|
||||
message,
|
||||
)
|
||||
if before_process_result.aborted:
|
||||
logger.info(f"消息 {message.message_id} 在预处理前被 Hook 中止")
|
||||
return
|
||||
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
additional_config = message.message_info.additional_config
|
||||
if isinstance(additional_config, dict):
|
||||
account_id, scope = RouteKeyFactory.extract_components(additional_config)
|
||||
|
||||
# TODO: 修复事件预处理部分
|
||||
# continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
@@ -286,14 +549,24 @@ class ChatBot:
|
||||
# if await self.handle_notice_message(message):
|
||||
# pass
|
||||
|
||||
# 处理消息内容,识别表情包等二进制数据并转化为文本描述
|
||||
if global_config.maisaka.direct_image_input:
|
||||
message.maisaka_original_raw_message = deepcopy(message.raw_message) # type: ignore[attr-defined]
|
||||
# 处理消息内容,识别表情包等二进制数据并转化为文本描述。
|
||||
# 如果 Maisaka 需要直接消费图片,会在后续构建 prompt 时按需回填图片二进制数据,
|
||||
# 这里不再复制整条原始消息。
|
||||
# 入站主链优先保证消息尽快入队,避免图片、表情包、语音分析阻塞适配器超时。
|
||||
await message.process(
|
||||
enable_heavy_media_analysis=False,
|
||||
enable_voice_transcription=False,
|
||||
)
|
||||
after_process_result, message = await self._invoke_message_hook(
|
||||
"chat.receive.after_process",
|
||||
message,
|
||||
)
|
||||
if after_process_result.aborted:
|
||||
logger.info(f"消息 {message.message_id} 在预处理后被 Hook 中止")
|
||||
return
|
||||
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
|
||||
# 平台层的 @ 检测由底层 is_mentioned_bot_in_message 统一处理;此处不做用户名硬编码匹配
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import asyncio
|
||||
from asyncio import Task
|
||||
from typing import Dict, List, Sequence, Tuple
|
||||
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages
|
||||
@@ -36,6 +35,102 @@ class MsgIDMapping:
|
||||
|
||||
|
||||
class SessionMessage(MaiMessage):
|
||||
|
||||
#便于调试的打印函数
|
||||
def __str__(self) -> str:
|
||||
"""返回适合日志输出的消息摘要。"""
|
||||
return self.to_debug_string()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回适合调试场景的消息摘要。"""
|
||||
return self.to_debug_string()
|
||||
|
||||
def to_debug_string(self) -> str:
|
||||
"""构建包含引用信息的调试字符串。
|
||||
|
||||
Returns:
|
||||
str: 适合记录日志的消息摘要。
|
||||
"""
|
||||
user_info = self.message_info.user_info
|
||||
group_info = self.message_info.group_info
|
||||
chat_type = "group" if group_info else "private"
|
||||
group_id = group_info.group_id if group_info else None
|
||||
group_name = group_info.group_name if group_info else None
|
||||
component_summaries = [self._summarize_component(component) for component in self.raw_message.components]
|
||||
raw_components = ", ".join(component_summaries) if component_summaries else "empty"
|
||||
|
||||
return (
|
||||
"SessionMessage("
|
||||
f"message_id={self.message_id!r}, "
|
||||
f"platform={self.platform!r}, "
|
||||
f"chat_type={chat_type!r}, "
|
||||
f"group_id={group_id!r}, "
|
||||
f"group_name={group_name!r}, "
|
||||
f"user_id={user_info.user_id!r}, "
|
||||
f"user_nickname={user_info.user_nickname!r}, "
|
||||
f"user_cardname={user_info.user_cardname!r}, "
|
||||
f"reply_to={self.reply_to!r}, "
|
||||
f"processed_plain_text={self._truncate_text(self.processed_plain_text)}, "
|
||||
f"raw_components=[{raw_components}]"
|
||||
")"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _truncate_text(text: str | None, max_length: int = 120) -> str:
|
||||
"""截断较长文本,避免日志过长。
|
||||
|
||||
Args:
|
||||
text: 原始文本。
|
||||
max_length: 最大保留长度。
|
||||
|
||||
Returns:
|
||||
str: 截断后的文本表示。
|
||||
"""
|
||||
if text is None:
|
||||
return "None"
|
||||
normalized_text = text.replace("\r", "\\r").replace("\n", "\\n")
|
||||
if len(normalized_text) <= max_length:
|
||||
return repr(normalized_text)
|
||||
return repr(f"{normalized_text[:max_length]}...")
|
||||
|
||||
def _summarize_component(self, component: StandardMessageComponents) -> str:
|
||||
"""生成单个消息组件的调试摘要。
|
||||
|
||||
Args:
|
||||
component: 消息组件对象。
|
||||
|
||||
Returns:
|
||||
str: 组件摘要文本。
|
||||
"""
|
||||
if isinstance(component, TextComponent):
|
||||
return f"Text(text={self._truncate_text(component.text, 80)})"
|
||||
if isinstance(component, ImageComponent):
|
||||
return f"Image(content={self._truncate_text(component.content or None, 60)})"
|
||||
if isinstance(component, EmojiComponent):
|
||||
return f"Emoji(content={self._truncate_text(component.content or None, 60)})"
|
||||
if isinstance(component, AtComponent):
|
||||
target_name = component.target_user_cardname or component.target_user_nickname or component.target_user_id
|
||||
return f"At(target={target_name!r})"
|
||||
if isinstance(component, VoiceComponent):
|
||||
return f"Voice(content={self._truncate_text(component.content or None, 60)})"
|
||||
if isinstance(component, ReplyComponent):
|
||||
sender_name = (
|
||||
component.target_message_sender_cardname
|
||||
or component.target_message_sender_nickname
|
||||
or component.target_message_sender_id
|
||||
)
|
||||
return (
|
||||
"Reply("
|
||||
f"target_message_id={component.target_message_id!r}, "
|
||||
f"target_sender={sender_name!r}, "
|
||||
f"target_content={self._truncate_text(component.target_message_content, 80)}"
|
||||
")"
|
||||
)
|
||||
if isinstance(component, ForwardNodeComponent):
|
||||
return f"ForwardNode(count={len(component.forward_components)})"
|
||||
return f"{component.__class__.__name__}"
|
||||
#便于调试的打印函数end
|
||||
|
||||
async def process(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -18,28 +18,29 @@ install(extra_lines=3)
|
||||
logger = get_logger("sender")
|
||||
|
||||
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
|
||||
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None
|
||||
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str], Optional[str]]] = None
|
||||
|
||||
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
|
||||
|
||||
# TODO: 重构完成后完成webui相关
|
||||
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]:
|
||||
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str], Optional[str]]:
|
||||
"""获取 WebUI 聊天室广播器。
|
||||
|
||||
Returns:
|
||||
Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 二元组;
|
||||
Tuple[Any, Optional[str], Optional[str]]: ``(chat_manager, platform_name, default_group_id)`` 三元组;
|
||||
若 WebUI 相关模块不可用,则元素会退化为 ``None``。
|
||||
"""
|
||||
global _webui_chat_broadcaster
|
||||
if _webui_chat_broadcaster is None:
|
||||
try:
|
||||
from src.webui.routers.chat import WEBUI_CHAT_PLATFORM, chat_manager
|
||||
from src.webui.routers.chat.service import WEBUI_CHAT_GROUP_ID
|
||||
|
||||
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
|
||||
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM, WEBUI_CHAT_GROUP_ID)
|
||||
except ImportError:
|
||||
_webui_chat_broadcaster = (None, None)
|
||||
_webui_chat_broadcaster = (None, None, None)
|
||||
return _webui_chat_broadcaster
|
||||
|
||||
|
||||
@@ -76,7 +77,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
|
||||
try:
|
||||
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
|
||||
chat_manager, webui_platform = get_webui_chat_broadcaster()
|
||||
chat_manager, webui_platform, default_group_id = get_webui_chat_broadcaster()
|
||||
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
|
||||
|
||||
if is_webui_message and chat_manager is not None:
|
||||
@@ -97,8 +98,9 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
message_type = "rich"
|
||||
segments = message_segments
|
||||
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
await chat_manager.broadcast_to_group(
|
||||
group_id=group_id or default_group_id or "",
|
||||
message={
|
||||
"type": "bot_message",
|
||||
"content": message.processed_plain_text,
|
||||
"message_type": message_type,
|
||||
@@ -110,7 +112,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
"avatar": None,
|
||||
"is_bot": True,
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库
|
||||
|
||||
@@ -35,7 +35,7 @@ from src.services import llm_service as llm_api
|
||||
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
|
||||
from src.learners.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon
|
||||
from src.learners.jargon_explainer_old import explain_jargon_in_context
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
|
||||
init_memory_retrieval_sys()
|
||||
@@ -688,39 +688,41 @@ class DefaultReplyer:
|
||||
return None
|
||||
|
||||
def get_chat_prompt_for_chat(self, chat_id: str) -> str:
|
||||
"""
|
||||
根据聊天流ID获取匹配的额外prompt(仅匹配group类型)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID(哈希值)
|
||||
|
||||
Returns:
|
||||
str: 匹配的额外prompt内容,如果没有匹配则返回空字符串
|
||||
"""
|
||||
if not global_config.experimental.chat_prompts:
|
||||
"""根据聊天流 ID 获取匹配的额外 prompt。"""
|
||||
if not global_config.chat.chat_prompts:
|
||||
return ""
|
||||
|
||||
for chat_prompt_str in global_config.experimental.chat_prompts:
|
||||
if not isinstance(chat_prompt_str, str):
|
||||
for chat_prompt_item in global_config.chat.chat_prompts:
|
||||
if hasattr(chat_prompt_item, "rule_type") and hasattr(chat_prompt_item, "prompt"):
|
||||
if str(chat_prompt_item.rule_type or "").strip() != "group":
|
||||
continue
|
||||
|
||||
config_chat_id = self._build_chat_uid(
|
||||
str(chat_prompt_item.platform or "").strip(),
|
||||
str(chat_prompt_item.item_id or "").strip(),
|
||||
True,
|
||||
)
|
||||
prompt_content = str(chat_prompt_item.prompt or "").strip()
|
||||
if config_chat_id == chat_id and prompt_content:
|
||||
logger.debug(f"匹配到群聊 prompt 配置,chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
|
||||
return prompt_content
|
||||
continue
|
||||
|
||||
# 解析配置字符串,检查类型是否为group
|
||||
parts = chat_prompt_str.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
if not isinstance(chat_prompt_item, str):
|
||||
continue
|
||||
|
||||
stream_type = parts[2]
|
||||
# 只匹配group类型
|
||||
if stream_type != "group":
|
||||
# 兼容旧格式的 platform:id:type:prompt 配置字符串。
|
||||
parts = chat_prompt_item.split(":", 3)
|
||||
if len(parts) != 4 or parts[2] != "group":
|
||||
continue
|
||||
|
||||
result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_str)
|
||||
result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_item)
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
config_chat_id, prompt_content = result
|
||||
if config_chat_id == chat_id:
|
||||
logger.debug(f"匹配到群聊prompt配置,chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
|
||||
logger.debug(f"匹配到群聊 prompt 配置,chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
|
||||
return prompt_content
|
||||
|
||||
return ""
|
||||
|
||||
453
src/chat/replyer/maisaka_generator_multi.py
Normal file
453
src/chat/replyer/maisaka_generator_multi.py
Normal file
@@ -0,0 +1,453 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import random
|
||||
import time
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.data_models.reply_generation_data_models import (
|
||||
GenerationMetrics,
|
||||
LLMCompletionResult,
|
||||
ReplyGenerationResult,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.config.config import global_config
|
||||
from src.core.types import ActionInfo
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.maisaka.context_messages import AssistantMessage, LLMContextMessage, ReferenceMessage, SessionBackedMessage, ToolResultMessage
|
||||
from src.maisaka.message_adapter import parse_speaker_content
|
||||
|
||||
logger = get_logger("replyer")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaisakaReplyContext:
|
||||
"""Maisaka replyer 使用的回复上下文。"""
|
||||
|
||||
expression_habits: str = ""
|
||||
selected_expression_ids: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ExpressionRecord:
|
||||
"""表达方式的轻量记录。"""
|
||||
|
||||
expression_id: Optional[int]
|
||||
situation: str
|
||||
style: str
|
||||
|
||||
|
||||
class MaisakaReplyGenerator:
|
||||
"""生成 Maisaka 的最终可见回复。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
request_type: str = "maisaka_replyer",
|
||||
) -> None:
|
||||
self.chat_stream = chat_stream
|
||||
self.request_type = request_type
|
||||
self.express_model = LLMServiceClient(
|
||||
task_name="replyer",
|
||||
request_type=request_type,
|
||||
)
|
||||
self._personality_prompt = self._build_personality_prompt()
|
||||
|
||||
def _build_personality_prompt(self) -> str:
|
||||
"""构建 replyer 使用的人设描述。"""
|
||||
try:
|
||||
bot_name = global_config.bot.nickname
|
||||
alias_names = global_config.bot.alias_names
|
||||
bot_aliases = f",也有人叫你{','.join(alias_names)}" if alias_names else ""
|
||||
|
||||
prompt_personality = global_config.personality.personality
|
||||
if (
|
||||
hasattr(global_config.personality, "states")
|
||||
and global_config.personality.states
|
||||
and hasattr(global_config.personality, "state_probability")
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
return f"你的名字是{bot_name}{bot_aliases},你{prompt_personality};"
|
||||
except Exception as exc:
|
||||
logger.warning(f"构建 Maisaka 人设提示词失败: {exc}")
|
||||
return "你的名字是麦麦,你是一个活泼可爱的 AI 助手。"
|
||||
|
||||
@staticmethod
|
||||
def _normalize_content(content: str, limit: int = 500) -> str:
|
||||
normalized = " ".join((content or "").split())
|
||||
if len(normalized) > limit:
|
||||
return normalized[:limit] + "..."
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _extract_visible_assistant_reply(message: AssistantMessage) -> str:
|
||||
del message
|
||||
return ""
|
||||
|
||||
def _extract_guided_bot_reply(self, message: SessionBackedMessage) -> str:
|
||||
speaker_name, body = parse_speaker_content(message.processed_plain_text.strip())
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
if speaker_name == bot_nickname:
|
||||
return self._normalize_content(body.strip())
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _split_user_message_segments(raw_content: str) -> List[tuple[Optional[str], str]]:
|
||||
"""按说话人拆分用户消息。"""
|
||||
segments: List[tuple[Optional[str], str]] = []
|
||||
current_speaker: Optional[str] = None
|
||||
current_lines: List[str] = []
|
||||
|
||||
for raw_line in raw_content.splitlines():
|
||||
speaker_name, content_body = parse_speaker_content(raw_line)
|
||||
if speaker_name is not None:
|
||||
if current_lines:
|
||||
segments.append((current_speaker, "\n".join(current_lines)))
|
||||
current_speaker = speaker_name
|
||||
current_lines = [content_body]
|
||||
continue
|
||||
|
||||
current_lines.append(raw_line)
|
||||
|
||||
if current_lines:
|
||||
segments.append((current_speaker, "\n".join(current_lines)))
|
||||
|
||||
return segments
|
||||
|
||||
def _build_system_prompt(
|
||||
self,
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
) -> str:
|
||||
"""构建 Maisaka replyer 使用的系统提示词。"""
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
try:
|
||||
system_prompt = load_prompt(
|
||||
"maisaka_replyer",
|
||||
bot_name=global_config.bot.nickname,
|
||||
time_block=f"当前时间:{current_time}",
|
||||
identity=self._personality_prompt,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
)
|
||||
except Exception:
|
||||
system_prompt = "你是一个友好的 AI 助手,请根据聊天记录自然回复。"
|
||||
|
||||
extra_sections: List[str] = []
|
||||
if expression_habits.strip():
|
||||
extra_sections.append(expression_habits.strip())
|
||||
if reply_reason.strip():
|
||||
extra_sections.append(f"【回复信息参考】\n{reply_reason}")
|
||||
|
||||
if not extra_sections:
|
||||
return system_prompt
|
||||
return f"{system_prompt}\n\n" + "\n\n".join(extra_sections)
|
||||
|
||||
def _build_reply_instruction(self) -> str:
|
||||
"""构建追加在上下文末尾的回复指令。"""
|
||||
return "请基于以上逐条对话消息,自然地继续回复。直接输出你要说的话,不要额外解释。"
|
||||
|
||||
def _build_history_messages(self, chat_history: List[LLMContextMessage]) -> List[Message]:
|
||||
"""将 replyer 上下文拆成多条 LLM 消息。"""
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
default_user_name = global_config.maisaka.user_name.strip() or "User"
|
||||
messages: List[Message] = []
|
||||
|
||||
for message in chat_history:
|
||||
if isinstance(message, (ReferenceMessage, ToolResultMessage)):
|
||||
continue
|
||||
|
||||
if isinstance(message, SessionBackedMessage):
|
||||
guided_reply = self._extract_guided_bot_reply(message)
|
||||
if guided_reply:
|
||||
messages.append(
|
||||
MessageBuilder().set_role(RoleType.Assistant).add_text_content(guided_reply).build()
|
||||
)
|
||||
continue
|
||||
|
||||
for speaker_name, content_body in self._split_user_message_segments(message.processed_plain_text):
|
||||
content = self._normalize_content(content_body)
|
||||
if not content:
|
||||
continue
|
||||
|
||||
visible_speaker = speaker_name or default_user_name
|
||||
if visible_speaker == bot_nickname:
|
||||
messages.append(
|
||||
MessageBuilder().set_role(RoleType.Assistant).add_text_content(content).build()
|
||||
)
|
||||
continue
|
||||
|
||||
user_content = f"[{visible_speaker}]{content}"
|
||||
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(user_content).build())
|
||||
continue
|
||||
|
||||
if isinstance(message, AssistantMessage):
|
||||
visible_reply = self._extract_visible_assistant_reply(message)
|
||||
if visible_reply:
|
||||
messages.append(
|
||||
MessageBuilder().set_role(RoleType.Assistant).add_text_content(visible_reply).build()
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
def _build_request_messages(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
) -> List[Message]:
|
||||
"""构建发给大模型的消息列表。"""
|
||||
messages: List[Message] = []
|
||||
system_prompt = self._build_system_prompt(
|
||||
reply_reason=reply_reason,
|
||||
expression_habits=expression_habits,
|
||||
)
|
||||
instruction = self._build_reply_instruction()
|
||||
|
||||
messages.append(MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build())
|
||||
messages.extend(self._build_history_messages(chat_history))
|
||||
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(instruction).build())
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _build_request_prompt_preview(messages: List[Message]) -> str:
|
||||
"""将消息列表转为便于调试的文本预览。"""
|
||||
preview_lines: List[str] = []
|
||||
for message in messages:
|
||||
role_name = message.role.value.capitalize()
|
||||
preview_lines.append(f"{role_name}: {message.get_text_content()}")
|
||||
return "\n\n".join(preview_lines)
|
||||
|
||||
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
|
||||
"""解析当前回复使用的会话 ID。"""
|
||||
if stream_id:
|
||||
return stream_id
|
||||
if self.chat_stream is not None:
|
||||
return self.chat_stream.session_id
|
||||
return ""
|
||||
|
||||
async def _build_reply_context(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
stream_id: Optional[str],
|
||||
) -> MaisakaReplyContext:
|
||||
"""在 replyer 内部构建表达习惯和黑话解释。"""
|
||||
session_id = self._resolve_session_id(stream_id)
|
||||
if not session_id:
|
||||
logger.warning("构建 Maisaka 回复上下文失败:缺少会话标识")
|
||||
return MaisakaReplyContext()
|
||||
|
||||
expression_habits, selected_expression_ids = self._build_expression_habits(
|
||||
session_id=session_id,
|
||||
chat_history=chat_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
)
|
||||
return MaisakaReplyContext(
|
||||
expression_habits=expression_habits,
|
||||
selected_expression_ids=selected_expression_ids,
|
||||
)
|
||||
|
||||
def _build_expression_habits(
|
||||
self,
|
||||
session_id: str,
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
) -> tuple[str, List[int]]:
|
||||
"""查询并格式化适合当前会话的表达习惯。"""
|
||||
del chat_history
|
||||
del reply_message
|
||||
del reply_reason
|
||||
|
||||
expression_records = self._load_expression_records(session_id)
|
||||
if not expression_records:
|
||||
return "", []
|
||||
|
||||
lines: List[str] = []
|
||||
selected_ids: List[int] = []
|
||||
for expression in expression_records:
|
||||
if expression.expression_id is not None:
|
||||
selected_ids.append(expression.expression_id)
|
||||
lines.append(f"- 当{expression.situation}时,可以自然地用{expression.style}这种表达习惯。")
|
||||
|
||||
block = "【表达习惯参考】\n" + "\n".join(lines)
|
||||
logger.info(
|
||||
f"已构建 Maisaka 表达习惯: 会话标识={session_id} "
|
||||
f"数量={len(selected_ids)} 表达编号={selected_ids!r}"
|
||||
)
|
||||
return block, selected_ids
|
||||
|
||||
def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]:
|
||||
"""提取表达方式静态数据,避免 detached ORM 对象。"""
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
||||
if global_config.expression.expression_checked_only:
|
||||
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
||||
|
||||
query = query.where(
|
||||
(Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
|
||||
|
||||
expressions = session.exec(query.limit(5)).all()
|
||||
return [
|
||||
_ExpressionRecord(
|
||||
expression_id=expression.id,
|
||||
situation=expression.situation,
|
||||
style=expression.style,
|
||||
)
|
||||
for expression in expressions
|
||||
]
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
chosen_actions: Optional[List[object]] = None,
|
||||
from_plugin: bool = True,
|
||||
stream_id: Optional[str] = None,
|
||||
reply_message: Optional[SessionMessage] = None,
|
||||
reply_time_point: Optional[float] = None,
|
||||
think_level: int = 1,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
log_reply: bool = True,
|
||||
chat_history: Optional[List[LLMContextMessage]] = None,
|
||||
expression_habits: str = "",
|
||||
selected_expression_ids: Optional[List[int]] = None,
|
||||
) -> Tuple[bool, ReplyGenerationResult]:
|
||||
"""结合上下文生成 Maisaka 的最终可见回复。"""
|
||||
del available_actions
|
||||
del chosen_actions
|
||||
del extra_info
|
||||
del from_plugin
|
||||
del log_reply
|
||||
del reply_time_point
|
||||
del think_level
|
||||
del unknown_words
|
||||
|
||||
result = ReplyGenerationResult()
|
||||
if chat_history is None:
|
||||
result.error_message = "聊天历史为空"
|
||||
return False, result
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复器开始生成: 会话流标识={stream_id} 回复原因={reply_reason!r} "
|
||||
f"历史消息数={len(chat_history)} 目标消息编号="
|
||||
f"{reply_message.message_id if reply_message else None}"
|
||||
)
|
||||
|
||||
filtered_history = [
|
||||
message
|
||||
for message in chat_history
|
||||
if not isinstance(message, (ReferenceMessage, ToolResultMessage))
|
||||
]
|
||||
|
||||
logger.debug(f"Maisaka 回复器过滤后历史消息数={len(filtered_history)}")
|
||||
|
||||
# Validate that express_model is properly initialized
|
||||
if self.express_model is None:
|
||||
logger.error("Maisaka 回复器的回复模型未初始化")
|
||||
result.error_message = "回复模型尚未初始化"
|
||||
return False, result
|
||||
|
||||
try:
|
||||
reply_context = await self._build_reply_context(
|
||||
chat_history=filtered_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
stream_id=stream_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
logger.error(f"Maisaka 回复器构建回复上下文失败: {exc}\n{traceback.format_exc()}")
|
||||
result.error_message = f"构建回复上下文失败: {exc}"
|
||||
return False, result
|
||||
|
||||
merged_expression_habits = expression_habits.strip() or reply_context.expression_habits
|
||||
result.selected_expression_ids = (
|
||||
list(selected_expression_ids)
|
||||
if selected_expression_ids is not None
|
||||
else list(reply_context.selected_expression_ids)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复上下文构建完成: 会话流标识={stream_id} "
|
||||
f"已选表达编号={result.selected_expression_ids!r}"
|
||||
)
|
||||
|
||||
try:
|
||||
request_messages = self._build_request_messages(
|
||||
chat_history=filtered_history,
|
||||
reply_reason=reply_reason or "",
|
||||
expression_habits=merged_expression_habits,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
logger.error(f"Maisaka 回复器构建提示词失败: {exc}\n{traceback.format_exc()}")
|
||||
result.error_message = f"构建提示词失败: {exc}"
|
||||
return False, result
|
||||
|
||||
prompt_preview = self._build_request_prompt_preview(request_messages)
|
||||
|
||||
def message_factory(_client: object) -> List[Message]:
|
||||
return request_messages
|
||||
|
||||
result.completion.request_prompt = prompt_preview
|
||||
|
||||
if global_config.debug.show_replyer_prompt:
|
||||
logger.info(f"\nMaisaka 回复器提示词:\n{prompt_preview}\n")
|
||||
|
||||
started_at = time.perf_counter()
|
||||
try:
|
||||
generation_result = await self.express_model.generate_response_with_messages(message_factory=message_factory)
|
||||
except Exception as exc:
|
||||
logger.exception("Maisaka 回复器调用失败")
|
||||
result.error_message = str(exc)
|
||||
result.metrics = GenerationMetrics(
|
||||
overall_ms=round((time.perf_counter() - started_at) * 1000, 2),
|
||||
)
|
||||
return False, result
|
||||
|
||||
response_text = (generation_result.response or "").strip()
|
||||
result.success = bool(response_text)
|
||||
result.completion = LLMCompletionResult(
|
||||
request_prompt=prompt_preview,
|
||||
response_text=response_text,
|
||||
reasoning_text=generation_result.reasoning or "",
|
||||
model_name=generation_result.model_name or "",
|
||||
tool_calls=generation_result.tool_calls or [],
|
||||
)
|
||||
result.metrics = GenerationMetrics(
|
||||
overall_ms=round((time.perf_counter() - started_at) * 1000, 2),
|
||||
)
|
||||
|
||||
if global_config.debug.show_replyer_reasoning and result.completion.reasoning_text:
|
||||
logger.info(f"Maisaka 回复器思考内容:\n{result.completion.reasoning_text}")
|
||||
|
||||
if not result.success:
|
||||
result.error_message = "回复器返回了空内容"
|
||||
logger.warning("Maisaka 回复器返回了空内容")
|
||||
return False, result
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复器生成成功: 回复文本={response_text!r} "
|
||||
f"总耗时毫秒={result.metrics.overall_ms} "
|
||||
f"已选表达编号={result.selected_expression_ids!r}"
|
||||
)
|
||||
result.text_fragments = [response_text]
|
||||
return True, result
|
||||
21
src/chat/replyer/maisaka_replyer_factory.py
Normal file
21
src/chat/replyer/maisaka_replyer_factory.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Type
|
||||
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
def get_maisaka_replyer_class() -> Type[object]:
|
||||
"""根据配置返回 Maisaka replyer 类。"""
|
||||
generator_type = global_config.maisaka.replyer_generator_type
|
||||
if generator_type == "multi":
|
||||
from .maisaka_generator_multi import MaisakaReplyGenerator
|
||||
|
||||
return MaisakaReplyGenerator
|
||||
|
||||
from .maisaka_generator import MaisakaReplyGenerator
|
||||
|
||||
return MaisakaReplyGenerator
|
||||
|
||||
|
||||
def get_maisaka_replyer_generator_type() -> str:
|
||||
"""返回当前配置的 Maisaka replyer 生成器类型。"""
|
||||
return global_config.maisaka.replyer_generator_type
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,11 +1,14 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
from src.chat.replyer.maisaka_replyer_factory import (
|
||||
get_maisaka_replyer_class,
|
||||
get_maisaka_replyer_generator_type,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
|
||||
logger = get_logger("ReplyerManager")
|
||||
@@ -23,14 +26,15 @@ class ReplyerManager:
|
||||
chat_id: Optional[str] = None,
|
||||
request_type: str = "replyer",
|
||||
replyer_type: str = "default",
|
||||
) -> Optional["DefaultReplyer | MaisakaReplyGenerator | PrivateReplyer"]:
|
||||
) -> Optional["DefaultReplyer | PrivateReplyer | Any"]:
|
||||
"""按会话和 replyer 类型获取实例。"""
|
||||
stream_id = chat_stream.session_id if chat_stream else chat_id
|
||||
if not stream_id:
|
||||
logger.warning("[ReplyerManager] 缺少 stream_id,无法获取 replyer")
|
||||
return None
|
||||
|
||||
cache_key = f"{replyer_type}:{stream_id}"
|
||||
generator_type = get_maisaka_replyer_generator_type() if replyer_type == "maisaka" else ""
|
||||
cache_key = f"{replyer_type}:{generator_type}:{stream_id}"
|
||||
if cache_key in self._repliers:
|
||||
logger.info(f"[ReplyerManager] 命中缓存 replyer: cache_key={cache_key}")
|
||||
return self._repliers[cache_key]
|
||||
@@ -47,10 +51,10 @@ class ReplyerManager:
|
||||
|
||||
try:
|
||||
if replyer_type == "maisaka":
|
||||
logger.info("[ReplyerManager] importing MaisakaReplyGenerator")
|
||||
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
|
||||
logger.info(f"[ReplyerManager] 选择 MaisakaReplyGenerator: generator_type={generator_type}")
|
||||
maisaka_replyer_class = get_maisaka_replyer_class()
|
||||
|
||||
replyer = MaisakaReplyGenerator(
|
||||
replyer = maisaka_replyer_class(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,6 @@ from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages, ModelUsage, OnlineTime, ToolRecord
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("maibot_statistic")
|
||||
|
||||
|
||||
@@ -3,41 +3,22 @@ MaiSaka CLI and conversation loop.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
from rich import box
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from src.know_u.knowledge import KnowledgeLearner, retrieve_relevant_knowledge
|
||||
from src.know_u.knowledge_store import get_knowledge_store
|
||||
from src.chat.heart_flow.heartflow_manager import heartflow_manager
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
|
||||
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.mcp_module import MCPManager
|
||||
from src.mcp_module.host_llm_bridge import MCPHostLLMBridge
|
||||
|
||||
from src.maisaka.chat_loop_service import MaisakaChatLoopService
|
||||
from src.maisaka.context_messages import (
|
||||
AssistantMessage,
|
||||
LLMContextMessage,
|
||||
SessionBackedMessage,
|
||||
ToolResultMessage,
|
||||
)
|
||||
from src.maisaka.message_adapter import format_speaker_content
|
||||
from src.maisaka.tool_handlers import (
|
||||
ToolHandlerContext,
|
||||
handle_mcp_tool,
|
||||
handle_stop,
|
||||
handle_unknown_tool,
|
||||
handle_wait,
|
||||
)
|
||||
|
||||
from .maisaka_cli_sender import CLI_PLATFORM_NAME
|
||||
from .console import console
|
||||
from .input_reader import InputReader
|
||||
|
||||
@@ -45,41 +26,13 @@ from .input_reader import InputReader
|
||||
class BufferCLI:
|
||||
"""Maisaka 命令行交互入口。"""
|
||||
|
||||
_CLI_PLATFORM = CLI_PLATFORM_NAME
|
||||
_CLI_USER_ID = "maisaka_user"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._chat_loop_service: Optional[MaisakaChatLoopService] = None
|
||||
self._reply_generator = MaisakaReplyGenerator()
|
||||
self._reader = InputReader()
|
||||
self._chat_history: Optional[list[LLMContextMessage]] = None
|
||||
self._knowledge_store = get_knowledge_store()
|
||||
self._knowledge_learner = KnowledgeLearner("maisaka_cli")
|
||||
self._knowledge_min_messages_for_extraction = 10
|
||||
self._knowledge_min_extraction_interval = 30
|
||||
self._last_knowledge_extraction_time = 0.0
|
||||
|
||||
knowledge_stats = self._knowledge_store.get_stats()
|
||||
if knowledge_stats["total_items"] > 0:
|
||||
console.print(f"[success]知识库中已有 {knowledge_stats['total_items']} 条数据[/success]")
|
||||
else:
|
||||
console.print("[muted]知识库已初始化,当前没有数据[/muted]")
|
||||
|
||||
self._chat_start_time: Optional[datetime] = None
|
||||
self._last_user_input_time: Optional[datetime] = None
|
||||
self._last_assistant_response_time: Optional[datetime] = None
|
||||
self._user_input_times: list[datetime] = []
|
||||
self._mcp_manager: Optional[MCPManager] = None
|
||||
self._mcp_host_bridge: Optional[MCPHostLLMBridge] = None
|
||||
self._init_llm()
|
||||
|
||||
def _init_llm(self) -> None:
|
||||
"""初始化 Maisaka 使用的聊天服务。"""
|
||||
thinking_env = os.getenv("ENABLE_THINKING", "").strip().lower()
|
||||
enable_thinking: Optional[bool] = True if thinking_env == "true" else False if thinking_env == "false" else None
|
||||
|
||||
_ = enable_thinking
|
||||
self._chat_loop_service = MaisakaChatLoopService()
|
||||
|
||||
model_name = self._get_current_model_name()
|
||||
console.print(f"[success]大模型服务已初始化[/success] [muted](模型: {model_name})[/muted]")
|
||||
self._message_receiver = HeartFCMessageReceiver()
|
||||
self._session: BotChatSession | None = None
|
||||
|
||||
@staticmethod
|
||||
def _get_current_model_name() -> str:
|
||||
@@ -92,354 +45,59 @@ class BufferCLI:
|
||||
pass
|
||||
return "未配置"
|
||||
|
||||
def _build_tool_context(self) -> ToolHandlerContext:
|
||||
"""构建工具处理的共享上下文。"""
|
||||
tool_context = ToolHandlerContext(
|
||||
reader=self._reader,
|
||||
user_input_times=self._user_input_times,
|
||||
)
|
||||
tool_context.last_user_input_time = self._last_user_input_time
|
||||
return tool_context
|
||||
|
||||
def _show_banner(self) -> None:
|
||||
"""渲染启动横幅。"""
|
||||
banner = Text()
|
||||
banner.append("MaiSaka", style="bold cyan")
|
||||
banner.append(" v2.0\n", style="muted")
|
||||
banner.append(f"模型: {self._get_current_model_name()}\n", style="muted")
|
||||
banner.append("输入内容开始对话 | Ctrl+C 退出", style="muted")
|
||||
|
||||
console.print(Panel(banner, box=box.DOUBLE_EDGE, border_style="cyan", padding=(1, 2)))
|
||||
console.print()
|
||||
|
||||
async def _start_chat(self, user_text: str) -> None:
|
||||
"""追加用户输入并继续内部循环。"""
|
||||
if self._chat_loop_service is None:
|
||||
console.print("[warning]大模型服务尚未初始化,已跳过本次对话。[/warning]")
|
||||
return
|
||||
|
||||
now = datetime.now()
|
||||
self._last_user_input_time = now
|
||||
self._user_input_times.append(now)
|
||||
|
||||
if self._chat_history is None:
|
||||
self._chat_start_time = now
|
||||
self._last_assistant_response_time = None
|
||||
self._chat_history = self._chat_loop_service.build_chat_context(user_text)
|
||||
self._trigger_knowledge_learning([self._build_cli_session_message(user_text, now)])
|
||||
else:
|
||||
self._chat_history.append(
|
||||
self._build_cli_context_message(
|
||||
user_text=user_text,
|
||||
timestamp=now,
|
||||
source_kind="user",
|
||||
)
|
||||
)
|
||||
self._trigger_knowledge_learning([self._build_cli_session_message(user_text, now)])
|
||||
|
||||
await self._run_llm_loop(self._chat_history)
|
||||
|
||||
@staticmethod
|
||||
def _build_cli_context_message(
|
||||
def _build_cli_session_message(
|
||||
*,
|
||||
user_text: str,
|
||||
timestamp: datetime,
|
||||
source_kind: str = "user",
|
||||
speaker_name: Optional[str] = None,
|
||||
) -> SessionBackedMessage:
|
||||
"""为 CLI 构造新的上下文消息。"""
|
||||
resolved_speaker_name = speaker_name or global_config.maisaka.user_name.strip() or "用户"
|
||||
visible_text = format_speaker_content(
|
||||
resolved_speaker_name,
|
||||
user_text,
|
||||
timestamp,
|
||||
)
|
||||
planner_prefix = (
|
||||
f"[时间]{timestamp.strftime('%H:%M:%S')}\n"
|
||||
f"[用户]{resolved_speaker_name}\n"
|
||||
"[用户群昵称]\n"
|
||||
"[msg_id]\n"
|
||||
"[发言内容]"
|
||||
)
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
|
||||
return SessionBackedMessage(
|
||||
raw_message=MessageSequence([TextComponent(f"{planner_prefix}{user_text}")]),
|
||||
visible_text=visible_text,
|
||||
) -> SessionMessage:
|
||||
"""构造一条供 heartflow 复用的 CLI 用户消息。"""
|
||||
message = SessionMessage(
|
||||
message_id=f"maisaka_cli_{int(timestamp.timestamp() * 1000)}",
|
||||
timestamp=timestamp,
|
||||
source_kind=source_kind,
|
||||
platform=BufferCLI._CLI_PLATFORM,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_cli_session_message(user_text: str, timestamp: datetime) -> SessionMessage:
|
||||
"""为 CLI 的知识学习构造兼容 SessionMessage。"""
|
||||
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
|
||||
from src.common.data_models.message_component_data_model import MessageSequence
|
||||
|
||||
message = SessionMessage(message_id=f"maisaka_cli_{int(timestamp.timestamp() * 1000)}", timestamp=timestamp, platform="maisaka")
|
||||
user_name = global_config.maisaka.user_name.strip() or "用户"
|
||||
message.message_info = MessageInfo(
|
||||
user_info=UserInfo(
|
||||
user_id="maisaka_user",
|
||||
user_nickname=global_config.maisaka.user_name.strip() or "用户",
|
||||
user_id=BufferCLI._CLI_USER_ID,
|
||||
user_nickname=user_name,
|
||||
user_cardname=None,
|
||||
),
|
||||
group_info=None,
|
||||
additional_config={},
|
||||
)
|
||||
message.session_id = "maisaka_cli"
|
||||
message.raw_message = MessageSequence([])
|
||||
visible_text = format_speaker_content(
|
||||
global_config.maisaka.user_name.strip() or "用户",
|
||||
user_text,
|
||||
timestamp,
|
||||
)
|
||||
message.raw_message.text(visible_text)
|
||||
message.processed_plain_text = visible_text
|
||||
message.display_message = visible_text
|
||||
message.raw_message = MessageSequence([TextComponent(text=user_text)])
|
||||
message.processed_plain_text = user_text
|
||||
message.display_message = user_text
|
||||
message.initialized = True
|
||||
return message
|
||||
|
||||
def _trigger_knowledge_learning(self, messages: list[SessionMessage]) -> None:
|
||||
"""在 CLI 会话中按批次触发 knowledge 学习。"""
|
||||
if not global_config.maisaka.enable_knowledge_module:
|
||||
return
|
||||
|
||||
self._knowledge_learner.add_messages(messages)
|
||||
|
||||
elapsed = time.monotonic() - self._last_knowledge_extraction_time
|
||||
if elapsed < self._knowledge_min_extraction_interval:
|
||||
return
|
||||
|
||||
cache_size = self._knowledge_learner.get_cache_size()
|
||||
if cache_size < self._knowledge_min_messages_for_extraction:
|
||||
return
|
||||
|
||||
self._last_knowledge_extraction_time = time.monotonic()
|
||||
asyncio.create_task(self._run_knowledge_learning())
|
||||
|
||||
async def _run_knowledge_learning(self) -> None:
|
||||
"""后台执行 knowledge 学习,避免阻塞主对话。"""
|
||||
try:
|
||||
added_count = await self._knowledge_learner.learn()
|
||||
if added_count > 0 and global_config.maisaka.show_thinking:
|
||||
console.print(f"[muted]知识学习已完成,新增 {added_count} 条数据。[/muted]")
|
||||
except Exception as exc:
|
||||
console.print(f"[warning]知识学习失败:{exc}[/warning]")
|
||||
|
||||
async def _run_llm_loop(self, chat_history: list[LLMContextMessage]) -> None:
|
||||
"""
|
||||
Main inner loop for the Maisaka planner.
|
||||
|
||||
Each round may produce internal thoughts and optionally call tools:
|
||||
- reply(msg_id): generate a visible reply for the current round
|
||||
- no_reply(): skip visible output and continue the loop
|
||||
- wait(seconds): wait for new user input
|
||||
- stop(): stop the current inner loop and return to idle
|
||||
"""
|
||||
if self._chat_loop_service is None:
|
||||
return
|
||||
|
||||
consecutive_errors = 0
|
||||
last_had_tool_calls = True
|
||||
|
||||
while True:
|
||||
if last_had_tool_calls:
|
||||
tasks = []
|
||||
status_text_parts = []
|
||||
|
||||
if global_config.maisaka.enable_knowledge_module:
|
||||
tasks.append(("knowledge", retrieve_relevant_knowledge(self._chat_loop_service, chat_history)))
|
||||
status_text_parts.append("知识库")
|
||||
|
||||
with console.status(
|
||||
f"[info]{' + '.join(status_text_parts)} 分析中...[/info]",
|
||||
spinner="dots",
|
||||
):
|
||||
results = await asyncio.gather(*[task for _, task in tasks], return_exceptions=True)
|
||||
|
||||
knowledge_analysis = ""
|
||||
if global_config.maisaka.enable_knowledge_module:
|
||||
knowledge_result = results[0] if results else None
|
||||
if isinstance(knowledge_result, Exception):
|
||||
console.print(f"[warning]知识分析失败:{knowledge_result}[/warning]")
|
||||
elif isinstance(knowledge_result, str) and knowledge_result.strip():
|
||||
knowledge_analysis = knowledge_result
|
||||
if global_config.maisaka.show_thinking:
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(knowledge_analysis),
|
||||
title="知识",
|
||||
border_style="bright_magenta",
|
||||
padding=(0, 1),
|
||||
style="dim",
|
||||
)
|
||||
)
|
||||
|
||||
if chat_history and isinstance(chat_history[-1], AssistantMessage) and chat_history[-1].source == "perception":
|
||||
chat_history.pop()
|
||||
|
||||
perception_parts = []
|
||||
if knowledge_analysis:
|
||||
perception_parts.append(f"知识库\n{knowledge_analysis}")
|
||||
|
||||
if perception_parts:
|
||||
chat_history.append(
|
||||
AssistantMessage(
|
||||
content="\n\n".join(perception_parts),
|
||||
timestamp=datetime.now(),
|
||||
source_kind="perception",
|
||||
)
|
||||
)
|
||||
elif global_config.maisaka.show_thinking:
|
||||
console.print("[muted]上一轮没有使用工具,本轮跳过模块分析。[/muted]")
|
||||
|
||||
with console.status("[info]正在思考...[/info]", spinner="dots"):
|
||||
try:
|
||||
response = await self._chat_loop_service.chat_loop_step(chat_history)
|
||||
consecutive_errors = 0
|
||||
except Exception as exc:
|
||||
consecutive_errors += 1
|
||||
console.print(f"[error]大模型调用失败:{exc}[/error]")
|
||||
if consecutive_errors >= 3:
|
||||
console.print("[error]连续失败次数过多,结束对话。[/error]\n")
|
||||
break
|
||||
continue
|
||||
|
||||
chat_history.append(response.raw_message)
|
||||
self._last_assistant_response_time = datetime.now()
|
||||
|
||||
if global_config.maisaka.show_thinking and response.content:
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(response.content),
|
||||
title="思考",
|
||||
border_style="dim",
|
||||
padding=(1, 2),
|
||||
style="dim",
|
||||
)
|
||||
)
|
||||
|
||||
if response.content and not response.tool_calls:
|
||||
last_had_tool_calls = False
|
||||
continue
|
||||
|
||||
if not response.tool_calls:
|
||||
last_had_tool_calls = False
|
||||
continue
|
||||
|
||||
should_stop = False
|
||||
tool_context = self._build_tool_context()
|
||||
|
||||
for tool_call in response.tool_calls:
|
||||
if tool_call.func_name == "stop":
|
||||
await handle_stop(tool_call, chat_history)
|
||||
should_stop = True
|
||||
|
||||
elif tool_call.func_name == "reply":
|
||||
reply = await self._generate_visible_reply(chat_history, response.content or "")
|
||||
chat_history.append(
|
||||
ToolResultMessage(
|
||||
content="已生成并记录可见回复。",
|
||||
timestamp=datetime.now(),
|
||||
tool_call_id=tool_call.call_id,
|
||||
tool_name=tool_call.func_name,
|
||||
)
|
||||
)
|
||||
chat_history.append(
|
||||
self._build_cli_context_message(
|
||||
user_text=reply,
|
||||
timestamp=datetime.now(),
|
||||
source_kind="guided_reply",
|
||||
speaker_name=global_config.bot.nickname.strip() or "MaiSaka",
|
||||
)
|
||||
)
|
||||
|
||||
elif tool_call.func_name == "no_reply":
|
||||
if global_config.maisaka.show_thinking:
|
||||
console.print("[muted]本轮未发送可见回复。[/muted]")
|
||||
chat_history.append(
|
||||
ToolResultMessage(
|
||||
content="本轮未发送可见回复。",
|
||||
timestamp=datetime.now(),
|
||||
tool_call_id=tool_call.call_id,
|
||||
tool_name=tool_call.func_name,
|
||||
)
|
||||
)
|
||||
|
||||
elif tool_call.func_name == "wait":
|
||||
tool_result = await handle_wait(tool_call, chat_history, tool_context)
|
||||
if tool_context.last_user_input_time != self._last_user_input_time:
|
||||
self._last_user_input_time = tool_context.last_user_input_time
|
||||
if tool_result.startswith("[[QUIT]]"):
|
||||
should_stop = True
|
||||
|
||||
elif self._mcp_manager and self._mcp_manager.is_mcp_tool(tool_call.func_name):
|
||||
await handle_mcp_tool(tool_call, chat_history, self._mcp_manager)
|
||||
|
||||
else:
|
||||
await handle_unknown_tool(tool_call, chat_history)
|
||||
|
||||
if should_stop:
|
||||
console.print("[muted]对话已暂停,等待新的输入...[/muted]\n")
|
||||
break
|
||||
|
||||
last_had_tool_calls = True
|
||||
|
||||
async def _init_mcp(self) -> None:
|
||||
"""初始化 MCP 服务并注册暴露的工具。"""
|
||||
self._mcp_host_bridge = MCPHostLLMBridge(
|
||||
sampling_task_name=global_config.mcp.client.sampling.task_name,
|
||||
async def _dispatch_input(self, user_text: str) -> None:
|
||||
"""将 CLI 输入转发到 heartflow 路径。"""
|
||||
message = self._build_cli_session_message(
|
||||
user_text=user_text,
|
||||
timestamp=datetime.now(),
|
||||
)
|
||||
self._mcp_manager = await MCPManager.from_app_config(
|
||||
global_config.mcp,
|
||||
host_callbacks=self._mcp_host_bridge.build_callbacks(),
|
||||
chat_manager.register_message(message)
|
||||
self._session = await chat_manager.get_or_create_session(
|
||||
platform=self._CLI_PLATFORM,
|
||||
user_id=self._CLI_USER_ID,
|
||||
)
|
||||
|
||||
if self._mcp_manager and self._chat_loop_service:
|
||||
mcp_tools = self._mcp_manager.get_openai_tools()
|
||||
if mcp_tools:
|
||||
self._chat_loop_service.set_extra_tools(mcp_tools)
|
||||
summary = self._mcp_manager.get_feature_summary()
|
||||
console.print(
|
||||
Panel(
|
||||
f"已加载 {len(mcp_tools)} 个 MCP 工具。\n{summary}",
|
||||
title="MCP 能力",
|
||||
border_style="green",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
async def _generate_visible_reply(self, chat_history: list[LLMContextMessage], latest_thought: str) -> str:
|
||||
"""根据最新思考生成并输出可见回复。"""
|
||||
if not latest_thought:
|
||||
return ""
|
||||
|
||||
with console.status("[info]正在生成可见回复...[/info]", spinner="dots"):
|
||||
success, result = await self._reply_generator.generate_reply_with_context(
|
||||
reply_reason=latest_thought,
|
||||
chat_history=chat_history,
|
||||
)
|
||||
if success and result.text_fragments:
|
||||
reply = result.text_fragments[0]
|
||||
else:
|
||||
reply = "..."
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(reply),
|
||||
title="MaiSaka",
|
||||
border_style="magenta",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
return reply
|
||||
await self._message_receiver.process_message(message)
|
||||
|
||||
async def run(self) -> None:
|
||||
"""主交互循环。"""
|
||||
if global_config.mcp.enable:
|
||||
await self._init_mcp()
|
||||
else:
|
||||
console.print("[muted]MCP 已禁用(mcp.enable=false)[/muted]")
|
||||
|
||||
self._reader.start(asyncio.get_event_loop())
|
||||
self._show_banner()
|
||||
|
||||
@@ -447,17 +105,17 @@ class BufferCLI:
|
||||
while True:
|
||||
console.print("[bold cyan]> [/bold cyan]", end="")
|
||||
raw_input = await self._reader.get_line()
|
||||
|
||||
if raw_input is None:
|
||||
console.print("\n[muted]再见![/muted]")
|
||||
console.print("\n[muted]再见[/muted]")
|
||||
break
|
||||
|
||||
raw_input = raw_input.strip()
|
||||
if not raw_input:
|
||||
user_text = raw_input.strip()
|
||||
if not user_text:
|
||||
continue
|
||||
|
||||
await self._start_chat(raw_input)
|
||||
await self._dispatch_input(user_text)
|
||||
finally:
|
||||
if self._mcp_manager:
|
||||
await self._mcp_manager.close()
|
||||
self._mcp_host_bridge = None
|
||||
if self._session is not None:
|
||||
runtime = heartflow_manager.heartflow_chat_list.pop(self._session.session_id, None)
|
||||
if runtime is not None:
|
||||
await runtime.stop()
|
||||
|
||||
27
src/cli/maisaka_cli_sender.py
Normal file
27
src/cli/maisaka_cli_sender.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Maisaka CLI 展示适配。"""
|
||||
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
from .console import console
|
||||
|
||||
CLI_PLATFORM_NAME = "maisaka_cli"
|
||||
|
||||
logger = get_logger("maisaka_cli_sender")
|
||||
|
||||
|
||||
def render_cli_message(content: str, *, title: str = "") -> None:
|
||||
"""将 CLI 私聊实例的消息展示到终端。"""
|
||||
preview_text = content.strip() or "..."
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(preview_text),
|
||||
title=title or global_config.bot.nickname.strip() or "MaiSaka",
|
||||
border_style="magenta",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
logger.info(f"[CLI] 已将消息输出到终端: content={preview_text!r}")
|
||||
@@ -21,7 +21,6 @@ from .official_configs import (
|
||||
DatabaseConfig,
|
||||
DebugConfig,
|
||||
EmojiConfig,
|
||||
ExperimentalConfig,
|
||||
ExpressionConfig,
|
||||
KeywordReactionConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
@@ -56,7 +55,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0"
|
||||
CONFIG_VERSION: str = "8.2.0"
|
||||
CONFIG_VERSION: str = "8.3.0"
|
||||
MODEL_CONFIG_VERSION: str = "1.13.1"
|
||||
|
||||
logger = get_logger("config")
|
||||
@@ -113,13 +112,10 @@ class Config(ConfigBase):
|
||||
debug: DebugConfig = Field(default_factory=DebugConfig)
|
||||
"""调试配置类"""
|
||||
|
||||
experimental: ExperimentalConfig = Field(default_factory=ExperimentalConfig)
|
||||
"""实验性功能配置类"""
|
||||
|
||||
maim_message: MaimMessageConfig = Field(default_factory=MaimMessageConfig)
|
||||
"""maim_message配置类"""
|
||||
|
||||
lpmm_knowledge: LPMMKnowledgeConfig = Field(default_factory=LPMMKnowledgeConfig)
|
||||
lpmm_knowledge: LPMMKnowledgeConfig = Field(default_factory=LPMMKnowledgeConfig, repr=False)
|
||||
"""LPMM知识库配置类"""
|
||||
|
||||
webui: WebUIConfig = Field(default_factory=WebUIConfig)
|
||||
|
||||
@@ -30,7 +30,14 @@ def recursive_parse_item_to_table(
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, ConfigBase):
|
||||
config_table.add(config_item_name, recursive_parse_item_to_table(value, override_repr=override_repr))
|
||||
config_table.add(
|
||||
config_item_name,
|
||||
recursive_parse_item_to_table(
|
||||
value,
|
||||
is_inline_table=is_inline_table,
|
||||
override_repr=override_repr,
|
||||
),
|
||||
)
|
||||
else:
|
||||
config_table.add(
|
||||
config_item_name, convert_field(config_item_name, config_item_info, value, override_repr=override_repr)
|
||||
|
||||
@@ -268,11 +268,23 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
migrated_any = True
|
||||
reasons.append("expression.manual_reflect_operator_id")
|
||||
|
||||
chat = _as_dict(data.get("chat"))
|
||||
if chat is None:
|
||||
chat = {}
|
||||
data["chat"] = chat
|
||||
|
||||
mem = _as_dict(data.get("memory"))
|
||||
if mem is not None:
|
||||
if _migrate_target_item_list(mem, "global_memory_blacklist"):
|
||||
migrated_any = True
|
||||
reasons.append("memory.global_memory_blacklist")
|
||||
for removed_key in (
|
||||
"agent_timeout_seconds",
|
||||
"global_memory",
|
||||
"global_memory_blacklist",
|
||||
"max_agent_iterations",
|
||||
):
|
||||
if removed_key in mem:
|
||||
mem.pop(removed_key, None)
|
||||
migrated_any = True
|
||||
reasons.append(f"memory.{removed_key}_removed")
|
||||
|
||||
exp = _as_dict(data.get("experimental"))
|
||||
if exp is not None:
|
||||
@@ -280,7 +292,16 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
migrated_any = True
|
||||
reasons.append("experimental.chat_prompts")
|
||||
|
||||
chat = _as_dict(data.get("chat"))
|
||||
for key in ("private_plan_style", "group_chat_prompt", "private_chat_prompts", "chat_prompts"):
|
||||
if key in exp and key not in chat:
|
||||
chat[key] = exp[key]
|
||||
migrated_any = True
|
||||
reasons.append(f"experimental.{key}_moved_to_chat")
|
||||
|
||||
data.pop("experimental", None)
|
||||
migrated_any = True
|
||||
reasons.append("experimental_removed")
|
||||
|
||||
if chat is not None and "think_mode" in chat:
|
||||
chat.pop("think_mode", None)
|
||||
migrated_any = True
|
||||
|
||||
@@ -244,15 +244,45 @@ class ChatConfig(ConfigBase):
|
||||
},
|
||||
)
|
||||
"""每个聊天流最大保存的Plan/Reply日志数量,超过此数量时会自动删除最老的日志"""
|
||||
|
||||
llm_quote: bool = Field(
|
||||
default=False,
|
||||
private_plan_style: str = Field(
|
||||
default=(
|
||||
"1.思考**所有**的可用的action中的**每个动作**是否符合当下条件,如果动作使用条件符合聊天内容就使用\n"
|
||||
"2.如果相同的内容已经被执行,请不要重复执行\n"
|
||||
"3.某句话如果已经被回复过,不要重复回复"
|
||||
),
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "quote",
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "user",
|
||||
},
|
||||
)
|
||||
"""是否在 reply action 中启用 quote 参数,启用后 LLM 可以控制是否引用消息"""
|
||||
"""_wrap_私聊说话规则,行为风格"""
|
||||
|
||||
group_chat_prompt: str = Field(
|
||||
default="不要回复的太频繁!控制回复的频率,不要每个人的消息都回复,只回复你感兴趣的或者主动提及你的。",
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "users",
|
||||
},
|
||||
)
|
||||
"""_wrap_群聊通用注意事项"""
|
||||
|
||||
private_chat_prompts: str = Field(
|
||||
default="",
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "user",
|
||||
},
|
||||
)
|
||||
"""_wrap_私聊通用注意事项"""
|
||||
|
||||
chat_prompts: list["ExtraPromptItem"] = Field(
|
||||
default_factory=lambda: [],
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "list",
|
||||
},
|
||||
)
|
||||
"""_wrap_为指定聊天添加额外的 prompt 配置列表"""
|
||||
|
||||
enable_talk_value_rules: bool = Field(
|
||||
default=True,
|
||||
@@ -410,7 +440,6 @@ class MemoryConfig(ConfigBase):
|
||||
},
|
||||
)
|
||||
"""是否在发送回复后自动提取并写回人物事实到长期记忆"""
|
||||
|
||||
chat_history_topic_check_message_threshold: int = Field(
|
||||
default=80,
|
||||
ge=1,
|
||||
@@ -462,10 +491,6 @@ class MemoryConfig(ConfigBase):
|
||||
|
||||
def model_post_init(self, context: Optional[dict] = None) -> None:
|
||||
"""验证配置值"""
|
||||
if self.max_agent_iterations < 1:
|
||||
raise ValueError(f"max_agent_iterations 必须至少为1,当前值: {self.max_agent_iterations}")
|
||||
if self.agent_timeout_seconds <= 0:
|
||||
raise ValueError(f"agent_timeout_seconds 必须大于0,当前值: {self.agent_timeout_seconds}")
|
||||
if self.chat_history_topic_check_message_threshold < 1:
|
||||
raise ValueError(
|
||||
f"chat_history_topic_check_message_threshold 必须至少为1,当前值: {self.chat_history_topic_check_message_threshold}"
|
||||
@@ -1070,39 +1095,13 @@ class ExtraPromptItem(ConfigBase):
|
||||
"""额外的prompt内容"""
|
||||
|
||||
def model_post_init(self, context: Optional[dict] = None) -> None:
|
||||
if not self.platform and not self.item_id and not self.prompt:
|
||||
return super().model_post_init(context)
|
||||
if not self.platform or not self.item_id or not self.prompt:
|
||||
raise ValueError("ExtraPromptItem 中 platform, id 和 prompt 不能为空")
|
||||
return super().model_post_init(context)
|
||||
|
||||
|
||||
class ExperimentalConfig(ConfigBase):
|
||||
"""实验功能配置类"""
|
||||
|
||||
__ui_parent__ = "debug"
|
||||
|
||||
private_plan_style: str = Field(
|
||||
default=(
|
||||
"1.思考**所有**的可用的action中的**每个动作**是否符合当下条件,如果动作使用条件符合聊天内容就使用"
|
||||
"2.如果相同的内容已经被执行,请不要重复执行"
|
||||
"3.某句话如果已经被回复过,不要重复回复"
|
||||
),
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "user",
|
||||
},
|
||||
)
|
||||
"""_wrap_私聊说话规则,行为风格(实验性功能)"""
|
||||
|
||||
chat_prompts: list[ExtraPromptItem] = Field(
|
||||
default_factory=lambda: [],
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "list",
|
||||
},
|
||||
)
|
||||
"""_wrap_为指定聊天添加额外的prompt配置列表"""
|
||||
|
||||
|
||||
class MaimMessageConfig(ConfigBase):
|
||||
"""maim_message配置类"""
|
||||
|
||||
@@ -1473,7 +1472,6 @@ class MaiSakaConfig(ConfigBase):
|
||||
|
||||
__ui_label__ = "MaiSaka"
|
||||
__ui_icon__ = "message-circle"
|
||||
__ui_parent__ = "experimental"
|
||||
|
||||
enable_knowledge_module: bool = Field(
|
||||
default=True,
|
||||
@@ -1483,16 +1481,6 @@ class MaiSakaConfig(ConfigBase):
|
||||
},
|
||||
)
|
||||
"""启用知识库模块"""
|
||||
|
||||
show_analyze_cognition_prompt: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "terminal",
|
||||
},
|
||||
)
|
||||
"""是否在 CLI 中显示 analyze_cognition 的 Prompt"""
|
||||
|
||||
show_thinking: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
@@ -1529,6 +1517,15 @@ class MaiSakaConfig(ConfigBase):
|
||||
)
|
||||
"""是否将新接收的用户发言合并为单个用户消息"""
|
||||
|
||||
replyer_generator_type: Literal["legacy", "multi"] = Field(
|
||||
default="legacy",
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "git-branch",
|
||||
},
|
||||
)
|
||||
"""Maisaka replyer 生成器类型:legacy(旧版单 prompt)/ multi(多消息版)"""
|
||||
|
||||
max_internal_rounds: int = Field(
|
||||
default=6,
|
||||
ge=1,
|
||||
@@ -1568,24 +1565,14 @@ class MaiSakaConfig(ConfigBase):
|
||||
)
|
||||
"""工具筛选阶段最多保留的非内置工具数量"""
|
||||
|
||||
terminal_image_preview: bool = Field(
|
||||
default=False,
|
||||
terminal_image_display_mode: Literal["legacy", "path_link"] = Field(
|
||||
default="legacy",
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-widget": "select",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""是否渲染低分辨率终端预览图片"""
|
||||
|
||||
terminal_image_preview_width: int = Field(
|
||||
default=24,
|
||||
ge=8,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "columns",
|
||||
},
|
||||
)
|
||||
"""Maisaka终端图片预览的字符宽度"""
|
||||
"""图片展示模式:legacy(仅显示元信息)/ path_link(可点击本地路径)"""
|
||||
|
||||
|
||||
class MCPAuthorizationConfig(ConfigBase):
|
||||
@@ -1969,6 +1956,129 @@ class MCPConfig(ConfigBase):
|
||||
return super().model_post_init(context)
|
||||
|
||||
|
||||
class PluginRuntimeRenderConfig(ConfigBase):
|
||||
"""插件运行时浏览器渲染配置。"""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""是否启用插件运行时浏览器渲染能力"""
|
||||
|
||||
browser_ws_endpoint: str = Field(
|
||||
default="",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "link",
|
||||
},
|
||||
)
|
||||
"""优先复用的现有 Chromium CDP 地址,可填写 ws/http 端点"""
|
||||
|
||||
executable_path: str = Field(
|
||||
default="",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "folder",
|
||||
},
|
||||
)
|
||||
"""浏览器可执行文件路径,留空时自动探测本机 Chrome/Chromium"""
|
||||
|
||||
browser_install_root: str = Field(
|
||||
default="data/playwright-browsers",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "hard-drive",
|
||||
},
|
||||
)
|
||||
"""Playwright 托管浏览器目录,自动下载 Chromium 时会复用该目录"""
|
||||
|
||||
headless: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "monitor",
|
||||
},
|
||||
)
|
||||
"""是否以无头模式启动浏览器"""
|
||||
|
||||
launch_args: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"--disable-gpu",
|
||||
"--disable-dev-shm-usage",
|
||||
"--disable-setuid-sandbox",
|
||||
"--no-sandbox",
|
||||
"--no-zygote",
|
||||
],
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "terminal",
|
||||
},
|
||||
)
|
||||
"""浏览器启动参数列表"""
|
||||
|
||||
concurrency_limit: int = Field(
|
||||
default=2,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "number",
|
||||
"x-icon": "layers",
|
||||
},
|
||||
)
|
||||
"""同时允许进行的最大渲染任务数"""
|
||||
|
||||
startup_timeout_sec: float = Field(
|
||||
default=20.0,
|
||||
gt=0,
|
||||
json_schema_extra={
|
||||
"x-widget": "number",
|
||||
"x-icon": "clock",
|
||||
},
|
||||
)
|
||||
"""浏览器连接或启动超时时间(秒)"""
|
||||
|
||||
render_timeout_sec: float = Field(
|
||||
default=15.0,
|
||||
gt=0,
|
||||
json_schema_extra={
|
||||
"x-widget": "number",
|
||||
"x-icon": "timer",
|
||||
},
|
||||
)
|
||||
"""单次渲染默认超时时间(秒)"""
|
||||
|
||||
auto_download_chromium: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "download",
|
||||
},
|
||||
)
|
||||
"""未检测到可用浏览器时,是否自动下载 Playwright Chromium"""
|
||||
|
||||
download_connection_timeout_sec: float = Field(
|
||||
default=120.0,
|
||||
gt=0,
|
||||
json_schema_extra={
|
||||
"x-widget": "number",
|
||||
"x-icon": "cloud-lightning",
|
||||
},
|
||||
)
|
||||
"""自动下载 Chromium 时的连接超时时间(秒)"""
|
||||
|
||||
restart_after_render_count: int = Field(
|
||||
default=200,
|
||||
ge=0,
|
||||
json_schema_extra={
|
||||
"x-widget": "number",
|
||||
"x-icon": "refresh-cw",
|
||||
},
|
||||
)
|
||||
"""累计渲染指定次数后自动重建本地浏览器,0 表示关闭该策略"""
|
||||
|
||||
|
||||
class PluginRuntimeConfig(ConfigBase):
|
||||
"""插件运行时配置类"""
|
||||
|
||||
@@ -2031,3 +2141,6 @@ class PluginRuntimeConfig(ConfigBase):
|
||||
自定义 IPC Socket 路径(仅 Linux/macOS 生效)
|
||||
留空则自动生成临时路径
|
||||
"""
|
||||
|
||||
render: PluginRuntimeRenderConfig = Field(default_factory=PluginRuntimeRenderConfig)
|
||||
"""浏览器渲染能力配置"""
|
||||
|
||||
@@ -1,22 +1,25 @@
|
||||
from datetime import datetime
|
||||
from sqlmodel import select
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import difflib
|
||||
import json
|
||||
import re
|
||||
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.data_models.expression_data_model import MaiExpression
|
||||
from sqlmodel import select
|
||||
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.common.data_models.expression_data_model import MaiExpression
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
from src.config.config import global_config
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from .expression_utils import check_expression_suitability, parse_expression_response
|
||||
|
||||
@@ -34,8 +37,122 @@ summary_model = LLMServiceClient(task_name="utils", request_type="expression.sum
|
||||
check_model = LLMServiceClient(task_name="utils", request_type="expression.check")
|
||||
|
||||
|
||||
def register_expression_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册表达方式系统内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="expression.select.before_select",
|
||||
description="表达方式选择流程开始前触发,可改写会话上下文、选择参数或中止本次选择。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
|
||||
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
|
||||
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
|
||||
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
|
||||
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
|
||||
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
|
||||
},
|
||||
required=["chat_id", "chat_info", "max_num", "think_level"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="expression.select.after_selection",
|
||||
description="表达方式选择完成后触发,可改写最终选中的表达方式列表与 ID。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
|
||||
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
|
||||
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
|
||||
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
|
||||
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
|
||||
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
|
||||
"selected_expressions": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "当前已选中的表达方式列表。",
|
||||
},
|
||||
"selected_expression_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
"description": "当前已选中的表达方式 ID 列表。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"chat_id",
|
||||
"chat_info",
|
||||
"max_num",
|
||||
"think_level",
|
||||
"selected_expressions",
|
||||
"selected_expression_ids",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="expression.learn.after_extract",
|
||||
description="表达方式学习解析出表达/黑话候选后触发,可改写候选集或直接终止本轮学习。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"message_count": {"type": "integer", "description": "本轮参与学习的消息数量。"},
|
||||
"expressions": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "解析出的表达方式候选列表。",
|
||||
},
|
||||
"jargon_entries": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "解析出的黑话候选列表。",
|
||||
},
|
||||
},
|
||||
required=["session_id", "message_count", "expressions", "jargon_entries"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="expression.learn.before_upsert",
|
||||
description="表达方式写入数据库前触发,可改写情景/风格文本或跳过本条写入。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"situation": {"type": "string", "description": "即将写入的情景文本。"},
|
||||
"style": {"type": "string", "description": "即将写入的风格文本。"},
|
||||
},
|
||||
required=["session_id", "situation", "style"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ExpressionLearner:
|
||||
def __init__(self, session_id: str) -> None:
|
||||
"""初始化表达方式学习器。
|
||||
|
||||
Args:
|
||||
session_id: 当前会话 ID。
|
||||
"""
|
||||
|
||||
self.session_id = session_id
|
||||
|
||||
# 学习锁,防止并发执行学习任务
|
||||
@@ -44,6 +161,110 @@ class ExpressionLearner:
|
||||
# 消息缓存
|
||||
self._messages_cache: List["SessionMessage"] = []
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_expressions(expressions: List[Tuple[str, str, str]]) -> List[dict[str, str]]:
|
||||
"""将表达方式候选序列化为 Hook 载荷。
|
||||
|
||||
Args:
|
||||
expressions: 原始表达方式候选列表。
|
||||
|
||||
Returns:
|
||||
List[dict[str, str]]: 序列化后的表达方式候选。
|
||||
"""
|
||||
|
||||
return [
|
||||
{
|
||||
"situation": str(situation).strip(),
|
||||
"style": str(style).strip(),
|
||||
"source_id": str(source_id).strip(),
|
||||
}
|
||||
for situation, style, source_id in expressions
|
||||
if str(situation).strip() and str(style).strip()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_expressions(raw_expressions: Any) -> List[Tuple[str, str, str]]:
|
||||
"""从 Hook 载荷恢复表达方式候选列表。
|
||||
|
||||
Args:
|
||||
raw_expressions: Hook 返回的表达方式候选。
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 恢复后的表达方式候选列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_expressions, list):
|
||||
return []
|
||||
|
||||
normalized_expressions: List[Tuple[str, str, str]] = []
|
||||
for raw_expression in raw_expressions:
|
||||
if not isinstance(raw_expression, dict):
|
||||
continue
|
||||
situation = str(raw_expression.get("situation") or "").strip()
|
||||
style = str(raw_expression.get("style") or "").strip()
|
||||
source_id = str(raw_expression.get("source_id") or "").strip()
|
||||
if not situation or not style:
|
||||
continue
|
||||
normalized_expressions.append((situation, style, source_id))
|
||||
return normalized_expressions
|
||||
|
||||
@staticmethod
|
||||
def _serialize_jargon_entries(jargon_entries: List[Tuple[str, str]]) -> List[dict[str, str]]:
|
||||
"""将黑话候选序列化为 Hook 载荷。
|
||||
|
||||
Args:
|
||||
jargon_entries: 原始黑话候选列表。
|
||||
|
||||
Returns:
|
||||
List[dict[str, str]]: 序列化后的黑话候选列表。
|
||||
"""
|
||||
|
||||
return [
|
||||
{
|
||||
"content": str(content).strip(),
|
||||
"source_id": str(source_id).strip(),
|
||||
}
|
||||
for content, source_id in jargon_entries
|
||||
if str(content).strip()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_jargon_entries(raw_jargon_entries: Any) -> List[Tuple[str, str]]:
|
||||
"""从 Hook 载荷恢复黑话候选列表。
|
||||
|
||||
Args:
|
||||
raw_jargon_entries: Hook 返回的黑话候选列表。
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 恢复后的黑话候选列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_jargon_entries, list):
|
||||
return []
|
||||
|
||||
normalized_entries: List[Tuple[str, str]] = []
|
||||
for raw_entry in raw_jargon_entries:
|
||||
if not isinstance(raw_entry, dict):
|
||||
continue
|
||||
content = str(raw_entry.get("content") or "").strip()
|
||||
source_id = str(raw_entry.get("source_id") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
normalized_entries.append((content, source_id))
|
||||
return normalized_entries
|
||||
|
||||
def add_messages(self, messages: List["SessionMessage"]) -> None:
|
||||
"""添加消息到缓存"""
|
||||
self._messages_cache.extend(messages)
|
||||
@@ -52,8 +273,12 @@ class ExpressionLearner:
|
||||
"""获取当前消息缓存的大小"""
|
||||
return len(self._messages_cache)
|
||||
|
||||
async def learn(self, jargon_miner: Optional["JargonMiner"] = None):
|
||||
"""学习主流程"""
|
||||
async def learn(self, jargon_miner: Optional["JargonMiner"] = None) -> None:
|
||||
"""执行表达方式学习主流程。
|
||||
|
||||
Args:
|
||||
jargon_miner: 可选的黑话学习器实例,用于同步处理黑话候选。
|
||||
"""
|
||||
if not self._messages_cache:
|
||||
logger.debug("没有消息可供学习,跳过学习过程")
|
||||
return
|
||||
@@ -109,6 +334,25 @@ class ExpressionLearner:
|
||||
logger.info(f"黑话提取数量超过 30 个(实际{len(jargon_entries)}个),放弃本次黑话学习")
|
||||
jargon_entries = []
|
||||
|
||||
after_extract_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.learn.after_extract",
|
||||
session_id=self.session_id,
|
||||
message_count=len(self._messages_cache),
|
||||
expressions=self._serialize_expressions(expressions),
|
||||
jargon_entries=self._serialize_jargon_entries(jargon_entries),
|
||||
)
|
||||
if after_extract_result.aborted:
|
||||
logger.info(f"{self.session_id} 的表达方式学习结果被 Hook 中止")
|
||||
return
|
||||
|
||||
after_extract_kwargs = after_extract_result.kwargs
|
||||
raw_expressions = after_extract_kwargs.get("expressions")
|
||||
if raw_expressions is not None:
|
||||
expressions = self._deserialize_expressions(raw_expressions)
|
||||
raw_jargon_entries = after_extract_kwargs.get("jargon_entries")
|
||||
if raw_jargon_entries is not None:
|
||||
jargon_entries = self._deserialize_jargon_entries(raw_jargon_entries)
|
||||
|
||||
# 处理黑话条目,路由到 jargon_miner(即使没有表达方式也要处理黑话)
|
||||
# TODO: 检测是否开启了
|
||||
if jargon_entries:
|
||||
@@ -135,6 +379,22 @@ class ExpressionLearner:
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for situation, style in learnt_expressions:
|
||||
before_upsert_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.learn.before_upsert",
|
||||
session_id=self.session_id,
|
||||
situation=situation,
|
||||
style=style,
|
||||
)
|
||||
if before_upsert_result.aborted:
|
||||
logger.info(f"{self.session_id} 的表达方式写入被 Hook 跳过: situation={situation!r}")
|
||||
continue
|
||||
|
||||
upsert_kwargs = before_upsert_result.kwargs
|
||||
situation = str(upsert_kwargs.get("situation", situation) or "").strip()
|
||||
style = str(upsert_kwargs.get("style", style) or "").strip()
|
||||
if not situation or not style:
|
||||
logger.info(f"{self.session_id} 的表达方式写入被 Hook 清空,已跳过")
|
||||
continue
|
||||
await self._upsert_expression_to_db(situation, style)
|
||||
|
||||
# ====== 黑话相关 ======
|
||||
|
||||
@@ -1,27 +1,109 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.learners.learner_utils_old import weighted_sample
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.learners.learner_utils_old import weighted_sample
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""初始化表达方式选择器。"""
|
||||
|
||||
self.llm_model = LLMServiceClient(
|
||||
task_name="utils", request_type="expression.selector"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
"""将任意值安全转换为整数。
|
||||
|
||||
Args:
|
||||
value: 待转换的值。
|
||||
default: 转换失败时的默认值。
|
||||
|
||||
Returns:
|
||||
int: 转换后的整数结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _normalize_selected_expressions(raw_expressions: Any) -> List[Dict[str, Any]]:
|
||||
"""从 Hook 载荷恢复表达方式选择结果。
|
||||
|
||||
Args:
|
||||
raw_expressions: Hook 返回的表达方式列表。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 恢复后的表达方式列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_expressions, list):
|
||||
return []
|
||||
|
||||
normalized_expressions: List[Dict[str, Any]] = []
|
||||
for raw_expression in raw_expressions:
|
||||
if not isinstance(raw_expression, dict):
|
||||
continue
|
||||
expression_id = raw_expression.get("id")
|
||||
situation = str(raw_expression.get("situation") or "").strip()
|
||||
style = str(raw_expression.get("style") or "").strip()
|
||||
source_id = str(raw_expression.get("source_id") or "").strip()
|
||||
if not isinstance(expression_id, int) or not situation or not style or not source_id:
|
||||
continue
|
||||
normalized_expression = dict(raw_expression)
|
||||
normalized_expression["id"] = expression_id
|
||||
normalized_expression["situation"] = situation
|
||||
normalized_expression["style"] = style
|
||||
normalized_expression["source_id"] = source_id
|
||||
normalized_expressions.append(normalized_expression)
|
||||
return normalized_expressions
|
||||
|
||||
@staticmethod
|
||||
def _normalize_selected_expression_ids(raw_ids: Any, expressions: List[Dict[str, Any]]) -> List[int]:
|
||||
"""规范化最终选中的表达方式 ID 列表。
|
||||
|
||||
Args:
|
||||
raw_ids: Hook 返回的 ID 列表。
|
||||
expressions: 当前最终表达方式列表。
|
||||
|
||||
Returns:
|
||||
List[int]: 规范化后的 ID 列表。
|
||||
"""
|
||||
|
||||
if isinstance(raw_ids, list):
|
||||
normalized_ids = [item for item in raw_ids if isinstance(item, int)]
|
||||
if normalized_ids:
|
||||
return normalized_ids
|
||||
return [expression["id"] for expression in expressions if isinstance(expression.get("id"), int)]
|
||||
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许使用表达
|
||||
@@ -214,8 +296,7 @@ class ExpressionSelector:
|
||||
reply_reason: Optional[str] = None,
|
||||
think_level: int = 1,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
选择适合的表达方式(使用classic模式:随机选择+LLM选择)
|
||||
"""选择适合的表达方式。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
@@ -233,11 +314,60 @@ class ExpressionSelector:
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return [], []
|
||||
|
||||
before_select_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.select.before_select",
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
max_num=max_num,
|
||||
target_message=target_message or "",
|
||||
reply_reason=reply_reason or "",
|
||||
think_level=think_level,
|
||||
)
|
||||
if before_select_result.aborted:
|
||||
logger.info(f"聊天流 {chat_id} 的表达方式选择被 Hook 中止")
|
||||
return [], []
|
||||
|
||||
before_select_kwargs = before_select_result.kwargs
|
||||
chat_id = str(before_select_kwargs.get("chat_id", chat_id) or "").strip() or chat_id
|
||||
chat_info = str(before_select_kwargs.get("chat_info", chat_info) or "")
|
||||
max_num = max(self._coerce_int(before_select_kwargs.get("max_num"), max_num), 1)
|
||||
raw_target_message = before_select_kwargs.get("target_message", target_message or "")
|
||||
target_message = str(raw_target_message or "").strip() or None
|
||||
raw_reply_reason = before_select_kwargs.get("reply_reason", reply_reason or "")
|
||||
reply_reason = str(raw_reply_reason or "").strip() or None
|
||||
think_level = self._coerce_int(before_select_kwargs.get("think_level"), think_level)
|
||||
|
||||
# 使用classic模式(随机选择+LLM选择)
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式,think_level={think_level}")
|
||||
return await self._select_expressions_classic(
|
||||
selected_expressions, selected_ids = await self._select_expressions_classic(
|
||||
chat_id, chat_info, max_num, target_message, reply_reason, think_level
|
||||
)
|
||||
after_selection_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.select.after_selection",
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
max_num=max_num,
|
||||
target_message=target_message or "",
|
||||
reply_reason=reply_reason or "",
|
||||
think_level=think_level,
|
||||
selected_expressions=[dict(item) for item in selected_expressions],
|
||||
selected_expression_ids=list(selected_ids),
|
||||
)
|
||||
if after_selection_result.aborted:
|
||||
logger.info(f"聊天流 {chat_id} 的表达方式选择结果被 Hook 中止")
|
||||
return [], []
|
||||
|
||||
after_selection_kwargs = after_selection_result.kwargs
|
||||
raw_selected_expressions = after_selection_kwargs.get("selected_expressions")
|
||||
if raw_selected_expressions is not None:
|
||||
selected_expressions = self._normalize_selected_expressions(raw_selected_expressions)
|
||||
selected_ids = self._normalize_selected_expression_ids(
|
||||
after_selection_kwargs.get("selected_expression_ids"),
|
||||
selected_expressions,
|
||||
)
|
||||
if selected_expressions:
|
||||
self.update_expressions_last_active_time(selected_expressions)
|
||||
return selected_expressions, selected_ids
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Dict, List, Optional, Set, TypedDict
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, TypedDict
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
@@ -9,13 +9,15 @@ from json_repair import repair_json
|
||||
from sqlmodel import select
|
||||
|
||||
from src.common.data_models.jargon_data_model import MaiJargon
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from .expression_utils import is_single_char_jargon
|
||||
|
||||
@@ -35,8 +37,140 @@ class JargonMeaningEntry(TypedDict):
|
||||
meaning: str
|
||||
|
||||
|
||||
def register_jargon_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册 jargon 系统内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="jargon.query.before_search",
|
||||
description="Maisaka 黑话查询工具执行检索前触发,可改写词条列表、检索参数或直接中止。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"words": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "准备查询的黑话词条列表。",
|
||||
},
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"limit": {"type": "integer", "description": "单个词条的最大返回条数。"},
|
||||
"case_sensitive": {"type": "boolean", "description": "是否大小写敏感。"},
|
||||
"enable_fuzzy_fallback": {"type": "boolean", "description": "是否允许精确命中失败后回退模糊检索。"},
|
||||
"abort_message": {"type": "string", "description": "Hook 主动中止时的失败提示。"},
|
||||
},
|
||||
required=["words", "session_id", "limit", "case_sensitive", "enable_fuzzy_fallback"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="jargon.query.after_search",
|
||||
description="Maisaka 黑话查询工具完成检索后触发,可改写结果列表或中止返回。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"words": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "实际查询的黑话词条列表。",
|
||||
},
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"limit": {"type": "integer", "description": "单个词条的最大返回条数。"},
|
||||
"case_sensitive": {"type": "boolean", "description": "是否大小写敏感。"},
|
||||
"enable_fuzzy_fallback": {"type": "boolean", "description": "是否启用了模糊检索回退。"},
|
||||
"results": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "查询结果列表。",
|
||||
},
|
||||
"abort_message": {"type": "string", "description": "Hook 主动中止时的失败提示。"},
|
||||
},
|
||||
required=["words", "session_id", "limit", "case_sensitive", "enable_fuzzy_fallback", "results"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="jargon.extract.before_persist",
|
||||
description="黑话条目准备写入数据库前触发,可改写去重后的条目列表或跳过本次持久化。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"session_name": {"type": "string", "description": "当前会话展示名称。"},
|
||||
"entries": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
"description": "即将持久化的黑话条目列表。",
|
||||
},
|
||||
},
|
||||
required=["session_id", "session_name", "entries"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="jargon.inference.before_finalize",
|
||||
description="黑话含义推断完成、写回数据库前触发,可改写最终判定与含义结果。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"session_id": {"type": "string", "description": "当前会话 ID。"},
|
||||
"session_name": {"type": "string", "description": "当前会话展示名称。"},
|
||||
"content": {"type": "string", "description": "当前黑话词条。"},
|
||||
"count": {"type": "integer", "description": "当前词条累计命中次数。"},
|
||||
"raw_content_list": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "用于推断的原始上下文片段列表。",
|
||||
},
|
||||
"inference_with_context": {"type": "object", "description": "基于上下文的推断结果。"},
|
||||
"inference_with_content_only": {"type": "object", "description": "仅基于词条内容的推断结果。"},
|
||||
"comparison_result": {"type": "object", "description": "比较阶段输出结果。"},
|
||||
"is_jargon": {"type": "boolean", "description": "当前推断是否判定为黑话。"},
|
||||
"meaning": {"type": "string", "description": "当前推断出的黑话含义。"},
|
||||
"is_complete": {"type": "boolean", "description": "当前是否已完成全部推断流程。"},
|
||||
"last_inference_count": {"type": "integer", "description": "本次推断完成后应写回的 last_inference_count。"},
|
||||
},
|
||||
required=[
|
||||
"session_id",
|
||||
"session_name",
|
||||
"content",
|
||||
"count",
|
||||
"raw_content_list",
|
||||
"inference_with_context",
|
||||
"inference_with_content_only",
|
||||
"comparison_result",
|
||||
"is_jargon",
|
||||
"meaning",
|
||||
"is_complete",
|
||||
"last_inference_count",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class JargonMiner:
|
||||
def __init__(self, session_id: str, session_name: str) -> None:
|
||||
"""初始化黑话学习器。
|
||||
|
||||
Args:
|
||||
session_id: 当前会话 ID。
|
||||
session_name: 当前会话展示名称。
|
||||
"""
|
||||
|
||||
self.session_id = session_id
|
||||
self.session_name = session_name
|
||||
|
||||
@@ -46,13 +180,92 @@ class JargonMiner:
|
||||
# 黑话提取锁,防止并发执行
|
||||
self._extraction_lock = asyncio.Lock()
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
"""将任意值安全转换为整数。
|
||||
|
||||
Args:
|
||||
value: 待转换的值。
|
||||
default: 转换失败时使用的默认值。
|
||||
|
||||
Returns:
|
||||
int: 转换后的整数结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _serialize_jargon_entries(entries: List[JargonEntry]) -> List[Dict[str, object]]:
|
||||
"""将黑话条目列表序列化为 Hook 可传输结构。
|
||||
|
||||
Args:
|
||||
entries: 原始黑话条目列表。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, object]]: 序列化后的条目列表。
|
||||
"""
|
||||
|
||||
return [
|
||||
{
|
||||
"content": str(entry["content"]).strip(),
|
||||
"raw_content": sorted(str(item).strip() for item in entry["raw_content"] if str(item).strip()),
|
||||
}
|
||||
for entry in entries
|
||||
if str(entry["content"]).strip()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_jargon_entries(raw_entries: Any) -> List[JargonEntry]:
|
||||
"""从 Hook 载荷恢复黑话条目列表。
|
||||
|
||||
Args:
|
||||
raw_entries: Hook 返回的条目数据。
|
||||
|
||||
Returns:
|
||||
List[JargonEntry]: 恢复后的黑话条目列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_entries, list):
|
||||
return []
|
||||
|
||||
normalized_entries: List[JargonEntry] = []
|
||||
for raw_entry in raw_entries:
|
||||
if not isinstance(raw_entry, dict):
|
||||
continue
|
||||
content = str(raw_entry.get("content") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
raw_content_values = raw_entry.get("raw_content")
|
||||
raw_content: Set[str] = set()
|
||||
if isinstance(raw_content_values, list):
|
||||
raw_content = {str(item).strip() for item in raw_content_values if str(item).strip()}
|
||||
normalized_entries.append({"content": content, "raw_content": raw_content})
|
||||
return normalized_entries
|
||||
|
||||
def get_cached_jargons(self) -> List[str]:
|
||||
"""获取缓存中的所有黑话列表"""
|
||||
return list(self.cache.keys())
|
||||
|
||||
async def infer_meaning(self, jargon_obj: MaiJargon) -> None:
|
||||
"""
|
||||
对jargon进行含义推断
|
||||
"""对黑话条目执行含义推断。
|
||||
|
||||
Args:
|
||||
jargon_obj: 待推断的黑话数据对象。
|
||||
"""
|
||||
content = jargon_obj.content
|
||||
# 解析raw_content列表
|
||||
@@ -175,15 +388,45 @@ class JargonMiner:
|
||||
is_similar = comparison_result.get("is_similar", False)
|
||||
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
|
||||
|
||||
finalized_meaning = inference1.get("meaning", "") if is_jargon else ""
|
||||
is_complete = (jargon_obj.count or 0) >= 100
|
||||
last_inference_count = jargon_obj.count or 0
|
||||
finalize_result = await self._get_runtime_manager().invoke_hook(
|
||||
"jargon.inference.before_finalize",
|
||||
session_id=self.session_id,
|
||||
session_name=self.session_name,
|
||||
content=content,
|
||||
count=current_count,
|
||||
raw_content_list=list(raw_content_list),
|
||||
inference_with_context=dict(inference1),
|
||||
inference_with_content_only=dict(inference2),
|
||||
comparison_result=dict(comparison_result),
|
||||
is_jargon=is_jargon,
|
||||
meaning=finalized_meaning,
|
||||
is_complete=is_complete,
|
||||
last_inference_count=last_inference_count,
|
||||
)
|
||||
if finalize_result.aborted:
|
||||
logger.info(f"jargon {content} 的推断结果被 Hook 中止写回")
|
||||
return
|
||||
|
||||
finalize_kwargs = finalize_result.kwargs
|
||||
is_jargon = bool(finalize_kwargs.get("is_jargon", is_jargon))
|
||||
finalized_meaning = str(finalize_kwargs.get("meaning", finalized_meaning) or "").strip() if is_jargon else ""
|
||||
is_complete = bool(finalize_kwargs.get("is_complete", is_complete))
|
||||
last_inference_count = self._coerce_int(
|
||||
finalize_kwargs.get("last_inference_count"),
|
||||
last_inference_count,
|
||||
)
|
||||
|
||||
# 更新数据库记录
|
||||
jargon_obj.is_jargon = is_jargon
|
||||
jargon_obj.meaning = inference1.get("meaning", "") if is_jargon else ""
|
||||
jargon_obj.meaning = finalized_meaning
|
||||
# 更新最后一次判定的count值,避免重启后重复判定
|
||||
jargon_obj.last_inference_count = jargon_obj.count or 0
|
||||
jargon_obj.last_inference_count = last_inference_count
|
||||
|
||||
# 如果count>=100,标记为完成,不再进行推断
|
||||
if (jargon_obj.count or 0) >= 100:
|
||||
jargon_obj.is_complete = True
|
||||
jargon_obj.is_complete = is_complete
|
||||
|
||||
try:
|
||||
self._modify_jargon_entry(jargon_obj)
|
||||
@@ -232,6 +475,22 @@ class JargonMiner:
|
||||
merged_entries[content] = {"content": content, "raw_content": set(raw_list)}
|
||||
|
||||
uniq_entries: List[JargonEntry] = list(merged_entries.values())
|
||||
before_persist_result = await self._get_runtime_manager().invoke_hook(
|
||||
"jargon.extract.before_persist",
|
||||
session_id=self.session_id,
|
||||
session_name=self.session_name,
|
||||
entries=self._serialize_jargon_entries(uniq_entries),
|
||||
)
|
||||
if before_persist_result.aborted:
|
||||
logger.info(f"[{self.session_name}] 黑话提取结果被 Hook 中止,不写入数据库")
|
||||
return
|
||||
|
||||
raw_hook_entries = before_persist_result.kwargs.get("entries")
|
||||
if raw_hook_entries is not None:
|
||||
uniq_entries = self._deserialize_jargon_entries(raw_hook_entries)
|
||||
if not uniq_entries:
|
||||
logger.info(f"[{self.session_name}] Hook 过滤后没有可写入的黑话条目")
|
||||
return
|
||||
|
||||
saved = 0
|
||||
updated = 0
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
|
||||
|
||||
from json_repair import repair_json
|
||||
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream
|
||||
@@ -27,6 +27,7 @@ from openai.types.chat import (
|
||||
)
|
||||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.model_configs import APIProvider, ReasoningParseMode, ToolArgumentParseMode
|
||||
@@ -62,6 +63,9 @@ from .base_client import (
|
||||
|
||||
logger = get_logger("llm_models")
|
||||
|
||||
SUPPORTED_OPENAI_IMAGE_FORMATS = {"jpeg", "png", "webp"}
|
||||
"""OpenAI 兼容图片输入稳定支持的格式集合。"""
|
||||
|
||||
THINK_CONTENT_PATTERN = re.compile(
|
||||
r"<think>(?P<think>.*?)</think>(?P<content>.*)|<think>(?P<think_unclosed>.*)|(?P<content_only>.+)",
|
||||
re.DOTALL,
|
||||
@@ -149,14 +153,85 @@ def _build_image_content_part(part: ImageMessagePart) -> ChatCompletionContentPa
|
||||
Returns:
|
||||
ChatCompletionContentPartImageParam: OpenAI 兼容的图片片段。
|
||||
"""
|
||||
normalized_image = _normalize_image_part_for_openai(part)
|
||||
if normalized_image is None:
|
||||
raise ValueError("图片数据无效,无法构建图片消息片段")
|
||||
|
||||
image_format, image_base64 = normalized_image
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/{part.normalized_image_format};base64,{part.image_base64}",
|
||||
"url": f"data:image/{image_format};base64,{image_base64}",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _normalize_image_part_for_openai(part: ImageMessagePart) -> Tuple[str, str] | None:
|
||||
"""将图片片段规范化为 OpenAI 兼容格式。
|
||||
|
||||
Args:
|
||||
part: 内部图片片段。
|
||||
|
||||
Returns:
|
||||
Tuple[str, str] | None: `(image_format, image_base64)`;无法解析时返回 `None`。
|
||||
"""
|
||||
try:
|
||||
image_bytes = base64.b64decode(part.image_base64, validate=True)
|
||||
except (binascii.Error, ValueError) as exc:
|
||||
logger.warning(f"图片 Base64 解码失败,已跳过该图片片段: {exc}")
|
||||
return None
|
||||
|
||||
try:
|
||||
with PILImage.open(io.BytesIO(image_bytes)) as image:
|
||||
image_format = (image.format or part.normalized_image_format).lower()
|
||||
if image_format in {"jpg", "jpeg"}:
|
||||
image_format = "jpeg"
|
||||
|
||||
if image_format in SUPPORTED_OPENAI_IMAGE_FORMATS:
|
||||
return image_format, part.image_base64
|
||||
|
||||
if image_format == "gif":
|
||||
frame_count = getattr(image, "n_frames", 1)
|
||||
frames: List[PILImage.Image] = []
|
||||
durations: List[int] = []
|
||||
|
||||
for frame_index in range(frame_count):
|
||||
image.seek(frame_index)
|
||||
frame = image.copy()
|
||||
if frame.mode not in {"RGB", "RGBA"}:
|
||||
frame = frame.convert("RGBA")
|
||||
frames.append(frame)
|
||||
durations.append(int(image.info.get("duration", 100) or 100))
|
||||
|
||||
output_buffer = io.BytesIO()
|
||||
save_kwargs: Dict[str, Any] = {
|
||||
"format": "WEBP",
|
||||
"save_all": True,
|
||||
"append_images": frames[1:],
|
||||
"duration": durations,
|
||||
"loop": int(image.info.get("loop", 0) or 0),
|
||||
}
|
||||
if frame_count > 1:
|
||||
save_kwargs["lossless"] = True
|
||||
|
||||
frames[0].save(output_buffer, **save_kwargs)
|
||||
converted_base64 = base64.b64encode(output_buffer.getvalue()).decode("utf-8")
|
||||
return "webp", converted_base64
|
||||
|
||||
image.seek(0)
|
||||
normalized_image = image.copy()
|
||||
if normalized_image.mode not in {"RGB", "RGBA"}:
|
||||
normalized_image = normalized_image.convert("RGBA")
|
||||
|
||||
output_buffer = io.BytesIO()
|
||||
normalized_image.save(output_buffer, format="PNG")
|
||||
converted_base64 = base64.b64encode(output_buffer.getvalue()).decode("utf-8")
|
||||
return "png", converted_base64
|
||||
except Exception as exc:
|
||||
logger.warning(f"图片内容无法被识别为有效图片,已跳过该图片片段: {exc}")
|
||||
return None
|
||||
|
||||
|
||||
def _convert_response_format(response_format: RespFormat | None) -> Any:
|
||||
"""将内部响应格式转换为 OpenAI 兼容结构。
|
||||
|
||||
@@ -222,7 +297,21 @@ def _convert_user_message_content(message: Message) -> str | List[ChatCompletion
|
||||
if isinstance(part, TextMessagePart):
|
||||
content.append(_build_text_content_part(part.text))
|
||||
continue
|
||||
content.append(_build_image_content_part(part))
|
||||
|
||||
normalized_image = _normalize_image_part_for_openai(part)
|
||||
if normalized_image is None:
|
||||
content.append(_build_text_content_part("[图片内容不可用]"))
|
||||
continue
|
||||
|
||||
image_format, image_base64 = normalized_image
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/{image_format};base64,{image_base64}",
|
||||
},
|
||||
}
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
@@ -314,13 +403,15 @@ def _convert_tool_options(tool_options: List[ToolOption]) -> List[ChatCompletion
|
||||
"""
|
||||
converted_tools: List[ChatCompletionToolParam] = []
|
||||
for tool_option in tool_options:
|
||||
parameters_schema = cast(
|
||||
Dict[str, object],
|
||||
tool_option.parameters_schema or {"type": "object", "properties": {}},
|
||||
)
|
||||
function_schema: FunctionDefinition = {
|
||||
"name": tool_option.name,
|
||||
"description": tool_option.description,
|
||||
"parameters": parameters_schema,
|
||||
}
|
||||
parameters_schema = tool_option.parameters_schema
|
||||
if parameters_schema is not None:
|
||||
function_schema["parameters"] = cast(Dict[str, object], parameters_schema)
|
||||
converted_tools.append(
|
||||
{
|
||||
"type": "function",
|
||||
|
||||
@@ -88,6 +88,15 @@ def _build_parameters_schema_from_property_map(property_map: Dict[str, Any]) ->
|
||||
return parameters_schema
|
||||
|
||||
|
||||
def _build_empty_object_schema() -> Dict[str, Any]:
|
||||
"""构建无参工具使用的空对象 Schema。"""
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolParam:
|
||||
"""工具参数定义。"""
|
||||
@@ -333,9 +342,8 @@ class ToolOption:
|
||||
function_schema: Dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters_schema or _build_empty_object_schema(),
|
||||
}
|
||||
if self.parameters_schema is not None:
|
||||
function_schema["parameters"] = self.parameters_schema
|
||||
return {
|
||||
"type": "function",
|
||||
"function": function_schema,
|
||||
|
||||
@@ -843,12 +843,6 @@ class LLMOrchestrator:
|
||||
|
||||
for _ in range(max_attempts):
|
||||
model_info, api_provider, client = self._select_model(exclude_models=failed_models_this_request)
|
||||
if self.request_type.startswith("maisaka_"):
|
||||
logger.info(
|
||||
f"LLMOrchestrator[{self.request_type}] 已选择模型 model={model_info.name} "
|
||||
f"provider={api_provider.name} request_type={request_type.value}"
|
||||
)
|
||||
|
||||
message_list = []
|
||||
if message_factory:
|
||||
message_list = message_factory(client)
|
||||
|
||||
71
src/maisaka/builtin_tool/__init__.py
Normal file
71
src/maisaka/builtin_tool/__init__.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Maisaka 内置工具聚合入口。"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
|
||||
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
from .no_reply import get_tool_spec as get_no_reply_tool_spec
|
||||
from .no_reply import handle_tool as handle_no_reply_tool
|
||||
from .query_jargon import get_tool_spec as get_query_jargon_tool_spec
|
||||
from .query_jargon import handle_tool as handle_query_jargon_tool
|
||||
from .query_person_info import get_tool_spec as get_query_person_info_tool_spec
|
||||
from .query_person_info import handle_tool as handle_query_person_info_tool
|
||||
from .reply import get_tool_spec as get_reply_tool_spec
|
||||
from .reply import handle_tool as handle_reply_tool
|
||||
from .send_emoji import get_tool_spec as get_send_emoji_tool_spec
|
||||
from .send_emoji import handle_tool as handle_send_emoji_tool
|
||||
from .wait import get_tool_spec as get_wait_tool_spec
|
||||
from .wait import handle_tool as handle_wait_tool
|
||||
|
||||
BuiltinToolHandler = Callable[[ToolInvocation, Optional[ToolExecutionContext]], Awaitable[ToolExecutionResult]]
|
||||
|
||||
|
||||
def get_builtin_tool_specs() -> List[ToolSpec]:
|
||||
"""获取默认启用的内置工具声明列表。"""
|
||||
|
||||
return [
|
||||
get_wait_tool_spec(),
|
||||
get_reply_tool_spec(),
|
||||
get_query_jargon_tool_spec(),
|
||||
get_no_reply_tool_spec(),
|
||||
get_send_emoji_tool_spec(),
|
||||
]
|
||||
|
||||
|
||||
def get_all_builtin_tool_specs() -> List[ToolSpec]:
|
||||
"""获取全部内置工具声明列表。"""
|
||||
|
||||
return [
|
||||
get_wait_tool_spec(),
|
||||
get_reply_tool_spec(),
|
||||
get_query_jargon_tool_spec(),
|
||||
get_query_person_info_tool_spec(),
|
||||
get_no_reply_tool_spec(),
|
||||
get_send_emoji_tool_spec(),
|
||||
]
|
||||
|
||||
|
||||
def get_builtin_tools() -> List[ToolDefinitionInput]:
|
||||
"""获取兼容旧模型层的内置工具定义。"""
|
||||
|
||||
return [tool_spec.to_llm_definition() for tool_spec in get_builtin_tool_specs()]
|
||||
|
||||
|
||||
def build_builtin_tool_handlers(tool_ctx: BuiltinToolRuntimeContext) -> Dict[str, BuiltinToolHandler]:
|
||||
"""构建内置工具处理器映射。"""
|
||||
|
||||
return {
|
||||
"reply": lambda invocation, context=None: handle_reply_tool(tool_ctx, invocation, context),
|
||||
"no_reply": lambda invocation, context=None: handle_no_reply_tool(tool_ctx, invocation, context),
|
||||
"query_jargon": lambda invocation, context=None: handle_query_jargon_tool(tool_ctx, invocation, context),
|
||||
"query_person_info": lambda invocation, context=None: handle_query_person_info_tool(
|
||||
tool_ctx,
|
||||
invocation,
|
||||
context,
|
||||
),
|
||||
"wait": lambda invocation, context=None: handle_wait_tool(tool_ctx, invocation, context),
|
||||
"send_emoji": lambda invocation, context=None: handle_send_emoji_tool(tool_ctx, invocation, context),
|
||||
}
|
||||
185
src/maisaka/builtin_tool/context.py
Normal file
185
src/maisaka/builtin_tool/context.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Maisaka 内置工具执行上下文。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from base64 import b64decode
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.common.data_models.message_component_data_model import EmojiComponent, MessageSequence, TextComponent
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolExecutionResult
|
||||
|
||||
from ..context_messages import SessionBackedMessage
|
||||
from ..message_adapter import format_speaker_content
|
||||
from ..planner_message_utils import build_planner_prefix, build_session_backed_text_message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..reasoning_engine import MaisakaReasoningEngine
|
||||
from ..runtime import MaisakaHeartFlowChatting
|
||||
|
||||
|
||||
class BuiltinToolRuntimeContext:
|
||||
"""为拆分后的内置工具提供统一运行时能力。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: "MaisakaReasoningEngine",
|
||||
runtime: "MaisakaHeartFlowChatting",
|
||||
) -> None:
|
||||
self.engine = engine
|
||||
self.runtime = runtime
|
||||
|
||||
@staticmethod
|
||||
def build_success_result(
|
||||
tool_name: str,
|
||||
content: str = "",
|
||||
structured_content: Any = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""构造统一工具成功结果。"""
|
||||
|
||||
return ToolExecutionResult(
|
||||
tool_name=tool_name,
|
||||
success=True,
|
||||
content=content,
|
||||
structured_content=structured_content,
|
||||
metadata=dict(metadata or {}),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_failure_result(
|
||||
tool_name: str,
|
||||
error_message: str,
|
||||
structured_content: Any = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""构造统一工具失败结果。"""
|
||||
|
||||
return ToolExecutionResult(
|
||||
tool_name=tool_name,
|
||||
success=False,
|
||||
error_message=error_message,
|
||||
structured_content=structured_content,
|
||||
metadata=dict(metadata or {}),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def normalize_words(raw_words: Any) -> List[str]:
|
||||
"""清洗黑话查询词条列表。"""
|
||||
|
||||
if not isinstance(raw_words, list):
|
||||
return []
|
||||
|
||||
normalized_words: List[str] = []
|
||||
seen_words: set[str] = set()
|
||||
for item in raw_words:
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
word = item.strip()
|
||||
if not word or word in seen_words:
|
||||
continue
|
||||
seen_words.add(word)
|
||||
normalized_words.append(word)
|
||||
return normalized_words
|
||||
|
||||
@staticmethod
|
||||
def normalize_jargon_query_results(raw_results: Any) -> List[Dict[str, object]]:
|
||||
"""规范化黑话查询结果列表。"""
|
||||
|
||||
if not isinstance(raw_results, list):
|
||||
return []
|
||||
|
||||
normalized_results: List[Dict[str, object]] = []
|
||||
for raw_item in raw_results:
|
||||
if not isinstance(raw_item, dict):
|
||||
continue
|
||||
word = str(raw_item.get("word") or "").strip()
|
||||
matches = raw_item.get("matches")
|
||||
normalized_matches: List[Dict[str, str]] = []
|
||||
if isinstance(matches, list):
|
||||
for match in matches:
|
||||
if not isinstance(match, dict):
|
||||
continue
|
||||
content = str(match.get("content") or "").strip()
|
||||
meaning = str(match.get("meaning") or "").strip()
|
||||
if not content or not meaning:
|
||||
continue
|
||||
normalized_matches.append({"content": content, "meaning": meaning})
|
||||
|
||||
normalized_results.append(
|
||||
{
|
||||
"word": word,
|
||||
"found": bool(raw_item.get("found", bool(normalized_matches))),
|
||||
"matches": normalized_matches,
|
||||
}
|
||||
)
|
||||
return normalized_results
|
||||
|
||||
@staticmethod
|
||||
def post_process_reply_text(reply_text: str) -> List[str]:
|
||||
"""沿用旧回复链的文本后处理,执行分段与错别字注入。"""
|
||||
|
||||
processed_segments: List[str] = []
|
||||
for segment in process_llm_response(reply_text):
|
||||
normalized_segment = segment.strip()
|
||||
if normalized_segment:
|
||||
processed_segments.append(normalized_segment)
|
||||
|
||||
if processed_segments:
|
||||
return processed_segments
|
||||
return [reply_text.strip()]
|
||||
|
||||
def get_runtime_manager(self) -> Any:
|
||||
"""获取插件运行时管理器。"""
|
||||
|
||||
return self.engine._get_runtime_manager()
|
||||
|
||||
def append_guided_reply_to_chat_history(self, reply_text: str) -> None:
|
||||
"""将引导回复写回 Maisaka 历史。"""
|
||||
|
||||
bot_name = global_config.bot.nickname.strip() or "MaiSaka"
|
||||
reply_timestamp = datetime.now()
|
||||
history_message = build_session_backed_text_message(
|
||||
speaker_name=bot_name,
|
||||
text=reply_text,
|
||||
timestamp=reply_timestamp,
|
||||
source_kind="guided_reply",
|
||||
)
|
||||
self.runtime._chat_history.append(history_message)
|
||||
|
||||
def append_sent_emoji_to_chat_history(
|
||||
self,
|
||||
*,
|
||||
emoji_base64: str,
|
||||
success_message: str,
|
||||
) -> None:
|
||||
"""将 bot 主动发送的表情包同步到 Maisaka 历史。"""
|
||||
|
||||
bot_name = global_config.bot.nickname.strip() or "MaiSaka"
|
||||
reply_timestamp = datetime.now()
|
||||
planner_prefix = build_planner_prefix(
|
||||
timestamp=reply_timestamp,
|
||||
user_name=bot_name,
|
||||
)
|
||||
history_message = SessionBackedMessage(
|
||||
raw_message=MessageSequence(
|
||||
[
|
||||
TextComponent(planner_prefix),
|
||||
EmojiComponent(
|
||||
binary_hash="",
|
||||
content=success_message,
|
||||
binary_data=b64decode(emoji_base64),
|
||||
),
|
||||
]
|
||||
),
|
||||
visible_text=format_speaker_content(
|
||||
bot_name,
|
||||
"[表情包]",
|
||||
reply_timestamp,
|
||||
),
|
||||
timestamp=reply_timestamp,
|
||||
source_kind="guided_reply",
|
||||
)
|
||||
self.runtime._chat_history.append(history_message)
|
||||
34
src/maisaka/builtin_tool/no_reply.py
Normal file
34
src/maisaka/builtin_tool/no_reply.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""no_reply 内置工具。"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
|
||||
|
||||
def get_tool_spec() -> ToolSpec:
|
||||
"""获取 no_reply 工具声明。"""
|
||||
|
||||
return ToolSpec(
|
||||
name="no_reply",
|
||||
brief_description="本轮不进行回复,等待其他用户的新消息。",
|
||||
provider_name="maisaka_builtin",
|
||||
provider_type="builtin",
|
||||
)
|
||||
|
||||
|
||||
async def handle_tool(
|
||||
tool_ctx: BuiltinToolRuntimeContext,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 no_reply 内置工具。"""
|
||||
|
||||
del context
|
||||
tool_ctx.runtime._enter_stop_state()
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
"当前对话循环已暂停,等待新消息到来。",
|
||||
metadata={"pause_execution": True},
|
||||
)
|
||||
143
src/maisaka/builtin_tool/query_jargon.py
Normal file
143
src/maisaka/builtin_tool/query_jargon.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""query_jargon 内置工具。"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import json
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.learners.jargon_explainer import search_jargon
|
||||
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
|
||||
|
||||
def get_tool_spec() -> ToolSpec:
|
||||
"""获取 query_jargon 工具声明。"""
|
||||
|
||||
return ToolSpec(
|
||||
name="query_jargon",
|
||||
brief_description="查询当前聊天上下文中的黑话或词条含义。",
|
||||
detailed_description="参数说明:\n- words:array,必填。要查询的词条列表。",
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"words": {
|
||||
"type": "array",
|
||||
"description": "要查询的词条列表。",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["words"],
|
||||
},
|
||||
provider_name="maisaka_builtin",
|
||||
provider_type="builtin",
|
||||
)
|
||||
|
||||
|
||||
async def handle_tool(
|
||||
tool_ctx: BuiltinToolRuntimeContext,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 query_jargon 内置工具。"""
|
||||
|
||||
del context
|
||||
raw_words = invocation.arguments.get("words")
|
||||
|
||||
if not isinstance(raw_words, list):
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"查询黑话工具需要提供 `words` 数组参数。",
|
||||
)
|
||||
|
||||
words = tool_ctx.normalize_words(raw_words)
|
||||
if not words:
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"查询黑话工具至少需要一个非空词条。",
|
||||
)
|
||||
|
||||
limit = 5
|
||||
case_sensitive = False
|
||||
enable_fuzzy_fallback = True
|
||||
before_search_result = await tool_ctx.get_runtime_manager().invoke_hook(
|
||||
"jargon.query.before_search",
|
||||
words=list(words),
|
||||
session_id=tool_ctx.runtime.session_id,
|
||||
limit=limit,
|
||||
case_sensitive=case_sensitive,
|
||||
enable_fuzzy_fallback=enable_fuzzy_fallback,
|
||||
abort_message="黑话查询已被 Hook 中止。",
|
||||
)
|
||||
if before_search_result.aborted:
|
||||
abort_message = str(before_search_result.kwargs.get("abort_message") or "黑话查询已被 Hook 中止。").strip()
|
||||
return tool_ctx.build_failure_result(invocation.tool_name, abort_message or "黑话查询已被 Hook 中止。")
|
||||
|
||||
before_search_kwargs = before_search_result.kwargs
|
||||
if before_search_kwargs.get("words") is not None:
|
||||
words = tool_ctx.normalize_words(before_search_kwargs.get("words"))
|
||||
|
||||
if not words:
|
||||
return tool_ctx.build_failure_result(invocation.tool_name, "Hook 过滤后没有可查询的黑话词条。")
|
||||
|
||||
try:
|
||||
limit = int(before_search_kwargs.get("limit", limit))
|
||||
except (TypeError, ValueError):
|
||||
limit = 5
|
||||
limit = max(limit, 1)
|
||||
case_sensitive = bool(before_search_kwargs.get("case_sensitive", case_sensitive))
|
||||
enable_fuzzy_fallback = bool(before_search_kwargs.get("enable_fuzzy_fallback", enable_fuzzy_fallback))
|
||||
|
||||
results: List[Dict[str, object]] = []
|
||||
for word in words:
|
||||
exact_matches = search_jargon(
|
||||
keyword=word,
|
||||
chat_id=tool_ctx.runtime.session_id,
|
||||
limit=limit,
|
||||
case_sensitive=case_sensitive,
|
||||
fuzzy=False,
|
||||
)
|
||||
matched_entries = exact_matches
|
||||
if not matched_entries and enable_fuzzy_fallback:
|
||||
matched_entries = search_jargon(
|
||||
keyword=word,
|
||||
chat_id=tool_ctx.runtime.session_id,
|
||||
limit=limit,
|
||||
case_sensitive=case_sensitive,
|
||||
fuzzy=True,
|
||||
)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"word": word,
|
||||
"found": bool(matched_entries),
|
||||
"matches": matched_entries,
|
||||
}
|
||||
)
|
||||
|
||||
after_search_result = await tool_ctx.get_runtime_manager().invoke_hook(
|
||||
"jargon.query.after_search",
|
||||
words=list(words),
|
||||
session_id=tool_ctx.runtime.session_id,
|
||||
limit=limit,
|
||||
case_sensitive=case_sensitive,
|
||||
enable_fuzzy_fallback=enable_fuzzy_fallback,
|
||||
results=list(results),
|
||||
abort_message="黑话查询结果已被 Hook 中止。",
|
||||
)
|
||||
if after_search_result.aborted:
|
||||
abort_message = str(after_search_result.kwargs.get("abort_message") or "黑话查询结果已被 Hook 中止。").strip()
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
abort_message or "黑话查询结果已被 Hook 中止。",
|
||||
)
|
||||
|
||||
raw_results = after_search_result.kwargs.get("results")
|
||||
if raw_results is not None:
|
||||
results = tool_ctx.normalize_jargon_query_results(raw_results)
|
||||
|
||||
structured_content: Dict[str, Any] = {"results": results}
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
json.dumps(structured_content, ensure_ascii=False),
|
||||
structured_content=structured_content,
|
||||
)
|
||||
183
src/maisaka/builtin_tool/query_person_info.py
Normal file
183
src/maisaka/builtin_tool/query_person_info.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""query_person_info 内置工具。"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import json
|
||||
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.know_u.knowledge_store import get_knowledge_store
|
||||
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
|
||||
|
||||
def get_tool_spec(*, enabled: bool = False) -> ToolSpec:
|
||||
"""获取 query_person_info 工具声明。"""
|
||||
|
||||
return ToolSpec(
|
||||
name="query_person_info",
|
||||
brief_description="查询某个人的档案和相关记忆信息。",
|
||||
detailed_description=(
|
||||
"参数说明:\n"
|
||||
"- person_name:string,必填。人物名称、昵称或用户 ID。\n"
|
||||
"- limit:integer,可选。最多返回多少条匹配记录,默认 3。"
|
||||
),
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"person_name": {
|
||||
"type": "string",
|
||||
"description": "人物名称、昵称或用户 ID。",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "最多返回多少条匹配记录。",
|
||||
"default": 3,
|
||||
},
|
||||
},
|
||||
"required": ["person_name"],
|
||||
},
|
||||
provider_name="maisaka_builtin",
|
||||
provider_type="builtin",
|
||||
enabled=enabled,
|
||||
)
|
||||
|
||||
|
||||
async def handle_tool(
|
||||
tool_ctx: BuiltinToolRuntimeContext,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 query_person_info 内置工具。"""
|
||||
|
||||
del context
|
||||
raw_person_name = invocation.arguments.get("person_name")
|
||||
raw_limit = invocation.arguments.get("limit", 3)
|
||||
|
||||
if not isinstance(raw_person_name, str):
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"查询人物信息工具需要提供字符串类型的 `person_name` 参数。",
|
||||
)
|
||||
|
||||
person_name = raw_person_name.strip()
|
||||
if not person_name:
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"查询人物信息工具需要提供非空的 `person_name` 参数。",
|
||||
)
|
||||
|
||||
try:
|
||||
limit = max(1, min(int(raw_limit), 10))
|
||||
except (TypeError, ValueError):
|
||||
limit = 3
|
||||
|
||||
persons = _query_person_records(person_name, limit)
|
||||
result: Dict[str, Any] = {
|
||||
"query": person_name,
|
||||
"persons": persons,
|
||||
"related_knowledge": _query_related_knowledge(person_name, persons, limit),
|
||||
}
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
json.dumps(result, ensure_ascii=False),
|
||||
structured_content=result,
|
||||
)
|
||||
|
||||
|
||||
def _query_person_records(person_name: str, limit: int) -> List[Dict[str, Any]]:
|
||||
"""按名称、昵称或用户 ID 查询人物档案。"""
|
||||
|
||||
with get_db_session() as session:
|
||||
records = session.exec(
|
||||
select(PersonInfo)
|
||||
.where(
|
||||
col(PersonInfo.person_name).contains(person_name)
|
||||
| col(PersonInfo.user_nickname).contains(person_name)
|
||||
| col(PersonInfo.user_id).contains(person_name)
|
||||
)
|
||||
.order_by(col(PersonInfo.last_known_time).desc(), col(PersonInfo.id).desc())
|
||||
.limit(limit)
|
||||
).all()
|
||||
persons: List[Dict[str, Any]] = []
|
||||
for record in records:
|
||||
memory_points: List[str] = []
|
||||
if record.memory_points:
|
||||
try:
|
||||
parsed_points = json.loads(record.memory_points)
|
||||
if isinstance(parsed_points, list):
|
||||
memory_points = [str(point).strip() for point in parsed_points if str(point).strip()]
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
memory_points = []
|
||||
|
||||
persons.append(
|
||||
{
|
||||
"person_id": record.person_id,
|
||||
"person_name": record.person_name or "",
|
||||
"user_nickname": record.user_nickname,
|
||||
"user_id": record.user_id,
|
||||
"platform": record.platform,
|
||||
"name_reason": record.name_reason or "",
|
||||
"is_known": record.is_known,
|
||||
"know_counts": record.know_counts,
|
||||
"memory_points": memory_points[:20],
|
||||
"last_known_time": record.last_known_time.isoformat() if record.last_known_time is not None else None,
|
||||
}
|
||||
)
|
||||
|
||||
return persons
|
||||
|
||||
|
||||
def _query_related_knowledge(
|
||||
person_name: str,
|
||||
persons: List[Dict[str, Any]],
|
||||
limit: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""从 Maisaka knowledge 中补充检索与该人物相关的条目。"""
|
||||
|
||||
store = get_knowledge_store()
|
||||
knowledge_items: List[Dict[str, Any]] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
for person in persons:
|
||||
matched_items = store.get_knowledge_by_user(
|
||||
platform=str(person.get("platform", "")).strip(),
|
||||
user_id=str(person.get("user_id", "")).strip(),
|
||||
user_nickname=str(person.get("user_nickname", "")).strip(),
|
||||
person_name=str(person.get("person_name", "")).strip(),
|
||||
limit=max(limit, 5),
|
||||
)
|
||||
for item in matched_items:
|
||||
item_id = str(item.get("id", "")).strip()
|
||||
if item_id and item_id in seen_ids:
|
||||
continue
|
||||
if item_id:
|
||||
seen_ids.add(item_id)
|
||||
knowledge_items.append(item)
|
||||
|
||||
if not knowledge_items:
|
||||
fallback_items = store.search_knowledge(person_name, limit=max(limit, 5))
|
||||
for item in fallback_items:
|
||||
item_id = str(item.get("id", "")).strip()
|
||||
if item_id and item_id in seen_ids:
|
||||
continue
|
||||
if item_id:
|
||||
seen_ids.add(item_id)
|
||||
knowledge_items.append(item)
|
||||
|
||||
results: List[Dict[str, Any]] = []
|
||||
for item in knowledge_items:
|
||||
results.append(
|
||||
{
|
||||
"id": str(item.get("id", "")).strip(),
|
||||
"category_id": str(item.get("category_id", "")).strip(),
|
||||
"category_name": str(item.get("category_name", "")).strip(),
|
||||
"content": str(item.get("content", "")).strip(),
|
||||
"metadata": item.get("metadata", {}),
|
||||
"created_at": item.get("created_at"),
|
||||
}
|
||||
)
|
||||
return results
|
||||
188
src/maisaka/builtin_tool/reply.py
Normal file
188
src/maisaka/builtin_tool/reply.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""reply 内置工具。"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.chat.replyer.replyer_manager import replyer_manager
|
||||
from src.cli.maisaka_cli_sender import CLI_PLATFORM_NAME, render_cli_message
|
||||
from src.common.logger import get_logger
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.services import send_service
|
||||
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
|
||||
logger = get_logger("maisaka_builtin_reply")
|
||||
|
||||
|
||||
def get_tool_spec() -> ToolSpec:
|
||||
"""获取 reply 工具声明。"""
|
||||
|
||||
return ToolSpec(
|
||||
name="reply",
|
||||
brief_description="根据当前思考生成并发送一条可见回复。",
|
||||
detailed_description=(
|
||||
"参数说明:\n"
|
||||
"- msg_id:string,必填。要回复的目标用户消息编号。\n"
|
||||
"- quote:boolean,可选。当有非常明确的回复目标时,以引用回复的方式发送,默认 true。\n"
|
||||
"- unknown_words:array,可选。回复前可能需要查询的黑话或词条列表。"
|
||||
),
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"msg_id": {
|
||||
"type": "string",
|
||||
"description": "要回复的目标用户消息编号。",
|
||||
},
|
||||
"quote": {
|
||||
"type": "boolean",
|
||||
"description": "当有非常明确的回复目标时,以引用回复的方式发送。",
|
||||
"default": True,
|
||||
},
|
||||
"unknown_words": {
|
||||
"type": "array",
|
||||
"description": "回复前可能需要查询的黑话或词条列表。",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["msg_id"],
|
||||
},
|
||||
provider_name="maisaka_builtin",
|
||||
provider_type="builtin",
|
||||
)
|
||||
|
||||
|
||||
async def handle_tool(
|
||||
tool_ctx: BuiltinToolRuntimeContext,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 reply 内置工具。"""
|
||||
|
||||
latest_thought = context.reasoning if context is not None else invocation.reasoning
|
||||
target_message_id = str(invocation.arguments.get("msg_id") or "").strip()
|
||||
quote_reply = bool(invocation.arguments.get("quote", True))
|
||||
raw_unknown_words = invocation.arguments.get("unknown_words")
|
||||
unknown_words = raw_unknown_words if isinstance(raw_unknown_words, list) else None
|
||||
|
||||
if not target_message_id:
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"回复工具需要提供有效的 `msg_id` 参数。",
|
||||
)
|
||||
|
||||
target_message = tool_ctx.runtime._source_messages_by_id.get(target_message_id)
|
||||
if target_message is None:
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
f"未找到要回复的目标消息,msg_id={target_message_id}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{tool_ctx.runtime.log_prefix} 已触发回复工具 "
|
||||
f"目标消息编号={target_message_id} 引用回复={quote_reply} 最新思考={latest_thought!r}"
|
||||
)
|
||||
try:
|
||||
replyer = replyer_manager.get_replyer(
|
||||
chat_stream=tool_ctx.runtime.chat_stream,
|
||||
request_type="maisaka_replyer",
|
||||
replyer_type="maisaka",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"{tool_ctx.runtime.log_prefix} 获取回复生成器时发生异常: 目标消息编号={target_message_id}"
|
||||
)
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"获取 Maisaka 回复生成器时发生异常。",
|
||||
)
|
||||
|
||||
if replyer is None:
|
||||
logger.error(f"{tool_ctx.runtime.log_prefix} 获取 Maisaka 回复生成器失败")
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"Maisaka 回复生成器当前不可用。",
|
||||
)
|
||||
|
||||
try:
|
||||
success, reply_result = await replyer.generate_reply_with_context(
|
||||
reply_reason=latest_thought,
|
||||
stream_id=tool_ctx.runtime.session_id,
|
||||
reply_message=target_message,
|
||||
chat_history=tool_ctx.runtime._chat_history,
|
||||
unknown_words=unknown_words,
|
||||
log_reply=False,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
f"{tool_ctx.runtime.log_prefix} 回复生成器执行异常: 目标消息编号={target_message_id} 异常={exc}"
|
||||
)
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"生成可见回复时发生异常。",
|
||||
)
|
||||
|
||||
reply_text = reply_result.completion.response_text.strip() if success else ""
|
||||
if not reply_text:
|
||||
logger.warning(
|
||||
f"{tool_ctx.runtime.log_prefix} 回复生成器返回空文本: "
|
||||
f"目标消息编号={target_message_id} 错误信息={reply_result.error_message!r}"
|
||||
)
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"生成可见回复失败。",
|
||||
)
|
||||
|
||||
reply_segments = tool_ctx.post_process_reply_text(reply_text)
|
||||
combined_reply_text = "".join(reply_segments)
|
||||
try:
|
||||
sent = False
|
||||
if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME:
|
||||
for segment in reply_segments:
|
||||
render_cli_message(segment)
|
||||
sent = True
|
||||
else:
|
||||
for index, segment in enumerate(reply_segments):
|
||||
sent = await send_service.text_to_stream(
|
||||
text=segment,
|
||||
stream_id=tool_ctx.runtime.session_id,
|
||||
set_reply=quote_reply if index == 0 else False,
|
||||
reply_message=target_message if quote_reply and index == 0 else None,
|
||||
selected_expressions=reply_result.selected_expression_ids or None,
|
||||
typing=index > 0,
|
||||
)
|
||||
if not sent:
|
||||
break
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"{tool_ctx.runtime.log_prefix} 发送文字消息时发生异常,目标消息编号={target_message_id}"
|
||||
)
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"发送可见回复时发生异常。",
|
||||
)
|
||||
|
||||
if not sent:
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
"可见回复生成成功,但发送失败。",
|
||||
structured_content={
|
||||
"msg_id": target_message_id,
|
||||
"quote": quote_reply,
|
||||
"reply_segments": reply_segments,
|
||||
},
|
||||
)
|
||||
|
||||
target_user_info = target_message.message_info.user_info
|
||||
target_user_name = target_user_info.user_cardname or target_user_info.user_nickname or target_user_info.user_id
|
||||
|
||||
tool_ctx.append_guided_reply_to_chat_history(combined_reply_text)
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
"回复已生成并发送。",
|
||||
structured_content={
|
||||
"msg_id": target_message_id,
|
||||
"quote": quote_reply,
|
||||
"reply_text": combined_reply_text,
|
||||
"reply_segments": reply_segments,
|
||||
"target_user_name": target_user_name,
|
||||
},
|
||||
)
|
||||
106
src/maisaka/builtin_tool/send_emoji.py
Normal file
106
src/maisaka/builtin_tool/send_emoji.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""send_emoji 内置工具。"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.chat.emoji_system.maisaka_tool import send_emoji_for_maisaka
|
||||
from src.common.logger import get_logger
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
|
||||
logger = get_logger("maisaka_builtin_send_emoji")
|
||||
|
||||
|
||||
def get_tool_spec() -> ToolSpec:
|
||||
"""获取 send_emoji 工具声明。"""
|
||||
|
||||
return ToolSpec(
|
||||
name="send_emoji",
|
||||
brief_description="发送一个合适的表情包来辅助表达情绪。",
|
||||
detailed_description="参数说明:\n- emotion:string,可选。希望表达的情绪,例如 happy、sad、angry 等。",
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"emotion": {
|
||||
"type": "string",
|
||||
"description": "希望表达的情绪,例如 happy、sad、angry 等。",
|
||||
},
|
||||
},
|
||||
},
|
||||
provider_name="maisaka_builtin",
|
||||
provider_type="builtin",
|
||||
)
|
||||
|
||||
|
||||
async def handle_tool(
|
||||
tool_ctx: BuiltinToolRuntimeContext,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 send_emoji 内置工具。"""
|
||||
|
||||
del context
|
||||
emotion = str(invocation.arguments.get("emotion") or "").strip()
|
||||
context_texts = [
|
||||
message.get_history_text()
|
||||
for message in tool_ctx.runtime._chat_history[-5:]
|
||||
if message.get_history_text().strip()
|
||||
]
|
||||
structured_result: Dict[str, Any] = {
|
||||
"success": False,
|
||||
"message": "",
|
||||
"description": "",
|
||||
"emotion": [],
|
||||
"requested_emotion": emotion,
|
||||
"matched_emotion": "",
|
||||
}
|
||||
|
||||
logger.info(f"{tool_ctx.runtime.log_prefix} 触发表情包发送工具,请求情绪={emotion!r}")
|
||||
|
||||
try:
|
||||
send_result = await send_emoji_for_maisaka(
|
||||
stream_id=tool_ctx.runtime.session_id,
|
||||
requested_emotion=emotion,
|
||||
reasoning=tool_ctx.engine.last_reasoning_content,
|
||||
context_texts=context_texts,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(f"{tool_ctx.runtime.log_prefix} 发送表情包时发生异常: {exc}")
|
||||
structured_result["message"] = f"发送表情包时发生异常:{exc}"
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
structured_result["message"],
|
||||
structured_content=structured_result,
|
||||
)
|
||||
|
||||
structured_result["description"] = send_result.description
|
||||
structured_result["emotion"] = list(send_result.emotions)
|
||||
structured_result["matched_emotion"] = send_result.matched_emotion
|
||||
structured_result["message"] = send_result.message
|
||||
|
||||
if send_result.success:
|
||||
logger.info(
|
||||
f"{tool_ctx.runtime.log_prefix} 表情包发送成功 "
|
||||
f"描述={send_result.description!r} 情绪标签={send_result.emotions} "
|
||||
f"请求情绪={emotion!r} 命中情绪={send_result.matched_emotion!r}"
|
||||
)
|
||||
tool_ctx.append_sent_emoji_to_chat_history(
|
||||
emoji_base64=send_result.emoji_base64,
|
||||
success_message=send_result.message,
|
||||
)
|
||||
structured_result["success"] = True
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
send_result.message,
|
||||
structured_content=structured_result,
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"{tool_ctx.runtime.log_prefix} 表情包发送失败 "
|
||||
f"请求情绪={emotion!r} 错误信息={send_result.message}"
|
||||
)
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
structured_result["message"],
|
||||
structured_content=structured_result,
|
||||
)
|
||||
51
src/maisaka/builtin_tool/wait.py
Normal file
51
src/maisaka/builtin_tool/wait.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""wait 内置工具。"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
|
||||
from .context import BuiltinToolRuntimeContext
|
||||
|
||||
|
||||
def get_tool_spec() -> ToolSpec:
|
||||
"""获取 wait 工具声明。"""
|
||||
|
||||
return ToolSpec(
|
||||
name="wait",
|
||||
brief_description="暂停当前对话并等待用户新的输入。",
|
||||
detailed_description="参数说明:\n- seconds:integer,必填。等待的秒数。",
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"seconds": {
|
||||
"type": "integer",
|
||||
"description": "等待的秒数。",
|
||||
},
|
||||
},
|
||||
"required": ["seconds"],
|
||||
},
|
||||
provider_name="maisaka_builtin",
|
||||
provider_type="builtin",
|
||||
)
|
||||
|
||||
|
||||
async def handle_tool(
|
||||
tool_ctx: BuiltinToolRuntimeContext,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行 wait 内置工具。"""
|
||||
|
||||
del context
|
||||
seconds = invocation.arguments.get("seconds", 30)
|
||||
try:
|
||||
wait_seconds = int(seconds)
|
||||
except (TypeError, ValueError):
|
||||
wait_seconds = 30
|
||||
wait_seconds = max(0, wait_seconds)
|
||||
tool_ctx.runtime._enter_wait_state(seconds=wait_seconds, tool_call_id=invocation.call_id)
|
||||
return tool_ctx.build_success_result(
|
||||
invocation.tool_name,
|
||||
f"当前对话循环进入等待状态,最长等待 {wait_seconds} 秒。",
|
||||
metadata={"pause_execution": True},
|
||||
)
|
||||
@@ -1,159 +0,0 @@
|
||||
"""Maisaka 内置工具声明。"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from src.core.tooling import ToolSpec, build_tool_detailed_description
|
||||
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
|
||||
|
||||
|
||||
def _build_tool_spec(
|
||||
name: str,
|
||||
brief_description: str,
|
||||
parameters_schema: Dict[str, Any] | None = None,
|
||||
detailed_description: str = "",
|
||||
) -> ToolSpec:
|
||||
"""构建单个内置工具声明。
|
||||
|
||||
Args:
|
||||
name: 工具名称。
|
||||
brief_description: 简要描述。
|
||||
parameters_schema: 参数 Schema。
|
||||
detailed_description: 详细描述;为空时自动根据参数生成。
|
||||
|
||||
Returns:
|
||||
ToolSpec: 构建完成的工具声明。
|
||||
"""
|
||||
|
||||
normalized_schema = deepcopy(parameters_schema) if parameters_schema is not None else None
|
||||
return ToolSpec(
|
||||
name=name,
|
||||
brief_description=brief_description,
|
||||
detailed_description=(
|
||||
detailed_description.strip()
|
||||
or build_tool_detailed_description(normalized_schema)
|
||||
),
|
||||
parameters_schema=normalized_schema,
|
||||
provider_name="maisaka_builtin",
|
||||
provider_type="builtin",
|
||||
)
|
||||
|
||||
|
||||
def create_builtin_tool_specs() -> List[ToolSpec]:
|
||||
"""创建 Maisaka 内置工具声明列表。
|
||||
|
||||
Returns:
|
||||
List[ToolSpec]: 内置工具声明列表。
|
||||
"""
|
||||
|
||||
return [
|
||||
_build_tool_spec(
|
||||
name="wait",
|
||||
brief_description="暂停当前对话并等待用户新的输入。",
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"seconds": {
|
||||
"type": "integer",
|
||||
"description": "等待的秒数。",
|
||||
},
|
||||
},
|
||||
"required": ["seconds"],
|
||||
},
|
||||
),
|
||||
_build_tool_spec(
|
||||
name="reply",
|
||||
brief_description="根据当前思考生成并发送一条可见回复。",
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"msg_id": {
|
||||
"type": "string",
|
||||
"description": "要回复的目标用户消息编号。",
|
||||
},
|
||||
"quote": {
|
||||
"type": "boolean",
|
||||
"description": "是否以引用回复的方式发送。",
|
||||
"default": True,
|
||||
},
|
||||
"unknown_words": {
|
||||
"type": "array",
|
||||
"description": "回复前可能需要查询的黑话或词条列表。",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["msg_id"],
|
||||
},
|
||||
),
|
||||
_build_tool_spec(
|
||||
name="query_jargon",
|
||||
brief_description="查询当前聊天上下文中的黑话或词条含义。",
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"words": {
|
||||
"type": "array",
|
||||
"description": "要查询的词条列表。",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["words"],
|
||||
},
|
||||
),
|
||||
_build_tool_spec(
|
||||
name="query_person_info",
|
||||
brief_description="查询某个人的档案和相关记忆信息。",
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"person_name": {
|
||||
"type": "string",
|
||||
"description": "人物名称、昵称或用户 ID。",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "最多返回多少条匹配记录。",
|
||||
"default": 3,
|
||||
},
|
||||
},
|
||||
"required": ["person_name"],
|
||||
},
|
||||
),
|
||||
_build_tool_spec(
|
||||
name="no_reply",
|
||||
brief_description="本轮不进行回复,等待其他用户的新消息。",
|
||||
),
|
||||
_build_tool_spec(
|
||||
name="send_emoji",
|
||||
brief_description="发送一个合适的表情包来辅助表达情绪。",
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"emotion": {
|
||||
"type": "string",
|
||||
"description": "希望表达的情绪,例如 happy、sad、angry 等。",
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_builtin_tool_specs() -> List[ToolSpec]:
|
||||
"""获取 Maisaka 内置工具声明。
|
||||
|
||||
Returns:
|
||||
List[ToolSpec]: 内置工具声明列表。
|
||||
"""
|
||||
|
||||
return create_builtin_tool_specs()
|
||||
|
||||
|
||||
def get_builtin_tools() -> List[ToolDefinitionInput]:
|
||||
"""获取兼容旧模型层的内置工具定义。
|
||||
|
||||
Returns:
|
||||
List[ToolDefinitionInput]: 可直接传给模型层的工具定义。
|
||||
"""
|
||||
|
||||
return [tool_spec.to_llm_definition() for tool_spec in create_builtin_tool_specs()]
|
||||
@@ -1,28 +1,23 @@
|
||||
"""Maisaka 对话循环服务。"""
|
||||
|
||||
from base64 import b64decode
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from time import perf_counter
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
from rich.console import Group, RenderableType
|
||||
from rich.console import Group
|
||||
from rich.panel import Panel
|
||||
from rich.pretty import Pretty
|
||||
from rich.text import Text
|
||||
|
||||
from src.cli.console import console
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolRegistry, ToolSpec
|
||||
from src.know_u.knowledge import extract_category_ids_from_result
|
||||
@@ -30,11 +25,20 @@ from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, ToolOption, normalize_tool_options
|
||||
from src.plugin_runtime.hook_payloads import (
|
||||
deserialize_prompt_messages,
|
||||
deserialize_tool_calls,
|
||||
serialize_prompt_messages,
|
||||
serialize_tool_calls,
|
||||
serialize_tool_definitions,
|
||||
)
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from .builtin_tools import get_builtin_tools
|
||||
from .context_messages import AssistantMessage, LLMContextMessage, SessionBackedMessage
|
||||
from .message_adapter import format_speaker_content
|
||||
from .builtin_tool import get_builtin_tools
|
||||
from .context_messages import AssistantMessage, LLMContextMessage, ToolResultMessage
|
||||
from .prompt_cli_renderer import PromptCLIVisualizer
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -44,6 +48,11 @@ class ChatResponse:
|
||||
content: Optional[str]
|
||||
tool_calls: List[ToolCall]
|
||||
raw_message: AssistantMessage
|
||||
selected_history_count: int
|
||||
prompt_tokens: int
|
||||
built_message_count: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ToolFilterSelection(BaseModel):
|
||||
@@ -56,12 +65,131 @@ class ToolFilterSelection(BaseModel):
|
||||
logger = get_logger("maisaka_chat_loop")
|
||||
|
||||
|
||||
def register_maisaka_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册 Maisaka 规划器内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="maisaka.planner.before_request",
|
||||
description="在 Maisaka 向模型发起规划请求前触发,可改写消息窗口与工具定义。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"messages": {
|
||||
"type": "array",
|
||||
"description": "即将发给模型的 PromptMessage 列表。",
|
||||
},
|
||||
"tool_definitions": {
|
||||
"type": "array",
|
||||
"description": "当前候选工具定义列表。",
|
||||
},
|
||||
"selected_history_count": {
|
||||
"type": "integer",
|
||||
"description": "当前选中的上下文消息数量。",
|
||||
},
|
||||
"built_message_count": {
|
||||
"type": "integer",
|
||||
"description": "实际发送给模型的消息数量。",
|
||||
},
|
||||
"selection_reason": {
|
||||
"type": "string",
|
||||
"description": "上下文选择说明。",
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "当前会话 ID。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"messages",
|
||||
"tool_definitions",
|
||||
"selected_history_count",
|
||||
"built_message_count",
|
||||
"selection_reason",
|
||||
"session_id",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=6000,
|
||||
allow_abort=False,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="maisaka.planner.after_response",
|
||||
description="在 Maisaka 收到模型响应后触发,可调整文本结果与工具调用列表。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"response": {
|
||||
"type": "string",
|
||||
"description": "模型返回的文本内容。",
|
||||
},
|
||||
"tool_calls": {
|
||||
"type": "array",
|
||||
"description": "模型返回的工具调用列表。",
|
||||
},
|
||||
"selected_history_count": {
|
||||
"type": "integer",
|
||||
"description": "当前选中的上下文消息数量。",
|
||||
},
|
||||
"built_message_count": {
|
||||
"type": "integer",
|
||||
"description": "实际发送给模型的消息数量。",
|
||||
},
|
||||
"selection_reason": {
|
||||
"type": "string",
|
||||
"description": "上下文选择说明。",
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "当前会话 ID。",
|
||||
},
|
||||
"prompt_tokens": {
|
||||
"type": "integer",
|
||||
"description": "输入 Token 数。",
|
||||
},
|
||||
"completion_tokens": {
|
||||
"type": "integer",
|
||||
"description": "输出 Token 数。",
|
||||
},
|
||||
"total_tokens": {
|
||||
"type": "integer",
|
||||
"description": "总 Token 数。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"response",
|
||||
"tool_calls",
|
||||
"selected_history_count",
|
||||
"built_message_count",
|
||||
"selection_reason",
|
||||
"session_id",
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=6000,
|
||||
allow_abort=False,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class MaisakaChatLoopService:
|
||||
"""负责 Maisaka 主对话循环、系统提示词和终端渲染。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_system_prompt: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_group_chat: Optional[bool] = None,
|
||||
temperature: float = 0.5,
|
||||
max_tokens: int = 2048,
|
||||
) -> None:
|
||||
@@ -69,12 +197,16 @@ class MaisakaChatLoopService:
|
||||
|
||||
Args:
|
||||
chat_system_prompt: 可选的系统提示词。
|
||||
session_id: 当前会话 ID,用于匹配会话级额外提示。
|
||||
is_group_chat: 当前会话是否为群聊。
|
||||
temperature: 规划器温度参数。
|
||||
max_tokens: 规划器最大输出长度。
|
||||
"""
|
||||
|
||||
self._temperature = temperature
|
||||
self._max_tokens = max_tokens
|
||||
self._is_group_chat = is_group_chat
|
||||
self._session_id = session_id or ""
|
||||
self._extra_tools: List[ToolOption] = []
|
||||
self._interrupt_flag: asyncio.Event | None = None
|
||||
self._tool_registry: ToolRegistry | None = None
|
||||
@@ -97,6 +229,35 @@ class MaisakaChatLoopService:
|
||||
|
||||
return self._personality_prompt
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
"""将任意值安全转换为整数。
|
||||
|
||||
Args:
|
||||
value: 待转换的输入值。
|
||||
default: 转换失败时的默认值。
|
||||
|
||||
Returns:
|
||||
int: 转换后的整数结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
def _build_personality_prompt(self) -> str:
|
||||
"""构造人格提示词。"""
|
||||
|
||||
@@ -127,19 +288,13 @@ class MaisakaChatLoopService:
|
||||
Args:
|
||||
tools_section: 额外注入到提示词中的工具说明片段。
|
||||
"""
|
||||
|
||||
if self._prompts_loaded:
|
||||
return
|
||||
|
||||
async with self._prompt_load_lock:
|
||||
if self._prompts_loaded:
|
||||
return
|
||||
|
||||
try:
|
||||
self._chat_system_prompt = load_prompt(
|
||||
"maisaka_chat",
|
||||
file_tools_section=tools_section,
|
||||
bot_name=global_config.bot.nickname,
|
||||
group_chat_attention_block=self._build_group_chat_attention_block(),
|
||||
identity=self._personality_prompt,
|
||||
)
|
||||
except Exception:
|
||||
@@ -147,6 +302,74 @@ class MaisakaChatLoopService:
|
||||
|
||||
self._prompts_loaded = True
|
||||
|
||||
def _build_group_chat_attention_block(self) -> str:
|
||||
"""构建当前聊天场景下的额外注意事项块。"""
|
||||
|
||||
prompt_lines: List[str] = []
|
||||
|
||||
if self._is_group_chat is True:
|
||||
if group_chat_prompt := str(global_config.chat.group_chat_prompt or "").strip():
|
||||
prompt_lines.append(f"通用注意事项:\n{group_chat_prompt}")
|
||||
elif self._is_group_chat is False:
|
||||
if private_chat_prompt := str(global_config.chat.private_chat_prompts or "").strip():
|
||||
prompt_lines.append(f"通用注意事项:\n{private_chat_prompt}")
|
||||
|
||||
if self._session_id:
|
||||
if chat_prompt := self._get_chat_prompt_for_chat(self._session_id, self._is_group_chat).strip():
|
||||
prompt_lines.append(f"当前聊天额外注意事项:\n{chat_prompt}")
|
||||
|
||||
if not prompt_lines:
|
||||
return ""
|
||||
|
||||
return "在该聊天中的注意事项:\n" + "\n\n".join(prompt_lines) + "\n"
|
||||
|
||||
@staticmethod
|
||||
def _get_chat_prompt_for_chat(chat_id: str, is_group_chat: Optional[bool]) -> str:
|
||||
"""根据聊天流 ID 获取匹配的额外提示。"""
|
||||
|
||||
if not global_config.chat.chat_prompts:
|
||||
return ""
|
||||
|
||||
for chat_prompt_item in global_config.chat.chat_prompts:
|
||||
if hasattr(chat_prompt_item, "platform"):
|
||||
platform = str(chat_prompt_item.platform or "").strip()
|
||||
item_id = str(chat_prompt_item.item_id or "").strip()
|
||||
rule_type = str(chat_prompt_item.rule_type or "").strip()
|
||||
prompt_content = str(chat_prompt_item.prompt or "").strip()
|
||||
elif isinstance(chat_prompt_item, str):
|
||||
parts = chat_prompt_item.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
continue
|
||||
|
||||
platform, item_id, rule_type, prompt_content = parts
|
||||
platform = platform.strip()
|
||||
item_id = item_id.strip()
|
||||
rule_type = rule_type.strip()
|
||||
prompt_content = prompt_content.strip()
|
||||
else:
|
||||
continue
|
||||
|
||||
if not platform or not item_id or not prompt_content:
|
||||
continue
|
||||
|
||||
if rule_type == "group":
|
||||
config_is_group = True
|
||||
config_chat_id = SessionUtils.calculate_session_id(platform, group_id=item_id)
|
||||
elif rule_type == "private":
|
||||
config_is_group = False
|
||||
config_chat_id = SessionUtils.calculate_session_id(platform, user_id=item_id)
|
||||
else:
|
||||
continue
|
||||
|
||||
if is_group_chat is not None and config_is_group != is_group_chat:
|
||||
continue
|
||||
|
||||
if config_chat_id == chat_id:
|
||||
logger.debug(f"匹配到 Maisaka 聊天额外提示,chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
|
||||
return prompt_content
|
||||
|
||||
return ""
|
||||
|
||||
def set_extra_tools(self, tools: Sequence[ToolDefinitionInput]) -> None:
|
||||
"""设置额外工具定义。
|
||||
|
||||
@@ -468,259 +691,6 @@ class MaisakaChatLoopService:
|
||||
|
||||
return extract_category_ids_from_result(generation_result.response or "")
|
||||
|
||||
@staticmethod
|
||||
def _get_role_badge_style(role: str) -> str:
|
||||
"""返回终端中角色标签的样式。
|
||||
|
||||
Args:
|
||||
role: 消息角色名称。
|
||||
|
||||
Returns:
|
||||
str: Rich 可识别的样式字符串。
|
||||
"""
|
||||
|
||||
if role == "system":
|
||||
return "bold white on blue"
|
||||
if role == "user":
|
||||
return "bold black on green"
|
||||
if role == "assistant":
|
||||
return "bold black on yellow"
|
||||
if role == "tool":
|
||||
return "bold white on magenta"
|
||||
return "bold white on bright_black"
|
||||
|
||||
@staticmethod
|
||||
def _get_role_badge_label(role: str) -> str:
|
||||
"""返回终端中角色标签的中文名称。
|
||||
|
||||
Args:
|
||||
role: 消息角色名称。
|
||||
|
||||
Returns:
|
||||
str: 用于展示的中文角色名称。
|
||||
"""
|
||||
|
||||
if role == "system":
|
||||
return "系统"
|
||||
if role == "user":
|
||||
return "用户"
|
||||
if role == "assistant":
|
||||
return "助手"
|
||||
if role == "tool":
|
||||
return "工具"
|
||||
return "未知"
|
||||
|
||||
@staticmethod
|
||||
def _build_terminal_image_preview(image_base64: str) -> Optional[str]:
|
||||
"""构造终端图片预览字符画。
|
||||
|
||||
Args:
|
||||
image_base64: 图片的 Base64 编码。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 生成成功时返回字符画文本,否则返回 ``None``。
|
||||
"""
|
||||
|
||||
ascii_chars = " .:-=+*#%@"
|
||||
|
||||
try:
|
||||
image_bytes = b64decode(image_base64)
|
||||
with PILImage.open(BytesIO(image_bytes)) as image:
|
||||
grayscale = image.convert("L")
|
||||
width, height = grayscale.size
|
||||
if width <= 0 or height <= 0:
|
||||
return None
|
||||
|
||||
preview_width = max(8, int(global_config.maisaka.terminal_image_preview_width))
|
||||
preview_height = max(1, int(height * (preview_width / width) * 0.5))
|
||||
resized = grayscale.resize((preview_width, preview_height))
|
||||
pixels = list(resized.tobytes())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
rows: List[str] = []
|
||||
for row_index in range(preview_height):
|
||||
row_pixels = pixels[row_index * preview_width : (row_index + 1) * preview_width]
|
||||
row = "".join(ascii_chars[min(len(ascii_chars) - 1, pixel * len(ascii_chars) // 256)] for pixel in row_pixels)
|
||||
rows.append(row)
|
||||
|
||||
return "\n".join(rows)
|
||||
|
||||
@classmethod
|
||||
def _render_message_content(cls, content: Any) -> RenderableType:
|
||||
"""将消息内容渲染为终端可展示对象。
|
||||
|
||||
Args:
|
||||
content: 原始消息内容。
|
||||
|
||||
Returns:
|
||||
RenderableType: Rich 可渲染对象。
|
||||
"""
|
||||
|
||||
if isinstance(content, str):
|
||||
return Text(content)
|
||||
|
||||
if isinstance(content, list):
|
||||
parts: List[RenderableType] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(Text(item))
|
||||
continue
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
image_format, image_base64 = item
|
||||
if isinstance(image_format, str) and isinstance(image_base64, str):
|
||||
approx_size = max(0, len(image_base64) * 3 // 4)
|
||||
size_text = f"{approx_size / 1024:.1f} KB" if approx_size >= 1024 else f"{approx_size} B"
|
||||
preview_parts: List[RenderableType] = [
|
||||
Text(f"图片格式 image/{image_format} {size_text}\nbase64 内容已省略", style="magenta")
|
||||
]
|
||||
if global_config.maisaka.terminal_image_preview:
|
||||
preview_text = cls._build_terminal_image_preview(image_base64)
|
||||
if preview_text:
|
||||
preview_parts.append(Text(preview_text, style="white"))
|
||||
parts.append(
|
||||
Panel(
|
||||
Group(*preview_parts),
|
||||
border_style="magenta",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
continue
|
||||
if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str):
|
||||
parts.append(Text(item["text"]))
|
||||
else:
|
||||
parts.append(Pretty(item, expand_all=True))
|
||||
return Group(*parts) if parts else Text("")
|
||||
|
||||
if content is None:
|
||||
return Text("")
|
||||
|
||||
return Pretty(content, expand_all=True)
|
||||
|
||||
@staticmethod
|
||||
def _format_tool_call_for_display(tool_call: Any) -> Dict[str, Any]:
|
||||
"""将工具调用对象格式化为易读字典。
|
||||
|
||||
Args:
|
||||
tool_call: 原始工具调用对象或字典。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 适合终端展示的工具调用字典。
|
||||
"""
|
||||
|
||||
if isinstance(tool_call, dict):
|
||||
function_info = tool_call.get("function", {})
|
||||
return {
|
||||
"id": tool_call.get("id"),
|
||||
"name": function_info.get("name", tool_call.get("name")),
|
||||
"arguments": function_info.get("arguments", tool_call.get("arguments")),
|
||||
}
|
||||
|
||||
return {
|
||||
"id": getattr(tool_call, "call_id", getattr(tool_call, "id", None)),
|
||||
"name": getattr(tool_call, "func_name", getattr(tool_call, "name", None)),
|
||||
"arguments": getattr(tool_call, "args", getattr(tool_call, "arguments", None)),
|
||||
}
|
||||
|
||||
def _render_tool_call_panel(self, tool_call: Any, index: int, parent_index: int) -> Panel:
|
||||
"""渲染单个工具调用面板。
|
||||
|
||||
Args:
|
||||
tool_call: 原始工具调用对象。
|
||||
index: 工具调用在当前消息中的序号。
|
||||
parent_index: 所属消息的序号。
|
||||
|
||||
Returns:
|
||||
Panel: 工具调用展示面板。
|
||||
"""
|
||||
|
||||
title = Text.assemble(
|
||||
Text(" 工具调用 ", style="bold white on magenta"),
|
||||
Text(f" #{parent_index}.{index}", style="muted"),
|
||||
)
|
||||
return Panel(
|
||||
Pretty(self._format_tool_call_for_display(tool_call), expand_all=True),
|
||||
title=title,
|
||||
border_style="magenta",
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
def _render_message_panel(self, message: Any, index: int) -> Panel:
|
||||
"""渲染单条消息面板。
|
||||
|
||||
Args:
|
||||
message: 原始消息对象或字典。
|
||||
index: 消息序号。
|
||||
|
||||
Returns:
|
||||
Panel: 终端展示面板。
|
||||
"""
|
||||
|
||||
if isinstance(message, dict):
|
||||
raw_role = message.get("role", "unknown")
|
||||
content = message.get("content")
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
else:
|
||||
raw_role = getattr(message, "role", "unknown")
|
||||
content = getattr(message, "content", None)
|
||||
tool_call_id = getattr(message, "tool_call_id", None)
|
||||
|
||||
role = raw_role.value if isinstance(raw_role, RoleType) else str(raw_role)
|
||||
title = Text.assemble(
|
||||
Text(f" {self._get_role_badge_label(role)} ", style=self._get_role_badge_style(role)),
|
||||
Text(f" #{index}", style="muted"),
|
||||
)
|
||||
|
||||
parts: List[RenderableType] = []
|
||||
if content not in (None, "", []):
|
||||
parts.append(Text(" 消息 ", style="bold cyan"))
|
||||
parts.append(self._render_message_content(content))
|
||||
|
||||
if tool_call_id:
|
||||
parts.append(
|
||||
Text.assemble(
|
||||
Text(" 工具调用编号 ", style="bold magenta"),
|
||||
Text(" "),
|
||||
Text(str(tool_call_id), style="magenta"),
|
||||
)
|
||||
)
|
||||
|
||||
if not parts:
|
||||
parts.append(Text("[空消息]", style="muted"))
|
||||
|
||||
return Panel(
|
||||
Group(*parts),
|
||||
title=title,
|
||||
border_style="dim",
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _format_token_count(token_count: int) -> str:
|
||||
"""格式化 token 数量展示文本。"""
|
||||
if token_count >= 10_000:
|
||||
return f"{token_count / 1000:.1f}k"
|
||||
return str(token_count)
|
||||
|
||||
@classmethod
|
||||
def _build_prompt_stats_text(
|
||||
cls,
|
||||
*,
|
||||
selected_history_count: int,
|
||||
built_message_count: int,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
) -> str:
|
||||
"""构造本轮 prompt 的统计信息文本。"""
|
||||
return (
|
||||
f"已选上下文消息数={selected_history_count} "
|
||||
f"大模型消息数={built_message_count} "
|
||||
f"实际输入Token={cls._format_token_count(prompt_tokens)} "
|
||||
f"输出Token={cls._format_token_count(completion_tokens)} "
|
||||
f"总Token={cls._format_token_count(total_tokens)}"
|
||||
)
|
||||
|
||||
async def chat_loop_step(self, chat_history: List[LLMContextMessage]) -> ChatResponse:
|
||||
"""执行一轮 Maisaka 规划器请求。
|
||||
|
||||
@@ -756,13 +726,30 @@ class MaisakaChatLoopService:
|
||||
else:
|
||||
all_tools = [*get_builtin_tools(), *self._extra_tools]
|
||||
|
||||
ordered_panels: List[Panel] = []
|
||||
for index, msg in enumerate(built_messages, start=1):
|
||||
ordered_panels.append(self._render_message_panel(msg, index))
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if tool_calls:
|
||||
for tool_call_index, tool_call in enumerate(tool_calls, start=1):
|
||||
ordered_panels.append(self._render_tool_call_panel(tool_call, tool_call_index, index))
|
||||
before_request_result = await self._get_runtime_manager().invoke_hook(
|
||||
"maisaka.planner.before_request",
|
||||
messages=serialize_prompt_messages(built_messages),
|
||||
tool_definitions=serialize_tool_definitions(all_tools),
|
||||
selected_history_count=len(selected_history),
|
||||
built_message_count=len(built_messages),
|
||||
selection_reason=selection_reason,
|
||||
session_id=self._session_id,
|
||||
)
|
||||
before_request_kwargs = before_request_result.kwargs
|
||||
raw_messages = before_request_kwargs.get("messages")
|
||||
if isinstance(raw_messages, list):
|
||||
try:
|
||||
built_messages = deserialize_prompt_messages(raw_messages)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Hook maisaka.planner.before_request 返回的 messages 无法反序列化,已忽略: {exc}")
|
||||
raw_tool_definitions = before_request_kwargs.get("tool_definitions")
|
||||
if isinstance(raw_tool_definitions, list):
|
||||
all_tools = [item for item in raw_tool_definitions if isinstance(item, dict)]
|
||||
|
||||
ordered_panels = PromptCLIVisualizer.build_prompt_panels(
|
||||
built_messages,
|
||||
image_display_mode=global_config.maisaka.terminal_image_display_mode,
|
||||
)
|
||||
|
||||
if global_config.maisaka.show_thinking and ordered_panels:
|
||||
console.print(
|
||||
@@ -795,7 +782,7 @@ class MaisakaChatLoopService:
|
||||
request_elapsed = perf_counter() - request_started_at
|
||||
logger.info(f"规划器请求完成,耗时={request_elapsed:.3f} 秒")
|
||||
|
||||
prompt_stats_text = self._build_prompt_stats_text(
|
||||
prompt_stats_text = PromptCLIVisualizer.build_prompt_stats_text(
|
||||
selected_history_count=len(selected_history),
|
||||
built_message_count=len(built_messages),
|
||||
prompt_tokens=generation_result.prompt_tokens,
|
||||
@@ -804,28 +791,63 @@ class MaisakaChatLoopService:
|
||||
)
|
||||
logger.info(f"本轮Prompt统计: {prompt_stats_text}")
|
||||
|
||||
final_response = generation_result.response or ""
|
||||
final_tool_calls = list(generation_result.tool_calls or [])
|
||||
after_response_result = await self._get_runtime_manager().invoke_hook(
|
||||
"maisaka.planner.after_response",
|
||||
response=final_response,
|
||||
tool_calls=serialize_tool_calls(final_tool_calls),
|
||||
selected_history_count=len(selected_history),
|
||||
built_message_count=len(built_messages),
|
||||
selection_reason=selection_reason,
|
||||
session_id=self._session_id,
|
||||
prompt_tokens=generation_result.prompt_tokens,
|
||||
completion_tokens=generation_result.completion_tokens,
|
||||
total_tokens=generation_result.total_tokens,
|
||||
)
|
||||
after_response_kwargs = after_response_result.kwargs
|
||||
if "response" in after_response_kwargs:
|
||||
final_response = str(after_response_kwargs.get("response") or "")
|
||||
raw_tool_calls = after_response_kwargs.get("tool_calls")
|
||||
if isinstance(raw_tool_calls, list):
|
||||
try:
|
||||
final_tool_calls = deserialize_tool_calls(raw_tool_calls)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Hook maisaka.planner.after_response 返回的 tool_calls 无法反序列化,已忽略: {exc}")
|
||||
prompt_tokens = self._coerce_int(after_response_kwargs.get("prompt_tokens"), generation_result.prompt_tokens)
|
||||
completion_tokens = self._coerce_int(
|
||||
after_response_kwargs.get("completion_tokens"),
|
||||
generation_result.completion_tokens,
|
||||
)
|
||||
total_tokens = self._coerce_int(after_response_kwargs.get("total_tokens"), generation_result.total_tokens)
|
||||
|
||||
tool_call_summaries = [
|
||||
{
|
||||
"调用编号": getattr(tool_call, "call_id", getattr(tool_call, "id", None)),
|
||||
"工具名": getattr(tool_call, "func_name", getattr(tool_call, "name", None)),
|
||||
"参数": getattr(tool_call, "args", getattr(tool_call, "arguments", None)),
|
||||
}
|
||||
for tool_call in (generation_result.tool_calls or [])
|
||||
for tool_call in final_tool_calls
|
||||
]
|
||||
logger.info(
|
||||
f"Maisaka 规划器返回结果: 内容={generation_result.response or ''!r} "
|
||||
f"Maisaka 规划器返回结果: 内容={final_response!r} "
|
||||
f"工具调用={tool_call_summaries}"
|
||||
)
|
||||
|
||||
raw_message = AssistantMessage(
|
||||
content=generation_result.response or "",
|
||||
content=final_response,
|
||||
timestamp=datetime.now(),
|
||||
tool_calls=generation_result.tool_calls or [],
|
||||
tool_calls=final_tool_calls,
|
||||
)
|
||||
return ChatResponse(
|
||||
content=generation_result.response,
|
||||
tool_calls=generation_result.tool_calls or [],
|
||||
content=final_response or None,
|
||||
tool_calls=final_tool_calls,
|
||||
raw_message=raw_message,
|
||||
selected_history_count=len(selected_history),
|
||||
prompt_tokens=prompt_tokens,
|
||||
built_message_count=len(built_messages),
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -859,6 +881,7 @@ class MaisakaChatLoopService:
|
||||
|
||||
selected_indices.reverse()
|
||||
selected_history = [chat_history[index] for index in selected_indices]
|
||||
selected_history = MaisakaChatLoopService._drop_leading_orphan_tool_results(selected_history)
|
||||
return (
|
||||
selected_history,
|
||||
(
|
||||
@@ -868,34 +891,31 @@ class MaisakaChatLoopService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_chat_context(user_text: str) -> List[LLMContextMessage]:
|
||||
"""根据用户输入构造最小对话上下文。
|
||||
def _drop_leading_orphan_tool_results(
|
||||
selected_history: List[LLMContextMessage],
|
||||
) -> List[LLMContextMessage]:
|
||||
"""移除窗口前缀中缺少对应 tool_call 的工具结果消息。"""
|
||||
|
||||
Args:
|
||||
user_text: 用户输入文本。
|
||||
if not selected_history:
|
||||
return selected_history
|
||||
|
||||
Returns:
|
||||
List[LLMContextMessage]: 构造好的上下文消息列表。
|
||||
"""
|
||||
available_tool_call_ids = {
|
||||
tool_call.call_id
|
||||
for message in selected_history
|
||||
if isinstance(message, AssistantMessage)
|
||||
for tool_call in message.tool_calls
|
||||
if tool_call.call_id
|
||||
}
|
||||
|
||||
timestamp = datetime.now()
|
||||
visible_text = format_speaker_content(
|
||||
global_config.maisaka.user_name.strip() or "用户",
|
||||
user_text,
|
||||
timestamp,
|
||||
)
|
||||
planner_prefix = (
|
||||
f"[时间]{timestamp.strftime('%H:%M:%S')}\n"
|
||||
f"[用户]{global_config.maisaka.user_name.strip() or '用户'}\n"
|
||||
"[用户群昵称]\n"
|
||||
"[msg_id]\n"
|
||||
"[发言内容]"
|
||||
)
|
||||
return [
|
||||
SessionBackedMessage(
|
||||
raw_message=MessageSequence([TextComponent(f"{planner_prefix}{user_text}")]),
|
||||
visible_text=visible_text,
|
||||
timestamp=timestamp,
|
||||
source_kind="user",
|
||||
)
|
||||
]
|
||||
first_valid_index = 0
|
||||
while first_valid_index < len(selected_history):
|
||||
message = selected_history[first_valid_index]
|
||||
if not isinstance(message, ToolResultMessage):
|
||||
break
|
||||
if message.tool_call_id in available_tool_call_ids:
|
||||
break
|
||||
first_valid_index += 1
|
||||
|
||||
if first_valid_index == 0:
|
||||
return selected_history
|
||||
return selected_history[first_valid_index:]
|
||||
@@ -11,7 +11,13 @@ import base64
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence, TextComponent
|
||||
from src.common.data_models.message_component_data_model import (
|
||||
EmojiComponent,
|
||||
ImageComponent,
|
||||
MessageSequence,
|
||||
ReplyComponent,
|
||||
TextComponent,
|
||||
)
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
|
||||
@@ -27,6 +33,44 @@ def _guess_image_format(image_bytes: bytes) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _append_emoji_component(builder: MessageBuilder, component: EmojiComponent) -> bool:
|
||||
"""将表情组件追加到 LLM 消息构建器。"""
|
||||
image_format = _guess_image_format(component.binary_data)
|
||||
if image_format and component.binary_data:
|
||||
builder.add_text_content("[消息类型]表情包")
|
||||
builder.add_image_content(image_format, base64.b64encode(component.binary_data).decode("utf-8"))
|
||||
return True
|
||||
|
||||
if component.content:
|
||||
builder.add_text_content(component.content)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _append_image_component(builder: MessageBuilder, component: ImageComponent) -> bool:
|
||||
"""将图片组件追加到 LLM 消息构建器。"""
|
||||
image_format = _guess_image_format(component.binary_data)
|
||||
if image_format and component.binary_data:
|
||||
builder.add_text_content("[消息类型]图片")
|
||||
builder.add_image_content(image_format, base64.b64encode(component.binary_data).decode("utf-8"))
|
||||
return True
|
||||
|
||||
if component.content:
|
||||
builder.add_text_content(component.content)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _append_reply_component(builder: MessageBuilder, component: ReplyComponent) -> bool:
|
||||
"""将回复组件追加到 LLM 消息构建器。"""
|
||||
target_message_id = component.target_message_id.strip()
|
||||
if not target_message_id:
|
||||
return False
|
||||
|
||||
builder.add_text_content(f"[引用回复]({target_message_id})")
|
||||
return True
|
||||
|
||||
|
||||
def _build_message_from_sequence(
|
||||
role: RoleType,
|
||||
message_sequence: MessageSequence,
|
||||
@@ -50,16 +94,17 @@ def _build_message_from_sequence(
|
||||
has_content = True
|
||||
continue
|
||||
|
||||
if isinstance(component, (EmojiComponent, ImageComponent)):
|
||||
image_format = _guess_image_format(component.binary_data)
|
||||
if image_format and component.binary_data:
|
||||
builder.add_image_content(image_format, base64.b64encode(component.binary_data).decode("utf-8"))
|
||||
has_content = True
|
||||
continue
|
||||
if isinstance(component, EmojiComponent):
|
||||
has_content = _append_emoji_component(builder, component) or has_content
|
||||
continue
|
||||
|
||||
if component.content:
|
||||
builder.add_text_content(component.content)
|
||||
has_content = True
|
||||
if isinstance(component, ImageComponent):
|
||||
has_content = _append_image_component(builder, component) or has_content
|
||||
continue
|
||||
|
||||
if isinstance(component, ReplyComponent):
|
||||
has_content = _append_reply_component(builder, component) or has_content
|
||||
continue
|
||||
|
||||
if not has_content and fallback_text:
|
||||
builder.add_text_content(fallback_text)
|
||||
|
||||
@@ -5,7 +5,13 @@ from datetime import datetime
|
||||
from typing import Optional
|
||||
import re
|
||||
|
||||
from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence, TextComponent
|
||||
from src.common.data_models.message_component_data_model import (
|
||||
EmojiComponent,
|
||||
ImageComponent,
|
||||
MessageSequence,
|
||||
ReplyComponent,
|
||||
TextComponent,
|
||||
)
|
||||
|
||||
SPEAKER_PREFIX_PATTERN = re.compile(
|
||||
r"^(?:(?P<timestamp>\d{2}:\d{2}:\d{2}))?(?:\[msg_id:(?P<message_id>[^\]]+)\])?\[(?P<speaker>[^\]]+)\](?P<content>.*)$",
|
||||
@@ -65,5 +71,11 @@ def build_visible_text_from_sequence(message_sequence: MessageSequence) -> str:
|
||||
|
||||
if isinstance(component, ImageComponent):
|
||||
parts.append("[图片]")
|
||||
continue
|
||||
|
||||
if isinstance(component, ReplyComponent):
|
||||
target_message_id = component.target_message_id.strip()
|
||||
if target_message_id:
|
||||
parts.append(f"[引用回复]({target_message_id})")
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
109
src/maisaka/planner_message_utils.py
Normal file
109
src/maisaka/planner_message_utils.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Maisaka 规划器消息构造工具。"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
|
||||
from .context_messages import SessionBackedMessage
|
||||
from .message_adapter import format_speaker_content
|
||||
|
||||
|
||||
def build_planner_prefix(
|
||||
*,
|
||||
timestamp: datetime,
|
||||
user_name: str,
|
||||
group_card: str = "",
|
||||
message_id: Optional[str] = None,
|
||||
include_message_id: bool = True,
|
||||
) -> str:
|
||||
"""构造 Maisaka 规划器使用的统一消息前缀。
|
||||
|
||||
Args:
|
||||
timestamp: 消息时间。
|
||||
user_name: 展示给规划器的用户名。
|
||||
group_card: 群昵称。
|
||||
message_id: 消息 ID。
|
||||
include_message_id: 是否输出 `msg_id` 段。
|
||||
|
||||
Returns:
|
||||
str: 拼接完成的规划器前缀。
|
||||
"""
|
||||
|
||||
prefix_parts = [
|
||||
f"[时间]{timestamp.strftime('%H:%M:%S')}\n",
|
||||
f"[用户名]{user_name}\n",
|
||||
f"[用户群昵称]{group_card}\n",
|
||||
]
|
||||
if include_message_id:
|
||||
prefix_parts.append(f"[msg_id]{message_id or ''}\n")
|
||||
prefix_parts.append("[发言内容]")
|
||||
return "".join(prefix_parts)
|
||||
|
||||
|
||||
def build_planner_user_prefix_from_session_message(message: SessionMessage) -> str:
|
||||
"""根据真实会话消息构造规划器前缀。
|
||||
|
||||
Args:
|
||||
message: 原始会话消息。
|
||||
|
||||
Returns:
|
||||
str: 规划器前缀字符串。
|
||||
"""
|
||||
|
||||
user_info = message.message_info.user_info
|
||||
user_name = user_info.user_nickname or user_info.user_id
|
||||
return build_planner_prefix(
|
||||
timestamp=message.timestamp,
|
||||
user_name=user_name,
|
||||
group_card=user_info.user_cardname or "",
|
||||
message_id=message.message_id,
|
||||
include_message_id=not message.is_notify and bool(message.message_id),
|
||||
)
|
||||
|
||||
|
||||
def build_session_backed_text_message(
|
||||
*,
|
||||
speaker_name: str,
|
||||
text: str,
|
||||
timestamp: datetime,
|
||||
source_kind: str,
|
||||
group_card: str = "",
|
||||
message_id: Optional[str] = None,
|
||||
include_message_id: bool = True,
|
||||
) -> SessionBackedMessage:
|
||||
"""构造带规划器前缀的纯文本历史消息。
|
||||
|
||||
Args:
|
||||
speaker_name: 发言者名称。
|
||||
text: 发言内容。
|
||||
timestamp: 发言时间。
|
||||
source_kind: 上下文来源类型。
|
||||
group_card: 群昵称。
|
||||
message_id: 消息 ID。
|
||||
include_message_id: 是否输出 `msg_id` 段。
|
||||
|
||||
Returns:
|
||||
SessionBackedMessage: 可直接写入历史的上下文消息。
|
||||
"""
|
||||
|
||||
planner_prefix = build_planner_prefix(
|
||||
timestamp=timestamp,
|
||||
user_name=speaker_name,
|
||||
group_card=group_card,
|
||||
message_id=message_id,
|
||||
include_message_id=include_message_id,
|
||||
)
|
||||
return SessionBackedMessage(
|
||||
raw_message=MessageSequence([TextComponent(f"{planner_prefix}{text}")]),
|
||||
visible_text=format_speaker_content(
|
||||
speaker_name,
|
||||
text,
|
||||
timestamp,
|
||||
message_id if include_message_id else None,
|
||||
),
|
||||
timestamp=timestamp,
|
||||
message_id=message_id,
|
||||
source_kind=source_kind,
|
||||
)
|
||||
306
src/maisaka/prompt_cli_renderer.py
Normal file
306
src/maisaka/prompt_cli_renderer.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""CLI 下的 Prompt 可视化渲染模块。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from base64 import b64decode
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote
|
||||
from typing import Any, Dict, List, Literal
|
||||
|
||||
import tempfile
|
||||
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
from rich.console import Group, RenderableType
|
||||
from rich.pretty import Pretty
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve()
|
||||
DATA_IMAGE_DIR = PROJECT_ROOT / "data" / "images"
|
||||
|
||||
|
||||
class PromptImageDisplayMode(str, Enum):
|
||||
"""图片在终端中的展示模式。"""
|
||||
|
||||
LEGACY = "legacy"
|
||||
"""不新增链接,仅保留原有的元信息展示。"""
|
||||
|
||||
PATH_LINK = "path_link"
|
||||
"""把图片落盘到临时目录并输出可点击路径。"""
|
||||
|
||||
|
||||
class PromptImageDisplaySettings(BaseModel):
|
||||
"""图片展示参数。"""
|
||||
|
||||
display_mode: PromptImageDisplayMode = PydanticField(default=PromptImageDisplayMode.LEGACY)
|
||||
"""图片展示模式。"""
|
||||
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _MessageRenderResult:
|
||||
"""可渲染结果与是否有工具调用信息。"""
|
||||
|
||||
message_panel: Panel
|
||||
tool_call_panels: List[Panel]
|
||||
|
||||
|
||||
class PromptCLIVisualizer:
|
||||
"""负责构建 CLI 下 prompt 展示所需的所有可视化组件。"""
|
||||
|
||||
@staticmethod
|
||||
def _get_role_badge_style(role: str) -> str:
|
||||
if role == "system":
|
||||
return "bold white on blue"
|
||||
if role == "user":
|
||||
return "bold black on green"
|
||||
if role == "assistant":
|
||||
return "bold black on yellow"
|
||||
if role == "tool":
|
||||
return "bold white on magenta"
|
||||
return "bold white on bright_black"
|
||||
|
||||
@staticmethod
|
||||
def _get_role_badge_label(role: str) -> str:
|
||||
if role == "system":
|
||||
return "系统"
|
||||
if role == "user":
|
||||
return "用户"
|
||||
if role == "assistant":
|
||||
return "助手"
|
||||
if role == "tool":
|
||||
return "工具"
|
||||
return "未知"
|
||||
|
||||
@staticmethod
|
||||
def _format_token_count(token_count: int) -> str:
|
||||
if token_count >= 10_000:
|
||||
return f"{token_count / 1000:.1f}k"
|
||||
return str(token_count)
|
||||
|
||||
@classmethod
|
||||
def build_prompt_stats_text(
|
||||
cls,
|
||||
*,
|
||||
selected_history_count: int,
|
||||
built_message_count: int,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
) -> str:
|
||||
"""构造 prompt 统计文本。"""
|
||||
return (
|
||||
f"上下文消息数量={selected_history_count} "
|
||||
f"已构建消息数={built_message_count} "
|
||||
f"实际输入Token={cls._format_token_count(prompt_tokens)} "
|
||||
f"输出Token={cls._format_token_count(completion_tokens)} "
|
||||
f"总Token={cls._format_token_count(total_tokens)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_image_format(image_format: str) -> str:
|
||||
"""归一化图片扩展名。"""
|
||||
normalized = image_format.strip().lower()
|
||||
if normalized == "jpg":
|
||||
return "jpeg"
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _build_image_cache_path(image_format: str, image_base64: str) -> Path:
|
||||
image_format = PromptCLIVisualizer._normalize_image_format(image_format)
|
||||
root = Path(tempfile.gettempdir()) / "maisaka_prompt_images"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest()
|
||||
return root / f"{digest}.{image_format}"
|
||||
|
||||
@staticmethod
|
||||
def _build_file_uri(file_path: Path) -> str:
|
||||
normalized = file_path.as_posix()
|
||||
return f"file:///{quote(normalized, safe='/:')}"
|
||||
|
||||
@staticmethod
|
||||
def _build_official_image_path(image_format: str, image_base64: str) -> Path | None:
|
||||
normalized_format = PromptCLIVisualizer._normalize_image_format(image_format)
|
||||
try:
|
||||
image_bytes = b64decode(image_base64)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
digest = hashlib.sha256(image_bytes).hexdigest()
|
||||
official_path = DATA_IMAGE_DIR / f"{digest}.{normalized_format}"
|
||||
if official_path.exists():
|
||||
return official_path
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _build_image_file_link(image_format: str, image_base64: str) -> tuple[str, Path] | None:
|
||||
"""优先返回正式图片路径;不存在时回退到临时缓存路径。"""
|
||||
normalized_format = PromptCLIVisualizer._normalize_image_format(image_format) or "bin"
|
||||
official_path = PromptCLIVisualizer._build_official_image_path(image_format, image_base64)
|
||||
if official_path is not None:
|
||||
return PromptCLIVisualizer._build_file_uri(official_path), official_path
|
||||
|
||||
try:
|
||||
image_bytes = b64decode(image_base64)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
path = PromptCLIVisualizer._build_image_cache_path(normalized_format, image_base64)
|
||||
if not path.exists():
|
||||
try:
|
||||
path.write_bytes(image_bytes)
|
||||
except Exception:
|
||||
return None
|
||||
return PromptCLIVisualizer._build_file_uri(path), path
|
||||
|
||||
@classmethod
|
||||
def _render_image_item(cls, image_format: str, image_base64: str, settings: PromptImageDisplaySettings) -> Panel:
|
||||
normalized_format = cls._normalize_image_format(image_format)
|
||||
approx_size = max(0, len(image_base64) * 3 // 4)
|
||||
size_text = f"{approx_size / 1024:.1f} KB" if approx_size >= 1024 else f"{approx_size} B"
|
||||
|
||||
preview_parts: List[RenderableType] = [
|
||||
Text(f"图片格式 image/{normalized_format} {size_text}", style="magenta")
|
||||
]
|
||||
|
||||
if settings.display_mode == PromptImageDisplayMode.PATH_LINK:
|
||||
path_result = cls._build_image_file_link(image_format, image_base64)
|
||||
if path_result is not None:
|
||||
file_uri, file_path = path_result
|
||||
preview_parts.append(Text.from_markup(f"\n[link={file_uri}]点击打开图片[/link]", style="cyan"))
|
||||
preview_parts.append(Text(f"\n{file_path}", style="dim"))
|
||||
|
||||
return Panel(
|
||||
Group(*preview_parts),
|
||||
border_style="magenta",
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _render_message_content(cls, content: Any, settings: PromptImageDisplaySettings) -> RenderableType:
|
||||
if isinstance(content, str):
|
||||
return Text(content)
|
||||
|
||||
if isinstance(content, list):
|
||||
parts: List[RenderableType] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(Text(item))
|
||||
continue
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
image_format, image_base64 = item
|
||||
if isinstance(image_format, str) and isinstance(image_base64, str):
|
||||
parts.append(cls._render_image_item(image_format, image_base64, settings))
|
||||
continue
|
||||
if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str):
|
||||
parts.append(Text(item["text"]))
|
||||
else:
|
||||
parts.append(Pretty(item, expand_all=True))
|
||||
return Group(*parts) if parts else Text("")
|
||||
|
||||
if content is None:
|
||||
return Text("")
|
||||
|
||||
return Pretty(content, expand_all=True)
|
||||
|
||||
@classmethod
|
||||
def format_tool_call_for_display(cls, tool_call: Any) -> Dict[str, Any]:
|
||||
if isinstance(tool_call, dict):
|
||||
function_info = tool_call.get("function", {})
|
||||
return {
|
||||
"id": tool_call.get("id"),
|
||||
"name": function_info.get("name", tool_call.get("name")),
|
||||
"arguments": function_info.get("arguments", tool_call.get("arguments")),
|
||||
}
|
||||
|
||||
return {
|
||||
"id": getattr(tool_call, "call_id", getattr(tool_call, "id", None)),
|
||||
"name": getattr(tool_call, "func_name", getattr(tool_call, "name", None)),
|
||||
"arguments": getattr(tool_call, "args", getattr(tool_call, "arguments", None)),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _render_tool_call_panel(cls, tool_call: Any, index: int, parent_index: int) -> Panel:
|
||||
title = Text.assemble(
|
||||
Text(" 工具调用 ", style="bold white on magenta"),
|
||||
Text(f" #{parent_index}.{index}", style="muted"),
|
||||
)
|
||||
return Panel(
|
||||
Pretty(cls.format_tool_call_for_display(tool_call), expand_all=True),
|
||||
title=title,
|
||||
border_style="magenta",
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _render_message_panel(cls, message: Any, index: int, settings: PromptImageDisplaySettings) -> _MessageRenderResult:
|
||||
if isinstance(message, dict):
|
||||
raw_role = message.get("role", "unknown")
|
||||
content = message.get("content")
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
else:
|
||||
raw_role = getattr(message, "role", "unknown")
|
||||
content = getattr(message, "content", None)
|
||||
tool_call_id = getattr(message, "tool_call_id", None)
|
||||
|
||||
role = raw_role.value if hasattr(raw_role, "value") else str(raw_role)
|
||||
title = Text.assemble(
|
||||
Text(f" {cls._get_role_badge_label(role)} ", style=cls._get_role_badge_style(role)),
|
||||
Text(f" #{index}", style="muted"),
|
||||
)
|
||||
|
||||
parts: List[RenderableType] = []
|
||||
if content not in (None, "", []):
|
||||
parts.append(Text(" 内容 ", style="bold cyan"))
|
||||
parts.append(cls._render_message_content(content, settings))
|
||||
|
||||
if tool_call_id:
|
||||
parts.append(
|
||||
Text.assemble(
|
||||
Text(" 工具调用ID ", style="bold magenta"),
|
||||
Text(" "),
|
||||
Text(str(tool_call_id), style="magenta"),
|
||||
)
|
||||
)
|
||||
|
||||
if not parts:
|
||||
parts.append(Text("[空]", style="muted"))
|
||||
|
||||
message_panel = Panel(
|
||||
Group(*parts),
|
||||
title=title,
|
||||
border_style="dim",
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
tool_call_panels: List[Panel] = []
|
||||
tool_calls = getattr(message, "tool_calls", None)
|
||||
if tool_calls:
|
||||
for tool_call_index, tool_call in enumerate(tool_calls, start=1):
|
||||
tool_call_panels.append(cls._render_tool_call_panel(tool_call, tool_call_index, index))
|
||||
|
||||
return _MessageRenderResult(message_panel=message_panel, tool_call_panels=tool_call_panels)
|
||||
|
||||
@classmethod
|
||||
def build_prompt_panels(
|
||||
cls,
|
||||
messages: list[Any],
|
||||
*,
|
||||
image_display_mode: Literal["legacy", "path_link"],
|
||||
) -> List[Panel]:
|
||||
"""构建完整 prompt 可视化面板。"""
|
||||
if image_display_mode not in {mode.value for mode in PromptImageDisplayMode}:
|
||||
image_display_mode = PromptImageDisplayMode.LEGACY
|
||||
settings = PromptImageDisplaySettings(
|
||||
display_mode=PromptImageDisplayMode(image_display_mode),
|
||||
)
|
||||
|
||||
ordered_panels: List[Panel] = []
|
||||
for index, message in enumerate(messages, start=1):
|
||||
message_render_result = cls._render_message_panel(message, index, settings)
|
||||
ordered_panels.append(message_render_result.message_panel)
|
||||
ordered_panels.extend(message_render_result.tool_call_panels)
|
||||
return ordered_panels
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,10 @@ from typing import Literal, Optional
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from src.cli.console import console
|
||||
from src.chat.heart_flow.heartFC_utils import CycleDetail
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
@@ -45,7 +49,10 @@ class MaisakaHeartFlowChatting:
|
||||
|
||||
session_name = chat_manager.get_session_name(session_id) or session_id
|
||||
self.log_prefix = f"[{session_name}]"
|
||||
self._chat_loop_service = MaisakaChatLoopService()
|
||||
self._chat_loop_service = MaisakaChatLoopService(
|
||||
session_id=session_id,
|
||||
is_group_chat=self.chat_stream.is_group_session,
|
||||
)
|
||||
self._chat_history: list[LLMContextMessage] = []
|
||||
self.history_loop: list[CycleDetail] = []
|
||||
|
||||
@@ -431,6 +438,40 @@ class MaisakaHeartFlowChatting:
|
||||
|
||||
return GroupInfo(group_id=group_info.group_id, group_name=group_info.group_name)
|
||||
|
||||
@staticmethod
|
||||
def _format_token_count(token_count: int) -> str:
|
||||
"""格式化 token 数量展示文本。"""
|
||||
if token_count >= 10_000:
|
||||
return f"{token_count / 1000:.1f}k"
|
||||
return str(token_count)
|
||||
|
||||
def _render_context_usage_panel(
|
||||
self,
|
||||
*,
|
||||
selected_history_count: int,
|
||||
prompt_tokens: int,
|
||||
) -> None:
|
||||
"""在终端展示当前聊天流的上下文占用情况。"""
|
||||
if not global_config.maisaka.show_thinking:
|
||||
return
|
||||
|
||||
session_name = chat_manager.get_session_name(self.session_id) or self.session_id
|
||||
body = "\n".join(
|
||||
[
|
||||
f"聊天流: {session_name}",
|
||||
f"Chat ID: {self.session_id}",
|
||||
f"上下文占用: {selected_history_count}条 / {self._format_token_count(prompt_tokens)}",
|
||||
]
|
||||
)
|
||||
console.print(
|
||||
Panel(
|
||||
Text(body),
|
||||
title="MaiSaka 上下文占用",
|
||||
border_style="bright_blue",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
def _log_cycle_started(self, cycle_detail: CycleDetail, round_index: int) -> None:
|
||||
logger.info(
|
||||
f"{self.log_prefix} MaiSaka 轮次开始: 循环编号={cycle_detail.cycle_id} "
|
||||
|
||||
@@ -32,19 +32,6 @@ class ToolHandlerContext:
|
||||
self.last_user_input_time: Optional[datetime] = None
|
||||
|
||||
|
||||
async def handle_stop(tc: ToolCall, chat_history: list[LLMContextMessage]) -> None:
|
||||
"""处理 stop 工具。"""
|
||||
console.print("[accent]调用工具: stop()[/accent]")
|
||||
chat_history.append(
|
||||
ToolResultMessage(
|
||||
content="当前轮次结束后将停止对话循环。",
|
||||
timestamp=datetime.now(),
|
||||
tool_call_id=tc.call_id,
|
||||
tool_name=tc.func_name,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def handle_wait(tc: ToolCall, chat_history: list[LLMContextMessage], ctx: ToolHandlerContext) -> str:
|
||||
"""处理 wait 工具。"""
|
||||
seconds = (tc.args or {}).get("seconds", 30)
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Dict, Optional
|
||||
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec
|
||||
|
||||
from .builtin_tools import get_builtin_tool_specs
|
||||
from .builtin_tool import get_builtin_tool_specs
|
||||
|
||||
BuiltinToolHandler = Callable[[ToolInvocation, Optional[ToolExecutionContext]], Awaitable[ToolExecutionResult]]
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import httpx
|
||||
from src.cli.console import console
|
||||
from src.core.tooling import ToolExecutionResult
|
||||
|
||||
from .config import MCPClientRuntimeConfig, MCPRootRuntimeConfig, MCPServerRuntimeConfig
|
||||
from .config import MCPClientRuntimeConfig, MCPServerRuntimeConfig
|
||||
from .hooks import MCPHostCallbacks
|
||||
from .models import (
|
||||
MCPPromptResult,
|
||||
|
||||
@@ -20,5 +20,8 @@ ENV_HOST_VERSION = "MAIBOT_HOST_VERSION"
|
||||
ENV_EXTERNAL_PLUGIN_IDS = "MAIBOT_EXTERNAL_PLUGIN_IDS"
|
||||
"""Runner 启动时可视为已满足的外部插件依赖版本映射(JSON 对象)"""
|
||||
|
||||
ENV_BLOCKED_PLUGIN_REASONS = "MAIBOT_BLOCKED_PLUGIN_REASONS"
|
||||
"""Runner 启动时收到的拒绝加载插件原因映射(JSON 对象)"""
|
||||
|
||||
ENV_GLOBAL_CONFIG_SNAPSHOT = "MAIBOT_GLOBAL_CONFIG_SNAPSHOT"
|
||||
"""Runner 启动时注入的全局配置快照(JSON 对象)"""
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from .components import RuntimeComponentCapabilityMixin
|
||||
from .core import RuntimeCoreCapabilityMixin
|
||||
from .data import RuntimeDataCapabilityMixin
|
||||
from .render import RuntimeRenderCapabilityMixin
|
||||
|
||||
__all__ = [
|
||||
"RuntimeComponentCapabilityMixin",
|
||||
"RuntimeCoreCapabilityMixin",
|
||||
"RuntimeDataCapabilityMixin",
|
||||
"RuntimeRenderCapabilityMixin",
|
||||
]
|
||||
|
||||
@@ -458,6 +458,17 @@ class RuntimeComponentCapabilityMixin:
|
||||
async def _cap_component_get_plugin_info(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""获取指定插件的基础信息。
|
||||
|
||||
Args:
|
||||
plugin_id: 当前调用方插件 ID。
|
||||
capability: 当前能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 插件基础信息响应。
|
||||
"""
|
||||
|
||||
plugin_name: str = args.get("plugin_name", plugin_id)
|
||||
try:
|
||||
sv = self._get_supervisor_for_plugin(plugin_name)
|
||||
@@ -473,10 +484,46 @@ class RuntimeComponentCapabilityMixin:
|
||||
"description": "",
|
||||
"author": "",
|
||||
"enabled": True,
|
||||
"default_config": reg.default_config,
|
||||
"config_schema": reg.config_schema,
|
||||
},
|
||||
}
|
||||
return {"success": False, "error": f"未找到插件: {plugin_name}"}
|
||||
|
||||
async def _cap_component_get_plugin_config_schema(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""获取指定插件注册时上报的配置 Schema。
|
||||
|
||||
Args:
|
||||
plugin_id: 当前调用方插件 ID。
|
||||
capability: 当前能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 包含配置 Schema 与默认配置的响应。
|
||||
"""
|
||||
|
||||
plugin_name: str = args.get("plugin_name", plugin_id)
|
||||
try:
|
||||
sv = self._get_supervisor_for_plugin(plugin_name)
|
||||
except RuntimeError as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
if sv is None:
|
||||
return {"success": False, "error": f"未找到插件: {plugin_name}"}
|
||||
|
||||
registration = sv._registered_plugins.get(plugin_name)
|
||||
if registration is None:
|
||||
return {"success": False, "error": f"未找到插件: {plugin_name}"}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"plugin_id": plugin_name,
|
||||
"schema": registration.config_schema,
|
||||
"default_config": registration.default_config,
|
||||
}
|
||||
|
||||
async def _cap_component_list_loaded_plugins(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
) -> Any:
|
||||
|
||||
@@ -81,6 +81,7 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi
|
||||
|
||||
_register("component.get_all_plugins", manager._cap_component_get_all_plugins)
|
||||
_register("component.get_plugin_info", manager._cap_component_get_plugin_info)
|
||||
_register("component.get_plugin_config_schema", manager._cap_component_get_plugin_config_schema)
|
||||
_register("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
|
||||
_register("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
|
||||
_register("component.enable", manager._cap_component_enable)
|
||||
@@ -90,4 +91,5 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi
|
||||
_register("component.reload_plugin", manager._cap_component_reload_plugin)
|
||||
|
||||
_register("knowledge.search", manager._cap_knowledge_search)
|
||||
_register("render.html2png", manager._cap_render_html2png)
|
||||
logger.debug("已注册全部主程序能力实现")
|
||||
|
||||
121
src/plugin_runtime/capabilities/render.py
Normal file
121
src/plugin_runtime/capabilities/render.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""插件运行时的浏览器渲染能力。"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.services.html_render_service import HtmlRenderRequest, get_html_render_service
|
||||
|
||||
logger = get_logger("plugin_runtime.integration")
|
||||
|
||||
|
||||
class RuntimeRenderCapabilityMixin:
|
||||
"""插件运行时的浏览器渲染能力混入。"""
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
"""将任意值尽量转换为整数。
|
||||
|
||||
Args:
|
||||
value: 原始输入值。
|
||||
default: 转换失败时返回的默认值。
|
||||
|
||||
Returns:
|
||||
int: 规范化后的整数结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _coerce_float(value: Any, default: float) -> float:
|
||||
"""将任意值尽量转换为浮点数。
|
||||
|
||||
Args:
|
||||
value: 原始输入值。
|
||||
default: 转换失败时返回的默认值。
|
||||
|
||||
Returns:
|
||||
float: 规范化后的浮点结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _coerce_bool(value: Any, default: bool = False) -> bool:
|
||||
"""将任意值转换为布尔值。
|
||||
|
||||
Args:
|
||||
value: 原始输入值。
|
||||
default: 输入为空时返回的默认值。
|
||||
|
||||
Returns:
|
||||
bool: 规范化后的布尔结果。
|
||||
"""
|
||||
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
normalized_value = value.strip().lower()
|
||||
if normalized_value in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
if normalized_value in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
return bool(value)
|
||||
|
||||
def _build_html_render_request(self, args: Dict[str, Any]) -> HtmlRenderRequest:
|
||||
"""根据 capability 调用参数构造渲染请求。
|
||||
|
||||
Args:
|
||||
args: capability 调用参数。
|
||||
|
||||
Returns:
|
||||
HtmlRenderRequest: 结构化后的渲染请求。
|
||||
"""
|
||||
|
||||
viewport = args.get("viewport", {})
|
||||
viewport_width = 900
|
||||
viewport_height = 500
|
||||
if isinstance(viewport, dict):
|
||||
viewport_width = self._coerce_int(viewport.get("width"), viewport_width)
|
||||
viewport_height = self._coerce_int(viewport.get("height"), viewport_height)
|
||||
|
||||
return HtmlRenderRequest(
|
||||
html=str(args.get("html", "") or ""),
|
||||
selector=str(args.get("selector", "body") or "body"),
|
||||
viewport_width=viewport_width,
|
||||
viewport_height=viewport_height,
|
||||
device_scale_factor=self._coerce_float(args.get("device_scale_factor"), 2.0),
|
||||
full_page=self._coerce_bool(args.get("full_page"), False),
|
||||
omit_background=self._coerce_bool(args.get("omit_background"), False),
|
||||
wait_until=str(args.get("wait_until", "load") or "load"),
|
||||
wait_for_selector=str(args.get("wait_for_selector", "") or ""),
|
||||
wait_for_timeout_ms=self._coerce_int(args.get("wait_for_timeout_ms"), 0),
|
||||
timeout_ms=self._coerce_int(args.get("timeout_ms"), 0),
|
||||
allow_network=self._coerce_bool(args.get("allow_network"), False),
|
||||
)
|
||||
|
||||
async def _cap_render_html2png(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""将 HTML 内容渲染为 PNG 图片。
|
||||
|
||||
Args:
|
||||
plugin_id: 调用该能力的插件 ID。
|
||||
capability: 当前能力名称。
|
||||
args: 能力调用参数。
|
||||
|
||||
Returns:
|
||||
Any: 标准化后的能力返回结构。
|
||||
"""
|
||||
|
||||
del plugin_id, capability
|
||||
try:
|
||||
request = self._build_html_render_request(args)
|
||||
result = await get_html_render_service().render_html_to_png(request)
|
||||
return {"success": True, "result": result.to_payload()}
|
||||
except Exception as exc:
|
||||
logger.error(f"[cap.render.html2png] 执行失败: {exc}", exc_info=True)
|
||||
return {"success": False, "error": str(exc)}
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple, cast
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -858,5 +859,77 @@ class ComponentQueryService:
|
||||
logger.error(f"读取插件 {plugin_name} 配置失败: {exc}", exc_info=True)
|
||||
return None
|
||||
|
||||
def get_plugin_default_config(self, plugin_name: str) -> Optional[dict]:
|
||||
"""获取指定插件注册时上报的默认配置。
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称。
|
||||
|
||||
Returns:
|
||||
Optional[dict]: 默认配置字典;未找到时返回 ``None``。
|
||||
"""
|
||||
|
||||
runtime_manager = self._get_runtime_manager()
|
||||
try:
|
||||
supervisor = runtime_manager._get_supervisor_for_plugin(plugin_name)
|
||||
except RuntimeError as exc:
|
||||
logger.error(f"读取插件默认配置失败: {exc}")
|
||||
return None
|
||||
|
||||
if supervisor is None:
|
||||
return None
|
||||
|
||||
registration = supervisor._registered_plugins.get(plugin_name)
|
||||
if registration is None:
|
||||
return None
|
||||
return dict(registration.default_config)
|
||||
|
||||
def get_plugin_config_schema(self, plugin_name: str) -> Optional[dict]:
|
||||
"""获取指定插件注册时上报的配置 Schema。
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称。
|
||||
|
||||
Returns:
|
||||
Optional[dict]: 配置 Schema;未找到时返回 ``None``。
|
||||
"""
|
||||
|
||||
runtime_manager = self._get_runtime_manager()
|
||||
try:
|
||||
supervisor = runtime_manager._get_supervisor_for_plugin(plugin_name)
|
||||
except RuntimeError as exc:
|
||||
logger.error(f"读取插件配置 Schema 失败: {exc}")
|
||||
return None
|
||||
|
||||
if supervisor is None:
|
||||
return None
|
||||
|
||||
registration = supervisor._registered_plugins.get(plugin_name)
|
||||
if registration is None:
|
||||
return None
|
||||
return dict(registration.config_schema)
|
||||
|
||||
def list_hook_specs(self) -> list[dict[str, Any]]:
|
||||
"""返回当前运行时公开的 Hook 规格清单。
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: 可直接序列化给 WebUI 的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
runtime_manager = self._get_runtime_manager()
|
||||
return [
|
||||
{
|
||||
"name": spec.name,
|
||||
"description": spec.description,
|
||||
"parameters_schema": deepcopy(spec.parameters_schema),
|
||||
"default_timeout_ms": spec.default_timeout_ms,
|
||||
"allow_blocking": spec.allow_blocking,
|
||||
"allow_observe": spec.allow_observe,
|
||||
"allow_abort": spec.allow_abort,
|
||||
"allow_kwargs_mutation": spec.allow_kwargs_mutation,
|
||||
}
|
||||
for spec in runtime_manager.list_hook_specs()
|
||||
]
|
||||
|
||||
|
||||
component_query_service = ComponentQueryService()
|
||||
|
||||
441
src/plugin_runtime/dependency_pipeline.py
Normal file
441
src/plugin_runtime/dependency_pipeline.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""插件 Python 依赖流水线。
|
||||
|
||||
负责在 Host 侧统一完成以下工作:
|
||||
1. 扫描插件 Manifest;
|
||||
2. 检测插件与主程序、插件与插件之间的 Python 依赖冲突;
|
||||
3. 为可加载插件自动安装缺失的 Python 依赖;
|
||||
4. 产出最终的拒绝加载列表,供运行时使用。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from packaging.utils import canonicalize_name
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator, PluginManifest
|
||||
|
||||
|
||||
logger = get_logger("plugin_runtime.dependency_pipeline")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PackageDependencyUsage:
|
||||
"""记录单个插件对某个 Python 包的依赖声明。"""
|
||||
|
||||
package_name: str
|
||||
plugin_id: str
|
||||
version_spec: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CombinedPackageRequirement:
|
||||
"""表示一个已经合并后的 Python 包安装需求。"""
|
||||
|
||||
package_name: str
|
||||
plugin_ids: Tuple[str, ...]
|
||||
requirement_text: str
|
||||
version_spec: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DependencyPipelinePlan:
|
||||
"""表示一次依赖分析后得到的计划。"""
|
||||
|
||||
blocked_plugin_reasons: Dict[str, str]
|
||||
install_requirements: Tuple[CombinedPackageRequirement, ...]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DependencyPipelineResult:
|
||||
"""表示一次依赖流水线执行后的结果。"""
|
||||
|
||||
blocked_plugin_reasons: Dict[str, str]
|
||||
environment_changed: bool
|
||||
install_requirements: Tuple[CombinedPackageRequirement, ...]
|
||||
|
||||
|
||||
class PluginDependencyPipeline:
|
||||
"""插件依赖流水线。
|
||||
|
||||
该类不负责插件启停,只负责对插件目录进行依赖分析,并在必要时
|
||||
使用 ``uv`` 为可加载插件补齐缺失的 Python 依赖。
|
||||
"""
|
||||
|
||||
def __init__(self, project_root: Optional[Path] = None) -> None:
|
||||
"""初始化依赖流水线。
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录;留空时自动推断。
|
||||
"""
|
||||
|
||||
self._project_root: Path = project_root or Path(__file__).resolve().parents[2]
|
||||
self._manifest_validator: ManifestValidator = ManifestValidator(
|
||||
project_root=self._project_root,
|
||||
validate_python_package_dependencies=False,
|
||||
)
|
||||
|
||||
async def execute(self, plugin_dirs: Iterable[Path]) -> DependencyPipelineResult:
|
||||
"""执行完整的依赖分析与自动安装流程。
|
||||
|
||||
Args:
|
||||
plugin_dirs: 需要扫描的插件根目录集合。
|
||||
|
||||
Returns:
|
||||
DependencyPipelineResult: 最终的阻止加载结果与环境变更状态。
|
||||
"""
|
||||
|
||||
plan = self.build_plan(plugin_dirs)
|
||||
if not plan.install_requirements:
|
||||
return DependencyPipelineResult(
|
||||
blocked_plugin_reasons=dict(plan.blocked_plugin_reasons),
|
||||
environment_changed=False,
|
||||
install_requirements=plan.install_requirements,
|
||||
)
|
||||
|
||||
install_succeeded, error_message = await self._install_requirements(plan.install_requirements)
|
||||
if install_succeeded:
|
||||
return DependencyPipelineResult(
|
||||
blocked_plugin_reasons=dict(plan.blocked_plugin_reasons),
|
||||
environment_changed=True,
|
||||
install_requirements=plan.install_requirements,
|
||||
)
|
||||
|
||||
blocked_plugin_reasons = dict(plan.blocked_plugin_reasons)
|
||||
affected_plugin_ids = sorted(
|
||||
{
|
||||
plugin_id
|
||||
for requirement in plan.install_requirements
|
||||
for plugin_id in requirement.plugin_ids
|
||||
}
|
||||
)
|
||||
for plugin_id in affected_plugin_ids:
|
||||
self._append_block_reason(
|
||||
blocked_plugin_reasons,
|
||||
plugin_id,
|
||||
f"自动安装 Python 依赖失败: {error_message}",
|
||||
)
|
||||
|
||||
return DependencyPipelineResult(
|
||||
blocked_plugin_reasons=blocked_plugin_reasons,
|
||||
environment_changed=False,
|
||||
install_requirements=plan.install_requirements,
|
||||
)
|
||||
|
||||
def build_plan(self, plugin_dirs: Iterable[Path]) -> DependencyPipelinePlan:
|
||||
"""构建依赖分析计划。
|
||||
|
||||
Args:
|
||||
plugin_dirs: 需要扫描的插件根目录集合。
|
||||
|
||||
Returns:
|
||||
DependencyPipelinePlan: 分析后的阻止加载列表与安装计划。
|
||||
"""
|
||||
|
||||
manifests = self._collect_manifests(plugin_dirs)
|
||||
blocked_plugin_reasons = self._detect_host_conflicts(manifests)
|
||||
plugin_conflict_reasons = self._detect_plugin_conflicts(manifests, blocked_plugin_reasons)
|
||||
for plugin_id, reason in plugin_conflict_reasons.items():
|
||||
self._append_block_reason(blocked_plugin_reasons, plugin_id, reason)
|
||||
|
||||
install_requirements = self._build_install_requirements(manifests, blocked_plugin_reasons)
|
||||
return DependencyPipelinePlan(
|
||||
blocked_plugin_reasons=blocked_plugin_reasons,
|
||||
install_requirements=install_requirements,
|
||||
)
|
||||
|
||||
def _collect_manifests(self, plugin_dirs: Iterable[Path]) -> Dict[str, PluginManifest]:
|
||||
"""收集所有可成功解析的插件 Manifest。
|
||||
|
||||
Args:
|
||||
plugin_dirs: 需要扫描的插件根目录集合。
|
||||
|
||||
Returns:
|
||||
Dict[str, PluginManifest]: 以插件 ID 为键的 Manifest 映射。
|
||||
"""
|
||||
|
||||
manifests: Dict[str, PluginManifest] = {}
|
||||
for _plugin_path, manifest in self._manifest_validator.iter_plugin_manifests(plugin_dirs):
|
||||
manifests[manifest.id] = manifest
|
||||
return manifests
|
||||
|
||||
def _detect_host_conflicts(self, manifests: Dict[str, PluginManifest]) -> Dict[str, str]:
|
||||
"""检测插件与主程序依赖之间的冲突。
|
||||
|
||||
Args:
|
||||
manifests: 当前已解析到的插件 Manifest 映射。
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 需要被阻止加载的插件及原因。
|
||||
"""
|
||||
|
||||
host_requirements = self._manifest_validator.load_host_dependency_requirements()
|
||||
blocked_plugin_reasons: Dict[str, str] = {}
|
||||
|
||||
for manifest in manifests.values():
|
||||
for dependency in manifest.python_package_dependencies:
|
||||
package_specifier = self._manifest_validator.build_specifier_set(dependency.version_spec)
|
||||
if package_specifier is None:
|
||||
self._append_block_reason(
|
||||
blocked_plugin_reasons,
|
||||
manifest.id,
|
||||
f"Python 包依赖声明无效: {dependency.name}{dependency.version_spec}",
|
||||
)
|
||||
continue
|
||||
|
||||
normalized_package_name = canonicalize_name(dependency.name)
|
||||
host_requirement = host_requirements.get(normalized_package_name)
|
||||
if host_requirement is None:
|
||||
continue
|
||||
|
||||
if self._manifest_validator.requirements_may_overlap(
|
||||
host_requirement.specifier,
|
||||
package_specifier,
|
||||
):
|
||||
continue
|
||||
|
||||
host_specifier_text = str(host_requirement.specifier or "") or "任意版本"
|
||||
self._append_block_reason(
|
||||
blocked_plugin_reasons,
|
||||
manifest.id,
|
||||
(
|
||||
f"Python 包依赖与主程序冲突: {dependency.name} 需要 "
|
||||
f"{dependency.version_spec},主程序约束为 {host_specifier_text}"
|
||||
),
|
||||
)
|
||||
|
||||
return blocked_plugin_reasons
|
||||
|
||||
def _detect_plugin_conflicts(
|
||||
self,
|
||||
manifests: Dict[str, PluginManifest],
|
||||
blocked_plugin_reasons: Dict[str, str],
|
||||
) -> Dict[str, str]:
|
||||
"""检测插件之间的 Python 依赖冲突。
|
||||
|
||||
Args:
|
||||
manifests: 当前已解析到的插件 Manifest 映射。
|
||||
blocked_plugin_reasons: 已经因为其他原因被阻止加载的插件。
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 新增的插件冲突原因映射。
|
||||
"""
|
||||
|
||||
blocked_by_plugin_conflicts: Dict[str, str] = {}
|
||||
dependency_usages = self._collect_package_usages(manifests, blocked_plugin_reasons)
|
||||
|
||||
for _package_name, usages in dependency_usages.items():
|
||||
display_package_name = usages[0].package_name
|
||||
for index, left_usage in enumerate(usages):
|
||||
for right_usage in usages[index + 1 :]:
|
||||
left_specifier = self._manifest_validator.build_specifier_set(left_usage.version_spec)
|
||||
right_specifier = self._manifest_validator.build_specifier_set(right_usage.version_spec)
|
||||
if left_specifier is None or right_specifier is None:
|
||||
continue
|
||||
|
||||
if self._manifest_validator.requirements_may_overlap(left_specifier, right_specifier):
|
||||
continue
|
||||
|
||||
left_reason = (
|
||||
f"Python 包依赖冲突: 与插件 {right_usage.plugin_id} 在 {display_package_name} 上的约束不兼容 "
|
||||
f"({left_usage.version_spec} vs {right_usage.version_spec})"
|
||||
)
|
||||
right_reason = (
|
||||
f"Python 包依赖冲突: 与插件 {left_usage.plugin_id} 在 {display_package_name} 上的约束不兼容 "
|
||||
f"({right_usage.version_spec} vs {left_usage.version_spec})"
|
||||
)
|
||||
self._append_block_reason(blocked_by_plugin_conflicts, left_usage.plugin_id, left_reason)
|
||||
self._append_block_reason(blocked_by_plugin_conflicts, right_usage.plugin_id, right_reason)
|
||||
|
||||
return blocked_by_plugin_conflicts
|
||||
|
||||
def _collect_package_usages(
|
||||
self,
|
||||
manifests: Dict[str, PluginManifest],
|
||||
blocked_plugin_reasons: Dict[str, str],
|
||||
) -> Dict[str, List[PackageDependencyUsage]]:
|
||||
"""收集所有未被阻止加载插件的包依赖声明。
|
||||
|
||||
Args:
|
||||
manifests: 当前已解析到的插件 Manifest 映射。
|
||||
blocked_plugin_reasons: 已经被阻止加载的插件及原因。
|
||||
|
||||
Returns:
|
||||
Dict[str, List[PackageDependencyUsage]]: 按规范化包名分组后的依赖声明。
|
||||
"""
|
||||
|
||||
dependency_usages: Dict[str, List[PackageDependencyUsage]] = {}
|
||||
for manifest in manifests.values():
|
||||
if manifest.id in blocked_plugin_reasons:
|
||||
continue
|
||||
|
||||
for dependency in manifest.python_package_dependencies:
|
||||
normalized_package_name = canonicalize_name(dependency.name)
|
||||
dependency_usages.setdefault(normalized_package_name, []).append(
|
||||
PackageDependencyUsage(
|
||||
package_name=dependency.name,
|
||||
plugin_id=manifest.id,
|
||||
version_spec=dependency.version_spec,
|
||||
)
|
||||
)
|
||||
|
||||
return dependency_usages
|
||||
|
||||
def _build_install_requirements(
|
||||
self,
|
||||
manifests: Dict[str, PluginManifest],
|
||||
blocked_plugin_reasons: Dict[str, str],
|
||||
) -> Tuple[CombinedPackageRequirement, ...]:
|
||||
"""构建需要安装到当前环境的 Python 包需求列表。
|
||||
|
||||
Args:
|
||||
manifests: 当前已解析到的插件 Manifest 映射。
|
||||
blocked_plugin_reasons: 已经被阻止加载的插件及原因。
|
||||
|
||||
Returns:
|
||||
Tuple[CombinedPackageRequirement, ...]: 需要安装或调整版本的依赖列表。
|
||||
"""
|
||||
|
||||
combined_requirements: List[CombinedPackageRequirement] = []
|
||||
dependency_usages = self._collect_package_usages(manifests, blocked_plugin_reasons)
|
||||
|
||||
for usages in dependency_usages.values():
|
||||
merged_specifier_text = self._merge_specifier_texts([usage.version_spec for usage in usages])
|
||||
package_name = usages[0].package_name
|
||||
requirement_text = f"{package_name}{merged_specifier_text}"
|
||||
installed_version = self._manifest_validator.get_installed_package_version(package_name)
|
||||
if installed_version is not None and self._manifest_validator.version_matches_specifier(
|
||||
installed_version,
|
||||
merged_specifier_text,
|
||||
):
|
||||
continue
|
||||
|
||||
combined_requirements.append(
|
||||
CombinedPackageRequirement(
|
||||
package_name=package_name,
|
||||
plugin_ids=tuple(sorted({usage.plugin_id for usage in usages})),
|
||||
requirement_text=requirement_text,
|
||||
version_spec=merged_specifier_text,
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(sorted(combined_requirements, key=lambda requirement: canonicalize_name(requirement.package_name)))
|
||||
|
||||
@staticmethod
|
||||
def _merge_specifier_texts(specifier_texts: Sequence[str]) -> str:
|
||||
"""合并多个版本约束文本。
|
||||
|
||||
Args:
|
||||
specifier_texts: 需要合并的版本约束文本序列。
|
||||
|
||||
Returns:
|
||||
str: 合并后的版本约束文本。
|
||||
"""
|
||||
|
||||
merged_parts: List[str] = []
|
||||
for specifier_text in specifier_texts:
|
||||
for part in str(specifier_text or "").split(","):
|
||||
normalized_part = part.strip()
|
||||
if not normalized_part or normalized_part in merged_parts:
|
||||
continue
|
||||
merged_parts.append(normalized_part)
|
||||
return f"{','.join(merged_parts)}" if merged_parts else ""
|
||||
|
||||
async def _install_requirements(self, requirements: Sequence[CombinedPackageRequirement]) -> Tuple[bool, str]:
|
||||
"""安装指定的 Python 包需求列表。
|
||||
|
||||
Args:
|
||||
requirements: 需要安装的依赖列表。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: 安装是否成功,以及错误摘要。
|
||||
"""
|
||||
|
||||
requirement_texts = [requirement.requirement_text for requirement in requirements]
|
||||
if not requirement_texts:
|
||||
return True, ""
|
||||
|
||||
logger.info(f"开始自动安装插件 Python 依赖: {', '.join(requirement_texts)}")
|
||||
command = self._build_install_command(requirement_texts)
|
||||
|
||||
try:
|
||||
completed_process = await asyncio.to_thread(
|
||||
subprocess.run,
|
||||
command,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
cwd=self._project_root,
|
||||
text=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
return False, str(exc)
|
||||
|
||||
if completed_process.returncode == 0:
|
||||
logger.info("插件 Python 依赖自动安装完成")
|
||||
return True, ""
|
||||
|
||||
output = self._summarize_install_error(completed_process.stdout, completed_process.stderr)
|
||||
return False, output or f"命令执行失败,退出码 {completed_process.returncode}"
|
||||
|
||||
@staticmethod
|
||||
def _build_install_command(requirement_texts: Sequence[str]) -> List[str]:
|
||||
"""构造依赖安装命令。
|
||||
|
||||
Args:
|
||||
requirement_texts: 待安装的依赖文本序列。
|
||||
|
||||
Returns:
|
||||
List[str]: 适用于 ``subprocess.run`` 的命令参数列表。
|
||||
"""
|
||||
|
||||
if shutil.which("uv"):
|
||||
return ["uv", "pip", "install", "--python", sys.executable, *requirement_texts]
|
||||
return [sys.executable, "-m", "pip", "install", *requirement_texts]
|
||||
|
||||
@staticmethod
|
||||
def _summarize_install_error(stdout: str, stderr: str) -> str:
|
||||
"""提炼安装失败输出。
|
||||
|
||||
Args:
|
||||
stdout: 标准输出内容。
|
||||
stderr: 标准错误内容。
|
||||
|
||||
Returns:
|
||||
str: 简短的错误摘要。
|
||||
"""
|
||||
|
||||
merged_output = "\n".join(part.strip() for part in (stderr, stdout) if part and part.strip()).strip()
|
||||
if not merged_output:
|
||||
return ""
|
||||
lines = [line.strip() for line in merged_output.splitlines() if line.strip()]
|
||||
return " | ".join(lines[-5:])
|
||||
|
||||
@staticmethod
|
||||
def _append_block_reason(
|
||||
blocked_plugin_reasons: Dict[str, str],
|
||||
plugin_id: str,
|
||||
reason: str,
|
||||
) -> None:
|
||||
"""向阻止加载映射中追加原因。
|
||||
|
||||
Args:
|
||||
blocked_plugin_reasons: 待更新的阻止加载映射。
|
||||
plugin_id: 目标插件 ID。
|
||||
reason: 需要追加的原因文本。
|
||||
"""
|
||||
|
||||
existing_reason = blocked_plugin_reasons.get(plugin_id)
|
||||
if existing_reason is None:
|
||||
blocked_plugin_reasons[plugin_id] = reason
|
||||
return
|
||||
|
||||
existing_parts = [part.strip() for part in existing_reason.split(";") if part.strip()]
|
||||
if reason in existing_parts:
|
||||
return
|
||||
blocked_plugin_reasons[plugin_id] = f"{existing_reason};{reason}"
|
||||
52
src/plugin_runtime/hook_catalog.py
Normal file
52
src/plugin_runtime/hook_catalog.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""内置命名 Hook 目录注册器。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import List
|
||||
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
|
||||
|
||||
HookSpecRegistrar = Callable[[HookSpecRegistry], List[HookSpec]]
|
||||
"""单个业务模块向注册中心写入 Hook 规格的注册器签名。"""
|
||||
|
||||
|
||||
def _get_builtin_hook_spec_registrars() -> List[HookSpecRegistrar]:
|
||||
"""返回当前内置 Hook 规格注册器列表。
|
||||
|
||||
Returns:
|
||||
List[HookSpecRegistrar]: 已启用的内置 Hook 注册器列表。
|
||||
"""
|
||||
|
||||
from src.chat.message_receive.bot import register_chat_hook_specs
|
||||
from src.chat.emoji_system.emoji_manager import register_emoji_hook_specs
|
||||
from src.learners.expression_learner import register_expression_hook_specs
|
||||
from src.learners.jargon_miner import register_jargon_hook_specs
|
||||
from src.maisaka.chat_loop_service import register_maisaka_hook_specs
|
||||
from src.services.send_service import register_send_service_hook_specs
|
||||
|
||||
return [
|
||||
register_chat_hook_specs,
|
||||
register_emoji_hook_specs,
|
||||
register_jargon_hook_specs,
|
||||
register_expression_hook_specs,
|
||||
register_send_service_hook_specs,
|
||||
register_maisaka_hook_specs,
|
||||
]
|
||||
|
||||
|
||||
def register_builtin_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""向注册中心写入全部内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 本次完成注册后的全部内置 Hook 规格。
|
||||
"""
|
||||
|
||||
registered_specs: List[HookSpec] = []
|
||||
for registrar in _get_builtin_hook_spec_registrars():
|
||||
registered_specs.extend(registrar(registry))
|
||||
return registered_specs
|
||||
178
src/plugin_runtime/hook_payloads.py
Normal file
178
src/plugin_runtime/hook_payloads.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""运行时 Hook 载荷序列化辅助。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Sequence
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.llm_service_data_models import PromptMessage
|
||||
from src.llm_models.payload_content.message import Message
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, normalize_tool_options
|
||||
from src.plugin_runtime.host.message_utils import PluginMessageUtils
|
||||
|
||||
|
||||
def serialize_session_message(message: SessionMessage) -> Dict[str, Any]:
|
||||
"""将会话消息序列化为 Hook 可传输载荷。
|
||||
|
||||
Args:
|
||||
message: 待序列化的会话消息。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 可通过插件运行时传输的消息字典。
|
||||
"""
|
||||
|
||||
return dict(PluginMessageUtils._session_message_to_dict(message))
|
||||
|
||||
|
||||
def deserialize_session_message(raw_message: Any) -> SessionMessage:
|
||||
"""从 Hook 载荷恢复会话消息。
|
||||
|
||||
Args:
|
||||
raw_message: Hook 返回的消息字典。
|
||||
|
||||
Returns:
|
||||
SessionMessage: 恢复后的会话消息对象。
|
||||
|
||||
Raises:
|
||||
ValueError: 消息结构不合法时抛出。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_message, dict):
|
||||
raise ValueError("Hook 返回的 `message` 必须是字典")
|
||||
return PluginMessageUtils._build_session_message_from_dict(raw_message)
|
||||
|
||||
|
||||
def serialize_tool_calls(tool_calls: Sequence[ToolCall] | None) -> List[Dict[str, Any]]:
|
||||
"""将工具调用列表序列化为 Hook 可传输载荷。
|
||||
|
||||
Args:
|
||||
tool_calls: 原始工具调用列表。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 序列化后的工具调用列表。
|
||||
"""
|
||||
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"id": tool_call.call_id,
|
||||
"function": {
|
||||
"name": tool_call.func_name,
|
||||
"arguments": dict(tool_call.args or {}),
|
||||
},
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
|
||||
|
||||
def deserialize_tool_calls(raw_tool_calls: Any) -> List[ToolCall]:
|
||||
"""从 Hook 载荷恢复工具调用列表。
|
||||
|
||||
Args:
|
||||
raw_tool_calls: Hook 返回的工具调用列表。
|
||||
|
||||
Returns:
|
||||
List[ToolCall]: 恢复后的工具调用列表。
|
||||
|
||||
Raises:
|
||||
ValueError: 结构不合法时抛出。
|
||||
"""
|
||||
|
||||
if raw_tool_calls in (None, []):
|
||||
return []
|
||||
if not isinstance(raw_tool_calls, list):
|
||||
raise ValueError("Hook 返回的 `tool_calls` 必须是列表")
|
||||
|
||||
normalized_tool_calls: List[ToolCall] = []
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
if not isinstance(raw_tool_call, dict):
|
||||
raise ValueError("Hook 返回的工具调用项必须是字典")
|
||||
|
||||
function_info = raw_tool_call.get("function", {})
|
||||
if isinstance(function_info, dict):
|
||||
function_name = function_info.get("name")
|
||||
function_arguments = function_info.get("arguments")
|
||||
else:
|
||||
function_name = raw_tool_call.get("name")
|
||||
function_arguments = raw_tool_call.get("arguments")
|
||||
|
||||
call_id = raw_tool_call.get("id") or raw_tool_call.get("call_id")
|
||||
if not isinstance(call_id, str) or not isinstance(function_name, str):
|
||||
raise ValueError("Hook 返回的工具调用缺少 `id` 或函数名称")
|
||||
|
||||
normalized_tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=call_id,
|
||||
func_name=function_name,
|
||||
args=function_arguments if isinstance(function_arguments, dict) else {},
|
||||
)
|
||||
)
|
||||
return normalized_tool_calls
|
||||
|
||||
|
||||
def serialize_prompt_messages(messages: Sequence[Message]) -> List[PromptMessage]:
|
||||
"""将 LLM 消息列表序列化为 Hook 可传输载荷。
|
||||
|
||||
Args:
|
||||
messages: 原始 LLM 消息列表。
|
||||
|
||||
Returns:
|
||||
List[PromptMessage]: 序列化后的消息字典列表。
|
||||
"""
|
||||
|
||||
serialized_messages: List[PromptMessage] = []
|
||||
for message in messages:
|
||||
serialized_message: PromptMessage = {
|
||||
"role": message.role.value,
|
||||
"content": message.content,
|
||||
}
|
||||
if message.tool_call_id:
|
||||
serialized_message["tool_call_id"] = message.tool_call_id
|
||||
if message.tool_calls:
|
||||
serialized_message["tool_calls"] = serialize_tool_calls(message.tool_calls)
|
||||
serialized_messages.append(serialized_message)
|
||||
return serialized_messages
|
||||
|
||||
|
||||
def deserialize_prompt_messages(raw_messages: Any) -> List[Message]:
|
||||
"""从 Hook 载荷恢复 LLM 消息列表。
|
||||
|
||||
Args:
|
||||
raw_messages: Hook 返回的消息列表。
|
||||
|
||||
Returns:
|
||||
List[Message]: 恢复后的 LLM 消息列表。
|
||||
|
||||
Raises:
|
||||
ValueError: 结构不合法时抛出。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_messages, list):
|
||||
raise ValueError("Hook 返回的 `messages` 必须是列表")
|
||||
|
||||
from src.services.llm_service import _build_message_from_dict
|
||||
|
||||
normalized_messages: List[Message] = []
|
||||
for raw_message in raw_messages:
|
||||
if not isinstance(raw_message, dict):
|
||||
raise ValueError("Hook 返回的消息项必须是字典")
|
||||
normalized_messages.append(_build_message_from_dict(raw_message))
|
||||
return normalized_messages
|
||||
|
||||
|
||||
def serialize_tool_definitions(tool_definitions: Sequence[ToolDefinitionInput]) -> List[Dict[str, Any]]:
|
||||
"""将工具定义列表序列化为 Hook 可传输载荷。
|
||||
|
||||
Args:
|
||||
tool_definitions: 原始工具定义列表。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 序列化后的工具定义列表。
|
||||
"""
|
||||
|
||||
normalized_tool_options = normalize_tool_options(list(tool_definitions))
|
||||
if not normalized_tool_options:
|
||||
return []
|
||||
return [tool_option.to_openai_function_schema() for tool_option in normalized_tool_options]
|
||||
31
src/plugin_runtime/hook_schema_utils.py
Normal file
31
src/plugin_runtime/hook_schema_utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Hook 参数模型构造辅助。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
|
||||
def build_object_schema(
|
||||
properties: Dict[str, Dict[str, Any]],
|
||||
*,
|
||||
required: Sequence[str] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""构造对象级 JSON Schema。
|
||||
|
||||
Args:
|
||||
properties: 字段定义映射。
|
||||
required: 必填字段名列表。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 标准化后的对象级 Schema。
|
||||
"""
|
||||
|
||||
schema: Dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": deepcopy(properties),
|
||||
}
|
||||
normalized_required = [str(item).strip() for item in (required or []) if str(item).strip()]
|
||||
if normalized_required:
|
||||
schema["required"] = normalized_required
|
||||
return schema
|
||||
@@ -18,9 +18,37 @@ import re
|
||||
from src.common.logger import get_logger
|
||||
from src.core.tooling import build_tool_detailed_description
|
||||
|
||||
from .hook_spec_registry import HookSpecRegistry
|
||||
|
||||
logger = get_logger("plugin_runtime.host.component_registry")
|
||||
|
||||
|
||||
class ComponentRegistrationError(ValueError):
|
||||
"""组件注册失败异常。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
component_name: str = "",
|
||||
component_type: str = "",
|
||||
plugin_id: str = "",
|
||||
) -> None:
|
||||
"""初始化组件注册失败异常。
|
||||
|
||||
Args:
|
||||
message: 原始错误信息。
|
||||
component_name: 组件名称。
|
||||
component_type: 组件类型。
|
||||
plugin_id: 插件 ID。
|
||||
"""
|
||||
|
||||
self.component_name = str(component_name or "").strip()
|
||||
self.component_type = str(component_type or "").strip()
|
||||
self.plugin_id = str(plugin_id or "").strip()
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ComponentTypes(str, Enum):
|
||||
ACTION = "ACTION"
|
||||
COMMAND = "COMMAND"
|
||||
@@ -359,7 +387,14 @@ class ComponentRegistry:
|
||||
供业务层查询可用组件、匹配命令、调度 action/event 等。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, hook_spec_registry: Optional[HookSpecRegistry] = None) -> None:
|
||||
"""初始化组件注册表。
|
||||
|
||||
Args:
|
||||
hook_spec_registry: 可选的 Hook 规格注册中心;提供后会在注册
|
||||
HookHandler 时执行规格校验。
|
||||
"""
|
||||
|
||||
# 全量索引
|
||||
self._components: Dict[str, ComponentEntry] = {} # full_name -> comp
|
||||
|
||||
@@ -370,6 +405,7 @@ class ComponentRegistry:
|
||||
|
||||
# 按插件索引
|
||||
self._by_plugin: Dict[str, List[ComponentEntry]] = {}
|
||||
self._hook_spec_registry = hook_spec_registry
|
||||
|
||||
@staticmethod
|
||||
def _convert_action_metadata_to_tool_metadata(
|
||||
@@ -475,77 +511,211 @@ class ComponentRegistry:
|
||||
type_dict.clear()
|
||||
self._by_plugin.clear()
|
||||
|
||||
# ====== 注册 / 注销 ======
|
||||
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""注册单个组件
|
||||
@staticmethod
|
||||
def _is_legacy_action_component(component: ComponentEntry) -> bool:
|
||||
"""判断组件是否为兼容旧 Action 的 Tool 条目。
|
||||
|
||||
Args:
|
||||
name: 组件名称(不含插件id前缀)
|
||||
component_type: 组件类型(如 `ACTION`、`COMMAND` 等)
|
||||
plugin_id: 插件id
|
||||
metadata: 组件元数据
|
||||
component: 待判断的组件条目。
|
||||
|
||||
Returns:
|
||||
success (bool): 是否成功注册(失败原因通常是组件类型无效)
|
||||
bool: 是否为兼容旧 Action 组件。
|
||||
"""
|
||||
|
||||
if not isinstance(component, ToolEntry):
|
||||
return False
|
||||
return str(component.metadata.get("legacy_component_type", "") or "").strip().upper() == "ACTION"
|
||||
|
||||
def _validate_hook_handler_entry(self, component: HookHandlerEntry) -> None:
|
||||
"""校验 HookHandler 是否满足已注册的 Hook 规格。
|
||||
|
||||
Args:
|
||||
component: 待校验的 HookHandler 条目。
|
||||
|
||||
Raises:
|
||||
ComponentRegistrationError: HookHandler 声明不合法时抛出。
|
||||
"""
|
||||
|
||||
if self._hook_spec_registry is None:
|
||||
return
|
||||
|
||||
hook_spec = self._hook_spec_registry.get_hook_spec(component.hook)
|
||||
if hook_spec is None:
|
||||
raise ComponentRegistrationError(
|
||||
f"HookHandler {component.full_name} 声明了未注册的 Hook: {component.hook}",
|
||||
component_name=component.name,
|
||||
component_type=component.component_type.value,
|
||||
plugin_id=component.plugin_id,
|
||||
)
|
||||
|
||||
if component.is_blocking and not hook_spec.allow_blocking:
|
||||
raise ComponentRegistrationError(
|
||||
f"HookHandler {component.full_name} 不能注册为 blocking:Hook {component.hook} 不允许 blocking 处理器",
|
||||
component_name=component.name,
|
||||
component_type=component.component_type.value,
|
||||
plugin_id=component.plugin_id,
|
||||
)
|
||||
|
||||
if component.is_observe and not hook_spec.allow_observe:
|
||||
raise ComponentRegistrationError(
|
||||
f"HookHandler {component.full_name} 不能注册为 observe:Hook {component.hook} 不允许 observe 处理器",
|
||||
component_name=component.name,
|
||||
component_type=component.component_type.value,
|
||||
plugin_id=component.plugin_id,
|
||||
)
|
||||
|
||||
if component.error_policy == "abort" and not hook_spec.allow_abort:
|
||||
raise ComponentRegistrationError(
|
||||
f"HookHandler {component.full_name} 不能使用 error_policy=abort:Hook {component.hook} 不允许 abort",
|
||||
component_name=component.name,
|
||||
component_type=component.component_type.value,
|
||||
plugin_id=component.plugin_id,
|
||||
)
|
||||
|
||||
def _build_component_entry(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
) -> ComponentEntry:
|
||||
"""根据声明构造组件条目。
|
||||
|
||||
Args:
|
||||
name: 组件名称。
|
||||
component_type: 组件类型。
|
||||
plugin_id: 插件 ID。
|
||||
metadata: 组件元数据。
|
||||
|
||||
Returns:
|
||||
ComponentEntry: 已构造并完成校验的组件条目。
|
||||
|
||||
Raises:
|
||||
ComponentRegistrationError: 组件声明不合法时抛出。
|
||||
"""
|
||||
|
||||
try:
|
||||
normalized_type = self._normalize_component_type(component_type)
|
||||
normalized_metadata = dict(metadata)
|
||||
if normalized_type == ComponentTypes.ACTION:
|
||||
normalized_metadata = self._convert_action_metadata_to_tool_metadata(name, normalized_metadata)
|
||||
comp = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata)
|
||||
component = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.COMMAND:
|
||||
comp = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.TOOL:
|
||||
comp = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.EVENT_HANDLER:
|
||||
comp = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
elif normalized_type == ComponentTypes.HOOK_HANDLER:
|
||||
comp = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
self._validate_hook_handler_entry(component)
|
||||
elif normalized_type == ComponentTypes.MESSAGE_GATEWAY:
|
||||
comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
component = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata)
|
||||
else:
|
||||
raise ValueError(f"组件类型 {component_type} 不存在")
|
||||
except ValueError:
|
||||
logger.error(f"组件类型 {component_type} 不存在")
|
||||
return False
|
||||
raise ComponentRegistrationError(
|
||||
f"组件类型 {component_type} 不存在",
|
||||
component_name=name,
|
||||
component_type=component_type,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
except ComponentRegistrationError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise ComponentRegistrationError(
|
||||
str(exc),
|
||||
component_name=name,
|
||||
component_type=component_type,
|
||||
plugin_id=plugin_id,
|
||||
) from exc
|
||||
|
||||
if comp.full_name in self._components:
|
||||
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
|
||||
old_comp = self._components[comp.full_name]
|
||||
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
|
||||
old_list = self._by_plugin.get(old_comp.plugin_id)
|
||||
if old_list is not None:
|
||||
with contextlib.suppress(ValueError):
|
||||
old_list.remove(old_comp)
|
||||
# 从旧类型索引中移除,防止类型变更时幽灵残留
|
||||
if old_type_dict := self._by_type.get(old_comp.component_type):
|
||||
old_type_dict.pop(comp.full_name, None)
|
||||
return component
|
||||
|
||||
self._components[comp.full_name] = comp
|
||||
self._by_type[comp.component_type][comp.full_name] = comp
|
||||
self._by_plugin.setdefault(plugin_id, []).append(comp)
|
||||
def _remove_existing_component_entry(self, component: ComponentEntry) -> None:
|
||||
"""移除同名旧组件条目。
|
||||
|
||||
Args:
|
||||
component: 即将写入的新组件条目。
|
||||
"""
|
||||
|
||||
if component.full_name not in self._components:
|
||||
return
|
||||
|
||||
logger.warning(f"组件 {component.full_name} 已存在,覆盖")
|
||||
old_component = self._components[component.full_name]
|
||||
old_list = self._by_plugin.get(old_component.plugin_id)
|
||||
if old_list is not None:
|
||||
with contextlib.suppress(ValueError):
|
||||
old_list.remove(old_component)
|
||||
if old_type_dict := self._by_type.get(old_component.component_type):
|
||||
old_type_dict.pop(component.full_name, None)
|
||||
|
||||
def _add_component_entry(self, component: ComponentEntry) -> None:
|
||||
"""写入单个组件条目到全部索引。
|
||||
|
||||
Args:
|
||||
component: 待写入的组件条目。
|
||||
"""
|
||||
|
||||
self._remove_existing_component_entry(component)
|
||||
self._components[component.full_name] = component
|
||||
self._by_type[component.component_type][component.full_name] = component
|
||||
self._by_plugin.setdefault(component.plugin_id, []).append(component)
|
||||
|
||||
# ====== 注册 / 注销 ======
|
||||
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""注册单个组件。
|
||||
|
||||
Args:
|
||||
name: 组件名称(不含插件 ID 前缀)。
|
||||
component_type: 组件类型(如 ``ACTION``、``COMMAND`` 等)。
|
||||
plugin_id: 插件 ID。
|
||||
metadata: 组件元数据。
|
||||
|
||||
Returns:
|
||||
bool: 注册成功时恒为 ``True``。
|
||||
|
||||
Raises:
|
||||
ComponentRegistrationError: 组件声明不合法时抛出。
|
||||
"""
|
||||
|
||||
component = self._build_component_entry(name, component_type, plugin_id, metadata)
|
||||
self._add_component_entry(component)
|
||||
return True
|
||||
|
||||
def register_plugin_components(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
|
||||
"""批量注册一个插件的所有组件,返回成功注册数。
|
||||
"""批量替换一个插件的组件集合。
|
||||
|
||||
该方法会先完整校验所有组件声明,只有全部通过后才会替换旧组件,
|
||||
从而避免插件进入半注册状态。
|
||||
|
||||
Args:
|
||||
plugin_id (str): 插件id
|
||||
components (List[Dict[str, Any]]): 组件字典列表,每个组件包含 name, component_type, metadata 等字段
|
||||
plugin_id: 插件 ID。
|
||||
components: 组件声明字典列表。
|
||||
|
||||
Returns:
|
||||
count (int): 成功注册的组件数量
|
||||
int: 实际注册的组件数量。
|
||||
|
||||
Raises:
|
||||
ComponentRegistrationError: 任一组件声明不合法时抛出。
|
||||
"""
|
||||
count = 0
|
||||
for comp_data in components:
|
||||
ok = self.register_component(
|
||||
name=comp_data.get("name", ""),
|
||||
component_type=comp_data.get("component_type", ""),
|
||||
plugin_id=plugin_id,
|
||||
metadata=comp_data.get("metadata", {}),
|
||||
|
||||
prepared_components: List[ComponentEntry] = []
|
||||
for component_data in components:
|
||||
prepared_components.append(
|
||||
self._build_component_entry(
|
||||
name=str(component_data.get("name", "") or ""),
|
||||
component_type=str(component_data.get("component_type", "") or ""),
|
||||
plugin_id=plugin_id,
|
||||
metadata=component_data.get("metadata", {})
|
||||
if isinstance(component_data.get("metadata"), dict)
|
||||
else {},
|
||||
)
|
||||
)
|
||||
if ok:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
self.remove_components_by_plugin(plugin_id)
|
||||
for component in prepared_components:
|
||||
self._add_component_entry(component)
|
||||
return len(prepared_components)
|
||||
|
||||
def remove_components_by_plugin(self, plugin_id: str) -> int:
|
||||
"""移除某个插件的所有组件,返回移除数量。
|
||||
@@ -652,6 +822,17 @@ class ComponentRegistry:
|
||||
except ValueError:
|
||||
logger.error(f"组件类型 {component_type} 不存在")
|
||||
raise
|
||||
|
||||
if comp_type == ComponentTypes.ACTION:
|
||||
action_components = [
|
||||
component
|
||||
for component in self._by_type.get(ComponentTypes.TOOL, {}).values()
|
||||
if self._is_legacy_action_component(component)
|
||||
]
|
||||
if enabled_only:
|
||||
return [component for component in action_components if self.check_component_enabled(component, session_id)]
|
||||
return action_components
|
||||
|
||||
type_dict = self._by_type.get(comp_type, {})
|
||||
if enabled_only:
|
||||
return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)]
|
||||
@@ -854,6 +1035,34 @@ class ComponentRegistry:
|
||||
tools.append(comp)
|
||||
return tools
|
||||
|
||||
def get_tools_for_llm(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""兼容旧接口,返回可供 LLM 使用的工具条目列表。
|
||||
|
||||
Args:
|
||||
enabled_only: 是否仅返回启用的组件。
|
||||
session_id: 可选的会话 ID,若提供则考虑会话禁用状态。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 兼容旧结构的工具组件字典列表。
|
||||
"""
|
||||
|
||||
return [
|
||||
{
|
||||
"name": tool.full_name,
|
||||
"description": tool.description,
|
||||
"parameters": (
|
||||
dict(tool.parameters_raw)
|
||||
if isinstance(tool.parameters_raw, dict) and tool.parameters_raw
|
||||
else tool._get_parameters_schema() or {}
|
||||
),
|
||||
"parameters_raw": tool.parameters_raw,
|
||||
"enabled": tool.enabled,
|
||||
"plugin_id": tool.plugin_id,
|
||||
}
|
||||
for tool in self.get_tools(enabled_only=enabled_only, session_id=session_id)
|
||||
if not self._is_legacy_action_component(tool)
|
||||
]
|
||||
|
||||
# ====== 统计信息 ======
|
||||
def get_stats(self) -> StatusDict:
|
||||
"""获取注册统计。
|
||||
@@ -863,9 +1072,21 @@ class ComponentRegistry:
|
||||
"""
|
||||
return StatusDict(
|
||||
total=len(self._components),
|
||||
action=len(self._by_type[ComponentTypes.ACTION]),
|
||||
action=len(
|
||||
[
|
||||
component
|
||||
for component in self._by_type.get(ComponentTypes.TOOL, {}).values()
|
||||
if self._is_legacy_action_component(component)
|
||||
]
|
||||
),
|
||||
command=len(self._by_type[ComponentTypes.COMMAND]),
|
||||
tool=len(self._by_type[ComponentTypes.TOOL]),
|
||||
tool=len(
|
||||
[
|
||||
component
|
||||
for component in self._by_type.get(ComponentTypes.TOOL, {}).values()
|
||||
if not self._is_legacy_action_component(component)
|
||||
]
|
||||
),
|
||||
event_handler=len(self._by_type[ComponentTypes.EVENT_HANDLER]),
|
||||
hook_handler=len(self._by_type[ComponentTypes.HOOK_HANDLER]),
|
||||
message_gateway=len(self._by_type[ComponentTypes.MESSAGE_GATEWAY]),
|
||||
|
||||
@@ -26,6 +26,8 @@ import contextlib
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
from .hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .component_registry import HookHandlerEntry
|
||||
from .supervisor import PluginRunnerSupervisor
|
||||
@@ -33,29 +35,6 @@ if TYPE_CHECKING:
|
||||
logger = get_logger("plugin_runtime.host.hook_dispatcher")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HookSpec:
|
||||
"""命名 Hook 的静态规格定义。
|
||||
|
||||
Attributes:
|
||||
name: Hook 的唯一名称。
|
||||
description: Hook 描述。
|
||||
default_timeout_ms: 默认超时毫秒数;为 `0` 时退回系统默认值。
|
||||
allow_blocking: 是否允许注册阻塞处理器。
|
||||
allow_observe: 是否允许注册观察处理器。
|
||||
allow_abort: 是否允许处理器中止当前 Hook 调用。
|
||||
allow_kwargs_mutation: 是否允许阻塞处理器修改 `kwargs`。
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str = ""
|
||||
default_timeout_ms: int = 0
|
||||
allow_blocking: bool = True
|
||||
allow_observe: bool = True
|
||||
allow_abort: bool = True
|
||||
allow_kwargs_mutation: bool = True
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HookHandlerExecutionResult:
|
||||
"""单个 HookHandler 的执行结果。
|
||||
@@ -121,17 +100,19 @@ class HookDispatcher:
|
||||
def __init__(
|
||||
self,
|
||||
supervisors_provider: Optional[Callable[[], Sequence["PluginRunnerSupervisor"]]] = None,
|
||||
hook_spec_registry: Optional[HookSpecRegistry] = None,
|
||||
) -> None:
|
||||
"""初始化 Hook 分发器。
|
||||
|
||||
Args:
|
||||
supervisors_provider: 可选的 Supervisor 提供器。若调用 `invoke_hook()`
|
||||
时未显式传入 `supervisors`,则使用该回调获取目标 Supervisor 列表。
|
||||
hook_spec_registry: 可选的 Hook 规格注册中心;留空时使用独立注册中心。
|
||||
"""
|
||||
|
||||
self._background_tasks: Set[asyncio.Task[Any]] = set()
|
||||
self._hook_specs: Dict[str, HookSpec] = {}
|
||||
self._supervisors_provider = supervisors_provider
|
||||
self._hook_spec_registry = hook_spec_registry or HookSpecRegistry()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止分发器并取消所有未完成的观察任务。"""
|
||||
@@ -148,16 +129,7 @@ class HookDispatcher:
|
||||
spec: 需要注册的 Hook 规格。
|
||||
"""
|
||||
|
||||
normalized_name = self._normalize_hook_name(spec.name)
|
||||
self._hook_specs[normalized_name] = HookSpec(
|
||||
name=normalized_name,
|
||||
description=spec.description,
|
||||
default_timeout_ms=max(int(spec.default_timeout_ms), 0),
|
||||
allow_blocking=bool(spec.allow_blocking),
|
||||
allow_observe=bool(spec.allow_observe),
|
||||
allow_abort=bool(spec.allow_abort),
|
||||
allow_kwargs_mutation=bool(spec.allow_kwargs_mutation),
|
||||
)
|
||||
self._hook_spec_registry.register_hook_spec(spec)
|
||||
|
||||
def register_hook_specs(self, specs: Sequence[HookSpec]) -> None:
|
||||
"""批量注册命名 Hook 规格。
|
||||
@@ -180,14 +152,37 @@ class HookDispatcher:
|
||||
"""
|
||||
|
||||
normalized_name = self._normalize_hook_name(hook_name)
|
||||
if normalized_name in self._hook_specs:
|
||||
return self._hook_specs[normalized_name]
|
||||
registered_spec = self._hook_spec_registry.get_hook_spec(normalized_name)
|
||||
if registered_spec is not None:
|
||||
return registered_spec
|
||||
|
||||
return HookSpec(
|
||||
name=normalized_name,
|
||||
parameters_schema={},
|
||||
default_timeout_ms=self._get_default_timeout_ms(),
|
||||
)
|
||||
|
||||
def unregister_hook_spec(self, hook_name: str) -> bool:
|
||||
"""注销指定命名 Hook 规格。
|
||||
|
||||
Args:
|
||||
hook_name: 目标 Hook 名称。
|
||||
|
||||
Returns:
|
||||
bool: 是否成功注销。
|
||||
"""
|
||||
|
||||
return self._hook_spec_registry.unregister_hook_spec(hook_name)
|
||||
|
||||
def list_hook_specs(self) -> List[HookSpec]:
|
||||
"""返回当前全部显式注册的 Hook 规格。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 已注册 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return self._hook_spec_registry.list_hook_specs()
|
||||
|
||||
async def invoke_hook(
|
||||
self,
|
||||
hook_name: str,
|
||||
|
||||
190
src/plugin_runtime/host/hook_spec_registry.py
Normal file
190
src/plugin_runtime/host/hook_spec_registry.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""命名 Hook 规格注册中心。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HookSpec:
|
||||
"""命名 Hook 的静态规格定义。
|
||||
|
||||
Attributes:
|
||||
name: Hook 的唯一名称。
|
||||
description: Hook 描述。
|
||||
parameters_schema: Hook 参数模型,使用对象级 JSON Schema 表示。
|
||||
default_timeout_ms: 默认超时毫秒数;为 ``0`` 时退回系统默认值。
|
||||
allow_blocking: 是否允许注册阻塞处理器。
|
||||
allow_observe: 是否允许注册观察处理器。
|
||||
allow_abort: 是否允许处理器中止当前 Hook 调用。
|
||||
allow_kwargs_mutation: 是否允许阻塞处理器修改 ``kwargs``。
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str = ""
|
||||
parameters_schema: Dict[str, Any] = field(default_factory=dict)
|
||||
default_timeout_ms: int = 0
|
||||
allow_blocking: bool = True
|
||||
allow_observe: bool = True
|
||||
allow_abort: bool = True
|
||||
allow_kwargs_mutation: bool = True
|
||||
|
||||
|
||||
class HookSpecRegistry:
|
||||
"""命名 Hook 规格注册中心。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化 Hook 规格注册中心。"""
|
||||
|
||||
self._hook_specs: Dict[str, HookSpec] = {}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_hook_name(hook_name: str) -> str:
|
||||
"""规范化 Hook 名称。
|
||||
|
||||
Args:
|
||||
hook_name: 原始 Hook 名称。
|
||||
|
||||
Returns:
|
||||
str: 规范化后的 Hook 名称。
|
||||
|
||||
Raises:
|
||||
ValueError: Hook 名称为空时抛出。
|
||||
"""
|
||||
|
||||
normalized_name = str(hook_name or "").strip()
|
||||
if not normalized_name:
|
||||
raise ValueError("Hook 名称不能为空")
|
||||
return normalized_name
|
||||
|
||||
@staticmethod
|
||||
def _normalize_parameters_schema(raw_schema: Any) -> Dict[str, Any]:
|
||||
"""规范化 Hook 参数模型。
|
||||
|
||||
Args:
|
||||
raw_schema: 原始参数模型。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 规范化后的对象级 JSON Schema。
|
||||
|
||||
Raises:
|
||||
ValueError: 参数模型不是合法对象级 Schema 时抛出。
|
||||
"""
|
||||
|
||||
if raw_schema is None:
|
||||
return {}
|
||||
if not isinstance(raw_schema, dict):
|
||||
raise ValueError("Hook 参数模型必须是字典")
|
||||
if not raw_schema:
|
||||
return {}
|
||||
|
||||
normalized_schema = deepcopy(raw_schema)
|
||||
schema_type = normalized_schema.get("type")
|
||||
properties = normalized_schema.get("properties")
|
||||
if schema_type not in {"", None, "object"} and properties is None:
|
||||
raise ValueError("Hook 参数模型必须是 object 类型或属性映射")
|
||||
if schema_type in {"", None} and properties is None:
|
||||
normalized_schema = {
|
||||
"type": "object",
|
||||
"properties": normalized_schema,
|
||||
}
|
||||
elif schema_type in {"", None}:
|
||||
normalized_schema["type"] = "object"
|
||||
|
||||
if normalized_schema.get("type") != "object":
|
||||
raise ValueError("Hook 参数模型必须是 object 类型")
|
||||
return normalized_schema
|
||||
|
||||
@classmethod
|
||||
def _normalize_spec(cls, spec: HookSpec) -> HookSpec:
|
||||
"""规范化 Hook 规格对象。
|
||||
|
||||
Args:
|
||||
spec: 原始 Hook 规格。
|
||||
|
||||
Returns:
|
||||
HookSpec: 规范化后的 Hook 规格副本。
|
||||
"""
|
||||
|
||||
return HookSpec(
|
||||
name=cls._normalize_hook_name(spec.name),
|
||||
description=str(spec.description or "").strip(),
|
||||
parameters_schema=cls._normalize_parameters_schema(spec.parameters_schema),
|
||||
default_timeout_ms=max(int(spec.default_timeout_ms), 0),
|
||||
allow_blocking=bool(spec.allow_blocking),
|
||||
allow_observe=bool(spec.allow_observe),
|
||||
allow_abort=bool(spec.allow_abort),
|
||||
allow_kwargs_mutation=bool(spec.allow_kwargs_mutation),
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空全部 Hook 规格。"""
|
||||
|
||||
self._hook_specs.clear()
|
||||
|
||||
def register_hook_spec(self, spec: HookSpec) -> HookSpec:
|
||||
"""注册单个 Hook 规格。
|
||||
|
||||
Args:
|
||||
spec: 需要注册的 Hook 规格。
|
||||
|
||||
Returns:
|
||||
HookSpec: 规范化后实际注册的 Hook 规格。
|
||||
"""
|
||||
|
||||
normalized_spec = self._normalize_spec(spec)
|
||||
self._hook_specs[normalized_spec.name] = normalized_spec
|
||||
return normalized_spec
|
||||
|
||||
def register_hook_specs(self, specs: Sequence[HookSpec]) -> List[HookSpec]:
|
||||
"""批量注册 Hook 规格。
|
||||
|
||||
Args:
|
||||
specs: 需要注册的 Hook 规格列表。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 规范化后实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return [self.register_hook_spec(spec) for spec in specs]
|
||||
|
||||
def unregister_hook_spec(self, hook_name: str) -> bool:
|
||||
"""注销指定 Hook 规格。
|
||||
|
||||
Args:
|
||||
hook_name: 目标 Hook 名称。
|
||||
|
||||
Returns:
|
||||
bool: 是否成功删除。
|
||||
"""
|
||||
|
||||
normalized_name = self._normalize_hook_name(hook_name)
|
||||
return self._hook_specs.pop(normalized_name, None) is not None
|
||||
|
||||
def get_hook_spec(self, hook_name: str) -> Optional[HookSpec]:
|
||||
"""获取指定 Hook 的显式规格。
|
||||
|
||||
Args:
|
||||
hook_name: 目标 Hook 名称。
|
||||
|
||||
Returns:
|
||||
Optional[HookSpec]: 已注册时返回规格副本,否则返回 ``None``。
|
||||
"""
|
||||
|
||||
normalized_name = self._normalize_hook_name(hook_name)
|
||||
spec = self._hook_specs.get(normalized_name)
|
||||
return None if spec is None else self._normalize_spec(spec)
|
||||
|
||||
def list_hook_specs(self) -> List[HookSpec]:
|
||||
"""返回当前全部 Hook 规格。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 按 Hook 名称升序排列的规格副本列表。
|
||||
"""
|
||||
|
||||
return [
|
||||
self._normalize_spec(spec)
|
||||
for _, spec in sorted(self._hook_specs.items(), key=lambda item: item[0])
|
||||
]
|
||||
@@ -14,6 +14,7 @@ from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, Ro
|
||||
from src.platform_io.drivers import PluginPlatformDriver
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
from src.plugin_runtime import (
|
||||
ENV_BLOCKED_PLUGIN_REASONS,
|
||||
ENV_EXTERNAL_PLUGIN_IDS,
|
||||
ENV_GLOBAL_CONFIG_SNAPSHOT,
|
||||
ENV_HOST_VERSION,
|
||||
@@ -27,6 +28,8 @@ from src.plugin_runtime.protocol.envelope import (
|
||||
ConfigUpdatedPayload,
|
||||
Envelope,
|
||||
HealthPayload,
|
||||
InspectPluginConfigPayload,
|
||||
InspectPluginConfigResultPayload,
|
||||
MessageGatewayStateUpdatePayload,
|
||||
MessageGatewayStateUpdateResultPayload,
|
||||
PROTOCOL_VERSION,
|
||||
@@ -39,6 +42,8 @@ from src.plugin_runtime.protocol.envelope import (
|
||||
RunnerReadyPayload,
|
||||
ShutdownPayload,
|
||||
UnregisterPluginPayload,
|
||||
ValidatePluginConfigPayload,
|
||||
ValidatePluginConfigResultPayload,
|
||||
)
|
||||
from src.plugin_runtime.protocol.codec import MsgPackCodec
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
|
||||
@@ -50,6 +55,7 @@ from .capability_service import CapabilityService
|
||||
from .component_registry import ComponentRegistry
|
||||
from .event_dispatcher import EventDispatcher
|
||||
from .hook_dispatcher import HookDispatchResult, HookDispatcher
|
||||
from .hook_spec_registry import HookSpecRegistry
|
||||
from .logger_bridge import RunnerLogBridge
|
||||
from .message_gateway import MessageGateway
|
||||
from .rpc_server import RPCServer
|
||||
@@ -59,6 +65,7 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = get_logger("plugin_runtime.host.runner_manager")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _MessageGatewayRuntimeState:
|
||||
"""保存消息网关当前的运行时连接状态。"""
|
||||
@@ -81,6 +88,7 @@ class PluginRunnerSupervisor:
|
||||
self,
|
||||
plugin_dirs: Optional[List[Path]] = None,
|
||||
group_name: str = "third_party",
|
||||
hook_spec_registry: Optional[HookSpecRegistry] = None,
|
||||
socket_path: Optional[str] = None,
|
||||
health_check_interval_sec: Optional[float] = None,
|
||||
max_restart_attempts: Optional[int] = None,
|
||||
@@ -91,6 +99,7 @@ class PluginRunnerSupervisor:
|
||||
Args:
|
||||
plugin_dirs: 由当前 Runner 负责加载的插件目录列表。
|
||||
group_name: 当前 Supervisor 所属运行时分组名称。
|
||||
hook_spec_registry: 可选的共享 Hook 规格注册中心。
|
||||
socket_path: 自定义 IPC 地址;留空时由传输层自动生成。
|
||||
health_check_interval_sec: 健康检查间隔,单位秒。
|
||||
max_restart_attempts: 自动重启 Runner 的最大次数。
|
||||
@@ -100,18 +109,19 @@ class PluginRunnerSupervisor:
|
||||
self._group_name: str = str(group_name or "third_party").strip() or "third_party"
|
||||
self._plugin_dirs: List[Path] = plugin_dirs or []
|
||||
self._health_interval: float = health_check_interval_sec or runtime_config.health_check_interval_sec or 30.0
|
||||
self._runner_spawn_timeout: float = (
|
||||
runner_spawn_timeout_sec or runtime_config.runner_spawn_timeout_sec or 30.0
|
||||
)
|
||||
self._runner_spawn_timeout: float = runner_spawn_timeout_sec or runtime_config.runner_spawn_timeout_sec or 30.0
|
||||
self._max_restart_attempts: int = max_restart_attempts or runtime_config.max_restart_attempts or 3
|
||||
|
||||
self._transport = create_transport_server(socket_path=socket_path)
|
||||
self._authorization = AuthorizationManager()
|
||||
self._capability_service = CapabilityService(self._authorization)
|
||||
self._api_registry = APIRegistry()
|
||||
self._component_registry = ComponentRegistry()
|
||||
self._component_registry = ComponentRegistry(hook_spec_registry=hook_spec_registry)
|
||||
self._event_dispatcher = EventDispatcher(self._component_registry)
|
||||
self._hook_dispatcher = HookDispatcher(lambda: [self])
|
||||
self._hook_dispatcher = HookDispatcher(
|
||||
lambda: [self],
|
||||
hook_spec_registry=hook_spec_registry,
|
||||
)
|
||||
self._message_gateway = MessageGateway(self._component_registry)
|
||||
self._log_bridge = RunnerLogBridge()
|
||||
|
||||
@@ -122,6 +132,7 @@ class PluginRunnerSupervisor:
|
||||
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
|
||||
self._message_gateway_states: Dict[str, Dict[str, _MessageGatewayRuntimeState]] = {}
|
||||
self._external_available_plugins: Dict[str, str] = {}
|
||||
self._blocked_plugin_reasons: Dict[str, str] = {}
|
||||
self._runner_ready_events: asyncio.Event = asyncio.Event()
|
||||
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
|
||||
self._health_task: Optional[asyncio.Task[None]] = None
|
||||
@@ -200,9 +211,19 @@ class PluginRunnerSupervisor:
|
||||
Returns:
|
||||
Dict[str, str]: 已注册插件版本映射,键为插件 ID,值为插件版本。
|
||||
"""
|
||||
return {
|
||||
plugin_id: registration.plugin_version
|
||||
for plugin_id, registration in self._registered_plugins.items()
|
||||
return {plugin_id: registration.plugin_version for plugin_id, registration in self._registered_plugins.items()}
|
||||
|
||||
def set_blocked_plugin_reasons(self, blocked_plugin_reasons: Dict[str, str]) -> None:
|
||||
"""设置当前 Runner 启动时应拒绝加载的插件列表。
|
||||
|
||||
Args:
|
||||
blocked_plugin_reasons: 需要拒绝加载的插件及原因映射。
|
||||
"""
|
||||
|
||||
self._blocked_plugin_reasons = {
|
||||
str(plugin_id or "").strip(): str(reason or "").strip()
|
||||
for plugin_id, reason in blocked_plugin_reasons.items()
|
||||
if str(plugin_id or "").strip() and str(reason or "").strip()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -550,6 +571,82 @@ class PluginRunnerSupervisor:
|
||||
|
||||
return bool(response.payload.get("acknowledged", False))
|
||||
|
||||
async def validate_plugin_config(self, plugin_id: str, config_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""请求 Runner 使用插件自身配置模型校验配置。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
config_data: 待校验的配置内容。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 插件模型归一化后的配置字典。
|
||||
|
||||
Raises:
|
||||
ValueError: 插件拒绝该配置或校验失败时抛出。
|
||||
"""
|
||||
|
||||
payload = ValidatePluginConfigPayload(config_data=config_data)
|
||||
try:
|
||||
response = await self._rpc_server.send_request(
|
||||
"plugin.validate_config",
|
||||
plugin_id=plugin_id,
|
||||
payload=payload.model_dump(),
|
||||
timeout_ms=10000,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError(f"插件配置校验请求失败: {exc}") from exc
|
||||
|
||||
if response.error:
|
||||
raise ValueError(str(response.error.get("message", "插件配置校验失败")))
|
||||
|
||||
result = ValidatePluginConfigResultPayload.model_validate(response.payload)
|
||||
if not result.success:
|
||||
raise ValueError("插件配置校验失败")
|
||||
return dict(result.normalized_config)
|
||||
|
||||
async def inspect_plugin_config(
|
||||
self,
|
||||
plugin_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
use_provided_config: bool = False,
|
||||
) -> InspectPluginConfigResultPayload:
|
||||
"""请求 Runner 解析插件配置元数据。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
config_data: 可选的配置内容。
|
||||
use_provided_config: 是否优先使用传入配置而不是磁盘配置。
|
||||
|
||||
Returns:
|
||||
InspectPluginConfigResultPayload: 插件配置解析结果。
|
||||
|
||||
Raises:
|
||||
ValueError: Runner 无法解析插件或返回了错误响应时抛出。
|
||||
"""
|
||||
|
||||
payload = InspectPluginConfigPayload(
|
||||
config_data=config_data or {},
|
||||
use_provided_config=use_provided_config,
|
||||
)
|
||||
try:
|
||||
response = await self._rpc_server.send_request(
|
||||
"plugin.inspect_config",
|
||||
plugin_id=plugin_id,
|
||||
payload=payload.model_dump(),
|
||||
timeout_ms=10000,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError(f"插件配置解析请求失败: {exc}") from exc
|
||||
|
||||
if response.error:
|
||||
raise ValueError(str(response.error.get("message", "插件配置解析失败")))
|
||||
|
||||
result = InspectPluginConfigResultPayload.model_validate(response.payload)
|
||||
if not result.success:
|
||||
raise ValueError("插件配置解析失败")
|
||||
return result
|
||||
|
||||
def get_config_reload_subscribers(self, scope: str) -> List[str]:
|
||||
"""返回订阅指定全局配置广播的插件列表。
|
||||
|
||||
@@ -608,6 +705,7 @@ class PluginRunnerSupervisor:
|
||||
Raises:
|
||||
TimeoutError: 在超时时间内 Runner 未完成初始化。
|
||||
"""
|
||||
|
||||
async def wait_for_ready() -> RunnerReadyPayload:
|
||||
"""轮询等待 Runner 上报就绪。"""
|
||||
while True:
|
||||
@@ -681,15 +779,25 @@ class PluginRunnerSupervisor:
|
||||
|
||||
component_declarations = [component.model_dump() for component in payload.components]
|
||||
runtime_components, api_components = self._split_component_declarations(component_declarations)
|
||||
self._component_registry.remove_components_by_plugin(payload.plugin_id)
|
||||
self._api_registry.remove_apis_by_plugin(payload.plugin_id)
|
||||
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
|
||||
try:
|
||||
registered_count = self._component_registry.register_plugin_components(
|
||||
payload.plugin_id,
|
||||
runtime_components,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"插件 {payload.plugin_id} 组件注册失败: {exc}")
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_BAD_PAYLOAD.value,
|
||||
str(exc),
|
||||
details={
|
||||
"plugin_id": payload.plugin_id,
|
||||
"component_count": len(runtime_components),
|
||||
},
|
||||
)
|
||||
|
||||
registered_count = self._component_registry.register_plugin_components(
|
||||
payload.plugin_id,
|
||||
runtime_components,
|
||||
)
|
||||
self._api_registry.remove_apis_by_plugin(payload.plugin_id)
|
||||
registered_api_count = self._api_registry.register_plugin_apis(payload.plugin_id, api_components)
|
||||
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
|
||||
self._registered_plugins[payload.plugin_id] = payload
|
||||
self._message_gateway_states[payload.plugin_id] = {}
|
||||
|
||||
@@ -1058,7 +1166,9 @@ class PluginRunnerSupervisor:
|
||||
route_key = RouteKey(platform=platform)
|
||||
|
||||
route_account_id, route_scope = RouteKeyFactory.extract_components(route_metadata)
|
||||
account_id = route_key.account_id or route_account_id or runtime_state.account_id or gateway_entry.account_id or None
|
||||
account_id = (
|
||||
route_key.account_id or route_account_id or runtime_state.account_id or gateway_entry.account_id or None
|
||||
)
|
||||
scope = route_key.scope or route_scope or runtime_state.scope or gateway_entry.scope or None
|
||||
return RouteKey(
|
||||
platform=platform,
|
||||
@@ -1208,6 +1318,7 @@ class PluginRunnerSupervisor:
|
||||
global_config_snapshot = config_manager.get_global_config().model_dump(mode="json")
|
||||
global_config_snapshot["model"] = config_manager.get_model_config().model_dump(mode="json")
|
||||
return {
|
||||
ENV_BLOCKED_PLUGIN_REASONS: json.dumps(self._blocked_plugin_reasons, ensure_ascii=False),
|
||||
ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugins, ensure_ascii=False),
|
||||
ENV_GLOBAL_CONFIG_SNAPSHOT: json.dumps(global_config_snapshot, ensure_ascii=False),
|
||||
ENV_HOST_VERSION: PROTOCOL_VERSION,
|
||||
|
||||
@@ -8,10 +8,25 @@
|
||||
5. 提供统一的能力实现注册接口,使插件可以调用主程序功能
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
import tomlkit
|
||||
|
||||
@@ -23,10 +38,15 @@ from src.plugin_runtime.capabilities import (
|
||||
RuntimeComponentCapabilityMixin,
|
||||
RuntimeCoreCapabilityMixin,
|
||||
RuntimeDataCapabilityMixin,
|
||||
RuntimeRenderCapabilityMixin,
|
||||
)
|
||||
from src.plugin_runtime.capabilities.registry import register_capability_impls
|
||||
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult, HookDispatcher, HookSpec
|
||||
from src.plugin_runtime.dependency_pipeline import PluginDependencyPipeline
|
||||
from src.plugin_runtime.hook_catalog import register_builtin_hook_specs
|
||||
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult, HookDispatcher
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils
|
||||
from src.plugin_runtime.protocol.envelope import InspectPluginConfigResultPayload
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -50,10 +70,19 @@ _EVENT_TYPE_MAP: Dict[str, str] = {
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DependencySyncState:
|
||||
"""表示一次插件依赖同步后的状态。"""
|
||||
|
||||
blocked_changed_plugin_ids: Set[str]
|
||||
environment_changed: bool
|
||||
|
||||
|
||||
class PluginRuntimeManager(
|
||||
RuntimeCoreCapabilityMixin,
|
||||
RuntimeDataCapabilityMixin,
|
||||
RuntimeComponentCapabilityMixin,
|
||||
RuntimeRenderCapabilityMixin,
|
||||
):
|
||||
"""插件运行时管理器(单例)
|
||||
|
||||
@@ -71,10 +100,17 @@ class PluginRuntimeManager(
|
||||
self._plugin_source_watcher_subscription_id: Optional[str] = None
|
||||
self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {}
|
||||
self._plugin_path_cache: Dict[str, Path] = {}
|
||||
self._manifest_validator: ManifestValidator = ManifestValidator()
|
||||
self._manifest_validator: ManifestValidator = ManifestValidator(validate_python_package_dependencies=False)
|
||||
self._plugin_dependency_pipeline: PluginDependencyPipeline = PluginDependencyPipeline()
|
||||
self._blocked_plugin_reasons: Dict[str, str] = {}
|
||||
self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload
|
||||
self._config_reload_callback_registered: bool = False
|
||||
self._hook_dispatcher: HookDispatcher = HookDispatcher(lambda: self.supervisors)
|
||||
self._hook_spec_registry: HookSpecRegistry = HookSpecRegistry()
|
||||
self._builtin_hook_specs_registered: bool = False
|
||||
self._hook_dispatcher: HookDispatcher = HookDispatcher(
|
||||
lambda: self.supervisors,
|
||||
hook_spec_registry=self._hook_spec_registry,
|
||||
)
|
||||
|
||||
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
|
||||
"""接收 Platform IO 审核后的入站消息并送入主消息链。
|
||||
@@ -109,7 +145,7 @@ class PluginRuntimeManager(
|
||||
@classmethod
|
||||
def _discover_plugin_dependency_map(cls, plugin_dirs: Iterable[Path]) -> Dict[str, List[str]]:
|
||||
"""扫描指定插件目录集合,返回 ``plugin_id -> dependencies`` 映射。"""
|
||||
validator = ManifestValidator()
|
||||
validator = ManifestValidator(validate_python_package_dependencies=False)
|
||||
return validator.build_plugin_dependency_map(plugin_dirs)
|
||||
|
||||
@classmethod
|
||||
@@ -142,6 +178,233 @@ class PluginRuntimeManager(
|
||||
return ["third_party", "builtin"]
|
||||
return ["builtin", "third_party"]
|
||||
|
||||
@staticmethod
|
||||
def _instantiate_supervisor(supervisor_cls: Any, **kwargs: Any) -> Any:
|
||||
"""兼容不同构造签名地实例化 Supervisor。
|
||||
|
||||
Args:
|
||||
supervisor_cls: 目标 Supervisor 类。
|
||||
**kwargs: 期望传入的构造参数。
|
||||
|
||||
Returns:
|
||||
Any: 实例化后的 Supervisor。
|
||||
"""
|
||||
|
||||
signature = inspect.signature(supervisor_cls)
|
||||
accepts_var_keyword = any(
|
||||
parameter.kind == inspect.Parameter.VAR_KEYWORD
|
||||
for parameter in signature.parameters.values()
|
||||
)
|
||||
if accepts_var_keyword:
|
||||
return supervisor_cls(**kwargs)
|
||||
|
||||
supported_kwargs = {
|
||||
key: value
|
||||
for key, value in kwargs.items()
|
||||
if key in signature.parameters
|
||||
}
|
||||
return supervisor_cls(**supported_kwargs)
|
||||
|
||||
def _resolve_runtime_plugin_dirs(self) -> Tuple[List[Path], List[Path]]:
|
||||
"""解析当前运行时应管理的插件根目录。
|
||||
|
||||
Returns:
|
||||
Tuple[List[Path], List[Path]]: 内置插件目录列表与第三方插件目录列表。
|
||||
"""
|
||||
|
||||
return self._get_builtin_plugin_dirs(), self._get_third_party_plugin_dirs()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_supervisor_socket_paths() -> Tuple[Optional[str], Optional[str]]:
|
||||
"""解析内置与第三方 Supervisor 的 IPC 地址。
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[str], Optional[str]]: 内置 Runner 与第三方 Runner 的 socket 地址。
|
||||
"""
|
||||
|
||||
runtime_config = config_manager.get_global_config().plugin_runtime
|
||||
socket_path_base = runtime_config.ipc_socket_path or None
|
||||
builtin_socket = f"{socket_path_base}-builtin" if socket_path_base else None
|
||||
third_party_socket = f"{socket_path_base}-third_party" if socket_path_base else None
|
||||
return builtin_socket, third_party_socket
|
||||
|
||||
def _apply_blocked_plugin_reasons_to_supervisors(self) -> None:
|
||||
"""将当前阻止加载插件列表同步到全部 Supervisor。"""
|
||||
|
||||
for supervisor in self.supervisors:
|
||||
set_blocked_plugin_reasons = getattr(supervisor, "set_blocked_plugin_reasons", None)
|
||||
if callable(set_blocked_plugin_reasons):
|
||||
set_blocked_plugin_reasons(self._blocked_plugin_reasons)
|
||||
|
||||
def _set_blocked_plugin_reasons(self, blocked_plugin_reasons: Dict[str, str]) -> Set[str]:
|
||||
"""更新 Host 侧维护的阻止加载插件列表。
|
||||
|
||||
Args:
|
||||
blocked_plugin_reasons: 最新的阻止加载插件及原因映射。
|
||||
|
||||
Returns:
|
||||
Set[str]: 本次发生状态变化的插件 ID 集合。
|
||||
"""
|
||||
|
||||
normalized_reasons = {
|
||||
str(plugin_id or "").strip(): str(reason or "").strip()
|
||||
for plugin_id, reason in blocked_plugin_reasons.items()
|
||||
if str(plugin_id or "").strip() and str(reason or "").strip()
|
||||
}
|
||||
changed_plugin_ids = {
|
||||
plugin_id
|
||||
for plugin_id in set(self._blocked_plugin_reasons) | set(normalized_reasons)
|
||||
if self._blocked_plugin_reasons.get(plugin_id) != normalized_reasons.get(plugin_id)
|
||||
}
|
||||
self._blocked_plugin_reasons = normalized_reasons
|
||||
self._apply_blocked_plugin_reasons_to_supervisors()
|
||||
return changed_plugin_ids
|
||||
|
||||
async def _sync_plugin_dependencies(self, plugin_dirs: Sequence[Path]) -> DependencySyncState:
|
||||
"""执行插件依赖同步,并刷新阻止加载插件列表。
|
||||
|
||||
Args:
|
||||
plugin_dirs: 当前需要参与分析的插件根目录列表。
|
||||
|
||||
Returns:
|
||||
DependencySyncState: 同步后的环境变更状态与阻止列表变化集合。
|
||||
"""
|
||||
|
||||
result = await self._plugin_dependency_pipeline.execute(plugin_dirs)
|
||||
changed_plugin_ids = self._set_blocked_plugin_reasons(result.blocked_plugin_reasons)
|
||||
return DependencySyncState(
|
||||
blocked_changed_plugin_ids=changed_plugin_ids,
|
||||
environment_changed=result.environment_changed,
|
||||
)
|
||||
|
||||
def _build_supervisors(self, builtin_dirs: Sequence[Path], third_party_dirs: Sequence[Path]) -> None:
|
||||
"""根据目录列表创建当前运行时所需的 Supervisor。
|
||||
|
||||
Args:
|
||||
builtin_dirs: 内置插件目录列表。
|
||||
third_party_dirs: 第三方插件目录列表。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
builtin_socket, third_party_socket = self._resolve_supervisor_socket_paths()
|
||||
self._builtin_supervisor = None
|
||||
self._third_party_supervisor = None
|
||||
|
||||
if builtin_dirs:
|
||||
builtin_supervisor = self._instantiate_supervisor(
|
||||
PluginSupervisor,
|
||||
plugin_dirs=list(builtin_dirs),
|
||||
group_name="builtin",
|
||||
hook_spec_registry=self._hook_spec_registry,
|
||||
socket_path=builtin_socket,
|
||||
)
|
||||
self._builtin_supervisor = builtin_supervisor
|
||||
self._register_capability_impls(builtin_supervisor)
|
||||
|
||||
if third_party_dirs:
|
||||
third_party_supervisor = self._instantiate_supervisor(
|
||||
PluginSupervisor,
|
||||
plugin_dirs=list(third_party_dirs),
|
||||
group_name="third_party",
|
||||
hook_spec_registry=self._hook_spec_registry,
|
||||
socket_path=third_party_socket,
|
||||
)
|
||||
self._third_party_supervisor = third_party_supervisor
|
||||
self._register_capability_impls(third_party_supervisor)
|
||||
|
||||
self._apply_blocked_plugin_reasons_to_supervisors()
|
||||
|
||||
async def _start_supervisors(
|
||||
self,
|
||||
builtin_dirs: Sequence[Path],
|
||||
third_party_dirs: Sequence[Path],
|
||||
) -> List["PluginSupervisor"]:
|
||||
"""按依赖顺序启动当前已创建的 Supervisor。
|
||||
|
||||
Args:
|
||||
builtin_dirs: 内置插件目录列表。
|
||||
third_party_dirs: 第三方插件目录列表。
|
||||
|
||||
Returns:
|
||||
List[PluginSupervisor]: 成功启动的 Supervisor 列表。
|
||||
"""
|
||||
|
||||
started_supervisors: List["PluginSupervisor"] = []
|
||||
supervisor_groups: Dict[str, Optional["PluginSupervisor"]] = {
|
||||
"builtin": self._builtin_supervisor,
|
||||
"third_party": self._third_party_supervisor,
|
||||
}
|
||||
start_order = self._build_group_start_order(builtin_dirs, third_party_dirs)
|
||||
|
||||
try:
|
||||
for group_name in start_order:
|
||||
supervisor = supervisor_groups.get(group_name)
|
||||
if supervisor is None:
|
||||
continue
|
||||
|
||||
external_plugin_versions = {
|
||||
plugin_id: plugin_version
|
||||
for started_supervisor in started_supervisors
|
||||
for plugin_id, plugin_version in started_supervisor.get_loaded_plugin_versions().items()
|
||||
}
|
||||
supervisor.set_external_available_plugins(external_plugin_versions)
|
||||
set_blocked_plugin_reasons = getattr(supervisor, "set_blocked_plugin_reasons", None)
|
||||
if callable(set_blocked_plugin_reasons):
|
||||
set_blocked_plugin_reasons(self._blocked_plugin_reasons)
|
||||
await supervisor.start()
|
||||
started_supervisors.append(supervisor)
|
||||
except Exception:
|
||||
await asyncio.gather(*(supervisor.stop() for supervisor in started_supervisors), return_exceptions=True)
|
||||
raise
|
||||
|
||||
return started_supervisors
|
||||
|
||||
async def _stop_supervisors(self) -> None:
|
||||
"""停止当前全部 Supervisor。"""
|
||||
|
||||
supervisors = self.supervisors
|
||||
if not supervisors:
|
||||
return
|
||||
|
||||
await asyncio.gather(*(supervisor.stop() for supervisor in supervisors), return_exceptions=True)
|
||||
self._builtin_supervisor = None
|
||||
self._third_party_supervisor = None
|
||||
|
||||
async def _restart_supervisors(self, reason: str) -> bool:
|
||||
"""重启当前全部 Supervisor。
|
||||
|
||||
Args:
|
||||
reason: 本次重启的原因。
|
||||
|
||||
Returns:
|
||||
bool: 是否重启成功。
|
||||
"""
|
||||
|
||||
builtin_dirs, third_party_dirs = self._resolve_runtime_plugin_dirs()
|
||||
if duplicate_plugin_ids := self._find_duplicate_plugin_ids(builtin_dirs + third_party_dirs):
|
||||
details = "; ".join(
|
||||
f"{plugin_id}: {', '.join(str(path) for path in paths)}"
|
||||
for plugin_id, paths in sorted(duplicate_plugin_ids.items())
|
||||
)
|
||||
logger.error(f"检测到重复插件 ID,拒绝执行 Supervisor 重启: {details}")
|
||||
return False
|
||||
|
||||
logger.info(f"开始重启插件运行时 Supervisor: {reason}")
|
||||
await self._stop_supervisors()
|
||||
self._build_supervisors(builtin_dirs, third_party_dirs)
|
||||
|
||||
try:
|
||||
await self._start_supervisors(builtin_dirs, third_party_dirs)
|
||||
except Exception as exc:
|
||||
logger.error(f"重启插件运行时 Supervisor 失败: {exc}", exc_info=True)
|
||||
await self._stop_supervisors()
|
||||
return False
|
||||
|
||||
self._refresh_plugin_config_watch_subscriptions()
|
||||
logger.info(f"插件运行时 Supervisor 已重启完成: {reason}")
|
||||
return True
|
||||
|
||||
# ─── 生命周期 ─────────────────────────────────────────────
|
||||
|
||||
async def start(self) -> None:
|
||||
@@ -155,10 +418,7 @@ class PluginRuntimeManager(
|
||||
logger.info("插件运行时已在配置中禁用,跳过启动")
|
||||
return
|
||||
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
builtin_dirs = self._get_builtin_plugin_dirs()
|
||||
third_party_dirs = self._get_third_party_plugin_dirs()
|
||||
builtin_dirs, third_party_dirs = self._resolve_runtime_plugin_dirs()
|
||||
|
||||
if duplicate_plugin_ids := self._find_duplicate_plugin_ids(builtin_dirs + third_party_dirs):
|
||||
details = "; ".join(
|
||||
@@ -172,56 +432,19 @@ class PluginRuntimeManager(
|
||||
logger.info("未找到任何插件目录,跳过插件运行时启动")
|
||||
return
|
||||
|
||||
dependency_sync_state = await self._sync_plugin_dependencies(builtin_dirs + third_party_dirs)
|
||||
if dependency_sync_state.environment_changed:
|
||||
logger.info("插件依赖流水线已更新当前 Python 环境,启动时将直接加载最新环境")
|
||||
|
||||
self.ensure_builtin_hook_specs_registered()
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
self._build_supervisors(builtin_dirs, third_party_dirs)
|
||||
|
||||
# 从配置读取自定义 IPC socket 路径(留空则自动生成)
|
||||
socket_path_base = _cfg.ipc_socket_path or None
|
||||
|
||||
# 当用户指定了自定义路径时,为两个 Supervisor 添加后缀以避免 UDS 冲突
|
||||
builtin_socket = f"{socket_path_base}-builtin" if socket_path_base else None
|
||||
third_party_socket = f"{socket_path_base}-third_party" if socket_path_base else None
|
||||
|
||||
# 创建两个 Supervisor,各自拥有独立的 socket / Runner 子进程
|
||||
if builtin_dirs:
|
||||
self._builtin_supervisor = PluginSupervisor(
|
||||
plugin_dirs=builtin_dirs,
|
||||
group_name="builtin",
|
||||
socket_path=builtin_socket,
|
||||
)
|
||||
self._register_capability_impls(self._builtin_supervisor)
|
||||
|
||||
if third_party_dirs:
|
||||
self._third_party_supervisor = PluginSupervisor(
|
||||
plugin_dirs=third_party_dirs,
|
||||
group_name="third_party",
|
||||
socket_path=third_party_socket,
|
||||
)
|
||||
self._register_capability_impls(self._third_party_supervisor)
|
||||
|
||||
started_supervisors: List[PluginSupervisor] = []
|
||||
started_supervisors: List["PluginSupervisor"] = []
|
||||
try:
|
||||
platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound)
|
||||
await platform_io_manager.ensure_send_pipeline_ready()
|
||||
|
||||
supervisor_groups: Dict[str, Optional[PluginSupervisor]] = {
|
||||
"builtin": self._builtin_supervisor,
|
||||
"third_party": self._third_party_supervisor,
|
||||
}
|
||||
start_order = self._build_group_start_order(builtin_dirs, third_party_dirs)
|
||||
|
||||
for group_name in start_order:
|
||||
supervisor = supervisor_groups.get(group_name)
|
||||
if supervisor is None:
|
||||
continue
|
||||
|
||||
external_plugin_versions = {
|
||||
plugin_id: plugin_version
|
||||
for started_supervisor in started_supervisors
|
||||
for plugin_id, plugin_version in started_supervisor.get_loaded_plugin_versions().items()
|
||||
}
|
||||
supervisor.set_external_available_plugins(external_plugin_versions)
|
||||
await supervisor.start()
|
||||
started_supervisors.append(supervisor)
|
||||
started_supervisors = await self._start_supervisors(builtin_dirs, third_party_dirs)
|
||||
|
||||
await self._start_plugin_file_watcher()
|
||||
config_manager.register_reload_callback(self._config_reload_callback)
|
||||
@@ -315,6 +538,7 @@ class PluginRuntimeManager(
|
||||
spec: 需要注册的 Hook 规格。
|
||||
"""
|
||||
|
||||
self.ensure_builtin_hook_specs_registered()
|
||||
self._hook_dispatcher.register_hook_spec(spec)
|
||||
|
||||
def register_hook_specs(self, specs: Sequence[HookSpec]) -> None:
|
||||
@@ -324,8 +548,41 @@ class PluginRuntimeManager(
|
||||
specs: 需要注册的 Hook 规格序列。
|
||||
"""
|
||||
|
||||
self.ensure_builtin_hook_specs_registered()
|
||||
self._hook_dispatcher.register_hook_specs(specs)
|
||||
|
||||
def unregister_hook_spec(self, hook_name: str) -> bool:
|
||||
"""注销指定命名 Hook 规格。
|
||||
|
||||
Args:
|
||||
hook_name: 目标 Hook 名称。
|
||||
|
||||
Returns:
|
||||
bool: 是否成功注销。
|
||||
"""
|
||||
|
||||
self.ensure_builtin_hook_specs_registered()
|
||||
return self._hook_dispatcher.unregister_hook_spec(hook_name)
|
||||
|
||||
def list_hook_specs(self) -> List[HookSpec]:
|
||||
"""返回当前全部命名 Hook 规格。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 当前已注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
self.ensure_builtin_hook_specs_registered()
|
||||
return self._hook_dispatcher.list_hook_specs()
|
||||
|
||||
def ensure_builtin_hook_specs_registered(self) -> None:
|
||||
"""确保内置 Hook 规格已经注册到共享中心表。"""
|
||||
|
||||
if self._builtin_hook_specs_registered:
|
||||
return
|
||||
|
||||
register_builtin_hook_specs(self._hook_spec_registry)
|
||||
self._builtin_hook_specs_registered = True
|
||||
|
||||
def _build_registered_dependency_map(self) -> Dict[str, Set[str]]:
|
||||
"""根据当前已注册插件构建全局依赖图。"""
|
||||
|
||||
@@ -364,9 +621,7 @@ class PluginRuntimeManager(
|
||||
"""构建当前已注册插件到所属 Supervisor 的映射。"""
|
||||
|
||||
return {
|
||||
plugin_id: supervisor
|
||||
for supervisor in self.supervisors
|
||||
for plugin_id in supervisor.get_loaded_plugin_ids()
|
||||
plugin_id: supervisor for supervisor in self.supervisors for plugin_id in supervisor.get_loaded_plugin_ids()
|
||||
}
|
||||
|
||||
def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> Dict[str, str]:
|
||||
@@ -411,9 +666,7 @@ class PluginRuntimeManager(
|
||||
local_plugin_ids = set(supervisor.get_loaded_plugin_ids())
|
||||
local_dependency_map = {
|
||||
plugin_id: {
|
||||
dependency
|
||||
for dependency in dependency_map.get(plugin_id, set())
|
||||
if dependency in local_plugin_ids
|
||||
dependency for dependency in dependency_map.get(plugin_id, set()) if dependency in local_plugin_ids
|
||||
}
|
||||
for plugin_id in local_plugin_ids
|
||||
}
|
||||
@@ -440,13 +693,26 @@ class PluginRuntimeManager(
|
||||
"""
|
||||
|
||||
normalized_plugin_ids = [
|
||||
normalized_plugin_id
|
||||
for plugin_id in plugin_ids
|
||||
if (normalized_plugin_id := str(plugin_id or "").strip())
|
||||
normalized_plugin_id for plugin_id in plugin_ids if (normalized_plugin_id := str(plugin_id or "").strip())
|
||||
]
|
||||
if not normalized_plugin_ids:
|
||||
return True
|
||||
|
||||
blocked_plugin_ids = [plugin_id for plugin_id in normalized_plugin_ids if plugin_id in self._blocked_plugin_reasons]
|
||||
if blocked_plugin_ids:
|
||||
logger.warning(
|
||||
"以下插件当前被依赖流水线阻止加载,已拒绝重载请求: "
|
||||
+ ", ".join(
|
||||
f"{plugin_id} ({self._blocked_plugin_reasons[plugin_id]})"
|
||||
for plugin_id in sorted(blocked_plugin_ids)
|
||||
)
|
||||
)
|
||||
normalized_plugin_ids = [
|
||||
plugin_id for plugin_id in normalized_plugin_ids if plugin_id not in self._blocked_plugin_reasons
|
||||
]
|
||||
if not normalized_plugin_ids:
|
||||
return False
|
||||
|
||||
dependency_map = self._build_registered_dependency_map()
|
||||
supervisor_by_plugin = self._build_registered_supervisor_map()
|
||||
supervisor_roots: Dict["PluginSupervisor", List[str]] = {}
|
||||
@@ -518,9 +784,7 @@ class PluginRuntimeManager(
|
||||
return False
|
||||
|
||||
config_payload = (
|
||||
config_data
|
||||
if config_data is not None
|
||||
else self._load_plugin_config_for_supervisor(sv, plugin_id)
|
||||
config_data if config_data is not None else self._load_plugin_config_for_supervisor(sv, plugin_id)
|
||||
)
|
||||
return await sv.notify_plugin_config_updated(
|
||||
plugin_id=plugin_id,
|
||||
@@ -529,6 +793,91 @@ class PluginRuntimeManager(
|
||||
config_scope=config_scope,
|
||||
)
|
||||
|
||||
async def validate_plugin_config(self, plugin_id: str, config_data: Dict[str, Any]) -> Dict[str, Any] | None:
|
||||
"""请求运行时按插件自身配置模型校验配置。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
config_data: 待校验的配置内容。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | None: 校验成功时返回规范化后的配置;若插件不存在、
|
||||
当前不可路由或运行时不可用,则返回 ``None`` 以便调用方回退到弱推断方案。
|
||||
|
||||
Raises:
|
||||
ValueError: 插件已加载,但配置校验失败时抛出。
|
||||
"""
|
||||
|
||||
if not self._started:
|
||||
return None
|
||||
|
||||
try:
|
||||
supervisor = self._get_supervisor_for_plugin(plugin_id)
|
||||
except RuntimeError as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置校验路由失败,将回退到静态 Schema: {exc}")
|
||||
return None
|
||||
|
||||
if supervisor is None:
|
||||
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
|
||||
if supervisor is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return await supervisor.validate_plugin_config(plugin_id, config_data)
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning(f"插件 {plugin_id} 运行时配置校验不可用,将回退到静态 Schema: {exc}")
|
||||
return None
|
||||
|
||||
async def inspect_plugin_config(
|
||||
self,
|
||||
plugin_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
use_provided_config: bool = False,
|
||||
) -> InspectPluginConfigResultPayload | None:
|
||||
"""请求运行时解析插件配置元数据。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
config_data: 可选的配置内容。
|
||||
use_provided_config: 是否优先使用传入的配置内容而不是磁盘配置。
|
||||
|
||||
Returns:
|
||||
InspectPluginConfigResultPayload | None: 解析成功时返回结构化结果;若插件
|
||||
当前不可路由或运行时不可用,则返回 ``None``。
|
||||
|
||||
Raises:
|
||||
ValueError: 插件存在,但运行时明确拒绝解析请求时抛出。
|
||||
"""
|
||||
|
||||
if not self._started:
|
||||
return None
|
||||
|
||||
try:
|
||||
supervisor = self._get_supervisor_for_plugin(plugin_id)
|
||||
except RuntimeError as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置解析路由失败: {exc}")
|
||||
return None
|
||||
|
||||
if supervisor is None:
|
||||
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
|
||||
if supervisor is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return await supervisor.inspect_plugin_config(
|
||||
plugin_id=plugin_id,
|
||||
config_data=config_data,
|
||||
use_provided_config=use_provided_config,
|
||||
)
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置解析不可用: {exc}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_config_reload_scopes(changed_scopes: Sequence[str]) -> tuple[str, ...]:
|
||||
"""规范化配置热重载范围列表。
|
||||
@@ -731,11 +1080,25 @@ class PluginRuntimeManager(
|
||||
return matches[0] if matches else None
|
||||
|
||||
async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool:
|
||||
"""加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。"""
|
||||
"""加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
reason: 加载或重载原因。
|
||||
|
||||
Returns:
|
||||
bool: 插件最终是否处于已加载状态。
|
||||
"""
|
||||
|
||||
normalized_plugin_id = str(plugin_id or "").strip()
|
||||
if not normalized_plugin_id:
|
||||
return False
|
||||
if normalized_plugin_id in self._blocked_plugin_reasons:
|
||||
logger.warning(
|
||||
f"插件 {normalized_plugin_id} 当前被依赖流水线阻止加载: "
|
||||
f"{self._blocked_plugin_reasons[normalized_plugin_id]}"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
registered_supervisor = self._get_supervisor_for_plugin(normalized_plugin_id)
|
||||
@@ -749,17 +1112,18 @@ class PluginRuntimeManager(
|
||||
if supervisor is None:
|
||||
return False
|
||||
|
||||
return await supervisor.reload_plugins(
|
||||
reloaded = await supervisor.reload_plugins(
|
||||
plugin_ids=[normalized_plugin_id],
|
||||
reason=reason,
|
||||
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
|
||||
)
|
||||
return reloaded and normalized_plugin_id in supervisor.get_loaded_plugin_ids()
|
||||
|
||||
@classmethod
|
||||
def _find_duplicate_plugin_ids(cls, plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
|
||||
"""扫描插件目录,找出被多个目录重复声明的插件 ID。"""
|
||||
plugin_locations: Dict[str, List[Path]] = {}
|
||||
validator = ManifestValidator()
|
||||
validator = ManifestValidator(validate_python_package_dependencies=False)
|
||||
for plugin_path, manifest in validator.iter_plugin_manifests(plugin_dirs):
|
||||
plugin_locations.setdefault(manifest.id, []).append(plugin_path)
|
||||
|
||||
@@ -869,7 +1233,9 @@ class PluginRuntimeManager(
|
||||
if self._plugin_dir_matches(cached_path, Path(plugin_dir)):
|
||||
return cached_path
|
||||
|
||||
for candidate_plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
|
||||
for candidate_plugin_id, plugin_path in self._iter_discovered_plugin_paths(
|
||||
getattr(supervisor, "_plugin_dirs", [])
|
||||
):
|
||||
if candidate_plugin_id != plugin_id:
|
||||
continue
|
||||
self._plugin_path_cache[plugin_id] = plugin_path
|
||||
@@ -878,15 +1244,16 @@ class PluginRuntimeManager(
|
||||
return None
|
||||
|
||||
def _refresh_plugin_config_watch_subscriptions(self) -> None:
|
||||
"""按当前已注册插件集合刷新 config.toml 的单插件订阅。
|
||||
"""按当前可识别插件集合刷新 config.toml 的单插件订阅。
|
||||
|
||||
当插件热重载后,插件集合或目录位置可能发生变化,因此需要重新对齐
|
||||
watcher 的订阅,确保每个插件配置变更只触发对应 plugin_id。
|
||||
这里不仅覆盖当前已注册插件,也覆盖已存在但暂未激活的合法插件。
|
||||
"""
|
||||
if self._plugin_file_watcher is None:
|
||||
return
|
||||
|
||||
desired_plugin_paths = dict(self._iter_registered_plugin_paths())
|
||||
desired_plugin_paths = dict(self._iter_watchable_plugin_paths())
|
||||
self._plugin_path_cache = desired_plugin_paths.copy()
|
||||
desired_config_paths = {
|
||||
plugin_id: self._resolve_plugin_config_path(plugin_id, plugin_path)
|
||||
@@ -909,9 +1276,7 @@ class PluginRuntimeManager(
|
||||
)
|
||||
self._plugin_config_watcher_subscriptions[plugin_id] = (config_path, subscription_id)
|
||||
|
||||
def _build_plugin_config_change_callback(
|
||||
self, plugin_id: str
|
||||
) -> Callable[[Sequence[FileChange]], Awaitable[None]]:
|
||||
def _build_plugin_config_change_callback(self, plugin_id: str) -> Callable[[Sequence[FileChange]], Awaitable[None]]:
|
||||
"""为指定插件生成配置文件变更回调。"""
|
||||
|
||||
async def _callback(changes: Sequence[FileChange]) -> None:
|
||||
@@ -931,6 +1296,18 @@ class PluginRuntimeManager(
|
||||
if plugin_path := self._get_plugin_path_for_supervisor(supervisor, plugin_id):
|
||||
yield plugin_id, plugin_path
|
||||
|
||||
def _iter_watchable_plugin_paths(self) -> Iterable[Tuple[str, Path]]:
|
||||
"""迭代应被配置监听器追踪的插件目录。
|
||||
|
||||
Returns:
|
||||
Iterable[Tuple[str, Path]]: ``(plugin_id, plugin_path)`` 迭代器。
|
||||
"""
|
||||
|
||||
watchable_plugin_paths = dict(self._iter_discovered_plugin_paths(self._iter_plugin_dirs()))
|
||||
for plugin_id, plugin_path in self._iter_registered_plugin_paths():
|
||||
watchable_plugin_paths.setdefault(plugin_id, plugin_path)
|
||||
yield from watchable_plugin_paths.items()
|
||||
|
||||
def _get_plugin_config_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
|
||||
"""从指定 Supervisor 的插件目录中定位某个插件的 config.toml。"""
|
||||
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
|
||||
@@ -958,18 +1335,43 @@ class PluginRuntimeManager(
|
||||
return
|
||||
|
||||
if supervisor is None:
|
||||
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
|
||||
if supervisor is None:
|
||||
return
|
||||
|
||||
plugin_is_loaded = plugin_id in getattr(supervisor, "_registered_plugins", {})
|
||||
|
||||
try:
|
||||
snapshot = await supervisor.inspect_plugin_config(plugin_id)
|
||||
except Exception as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置文件变更解析失败: {exc}")
|
||||
return
|
||||
|
||||
try:
|
||||
config_payload = self._load_plugin_config_for_supervisor(supervisor, plugin_id)
|
||||
delivered = await supervisor.notify_plugin_config_updated(
|
||||
plugin_id=plugin_id,
|
||||
config_data=config_payload,
|
||||
config_version="",
|
||||
config_scope="self",
|
||||
)
|
||||
if not delivered:
|
||||
logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败")
|
||||
if plugin_is_loaded and snapshot.enabled:
|
||||
delivered = await supervisor.notify_plugin_config_updated(
|
||||
plugin_id=plugin_id,
|
||||
config_data=dict(snapshot.normalized_config),
|
||||
config_version="",
|
||||
config_scope="self",
|
||||
)
|
||||
if not delivered:
|
||||
logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败")
|
||||
return
|
||||
|
||||
if plugin_is_loaded and not snapshot.enabled:
|
||||
reloaded = await self.reload_plugins_globally([plugin_id], reason="config_disabled")
|
||||
if not reloaded:
|
||||
logger.warning(f"插件 {plugin_id} 禁用配置已写入,但运行时卸载失败")
|
||||
return
|
||||
|
||||
if not snapshot.enabled:
|
||||
logger.info(f"插件 {plugin_id} 当前处于禁用状态,跳过自动加载")
|
||||
return
|
||||
|
||||
loaded = await self.load_plugin_globally(plugin_id, reason="config_enabled")
|
||||
if not loaded:
|
||||
logger.warning(f"插件 {plugin_id} 配置文件变更后自动加载失败")
|
||||
except Exception as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}")
|
||||
|
||||
@@ -983,7 +1385,8 @@ class PluginRuntimeManager(
|
||||
if not self._started or not changes:
|
||||
return
|
||||
|
||||
if duplicate_plugin_ids := self._find_duplicate_plugin_ids(list(self._iter_plugin_dirs())):
|
||||
plugin_dirs = list(self._iter_plugin_dirs())
|
||||
if duplicate_plugin_ids := self._find_duplicate_plugin_ids(plugin_dirs):
|
||||
details = "; ".join(
|
||||
f"{plugin_id}: {', '.join(str(path) for path in paths)}"
|
||||
for plugin_id, paths in sorted(duplicate_plugin_ids.items())
|
||||
@@ -991,21 +1394,24 @@ class PluginRuntimeManager(
|
||||
logger.error(f"检测到重复插件 ID,跳过本次插件热重载: {details}")
|
||||
return
|
||||
|
||||
changed_plugin_ids: List[str] = []
|
||||
changed_paths = [change.path.resolve() for change in changes]
|
||||
relevant_source_changes = [
|
||||
change.path.resolve()
|
||||
for change in changes
|
||||
if change.path.name in {"plugin.py", "_manifest.json"} or change.path.suffix == ".py"
|
||||
]
|
||||
if not relevant_source_changes:
|
||||
return
|
||||
|
||||
for supervisor in self.supervisors:
|
||||
for path in changed_paths:
|
||||
plugin_id = self._match_plugin_id_for_supervisor(supervisor, path)
|
||||
if plugin_id is None:
|
||||
continue
|
||||
if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py":
|
||||
if plugin_id not in changed_plugin_ids:
|
||||
changed_plugin_ids.append(plugin_id)
|
||||
dependency_sync_state = await self._sync_plugin_dependencies(plugin_dirs)
|
||||
restart_reason = "file_watcher"
|
||||
if dependency_sync_state.environment_changed:
|
||||
restart_reason = "file_watcher_dependency_install"
|
||||
elif dependency_sync_state.blocked_changed_plugin_ids:
|
||||
restart_reason = "file_watcher_blocklist_changed"
|
||||
|
||||
if changed_plugin_ids:
|
||||
await self.reload_plugins_globally(changed_plugin_ids, reason="file_watcher")
|
||||
self._refresh_plugin_config_watch_subscriptions()
|
||||
restarted = await self._restart_supervisors(restart_reason)
|
||||
if not restarted:
|
||||
logger.warning(f"插件源码变更后重启 Supervisor 失败: {restart_reason}")
|
||||
|
||||
@staticmethod
|
||||
def _plugin_dir_matches(path: Path, plugin_dir: Path) -> bool:
|
||||
@@ -1023,7 +1429,10 @@ class PluginRuntimeManager(
|
||||
return plugin_id
|
||||
|
||||
for plugin_id, plugin_path in self._plugin_path_cache.items():
|
||||
if not any(self._plugin_dir_matches(plugin_path, Path(plugin_dir)) for plugin_dir in getattr(supervisor, "_plugin_dirs", [])):
|
||||
if not any(
|
||||
self._plugin_dir_matches(plugin_path, Path(plugin_dir))
|
||||
for plugin_dir in getattr(supervisor, "_plugin_dirs", [])
|
||||
):
|
||||
continue
|
||||
if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
|
||||
return plugin_id
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""RPC Envelope 消息模型
|
||||
"""RPC Envelope 消息模型。
|
||||
|
||||
定义 Host 与 Runner 之间所有 RPC 消息的统一信封格式。
|
||||
使用 Pydantic 进行 schema 定义与校验。
|
||||
使用 Pydantic 进行 Schema 定义与校验。
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
@@ -39,12 +39,23 @@ class ConfigReloadScope(str, Enum):
|
||||
|
||||
# ====== 请求 ID 生成器 ======
|
||||
class RequestIdGenerator:
|
||||
"""单调递增 int64 请求 ID 生成器"""
|
||||
"""单调递增 int64 请求 ID 生成器。"""
|
||||
|
||||
def __init__(self, start: int = 1) -> None:
|
||||
"""初始化请求 ID 生成器。
|
||||
|
||||
Args:
|
||||
start: 起始请求 ID。
|
||||
"""
|
||||
self._counter = start
|
||||
|
||||
async def next(self) -> int:
|
||||
"""返回下一个请求 ID。
|
||||
|
||||
Returns:
|
||||
int: 下一个可用的请求 ID。
|
||||
"""
|
||||
|
||||
current = self._counter
|
||||
self._counter += 1
|
||||
return current
|
||||
@@ -52,7 +63,7 @@ class RequestIdGenerator:
|
||||
|
||||
# ====== Envelope 模型 ======
|
||||
class Envelope(BaseModel):
|
||||
"""RPC 统一消息封装
|
||||
"""RPC 统一消息封装。
|
||||
|
||||
所有 Host <-> Runner 消息均封装为此格式。
|
||||
序列化流程:Envelope -> .model_dump() -> MsgPack encode
|
||||
@@ -79,18 +90,44 @@ class Envelope(BaseModel):
|
||||
"""错误信息 (仅 response)"""
|
||||
|
||||
def is_request(self) -> bool:
|
||||
"""判断当前信封是否为请求消息。
|
||||
|
||||
Returns:
|
||||
bool: 当前消息类型是否为 ``REQUEST``。
|
||||
"""
|
||||
|
||||
return self.message_type == MessageType.REQUEST
|
||||
|
||||
def is_response(self) -> bool:
|
||||
"""判断当前信封是否为响应消息。
|
||||
|
||||
Returns:
|
||||
bool: 当前消息类型是否为 ``RESPONSE``。
|
||||
"""
|
||||
|
||||
return self.message_type == MessageType.RESPONSE
|
||||
|
||||
def is_broadcast(self) -> bool:
|
||||
"""判断当前信封是否为广播消息。
|
||||
|
||||
Returns:
|
||||
bool: 当前消息类型是否为 ``BROADCAST``。
|
||||
"""
|
||||
|
||||
return self.message_type == MessageType.BROADCAST
|
||||
|
||||
def make_response(
|
||||
self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None
|
||||
) -> "Envelope":
|
||||
"""基于当前请求创建对应的响应信封"""
|
||||
"""基于当前请求创建对应的响应信封。
|
||||
|
||||
Args:
|
||||
payload: 响应业务载荷。
|
||||
error: 响应错误信息。
|
||||
|
||||
Returns:
|
||||
Envelope: 对应的响应信封。
|
||||
"""
|
||||
return Envelope(
|
||||
protocol_version=self.protocol_version,
|
||||
request_id=self.request_id,
|
||||
@@ -102,7 +139,16 @@ class Envelope(BaseModel):
|
||||
)
|
||||
|
||||
def make_error_response(self, code: str, message: str = "", details: Optional[Dict[str, Any]] = None) -> "Envelope":
|
||||
"""基于当前请求创建错误响应"""
|
||||
"""基于当前请求创建错误响应。
|
||||
|
||||
Args:
|
||||
code: 错误码。
|
||||
message: 错误描述。
|
||||
details: 详细错误信息。
|
||||
|
||||
Returns:
|
||||
Envelope: 错误响应信封。
|
||||
"""
|
||||
return self.make_response(
|
||||
error={
|
||||
"code": code,
|
||||
@@ -141,9 +187,7 @@ class ComponentDeclaration(BaseModel):
|
||||
|
||||
name: str = Field(description="组件名称")
|
||||
"""组件名称"""
|
||||
component_type: str = Field(
|
||||
description="组件类型:action/command/tool/event_handler/hook_handler/message_gateway"
|
||||
)
|
||||
component_type: str = Field(description="组件类型:action/command/tool/event_handler/hook_handler/message_gateway")
|
||||
"""组件类型:`action`/`command`/`tool`/`event_handler`/`hook_handler`/`message_gateway`"""
|
||||
plugin_id: str = Field(description="所属插件 ID")
|
||||
"""所属插件 ID"""
|
||||
@@ -170,6 +214,10 @@ class RegisterPluginPayload(BaseModel):
|
||||
"""插件级依赖插件 ID 列表"""
|
||||
config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围")
|
||||
"""订阅的全局配置热重载范围"""
|
||||
default_config: Dict[str, Any] = Field(default_factory=dict, description="插件默认配置")
|
||||
"""插件默认配置"""
|
||||
config_schema: Dict[str, Any] = Field(default_factory=dict, description="插件配置 Schema")
|
||||
"""插件配置 Schema"""
|
||||
|
||||
|
||||
class BootstrapPluginPayload(BaseModel):
|
||||
@@ -240,6 +288,8 @@ class RunnerReadyPayload(BaseModel):
|
||||
"""已完成初始化的插件列表"""
|
||||
failed_plugins: List[str] = Field(default_factory=list, description="初始化失败的插件列表")
|
||||
"""初始化失败的插件列表"""
|
||||
inactive_plugins: List[str] = Field(default_factory=list, description="当前因禁用或依赖不可用而未激活的插件列表")
|
||||
"""当前因禁用或依赖不可用而未激活的插件列表"""
|
||||
|
||||
|
||||
# ====== 配置更新 ======
|
||||
@@ -256,6 +306,50 @@ class ConfigUpdatedPayload(BaseModel):
|
||||
"""配置内容"""
|
||||
|
||||
|
||||
class ValidatePluginConfigPayload(BaseModel):
|
||||
"""plugin.validate_config 请求 payload。"""
|
||||
|
||||
config_data: Dict[str, Any] = Field(default_factory=dict, description="待校验的配置内容")
|
||||
"""待校验的配置内容"""
|
||||
|
||||
|
||||
class InspectPluginConfigPayload(BaseModel):
|
||||
"""plugin.inspect_config 请求 payload。"""
|
||||
|
||||
config_data: Dict[str, Any] = Field(default_factory=dict, description="可选的配置内容")
|
||||
"""可选的配置内容"""
|
||||
use_provided_config: bool = Field(default=False, description="是否优先使用请求中携带的配置内容")
|
||||
"""是否优先使用请求中携带的配置内容"""
|
||||
|
||||
|
||||
class InspectPluginConfigResultPayload(BaseModel):
|
||||
"""plugin.inspect_config 响应 payload。"""
|
||||
|
||||
success: bool = Field(description="是否解析成功")
|
||||
"""是否解析成功"""
|
||||
default_config: Dict[str, Any] = Field(default_factory=dict, description="插件默认配置")
|
||||
"""插件默认配置"""
|
||||
config_schema: Dict[str, Any] = Field(default_factory=dict, description="插件配置 Schema")
|
||||
"""插件配置 Schema"""
|
||||
normalized_config: Dict[str, Any] = Field(default_factory=dict, description="归一化后的配置内容")
|
||||
"""归一化后的配置内容"""
|
||||
changed: bool = Field(default=False, description="是否在归一化过程中自动补齐或修正了配置")
|
||||
"""是否在归一化过程中自动补齐或修正了配置"""
|
||||
enabled: bool = Field(default=True, description="插件在当前配置下是否应被视为启用")
|
||||
"""插件在当前配置下是否应被视为启用"""
|
||||
|
||||
|
||||
class ValidatePluginConfigResultPayload(BaseModel):
|
||||
"""plugin.validate_config 响应 payload。"""
|
||||
|
||||
success: bool = Field(description="是否校验成功")
|
||||
"""是否校验成功"""
|
||||
normalized_config: Dict[str, Any] = Field(default_factory=dict, description="校验后的规范化配置")
|
||||
"""校验后的规范化配置"""
|
||||
changed: bool = Field(default=False, description="是否在校验过程中自动补齐或归一化")
|
||||
"""是否在校验过程中自动补齐或归一化"""
|
||||
|
||||
|
||||
# ====== 关停 ======
|
||||
class ShutdownPayload(BaseModel):
|
||||
"""plugin.shutdown / plugin.prepare_shutdown payload"""
|
||||
@@ -314,6 +408,8 @@ class ReloadPluginResultPayload(BaseModel):
|
||||
"""成功完成重载的插件列表"""
|
||||
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
|
||||
"""本次已卸载的插件列表"""
|
||||
inactive_plugins: List[str] = Field(default_factory=list, description="本次处于未激活状态的插件列表")
|
||||
"""本次处于未激活状态的插件列表"""
|
||||
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
|
||||
"""重载失败的插件及原因"""
|
||||
|
||||
@@ -329,6 +425,8 @@ class ReloadPluginsResultPayload(BaseModel):
|
||||
"""成功完成重载的插件列表"""
|
||||
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
|
||||
"""本次已卸载的插件列表"""
|
||||
inactive_plugins: List[str] = Field(default_factory=list, description="本次处于未激活状态的插件列表")
|
||||
"""本次处于未激活状态的插件列表"""
|
||||
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
|
||||
"""重载失败的插件及原因"""
|
||||
|
||||
|
||||
@@ -609,6 +609,7 @@ class ManifestValidator:
|
||||
host_version: str = "",
|
||||
sdk_version: str = "",
|
||||
project_root: Optional[Path] = None,
|
||||
validate_python_package_dependencies: bool = True,
|
||||
) -> None:
|
||||
"""初始化 Manifest 校验器。
|
||||
|
||||
@@ -616,10 +617,12 @@ class ManifestValidator:
|
||||
host_version: 当前 Host 版本号;留空时自动从主程序 ``pyproject.toml`` 读取。
|
||||
sdk_version: 当前 SDK 版本号;留空时自动从运行环境中探测。
|
||||
project_root: 项目根目录;留空时自动推断。
|
||||
validate_python_package_dependencies: 是否校验 Python 包依赖与当前环境的关系。
|
||||
"""
|
||||
self._project_root: Path = project_root or self._resolve_project_root()
|
||||
self._host_version: str = host_version or self._detect_default_host_version(self._project_root)
|
||||
self._sdk_version: str = sdk_version or self._detect_default_sdk_version(self._project_root)
|
||||
self._validate_python_package_dependencies: bool = validate_python_package_dependencies
|
||||
self.errors: List[str] = []
|
||||
self.warnings: List[str] = []
|
||||
|
||||
@@ -823,9 +826,10 @@ class ManifestValidator:
|
||||
if not sdk_ok:
|
||||
self.errors.append(f"SDK 版本不兼容: {sdk_message} (当前 SDK: {self._sdk_version})")
|
||||
|
||||
self._validate_python_package_dependencies(manifest)
|
||||
if self._validate_python_package_dependencies:
|
||||
self._validate_python_package_dependencies_against_runtime(manifest)
|
||||
|
||||
def _validate_python_package_dependencies(self, manifest: PluginManifest) -> None:
|
||||
def _validate_python_package_dependencies_against_runtime(self, manifest: PluginManifest) -> None:
|
||||
"""校验 Python 包依赖与主程序运行环境是否冲突。
|
||||
|
||||
Args:
|
||||
@@ -865,6 +869,68 @@ class ManifestValidator:
|
||||
f"主程序依赖约束为 {host_specifier or '任意版本'}"
|
||||
)
|
||||
|
||||
def load_host_dependency_requirements(self) -> Dict[str, Requirement]:
|
||||
"""读取主程序在 ``pyproject.toml`` 中声明的依赖约束。
|
||||
|
||||
Returns:
|
||||
Dict[str, Requirement]: 以规范化包名为键的依赖约束映射。
|
||||
"""
|
||||
|
||||
return self._load_host_dependency_requirements(self._project_root)
|
||||
|
||||
def get_installed_package_version(self, package_name: str) -> Optional[str]:
|
||||
"""查询当前运行环境中指定包的安装版本。
|
||||
|
||||
Args:
|
||||
package_name: 需要查询的包名。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 已安装版本号;未安装时返回 ``None``。
|
||||
"""
|
||||
|
||||
return self._get_installed_package_version(package_name)
|
||||
|
||||
@staticmethod
|
||||
def build_specifier_set(version_spec: str) -> Optional[SpecifierSet]:
|
||||
"""将版本约束文本转换为 ``SpecifierSet``。
|
||||
|
||||
Args:
|
||||
version_spec: 原始版本约束文本。
|
||||
|
||||
Returns:
|
||||
Optional[SpecifierSet]: 转换成功时返回约束对象,否则返回 ``None``。
|
||||
"""
|
||||
|
||||
return ManifestValidator._build_specifier_set(version_spec)
|
||||
|
||||
@staticmethod
|
||||
def version_matches_specifier(version: str, version_spec: str) -> bool:
|
||||
"""判断版本号是否满足给定约束。
|
||||
|
||||
Args:
|
||||
version: 待判断的版本号。
|
||||
version_spec: 版本约束表达式。
|
||||
|
||||
Returns:
|
||||
bool: 是否满足约束。
|
||||
"""
|
||||
|
||||
return ManifestValidator._version_matches_specifier(version, version_spec)
|
||||
|
||||
@classmethod
|
||||
def requirements_may_overlap(cls, left: SpecifierSet, right: SpecifierSet) -> bool:
|
||||
"""判断两个版本约束是否可能存在交集。
|
||||
|
||||
Args:
|
||||
left: 左侧版本约束。
|
||||
right: 右侧版本约束。
|
||||
|
||||
Returns:
|
||||
bool: 若两者可能同时满足则返回 ``True``。
|
||||
"""
|
||||
|
||||
return cls._requirements_may_overlap(left, right)
|
||||
|
||||
def _log_errors(self) -> None:
|
||||
"""输出当前累计的 Manifest 校验错误。"""
|
||||
for error_message in self.errors:
|
||||
|
||||
@@ -75,6 +75,35 @@ class PluginLoader:
|
||||
self._failed_plugins: Dict[str, str] = {}
|
||||
self._manifest_validator = ManifestValidator(host_version=host_version)
|
||||
self._compat_hook_installed = False
|
||||
self._blocked_plugin_reasons: Dict[str, str] = {}
|
||||
|
||||
def set_blocked_plugin_reasons(self, blocked_plugin_reasons: Optional[Dict[str, str]] = None) -> None:
|
||||
"""更新当前加载器持有的拒绝加载插件列表。
|
||||
|
||||
Args:
|
||||
blocked_plugin_reasons: 需要拒绝加载的插件及原因映射。
|
||||
"""
|
||||
|
||||
self._blocked_plugin_reasons = {
|
||||
str(plugin_id or "").strip(): str(reason or "").strip()
|
||||
for plugin_id, reason in (blocked_plugin_reasons or {}).items()
|
||||
if str(plugin_id or "").strip() and str(reason or "").strip()
|
||||
}
|
||||
|
||||
def get_blocked_plugin_reason(self, plugin_id: str) -> Optional[str]:
|
||||
"""返回指定插件当前的拒绝加载原因。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 若插件被阻止加载则返回原因,否则返回 ``None``。
|
||||
"""
|
||||
|
||||
normalized_plugin_id = str(plugin_id or "").strip()
|
||||
if not normalized_plugin_id:
|
||||
return None
|
||||
return self._blocked_plugin_reasons.get(normalized_plugin_id)
|
||||
|
||||
def discover_and_load(
|
||||
self,
|
||||
@@ -156,6 +185,11 @@ class PluginLoader:
|
||||
return None
|
||||
|
||||
plugin_id = manifest.id
|
||||
if blocked_reason := self.get_blocked_plugin_reason(plugin_id):
|
||||
self._failed_plugins[plugin_id] = blocked_reason
|
||||
logger.warning(f"插件 {plugin_id} 已被 Host 依赖流水线阻止加载: {blocked_reason}")
|
||||
return None
|
||||
|
||||
return plugin_id, (plugin_dir, manifest, plugin_path)
|
||||
|
||||
def _record_duplicate_candidates(self, duplicate_candidates: Dict[str, List[Path]]) -> None:
|
||||
|
||||
@@ -9,8 +9,10 @@
|
||||
6. 转发插件的能力调用到 Host
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, cast
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
@@ -23,8 +25,11 @@ import sys
|
||||
import time
|
||||
import tomllib
|
||||
|
||||
import tomlkit
|
||||
|
||||
from src.common.logger import get_console_handler, get_logger, initialize_logging
|
||||
from src.plugin_runtime import (
|
||||
ENV_BLOCKED_PLUGIN_REASONS,
|
||||
ENV_EXTERNAL_PLUGIN_IDS,
|
||||
ENV_HOST_VERSION,
|
||||
ENV_IPC_ADDRESS,
|
||||
@@ -37,6 +42,8 @@ from src.plugin_runtime.protocol.envelope import (
|
||||
ConfigUpdatedPayload,
|
||||
Envelope,
|
||||
HealthPayload,
|
||||
InspectPluginConfigPayload,
|
||||
InspectPluginConfigResultPayload,
|
||||
InvokePayload,
|
||||
InvokeResultPayload,
|
||||
RegisterPluginPayload,
|
||||
@@ -46,6 +53,8 @@ from src.plugin_runtime.protocol.envelope import (
|
||||
ReloadPluginsResultPayload,
|
||||
RunnerReadyPayload,
|
||||
UnregisterPluginPayload,
|
||||
ValidatePluginConfigPayload,
|
||||
ValidatePluginConfigResultPayload,
|
||||
)
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode
|
||||
from src.plugin_runtime.runner.log_handler import RunnerIPCLogHandler
|
||||
@@ -79,6 +88,72 @@ class _ContextAwarePlugin(Protocol):
|
||||
"""
|
||||
|
||||
|
||||
class _ConfigAwarePlugin(Protocol):
|
||||
"""支持声明式插件配置能力的插件协议。"""
|
||||
|
||||
def normalize_plugin_config(self, config_data: Optional[Mapping[str, Any]]) -> Tuple[Dict[str, Any], bool]:
|
||||
"""对插件配置进行归一化与补齐。
|
||||
|
||||
Args:
|
||||
config_data: 原始配置数据。
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, Any], bool]: 归一化后的配置,以及是否发生自动变更。
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||
"""注入插件当前配置。
|
||||
|
||||
Args:
|
||||
config: 当前最新插件配置。
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
def get_default_config(self) -> Dict[str, Any]:
|
||||
"""返回插件默认配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 默认配置字典。
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
def get_webui_config_schema(
|
||||
self,
|
||||
*,
|
||||
plugin_id: str = "",
|
||||
plugin_name: str = "",
|
||||
plugin_version: str = "",
|
||||
plugin_description: str = "",
|
||||
plugin_author: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
"""返回插件配置 Schema。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
plugin_name: 插件名称。
|
||||
plugin_version: 插件版本。
|
||||
plugin_description: 插件描述。
|
||||
plugin_author: 插件作者。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: WebUI 配置 Schema。
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class PluginActivationStatus(str, Enum):
|
||||
"""描述插件激活结果。"""
|
||||
|
||||
LOADED = "loaded"
|
||||
INACTIVE = "inactive"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
def _install_shutdown_signal_handlers(
|
||||
mark_runner_shutting_down: Callable[[], None],
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
@@ -122,6 +197,7 @@ class PluginRunner:
|
||||
session_token: str,
|
||||
plugin_dirs: List[str],
|
||||
external_available_plugins: Optional[Dict[str, str]] = None,
|
||||
blocked_plugin_reasons: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
"""初始化 Runner。
|
||||
|
||||
@@ -130,6 +206,7 @@ class PluginRunner:
|
||||
session_token: 握手用会话令牌。
|
||||
plugin_dirs: 当前 Runner 负责扫描的插件目录列表。
|
||||
external_available_plugins: 视为已满足的外部依赖插件版本映射。
|
||||
blocked_plugin_reasons: 需要拒绝加载的插件及原因映射。
|
||||
"""
|
||||
self._host_address: str = host_address
|
||||
self._session_token: str = session_token
|
||||
@@ -139,9 +216,15 @@ class PluginRunner:
|
||||
for plugin_id, plugin_version in (external_available_plugins or {}).items()
|
||||
if str(plugin_id or "").strip() and str(plugin_version or "").strip()
|
||||
}
|
||||
self._blocked_plugin_reasons: Dict[str, str] = {
|
||||
str(plugin_id or "").strip(): str(reason or "").strip()
|
||||
for plugin_id, reason in (blocked_plugin_reasons or {}).items()
|
||||
if str(plugin_id or "").strip() and str(reason or "").strip()
|
||||
}
|
||||
|
||||
self._rpc_client: RPCClient = RPCClient(host_address, session_token)
|
||||
self._loader: PluginLoader = PluginLoader(host_version=os.getenv(ENV_HOST_VERSION, ""))
|
||||
self._loader.set_blocked_plugin_reasons(self._blocked_plugin_reasons)
|
||||
self._start_time: float = time.monotonic()
|
||||
self._shutting_down: bool = False
|
||||
self._reload_lock: asyncio.Lock = asyncio.Lock()
|
||||
@@ -174,13 +257,43 @@ class PluginRunner:
|
||||
|
||||
# 4. 注入 PluginContext + 调用 on_load 生命周期钩子
|
||||
failed_plugins: Set[str] = set(self._loader.failed_plugins.keys())
|
||||
inactive_plugins: Set[str] = set()
|
||||
available_plugin_versions: Dict[str, str] = dict(self._external_available_plugins)
|
||||
for meta in plugins:
|
||||
ok = await self._activate_plugin(meta)
|
||||
if not ok:
|
||||
unsatisfied_dependencies = [
|
||||
dependency.id
|
||||
for dependency in meta.manifest.plugin_dependencies
|
||||
if dependency.id not in available_plugin_versions
|
||||
or not self._loader.manifest_validator.is_plugin_dependency_satisfied(
|
||||
dependency,
|
||||
available_plugin_versions[dependency.id],
|
||||
)
|
||||
]
|
||||
if unsatisfied_dependencies:
|
||||
if any(dependency_id in inactive_plugins for dependency_id in unsatisfied_dependencies):
|
||||
logger.info(
|
||||
f"插件 {meta.plugin_id} 依赖的插件当前未激活,跳过本次启动: {', '.join(unsatisfied_dependencies)}"
|
||||
)
|
||||
inactive_plugins.add(meta.plugin_id)
|
||||
continue
|
||||
failed_plugins.add(meta.plugin_id)
|
||||
continue
|
||||
|
||||
successful_plugins = [meta.plugin_id for meta in plugins if meta.plugin_id not in failed_plugins]
|
||||
await self._notify_ready(successful_plugins, sorted(failed_plugins))
|
||||
activation_status = await self._activate_plugin(meta)
|
||||
if activation_status == PluginActivationStatus.LOADED:
|
||||
available_plugin_versions[meta.plugin_id] = meta.version
|
||||
continue
|
||||
if activation_status == PluginActivationStatus.INACTIVE:
|
||||
inactive_plugins.add(meta.plugin_id)
|
||||
continue
|
||||
failed_plugins.add(meta.plugin_id)
|
||||
|
||||
successful_plugins = [
|
||||
meta.plugin_id
|
||||
for meta in plugins
|
||||
if meta.plugin_id not in failed_plugins and meta.plugin_id not in inactive_plugins
|
||||
]
|
||||
await self._notify_ready(successful_plugins, sorted(failed_plugins), sorted(inactive_plugins))
|
||||
|
||||
# 5. 等待直到收到关停信号
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
@@ -271,14 +384,11 @@ class PluginRunner:
|
||||
始终绑定为当前插件实例,避免伪造其他插件身份申请能力。
|
||||
"""
|
||||
if plugin_id and plugin_id != bound_plugin_id:
|
||||
logger.warning(
|
||||
f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC,已强制绑定回自身身份"
|
||||
)
|
||||
logger.warning(f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC,已强制绑定回自身身份")
|
||||
normalized_method = str(method or "").strip()
|
||||
if normalized_method not in _PLUGIN_ALLOWED_RAW_HOST_METHODS:
|
||||
raise PermissionError(
|
||||
f"插件 {bound_plugin_id} 不允许直接调用 Host 原始 RPC 方法: "
|
||||
f"{normalized_method or '<empty>'}"
|
||||
f"插件 {bound_plugin_id} 不允许直接调用 Host 原始 RPC 方法: {normalized_method or '<empty>'}"
|
||||
)
|
||||
resp = await rpc_client.send_request(
|
||||
method=normalized_method,
|
||||
@@ -293,17 +403,101 @@ class PluginRunner:
|
||||
cast(_ContextAwarePlugin, instance)._set_context(ctx)
|
||||
logger.debug(f"已为插件 {plugin_id} 注入 PluginContext")
|
||||
|
||||
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""在 Runner 侧为插件实例注入当前插件配置。"""
|
||||
instance = meta.instance
|
||||
if not hasattr(instance, "set_plugin_config"):
|
||||
return
|
||||
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""在 Runner 侧为插件实例注入当前插件配置。
|
||||
|
||||
Args:
|
||||
meta: 插件元数据。
|
||||
config_data: 可选的配置数据;留空时自动从插件目录读取。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 归一化后的当前插件配置。
|
||||
"""
|
||||
instance = meta.instance
|
||||
raw_config = config_data if config_data is not None else self._load_plugin_config(meta.plugin_dir)
|
||||
plugin_config, should_persist = self._normalize_plugin_config(instance, raw_config)
|
||||
config_path = Path(meta.plugin_dir) / "config.toml"
|
||||
default_config = self._get_plugin_default_config(instance)
|
||||
should_initialize_file = not config_path.exists() and bool(default_config)
|
||||
if should_persist or should_initialize_file:
|
||||
self._save_plugin_config(meta.plugin_dir, plugin_config)
|
||||
if hasattr(instance, "set_plugin_config"):
|
||||
try:
|
||||
cast(_ConfigAwarePlugin, instance).set_plugin_config(plugin_config)
|
||||
except Exception as exc:
|
||||
logger.warning(f"插件 {meta.plugin_id} 配置注入失败: {exc}")
|
||||
return plugin_config
|
||||
|
||||
def _normalize_plugin_config(
|
||||
self,
|
||||
instance: object,
|
||||
config_data: Optional[Dict[str, Any]],
|
||||
*,
|
||||
suppress_errors: bool = True,
|
||||
) -> Tuple[Dict[str, Any], bool]:
|
||||
"""对插件配置做统一归一化处理。
|
||||
|
||||
Args:
|
||||
instance: 插件实例。
|
||||
config_data: 原始配置数据。
|
||||
suppress_errors: 是否在归一化失败时吞掉异常并回退原始配置。
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, Any], bool]: 归一化后的配置,以及是否需要回写文件。
|
||||
"""
|
||||
|
||||
normalized_config = dict(config_data or {})
|
||||
if not hasattr(instance, "normalize_plugin_config"):
|
||||
return normalized_config, False
|
||||
|
||||
plugin_config = config_data if config_data is not None else self._load_plugin_config(meta.plugin_dir, meta.plugin_id)
|
||||
try:
|
||||
instance.set_plugin_config(plugin_config)
|
||||
return cast(_ConfigAwarePlugin, instance).normalize_plugin_config(normalized_config)
|
||||
except Exception as exc:
|
||||
logger.warning(f"插件 {meta.plugin_id} 配置注入失败: {exc}")
|
||||
if not suppress_errors:
|
||||
raise
|
||||
logger.warning(f"插件配置归一化失败,将回退为原始配置: {exc}")
|
||||
return normalized_config, False
|
||||
|
||||
@staticmethod
|
||||
def _is_plugin_enabled(config_data: Optional[Mapping[str, Any]]) -> bool:
|
||||
"""根据配置内容判断插件是否应被视为启用。
|
||||
|
||||
Args:
|
||||
config_data: 当前插件配置。
|
||||
|
||||
Returns:
|
||||
bool: 插件是否启用。
|
||||
"""
|
||||
|
||||
if not isinstance(config_data, Mapping):
|
||||
return True
|
||||
|
||||
plugin_section = config_data.get("plugin")
|
||||
if not isinstance(plugin_section, Mapping):
|
||||
return True
|
||||
|
||||
enabled_value = plugin_section.get("enabled", True)
|
||||
if isinstance(enabled_value, str):
|
||||
normalized_value = enabled_value.strip().lower()
|
||||
if normalized_value in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
if normalized_value in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
return bool(enabled_value)
|
||||
|
||||
@staticmethod
|
||||
def _save_plugin_config(plugin_dir: str, config_data: Dict[str, Any]) -> None:
|
||||
"""将插件配置写回到 ``config.toml``。
|
||||
|
||||
Args:
|
||||
plugin_dir: 插件目录。
|
||||
config_data: 需要写回的配置字典。
|
||||
"""
|
||||
|
||||
config_path = Path(plugin_dir) / "config.toml"
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with config_path.open("w", encoding="utf-8") as handle:
|
||||
handle.write(tomlkit.dumps(config_data))
|
||||
|
||||
@staticmethod
|
||||
def _load_plugin_config(plugin_dir: str, plugin_id: str = "") -> Dict[str, Any]:
|
||||
@@ -322,6 +516,99 @@ class PluginRunner:
|
||||
|
||||
return loaded if isinstance(loaded, dict) else {}
|
||||
|
||||
def _resolve_plugin_candidate(self, plugin_id: str) -> Tuple[Optional[PluginCandidate], Optional[str]]:
|
||||
"""解析指定插件的候选目录。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[PluginCandidate], Optional[str]]: 候选插件与错误信息。
|
||||
"""
|
||||
|
||||
candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs)
|
||||
if plugin_id in duplicate_candidates:
|
||||
conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id])
|
||||
return None, f"检测到重复插件 ID: {conflict_paths}"
|
||||
|
||||
candidate = candidates.get(plugin_id)
|
||||
if candidate is None:
|
||||
return None, f"未找到插件: {plugin_id}"
|
||||
return candidate, None
|
||||
|
||||
def _resolve_plugin_meta_for_config_request(
|
||||
self,
|
||||
plugin_id: str,
|
||||
) -> Tuple[Optional[PluginMeta], bool, Optional[str]]:
|
||||
"""为配置相关请求解析插件元数据。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[PluginMeta], bool, Optional[str]]: 依次为插件元数据、
|
||||
是否为临时冷加载实例、以及错误信息。
|
||||
"""
|
||||
|
||||
loaded_meta = self._loader.get_plugin(plugin_id)
|
||||
if loaded_meta is not None:
|
||||
return loaded_meta, False, None
|
||||
|
||||
candidate, error_message = self._resolve_plugin_candidate(plugin_id)
|
||||
if candidate is None:
|
||||
return None, False, error_message
|
||||
|
||||
try:
|
||||
meta = self._loader.load_candidate(plugin_id, candidate)
|
||||
except Exception as exc:
|
||||
return None, False, str(exc)
|
||||
if meta is None:
|
||||
return None, False, "插件模块加载失败"
|
||||
return meta, True, None
|
||||
|
||||
def _inspect_plugin_config(
|
||||
self,
|
||||
meta: PluginMeta,
|
||||
*,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
use_provided_config: bool = False,
|
||||
suppress_errors: bool = True,
|
||||
) -> InspectPluginConfigResultPayload:
|
||||
"""解析插件代码定义的配置元数据。
|
||||
|
||||
Args:
|
||||
meta: 插件元数据。
|
||||
config_data: 可选的配置内容。
|
||||
use_provided_config: 是否优先使用传入的配置内容。
|
||||
suppress_errors: 是否在归一化失败时回退原始配置。
|
||||
|
||||
Returns:
|
||||
InspectPluginConfigResultPayload: 结构化解析结果。
|
||||
"""
|
||||
|
||||
raw_config = config_data if use_provided_config else self._load_plugin_config(meta.plugin_dir)
|
||||
if use_provided_config and config_data is None:
|
||||
raw_config = {}
|
||||
|
||||
normalized_config, changed = self._normalize_plugin_config(
|
||||
meta.instance,
|
||||
raw_config,
|
||||
suppress_errors=suppress_errors,
|
||||
)
|
||||
default_config = self._get_plugin_default_config(meta.instance)
|
||||
if not normalized_config and not raw_config and default_config:
|
||||
normalized_config = dict(default_config)
|
||||
changed = True
|
||||
|
||||
return InspectPluginConfigResultPayload(
|
||||
success=True,
|
||||
default_config=default_config,
|
||||
config_schema=self._get_plugin_config_schema(meta),
|
||||
normalized_config=normalized_config,
|
||||
changed=changed,
|
||||
enabled=self._is_plugin_enabled(normalized_config),
|
||||
)
|
||||
|
||||
def _register_handlers(self) -> None:
|
||||
"""注册 Host -> Runner 的方法处理器。"""
|
||||
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
|
||||
@@ -335,6 +622,8 @@ class PluginRunner:
|
||||
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
|
||||
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
|
||||
self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated)
|
||||
self._rpc_client.register_method("plugin.inspect_config", self._handle_inspect_plugin_config)
|
||||
self._rpc_client.register_method("plugin.validate_config", self._handle_validate_plugin_config)
|
||||
self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin)
|
||||
self._rpc_client.register_method("plugin.reload_batch", self._handle_reload_plugins)
|
||||
|
||||
@@ -452,6 +741,8 @@ class PluginRunner:
|
||||
capabilities_required=meta.capabilities_required,
|
||||
dependencies=meta.dependencies,
|
||||
config_reload_subscriptions=config_reload_subscriptions,
|
||||
default_config=self._get_plugin_default_config(instance),
|
||||
config_schema=self._get_plugin_config_schema(meta),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -463,12 +754,62 @@ class PluginRunner:
|
||||
)
|
||||
if response.error:
|
||||
raise RuntimeError(response.error.get("message", "插件注册失败"))
|
||||
response_payload = response.payload if isinstance(response.payload, dict) else {}
|
||||
if not bool(response_payload.get("accepted", True)):
|
||||
raise RuntimeError(str(response_payload.get("reason", "插件注册失败")))
|
||||
logger.info(f"插件 {meta.plugin_id} 注册完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {meta.plugin_id} 注册失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_plugin_default_config(instance: object) -> Dict[str, Any]:
|
||||
"""获取插件默认配置。
|
||||
|
||||
Args:
|
||||
instance: 插件实例。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 默认配置;插件未声明时返回空字典。
|
||||
"""
|
||||
|
||||
if not hasattr(instance, "get_default_config"):
|
||||
return {}
|
||||
try:
|
||||
default_config = cast(_ConfigAwarePlugin, instance).get_default_config()
|
||||
except Exception as exc:
|
||||
logger.warning(f"读取插件默认配置失败: {exc}")
|
||||
return {}
|
||||
return default_config if isinstance(default_config, dict) else {}
|
||||
|
||||
@staticmethod
|
||||
def _get_plugin_config_schema(meta: PluginMeta) -> Dict[str, Any]:
|
||||
"""获取插件 WebUI 配置 Schema。
|
||||
|
||||
Args:
|
||||
meta: 插件元数据。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 插件配置 Schema;插件未声明时返回空字典。
|
||||
"""
|
||||
|
||||
instance = meta.instance
|
||||
if not hasattr(instance, "get_webui_config_schema"):
|
||||
return {}
|
||||
try:
|
||||
schema = cast(_ConfigAwarePlugin, instance).get_webui_config_schema(
|
||||
plugin_id=meta.plugin_id,
|
||||
plugin_name=meta.manifest.name,
|
||||
plugin_version=meta.version,
|
||||
plugin_description=meta.manifest.description,
|
||||
plugin_author=meta.manifest.author.name,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"构造插件配置 Schema 失败: {exc}")
|
||||
return {}
|
||||
return schema if isinstance(schema, dict) else {}
|
||||
|
||||
async def _unregister_plugin(self, plugin_id: str, reason: str) -> None:
|
||||
"""通知 Host 注销指定插件。
|
||||
|
||||
@@ -526,36 +867,40 @@ class PluginRunner:
|
||||
except Exception as exc:
|
||||
logger.error(f"插件 {meta.plugin_id} on_unload 失败: {exc}", exc_info=True)
|
||||
|
||||
async def _activate_plugin(self, meta: PluginMeta) -> bool:
|
||||
async def _activate_plugin(self, meta: PluginMeta) -> PluginActivationStatus:
|
||||
"""完成插件注入、授权、生命周期和组件注册。
|
||||
|
||||
Args:
|
||||
meta: 待激活的插件元数据。
|
||||
|
||||
Returns:
|
||||
bool: 是否激活成功。
|
||||
PluginActivationStatus: 插件激活结果。
|
||||
"""
|
||||
self._inject_context(meta.plugin_id, meta.instance)
|
||||
self._apply_plugin_config(meta)
|
||||
plugin_config = self._apply_plugin_config(meta)
|
||||
if not self._is_plugin_enabled(plugin_config):
|
||||
logger.info(f"插件 {meta.plugin_id} 已在配置中禁用,跳过激活")
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return PluginActivationStatus.INACTIVE
|
||||
|
||||
if not await self._bootstrap_plugin(meta):
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return False
|
||||
return PluginActivationStatus.FAILED
|
||||
|
||||
if not await self._register_plugin(meta):
|
||||
await self._invoke_plugin_on_unload(meta)
|
||||
await self._deactivate_plugin(meta)
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return False
|
||||
return PluginActivationStatus.FAILED
|
||||
|
||||
if not await self._invoke_plugin_on_load(meta):
|
||||
await self._unregister_plugin(meta.plugin_id, reason="on_load_failed")
|
||||
await self._deactivate_plugin(meta)
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return False
|
||||
return PluginActivationStatus.FAILED
|
||||
|
||||
self._loader.set_loaded_plugin(meta)
|
||||
return True
|
||||
return PluginActivationStatus.LOADED
|
||||
|
||||
async def _unload_plugin(self, meta: PluginMeta, reason: str, *, purge_modules: bool = True) -> None:
|
||||
"""卸载单个插件并清理 Host/Runner 两侧状态。
|
||||
@@ -632,7 +977,9 @@ class PluginRunner:
|
||||
continue
|
||||
dependency_graph[plugin_id] = {dependency for dependency in meta.dependencies if dependency in plugin_ids}
|
||||
|
||||
indegree: Dict[str, int] = {plugin_id: len(dependencies) for plugin_id, dependencies in dependency_graph.items()}
|
||||
indegree: Dict[str, int] = {
|
||||
plugin_id: len(dependencies) for plugin_id, dependencies in dependency_graph.items()
|
||||
}
|
||||
reverse_graph: Dict[str, Set[str]] = {plugin_id: set() for plugin_id in dependency_graph}
|
||||
|
||||
for plugin_id, dependencies in dependency_graph.items():
|
||||
@@ -678,9 +1025,7 @@ class PluginRunner:
|
||||
for failed_plugin_id, failure_reason in failed_plugins.items():
|
||||
rollback_failure = rollback_failures.get(failed_plugin_id)
|
||||
if rollback_failure:
|
||||
finalized_failures[failed_plugin_id] = (
|
||||
f"{failure_reason};且旧版本恢复失败: {rollback_failure}"
|
||||
)
|
||||
finalized_failures[failed_plugin_id] = f"{failure_reason};且旧版本恢复失败: {rollback_failure}"
|
||||
else:
|
||||
finalized_failures[failed_plugin_id] = f"{failure_reason}(已恢复旧版本)"
|
||||
|
||||
@@ -716,6 +1061,7 @@ class PluginRunner:
|
||||
requested_plugin_id=plugin_id,
|
||||
reloaded_plugins=batch_result.reloaded_plugins,
|
||||
unloaded_plugins=batch_result.unloaded_plugins,
|
||||
inactive_plugins=batch_result.inactive_plugins,
|
||||
failed_plugins=batch_result.failed_plugins,
|
||||
)
|
||||
|
||||
@@ -762,9 +1108,7 @@ class PluginRunner:
|
||||
failed_plugins=failed_plugins,
|
||||
)
|
||||
|
||||
target_plugin_ids: Set[str] = {
|
||||
plugin_id for plugin_id in reload_root_ids if plugin_id not in loaded_plugin_ids
|
||||
}
|
||||
target_plugin_ids: Set[str] = {plugin_id for plugin_id in reload_root_ids if plugin_id not in loaded_plugin_ids}
|
||||
if loaded_root_plugin_ids := reload_root_ids & loaded_plugin_ids:
|
||||
target_plugin_ids.update(self._collect_reverse_dependents_for_roots(loaded_root_plugin_ids))
|
||||
|
||||
@@ -812,6 +1156,8 @@ class PluginRunner:
|
||||
},
|
||||
}
|
||||
reloaded_plugins: List[str] = []
|
||||
inactive_plugins: List[str] = []
|
||||
inactive_plugin_ids: Set[str] = set()
|
||||
|
||||
for load_plugin_id in load_order:
|
||||
if load_plugin_id in failed_plugins:
|
||||
@@ -822,10 +1168,28 @@ class PluginRunner:
|
||||
continue
|
||||
|
||||
_, manifest, _ = candidate
|
||||
unsatisfied_dependency_ids = [
|
||||
dependency.id
|
||||
for dependency in manifest.plugin_dependencies
|
||||
if dependency.id not in available_plugins
|
||||
or not self._loader.manifest_validator.is_plugin_dependency_satisfied(
|
||||
dependency,
|
||||
available_plugins[dependency.id],
|
||||
)
|
||||
]
|
||||
if unsatisfied_dependencies := self._loader.manifest_validator.get_unsatisfied_plugin_dependencies(
|
||||
manifest,
|
||||
available_plugin_versions=available_plugins,
|
||||
):
|
||||
if load_plugin_id not in reload_root_ids and any(
|
||||
dependency_id in inactive_plugin_ids for dependency_id in unsatisfied_dependency_ids
|
||||
):
|
||||
logger.info(
|
||||
f"插件 {load_plugin_id} 的依赖当前未激活,保留为未激活状态: {', '.join(unsatisfied_dependencies)}"
|
||||
)
|
||||
inactive_plugin_ids.add(load_plugin_id)
|
||||
inactive_plugins.append(load_plugin_id)
|
||||
continue
|
||||
failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(unsatisfied_dependencies)}"
|
||||
continue
|
||||
|
||||
@@ -835,9 +1199,13 @@ class PluginRunner:
|
||||
continue
|
||||
|
||||
activated = await self._activate_plugin(meta)
|
||||
if not activated:
|
||||
if activated == PluginActivationStatus.FAILED:
|
||||
failed_plugins[load_plugin_id] = "插件初始化失败"
|
||||
continue
|
||||
if activated == PluginActivationStatus.INACTIVE:
|
||||
inactive_plugin_ids.add(load_plugin_id)
|
||||
inactive_plugins.append(load_plugin_id)
|
||||
continue
|
||||
|
||||
available_plugins[load_plugin_id] = meta.version
|
||||
reloaded_plugins.append(load_plugin_id)
|
||||
@@ -872,7 +1240,7 @@ class PluginRunner:
|
||||
rollback_failures[rollback_plugin_id] = str(exc)
|
||||
continue
|
||||
|
||||
if not restored:
|
||||
if restored != PluginActivationStatus.LOADED:
|
||||
rollback_failures[rollback_plugin_id] = "无法重新激活旧版本"
|
||||
|
||||
return ReloadPluginsResultPayload(
|
||||
@@ -880,29 +1248,40 @@ class PluginRunner:
|
||||
requested_plugin_ids=normalized_plugin_ids,
|
||||
reloaded_plugins=[],
|
||||
unloaded_plugins=unloaded_plugins,
|
||||
inactive_plugins=[],
|
||||
failed_plugins=self._finalize_failed_reload_messages(failed_plugins, rollback_failures),
|
||||
)
|
||||
|
||||
requested_plugin_success = all(plugin_id in reloaded_plugins for plugin_id in reload_root_ids)
|
||||
requested_plugin_success = all(
|
||||
plugin_id in reloaded_plugins or plugin_id in inactive_plugins for plugin_id in reload_root_ids
|
||||
)
|
||||
|
||||
return ReloadPluginsResultPayload(
|
||||
success=requested_plugin_success and not failed_plugins,
|
||||
requested_plugin_ids=normalized_plugin_ids,
|
||||
reloaded_plugins=reloaded_plugins,
|
||||
unloaded_plugins=unloaded_plugins,
|
||||
inactive_plugins=inactive_plugins,
|
||||
failed_plugins=failed_plugins,
|
||||
)
|
||||
|
||||
async def _notify_ready(self, loaded_plugins: List[str], failed_plugins: List[str]) -> None:
|
||||
async def _notify_ready(
|
||||
self,
|
||||
loaded_plugins: List[str],
|
||||
failed_plugins: List[str],
|
||||
inactive_plugins: List[str],
|
||||
) -> None:
|
||||
"""通知 Host 当前 Runner 已完成插件初始化。
|
||||
|
||||
Args:
|
||||
loaded_plugins: 成功初始化的插件列表。
|
||||
failed_plugins: 初始化失败的插件列表。
|
||||
inactive_plugins: 因禁用或依赖不可用而未激活的插件列表。
|
||||
"""
|
||||
payload = RunnerReadyPayload(
|
||||
loaded_plugins=loaded_plugins,
|
||||
failed_plugins=failed_plugins,
|
||||
inactive_plugins=inactive_plugins,
|
||||
)
|
||||
await self._rpc_client.send_request(
|
||||
"runner.ready",
|
||||
@@ -1128,6 +1507,87 @@ class PluginRunner:
|
||||
return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
|
||||
return envelope.make_response(payload={"acknowledged": True})
|
||||
|
||||
async def _handle_inspect_plugin_config(self, envelope: Envelope) -> Envelope:
|
||||
"""处理插件配置元数据解析请求。
|
||||
|
||||
Args:
|
||||
envelope: RPC 请求信封。
|
||||
|
||||
Returns:
|
||||
Envelope: RPC 响应信封。
|
||||
"""
|
||||
|
||||
try:
|
||||
payload = InspectPluginConfigPayload.model_validate(envelope.payload)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
plugin_id = envelope.plugin_id
|
||||
meta, is_temporary_meta, error_message = self._resolve_plugin_meta_for_config_request(plugin_id)
|
||||
if meta is None:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_PLUGIN_NOT_FOUND.value,
|
||||
error_message or f"未找到插件: {plugin_id}",
|
||||
)
|
||||
|
||||
try:
|
||||
result = self._inspect_plugin_config(
|
||||
meta,
|
||||
config_data=payload.config_data,
|
||||
use_provided_config=payload.use_provided_config,
|
||||
suppress_errors=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
finally:
|
||||
if is_temporary_meta:
|
||||
self._loader.purge_plugin_modules(plugin_id, meta.plugin_dir)
|
||||
|
||||
return envelope.make_response(payload=result.model_dump())
|
||||
|
||||
async def _handle_validate_plugin_config(self, envelope: Envelope) -> Envelope:
|
||||
"""处理插件配置校验请求。
|
||||
|
||||
Args:
|
||||
envelope: RPC 请求信封。
|
||||
|
||||
Returns:
|
||||
Envelope: RPC 响应信封。
|
||||
"""
|
||||
|
||||
try:
|
||||
payload = ValidatePluginConfigPayload.model_validate(envelope.payload)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
plugin_id = envelope.plugin_id
|
||||
meta, is_temporary_meta, error_message = self._resolve_plugin_meta_for_config_request(plugin_id)
|
||||
if meta is None:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_PLUGIN_NOT_FOUND.value,
|
||||
error_message or f"未找到插件: {plugin_id}",
|
||||
)
|
||||
|
||||
try:
|
||||
inspection_result = self._inspect_plugin_config(
|
||||
meta,
|
||||
config_data=payload.config_data,
|
||||
use_provided_config=True,
|
||||
suppress_errors=False,
|
||||
)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
finally:
|
||||
if is_temporary_meta:
|
||||
self._loader.purge_plugin_modules(plugin_id, meta.plugin_dir)
|
||||
|
||||
result = ValidatePluginConfigResultPayload(
|
||||
success=True,
|
||||
normalized_config=inspection_result.normalized_config,
|
||||
changed=inspection_result.changed,
|
||||
)
|
||||
return envelope.make_response(payload=result.model_dump())
|
||||
|
||||
async def _handle_reload_plugin(self, envelope: Envelope) -> Envelope:
|
||||
"""处理按插件 ID 的精确重载请求。
|
||||
|
||||
@@ -1189,6 +1649,7 @@ class PluginRunner:
|
||||
|
||||
async def _async_main() -> None:
|
||||
"""异步主入口"""
|
||||
blocked_plugin_reasons_raw = os.environ.get(ENV_BLOCKED_PLUGIN_REASONS, "")
|
||||
host_address = os.environ.pop(ENV_IPC_ADDRESS, "")
|
||||
external_plugin_ids_raw = os.environ.get(ENV_EXTERNAL_PLUGIN_IDS, "")
|
||||
session_token = os.environ.pop(ENV_SESSION_TOKEN, "")
|
||||
@@ -1208,14 +1669,30 @@ async def _async_main() -> None:
|
||||
logger.warning("外部依赖插件版本映射格式非法,已回退为空映射")
|
||||
external_plugin_ids = {}
|
||||
|
||||
try:
|
||||
blocked_plugin_reasons = json.loads(blocked_plugin_reasons_raw) if blocked_plugin_reasons_raw else {}
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("解析阻止加载插件原因映射失败,已回退为空映射")
|
||||
blocked_plugin_reasons = {}
|
||||
if not isinstance(blocked_plugin_reasons, dict):
|
||||
logger.warning("阻止加载插件原因映射格式非法,已回退为空映射")
|
||||
blocked_plugin_reasons = {}
|
||||
|
||||
runner_kwargs: Dict[str, Any] = {
|
||||
"external_available_plugins": {
|
||||
str(plugin_id): str(plugin_version) for plugin_id, plugin_version in external_plugin_ids.items()
|
||||
}
|
||||
}
|
||||
if blocked_plugin_reasons:
|
||||
runner_kwargs["blocked_plugin_reasons"] = {
|
||||
str(plugin_id): str(reason) for plugin_id, reason in blocked_plugin_reasons.items()
|
||||
}
|
||||
|
||||
runner = PluginRunner(
|
||||
host_address,
|
||||
session_token,
|
||||
plugin_dirs,
|
||||
external_available_plugins={
|
||||
str(plugin_id): str(plugin_version)
|
||||
for plugin_id, plugin_version in external_plugin_ids.items()
|
||||
},
|
||||
**runner_kwargs,
|
||||
)
|
||||
|
||||
# 注册信号处理
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
{
|
||||
"manifest_version": 2,
|
||||
"version": "2.0.0",
|
||||
"name": "Emoji插件 (Emoji Actions)",
|
||||
"description": "可以发送和管理 Emoji",
|
||||
"author": {
|
||||
"name": "SengokuCola",
|
||||
"url": "https://github.com/MaiM-with-u"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
"urls": {
|
||||
"repository": "https://github.com/MaiM-with-u/maibot",
|
||||
"homepage": "https://github.com/MaiM-with-u/maibot",
|
||||
"documentation": "https://github.com/MaiM-with-u/maibot",
|
||||
"issues": "https://github.com/MaiM-with-u/maibot/issues"
|
||||
},
|
||||
"host_application": {
|
||||
"min_version": "1.0.0",
|
||||
"max_version": "1.0.0"
|
||||
},
|
||||
"sdk": {
|
||||
"min_version": "2.0.0",
|
||||
"max_version": "2.99.99"
|
||||
},
|
||||
"dependencies": [],
|
||||
"capabilities": [
|
||||
"emoji.get_random",
|
||||
"message.get_recent",
|
||||
"message.build_readable",
|
||||
"llm.generate",
|
||||
"send.emoji",
|
||||
"config.get"
|
||||
],
|
||||
"i18n": {
|
||||
"default_locale": "zh-CN",
|
||||
"supported_locales": [
|
||||
"zh-CN"
|
||||
]
|
||||
},
|
||||
"id": "builtin.emoji-plugin"
|
||||
}
|
||||
@@ -1,129 +0,0 @@
|
||||
"""Emoji 插件 — 新 SDK 版本
|
||||
|
||||
根据聊天上下文的情感,使用 LLM 选择并发送合适的表情包。
|
||||
"""
|
||||
|
||||
from maibot_sdk import Action, MaiBotPlugin
|
||||
from maibot_sdk.types import ActivationType
|
||||
|
||||
import random
|
||||
|
||||
|
||||
class EmojiPlugin(MaiBotPlugin):
|
||||
"""表情包插件"""
|
||||
|
||||
@Action(
|
||||
"emoji",
|
||||
description="发送表情包辅助表达情绪",
|
||||
activation_type=ActivationType.RANDOM,
|
||||
activation_probability=0.3,
|
||||
parallel_action=True,
|
||||
action_require=[
|
||||
"发送表情包辅助表达情绪",
|
||||
"表达情绪时可以选择使用",
|
||||
"不要连续发送,如果你已经发过[表情包],就不要选择此动作",
|
||||
],
|
||||
associated_types=["emoji"],
|
||||
)
|
||||
async def handle_emoji(self, stream_id: str = "", reasoning: str = "", chat_id: str = "", **kwargs):
|
||||
"""执行表情动作"""
|
||||
reason = reasoning or "表达当前情绪"
|
||||
|
||||
# 1. 随机获取30个表情包
|
||||
sampled_emojis = await self.ctx.emoji.get_random(30)
|
||||
if not sampled_emojis:
|
||||
return False, "无法获取随机表情包"
|
||||
|
||||
# 2. 按情感分组
|
||||
emotion_map: dict[str, list] = {}
|
||||
for emoji in sampled_emojis:
|
||||
emo = emoji.get("emotion", "")
|
||||
if emo not in emotion_map:
|
||||
emotion_map[emo] = []
|
||||
emotion_map[emo].append(emoji)
|
||||
|
||||
available_emotions = list(emotion_map.keys())
|
||||
|
||||
if not available_emotions:
|
||||
# 无情感标签,随机发送
|
||||
chosen = random.choice(sampled_emojis)
|
||||
await self.ctx.send.emoji(chosen["base64"], stream_id)
|
||||
return True, "随机发送了表情包"
|
||||
|
||||
# 3. 获取最近消息作为上下文
|
||||
messages_text = ""
|
||||
if chat_id:
|
||||
recent_messages = await self.ctx.message.get_recent(chat_id=chat_id, limit=5)
|
||||
if recent_messages:
|
||||
messages_text = await self.ctx.message.build_readable(
|
||||
recent_messages,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False,
|
||||
)
|
||||
|
||||
# 4. 构建 prompt 让 LLM 选择情感
|
||||
available_emotions_str = "\n".join(available_emotions)
|
||||
prompt = f"""你正在进行QQ聊天,你需要根据聊天记录,选出一个合适的情感标签。
|
||||
请你根据以下原因和聊天记录进行选择
|
||||
原因:{reason}
|
||||
聊天记录:
|
||||
{messages_text}
|
||||
|
||||
这里是可用的情感标签:
|
||||
{available_emotions_str}
|
||||
请直接返回最匹配的那个情感标签,不要进行任何解释或添加其他多余的文字。
|
||||
"""
|
||||
|
||||
# 5. 调用 LLM
|
||||
llm_result = await self.ctx.llm.generate(prompt=prompt, model_name="utils")
|
||||
if not llm_result or not llm_result.get("success"):
|
||||
chosen = random.choice(sampled_emojis)
|
||||
await self.ctx.send.emoji(chosen["base64"], stream_id)
|
||||
return True, "LLM调用失败,随机发送了表情包"
|
||||
|
||||
chosen_emotion = llm_result.get("response", "").strip().replace('"', "").replace("'", "")
|
||||
|
||||
# 6. 根据选择的情感匹配表情包
|
||||
if chosen_emotion in emotion_map:
|
||||
chosen = random.choice(emotion_map[chosen_emotion])
|
||||
else:
|
||||
chosen = random.choice(sampled_emojis)
|
||||
|
||||
# 7. 发送
|
||||
send_ok = await self.ctx.send.emoji(chosen["base64"], stream_id)
|
||||
if send_ok:
|
||||
return True, f"成功发送表情包:[表情包:{chosen_emotion}]"
|
||||
return False, "发送表情包失败"
|
||||
|
||||
async def on_load(self) -> None:
|
||||
"""处理插件加载。"""
|
||||
|
||||
# 从插件配置读取 emoji_chance 来覆盖默认概率
|
||||
await self.ctx.config.get("emoji.emoji_chance")
|
||||
|
||||
async def on_unload(self) -> None:
|
||||
"""处理插件卸载。"""
|
||||
|
||||
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
|
||||
"""处理配置热重载事件。
|
||||
|
||||
Args:
|
||||
scope: 配置变更范围。
|
||||
config_data: 最新配置数据。
|
||||
version: 配置版本号。
|
||||
"""
|
||||
|
||||
del config_data
|
||||
del version
|
||||
if scope == "self":
|
||||
await self.ctx.config.get("emoji.emoji_chance")
|
||||
|
||||
|
||||
def create_plugin() -> EmojiPlugin:
|
||||
"""创建 Emoji 插件实例。
|
||||
|
||||
Returns:
|
||||
EmojiPlugin: 新的 Emoji 插件实例。
|
||||
"""
|
||||
|
||||
return EmojiPlugin()
|
||||
937
src/services/html_render_service.py
Normal file
937
src/services/html_render_service.py
Normal file
@@ -0,0 +1,937 @@
|
||||
"""HTML 浏览器渲染服务。
|
||||
|
||||
负责在 Host 侧复用已有浏览器,并将 HTML 内容渲染为 PNG 图片。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from importlib import metadata
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Tuple, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextlib
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
|
||||
from src.common.logger import PROJECT_ROOT, get_logger
|
||||
from src.config.config import config_manager
|
||||
from src.config.official_configs import PluginRuntimeRenderConfig
|
||||
|
||||
logger = get_logger("services.html_render_service")
|
||||
|
||||
_NETWORK_ALLOW_SCHEMES = frozenset({"about", "blob", "data", "file"})
|
||||
_WINDOWS_BROWSER_PATHS = (
|
||||
Path("C:/Program Files/Google/Chrome/Application/chrome.exe"),
|
||||
Path("C:/Program Files (x86)/Google/Chrome/Application/chrome.exe"),
|
||||
Path("C:/Program Files/Microsoft/Edge/Application/msedge.exe"),
|
||||
Path("C:/Program Files (x86)/Microsoft/Edge/Application/msedge.exe"),
|
||||
)
|
||||
_MACOS_BROWSER_PATHS = (
|
||||
Path("/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"),
|
||||
Path("/Applications/Microsoft Edge.app/Contents/MacOS/Microsoft Edge"),
|
||||
)
|
||||
_UNIX_BROWSER_NAMES = (
|
||||
"chromium",
|
||||
"chromium-browser",
|
||||
"google-chrome",
|
||||
"google-chrome-stable",
|
||||
"microsoft-edge",
|
||||
"msedge",
|
||||
)
|
||||
_PLAYWRIGHT_MANAGED_BROWSER_PREFIXES = ("chromium-", "chrome-", "chrome-headless-shell-")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HtmlRenderRequest:
|
||||
"""描述一次 HTML 转 PNG 请求。"""
|
||||
|
||||
html: str
|
||||
selector: str = "body"
|
||||
viewport_width: int = 900
|
||||
viewport_height: int = 500
|
||||
device_scale_factor: float = 2.0
|
||||
full_page: bool = False
|
||||
omit_background: bool = False
|
||||
wait_until: str = "load"
|
||||
wait_for_selector: str = ""
|
||||
wait_for_timeout_ms: int = 0
|
||||
timeout_ms: int = 10000
|
||||
allow_network: bool = False
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class HtmlRenderResult:
|
||||
"""描述一次 HTML 转 PNG 的输出结果。"""
|
||||
|
||||
image_base64: str
|
||||
mime_type: str
|
||||
width: int
|
||||
height: int
|
||||
render_ms: int
|
||||
|
||||
def to_payload(self) -> Dict[str, Any]:
|
||||
"""将结果序列化为能力层返回结构。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 可直接返回给插件运行时的结构化数据。
|
||||
"""
|
||||
|
||||
return {
|
||||
"image_base64": self.image_base64,
|
||||
"mime_type": self.mime_type,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"render_ms": self.render_ms,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ManagedBrowserRecord:
|
||||
"""记录 Playwright 托管浏览器的本地状态。"""
|
||||
|
||||
browser_name: str
|
||||
browsers_path: str
|
||||
install_source: Literal["auto_download", "existing_cache"]
|
||||
playwright_version: str
|
||||
recorded_at: str
|
||||
last_verified_at: str
|
||||
|
||||
def to_dict(self) -> Dict[str, str]:
|
||||
"""将浏览器记录转换为可持久化字典。
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 可写入 JSON 文件的字典结构。
|
||||
"""
|
||||
|
||||
return {
|
||||
"browser_name": self.browser_name,
|
||||
"browsers_path": self.browsers_path,
|
||||
"install_source": self.install_source,
|
||||
"playwright_version": self.playwright_version,
|
||||
"recorded_at": self.recorded_at,
|
||||
"last_verified_at": self.last_verified_at,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, payload: Dict[str, Any]) -> Optional["ManagedBrowserRecord"]:
|
||||
"""从字典中恢复浏览器状态记录。
|
||||
|
||||
Args:
|
||||
payload: 原始字典数据。
|
||||
|
||||
Returns:
|
||||
Optional[ManagedBrowserRecord]: 解析成功时返回记录对象,否则返回 ``None``。
|
||||
"""
|
||||
|
||||
browser_name = str(payload.get("browser_name", "") or "").strip()
|
||||
browsers_path = str(payload.get("browsers_path", "") or "").strip()
|
||||
install_source = str(payload.get("install_source", "") or "").strip()
|
||||
playwright_version = str(payload.get("playwright_version", "") or "").strip()
|
||||
recorded_at = str(payload.get("recorded_at", "") or "").strip()
|
||||
last_verified_at = str(payload.get("last_verified_at", "") or "").strip()
|
||||
if not all([browser_name, browsers_path, install_source, playwright_version, recorded_at, last_verified_at]):
|
||||
return None
|
||||
if install_source not in {"auto_download", "existing_cache"}:
|
||||
return None
|
||||
validated_install_source = cast(Literal["auto_download", "existing_cache"], install_source)
|
||||
return cls(
|
||||
browser_name=browser_name,
|
||||
browsers_path=browsers_path,
|
||||
install_source=validated_install_source,
|
||||
playwright_version=playwright_version,
|
||||
recorded_at=recorded_at,
|
||||
last_verified_at=last_verified_at,
|
||||
)
|
||||
|
||||
|
||||
class HTMLRenderService:
|
||||
"""HTML 浏览器渲染服务。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化渲染服务。"""
|
||||
|
||||
self._browser: Any = None
|
||||
self._browser_lock: asyncio.Lock = asyncio.Lock()
|
||||
self._connected_via_cdp: bool = False
|
||||
self._playwright: Any = None
|
||||
self._render_count: int = 0
|
||||
self._render_semaphore: Optional[asyncio.Semaphore] = None
|
||||
self._render_semaphore_limit: int = 0
|
||||
|
||||
def _get_render_config(self) -> PluginRuntimeRenderConfig:
|
||||
"""读取当前插件运行时的浏览器渲染配置。
|
||||
|
||||
Returns:
|
||||
PluginRuntimeRenderConfig: 当前生效的浏览器渲染配置。
|
||||
"""
|
||||
|
||||
return config_manager.get_global_config().plugin_runtime.render
|
||||
|
||||
def _get_render_semaphore(self) -> asyncio.Semaphore:
|
||||
"""根据当前配置返回渲染并发信号量。
|
||||
|
||||
Returns:
|
||||
asyncio.Semaphore: 控制并发的信号量对象。
|
||||
"""
|
||||
|
||||
config = self._get_render_config()
|
||||
limit = max(1, int(config.concurrency_limit))
|
||||
if self._render_semaphore is None or self._render_semaphore_limit != limit:
|
||||
self._render_semaphore = asyncio.Semaphore(limit)
|
||||
self._render_semaphore_limit = limit
|
||||
return self._render_semaphore
|
||||
|
||||
async def render_html_to_png(self, request: HtmlRenderRequest) -> HtmlRenderResult:
|
||||
"""将 HTML 内容渲染为 PNG 图片。
|
||||
|
||||
Args:
|
||||
request: 本次渲染请求。
|
||||
|
||||
Returns:
|
||||
HtmlRenderResult: 渲染结果。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 浏览器能力被禁用、Playwright 不可用或浏览器启动失败时抛出。
|
||||
ValueError: 请求参数非法时抛出。
|
||||
"""
|
||||
|
||||
config = self._get_render_config()
|
||||
if not config.enabled:
|
||||
raise RuntimeError("插件运行时浏览器渲染能力已禁用")
|
||||
|
||||
normalized_request = self._normalize_request(request, config)
|
||||
semaphore = self._get_render_semaphore()
|
||||
async with semaphore:
|
||||
start_time = time.perf_counter()
|
||||
browser = await self._ensure_browser(config)
|
||||
context: Any = None
|
||||
try:
|
||||
context = await browser.new_context(
|
||||
device_scale_factor=normalized_request.device_scale_factor,
|
||||
locale="zh-CN",
|
||||
viewport={
|
||||
"width": normalized_request.viewport_width,
|
||||
"height": normalized_request.viewport_height,
|
||||
},
|
||||
)
|
||||
page = await context.new_page()
|
||||
await self._configure_page(page, normalized_request)
|
||||
image_bytes = await self._capture_image(page, normalized_request)
|
||||
width, height = self._measure_image_size(image_bytes)
|
||||
self._render_count += 1
|
||||
await self._maybe_restart_browser(config)
|
||||
return HtmlRenderResult(
|
||||
image_base64=base64.b64encode(image_bytes).decode("utf-8"),
|
||||
mime_type="image/png",
|
||||
width=width,
|
||||
height=height,
|
||||
render_ms=int((time.perf_counter() - start_time) * 1000),
|
||||
)
|
||||
except Exception:
|
||||
await self.reset_browser(restart_playwright=False)
|
||||
raise
|
||||
finally:
|
||||
if context is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
await context.close()
|
||||
|
||||
async def reset_browser(self, restart_playwright: bool = False) -> None:
|
||||
"""关闭当前缓存的浏览器实例。
|
||||
|
||||
Args:
|
||||
restart_playwright: 是否同时关闭 Playwright 运行时。
|
||||
"""
|
||||
|
||||
async with self._browser_lock:
|
||||
await self._close_browser_unlocked(restart_playwright=restart_playwright)
|
||||
|
||||
async def _close_browser_unlocked(self, restart_playwright: bool = False) -> None:
|
||||
"""在已持有锁的情况下关闭浏览器与 Playwright。
|
||||
|
||||
Args:
|
||||
restart_playwright: 是否同时关闭 Playwright 运行时。
|
||||
"""
|
||||
|
||||
if self._browser is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
await self._browser.close()
|
||||
self._browser = None
|
||||
self._connected_via_cdp = False
|
||||
if restart_playwright and self._playwright is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
await self._playwright.stop()
|
||||
self._playwright = None
|
||||
|
||||
async def _ensure_browser(self, config: PluginRuntimeRenderConfig) -> Any:
|
||||
"""获取可复用的浏览器实例。
|
||||
|
||||
Args:
|
||||
config: 当前浏览器渲染配置。
|
||||
|
||||
Returns:
|
||||
Any: Playwright Browser 对象。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当无法连接或启动浏览器时抛出。
|
||||
"""
|
||||
|
||||
async with self._browser_lock:
|
||||
if self._is_browser_connected(self._browser):
|
||||
logger.debug("HTML 渲染服务复用进程内缓存浏览器实例")
|
||||
return self._browser
|
||||
|
||||
await self._close_browser_unlocked(restart_playwright=False)
|
||||
self._prepare_playwright_environment(config)
|
||||
playwright = await self._ensure_playwright()
|
||||
browser = await self._connect_to_existing_browser(playwright, config)
|
||||
if browser is None:
|
||||
browser = await self._launch_browser(playwright, config)
|
||||
self._connected_via_cdp = False
|
||||
else:
|
||||
self._connected_via_cdp = True
|
||||
|
||||
self._browser = browser
|
||||
self._bind_browser_events(browser)
|
||||
return browser
|
||||
|
||||
async def _ensure_playwright(self) -> Any:
|
||||
"""懒加载并启动 Playwright 运行时。
|
||||
|
||||
Returns:
|
||||
Any: 已启动的 Playwright 对象。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当前环境未安装 Playwright 时抛出。
|
||||
"""
|
||||
|
||||
if self._playwright is not None:
|
||||
return self._playwright
|
||||
|
||||
try:
|
||||
from playwright.async_api import async_playwright
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
"当前环境未安装 Python Playwright,请先在宿主环境安装 `playwright` 依赖。"
|
||||
) from exc
|
||||
|
||||
self._playwright = await async_playwright().start()
|
||||
return self._playwright
|
||||
|
||||
@staticmethod
|
||||
def _is_browser_connected(browser: Any) -> bool:
|
||||
"""判断浏览器对象当前是否仍然可用。
|
||||
|
||||
Args:
|
||||
browser: 待检查的浏览器对象。
|
||||
|
||||
Returns:
|
||||
bool: 若浏览器仍连接,则返回 ``True``。
|
||||
"""
|
||||
|
||||
if browser is None:
|
||||
return False
|
||||
try:
|
||||
return bool(browser.is_connected())
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def _connect_to_existing_browser(self, playwright: Any, config: PluginRuntimeRenderConfig) -> Any:
|
||||
"""优先连接外部已有的 Chromium 浏览器。
|
||||
|
||||
Args:
|
||||
playwright: 已启动的 Playwright 对象。
|
||||
config: 当前浏览器渲染配置。
|
||||
|
||||
Returns:
|
||||
Any: 连接成功时返回 Browser;否则返回 ``None``。
|
||||
"""
|
||||
|
||||
if not config.browser_ws_endpoint.strip():
|
||||
return None
|
||||
|
||||
try:
|
||||
timeout_ms = int(config.startup_timeout_sec * 1000)
|
||||
logger.info(
|
||||
"HTML 渲染服务准备连接现有浏览器: "
|
||||
f"endpoint={config.browser_ws_endpoint.strip()}, timeout_ms={timeout_ms}"
|
||||
)
|
||||
browser = await playwright.chromium.connect_over_cdp(
|
||||
config.browser_ws_endpoint.strip(),
|
||||
timeout=timeout_ms,
|
||||
)
|
||||
logger.info("HTML 渲染服务已连接到现有浏览器")
|
||||
return browser
|
||||
except Exception as exc:
|
||||
logger.warning(f"连接现有浏览器失败,将回退为本地启动: {exc}")
|
||||
return None
|
||||
|
||||
async def _launch_browser(self, playwright: Any, config: PluginRuntimeRenderConfig) -> Any:
|
||||
"""启动本地 Chromium 浏览器。
|
||||
|
||||
Args:
|
||||
playwright: 已启动的 Playwright 对象。
|
||||
config: 当前浏览器渲染配置。
|
||||
|
||||
Returns:
|
||||
Any: 新启动的 Browser 对象。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 浏览器启动失败时抛出。
|
||||
"""
|
||||
|
||||
launch_options = self._build_launch_options(config)
|
||||
logger.info(
|
||||
"HTML 渲染服务准备启动浏览器: "
|
||||
f"source={'system' if 'executable_path' in launch_options else 'managed'}, "
|
||||
f"headless={bool(launch_options.get('headless'))}, "
|
||||
f"timeout_ms={int(launch_options.get('timeout', 0))}"
|
||||
)
|
||||
try:
|
||||
browser = await playwright.chromium.launch(**launch_options)
|
||||
if "executable_path" in launch_options:
|
||||
logger.info(f"HTML 渲染服务已启动本机浏览器: executable_path={launch_options['executable_path']}")
|
||||
else:
|
||||
self._update_managed_browser_record(config, install_source="existing_cache")
|
||||
logger.info("HTML 渲染服务已启动 Playwright 托管浏览器")
|
||||
return browser
|
||||
except Exception as exc:
|
||||
if self._should_auto_download_browser(exc, launch_options, config):
|
||||
logger.warning(f"HTML 渲染服务未找到可用浏览器,将尝试自动下载 Chromium: {exc}")
|
||||
await self._install_chromium_browser(config)
|
||||
retry_browser = await playwright.chromium.launch(**launch_options)
|
||||
self._update_managed_browser_record(config, install_source="auto_download")
|
||||
logger.info("HTML 渲染服务已自动下载并启动 Chromium")
|
||||
return retry_browser
|
||||
raise RuntimeError(f"启动本地浏览器失败: {exc}") from exc
|
||||
|
||||
def _bind_browser_events(self, browser: Any) -> None:
|
||||
"""为浏览器绑定断线回调。
|
||||
|
||||
Args:
|
||||
browser: 需要绑定事件的浏览器对象。
|
||||
"""
|
||||
|
||||
try:
|
||||
browser.on("disconnected", self._handle_browser_disconnected)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def _handle_browser_disconnected(self, *_args: Any) -> None:
|
||||
"""处理浏览器断线事件。
|
||||
|
||||
Args:
|
||||
*_args: 浏览器断线事件透传的参数。
|
||||
"""
|
||||
|
||||
self._browser = None
|
||||
self._connected_via_cdp = False
|
||||
logger.warning("HTML 渲染浏览器已断开,将在下次请求时重新建立连接")
|
||||
|
||||
def _build_launch_options(self, config: PluginRuntimeRenderConfig) -> Dict[str, Any]:
|
||||
"""构造本地浏览器启动参数。
|
||||
|
||||
Args:
|
||||
config: 当前浏览器渲染配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 可直接传给 Playwright 的启动参数。
|
||||
"""
|
||||
|
||||
launch_options: Dict[str, Any] = {
|
||||
"args": list(config.launch_args),
|
||||
"headless": bool(config.headless),
|
||||
"timeout": int(config.startup_timeout_sec * 1000),
|
||||
}
|
||||
executable_path = self._resolve_executable_path(config)
|
||||
if executable_path:
|
||||
launch_options["executable_path"] = executable_path
|
||||
return launch_options
|
||||
|
||||
@staticmethod
|
||||
def _should_auto_download_browser(
|
||||
exc: Exception,
|
||||
launch_options: Dict[str, Any],
|
||||
config: PluginRuntimeRenderConfig,
|
||||
) -> bool:
|
||||
"""判断当前启动错误是否适合自动下载 Chromium 后重试。
|
||||
|
||||
Args:
|
||||
exc: 浏览器启动异常。
|
||||
launch_options: 本次启动参数。
|
||||
config: 当前浏览器渲染配置。
|
||||
|
||||
Returns:
|
||||
bool: 若应自动下载后重试,则返回 ``True``。
|
||||
"""
|
||||
|
||||
if "executable_path" in launch_options:
|
||||
logger.debug("当前启动参数已指定本机浏览器路径,不进入自动下载分支")
|
||||
return False
|
||||
if not config.auto_download_chromium:
|
||||
logger.warning("HTML 渲染服务未检测到可用浏览器,且已禁用自动下载 Chromium")
|
||||
return False
|
||||
error_text = str(exc).lower()
|
||||
should_download = "executable doesn't exist" in error_text or "browser executable" in error_text
|
||||
if not should_download:
|
||||
logger.warning(f"浏览器启动失败,但错误不属于可自动下载恢复的类型: {exc}")
|
||||
return should_download
|
||||
|
||||
def _resolve_executable_path(self, config: PluginRuntimeRenderConfig) -> str:
|
||||
"""解析实际应使用的浏览器可执行文件路径。
|
||||
|
||||
Args:
|
||||
config: 当前浏览器渲染配置。
|
||||
|
||||
Returns:
|
||||
str: 命中的浏览器可执行文件路径;未命中时返回空字符串。
|
||||
"""
|
||||
|
||||
configured_path = config.executable_path.strip()
|
||||
if configured_path:
|
||||
path = Path(configured_path).expanduser()
|
||||
if path.exists():
|
||||
logger.info(f"HTML 渲染服务使用配置指定的浏览器路径: {path}")
|
||||
return str(path)
|
||||
logger.warning(f"配置的浏览器路径不存在,将尝试自动探测: {configured_path}")
|
||||
|
||||
detected_path = self._detect_local_browser_executable()
|
||||
if detected_path:
|
||||
logger.info(f"HTML 渲染服务自动探测到本机浏览器: {detected_path}")
|
||||
else:
|
||||
logger.info("HTML 渲染服务未探测到本机浏览器,将尝试使用 Playwright 托管浏览器")
|
||||
return detected_path
|
||||
|
||||
def _prepare_playwright_environment(self, config: PluginRuntimeRenderConfig) -> Path:
|
||||
"""准备 Playwright 运行所需的共享浏览器目录环境变量。
|
||||
|
||||
Args:
|
||||
config: 当前浏览器渲染配置。
|
||||
|
||||
Returns:
|
||||
Path: Playwright 浏览器缓存目录。
|
||||
"""
|
||||
|
||||
browsers_path = self._get_managed_browsers_path(config)
|
||||
browsers_path.mkdir(parents=True, exist_ok=True)
|
||||
os.environ["PLAYWRIGHT_BROWSERS_PATH"] = str(browsers_path)
|
||||
logger.debug(f"HTML 渲染服务使用 Playwright 浏览器目录: {browsers_path}")
|
||||
return browsers_path
|
||||
|
||||
def _get_managed_browsers_path(self, config: PluginRuntimeRenderConfig) -> Path:
|
||||
"""获取 Playwright 托管浏览器目录。
|
||||
|
||||
Args:
|
||||
config: 当前浏览器渲染配置。
|
||||
|
||||
Returns:
|
||||
Path: 托管浏览器目录的绝对路径。
|
||||
"""
|
||||
|
||||
configured_path = config.browser_install_root.strip()
|
||||
if not configured_path:
|
||||
return (PROJECT_ROOT / "data" / "playwright-browsers").resolve()
|
||||
candidate_path = Path(configured_path).expanduser()
|
||||
if candidate_path.is_absolute():
|
||||
return candidate_path.resolve()
|
||||
return (PROJECT_ROOT / candidate_path).resolve()
|
||||
|
||||
def _get_browser_state_path(self) -> Path:
|
||||
"""获取托管浏览器状态文件路径。
|
||||
|
||||
Returns:
|
||||
Path: 浏览器状态文件路径。
|
||||
"""
|
||||
|
||||
return (PROJECT_ROOT / "data" / "plugin_runtime" / "html_render_browser_state.json").resolve()
|
||||
|
||||
def _load_managed_browser_record(self) -> Optional[ManagedBrowserRecord]:
|
||||
"""读取最近一次成功使用的托管浏览器记录。
|
||||
|
||||
Returns:
|
||||
Optional[ManagedBrowserRecord]: 解析成功时返回记录对象,否则返回 ``None``。
|
||||
"""
|
||||
|
||||
state_path = self._get_browser_state_path()
|
||||
if not state_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
raw_payload = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning(f"HTML 渲染浏览器状态文件读取失败,将忽略并继续: {state_path}")
|
||||
return None
|
||||
if not isinstance(raw_payload, dict):
|
||||
logger.warning(f"HTML 渲染浏览器状态文件格式无效,将忽略并继续: {state_path}")
|
||||
return None
|
||||
browser_record = ManagedBrowserRecord.from_dict(raw_payload)
|
||||
if browser_record is not None:
|
||||
logger.debug(
|
||||
"HTML 渲染服务已加载浏览器状态记录: "
|
||||
f"source={browser_record.install_source}, path={browser_record.browsers_path}, "
|
||||
f"verified_at={browser_record.last_verified_at}"
|
||||
)
|
||||
return browser_record
|
||||
|
||||
def _save_managed_browser_record(self, record: ManagedBrowserRecord) -> None:
|
||||
"""保存托管浏览器记录。
|
||||
|
||||
Args:
|
||||
record: 待保存的浏览器记录。
|
||||
"""
|
||||
|
||||
state_path = self._get_browser_state_path()
|
||||
state_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
state_path.write_text(
|
||||
json.dumps(record.to_dict(), ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
logger.info(
|
||||
"HTML 渲染服务已写入浏览器状态记录: "
|
||||
f"path={state_path}, source={record.install_source}, browsers_path={record.browsers_path}"
|
||||
)
|
||||
|
||||
def _update_managed_browser_record(
|
||||
self,
|
||||
config: PluginRuntimeRenderConfig,
|
||||
install_source: Literal["auto_download", "existing_cache"],
|
||||
) -> None:
|
||||
"""更新托管 Chromium 的使用记录。
|
||||
|
||||
Args:
|
||||
config: 当前浏览器渲染配置。
|
||||
install_source: 本次记录的浏览器来源。
|
||||
"""
|
||||
|
||||
browsers_path = self._get_managed_browsers_path(config)
|
||||
if not self._has_managed_browser_artifact(browsers_path):
|
||||
return
|
||||
|
||||
now_iso = datetime.now(timezone.utc).isoformat()
|
||||
existing_record = self._load_managed_browser_record()
|
||||
recorded_at = now_iso
|
||||
if existing_record is not None and existing_record.browsers_path == str(browsers_path):
|
||||
recorded_at = existing_record.recorded_at
|
||||
|
||||
self._save_managed_browser_record(
|
||||
ManagedBrowserRecord(
|
||||
browser_name="chromium",
|
||||
browsers_path=str(browsers_path),
|
||||
install_source=install_source,
|
||||
playwright_version=self._get_playwright_version(),
|
||||
recorded_at=recorded_at,
|
||||
last_verified_at=now_iso,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"HTML 渲染服务已更新托管浏览器记录: "
|
||||
f"source={install_source}, browsers_path={browsers_path}, last_verified_at={now_iso}"
|
||||
)
|
||||
|
||||
async def _install_chromium_browser(self, config: PluginRuntimeRenderConfig) -> None:
|
||||
"""自动下载 Playwright Chromium 浏览器。
|
||||
|
||||
Args:
|
||||
config: 当前浏览器渲染配置。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 下载失败时抛出。
|
||||
"""
|
||||
|
||||
browsers_path = self._prepare_playwright_environment(config)
|
||||
logger.warning(
|
||||
"HTML 渲染服务开始自动下载 Chromium: "
|
||||
f"target_dir={browsers_path}, timeout_sec={config.download_connection_timeout_sec}"
|
||||
)
|
||||
env = os.environ.copy()
|
||||
env["PLAYWRIGHT_BROWSERS_PATH"] = str(browsers_path)
|
||||
env["PLAYWRIGHT_DOWNLOAD_CONNECTION_TIMEOUT"] = str(int(config.download_connection_timeout_sec * 1000))
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
sys.executable,
|
||||
"-m",
|
||||
"playwright",
|
||||
"install",
|
||||
"chromium",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
stdout_bytes, stderr_bytes = await process.communicate()
|
||||
if process.returncode != 0:
|
||||
stderr_text = stderr_bytes.decode("utf-8", errors="ignore").strip()
|
||||
stdout_text = stdout_bytes.decode("utf-8", errors="ignore").strip()
|
||||
error_detail = stderr_text or stdout_text or f"退出码 {process.returncode}"
|
||||
raise RuntimeError(f"自动下载 Chromium 失败: {error_detail}")
|
||||
|
||||
if not self._has_managed_browser_artifact(browsers_path):
|
||||
raise RuntimeError("Chromium 下载完成后未检测到可用浏览器文件")
|
||||
logger.info(f"HTML 渲染服务自动下载 Chromium 完成: target_dir={browsers_path}")
|
||||
|
||||
@staticmethod
|
||||
def _get_playwright_version() -> str:
|
||||
"""读取当前环境中的 Playwright 版本号。
|
||||
|
||||
Returns:
|
||||
str: Playwright 版本字符串;读取失败时返回 ``unknown``。
|
||||
"""
|
||||
|
||||
try:
|
||||
return metadata.version("playwright")
|
||||
except metadata.PackageNotFoundError:
|
||||
return "unknown"
|
||||
|
||||
@staticmethod
|
||||
def _has_managed_browser_artifact(browsers_path: Path) -> bool:
|
||||
"""检查共享目录中是否存在可用的 Playwright 托管浏览器。
|
||||
|
||||
Args:
|
||||
browsers_path: Playwright 浏览器目录。
|
||||
|
||||
Returns:
|
||||
bool: 若检测到 Chromium/Chrome 相关浏览器文件夹,则返回 ``True``。
|
||||
"""
|
||||
|
||||
if not browsers_path.exists():
|
||||
return False
|
||||
for child_path in browsers_path.iterdir():
|
||||
if not child_path.is_dir():
|
||||
continue
|
||||
if child_path.name.startswith(_PLAYWRIGHT_MANAGED_BROWSER_PREFIXES):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _detect_local_browser_executable(self) -> str:
|
||||
"""自动探测当前宿主系统中的可复用浏览器路径。
|
||||
|
||||
Returns:
|
||||
str: 命中的浏览器可执行文件路径;未命中时返回空字符串。
|
||||
"""
|
||||
|
||||
for browser_name in _UNIX_BROWSER_NAMES:
|
||||
resolved_path = shutil.which(browser_name)
|
||||
if resolved_path:
|
||||
return resolved_path
|
||||
|
||||
for candidate_path in self._get_candidate_executable_paths():
|
||||
if candidate_path.exists():
|
||||
return str(candidate_path)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _get_candidate_executable_paths() -> Tuple[Path, ...]:
|
||||
"""返回当前平台常见浏览器路径候选集合。
|
||||
|
||||
Returns:
|
||||
Tuple[Path, ...]: 可能存在浏览器可执行文件的路径列表。
|
||||
"""
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
return _WINDOWS_BROWSER_PATHS
|
||||
if sys.platform == "darwin":
|
||||
return _MACOS_BROWSER_PATHS
|
||||
return ()
|
||||
|
||||
async def _configure_page(self, page: Any, request: HtmlRenderRequest) -> None:
|
||||
"""为页面设置超时、网络策略并写入 HTML。
|
||||
|
||||
Args:
|
||||
page: Playwright 页面对象。
|
||||
request: 当前渲染请求。
|
||||
"""
|
||||
|
||||
page.set_default_timeout(request.timeout_ms)
|
||||
await page.route(
|
||||
"**/*",
|
||||
functools.partial(self._handle_network_route, allow_network=request.allow_network),
|
||||
)
|
||||
await page.set_content(
|
||||
request.html,
|
||||
timeout=request.timeout_ms,
|
||||
wait_until=request.wait_until,
|
||||
)
|
||||
if request.wait_for_selector:
|
||||
await page.locator(request.wait_for_selector).first.wait_for(
|
||||
state="attached",
|
||||
timeout=request.timeout_ms,
|
||||
)
|
||||
if request.wait_for_timeout_ms > 0:
|
||||
await page.wait_for_timeout(request.wait_for_timeout_ms)
|
||||
|
||||
async def _handle_network_route(self, route: Any, allow_network: bool) -> None:
|
||||
"""处理页面资源请求的网络准入策略。
|
||||
|
||||
Args:
|
||||
route: Playwright 路由对象。
|
||||
allow_network: 是否允许页面访问外部网络资源。
|
||||
"""
|
||||
|
||||
request_url = str(route.request.url)
|
||||
if allow_network or self._is_network_request_allowed(request_url):
|
||||
await route.continue_()
|
||||
return
|
||||
await route.abort()
|
||||
|
||||
@staticmethod
|
||||
def _is_network_request_allowed(request_url: str) -> bool:
|
||||
"""判断某个资源 URL 是否属于本地安全资源。
|
||||
|
||||
Args:
|
||||
request_url: 待判断的资源地址。
|
||||
|
||||
Returns:
|
||||
bool: 若请求可在无网络模式下放行,则返回 ``True``。
|
||||
"""
|
||||
|
||||
if not request_url:
|
||||
return False
|
||||
parsed_url = urlparse(request_url)
|
||||
return parsed_url.scheme in _NETWORK_ALLOW_SCHEMES
|
||||
|
||||
async def _capture_image(self, page: Any, request: HtmlRenderRequest) -> bytes:
|
||||
"""从页面或目标元素中截取 PNG 图片。
|
||||
|
||||
Args:
|
||||
page: Playwright 页面对象。
|
||||
request: 当前渲染请求。
|
||||
|
||||
Returns:
|
||||
bytes: PNG 二进制内容。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 目标元素不存在或截图结果为空时抛出。
|
||||
"""
|
||||
|
||||
if request.full_page and request.selector == "body":
|
||||
image_bytes = await page.screenshot(
|
||||
full_page=True,
|
||||
omit_background=request.omit_background,
|
||||
timeout=request.timeout_ms,
|
||||
type="png",
|
||||
)
|
||||
else:
|
||||
locator = page.locator(request.selector).first
|
||||
await locator.wait_for(state="visible", timeout=request.timeout_ms)
|
||||
image_bytes = await locator.screenshot(
|
||||
omit_background=request.omit_background,
|
||||
timeout=request.timeout_ms,
|
||||
type="png",
|
||||
)
|
||||
|
||||
if not image_bytes:
|
||||
raise RuntimeError("浏览器截图结果为空")
|
||||
return image_bytes
|
||||
|
||||
@staticmethod
|
||||
def _measure_image_size(image_bytes: bytes) -> Tuple[int, int]:
|
||||
"""读取 PNG 图片的真实像素尺寸。
|
||||
|
||||
Args:
|
||||
image_bytes: PNG 图片二进制内容。
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: 图片宽高像素值。
|
||||
"""
|
||||
|
||||
from PIL import Image
|
||||
|
||||
with Image.open(BytesIO(image_bytes)) as image:
|
||||
return int(image.width), int(image.height)
|
||||
|
||||
async def _maybe_restart_browser(self, config: PluginRuntimeRenderConfig) -> None:
|
||||
"""按策略决定是否重建本地浏览器实例。
|
||||
|
||||
Args:
|
||||
config: 当前浏览器渲染配置。
|
||||
"""
|
||||
|
||||
restart_after = int(config.restart_after_render_count)
|
||||
if restart_after <= 0 or self._connected_via_cdp:
|
||||
return
|
||||
if self._render_count % restart_after != 0:
|
||||
return
|
||||
await self.reset_browser(restart_playwright=False)
|
||||
logger.info("HTML 渲染服务已按累计次数策略重建本地浏览器")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_request(
|
||||
request: HtmlRenderRequest,
|
||||
config: PluginRuntimeRenderConfig,
|
||||
) -> HtmlRenderRequest:
|
||||
"""规范化并补齐 HTML 渲染请求。
|
||||
|
||||
Args:
|
||||
request: 原始渲染请求。
|
||||
config: 当前浏览器渲染配置。
|
||||
|
||||
Returns:
|
||||
HtmlRenderRequest: 规范化后的请求对象。
|
||||
|
||||
Raises:
|
||||
ValueError: 请求缺少必要字段或取值非法时抛出。
|
||||
"""
|
||||
|
||||
html = request.html.strip()
|
||||
if not html:
|
||||
raise ValueError("缺少必要参数 html")
|
||||
|
||||
selector = request.selector.strip() or "body"
|
||||
wait_until = HTMLRenderService._normalize_wait_until(request.wait_until)
|
||||
timeout_ms = request.timeout_ms
|
||||
if timeout_ms <= 0:
|
||||
timeout_ms = int(config.render_timeout_sec * 1000)
|
||||
|
||||
return HtmlRenderRequest(
|
||||
html=html,
|
||||
selector=selector,
|
||||
viewport_width=max(1, int(request.viewport_width)),
|
||||
viewport_height=max(1, int(request.viewport_height)),
|
||||
device_scale_factor=max(1.0, float(request.device_scale_factor)),
|
||||
full_page=bool(request.full_page),
|
||||
omit_background=bool(request.omit_background),
|
||||
wait_until=wait_until,
|
||||
wait_for_selector=request.wait_for_selector.strip(),
|
||||
wait_for_timeout_ms=max(0, int(request.wait_for_timeout_ms)),
|
||||
timeout_ms=max(1, int(timeout_ms)),
|
||||
allow_network=bool(request.allow_network),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_wait_until(wait_until: str) -> str:
|
||||
"""规范化页面等待阶段参数。
|
||||
|
||||
Args:
|
||||
wait_until: 原始等待阶段字符串。
|
||||
|
||||
Returns:
|
||||
str: Playwright 支持的等待阶段值。
|
||||
"""
|
||||
|
||||
normalized_wait_until = wait_until.strip().lower()
|
||||
if normalized_wait_until in {"commit", "domcontentloaded", "load", "networkidle"}:
|
||||
return normalized_wait_until
|
||||
return "load"
|
||||
|
||||
|
||||
_html_render_service: Optional[HTMLRenderService] = None
|
||||
|
||||
|
||||
def get_html_render_service() -> HTMLRenderService:
|
||||
"""获取 HTML 浏览器渲染服务单例。
|
||||
|
||||
Returns:
|
||||
HTMLRenderService: 全局唯一的浏览器渲染服务实例。
|
||||
"""
|
||||
|
||||
global _html_render_service
|
||||
if _html_render_service is None:
|
||||
_html_render_service = HTMLRenderService()
|
||||
return _html_render_service
|
||||
@@ -40,10 +40,213 @@ from src.common.utils.utils_message import MessageUtils
|
||||
from src.config.config import global_config
|
||||
from src.platform_io import DeliveryBatch, get_platform_io_manager
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
from src.plugin_runtime.hook_payloads import deserialize_session_message, serialize_session_message
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
|
||||
logger = get_logger("send_service")
|
||||
|
||||
|
||||
def register_send_service_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册发送服务内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="send_service.after_build_message",
|
||||
description="在出站 SessionMessage 构建完成后触发,可改写消息体或取消发送。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "待发送消息的序列化 SessionMessage。",
|
||||
},
|
||||
"stream_id": {
|
||||
"type": "string",
|
||||
"description": "目标会话 ID。",
|
||||
},
|
||||
"display_message": {
|
||||
"type": "string",
|
||||
"description": "展示层文本。",
|
||||
},
|
||||
"typing": {
|
||||
"type": "boolean",
|
||||
"description": "是否模拟打字。",
|
||||
},
|
||||
"set_reply": {
|
||||
"type": "boolean",
|
||||
"description": "是否附带引用回复。",
|
||||
},
|
||||
"storage_message": {
|
||||
"type": "boolean",
|
||||
"description": "发送成功后是否写库。",
|
||||
},
|
||||
"show_log": {
|
||||
"type": "boolean",
|
||||
"description": "是否输出发送日志。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"message",
|
||||
"stream_id",
|
||||
"display_message",
|
||||
"typing",
|
||||
"set_reply",
|
||||
"storage_message",
|
||||
"show_log",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="send_service.before_send",
|
||||
description="在真正调用 Platform IO 发送前触发,可改写消息或取消本次发送。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "待发送消息的序列化 SessionMessage。",
|
||||
},
|
||||
"typing": {
|
||||
"type": "boolean",
|
||||
"description": "是否模拟打字。",
|
||||
},
|
||||
"set_reply": {
|
||||
"type": "boolean",
|
||||
"description": "是否附带引用回复。",
|
||||
},
|
||||
"reply_message_id": {
|
||||
"type": "string",
|
||||
"description": "被引用消息 ID。",
|
||||
},
|
||||
"storage_message": {
|
||||
"type": "boolean",
|
||||
"description": "发送成功后是否写库。",
|
||||
},
|
||||
"show_log": {
|
||||
"type": "boolean",
|
||||
"description": "是否输出发送日志。",
|
||||
},
|
||||
},
|
||||
required=["message", "typing", "set_reply", "storage_message", "show_log"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="send_service.after_send",
|
||||
description="在发送流程结束后触发,用于观察最终发送结果。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "本次发送消息的序列化 SessionMessage。",
|
||||
},
|
||||
"sent": {
|
||||
"type": "boolean",
|
||||
"description": "本次发送是否成功。",
|
||||
},
|
||||
"typing": {
|
||||
"type": "boolean",
|
||||
"description": "是否模拟打字。",
|
||||
},
|
||||
"set_reply": {
|
||||
"type": "boolean",
|
||||
"description": "是否附带引用回复。",
|
||||
},
|
||||
"reply_message_id": {
|
||||
"type": "string",
|
||||
"description": "被引用消息 ID。",
|
||||
},
|
||||
"storage_message": {
|
||||
"type": "boolean",
|
||||
"description": "发送成功后是否写库。",
|
||||
},
|
||||
"show_log": {
|
||||
"type": "boolean",
|
||||
"description": "是否输出发送日志。",
|
||||
},
|
||||
},
|
||||
required=["message", "sent", "typing", "set_reply", "storage_message", "show_log"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=False,
|
||||
allow_kwargs_mutation=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
|
||||
def _coerce_bool(value: Any, default: bool) -> bool:
|
||||
"""将任意值安全转换为布尔值。
|
||||
|
||||
Args:
|
||||
value: 待转换的值。
|
||||
default: 当值为空时使用的默认值。
|
||||
|
||||
Returns:
|
||||
bool: 转换后的布尔值。
|
||||
"""
|
||||
|
||||
if value is None:
|
||||
return default
|
||||
return bool(value)
|
||||
|
||||
|
||||
async def _invoke_send_hook(
|
||||
hook_name: str,
|
||||
message: SessionMessage,
|
||||
**kwargs: Any,
|
||||
) -> tuple[HookDispatchResult, SessionMessage]:
|
||||
"""触发携带出站消息的命名 Hook。
|
||||
|
||||
Args:
|
||||
hook_name: 目标 Hook 名称。
|
||||
message: 当前待发送消息。
|
||||
**kwargs: 需要附带的额外参数。
|
||||
|
||||
Returns:
|
||||
tuple[HookDispatchResult, SessionMessage]: Hook 聚合结果以及可能被改写后的消息对象。
|
||||
"""
|
||||
|
||||
hook_result = await _get_runtime_manager().invoke_hook(
|
||||
hook_name,
|
||||
message=serialize_session_message(message),
|
||||
**kwargs,
|
||||
)
|
||||
mutated_message = message
|
||||
raw_message = hook_result.kwargs.get("message")
|
||||
if raw_message is not None:
|
||||
try:
|
||||
mutated_message = deserialize_session_message(raw_message)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Hook {hook_name} 返回的 message 无法反序列化,已忽略: {exc}")
|
||||
return hook_result, mutated_message
|
||||
|
||||
|
||||
def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[str, object]:
|
||||
"""从目标会话继承 Platform IO 路由元数据。
|
||||
|
||||
@@ -484,6 +687,27 @@ async def _send_via_platform_io(
|
||||
Returns:
|
||||
bool: 发送成功时返回 ``True``。
|
||||
"""
|
||||
before_send_result, message = await _invoke_send_hook(
|
||||
"send_service.before_send",
|
||||
message,
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
reply_message_id=reply_message_id,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
if before_send_result.aborted:
|
||||
logger.info(f"[SendService] 消息 {message.message_id} 在发送前被 Hook 中止")
|
||||
return False
|
||||
|
||||
before_kwargs = before_send_result.kwargs
|
||||
typing = _coerce_bool(before_kwargs.get("typing"), typing)
|
||||
set_reply = _coerce_bool(before_kwargs.get("set_reply"), set_reply)
|
||||
storage_message = _coerce_bool(before_kwargs.get("storage_message"), storage_message)
|
||||
show_log = _coerce_bool(before_kwargs.get("show_log"), show_log)
|
||||
raw_reply_message_id = before_kwargs.get("reply_message_id", reply_message_id)
|
||||
reply_message_id = None if raw_reply_message_id in {None, ""} else str(raw_reply_message_id)
|
||||
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
try:
|
||||
await platform_io_manager.ensure_send_pipeline_ready()
|
||||
@@ -515,6 +739,18 @@ async def _send_via_platform_io(
|
||||
logger.debug(traceback.format_exc())
|
||||
return False
|
||||
|
||||
sent = bool(delivery_batch.has_success)
|
||||
await _invoke_send_hook(
|
||||
"send_service.after_send",
|
||||
message,
|
||||
sent=sent,
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
reply_message_id=reply_message_id,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
|
||||
if delivery_batch.has_success:
|
||||
if storage_message:
|
||||
_store_sent_message(message)
|
||||
@@ -622,6 +858,26 @@ async def _send_to_target(
|
||||
if outbound_message is None:
|
||||
return False
|
||||
|
||||
after_build_result, outbound_message = await _invoke_send_hook(
|
||||
"send_service.after_build_message",
|
||||
outbound_message,
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
if after_build_result.aborted:
|
||||
logger.info(f"[SendService] 消息 {outbound_message.message_id} 在构建后被 Hook 中止")
|
||||
return False
|
||||
|
||||
after_build_kwargs = after_build_result.kwargs
|
||||
typing = _coerce_bool(after_build_kwargs.get("typing"), typing)
|
||||
set_reply = _coerce_bool(after_build_kwargs.get("set_reply"), set_reply)
|
||||
storage_message = _coerce_bool(after_build_kwargs.get("storage_message"), storage_message)
|
||||
show_log = _coerce_bool(after_build_kwargs.get("show_log"), show_log)
|
||||
|
||||
sent = await send_session_message(
|
||||
outbound_message,
|
||||
typing=typing,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import inspect
|
||||
from typing import Any, Dict, List, get_args, get_origin
|
||||
|
||||
import inspect
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from src.config.config_base import ConfigBase
|
||||
@@ -56,7 +57,7 @@ class ConfigSchemaGenerator:
|
||||
if inspect.isclass(annotation) and issubclass(annotation, ConfigBase):
|
||||
return cls.generate_config_schema(annotation)
|
||||
|
||||
if origin in {list, tuple} and args:
|
||||
if origin in {list, set, tuple} and args:
|
||||
first = args[0]
|
||||
if inspect.isclass(first) and issubclass(first, ConfigBase):
|
||||
return cls.generate_config_schema(first)
|
||||
@@ -83,7 +84,7 @@ class ConfigSchemaGenerator:
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if origin is list and args:
|
||||
if origin in {list, set} and args:
|
||||
schema["items"] = {"type": cls._map_field_type(args[0])}
|
||||
|
||||
if options := cls._extract_options(annotation):
|
||||
@@ -120,7 +121,7 @@ class ConfigSchemaGenerator:
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if origin in {list, tuple}:
|
||||
if origin in {list, set, tuple}:
|
||||
return "array"
|
||||
if inspect.isclass(annotation) and issubclass(annotation, ConfigBase):
|
||||
return "object"
|
||||
@@ -133,7 +134,7 @@ class ConfigSchemaGenerator:
|
||||
if annotation is str:
|
||||
return "string"
|
||||
|
||||
if origin in {list, tuple} and args:
|
||||
if origin in {list, set, tuple} and args:
|
||||
return "array"
|
||||
|
||||
if origin in {dict}:
|
||||
|
||||
@@ -9,6 +9,7 @@ from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
from src.webui.routers.websocket.manager import websocket_manager
|
||||
|
||||
logger = get_logger("webui.logs_ws")
|
||||
router = APIRouter()
|
||||
@@ -148,24 +149,9 @@ async def broadcast_log(log_data: Dict):
|
||||
Args:
|
||||
log_data: 日志数据字典
|
||||
"""
|
||||
if not active_connections:
|
||||
return
|
||||
|
||||
# 格式化为 JSON
|
||||
message = json.dumps(log_data, ensure_ascii=False)
|
||||
|
||||
# 记录需要断开的连接
|
||||
disconnected = set()
|
||||
|
||||
# 广播到所有客户端
|
||||
for connection in active_connections:
|
||||
try:
|
||||
await connection.send_text(message)
|
||||
except Exception:
|
||||
# 发送失败,标记为断开
|
||||
disconnected.add(connection)
|
||||
|
||||
# 清理断开的连接
|
||||
if disconnected:
|
||||
active_connections.difference_update(disconnected)
|
||||
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")
|
||||
await websocket_manager.broadcast_to_topic(
|
||||
domain="logs",
|
||||
topic="main",
|
||||
event="entry",
|
||||
data={"entry": log_data},
|
||||
)
|
||||
|
||||
@@ -18,13 +18,13 @@ def get_all_routers() -> List[APIRouter]:
|
||||
from src.webui.api.replier import router as replier_router
|
||||
from src.webui.routers.chat import router as chat_router
|
||||
from src.webui.routers.memory import compat_router as memory_compat_router
|
||||
from src.webui.routers.websocket.logs import router as logs_router
|
||||
from src.webui.routers.knowledge import router as knowledge_router
|
||||
from src.webui.routes import router as main_router
|
||||
|
||||
return [
|
||||
main_router,
|
||||
memory_compat_router,
|
||||
logs_router,
|
||||
knowledge_router,
|
||||
chat_router,
|
||||
planner_router,
|
||||
replier_router,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Tuple
|
||||
|
||||
from .routes import router
|
||||
from .support import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
|
||||
from .service import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
|
||||
|
||||
|
||||
def get_webui_chat_broadcaster() -> Tuple[ChatConnectionManager, str]:
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""本地聊天室路由 - WebUI 与麦麦直接对话。"""
|
||||
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import case, func
|
||||
from sqlmodel import col, select
|
||||
|
||||
@@ -13,16 +12,11 @@ from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
from .support import (
|
||||
from .service import (
|
||||
WEBUI_CHAT_GROUP_ID,
|
||||
WEBUI_CHAT_PLATFORM,
|
||||
authenticate_chat_websocket,
|
||||
chat_history,
|
||||
chat_manager,
|
||||
dispatch_chat_event,
|
||||
normalize_webui_user_id,
|
||||
resolve_initial_virtual_identity,
|
||||
send_initial_chat_state,
|
||||
)
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
@@ -113,55 +107,6 @@ async def clear_chat_history(
|
||||
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_chat(
|
||||
websocket: WebSocket,
|
||||
user_id: Optional[str] = Query(default=None),
|
||||
user_name: Optional[str] = Query(default="WebUI用户"),
|
||||
platform: Optional[str] = Query(default=None),
|
||||
person_id: Optional[str] = Query(default=None),
|
||||
group_name: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None),
|
||||
token: Optional[str] = Query(default=None),
|
||||
) -> None:
|
||||
"""WebSocket 聊天端点。"""
|
||||
if not await authenticate_chat_websocket(websocket, token):
|
||||
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
normalized_user_id = normalize_webui_user_id(user_id)
|
||||
current_user_name = user_name or "WebUI用户"
|
||||
current_virtual_config = resolve_initial_virtual_identity(platform, person_id, group_name, group_id)
|
||||
|
||||
await chat_manager.connect(websocket, session_id, normalized_user_id)
|
||||
try:
|
||||
await send_initial_chat_state(
|
||||
session_id=session_id,
|
||||
user_id=normalized_user_id,
|
||||
user_name=current_user_name,
|
||||
virtual_config=current_virtual_config,
|
||||
)
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
current_user_name, current_virtual_config = await dispatch_chat_event(
|
||||
session_id=session_id,
|
||||
session_id_prefix=session_id[:8],
|
||||
data=data,
|
||||
current_user_name=current_user_name,
|
||||
normalized_user_id=normalized_user_id,
|
||||
current_virtual_config=current_virtual_config,
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开: session={session_id}, user={normalized_user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 错误: {e}")
|
||||
finally:
|
||||
chat_manager.disconnect(session_id, normalized_user_id)
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_chat_info() -> Dict[str, object]:
|
||||
"""获取聊天室信息。"""
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""WebUI 聊天路由支持逻辑。"""
|
||||
"""WebUI 聊天运行时服务。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, cast
|
||||
|
||||
from fastapi import WebSocket
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, delete, select
|
||||
|
||||
@@ -17,8 +17,6 @@ from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
|
||||
@@ -27,6 +25,8 @@ WEBUI_CHAT_PLATFORM = "webui"
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
WEBUI_USER_ID_PREFIX = "webui_user_"
|
||||
|
||||
AsyncMessageSender = Callable[[Dict[str, Any]], Awaitable[None]]
|
||||
|
||||
|
||||
class VirtualIdentityConfig(BaseModel):
|
||||
"""虚拟身份配置。"""
|
||||
@@ -52,13 +52,42 @@ class ChatHistoryMessage(BaseModel):
|
||||
is_bot: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatSessionConnection:
|
||||
"""逻辑聊天会话连接信息。"""
|
||||
|
||||
session_id: str
|
||||
connection_id: str
|
||||
client_session_id: str
|
||||
user_id: str
|
||||
user_name: str
|
||||
active_group_id: str
|
||||
virtual_config: Optional[VirtualIdentityConfig]
|
||||
sender: AsyncMessageSender
|
||||
|
||||
|
||||
class ChatHistoryManager:
|
||||
"""聊天历史管理器。"""
|
||||
|
||||
def __init__(self, max_messages: int = 200) -> None:
|
||||
"""初始化聊天历史管理器。
|
||||
|
||||
Args:
|
||||
max_messages: 内存中允许处理的最大消息数。
|
||||
"""
|
||||
self.max_messages = max_messages
|
||||
|
||||
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""将内部消息对象转换为前端可消费的字典。
|
||||
|
||||
Args:
|
||||
msg: 内部统一消息对象。
|
||||
group_id: 当前会话所属的群组标识。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 面向 WebUI 的消息字典。
|
||||
"""
|
||||
del group_id
|
||||
user_info = msg.message_info.user_info
|
||||
user_id = user_info.user_id or ""
|
||||
is_bot = is_bot_self(msg.platform, user_id)
|
||||
@@ -74,10 +103,27 @@ class ChatHistoryManager:
|
||||
}
|
||||
|
||||
def _resolve_session_id(self, group_id: Optional[str]) -> str:
|
||||
"""根据群组标识解析聊天会话 ID。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识。
|
||||
|
||||
Returns:
|
||||
str: 内部聊天会话 ID。
|
||||
"""
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
|
||||
|
||||
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""获取指定会话的历史消息。
|
||||
|
||||
Args:
|
||||
limit: 最大返回条数。
|
||||
group_id: 群组标识。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 历史消息列表。
|
||||
"""
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
session_id = self._resolve_session_id(target_group_id)
|
||||
try:
|
||||
@@ -90,11 +136,19 @@ class ChatHistoryManager:
|
||||
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
|
||||
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载聊天记录失败: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"从数据库加载聊天记录失败: {exc}")
|
||||
return []
|
||||
|
||||
def clear_history(self, group_id: Optional[str] = None) -> int:
|
||||
"""清空指定会话的历史消息。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识。
|
||||
|
||||
Returns:
|
||||
int: 被删除的消息数量。
|
||||
"""
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
session_id = self._resolve_session_id(target_group_id)
|
||||
try:
|
||||
@@ -104,66 +158,245 @@ class ChatHistoryManager:
|
||||
deleted = result.rowcount or 0
|
||||
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
|
||||
return deleted
|
||||
except Exception as e:
|
||||
logger.error(f"清空聊天记录失败: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"清空聊天记录失败: {exc}")
|
||||
return 0
|
||||
|
||||
|
||||
class ChatConnectionManager:
|
||||
"""聊天连接管理器。"""
|
||||
"""统一聊天逻辑会话管理器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.user_sessions: Dict[str, str] = {}
|
||||
"""初始化聊天逻辑会话管理器。"""
|
||||
self.active_connections: Dict[str, ChatSessionConnection] = {}
|
||||
self.client_sessions: Dict[Tuple[str, str], str] = {}
|
||||
self.connection_sessions: Dict[str, Set[str]] = {}
|
||||
self.group_sessions: Dict[str, Set[str]] = {}
|
||||
self.user_sessions: Dict[str, Set[str]] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None:
|
||||
await websocket.accept()
|
||||
self.active_connections[session_id] = websocket
|
||||
self.user_sessions[user_id] = session_id
|
||||
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
|
||||
def _bind_group(self, session_id: str, group_id: str) -> None:
|
||||
"""为会话绑定群组索引。
|
||||
|
||||
def disconnect(self, session_id: str, user_id: str) -> None:
|
||||
if session_id in self.active_connections:
|
||||
del self.active_connections[session_id]
|
||||
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
|
||||
del self.user_sessions[user_id]
|
||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||
Args:
|
||||
session_id: 内部会话 ID。
|
||||
group_id: 群组标识。
|
||||
"""
|
||||
group_session_ids = self.group_sessions.setdefault(group_id, set())
|
||||
group_session_ids.add(session_id)
|
||||
|
||||
def _unbind_group(self, session_id: str, group_id: str) -> None:
|
||||
"""移除会话与群组的索引关系。
|
||||
|
||||
Args:
|
||||
session_id: 内部会话 ID。
|
||||
group_id: 群组标识。
|
||||
"""
|
||||
group_session_ids = self.group_sessions.get(group_id)
|
||||
if group_session_ids is None:
|
||||
return
|
||||
|
||||
group_session_ids.discard(session_id)
|
||||
if not group_session_ids:
|
||||
del self.group_sessions[group_id]
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
session_id: str,
|
||||
connection_id: str,
|
||||
client_session_id: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
sender: AsyncMessageSender,
|
||||
) -> None:
|
||||
"""注册一个新的逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
connection_id: 物理 WebSocket 连接 ID。
|
||||
client_session_id: 前端标签页使用的会话 ID。
|
||||
user_id: 规范化后的用户 ID。
|
||||
user_name: 当前展示昵称。
|
||||
virtual_config: 当前虚拟身份配置。
|
||||
sender: 发送消息到前端的异步回调。
|
||||
"""
|
||||
existing_session_id = self.client_sessions.get((connection_id, client_session_id))
|
||||
if existing_session_id is not None:
|
||||
self.disconnect(existing_session_id)
|
||||
|
||||
active_group_id = get_current_group_id(virtual_config)
|
||||
session_connection = ChatSessionConnection(
|
||||
session_id=session_id,
|
||||
connection_id=connection_id,
|
||||
client_session_id=client_session_id,
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
active_group_id=active_group_id,
|
||||
virtual_config=virtual_config,
|
||||
sender=sender,
|
||||
)
|
||||
|
||||
self.active_connections[session_id] = session_connection
|
||||
self.client_sessions[(connection_id, client_session_id)] = session_id
|
||||
self.connection_sessions.setdefault(connection_id, set()).add(session_id)
|
||||
self.user_sessions.setdefault(user_id, set()).add(session_id)
|
||||
self._bind_group(session_id, active_group_id)
|
||||
logger.info(
|
||||
"WebUI 聊天会话已连接: session=%s, connection=%s, client_session=%s, user=%s, group=%s",
|
||||
session_id,
|
||||
connection_id,
|
||||
client_session_id,
|
||||
user_id,
|
||||
active_group_id,
|
||||
)
|
||||
|
||||
def disconnect(self, session_id: str) -> None:
|
||||
"""断开一个逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
"""
|
||||
session_connection = self.active_connections.pop(session_id, None)
|
||||
if session_connection is None:
|
||||
return
|
||||
|
||||
self.client_sessions.pop((session_connection.connection_id, session_connection.client_session_id), None)
|
||||
self._unbind_group(session_id, session_connection.active_group_id)
|
||||
|
||||
connection_session_ids = self.connection_sessions.get(session_connection.connection_id)
|
||||
if connection_session_ids is not None:
|
||||
connection_session_ids.discard(session_id)
|
||||
if not connection_session_ids:
|
||||
del self.connection_sessions[session_connection.connection_id]
|
||||
|
||||
user_session_ids = self.user_sessions.get(session_connection.user_id)
|
||||
if user_session_ids is not None:
|
||||
user_session_ids.discard(session_id)
|
||||
if not user_session_ids:
|
||||
del self.user_sessions[session_connection.user_id]
|
||||
|
||||
logger.info("WebUI 聊天会话已断开: session=%s", session_id)
|
||||
|
||||
def disconnect_connection(self, connection_id: str) -> None:
|
||||
"""断开物理连接下的全部逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
connection_id: 物理 WebSocket 连接 ID。
|
||||
"""
|
||||
session_ids = list(self.connection_sessions.get(connection_id, set()))
|
||||
for session_id in session_ids:
|
||||
self.disconnect(session_id)
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[ChatSessionConnection]:
|
||||
"""获取逻辑聊天会话信息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
|
||||
Returns:
|
||||
Optional[ChatSessionConnection]: 会话存在时返回对应信息。
|
||||
"""
|
||||
return self.active_connections.get(session_id)
|
||||
|
||||
def get_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
|
||||
"""根据连接 ID 和前端会话 ID 查询内部会话 ID。
|
||||
|
||||
Args:
|
||||
connection_id: 物理 WebSocket 连接 ID。
|
||||
client_session_id: 前端标签页使用的会话 ID。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 找到时返回内部会话 ID。
|
||||
"""
|
||||
return self.client_sessions.get((connection_id, client_session_id))
|
||||
|
||||
def update_session_context(
|
||||
self,
|
||||
session_id: str,
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> None:
|
||||
"""更新会话上下文信息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
user_name: 最新昵称。
|
||||
virtual_config: 最新虚拟身份配置。
|
||||
"""
|
||||
session_connection = self.active_connections.get(session_id)
|
||||
if session_connection is None:
|
||||
return
|
||||
|
||||
next_group_id = get_current_group_id(virtual_config)
|
||||
if next_group_id != session_connection.active_group_id:
|
||||
self._unbind_group(session_id, session_connection.active_group_id)
|
||||
self._bind_group(session_id, next_group_id)
|
||||
session_connection.active_group_id = next_group_id
|
||||
|
||||
session_connection.user_name = user_name
|
||||
session_connection.virtual_config = virtual_config
|
||||
|
||||
async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
|
||||
if session_id in self.active_connections:
|
||||
try:
|
||||
await self.active_connections[session_id].send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
"""向指定逻辑会话发送消息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
message: 待发送的消息内容。
|
||||
"""
|
||||
session_connection = self.active_connections.get(session_id)
|
||||
if session_connection is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await session_connection.sender(message)
|
||||
except Exception as exc:
|
||||
logger.error("发送聊天消息失败: session=%s, error=%s", session_id, exc)
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]) -> None:
|
||||
"""向全部逻辑聊天会话广播消息。
|
||||
|
||||
Args:
|
||||
message: 待广播的消息内容。
|
||||
"""
|
||||
for session_id in list(self.active_connections.keys()):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
async def broadcast_to_group(self, group_id: str, message: Dict[str, Any]) -> None:
|
||||
"""向指定群组下的全部逻辑会话广播消息。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识。
|
||||
message: 待广播的消息内容。
|
||||
"""
|
||||
for session_id in list(self.group_sessions.get(group_id, set())):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
|
||||
chat_history = ChatHistoryManager()
|
||||
chat_manager = ChatConnectionManager()
|
||||
|
||||
|
||||
def is_virtual_mode_enabled(virtual_config: Optional[VirtualIdentityConfig]) -> bool:
|
||||
"""判断当前是否启用了虚拟身份模式。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
bool: 已启用时返回 ``True``。
|
||||
"""
|
||||
return bool(virtual_config and virtual_config.enabled)
|
||||
|
||||
|
||||
async def authenticate_chat_websocket(websocket: WebSocket, token: Optional[str]) -> bool:
|
||||
if token and verify_ws_token(token):
|
||||
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
|
||||
return True
|
||||
|
||||
if cookie_token := websocket.cookies.get("maibot_session"):
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def normalize_webui_user_id(user_id: Optional[str]) -> str:
|
||||
"""标准化 WebUI 用户 ID。
|
||||
|
||||
Args:
|
||||
user_id: 原始用户 ID。
|
||||
|
||||
Returns:
|
||||
str: 带统一前缀的用户 ID。
|
||||
"""
|
||||
if not user_id:
|
||||
return f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
|
||||
if user_id.startswith(WEBUI_USER_ID_PREFIX):
|
||||
@@ -172,12 +405,30 @@ def normalize_webui_user_id(user_id: Optional[str]) -> str:
|
||||
|
||||
|
||||
def get_person_by_person_id(person_id: str) -> Optional[PersonInfo]:
|
||||
"""根据人物 ID 查询人物信息。
|
||||
|
||||
Args:
|
||||
person_id: 人物 ID。
|
||||
|
||||
Returns:
|
||||
Optional[PersonInfo]: 查到时返回人物信息。
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||
return session.exec(statement).first()
|
||||
|
||||
|
||||
def build_virtual_identity_config(person: PersonInfo, group_id: str, group_name: str) -> VirtualIdentityConfig:
|
||||
"""根据人物信息构建虚拟身份配置。
|
||||
|
||||
Args:
|
||||
person: 人物信息对象。
|
||||
group_id: 逻辑群组 ID。
|
||||
group_name: 逻辑群组名称。
|
||||
|
||||
Returns:
|
||||
VirtualIdentityConfig: 虚拟身份配置对象。
|
||||
"""
|
||||
return VirtualIdentityConfig(
|
||||
enabled=True,
|
||||
platform=person.platform,
|
||||
@@ -195,6 +446,17 @@ def resolve_initial_virtual_identity(
|
||||
group_name: Optional[str],
|
||||
group_id: Optional[str],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
"""根据初始参数解析虚拟身份配置。
|
||||
|
||||
Args:
|
||||
platform: 平台名称。
|
||||
person_id: 人物 ID。
|
||||
group_name: 群组名称。
|
||||
group_id: 群组 ID。
|
||||
|
||||
Returns:
|
||||
Optional[VirtualIdentityConfig]: 解析成功时返回虚拟身份配置。
|
||||
"""
|
||||
if not (platform and person_id):
|
||||
return None
|
||||
|
||||
@@ -210,11 +472,14 @@ def resolve_initial_virtual_identity(
|
||||
group_name=group_name or "WebUI虚拟群聊",
|
||||
)
|
||||
logger.info(
|
||||
f"虚拟身份模式已通过 URL 参数激活: {virtual_config.user_nickname} @ {virtual_config.platform}, group_id={virtual_group_id}"
|
||||
"虚拟身份模式已通过参数激活: %s @ %s, group_id=%s",
|
||||
virtual_config.user_nickname,
|
||||
virtual_config.platform,
|
||||
virtual_group_id,
|
||||
)
|
||||
return virtual_config
|
||||
except Exception as e:
|
||||
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"通过参数配置虚拟身份失败: {exc}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -224,6 +489,17 @@ def build_session_info_message(
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Dict[str, Any]:
|
||||
"""构建会话信息消息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
user_id: 规范化后的用户 ID。
|
||||
user_name: 当前昵称。
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 会话信息消息。
|
||||
"""
|
||||
session_info_data: Dict[str, Any] = {
|
||||
"type": "session_info",
|
||||
"session_id": session_id,
|
||||
@@ -247,13 +523,41 @@ def build_session_info_message(
|
||||
|
||||
|
||||
def get_active_history_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> Optional[str]:
|
||||
"""获取当前虚拟身份对应的历史群组 ID。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 虚拟身份启用时返回对应群组 ID。
|
||||
"""
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return virtual_config.group_id
|
||||
return None
|
||||
|
||||
|
||||
def get_current_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> str:
|
||||
"""获取当前会话的有效群组 ID。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
str: 当前会话应使用的群组 ID。
|
||||
"""
|
||||
return get_active_history_group_id(virtual_config) or WEBUI_CHAT_GROUP_ID
|
||||
|
||||
|
||||
def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> str:
|
||||
"""构建欢迎消息。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
str: 欢迎消息文本。
|
||||
"""
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return (
|
||||
@@ -264,6 +568,12 @@ def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> st
|
||||
|
||||
|
||||
async def send_chat_error(session_id: str, content: str) -> None:
|
||||
"""向指定会话发送错误消息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
content: 错误消息内容。
|
||||
"""
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
@@ -279,7 +589,17 @@ async def send_initial_chat_state(
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
include_welcome: bool = True,
|
||||
) -> None:
|
||||
"""向新会话发送初始化状态。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
user_id: 规范化后的用户 ID。
|
||||
user_name: 当前昵称。
|
||||
virtual_config: 虚拟身份配置。
|
||||
include_welcome: 是否发送欢迎消息。
|
||||
"""
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
build_session_info_message(
|
||||
@@ -290,30 +610,43 @@ async def send_initial_chat_state(
|
||||
),
|
||||
)
|
||||
|
||||
if history := chat_history.get_history(50, get_active_history_group_id(virtual_config)):
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
},
|
||||
)
|
||||
|
||||
history_group_id = get_active_history_group_id(virtual_config)
|
||||
history = chat_history.get_history(50, history_group_id)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": build_welcome_message(virtual_config),
|
||||
"timestamp": time.time(),
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
"group_id": get_current_group_id(virtual_config),
|
||||
},
|
||||
)
|
||||
|
||||
if include_welcome:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": build_welcome_message(virtual_config),
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def resolve_sender_identity(
|
||||
current_user_name: str,
|
||||
normalized_user_id: str,
|
||||
virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Tuple[str, str]:
|
||||
"""解析当前发送者身份。
|
||||
|
||||
Args:
|
||||
current_user_name: 当前昵称。
|
||||
normalized_user_id: 规范化后的用户 ID。
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: ``(发送者昵称, 发送者用户 ID)``。
|
||||
"""
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id
|
||||
@@ -328,6 +661,19 @@ def create_message_data(
|
||||
is_at_bot: bool = True,
|
||||
virtual_config: Optional[VirtualIdentityConfig] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建发送给聊天核心的消息数据。
|
||||
|
||||
Args:
|
||||
content: 文本内容。
|
||||
user_id: 用户 ID。
|
||||
user_name: 用户昵称。
|
||||
message_id: 消息 ID。
|
||||
is_at_bot: 是否默认艾特机器人。
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 聊天核心可处理的消息数据。
|
||||
"""
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
@@ -389,6 +735,18 @@ async def handle_chat_message(
|
||||
normalized_user_id: str,
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> str:
|
||||
"""处理用户发送的聊天消息。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
data: 前端提交的消息数据。
|
||||
current_user_name: 当前昵称。
|
||||
normalized_user_id: 规范化后的用户 ID。
|
||||
current_virtual_config: 当前虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
str: 处理后的最新昵称。
|
||||
"""
|
||||
content = str(data.get("content", "")).strip()
|
||||
if not content:
|
||||
return current_user_name
|
||||
@@ -401,11 +759,14 @@ async def handle_chat_message(
|
||||
normalized_user_id=normalized_user_id,
|
||||
virtual_config=current_virtual_config,
|
||||
)
|
||||
target_group_id = get_current_group_id(current_virtual_config)
|
||||
|
||||
await chat_manager.broadcast(
|
||||
await chat_manager.broadcast_to_group(
|
||||
target_group_id,
|
||||
{
|
||||
"type": "user_message",
|
||||
"content": content,
|
||||
"group_id": target_group_id,
|
||||
"message_id": message_id,
|
||||
"timestamp": timestamp,
|
||||
"sender": {
|
||||
@@ -414,7 +775,7 @@ async def handle_chat_message(
|
||||
"is_bot": False,
|
||||
},
|
||||
"virtual_mode": is_virtual_mode_enabled(current_virtual_config),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
message_data = create_message_data(
|
||||
@@ -427,22 +788,37 @@ async def handle_chat_message(
|
||||
)
|
||||
|
||||
try:
|
||||
await chat_manager.broadcast({"type": "typing", "is_typing": True})
|
||||
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": True})
|
||||
await chat_bot.message_process(message_data)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时出错: {e}")
|
||||
await send_chat_error(session_id, f"处理消息时出错: {str(e)}")
|
||||
except Exception as exc:
|
||||
logger.error(f"处理消息时出错: {exc}")
|
||||
await send_chat_error(session_id, f"处理消息时出错: {str(exc)}")
|
||||
finally:
|
||||
await chat_manager.broadcast({"type": "typing", "is_typing": False})
|
||||
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": False})
|
||||
|
||||
return next_user_name
|
||||
|
||||
|
||||
async def handle_chat_ping(session_id: str) -> None:
|
||||
"""处理聊天心跳。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
"""
|
||||
await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()})
|
||||
|
||||
|
||||
async def handle_nickname_update(session_id: str, data: Dict[str, Any], current_user_name: str) -> str:
|
||||
"""处理昵称更新请求。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
data: 前端提交的数据。
|
||||
current_user_name: 当前昵称。
|
||||
|
||||
Returns:
|
||||
str: 更新后的昵称。
|
||||
"""
|
||||
new_name = str(data.get("user_name", "")).strip()
|
||||
if not new_name:
|
||||
return current_user_name
|
||||
@@ -463,6 +839,16 @@ async def enable_virtual_identity(
|
||||
session_prefix: str,
|
||||
virtual_data: Dict[str, Any],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
"""启用虚拟身份模式。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
session_prefix: 会话前缀,用于生成默认群组 ID。
|
||||
virtual_data: 前端提交的虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Optional[VirtualIdentityConfig]: 启用成功时返回新的虚拟身份配置。
|
||||
"""
|
||||
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
|
||||
await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id")
|
||||
return None
|
||||
@@ -470,16 +856,18 @@ async def enable_virtual_identity(
|
||||
person_id_value = str(virtual_data.get("person_id"))
|
||||
try:
|
||||
person = get_person_by_person_id(person_id_value)
|
||||
if not person:
|
||||
if person is None:
|
||||
await send_chat_error(session_id, f"找不到用户: {person_id_value}")
|
||||
return None
|
||||
|
||||
custom_group_id = virtual_data.get("group_id")
|
||||
current_group_id = (
|
||||
f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
|
||||
if custom_group_id
|
||||
else f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
|
||||
)
|
||||
custom_group_id = str(virtual_data.get("group_id") or "").strip()
|
||||
if custom_group_id:
|
||||
current_group_id = custom_group_id
|
||||
if not current_group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
|
||||
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{current_group_id}"
|
||||
else:
|
||||
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
|
||||
|
||||
current_virtual_config = build_virtual_identity_config(
|
||||
person=person,
|
||||
group_id=current_group_id,
|
||||
@@ -521,13 +909,18 @@ async def enable_virtual_identity(
|
||||
},
|
||||
)
|
||||
return current_virtual_config
|
||||
except Exception as e:
|
||||
logger.error(f"设置虚拟身份失败: {e}")
|
||||
await send_chat_error(session_id, f"设置虚拟身份失败: {str(e)}")
|
||||
except Exception as exc:
|
||||
logger.error(f"设置虚拟身份失败: {exc}")
|
||||
await send_chat_error(session_id, f"设置虚拟身份失败: {str(exc)}")
|
||||
return None
|
||||
|
||||
|
||||
async def disable_virtual_identity(session_id: str) -> None:
|
||||
"""关闭虚拟身份模式。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
"""
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
@@ -560,7 +953,18 @@ async def handle_virtual_identity_update(
|
||||
data: Dict[str, Any],
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
virtual_data = cast(dict[str, Any], data.get("config", {}))
|
||||
"""处理虚拟身份切换请求。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
session_id_prefix: 会话前缀。
|
||||
data: 前端提交的数据。
|
||||
current_virtual_config: 当前虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Optional[VirtualIdentityConfig]: 更新后的虚拟身份配置。
|
||||
"""
|
||||
virtual_data = cast(Dict[str, Any], data.get("config", {}))
|
||||
if virtual_data.get("enabled"):
|
||||
next_config = await enable_virtual_identity(session_id, session_id_prefix, virtual_data)
|
||||
return next_config if next_config is not None else current_virtual_config
|
||||
@@ -577,6 +981,19 @@ async def dispatch_chat_event(
|
||||
normalized_user_id: str,
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
) -> Tuple[str, Optional[VirtualIdentityConfig]]:
|
||||
"""分发聊天事件到对应的处理器。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
session_id_prefix: 会话前缀。
|
||||
data: 前端提交的数据。
|
||||
current_user_name: 当前昵称。
|
||||
normalized_user_id: 规范化后的用户 ID。
|
||||
current_virtual_config: 当前虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Tuple[str, Optional[VirtualIdentityConfig]]: ``(最新昵称, 最新虚拟身份配置)``。
|
||||
"""
|
||||
event_type = data.get("type")
|
||||
if event_type == "message":
|
||||
next_user_name = await handle_chat_message(
|
||||
@@ -24,10 +24,8 @@ from src.config.official_configs import (
|
||||
ChineseTypoConfig,
|
||||
DebugConfig,
|
||||
EmojiConfig,
|
||||
ExperimentalConfig,
|
||||
ExpressionConfig,
|
||||
KeywordReactionConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
MaimMessageConfig,
|
||||
MemoryConfig,
|
||||
MessageReceiveConfig,
|
||||
@@ -109,9 +107,7 @@ async def get_config_section_schema(section_name: str):
|
||||
- response_post_process: ResponsePostProcessConfig
|
||||
- response_splitter: ResponseSplitterConfig
|
||||
- telemetry: TelemetryConfig
|
||||
- experimental: ExperimentalConfig
|
||||
- maim_message: MaimMessageConfig
|
||||
- lpmm_knowledge: LPMMKnowledgeConfig
|
||||
- memory: MemoryConfig
|
||||
- debug: DebugConfig
|
||||
- voice: VoiceConfig
|
||||
@@ -133,9 +129,7 @@ async def get_config_section_schema(section_name: str):
|
||||
"response_post_process": ResponsePostProcessConfig,
|
||||
"response_splitter": ResponseSplitterConfig,
|
||||
"telemetry": TelemetryConfig,
|
||||
"experimental": ExperimentalConfig,
|
||||
"maim_message": MaimMessageConfig,
|
||||
"lpmm_knowledge": LPMMKnowledgeConfig,
|
||||
"memory": MemoryConfig,
|
||||
"debug": DebugConfig,
|
||||
"voice": VoiceConfig,
|
||||
|
||||
@@ -6,11 +6,13 @@ from .catalog import router as catalog_router
|
||||
from .config_routes import router as config_router
|
||||
from .management import router as management_router
|
||||
from .progress import get_progress_router, update_progress
|
||||
from .runtime_routes import router as runtime_router
|
||||
|
||||
router = APIRouter(prefix="/plugins", tags=["插件管理"])
|
||||
router.include_router(catalog_router)
|
||||
router.include_router(management_router)
|
||||
router.include_router(config_router)
|
||||
router.include_router(runtime_router)
|
||||
|
||||
set_update_progress_callback(update_progress)
|
||||
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import json
|
||||
"""插件配置相关 WebUI 路由。"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, Cookie, HTTPException
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_runtime.protocol.envelope import InspectPluginConfigResultPayload
|
||||
from src.webui.utils.toml_utils import save_toml_with_format
|
||||
|
||||
from .schemas import UpdatePluginConfigRequest, UpdatePluginRawConfigRequest
|
||||
from .support import (
|
||||
backup_file,
|
||||
coerce_types,
|
||||
find_plugin_instance,
|
||||
find_plugin_path_by_id,
|
||||
get_plugin_config_path,
|
||||
normalize_dotted_keys,
|
||||
@@ -39,6 +40,16 @@ def _to_builtin_data(obj: Any) -> Any:
|
||||
|
||||
|
||||
def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> Dict[str, Any]:
|
||||
"""根据当前配置内容自动推断一个兜底 Schema。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
current_config: 当前配置对象。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 可供前端渲染的兜底 Schema。
|
||||
"""
|
||||
|
||||
schema: Dict[str, Any] = {
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_info": {
|
||||
@@ -134,33 +145,187 @@ def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> Di
|
||||
return schema
|
||||
|
||||
|
||||
def _coerce_scalar_value(field_schema: Dict[str, Any], value: Any) -> Any:
|
||||
"""根据字段 Schema 规范化单个字段值。
|
||||
|
||||
Args:
|
||||
field_schema: 单个字段 Schema。
|
||||
value: 当前字段值。
|
||||
|
||||
Returns:
|
||||
Any: 规范化后的字段值。
|
||||
"""
|
||||
|
||||
field_type = str(field_schema.get("type", "") or "").lower()
|
||||
if field_type == "boolean" and isinstance(value, str):
|
||||
normalized_value = value.strip().lower()
|
||||
if normalized_value in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if normalized_value in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
if field_type == "integer" and isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
if field_type == "number" and isinstance(value, str):
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return value
|
||||
if field_type == "array" and isinstance(value, str):
|
||||
return [item.strip() for item in value.split(",") if item.strip()]
|
||||
return value
|
||||
|
||||
|
||||
def _coerce_config_by_plugin_schema(schema: Dict[str, Any], config_data: Dict[str, Any]) -> None:
|
||||
"""根据插件配置 Schema 就地规范化配置值类型。
|
||||
|
||||
Args:
|
||||
schema: 插件配置 Schema。
|
||||
config_data: 待规范化的配置字典。
|
||||
"""
|
||||
|
||||
sections = schema.get("sections")
|
||||
if not isinstance(sections, dict):
|
||||
return
|
||||
|
||||
for section_name, section_schema in sections.items():
|
||||
if not isinstance(section_schema, dict):
|
||||
continue
|
||||
if section_name not in config_data or not isinstance(config_data[section_name], dict):
|
||||
continue
|
||||
|
||||
section_fields = section_schema.get("fields")
|
||||
if not isinstance(section_fields, dict):
|
||||
continue
|
||||
|
||||
section_config = cast(Dict[str, Any], config_data[section_name])
|
||||
for field_name, field_schema in section_fields.items():
|
||||
if field_name not in section_config or not isinstance(field_schema, dict):
|
||||
continue
|
||||
section_config[field_name] = _coerce_scalar_value(field_schema, section_config[field_name])
|
||||
|
||||
|
||||
def _build_toml_document(config_data: Dict[str, Any]) -> tomlkit.TOMLDocument:
|
||||
"""将普通字典转换为 TOML 文档对象。
|
||||
|
||||
Args:
|
||||
config_data: 原始配置字典。
|
||||
|
||||
Returns:
|
||||
tomlkit.TOMLDocument: 解析后的 TOML 文档。
|
||||
"""
|
||||
|
||||
if not config_data:
|
||||
return tomlkit.document()
|
||||
return tomlkit.parse(tomlkit.dumps(config_data))
|
||||
|
||||
|
||||
def _load_plugin_config_from_disk(plugin_path: Path) -> Dict[str, Any]:
|
||||
"""从磁盘读取插件配置。
|
||||
|
||||
Args:
|
||||
plugin_path: 插件目录。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 当前配置字典;文件不存在时返回空字典。
|
||||
"""
|
||||
|
||||
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
|
||||
if not config_path.exists():
|
||||
return {}
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as file_obj:
|
||||
loaded_config = tomlkit.load(file_obj).unwrap()
|
||||
return loaded_config if isinstance(loaded_config, dict) else {}
|
||||
|
||||
|
||||
async def _inspect_plugin_config_via_runtime(
|
||||
plugin_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
use_provided_config: bool = False,
|
||||
) -> InspectPluginConfigResultPayload | None:
|
||||
"""通过插件运行时解析配置元数据。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
config_data: 可选的配置内容。
|
||||
use_provided_config: 是否优先使用传入配置而不是磁盘配置。
|
||||
|
||||
Returns:
|
||||
InspectPluginConfigResultPayload | None: 运行时可用时返回解析结果,否则返回 ``None``。
|
||||
|
||||
Raises:
|
||||
ValueError: 插件运行时明确拒绝解析请求时抛出。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
runtime_manager = get_plugin_runtime_manager()
|
||||
return await runtime_manager.inspect_plugin_config(
|
||||
plugin_id,
|
||||
config_data,
|
||||
use_provided_config=use_provided_config,
|
||||
)
|
||||
|
||||
|
||||
async def _validate_plugin_config_via_runtime(plugin_id: str, config_data: Dict[str, Any]) -> Dict[str, Any] | None:
|
||||
"""通过插件运行时对配置进行校验。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
config_data: 待校验的配置内容。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | None: 校验成功时返回规范化后的配置;若运行时不可用则返回
|
||||
``None``,由调用方自行回退到静态 Schema 方案。
|
||||
|
||||
Raises:
|
||||
ValueError: 插件运行时明确判定配置非法时抛出。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
runtime_manager = get_plugin_runtime_manager()
|
||||
return await runtime_manager.validate_plugin_config(plugin_id, config_data)
|
||||
|
||||
|
||||
@router.get("/config/{plugin_id}/schema")
|
||||
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
"""按插件 ID 返回配置 Schema。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
maibot_session: 当前会话令牌。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含 Schema 的响应字典。
|
||||
"""
|
||||
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"获取插件配置 Schema: {plugin_id}")
|
||||
|
||||
try:
|
||||
plugin_instance = find_plugin_instance(plugin_id)
|
||||
if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"):
|
||||
return {"success": True, "schema": plugin_instance.get_webui_config_schema()}
|
||||
|
||||
plugin_path = find_plugin_path_by_id(plugin_id)
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
schema_json_path = resolve_plugin_file_path(plugin_path, "config_schema.json")
|
||||
if schema_json_path.exists():
|
||||
try:
|
||||
with open(schema_json_path, "r", encoding="utf-8") as file_obj:
|
||||
return {"success": True, "schema": json.load(file_obj)}
|
||||
except Exception as e:
|
||||
logger.warning(f"读取 config_schema.json 失败,回退到自动推断: {e}")
|
||||
try:
|
||||
runtime_snapshot = await _inspect_plugin_config_via_runtime(plugin_id)
|
||||
except ValueError as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置 Schema 解析失败,将回退到弱推断: {exc}")
|
||||
runtime_snapshot = None
|
||||
|
||||
current_config: Any = {}
|
||||
config_path = get_plugin_config_path(plugin_id, plugin_path)
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as file_obj:
|
||||
current_config = tomlkit.load(file_obj)
|
||||
if runtime_snapshot is not None and runtime_snapshot.config_schema:
|
||||
return {"success": True, "schema": dict(runtime_snapshot.config_schema)}
|
||||
|
||||
current_config: Any = (
|
||||
dict(runtime_snapshot.normalized_config)
|
||||
if runtime_snapshot is not None
|
||||
else _load_plugin_config_from_disk(plugin_path)
|
||||
)
|
||||
|
||||
return {"success": True, "schema": _build_schema_from_current_config(plugin_id, current_config)}
|
||||
except HTTPException:
|
||||
@@ -172,6 +337,16 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
|
||||
|
||||
@router.get("/config/{plugin_id}/raw")
|
||||
async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
"""获取插件原始 TOML 配置内容。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
maibot_session: 当前会话令牌。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含原始配置文本的响应字典。
|
||||
"""
|
||||
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"获取插件原始配置: {plugin_id}")
|
||||
|
||||
@@ -199,6 +374,17 @@ async def update_plugin_config_raw(
|
||||
request: UpdatePluginRawConfigRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
) -> Dict[str, Any]:
|
||||
"""更新插件原始 TOML 配置内容。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
request: 原始配置更新请求。
|
||||
maibot_session: 当前会话令牌。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新结果。
|
||||
"""
|
||||
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"更新插件原始配置: {plugin_id}")
|
||||
|
||||
@@ -232,6 +418,16 @@ async def update_plugin_config_raw(
|
||||
|
||||
@router.get("/config/{plugin_id}")
|
||||
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
"""获取插件配置字典。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
maibot_session: 当前会话令牌。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 当前配置响应。
|
||||
"""
|
||||
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"获取插件配置: {plugin_id}")
|
||||
|
||||
@@ -241,12 +437,24 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = get_plugin_config_path(plugin_id, plugin_path)
|
||||
try:
|
||||
runtime_snapshot = await _inspect_plugin_config_via_runtime(plugin_id)
|
||||
except ValueError as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置读取失败,将回退到磁盘内容: {exc}")
|
||||
runtime_snapshot = None
|
||||
|
||||
if runtime_snapshot is not None:
|
||||
message = "配置文件不存在,已返回默认配置" if not config_path.exists() else ""
|
||||
return {
|
||||
"success": True,
|
||||
"config": dict(runtime_snapshot.normalized_config),
|
||||
"message": message,
|
||||
}
|
||||
|
||||
if not config_path.exists():
|
||||
return {"success": True, "config": {}, "message": "配置文件不存在"}
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as file_obj:
|
||||
config = tomlkit.load(file_obj)
|
||||
return {"success": True, "config": _to_builtin_data(config)}
|
||||
return {"success": True, "config": _load_plugin_config_from_disk(plugin_path)}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -260,21 +468,40 @@ async def update_plugin_config(
|
||||
request: UpdatePluginConfigRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
) -> Dict[str, Any]:
|
||||
"""更新插件结构化配置。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
request: 结构化配置更新请求。
|
||||
maibot_session: 当前会话令牌。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新结果。
|
||||
"""
|
||||
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"更新插件配置: {plugin_id}")
|
||||
|
||||
try:
|
||||
plugin_instance = find_plugin_instance(plugin_id)
|
||||
config_data = request.config or {}
|
||||
if plugin_instance and isinstance(config_data, dict):
|
||||
config_data = normalize_dotted_keys(config_data)
|
||||
if isinstance(plugin_instance.config_schema, dict):
|
||||
coerce_types(plugin_instance.config_schema, config_data)
|
||||
|
||||
plugin_path = find_plugin_path_by_id(plugin_id)
|
||||
if plugin_path is None:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_data = request.config or {}
|
||||
if isinstance(config_data, dict):
|
||||
config_data = normalize_dotted_keys(config_data)
|
||||
runtime_validated_config = await _validate_plugin_config_via_runtime(plugin_id, config_data)
|
||||
if isinstance(runtime_validated_config, dict):
|
||||
config_data = runtime_validated_config
|
||||
else:
|
||||
runtime_snapshot = await _inspect_plugin_config_via_runtime(
|
||||
plugin_id,
|
||||
config_data,
|
||||
use_provided_config=True,
|
||||
)
|
||||
if runtime_snapshot is not None and runtime_snapshot.config_schema:
|
||||
_coerce_config_by_plugin_schema(dict(runtime_snapshot.config_schema), config_data)
|
||||
|
||||
config_path = get_plugin_config_path(plugin_id, plugin_path)
|
||||
backup_path = backup_file(config_path, "backup")
|
||||
if backup_path is not None:
|
||||
@@ -284,6 +511,8 @@ async def update_plugin_config(
|
||||
save_toml_with_format(config_data, str(config_path))
|
||||
logger.info(f"已更新插件配置: {plugin_id}")
|
||||
return {"success": True, "message": "配置已保存", "note": "配置更改将自动热更新到对应插件"}
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -293,6 +522,16 @@ async def update_plugin_config(
|
||||
|
||||
@router.post("/config/{plugin_id}/reset")
|
||||
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
"""重置插件配置文件。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
maibot_session: 当前会话令牌。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 重置结果。
|
||||
"""
|
||||
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"重置插件配置: {plugin_id}")
|
||||
|
||||
@@ -317,6 +556,16 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co
|
||||
|
||||
@router.post("/config/{plugin_id}/toggle")
|
||||
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
"""切换插件启用状态。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
maibot_session: 当前会话令牌。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 切换结果。
|
||||
"""
|
||||
|
||||
require_plugin_token(maibot_session)
|
||||
logger.info(f"切换插件状态: {plugin_id}")
|
||||
|
||||
@@ -326,16 +575,29 @@ async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(N
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = get_plugin_config_path(plugin_id, plugin_path)
|
||||
config = tomlkit.document()
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as file_obj:
|
||||
config = tomlkit.load(file_obj)
|
||||
try:
|
||||
runtime_snapshot = await _inspect_plugin_config_via_runtime(plugin_id)
|
||||
except ValueError as exc:
|
||||
logger.warning(f"插件 {plugin_id} 状态切换前配置解析失败,将回退到磁盘内容: {exc}")
|
||||
runtime_snapshot = None
|
||||
|
||||
if "plugin" not in config:
|
||||
current_config = (
|
||||
dict(runtime_snapshot.normalized_config)
|
||||
if runtime_snapshot is not None
|
||||
else _load_plugin_config_from_disk(plugin_path)
|
||||
)
|
||||
config = _build_toml_document(current_config)
|
||||
|
||||
plugin_section = config.get("plugin")
|
||||
if plugin_section is None or not hasattr(plugin_section, "get"):
|
||||
config["plugin"] = tomlkit.table()
|
||||
|
||||
plugin_config = cast(Any, config["plugin"])
|
||||
current_enabled = bool(plugin_config.get("enabled", True))
|
||||
current_enabled = (
|
||||
bool(runtime_snapshot.enabled)
|
||||
if runtime_snapshot is not None
|
||||
else bool(plugin_config.get("enabled", True))
|
||||
)
|
||||
new_enabled = not current_enabled
|
||||
plugin_config["enabled"] = new_enabled
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -347,7 +609,7 @@ async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(N
|
||||
"success": True,
|
||||
"enabled": new_enabled,
|
||||
"message": f"插件已{status}",
|
||||
"note": "状态更改将在下次加载插件时生效",
|
||||
"note": "状态更改将自动热更新到对应插件",
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
"""插件进度实时推送支持。"""
|
||||
|
||||
from typing import Any, Dict, Optional, Set
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
from src.webui.routers.websocket.manager import websocket_manager
|
||||
|
||||
logger = get_logger("webui.plugin_progress")
|
||||
|
||||
@@ -25,25 +28,29 @@ current_progress: Dict[str, Any] = {
|
||||
}
|
||||
|
||||
|
||||
def get_current_progress() -> Dict[str, Any]:
|
||||
"""获取当前插件进度快照。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 当前插件进度数据副本。
|
||||
"""
|
||||
return current_progress.copy()
|
||||
|
||||
|
||||
async def broadcast_progress(progress_data: Dict[str, Any]) -> None:
|
||||
"""向统一连接层广播插件进度更新。
|
||||
|
||||
Args:
|
||||
progress_data: 插件进度数据。
|
||||
"""
|
||||
global current_progress
|
||||
current_progress = progress_data.copy()
|
||||
|
||||
if not active_connections:
|
||||
return
|
||||
|
||||
message = json.dumps(progress_data, ensure_ascii=False)
|
||||
disconnected: Set[WebSocket] = set()
|
||||
|
||||
for websocket in active_connections:
|
||||
try:
|
||||
await websocket.send_text(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送进度更新失败: {e}")
|
||||
disconnected.add(websocket)
|
||||
|
||||
for websocket in disconnected:
|
||||
active_connections.discard(websocket)
|
||||
await websocket_manager.broadcast_to_topic(
|
||||
domain="plugin_progress",
|
||||
topic="main",
|
||||
event="update",
|
||||
data={"progress": progress_data},
|
||||
)
|
||||
|
||||
|
||||
async def update_progress(
|
||||
@@ -56,6 +63,18 @@ async def update_progress(
|
||||
total_plugins: int = 0,
|
||||
loaded_plugins: int = 0,
|
||||
) -> None:
|
||||
"""更新当前插件进度并广播。
|
||||
|
||||
Args:
|
||||
stage: 当前阶段。
|
||||
progress: 当前进度百分比。
|
||||
message: 进度说明消息。
|
||||
operation: 当前操作类型。
|
||||
error: 可选的错误信息。
|
||||
plugin_id: 当前处理的插件 ID。
|
||||
total_plugins: 总插件数量。
|
||||
loaded_plugins: 已处理插件数量。
|
||||
"""
|
||||
progress_data = {
|
||||
"operation": operation,
|
||||
"stage": stage,
|
||||
@@ -74,6 +93,12 @@ async def update_progress(
|
||||
|
||||
@router.websocket("/ws/plugin-progress")
|
||||
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
|
||||
"""旧版插件进度 WebSocket 入口。
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket 对象。
|
||||
token: 可选的一次性握手 Token。
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
if token and verify_ws_token(token):
|
||||
@@ -105,17 +130,22 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
|
||||
data = await websocket.receive_text()
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
except Exception as e:
|
||||
logger.error(f"处理客户端消息时出错: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"处理客户端消息时出错: {exc}")
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
active_connections.discard(websocket)
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebSocket 错误: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"❌ WebSocket 错误: {exc}")
|
||||
active_connections.discard(websocket)
|
||||
|
||||
|
||||
def get_progress_router() -> APIRouter:
|
||||
"""获取旧版插件进度路由对象。
|
||||
|
||||
Returns:
|
||||
APIRouter: 插件进度路由对象。
|
||||
"""
|
||||
return router
|
||||
|
||||
28
src/webui/routers/plugin/runtime_routes.py
Normal file
28
src/webui/routers/plugin/runtime_routes.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""插件运行时相关 WebUI 路由。"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Cookie
|
||||
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
|
||||
from .schemas import HookSpecListResponse, HookSpecResponse
|
||||
from .support import require_plugin_token
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/runtime/hooks", response_model=HookSpecListResponse)
|
||||
async def list_runtime_hook_specs(maibot_session: Optional[str] = Cookie(None)) -> HookSpecListResponse:
|
||||
"""返回当前插件运行时公开的 Hook 规格清单。
|
||||
|
||||
Args:
|
||||
maibot_session: 当前 WebUI 会话令牌。
|
||||
|
||||
Returns:
|
||||
HookSpecListResponse: Hook 规格列表响应。
|
||||
"""
|
||||
|
||||
require_plugin_token(maibot_session)
|
||||
hooks = [HookSpecResponse(**hook_data) for hook_data in component_query_service.list_hook_specs()]
|
||||
return HookSpecListResponse(success=True, hooks=hooks)
|
||||
@@ -111,3 +111,19 @@ class UpdatePluginConfigRequest(BaseModel):
|
||||
|
||||
class UpdatePluginRawConfigRequest(BaseModel):
|
||||
config: str = Field(..., description="原始 TOML 配置内容")
|
||||
|
||||
|
||||
class HookSpecResponse(BaseModel):
|
||||
name: str = Field(..., description="Hook 名称")
|
||||
description: str = Field("", description="Hook 描述")
|
||||
parameters_schema: Dict[str, Any] = Field(default_factory=dict, description="Hook 参数模型")
|
||||
default_timeout_ms: int = Field(..., description="默认超时毫秒数")
|
||||
allow_blocking: bool = Field(..., description="是否允许 blocking 处理器")
|
||||
allow_observe: bool = Field(..., description="是否允许 observe 处理器")
|
||||
allow_abort: bool = Field(..., description="是否允许 abort")
|
||||
allow_kwargs_mutation: bool = Field(..., description="是否允许修改 kwargs")
|
||||
|
||||
|
||||
class HookSpecListResponse(BaseModel):
|
||||
success: bool = Field(..., description="是否成功")
|
||||
hooks: List[HookSpecResponse] = Field(default_factory=list, description="Hook 规格列表")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from .auth import router as ws_auth_router
|
||||
from .logs import router as logs_router
|
||||
"""WebSocket 路由包。"""
|
||||
|
||||
__all__ = [
|
||||
"logs_router",
|
||||
"ws_auth_router",
|
||||
"auth",
|
||||
"manager",
|
||||
"unified",
|
||||
]
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
"""WebSocket 日志推送路由兼容导出。"""
|
||||
|
||||
from src.webui.logs_ws import active_connections, broadcast_log, load_recent_logs, router, websocket_logs
|
||||
|
||||
__all__ = [
|
||||
"active_connections",
|
||||
"broadcast_log",
|
||||
"load_recent_logs",
|
||||
"router",
|
||||
"websocket_logs",
|
||||
]
|
||||
297
src/webui/routers/websocket/manager.py
Normal file
297
src/webui/routers/websocket/manager.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""统一 WebSocket 连接管理器。"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.websocket")
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocketConnection:
|
||||
"""统一 WebSocket 连接上下文。"""
|
||||
|
||||
connection_id: str
|
||||
websocket: WebSocket
|
||||
subscriptions: Set[str] = field(default_factory=set)
|
||||
chat_sessions: Dict[str, str] = field(default_factory=dict)
|
||||
send_queue: "asyncio.Queue[Optional[Dict[str, Any]]]" = field(default_factory=asyncio.Queue)
|
||||
sender_task: Optional["asyncio.Task[None]"] = None
|
||||
|
||||
|
||||
class UnifiedWebSocketManager:
|
||||
"""统一 WebSocket 连接管理器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化统一 WebSocket 连接管理器。"""
|
||||
self.connections: Dict[str, WebSocketConnection] = {}
|
||||
|
||||
def _build_subscription_key(self, domain: str, topic: str) -> str:
|
||||
"""构建订阅索引键。
|
||||
|
||||
Args:
|
||||
domain: 业务域名称。
|
||||
topic: 主题名称。
|
||||
|
||||
Returns:
|
||||
str: 订阅索引键。
|
||||
"""
|
||||
return f"{domain}:{topic}"
|
||||
|
||||
async def _sender_loop(self, connection: WebSocketConnection) -> None:
|
||||
"""串行发送指定连接的出站消息。
|
||||
|
||||
Args:
|
||||
connection: 目标连接上下文。
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
message = await connection.send_queue.get()
|
||||
if message is None:
|
||||
return
|
||||
await connection.websocket.send_json(message)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("统一 WebSocket 发送失败: connection=%s, error=%s", connection.connection_id, exc)
|
||||
|
||||
async def connect(self, connection_id: str, websocket: WebSocket) -> WebSocketConnection:
|
||||
"""注册一个新的物理 WebSocket 连接。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
websocket: FastAPI WebSocket 对象。
|
||||
|
||||
Returns:
|
||||
WebSocketConnection: 新建的连接上下文。
|
||||
"""
|
||||
await websocket.accept()
|
||||
connection = WebSocketConnection(connection_id=connection_id, websocket=websocket)
|
||||
connection.sender_task = asyncio.create_task(self._sender_loop(connection))
|
||||
self.connections[connection_id] = connection
|
||||
return connection
|
||||
|
||||
async def disconnect(self, connection_id: str) -> None:
|
||||
"""断开并清理指定连接。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
"""
|
||||
connection = self.connections.pop(connection_id, None)
|
||||
if connection is None:
|
||||
return
|
||||
|
||||
await connection.send_queue.put(None)
|
||||
if connection.sender_task is not None:
|
||||
try:
|
||||
await connection.sender_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.debug("等待发送协程退出时出现异常: connection=%s, error=%s", connection_id, exc)
|
||||
|
||||
def get_connection(self, connection_id: str) -> Optional[WebSocketConnection]:
|
||||
"""获取指定连接上下文。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
|
||||
Returns:
|
||||
Optional[WebSocketConnection]: 找到时返回连接上下文。
|
||||
"""
|
||||
return self.connections.get(connection_id)
|
||||
|
||||
def register_chat_session(self, connection_id: str, client_session_id: str, session_id: str) -> None:
|
||||
"""登记连接下的逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
client_session_id: 前端会话 ID。
|
||||
session_id: 内部会话 ID。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return
|
||||
connection.chat_sessions[client_session_id] = session_id
|
||||
|
||||
def unregister_chat_session(self, connection_id: str, client_session_id: str) -> None:
|
||||
"""移除连接下的逻辑聊天会话登记。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
client_session_id: 前端会话 ID。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return
|
||||
connection.chat_sessions.pop(client_session_id, None)
|
||||
|
||||
def get_chat_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
|
||||
"""查询连接下的内部聊天会话 ID。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
client_session_id: 前端会话 ID。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 找到时返回内部会话 ID。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return None
|
||||
return connection.chat_sessions.get(client_session_id)
|
||||
|
||||
def subscribe(self, connection_id: str, domain: str, topic: str) -> None:
|
||||
"""登记连接的主题订阅。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
domain: 业务域名称。
|
||||
topic: 主题名称。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return
|
||||
connection.subscriptions.add(self._build_subscription_key(domain, topic))
|
||||
|
||||
def unsubscribe(self, connection_id: str, domain: str, topic: str) -> None:
|
||||
"""移除连接的主题订阅。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
domain: 业务域名称。
|
||||
topic: 主题名称。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return
|
||||
connection.subscriptions.discard(self._build_subscription_key(domain, topic))
|
||||
|
||||
def is_subscribed(self, connection_id: str, domain: str, topic: str) -> bool:
|
||||
"""判断连接是否订阅了指定主题。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
domain: 业务域名称。
|
||||
topic: 主题名称。
|
||||
|
||||
Returns:
|
||||
bool: 已订阅时返回 ``True``。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return False
|
||||
return self._build_subscription_key(domain, topic) in connection.subscriptions
|
||||
|
||||
async def enqueue(self, connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""向指定连接的发送队列压入消息。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 待发送的消息。
|
||||
"""
|
||||
connection = self.connections.get(connection_id)
|
||||
if connection is None:
|
||||
return
|
||||
await connection.send_queue.put(message)
|
||||
|
||||
async def send_response(
|
||||
self,
|
||||
connection_id: str,
|
||||
request_id: Optional[str],
|
||||
ok: bool,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
error: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""发送统一响应消息。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
request_id: 请求 ID。
|
||||
ok: 请求是否成功。
|
||||
data: 成功响应数据。
|
||||
error: 失败响应数据。
|
||||
"""
|
||||
response_message: Dict[str, Any] = {
|
||||
"op": "response",
|
||||
"id": request_id,
|
||||
"ok": ok,
|
||||
}
|
||||
if data is not None:
|
||||
response_message["data"] = data
|
||||
if error is not None:
|
||||
response_message["error"] = error
|
||||
await self.enqueue(connection_id, response_message)
|
||||
|
||||
async def send_event(
|
||||
self,
|
||||
connection_id: str,
|
||||
domain: str,
|
||||
event: str,
|
||||
data: Dict[str, Any],
|
||||
session: Optional[str] = None,
|
||||
topic: Optional[str] = None,
|
||||
) -> None:
|
||||
"""发送统一事件消息。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
domain: 业务域名称。
|
||||
event: 事件名称。
|
||||
data: 事件数据。
|
||||
session: 可选的逻辑会话 ID。
|
||||
topic: 可选的主题名称。
|
||||
"""
|
||||
event_message: Dict[str, Any] = {
|
||||
"op": "event",
|
||||
"domain": domain,
|
||||
"event": event,
|
||||
"data": data,
|
||||
}
|
||||
if session is not None:
|
||||
event_message["session"] = session
|
||||
if topic is not None:
|
||||
event_message["topic"] = topic
|
||||
await self.enqueue(connection_id, event_message)
|
||||
|
||||
async def send_pong(self, connection_id: str, timestamp: float) -> None:
|
||||
"""发送心跳响应。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
timestamp: 当前时间戳。
|
||||
"""
|
||||
await self.enqueue(
|
||||
connection_id,
|
||||
{
|
||||
"op": "pong",
|
||||
"ts": timestamp,
|
||||
},
|
||||
)
|
||||
|
||||
async def broadcast_to_topic(self, domain: str, topic: str, event: str, data: Dict[str, Any]) -> None:
|
||||
"""向订阅指定主题的全部连接广播事件。
|
||||
|
||||
Args:
|
||||
domain: 业务域名称。
|
||||
topic: 主题名称。
|
||||
event: 事件名称。
|
||||
data: 事件数据。
|
||||
"""
|
||||
subscription_key = self._build_subscription_key(domain, topic)
|
||||
for connection in list(self.connections.values()):
|
||||
if subscription_key in connection.subscriptions:
|
||||
await self.send_event(
|
||||
connection.connection_id,
|
||||
domain=domain,
|
||||
event=event,
|
||||
data=data,
|
||||
topic=topic,
|
||||
)
|
||||
|
||||
|
||||
websocket_manager = UnifiedWebSocketManager()
|
||||
548
src/webui/routers/websocket/unified.py
Normal file
548
src/webui/routers/websocket/unified.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""统一 WebSocket 路由。"""
|
||||
|
||||
from typing import Any, Dict, Optional, Set, cast
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.logs_ws import load_recent_logs
|
||||
from src.webui.routers.chat.service import (
|
||||
chat_manager,
|
||||
dispatch_chat_event,
|
||||
normalize_webui_user_id,
|
||||
resolve_initial_virtual_identity,
|
||||
send_initial_chat_state,
|
||||
)
|
||||
from src.webui.routers.plugin.progress import get_current_progress
|
||||
from src.webui.routers.websocket.auth import verify_ws_token
|
||||
from src.webui.routers.websocket.manager import websocket_manager
|
||||
|
||||
logger = get_logger("webui.unified_ws")
|
||||
router = APIRouter()
|
||||
_background_tasks: Set["asyncio.Task[None]"] = set()
|
||||
|
||||
|
||||
def _build_error(code: str, message: str) -> Dict[str, Any]:
|
||||
"""构建统一错误响应体。
|
||||
|
||||
Args:
|
||||
code: 错误码。
|
||||
message: 错误描述。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 统一错误对象。
|
||||
"""
|
||||
return {
|
||||
"code": code,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
|
||||
def _get_request_data(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""从客户端消息中提取数据字段。
|
||||
|
||||
Args:
|
||||
message: 客户端消息。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 标准化后的数据字典。
|
||||
"""
|
||||
data = message.get("data", {})
|
||||
if isinstance(data, dict):
|
||||
return cast(Dict[str, Any], data)
|
||||
return {}
|
||||
|
||||
|
||||
def _track_background_task(task: "asyncio.Task[None]") -> None:
|
||||
"""登记后台任务并在完成后自动清理。
|
||||
|
||||
Args:
|
||||
task: 后台协程任务。
|
||||
"""
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
|
||||
async def authenticate_websocket_connection(websocket: WebSocket, token: Optional[str]) -> bool:
|
||||
"""校验统一 WebSocket 连接的认证状态。
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket 对象。
|
||||
token: 可选的一次性握手 Token。
|
||||
|
||||
Returns:
|
||||
bool: 认证通过时返回 ``True``。
|
||||
"""
|
||||
if token and verify_ws_token(token):
|
||||
logger.debug("统一 WebSocket 使用临时 token 认证成功")
|
||||
return True
|
||||
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
logger.debug("统一 WebSocket 使用 Cookie 认证成功")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def _handle_logs_subscribe(connection_id: str, request_id: Optional[str], data: Dict[str, Any]) -> None:
|
||||
"""处理日志域订阅请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
request_id: 请求 ID。
|
||||
data: 订阅参数。
|
||||
"""
|
||||
replay_limit = int(data.get("replay", 100) or 100)
|
||||
replay_limit = max(0, min(replay_limit, 500))
|
||||
websocket_manager.subscribe(connection_id, domain="logs", topic="main")
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"domain": "logs", "topic": "main"},
|
||||
)
|
||||
await websocket_manager.send_event(
|
||||
connection_id,
|
||||
domain="logs",
|
||||
event="snapshot",
|
||||
topic="main",
|
||||
data={"entries": load_recent_logs(limit=replay_limit)},
|
||||
)
|
||||
|
||||
|
||||
async def _handle_plugin_progress_subscribe(connection_id: str, request_id: Optional[str]) -> None:
|
||||
"""处理插件进度域订阅请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
request_id: 请求 ID。
|
||||
"""
|
||||
websocket_manager.subscribe(connection_id, domain="plugin_progress", topic="main")
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"domain": "plugin_progress", "topic": "main"},
|
||||
)
|
||||
await websocket_manager.send_event(
|
||||
connection_id,
|
||||
domain="plugin_progress",
|
||||
event="snapshot",
|
||||
topic="main",
|
||||
data={"progress": get_current_progress()},
|
||||
)
|
||||
|
||||
|
||||
async def _handle_subscribe(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理主题订阅请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
domain = str(message.get("domain") or "").strip()
|
||||
topic = str(message.get("topic") or "").strip()
|
||||
data = _get_request_data(message)
|
||||
|
||||
if domain == "logs" and topic == "main":
|
||||
await _handle_logs_subscribe(connection_id, request_id, data)
|
||||
return
|
||||
|
||||
if domain == "plugin_progress" and topic == "main":
|
||||
await _handle_plugin_progress_subscribe(connection_id, request_id)
|
||||
return
|
||||
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("unsupported_subscription", f"不支持的订阅目标: {domain}:{topic}"),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_unsubscribe(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理主题退订请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
domain = str(message.get("domain") or "").strip()
|
||||
topic = str(message.get("topic") or "").strip()
|
||||
|
||||
if not domain or not topic:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("invalid_unsubscribe", "退订请求缺少 domain 或 topic"),
|
||||
)
|
||||
return
|
||||
|
||||
websocket_manager.unsubscribe(connection_id, domain=domain, topic=topic)
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"domain": domain, "topic": topic},
|
||||
)
|
||||
|
||||
|
||||
async def _open_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""打开一个逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
client_session_id = str(message.get("session") or "").strip()
|
||||
if not client_session_id:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("missing_session", "聊天会话打开请求缺少 session"),
|
||||
)
|
||||
return
|
||||
|
||||
data = _get_request_data(message)
|
||||
normalized_user_id = normalize_webui_user_id(cast(Optional[str], data.get("user_id")))
|
||||
current_user_name = str(data.get("user_name") or "WebUI用户")
|
||||
current_virtual_config = resolve_initial_virtual_identity(
|
||||
platform=cast(Optional[str], data.get("platform")),
|
||||
person_id=cast(Optional[str], data.get("person_id")),
|
||||
group_name=cast(Optional[str], data.get("group_name")),
|
||||
group_id=cast(Optional[str], data.get("group_id")),
|
||||
)
|
||||
restore = bool(data.get("restore"))
|
||||
session_id = f"{connection_id}:{client_session_id}"
|
||||
|
||||
async def send_chat_event(chat_message: Dict[str, Any]) -> None:
|
||||
"""将聊天消息封装为统一事件并发送。
|
||||
|
||||
Args:
|
||||
chat_message: 聊天消息体。
|
||||
"""
|
||||
event_name = str(chat_message.get("type") or "message")
|
||||
await websocket_manager.send_event(
|
||||
connection_id,
|
||||
domain="chat",
|
||||
event=event_name,
|
||||
session=client_session_id,
|
||||
data=chat_message,
|
||||
)
|
||||
|
||||
await chat_manager.connect(
|
||||
session_id=session_id,
|
||||
connection_id=connection_id,
|
||||
client_session_id=client_session_id,
|
||||
user_id=normalized_user_id,
|
||||
user_name=current_user_name,
|
||||
virtual_config=current_virtual_config,
|
||||
sender=send_chat_event,
|
||||
)
|
||||
websocket_manager.register_chat_session(connection_id, client_session_id, session_id)
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"session": client_session_id, "session_id": session_id},
|
||||
)
|
||||
await send_initial_chat_state(
|
||||
session_id=session_id,
|
||||
user_id=normalized_user_id,
|
||||
user_name=current_user_name,
|
||||
virtual_config=current_virtual_config,
|
||||
include_welcome=not restore,
|
||||
)
|
||||
|
||||
|
||||
async def _close_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""关闭一个逻辑聊天会话。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
client_session_id = str(message.get("session") or "").strip()
|
||||
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
|
||||
if session_id is None:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
|
||||
)
|
||||
return
|
||||
|
||||
chat_manager.disconnect(session_id)
|
||||
websocket_manager.unregister_chat_session(connection_id, client_session_id)
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"session": client_session_id},
|
||||
)
|
||||
|
||||
|
||||
async def _process_chat_message(connection_id: str, client_session_id: str, data: Dict[str, Any]) -> None:
|
||||
"""在后台处理聊天消息事件。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
client_session_id: 前端会话 ID。
|
||||
data: 客户端提交的消息数据。
|
||||
"""
|
||||
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
|
||||
if session_id is None:
|
||||
return
|
||||
|
||||
session_state = chat_manager.get_session(session_id)
|
||||
if session_state is None:
|
||||
return
|
||||
|
||||
next_user_name, next_virtual_config = await dispatch_chat_event(
|
||||
session_id=session_id,
|
||||
session_id_prefix=session_id[:8],
|
||||
data=data,
|
||||
current_user_name=session_state.user_name,
|
||||
normalized_user_id=session_state.user_id,
|
||||
current_virtual_config=session_state.virtual_config,
|
||||
)
|
||||
chat_manager.update_session_context(
|
||||
session_id=session_id,
|
||||
user_name=next_user_name,
|
||||
virtual_config=next_virtual_config,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_chat_message_send(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理聊天消息发送请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
client_session_id = str(message.get("session") or "").strip()
|
||||
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
|
||||
if session_id is None:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
|
||||
)
|
||||
return
|
||||
|
||||
data = _get_request_data(message)
|
||||
payload = {
|
||||
"type": "message",
|
||||
"content": data.get("content", ""),
|
||||
"user_name": data.get("user_name", ""),
|
||||
}
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"accepted": True, "session": client_session_id},
|
||||
)
|
||||
_track_background_task(asyncio.create_task(_process_chat_message(connection_id, client_session_id, payload)))
|
||||
|
||||
|
||||
async def _handle_chat_nickname_update(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理聊天昵称更新请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
client_session_id = str(message.get("session") or "").strip()
|
||||
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
|
||||
if session_id is None:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
|
||||
)
|
||||
return
|
||||
|
||||
data = _get_request_data(message)
|
||||
session_state = chat_manager.get_session(session_id)
|
||||
if session_state is None:
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
|
||||
)
|
||||
return
|
||||
|
||||
next_user_name, next_virtual_config = await dispatch_chat_event(
|
||||
session_id=session_id,
|
||||
session_id_prefix=session_id[:8],
|
||||
data={
|
||||
"type": "update_nickname",
|
||||
"user_name": data.get("user_name", ""),
|
||||
},
|
||||
current_user_name=session_state.user_name,
|
||||
normalized_user_id=session_state.user_id,
|
||||
current_virtual_config=session_state.virtual_config,
|
||||
)
|
||||
chat_manager.update_session_context(
|
||||
session_id=session_id,
|
||||
user_name=next_user_name,
|
||||
virtual_config=next_virtual_config,
|
||||
)
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=True,
|
||||
data={"session": client_session_id, "user_name": next_user_name},
|
||||
)
|
||||
|
||||
|
||||
async def _handle_chat_call(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理聊天域调用请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
method = str(message.get("method") or "").strip()
|
||||
|
||||
if method == "session.open":
|
||||
await _open_chat_session(connection_id, message)
|
||||
return
|
||||
|
||||
if method == "session.close":
|
||||
await _close_chat_session(connection_id, message)
|
||||
return
|
||||
|
||||
if method == "message.send":
|
||||
await _handle_chat_message_send(connection_id, message)
|
||||
return
|
||||
|
||||
if method == "session.update_nickname":
|
||||
await _handle_chat_nickname_update(connection_id, message)
|
||||
return
|
||||
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("unsupported_method", f"不支持的聊天方法: {method}"),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_call(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理统一调用请求。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
domain = str(message.get("domain") or "").strip()
|
||||
if domain == "chat":
|
||||
await _handle_chat_call(connection_id, message)
|
||||
return
|
||||
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("unsupported_domain", f"不支持的调用域: {domain}"),
|
||||
)
|
||||
|
||||
|
||||
async def handle_client_message(connection_id: str, message: Dict[str, Any]) -> None:
|
||||
"""处理统一 WebSocket 客户端消息。
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID。
|
||||
message: 客户端消息。
|
||||
"""
|
||||
operation = str(message.get("op") or "").strip()
|
||||
request_id = cast(Optional[str], message.get("id"))
|
||||
|
||||
if operation == "ping":
|
||||
await websocket_manager.send_pong(connection_id, time.time())
|
||||
return
|
||||
|
||||
if operation == "subscribe":
|
||||
await _handle_subscribe(connection_id, message)
|
||||
return
|
||||
|
||||
if operation == "unsubscribe":
|
||||
await _handle_unsubscribe(connection_id, message)
|
||||
return
|
||||
|
||||
if operation == "call":
|
||||
await _handle_call(connection_id, message)
|
||||
return
|
||||
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=request_id,
|
||||
ok=False,
|
||||
error=_build_error("unsupported_operation", f"不支持的操作: {operation}"),
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
|
||||
"""统一 WebSocket 入口。
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket 对象。
|
||||
token: 可选的一次性握手 Token。
|
||||
"""
|
||||
if not await authenticate_websocket_connection(websocket, token):
|
||||
logger.warning("统一 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
connection_id = uuid.uuid4().hex
|
||||
await websocket_manager.connect(connection_id=connection_id, websocket=websocket)
|
||||
logger.info("统一 WebSocket 客户端已连接: connection=%s", connection_id)
|
||||
await websocket_manager.send_event(
|
||||
connection_id,
|
||||
domain="system",
|
||||
event="ready",
|
||||
data={"connection_id": connection_id, "timestamp": time.time()},
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw_message = await websocket.receive_json()
|
||||
if not isinstance(raw_message, dict):
|
||||
await websocket_manager.send_response(
|
||||
connection_id,
|
||||
request_id=None,
|
||||
ok=False,
|
||||
error=_build_error("invalid_message", "消息必须是 JSON 对象"),
|
||||
)
|
||||
continue
|
||||
await handle_client_message(connection_id, cast(Dict[str, Any], raw_message))
|
||||
except WebSocketDisconnect:
|
||||
logger.info("统一 WebSocket 客户端已断开: connection=%s", connection_id)
|
||||
except Exception as exc:
|
||||
logger.error(f"统一 WebSocket 处理失败: {exc}")
|
||||
finally:
|
||||
chat_manager.disconnect_connection(connection_id)
|
||||
await websocket_manager.disconnect(connection_id)
|
||||
@@ -19,11 +19,11 @@ from src.webui.routers.jargon import router as jargon_router
|
||||
from src.webui.routers.memory import router as memory_router
|
||||
from src.webui.routers.model import router as model_router
|
||||
from src.webui.routers.person import router as person_router
|
||||
from src.webui.routers.plugin import get_progress_router
|
||||
from src.webui.routers.plugin import router as plugin_router
|
||||
from src.webui.routers.statistics import router as statistics_router
|
||||
from src.webui.routers.system import router as system_router
|
||||
from src.webui.routers.websocket.auth import router as ws_auth_router
|
||||
from src.webui.routers.websocket.unified import router as unified_ws_router
|
||||
|
||||
logger = get_logger("webui.api")
|
||||
|
||||
@@ -44,8 +44,6 @@ router.include_router(jargon_router)
|
||||
router.include_router(emoji_router)
|
||||
# 注册插件管理路由
|
||||
router.include_router(plugin_router)
|
||||
# 注册插件进度 WebSocket 路由
|
||||
router.include_router(get_progress_router())
|
||||
# 注册系统控制路由
|
||||
router.include_router(system_router)
|
||||
# 注册模型列表获取路由
|
||||
@@ -54,6 +52,8 @@ router.include_router(model_router)
|
||||
router.include_router(memory_router)
|
||||
# 注册 WebSocket 认证路由
|
||||
router.include_router(ws_auth_router)
|
||||
# 注册统一 WebSocket 路由
|
||||
router.include_router(unified_ws_router)
|
||||
|
||||
|
||||
class TokenVerifyRequest(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user