merge: 同步 upstream/r-dev 并解决冲突

This commit is contained in:
DawnARC
2026-04-03 19:56:45 +08:00
186 changed files with 14212 additions and 6705 deletions

View File

@@ -1,25 +1,28 @@
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from rich.traceback import install
from sqlmodel import select
from typing import Any, Dict, List, Optional, Tuple
import asyncio
import hashlib
import heapq
import Levenshtein
import random
import re
from src.common.logger import get_logger
from rich.traceback import install
from sqlmodel import select
import Levenshtein
from src.common.data_models.image_data_model import MaiEmoji
from src.common.database.database_model import Images, ImageType
from src.common.database.database import get_db_session, get_db_session_manual
from src.common.utils.utils_image import ImageUtils
from src.prompt.prompt_manager import prompt_manager
from src.config.config import config_manager, global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMImageOptions
from src.common.database.database import get_db_session, get_db_session_manual
from src.common.database.database_model import Images, ImageType
from src.common.logger import get_logger
from src.common.utils.utils_image import ImageUtils
from src.config.config import config_manager, global_config
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
logger = get_logger("emoji")
@@ -33,6 +36,171 @@ EMOJI_REGISTERED_DIR = DATA_DIR / "emoji_registered" # 已注册的表情包注
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册表情包系统内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
emoji_schema = {
"type": "object",
"description": "当前表情包的序列化信息,主要包含 file_hash、description、emotions 等字段。",
}
string_array_schema = {
"type": "array",
"items": {"type": "string"},
}
return registry.register_hook_specs(
[
HookSpec(
name="emoji.maisaka.before_select",
description="Maisaka 表情发送工具选择表情前触发,可改写情绪、上下文和采样参数,或中止本次选择。",
parameters_schema=build_object_schema(
{
"stream_id": {"type": "string", "description": "目标会话 ID。"},
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
"context_texts": {
**string_array_schema,
"description": "最近聊天上下文文本列表。",
},
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
"abort_message": {
"type": "string",
"description": "当 Hook 主动中止时可附带的失败提示。",
},
},
required=["stream_id", "requested_emotion", "reasoning", "context_texts", "sample_size"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="emoji.maisaka.after_select",
description="Maisaka 已选出表情后触发,可替换选中的表情哈希、补充匹配情绪,或中止发送。",
parameters_schema=build_object_schema(
{
"stream_id": {"type": "string", "description": "目标会话 ID。"},
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
"context_texts": {
**string_array_schema,
"description": "最近聊天上下文文本列表。",
},
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
"selected_emoji": emoji_schema,
"selected_emoji_hash": {"type": "string", "description": "选中的表情哈希。"},
"matched_emotion": {"type": "string", "description": "最终命中的情绪标签。"},
"abort_message": {
"type": "string",
"description": "当 Hook 主动中止时可附带的失败提示。",
},
},
required=[
"stream_id",
"requested_emotion",
"reasoning",
"context_texts",
"sample_size",
"matched_emotion",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="emoji.register.after_build_description",
description="表情包描述生成并通过内容审查后触发,可改写描述文本或拒绝本次注册。",
parameters_schema=build_object_schema(
{
"emoji": emoji_schema,
"description": {"type": "string", "description": "当前生成出的表情包描述。"},
"image_format": {"type": "string", "description": "表情图片格式。"},
},
required=["emoji", "description", "image_format"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="emoji.register.after_build_emotion",
description="表情包情绪标签生成完成后触发,可改写标签列表或拒绝本次注册。",
parameters_schema=build_object_schema(
{
"emoji": emoji_schema,
"description": {"type": "string", "description": "当前表情包描述。"},
"emotions": {
**string_array_schema,
"description": "当前生成出的情绪标签列表。",
},
},
required=["emoji", "description", "emotions"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
]
)
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, Any]]:
"""将表情包对象序列化为 Hook 可传输载荷。
Args:
emoji: 待序列化的表情包对象。
Returns:
Optional[Dict[str, Any]]: 序列化后的字典;当表情为空时返回 ``None``。
"""
if emoji is None:
return None
return {
"file_hash": str(emoji.file_hash or "").strip(),
"file_name": emoji.file_name,
"full_path": str(emoji.full_path),
"description": emoji.description,
"emotions": [str(item).strip() for item in emoji.emotion if str(item).strip()],
"query_count": int(emoji.query_count),
}
def _normalize_string_list(raw_values: Any) -> List[str]:
"""将任意列表值规范化为字符串列表。
Args:
raw_values: 待规范化的原始值。
Returns:
List[str]: 去空白后的字符串列表。
"""
if not isinstance(raw_values, list):
return []
return [str(item).strip() for item in raw_values if str(item).strip()]
def _ensure_directories() -> None:
"""确保表情包相关目录存在"""
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
@@ -642,6 +810,22 @@ class EmojiManager:
if "" in llm_response:
logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
hook_result = await _get_runtime_manager().invoke_hook(
"emoji.register.after_build_description",
emoji=_serialize_emoji_for_hook(target_emoji),
description=description,
image_format=image_format,
)
if hook_result.aborted:
logger.info(f"[构建描述] 表情包描述被 Hook 中止注册: {target_emoji.file_name}")
return False, target_emoji
normalized_description = str(hook_result.kwargs.get("description", description) or "").strip()
if not normalized_description:
logger.warning(f"[构建描述] Hook 返回空描述,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
description = normalized_description
target_emoji.description = description
logger.info(f"[构建描述] 成功为表情包构建描述: {target_emoji.description}")
return True, target_emoji
@@ -687,6 +871,23 @@ class EmojiManager:
elif len(emotions) > 2:
emotions = random.sample(emotions, 2)
hook_result = await _get_runtime_manager().invoke_hook(
"emoji.register.after_build_emotion",
emoji=_serialize_emoji_for_hook(target_emoji),
description=target_emoji.description,
emotions=list(emotions),
)
if hook_result.aborted:
logger.info(f"[构建情感标签] 表情包情感标签被 Hook 中止注册: {target_emoji.file_name}")
return False, target_emoji
raw_emotions = hook_result.kwargs.get("emotions")
if raw_emotions is not None:
emotions = _normalize_string_list(raw_emotions)
if not emotions:
logger.warning(f"[构建情感标签] Hook 返回空情绪标签,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
logger.info(f"[构建情感标签] 成功为表情包构建情感标签: {','.join(emotions)}")
target_emoji.emotion = emotions
return True, target_emoji

View File

@@ -0,0 +1,349 @@
"""Maisaka 表情工具内置能力。"""
from dataclasses import dataclass, field
from typing import Any, Optional, Sequence
import random
from src.chat.message_receive.chat_manager import chat_manager
from src.cli.maisaka_cli_sender import CLI_PLATFORM_NAME, render_cli_message
from src.common.data_models.image_data_model import MaiEmoji
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.logger import get_logger
from src.common.utils.utils_image import ImageUtils
from src.services import send_service
from .emoji_manager import _serialize_emoji_for_hook, emoji_manager, emoji_manager_emotion_judge_llm
logger = get_logger("emoji_maisaka_tool")
@dataclass(slots=True)
class MaisakaEmojiSendResult:
"""Maisaka 表情发送结果。"""
success: bool
message: str
emoji_base64: str = ""
description: str = ""
emotions: list[str] = field(default_factory=list)
requested_emotion: str = ""
matched_emotion: str = ""
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
def _coerce_positive_int(value: Any, default: int) -> int:
"""将任意值安全转换为正整数。
Args:
value: 待转换的值。
default: 转换失败时使用的默认值。
Returns:
int: 规范化后的正整数。
"""
try:
normalized_value = int(value)
except (TypeError, ValueError):
return default
return normalized_value if normalized_value > 0 else default
def _normalize_context_texts(context_texts: Sequence[str] | None) -> list[str]:
"""清洗 Hook 和调用链传入的上下文文本列表。
Args:
context_texts: 原始上下文文本序列。
Returns:
list[str]: 过滤空白后的上下文文本列表。
"""
if not context_texts:
return []
return [str(item).strip() for item in context_texts if str(item).strip()]
def _resolve_selected_emoji(raw_value: Any) -> Optional[MaiEmoji]:
"""根据 Hook 返回值解析目标表情包对象。
Args:
raw_value: Hook 返回的 ``selected_emoji`` 或 ``selected_emoji_hash``。
Returns:
Optional[MaiEmoji]: 命中的表情包对象;未命中时返回 ``None``。
"""
raw_hash: str = ""
if isinstance(raw_value, dict):
raw_hash = str(raw_value.get("file_hash") or raw_value.get("hash") or "").strip()
elif isinstance(raw_value, str):
raw_hash = raw_value.strip()
if not raw_hash:
return None
for emoji in emoji_manager.emojis:
if emoji.file_hash == raw_hash:
return emoji
return None
def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
"""提取并清洗单个表情的情绪标签。"""
return [str(item).strip() for item in emoji.emotion if str(item).strip()]
def _build_recent_context_text(context_texts: Sequence[str], max_items: int = 5) -> str:
"""构建供情绪判断使用的最近上下文文本。"""
normalized_items = [str(item).strip() for item in context_texts if str(item).strip()]
if not normalized_items:
return ""
return "\n".join(normalized_items[-max_items:])
async def _select_emoji_with_llm(
*,
sampled_emojis: Sequence[MaiEmoji],
reasoning: str,
context_text: str,
) -> tuple[MaiEmoji, str]:
"""让模型在采样表情中选择更合适的情绪标签。"""
emotion_map: dict[str, list[MaiEmoji]] = {}
for emoji in sampled_emojis:
for emotion in _normalize_emotions(emoji):
emotion_map.setdefault(emotion, []).append(emoji)
available_emotions = list(emotion_map.keys())
if not available_emotions:
return random.choice(list(sampled_emojis)), ""
prompt = (
"你正在为聊天场景选择一个最合适的表情包情绪标签。\n"
f"发送原因:{reasoning or '辅助表达当前语气和情绪'}\n"
f"最近聊天记录:\n{context_text or '(暂无额外上下文)'}\n\n"
"可选情绪标签如下:\n"
f"{chr(10).join(available_emotions)}\n\n"
"请只返回一个最匹配的情绪标签,不要解释。"
)
try:
llm_result = await emoji_manager_emotion_judge_llm.generate_response(
prompt,
options=LLMGenerationOptions(temperature=0.3, max_tokens=60),
)
chosen_emotion = (llm_result.response or "").strip().strip("\"'")
except Exception as exc:
logger.warning(f"使用 LLM 选择表情情绪失败,将回退为随机选择: {exc}")
chosen_emotion = ""
if chosen_emotion and chosen_emotion in emotion_map:
return random.choice(emotion_map[chosen_emotion]), chosen_emotion
return random.choice(list(sampled_emojis)), ""
async def select_emoji_for_maisaka(
*,
requested_emotion: str = "",
reasoning: str = "",
context_texts: Sequence[str] | None = None,
sample_size: int = 30,
) -> tuple[MaiEmoji | None, str]:
"""为 Maisaka 选择一个合适的表情。"""
available_emojis = list(emoji_manager.emojis)
if not available_emojis:
return None, ""
normalized_requested_emotion = requested_emotion.strip()
if normalized_requested_emotion:
matched_emojis = [
emoji
for emoji in available_emojis
if normalized_requested_emotion.lower() in (emotion.lower() for emotion in _normalize_emotions(emoji))
]
if matched_emojis:
return random.choice(matched_emojis), normalized_requested_emotion
sampled_emojis = random.sample(
available_emojis,
min(max(sample_size, 1), len(available_emojis)),
)
context_text = _build_recent_context_text(context_texts or [])
return await _select_emoji_with_llm(
sampled_emojis=sampled_emojis,
reasoning=reasoning,
context_text=context_text,
)
async def send_emoji_for_maisaka(
*,
stream_id: str,
requested_emotion: str = "",
reasoning: str = "",
context_texts: Sequence[str] | None = None,
) -> MaisakaEmojiSendResult:
"""为 Maisaka 选择并发送一个表情。"""
normalized_requested_emotion = requested_emotion.strip()
normalized_reasoning = reasoning.strip()
normalized_context_texts = _normalize_context_texts(context_texts)
sample_size = 30
before_select_result = await _get_runtime_manager().invoke_hook(
"emoji.maisaka.before_select",
stream_id=stream_id,
requested_emotion=normalized_requested_emotion,
reasoning=normalized_reasoning,
context_texts=list(normalized_context_texts),
sample_size=sample_size,
abort_message="表情选择已被 Hook 中止。",
)
if before_select_result.aborted:
abort_message = str(before_select_result.kwargs.get("abort_message") or "表情选择已被 Hook 中止。").strip()
return MaisakaEmojiSendResult(
success=False,
message=abort_message or "表情选择已被 Hook 中止。",
requested_emotion=normalized_requested_emotion,
)
before_select_kwargs = before_select_result.kwargs
normalized_requested_emotion = str(
before_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
).strip()
normalized_reasoning = str(before_select_kwargs.get("reasoning", normalized_reasoning) or "").strip()
if isinstance(before_select_kwargs.get("context_texts"), list):
normalized_context_texts = _normalize_context_texts(before_select_kwargs.get("context_texts"))
sample_size = _coerce_positive_int(before_select_kwargs.get("sample_size"), sample_size)
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
requested_emotion=normalized_requested_emotion,
reasoning=normalized_reasoning,
context_texts=normalized_context_texts,
sample_size=sample_size,
)
after_select_result = await _get_runtime_manager().invoke_hook(
"emoji.maisaka.after_select",
stream_id=stream_id,
requested_emotion=normalized_requested_emotion,
reasoning=normalized_reasoning,
context_texts=list(normalized_context_texts),
sample_size=sample_size,
selected_emoji=_serialize_emoji_for_hook(selected_emoji),
selected_emoji_hash=str(selected_emoji.file_hash or "").strip() if selected_emoji is not None else "",
matched_emotion=matched_emotion,
abort_message="表情发送已被 Hook 中止。",
)
if after_select_result.aborted:
abort_message = str(after_select_result.kwargs.get("abort_message") or "表情发送已被 Hook 中止。").strip()
return MaisakaEmojiSendResult(
success=False,
message=abort_message or "表情发送已被 Hook 中止。",
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
after_select_kwargs = after_select_result.kwargs
normalized_requested_emotion = str(
after_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
).strip()
matched_emotion = str(after_select_kwargs.get("matched_emotion", matched_emotion) or "").strip()
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji_hash"))
if override_emoji is None:
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji"))
if override_emoji is not None:
selected_emoji = override_emoji
if selected_emoji is None:
return MaisakaEmojiSendResult(
success=False,
message="当前表情包库中没有可用表情。",
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
try:
emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
if not emoji_base64:
raise ValueError("表情图片转换为 base64 失败")
except Exception as exc:
return MaisakaEmojiSendResult(
success=False,
message=f"发送表情包失败:{exc}",
description=selected_emoji.description.strip(),
emotions=_normalize_emotions(selected_emoji),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
try:
target_session = chat_manager.get_session_by_session_id(stream_id)
if target_session is not None and target_session.platform == CLI_PLATFORM_NAME:
preview_message = (
f"已发送表情包:{selected_emoji.description.strip()}"
if selected_emoji.description.strip()
else "[表情包]"
)
render_cli_message(preview_message)
sent = True
else:
sent = await send_service.emoji_to_stream(
emoji_base64=emoji_base64,
stream_id=stream_id,
storage_message=True,
set_reply=False,
reply_message=None,
)
except Exception as exc:
return MaisakaEmojiSendResult(
success=False,
message=f"发送表情包时发生异常:{exc}",
description=selected_emoji.description.strip(),
emotions=_normalize_emotions(selected_emoji),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
description = selected_emoji.description.strip()
emotions = _normalize_emotions(selected_emoji)
if not sent:
return MaisakaEmojiSendResult(
success=False,
message="发送表情包失败。",
description=description,
emotions=emotions,
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
emoji_manager.update_emoji_usage(selected_emoji)
success_message = (
f"已发送表情包:{description}(情绪:{', '.join(emotions)}"
if emotions
else f"已发送表情包:{description}"
)
return MaisakaEmojiSendResult(
success=True,
message=success_message,
emoji_base64=emoji_base64,
description=description,
emotions=emotions,
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)

View File

@@ -44,6 +44,12 @@ class ImageManager:
logger.info("图片管理器初始化完成")
def _get_image_record(self, image_hash: str) -> Optional[Images]:
"""根据哈希获取图片记录。"""
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1)
return session.exec(statement).first()
async def get_image_description(
self,
*,
@@ -76,9 +82,8 @@ class ImageManager:
hash_str = hashlib.sha256(image_bytes).hexdigest()
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1)
if record := session.exec(statement).first():
if record := self._get_image_record(hash_str):
if record.vlm_processed and record.description:
return record.description
except Exception as e:
logger.error(f"查询图片描述时发生错误: {e}")
@@ -86,12 +91,17 @@ class ImageManager:
if not image_bytes:
logger.warning("图片哈希值未找到,且未提供图片字节数据,返回无描述")
return ""
try:
await self.ensure_image_saved(image_bytes)
except Exception as e:
logger.error(f"保存图片文件时发生错误: {e}")
return ""
if not wait_for_build:
self._schedule_description_build(hash_str, image_bytes)
return ""
logger.info(f"图片描述未找到,哈希值: {hash_str},准备生成新描述")
try:
image = await self.save_image_and_process(image_bytes)
image = await self.build_image_description(image_bytes)
return image.description
except Exception as e:
logger.error(f"生成图片描述时发生错误: {e}")
@@ -120,7 +130,7 @@ class ImageManager:
"""
try:
logger.info(f"图片描述后台构建已开始,哈希值: {image_hash}")
await self.save_image_and_process(image_bytes)
await self.build_image_description(image_bytes)
logger.info(f"图片描述后台构建完成,哈希值: {image_hash}")
except Exception as exc:
logger.warning(f"图片描述后台构建失败,哈希值: {image_hash},错误: {exc}")
@@ -201,6 +211,7 @@ class ImageManager:
return False
record.description = image.description
record.last_used_time = datetime.now()
record.vlm_processed = image.vlm_processed
session.add(record)
logger.info(f"成功更新图片描述: {image.file_hash},新描述: {image.description}")
except Exception as e:
@@ -239,22 +250,13 @@ class ImageManager:
return False
return True
async def save_image_and_process(self, image_bytes: bytes) -> MaiImage:
"""
保存图片并生成描述
Args:
image_bytes (bytes): 图片的字节数据
Returns:
return (MaiImage): 包含图片信息的 MaiImage 对象
Raises:
Exception: 如果在保存或处理过程中发生错误
"""
async def ensure_image_saved(self, image_bytes: bytes) -> MaiImage:
"""先保存图片记录,确保后续可以按哈希回填图片内容。"""
hash_str = hashlib.sha256(image_bytes).hexdigest()
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=hash_str).limit(1)
statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1)
if record := session.exec(statement).first():
logger.info(f"图片已存在于数据库中,哈希值: {hash_str}")
record.last_used_time = datetime.now()
@@ -270,18 +272,38 @@ class ImageManager:
tmp_file_path = IMAGE_DIR / f"{hash_str}.tmp"
with tmp_file_path.open("wb") as f:
f.write(image_bytes)
mai_image = MaiImage(full_path=(IMAGE_DIR / f"{hash_str}.tmp"), image_bytes=image_bytes)
mai_image = MaiImage(full_path=tmp_file_path, image_bytes=image_bytes)
await mai_image.calculate_hash_format()
if not self.register_image_to_db(mai_image):
raise RuntimeError(f"保存图片记录到数据库失败: {hash_str}")
return mai_image
async def build_image_description(self, image_bytes: bytes) -> MaiImage:
"""在图片已保存的前提下生成或补齐图片描述。"""
mai_image = await self.ensure_image_saved(image_bytes)
if mai_image.vlm_processed and mai_image.description:
return mai_image
desc = await self._generate_image_description(image_bytes, mai_image.image_format)
mai_image.description = desc
mai_image.vlm_processed = True
try:
self.register_image_to_db(mai_image)
except Exception as e:
logger.error(f"保存新图片记录到数据库时发生错误: {e}")
raise e
if not self.update_image_description(mai_image):
raise RuntimeError(f"更新图片描述失败: {mai_image.file_hash}")
return mai_image
async def save_image_and_process(self, image_bytes: bytes) -> MaiImage:
"""
保存图片并生成描述
Args:
image_bytes (bytes): 图片的字节数据
Returns:
return (MaiImage): 包含图片信息的 MaiImage 对象
Raises:
Exception: 如果在保存或处理过程中发生错误
"""
return await self.build_image_description(image_bytes)
def cleanup_invalid_descriptions_in_db(self):
"""
清理数据库中无效的图片记录

View File

@@ -1,6 +1,7 @@
"""聊天消息入口与主链路调度。"""
from contextlib import suppress
from copy import deepcopy
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
import os
import traceback
@@ -13,12 +14,15 @@ from src.common.utils.utils_message import MessageUtils
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.platform_io.route_key_factory import RouteKeyFactory
from src.core.announcement_manager import global_announcement_manager
from src.plugin_runtime.component_query import component_query_service
from src.plugin_runtime.hook_payloads import deserialize_session_message, serialize_session_message
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from .message import SessionMessage
from .chat_manager import chat_manager
from .message import SessionMessage
# 定义日志配置
@@ -29,7 +33,137 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..
logger = get_logger("chat")
def register_chat_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册聊天消息主链内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="chat.receive.before_process",
description="在入站消息执行 `SessionMessage.process()` 之前触发,可拦截或改写消息。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "当前入站消息的序列化 SessionMessage。",
},
},
required=["message"],
),
default_timeout_ms=8000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="chat.receive.after_process",
description="在入站消息完成轻量预处理后触发,可改写文本、消息体或中止后续链路。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "已完成 `process()` 的序列化 SessionMessage。",
},
},
required=["message"],
),
default_timeout_ms=8000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="chat.command.before_execute",
description="在命令匹配成功、实际执行前触发,可拦截命令或改写命令上下文。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "当前命令消息的序列化 SessionMessage。",
},
"command_name": {
"type": "string",
"description": "命中的命令名称。",
},
"plugin_id": {
"type": "string",
"description": "命令所属插件 ID。",
},
"matched_groups": {
"type": "object",
"description": "命令正则命名捕获结果。",
},
},
required=["message", "command_name", "plugin_id", "matched_groups"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="chat.command.after_execute",
description="在命令执行结束后触发,可调整返回文本和是否继续主链处理。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "当前命令消息的序列化 SessionMessage。",
},
"command_name": {
"type": "string",
"description": "命令名称。",
},
"plugin_id": {
"type": "string",
"description": "命令所属插件 ID。",
},
"matched_groups": {
"type": "object",
"description": "命令正则命名捕获结果。",
},
"success": {
"type": "boolean",
"description": "命令执行是否成功。",
},
"response": {
"type": "string",
"description": "命令返回文本。",
},
"intercept_message_level": {
"type": "integer",
"description": "命令拦截等级。",
},
"continue_process": {
"type": "boolean",
"description": "命令执行后是否继续后续消息处理。",
},
},
required=[
"message",
"command_name",
"plugin_id",
"matched_groups",
"success",
"intercept_message_level",
"continue_process",
],
),
default_timeout_ms=5000,
allow_abort=False,
allow_kwargs_mutation=True,
),
]
)
class ChatBot:
"""聊天机器人入口协调器。"""
def __init__(self) -> None:
"""初始化聊天机器人入口。"""
@@ -44,6 +178,66 @@ class ChatBot:
self._started = True
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
"""将任意值安全转换为整数。
Args:
value: 待转换的值。
default: 转换失败时的默认值。
Returns:
int: 转换后的整数结果。
"""
try:
return int(value)
except (TypeError, ValueError):
return default
async def _invoke_message_hook(
self,
hook_name: str,
message: SessionMessage,
**kwargs: Any,
) -> tuple[HookDispatchResult, SessionMessage]:
"""触发携带会话消息的命名 Hook。
Args:
hook_name: 目标 Hook 名称。
message: 当前会话消息。
**kwargs: 需要附带传递的额外参数。
Returns:
tuple[HookDispatchResult, SessionMessage]: Hook 聚合结果以及可能被改写后的消息对象。
"""
hook_result = await self._get_runtime_manager().invoke_hook(
hook_name,
message=serialize_session_message(message),
**kwargs,
)
mutated_message = message
raw_message = hook_result.kwargs.get("message")
if raw_message is not None:
try:
mutated_message = deserialize_session_message(raw_message)
except Exception as exc:
logger.warning(f"Hook {hook_name} 返回的 message 无法反序列化,已忽略: {exc}")
return hook_result, mutated_message
async def _process_commands(self, message: SessionMessage) -> tuple[bool, Optional[str], bool]:
"""使用统一组件注册表处理命令。
@@ -71,6 +265,25 @@ class ChatBot:
return False, None, True
message.is_command = True
before_result, message = await self._invoke_message_hook(
"chat.command.before_execute",
message,
command_name=command_name,
plugin_id=plugin_name,
matched_groups=dict(matched_groups),
)
if before_result.aborted:
logger.info(f"命令 {command_name} 被 Hook 中止,跳过命令执行")
return True, None, False
hook_kwargs = before_result.kwargs
command_name = str(hook_kwargs.get("command_name", command_name) or command_name)
plugin_name = str(hook_kwargs.get("plugin_id", plugin_name) or plugin_name)
matched_groups = (
dict(hook_kwargs["matched_groups"])
if isinstance(hook_kwargs.get("matched_groups"), dict)
else dict(matched_groups)
)
# 获取插件配置
plugin_config = component_query_service.get_plugin_config(plugin_name)
@@ -82,27 +295,43 @@ class ChatBot:
plugin_config=plugin_config,
matched_groups=matched_groups,
)
self._mark_command_message(message, intercept_message_level)
# 记录命令执行结果
if success:
logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})")
else:
logger.warning(f"命令执行失败: {command_name} - {response}")
# 根据命令的拦截设置决定是否继续处理消息
return (
True,
response,
not bool(intercept_message_level),
) # 找到命令根据intercept_message决定是否继续
except Exception as e:
logger.error(f"执行命令时出错: {command_name} - {e}")
continue_process = not bool(intercept_message_level)
except Exception as exc:
logger.error(f"执行命令时出错: {command_name} - {exc}")
logger.error(traceback.format_exc())
success = False
response = str(exc)
intercept_message_level = 1
continue_process = False
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
return True, str(e), False # 出错时继续处理消息
after_result, message = await self._invoke_message_hook(
"chat.command.after_execute",
message,
command_name=command_name,
plugin_id=plugin_name,
matched_groups=dict(matched_groups),
success=success,
response=response,
intercept_message_level=intercept_message_level,
continue_process=continue_process,
)
after_kwargs = after_result.kwargs
success = bool(after_kwargs.get("success", success))
raw_response = after_kwargs.get("response", response)
response = None if raw_response is None else str(raw_response)
intercept_message_level = self._coerce_int(
after_kwargs.get("intercept_message_level", intercept_message_level),
intercept_message_level,
)
continue_process = bool(after_kwargs.get("continue_process", continue_process))
self._mark_command_message(message, intercept_message_level)
if success:
logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})")
else:
logger.warning(f"命令执行失败: {command_name} - {response}")
return True, response, continue_process
return False, None, True
@@ -138,6 +367,17 @@ class ChatBot:
cmd_result: Optional[str],
continue_process: bool,
) -> bool:
"""处理命令链结果并决定是否终止主消息链。
Args:
message: 当前命令消息。
cmd_result: 命令响应文本。
continue_process: 是否继续后续主链处理。
Returns:
bool: ``True`` 表示已经终止后续主链。
"""
if continue_process:
return False
@@ -145,9 +385,18 @@ class ChatBot:
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
return True
async def handle_notice_message(self, message: SessionMessage):
async def handle_notice_message(self, message: SessionMessage) -> bool:
"""处理通知类消息。
Args:
message: 当前通知消息。
Returns:
bool: 当前消息是否为通知消息。
"""
if message.message_id != "notice":
return
return False
message.is_notify = True
logger.debug("notice消息")
@@ -203,9 +452,12 @@ class ChatBot:
return True
async def echo_message_process(self, raw_data: Dict[str, Any]) -> None:
"""处理消息回送 ID 对应关系。
Args:
raw_data: 平台适配器上报的原始回送载荷。
"""
用于专门处理回送消息ID的函数
"""
message_data: Dict[str, Any] = raw_data.get("content", {})
if not message_data:
return
@@ -218,18 +470,10 @@ class ChatBot:
logger.debug(f"收到回送消息ID: {mmc_message_id} -> {actual_message_id}")
async def message_process(self, message_data: Dict[str, Any]) -> None:
"""处理转化后的统一格式消息
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
heart_flow模式使用思维流系统进行回复
- 包含思维流状态管理
- 在回复前进行观察和状态更新
- 回复后更新思维流状态
- 消息过滤
- 记忆激活
- 意愿计算
- 消息生成和发送
- 表情包处理
- 性能计时
"""处理统一格式的入站消息字典。
Args:
message_data: 适配器整理后的统一消息字典。
"""
try:
# 确保所有任务已启动
@@ -253,7 +497,13 @@ class ChatBot:
logger.error(f"预处理消息失败: {e}")
traceback.print_exc()
async def receive_message(self, message: SessionMessage):
async def receive_message(self, message: SessionMessage) -> None:
"""处理单条入站会话消息。
Args:
message: 待处理的会话消息。
"""
try:
group_info = message.message_info.group_info
user_info = message.message_info.user_info
@@ -272,6 +522,19 @@ class ChatBot:
)
message.session_id = session_id # 正确初始化session_id
before_process_result, message = await self._invoke_message_hook(
"chat.receive.before_process",
message,
)
if before_process_result.aborted:
logger.info(f"消息 {message.message_id} 在预处理前被 Hook 中止")
return
group_info = message.message_info.group_info
user_info = message.message_info.user_info
additional_config = message.message_info.additional_config
if isinstance(additional_config, dict):
account_id, scope = RouteKeyFactory.extract_components(additional_config)
# TODO: 修复事件预处理部分
# continue_flag, modified_message = await events_manager.handle_mai_events(
@@ -286,14 +549,24 @@ class ChatBot:
# if await self.handle_notice_message(message):
# pass
# 处理消息内容,识别表情包等二进制数据并转化为文本描述
if global_config.maisaka.direct_image_input:
message.maisaka_original_raw_message = deepcopy(message.raw_message) # type: ignore[attr-defined]
# 处理消息内容,识别表情包等二进制数据并转化为文本描述
# 如果 Maisaka 需要直接消费图片,会在后续构建 prompt 时按需回填图片二进制数据,
# 这里不再复制整条原始消息。
# 入站主链优先保证消息尽快入队,避免图片、表情包、语音分析阻塞适配器超时。
await message.process(
enable_heavy_media_analysis=False,
enable_voice_transcription=False,
)
after_process_result, message = await self._invoke_message_hook(
"chat.receive.after_process",
message,
)
if after_process_result.aborted:
logger.info(f"消息 {message.message_id} 在预处理后被 Hook 中止")
return
group_info = message.message_info.group_info
user_info = message.message_info.user_info
# 平台层的 @ 检测由底层 is_mentioned_bot_in_message 统一处理;此处不做用户名硬编码匹配

View File

@@ -1,11 +1,10 @@
import asyncio
from asyncio import Task
from typing import Dict, List, Sequence, Tuple
from rich.traceback import install
from sqlmodel import select
import asyncio
from src.common.logger import get_logger
from src.common.database.database import get_db_session
from src.common.database.database_model import Messages
@@ -36,6 +35,102 @@ class MsgIDMapping:
class SessionMessage(MaiMessage):
#便于调试的打印函数
def __str__(self) -> str:
"""返回适合日志输出的消息摘要。"""
return self.to_debug_string()
def __repr__(self) -> str:
"""返回适合调试场景的消息摘要。"""
return self.to_debug_string()
def to_debug_string(self) -> str:
"""构建包含引用信息的调试字符串。
Returns:
str: 适合记录日志的消息摘要。
"""
user_info = self.message_info.user_info
group_info = self.message_info.group_info
chat_type = "group" if group_info else "private"
group_id = group_info.group_id if group_info else None
group_name = group_info.group_name if group_info else None
component_summaries = [self._summarize_component(component) for component in self.raw_message.components]
raw_components = ", ".join(component_summaries) if component_summaries else "empty"
return (
"SessionMessage("
f"message_id={self.message_id!r}, "
f"platform={self.platform!r}, "
f"chat_type={chat_type!r}, "
f"group_id={group_id!r}, "
f"group_name={group_name!r}, "
f"user_id={user_info.user_id!r}, "
f"user_nickname={user_info.user_nickname!r}, "
f"user_cardname={user_info.user_cardname!r}, "
f"reply_to={self.reply_to!r}, "
f"processed_plain_text={self._truncate_text(self.processed_plain_text)}, "
f"raw_components=[{raw_components}]"
")"
)
@staticmethod
def _truncate_text(text: str | None, max_length: int = 120) -> str:
"""截断较长文本,避免日志过长。
Args:
text: 原始文本。
max_length: 最大保留长度。
Returns:
str: 截断后的文本表示。
"""
if text is None:
return "None"
normalized_text = text.replace("\r", "\\r").replace("\n", "\\n")
if len(normalized_text) <= max_length:
return repr(normalized_text)
return repr(f"{normalized_text[:max_length]}...")
def _summarize_component(self, component: StandardMessageComponents) -> str:
"""生成单个消息组件的调试摘要。
Args:
component: 消息组件对象。
Returns:
str: 组件摘要文本。
"""
if isinstance(component, TextComponent):
return f"Text(text={self._truncate_text(component.text, 80)})"
if isinstance(component, ImageComponent):
return f"Image(content={self._truncate_text(component.content or None, 60)})"
if isinstance(component, EmojiComponent):
return f"Emoji(content={self._truncate_text(component.content or None, 60)})"
if isinstance(component, AtComponent):
target_name = component.target_user_cardname or component.target_user_nickname or component.target_user_id
return f"At(target={target_name!r})"
if isinstance(component, VoiceComponent):
return f"Voice(content={self._truncate_text(component.content or None, 60)})"
if isinstance(component, ReplyComponent):
sender_name = (
component.target_message_sender_cardname
or component.target_message_sender_nickname
or component.target_message_sender_id
)
return (
"Reply("
f"target_message_id={component.target_message_id!r}, "
f"target_sender={sender_name!r}, "
f"target_content={self._truncate_text(component.target_message_content, 80)}"
")"
)
if isinstance(component, ForwardNodeComponent):
return f"ForwardNode(count={len(component.forward_components)})"
return f"{component.__class__.__name__}"
#便于调试的打印函数end
async def process(
self,
*,

View File

@@ -18,28 +18,29 @@ install(extra_lines=3)
logger = get_logger("sender")
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str], Optional[str]]] = None
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
# TODO: 重构完成后完成webui相关
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]:
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str], Optional[str]]:
"""获取 WebUI 聊天室广播器。
Returns:
Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 元组;
Tuple[Any, Optional[str], Optional[str]]: ``(chat_manager, platform_name, default_group_id)`` 元组;
若 WebUI 相关模块不可用,则元素会退化为 ``None``。
"""
global _webui_chat_broadcaster
if _webui_chat_broadcaster is None:
try:
from src.webui.routers.chat import WEBUI_CHAT_PLATFORM, chat_manager
from src.webui.routers.chat.service import WEBUI_CHAT_GROUP_ID
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM, WEBUI_CHAT_GROUP_ID)
except ImportError:
_webui_chat_broadcaster = (None, None)
_webui_chat_broadcaster = (None, None, None)
return _webui_chat_broadcaster
@@ -76,7 +77,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
try:
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
chat_manager, webui_platform = get_webui_chat_broadcaster()
chat_manager, webui_platform, default_group_id = get_webui_chat_broadcaster()
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
if is_webui_message and chat_manager is not None:
@@ -97,8 +98,9 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
message_type = "rich"
segments = message_segments
await chat_manager.broadcast(
{
await chat_manager.broadcast_to_group(
group_id=group_id or default_group_id or "",
message={
"type": "bot_message",
"content": message.processed_plain_text,
"message_type": message_type,
@@ -110,7 +112,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
"avatar": None,
"is_bot": True,
},
}
},
)
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库

View File

@@ -35,7 +35,7 @@ from src.services import llm_service as llm_api
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
from src.learners.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon
from src.learners.jargon_explainer_old import explain_jargon_in_context
from src.chat.utils.common_utils import TempMethodsExpression
init_memory_retrieval_sys()
@@ -688,39 +688,41 @@ class DefaultReplyer:
return None
def get_chat_prompt_for_chat(self, chat_id: str) -> str:
"""
根据聊天流ID获取匹配的额外prompt仅匹配group类型
Args:
chat_id: 聊天流ID哈希值
Returns:
str: 匹配的额外prompt内容如果没有匹配则返回空字符串
"""
if not global_config.experimental.chat_prompts:
"""根据聊天流 ID 获取匹配的额外 prompt。"""
if not global_config.chat.chat_prompts:
return ""
for chat_prompt_str in global_config.experimental.chat_prompts:
if not isinstance(chat_prompt_str, str):
for chat_prompt_item in global_config.chat.chat_prompts:
if hasattr(chat_prompt_item, "rule_type") and hasattr(chat_prompt_item, "prompt"):
if str(chat_prompt_item.rule_type or "").strip() != "group":
continue
config_chat_id = self._build_chat_uid(
str(chat_prompt_item.platform or "").strip(),
str(chat_prompt_item.item_id or "").strip(),
True,
)
prompt_content = str(chat_prompt_item.prompt or "").strip()
if config_chat_id == chat_id and prompt_content:
logger.debug(f"匹配到群聊 prompt 配置chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
return prompt_content
continue
# 解析配置字符串检查类型是否为group
parts = chat_prompt_str.split(":", 3)
if len(parts) != 4:
if not isinstance(chat_prompt_item, str):
continue
stream_type = parts[2]
# 只匹配group类型
if stream_type != "group":
# 兼容旧格式的 platform:id:type:prompt 配置字符串。
parts = chat_prompt_item.split(":", 3)
if len(parts) != 4 or parts[2] != "group":
continue
result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_str)
result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_item)
if result is None:
continue
config_chat_id, prompt_content = result
if config_chat_id == chat_id:
logger.debug(f"匹配到群聊prompt配置chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
logger.debug(f"匹配到群聊 prompt 配置chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
return prompt_content
return ""

View File

@@ -0,0 +1,453 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List, Optional, Tuple
import random
import time
from sqlmodel import select
from src.chat.message_receive.chat_manager import BotChatSession
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.data_models.reply_generation_data_models import (
GenerationMetrics,
LLMCompletionResult,
ReplyGenerationResult,
)
from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt
from src.config.config import global_config
from src.core.types import ActionInfo
from src.services.llm_service import LLMServiceClient
from src.chat.message_receive.message import SessionMessage
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.maisaka.context_messages import AssistantMessage, LLMContextMessage, ReferenceMessage, SessionBackedMessage, ToolResultMessage
from src.maisaka.message_adapter import parse_speaker_content
logger = get_logger("replyer")
@dataclass
class MaisakaReplyContext:
"""Maisaka replyer 使用的回复上下文。"""
expression_habits: str = ""
selected_expression_ids: List[int] = field(default_factory=list)
@dataclass
class _ExpressionRecord:
"""表达方式的轻量记录。"""
expression_id: Optional[int]
situation: str
style: str
class MaisakaReplyGenerator:
"""生成 Maisaka 的最终可见回复。"""
def __init__(
self,
chat_stream: Optional[BotChatSession] = None,
request_type: str = "maisaka_replyer",
) -> None:
self.chat_stream = chat_stream
self.request_type = request_type
self.express_model = LLMServiceClient(
task_name="replyer",
request_type=request_type,
)
self._personality_prompt = self._build_personality_prompt()
def _build_personality_prompt(self) -> str:
"""构建 replyer 使用的人设描述。"""
try:
bot_name = global_config.bot.nickname
alias_names = global_config.bot.alias_names
bot_aliases = f",也有人叫你{','.join(alias_names)}" if alias_names else ""
prompt_personality = global_config.personality.personality
if (
hasattr(global_config.personality, "states")
and global_config.personality.states
and hasattr(global_config.personality, "state_probability")
and global_config.personality.state_probability > 0
and random.random() < global_config.personality.state_probability
):
prompt_personality = random.choice(global_config.personality.states)
return f"你的名字是{bot_name}{bot_aliases},你{prompt_personality};"
except Exception as exc:
logger.warning(f"构建 Maisaka 人设提示词失败: {exc}")
return "你的名字是麦麦,你是一个活泼可爱的 AI 助手。"
@staticmethod
def _normalize_content(content: str, limit: int = 500) -> str:
normalized = " ".join((content or "").split())
if len(normalized) > limit:
return normalized[:limit] + "..."
return normalized
@staticmethod
def _extract_visible_assistant_reply(message: AssistantMessage) -> str:
del message
return ""
def _extract_guided_bot_reply(self, message: SessionBackedMessage) -> str:
speaker_name, body = parse_speaker_content(message.processed_plain_text.strip())
bot_nickname = global_config.bot.nickname.strip() or "Bot"
if speaker_name == bot_nickname:
return self._normalize_content(body.strip())
return ""
@staticmethod
def _split_user_message_segments(raw_content: str) -> List[tuple[Optional[str], str]]:
"""按说话人拆分用户消息。"""
segments: List[tuple[Optional[str], str]] = []
current_speaker: Optional[str] = None
current_lines: List[str] = []
for raw_line in raw_content.splitlines():
speaker_name, content_body = parse_speaker_content(raw_line)
if speaker_name is not None:
if current_lines:
segments.append((current_speaker, "\n".join(current_lines)))
current_speaker = speaker_name
current_lines = [content_body]
continue
current_lines.append(raw_line)
if current_lines:
segments.append((current_speaker, "\n".join(current_lines)))
return segments
def _build_system_prompt(
self,
reply_reason: str,
expression_habits: str = "",
) -> str:
"""构建 Maisaka replyer 使用的系统提示词。"""
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
try:
system_prompt = load_prompt(
"maisaka_replyer",
bot_name=global_config.bot.nickname,
time_block=f"当前时间:{current_time}",
identity=self._personality_prompt,
reply_style=global_config.personality.reply_style,
)
except Exception:
system_prompt = "你是一个友好的 AI 助手,请根据聊天记录自然回复。"
extra_sections: List[str] = []
if expression_habits.strip():
extra_sections.append(expression_habits.strip())
if reply_reason.strip():
extra_sections.append(f"【回复信息参考】\n{reply_reason}")
if not extra_sections:
return system_prompt
return f"{system_prompt}\n\n" + "\n\n".join(extra_sections)
def _build_reply_instruction(self) -> str:
"""构建追加在上下文末尾的回复指令。"""
return "请基于以上逐条对话消息,自然地继续回复。直接输出你要说的话,不要额外解释。"
def _build_history_messages(self, chat_history: List[LLMContextMessage]) -> List[Message]:
"""将 replyer 上下文拆成多条 LLM 消息。"""
bot_nickname = global_config.bot.nickname.strip() or "Bot"
default_user_name = global_config.maisaka.user_name.strip() or "User"
messages: List[Message] = []
for message in chat_history:
if isinstance(message, (ReferenceMessage, ToolResultMessage)):
continue
if isinstance(message, SessionBackedMessage):
guided_reply = self._extract_guided_bot_reply(message)
if guided_reply:
messages.append(
MessageBuilder().set_role(RoleType.Assistant).add_text_content(guided_reply).build()
)
continue
for speaker_name, content_body in self._split_user_message_segments(message.processed_plain_text):
content = self._normalize_content(content_body)
if not content:
continue
visible_speaker = speaker_name or default_user_name
if visible_speaker == bot_nickname:
messages.append(
MessageBuilder().set_role(RoleType.Assistant).add_text_content(content).build()
)
continue
user_content = f"[{visible_speaker}]{content}"
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(user_content).build())
continue
if isinstance(message, AssistantMessage):
visible_reply = self._extract_visible_assistant_reply(message)
if visible_reply:
messages.append(
MessageBuilder().set_role(RoleType.Assistant).add_text_content(visible_reply).build()
)
return messages
def _build_request_messages(
self,
chat_history: List[LLMContextMessage],
reply_reason: str,
expression_habits: str = "",
) -> List[Message]:
"""构建发给大模型的消息列表。"""
messages: List[Message] = []
system_prompt = self._build_system_prompt(
reply_reason=reply_reason,
expression_habits=expression_habits,
)
instruction = self._build_reply_instruction()
messages.append(MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build())
messages.extend(self._build_history_messages(chat_history))
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(instruction).build())
return messages
@staticmethod
def _build_request_prompt_preview(messages: List[Message]) -> str:
"""将消息列表转为便于调试的文本预览。"""
preview_lines: List[str] = []
for message in messages:
role_name = message.role.value.capitalize()
preview_lines.append(f"{role_name}: {message.get_text_content()}")
return "\n\n".join(preview_lines)
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
"""解析当前回复使用的会话 ID。"""
if stream_id:
return stream_id
if self.chat_stream is not None:
return self.chat_stream.session_id
return ""
async def _build_reply_context(
self,
chat_history: List[LLMContextMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
stream_id: Optional[str],
) -> MaisakaReplyContext:
"""在 replyer 内部构建表达习惯和黑话解释。"""
session_id = self._resolve_session_id(stream_id)
if not session_id:
logger.warning("构建 Maisaka 回复上下文失败:缺少会话标识")
return MaisakaReplyContext()
expression_habits, selected_expression_ids = self._build_expression_habits(
session_id=session_id,
chat_history=chat_history,
reply_message=reply_message,
reply_reason=reply_reason,
)
return MaisakaReplyContext(
expression_habits=expression_habits,
selected_expression_ids=selected_expression_ids,
)
def _build_expression_habits(
self,
session_id: str,
chat_history: List[LLMContextMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
) -> tuple[str, List[int]]:
"""查询并格式化适合当前会话的表达习惯。"""
del chat_history
del reply_message
del reply_reason
expression_records = self._load_expression_records(session_id)
if not expression_records:
return "", []
lines: List[str] = []
selected_ids: List[int] = []
for expression in expression_records:
if expression.expression_id is not None:
selected_ids.append(expression.expression_id)
lines.append(f"- 当{expression.situation}时,可以自然地用{expression.style}这种表达习惯。")
block = "【表达习惯参考】\n" + "\n".join(lines)
logger.info(
f"已构建 Maisaka 表达习惯: 会话标识={session_id} "
f"数量={len(selected_ids)} 表达编号={selected_ids!r}"
)
return block, selected_ids
def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]:
"""提取表达方式静态数据,避免 detached ORM 对象。"""
with get_db_session(auto_commit=False) as session:
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
if global_config.expression.expression_checked_only:
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
query = query.where(
(Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
expressions = session.exec(query.limit(5)).all()
return [
_ExpressionRecord(
expression_id=expression.id,
situation=expression.situation,
style=expression.style,
)
for expression in expressions
]
async def generate_reply_with_context(
self,
extra_info: str = "",
reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
chosen_actions: Optional[List[object]] = None,
from_plugin: bool = True,
stream_id: Optional[str] = None,
reply_message: Optional[SessionMessage] = None,
reply_time_point: Optional[float] = None,
think_level: int = 1,
unknown_words: Optional[List[str]] = None,
log_reply: bool = True,
chat_history: Optional[List[LLMContextMessage]] = None,
expression_habits: str = "",
selected_expression_ids: Optional[List[int]] = None,
) -> Tuple[bool, ReplyGenerationResult]:
"""结合上下文生成 Maisaka 的最终可见回复。"""
del available_actions
del chosen_actions
del extra_info
del from_plugin
del log_reply
del reply_time_point
del think_level
del unknown_words
result = ReplyGenerationResult()
if chat_history is None:
result.error_message = "聊天历史为空"
return False, result
logger.info(
f"Maisaka 回复器开始生成: 会话流标识={stream_id} 回复原因={reply_reason!r} "
f"历史消息数={len(chat_history)} 目标消息编号="
f"{reply_message.message_id if reply_message else None}"
)
filtered_history = [
message
for message in chat_history
if not isinstance(message, (ReferenceMessage, ToolResultMessage))
]
logger.debug(f"Maisaka 回复器过滤后历史消息数={len(filtered_history)}")
# Validate that express_model is properly initialized
if self.express_model is None:
logger.error("Maisaka 回复器的回复模型未初始化")
result.error_message = "回复模型尚未初始化"
return False, result
try:
reply_context = await self._build_reply_context(
chat_history=filtered_history,
reply_message=reply_message,
reply_reason=reply_reason or "",
stream_id=stream_id,
)
except Exception as exc:
import traceback
logger.error(f"Maisaka 回复器构建回复上下文失败: {exc}\n{traceback.format_exc()}")
result.error_message = f"构建回复上下文失败: {exc}"
return False, result
merged_expression_habits = expression_habits.strip() or reply_context.expression_habits
result.selected_expression_ids = (
list(selected_expression_ids)
if selected_expression_ids is not None
else list(reply_context.selected_expression_ids)
)
logger.info(
f"Maisaka 回复上下文构建完成: 会话流标识={stream_id} "
f"已选表达编号={result.selected_expression_ids!r}"
)
try:
request_messages = self._build_request_messages(
chat_history=filtered_history,
reply_reason=reply_reason or "",
expression_habits=merged_expression_habits,
)
except Exception as exc:
import traceback
logger.error(f"Maisaka 回复器构建提示词失败: {exc}\n{traceback.format_exc()}")
result.error_message = f"构建提示词失败: {exc}"
return False, result
prompt_preview = self._build_request_prompt_preview(request_messages)
def message_factory(_client: object) -> List[Message]:
return request_messages
result.completion.request_prompt = prompt_preview
if global_config.debug.show_replyer_prompt:
logger.info(f"\nMaisaka 回复器提示词:\n{prompt_preview}\n")
started_at = time.perf_counter()
try:
generation_result = await self.express_model.generate_response_with_messages(message_factory=message_factory)
except Exception as exc:
logger.exception("Maisaka 回复器调用失败")
result.error_message = str(exc)
result.metrics = GenerationMetrics(
overall_ms=round((time.perf_counter() - started_at) * 1000, 2),
)
return False, result
response_text = (generation_result.response or "").strip()
result.success = bool(response_text)
result.completion = LLMCompletionResult(
request_prompt=prompt_preview,
response_text=response_text,
reasoning_text=generation_result.reasoning or "",
model_name=generation_result.model_name or "",
tool_calls=generation_result.tool_calls or [],
)
result.metrics = GenerationMetrics(
overall_ms=round((time.perf_counter() - started_at) * 1000, 2),
)
if global_config.debug.show_replyer_reasoning and result.completion.reasoning_text:
logger.info(f"Maisaka 回复器思考内容:\n{result.completion.reasoning_text}")
if not result.success:
result.error_message = "回复器返回了空内容"
logger.warning("Maisaka 回复器返回了空内容")
return False, result
logger.info(
f"Maisaka 回复器生成成功: 回复文本={response_text!r} "
f"总耗时毫秒={result.metrics.overall_ms} "
f"已选表达编号={result.selected_expression_ids!r}"
)
result.text_fragments = [response_text]
return True, result

View File

@@ -0,0 +1,21 @@
from typing import Type
from src.config.config import global_config
def get_maisaka_replyer_class() -> Type[object]:
"""根据配置返回 Maisaka replyer 类。"""
generator_type = global_config.maisaka.replyer_generator_type
if generator_type == "multi":
from .maisaka_generator_multi import MaisakaReplyGenerator
return MaisakaReplyGenerator
from .maisaka_generator import MaisakaReplyGenerator
return MaisakaReplyGenerator
def get_maisaka_replyer_generator_type() -> str:
"""返回当前配置的 Maisaka replyer 生成器类型。"""
return global_config.maisaka.replyer_generator_type

File diff suppressed because it is too large Load Diff

View File

@@ -1,11 +1,14 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from src.chat.replyer.maisaka_replyer_factory import (
get_maisaka_replyer_class,
get_maisaka_replyer_generator_type,
)
from src.common.logger import get_logger
if TYPE_CHECKING:
from src.chat.replyer.group_generator import DefaultReplyer
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
from src.chat.replyer.private_generator import PrivateReplyer
logger = get_logger("ReplyerManager")
@@ -23,14 +26,15 @@ class ReplyerManager:
chat_id: Optional[str] = None,
request_type: str = "replyer",
replyer_type: str = "default",
) -> Optional["DefaultReplyer | MaisakaReplyGenerator | PrivateReplyer"]:
) -> Optional["DefaultReplyer | PrivateReplyer | Any"]:
"""按会话和 replyer 类型获取实例。"""
stream_id = chat_stream.session_id if chat_stream else chat_id
if not stream_id:
logger.warning("[ReplyerManager] 缺少 stream_id无法获取 replyer")
return None
cache_key = f"{replyer_type}:{stream_id}"
generator_type = get_maisaka_replyer_generator_type() if replyer_type == "maisaka" else ""
cache_key = f"{replyer_type}:{generator_type}:{stream_id}"
if cache_key in self._repliers:
logger.info(f"[ReplyerManager] 命中缓存 replyer: cache_key={cache_key}")
return self._repliers[cache_key]
@@ -47,10 +51,10 @@ class ReplyerManager:
try:
if replyer_type == "maisaka":
logger.info("[ReplyerManager] importing MaisakaReplyGenerator")
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
logger.info(f"[ReplyerManager] 选择 MaisakaReplyGenerator: generator_type={generator_type}")
maisaka_replyer_class = get_maisaka_replyer_class()
replyer = MaisakaReplyGenerator(
replyer = maisaka_replyer_class(
chat_stream=target_stream,
request_type=request_type,
)

View File

@@ -15,7 +15,6 @@ from src.common.database.database import get_db_session
from src.common.database.database_model import Messages, ModelUsage, OnlineTime, ToolRecord
from src.manager.async_task_manager import AsyncTask
from src.manager.local_store_manager import local_storage
from src.config.config import global_config
logger = get_logger("maibot_statistic")

View File

@@ -3,41 +3,22 @@ MaiSaka CLI and conversation loop.
"""
from datetime import datetime
from typing import Optional
import asyncio
import os
import time
from rich import box
from rich.markdown import Markdown
from rich.panel import Panel
from rich.text import Text
from src.know_u.knowledge import KnowledgeLearner, retrieve_relevant_knowledge
from src.know_u.knowledge_store import get_knowledge_store
from src.chat.heart_flow.heartflow_manager import heartflow_manager
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager
from src.chat.message_receive.message import SessionMessage
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.config.config import config_manager, global_config
from src.mcp_module import MCPManager
from src.mcp_module.host_llm_bridge import MCPHostLLMBridge
from src.maisaka.chat_loop_service import MaisakaChatLoopService
from src.maisaka.context_messages import (
AssistantMessage,
LLMContextMessage,
SessionBackedMessage,
ToolResultMessage,
)
from src.maisaka.message_adapter import format_speaker_content
from src.maisaka.tool_handlers import (
ToolHandlerContext,
handle_mcp_tool,
handle_stop,
handle_unknown_tool,
handle_wait,
)
from .maisaka_cli_sender import CLI_PLATFORM_NAME
from .console import console
from .input_reader import InputReader
@@ -45,41 +26,13 @@ from .input_reader import InputReader
class BufferCLI:
"""Maisaka 命令行交互入口。"""
_CLI_PLATFORM = CLI_PLATFORM_NAME
_CLI_USER_ID = "maisaka_user"
def __init__(self) -> None:
self._chat_loop_service: Optional[MaisakaChatLoopService] = None
self._reply_generator = MaisakaReplyGenerator()
self._reader = InputReader()
self._chat_history: Optional[list[LLMContextMessage]] = None
self._knowledge_store = get_knowledge_store()
self._knowledge_learner = KnowledgeLearner("maisaka_cli")
self._knowledge_min_messages_for_extraction = 10
self._knowledge_min_extraction_interval = 30
self._last_knowledge_extraction_time = 0.0
knowledge_stats = self._knowledge_store.get_stats()
if knowledge_stats["total_items"] > 0:
console.print(f"[success]知识库中已有 {knowledge_stats['total_items']} 条数据[/success]")
else:
console.print("[muted]知识库已初始化,当前没有数据[/muted]")
self._chat_start_time: Optional[datetime] = None
self._last_user_input_time: Optional[datetime] = None
self._last_assistant_response_time: Optional[datetime] = None
self._user_input_times: list[datetime] = []
self._mcp_manager: Optional[MCPManager] = None
self._mcp_host_bridge: Optional[MCPHostLLMBridge] = None
self._init_llm()
def _init_llm(self) -> None:
"""初始化 Maisaka 使用的聊天服务。"""
thinking_env = os.getenv("ENABLE_THINKING", "").strip().lower()
enable_thinking: Optional[bool] = True if thinking_env == "true" else False if thinking_env == "false" else None
_ = enable_thinking
self._chat_loop_service = MaisakaChatLoopService()
model_name = self._get_current_model_name()
console.print(f"[success]大模型服务已初始化[/success] [muted](模型: {model_name})[/muted]")
self._message_receiver = HeartFCMessageReceiver()
self._session: BotChatSession | None = None
@staticmethod
def _get_current_model_name() -> str:
@@ -92,354 +45,59 @@ class BufferCLI:
pass
return "未配置"
def _build_tool_context(self) -> ToolHandlerContext:
"""构建工具处理的共享上下文。"""
tool_context = ToolHandlerContext(
reader=self._reader,
user_input_times=self._user_input_times,
)
tool_context.last_user_input_time = self._last_user_input_time
return tool_context
def _show_banner(self) -> None:
"""渲染启动横幅。"""
banner = Text()
banner.append("MaiSaka", style="bold cyan")
banner.append(" v2.0\n", style="muted")
banner.append(f"模型: {self._get_current_model_name()}\n", style="muted")
banner.append("输入内容开始对话 | Ctrl+C 退出", style="muted")
console.print(Panel(banner, box=box.DOUBLE_EDGE, border_style="cyan", padding=(1, 2)))
console.print()
async def _start_chat(self, user_text: str) -> None:
"""追加用户输入并继续内部循环。"""
if self._chat_loop_service is None:
console.print("[warning]大模型服务尚未初始化,已跳过本次对话。[/warning]")
return
now = datetime.now()
self._last_user_input_time = now
self._user_input_times.append(now)
if self._chat_history is None:
self._chat_start_time = now
self._last_assistant_response_time = None
self._chat_history = self._chat_loop_service.build_chat_context(user_text)
self._trigger_knowledge_learning([self._build_cli_session_message(user_text, now)])
else:
self._chat_history.append(
self._build_cli_context_message(
user_text=user_text,
timestamp=now,
source_kind="user",
)
)
self._trigger_knowledge_learning([self._build_cli_session_message(user_text, now)])
await self._run_llm_loop(self._chat_history)
@staticmethod
def _build_cli_context_message(
def _build_cli_session_message(
*,
user_text: str,
timestamp: datetime,
source_kind: str = "user",
speaker_name: Optional[str] = None,
) -> SessionBackedMessage:
"""为 CLI 构造新的上下文消息。"""
resolved_speaker_name = speaker_name or global_config.maisaka.user_name.strip() or "用户"
visible_text = format_speaker_content(
resolved_speaker_name,
user_text,
timestamp,
)
planner_prefix = (
f"[时间]{timestamp.strftime('%H:%M:%S')}\n"
f"[用户]{resolved_speaker_name}\n"
"[用户群昵称]\n"
"[msg_id]\n"
"[发言内容]"
)
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
return SessionBackedMessage(
raw_message=MessageSequence([TextComponent(f"{planner_prefix}{user_text}")]),
visible_text=visible_text,
) -> SessionMessage:
"""构造一条供 heartflow 复用的 CLI 用户消息。"""
message = SessionMessage(
message_id=f"maisaka_cli_{int(timestamp.timestamp() * 1000)}",
timestamp=timestamp,
source_kind=source_kind,
platform=BufferCLI._CLI_PLATFORM,
)
@staticmethod
def _build_cli_session_message(user_text: str, timestamp: datetime) -> SessionMessage:
"""为 CLI 的知识学习构造兼容 SessionMessage。"""
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
from src.common.data_models.message_component_data_model import MessageSequence
message = SessionMessage(message_id=f"maisaka_cli_{int(timestamp.timestamp() * 1000)}", timestamp=timestamp, platform="maisaka")
user_name = global_config.maisaka.user_name.strip() or "用户"
message.message_info = MessageInfo(
user_info=UserInfo(
user_id="maisaka_user",
user_nickname=global_config.maisaka.user_name.strip() or "用户",
user_id=BufferCLI._CLI_USER_ID,
user_nickname=user_name,
user_cardname=None,
),
group_info=None,
additional_config={},
)
message.session_id = "maisaka_cli"
message.raw_message = MessageSequence([])
visible_text = format_speaker_content(
global_config.maisaka.user_name.strip() or "用户",
user_text,
timestamp,
)
message.raw_message.text(visible_text)
message.processed_plain_text = visible_text
message.display_message = visible_text
message.raw_message = MessageSequence([TextComponent(text=user_text)])
message.processed_plain_text = user_text
message.display_message = user_text
message.initialized = True
return message
def _trigger_knowledge_learning(self, messages: list[SessionMessage]) -> None:
""" CLI 会话中按批次触发 knowledge 学习"""
if not global_config.maisaka.enable_knowledge_module:
return
self._knowledge_learner.add_messages(messages)
elapsed = time.monotonic() - self._last_knowledge_extraction_time
if elapsed < self._knowledge_min_extraction_interval:
return
cache_size = self._knowledge_learner.get_cache_size()
if cache_size < self._knowledge_min_messages_for_extraction:
return
self._last_knowledge_extraction_time = time.monotonic()
asyncio.create_task(self._run_knowledge_learning())
async def _run_knowledge_learning(self) -> None:
"""后台执行 knowledge 学习,避免阻塞主对话。"""
try:
added_count = await self._knowledge_learner.learn()
if added_count > 0 and global_config.maisaka.show_thinking:
console.print(f"[muted]知识学习已完成,新增 {added_count} 条数据。[/muted]")
except Exception as exc:
console.print(f"[warning]知识学习失败:{exc}[/warning]")
async def _run_llm_loop(self, chat_history: list[LLMContextMessage]) -> None:
"""
Main inner loop for the Maisaka planner.
Each round may produce internal thoughts and optionally call tools:
- reply(msg_id): generate a visible reply for the current round
- no_reply(): skip visible output and continue the loop
- wait(seconds): wait for new user input
- stop(): stop the current inner loop and return to idle
"""
if self._chat_loop_service is None:
return
consecutive_errors = 0
last_had_tool_calls = True
while True:
if last_had_tool_calls:
tasks = []
status_text_parts = []
if global_config.maisaka.enable_knowledge_module:
tasks.append(("knowledge", retrieve_relevant_knowledge(self._chat_loop_service, chat_history)))
status_text_parts.append("知识库")
with console.status(
f"[info]{' + '.join(status_text_parts)} 分析中...[/info]",
spinner="dots",
):
results = await asyncio.gather(*[task for _, task in tasks], return_exceptions=True)
knowledge_analysis = ""
if global_config.maisaka.enable_knowledge_module:
knowledge_result = results[0] if results else None
if isinstance(knowledge_result, Exception):
console.print(f"[warning]知识分析失败:{knowledge_result}[/warning]")
elif isinstance(knowledge_result, str) and knowledge_result.strip():
knowledge_analysis = knowledge_result
if global_config.maisaka.show_thinking:
console.print(
Panel(
Markdown(knowledge_analysis),
title="知识",
border_style="bright_magenta",
padding=(0, 1),
style="dim",
)
)
if chat_history and isinstance(chat_history[-1], AssistantMessage) and chat_history[-1].source == "perception":
chat_history.pop()
perception_parts = []
if knowledge_analysis:
perception_parts.append(f"知识库\n{knowledge_analysis}")
if perception_parts:
chat_history.append(
AssistantMessage(
content="\n\n".join(perception_parts),
timestamp=datetime.now(),
source_kind="perception",
)
)
elif global_config.maisaka.show_thinking:
console.print("[muted]上一轮没有使用工具,本轮跳过模块分析。[/muted]")
with console.status("[info]正在思考...[/info]", spinner="dots"):
try:
response = await self._chat_loop_service.chat_loop_step(chat_history)
consecutive_errors = 0
except Exception as exc:
consecutive_errors += 1
console.print(f"[error]大模型调用失败:{exc}[/error]")
if consecutive_errors >= 3:
console.print("[error]连续失败次数过多,结束对话。[/error]\n")
break
continue
chat_history.append(response.raw_message)
self._last_assistant_response_time = datetime.now()
if global_config.maisaka.show_thinking and response.content:
console.print(
Panel(
Markdown(response.content),
title="思考",
border_style="dim",
padding=(1, 2),
style="dim",
)
)
if response.content and not response.tool_calls:
last_had_tool_calls = False
continue
if not response.tool_calls:
last_had_tool_calls = False
continue
should_stop = False
tool_context = self._build_tool_context()
for tool_call in response.tool_calls:
if tool_call.func_name == "stop":
await handle_stop(tool_call, chat_history)
should_stop = True
elif tool_call.func_name == "reply":
reply = await self._generate_visible_reply(chat_history, response.content or "")
chat_history.append(
ToolResultMessage(
content="已生成并记录可见回复。",
timestamp=datetime.now(),
tool_call_id=tool_call.call_id,
tool_name=tool_call.func_name,
)
)
chat_history.append(
self._build_cli_context_message(
user_text=reply,
timestamp=datetime.now(),
source_kind="guided_reply",
speaker_name=global_config.bot.nickname.strip() or "MaiSaka",
)
)
elif tool_call.func_name == "no_reply":
if global_config.maisaka.show_thinking:
console.print("[muted]本轮未发送可见回复。[/muted]")
chat_history.append(
ToolResultMessage(
content="本轮未发送可见回复。",
timestamp=datetime.now(),
tool_call_id=tool_call.call_id,
tool_name=tool_call.func_name,
)
)
elif tool_call.func_name == "wait":
tool_result = await handle_wait(tool_call, chat_history, tool_context)
if tool_context.last_user_input_time != self._last_user_input_time:
self._last_user_input_time = tool_context.last_user_input_time
if tool_result.startswith("[[QUIT]]"):
should_stop = True
elif self._mcp_manager and self._mcp_manager.is_mcp_tool(tool_call.func_name):
await handle_mcp_tool(tool_call, chat_history, self._mcp_manager)
else:
await handle_unknown_tool(tool_call, chat_history)
if should_stop:
console.print("[muted]对话已暂停,等待新的输入...[/muted]\n")
break
last_had_tool_calls = True
async def _init_mcp(self) -> None:
"""初始化 MCP 服务并注册暴露的工具。"""
self._mcp_host_bridge = MCPHostLLMBridge(
sampling_task_name=global_config.mcp.client.sampling.task_name,
async def _dispatch_input(self, user_text: str) -> None:
""" CLI 输入转发到 heartflow 路径"""
message = self._build_cli_session_message(
user_text=user_text,
timestamp=datetime.now(),
)
self._mcp_manager = await MCPManager.from_app_config(
global_config.mcp,
host_callbacks=self._mcp_host_bridge.build_callbacks(),
chat_manager.register_message(message)
self._session = await chat_manager.get_or_create_session(
platform=self._CLI_PLATFORM,
user_id=self._CLI_USER_ID,
)
if self._mcp_manager and self._chat_loop_service:
mcp_tools = self._mcp_manager.get_openai_tools()
if mcp_tools:
self._chat_loop_service.set_extra_tools(mcp_tools)
summary = self._mcp_manager.get_feature_summary()
console.print(
Panel(
f"已加载 {len(mcp_tools)} 个 MCP 工具。\n{summary}",
title="MCP 能力",
border_style="green",
padding=(0, 1),
)
)
async def _generate_visible_reply(self, chat_history: list[LLMContextMessage], latest_thought: str) -> str:
"""根据最新思考生成并输出可见回复。"""
if not latest_thought:
return ""
with console.status("[info]正在生成可见回复...[/info]", spinner="dots"):
success, result = await self._reply_generator.generate_reply_with_context(
reply_reason=latest_thought,
chat_history=chat_history,
)
if success and result.text_fragments:
reply = result.text_fragments[0]
else:
reply = "..."
console.print(
Panel(
Markdown(reply),
title="MaiSaka",
border_style="magenta",
padding=(1, 2),
)
)
return reply
await self._message_receiver.process_message(message)
async def run(self) -> None:
"""主交互循环。"""
if global_config.mcp.enable:
await self._init_mcp()
else:
console.print("[muted]MCP 已禁用mcp.enable=false[/muted]")
self._reader.start(asyncio.get_event_loop())
self._show_banner()
@@ -447,17 +105,17 @@ class BufferCLI:
while True:
console.print("[bold cyan]> [/bold cyan]", end="")
raw_input = await self._reader.get_line()
if raw_input is None:
console.print("\n[muted]再见[/muted]")
console.print("\n[muted]再见[/muted]")
break
raw_input = raw_input.strip()
if not raw_input:
user_text = raw_input.strip()
if not user_text:
continue
await self._start_chat(raw_input)
await self._dispatch_input(user_text)
finally:
if self._mcp_manager:
await self._mcp_manager.close()
self._mcp_host_bridge = None
if self._session is not None:
runtime = heartflow_manager.heartflow_chat_list.pop(self._session.session_id, None)
if runtime is not None:
await runtime.stop()

View File

@@ -0,0 +1,27 @@
"""Maisaka CLI 展示适配。"""
from rich.markdown import Markdown
from rich.panel import Panel
from src.common.logger import get_logger
from src.config.config import global_config
from .console import console
CLI_PLATFORM_NAME = "maisaka_cli"
logger = get_logger("maisaka_cli_sender")
def render_cli_message(content: str, *, title: str = "") -> None:
"""将 CLI 私聊实例的消息展示到终端。"""
preview_text = content.strip() or "..."
console.print(
Panel(
Markdown(preview_text),
title=title or global_config.bot.nickname.strip() or "MaiSaka",
border_style="magenta",
padding=(1, 2),
)
)
logger.info(f"[CLI] 已将消息输出到终端: content={preview_text!r}")

View File

@@ -21,7 +21,6 @@ from .official_configs import (
DatabaseConfig,
DebugConfig,
EmojiConfig,
ExperimentalConfig,
ExpressionConfig,
KeywordReactionConfig,
LPMMKnowledgeConfig,
@@ -56,7 +55,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config"
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
MMC_VERSION: str = "1.0.0"
CONFIG_VERSION: str = "8.2.0"
CONFIG_VERSION: str = "8.3.0"
MODEL_CONFIG_VERSION: str = "1.13.1"
logger = get_logger("config")
@@ -113,13 +112,10 @@ class Config(ConfigBase):
debug: DebugConfig = Field(default_factory=DebugConfig)
"""调试配置类"""
experimental: ExperimentalConfig = Field(default_factory=ExperimentalConfig)
"""实验性功能配置类"""
maim_message: MaimMessageConfig = Field(default_factory=MaimMessageConfig)
"""maim_message配置类"""
lpmm_knowledge: LPMMKnowledgeConfig = Field(default_factory=LPMMKnowledgeConfig)
lpmm_knowledge: LPMMKnowledgeConfig = Field(default_factory=LPMMKnowledgeConfig, repr=False)
"""LPMM知识库配置类"""
webui: WebUIConfig = Field(default_factory=WebUIConfig)

View File

@@ -30,7 +30,14 @@ def recursive_parse_item_to_table(
if value is None:
continue
if isinstance(value, ConfigBase):
config_table.add(config_item_name, recursive_parse_item_to_table(value, override_repr=override_repr))
config_table.add(
config_item_name,
recursive_parse_item_to_table(
value,
is_inline_table=is_inline_table,
override_repr=override_repr,
),
)
else:
config_table.add(
config_item_name, convert_field(config_item_name, config_item_info, value, override_repr=override_repr)

View File

@@ -268,11 +268,23 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
migrated_any = True
reasons.append("expression.manual_reflect_operator_id")
chat = _as_dict(data.get("chat"))
if chat is None:
chat = {}
data["chat"] = chat
mem = _as_dict(data.get("memory"))
if mem is not None:
if _migrate_target_item_list(mem, "global_memory_blacklist"):
migrated_any = True
reasons.append("memory.global_memory_blacklist")
for removed_key in (
"agent_timeout_seconds",
"global_memory",
"global_memory_blacklist",
"max_agent_iterations",
):
if removed_key in mem:
mem.pop(removed_key, None)
migrated_any = True
reasons.append(f"memory.{removed_key}_removed")
exp = _as_dict(data.get("experimental"))
if exp is not None:
@@ -280,7 +292,16 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
migrated_any = True
reasons.append("experimental.chat_prompts")
chat = _as_dict(data.get("chat"))
for key in ("private_plan_style", "group_chat_prompt", "private_chat_prompts", "chat_prompts"):
if key in exp and key not in chat:
chat[key] = exp[key]
migrated_any = True
reasons.append(f"experimental.{key}_moved_to_chat")
data.pop("experimental", None)
migrated_any = True
reasons.append("experimental_removed")
if chat is not None and "think_mode" in chat:
chat.pop("think_mode", None)
migrated_any = True

View File

@@ -244,15 +244,45 @@ class ChatConfig(ConfigBase):
},
)
"""每个聊天流最大保存的Plan/Reply日志数量超过此数量时会自动删除最老的日志"""
llm_quote: bool = Field(
default=False,
private_plan_style: str = Field(
default=(
"1.思考**所有**的可用的action中的**每个动作**是否符合当下条件,如果动作使用条件符合聊天内容就使用\n"
"2.如果相同的内容已经被执行,请不要重复执行\n"
"3.某句话如果已经被回复过,不要重复回复"
),
json_schema_extra={
"x-widget": "switch",
"x-icon": "quote",
"x-widget": "textarea",
"x-icon": "user",
},
)
"""是否在 reply action 中启用 quote 参数,启用后 LLM 可以控制是否引用消息"""
"""_wrap_私聊说话规则行为风格"""
group_chat_prompt: str = Field(
default="不要回复的太频繁!控制回复的频率,不要每个人的消息都回复,只回复你感兴趣的或者主动提及你的。",
json_schema_extra={
"x-widget": "textarea",
"x-icon": "users",
},
)
"""_wrap_群聊通用注意事项"""
private_chat_prompts: str = Field(
default="",
json_schema_extra={
"x-widget": "textarea",
"x-icon": "user",
},
)
"""_wrap_私聊通用注意事项"""
chat_prompts: list["ExtraPromptItem"] = Field(
default_factory=lambda: [],
json_schema_extra={
"x-widget": "custom",
"x-icon": "list",
},
)
"""_wrap_为指定聊天添加额外的 prompt 配置列表"""
enable_talk_value_rules: bool = Field(
default=True,
@@ -410,7 +440,6 @@ class MemoryConfig(ConfigBase):
},
)
"""是否在发送回复后自动提取并写回人物事实到长期记忆"""
chat_history_topic_check_message_threshold: int = Field(
default=80,
ge=1,
@@ -462,10 +491,6 @@ class MemoryConfig(ConfigBase):
def model_post_init(self, context: Optional[dict] = None) -> None:
"""验证配置值"""
if self.max_agent_iterations < 1:
raise ValueError(f"max_agent_iterations 必须至少为1当前值: {self.max_agent_iterations}")
if self.agent_timeout_seconds <= 0:
raise ValueError(f"agent_timeout_seconds 必须大于0当前值: {self.agent_timeout_seconds}")
if self.chat_history_topic_check_message_threshold < 1:
raise ValueError(
f"chat_history_topic_check_message_threshold 必须至少为1当前值: {self.chat_history_topic_check_message_threshold}"
@@ -1070,39 +1095,13 @@ class ExtraPromptItem(ConfigBase):
"""额外的prompt内容"""
def model_post_init(self, context: Optional[dict] = None) -> None:
if not self.platform and not self.item_id and not self.prompt:
return super().model_post_init(context)
if not self.platform or not self.item_id or not self.prompt:
raise ValueError("ExtraPromptItem 中 platform, id 和 prompt 不能为空")
return super().model_post_init(context)
class ExperimentalConfig(ConfigBase):
"""实验功能配置类"""
__ui_parent__ = "debug"
private_plan_style: str = Field(
default=(
"1.思考**所有**的可用的action中的**每个动作**是否符合当下条件,如果动作使用条件符合聊天内容就使用"
"2.如果相同的内容已经被执行,请不要重复执行"
"3.某句话如果已经被回复过,不要重复回复"
),
json_schema_extra={
"x-widget": "textarea",
"x-icon": "user",
},
)
"""_wrap_私聊说话规则行为风格实验性功能"""
chat_prompts: list[ExtraPromptItem] = Field(
default_factory=lambda: [],
json_schema_extra={
"x-widget": "custom",
"x-icon": "list",
},
)
"""_wrap_为指定聊天添加额外的prompt配置列表"""
class MaimMessageConfig(ConfigBase):
"""maim_message配置类"""
@@ -1473,7 +1472,6 @@ class MaiSakaConfig(ConfigBase):
__ui_label__ = "MaiSaka"
__ui_icon__ = "message-circle"
__ui_parent__ = "experimental"
enable_knowledge_module: bool = Field(
default=True,
@@ -1483,16 +1481,6 @@ class MaiSakaConfig(ConfigBase):
},
)
"""启用知识库模块"""
show_analyze_cognition_prompt: bool = Field(
default=False,
json_schema_extra={
"x-widget": "switch",
"x-icon": "terminal",
},
)
"""是否在 CLI 中显示 analyze_cognition 的 Prompt"""
show_thinking: bool = Field(
default=True,
json_schema_extra={
@@ -1529,6 +1517,15 @@ class MaiSakaConfig(ConfigBase):
)
"""是否将新接收的用户发言合并为单个用户消息"""
replyer_generator_type: Literal["legacy", "multi"] = Field(
default="legacy",
json_schema_extra={
"x-widget": "select",
"x-icon": "git-branch",
},
)
"""Maisaka replyer 生成器类型legacy旧版单 prompt/ multi多消息版"""
max_internal_rounds: int = Field(
default=6,
ge=1,
@@ -1568,24 +1565,14 @@ class MaiSakaConfig(ConfigBase):
)
"""工具筛选阶段最多保留的非内置工具数量"""
terminal_image_preview: bool = Field(
default=False,
terminal_image_display_mode: Literal["legacy", "path_link"] = Field(
default="legacy",
json_schema_extra={
"x-widget": "switch",
"x-widget": "select",
"x-icon": "image",
},
)
"""是否渲染低分辨率终端预览图片"""
terminal_image_preview_width: int = Field(
default=24,
ge=8,
json_schema_extra={
"x-widget": "input",
"x-icon": "columns",
},
)
"""Maisaka终端图片预览的字符宽度"""
"""图片展示模式legacy仅显示元信息/ path_link可点击本地路径"""
class MCPAuthorizationConfig(ConfigBase):
@@ -1969,6 +1956,129 @@ class MCPConfig(ConfigBase):
return super().model_post_init(context)
class PluginRuntimeRenderConfig(ConfigBase):
"""插件运行时浏览器渲染配置。"""
enabled: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "image",
},
)
"""是否启用插件运行时浏览器渲染能力"""
browser_ws_endpoint: str = Field(
default="",
json_schema_extra={
"x-widget": "input",
"x-icon": "link",
},
)
"""优先复用的现有 Chromium CDP 地址,可填写 ws/http 端点"""
executable_path: str = Field(
default="",
json_schema_extra={
"x-widget": "input",
"x-icon": "folder",
},
)
"""浏览器可执行文件路径,留空时自动探测本机 Chrome/Chromium"""
browser_install_root: str = Field(
default="data/playwright-browsers",
json_schema_extra={
"x-widget": "input",
"x-icon": "hard-drive",
},
)
"""Playwright 托管浏览器目录,自动下载 Chromium 时会复用该目录"""
headless: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "monitor",
},
)
"""是否以无头模式启动浏览器"""
launch_args: list[str] = Field(
default_factory=lambda: [
"--disable-gpu",
"--disable-dev-shm-usage",
"--disable-setuid-sandbox",
"--no-sandbox",
"--no-zygote",
],
json_schema_extra={
"x-widget": "custom",
"x-icon": "terminal",
},
)
"""浏览器启动参数列表"""
concurrency_limit: int = Field(
default=2,
ge=1,
json_schema_extra={
"x-widget": "number",
"x-icon": "layers",
},
)
"""同时允许进行的最大渲染任务数"""
startup_timeout_sec: float = Field(
default=20.0,
gt=0,
json_schema_extra={
"x-widget": "number",
"x-icon": "clock",
},
)
"""浏览器连接或启动超时时间(秒)"""
render_timeout_sec: float = Field(
default=15.0,
gt=0,
json_schema_extra={
"x-widget": "number",
"x-icon": "timer",
},
)
"""单次渲染默认超时时间(秒)"""
auto_download_chromium: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "download",
},
)
"""未检测到可用浏览器时,是否自动下载 Playwright Chromium"""
download_connection_timeout_sec: float = Field(
default=120.0,
gt=0,
json_schema_extra={
"x-widget": "number",
"x-icon": "cloud-lightning",
},
)
"""自动下载 Chromium 时的连接超时时间(秒)"""
restart_after_render_count: int = Field(
default=200,
ge=0,
json_schema_extra={
"x-widget": "number",
"x-icon": "refresh-cw",
},
)
"""累计渲染指定次数后自动重建本地浏览器0 表示关闭该策略"""
class PluginRuntimeConfig(ConfigBase):
"""插件运行时配置类"""
@@ -2031,3 +2141,6 @@ class PluginRuntimeConfig(ConfigBase):
自定义 IPC Socket 路径(仅 Linux/macOS 生效)
留空则自动生成临时路径
"""
render: PluginRuntimeRenderConfig = Field(default_factory=PluginRuntimeRenderConfig)
"""浏览器渲染能力配置"""

View File

@@ -1,22 +1,25 @@
from datetime import datetime
from sqlmodel import select
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
import asyncio
import difflib
import json
import re
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
from src.prompt.prompt_manager import prompt_manager
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.database.database import get_db_session
from src.common.data_models.expression_data_model import MaiExpression
from sqlmodel import select
from src.chat.utils.utils import is_bot_self
from src.common.data_models.expression_data_model import MaiExpression
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.common.utils.utils_message import MessageUtils
from src.config.config import global_config
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
from .expression_utils import check_expression_suitability, parse_expression_response
@@ -34,8 +37,122 @@ summary_model = LLMServiceClient(task_name="utils", request_type="expression.sum
check_model = LLMServiceClient(task_name="utils", request_type="expression.check")
def register_expression_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册表达方式系统内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="expression.select.before_select",
description="表达方式选择流程开始前触发,可改写会话上下文、选择参数或中止本次选择。",
parameters_schema=build_object_schema(
{
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
},
required=["chat_id", "chat_info", "max_num", "think_level"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.select.after_selection",
description="表达方式选择完成后触发,可改写最终选中的表达方式列表与 ID。",
parameters_schema=build_object_schema(
{
"chat_id": {"type": "string", "description": "当前聊天流 ID。"},
"chat_info": {"type": "string", "description": "用于选择表达方式的聊天上下文。"},
"max_num": {"type": "integer", "description": "最大可选表达方式数量。"},
"target_message": {"type": "string", "description": "当前目标回复消息文本。"},
"reply_reason": {"type": "string", "description": "规划器给出的回复理由。"},
"think_level": {"type": "integer", "description": "表达方式选择思考级别。"},
"selected_expressions": {
"type": "array",
"items": {"type": "object"},
"description": "当前已选中的表达方式列表。",
},
"selected_expression_ids": {
"type": "array",
"items": {"type": "integer"},
"description": "当前已选中的表达方式 ID 列表。",
},
},
required=[
"chat_id",
"chat_info",
"max_num",
"think_level",
"selected_expressions",
"selected_expression_ids",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.learn.after_extract",
description="表达方式学习解析出表达/黑话候选后触发,可改写候选集或直接终止本轮学习。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"message_count": {"type": "integer", "description": "本轮参与学习的消息数量。"},
"expressions": {
"type": "array",
"items": {"type": "object"},
"description": "解析出的表达方式候选列表。",
},
"jargon_entries": {
"type": "array",
"items": {"type": "object"},
"description": "解析出的黑话候选列表。",
},
},
required=["session_id", "message_count", "expressions", "jargon_entries"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="expression.learn.before_upsert",
description="表达方式写入数据库前触发,可改写情景/风格文本或跳过本条写入。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"situation": {"type": "string", "description": "即将写入的情景文本。"},
"style": {"type": "string", "description": "即将写入的风格文本。"},
},
required=["session_id", "situation", "style"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
]
)
class ExpressionLearner:
def __init__(self, session_id: str) -> None:
"""初始化表达方式学习器。
Args:
session_id: 当前会话 ID。
"""
self.session_id = session_id
# 学习锁,防止并发执行学习任务
@@ -44,6 +161,110 @@ class ExpressionLearner:
# 消息缓存
self._messages_cache: List["SessionMessage"] = []
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _serialize_expressions(expressions: List[Tuple[str, str, str]]) -> List[dict[str, str]]:
"""将表达方式候选序列化为 Hook 载荷。
Args:
expressions: 原始表达方式候选列表。
Returns:
List[dict[str, str]]: 序列化后的表达方式候选。
"""
return [
{
"situation": str(situation).strip(),
"style": str(style).strip(),
"source_id": str(source_id).strip(),
}
for situation, style, source_id in expressions
if str(situation).strip() and str(style).strip()
]
@staticmethod
def _deserialize_expressions(raw_expressions: Any) -> List[Tuple[str, str, str]]:
"""从 Hook 载荷恢复表达方式候选列表。
Args:
raw_expressions: Hook 返回的表达方式候选。
Returns:
List[Tuple[str, str, str]]: 恢复后的表达方式候选列表。
"""
if not isinstance(raw_expressions, list):
return []
normalized_expressions: List[Tuple[str, str, str]] = []
for raw_expression in raw_expressions:
if not isinstance(raw_expression, dict):
continue
situation = str(raw_expression.get("situation") or "").strip()
style = str(raw_expression.get("style") or "").strip()
source_id = str(raw_expression.get("source_id") or "").strip()
if not situation or not style:
continue
normalized_expressions.append((situation, style, source_id))
return normalized_expressions
@staticmethod
def _serialize_jargon_entries(jargon_entries: List[Tuple[str, str]]) -> List[dict[str, str]]:
"""将黑话候选序列化为 Hook 载荷。
Args:
jargon_entries: 原始黑话候选列表。
Returns:
List[dict[str, str]]: 序列化后的黑话候选列表。
"""
return [
{
"content": str(content).strip(),
"source_id": str(source_id).strip(),
}
for content, source_id in jargon_entries
if str(content).strip()
]
@staticmethod
def _deserialize_jargon_entries(raw_jargon_entries: Any) -> List[Tuple[str, str]]:
"""从 Hook 载荷恢复黑话候选列表。
Args:
raw_jargon_entries: Hook 返回的黑话候选列表。
Returns:
List[Tuple[str, str]]: 恢复后的黑话候选列表。
"""
if not isinstance(raw_jargon_entries, list):
return []
normalized_entries: List[Tuple[str, str]] = []
for raw_entry in raw_jargon_entries:
if not isinstance(raw_entry, dict):
continue
content = str(raw_entry.get("content") or "").strip()
source_id = str(raw_entry.get("source_id") or "").strip()
if not content:
continue
normalized_entries.append((content, source_id))
return normalized_entries
def add_messages(self, messages: List["SessionMessage"]) -> None:
"""添加消息到缓存"""
self._messages_cache.extend(messages)
@@ -52,8 +273,12 @@ class ExpressionLearner:
"""获取当前消息缓存的大小"""
return len(self._messages_cache)
async def learn(self, jargon_miner: Optional["JargonMiner"] = None):
"""学习主流程"""
async def learn(self, jargon_miner: Optional["JargonMiner"] = None) -> None:
"""执行表达方式学习主流程
Args:
jargon_miner: 可选的黑话学习器实例,用于同步处理黑话候选。
"""
if not self._messages_cache:
logger.debug("没有消息可供学习,跳过学习过程")
return
@@ -109,6 +334,25 @@ class ExpressionLearner:
logger.info(f"黑话提取数量超过 30 个(实际{len(jargon_entries)}个),放弃本次黑话学习")
jargon_entries = []
after_extract_result = await self._get_runtime_manager().invoke_hook(
"expression.learn.after_extract",
session_id=self.session_id,
message_count=len(self._messages_cache),
expressions=self._serialize_expressions(expressions),
jargon_entries=self._serialize_jargon_entries(jargon_entries),
)
if after_extract_result.aborted:
logger.info(f"{self.session_id} 的表达方式学习结果被 Hook 中止")
return
after_extract_kwargs = after_extract_result.kwargs
raw_expressions = after_extract_kwargs.get("expressions")
if raw_expressions is not None:
expressions = self._deserialize_expressions(raw_expressions)
raw_jargon_entries = after_extract_kwargs.get("jargon_entries")
if raw_jargon_entries is not None:
jargon_entries = self._deserialize_jargon_entries(raw_jargon_entries)
# 处理黑话条目,路由到 jargon_miner即使没有表达方式也要处理黑话
# TODO: 检测是否开启了
if jargon_entries:
@@ -135,6 +379,22 @@ class ExpressionLearner:
# 存储到数据库 Expression 表
for situation, style in learnt_expressions:
before_upsert_result = await self._get_runtime_manager().invoke_hook(
"expression.learn.before_upsert",
session_id=self.session_id,
situation=situation,
style=style,
)
if before_upsert_result.aborted:
logger.info(f"{self.session_id} 的表达方式写入被 Hook 跳过: situation={situation!r}")
continue
upsert_kwargs = before_upsert_result.kwargs
situation = str(upsert_kwargs.get("situation", situation) or "").strip()
style = str(upsert_kwargs.get("style", style) or "").strip()
if not situation or not style:
logger.info(f"{self.session_id} 的表达方式写入被 Hook 清空,已跳过")
continue
await self._upsert_expression_to_db(situation, style)
# ====== 黑话相关 ======

View File

@@ -1,27 +1,109 @@
from typing import Any, Dict, List, Optional, Tuple
import json
import time
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
from src.services.llm_service import LLMServiceClient
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.utils.utils_session import SessionUtils
from src.prompt.prompt_manager import prompt_manager
from src.learners.learner_utils_old import weighted_sample
from src.chat.utils.common_utils import TempMethodsExpression
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.learners.learner_utils_old import weighted_sample
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
logger = get_logger("expression_selector")
class ExpressionSelector:
def __init__(self):
def __init__(self) -> None:
"""初始化表达方式选择器。"""
self.llm_model = LLMServiceClient(
task_name="utils", request_type="expression.selector"
)
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
"""将任意值安全转换为整数。
Args:
value: 待转换的值。
default: 转换失败时的默认值。
Returns:
int: 转换后的整数结果。
"""
try:
return int(value)
except (TypeError, ValueError):
return default
@staticmethod
def _normalize_selected_expressions(raw_expressions: Any) -> List[Dict[str, Any]]:
"""从 Hook 载荷恢复表达方式选择结果。
Args:
raw_expressions: Hook 返回的表达方式列表。
Returns:
List[Dict[str, Any]]: 恢复后的表达方式列表。
"""
if not isinstance(raw_expressions, list):
return []
normalized_expressions: List[Dict[str, Any]] = []
for raw_expression in raw_expressions:
if not isinstance(raw_expression, dict):
continue
expression_id = raw_expression.get("id")
situation = str(raw_expression.get("situation") or "").strip()
style = str(raw_expression.get("style") or "").strip()
source_id = str(raw_expression.get("source_id") or "").strip()
if not isinstance(expression_id, int) or not situation or not style or not source_id:
continue
normalized_expression = dict(raw_expression)
normalized_expression["id"] = expression_id
normalized_expression["situation"] = situation
normalized_expression["style"] = style
normalized_expression["source_id"] = source_id
normalized_expressions.append(normalized_expression)
return normalized_expressions
@staticmethod
def _normalize_selected_expression_ids(raw_ids: Any, expressions: List[Dict[str, Any]]) -> List[int]:
"""规范化最终选中的表达方式 ID 列表。
Args:
raw_ids: Hook 返回的 ID 列表。
expressions: 当前最终表达方式列表。
Returns:
List[int]: 规范化后的 ID 列表。
"""
if isinstance(raw_ids, list):
normalized_ids = [item for item in raw_ids if isinstance(item, int)]
if normalized_ids:
return normalized_ids
return [expression["id"] for expression in expressions if isinstance(expression.get("id"), int)]
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
@@ -214,8 +296,7 @@ class ExpressionSelector:
reply_reason: Optional[str] = None,
think_level: int = 1,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
选择适合的表达方式使用classic模式随机选择+LLM选择
"""选择适合的表达方式。
Args:
chat_id: 聊天流ID
@@ -233,11 +314,60 @@ class ExpressionSelector:
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
before_select_result = await self._get_runtime_manager().invoke_hook(
"expression.select.before_select",
chat_id=chat_id,
chat_info=chat_info,
max_num=max_num,
target_message=target_message or "",
reply_reason=reply_reason or "",
think_level=think_level,
)
if before_select_result.aborted:
logger.info(f"聊天流 {chat_id} 的表达方式选择被 Hook 中止")
return [], []
before_select_kwargs = before_select_result.kwargs
chat_id = str(before_select_kwargs.get("chat_id", chat_id) or "").strip() or chat_id
chat_info = str(before_select_kwargs.get("chat_info", chat_info) or "")
max_num = max(self._coerce_int(before_select_kwargs.get("max_num"), max_num), 1)
raw_target_message = before_select_kwargs.get("target_message", target_message or "")
target_message = str(raw_target_message or "").strip() or None
raw_reply_reason = before_select_kwargs.get("reply_reason", reply_reason or "")
reply_reason = str(raw_reply_reason or "").strip() or None
think_level = self._coerce_int(before_select_kwargs.get("think_level"), think_level)
# 使用classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式think_level={think_level}")
return await self._select_expressions_classic(
selected_expressions, selected_ids = await self._select_expressions_classic(
chat_id, chat_info, max_num, target_message, reply_reason, think_level
)
after_selection_result = await self._get_runtime_manager().invoke_hook(
"expression.select.after_selection",
chat_id=chat_id,
chat_info=chat_info,
max_num=max_num,
target_message=target_message or "",
reply_reason=reply_reason or "",
think_level=think_level,
selected_expressions=[dict(item) for item in selected_expressions],
selected_expression_ids=list(selected_ids),
)
if after_selection_result.aborted:
logger.info(f"聊天流 {chat_id} 的表达方式选择结果被 Hook 中止")
return [], []
after_selection_kwargs = after_selection_result.kwargs
raw_selected_expressions = after_selection_kwargs.get("selected_expressions")
if raw_selected_expressions is not None:
selected_expressions = self._normalize_selected_expressions(raw_selected_expressions)
selected_ids = self._normalize_selected_expression_ids(
after_selection_kwargs.get("selected_expression_ids"),
selected_expressions,
)
if selected_expressions:
self.update_expressions_last_active_time(selected_expressions)
return selected_expressions, selected_ids
async def _select_expressions_classic(
self,

View File

@@ -1,5 +1,5 @@
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Set, TypedDict
from typing import Any, Callable, Dict, List, Optional, Set, TypedDict
import asyncio
import json
@@ -9,13 +9,15 @@ from json_repair import repair_json
from sqlmodel import select
from src.common.data_models.jargon_data_model import MaiJargon
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.database.database import get_db_session
from src.common.database.database_model import Jargon
from src.common.logger import get_logger
from src.config.config import global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
from .expression_utils import is_single_char_jargon
@@ -35,8 +37,140 @@ class JargonMeaningEntry(TypedDict):
meaning: str
def register_jargon_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册 jargon 系统内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="jargon.query.before_search",
description="Maisaka 黑话查询工具执行检索前触发,可改写词条列表、检索参数或直接中止。",
parameters_schema=build_object_schema(
{
"words": {
"type": "array",
"items": {"type": "string"},
"description": "准备查询的黑话词条列表。",
},
"session_id": {"type": "string", "description": "当前会话 ID。"},
"limit": {"type": "integer", "description": "单个词条的最大返回条数。"},
"case_sensitive": {"type": "boolean", "description": "是否大小写敏感。"},
"enable_fuzzy_fallback": {"type": "boolean", "description": "是否允许精确命中失败后回退模糊检索。"},
"abort_message": {"type": "string", "description": "Hook 主动中止时的失败提示。"},
},
required=["words", "session_id", "limit", "case_sensitive", "enable_fuzzy_fallback"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="jargon.query.after_search",
description="Maisaka 黑话查询工具完成检索后触发,可改写结果列表或中止返回。",
parameters_schema=build_object_schema(
{
"words": {
"type": "array",
"items": {"type": "string"},
"description": "实际查询的黑话词条列表。",
},
"session_id": {"type": "string", "description": "当前会话 ID。"},
"limit": {"type": "integer", "description": "单个词条的最大返回条数。"},
"case_sensitive": {"type": "boolean", "description": "是否大小写敏感。"},
"enable_fuzzy_fallback": {"type": "boolean", "description": "是否启用了模糊检索回退。"},
"results": {
"type": "array",
"items": {"type": "object"},
"description": "查询结果列表。",
},
"abort_message": {"type": "string", "description": "Hook 主动中止时的失败提示。"},
},
required=["words", "session_id", "limit", "case_sensitive", "enable_fuzzy_fallback", "results"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="jargon.extract.before_persist",
description="黑话条目准备写入数据库前触发,可改写去重后的条目列表或跳过本次持久化。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"session_name": {"type": "string", "description": "当前会话展示名称。"},
"entries": {
"type": "array",
"items": {"type": "object"},
"description": "即将持久化的黑话条目列表。",
},
},
required=["session_id", "session_name", "entries"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="jargon.inference.before_finalize",
description="黑话含义推断完成、写回数据库前触发,可改写最终判定与含义结果。",
parameters_schema=build_object_schema(
{
"session_id": {"type": "string", "description": "当前会话 ID。"},
"session_name": {"type": "string", "description": "当前会话展示名称。"},
"content": {"type": "string", "description": "当前黑话词条。"},
"count": {"type": "integer", "description": "当前词条累计命中次数。"},
"raw_content_list": {
"type": "array",
"items": {"type": "string"},
"description": "用于推断的原始上下文片段列表。",
},
"inference_with_context": {"type": "object", "description": "基于上下文的推断结果。"},
"inference_with_content_only": {"type": "object", "description": "仅基于词条内容的推断结果。"},
"comparison_result": {"type": "object", "description": "比较阶段输出结果。"},
"is_jargon": {"type": "boolean", "description": "当前推断是否判定为黑话。"},
"meaning": {"type": "string", "description": "当前推断出的黑话含义。"},
"is_complete": {"type": "boolean", "description": "当前是否已完成全部推断流程。"},
"last_inference_count": {"type": "integer", "description": "本次推断完成后应写回的 last_inference_count。"},
},
required=[
"session_id",
"session_name",
"content",
"count",
"raw_content_list",
"inference_with_context",
"inference_with_content_only",
"comparison_result",
"is_jargon",
"meaning",
"is_complete",
"last_inference_count",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
]
)
class JargonMiner:
def __init__(self, session_id: str, session_name: str) -> None:
"""初始化黑话学习器。
Args:
session_id: 当前会话 ID。
session_name: 当前会话展示名称。
"""
self.session_id = session_id
self.session_name = session_name
@@ -46,13 +180,92 @@ class JargonMiner:
# 黑话提取锁,防止并发执行
self._extraction_lock = asyncio.Lock()
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
"""将任意值安全转换为整数。
Args:
value: 待转换的值。
default: 转换失败时使用的默认值。
Returns:
int: 转换后的整数结果。
"""
try:
return int(value)
except (TypeError, ValueError):
return default
@staticmethod
def _serialize_jargon_entries(entries: List[JargonEntry]) -> List[Dict[str, object]]:
"""将黑话条目列表序列化为 Hook 可传输结构。
Args:
entries: 原始黑话条目列表。
Returns:
List[Dict[str, object]]: 序列化后的条目列表。
"""
return [
{
"content": str(entry["content"]).strip(),
"raw_content": sorted(str(item).strip() for item in entry["raw_content"] if str(item).strip()),
}
for entry in entries
if str(entry["content"]).strip()
]
@staticmethod
def _deserialize_jargon_entries(raw_entries: Any) -> List[JargonEntry]:
"""从 Hook 载荷恢复黑话条目列表。
Args:
raw_entries: Hook 返回的条目数据。
Returns:
List[JargonEntry]: 恢复后的黑话条目列表。
"""
if not isinstance(raw_entries, list):
return []
normalized_entries: List[JargonEntry] = []
for raw_entry in raw_entries:
if not isinstance(raw_entry, dict):
continue
content = str(raw_entry.get("content") or "").strip()
if not content:
continue
raw_content_values = raw_entry.get("raw_content")
raw_content: Set[str] = set()
if isinstance(raw_content_values, list):
raw_content = {str(item).strip() for item in raw_content_values if str(item).strip()}
normalized_entries.append({"content": content, "raw_content": raw_content})
return normalized_entries
def get_cached_jargons(self) -> List[str]:
"""获取缓存中的所有黑话列表"""
return list(self.cache.keys())
async def infer_meaning(self, jargon_obj: MaiJargon) -> None:
"""
对jargon进行含义推断
"""对黑话条目执行含义推断。
Args:
jargon_obj: 待推断的黑话数据对象。
"""
content = jargon_obj.content
# 解析raw_content列表
@@ -175,15 +388,45 @@ class JargonMiner:
is_similar = comparison_result.get("is_similar", False)
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
finalized_meaning = inference1.get("meaning", "") if is_jargon else ""
is_complete = (jargon_obj.count or 0) >= 100
last_inference_count = jargon_obj.count or 0
finalize_result = await self._get_runtime_manager().invoke_hook(
"jargon.inference.before_finalize",
session_id=self.session_id,
session_name=self.session_name,
content=content,
count=current_count,
raw_content_list=list(raw_content_list),
inference_with_context=dict(inference1),
inference_with_content_only=dict(inference2),
comparison_result=dict(comparison_result),
is_jargon=is_jargon,
meaning=finalized_meaning,
is_complete=is_complete,
last_inference_count=last_inference_count,
)
if finalize_result.aborted:
logger.info(f"jargon {content} 的推断结果被 Hook 中止写回")
return
finalize_kwargs = finalize_result.kwargs
is_jargon = bool(finalize_kwargs.get("is_jargon", is_jargon))
finalized_meaning = str(finalize_kwargs.get("meaning", finalized_meaning) or "").strip() if is_jargon else ""
is_complete = bool(finalize_kwargs.get("is_complete", is_complete))
last_inference_count = self._coerce_int(
finalize_kwargs.get("last_inference_count"),
last_inference_count,
)
# 更新数据库记录
jargon_obj.is_jargon = is_jargon
jargon_obj.meaning = inference1.get("meaning", "") if is_jargon else ""
jargon_obj.meaning = finalized_meaning
# 更新最后一次判定的count值避免重启后重复判定
jargon_obj.last_inference_count = jargon_obj.count or 0
jargon_obj.last_inference_count = last_inference_count
# 如果count>=100标记为完成不再进行推断
if (jargon_obj.count or 0) >= 100:
jargon_obj.is_complete = True
jargon_obj.is_complete = is_complete
try:
self._modify_jargon_entry(jargon_obj)
@@ -232,6 +475,22 @@ class JargonMiner:
merged_entries[content] = {"content": content, "raw_content": set(raw_list)}
uniq_entries: List[JargonEntry] = list(merged_entries.values())
before_persist_result = await self._get_runtime_manager().invoke_hook(
"jargon.extract.before_persist",
session_id=self.session_id,
session_name=self.session_name,
entries=self._serialize_jargon_entries(uniq_entries),
)
if before_persist_result.aborted:
logger.info(f"[{self.session_name}] 黑话提取结果被 Hook 中止,不写入数据库")
return
raw_hook_entries = before_persist_result.kwargs.get("entries")
if raw_hook_entries is not None:
uniq_entries = self._deserialize_jargon_entries(raw_hook_entries)
if not uniq_entries:
logger.info(f"[{self.session_name}] Hook 过滤后没有可写入的黑话条目")
return
saved = 0
updated = 0

View File

@@ -1,12 +1,12 @@
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
import asyncio
import base64
import binascii
import io
import json
import re
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
from json_repair import repair_json
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream
@@ -27,6 +27,7 @@ from openai.types.chat import (
)
from openai.types.shared_params.function_definition import FunctionDefinition
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from PIL import Image as PILImage
from src.common.logger import get_logger
from src.config.model_configs import APIProvider, ReasoningParseMode, ToolArgumentParseMode
@@ -62,6 +63,9 @@ from .base_client import (
logger = get_logger("llm_models")
SUPPORTED_OPENAI_IMAGE_FORMATS = {"jpeg", "png", "webp"}
"""OpenAI 兼容图片输入稳定支持的格式集合。"""
THINK_CONTENT_PATTERN = re.compile(
r"<think>(?P<think>.*?)</think>(?P<content>.*)|<think>(?P<think_unclosed>.*)|(?P<content_only>.+)",
re.DOTALL,
@@ -149,14 +153,85 @@ def _build_image_content_part(part: ImageMessagePart) -> ChatCompletionContentPa
Returns:
ChatCompletionContentPartImageParam: OpenAI 兼容的图片片段。
"""
normalized_image = _normalize_image_part_for_openai(part)
if normalized_image is None:
raise ValueError("图片数据无效,无法构建图片消息片段")
image_format, image_base64 = normalized_image
return {
"type": "image_url",
"image_url": {
"url": f"data:image/{part.normalized_image_format};base64,{part.image_base64}",
"url": f"data:image/{image_format};base64,{image_base64}",
},
}
def _normalize_image_part_for_openai(part: ImageMessagePart) -> Tuple[str, str] | None:
"""将图片片段规范化为 OpenAI 兼容格式。
Args:
part: 内部图片片段。
Returns:
Tuple[str, str] | None: `(image_format, image_base64)`;无法解析时返回 `None`。
"""
try:
image_bytes = base64.b64decode(part.image_base64, validate=True)
except (binascii.Error, ValueError) as exc:
logger.warning(f"图片 Base64 解码失败,已跳过该图片片段: {exc}")
return None
try:
with PILImage.open(io.BytesIO(image_bytes)) as image:
image_format = (image.format or part.normalized_image_format).lower()
if image_format in {"jpg", "jpeg"}:
image_format = "jpeg"
if image_format in SUPPORTED_OPENAI_IMAGE_FORMATS:
return image_format, part.image_base64
if image_format == "gif":
frame_count = getattr(image, "n_frames", 1)
frames: List[PILImage.Image] = []
durations: List[int] = []
for frame_index in range(frame_count):
image.seek(frame_index)
frame = image.copy()
if frame.mode not in {"RGB", "RGBA"}:
frame = frame.convert("RGBA")
frames.append(frame)
durations.append(int(image.info.get("duration", 100) or 100))
output_buffer = io.BytesIO()
save_kwargs: Dict[str, Any] = {
"format": "WEBP",
"save_all": True,
"append_images": frames[1:],
"duration": durations,
"loop": int(image.info.get("loop", 0) or 0),
}
if frame_count > 1:
save_kwargs["lossless"] = True
frames[0].save(output_buffer, **save_kwargs)
converted_base64 = base64.b64encode(output_buffer.getvalue()).decode("utf-8")
return "webp", converted_base64
image.seek(0)
normalized_image = image.copy()
if normalized_image.mode not in {"RGB", "RGBA"}:
normalized_image = normalized_image.convert("RGBA")
output_buffer = io.BytesIO()
normalized_image.save(output_buffer, format="PNG")
converted_base64 = base64.b64encode(output_buffer.getvalue()).decode("utf-8")
return "png", converted_base64
except Exception as exc:
logger.warning(f"图片内容无法被识别为有效图片,已跳过该图片片段: {exc}")
return None
def _convert_response_format(response_format: RespFormat | None) -> Any:
"""将内部响应格式转换为 OpenAI 兼容结构。
@@ -222,7 +297,21 @@ def _convert_user_message_content(message: Message) -> str | List[ChatCompletion
if isinstance(part, TextMessagePart):
content.append(_build_text_content_part(part.text))
continue
content.append(_build_image_content_part(part))
normalized_image = _normalize_image_part_for_openai(part)
if normalized_image is None:
content.append(_build_text_content_part("[图片内容不可用]"))
continue
image_format, image_base64 = normalized_image
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/{image_format};base64,{image_base64}",
},
}
)
return content
@@ -314,13 +403,15 @@ def _convert_tool_options(tool_options: List[ToolOption]) -> List[ChatCompletion
"""
converted_tools: List[ChatCompletionToolParam] = []
for tool_option in tool_options:
parameters_schema = cast(
Dict[str, object],
tool_option.parameters_schema or {"type": "object", "properties": {}},
)
function_schema: FunctionDefinition = {
"name": tool_option.name,
"description": tool_option.description,
"parameters": parameters_schema,
}
parameters_schema = tool_option.parameters_schema
if parameters_schema is not None:
function_schema["parameters"] = cast(Dict[str, object], parameters_schema)
converted_tools.append(
{
"type": "function",

View File

@@ -88,6 +88,15 @@ def _build_parameters_schema_from_property_map(property_map: Dict[str, Any]) ->
return parameters_schema
def _build_empty_object_schema() -> Dict[str, Any]:
"""构建无参工具使用的空对象 Schema。"""
return {
"type": "object",
"properties": {},
}
@dataclass(slots=True)
class ToolParam:
"""工具参数定义。"""
@@ -333,9 +342,8 @@ class ToolOption:
function_schema: Dict[str, Any] = {
"name": self.name,
"description": self.description,
"parameters": self.parameters_schema or _build_empty_object_schema(),
}
if self.parameters_schema is not None:
function_schema["parameters"] = self.parameters_schema
return {
"type": "function",
"function": function_schema,

View File

@@ -843,12 +843,6 @@ class LLMOrchestrator:
for _ in range(max_attempts):
model_info, api_provider, client = self._select_model(exclude_models=failed_models_this_request)
if self.request_type.startswith("maisaka_"):
logger.info(
f"LLMOrchestrator[{self.request_type}] 已选择模型 model={model_info.name} "
f"provider={api_provider.name} request_type={request_type.value}"
)
message_list = []
if message_factory:
message_list = message_factory(client)

View File

@@ -0,0 +1,71 @@
"""Maisaka 内置工具聚合入口。"""
from collections.abc import Awaitable, Callable
from typing import Dict, List, Optional
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
from .context import BuiltinToolRuntimeContext
from .no_reply import get_tool_spec as get_no_reply_tool_spec
from .no_reply import handle_tool as handle_no_reply_tool
from .query_jargon import get_tool_spec as get_query_jargon_tool_spec
from .query_jargon import handle_tool as handle_query_jargon_tool
from .query_person_info import get_tool_spec as get_query_person_info_tool_spec
from .query_person_info import handle_tool as handle_query_person_info_tool
from .reply import get_tool_spec as get_reply_tool_spec
from .reply import handle_tool as handle_reply_tool
from .send_emoji import get_tool_spec as get_send_emoji_tool_spec
from .send_emoji import handle_tool as handle_send_emoji_tool
from .wait import get_tool_spec as get_wait_tool_spec
from .wait import handle_tool as handle_wait_tool
BuiltinToolHandler = Callable[[ToolInvocation, Optional[ToolExecutionContext]], Awaitable[ToolExecutionResult]]
def get_builtin_tool_specs() -> List[ToolSpec]:
"""获取默认启用的内置工具声明列表。"""
return [
get_wait_tool_spec(),
get_reply_tool_spec(),
get_query_jargon_tool_spec(),
get_no_reply_tool_spec(),
get_send_emoji_tool_spec(),
]
def get_all_builtin_tool_specs() -> List[ToolSpec]:
"""获取全部内置工具声明列表。"""
return [
get_wait_tool_spec(),
get_reply_tool_spec(),
get_query_jargon_tool_spec(),
get_query_person_info_tool_spec(),
get_no_reply_tool_spec(),
get_send_emoji_tool_spec(),
]
def get_builtin_tools() -> List[ToolDefinitionInput]:
"""获取兼容旧模型层的内置工具定义。"""
return [tool_spec.to_llm_definition() for tool_spec in get_builtin_tool_specs()]
def build_builtin_tool_handlers(tool_ctx: BuiltinToolRuntimeContext) -> Dict[str, BuiltinToolHandler]:
"""构建内置工具处理器映射。"""
return {
"reply": lambda invocation, context=None: handle_reply_tool(tool_ctx, invocation, context),
"no_reply": lambda invocation, context=None: handle_no_reply_tool(tool_ctx, invocation, context),
"query_jargon": lambda invocation, context=None: handle_query_jargon_tool(tool_ctx, invocation, context),
"query_person_info": lambda invocation, context=None: handle_query_person_info_tool(
tool_ctx,
invocation,
context,
),
"wait": lambda invocation, context=None: handle_wait_tool(tool_ctx, invocation, context),
"send_emoji": lambda invocation, context=None: handle_send_emoji_tool(tool_ctx, invocation, context),
}

View File

@@ -0,0 +1,185 @@
"""Maisaka 内置工具执行上下文。"""
from __future__ import annotations
from base64 import b64decode
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from src.chat.utils.utils import process_llm_response
from src.common.data_models.message_component_data_model import EmojiComponent, MessageSequence, TextComponent
from src.config.config import global_config
from src.core.tooling import ToolExecutionResult
from ..context_messages import SessionBackedMessage
from ..message_adapter import format_speaker_content
from ..planner_message_utils import build_planner_prefix, build_session_backed_text_message
if TYPE_CHECKING:
from ..reasoning_engine import MaisakaReasoningEngine
from ..runtime import MaisakaHeartFlowChatting
class BuiltinToolRuntimeContext:
"""为拆分后的内置工具提供统一运行时能力。"""
def __init__(
self,
engine: "MaisakaReasoningEngine",
runtime: "MaisakaHeartFlowChatting",
) -> None:
self.engine = engine
self.runtime = runtime
@staticmethod
def build_success_result(
tool_name: str,
content: str = "",
structured_content: Any = None,
metadata: Optional[Dict[str, Any]] = None,
) -> ToolExecutionResult:
"""构造统一工具成功结果。"""
return ToolExecutionResult(
tool_name=tool_name,
success=True,
content=content,
structured_content=structured_content,
metadata=dict(metadata or {}),
)
@staticmethod
def build_failure_result(
tool_name: str,
error_message: str,
structured_content: Any = None,
metadata: Optional[Dict[str, Any]] = None,
) -> ToolExecutionResult:
"""构造统一工具失败结果。"""
return ToolExecutionResult(
tool_name=tool_name,
success=False,
error_message=error_message,
structured_content=structured_content,
metadata=dict(metadata or {}),
)
@staticmethod
def normalize_words(raw_words: Any) -> List[str]:
"""清洗黑话查询词条列表。"""
if not isinstance(raw_words, list):
return []
normalized_words: List[str] = []
seen_words: set[str] = set()
for item in raw_words:
if not isinstance(item, str):
continue
word = item.strip()
if not word or word in seen_words:
continue
seen_words.add(word)
normalized_words.append(word)
return normalized_words
@staticmethod
def normalize_jargon_query_results(raw_results: Any) -> List[Dict[str, object]]:
"""规范化黑话查询结果列表。"""
if not isinstance(raw_results, list):
return []
normalized_results: List[Dict[str, object]] = []
for raw_item in raw_results:
if not isinstance(raw_item, dict):
continue
word = str(raw_item.get("word") or "").strip()
matches = raw_item.get("matches")
normalized_matches: List[Dict[str, str]] = []
if isinstance(matches, list):
for match in matches:
if not isinstance(match, dict):
continue
content = str(match.get("content") or "").strip()
meaning = str(match.get("meaning") or "").strip()
if not content or not meaning:
continue
normalized_matches.append({"content": content, "meaning": meaning})
normalized_results.append(
{
"word": word,
"found": bool(raw_item.get("found", bool(normalized_matches))),
"matches": normalized_matches,
}
)
return normalized_results
@staticmethod
def post_process_reply_text(reply_text: str) -> List[str]:
"""沿用旧回复链的文本后处理,执行分段与错别字注入。"""
processed_segments: List[str] = []
for segment in process_llm_response(reply_text):
normalized_segment = segment.strip()
if normalized_segment:
processed_segments.append(normalized_segment)
if processed_segments:
return processed_segments
return [reply_text.strip()]
def get_runtime_manager(self) -> Any:
"""获取插件运行时管理器。"""
return self.engine._get_runtime_manager()
def append_guided_reply_to_chat_history(self, reply_text: str) -> None:
"""将引导回复写回 Maisaka 历史。"""
bot_name = global_config.bot.nickname.strip() or "MaiSaka"
reply_timestamp = datetime.now()
history_message = build_session_backed_text_message(
speaker_name=bot_name,
text=reply_text,
timestamp=reply_timestamp,
source_kind="guided_reply",
)
self.runtime._chat_history.append(history_message)
def append_sent_emoji_to_chat_history(
self,
*,
emoji_base64: str,
success_message: str,
) -> None:
"""将 bot 主动发送的表情包同步到 Maisaka 历史。"""
bot_name = global_config.bot.nickname.strip() or "MaiSaka"
reply_timestamp = datetime.now()
planner_prefix = build_planner_prefix(
timestamp=reply_timestamp,
user_name=bot_name,
)
history_message = SessionBackedMessage(
raw_message=MessageSequence(
[
TextComponent(planner_prefix),
EmojiComponent(
binary_hash="",
content=success_message,
binary_data=b64decode(emoji_base64),
),
]
),
visible_text=format_speaker_content(
bot_name,
"[表情包]",
reply_timestamp,
),
timestamp=reply_timestamp,
source_kind="guided_reply",
)
self.runtime._chat_history.append(history_message)

View File

@@ -0,0 +1,34 @@
"""no_reply 内置工具。"""
from typing import Optional
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from .context import BuiltinToolRuntimeContext
def get_tool_spec() -> ToolSpec:
"""获取 no_reply 工具声明。"""
return ToolSpec(
name="no_reply",
brief_description="本轮不进行回复,等待其他用户的新消息。",
provider_name="maisaka_builtin",
provider_type="builtin",
)
async def handle_tool(
tool_ctx: BuiltinToolRuntimeContext,
invocation: ToolInvocation,
context: Optional[ToolExecutionContext] = None,
) -> ToolExecutionResult:
"""执行 no_reply 内置工具。"""
del context
tool_ctx.runtime._enter_stop_state()
return tool_ctx.build_success_result(
invocation.tool_name,
"当前对话循环已暂停,等待新消息到来。",
metadata={"pause_execution": True},
)

View File

@@ -0,0 +1,143 @@
"""query_jargon 内置工具。"""
from typing import Any, Dict, List, Optional
import json
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.learners.jargon_explainer import search_jargon
from .context import BuiltinToolRuntimeContext
def get_tool_spec() -> ToolSpec:
"""获取 query_jargon 工具声明。"""
return ToolSpec(
name="query_jargon",
brief_description="查询当前聊天上下文中的黑话或词条含义。",
detailed_description="参数说明:\n- wordsarray必填。要查询的词条列表。",
parameters_schema={
"type": "object",
"properties": {
"words": {
"type": "array",
"description": "要查询的词条列表。",
"items": {"type": "string"},
},
},
"required": ["words"],
},
provider_name="maisaka_builtin",
provider_type="builtin",
)
async def handle_tool(
tool_ctx: BuiltinToolRuntimeContext,
invocation: ToolInvocation,
context: Optional[ToolExecutionContext] = None,
) -> ToolExecutionResult:
"""执行 query_jargon 内置工具。"""
del context
raw_words = invocation.arguments.get("words")
if not isinstance(raw_words, list):
return tool_ctx.build_failure_result(
invocation.tool_name,
"查询黑话工具需要提供 `words` 数组参数。",
)
words = tool_ctx.normalize_words(raw_words)
if not words:
return tool_ctx.build_failure_result(
invocation.tool_name,
"查询黑话工具至少需要一个非空词条。",
)
limit = 5
case_sensitive = False
enable_fuzzy_fallback = True
before_search_result = await tool_ctx.get_runtime_manager().invoke_hook(
"jargon.query.before_search",
words=list(words),
session_id=tool_ctx.runtime.session_id,
limit=limit,
case_sensitive=case_sensitive,
enable_fuzzy_fallback=enable_fuzzy_fallback,
abort_message="黑话查询已被 Hook 中止。",
)
if before_search_result.aborted:
abort_message = str(before_search_result.kwargs.get("abort_message") or "黑话查询已被 Hook 中止。").strip()
return tool_ctx.build_failure_result(invocation.tool_name, abort_message or "黑话查询已被 Hook 中止。")
before_search_kwargs = before_search_result.kwargs
if before_search_kwargs.get("words") is not None:
words = tool_ctx.normalize_words(before_search_kwargs.get("words"))
if not words:
return tool_ctx.build_failure_result(invocation.tool_name, "Hook 过滤后没有可查询的黑话词条。")
try:
limit = int(before_search_kwargs.get("limit", limit))
except (TypeError, ValueError):
limit = 5
limit = max(limit, 1)
case_sensitive = bool(before_search_kwargs.get("case_sensitive", case_sensitive))
enable_fuzzy_fallback = bool(before_search_kwargs.get("enable_fuzzy_fallback", enable_fuzzy_fallback))
results: List[Dict[str, object]] = []
for word in words:
exact_matches = search_jargon(
keyword=word,
chat_id=tool_ctx.runtime.session_id,
limit=limit,
case_sensitive=case_sensitive,
fuzzy=False,
)
matched_entries = exact_matches
if not matched_entries and enable_fuzzy_fallback:
matched_entries = search_jargon(
keyword=word,
chat_id=tool_ctx.runtime.session_id,
limit=limit,
case_sensitive=case_sensitive,
fuzzy=True,
)
results.append(
{
"word": word,
"found": bool(matched_entries),
"matches": matched_entries,
}
)
after_search_result = await tool_ctx.get_runtime_manager().invoke_hook(
"jargon.query.after_search",
words=list(words),
session_id=tool_ctx.runtime.session_id,
limit=limit,
case_sensitive=case_sensitive,
enable_fuzzy_fallback=enable_fuzzy_fallback,
results=list(results),
abort_message="黑话查询结果已被 Hook 中止。",
)
if after_search_result.aborted:
abort_message = str(after_search_result.kwargs.get("abort_message") or "黑话查询结果已被 Hook 中止。").strip()
return tool_ctx.build_failure_result(
invocation.tool_name,
abort_message or "黑话查询结果已被 Hook 中止。",
)
raw_results = after_search_result.kwargs.get("results")
if raw_results is not None:
results = tool_ctx.normalize_jargon_query_results(raw_results)
structured_content: Dict[str, Any] = {"results": results}
return tool_ctx.build_success_result(
invocation.tool_name,
json.dumps(structured_content, ensure_ascii=False),
structured_content=structured_content,
)

View File

@@ -0,0 +1,183 @@
"""query_person_info 内置工具。"""
from typing import Any, Dict, List, Optional
import json
from sqlmodel import col, select
from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.know_u.knowledge_store import get_knowledge_store
from .context import BuiltinToolRuntimeContext
def get_tool_spec(*, enabled: bool = False) -> ToolSpec:
"""获取 query_person_info 工具声明。"""
return ToolSpec(
name="query_person_info",
brief_description="查询某个人的档案和相关记忆信息。",
detailed_description=(
"参数说明:\n"
"- person_namestring必填。人物名称、昵称或用户 ID。\n"
"- limitinteger可选。最多返回多少条匹配记录默认 3。"
),
parameters_schema={
"type": "object",
"properties": {
"person_name": {
"type": "string",
"description": "人物名称、昵称或用户 ID。",
},
"limit": {
"type": "integer",
"description": "最多返回多少条匹配记录。",
"default": 3,
},
},
"required": ["person_name"],
},
provider_name="maisaka_builtin",
provider_type="builtin",
enabled=enabled,
)
async def handle_tool(
tool_ctx: BuiltinToolRuntimeContext,
invocation: ToolInvocation,
context: Optional[ToolExecutionContext] = None,
) -> ToolExecutionResult:
"""执行 query_person_info 内置工具。"""
del context
raw_person_name = invocation.arguments.get("person_name")
raw_limit = invocation.arguments.get("limit", 3)
if not isinstance(raw_person_name, str):
return tool_ctx.build_failure_result(
invocation.tool_name,
"查询人物信息工具需要提供字符串类型的 `person_name` 参数。",
)
person_name = raw_person_name.strip()
if not person_name:
return tool_ctx.build_failure_result(
invocation.tool_name,
"查询人物信息工具需要提供非空的 `person_name` 参数。",
)
try:
limit = max(1, min(int(raw_limit), 10))
except (TypeError, ValueError):
limit = 3
persons = _query_person_records(person_name, limit)
result: Dict[str, Any] = {
"query": person_name,
"persons": persons,
"related_knowledge": _query_related_knowledge(person_name, persons, limit),
}
return tool_ctx.build_success_result(
invocation.tool_name,
json.dumps(result, ensure_ascii=False),
structured_content=result,
)
def _query_person_records(person_name: str, limit: int) -> List[Dict[str, Any]]:
"""按名称、昵称或用户 ID 查询人物档案。"""
with get_db_session() as session:
records = session.exec(
select(PersonInfo)
.where(
col(PersonInfo.person_name).contains(person_name)
| col(PersonInfo.user_nickname).contains(person_name)
| col(PersonInfo.user_id).contains(person_name)
)
.order_by(col(PersonInfo.last_known_time).desc(), col(PersonInfo.id).desc())
.limit(limit)
).all()
persons: List[Dict[str, Any]] = []
for record in records:
memory_points: List[str] = []
if record.memory_points:
try:
parsed_points = json.loads(record.memory_points)
if isinstance(parsed_points, list):
memory_points = [str(point).strip() for point in parsed_points if str(point).strip()]
except (json.JSONDecodeError, TypeError, ValueError):
memory_points = []
persons.append(
{
"person_id": record.person_id,
"person_name": record.person_name or "",
"user_nickname": record.user_nickname,
"user_id": record.user_id,
"platform": record.platform,
"name_reason": record.name_reason or "",
"is_known": record.is_known,
"know_counts": record.know_counts,
"memory_points": memory_points[:20],
"last_known_time": record.last_known_time.isoformat() if record.last_known_time is not None else None,
}
)
return persons
def _query_related_knowledge(
person_name: str,
persons: List[Dict[str, Any]],
limit: int,
) -> List[Dict[str, Any]]:
"""从 Maisaka knowledge 中补充检索与该人物相关的条目。"""
store = get_knowledge_store()
knowledge_items: List[Dict[str, Any]] = []
seen_ids: set[str] = set()
for person in persons:
matched_items = store.get_knowledge_by_user(
platform=str(person.get("platform", "")).strip(),
user_id=str(person.get("user_id", "")).strip(),
user_nickname=str(person.get("user_nickname", "")).strip(),
person_name=str(person.get("person_name", "")).strip(),
limit=max(limit, 5),
)
for item in matched_items:
item_id = str(item.get("id", "")).strip()
if item_id and item_id in seen_ids:
continue
if item_id:
seen_ids.add(item_id)
knowledge_items.append(item)
if not knowledge_items:
fallback_items = store.search_knowledge(person_name, limit=max(limit, 5))
for item in fallback_items:
item_id = str(item.get("id", "")).strip()
if item_id and item_id in seen_ids:
continue
if item_id:
seen_ids.add(item_id)
knowledge_items.append(item)
results: List[Dict[str, Any]] = []
for item in knowledge_items:
results.append(
{
"id": str(item.get("id", "")).strip(),
"category_id": str(item.get("category_id", "")).strip(),
"category_name": str(item.get("category_name", "")).strip(),
"content": str(item.get("content", "")).strip(),
"metadata": item.get("metadata", {}),
"created_at": item.get("created_at"),
}
)
return results

View File

@@ -0,0 +1,188 @@
"""reply 内置工具。"""
from typing import Optional
from src.chat.replyer.replyer_manager import replyer_manager
from src.cli.maisaka_cli_sender import CLI_PLATFORM_NAME, render_cli_message
from src.common.logger import get_logger
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.services import send_service
from .context import BuiltinToolRuntimeContext
logger = get_logger("maisaka_builtin_reply")
def get_tool_spec() -> ToolSpec:
"""获取 reply 工具声明。"""
return ToolSpec(
name="reply",
brief_description="根据当前思考生成并发送一条可见回复。",
detailed_description=(
"参数说明:\n"
"- msg_idstring必填。要回复的目标用户消息编号。\n"
"- quoteboolean可选。当有非常明确的回复目标时以引用回复的方式发送默认 true。\n"
"- unknown_wordsarray可选。回复前可能需要查询的黑话或词条列表。"
),
parameters_schema={
"type": "object",
"properties": {
"msg_id": {
"type": "string",
"description": "要回复的目标用户消息编号。",
},
"quote": {
"type": "boolean",
"description": "当有非常明确的回复目标时,以引用回复的方式发送。",
"default": True,
},
"unknown_words": {
"type": "array",
"description": "回复前可能需要查询的黑话或词条列表。",
"items": {"type": "string"},
},
},
"required": ["msg_id"],
},
provider_name="maisaka_builtin",
provider_type="builtin",
)
async def handle_tool(
tool_ctx: BuiltinToolRuntimeContext,
invocation: ToolInvocation,
context: Optional[ToolExecutionContext] = None,
) -> ToolExecutionResult:
"""执行 reply 内置工具。"""
latest_thought = context.reasoning if context is not None else invocation.reasoning
target_message_id = str(invocation.arguments.get("msg_id") or "").strip()
quote_reply = bool(invocation.arguments.get("quote", True))
raw_unknown_words = invocation.arguments.get("unknown_words")
unknown_words = raw_unknown_words if isinstance(raw_unknown_words, list) else None
if not target_message_id:
return tool_ctx.build_failure_result(
invocation.tool_name,
"回复工具需要提供有效的 `msg_id` 参数。",
)
target_message = tool_ctx.runtime._source_messages_by_id.get(target_message_id)
if target_message is None:
return tool_ctx.build_failure_result(
invocation.tool_name,
f"未找到要回复的目标消息msg_id={target_message_id}",
)
logger.info(
f"{tool_ctx.runtime.log_prefix} 已触发回复工具 "
f"目标消息编号={target_message_id} 引用回复={quote_reply} 最新思考={latest_thought!r}"
)
try:
replyer = replyer_manager.get_replyer(
chat_stream=tool_ctx.runtime.chat_stream,
request_type="maisaka_replyer",
replyer_type="maisaka",
)
except Exception:
logger.exception(
f"{tool_ctx.runtime.log_prefix} 获取回复生成器时发生异常: 目标消息编号={target_message_id}"
)
return tool_ctx.build_failure_result(
invocation.tool_name,
"获取 Maisaka 回复生成器时发生异常。",
)
if replyer is None:
logger.error(f"{tool_ctx.runtime.log_prefix} 获取 Maisaka 回复生成器失败")
return tool_ctx.build_failure_result(
invocation.tool_name,
"Maisaka 回复生成器当前不可用。",
)
try:
success, reply_result = await replyer.generate_reply_with_context(
reply_reason=latest_thought,
stream_id=tool_ctx.runtime.session_id,
reply_message=target_message,
chat_history=tool_ctx.runtime._chat_history,
unknown_words=unknown_words,
log_reply=False,
)
except Exception as exc:
logger.exception(
f"{tool_ctx.runtime.log_prefix} 回复生成器执行异常: 目标消息编号={target_message_id} 异常={exc}"
)
return tool_ctx.build_failure_result(
invocation.tool_name,
"生成可见回复时发生异常。",
)
reply_text = reply_result.completion.response_text.strip() if success else ""
if not reply_text:
logger.warning(
f"{tool_ctx.runtime.log_prefix} 回复生成器返回空文本: "
f"目标消息编号={target_message_id} 错误信息={reply_result.error_message!r}"
)
return tool_ctx.build_failure_result(
invocation.tool_name,
"生成可见回复失败。",
)
reply_segments = tool_ctx.post_process_reply_text(reply_text)
combined_reply_text = "".join(reply_segments)
try:
sent = False
if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME:
for segment in reply_segments:
render_cli_message(segment)
sent = True
else:
for index, segment in enumerate(reply_segments):
sent = await send_service.text_to_stream(
text=segment,
stream_id=tool_ctx.runtime.session_id,
set_reply=quote_reply if index == 0 else False,
reply_message=target_message if quote_reply and index == 0 else None,
selected_expressions=reply_result.selected_expression_ids or None,
typing=index > 0,
)
if not sent:
break
except Exception:
logger.exception(
f"{tool_ctx.runtime.log_prefix} 发送文字消息时发生异常,目标消息编号={target_message_id}"
)
return tool_ctx.build_failure_result(
invocation.tool_name,
"发送可见回复时发生异常。",
)
if not sent:
return tool_ctx.build_failure_result(
invocation.tool_name,
"可见回复生成成功,但发送失败。",
structured_content={
"msg_id": target_message_id,
"quote": quote_reply,
"reply_segments": reply_segments,
},
)
target_user_info = target_message.message_info.user_info
target_user_name = target_user_info.user_cardname or target_user_info.user_nickname or target_user_info.user_id
tool_ctx.append_guided_reply_to_chat_history(combined_reply_text)
return tool_ctx.build_success_result(
invocation.tool_name,
"回复已生成并发送。",
structured_content={
"msg_id": target_message_id,
"quote": quote_reply,
"reply_text": combined_reply_text,
"reply_segments": reply_segments,
"target_user_name": target_user_name,
},
)

View File

@@ -0,0 +1,106 @@
"""send_emoji 内置工具。"""
from typing import Any, Dict, Optional
from src.chat.emoji_system.maisaka_tool import send_emoji_for_maisaka
from src.common.logger import get_logger
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from .context import BuiltinToolRuntimeContext
logger = get_logger("maisaka_builtin_send_emoji")
def get_tool_spec() -> ToolSpec:
"""获取 send_emoji 工具声明。"""
return ToolSpec(
name="send_emoji",
brief_description="发送一个合适的表情包来辅助表达情绪。",
detailed_description="参数说明:\n- emotionstring可选。希望表达的情绪例如 happy、sad、angry 等。",
parameters_schema={
"type": "object",
"properties": {
"emotion": {
"type": "string",
"description": "希望表达的情绪,例如 happy、sad、angry 等。",
},
},
},
provider_name="maisaka_builtin",
provider_type="builtin",
)
async def handle_tool(
tool_ctx: BuiltinToolRuntimeContext,
invocation: ToolInvocation,
context: Optional[ToolExecutionContext] = None,
) -> ToolExecutionResult:
"""执行 send_emoji 内置工具。"""
del context
emotion = str(invocation.arguments.get("emotion") or "").strip()
context_texts = [
message.get_history_text()
for message in tool_ctx.runtime._chat_history[-5:]
if message.get_history_text().strip()
]
structured_result: Dict[str, Any] = {
"success": False,
"message": "",
"description": "",
"emotion": [],
"requested_emotion": emotion,
"matched_emotion": "",
}
logger.info(f"{tool_ctx.runtime.log_prefix} 触发表情包发送工具,请求情绪={emotion!r}")
try:
send_result = await send_emoji_for_maisaka(
stream_id=tool_ctx.runtime.session_id,
requested_emotion=emotion,
reasoning=tool_ctx.engine.last_reasoning_content,
context_texts=context_texts,
)
except Exception as exc:
logger.exception(f"{tool_ctx.runtime.log_prefix} 发送表情包时发生异常: {exc}")
structured_result["message"] = f"发送表情包时发生异常:{exc}"
return tool_ctx.build_failure_result(
invocation.tool_name,
structured_result["message"],
structured_content=structured_result,
)
structured_result["description"] = send_result.description
structured_result["emotion"] = list(send_result.emotions)
structured_result["matched_emotion"] = send_result.matched_emotion
structured_result["message"] = send_result.message
if send_result.success:
logger.info(
f"{tool_ctx.runtime.log_prefix} 表情包发送成功 "
f"描述={send_result.description!r} 情绪标签={send_result.emotions} "
f"请求情绪={emotion!r} 命中情绪={send_result.matched_emotion!r}"
)
tool_ctx.append_sent_emoji_to_chat_history(
emoji_base64=send_result.emoji_base64,
success_message=send_result.message,
)
structured_result["success"] = True
return tool_ctx.build_success_result(
invocation.tool_name,
send_result.message,
structured_content=structured_result,
)
logger.warning(
f"{tool_ctx.runtime.log_prefix} 表情包发送失败 "
f"请求情绪={emotion!r} 错误信息={send_result.message}"
)
return tool_ctx.build_failure_result(
invocation.tool_name,
structured_result["message"],
structured_content=structured_result,
)

View File

@@ -0,0 +1,51 @@
"""wait 内置工具。"""
from typing import Optional
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from .context import BuiltinToolRuntimeContext
def get_tool_spec() -> ToolSpec:
"""获取 wait 工具声明。"""
return ToolSpec(
name="wait",
brief_description="暂停当前对话并等待用户新的输入。",
detailed_description="参数说明:\n- secondsinteger必填。等待的秒数。",
parameters_schema={
"type": "object",
"properties": {
"seconds": {
"type": "integer",
"description": "等待的秒数。",
},
},
"required": ["seconds"],
},
provider_name="maisaka_builtin",
provider_type="builtin",
)
async def handle_tool(
tool_ctx: BuiltinToolRuntimeContext,
invocation: ToolInvocation,
context: Optional[ToolExecutionContext] = None,
) -> ToolExecutionResult:
"""执行 wait 内置工具。"""
del context
seconds = invocation.arguments.get("seconds", 30)
try:
wait_seconds = int(seconds)
except (TypeError, ValueError):
wait_seconds = 30
wait_seconds = max(0, wait_seconds)
tool_ctx.runtime._enter_wait_state(seconds=wait_seconds, tool_call_id=invocation.call_id)
return tool_ctx.build_success_result(
invocation.tool_name,
f"当前对话循环进入等待状态,最长等待 {wait_seconds} 秒。",
metadata={"pause_execution": True},
)

View File

@@ -1,159 +0,0 @@
"""Maisaka 内置工具声明。"""
from copy import deepcopy
from typing import Any, Dict, List
from src.core.tooling import ToolSpec, build_tool_detailed_description
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
def _build_tool_spec(
name: str,
brief_description: str,
parameters_schema: Dict[str, Any] | None = None,
detailed_description: str = "",
) -> ToolSpec:
"""构建单个内置工具声明。
Args:
name: 工具名称。
brief_description: 简要描述。
parameters_schema: 参数 Schema。
detailed_description: 详细描述;为空时自动根据参数生成。
Returns:
ToolSpec: 构建完成的工具声明。
"""
normalized_schema = deepcopy(parameters_schema) if parameters_schema is not None else None
return ToolSpec(
name=name,
brief_description=brief_description,
detailed_description=(
detailed_description.strip()
or build_tool_detailed_description(normalized_schema)
),
parameters_schema=normalized_schema,
provider_name="maisaka_builtin",
provider_type="builtin",
)
def create_builtin_tool_specs() -> List[ToolSpec]:
"""创建 Maisaka 内置工具声明列表。
Returns:
List[ToolSpec]: 内置工具声明列表。
"""
return [
_build_tool_spec(
name="wait",
brief_description="暂停当前对话并等待用户新的输入。",
parameters_schema={
"type": "object",
"properties": {
"seconds": {
"type": "integer",
"description": "等待的秒数。",
},
},
"required": ["seconds"],
},
),
_build_tool_spec(
name="reply",
brief_description="根据当前思考生成并发送一条可见回复。",
parameters_schema={
"type": "object",
"properties": {
"msg_id": {
"type": "string",
"description": "要回复的目标用户消息编号。",
},
"quote": {
"type": "boolean",
"description": "是否以引用回复的方式发送。",
"default": True,
},
"unknown_words": {
"type": "array",
"description": "回复前可能需要查询的黑话或词条列表。",
"items": {"type": "string"},
},
},
"required": ["msg_id"],
},
),
_build_tool_spec(
name="query_jargon",
brief_description="查询当前聊天上下文中的黑话或词条含义。",
parameters_schema={
"type": "object",
"properties": {
"words": {
"type": "array",
"description": "要查询的词条列表。",
"items": {"type": "string"},
},
},
"required": ["words"],
},
),
_build_tool_spec(
name="query_person_info",
brief_description="查询某个人的档案和相关记忆信息。",
parameters_schema={
"type": "object",
"properties": {
"person_name": {
"type": "string",
"description": "人物名称、昵称或用户 ID。",
},
"limit": {
"type": "integer",
"description": "最多返回多少条匹配记录。",
"default": 3,
},
},
"required": ["person_name"],
},
),
_build_tool_spec(
name="no_reply",
brief_description="本轮不进行回复,等待其他用户的新消息。",
),
_build_tool_spec(
name="send_emoji",
brief_description="发送一个合适的表情包来辅助表达情绪。",
parameters_schema={
"type": "object",
"properties": {
"emotion": {
"type": "string",
"description": "希望表达的情绪,例如 happy、sad、angry 等。",
},
},
},
),
]
def get_builtin_tool_specs() -> List[ToolSpec]:
"""获取 Maisaka 内置工具声明。
Returns:
List[ToolSpec]: 内置工具声明列表。
"""
return create_builtin_tool_specs()
def get_builtin_tools() -> List[ToolDefinitionInput]:
"""获取兼容旧模型层的内置工具定义。
Returns:
List[ToolDefinitionInput]: 可直接传给模型层的工具定义。
"""
return [tool_spec.to_llm_definition() for tool_spec in create_builtin_tool_specs()]

View File

@@ -1,28 +1,23 @@
"""Maisaka 对话循环服务。"""
from base64 import b64decode
from dataclasses import dataclass
from datetime import datetime
from io import BytesIO
from time import perf_counter
from typing import Any, Dict, List, Optional, Sequence
from typing import Any, List, Optional, Sequence
import asyncio
import json
import random
from PIL import Image as PILImage
from pydantic import BaseModel, Field as PydanticField
from rich.console import Group, RenderableType
from rich.console import Group
from rich.panel import Panel
from rich.pretty import Pretty
from rich.text import Text
from src.cli.console import console
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.core.tooling import ToolRegistry, ToolSpec
from src.know_u.knowledge import extract_category_ids_from_result
@@ -30,11 +25,20 @@ from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, ToolOption, normalize_tool_options
from src.plugin_runtime.hook_payloads import (
deserialize_prompt_messages,
deserialize_tool_calls,
serialize_prompt_messages,
serialize_tool_calls,
serialize_tool_definitions,
)
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.services.llm_service import LLMServiceClient
from .builtin_tools import get_builtin_tools
from .context_messages import AssistantMessage, LLMContextMessage, SessionBackedMessage
from .message_adapter import format_speaker_content
from .builtin_tool import get_builtin_tools
from .context_messages import AssistantMessage, LLMContextMessage, ToolResultMessage
from .prompt_cli_renderer import PromptCLIVisualizer
@dataclass(slots=True)
@@ -44,6 +48,11 @@ class ChatResponse:
content: Optional[str]
tool_calls: List[ToolCall]
raw_message: AssistantMessage
selected_history_count: int
prompt_tokens: int
built_message_count: int
completion_tokens: int
total_tokens: int
class ToolFilterSelection(BaseModel):
@@ -56,12 +65,131 @@ class ToolFilterSelection(BaseModel):
logger = get_logger("maisaka_chat_loop")
def register_maisaka_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册 Maisaka 规划器内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="maisaka.planner.before_request",
description="在 Maisaka 向模型发起规划请求前触发,可改写消息窗口与工具定义。",
parameters_schema=build_object_schema(
{
"messages": {
"type": "array",
"description": "即将发给模型的 PromptMessage 列表。",
},
"tool_definitions": {
"type": "array",
"description": "当前候选工具定义列表。",
},
"selected_history_count": {
"type": "integer",
"description": "当前选中的上下文消息数量。",
},
"built_message_count": {
"type": "integer",
"description": "实际发送给模型的消息数量。",
},
"selection_reason": {
"type": "string",
"description": "上下文选择说明。",
},
"session_id": {
"type": "string",
"description": "当前会话 ID。",
},
},
required=[
"messages",
"tool_definitions",
"selected_history_count",
"built_message_count",
"selection_reason",
"session_id",
],
),
default_timeout_ms=6000,
allow_abort=False,
allow_kwargs_mutation=True,
),
HookSpec(
name="maisaka.planner.after_response",
description="在 Maisaka 收到模型响应后触发,可调整文本结果与工具调用列表。",
parameters_schema=build_object_schema(
{
"response": {
"type": "string",
"description": "模型返回的文本内容。",
},
"tool_calls": {
"type": "array",
"description": "模型返回的工具调用列表。",
},
"selected_history_count": {
"type": "integer",
"description": "当前选中的上下文消息数量。",
},
"built_message_count": {
"type": "integer",
"description": "实际发送给模型的消息数量。",
},
"selection_reason": {
"type": "string",
"description": "上下文选择说明。",
},
"session_id": {
"type": "string",
"description": "当前会话 ID。",
},
"prompt_tokens": {
"type": "integer",
"description": "输入 Token 数。",
},
"completion_tokens": {
"type": "integer",
"description": "输出 Token 数。",
},
"total_tokens": {
"type": "integer",
"description": "总 Token 数。",
},
},
required=[
"response",
"tool_calls",
"selected_history_count",
"built_message_count",
"selection_reason",
"session_id",
"prompt_tokens",
"completion_tokens",
"total_tokens",
],
),
default_timeout_ms=6000,
allow_abort=False,
allow_kwargs_mutation=True,
),
]
)
class MaisakaChatLoopService:
"""负责 Maisaka 主对话循环、系统提示词和终端渲染。"""
def __init__(
self,
chat_system_prompt: Optional[str] = None,
session_id: Optional[str] = None,
is_group_chat: Optional[bool] = None,
temperature: float = 0.5,
max_tokens: int = 2048,
) -> None:
@@ -69,12 +197,16 @@ class MaisakaChatLoopService:
Args:
chat_system_prompt: 可选的系统提示词。
session_id: 当前会话 ID用于匹配会话级额外提示。
is_group_chat: 当前会话是否为群聊。
temperature: 规划器温度参数。
max_tokens: 规划器最大输出长度。
"""
self._temperature = temperature
self._max_tokens = max_tokens
self._is_group_chat = is_group_chat
self._session_id = session_id or ""
self._extra_tools: List[ToolOption] = []
self._interrupt_flag: asyncio.Event | None = None
self._tool_registry: ToolRegistry | None = None
@@ -97,6 +229,35 @@ class MaisakaChatLoopService:
return self._personality_prompt
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
"""将任意值安全转换为整数。
Args:
value: 待转换的输入值。
default: 转换失败时的默认值。
Returns:
int: 转换后的整数结果。
"""
try:
return int(value)
except (TypeError, ValueError):
return default
def _build_personality_prompt(self) -> str:
"""构造人格提示词。"""
@@ -127,19 +288,13 @@ class MaisakaChatLoopService:
Args:
tools_section: 额外注入到提示词中的工具说明片段。
"""
if self._prompts_loaded:
return
async with self._prompt_load_lock:
if self._prompts_loaded:
return
try:
self._chat_system_prompt = load_prompt(
"maisaka_chat",
file_tools_section=tools_section,
bot_name=global_config.bot.nickname,
group_chat_attention_block=self._build_group_chat_attention_block(),
identity=self._personality_prompt,
)
except Exception:
@@ -147,6 +302,74 @@ class MaisakaChatLoopService:
self._prompts_loaded = True
def _build_group_chat_attention_block(self) -> str:
"""构建当前聊天场景下的额外注意事项块。"""
prompt_lines: List[str] = []
if self._is_group_chat is True:
if group_chat_prompt := str(global_config.chat.group_chat_prompt or "").strip():
prompt_lines.append(f"通用注意事项:\n{group_chat_prompt}")
elif self._is_group_chat is False:
if private_chat_prompt := str(global_config.chat.private_chat_prompts or "").strip():
prompt_lines.append(f"通用注意事项:\n{private_chat_prompt}")
if self._session_id:
if chat_prompt := self._get_chat_prompt_for_chat(self._session_id, self._is_group_chat).strip():
prompt_lines.append(f"当前聊天额外注意事项:\n{chat_prompt}")
if not prompt_lines:
return ""
return "在该聊天中的注意事项:\n" + "\n\n".join(prompt_lines) + "\n"
@staticmethod
def _get_chat_prompt_for_chat(chat_id: str, is_group_chat: Optional[bool]) -> str:
"""根据聊天流 ID 获取匹配的额外提示。"""
if not global_config.chat.chat_prompts:
return ""
for chat_prompt_item in global_config.chat.chat_prompts:
if hasattr(chat_prompt_item, "platform"):
platform = str(chat_prompt_item.platform or "").strip()
item_id = str(chat_prompt_item.item_id or "").strip()
rule_type = str(chat_prompt_item.rule_type or "").strip()
prompt_content = str(chat_prompt_item.prompt or "").strip()
elif isinstance(chat_prompt_item, str):
parts = chat_prompt_item.split(":", 3)
if len(parts) != 4:
continue
platform, item_id, rule_type, prompt_content = parts
platform = platform.strip()
item_id = item_id.strip()
rule_type = rule_type.strip()
prompt_content = prompt_content.strip()
else:
continue
if not platform or not item_id or not prompt_content:
continue
if rule_type == "group":
config_is_group = True
config_chat_id = SessionUtils.calculate_session_id(platform, group_id=item_id)
elif rule_type == "private":
config_is_group = False
config_chat_id = SessionUtils.calculate_session_id(platform, user_id=item_id)
else:
continue
if is_group_chat is not None and config_is_group != is_group_chat:
continue
if config_chat_id == chat_id:
logger.debug(f"匹配到 Maisaka 聊天额外提示chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
return prompt_content
return ""
def set_extra_tools(self, tools: Sequence[ToolDefinitionInput]) -> None:
"""设置额外工具定义。
@@ -468,259 +691,6 @@ class MaisakaChatLoopService:
return extract_category_ids_from_result(generation_result.response or "")
@staticmethod
def _get_role_badge_style(role: str) -> str:
"""返回终端中角色标签的样式。
Args:
role: 消息角色名称。
Returns:
str: Rich 可识别的样式字符串。
"""
if role == "system":
return "bold white on blue"
if role == "user":
return "bold black on green"
if role == "assistant":
return "bold black on yellow"
if role == "tool":
return "bold white on magenta"
return "bold white on bright_black"
@staticmethod
def _get_role_badge_label(role: str) -> str:
"""返回终端中角色标签的中文名称。
Args:
role: 消息角色名称。
Returns:
str: 用于展示的中文角色名称。
"""
if role == "system":
return "系统"
if role == "user":
return "用户"
if role == "assistant":
return "助手"
if role == "tool":
return "工具"
return "未知"
@staticmethod
def _build_terminal_image_preview(image_base64: str) -> Optional[str]:
"""构造终端图片预览字符画。
Args:
image_base64: 图片的 Base64 编码。
Returns:
Optional[str]: 生成成功时返回字符画文本,否则返回 ``None``。
"""
ascii_chars = " .:-=+*#%@"
try:
image_bytes = b64decode(image_base64)
with PILImage.open(BytesIO(image_bytes)) as image:
grayscale = image.convert("L")
width, height = grayscale.size
if width <= 0 or height <= 0:
return None
preview_width = max(8, int(global_config.maisaka.terminal_image_preview_width))
preview_height = max(1, int(height * (preview_width / width) * 0.5))
resized = grayscale.resize((preview_width, preview_height))
pixels = list(resized.tobytes())
except Exception:
return None
rows: List[str] = []
for row_index in range(preview_height):
row_pixels = pixels[row_index * preview_width : (row_index + 1) * preview_width]
row = "".join(ascii_chars[min(len(ascii_chars) - 1, pixel * len(ascii_chars) // 256)] for pixel in row_pixels)
rows.append(row)
return "\n".join(rows)
@classmethod
def _render_message_content(cls, content: Any) -> RenderableType:
"""将消息内容渲染为终端可展示对象。
Args:
content: 原始消息内容。
Returns:
RenderableType: Rich 可渲染对象。
"""
if isinstance(content, str):
return Text(content)
if isinstance(content, list):
parts: List[RenderableType] = []
for item in content:
if isinstance(item, str):
parts.append(Text(item))
continue
if isinstance(item, tuple) and len(item) == 2:
image_format, image_base64 = item
if isinstance(image_format, str) and isinstance(image_base64, str):
approx_size = max(0, len(image_base64) * 3 // 4)
size_text = f"{approx_size / 1024:.1f} KB" if approx_size >= 1024 else f"{approx_size} B"
preview_parts: List[RenderableType] = [
Text(f"图片格式 image/{image_format} {size_text}\nbase64 内容已省略", style="magenta")
]
if global_config.maisaka.terminal_image_preview:
preview_text = cls._build_terminal_image_preview(image_base64)
if preview_text:
preview_parts.append(Text(preview_text, style="white"))
parts.append(
Panel(
Group(*preview_parts),
border_style="magenta",
padding=(0, 1),
)
)
continue
if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str):
parts.append(Text(item["text"]))
else:
parts.append(Pretty(item, expand_all=True))
return Group(*parts) if parts else Text("")
if content is None:
return Text("")
return Pretty(content, expand_all=True)
@staticmethod
def _format_tool_call_for_display(tool_call: Any) -> Dict[str, Any]:
"""将工具调用对象格式化为易读字典。
Args:
tool_call: 原始工具调用对象或字典。
Returns:
Dict[str, Any]: 适合终端展示的工具调用字典。
"""
if isinstance(tool_call, dict):
function_info = tool_call.get("function", {})
return {
"id": tool_call.get("id"),
"name": function_info.get("name", tool_call.get("name")),
"arguments": function_info.get("arguments", tool_call.get("arguments")),
}
return {
"id": getattr(tool_call, "call_id", getattr(tool_call, "id", None)),
"name": getattr(tool_call, "func_name", getattr(tool_call, "name", None)),
"arguments": getattr(tool_call, "args", getattr(tool_call, "arguments", None)),
}
def _render_tool_call_panel(self, tool_call: Any, index: int, parent_index: int) -> Panel:
"""渲染单个工具调用面板。
Args:
tool_call: 原始工具调用对象。
index: 工具调用在当前消息中的序号。
parent_index: 所属消息的序号。
Returns:
Panel: 工具调用展示面板。
"""
title = Text.assemble(
Text(" 工具调用 ", style="bold white on magenta"),
Text(f" #{parent_index}.{index}", style="muted"),
)
return Panel(
Pretty(self._format_tool_call_for_display(tool_call), expand_all=True),
title=title,
border_style="magenta",
padding=(0, 1),
)
def _render_message_panel(self, message: Any, index: int) -> Panel:
"""渲染单条消息面板。
Args:
message: 原始消息对象或字典。
index: 消息序号。
Returns:
Panel: 终端展示面板。
"""
if isinstance(message, dict):
raw_role = message.get("role", "unknown")
content = message.get("content")
tool_call_id = message.get("tool_call_id")
else:
raw_role = getattr(message, "role", "unknown")
content = getattr(message, "content", None)
tool_call_id = getattr(message, "tool_call_id", None)
role = raw_role.value if isinstance(raw_role, RoleType) else str(raw_role)
title = Text.assemble(
Text(f" {self._get_role_badge_label(role)} ", style=self._get_role_badge_style(role)),
Text(f" #{index}", style="muted"),
)
parts: List[RenderableType] = []
if content not in (None, "", []):
parts.append(Text(" 消息 ", style="bold cyan"))
parts.append(self._render_message_content(content))
if tool_call_id:
parts.append(
Text.assemble(
Text(" 工具调用编号 ", style="bold magenta"),
Text(" "),
Text(str(tool_call_id), style="magenta"),
)
)
if not parts:
parts.append(Text("[空消息]", style="muted"))
return Panel(
Group(*parts),
title=title,
border_style="dim",
padding=(0, 1),
)
@staticmethod
def _format_token_count(token_count: int) -> str:
"""格式化 token 数量展示文本。"""
if token_count >= 10_000:
return f"{token_count / 1000:.1f}k"
return str(token_count)
@classmethod
def _build_prompt_stats_text(
cls,
*,
selected_history_count: int,
built_message_count: int,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
) -> str:
"""构造本轮 prompt 的统计信息文本。"""
return (
f"已选上下文消息数={selected_history_count} "
f"大模型消息数={built_message_count} "
f"实际输入Token={cls._format_token_count(prompt_tokens)} "
f"输出Token={cls._format_token_count(completion_tokens)} "
f"总Token={cls._format_token_count(total_tokens)}"
)
async def chat_loop_step(self, chat_history: List[LLMContextMessage]) -> ChatResponse:
"""执行一轮 Maisaka 规划器请求。
@@ -756,13 +726,30 @@ class MaisakaChatLoopService:
else:
all_tools = [*get_builtin_tools(), *self._extra_tools]
ordered_panels: List[Panel] = []
for index, msg in enumerate(built_messages, start=1):
ordered_panels.append(self._render_message_panel(msg, index))
tool_calls = getattr(msg, "tool_calls", None)
if tool_calls:
for tool_call_index, tool_call in enumerate(tool_calls, start=1):
ordered_panels.append(self._render_tool_call_panel(tool_call, tool_call_index, index))
before_request_result = await self._get_runtime_manager().invoke_hook(
"maisaka.planner.before_request",
messages=serialize_prompt_messages(built_messages),
tool_definitions=serialize_tool_definitions(all_tools),
selected_history_count=len(selected_history),
built_message_count=len(built_messages),
selection_reason=selection_reason,
session_id=self._session_id,
)
before_request_kwargs = before_request_result.kwargs
raw_messages = before_request_kwargs.get("messages")
if isinstance(raw_messages, list):
try:
built_messages = deserialize_prompt_messages(raw_messages)
except Exception as exc:
logger.warning(f"Hook maisaka.planner.before_request 返回的 messages 无法反序列化,已忽略: {exc}")
raw_tool_definitions = before_request_kwargs.get("tool_definitions")
if isinstance(raw_tool_definitions, list):
all_tools = [item for item in raw_tool_definitions if isinstance(item, dict)]
ordered_panels = PromptCLIVisualizer.build_prompt_panels(
built_messages,
image_display_mode=global_config.maisaka.terminal_image_display_mode,
)
if global_config.maisaka.show_thinking and ordered_panels:
console.print(
@@ -795,7 +782,7 @@ class MaisakaChatLoopService:
request_elapsed = perf_counter() - request_started_at
logger.info(f"规划器请求完成,耗时={request_elapsed:.3f}")
prompt_stats_text = self._build_prompt_stats_text(
prompt_stats_text = PromptCLIVisualizer.build_prompt_stats_text(
selected_history_count=len(selected_history),
built_message_count=len(built_messages),
prompt_tokens=generation_result.prompt_tokens,
@@ -804,28 +791,63 @@ class MaisakaChatLoopService:
)
logger.info(f"本轮Prompt统计: {prompt_stats_text}")
final_response = generation_result.response or ""
final_tool_calls = list(generation_result.tool_calls or [])
after_response_result = await self._get_runtime_manager().invoke_hook(
"maisaka.planner.after_response",
response=final_response,
tool_calls=serialize_tool_calls(final_tool_calls),
selected_history_count=len(selected_history),
built_message_count=len(built_messages),
selection_reason=selection_reason,
session_id=self._session_id,
prompt_tokens=generation_result.prompt_tokens,
completion_tokens=generation_result.completion_tokens,
total_tokens=generation_result.total_tokens,
)
after_response_kwargs = after_response_result.kwargs
if "response" in after_response_kwargs:
final_response = str(after_response_kwargs.get("response") or "")
raw_tool_calls = after_response_kwargs.get("tool_calls")
if isinstance(raw_tool_calls, list):
try:
final_tool_calls = deserialize_tool_calls(raw_tool_calls)
except Exception as exc:
logger.warning(f"Hook maisaka.planner.after_response 返回的 tool_calls 无法反序列化,已忽略: {exc}")
prompt_tokens = self._coerce_int(after_response_kwargs.get("prompt_tokens"), generation_result.prompt_tokens)
completion_tokens = self._coerce_int(
after_response_kwargs.get("completion_tokens"),
generation_result.completion_tokens,
)
total_tokens = self._coerce_int(after_response_kwargs.get("total_tokens"), generation_result.total_tokens)
tool_call_summaries = [
{
"调用编号": getattr(tool_call, "call_id", getattr(tool_call, "id", None)),
"工具名": getattr(tool_call, "func_name", getattr(tool_call, "name", None)),
"参数": getattr(tool_call, "args", getattr(tool_call, "arguments", None)),
}
for tool_call in (generation_result.tool_calls or [])
for tool_call in final_tool_calls
]
logger.info(
f"Maisaka 规划器返回结果: 内容={generation_result.response or ''!r} "
f"Maisaka 规划器返回结果: 内容={final_response!r} "
f"工具调用={tool_call_summaries}"
)
raw_message = AssistantMessage(
content=generation_result.response or "",
content=final_response,
timestamp=datetime.now(),
tool_calls=generation_result.tool_calls or [],
tool_calls=final_tool_calls,
)
return ChatResponse(
content=generation_result.response,
tool_calls=generation_result.tool_calls or [],
content=final_response or None,
tool_calls=final_tool_calls,
raw_message=raw_message,
selected_history_count=len(selected_history),
prompt_tokens=prompt_tokens,
built_message_count=len(built_messages),
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
@staticmethod
@@ -859,6 +881,7 @@ class MaisakaChatLoopService:
selected_indices.reverse()
selected_history = [chat_history[index] for index in selected_indices]
selected_history = MaisakaChatLoopService._drop_leading_orphan_tool_results(selected_history)
return (
selected_history,
(
@@ -868,34 +891,31 @@ class MaisakaChatLoopService:
)
@staticmethod
def build_chat_context(user_text: str) -> List[LLMContextMessage]:
"""根据用户输入构造最小对话上下文。
def _drop_leading_orphan_tool_results(
selected_history: List[LLMContextMessage],
) -> List[LLMContextMessage]:
"""移除窗口前缀中缺少对应 tool_call 的工具结果消息。"""
Args:
user_text: 用户输入文本。
if not selected_history:
return selected_history
Returns:
List[LLMContextMessage]: 构造好的上下文消息列表。
"""
available_tool_call_ids = {
tool_call.call_id
for message in selected_history
if isinstance(message, AssistantMessage)
for tool_call in message.tool_calls
if tool_call.call_id
}
timestamp = datetime.now()
visible_text = format_speaker_content(
global_config.maisaka.user_name.strip() or "用户",
user_text,
timestamp,
)
planner_prefix = (
f"[时间]{timestamp.strftime('%H:%M:%S')}\n"
f"[用户]{global_config.maisaka.user_name.strip() or '用户'}\n"
"[用户群昵称]\n"
"[msg_id]\n"
"[发言内容]"
)
return [
SessionBackedMessage(
raw_message=MessageSequence([TextComponent(f"{planner_prefix}{user_text}")]),
visible_text=visible_text,
timestamp=timestamp,
source_kind="user",
)
]
first_valid_index = 0
while first_valid_index < len(selected_history):
message = selected_history[first_valid_index]
if not isinstance(message, ToolResultMessage):
break
if message.tool_call_id in available_tool_call_ids:
break
first_valid_index += 1
if first_valid_index == 0:
return selected_history
return selected_history[first_valid_index:]

View File

@@ -11,7 +11,13 @@ import base64
from PIL import Image as PILImage
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence, TextComponent
from src.common.data_models.message_component_data_model import (
EmojiComponent,
ImageComponent,
MessageSequence,
ReplyComponent,
TextComponent,
)
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.llm_models.payload_content.tool_option import ToolCall
@@ -27,6 +33,44 @@ def _guess_image_format(image_bytes: bytes) -> Optional[str]:
return None
def _append_emoji_component(builder: MessageBuilder, component: EmojiComponent) -> bool:
"""将表情组件追加到 LLM 消息构建器。"""
image_format = _guess_image_format(component.binary_data)
if image_format and component.binary_data:
builder.add_text_content("[消息类型]表情包")
builder.add_image_content(image_format, base64.b64encode(component.binary_data).decode("utf-8"))
return True
if component.content:
builder.add_text_content(component.content)
return True
return False
def _append_image_component(builder: MessageBuilder, component: ImageComponent) -> bool:
"""将图片组件追加到 LLM 消息构建器。"""
image_format = _guess_image_format(component.binary_data)
if image_format and component.binary_data:
builder.add_text_content("[消息类型]图片")
builder.add_image_content(image_format, base64.b64encode(component.binary_data).decode("utf-8"))
return True
if component.content:
builder.add_text_content(component.content)
return True
return False
def _append_reply_component(builder: MessageBuilder, component: ReplyComponent) -> bool:
"""将回复组件追加到 LLM 消息构建器。"""
target_message_id = component.target_message_id.strip()
if not target_message_id:
return False
builder.add_text_content(f"[引用回复]({target_message_id})")
return True
def _build_message_from_sequence(
role: RoleType,
message_sequence: MessageSequence,
@@ -50,16 +94,17 @@ def _build_message_from_sequence(
has_content = True
continue
if isinstance(component, (EmojiComponent, ImageComponent)):
image_format = _guess_image_format(component.binary_data)
if image_format and component.binary_data:
builder.add_image_content(image_format, base64.b64encode(component.binary_data).decode("utf-8"))
has_content = True
continue
if isinstance(component, EmojiComponent):
has_content = _append_emoji_component(builder, component) or has_content
continue
if component.content:
builder.add_text_content(component.content)
has_content = True
if isinstance(component, ImageComponent):
has_content = _append_image_component(builder, component) or has_content
continue
if isinstance(component, ReplyComponent):
has_content = _append_reply_component(builder, component) or has_content
continue
if not has_content and fallback_text:
builder.add_text_content(fallback_text)

View File

@@ -5,7 +5,13 @@ from datetime import datetime
from typing import Optional
import re
from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence, TextComponent
from src.common.data_models.message_component_data_model import (
EmojiComponent,
ImageComponent,
MessageSequence,
ReplyComponent,
TextComponent,
)
SPEAKER_PREFIX_PATTERN = re.compile(
r"^(?:(?P<timestamp>\d{2}:\d{2}:\d{2}))?(?:\[msg_id:(?P<message_id>[^\]]+)\])?\[(?P<speaker>[^\]]+)\](?P<content>.*)$",
@@ -65,5 +71,11 @@ def build_visible_text_from_sequence(message_sequence: MessageSequence) -> str:
if isinstance(component, ImageComponent):
parts.append("[图片]")
continue
if isinstance(component, ReplyComponent):
target_message_id = component.target_message_id.strip()
if target_message_id:
parts.append(f"[引用回复]({target_message_id})")
return "".join(parts)

View File

@@ -0,0 +1,109 @@
"""Maisaka 规划器消息构造工具。"""
from datetime import datetime
from typing import Optional
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from .context_messages import SessionBackedMessage
from .message_adapter import format_speaker_content
def build_planner_prefix(
*,
timestamp: datetime,
user_name: str,
group_card: str = "",
message_id: Optional[str] = None,
include_message_id: bool = True,
) -> str:
"""构造 Maisaka 规划器使用的统一消息前缀。
Args:
timestamp: 消息时间。
user_name: 展示给规划器的用户名。
group_card: 群昵称。
message_id: 消息 ID。
include_message_id: 是否输出 `msg_id` 段。
Returns:
str: 拼接完成的规划器前缀。
"""
prefix_parts = [
f"[时间]{timestamp.strftime('%H:%M:%S')}\n",
f"[用户名]{user_name}\n",
f"[用户群昵称]{group_card}\n",
]
if include_message_id:
prefix_parts.append(f"[msg_id]{message_id or ''}\n")
prefix_parts.append("[发言内容]")
return "".join(prefix_parts)
def build_planner_user_prefix_from_session_message(message: SessionMessage) -> str:
"""根据真实会话消息构造规划器前缀。
Args:
message: 原始会话消息。
Returns:
str: 规划器前缀字符串。
"""
user_info = message.message_info.user_info
user_name = user_info.user_nickname or user_info.user_id
return build_planner_prefix(
timestamp=message.timestamp,
user_name=user_name,
group_card=user_info.user_cardname or "",
message_id=message.message_id,
include_message_id=not message.is_notify and bool(message.message_id),
)
def build_session_backed_text_message(
*,
speaker_name: str,
text: str,
timestamp: datetime,
source_kind: str,
group_card: str = "",
message_id: Optional[str] = None,
include_message_id: bool = True,
) -> SessionBackedMessage:
"""构造带规划器前缀的纯文本历史消息。
Args:
speaker_name: 发言者名称。
text: 发言内容。
timestamp: 发言时间。
source_kind: 上下文来源类型。
group_card: 群昵称。
message_id: 消息 ID。
include_message_id: 是否输出 `msg_id` 段。
Returns:
SessionBackedMessage: 可直接写入历史的上下文消息。
"""
planner_prefix = build_planner_prefix(
timestamp=timestamp,
user_name=speaker_name,
group_card=group_card,
message_id=message_id,
include_message_id=include_message_id,
)
return SessionBackedMessage(
raw_message=MessageSequence([TextComponent(f"{planner_prefix}{text}")]),
visible_text=format_speaker_content(
speaker_name,
text,
timestamp,
message_id if include_message_id else None,
),
timestamp=timestamp,
message_id=message_id,
source_kind=source_kind,
)

View File

@@ -0,0 +1,306 @@
"""CLI 下的 Prompt 可视化渲染模块。"""
from __future__ import annotations
import hashlib
from base64 import b64decode
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from urllib.parse import quote
from typing import Any, Dict, List, Literal
import tempfile
from pydantic import BaseModel, Field as PydanticField
from rich.console import Group, RenderableType
from rich.pretty import Pretty
from rich.panel import Panel
from rich.text import Text
PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve()
DATA_IMAGE_DIR = PROJECT_ROOT / "data" / "images"
class PromptImageDisplayMode(str, Enum):
"""图片在终端中的展示模式。"""
LEGACY = "legacy"
"""不新增链接,仅保留原有的元信息展示。"""
PATH_LINK = "path_link"
"""把图片落盘到临时目录并输出可点击路径。"""
class PromptImageDisplaySettings(BaseModel):
"""图片展示参数。"""
display_mode: PromptImageDisplayMode = PydanticField(default=PromptImageDisplayMode.LEGACY)
"""图片展示模式。"""
@dataclass(slots=True)
class _MessageRenderResult:
"""可渲染结果与是否有工具调用信息。"""
message_panel: Panel
tool_call_panels: List[Panel]
class PromptCLIVisualizer:
"""负责构建 CLI 下 prompt 展示所需的所有可视化组件。"""
@staticmethod
def _get_role_badge_style(role: str) -> str:
if role == "system":
return "bold white on blue"
if role == "user":
return "bold black on green"
if role == "assistant":
return "bold black on yellow"
if role == "tool":
return "bold white on magenta"
return "bold white on bright_black"
@staticmethod
def _get_role_badge_label(role: str) -> str:
if role == "system":
return "系统"
if role == "user":
return "用户"
if role == "assistant":
return "助手"
if role == "tool":
return "工具"
return "未知"
@staticmethod
def _format_token_count(token_count: int) -> str:
if token_count >= 10_000:
return f"{token_count / 1000:.1f}k"
return str(token_count)
@classmethod
def build_prompt_stats_text(
cls,
*,
selected_history_count: int,
built_message_count: int,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
) -> str:
"""构造 prompt 统计文本。"""
return (
f"上下文消息数量={selected_history_count} "
f"已构建消息数={built_message_count} "
f"实际输入Token={cls._format_token_count(prompt_tokens)} "
f"输出Token={cls._format_token_count(completion_tokens)} "
f"总Token={cls._format_token_count(total_tokens)}"
)
@staticmethod
def _normalize_image_format(image_format: str) -> str:
"""归一化图片扩展名。"""
normalized = image_format.strip().lower()
if normalized == "jpg":
return "jpeg"
return normalized
@staticmethod
def _build_image_cache_path(image_format: str, image_base64: str) -> Path:
image_format = PromptCLIVisualizer._normalize_image_format(image_format)
root = Path(tempfile.gettempdir()) / "maisaka_prompt_images"
root.mkdir(parents=True, exist_ok=True)
digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest()
return root / f"{digest}.{image_format}"
@staticmethod
def _build_file_uri(file_path: Path) -> str:
normalized = file_path.as_posix()
return f"file:///{quote(normalized, safe='/:')}"
@staticmethod
def _build_official_image_path(image_format: str, image_base64: str) -> Path | None:
normalized_format = PromptCLIVisualizer._normalize_image_format(image_format)
try:
image_bytes = b64decode(image_base64)
except Exception:
return None
digest = hashlib.sha256(image_bytes).hexdigest()
official_path = DATA_IMAGE_DIR / f"{digest}.{normalized_format}"
if official_path.exists():
return official_path
return None
@staticmethod
def _build_image_file_link(image_format: str, image_base64: str) -> tuple[str, Path] | None:
"""优先返回正式图片路径;不存在时回退到临时缓存路径。"""
normalized_format = PromptCLIVisualizer._normalize_image_format(image_format) or "bin"
official_path = PromptCLIVisualizer._build_official_image_path(image_format, image_base64)
if official_path is not None:
return PromptCLIVisualizer._build_file_uri(official_path), official_path
try:
image_bytes = b64decode(image_base64)
except Exception:
return None
path = PromptCLIVisualizer._build_image_cache_path(normalized_format, image_base64)
if not path.exists():
try:
path.write_bytes(image_bytes)
except Exception:
return None
return PromptCLIVisualizer._build_file_uri(path), path
@classmethod
def _render_image_item(cls, image_format: str, image_base64: str, settings: PromptImageDisplaySettings) -> Panel:
normalized_format = cls._normalize_image_format(image_format)
approx_size = max(0, len(image_base64) * 3 // 4)
size_text = f"{approx_size / 1024:.1f} KB" if approx_size >= 1024 else f"{approx_size} B"
preview_parts: List[RenderableType] = [
Text(f"图片格式 image/{normalized_format} {size_text}", style="magenta")
]
if settings.display_mode == PromptImageDisplayMode.PATH_LINK:
path_result = cls._build_image_file_link(image_format, image_base64)
if path_result is not None:
file_uri, file_path = path_result
preview_parts.append(Text.from_markup(f"\n[link={file_uri}]点击打开图片[/link]", style="cyan"))
preview_parts.append(Text(f"\n{file_path}", style="dim"))
return Panel(
Group(*preview_parts),
border_style="magenta",
padding=(0, 1),
)
@classmethod
def _render_message_content(cls, content: Any, settings: PromptImageDisplaySettings) -> RenderableType:
if isinstance(content, str):
return Text(content)
if isinstance(content, list):
parts: List[RenderableType] = []
for item in content:
if isinstance(item, str):
parts.append(Text(item))
continue
if isinstance(item, tuple) and len(item) == 2:
image_format, image_base64 = item
if isinstance(image_format, str) and isinstance(image_base64, str):
parts.append(cls._render_image_item(image_format, image_base64, settings))
continue
if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str):
parts.append(Text(item["text"]))
else:
parts.append(Pretty(item, expand_all=True))
return Group(*parts) if parts else Text("")
if content is None:
return Text("")
return Pretty(content, expand_all=True)
@classmethod
def format_tool_call_for_display(cls, tool_call: Any) -> Dict[str, Any]:
if isinstance(tool_call, dict):
function_info = tool_call.get("function", {})
return {
"id": tool_call.get("id"),
"name": function_info.get("name", tool_call.get("name")),
"arguments": function_info.get("arguments", tool_call.get("arguments")),
}
return {
"id": getattr(tool_call, "call_id", getattr(tool_call, "id", None)),
"name": getattr(tool_call, "func_name", getattr(tool_call, "name", None)),
"arguments": getattr(tool_call, "args", getattr(tool_call, "arguments", None)),
}
@classmethod
def _render_tool_call_panel(cls, tool_call: Any, index: int, parent_index: int) -> Panel:
title = Text.assemble(
Text(" 工具调用 ", style="bold white on magenta"),
Text(f" #{parent_index}.{index}", style="muted"),
)
return Panel(
Pretty(cls.format_tool_call_for_display(tool_call), expand_all=True),
title=title,
border_style="magenta",
padding=(0, 1),
)
@classmethod
def _render_message_panel(cls, message: Any, index: int, settings: PromptImageDisplaySettings) -> _MessageRenderResult:
if isinstance(message, dict):
raw_role = message.get("role", "unknown")
content = message.get("content")
tool_call_id = message.get("tool_call_id")
else:
raw_role = getattr(message, "role", "unknown")
content = getattr(message, "content", None)
tool_call_id = getattr(message, "tool_call_id", None)
role = raw_role.value if hasattr(raw_role, "value") else str(raw_role)
title = Text.assemble(
Text(f" {cls._get_role_badge_label(role)} ", style=cls._get_role_badge_style(role)),
Text(f" #{index}", style="muted"),
)
parts: List[RenderableType] = []
if content not in (None, "", []):
parts.append(Text(" 内容 ", style="bold cyan"))
parts.append(cls._render_message_content(content, settings))
if tool_call_id:
parts.append(
Text.assemble(
Text(" 工具调用ID ", style="bold magenta"),
Text(" "),
Text(str(tool_call_id), style="magenta"),
)
)
if not parts:
parts.append(Text("[空]", style="muted"))
message_panel = Panel(
Group(*parts),
title=title,
border_style="dim",
padding=(0, 1),
)
tool_call_panels: List[Panel] = []
tool_calls = getattr(message, "tool_calls", None)
if tool_calls:
for tool_call_index, tool_call in enumerate(tool_calls, start=1):
tool_call_panels.append(cls._render_tool_call_panel(tool_call, tool_call_index, index))
return _MessageRenderResult(message_panel=message_panel, tool_call_panels=tool_call_panels)
@classmethod
def build_prompt_panels(
cls,
messages: list[Any],
*,
image_display_mode: Literal["legacy", "path_link"],
) -> List[Panel]:
"""构建完整 prompt 可视化面板。"""
if image_display_mode not in {mode.value for mode in PromptImageDisplayMode}:
image_display_mode = PromptImageDisplayMode.LEGACY
settings = PromptImageDisplaySettings(
display_mode=PromptImageDisplayMode(image_display_mode),
)
ordered_panels: List[Panel] = []
for index, message in enumerate(messages, start=1):
message_render_result = cls._render_message_panel(message, index, settings)
ordered_panels.append(message_render_result.message_panel)
ordered_panels.extend(message_render_result.tool_call_panels)
return ordered_panels

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,10 @@ from typing import Literal, Optional
import asyncio
import time
from rich.panel import Panel
from rich.text import Text
from src.cli.console import console
from src.chat.heart_flow.heartFC_utils import CycleDetail
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager
from src.chat.message_receive.message import SessionMessage
@@ -45,7 +49,10 @@ class MaisakaHeartFlowChatting:
session_name = chat_manager.get_session_name(session_id) or session_id
self.log_prefix = f"[{session_name}]"
self._chat_loop_service = MaisakaChatLoopService()
self._chat_loop_service = MaisakaChatLoopService(
session_id=session_id,
is_group_chat=self.chat_stream.is_group_session,
)
self._chat_history: list[LLMContextMessage] = []
self.history_loop: list[CycleDetail] = []
@@ -431,6 +438,40 @@ class MaisakaHeartFlowChatting:
return GroupInfo(group_id=group_info.group_id, group_name=group_info.group_name)
@staticmethod
def _format_token_count(token_count: int) -> str:
"""格式化 token 数量展示文本。"""
if token_count >= 10_000:
return f"{token_count / 1000:.1f}k"
return str(token_count)
def _render_context_usage_panel(
self,
*,
selected_history_count: int,
prompt_tokens: int,
) -> None:
"""在终端展示当前聊天流的上下文占用情况。"""
if not global_config.maisaka.show_thinking:
return
session_name = chat_manager.get_session_name(self.session_id) or self.session_id
body = "\n".join(
[
f"聊天流: {session_name}",
f"Chat ID: {self.session_id}",
f"上下文占用: {selected_history_count}条 / {self._format_token_count(prompt_tokens)}",
]
)
console.print(
Panel(
Text(body),
title="MaiSaka 上下文占用",
border_style="bright_blue",
padding=(0, 1),
)
)
def _log_cycle_started(self, cycle_detail: CycleDetail, round_index: int) -> None:
logger.info(
f"{self.log_prefix} MaiSaka 轮次开始: 循环编号={cycle_detail.cycle_id} "

View File

@@ -32,19 +32,6 @@ class ToolHandlerContext:
self.last_user_input_time: Optional[datetime] = None
async def handle_stop(tc: ToolCall, chat_history: list[LLMContextMessage]) -> None:
"""处理 stop 工具。"""
console.print("[accent]调用工具: stop()[/accent]")
chat_history.append(
ToolResultMessage(
content="当前轮次结束后将停止对话循环。",
timestamp=datetime.now(),
tool_call_id=tc.call_id,
tool_name=tc.func_name,
)
)
async def handle_wait(tc: ToolCall, chat_history: list[LLMContextMessage], ctx: ToolHandlerContext) -> str:
"""处理 wait 工具。"""
seconds = (tc.args or {}).get("seconds", 30)

View File

@@ -7,7 +7,7 @@ from typing import Dict, Optional
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolProvider, ToolSpec
from .builtin_tools import get_builtin_tool_specs
from .builtin_tool import get_builtin_tool_specs
BuiltinToolHandler = Callable[[ToolInvocation, Optional[ToolExecutionContext]], Awaitable[ToolExecutionResult]]

View File

@@ -14,7 +14,7 @@ import httpx
from src.cli.console import console
from src.core.tooling import ToolExecutionResult
from .config import MCPClientRuntimeConfig, MCPRootRuntimeConfig, MCPServerRuntimeConfig
from .config import MCPClientRuntimeConfig, MCPServerRuntimeConfig
from .hooks import MCPHostCallbacks
from .models import (
MCPPromptResult,

View File

@@ -20,5 +20,8 @@ ENV_HOST_VERSION = "MAIBOT_HOST_VERSION"
ENV_EXTERNAL_PLUGIN_IDS = "MAIBOT_EXTERNAL_PLUGIN_IDS"
"""Runner 启动时可视为已满足的外部插件依赖版本映射JSON 对象)"""
ENV_BLOCKED_PLUGIN_REASONS = "MAIBOT_BLOCKED_PLUGIN_REASONS"
"""Runner 启动时收到的拒绝加载插件原因映射JSON 对象)"""
ENV_GLOBAL_CONFIG_SNAPSHOT = "MAIBOT_GLOBAL_CONFIG_SNAPSHOT"
"""Runner 启动时注入的全局配置快照JSON 对象)"""

View File

@@ -1,9 +1,11 @@
from .components import RuntimeComponentCapabilityMixin
from .core import RuntimeCoreCapabilityMixin
from .data import RuntimeDataCapabilityMixin
from .render import RuntimeRenderCapabilityMixin
__all__ = [
"RuntimeComponentCapabilityMixin",
"RuntimeCoreCapabilityMixin",
"RuntimeDataCapabilityMixin",
"RuntimeRenderCapabilityMixin",
]

View File

@@ -458,6 +458,17 @@ class RuntimeComponentCapabilityMixin:
async def _cap_component_get_plugin_info(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
"""获取指定插件的基础信息。
Args:
plugin_id: 当前调用方插件 ID。
capability: 当前能力名称。
args: 能力调用参数。
Returns:
Any: 插件基础信息响应。
"""
plugin_name: str = args.get("plugin_name", plugin_id)
try:
sv = self._get_supervisor_for_plugin(plugin_name)
@@ -473,10 +484,46 @@ class RuntimeComponentCapabilityMixin:
"description": "",
"author": "",
"enabled": True,
"default_config": reg.default_config,
"config_schema": reg.config_schema,
},
}
return {"success": False, "error": f"未找到插件: {plugin_name}"}
async def _cap_component_get_plugin_config_schema(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
"""获取指定插件注册时上报的配置 Schema。
Args:
plugin_id: 当前调用方插件 ID。
capability: 当前能力名称。
args: 能力调用参数。
Returns:
Any: 包含配置 Schema 与默认配置的响应。
"""
plugin_name: str = args.get("plugin_name", plugin_id)
try:
sv = self._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
if sv is None:
return {"success": False, "error": f"未找到插件: {plugin_name}"}
registration = sv._registered_plugins.get(plugin_name)
if registration is None:
return {"success": False, "error": f"未找到插件: {plugin_name}"}
return {
"success": True,
"plugin_id": plugin_name,
"schema": registration.config_schema,
"default_config": registration.default_config,
}
async def _cap_component_list_loaded_plugins(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:

View File

@@ -81,6 +81,7 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi
_register("component.get_all_plugins", manager._cap_component_get_all_plugins)
_register("component.get_plugin_info", manager._cap_component_get_plugin_info)
_register("component.get_plugin_config_schema", manager._cap_component_get_plugin_config_schema)
_register("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
_register("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
_register("component.enable", manager._cap_component_enable)
@@ -90,4 +91,5 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi
_register("component.reload_plugin", manager._cap_component_reload_plugin)
_register("knowledge.search", manager._cap_knowledge_search)
_register("render.html2png", manager._cap_render_html2png)
logger.debug("已注册全部主程序能力实现")

View File

@@ -0,0 +1,121 @@
"""插件运行时的浏览器渲染能力。"""
from typing import Any, Dict
from src.common.logger import get_logger
from src.services.html_render_service import HtmlRenderRequest, get_html_render_service
logger = get_logger("plugin_runtime.integration")
class RuntimeRenderCapabilityMixin:
"""插件运行时的浏览器渲染能力混入。"""
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
"""将任意值尽量转换为整数。
Args:
value: 原始输入值。
default: 转换失败时返回的默认值。
Returns:
int: 规范化后的整数结果。
"""
try:
return int(value)
except (TypeError, ValueError):
return default
@staticmethod
def _coerce_float(value: Any, default: float) -> float:
"""将任意值尽量转换为浮点数。
Args:
value: 原始输入值。
default: 转换失败时返回的默认值。
Returns:
float: 规范化后的浮点结果。
"""
try:
return float(value)
except (TypeError, ValueError):
return default
@staticmethod
def _coerce_bool(value: Any, default: bool = False) -> bool:
"""将任意值转换为布尔值。
Args:
value: 原始输入值。
default: 输入为空时返回的默认值。
Returns:
bool: 规范化后的布尔结果。
"""
if value is None:
return default
if isinstance(value, str):
normalized_value = value.strip().lower()
if normalized_value in {"0", "false", "no", "off"}:
return False
if normalized_value in {"1", "true", "yes", "on"}:
return True
return bool(value)
def _build_html_render_request(self, args: Dict[str, Any]) -> HtmlRenderRequest:
"""根据 capability 调用参数构造渲染请求。
Args:
args: capability 调用参数。
Returns:
HtmlRenderRequest: 结构化后的渲染请求。
"""
viewport = args.get("viewport", {})
viewport_width = 900
viewport_height = 500
if isinstance(viewport, dict):
viewport_width = self._coerce_int(viewport.get("width"), viewport_width)
viewport_height = self._coerce_int(viewport.get("height"), viewport_height)
return HtmlRenderRequest(
html=str(args.get("html", "") or ""),
selector=str(args.get("selector", "body") or "body"),
viewport_width=viewport_width,
viewport_height=viewport_height,
device_scale_factor=self._coerce_float(args.get("device_scale_factor"), 2.0),
full_page=self._coerce_bool(args.get("full_page"), False),
omit_background=self._coerce_bool(args.get("omit_background"), False),
wait_until=str(args.get("wait_until", "load") or "load"),
wait_for_selector=str(args.get("wait_for_selector", "") or ""),
wait_for_timeout_ms=self._coerce_int(args.get("wait_for_timeout_ms"), 0),
timeout_ms=self._coerce_int(args.get("timeout_ms"), 0),
allow_network=self._coerce_bool(args.get("allow_network"), False),
)
async def _cap_render_html2png(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""将 HTML 内容渲染为 PNG 图片。
Args:
plugin_id: 调用该能力的插件 ID。
capability: 当前能力名称。
args: 能力调用参数。
Returns:
Any: 标准化后的能力返回结构。
"""
del plugin_id, capability
try:
request = self._build_html_render_request(args)
result = await get_html_render_service().render_html_to_png(request)
return {"success": True, "result": result.to_payload()}
except Exception as exc:
logger.error(f"[cap.render.html2png] 执行失败: {exc}", exc_info=True)
return {"success": False, "error": str(exc)}

View File

@@ -6,6 +6,7 @@
from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple, cast
from src.common.logger import get_logger
@@ -858,5 +859,77 @@ class ComponentQueryService:
logger.error(f"读取插件 {plugin_name} 配置失败: {exc}", exc_info=True)
return None
def get_plugin_default_config(self, plugin_name: str) -> Optional[dict]:
"""获取指定插件注册时上报的默认配置。
Args:
plugin_name: 插件名称。
Returns:
Optional[dict]: 默认配置字典;未找到时返回 ``None``。
"""
runtime_manager = self._get_runtime_manager()
try:
supervisor = runtime_manager._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
logger.error(f"读取插件默认配置失败: {exc}")
return None
if supervisor is None:
return None
registration = supervisor._registered_plugins.get(plugin_name)
if registration is None:
return None
return dict(registration.default_config)
def get_plugin_config_schema(self, plugin_name: str) -> Optional[dict]:
"""获取指定插件注册时上报的配置 Schema。
Args:
plugin_name: 插件名称。
Returns:
Optional[dict]: 配置 Schema未找到时返回 ``None``。
"""
runtime_manager = self._get_runtime_manager()
try:
supervisor = runtime_manager._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
logger.error(f"读取插件配置 Schema 失败: {exc}")
return None
if supervisor is None:
return None
registration = supervisor._registered_plugins.get(plugin_name)
if registration is None:
return None
return dict(registration.config_schema)
def list_hook_specs(self) -> list[dict[str, Any]]:
"""返回当前运行时公开的 Hook 规格清单。
Returns:
list[dict[str, Any]]: 可直接序列化给 WebUI 的 Hook 规格列表。
"""
runtime_manager = self._get_runtime_manager()
return [
{
"name": spec.name,
"description": spec.description,
"parameters_schema": deepcopy(spec.parameters_schema),
"default_timeout_ms": spec.default_timeout_ms,
"allow_blocking": spec.allow_blocking,
"allow_observe": spec.allow_observe,
"allow_abort": spec.allow_abort,
"allow_kwargs_mutation": spec.allow_kwargs_mutation,
}
for spec in runtime_manager.list_hook_specs()
]
component_query_service = ComponentQueryService()

View File

@@ -0,0 +1,441 @@
"""插件 Python 依赖流水线。
负责在 Host 侧统一完成以下工作:
1. 扫描插件 Manifest
2. 检测插件与主程序、插件与插件之间的 Python 依赖冲突;
3. 为可加载插件自动安装缺失的 Python 依赖;
4. 产出最终的拒绝加载列表,供运行时使用。
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import asyncio
import shutil
import subprocess
import sys
from packaging.utils import canonicalize_name
from src.common.logger import get_logger
from src.plugin_runtime.runner.manifest_validator import ManifestValidator, PluginManifest
logger = get_logger("plugin_runtime.dependency_pipeline")
@dataclass(frozen=True)
class PackageDependencyUsage:
"""记录单个插件对某个 Python 包的依赖声明。"""
package_name: str
plugin_id: str
version_spec: str
@dataclass(frozen=True)
class CombinedPackageRequirement:
"""表示一个已经合并后的 Python 包安装需求。"""
package_name: str
plugin_ids: Tuple[str, ...]
requirement_text: str
version_spec: str
@dataclass(frozen=True)
class DependencyPipelinePlan:
"""表示一次依赖分析后得到的计划。"""
blocked_plugin_reasons: Dict[str, str]
install_requirements: Tuple[CombinedPackageRequirement, ...]
@dataclass(frozen=True)
class DependencyPipelineResult:
"""表示一次依赖流水线执行后的结果。"""
blocked_plugin_reasons: Dict[str, str]
environment_changed: bool
install_requirements: Tuple[CombinedPackageRequirement, ...]
class PluginDependencyPipeline:
"""插件依赖流水线。
该类不负责插件启停,只负责对插件目录进行依赖分析,并在必要时
使用 ``uv`` 为可加载插件补齐缺失的 Python 依赖。
"""
def __init__(self, project_root: Optional[Path] = None) -> None:
"""初始化依赖流水线。
Args:
project_root: 项目根目录;留空时自动推断。
"""
self._project_root: Path = project_root or Path(__file__).resolve().parents[2]
self._manifest_validator: ManifestValidator = ManifestValidator(
project_root=self._project_root,
validate_python_package_dependencies=False,
)
async def execute(self, plugin_dirs: Iterable[Path]) -> DependencyPipelineResult:
"""执行完整的依赖分析与自动安装流程。
Args:
plugin_dirs: 需要扫描的插件根目录集合。
Returns:
DependencyPipelineResult: 最终的阻止加载结果与环境变更状态。
"""
plan = self.build_plan(plugin_dirs)
if not plan.install_requirements:
return DependencyPipelineResult(
blocked_plugin_reasons=dict(plan.blocked_plugin_reasons),
environment_changed=False,
install_requirements=plan.install_requirements,
)
install_succeeded, error_message = await self._install_requirements(plan.install_requirements)
if install_succeeded:
return DependencyPipelineResult(
blocked_plugin_reasons=dict(plan.blocked_plugin_reasons),
environment_changed=True,
install_requirements=plan.install_requirements,
)
blocked_plugin_reasons = dict(plan.blocked_plugin_reasons)
affected_plugin_ids = sorted(
{
plugin_id
for requirement in plan.install_requirements
for plugin_id in requirement.plugin_ids
}
)
for plugin_id in affected_plugin_ids:
self._append_block_reason(
blocked_plugin_reasons,
plugin_id,
f"自动安装 Python 依赖失败: {error_message}",
)
return DependencyPipelineResult(
blocked_plugin_reasons=blocked_plugin_reasons,
environment_changed=False,
install_requirements=plan.install_requirements,
)
def build_plan(self, plugin_dirs: Iterable[Path]) -> DependencyPipelinePlan:
"""构建依赖分析计划。
Args:
plugin_dirs: 需要扫描的插件根目录集合。
Returns:
DependencyPipelinePlan: 分析后的阻止加载列表与安装计划。
"""
manifests = self._collect_manifests(plugin_dirs)
blocked_plugin_reasons = self._detect_host_conflicts(manifests)
plugin_conflict_reasons = self._detect_plugin_conflicts(manifests, blocked_plugin_reasons)
for plugin_id, reason in plugin_conflict_reasons.items():
self._append_block_reason(blocked_plugin_reasons, plugin_id, reason)
install_requirements = self._build_install_requirements(manifests, blocked_plugin_reasons)
return DependencyPipelinePlan(
blocked_plugin_reasons=blocked_plugin_reasons,
install_requirements=install_requirements,
)
def _collect_manifests(self, plugin_dirs: Iterable[Path]) -> Dict[str, PluginManifest]:
"""收集所有可成功解析的插件 Manifest。
Args:
plugin_dirs: 需要扫描的插件根目录集合。
Returns:
Dict[str, PluginManifest]: 以插件 ID 为键的 Manifest 映射。
"""
manifests: Dict[str, PluginManifest] = {}
for _plugin_path, manifest in self._manifest_validator.iter_plugin_manifests(plugin_dirs):
manifests[manifest.id] = manifest
return manifests
def _detect_host_conflicts(self, manifests: Dict[str, PluginManifest]) -> Dict[str, str]:
"""检测插件与主程序依赖之间的冲突。
Args:
manifests: 当前已解析到的插件 Manifest 映射。
Returns:
Dict[str, str]: 需要被阻止加载的插件及原因。
"""
host_requirements = self._manifest_validator.load_host_dependency_requirements()
blocked_plugin_reasons: Dict[str, str] = {}
for manifest in manifests.values():
for dependency in manifest.python_package_dependencies:
package_specifier = self._manifest_validator.build_specifier_set(dependency.version_spec)
if package_specifier is None:
self._append_block_reason(
blocked_plugin_reasons,
manifest.id,
f"Python 包依赖声明无效: {dependency.name}{dependency.version_spec}",
)
continue
normalized_package_name = canonicalize_name(dependency.name)
host_requirement = host_requirements.get(normalized_package_name)
if host_requirement is None:
continue
if self._manifest_validator.requirements_may_overlap(
host_requirement.specifier,
package_specifier,
):
continue
host_specifier_text = str(host_requirement.specifier or "") or "任意版本"
self._append_block_reason(
blocked_plugin_reasons,
manifest.id,
(
f"Python 包依赖与主程序冲突: {dependency.name} 需要 "
f"{dependency.version_spec},主程序约束为 {host_specifier_text}"
),
)
return blocked_plugin_reasons
def _detect_plugin_conflicts(
self,
manifests: Dict[str, PluginManifest],
blocked_plugin_reasons: Dict[str, str],
) -> Dict[str, str]:
"""检测插件之间的 Python 依赖冲突。
Args:
manifests: 当前已解析到的插件 Manifest 映射。
blocked_plugin_reasons: 已经因为其他原因被阻止加载的插件。
Returns:
Dict[str, str]: 新增的插件冲突原因映射。
"""
blocked_by_plugin_conflicts: Dict[str, str] = {}
dependency_usages = self._collect_package_usages(manifests, blocked_plugin_reasons)
for _package_name, usages in dependency_usages.items():
display_package_name = usages[0].package_name
for index, left_usage in enumerate(usages):
for right_usage in usages[index + 1 :]:
left_specifier = self._manifest_validator.build_specifier_set(left_usage.version_spec)
right_specifier = self._manifest_validator.build_specifier_set(right_usage.version_spec)
if left_specifier is None or right_specifier is None:
continue
if self._manifest_validator.requirements_may_overlap(left_specifier, right_specifier):
continue
left_reason = (
f"Python 包依赖冲突: 与插件 {right_usage.plugin_id}{display_package_name} 上的约束不兼容 "
f"({left_usage.version_spec} vs {right_usage.version_spec})"
)
right_reason = (
f"Python 包依赖冲突: 与插件 {left_usage.plugin_id}{display_package_name} 上的约束不兼容 "
f"({right_usage.version_spec} vs {left_usage.version_spec})"
)
self._append_block_reason(blocked_by_plugin_conflicts, left_usage.plugin_id, left_reason)
self._append_block_reason(blocked_by_plugin_conflicts, right_usage.plugin_id, right_reason)
return blocked_by_plugin_conflicts
def _collect_package_usages(
self,
manifests: Dict[str, PluginManifest],
blocked_plugin_reasons: Dict[str, str],
) -> Dict[str, List[PackageDependencyUsage]]:
"""收集所有未被阻止加载插件的包依赖声明。
Args:
manifests: 当前已解析到的插件 Manifest 映射。
blocked_plugin_reasons: 已经被阻止加载的插件及原因。
Returns:
Dict[str, List[PackageDependencyUsage]]: 按规范化包名分组后的依赖声明。
"""
dependency_usages: Dict[str, List[PackageDependencyUsage]] = {}
for manifest in manifests.values():
if manifest.id in blocked_plugin_reasons:
continue
for dependency in manifest.python_package_dependencies:
normalized_package_name = canonicalize_name(dependency.name)
dependency_usages.setdefault(normalized_package_name, []).append(
PackageDependencyUsage(
package_name=dependency.name,
plugin_id=manifest.id,
version_spec=dependency.version_spec,
)
)
return dependency_usages
def _build_install_requirements(
self,
manifests: Dict[str, PluginManifest],
blocked_plugin_reasons: Dict[str, str],
) -> Tuple[CombinedPackageRequirement, ...]:
"""构建需要安装到当前环境的 Python 包需求列表。
Args:
manifests: 当前已解析到的插件 Manifest 映射。
blocked_plugin_reasons: 已经被阻止加载的插件及原因。
Returns:
Tuple[CombinedPackageRequirement, ...]: 需要安装或调整版本的依赖列表。
"""
combined_requirements: List[CombinedPackageRequirement] = []
dependency_usages = self._collect_package_usages(manifests, blocked_plugin_reasons)
for usages in dependency_usages.values():
merged_specifier_text = self._merge_specifier_texts([usage.version_spec for usage in usages])
package_name = usages[0].package_name
requirement_text = f"{package_name}{merged_specifier_text}"
installed_version = self._manifest_validator.get_installed_package_version(package_name)
if installed_version is not None and self._manifest_validator.version_matches_specifier(
installed_version,
merged_specifier_text,
):
continue
combined_requirements.append(
CombinedPackageRequirement(
package_name=package_name,
plugin_ids=tuple(sorted({usage.plugin_id for usage in usages})),
requirement_text=requirement_text,
version_spec=merged_specifier_text,
)
)
return tuple(sorted(combined_requirements, key=lambda requirement: canonicalize_name(requirement.package_name)))
@staticmethod
def _merge_specifier_texts(specifier_texts: Sequence[str]) -> str:
"""合并多个版本约束文本。
Args:
specifier_texts: 需要合并的版本约束文本序列。
Returns:
str: 合并后的版本约束文本。
"""
merged_parts: List[str] = []
for specifier_text in specifier_texts:
for part in str(specifier_text or "").split(","):
normalized_part = part.strip()
if not normalized_part or normalized_part in merged_parts:
continue
merged_parts.append(normalized_part)
return f"{','.join(merged_parts)}" if merged_parts else ""
async def _install_requirements(self, requirements: Sequence[CombinedPackageRequirement]) -> Tuple[bool, str]:
"""安装指定的 Python 包需求列表。
Args:
requirements: 需要安装的依赖列表。
Returns:
Tuple[bool, str]: 安装是否成功,以及错误摘要。
"""
requirement_texts = [requirement.requirement_text for requirement in requirements]
if not requirement_texts:
return True, ""
logger.info(f"开始自动安装插件 Python 依赖: {', '.join(requirement_texts)}")
command = self._build_install_command(requirement_texts)
try:
completed_process = await asyncio.to_thread(
subprocess.run,
command,
capture_output=True,
check=False,
cwd=self._project_root,
text=True,
)
except Exception as exc:
return False, str(exc)
if completed_process.returncode == 0:
logger.info("插件 Python 依赖自动安装完成")
return True, ""
output = self._summarize_install_error(completed_process.stdout, completed_process.stderr)
return False, output or f"命令执行失败,退出码 {completed_process.returncode}"
@staticmethod
def _build_install_command(requirement_texts: Sequence[str]) -> List[str]:
"""构造依赖安装命令。
Args:
requirement_texts: 待安装的依赖文本序列。
Returns:
List[str]: 适用于 ``subprocess.run`` 的命令参数列表。
"""
if shutil.which("uv"):
return ["uv", "pip", "install", "--python", sys.executable, *requirement_texts]
return [sys.executable, "-m", "pip", "install", *requirement_texts]
@staticmethod
def _summarize_install_error(stdout: str, stderr: str) -> str:
"""提炼安装失败输出。
Args:
stdout: 标准输出内容。
stderr: 标准错误内容。
Returns:
str: 简短的错误摘要。
"""
merged_output = "\n".join(part.strip() for part in (stderr, stdout) if part and part.strip()).strip()
if not merged_output:
return ""
lines = [line.strip() for line in merged_output.splitlines() if line.strip()]
return " | ".join(lines[-5:])
@staticmethod
def _append_block_reason(
blocked_plugin_reasons: Dict[str, str],
plugin_id: str,
reason: str,
) -> None:
"""向阻止加载映射中追加原因。
Args:
blocked_plugin_reasons: 待更新的阻止加载映射。
plugin_id: 目标插件 ID。
reason: 需要追加的原因文本。
"""
existing_reason = blocked_plugin_reasons.get(plugin_id)
if existing_reason is None:
blocked_plugin_reasons[plugin_id] = reason
return
existing_parts = [part.strip() for part in existing_reason.split("") if part.strip()]
if reason in existing_parts:
return
blocked_plugin_reasons[plugin_id] = f"{existing_reason}{reason}"

View File

@@ -0,0 +1,52 @@
"""内置命名 Hook 目录注册器。"""
from __future__ import annotations
from collections.abc import Callable
from typing import List
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
HookSpecRegistrar = Callable[[HookSpecRegistry], List[HookSpec]]
"""单个业务模块向注册中心写入 Hook 规格的注册器签名。"""
def _get_builtin_hook_spec_registrars() -> List[HookSpecRegistrar]:
"""返回当前内置 Hook 规格注册器列表。
Returns:
List[HookSpecRegistrar]: 已启用的内置 Hook 注册器列表。
"""
from src.chat.message_receive.bot import register_chat_hook_specs
from src.chat.emoji_system.emoji_manager import register_emoji_hook_specs
from src.learners.expression_learner import register_expression_hook_specs
from src.learners.jargon_miner import register_jargon_hook_specs
from src.maisaka.chat_loop_service import register_maisaka_hook_specs
from src.services.send_service import register_send_service_hook_specs
return [
register_chat_hook_specs,
register_emoji_hook_specs,
register_jargon_hook_specs,
register_expression_hook_specs,
register_send_service_hook_specs,
register_maisaka_hook_specs,
]
def register_builtin_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""向注册中心写入全部内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 本次完成注册后的全部内置 Hook 规格。
"""
registered_specs: List[HookSpec] = []
for registrar in _get_builtin_hook_spec_registrars():
registered_specs.extend(registrar(registry))
return registered_specs

View File

@@ -0,0 +1,178 @@
"""运行时 Hook 载荷序列化辅助。"""
from __future__ import annotations
from typing import Any, Dict, List, Sequence
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.llm_service_data_models import PromptMessage
from src.llm_models.payload_content.message import Message
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, normalize_tool_options
from src.plugin_runtime.host.message_utils import PluginMessageUtils
def serialize_session_message(message: SessionMessage) -> Dict[str, Any]:
"""将会话消息序列化为 Hook 可传输载荷。
Args:
message: 待序列化的会话消息。
Returns:
Dict[str, Any]: 可通过插件运行时传输的消息字典。
"""
return dict(PluginMessageUtils._session_message_to_dict(message))
def deserialize_session_message(raw_message: Any) -> SessionMessage:
"""从 Hook 载荷恢复会话消息。
Args:
raw_message: Hook 返回的消息字典。
Returns:
SessionMessage: 恢复后的会话消息对象。
Raises:
ValueError: 消息结构不合法时抛出。
"""
if not isinstance(raw_message, dict):
raise ValueError("Hook 返回的 `message` 必须是字典")
return PluginMessageUtils._build_session_message_from_dict(raw_message)
def serialize_tool_calls(tool_calls: Sequence[ToolCall] | None) -> List[Dict[str, Any]]:
"""将工具调用列表序列化为 Hook 可传输载荷。
Args:
tool_calls: 原始工具调用列表。
Returns:
List[Dict[str, Any]]: 序列化后的工具调用列表。
"""
if not tool_calls:
return []
return [
{
"id": tool_call.call_id,
"function": {
"name": tool_call.func_name,
"arguments": dict(tool_call.args or {}),
},
}
for tool_call in tool_calls
]
def deserialize_tool_calls(raw_tool_calls: Any) -> List[ToolCall]:
"""从 Hook 载荷恢复工具调用列表。
Args:
raw_tool_calls: Hook 返回的工具调用列表。
Returns:
List[ToolCall]: 恢复后的工具调用列表。
Raises:
ValueError: 结构不合法时抛出。
"""
if raw_tool_calls in (None, []):
return []
if not isinstance(raw_tool_calls, list):
raise ValueError("Hook 返回的 `tool_calls` 必须是列表")
normalized_tool_calls: List[ToolCall] = []
for raw_tool_call in raw_tool_calls:
if not isinstance(raw_tool_call, dict):
raise ValueError("Hook 返回的工具调用项必须是字典")
function_info = raw_tool_call.get("function", {})
if isinstance(function_info, dict):
function_name = function_info.get("name")
function_arguments = function_info.get("arguments")
else:
function_name = raw_tool_call.get("name")
function_arguments = raw_tool_call.get("arguments")
call_id = raw_tool_call.get("id") or raw_tool_call.get("call_id")
if not isinstance(call_id, str) or not isinstance(function_name, str):
raise ValueError("Hook 返回的工具调用缺少 `id` 或函数名称")
normalized_tool_calls.append(
ToolCall(
call_id=call_id,
func_name=function_name,
args=function_arguments if isinstance(function_arguments, dict) else {},
)
)
return normalized_tool_calls
def serialize_prompt_messages(messages: Sequence[Message]) -> List[PromptMessage]:
"""将 LLM 消息列表序列化为 Hook 可传输载荷。
Args:
messages: 原始 LLM 消息列表。
Returns:
List[PromptMessage]: 序列化后的消息字典列表。
"""
serialized_messages: List[PromptMessage] = []
for message in messages:
serialized_message: PromptMessage = {
"role": message.role.value,
"content": message.content,
}
if message.tool_call_id:
serialized_message["tool_call_id"] = message.tool_call_id
if message.tool_calls:
serialized_message["tool_calls"] = serialize_tool_calls(message.tool_calls)
serialized_messages.append(serialized_message)
return serialized_messages
def deserialize_prompt_messages(raw_messages: Any) -> List[Message]:
"""从 Hook 载荷恢复 LLM 消息列表。
Args:
raw_messages: Hook 返回的消息列表。
Returns:
List[Message]: 恢复后的 LLM 消息列表。
Raises:
ValueError: 结构不合法时抛出。
"""
if not isinstance(raw_messages, list):
raise ValueError("Hook 返回的 `messages` 必须是列表")
from src.services.llm_service import _build_message_from_dict
normalized_messages: List[Message] = []
for raw_message in raw_messages:
if not isinstance(raw_message, dict):
raise ValueError("Hook 返回的消息项必须是字典")
normalized_messages.append(_build_message_from_dict(raw_message))
return normalized_messages
def serialize_tool_definitions(tool_definitions: Sequence[ToolDefinitionInput]) -> List[Dict[str, Any]]:
"""将工具定义列表序列化为 Hook 可传输载荷。
Args:
tool_definitions: 原始工具定义列表。
Returns:
List[Dict[str, Any]]: 序列化后的工具定义列表。
"""
normalized_tool_options = normalize_tool_options(list(tool_definitions))
if not normalized_tool_options:
return []
return [tool_option.to_openai_function_schema() for tool_option in normalized_tool_options]

View File

@@ -0,0 +1,31 @@
"""Hook 参数模型构造辅助。"""
from __future__ import annotations
from copy import deepcopy
from typing import Any, Dict, Sequence
def build_object_schema(
properties: Dict[str, Dict[str, Any]],
*,
required: Sequence[str] | None = None,
) -> Dict[str, Any]:
"""构造对象级 JSON Schema。
Args:
properties: 字段定义映射。
required: 必填字段名列表。
Returns:
Dict[str, Any]: 标准化后的对象级 Schema。
"""
schema: Dict[str, Any] = {
"type": "object",
"properties": deepcopy(properties),
}
normalized_required = [str(item).strip() for item in (required or []) if str(item).strip()]
if normalized_required:
schema["required"] = normalized_required
return schema

View File

@@ -18,9 +18,37 @@ import re
from src.common.logger import get_logger
from src.core.tooling import build_tool_detailed_description
from .hook_spec_registry import HookSpecRegistry
logger = get_logger("plugin_runtime.host.component_registry")
class ComponentRegistrationError(ValueError):
"""组件注册失败异常。"""
def __init__(
self,
message: str,
*,
component_name: str = "",
component_type: str = "",
plugin_id: str = "",
) -> None:
"""初始化组件注册失败异常。
Args:
message: 原始错误信息。
component_name: 组件名称。
component_type: 组件类型。
plugin_id: 插件 ID。
"""
self.component_name = str(component_name or "").strip()
self.component_type = str(component_type or "").strip()
self.plugin_id = str(plugin_id or "").strip()
super().__init__(message)
class ComponentTypes(str, Enum):
ACTION = "ACTION"
COMMAND = "COMMAND"
@@ -359,7 +387,14 @@ class ComponentRegistry:
供业务层查询可用组件、匹配命令、调度 action/event 等。
"""
def __init__(self) -> None:
def __init__(self, hook_spec_registry: Optional[HookSpecRegistry] = None) -> None:
"""初始化组件注册表。
Args:
hook_spec_registry: 可选的 Hook 规格注册中心;提供后会在注册
HookHandler 时执行规格校验。
"""
# 全量索引
self._components: Dict[str, ComponentEntry] = {} # full_name -> comp
@@ -370,6 +405,7 @@ class ComponentRegistry:
# 按插件索引
self._by_plugin: Dict[str, List[ComponentEntry]] = {}
self._hook_spec_registry = hook_spec_registry
@staticmethod
def _convert_action_metadata_to_tool_metadata(
@@ -475,77 +511,211 @@ class ComponentRegistry:
type_dict.clear()
self._by_plugin.clear()
# ====== 注册 / 注销 ======
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
"""注册单个组件
@staticmethod
def _is_legacy_action_component(component: ComponentEntry) -> bool:
"""判断组件是否为兼容旧 Action 的 Tool 条目。
Args:
name: 组件名称不含插件id前缀
component_type: 组件类型(如 `ACTION`、`COMMAND` 等)
plugin_id: 插件id
metadata: 组件元数据
component: 待判断的组件条目。
Returns:
success (bool): 是否成功注册(失败原因通常是组件类型无效)
bool: 是否为兼容旧 Action 组件。
"""
if not isinstance(component, ToolEntry):
return False
return str(component.metadata.get("legacy_component_type", "") or "").strip().upper() == "ACTION"
def _validate_hook_handler_entry(self, component: HookHandlerEntry) -> None:
"""校验 HookHandler 是否满足已注册的 Hook 规格。
Args:
component: 待校验的 HookHandler 条目。
Raises:
ComponentRegistrationError: HookHandler 声明不合法时抛出。
"""
if self._hook_spec_registry is None:
return
hook_spec = self._hook_spec_registry.get_hook_spec(component.hook)
if hook_spec is None:
raise ComponentRegistrationError(
f"HookHandler {component.full_name} 声明了未注册的 Hook: {component.hook}",
component_name=component.name,
component_type=component.component_type.value,
plugin_id=component.plugin_id,
)
if component.is_blocking and not hook_spec.allow_blocking:
raise ComponentRegistrationError(
f"HookHandler {component.full_name} 不能注册为 blockingHook {component.hook} 不允许 blocking 处理器",
component_name=component.name,
component_type=component.component_type.value,
plugin_id=component.plugin_id,
)
if component.is_observe and not hook_spec.allow_observe:
raise ComponentRegistrationError(
f"HookHandler {component.full_name} 不能注册为 observeHook {component.hook} 不允许 observe 处理器",
component_name=component.name,
component_type=component.component_type.value,
plugin_id=component.plugin_id,
)
if component.error_policy == "abort" and not hook_spec.allow_abort:
raise ComponentRegistrationError(
f"HookHandler {component.full_name} 不能使用 error_policy=abortHook {component.hook} 不允许 abort",
component_name=component.name,
component_type=component.component_type.value,
plugin_id=component.plugin_id,
)
def _build_component_entry(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
) -> ComponentEntry:
"""根据声明构造组件条目。
Args:
name: 组件名称。
component_type: 组件类型。
plugin_id: 插件 ID。
metadata: 组件元数据。
Returns:
ComponentEntry: 已构造并完成校验的组件条目。
Raises:
ComponentRegistrationError: 组件声明不合法时抛出。
"""
try:
normalized_type = self._normalize_component_type(component_type)
normalized_metadata = dict(metadata)
if normalized_type == ComponentTypes.ACTION:
normalized_metadata = self._convert_action_metadata_to_tool_metadata(name, normalized_metadata)
comp = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata)
component = ToolEntry(name, ComponentTypes.TOOL.value, plugin_id, normalized_metadata)
elif normalized_type == ComponentTypes.COMMAND:
comp = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata)
component = CommandEntry(name, normalized_type.value, plugin_id, normalized_metadata)
elif normalized_type == ComponentTypes.TOOL:
comp = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata)
component = ToolEntry(name, normalized_type.value, plugin_id, normalized_metadata)
elif normalized_type == ComponentTypes.EVENT_HANDLER:
comp = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
component = EventHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
elif normalized_type == ComponentTypes.HOOK_HANDLER:
comp = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
component = HookHandlerEntry(name, normalized_type.value, plugin_id, normalized_metadata)
self._validate_hook_handler_entry(component)
elif normalized_type == ComponentTypes.MESSAGE_GATEWAY:
comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata)
component = MessageGatewayEntry(name, normalized_type.value, plugin_id, normalized_metadata)
else:
raise ValueError(f"组件类型 {component_type} 不存在")
except ValueError:
logger.error(f"组件类型 {component_type} 不存在")
return False
raise ComponentRegistrationError(
f"组件类型 {component_type} 不存在",
component_name=name,
component_type=component_type,
plugin_id=plugin_id,
)
except ComponentRegistrationError:
raise
except Exception as exc:
raise ComponentRegistrationError(
str(exc),
component_name=name,
component_type=component_type,
plugin_id=plugin_id,
) from exc
if comp.full_name in self._components:
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
old_comp = self._components[comp.full_name]
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
old_list = self._by_plugin.get(old_comp.plugin_id)
if old_list is not None:
with contextlib.suppress(ValueError):
old_list.remove(old_comp)
# 从旧类型索引中移除,防止类型变更时幽灵残留
if old_type_dict := self._by_type.get(old_comp.component_type):
old_type_dict.pop(comp.full_name, None)
return component
self._components[comp.full_name] = comp
self._by_type[comp.component_type][comp.full_name] = comp
self._by_plugin.setdefault(plugin_id, []).append(comp)
def _remove_existing_component_entry(self, component: ComponentEntry) -> None:
"""移除同名旧组件条目。
Args:
component: 即将写入的新组件条目。
"""
if component.full_name not in self._components:
return
logger.warning(f"组件 {component.full_name} 已存在,覆盖")
old_component = self._components[component.full_name]
old_list = self._by_plugin.get(old_component.plugin_id)
if old_list is not None:
with contextlib.suppress(ValueError):
old_list.remove(old_component)
if old_type_dict := self._by_type.get(old_component.component_type):
old_type_dict.pop(component.full_name, None)
def _add_component_entry(self, component: ComponentEntry) -> None:
"""写入单个组件条目到全部索引。
Args:
component: 待写入的组件条目。
"""
self._remove_existing_component_entry(component)
self._components[component.full_name] = component
self._by_type[component.component_type][component.full_name] = component
self._by_plugin.setdefault(component.plugin_id, []).append(component)
# ====== 注册 / 注销 ======
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
"""注册单个组件。
Args:
name: 组件名称(不含插件 ID 前缀)。
component_type: 组件类型(如 ``ACTION``、``COMMAND`` 等)。
plugin_id: 插件 ID。
metadata: 组件元数据。
Returns:
bool: 注册成功时恒为 ``True``。
Raises:
ComponentRegistrationError: 组件声明不合法时抛出。
"""
component = self._build_component_entry(name, component_type, plugin_id, metadata)
self._add_component_entry(component)
return True
def register_plugin_components(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
"""批量注册一个插件的所有组件,返回成功注册数
"""批量替换一个插件的组件集合
该方法会先完整校验所有组件声明,只有全部通过后才会替换旧组件,
从而避免插件进入半注册状态。
Args:
plugin_id (str): 插件id
components (List[Dict[str, Any]]): 组件字典列表,每个组件包含 name, component_type, metadata 等字段
plugin_id: 插件 ID。
components: 组件声明字典列表
Returns:
count (int): 成功注册的组件数量
int: 实际注册的组件数量
Raises:
ComponentRegistrationError: 任一组件声明不合法时抛出。
"""
count = 0
for comp_data in components:
ok = self.register_component(
name=comp_data.get("name", ""),
component_type=comp_data.get("component_type", ""),
plugin_id=plugin_id,
metadata=comp_data.get("metadata", {}),
prepared_components: List[ComponentEntry] = []
for component_data in components:
prepared_components.append(
self._build_component_entry(
name=str(component_data.get("name", "") or ""),
component_type=str(component_data.get("component_type", "") or ""),
plugin_id=plugin_id,
metadata=component_data.get("metadata", {})
if isinstance(component_data.get("metadata"), dict)
else {},
)
)
if ok:
count += 1
return count
self.remove_components_by_plugin(plugin_id)
for component in prepared_components:
self._add_component_entry(component)
return len(prepared_components)
def remove_components_by_plugin(self, plugin_id: str) -> int:
"""移除某个插件的所有组件,返回移除数量。
@@ -652,6 +822,17 @@ class ComponentRegistry:
except ValueError:
logger.error(f"组件类型 {component_type} 不存在")
raise
if comp_type == ComponentTypes.ACTION:
action_components = [
component
for component in self._by_type.get(ComponentTypes.TOOL, {}).values()
if self._is_legacy_action_component(component)
]
if enabled_only:
return [component for component in action_components if self.check_component_enabled(component, session_id)]
return action_components
type_dict = self._by_type.get(comp_type, {})
if enabled_only:
return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)]
@@ -854,6 +1035,34 @@ class ComponentRegistry:
tools.append(comp)
return tools
def get_tools_for_llm(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""兼容旧接口,返回可供 LLM 使用的工具条目列表。
Args:
enabled_only: 是否仅返回启用的组件。
session_id: 可选的会话 ID若提供则考虑会话禁用状态。
Returns:
List[Dict[str, Any]]: 兼容旧结构的工具组件字典列表。
"""
return [
{
"name": tool.full_name,
"description": tool.description,
"parameters": (
dict(tool.parameters_raw)
if isinstance(tool.parameters_raw, dict) and tool.parameters_raw
else tool._get_parameters_schema() or {}
),
"parameters_raw": tool.parameters_raw,
"enabled": tool.enabled,
"plugin_id": tool.plugin_id,
}
for tool in self.get_tools(enabled_only=enabled_only, session_id=session_id)
if not self._is_legacy_action_component(tool)
]
# ====== 统计信息 ======
def get_stats(self) -> StatusDict:
"""获取注册统计。
@@ -863,9 +1072,21 @@ class ComponentRegistry:
"""
return StatusDict(
total=len(self._components),
action=len(self._by_type[ComponentTypes.ACTION]),
action=len(
[
component
for component in self._by_type.get(ComponentTypes.TOOL, {}).values()
if self._is_legacy_action_component(component)
]
),
command=len(self._by_type[ComponentTypes.COMMAND]),
tool=len(self._by_type[ComponentTypes.TOOL]),
tool=len(
[
component
for component in self._by_type.get(ComponentTypes.TOOL, {}).values()
if not self._is_legacy_action_component(component)
]
),
event_handler=len(self._by_type[ComponentTypes.EVENT_HANDLER]),
hook_handler=len(self._by_type[ComponentTypes.HOOK_HANDLER]),
message_gateway=len(self._by_type[ComponentTypes.MESSAGE_GATEWAY]),

View File

@@ -26,6 +26,8 @@ import contextlib
from src.common.logger import get_logger
from src.config.config import global_config
from .hook_spec_registry import HookSpec, HookSpecRegistry
if TYPE_CHECKING:
from .component_registry import HookHandlerEntry
from .supervisor import PluginRunnerSupervisor
@@ -33,29 +35,6 @@ if TYPE_CHECKING:
logger = get_logger("plugin_runtime.host.hook_dispatcher")
@dataclass(slots=True)
class HookSpec:
"""命名 Hook 的静态规格定义。
Attributes:
name: Hook 的唯一名称。
description: Hook 描述。
default_timeout_ms: 默认超时毫秒数;为 `0` 时退回系统默认值。
allow_blocking: 是否允许注册阻塞处理器。
allow_observe: 是否允许注册观察处理器。
allow_abort: 是否允许处理器中止当前 Hook 调用。
allow_kwargs_mutation: 是否允许阻塞处理器修改 `kwargs`。
"""
name: str
description: str = ""
default_timeout_ms: int = 0
allow_blocking: bool = True
allow_observe: bool = True
allow_abort: bool = True
allow_kwargs_mutation: bool = True
@dataclass(slots=True)
class HookHandlerExecutionResult:
"""单个 HookHandler 的执行结果。
@@ -121,17 +100,19 @@ class HookDispatcher:
def __init__(
self,
supervisors_provider: Optional[Callable[[], Sequence["PluginRunnerSupervisor"]]] = None,
hook_spec_registry: Optional[HookSpecRegistry] = None,
) -> None:
"""初始化 Hook 分发器。
Args:
supervisors_provider: 可选的 Supervisor 提供器。若调用 `invoke_hook()`
时未显式传入 `supervisors`,则使用该回调获取目标 Supervisor 列表。
hook_spec_registry: 可选的 Hook 规格注册中心;留空时使用独立注册中心。
"""
self._background_tasks: Set[asyncio.Task[Any]] = set()
self._hook_specs: Dict[str, HookSpec] = {}
self._supervisors_provider = supervisors_provider
self._hook_spec_registry = hook_spec_registry or HookSpecRegistry()
async def stop(self) -> None:
"""停止分发器并取消所有未完成的观察任务。"""
@@ -148,16 +129,7 @@ class HookDispatcher:
spec: 需要注册的 Hook 规格。
"""
normalized_name = self._normalize_hook_name(spec.name)
self._hook_specs[normalized_name] = HookSpec(
name=normalized_name,
description=spec.description,
default_timeout_ms=max(int(spec.default_timeout_ms), 0),
allow_blocking=bool(spec.allow_blocking),
allow_observe=bool(spec.allow_observe),
allow_abort=bool(spec.allow_abort),
allow_kwargs_mutation=bool(spec.allow_kwargs_mutation),
)
self._hook_spec_registry.register_hook_spec(spec)
def register_hook_specs(self, specs: Sequence[HookSpec]) -> None:
"""批量注册命名 Hook 规格。
@@ -180,14 +152,37 @@ class HookDispatcher:
"""
normalized_name = self._normalize_hook_name(hook_name)
if normalized_name in self._hook_specs:
return self._hook_specs[normalized_name]
registered_spec = self._hook_spec_registry.get_hook_spec(normalized_name)
if registered_spec is not None:
return registered_spec
return HookSpec(
name=normalized_name,
parameters_schema={},
default_timeout_ms=self._get_default_timeout_ms(),
)
def unregister_hook_spec(self, hook_name: str) -> bool:
"""注销指定命名 Hook 规格。
Args:
hook_name: 目标 Hook 名称。
Returns:
bool: 是否成功注销。
"""
return self._hook_spec_registry.unregister_hook_spec(hook_name)
def list_hook_specs(self) -> List[HookSpec]:
"""返回当前全部显式注册的 Hook 规格。
Returns:
List[HookSpec]: 已注册 Hook 规格列表。
"""
return self._hook_spec_registry.list_hook_specs()
async def invoke_hook(
self,
hook_name: str,

View File

@@ -0,0 +1,190 @@
"""命名 Hook 规格注册中心。"""
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence
@dataclass(slots=True)
class HookSpec:
"""命名 Hook 的静态规格定义。
Attributes:
name: Hook 的唯一名称。
description: Hook 描述。
parameters_schema: Hook 参数模型,使用对象级 JSON Schema 表示。
default_timeout_ms: 默认超时毫秒数;为 ``0`` 时退回系统默认值。
allow_blocking: 是否允许注册阻塞处理器。
allow_observe: 是否允许注册观察处理器。
allow_abort: 是否允许处理器中止当前 Hook 调用。
allow_kwargs_mutation: 是否允许阻塞处理器修改 ``kwargs``。
"""
name: str
description: str = ""
parameters_schema: Dict[str, Any] = field(default_factory=dict)
default_timeout_ms: int = 0
allow_blocking: bool = True
allow_observe: bool = True
allow_abort: bool = True
allow_kwargs_mutation: bool = True
class HookSpecRegistry:
"""命名 Hook 规格注册中心。"""
def __init__(self) -> None:
"""初始化 Hook 规格注册中心。"""
self._hook_specs: Dict[str, HookSpec] = {}
@staticmethod
def _normalize_hook_name(hook_name: str) -> str:
"""规范化 Hook 名称。
Args:
hook_name: 原始 Hook 名称。
Returns:
str: 规范化后的 Hook 名称。
Raises:
ValueError: Hook 名称为空时抛出。
"""
normalized_name = str(hook_name or "").strip()
if not normalized_name:
raise ValueError("Hook 名称不能为空")
return normalized_name
@staticmethod
def _normalize_parameters_schema(raw_schema: Any) -> Dict[str, Any]:
"""规范化 Hook 参数模型。
Args:
raw_schema: 原始参数模型。
Returns:
Dict[str, Any]: 规范化后的对象级 JSON Schema。
Raises:
ValueError: 参数模型不是合法对象级 Schema 时抛出。
"""
if raw_schema is None:
return {}
if not isinstance(raw_schema, dict):
raise ValueError("Hook 参数模型必须是字典")
if not raw_schema:
return {}
normalized_schema = deepcopy(raw_schema)
schema_type = normalized_schema.get("type")
properties = normalized_schema.get("properties")
if schema_type not in {"", None, "object"} and properties is None:
raise ValueError("Hook 参数模型必须是 object 类型或属性映射")
if schema_type in {"", None} and properties is None:
normalized_schema = {
"type": "object",
"properties": normalized_schema,
}
elif schema_type in {"", None}:
normalized_schema["type"] = "object"
if normalized_schema.get("type") != "object":
raise ValueError("Hook 参数模型必须是 object 类型")
return normalized_schema
@classmethod
def _normalize_spec(cls, spec: HookSpec) -> HookSpec:
"""规范化 Hook 规格对象。
Args:
spec: 原始 Hook 规格。
Returns:
HookSpec: 规范化后的 Hook 规格副本。
"""
return HookSpec(
name=cls._normalize_hook_name(spec.name),
description=str(spec.description or "").strip(),
parameters_schema=cls._normalize_parameters_schema(spec.parameters_schema),
default_timeout_ms=max(int(spec.default_timeout_ms), 0),
allow_blocking=bool(spec.allow_blocking),
allow_observe=bool(spec.allow_observe),
allow_abort=bool(spec.allow_abort),
allow_kwargs_mutation=bool(spec.allow_kwargs_mutation),
)
def clear(self) -> None:
"""清空全部 Hook 规格。"""
self._hook_specs.clear()
def register_hook_spec(self, spec: HookSpec) -> HookSpec:
"""注册单个 Hook 规格。
Args:
spec: 需要注册的 Hook 规格。
Returns:
HookSpec: 规范化后实际注册的 Hook 规格。
"""
normalized_spec = self._normalize_spec(spec)
self._hook_specs[normalized_spec.name] = normalized_spec
return normalized_spec
def register_hook_specs(self, specs: Sequence[HookSpec]) -> List[HookSpec]:
"""批量注册 Hook 规格。
Args:
specs: 需要注册的 Hook 规格列表。
Returns:
List[HookSpec]: 规范化后实际注册的 Hook 规格列表。
"""
return [self.register_hook_spec(spec) for spec in specs]
def unregister_hook_spec(self, hook_name: str) -> bool:
"""注销指定 Hook 规格。
Args:
hook_name: 目标 Hook 名称。
Returns:
bool: 是否成功删除。
"""
normalized_name = self._normalize_hook_name(hook_name)
return self._hook_specs.pop(normalized_name, None) is not None
def get_hook_spec(self, hook_name: str) -> Optional[HookSpec]:
"""获取指定 Hook 的显式规格。
Args:
hook_name: 目标 Hook 名称。
Returns:
Optional[HookSpec]: 已注册时返回规格副本,否则返回 ``None``。
"""
normalized_name = self._normalize_hook_name(hook_name)
spec = self._hook_specs.get(normalized_name)
return None if spec is None else self._normalize_spec(spec)
def list_hook_specs(self) -> List[HookSpec]:
"""返回当前全部 Hook 规格。
Returns:
List[HookSpec]: 按 Hook 名称升序排列的规格副本列表。
"""
return [
self._normalize_spec(spec)
for _, spec in sorted(self._hook_specs.items(), key=lambda item: item[0])
]

View File

@@ -14,6 +14,7 @@ from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, Ro
from src.platform_io.drivers import PluginPlatformDriver
from src.platform_io.route_key_factory import RouteKeyFactory
from src.plugin_runtime import (
ENV_BLOCKED_PLUGIN_REASONS,
ENV_EXTERNAL_PLUGIN_IDS,
ENV_GLOBAL_CONFIG_SNAPSHOT,
ENV_HOST_VERSION,
@@ -27,6 +28,8 @@ from src.plugin_runtime.protocol.envelope import (
ConfigUpdatedPayload,
Envelope,
HealthPayload,
InspectPluginConfigPayload,
InspectPluginConfigResultPayload,
MessageGatewayStateUpdatePayload,
MessageGatewayStateUpdateResultPayload,
PROTOCOL_VERSION,
@@ -39,6 +42,8 @@ from src.plugin_runtime.protocol.envelope import (
RunnerReadyPayload,
ShutdownPayload,
UnregisterPluginPayload,
ValidatePluginConfigPayload,
ValidatePluginConfigResultPayload,
)
from src.plugin_runtime.protocol.codec import MsgPackCodec
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
@@ -50,6 +55,7 @@ from .capability_service import CapabilityService
from .component_registry import ComponentRegistry
from .event_dispatcher import EventDispatcher
from .hook_dispatcher import HookDispatchResult, HookDispatcher
from .hook_spec_registry import HookSpecRegistry
from .logger_bridge import RunnerLogBridge
from .message_gateway import MessageGateway
from .rpc_server import RPCServer
@@ -59,6 +65,7 @@ if TYPE_CHECKING:
logger = get_logger("plugin_runtime.host.runner_manager")
@dataclass(slots=True)
class _MessageGatewayRuntimeState:
"""保存消息网关当前的运行时连接状态。"""
@@ -81,6 +88,7 @@ class PluginRunnerSupervisor:
self,
plugin_dirs: Optional[List[Path]] = None,
group_name: str = "third_party",
hook_spec_registry: Optional[HookSpecRegistry] = None,
socket_path: Optional[str] = None,
health_check_interval_sec: Optional[float] = None,
max_restart_attempts: Optional[int] = None,
@@ -91,6 +99,7 @@ class PluginRunnerSupervisor:
Args:
plugin_dirs: 由当前 Runner 负责加载的插件目录列表。
group_name: 当前 Supervisor 所属运行时分组名称。
hook_spec_registry: 可选的共享 Hook 规格注册中心。
socket_path: 自定义 IPC 地址;留空时由传输层自动生成。
health_check_interval_sec: 健康检查间隔,单位秒。
max_restart_attempts: 自动重启 Runner 的最大次数。
@@ -100,18 +109,19 @@ class PluginRunnerSupervisor:
self._group_name: str = str(group_name or "third_party").strip() or "third_party"
self._plugin_dirs: List[Path] = plugin_dirs or []
self._health_interval: float = health_check_interval_sec or runtime_config.health_check_interval_sec or 30.0
self._runner_spawn_timeout: float = (
runner_spawn_timeout_sec or runtime_config.runner_spawn_timeout_sec or 30.0
)
self._runner_spawn_timeout: float = runner_spawn_timeout_sec or runtime_config.runner_spawn_timeout_sec or 30.0
self._max_restart_attempts: int = max_restart_attempts or runtime_config.max_restart_attempts or 3
self._transport = create_transport_server(socket_path=socket_path)
self._authorization = AuthorizationManager()
self._capability_service = CapabilityService(self._authorization)
self._api_registry = APIRegistry()
self._component_registry = ComponentRegistry()
self._component_registry = ComponentRegistry(hook_spec_registry=hook_spec_registry)
self._event_dispatcher = EventDispatcher(self._component_registry)
self._hook_dispatcher = HookDispatcher(lambda: [self])
self._hook_dispatcher = HookDispatcher(
lambda: [self],
hook_spec_registry=hook_spec_registry,
)
self._message_gateway = MessageGateway(self._component_registry)
self._log_bridge = RunnerLogBridge()
@@ -122,6 +132,7 @@ class PluginRunnerSupervisor:
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
self._message_gateway_states: Dict[str, Dict[str, _MessageGatewayRuntimeState]] = {}
self._external_available_plugins: Dict[str, str] = {}
self._blocked_plugin_reasons: Dict[str, str] = {}
self._runner_ready_events: asyncio.Event = asyncio.Event()
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
self._health_task: Optional[asyncio.Task[None]] = None
@@ -200,9 +211,19 @@ class PluginRunnerSupervisor:
Returns:
Dict[str, str]: 已注册插件版本映射,键为插件 ID值为插件版本。
"""
return {
plugin_id: registration.plugin_version
for plugin_id, registration in self._registered_plugins.items()
return {plugin_id: registration.plugin_version for plugin_id, registration in self._registered_plugins.items()}
def set_blocked_plugin_reasons(self, blocked_plugin_reasons: Dict[str, str]) -> None:
"""设置当前 Runner 启动时应拒绝加载的插件列表。
Args:
blocked_plugin_reasons: 需要拒绝加载的插件及原因映射。
"""
self._blocked_plugin_reasons = {
str(plugin_id or "").strip(): str(reason or "").strip()
for plugin_id, reason in blocked_plugin_reasons.items()
if str(plugin_id or "").strip() and str(reason or "").strip()
}
@staticmethod
@@ -550,6 +571,82 @@ class PluginRunnerSupervisor:
return bool(response.payload.get("acknowledged", False))
async def validate_plugin_config(self, plugin_id: str, config_data: Dict[str, Any]) -> Dict[str, Any]:
"""请求 Runner 使用插件自身配置模型校验配置。
Args:
plugin_id: 目标插件 ID。
config_data: 待校验的配置内容。
Returns:
Dict[str, Any]: 插件模型归一化后的配置字典。
Raises:
ValueError: 插件拒绝该配置或校验失败时抛出。
"""
payload = ValidatePluginConfigPayload(config_data=config_data)
try:
response = await self._rpc_server.send_request(
"plugin.validate_config",
plugin_id=plugin_id,
payload=payload.model_dump(),
timeout_ms=10000,
)
except Exception as exc:
raise ValueError(f"插件配置校验请求失败: {exc}") from exc
if response.error:
raise ValueError(str(response.error.get("message", "插件配置校验失败")))
result = ValidatePluginConfigResultPayload.model_validate(response.payload)
if not result.success:
raise ValueError("插件配置校验失败")
return dict(result.normalized_config)
async def inspect_plugin_config(
self,
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
*,
use_provided_config: bool = False,
) -> InspectPluginConfigResultPayload:
"""请求 Runner 解析插件配置元数据。
Args:
plugin_id: 目标插件 ID。
config_data: 可选的配置内容。
use_provided_config: 是否优先使用传入配置而不是磁盘配置。
Returns:
InspectPluginConfigResultPayload: 插件配置解析结果。
Raises:
ValueError: Runner 无法解析插件或返回了错误响应时抛出。
"""
payload = InspectPluginConfigPayload(
config_data=config_data or {},
use_provided_config=use_provided_config,
)
try:
response = await self._rpc_server.send_request(
"plugin.inspect_config",
plugin_id=plugin_id,
payload=payload.model_dump(),
timeout_ms=10000,
)
except Exception as exc:
raise ValueError(f"插件配置解析请求失败: {exc}") from exc
if response.error:
raise ValueError(str(response.error.get("message", "插件配置解析失败")))
result = InspectPluginConfigResultPayload.model_validate(response.payload)
if not result.success:
raise ValueError("插件配置解析失败")
return result
def get_config_reload_subscribers(self, scope: str) -> List[str]:
"""返回订阅指定全局配置广播的插件列表。
@@ -608,6 +705,7 @@ class PluginRunnerSupervisor:
Raises:
TimeoutError: 在超时时间内 Runner 未完成初始化。
"""
async def wait_for_ready() -> RunnerReadyPayload:
"""轮询等待 Runner 上报就绪。"""
while True:
@@ -681,15 +779,25 @@ class PluginRunnerSupervisor:
component_declarations = [component.model_dump() for component in payload.components]
runtime_components, api_components = self._split_component_declarations(component_declarations)
self._component_registry.remove_components_by_plugin(payload.plugin_id)
self._api_registry.remove_apis_by_plugin(payload.plugin_id)
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
try:
registered_count = self._component_registry.register_plugin_components(
payload.plugin_id,
runtime_components,
)
except Exception as exc:
logger.error(f"插件 {payload.plugin_id} 组件注册失败: {exc}")
return envelope.make_error_response(
ErrorCode.E_BAD_PAYLOAD.value,
str(exc),
details={
"plugin_id": payload.plugin_id,
"component_count": len(runtime_components),
},
)
registered_count = self._component_registry.register_plugin_components(
payload.plugin_id,
runtime_components,
)
self._api_registry.remove_apis_by_plugin(payload.plugin_id)
registered_api_count = self._api_registry.register_plugin_apis(payload.plugin_id, api_components)
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
self._registered_plugins[payload.plugin_id] = payload
self._message_gateway_states[payload.plugin_id] = {}
@@ -1058,7 +1166,9 @@ class PluginRunnerSupervisor:
route_key = RouteKey(platform=platform)
route_account_id, route_scope = RouteKeyFactory.extract_components(route_metadata)
account_id = route_key.account_id or route_account_id or runtime_state.account_id or gateway_entry.account_id or None
account_id = (
route_key.account_id or route_account_id or runtime_state.account_id or gateway_entry.account_id or None
)
scope = route_key.scope or route_scope or runtime_state.scope or gateway_entry.scope or None
return RouteKey(
platform=platform,
@@ -1208,6 +1318,7 @@ class PluginRunnerSupervisor:
global_config_snapshot = config_manager.get_global_config().model_dump(mode="json")
global_config_snapshot["model"] = config_manager.get_model_config().model_dump(mode="json")
return {
ENV_BLOCKED_PLUGIN_REASONS: json.dumps(self._blocked_plugin_reasons, ensure_ascii=False),
ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugins, ensure_ascii=False),
ENV_GLOBAL_CONFIG_SNAPSHOT: json.dumps(global_config_snapshot, ensure_ascii=False),
ENV_HOST_VERSION: PROTOCOL_VERSION,

View File

@@ -8,10 +8,25 @@
5. 提供统一的能力实现注册接口,使插件可以调用主程序功能
"""
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Set, Tuple
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Coroutine,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
)
import asyncio
import inspect
import tomlkit
@@ -23,10 +38,15 @@ from src.plugin_runtime.capabilities import (
RuntimeComponentCapabilityMixin,
RuntimeCoreCapabilityMixin,
RuntimeDataCapabilityMixin,
RuntimeRenderCapabilityMixin,
)
from src.plugin_runtime.capabilities.registry import register_capability_impls
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult, HookDispatcher, HookSpec
from src.plugin_runtime.dependency_pipeline import PluginDependencyPipeline
from src.plugin_runtime.hook_catalog import register_builtin_hook_specs
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult, HookDispatcher
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils
from src.plugin_runtime.protocol.envelope import InspectPluginConfigResultPayload
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
if TYPE_CHECKING:
@@ -50,10 +70,19 @@ _EVENT_TYPE_MAP: Dict[str, str] = {
}
@dataclass(frozen=True)
class DependencySyncState:
"""表示一次插件依赖同步后的状态。"""
blocked_changed_plugin_ids: Set[str]
environment_changed: bool
class PluginRuntimeManager(
RuntimeCoreCapabilityMixin,
RuntimeDataCapabilityMixin,
RuntimeComponentCapabilityMixin,
RuntimeRenderCapabilityMixin,
):
"""插件运行时管理器(单例)
@@ -71,10 +100,17 @@ class PluginRuntimeManager(
self._plugin_source_watcher_subscription_id: Optional[str] = None
self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {}
self._plugin_path_cache: Dict[str, Path] = {}
self._manifest_validator: ManifestValidator = ManifestValidator()
self._manifest_validator: ManifestValidator = ManifestValidator(validate_python_package_dependencies=False)
self._plugin_dependency_pipeline: PluginDependencyPipeline = PluginDependencyPipeline()
self._blocked_plugin_reasons: Dict[str, str] = {}
self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload
self._config_reload_callback_registered: bool = False
self._hook_dispatcher: HookDispatcher = HookDispatcher(lambda: self.supervisors)
self._hook_spec_registry: HookSpecRegistry = HookSpecRegistry()
self._builtin_hook_specs_registered: bool = False
self._hook_dispatcher: HookDispatcher = HookDispatcher(
lambda: self.supervisors,
hook_spec_registry=self._hook_spec_registry,
)
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
"""接收 Platform IO 审核后的入站消息并送入主消息链。
@@ -109,7 +145,7 @@ class PluginRuntimeManager(
@classmethod
def _discover_plugin_dependency_map(cls, plugin_dirs: Iterable[Path]) -> Dict[str, List[str]]:
"""扫描指定插件目录集合,返回 ``plugin_id -> dependencies`` 映射。"""
validator = ManifestValidator()
validator = ManifestValidator(validate_python_package_dependencies=False)
return validator.build_plugin_dependency_map(plugin_dirs)
@classmethod
@@ -142,6 +178,233 @@ class PluginRuntimeManager(
return ["third_party", "builtin"]
return ["builtin", "third_party"]
@staticmethod
def _instantiate_supervisor(supervisor_cls: Any, **kwargs: Any) -> Any:
"""兼容不同构造签名地实例化 Supervisor。
Args:
supervisor_cls: 目标 Supervisor 类。
**kwargs: 期望传入的构造参数。
Returns:
Any: 实例化后的 Supervisor。
"""
signature = inspect.signature(supervisor_cls)
accepts_var_keyword = any(
parameter.kind == inspect.Parameter.VAR_KEYWORD
for parameter in signature.parameters.values()
)
if accepts_var_keyword:
return supervisor_cls(**kwargs)
supported_kwargs = {
key: value
for key, value in kwargs.items()
if key in signature.parameters
}
return supervisor_cls(**supported_kwargs)
def _resolve_runtime_plugin_dirs(self) -> Tuple[List[Path], List[Path]]:
"""解析当前运行时应管理的插件根目录。
Returns:
Tuple[List[Path], List[Path]]: 内置插件目录列表与第三方插件目录列表。
"""
return self._get_builtin_plugin_dirs(), self._get_third_party_plugin_dirs()
@staticmethod
def _resolve_supervisor_socket_paths() -> Tuple[Optional[str], Optional[str]]:
"""解析内置与第三方 Supervisor 的 IPC 地址。
Returns:
Tuple[Optional[str], Optional[str]]: 内置 Runner 与第三方 Runner 的 socket 地址。
"""
runtime_config = config_manager.get_global_config().plugin_runtime
socket_path_base = runtime_config.ipc_socket_path or None
builtin_socket = f"{socket_path_base}-builtin" if socket_path_base else None
third_party_socket = f"{socket_path_base}-third_party" if socket_path_base else None
return builtin_socket, third_party_socket
def _apply_blocked_plugin_reasons_to_supervisors(self) -> None:
"""将当前阻止加载插件列表同步到全部 Supervisor。"""
for supervisor in self.supervisors:
set_blocked_plugin_reasons = getattr(supervisor, "set_blocked_plugin_reasons", None)
if callable(set_blocked_plugin_reasons):
set_blocked_plugin_reasons(self._blocked_plugin_reasons)
def _set_blocked_plugin_reasons(self, blocked_plugin_reasons: Dict[str, str]) -> Set[str]:
"""更新 Host 侧维护的阻止加载插件列表。
Args:
blocked_plugin_reasons: 最新的阻止加载插件及原因映射。
Returns:
Set[str]: 本次发生状态变化的插件 ID 集合。
"""
normalized_reasons = {
str(plugin_id or "").strip(): str(reason or "").strip()
for plugin_id, reason in blocked_plugin_reasons.items()
if str(plugin_id or "").strip() and str(reason or "").strip()
}
changed_plugin_ids = {
plugin_id
for plugin_id in set(self._blocked_plugin_reasons) | set(normalized_reasons)
if self._blocked_plugin_reasons.get(plugin_id) != normalized_reasons.get(plugin_id)
}
self._blocked_plugin_reasons = normalized_reasons
self._apply_blocked_plugin_reasons_to_supervisors()
return changed_plugin_ids
async def _sync_plugin_dependencies(self, plugin_dirs: Sequence[Path]) -> DependencySyncState:
"""执行插件依赖同步,并刷新阻止加载插件列表。
Args:
plugin_dirs: 当前需要参与分析的插件根目录列表。
Returns:
DependencySyncState: 同步后的环境变更状态与阻止列表变化集合。
"""
result = await self._plugin_dependency_pipeline.execute(plugin_dirs)
changed_plugin_ids = self._set_blocked_plugin_reasons(result.blocked_plugin_reasons)
return DependencySyncState(
blocked_changed_plugin_ids=changed_plugin_ids,
environment_changed=result.environment_changed,
)
def _build_supervisors(self, builtin_dirs: Sequence[Path], third_party_dirs: Sequence[Path]) -> None:
"""根据目录列表创建当前运行时所需的 Supervisor。
Args:
builtin_dirs: 内置插件目录列表。
third_party_dirs: 第三方插件目录列表。
"""
from src.plugin_runtime.host.supervisor import PluginSupervisor
builtin_socket, third_party_socket = self._resolve_supervisor_socket_paths()
self._builtin_supervisor = None
self._third_party_supervisor = None
if builtin_dirs:
builtin_supervisor = self._instantiate_supervisor(
PluginSupervisor,
plugin_dirs=list(builtin_dirs),
group_name="builtin",
hook_spec_registry=self._hook_spec_registry,
socket_path=builtin_socket,
)
self._builtin_supervisor = builtin_supervisor
self._register_capability_impls(builtin_supervisor)
if third_party_dirs:
third_party_supervisor = self._instantiate_supervisor(
PluginSupervisor,
plugin_dirs=list(third_party_dirs),
group_name="third_party",
hook_spec_registry=self._hook_spec_registry,
socket_path=third_party_socket,
)
self._third_party_supervisor = third_party_supervisor
self._register_capability_impls(third_party_supervisor)
self._apply_blocked_plugin_reasons_to_supervisors()
async def _start_supervisors(
self,
builtin_dirs: Sequence[Path],
third_party_dirs: Sequence[Path],
) -> List["PluginSupervisor"]:
"""按依赖顺序启动当前已创建的 Supervisor。
Args:
builtin_dirs: 内置插件目录列表。
third_party_dirs: 第三方插件目录列表。
Returns:
List[PluginSupervisor]: 成功启动的 Supervisor 列表。
"""
started_supervisors: List["PluginSupervisor"] = []
supervisor_groups: Dict[str, Optional["PluginSupervisor"]] = {
"builtin": self._builtin_supervisor,
"third_party": self._third_party_supervisor,
}
start_order = self._build_group_start_order(builtin_dirs, third_party_dirs)
try:
for group_name in start_order:
supervisor = supervisor_groups.get(group_name)
if supervisor is None:
continue
external_plugin_versions = {
plugin_id: plugin_version
for started_supervisor in started_supervisors
for plugin_id, plugin_version in started_supervisor.get_loaded_plugin_versions().items()
}
supervisor.set_external_available_plugins(external_plugin_versions)
set_blocked_plugin_reasons = getattr(supervisor, "set_blocked_plugin_reasons", None)
if callable(set_blocked_plugin_reasons):
set_blocked_plugin_reasons(self._blocked_plugin_reasons)
await supervisor.start()
started_supervisors.append(supervisor)
except Exception:
await asyncio.gather(*(supervisor.stop() for supervisor in started_supervisors), return_exceptions=True)
raise
return started_supervisors
async def _stop_supervisors(self) -> None:
"""停止当前全部 Supervisor。"""
supervisors = self.supervisors
if not supervisors:
return
await asyncio.gather(*(supervisor.stop() for supervisor in supervisors), return_exceptions=True)
self._builtin_supervisor = None
self._third_party_supervisor = None
async def _restart_supervisors(self, reason: str) -> bool:
"""重启当前全部 Supervisor。
Args:
reason: 本次重启的原因。
Returns:
bool: 是否重启成功。
"""
builtin_dirs, third_party_dirs = self._resolve_runtime_plugin_dirs()
if duplicate_plugin_ids := self._find_duplicate_plugin_ids(builtin_dirs + third_party_dirs):
details = "; ".join(
f"{plugin_id}: {', '.join(str(path) for path in paths)}"
for plugin_id, paths in sorted(duplicate_plugin_ids.items())
)
logger.error(f"检测到重复插件 ID拒绝执行 Supervisor 重启: {details}")
return False
logger.info(f"开始重启插件运行时 Supervisor: {reason}")
await self._stop_supervisors()
self._build_supervisors(builtin_dirs, third_party_dirs)
try:
await self._start_supervisors(builtin_dirs, third_party_dirs)
except Exception as exc:
logger.error(f"重启插件运行时 Supervisor 失败: {exc}", exc_info=True)
await self._stop_supervisors()
return False
self._refresh_plugin_config_watch_subscriptions()
logger.info(f"插件运行时 Supervisor 已重启完成: {reason}")
return True
# ─── 生命周期 ─────────────────────────────────────────────
async def start(self) -> None:
@@ -155,10 +418,7 @@ class PluginRuntimeManager(
logger.info("插件运行时已在配置中禁用,跳过启动")
return
from src.plugin_runtime.host.supervisor import PluginSupervisor
builtin_dirs = self._get_builtin_plugin_dirs()
third_party_dirs = self._get_third_party_plugin_dirs()
builtin_dirs, third_party_dirs = self._resolve_runtime_plugin_dirs()
if duplicate_plugin_ids := self._find_duplicate_plugin_ids(builtin_dirs + third_party_dirs):
details = "; ".join(
@@ -172,56 +432,19 @@ class PluginRuntimeManager(
logger.info("未找到任何插件目录,跳过插件运行时启动")
return
dependency_sync_state = await self._sync_plugin_dependencies(builtin_dirs + third_party_dirs)
if dependency_sync_state.environment_changed:
logger.info("插件依赖流水线已更新当前 Python 环境,启动时将直接加载最新环境")
self.ensure_builtin_hook_specs_registered()
platform_io_manager = get_platform_io_manager()
self._build_supervisors(builtin_dirs, third_party_dirs)
# 从配置读取自定义 IPC socket 路径(留空则自动生成)
socket_path_base = _cfg.ipc_socket_path or None
# 当用户指定了自定义路径时,为两个 Supervisor 添加后缀以避免 UDS 冲突
builtin_socket = f"{socket_path_base}-builtin" if socket_path_base else None
third_party_socket = f"{socket_path_base}-third_party" if socket_path_base else None
# 创建两个 Supervisor各自拥有独立的 socket / Runner 子进程
if builtin_dirs:
self._builtin_supervisor = PluginSupervisor(
plugin_dirs=builtin_dirs,
group_name="builtin",
socket_path=builtin_socket,
)
self._register_capability_impls(self._builtin_supervisor)
if third_party_dirs:
self._third_party_supervisor = PluginSupervisor(
plugin_dirs=third_party_dirs,
group_name="third_party",
socket_path=third_party_socket,
)
self._register_capability_impls(self._third_party_supervisor)
started_supervisors: List[PluginSupervisor] = []
started_supervisors: List["PluginSupervisor"] = []
try:
platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound)
await platform_io_manager.ensure_send_pipeline_ready()
supervisor_groups: Dict[str, Optional[PluginSupervisor]] = {
"builtin": self._builtin_supervisor,
"third_party": self._third_party_supervisor,
}
start_order = self._build_group_start_order(builtin_dirs, third_party_dirs)
for group_name in start_order:
supervisor = supervisor_groups.get(group_name)
if supervisor is None:
continue
external_plugin_versions = {
plugin_id: plugin_version
for started_supervisor in started_supervisors
for plugin_id, plugin_version in started_supervisor.get_loaded_plugin_versions().items()
}
supervisor.set_external_available_plugins(external_plugin_versions)
await supervisor.start()
started_supervisors.append(supervisor)
started_supervisors = await self._start_supervisors(builtin_dirs, third_party_dirs)
await self._start_plugin_file_watcher()
config_manager.register_reload_callback(self._config_reload_callback)
@@ -315,6 +538,7 @@ class PluginRuntimeManager(
spec: 需要注册的 Hook 规格。
"""
self.ensure_builtin_hook_specs_registered()
self._hook_dispatcher.register_hook_spec(spec)
def register_hook_specs(self, specs: Sequence[HookSpec]) -> None:
@@ -324,8 +548,41 @@ class PluginRuntimeManager(
specs: 需要注册的 Hook 规格序列。
"""
self.ensure_builtin_hook_specs_registered()
self._hook_dispatcher.register_hook_specs(specs)
def unregister_hook_spec(self, hook_name: str) -> bool:
"""注销指定命名 Hook 规格。
Args:
hook_name: 目标 Hook 名称。
Returns:
bool: 是否成功注销。
"""
self.ensure_builtin_hook_specs_registered()
return self._hook_dispatcher.unregister_hook_spec(hook_name)
def list_hook_specs(self) -> List[HookSpec]:
"""返回当前全部命名 Hook 规格。
Returns:
List[HookSpec]: 当前已注册的 Hook 规格列表。
"""
self.ensure_builtin_hook_specs_registered()
return self._hook_dispatcher.list_hook_specs()
def ensure_builtin_hook_specs_registered(self) -> None:
"""确保内置 Hook 规格已经注册到共享中心表。"""
if self._builtin_hook_specs_registered:
return
register_builtin_hook_specs(self._hook_spec_registry)
self._builtin_hook_specs_registered = True
def _build_registered_dependency_map(self) -> Dict[str, Set[str]]:
"""根据当前已注册插件构建全局依赖图。"""
@@ -364,9 +621,7 @@ class PluginRuntimeManager(
"""构建当前已注册插件到所属 Supervisor 的映射。"""
return {
plugin_id: supervisor
for supervisor in self.supervisors
for plugin_id in supervisor.get_loaded_plugin_ids()
plugin_id: supervisor for supervisor in self.supervisors for plugin_id in supervisor.get_loaded_plugin_ids()
}
def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> Dict[str, str]:
@@ -411,9 +666,7 @@ class PluginRuntimeManager(
local_plugin_ids = set(supervisor.get_loaded_plugin_ids())
local_dependency_map = {
plugin_id: {
dependency
for dependency in dependency_map.get(plugin_id, set())
if dependency in local_plugin_ids
dependency for dependency in dependency_map.get(plugin_id, set()) if dependency in local_plugin_ids
}
for plugin_id in local_plugin_ids
}
@@ -440,13 +693,26 @@ class PluginRuntimeManager(
"""
normalized_plugin_ids = [
normalized_plugin_id
for plugin_id in plugin_ids
if (normalized_plugin_id := str(plugin_id or "").strip())
normalized_plugin_id for plugin_id in plugin_ids if (normalized_plugin_id := str(plugin_id or "").strip())
]
if not normalized_plugin_ids:
return True
blocked_plugin_ids = [plugin_id for plugin_id in normalized_plugin_ids if plugin_id in self._blocked_plugin_reasons]
if blocked_plugin_ids:
logger.warning(
"以下插件当前被依赖流水线阻止加载,已拒绝重载请求: "
+ ", ".join(
f"{plugin_id} ({self._blocked_plugin_reasons[plugin_id]})"
for plugin_id in sorted(blocked_plugin_ids)
)
)
normalized_plugin_ids = [
plugin_id for plugin_id in normalized_plugin_ids if plugin_id not in self._blocked_plugin_reasons
]
if not normalized_plugin_ids:
return False
dependency_map = self._build_registered_dependency_map()
supervisor_by_plugin = self._build_registered_supervisor_map()
supervisor_roots: Dict["PluginSupervisor", List[str]] = {}
@@ -518,9 +784,7 @@ class PluginRuntimeManager(
return False
config_payload = (
config_data
if config_data is not None
else self._load_plugin_config_for_supervisor(sv, plugin_id)
config_data if config_data is not None else self._load_plugin_config_for_supervisor(sv, plugin_id)
)
return await sv.notify_plugin_config_updated(
plugin_id=plugin_id,
@@ -529,6 +793,91 @@ class PluginRuntimeManager(
config_scope=config_scope,
)
async def validate_plugin_config(self, plugin_id: str, config_data: Dict[str, Any]) -> Dict[str, Any] | None:
"""请求运行时按插件自身配置模型校验配置。
Args:
plugin_id: 目标插件 ID。
config_data: 待校验的配置内容。
Returns:
Dict[str, Any] | None: 校验成功时返回规范化后的配置;若插件不存在、
当前不可路由或运行时不可用,则返回 ``None`` 以便调用方回退到弱推断方案。
Raises:
ValueError: 插件已加载,但配置校验失败时抛出。
"""
if not self._started:
return None
try:
supervisor = self._get_supervisor_for_plugin(plugin_id)
except RuntimeError as exc:
logger.warning(f"插件 {plugin_id} 配置校验路由失败,将回退到静态 Schema: {exc}")
return None
if supervisor is None:
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
if supervisor is None:
return None
try:
return await supervisor.validate_plugin_config(plugin_id, config_data)
except ValueError:
raise
except Exception as exc:
logger.warning(f"插件 {plugin_id} 运行时配置校验不可用,将回退到静态 Schema: {exc}")
return None
async def inspect_plugin_config(
self,
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
*,
use_provided_config: bool = False,
) -> InspectPluginConfigResultPayload | None:
"""请求运行时解析插件配置元数据。
Args:
plugin_id: 目标插件 ID。
config_data: 可选的配置内容。
use_provided_config: 是否优先使用传入的配置内容而不是磁盘配置。
Returns:
InspectPluginConfigResultPayload | None: 解析成功时返回结构化结果;若插件
当前不可路由或运行时不可用,则返回 ``None``。
Raises:
ValueError: 插件存在,但运行时明确拒绝解析请求时抛出。
"""
if not self._started:
return None
try:
supervisor = self._get_supervisor_for_plugin(plugin_id)
except RuntimeError as exc:
logger.warning(f"插件 {plugin_id} 配置解析路由失败: {exc}")
return None
if supervisor is None:
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
if supervisor is None:
return None
try:
return await supervisor.inspect_plugin_config(
plugin_id=plugin_id,
config_data=config_data,
use_provided_config=use_provided_config,
)
except ValueError:
raise
except Exception as exc:
logger.warning(f"插件 {plugin_id} 配置解析不可用: {exc}")
return None
@staticmethod
def _normalize_config_reload_scopes(changed_scopes: Sequence[str]) -> tuple[str, ...]:
"""规范化配置热重载范围列表。
@@ -731,11 +1080,25 @@ class PluginRuntimeManager(
return matches[0] if matches else None
async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool:
"""加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。"""
"""加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。
Args:
plugin_id: 目标插件 ID。
reason: 加载或重载原因。
Returns:
bool: 插件最终是否处于已加载状态。
"""
normalized_plugin_id = str(plugin_id or "").strip()
if not normalized_plugin_id:
return False
if normalized_plugin_id in self._blocked_plugin_reasons:
logger.warning(
f"插件 {normalized_plugin_id} 当前被依赖流水线阻止加载: "
f"{self._blocked_plugin_reasons[normalized_plugin_id]}"
)
return False
try:
registered_supervisor = self._get_supervisor_for_plugin(normalized_plugin_id)
@@ -749,17 +1112,18 @@ class PluginRuntimeManager(
if supervisor is None:
return False
return await supervisor.reload_plugins(
reloaded = await supervisor.reload_plugins(
plugin_ids=[normalized_plugin_id],
reason=reason,
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
)
return reloaded and normalized_plugin_id in supervisor.get_loaded_plugin_ids()
@classmethod
def _find_duplicate_plugin_ids(cls, plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
"""扫描插件目录,找出被多个目录重复声明的插件 ID。"""
plugin_locations: Dict[str, List[Path]] = {}
validator = ManifestValidator()
validator = ManifestValidator(validate_python_package_dependencies=False)
for plugin_path, manifest in validator.iter_plugin_manifests(plugin_dirs):
plugin_locations.setdefault(manifest.id, []).append(plugin_path)
@@ -869,7 +1233,9 @@ class PluginRuntimeManager(
if self._plugin_dir_matches(cached_path, Path(plugin_dir)):
return cached_path
for candidate_plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
for candidate_plugin_id, plugin_path in self._iter_discovered_plugin_paths(
getattr(supervisor, "_plugin_dirs", [])
):
if candidate_plugin_id != plugin_id:
continue
self._plugin_path_cache[plugin_id] = plugin_path
@@ -878,15 +1244,16 @@ class PluginRuntimeManager(
return None
def _refresh_plugin_config_watch_subscriptions(self) -> None:
"""按当前已注册插件集合刷新 config.toml 的单插件订阅。
"""按当前可识别插件集合刷新 config.toml 的单插件订阅。
当插件热重载后,插件集合或目录位置可能发生变化,因此需要重新对齐
watcher 的订阅,确保每个插件配置变更只触发对应 plugin_id。
这里不仅覆盖当前已注册插件,也覆盖已存在但暂未激活的合法插件。
"""
if self._plugin_file_watcher is None:
return
desired_plugin_paths = dict(self._iter_registered_plugin_paths())
desired_plugin_paths = dict(self._iter_watchable_plugin_paths())
self._plugin_path_cache = desired_plugin_paths.copy()
desired_config_paths = {
plugin_id: self._resolve_plugin_config_path(plugin_id, plugin_path)
@@ -909,9 +1276,7 @@ class PluginRuntimeManager(
)
self._plugin_config_watcher_subscriptions[plugin_id] = (config_path, subscription_id)
def _build_plugin_config_change_callback(
self, plugin_id: str
) -> Callable[[Sequence[FileChange]], Awaitable[None]]:
def _build_plugin_config_change_callback(self, plugin_id: str) -> Callable[[Sequence[FileChange]], Awaitable[None]]:
"""为指定插件生成配置文件变更回调。"""
async def _callback(changes: Sequence[FileChange]) -> None:
@@ -931,6 +1296,18 @@ class PluginRuntimeManager(
if plugin_path := self._get_plugin_path_for_supervisor(supervisor, plugin_id):
yield plugin_id, plugin_path
def _iter_watchable_plugin_paths(self) -> Iterable[Tuple[str, Path]]:
"""迭代应被配置监听器追踪的插件目录。
Returns:
Iterable[Tuple[str, Path]]: ``(plugin_id, plugin_path)`` 迭代器。
"""
watchable_plugin_paths = dict(self._iter_discovered_plugin_paths(self._iter_plugin_dirs()))
for plugin_id, plugin_path in self._iter_registered_plugin_paths():
watchable_plugin_paths.setdefault(plugin_id, plugin_path)
yield from watchable_plugin_paths.items()
def _get_plugin_config_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
"""从指定 Supervisor 的插件目录中定位某个插件的 config.toml。"""
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
@@ -958,18 +1335,43 @@ class PluginRuntimeManager(
return
if supervisor is None:
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
if supervisor is None:
return
plugin_is_loaded = plugin_id in getattr(supervisor, "_registered_plugins", {})
try:
snapshot = await supervisor.inspect_plugin_config(plugin_id)
except Exception as exc:
logger.warning(f"插件 {plugin_id} 配置文件变更解析失败: {exc}")
return
try:
config_payload = self._load_plugin_config_for_supervisor(supervisor, plugin_id)
delivered = await supervisor.notify_plugin_config_updated(
plugin_id=plugin_id,
config_data=config_payload,
config_version="",
config_scope="self",
)
if not delivered:
logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败")
if plugin_is_loaded and snapshot.enabled:
delivered = await supervisor.notify_plugin_config_updated(
plugin_id=plugin_id,
config_data=dict(snapshot.normalized_config),
config_version="",
config_scope="self",
)
if not delivered:
logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败")
return
if plugin_is_loaded and not snapshot.enabled:
reloaded = await self.reload_plugins_globally([plugin_id], reason="config_disabled")
if not reloaded:
logger.warning(f"插件 {plugin_id} 禁用配置已写入,但运行时卸载失败")
return
if not snapshot.enabled:
logger.info(f"插件 {plugin_id} 当前处于禁用状态,跳过自动加载")
return
loaded = await self.load_plugin_globally(plugin_id, reason="config_enabled")
if not loaded:
logger.warning(f"插件 {plugin_id} 配置文件变更后自动加载失败")
except Exception as exc:
logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}")
@@ -983,7 +1385,8 @@ class PluginRuntimeManager(
if not self._started or not changes:
return
if duplicate_plugin_ids := self._find_duplicate_plugin_ids(list(self._iter_plugin_dirs())):
plugin_dirs = list(self._iter_plugin_dirs())
if duplicate_plugin_ids := self._find_duplicate_plugin_ids(plugin_dirs):
details = "; ".join(
f"{plugin_id}: {', '.join(str(path) for path in paths)}"
for plugin_id, paths in sorted(duplicate_plugin_ids.items())
@@ -991,21 +1394,24 @@ class PluginRuntimeManager(
logger.error(f"检测到重复插件 ID跳过本次插件热重载: {details}")
return
changed_plugin_ids: List[str] = []
changed_paths = [change.path.resolve() for change in changes]
relevant_source_changes = [
change.path.resolve()
for change in changes
if change.path.name in {"plugin.py", "_manifest.json"} or change.path.suffix == ".py"
]
if not relevant_source_changes:
return
for supervisor in self.supervisors:
for path in changed_paths:
plugin_id = self._match_plugin_id_for_supervisor(supervisor, path)
if plugin_id is None:
continue
if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py":
if plugin_id not in changed_plugin_ids:
changed_plugin_ids.append(plugin_id)
dependency_sync_state = await self._sync_plugin_dependencies(plugin_dirs)
restart_reason = "file_watcher"
if dependency_sync_state.environment_changed:
restart_reason = "file_watcher_dependency_install"
elif dependency_sync_state.blocked_changed_plugin_ids:
restart_reason = "file_watcher_blocklist_changed"
if changed_plugin_ids:
await self.reload_plugins_globally(changed_plugin_ids, reason="file_watcher")
self._refresh_plugin_config_watch_subscriptions()
restarted = await self._restart_supervisors(restart_reason)
if not restarted:
logger.warning(f"插件源码变更后重启 Supervisor 失败: {restart_reason}")
@staticmethod
def _plugin_dir_matches(path: Path, plugin_dir: Path) -> bool:
@@ -1023,7 +1429,10 @@ class PluginRuntimeManager(
return plugin_id
for plugin_id, plugin_path in self._plugin_path_cache.items():
if not any(self._plugin_dir_matches(plugin_path, Path(plugin_dir)) for plugin_dir in getattr(supervisor, "_plugin_dirs", [])):
if not any(
self._plugin_dir_matches(plugin_path, Path(plugin_dir))
for plugin_dir in getattr(supervisor, "_plugin_dirs", [])
):
continue
if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
return plugin_id

View File

@@ -1,7 +1,7 @@
"""RPC Envelope 消息模型
"""RPC Envelope 消息模型
定义 Host 与 Runner 之间所有 RPC 消息的统一信封格式。
使用 Pydantic 进行 schema 定义与校验。
使用 Pydantic 进行 Schema 定义与校验。
"""
from enum import Enum
@@ -39,12 +39,23 @@ class ConfigReloadScope(str, Enum):
# ====== 请求 ID 生成器 ======
class RequestIdGenerator:
"""单调递增 int64 请求 ID 生成器"""
"""单调递增 int64 请求 ID 生成器"""
def __init__(self, start: int = 1) -> None:
"""初始化请求 ID 生成器。
Args:
start: 起始请求 ID。
"""
self._counter = start
async def next(self) -> int:
"""返回下一个请求 ID。
Returns:
int: 下一个可用的请求 ID。
"""
current = self._counter
self._counter += 1
return current
@@ -52,7 +63,7 @@ class RequestIdGenerator:
# ====== Envelope 模型 ======
class Envelope(BaseModel):
"""RPC 统一消息封装
"""RPC 统一消息封装
所有 Host <-> Runner 消息均封装为此格式。
序列化流程Envelope -> .model_dump() -> MsgPack encode
@@ -79,18 +90,44 @@ class Envelope(BaseModel):
"""错误信息 (仅 response)"""
def is_request(self) -> bool:
"""判断当前信封是否为请求消息。
Returns:
bool: 当前消息类型是否为 ``REQUEST``。
"""
return self.message_type == MessageType.REQUEST
def is_response(self) -> bool:
"""判断当前信封是否为响应消息。
Returns:
bool: 当前消息类型是否为 ``RESPONSE``。
"""
return self.message_type == MessageType.RESPONSE
def is_broadcast(self) -> bool:
"""判断当前信封是否为广播消息。
Returns:
bool: 当前消息类型是否为 ``BROADCAST``。
"""
return self.message_type == MessageType.BROADCAST
def make_response(
self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None
) -> "Envelope":
"""基于当前请求创建对应的响应信封"""
"""基于当前请求创建对应的响应信封
Args:
payload: 响应业务载荷。
error: 响应错误信息。
Returns:
Envelope: 对应的响应信封。
"""
return Envelope(
protocol_version=self.protocol_version,
request_id=self.request_id,
@@ -102,7 +139,16 @@ class Envelope(BaseModel):
)
def make_error_response(self, code: str, message: str = "", details: Optional[Dict[str, Any]] = None) -> "Envelope":
"""基于当前请求创建错误响应"""
"""基于当前请求创建错误响应
Args:
code: 错误码。
message: 错误描述。
details: 详细错误信息。
Returns:
Envelope: 错误响应信封。
"""
return self.make_response(
error={
"code": code,
@@ -141,9 +187,7 @@ class ComponentDeclaration(BaseModel):
name: str = Field(description="组件名称")
"""组件名称"""
component_type: str = Field(
description="组件类型action/command/tool/event_handler/hook_handler/message_gateway"
)
component_type: str = Field(description="组件类型action/command/tool/event_handler/hook_handler/message_gateway")
"""组件类型:`action`/`command`/`tool`/`event_handler`/`hook_handler`/`message_gateway`"""
plugin_id: str = Field(description="所属插件 ID")
"""所属插件 ID"""
@@ -170,6 +214,10 @@ class RegisterPluginPayload(BaseModel):
"""插件级依赖插件 ID 列表"""
config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围")
"""订阅的全局配置热重载范围"""
default_config: Dict[str, Any] = Field(default_factory=dict, description="插件默认配置")
"""插件默认配置"""
config_schema: Dict[str, Any] = Field(default_factory=dict, description="插件配置 Schema")
"""插件配置 Schema"""
class BootstrapPluginPayload(BaseModel):
@@ -240,6 +288,8 @@ class RunnerReadyPayload(BaseModel):
"""已完成初始化的插件列表"""
failed_plugins: List[str] = Field(default_factory=list, description="初始化失败的插件列表")
"""初始化失败的插件列表"""
inactive_plugins: List[str] = Field(default_factory=list, description="当前因禁用或依赖不可用而未激活的插件列表")
"""当前因禁用或依赖不可用而未激活的插件列表"""
# ====== 配置更新 ======
@@ -256,6 +306,50 @@ class ConfigUpdatedPayload(BaseModel):
"""配置内容"""
class ValidatePluginConfigPayload(BaseModel):
"""plugin.validate_config 请求 payload。"""
config_data: Dict[str, Any] = Field(default_factory=dict, description="待校验的配置内容")
"""待校验的配置内容"""
class InspectPluginConfigPayload(BaseModel):
"""plugin.inspect_config 请求 payload。"""
config_data: Dict[str, Any] = Field(default_factory=dict, description="可选的配置内容")
"""可选的配置内容"""
use_provided_config: bool = Field(default=False, description="是否优先使用请求中携带的配置内容")
"""是否优先使用请求中携带的配置内容"""
class InspectPluginConfigResultPayload(BaseModel):
"""plugin.inspect_config 响应 payload。"""
success: bool = Field(description="是否解析成功")
"""是否解析成功"""
default_config: Dict[str, Any] = Field(default_factory=dict, description="插件默认配置")
"""插件默认配置"""
config_schema: Dict[str, Any] = Field(default_factory=dict, description="插件配置 Schema")
"""插件配置 Schema"""
normalized_config: Dict[str, Any] = Field(default_factory=dict, description="归一化后的配置内容")
"""归一化后的配置内容"""
changed: bool = Field(default=False, description="是否在归一化过程中自动补齐或修正了配置")
"""是否在归一化过程中自动补齐或修正了配置"""
enabled: bool = Field(default=True, description="插件在当前配置下是否应被视为启用")
"""插件在当前配置下是否应被视为启用"""
class ValidatePluginConfigResultPayload(BaseModel):
"""plugin.validate_config 响应 payload。"""
success: bool = Field(description="是否校验成功")
"""是否校验成功"""
normalized_config: Dict[str, Any] = Field(default_factory=dict, description="校验后的规范化配置")
"""校验后的规范化配置"""
changed: bool = Field(default=False, description="是否在校验过程中自动补齐或归一化")
"""是否在校验过程中自动补齐或归一化"""
# ====== 关停 ======
class ShutdownPayload(BaseModel):
"""plugin.shutdown / plugin.prepare_shutdown payload"""
@@ -314,6 +408,8 @@ class ReloadPluginResultPayload(BaseModel):
"""成功完成重载的插件列表"""
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
"""本次已卸载的插件列表"""
inactive_plugins: List[str] = Field(default_factory=list, description="本次处于未激活状态的插件列表")
"""本次处于未激活状态的插件列表"""
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
"""重载失败的插件及原因"""
@@ -329,6 +425,8 @@ class ReloadPluginsResultPayload(BaseModel):
"""成功完成重载的插件列表"""
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
"""本次已卸载的插件列表"""
inactive_plugins: List[str] = Field(default_factory=list, description="本次处于未激活状态的插件列表")
"""本次处于未激活状态的插件列表"""
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
"""重载失败的插件及原因"""

View File

@@ -609,6 +609,7 @@ class ManifestValidator:
host_version: str = "",
sdk_version: str = "",
project_root: Optional[Path] = None,
validate_python_package_dependencies: bool = True,
) -> None:
"""初始化 Manifest 校验器。
@@ -616,10 +617,12 @@ class ManifestValidator:
host_version: 当前 Host 版本号;留空时自动从主程序 ``pyproject.toml`` 读取。
sdk_version: 当前 SDK 版本号;留空时自动从运行环境中探测。
project_root: 项目根目录;留空时自动推断。
validate_python_package_dependencies: 是否校验 Python 包依赖与当前环境的关系。
"""
self._project_root: Path = project_root or self._resolve_project_root()
self._host_version: str = host_version or self._detect_default_host_version(self._project_root)
self._sdk_version: str = sdk_version or self._detect_default_sdk_version(self._project_root)
self._validate_python_package_dependencies: bool = validate_python_package_dependencies
self.errors: List[str] = []
self.warnings: List[str] = []
@@ -823,9 +826,10 @@ class ManifestValidator:
if not sdk_ok:
self.errors.append(f"SDK 版本不兼容: {sdk_message} (当前 SDK: {self._sdk_version})")
self._validate_python_package_dependencies(manifest)
if self._validate_python_package_dependencies:
self._validate_python_package_dependencies_against_runtime(manifest)
def _validate_python_package_dependencies(self, manifest: PluginManifest) -> None:
def _validate_python_package_dependencies_against_runtime(self, manifest: PluginManifest) -> None:
"""校验 Python 包依赖与主程序运行环境是否冲突。
Args:
@@ -865,6 +869,68 @@ class ManifestValidator:
f"主程序依赖约束为 {host_specifier or '任意版本'}"
)
def load_host_dependency_requirements(self) -> Dict[str, Requirement]:
"""读取主程序在 ``pyproject.toml`` 中声明的依赖约束。
Returns:
Dict[str, Requirement]: 以规范化包名为键的依赖约束映射。
"""
return self._load_host_dependency_requirements(self._project_root)
def get_installed_package_version(self, package_name: str) -> Optional[str]:
"""查询当前运行环境中指定包的安装版本。
Args:
package_name: 需要查询的包名。
Returns:
Optional[str]: 已安装版本号;未安装时返回 ``None``。
"""
return self._get_installed_package_version(package_name)
@staticmethod
def build_specifier_set(version_spec: str) -> Optional[SpecifierSet]:
"""将版本约束文本转换为 ``SpecifierSet``。
Args:
version_spec: 原始版本约束文本。
Returns:
Optional[SpecifierSet]: 转换成功时返回约束对象,否则返回 ``None``。
"""
return ManifestValidator._build_specifier_set(version_spec)
@staticmethod
def version_matches_specifier(version: str, version_spec: str) -> bool:
"""判断版本号是否满足给定约束。
Args:
version: 待判断的版本号。
version_spec: 版本约束表达式。
Returns:
bool: 是否满足约束。
"""
return ManifestValidator._version_matches_specifier(version, version_spec)
@classmethod
def requirements_may_overlap(cls, left: SpecifierSet, right: SpecifierSet) -> bool:
"""判断两个版本约束是否可能存在交集。
Args:
left: 左侧版本约束。
right: 右侧版本约束。
Returns:
bool: 若两者可能同时满足则返回 ``True``。
"""
return cls._requirements_may_overlap(left, right)
def _log_errors(self) -> None:
"""输出当前累计的 Manifest 校验错误。"""
for error_message in self.errors:

View File

@@ -75,6 +75,35 @@ class PluginLoader:
self._failed_plugins: Dict[str, str] = {}
self._manifest_validator = ManifestValidator(host_version=host_version)
self._compat_hook_installed = False
self._blocked_plugin_reasons: Dict[str, str] = {}
def set_blocked_plugin_reasons(self, blocked_plugin_reasons: Optional[Dict[str, str]] = None) -> None:
"""更新当前加载器持有的拒绝加载插件列表。
Args:
blocked_plugin_reasons: 需要拒绝加载的插件及原因映射。
"""
self._blocked_plugin_reasons = {
str(plugin_id or "").strip(): str(reason or "").strip()
for plugin_id, reason in (blocked_plugin_reasons or {}).items()
if str(plugin_id or "").strip() and str(reason or "").strip()
}
def get_blocked_plugin_reason(self, plugin_id: str) -> Optional[str]:
"""返回指定插件当前的拒绝加载原因。
Args:
plugin_id: 目标插件 ID。
Returns:
Optional[str]: 若插件被阻止加载则返回原因,否则返回 ``None``。
"""
normalized_plugin_id = str(plugin_id or "").strip()
if not normalized_plugin_id:
return None
return self._blocked_plugin_reasons.get(normalized_plugin_id)
def discover_and_load(
self,
@@ -156,6 +185,11 @@ class PluginLoader:
return None
plugin_id = manifest.id
if blocked_reason := self.get_blocked_plugin_reason(plugin_id):
self._failed_plugins[plugin_id] = blocked_reason
logger.warning(f"插件 {plugin_id} 已被 Host 依赖流水线阻止加载: {blocked_reason}")
return None
return plugin_id, (plugin_dir, manifest, plugin_path)
def _record_duplicate_candidates(self, duplicate_candidates: Dict[str, List[Path]]) -> None:

View File

@@ -9,8 +9,10 @@
6. 转发插件的能力调用到 Host
"""
from collections.abc import Mapping
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, cast
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast
import asyncio
import contextlib
@@ -23,8 +25,11 @@ import sys
import time
import tomllib
import tomlkit
from src.common.logger import get_console_handler, get_logger, initialize_logging
from src.plugin_runtime import (
ENV_BLOCKED_PLUGIN_REASONS,
ENV_EXTERNAL_PLUGIN_IDS,
ENV_HOST_VERSION,
ENV_IPC_ADDRESS,
@@ -37,6 +42,8 @@ from src.plugin_runtime.protocol.envelope import (
ConfigUpdatedPayload,
Envelope,
HealthPayload,
InspectPluginConfigPayload,
InspectPluginConfigResultPayload,
InvokePayload,
InvokeResultPayload,
RegisterPluginPayload,
@@ -46,6 +53,8 @@ from src.plugin_runtime.protocol.envelope import (
ReloadPluginsResultPayload,
RunnerReadyPayload,
UnregisterPluginPayload,
ValidatePluginConfigPayload,
ValidatePluginConfigResultPayload,
)
from src.plugin_runtime.protocol.errors import ErrorCode
from src.plugin_runtime.runner.log_handler import RunnerIPCLogHandler
@@ -79,6 +88,72 @@ class _ContextAwarePlugin(Protocol):
"""
class _ConfigAwarePlugin(Protocol):
"""支持声明式插件配置能力的插件协议。"""
def normalize_plugin_config(self, config_data: Optional[Mapping[str, Any]]) -> Tuple[Dict[str, Any], bool]:
"""对插件配置进行归一化与补齐。
Args:
config_data: 原始配置数据。
Returns:
Tuple[Dict[str, Any], bool]: 归一化后的配置,以及是否发生自动变更。
"""
...
def set_plugin_config(self, config: Dict[str, Any]) -> None:
"""注入插件当前配置。
Args:
config: 当前最新插件配置。
"""
...
def get_default_config(self) -> Dict[str, Any]:
"""返回插件默认配置。
Returns:
Dict[str, Any]: 默认配置字典。
"""
...
def get_webui_config_schema(
self,
*,
plugin_id: str = "",
plugin_name: str = "",
plugin_version: str = "",
plugin_description: str = "",
plugin_author: str = "",
) -> Dict[str, Any]:
"""返回插件配置 Schema。
Args:
plugin_id: 插件 ID。
plugin_name: 插件名称。
plugin_version: 插件版本。
plugin_description: 插件描述。
plugin_author: 插件作者。
Returns:
Dict[str, Any]: WebUI 配置 Schema。
"""
...
class PluginActivationStatus(str, Enum):
"""描述插件激活结果。"""
LOADED = "loaded"
INACTIVE = "inactive"
FAILED = "failed"
def _install_shutdown_signal_handlers(
mark_runner_shutting_down: Callable[[], None],
loop: Optional[asyncio.AbstractEventLoop] = None,
@@ -122,6 +197,7 @@ class PluginRunner:
session_token: str,
plugin_dirs: List[str],
external_available_plugins: Optional[Dict[str, str]] = None,
blocked_plugin_reasons: Optional[Dict[str, str]] = None,
) -> None:
"""初始化 Runner。
@@ -130,6 +206,7 @@ class PluginRunner:
session_token: 握手用会话令牌。
plugin_dirs: 当前 Runner 负责扫描的插件目录列表。
external_available_plugins: 视为已满足的外部依赖插件版本映射。
blocked_plugin_reasons: 需要拒绝加载的插件及原因映射。
"""
self._host_address: str = host_address
self._session_token: str = session_token
@@ -139,9 +216,15 @@ class PluginRunner:
for plugin_id, plugin_version in (external_available_plugins or {}).items()
if str(plugin_id or "").strip() and str(plugin_version or "").strip()
}
self._blocked_plugin_reasons: Dict[str, str] = {
str(plugin_id or "").strip(): str(reason or "").strip()
for plugin_id, reason in (blocked_plugin_reasons or {}).items()
if str(plugin_id or "").strip() and str(reason or "").strip()
}
self._rpc_client: RPCClient = RPCClient(host_address, session_token)
self._loader: PluginLoader = PluginLoader(host_version=os.getenv(ENV_HOST_VERSION, ""))
self._loader.set_blocked_plugin_reasons(self._blocked_plugin_reasons)
self._start_time: float = time.monotonic()
self._shutting_down: bool = False
self._reload_lock: asyncio.Lock = asyncio.Lock()
@@ -174,13 +257,43 @@ class PluginRunner:
# 4. 注入 PluginContext + 调用 on_load 生命周期钩子
failed_plugins: Set[str] = set(self._loader.failed_plugins.keys())
inactive_plugins: Set[str] = set()
available_plugin_versions: Dict[str, str] = dict(self._external_available_plugins)
for meta in plugins:
ok = await self._activate_plugin(meta)
if not ok:
unsatisfied_dependencies = [
dependency.id
for dependency in meta.manifest.plugin_dependencies
if dependency.id not in available_plugin_versions
or not self._loader.manifest_validator.is_plugin_dependency_satisfied(
dependency,
available_plugin_versions[dependency.id],
)
]
if unsatisfied_dependencies:
if any(dependency_id in inactive_plugins for dependency_id in unsatisfied_dependencies):
logger.info(
f"插件 {meta.plugin_id} 依赖的插件当前未激活,跳过本次启动: {', '.join(unsatisfied_dependencies)}"
)
inactive_plugins.add(meta.plugin_id)
continue
failed_plugins.add(meta.plugin_id)
continue
successful_plugins = [meta.plugin_id for meta in plugins if meta.plugin_id not in failed_plugins]
await self._notify_ready(successful_plugins, sorted(failed_plugins))
activation_status = await self._activate_plugin(meta)
if activation_status == PluginActivationStatus.LOADED:
available_plugin_versions[meta.plugin_id] = meta.version
continue
if activation_status == PluginActivationStatus.INACTIVE:
inactive_plugins.add(meta.plugin_id)
continue
failed_plugins.add(meta.plugin_id)
successful_plugins = [
meta.plugin_id
for meta in plugins
if meta.plugin_id not in failed_plugins and meta.plugin_id not in inactive_plugins
]
await self._notify_ready(successful_plugins, sorted(failed_plugins), sorted(inactive_plugins))
# 5. 等待直到收到关停信号
with contextlib.suppress(asyncio.CancelledError):
@@ -271,14 +384,11 @@ class PluginRunner:
始终绑定为当前插件实例,避免伪造其他插件身份申请能力。
"""
if plugin_id and plugin_id != bound_plugin_id:
logger.warning(
f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC已强制绑定回自身身份"
)
logger.warning(f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC已强制绑定回自身身份")
normalized_method = str(method or "").strip()
if normalized_method not in _PLUGIN_ALLOWED_RAW_HOST_METHODS:
raise PermissionError(
f"插件 {bound_plugin_id} 不允许直接调用 Host 原始 RPC 方法: "
f"{normalized_method or '<empty>'}"
f"插件 {bound_plugin_id} 不允许直接调用 Host 原始 RPC 方法: {normalized_method or '<empty>'}"
)
resp = await rpc_client.send_request(
method=normalized_method,
@@ -293,17 +403,101 @@ class PluginRunner:
cast(_ContextAwarePlugin, instance)._set_context(ctx)
logger.debug(f"已为插件 {plugin_id} 注入 PluginContext")
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[Dict[str, Any]] = None) -> None:
"""在 Runner 侧为插件实例注入当前插件配置。"""
instance = meta.instance
if not hasattr(instance, "set_plugin_config"):
return
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""在 Runner 侧为插件实例注入当前插件配置。
Args:
meta: 插件元数据。
config_data: 可选的配置数据;留空时自动从插件目录读取。
Returns:
Dict[str, Any]: 归一化后的当前插件配置。
"""
instance = meta.instance
raw_config = config_data if config_data is not None else self._load_plugin_config(meta.plugin_dir)
plugin_config, should_persist = self._normalize_plugin_config(instance, raw_config)
config_path = Path(meta.plugin_dir) / "config.toml"
default_config = self._get_plugin_default_config(instance)
should_initialize_file = not config_path.exists() and bool(default_config)
if should_persist or should_initialize_file:
self._save_plugin_config(meta.plugin_dir, plugin_config)
if hasattr(instance, "set_plugin_config"):
try:
cast(_ConfigAwarePlugin, instance).set_plugin_config(plugin_config)
except Exception as exc:
logger.warning(f"插件 {meta.plugin_id} 配置注入失败: {exc}")
return plugin_config
def _normalize_plugin_config(
self,
instance: object,
config_data: Optional[Dict[str, Any]],
*,
suppress_errors: bool = True,
) -> Tuple[Dict[str, Any], bool]:
"""对插件配置做统一归一化处理。
Args:
instance: 插件实例。
config_data: 原始配置数据。
suppress_errors: 是否在归一化失败时吞掉异常并回退原始配置。
Returns:
Tuple[Dict[str, Any], bool]: 归一化后的配置,以及是否需要回写文件。
"""
normalized_config = dict(config_data or {})
if not hasattr(instance, "normalize_plugin_config"):
return normalized_config, False
plugin_config = config_data if config_data is not None else self._load_plugin_config(meta.plugin_dir, meta.plugin_id)
try:
instance.set_plugin_config(plugin_config)
return cast(_ConfigAwarePlugin, instance).normalize_plugin_config(normalized_config)
except Exception as exc:
logger.warning(f"插件 {meta.plugin_id} 配置注入失败: {exc}")
if not suppress_errors:
raise
logger.warning(f"插件配置归一化失败,将回退为原始配置: {exc}")
return normalized_config, False
@staticmethod
def _is_plugin_enabled(config_data: Optional[Mapping[str, Any]]) -> bool:
"""根据配置内容判断插件是否应被视为启用。
Args:
config_data: 当前插件配置。
Returns:
bool: 插件是否启用。
"""
if not isinstance(config_data, Mapping):
return True
plugin_section = config_data.get("plugin")
if not isinstance(plugin_section, Mapping):
return True
enabled_value = plugin_section.get("enabled", True)
if isinstance(enabled_value, str):
normalized_value = enabled_value.strip().lower()
if normalized_value in {"0", "false", "no", "off"}:
return False
if normalized_value in {"1", "true", "yes", "on"}:
return True
return bool(enabled_value)
@staticmethod
def _save_plugin_config(plugin_dir: str, config_data: Dict[str, Any]) -> None:
"""将插件配置写回到 ``config.toml``。
Args:
plugin_dir: 插件目录。
config_data: 需要写回的配置字典。
"""
config_path = Path(plugin_dir) / "config.toml"
config_path.parent.mkdir(parents=True, exist_ok=True)
with config_path.open("w", encoding="utf-8") as handle:
handle.write(tomlkit.dumps(config_data))
@staticmethod
def _load_plugin_config(plugin_dir: str, plugin_id: str = "") -> Dict[str, Any]:
@@ -322,6 +516,99 @@ class PluginRunner:
return loaded if isinstance(loaded, dict) else {}
def _resolve_plugin_candidate(self, plugin_id: str) -> Tuple[Optional[PluginCandidate], Optional[str]]:
"""解析指定插件的候选目录。
Args:
plugin_id: 目标插件 ID。
Returns:
Tuple[Optional[PluginCandidate], Optional[str]]: 候选插件与错误信息。
"""
candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs)
if plugin_id in duplicate_candidates:
conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id])
return None, f"检测到重复插件 ID: {conflict_paths}"
candidate = candidates.get(plugin_id)
if candidate is None:
return None, f"未找到插件: {plugin_id}"
return candidate, None
def _resolve_plugin_meta_for_config_request(
self,
plugin_id: str,
) -> Tuple[Optional[PluginMeta], bool, Optional[str]]:
"""为配置相关请求解析插件元数据。
Args:
plugin_id: 目标插件 ID。
Returns:
Tuple[Optional[PluginMeta], bool, Optional[str]]: 依次为插件元数据、
是否为临时冷加载实例、以及错误信息。
"""
loaded_meta = self._loader.get_plugin(plugin_id)
if loaded_meta is not None:
return loaded_meta, False, None
candidate, error_message = self._resolve_plugin_candidate(plugin_id)
if candidate is None:
return None, False, error_message
try:
meta = self._loader.load_candidate(plugin_id, candidate)
except Exception as exc:
return None, False, str(exc)
if meta is None:
return None, False, "插件模块加载失败"
return meta, True, None
def _inspect_plugin_config(
self,
meta: PluginMeta,
*,
config_data: Optional[Dict[str, Any]] = None,
use_provided_config: bool = False,
suppress_errors: bool = True,
) -> InspectPluginConfigResultPayload:
"""解析插件代码定义的配置元数据。
Args:
meta: 插件元数据。
config_data: 可选的配置内容。
use_provided_config: 是否优先使用传入的配置内容。
suppress_errors: 是否在归一化失败时回退原始配置。
Returns:
InspectPluginConfigResultPayload: 结构化解析结果。
"""
raw_config = config_data if use_provided_config else self._load_plugin_config(meta.plugin_dir)
if use_provided_config and config_data is None:
raw_config = {}
normalized_config, changed = self._normalize_plugin_config(
meta.instance,
raw_config,
suppress_errors=suppress_errors,
)
default_config = self._get_plugin_default_config(meta.instance)
if not normalized_config and not raw_config and default_config:
normalized_config = dict(default_config)
changed = True
return InspectPluginConfigResultPayload(
success=True,
default_config=default_config,
config_schema=self._get_plugin_config_schema(meta),
normalized_config=normalized_config,
changed=changed,
enabled=self._is_plugin_enabled(normalized_config),
)
def _register_handlers(self) -> None:
"""注册 Host -> Runner 的方法处理器。"""
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
@@ -335,6 +622,8 @@ class PluginRunner:
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated)
self._rpc_client.register_method("plugin.inspect_config", self._handle_inspect_plugin_config)
self._rpc_client.register_method("plugin.validate_config", self._handle_validate_plugin_config)
self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin)
self._rpc_client.register_method("plugin.reload_batch", self._handle_reload_plugins)
@@ -452,6 +741,8 @@ class PluginRunner:
capabilities_required=meta.capabilities_required,
dependencies=meta.dependencies,
config_reload_subscriptions=config_reload_subscriptions,
default_config=self._get_plugin_default_config(instance),
config_schema=self._get_plugin_config_schema(meta),
)
try:
@@ -463,12 +754,62 @@ class PluginRunner:
)
if response.error:
raise RuntimeError(response.error.get("message", "插件注册失败"))
response_payload = response.payload if isinstance(response.payload, dict) else {}
if not bool(response_payload.get("accepted", True)):
raise RuntimeError(str(response_payload.get("reason", "插件注册失败")))
logger.info(f"插件 {meta.plugin_id} 注册完成")
return True
except Exception as e:
logger.error(f"插件 {meta.plugin_id} 注册失败: {e}")
return False
@staticmethod
def _get_plugin_default_config(instance: object) -> Dict[str, Any]:
"""获取插件默认配置。
Args:
instance: 插件实例。
Returns:
Dict[str, Any]: 默认配置;插件未声明时返回空字典。
"""
if not hasattr(instance, "get_default_config"):
return {}
try:
default_config = cast(_ConfigAwarePlugin, instance).get_default_config()
except Exception as exc:
logger.warning(f"读取插件默认配置失败: {exc}")
return {}
return default_config if isinstance(default_config, dict) else {}
@staticmethod
def _get_plugin_config_schema(meta: PluginMeta) -> Dict[str, Any]:
"""获取插件 WebUI 配置 Schema。
Args:
meta: 插件元数据。
Returns:
Dict[str, Any]: 插件配置 Schema插件未声明时返回空字典。
"""
instance = meta.instance
if not hasattr(instance, "get_webui_config_schema"):
return {}
try:
schema = cast(_ConfigAwarePlugin, instance).get_webui_config_schema(
plugin_id=meta.plugin_id,
plugin_name=meta.manifest.name,
plugin_version=meta.version,
plugin_description=meta.manifest.description,
plugin_author=meta.manifest.author.name,
)
except Exception as exc:
logger.warning(f"构造插件配置 Schema 失败: {exc}")
return {}
return schema if isinstance(schema, dict) else {}
async def _unregister_plugin(self, plugin_id: str, reason: str) -> None:
"""通知 Host 注销指定插件。
@@ -526,36 +867,40 @@ class PluginRunner:
except Exception as exc:
logger.error(f"插件 {meta.plugin_id} on_unload 失败: {exc}", exc_info=True)
async def _activate_plugin(self, meta: PluginMeta) -> bool:
async def _activate_plugin(self, meta: PluginMeta) -> PluginActivationStatus:
"""完成插件注入、授权、生命周期和组件注册。
Args:
meta: 待激活的插件元数据。
Returns:
bool: 是否激活成功
PluginActivationStatus: 插件激活结果
"""
self._inject_context(meta.plugin_id, meta.instance)
self._apply_plugin_config(meta)
plugin_config = self._apply_plugin_config(meta)
if not self._is_plugin_enabled(plugin_config):
logger.info(f"插件 {meta.plugin_id} 已在配置中禁用,跳过激活")
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return PluginActivationStatus.INACTIVE
if not await self._bootstrap_plugin(meta):
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False
return PluginActivationStatus.FAILED
if not await self._register_plugin(meta):
await self._invoke_plugin_on_unload(meta)
await self._deactivate_plugin(meta)
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False
return PluginActivationStatus.FAILED
if not await self._invoke_plugin_on_load(meta):
await self._unregister_plugin(meta.plugin_id, reason="on_load_failed")
await self._deactivate_plugin(meta)
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False
return PluginActivationStatus.FAILED
self._loader.set_loaded_plugin(meta)
return True
return PluginActivationStatus.LOADED
async def _unload_plugin(self, meta: PluginMeta, reason: str, *, purge_modules: bool = True) -> None:
"""卸载单个插件并清理 Host/Runner 两侧状态。
@@ -632,7 +977,9 @@ class PluginRunner:
continue
dependency_graph[plugin_id] = {dependency for dependency in meta.dependencies if dependency in plugin_ids}
indegree: Dict[str, int] = {plugin_id: len(dependencies) for plugin_id, dependencies in dependency_graph.items()}
indegree: Dict[str, int] = {
plugin_id: len(dependencies) for plugin_id, dependencies in dependency_graph.items()
}
reverse_graph: Dict[str, Set[str]] = {plugin_id: set() for plugin_id in dependency_graph}
for plugin_id, dependencies in dependency_graph.items():
@@ -678,9 +1025,7 @@ class PluginRunner:
for failed_plugin_id, failure_reason in failed_plugins.items():
rollback_failure = rollback_failures.get(failed_plugin_id)
if rollback_failure:
finalized_failures[failed_plugin_id] = (
f"{failure_reason};且旧版本恢复失败: {rollback_failure}"
)
finalized_failures[failed_plugin_id] = f"{failure_reason};且旧版本恢复失败: {rollback_failure}"
else:
finalized_failures[failed_plugin_id] = f"{failure_reason}(已恢复旧版本)"
@@ -716,6 +1061,7 @@ class PluginRunner:
requested_plugin_id=plugin_id,
reloaded_plugins=batch_result.reloaded_plugins,
unloaded_plugins=batch_result.unloaded_plugins,
inactive_plugins=batch_result.inactive_plugins,
failed_plugins=batch_result.failed_plugins,
)
@@ -762,9 +1108,7 @@ class PluginRunner:
failed_plugins=failed_plugins,
)
target_plugin_ids: Set[str] = {
plugin_id for plugin_id in reload_root_ids if plugin_id not in loaded_plugin_ids
}
target_plugin_ids: Set[str] = {plugin_id for plugin_id in reload_root_ids if plugin_id not in loaded_plugin_ids}
if loaded_root_plugin_ids := reload_root_ids & loaded_plugin_ids:
target_plugin_ids.update(self._collect_reverse_dependents_for_roots(loaded_root_plugin_ids))
@@ -812,6 +1156,8 @@ class PluginRunner:
},
}
reloaded_plugins: List[str] = []
inactive_plugins: List[str] = []
inactive_plugin_ids: Set[str] = set()
for load_plugin_id in load_order:
if load_plugin_id in failed_plugins:
@@ -822,10 +1168,28 @@ class PluginRunner:
continue
_, manifest, _ = candidate
unsatisfied_dependency_ids = [
dependency.id
for dependency in manifest.plugin_dependencies
if dependency.id not in available_plugins
or not self._loader.manifest_validator.is_plugin_dependency_satisfied(
dependency,
available_plugins[dependency.id],
)
]
if unsatisfied_dependencies := self._loader.manifest_validator.get_unsatisfied_plugin_dependencies(
manifest,
available_plugin_versions=available_plugins,
):
if load_plugin_id not in reload_root_ids and any(
dependency_id in inactive_plugin_ids for dependency_id in unsatisfied_dependency_ids
):
logger.info(
f"插件 {load_plugin_id} 的依赖当前未激活,保留为未激活状态: {', '.join(unsatisfied_dependencies)}"
)
inactive_plugin_ids.add(load_plugin_id)
inactive_plugins.append(load_plugin_id)
continue
failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(unsatisfied_dependencies)}"
continue
@@ -835,9 +1199,13 @@ class PluginRunner:
continue
activated = await self._activate_plugin(meta)
if not activated:
if activated == PluginActivationStatus.FAILED:
failed_plugins[load_plugin_id] = "插件初始化失败"
continue
if activated == PluginActivationStatus.INACTIVE:
inactive_plugin_ids.add(load_plugin_id)
inactive_plugins.append(load_plugin_id)
continue
available_plugins[load_plugin_id] = meta.version
reloaded_plugins.append(load_plugin_id)
@@ -872,7 +1240,7 @@ class PluginRunner:
rollback_failures[rollback_plugin_id] = str(exc)
continue
if not restored:
if restored != PluginActivationStatus.LOADED:
rollback_failures[rollback_plugin_id] = "无法重新激活旧版本"
return ReloadPluginsResultPayload(
@@ -880,29 +1248,40 @@ class PluginRunner:
requested_plugin_ids=normalized_plugin_ids,
reloaded_plugins=[],
unloaded_plugins=unloaded_plugins,
inactive_plugins=[],
failed_plugins=self._finalize_failed_reload_messages(failed_plugins, rollback_failures),
)
requested_plugin_success = all(plugin_id in reloaded_plugins for plugin_id in reload_root_ids)
requested_plugin_success = all(
plugin_id in reloaded_plugins or plugin_id in inactive_plugins for plugin_id in reload_root_ids
)
return ReloadPluginsResultPayload(
success=requested_plugin_success and not failed_plugins,
requested_plugin_ids=normalized_plugin_ids,
reloaded_plugins=reloaded_plugins,
unloaded_plugins=unloaded_plugins,
inactive_plugins=inactive_plugins,
failed_plugins=failed_plugins,
)
async def _notify_ready(self, loaded_plugins: List[str], failed_plugins: List[str]) -> None:
async def _notify_ready(
self,
loaded_plugins: List[str],
failed_plugins: List[str],
inactive_plugins: List[str],
) -> None:
"""通知 Host 当前 Runner 已完成插件初始化。
Args:
loaded_plugins: 成功初始化的插件列表。
failed_plugins: 初始化失败的插件列表。
inactive_plugins: 因禁用或依赖不可用而未激活的插件列表。
"""
payload = RunnerReadyPayload(
loaded_plugins=loaded_plugins,
failed_plugins=failed_plugins,
inactive_plugins=inactive_plugins,
)
await self._rpc_client.send_request(
"runner.ready",
@@ -1128,6 +1507,87 @@ class PluginRunner:
return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
return envelope.make_response(payload={"acknowledged": True})
async def _handle_inspect_plugin_config(self, envelope: Envelope) -> Envelope:
"""处理插件配置元数据解析请求。
Args:
envelope: RPC 请求信封。
Returns:
Envelope: RPC 响应信封。
"""
try:
payload = InspectPluginConfigPayload.model_validate(envelope.payload)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
plugin_id = envelope.plugin_id
meta, is_temporary_meta, error_message = self._resolve_plugin_meta_for_config_request(plugin_id)
if meta is None:
return envelope.make_error_response(
ErrorCode.E_PLUGIN_NOT_FOUND.value,
error_message or f"未找到插件: {plugin_id}",
)
try:
result = self._inspect_plugin_config(
meta,
config_data=payload.config_data,
use_provided_config=payload.use_provided_config,
suppress_errors=True,
)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
finally:
if is_temporary_meta:
self._loader.purge_plugin_modules(plugin_id, meta.plugin_dir)
return envelope.make_response(payload=result.model_dump())
async def _handle_validate_plugin_config(self, envelope: Envelope) -> Envelope:
"""处理插件配置校验请求。
Args:
envelope: RPC 请求信封。
Returns:
Envelope: RPC 响应信封。
"""
try:
payload = ValidatePluginConfigPayload.model_validate(envelope.payload)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
plugin_id = envelope.plugin_id
meta, is_temporary_meta, error_message = self._resolve_plugin_meta_for_config_request(plugin_id)
if meta is None:
return envelope.make_error_response(
ErrorCode.E_PLUGIN_NOT_FOUND.value,
error_message or f"未找到插件: {plugin_id}",
)
try:
inspection_result = self._inspect_plugin_config(
meta,
config_data=payload.config_data,
use_provided_config=True,
suppress_errors=False,
)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
finally:
if is_temporary_meta:
self._loader.purge_plugin_modules(plugin_id, meta.plugin_dir)
result = ValidatePluginConfigResultPayload(
success=True,
normalized_config=inspection_result.normalized_config,
changed=inspection_result.changed,
)
return envelope.make_response(payload=result.model_dump())
async def _handle_reload_plugin(self, envelope: Envelope) -> Envelope:
"""处理按插件 ID 的精确重载请求。
@@ -1189,6 +1649,7 @@ class PluginRunner:
async def _async_main() -> None:
"""异步主入口"""
blocked_plugin_reasons_raw = os.environ.get(ENV_BLOCKED_PLUGIN_REASONS, "")
host_address = os.environ.pop(ENV_IPC_ADDRESS, "")
external_plugin_ids_raw = os.environ.get(ENV_EXTERNAL_PLUGIN_IDS, "")
session_token = os.environ.pop(ENV_SESSION_TOKEN, "")
@@ -1208,14 +1669,30 @@ async def _async_main() -> None:
logger.warning("外部依赖插件版本映射格式非法,已回退为空映射")
external_plugin_ids = {}
try:
blocked_plugin_reasons = json.loads(blocked_plugin_reasons_raw) if blocked_plugin_reasons_raw else {}
except json.JSONDecodeError:
logger.warning("解析阻止加载插件原因映射失败,已回退为空映射")
blocked_plugin_reasons = {}
if not isinstance(blocked_plugin_reasons, dict):
logger.warning("阻止加载插件原因映射格式非法,已回退为空映射")
blocked_plugin_reasons = {}
runner_kwargs: Dict[str, Any] = {
"external_available_plugins": {
str(plugin_id): str(plugin_version) for plugin_id, plugin_version in external_plugin_ids.items()
}
}
if blocked_plugin_reasons:
runner_kwargs["blocked_plugin_reasons"] = {
str(plugin_id): str(reason) for plugin_id, reason in blocked_plugin_reasons.items()
}
runner = PluginRunner(
host_address,
session_token,
plugin_dirs,
external_available_plugins={
str(plugin_id): str(plugin_version)
for plugin_id, plugin_version in external_plugin_ids.items()
},
**runner_kwargs,
)
# 注册信号处理

View File

@@ -1,41 +0,0 @@
{
"manifest_version": 2,
"version": "2.0.0",
"name": "Emoji插件 (Emoji Actions)",
"description": "可以发送和管理 Emoji",
"author": {
"name": "SengokuCola",
"url": "https://github.com/MaiM-with-u"
},
"license": "GPL-v3.0-or-later",
"urls": {
"repository": "https://github.com/MaiM-with-u/maibot",
"homepage": "https://github.com/MaiM-with-u/maibot",
"documentation": "https://github.com/MaiM-with-u/maibot",
"issues": "https://github.com/MaiM-with-u/maibot/issues"
},
"host_application": {
"min_version": "1.0.0",
"max_version": "1.0.0"
},
"sdk": {
"min_version": "2.0.0",
"max_version": "2.99.99"
},
"dependencies": [],
"capabilities": [
"emoji.get_random",
"message.get_recent",
"message.build_readable",
"llm.generate",
"send.emoji",
"config.get"
],
"i18n": {
"default_locale": "zh-CN",
"supported_locales": [
"zh-CN"
]
},
"id": "builtin.emoji-plugin"
}

View File

@@ -1,129 +0,0 @@
"""Emoji 插件 — 新 SDK 版本
根据聊天上下文的情感,使用 LLM 选择并发送合适的表情包。
"""
from maibot_sdk import Action, MaiBotPlugin
from maibot_sdk.types import ActivationType
import random
class EmojiPlugin(MaiBotPlugin):
"""表情包插件"""
@Action(
"emoji",
description="发送表情包辅助表达情绪",
activation_type=ActivationType.RANDOM,
activation_probability=0.3,
parallel_action=True,
action_require=[
"发送表情包辅助表达情绪",
"表达情绪时可以选择使用",
"不要连续发送,如果你已经发过[表情包],就不要选择此动作",
],
associated_types=["emoji"],
)
async def handle_emoji(self, stream_id: str = "", reasoning: str = "", chat_id: str = "", **kwargs):
"""执行表情动作"""
reason = reasoning or "表达当前情绪"
# 1. 随机获取30个表情包
sampled_emojis = await self.ctx.emoji.get_random(30)
if not sampled_emojis:
return False, "无法获取随机表情包"
# 2. 按情感分组
emotion_map: dict[str, list] = {}
for emoji in sampled_emojis:
emo = emoji.get("emotion", "")
if emo not in emotion_map:
emotion_map[emo] = []
emotion_map[emo].append(emoji)
available_emotions = list(emotion_map.keys())
if not available_emotions:
# 无情感标签,随机发送
chosen = random.choice(sampled_emojis)
await self.ctx.send.emoji(chosen["base64"], stream_id)
return True, "随机发送了表情包"
# 3. 获取最近消息作为上下文
messages_text = ""
if chat_id:
recent_messages = await self.ctx.message.get_recent(chat_id=chat_id, limit=5)
if recent_messages:
messages_text = await self.ctx.message.build_readable(
recent_messages,
timestamp_mode="normal_no_YMD",
truncate=False,
)
# 4. 构建 prompt 让 LLM 选择情感
available_emotions_str = "\n".join(available_emotions)
prompt = f"""你正在进行QQ聊天你需要根据聊天记录选出一个合适的情感标签。
请你根据以下原因和聊天记录进行选择
原因:{reason}
聊天记录:
{messages_text}
这里是可用的情感标签:
{available_emotions_str}
请直接返回最匹配的那个情感标签,不要进行任何解释或添加其他多余的文字。
"""
# 5. 调用 LLM
llm_result = await self.ctx.llm.generate(prompt=prompt, model_name="utils")
if not llm_result or not llm_result.get("success"):
chosen = random.choice(sampled_emojis)
await self.ctx.send.emoji(chosen["base64"], stream_id)
return True, "LLM调用失败随机发送了表情包"
chosen_emotion = llm_result.get("response", "").strip().replace('"', "").replace("'", "")
# 6. 根据选择的情感匹配表情包
if chosen_emotion in emotion_map:
chosen = random.choice(emotion_map[chosen_emotion])
else:
chosen = random.choice(sampled_emojis)
# 7. 发送
send_ok = await self.ctx.send.emoji(chosen["base64"], stream_id)
if send_ok:
return True, f"成功发送表情包:[表情包:{chosen_emotion}]"
return False, "发送表情包失败"
async def on_load(self) -> None:
"""处理插件加载。"""
# 从插件配置读取 emoji_chance 来覆盖默认概率
await self.ctx.config.get("emoji.emoji_chance")
async def on_unload(self) -> None:
"""处理插件卸载。"""
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
"""处理配置热重载事件。
Args:
scope: 配置变更范围。
config_data: 最新配置数据。
version: 配置版本号。
"""
del config_data
del version
if scope == "self":
await self.ctx.config.get("emoji.emoji_chance")
def create_plugin() -> EmojiPlugin:
"""创建 Emoji 插件实例。
Returns:
EmojiPlugin: 新的 Emoji 插件实例。
"""
return EmojiPlugin()

View File

@@ -0,0 +1,937 @@
"""HTML 浏览器渲染服务。
负责在 Host 侧复用已有浏览器,并将 HTML 内容渲染为 PNG 图片。
"""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from importlib import metadata
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Tuple, cast
from urllib.parse import urlparse
import asyncio
import base64
import contextlib
import functools
import json
import os
import shutil
import sys
import time
from src.common.logger import PROJECT_ROOT, get_logger
from src.config.config import config_manager
from src.config.official_configs import PluginRuntimeRenderConfig
logger = get_logger("services.html_render_service")
_NETWORK_ALLOW_SCHEMES = frozenset({"about", "blob", "data", "file"})
_WINDOWS_BROWSER_PATHS = (
Path("C:/Program Files/Google/Chrome/Application/chrome.exe"),
Path("C:/Program Files (x86)/Google/Chrome/Application/chrome.exe"),
Path("C:/Program Files/Microsoft/Edge/Application/msedge.exe"),
Path("C:/Program Files (x86)/Microsoft/Edge/Application/msedge.exe"),
)
_MACOS_BROWSER_PATHS = (
Path("/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"),
Path("/Applications/Microsoft Edge.app/Contents/MacOS/Microsoft Edge"),
)
_UNIX_BROWSER_NAMES = (
"chromium",
"chromium-browser",
"google-chrome",
"google-chrome-stable",
"microsoft-edge",
"msedge",
)
_PLAYWRIGHT_MANAGED_BROWSER_PREFIXES = ("chromium-", "chrome-", "chrome-headless-shell-")
@dataclass(slots=True)
class HtmlRenderRequest:
"""描述一次 HTML 转 PNG 请求。"""
html: str
selector: str = "body"
viewport_width: int = 900
viewport_height: int = 500
device_scale_factor: float = 2.0
full_page: bool = False
omit_background: bool = False
wait_until: str = "load"
wait_for_selector: str = ""
wait_for_timeout_ms: int = 0
timeout_ms: int = 10000
allow_network: bool = False
@dataclass(slots=True)
class HtmlRenderResult:
"""描述一次 HTML 转 PNG 的输出结果。"""
image_base64: str
mime_type: str
width: int
height: int
render_ms: int
def to_payload(self) -> Dict[str, Any]:
"""将结果序列化为能力层返回结构。
Returns:
Dict[str, Any]: 可直接返回给插件运行时的结构化数据。
"""
return {
"image_base64": self.image_base64,
"mime_type": self.mime_type,
"width": self.width,
"height": self.height,
"render_ms": self.render_ms,
}
@dataclass(slots=True)
class ManagedBrowserRecord:
"""记录 Playwright 托管浏览器的本地状态。"""
browser_name: str
browsers_path: str
install_source: Literal["auto_download", "existing_cache"]
playwright_version: str
recorded_at: str
last_verified_at: str
def to_dict(self) -> Dict[str, str]:
"""将浏览器记录转换为可持久化字典。
Returns:
Dict[str, str]: 可写入 JSON 文件的字典结构。
"""
return {
"browser_name": self.browser_name,
"browsers_path": self.browsers_path,
"install_source": self.install_source,
"playwright_version": self.playwright_version,
"recorded_at": self.recorded_at,
"last_verified_at": self.last_verified_at,
}
@classmethod
def from_dict(cls, payload: Dict[str, Any]) -> Optional["ManagedBrowserRecord"]:
"""从字典中恢复浏览器状态记录。
Args:
payload: 原始字典数据。
Returns:
Optional[ManagedBrowserRecord]: 解析成功时返回记录对象,否则返回 ``None``。
"""
browser_name = str(payload.get("browser_name", "") or "").strip()
browsers_path = str(payload.get("browsers_path", "") or "").strip()
install_source = str(payload.get("install_source", "") or "").strip()
playwright_version = str(payload.get("playwright_version", "") or "").strip()
recorded_at = str(payload.get("recorded_at", "") or "").strip()
last_verified_at = str(payload.get("last_verified_at", "") or "").strip()
if not all([browser_name, browsers_path, install_source, playwright_version, recorded_at, last_verified_at]):
return None
if install_source not in {"auto_download", "existing_cache"}:
return None
validated_install_source = cast(Literal["auto_download", "existing_cache"], install_source)
return cls(
browser_name=browser_name,
browsers_path=browsers_path,
install_source=validated_install_source,
playwright_version=playwright_version,
recorded_at=recorded_at,
last_verified_at=last_verified_at,
)
class HTMLRenderService:
"""HTML 浏览器渲染服务。"""
def __init__(self) -> None:
"""初始化渲染服务。"""
self._browser: Any = None
self._browser_lock: asyncio.Lock = asyncio.Lock()
self._connected_via_cdp: bool = False
self._playwright: Any = None
self._render_count: int = 0
self._render_semaphore: Optional[asyncio.Semaphore] = None
self._render_semaphore_limit: int = 0
def _get_render_config(self) -> PluginRuntimeRenderConfig:
"""读取当前插件运行时的浏览器渲染配置。
Returns:
PluginRuntimeRenderConfig: 当前生效的浏览器渲染配置。
"""
return config_manager.get_global_config().plugin_runtime.render
def _get_render_semaphore(self) -> asyncio.Semaphore:
"""根据当前配置返回渲染并发信号量。
Returns:
asyncio.Semaphore: 控制并发的信号量对象。
"""
config = self._get_render_config()
limit = max(1, int(config.concurrency_limit))
if self._render_semaphore is None or self._render_semaphore_limit != limit:
self._render_semaphore = asyncio.Semaphore(limit)
self._render_semaphore_limit = limit
return self._render_semaphore
async def render_html_to_png(self, request: HtmlRenderRequest) -> HtmlRenderResult:
"""将 HTML 内容渲染为 PNG 图片。
Args:
request: 本次渲染请求。
Returns:
HtmlRenderResult: 渲染结果。
Raises:
RuntimeError: 浏览器能力被禁用、Playwright 不可用或浏览器启动失败时抛出。
ValueError: 请求参数非法时抛出。
"""
config = self._get_render_config()
if not config.enabled:
raise RuntimeError("插件运行时浏览器渲染能力已禁用")
normalized_request = self._normalize_request(request, config)
semaphore = self._get_render_semaphore()
async with semaphore:
start_time = time.perf_counter()
browser = await self._ensure_browser(config)
context: Any = None
try:
context = await browser.new_context(
device_scale_factor=normalized_request.device_scale_factor,
locale="zh-CN",
viewport={
"width": normalized_request.viewport_width,
"height": normalized_request.viewport_height,
},
)
page = await context.new_page()
await self._configure_page(page, normalized_request)
image_bytes = await self._capture_image(page, normalized_request)
width, height = self._measure_image_size(image_bytes)
self._render_count += 1
await self._maybe_restart_browser(config)
return HtmlRenderResult(
image_base64=base64.b64encode(image_bytes).decode("utf-8"),
mime_type="image/png",
width=width,
height=height,
render_ms=int((time.perf_counter() - start_time) * 1000),
)
except Exception:
await self.reset_browser(restart_playwright=False)
raise
finally:
if context is not None:
with contextlib.suppress(Exception):
await context.close()
async def reset_browser(self, restart_playwright: bool = False) -> None:
"""关闭当前缓存的浏览器实例。
Args:
restart_playwright: 是否同时关闭 Playwright 运行时。
"""
async with self._browser_lock:
await self._close_browser_unlocked(restart_playwright=restart_playwright)
async def _close_browser_unlocked(self, restart_playwright: bool = False) -> None:
"""在已持有锁的情况下关闭浏览器与 Playwright。
Args:
restart_playwright: 是否同时关闭 Playwright 运行时。
"""
if self._browser is not None:
with contextlib.suppress(Exception):
await self._browser.close()
self._browser = None
self._connected_via_cdp = False
if restart_playwright and self._playwright is not None:
with contextlib.suppress(Exception):
await self._playwright.stop()
self._playwright = None
async def _ensure_browser(self, config: PluginRuntimeRenderConfig) -> Any:
"""获取可复用的浏览器实例。
Args:
config: 当前浏览器渲染配置。
Returns:
Any: Playwright Browser 对象。
Raises:
RuntimeError: 当无法连接或启动浏览器时抛出。
"""
async with self._browser_lock:
if self._is_browser_connected(self._browser):
logger.debug("HTML 渲染服务复用进程内缓存浏览器实例")
return self._browser
await self._close_browser_unlocked(restart_playwright=False)
self._prepare_playwright_environment(config)
playwright = await self._ensure_playwright()
browser = await self._connect_to_existing_browser(playwright, config)
if browser is None:
browser = await self._launch_browser(playwright, config)
self._connected_via_cdp = False
else:
self._connected_via_cdp = True
self._browser = browser
self._bind_browser_events(browser)
return browser
async def _ensure_playwright(self) -> Any:
"""懒加载并启动 Playwright 运行时。
Returns:
Any: 已启动的 Playwright 对象。
Raises:
RuntimeError: 当前环境未安装 Playwright 时抛出。
"""
if self._playwright is not None:
return self._playwright
try:
from playwright.async_api import async_playwright
except ImportError as exc:
raise RuntimeError(
"当前环境未安装 Python Playwright请先在宿主环境安装 `playwright` 依赖。"
) from exc
self._playwright = await async_playwright().start()
return self._playwright
@staticmethod
def _is_browser_connected(browser: Any) -> bool:
"""判断浏览器对象当前是否仍然可用。
Args:
browser: 待检查的浏览器对象。
Returns:
bool: 若浏览器仍连接,则返回 ``True``。
"""
if browser is None:
return False
try:
return bool(browser.is_connected())
except Exception:
return False
async def _connect_to_existing_browser(self, playwright: Any, config: PluginRuntimeRenderConfig) -> Any:
"""优先连接外部已有的 Chromium 浏览器。
Args:
playwright: 已启动的 Playwright 对象。
config: 当前浏览器渲染配置。
Returns:
Any: 连接成功时返回 Browser否则返回 ``None``。
"""
if not config.browser_ws_endpoint.strip():
return None
try:
timeout_ms = int(config.startup_timeout_sec * 1000)
logger.info(
"HTML 渲染服务准备连接现有浏览器: "
f"endpoint={config.browser_ws_endpoint.strip()}, timeout_ms={timeout_ms}"
)
browser = await playwright.chromium.connect_over_cdp(
config.browser_ws_endpoint.strip(),
timeout=timeout_ms,
)
logger.info("HTML 渲染服务已连接到现有浏览器")
return browser
except Exception as exc:
logger.warning(f"连接现有浏览器失败,将回退为本地启动: {exc}")
return None
async def _launch_browser(self, playwright: Any, config: PluginRuntimeRenderConfig) -> Any:
"""启动本地 Chromium 浏览器。
Args:
playwright: 已启动的 Playwright 对象。
config: 当前浏览器渲染配置。
Returns:
Any: 新启动的 Browser 对象。
Raises:
RuntimeError: 浏览器启动失败时抛出。
"""
launch_options = self._build_launch_options(config)
logger.info(
"HTML 渲染服务准备启动浏览器: "
f"source={'system' if 'executable_path' in launch_options else 'managed'}, "
f"headless={bool(launch_options.get('headless'))}, "
f"timeout_ms={int(launch_options.get('timeout', 0))}"
)
try:
browser = await playwright.chromium.launch(**launch_options)
if "executable_path" in launch_options:
logger.info(f"HTML 渲染服务已启动本机浏览器: executable_path={launch_options['executable_path']}")
else:
self._update_managed_browser_record(config, install_source="existing_cache")
logger.info("HTML 渲染服务已启动 Playwright 托管浏览器")
return browser
except Exception as exc:
if self._should_auto_download_browser(exc, launch_options, config):
logger.warning(f"HTML 渲染服务未找到可用浏览器,将尝试自动下载 Chromium: {exc}")
await self._install_chromium_browser(config)
retry_browser = await playwright.chromium.launch(**launch_options)
self._update_managed_browser_record(config, install_source="auto_download")
logger.info("HTML 渲染服务已自动下载并启动 Chromium")
return retry_browser
raise RuntimeError(f"启动本地浏览器失败: {exc}") from exc
def _bind_browser_events(self, browser: Any) -> None:
"""为浏览器绑定断线回调。
Args:
browser: 需要绑定事件的浏览器对象。
"""
try:
browser.on("disconnected", self._handle_browser_disconnected)
except Exception:
return
def _handle_browser_disconnected(self, *_args: Any) -> None:
"""处理浏览器断线事件。
Args:
*_args: 浏览器断线事件透传的参数。
"""
self._browser = None
self._connected_via_cdp = False
logger.warning("HTML 渲染浏览器已断开,将在下次请求时重新建立连接")
def _build_launch_options(self, config: PluginRuntimeRenderConfig) -> Dict[str, Any]:
"""构造本地浏览器启动参数。
Args:
config: 当前浏览器渲染配置。
Returns:
Dict[str, Any]: 可直接传给 Playwright 的启动参数。
"""
launch_options: Dict[str, Any] = {
"args": list(config.launch_args),
"headless": bool(config.headless),
"timeout": int(config.startup_timeout_sec * 1000),
}
executable_path = self._resolve_executable_path(config)
if executable_path:
launch_options["executable_path"] = executable_path
return launch_options
@staticmethod
def _should_auto_download_browser(
exc: Exception,
launch_options: Dict[str, Any],
config: PluginRuntimeRenderConfig,
) -> bool:
"""判断当前启动错误是否适合自动下载 Chromium 后重试。
Args:
exc: 浏览器启动异常。
launch_options: 本次启动参数。
config: 当前浏览器渲染配置。
Returns:
bool: 若应自动下载后重试,则返回 ``True``。
"""
if "executable_path" in launch_options:
logger.debug("当前启动参数已指定本机浏览器路径,不进入自动下载分支")
return False
if not config.auto_download_chromium:
logger.warning("HTML 渲染服务未检测到可用浏览器,且已禁用自动下载 Chromium")
return False
error_text = str(exc).lower()
should_download = "executable doesn't exist" in error_text or "browser executable" in error_text
if not should_download:
logger.warning(f"浏览器启动失败,但错误不属于可自动下载恢复的类型: {exc}")
return should_download
def _resolve_executable_path(self, config: PluginRuntimeRenderConfig) -> str:
"""解析实际应使用的浏览器可执行文件路径。
Args:
config: 当前浏览器渲染配置。
Returns:
str: 命中的浏览器可执行文件路径;未命中时返回空字符串。
"""
configured_path = config.executable_path.strip()
if configured_path:
path = Path(configured_path).expanduser()
if path.exists():
logger.info(f"HTML 渲染服务使用配置指定的浏览器路径: {path}")
return str(path)
logger.warning(f"配置的浏览器路径不存在,将尝试自动探测: {configured_path}")
detected_path = self._detect_local_browser_executable()
if detected_path:
logger.info(f"HTML 渲染服务自动探测到本机浏览器: {detected_path}")
else:
logger.info("HTML 渲染服务未探测到本机浏览器,将尝试使用 Playwright 托管浏览器")
return detected_path
def _prepare_playwright_environment(self, config: PluginRuntimeRenderConfig) -> Path:
"""准备 Playwright 运行所需的共享浏览器目录环境变量。
Args:
config: 当前浏览器渲染配置。
Returns:
Path: Playwright 浏览器缓存目录。
"""
browsers_path = self._get_managed_browsers_path(config)
browsers_path.mkdir(parents=True, exist_ok=True)
os.environ["PLAYWRIGHT_BROWSERS_PATH"] = str(browsers_path)
logger.debug(f"HTML 渲染服务使用 Playwright 浏览器目录: {browsers_path}")
return browsers_path
def _get_managed_browsers_path(self, config: PluginRuntimeRenderConfig) -> Path:
"""获取 Playwright 托管浏览器目录。
Args:
config: 当前浏览器渲染配置。
Returns:
Path: 托管浏览器目录的绝对路径。
"""
configured_path = config.browser_install_root.strip()
if not configured_path:
return (PROJECT_ROOT / "data" / "playwright-browsers").resolve()
candidate_path = Path(configured_path).expanduser()
if candidate_path.is_absolute():
return candidate_path.resolve()
return (PROJECT_ROOT / candidate_path).resolve()
def _get_browser_state_path(self) -> Path:
"""获取托管浏览器状态文件路径。
Returns:
Path: 浏览器状态文件路径。
"""
return (PROJECT_ROOT / "data" / "plugin_runtime" / "html_render_browser_state.json").resolve()
def _load_managed_browser_record(self) -> Optional[ManagedBrowserRecord]:
"""读取最近一次成功使用的托管浏览器记录。
Returns:
Optional[ManagedBrowserRecord]: 解析成功时返回记录对象,否则返回 ``None``。
"""
state_path = self._get_browser_state_path()
if not state_path.exists():
return None
try:
raw_payload = json.loads(state_path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
logger.warning(f"HTML 渲染浏览器状态文件读取失败,将忽略并继续: {state_path}")
return None
if not isinstance(raw_payload, dict):
logger.warning(f"HTML 渲染浏览器状态文件格式无效,将忽略并继续: {state_path}")
return None
browser_record = ManagedBrowserRecord.from_dict(raw_payload)
if browser_record is not None:
logger.debug(
"HTML 渲染服务已加载浏览器状态记录: "
f"source={browser_record.install_source}, path={browser_record.browsers_path}, "
f"verified_at={browser_record.last_verified_at}"
)
return browser_record
def _save_managed_browser_record(self, record: ManagedBrowserRecord) -> None:
"""保存托管浏览器记录。
Args:
record: 待保存的浏览器记录。
"""
state_path = self._get_browser_state_path()
state_path.parent.mkdir(parents=True, exist_ok=True)
state_path.write_text(
json.dumps(record.to_dict(), ensure_ascii=False, indent=2),
encoding="utf-8",
)
logger.info(
"HTML 渲染服务已写入浏览器状态记录: "
f"path={state_path}, source={record.install_source}, browsers_path={record.browsers_path}"
)
def _update_managed_browser_record(
self,
config: PluginRuntimeRenderConfig,
install_source: Literal["auto_download", "existing_cache"],
) -> None:
"""更新托管 Chromium 的使用记录。
Args:
config: 当前浏览器渲染配置。
install_source: 本次记录的浏览器来源。
"""
browsers_path = self._get_managed_browsers_path(config)
if not self._has_managed_browser_artifact(browsers_path):
return
now_iso = datetime.now(timezone.utc).isoformat()
existing_record = self._load_managed_browser_record()
recorded_at = now_iso
if existing_record is not None and existing_record.browsers_path == str(browsers_path):
recorded_at = existing_record.recorded_at
self._save_managed_browser_record(
ManagedBrowserRecord(
browser_name="chromium",
browsers_path=str(browsers_path),
install_source=install_source,
playwright_version=self._get_playwright_version(),
recorded_at=recorded_at,
last_verified_at=now_iso,
)
)
logger.info(
"HTML 渲染服务已更新托管浏览器记录: "
f"source={install_source}, browsers_path={browsers_path}, last_verified_at={now_iso}"
)
async def _install_chromium_browser(self, config: PluginRuntimeRenderConfig) -> None:
"""自动下载 Playwright Chromium 浏览器。
Args:
config: 当前浏览器渲染配置。
Raises:
RuntimeError: 下载失败时抛出。
"""
browsers_path = self._prepare_playwright_environment(config)
logger.warning(
"HTML 渲染服务开始自动下载 Chromium: "
f"target_dir={browsers_path}, timeout_sec={config.download_connection_timeout_sec}"
)
env = os.environ.copy()
env["PLAYWRIGHT_BROWSERS_PATH"] = str(browsers_path)
env["PLAYWRIGHT_DOWNLOAD_CONNECTION_TIMEOUT"] = str(int(config.download_connection_timeout_sec * 1000))
process = await asyncio.create_subprocess_exec(
sys.executable,
"-m",
"playwright",
"install",
"chromium",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
)
stdout_bytes, stderr_bytes = await process.communicate()
if process.returncode != 0:
stderr_text = stderr_bytes.decode("utf-8", errors="ignore").strip()
stdout_text = stdout_bytes.decode("utf-8", errors="ignore").strip()
error_detail = stderr_text or stdout_text or f"退出码 {process.returncode}"
raise RuntimeError(f"自动下载 Chromium 失败: {error_detail}")
if not self._has_managed_browser_artifact(browsers_path):
raise RuntimeError("Chromium 下载完成后未检测到可用浏览器文件")
logger.info(f"HTML 渲染服务自动下载 Chromium 完成: target_dir={browsers_path}")
@staticmethod
def _get_playwright_version() -> str:
"""读取当前环境中的 Playwright 版本号。
Returns:
str: Playwright 版本字符串;读取失败时返回 ``unknown``。
"""
try:
return metadata.version("playwright")
except metadata.PackageNotFoundError:
return "unknown"
@staticmethod
def _has_managed_browser_artifact(browsers_path: Path) -> bool:
"""检查共享目录中是否存在可用的 Playwright 托管浏览器。
Args:
browsers_path: Playwright 浏览器目录。
Returns:
bool: 若检测到 Chromium/Chrome 相关浏览器文件夹,则返回 ``True``。
"""
if not browsers_path.exists():
return False
for child_path in browsers_path.iterdir():
if not child_path.is_dir():
continue
if child_path.name.startswith(_PLAYWRIGHT_MANAGED_BROWSER_PREFIXES):
return True
return False
def _detect_local_browser_executable(self) -> str:
"""自动探测当前宿主系统中的可复用浏览器路径。
Returns:
str: 命中的浏览器可执行文件路径;未命中时返回空字符串。
"""
for browser_name in _UNIX_BROWSER_NAMES:
resolved_path = shutil.which(browser_name)
if resolved_path:
return resolved_path
for candidate_path in self._get_candidate_executable_paths():
if candidate_path.exists():
return str(candidate_path)
return ""
@staticmethod
def _get_candidate_executable_paths() -> Tuple[Path, ...]:
"""返回当前平台常见浏览器路径候选集合。
Returns:
Tuple[Path, ...]: 可能存在浏览器可执行文件的路径列表。
"""
if sys.platform.startswith("win"):
return _WINDOWS_BROWSER_PATHS
if sys.platform == "darwin":
return _MACOS_BROWSER_PATHS
return ()
async def _configure_page(self, page: Any, request: HtmlRenderRequest) -> None:
"""为页面设置超时、网络策略并写入 HTML。
Args:
page: Playwright 页面对象。
request: 当前渲染请求。
"""
page.set_default_timeout(request.timeout_ms)
await page.route(
"**/*",
functools.partial(self._handle_network_route, allow_network=request.allow_network),
)
await page.set_content(
request.html,
timeout=request.timeout_ms,
wait_until=request.wait_until,
)
if request.wait_for_selector:
await page.locator(request.wait_for_selector).first.wait_for(
state="attached",
timeout=request.timeout_ms,
)
if request.wait_for_timeout_ms > 0:
await page.wait_for_timeout(request.wait_for_timeout_ms)
async def _handle_network_route(self, route: Any, allow_network: bool) -> None:
"""处理页面资源请求的网络准入策略。
Args:
route: Playwright 路由对象。
allow_network: 是否允许页面访问外部网络资源。
"""
request_url = str(route.request.url)
if allow_network or self._is_network_request_allowed(request_url):
await route.continue_()
return
await route.abort()
@staticmethod
def _is_network_request_allowed(request_url: str) -> bool:
"""判断某个资源 URL 是否属于本地安全资源。
Args:
request_url: 待判断的资源地址。
Returns:
bool: 若请求可在无网络模式下放行,则返回 ``True``。
"""
if not request_url:
return False
parsed_url = urlparse(request_url)
return parsed_url.scheme in _NETWORK_ALLOW_SCHEMES
async def _capture_image(self, page: Any, request: HtmlRenderRequest) -> bytes:
"""从页面或目标元素中截取 PNG 图片。
Args:
page: Playwright 页面对象。
request: 当前渲染请求。
Returns:
bytes: PNG 二进制内容。
Raises:
RuntimeError: 目标元素不存在或截图结果为空时抛出。
"""
if request.full_page and request.selector == "body":
image_bytes = await page.screenshot(
full_page=True,
omit_background=request.omit_background,
timeout=request.timeout_ms,
type="png",
)
else:
locator = page.locator(request.selector).first
await locator.wait_for(state="visible", timeout=request.timeout_ms)
image_bytes = await locator.screenshot(
omit_background=request.omit_background,
timeout=request.timeout_ms,
type="png",
)
if not image_bytes:
raise RuntimeError("浏览器截图结果为空")
return image_bytes
@staticmethod
def _measure_image_size(image_bytes: bytes) -> Tuple[int, int]:
"""读取 PNG 图片的真实像素尺寸。
Args:
image_bytes: PNG 图片二进制内容。
Returns:
Tuple[int, int]: 图片宽高像素值。
"""
from PIL import Image
with Image.open(BytesIO(image_bytes)) as image:
return int(image.width), int(image.height)
async def _maybe_restart_browser(self, config: PluginRuntimeRenderConfig) -> None:
"""按策略决定是否重建本地浏览器实例。
Args:
config: 当前浏览器渲染配置。
"""
restart_after = int(config.restart_after_render_count)
if restart_after <= 0 or self._connected_via_cdp:
return
if self._render_count % restart_after != 0:
return
await self.reset_browser(restart_playwright=False)
logger.info("HTML 渲染服务已按累计次数策略重建本地浏览器")
@staticmethod
def _normalize_request(
request: HtmlRenderRequest,
config: PluginRuntimeRenderConfig,
) -> HtmlRenderRequest:
"""规范化并补齐 HTML 渲染请求。
Args:
request: 原始渲染请求。
config: 当前浏览器渲染配置。
Returns:
HtmlRenderRequest: 规范化后的请求对象。
Raises:
ValueError: 请求缺少必要字段或取值非法时抛出。
"""
html = request.html.strip()
if not html:
raise ValueError("缺少必要参数 html")
selector = request.selector.strip() or "body"
wait_until = HTMLRenderService._normalize_wait_until(request.wait_until)
timeout_ms = request.timeout_ms
if timeout_ms <= 0:
timeout_ms = int(config.render_timeout_sec * 1000)
return HtmlRenderRequest(
html=html,
selector=selector,
viewport_width=max(1, int(request.viewport_width)),
viewport_height=max(1, int(request.viewport_height)),
device_scale_factor=max(1.0, float(request.device_scale_factor)),
full_page=bool(request.full_page),
omit_background=bool(request.omit_background),
wait_until=wait_until,
wait_for_selector=request.wait_for_selector.strip(),
wait_for_timeout_ms=max(0, int(request.wait_for_timeout_ms)),
timeout_ms=max(1, int(timeout_ms)),
allow_network=bool(request.allow_network),
)
@staticmethod
def _normalize_wait_until(wait_until: str) -> str:
"""规范化页面等待阶段参数。
Args:
wait_until: 原始等待阶段字符串。
Returns:
str: Playwright 支持的等待阶段值。
"""
normalized_wait_until = wait_until.strip().lower()
if normalized_wait_until in {"commit", "domcontentloaded", "load", "networkidle"}:
return normalized_wait_until
return "load"
_html_render_service: Optional[HTMLRenderService] = None
def get_html_render_service() -> HTMLRenderService:
"""获取 HTML 浏览器渲染服务单例。
Returns:
HTMLRenderService: 全局唯一的浏览器渲染服务实例。
"""
global _html_render_service
if _html_render_service is None:
_html_render_service = HTMLRenderService()
return _html_render_service

View File

@@ -40,10 +40,213 @@ from src.common.utils.utils_message import MessageUtils
from src.config.config import global_config
from src.platform_io import DeliveryBatch, get_platform_io_manager
from src.platform_io.route_key_factory import RouteKeyFactory
from src.plugin_runtime.hook_payloads import deserialize_session_message, serialize_session_message
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
logger = get_logger("send_service")
def register_send_service_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册发送服务内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="send_service.after_build_message",
description="在出站 SessionMessage 构建完成后触发,可改写消息体或取消发送。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "待发送消息的序列化 SessionMessage。",
},
"stream_id": {
"type": "string",
"description": "目标会话 ID。",
},
"display_message": {
"type": "string",
"description": "展示层文本。",
},
"typing": {
"type": "boolean",
"description": "是否模拟打字。",
},
"set_reply": {
"type": "boolean",
"description": "是否附带引用回复。",
},
"storage_message": {
"type": "boolean",
"description": "发送成功后是否写库。",
},
"show_log": {
"type": "boolean",
"description": "是否输出发送日志。",
},
},
required=[
"message",
"stream_id",
"display_message",
"typing",
"set_reply",
"storage_message",
"show_log",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="send_service.before_send",
description="在真正调用 Platform IO 发送前触发,可改写消息或取消本次发送。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "待发送消息的序列化 SessionMessage。",
},
"typing": {
"type": "boolean",
"description": "是否模拟打字。",
},
"set_reply": {
"type": "boolean",
"description": "是否附带引用回复。",
},
"reply_message_id": {
"type": "string",
"description": "被引用消息 ID。",
},
"storage_message": {
"type": "boolean",
"description": "发送成功后是否写库。",
},
"show_log": {
"type": "boolean",
"description": "是否输出发送日志。",
},
},
required=["message", "typing", "set_reply", "storage_message", "show_log"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="send_service.after_send",
description="在发送流程结束后触发,用于观察最终发送结果。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "本次发送消息的序列化 SessionMessage。",
},
"sent": {
"type": "boolean",
"description": "本次发送是否成功。",
},
"typing": {
"type": "boolean",
"description": "是否模拟打字。",
},
"set_reply": {
"type": "boolean",
"description": "是否附带引用回复。",
},
"reply_message_id": {
"type": "string",
"description": "被引用消息 ID。",
},
"storage_message": {
"type": "boolean",
"description": "发送成功后是否写库。",
},
"show_log": {
"type": "boolean",
"description": "是否输出发送日志。",
},
},
required=["message", "sent", "typing", "set_reply", "storage_message", "show_log"],
),
default_timeout_ms=5000,
allow_abort=False,
allow_kwargs_mutation=False,
),
]
)
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
def _coerce_bool(value: Any, default: bool) -> bool:
"""将任意值安全转换为布尔值。
Args:
value: 待转换的值。
default: 当值为空时使用的默认值。
Returns:
bool: 转换后的布尔值。
"""
if value is None:
return default
return bool(value)
async def _invoke_send_hook(
hook_name: str,
message: SessionMessage,
**kwargs: Any,
) -> tuple[HookDispatchResult, SessionMessage]:
"""触发携带出站消息的命名 Hook。
Args:
hook_name: 目标 Hook 名称。
message: 当前待发送消息。
**kwargs: 需要附带的额外参数。
Returns:
tuple[HookDispatchResult, SessionMessage]: Hook 聚合结果以及可能被改写后的消息对象。
"""
hook_result = await _get_runtime_manager().invoke_hook(
hook_name,
message=serialize_session_message(message),
**kwargs,
)
mutated_message = message
raw_message = hook_result.kwargs.get("message")
if raw_message is not None:
try:
mutated_message = deserialize_session_message(raw_message)
except Exception as exc:
logger.warning(f"Hook {hook_name} 返回的 message 无法反序列化,已忽略: {exc}")
return hook_result, mutated_message
def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[str, object]:
"""从目标会话继承 Platform IO 路由元数据。
@@ -484,6 +687,27 @@ async def _send_via_platform_io(
Returns:
bool: 发送成功时返回 ``True``。
"""
before_send_result, message = await _invoke_send_hook(
"send_service.before_send",
message,
typing=typing,
set_reply=set_reply,
reply_message_id=reply_message_id,
storage_message=storage_message,
show_log=show_log,
)
if before_send_result.aborted:
logger.info(f"[SendService] 消息 {message.message_id} 在发送前被 Hook 中止")
return False
before_kwargs = before_send_result.kwargs
typing = _coerce_bool(before_kwargs.get("typing"), typing)
set_reply = _coerce_bool(before_kwargs.get("set_reply"), set_reply)
storage_message = _coerce_bool(before_kwargs.get("storage_message"), storage_message)
show_log = _coerce_bool(before_kwargs.get("show_log"), show_log)
raw_reply_message_id = before_kwargs.get("reply_message_id", reply_message_id)
reply_message_id = None if raw_reply_message_id in {None, ""} else str(raw_reply_message_id)
platform_io_manager = get_platform_io_manager()
try:
await platform_io_manager.ensure_send_pipeline_ready()
@@ -515,6 +739,18 @@ async def _send_via_platform_io(
logger.debug(traceback.format_exc())
return False
sent = bool(delivery_batch.has_success)
await _invoke_send_hook(
"send_service.after_send",
message,
sent=sent,
typing=typing,
set_reply=set_reply,
reply_message_id=reply_message_id,
storage_message=storage_message,
show_log=show_log,
)
if delivery_batch.has_success:
if storage_message:
_store_sent_message(message)
@@ -622,6 +858,26 @@ async def _send_to_target(
if outbound_message is None:
return False
after_build_result, outbound_message = await _invoke_send_hook(
"send_service.after_build_message",
outbound_message,
stream_id=stream_id,
display_message=display_message,
typing=typing,
set_reply=set_reply,
storage_message=storage_message,
show_log=show_log,
)
if after_build_result.aborted:
logger.info(f"[SendService] 消息 {outbound_message.message_id} 在构建后被 Hook 中止")
return False
after_build_kwargs = after_build_result.kwargs
typing = _coerce_bool(after_build_kwargs.get("typing"), typing)
set_reply = _coerce_bool(after_build_kwargs.get("set_reply"), set_reply)
storage_message = _coerce_bool(after_build_kwargs.get("storage_message"), storage_message)
show_log = _coerce_bool(after_build_kwargs.get("show_log"), show_log)
sent = await send_session_message(
outbound_message,
typing=typing,

View File

@@ -1,6 +1,7 @@
import inspect
from typing import Any, Dict, List, get_args, get_origin
import inspect
from pydantic_core import PydanticUndefined
from src.config.config_base import ConfigBase
@@ -56,7 +57,7 @@ class ConfigSchemaGenerator:
if inspect.isclass(annotation) and issubclass(annotation, ConfigBase):
return cls.generate_config_schema(annotation)
if origin in {list, tuple} and args:
if origin in {list, set, tuple} and args:
first = args[0]
if inspect.isclass(first) and issubclass(first, ConfigBase):
return cls.generate_config_schema(first)
@@ -83,7 +84,7 @@ class ConfigSchemaGenerator:
origin = get_origin(annotation)
args = get_args(annotation)
if origin is list and args:
if origin in {list, set} and args:
schema["items"] = {"type": cls._map_field_type(args[0])}
if options := cls._extract_options(annotation):
@@ -120,7 +121,7 @@ class ConfigSchemaGenerator:
origin = get_origin(annotation)
args = get_args(annotation)
if origin in {list, tuple}:
if origin in {list, set, tuple}:
return "array"
if inspect.isclass(annotation) and issubclass(annotation, ConfigBase):
return "object"
@@ -133,7 +134,7 @@ class ConfigSchemaGenerator:
if annotation is str:
return "string"
if origin in {list, tuple} and args:
if origin in {list, set, tuple} and args:
return "array"
if origin in {dict}:

View File

@@ -9,6 +9,7 @@ from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.routers.websocket.manager import websocket_manager
logger = get_logger("webui.logs_ws")
router = APIRouter()
@@ -148,24 +149,9 @@ async def broadcast_log(log_data: Dict):
Args:
log_data: 日志数据字典
"""
if not active_connections:
return
# 格式化为 JSON
message = json.dumps(log_data, ensure_ascii=False)
# 记录需要断开的连接
disconnected = set()
# 广播到所有客户端
for connection in active_connections:
try:
await connection.send_text(message)
except Exception:
# 发送失败,标记为断开
disconnected.add(connection)
# 清理断开的连接
if disconnected:
active_connections.difference_update(disconnected)
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")
await websocket_manager.broadcast_to_topic(
domain="logs",
topic="main",
event="entry",
data={"entry": log_data},
)

View File

@@ -18,13 +18,13 @@ def get_all_routers() -> List[APIRouter]:
from src.webui.api.replier import router as replier_router
from src.webui.routers.chat import router as chat_router
from src.webui.routers.memory import compat_router as memory_compat_router
from src.webui.routers.websocket.logs import router as logs_router
from src.webui.routers.knowledge import router as knowledge_router
from src.webui.routes import router as main_router
return [
main_router,
memory_compat_router,
logs_router,
knowledge_router,
chat_router,
planner_router,
replier_router,

View File

@@ -1,7 +1,7 @@
from typing import Tuple
from .routes import router
from .support import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
from .service import WEBUI_CHAT_PLATFORM, ChatConnectionManager, chat_manager
def get_webui_chat_broadcaster() -> Tuple[ChatConnectionManager, str]:

View File

@@ -1,9 +1,8 @@
"""本地聊天室路由 - WebUI 与麦麦直接对话。"""
import uuid
from typing import Dict, Optional
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Depends, Query
from sqlalchemy import case, func
from sqlmodel import col, select
@@ -13,16 +12,11 @@ from src.common.logger import get_logger
from src.config.config import global_config
from src.webui.dependencies import require_auth
from .support import (
from .service import (
WEBUI_CHAT_GROUP_ID,
WEBUI_CHAT_PLATFORM,
authenticate_chat_websocket,
chat_history,
chat_manager,
dispatch_chat_event,
normalize_webui_user_id,
resolve_initial_virtual_identity,
send_initial_chat_state,
)
logger = get_logger("webui.chat")
@@ -113,55 +107,6 @@ async def clear_chat_history(
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
@router.websocket("/ws")
async def websocket_chat(
websocket: WebSocket,
user_id: Optional[str] = Query(default=None),
user_name: Optional[str] = Query(default="WebUI用户"),
platform: Optional[str] = Query(default=None),
person_id: Optional[str] = Query(default=None),
group_name: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None),
token: Optional[str] = Query(default=None),
) -> None:
"""WebSocket 聊天端点。"""
if not await authenticate_chat_websocket(websocket, token):
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
session_id = str(uuid.uuid4())
normalized_user_id = normalize_webui_user_id(user_id)
current_user_name = user_name or "WebUI用户"
current_virtual_config = resolve_initial_virtual_identity(platform, person_id, group_name, group_id)
await chat_manager.connect(websocket, session_id, normalized_user_id)
try:
await send_initial_chat_state(
session_id=session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
)
while True:
data = await websocket.receive_json()
current_user_name, current_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data=data,
current_user_name=current_user_name,
normalized_user_id=normalized_user_id,
current_virtual_config=current_virtual_config,
)
except WebSocketDisconnect:
logger.info(f"WebSocket 断开: session={session_id}, user={normalized_user_id}")
except Exception as e:
logger.error(f"WebSocket 错误: {e}")
finally:
chat_manager.disconnect(session_id, normalized_user_id)
@router.get("/info")
async def get_chat_info() -> Dict[str, object]:
"""获取聊天室信息。"""

View File

@@ -1,10 +1,10 @@
"""WebUI 聊天路由支持逻辑"""
"""WebUI 聊天运行时服务"""
from dataclasses import dataclass
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, cast
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, cast
from fastapi import WebSocket
from pydantic import BaseModel
from sqlmodel import col, delete, select
@@ -17,8 +17,6 @@ from src.common.logger import get_logger
from src.common.message_repository import find_messages
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.chat")
@@ -27,6 +25,8 @@ WEBUI_CHAT_PLATFORM = "webui"
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
WEBUI_USER_ID_PREFIX = "webui_user_"
AsyncMessageSender = Callable[[Dict[str, Any]], Awaitable[None]]
class VirtualIdentityConfig(BaseModel):
"""虚拟身份配置。"""
@@ -52,13 +52,42 @@ class ChatHistoryMessage(BaseModel):
is_bot: bool = False
@dataclass
class ChatSessionConnection:
"""逻辑聊天会话连接信息。"""
session_id: str
connection_id: str
client_session_id: str
user_id: str
user_name: str
active_group_id: str
virtual_config: Optional[VirtualIdentityConfig]
sender: AsyncMessageSender
class ChatHistoryManager:
"""聊天历史管理器。"""
def __init__(self, max_messages: int = 200) -> None:
"""初始化聊天历史管理器。
Args:
max_messages: 内存中允许处理的最大消息数
"""
self.max_messages = max_messages
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> Dict[str, Any]:
"""将内部消息对象转换为前端可消费的字典。
Args:
msg: 内部统一消息对象
group_id: 当前会话所属的群组标识
Returns:
Dict[str, Any]: 面向 WebUI 的消息字典
"""
del group_id
user_info = msg.message_info.user_info
user_id = user_info.user_id or ""
is_bot = is_bot_self(msg.platform, user_id)
@@ -74,10 +103,27 @@ class ChatHistoryManager:
}
def _resolve_session_id(self, group_id: Optional[str]) -> str:
"""根据群组标识解析聊天会话 ID。
Args:
group_id: 群组标识
Returns:
str: 内部聊天会话 ID
"""
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""获取指定会话的历史消息。
Args:
limit: 最大返回条数
group_id: 群组标识
Returns:
List[Dict[str, Any]]: 历史消息列表
"""
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
session_id = self._resolve_session_id(target_group_id)
try:
@@ -90,11 +136,19 @@ class ChatHistoryManager:
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
return result
except Exception as e:
logger.error(f"从数据库加载聊天记录失败: {e}")
except Exception as exc:
logger.error(f"从数据库加载聊天记录失败: {exc}")
return []
def clear_history(self, group_id: Optional[str] = None) -> int:
"""清空指定会话的历史消息。
Args:
group_id: 群组标识
Returns:
int: 被删除的消息数量
"""
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
session_id = self._resolve_session_id(target_group_id)
try:
@@ -104,66 +158,245 @@ class ChatHistoryManager:
deleted = result.rowcount or 0
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
return deleted
except Exception as e:
logger.error(f"清空聊天记录失败: {e}")
except Exception as exc:
logger.error(f"清空聊天记录失败: {exc}")
return 0
class ChatConnectionManager:
"""聊天连接管理器。"""
"""统一聊天逻辑会话管理器。"""
def __init__(self) -> None:
self.active_connections: Dict[str, WebSocket] = {}
self.user_sessions: Dict[str, str] = {}
"""初始化聊天逻辑会话管理器。"""
self.active_connections: Dict[str, ChatSessionConnection] = {}
self.client_sessions: Dict[Tuple[str, str], str] = {}
self.connection_sessions: Dict[str, Set[str]] = {}
self.group_sessions: Dict[str, Set[str]] = {}
self.user_sessions: Dict[str, Set[str]] = {}
async def connect(self, websocket: WebSocket, session_id: str, user_id: str) -> None:
await websocket.accept()
self.active_connections[session_id] = websocket
self.user_sessions[user_id] = session_id
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
def _bind_group(self, session_id: str, group_id: str) -> None:
"""为会话绑定群组索引。
def disconnect(self, session_id: str, user_id: str) -> None:
if session_id in self.active_connections:
del self.active_connections[session_id]
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
del self.user_sessions[user_id]
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
Args:
session_id: 内部会话 ID
group_id: 群组标识
"""
group_session_ids = self.group_sessions.setdefault(group_id, set())
group_session_ids.add(session_id)
def _unbind_group(self, session_id: str, group_id: str) -> None:
"""移除会话与群组的索引关系。
Args:
session_id: 内部会话 ID
group_id: 群组标识
"""
group_session_ids = self.group_sessions.get(group_id)
if group_session_ids is None:
return
group_session_ids.discard(session_id)
if not group_session_ids:
del self.group_sessions[group_id]
async def connect(
self,
session_id: str,
connection_id: str,
client_session_id: str,
user_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
sender: AsyncMessageSender,
) -> None:
"""注册一个新的逻辑聊天会话。
Args:
session_id: 内部逻辑会话 ID
connection_id: 物理 WebSocket 连接 ID
client_session_id: 前端标签页使用的会话 ID
user_id: 规范化后的用户 ID
user_name: 当前展示昵称
virtual_config: 当前虚拟身份配置
sender: 发送消息到前端的异步回调
"""
existing_session_id = self.client_sessions.get((connection_id, client_session_id))
if existing_session_id is not None:
self.disconnect(existing_session_id)
active_group_id = get_current_group_id(virtual_config)
session_connection = ChatSessionConnection(
session_id=session_id,
connection_id=connection_id,
client_session_id=client_session_id,
user_id=user_id,
user_name=user_name,
active_group_id=active_group_id,
virtual_config=virtual_config,
sender=sender,
)
self.active_connections[session_id] = session_connection
self.client_sessions[(connection_id, client_session_id)] = session_id
self.connection_sessions.setdefault(connection_id, set()).add(session_id)
self.user_sessions.setdefault(user_id, set()).add(session_id)
self._bind_group(session_id, active_group_id)
logger.info(
"WebUI 聊天会话已连接: session=%s, connection=%s, client_session=%s, user=%s, group=%s",
session_id,
connection_id,
client_session_id,
user_id,
active_group_id,
)
def disconnect(self, session_id: str) -> None:
"""断开一个逻辑聊天会话。
Args:
session_id: 内部逻辑会话 ID
"""
session_connection = self.active_connections.pop(session_id, None)
if session_connection is None:
return
self.client_sessions.pop((session_connection.connection_id, session_connection.client_session_id), None)
self._unbind_group(session_id, session_connection.active_group_id)
connection_session_ids = self.connection_sessions.get(session_connection.connection_id)
if connection_session_ids is not None:
connection_session_ids.discard(session_id)
if not connection_session_ids:
del self.connection_sessions[session_connection.connection_id]
user_session_ids = self.user_sessions.get(session_connection.user_id)
if user_session_ids is not None:
user_session_ids.discard(session_id)
if not user_session_ids:
del self.user_sessions[session_connection.user_id]
logger.info("WebUI 聊天会话已断开: session=%s", session_id)
def disconnect_connection(self, connection_id: str) -> None:
"""断开物理连接下的全部逻辑聊天会话。
Args:
connection_id: 物理 WebSocket 连接 ID
"""
session_ids = list(self.connection_sessions.get(connection_id, set()))
for session_id in session_ids:
self.disconnect(session_id)
def get_session(self, session_id: str) -> Optional[ChatSessionConnection]:
"""获取逻辑聊天会话信息。
Args:
session_id: 内部逻辑会话 ID
Returns:
Optional[ChatSessionConnection]: 会话存在时返回对应信息
"""
return self.active_connections.get(session_id)
def get_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
"""根据连接 ID 和前端会话 ID 查询内部会话 ID。
Args:
connection_id: 物理 WebSocket 连接 ID
client_session_id: 前端标签页使用的会话 ID
Returns:
Optional[str]: 找到时返回内部会话 ID
"""
return self.client_sessions.get((connection_id, client_session_id))
def update_session_context(
self,
session_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> None:
"""更新会话上下文信息。
Args:
session_id: 内部逻辑会话 ID
user_name: 最新昵称
virtual_config: 最新虚拟身份配置
"""
session_connection = self.active_connections.get(session_id)
if session_connection is None:
return
next_group_id = get_current_group_id(virtual_config)
if next_group_id != session_connection.active_group_id:
self._unbind_group(session_id, session_connection.active_group_id)
self._bind_group(session_id, next_group_id)
session_connection.active_group_id = next_group_id
session_connection.user_name = user_name
session_connection.virtual_config = virtual_config
async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
if session_id in self.active_connections:
try:
await self.active_connections[session_id].send_json(message)
except Exception as e:
logger.error(f"发送消息失败: {e}")
"""向指定逻辑会话发送消息。
Args:
session_id: 内部逻辑会话 ID
message: 发送消息内容
"""
session_connection = self.active_connections.get(session_id)
if session_connection is None:
return
try:
await session_connection.sender(message)
except Exception as exc:
logger.error("发送聊天消息失败: session=%s, error=%s", session_id, exc)
async def broadcast(self, message: Dict[str, Any]) -> None:
"""向全部逻辑聊天会话广播消息。
Args:
message: 待广播的消息内容
"""
for session_id in list(self.active_connections.keys()):
await self.send_message(session_id, message)
async def broadcast_to_group(self, group_id: str, message: Dict[str, Any]) -> None:
"""向指定群组下的全部逻辑会话广播消息。
Args:
group_id: 群组标识
message: 待广播的消息内容
"""
for session_id in list(self.group_sessions.get(group_id, set())):
await self.send_message(session_id, message)
chat_history = ChatHistoryManager()
chat_manager = ChatConnectionManager()
def is_virtual_mode_enabled(virtual_config: Optional[VirtualIdentityConfig]) -> bool:
"""判断当前是否启用了虚拟身份模式。
Args:
virtual_config: 虚拟身份配置
Returns:
bool: 已启用时返回 ``True``
"""
return bool(virtual_config and virtual_config.enabled)
async def authenticate_chat_websocket(websocket: WebSocket, token: Optional[str]) -> bool:
if token and verify_ws_token(token):
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
return True
if cookie_token := websocket.cookies.get("maibot_session"):
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
return True
return False
def normalize_webui_user_id(user_id: Optional[str]) -> str:
"""标准化 WebUI 用户 ID。
Args:
user_id: 原始用户 ID
Returns:
str: 带统一前缀的用户 ID
"""
if not user_id:
return f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
if user_id.startswith(WEBUI_USER_ID_PREFIX):
@@ -172,12 +405,30 @@ def normalize_webui_user_id(user_id: Optional[str]) -> str:
def get_person_by_person_id(person_id: str) -> Optional[PersonInfo]:
"""根据人物 ID 查询人物信息。
Args:
person_id: 人物 ID
Returns:
Optional[PersonInfo]: 查到时返回人物信息
"""
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
return session.exec(statement).first()
def build_virtual_identity_config(person: PersonInfo, group_id: str, group_name: str) -> VirtualIdentityConfig:
"""根据人物信息构建虚拟身份配置。
Args:
person: 人物信息对象
group_id: 逻辑群组 ID
group_name: 逻辑群组名称
Returns:
VirtualIdentityConfig: 虚拟身份配置对象
"""
return VirtualIdentityConfig(
enabled=True,
platform=person.platform,
@@ -195,6 +446,17 @@ def resolve_initial_virtual_identity(
group_name: Optional[str],
group_id: Optional[str],
) -> Optional[VirtualIdentityConfig]:
"""根据初始参数解析虚拟身份配置。
Args:
platform: 平台名称
person_id: 人物 ID
group_name: 群组名称
group_id: 群组 ID
Returns:
Optional[VirtualIdentityConfig]: 解析成功时返回虚拟身份配置
"""
if not (platform and person_id):
return None
@@ -210,11 +472,14 @@ def resolve_initial_virtual_identity(
group_name=group_name or "WebUI虚拟群聊",
)
logger.info(
f"虚拟身份模式已通过 URL 参数激活: {virtual_config.user_nickname} @ {virtual_config.platform}, group_id={virtual_group_id}"
"虚拟身份模式已通过参数激活: %s @ %s, group_id=%s",
virtual_config.user_nickname,
virtual_config.platform,
virtual_group_id,
)
return virtual_config
except Exception as e:
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
except Exception as exc:
logger.warning(f"通过参数配置虚拟身份失败: {exc}")
return None
@@ -224,6 +489,17 @@ def build_session_info_message(
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> Dict[str, Any]:
"""构建会话信息消息。
Args:
session_id: 内部逻辑会话 ID
user_id: 规范化后的用户 ID
user_name: 当前昵称
virtual_config: 虚拟身份配置
Returns:
Dict[str, Any]: 会话信息消息
"""
session_info_data: Dict[str, Any] = {
"type": "session_info",
"session_id": session_id,
@@ -247,13 +523,41 @@ def build_session_info_message(
def get_active_history_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> Optional[str]:
"""获取当前虚拟身份对应的历史群组 ID。
Args:
virtual_config: 虚拟身份配置
Returns:
Optional[str]: 虚拟身份启用时返回对应群组 ID
"""
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return virtual_config.group_id
return None
def get_current_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> str:
"""获取当前会话的有效群组 ID。
Args:
virtual_config: 虚拟身份配置
Returns:
str: 当前会话应使用的群组 ID
"""
return get_active_history_group_id(virtual_config) or WEBUI_CHAT_GROUP_ID
def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> str:
"""构建欢迎消息。
Args:
virtual_config: 虚拟身份配置
Returns:
str: 欢迎消息文本
"""
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return (
@@ -264,6 +568,12 @@ def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> st
async def send_chat_error(session_id: str, content: str) -> None:
"""向指定会话发送错误消息。
Args:
session_id: 内部逻辑会话 ID
content: 错误消息内容
"""
await chat_manager.send_message(
session_id,
{
@@ -279,7 +589,17 @@ async def send_initial_chat_state(
user_id: str,
user_name: str,
virtual_config: Optional[VirtualIdentityConfig],
include_welcome: bool = True,
) -> None:
"""向新会话发送初始化状态。
Args:
session_id: 内部逻辑会话 ID
user_id: 规范化后的用户 ID
user_name: 当前昵称
virtual_config: 虚拟身份配置
include_welcome: 是否发送欢迎消息
"""
await chat_manager.send_message(
session_id,
build_session_info_message(
@@ -290,30 +610,43 @@ async def send_initial_chat_state(
),
)
if history := chat_history.get_history(50, get_active_history_group_id(virtual_config)):
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": history,
},
)
history_group_id = get_active_history_group_id(virtual_config)
history = chat_history.get_history(50, history_group_id)
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": build_welcome_message(virtual_config),
"timestamp": time.time(),
"type": "history",
"messages": history,
"group_id": get_current_group_id(virtual_config),
},
)
if include_welcome:
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": build_welcome_message(virtual_config),
"timestamp": time.time(),
},
)
def resolve_sender_identity(
current_user_name: str,
normalized_user_id: str,
virtual_config: Optional[VirtualIdentityConfig],
) -> Tuple[str, str]:
"""解析当前发送者身份。
Args:
current_user_name: 当前昵称
normalized_user_id: 规范化后的用户 ID
virtual_config: 虚拟身份配置
Returns:
Tuple[str, str]: ``(发送者昵称, 发送者用户 ID)``
"""
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
return virtual_config.user_nickname or current_user_name, virtual_config.user_id or normalized_user_id
@@ -328,6 +661,19 @@ def create_message_data(
is_at_bot: bool = True,
virtual_config: Optional[VirtualIdentityConfig] = None,
) -> Dict[str, Any]:
"""构建发送给聊天核心的消息数据。
Args:
content: 文本内容
user_id: 用户 ID
user_name: 用户昵称
message_id: 消息 ID
is_at_bot: 是否默认艾特机器人
virtual_config: 虚拟身份配置
Returns:
Dict[str, Any]: 聊天核心可处理的消息数据
"""
if message_id is None:
message_id = str(uuid.uuid4())
@@ -389,6 +735,18 @@ async def handle_chat_message(
normalized_user_id: str,
current_virtual_config: Optional[VirtualIdentityConfig],
) -> str:
"""处理用户发送的聊天消息。
Args:
session_id: 内部逻辑会话 ID
data: 前端提交的消息数据
current_user_name: 当前昵称
normalized_user_id: 规范化后的用户 ID
current_virtual_config: 当前虚拟身份配置
Returns:
str: 处理后的最新昵称
"""
content = str(data.get("content", "")).strip()
if not content:
return current_user_name
@@ -401,11 +759,14 @@ async def handle_chat_message(
normalized_user_id=normalized_user_id,
virtual_config=current_virtual_config,
)
target_group_id = get_current_group_id(current_virtual_config)
await chat_manager.broadcast(
await chat_manager.broadcast_to_group(
target_group_id,
{
"type": "user_message",
"content": content,
"group_id": target_group_id,
"message_id": message_id,
"timestamp": timestamp,
"sender": {
@@ -414,7 +775,7 @@ async def handle_chat_message(
"is_bot": False,
},
"virtual_mode": is_virtual_mode_enabled(current_virtual_config),
}
},
)
message_data = create_message_data(
@@ -427,22 +788,37 @@ async def handle_chat_message(
)
try:
await chat_manager.broadcast({"type": "typing", "is_typing": True})
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": True})
await chat_bot.message_process(message_data)
except Exception as e:
logger.error(f"处理消息时出错: {e}")
await send_chat_error(session_id, f"处理消息时出错: {str(e)}")
except Exception as exc:
logger.error(f"处理消息时出错: {exc}")
await send_chat_error(session_id, f"处理消息时出错: {str(exc)}")
finally:
await chat_manager.broadcast({"type": "typing", "is_typing": False})
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": False})
return next_user_name
async def handle_chat_ping(session_id: str) -> None:
"""处理聊天心跳。
Args:
session_id: 内部逻辑会话 ID
"""
await chat_manager.send_message(session_id, {"type": "pong", "timestamp": time.time()})
async def handle_nickname_update(session_id: str, data: Dict[str, Any], current_user_name: str) -> str:
"""处理昵称更新请求。
Args:
session_id: 内部逻辑会话 ID
data: 前端提交的数据
current_user_name: 当前昵称
Returns:
str: 更新后的昵称
"""
new_name = str(data.get("user_name", "")).strip()
if not new_name:
return current_user_name
@@ -463,6 +839,16 @@ async def enable_virtual_identity(
session_prefix: str,
virtual_data: Dict[str, Any],
) -> Optional[VirtualIdentityConfig]:
"""启用虚拟身份模式。
Args:
session_id: 内部逻辑会话 ID
session_prefix: 会话前缀用于生成默认群组 ID
virtual_data: 前端提交的虚拟身份配置
Returns:
Optional[VirtualIdentityConfig]: 启用成功时返回新的虚拟身份配置
"""
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
await send_chat_error(session_id, "虚拟身份配置缺少必要字段: platform 和 person_id")
return None
@@ -470,16 +856,18 @@ async def enable_virtual_identity(
person_id_value = str(virtual_data.get("person_id"))
try:
person = get_person_by_person_id(person_id_value)
if not person:
if person is None:
await send_chat_error(session_id, f"找不到用户: {person_id_value}")
return None
custom_group_id = virtual_data.get("group_id")
current_group_id = (
f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
if custom_group_id
else f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
)
custom_group_id = str(virtual_data.get("group_id") or "").strip()
if custom_group_id:
current_group_id = custom_group_id
if not current_group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{current_group_id}"
else:
current_group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_prefix}"
current_virtual_config = build_virtual_identity_config(
person=person,
group_id=current_group_id,
@@ -521,13 +909,18 @@ async def enable_virtual_identity(
},
)
return current_virtual_config
except Exception as e:
logger.error(f"设置虚拟身份失败: {e}")
await send_chat_error(session_id, f"设置虚拟身份失败: {str(e)}")
except Exception as exc:
logger.error(f"设置虚拟身份失败: {exc}")
await send_chat_error(session_id, f"设置虚拟身份失败: {str(exc)}")
return None
async def disable_virtual_identity(session_id: str) -> None:
"""关闭虚拟身份模式。
Args:
session_id: 内部逻辑会话 ID
"""
await chat_manager.send_message(
session_id,
{
@@ -560,7 +953,18 @@ async def handle_virtual_identity_update(
data: Dict[str, Any],
current_virtual_config: Optional[VirtualIdentityConfig],
) -> Optional[VirtualIdentityConfig]:
virtual_data = cast(dict[str, Any], data.get("config", {}))
"""处理虚拟身份切换请求。
Args:
session_id: 内部逻辑会话 ID
session_id_prefix: 会话前缀
data: 前端提交的数据
current_virtual_config: 当前虚拟身份配置
Returns:
Optional[VirtualIdentityConfig]: 更新后的虚拟身份配置
"""
virtual_data = cast(Dict[str, Any], data.get("config", {}))
if virtual_data.get("enabled"):
next_config = await enable_virtual_identity(session_id, session_id_prefix, virtual_data)
return next_config if next_config is not None else current_virtual_config
@@ -577,6 +981,19 @@ async def dispatch_chat_event(
normalized_user_id: str,
current_virtual_config: Optional[VirtualIdentityConfig],
) -> Tuple[str, Optional[VirtualIdentityConfig]]:
"""分发聊天事件到对应的处理器。
Args:
session_id: 内部逻辑会话 ID
session_id_prefix: 会话前缀
data: 前端提交的数据
current_user_name: 当前昵称
normalized_user_id: 规范化后的用户 ID
current_virtual_config: 当前虚拟身份配置
Returns:
Tuple[str, Optional[VirtualIdentityConfig]]: ``(最新昵称, 最新虚拟身份配置)``
"""
event_type = data.get("type")
if event_type == "message":
next_user_name = await handle_chat_message(

View File

@@ -24,10 +24,8 @@ from src.config.official_configs import (
ChineseTypoConfig,
DebugConfig,
EmojiConfig,
ExperimentalConfig,
ExpressionConfig,
KeywordReactionConfig,
LPMMKnowledgeConfig,
MaimMessageConfig,
MemoryConfig,
MessageReceiveConfig,
@@ -109,9 +107,7 @@ async def get_config_section_schema(section_name: str):
- response_post_process: ResponsePostProcessConfig
- response_splitter: ResponseSplitterConfig
- telemetry: TelemetryConfig
- experimental: ExperimentalConfig
- maim_message: MaimMessageConfig
- lpmm_knowledge: LPMMKnowledgeConfig
- memory: MemoryConfig
- debug: DebugConfig
- voice: VoiceConfig
@@ -133,9 +129,7 @@ async def get_config_section_schema(section_name: str):
"response_post_process": ResponsePostProcessConfig,
"response_splitter": ResponseSplitterConfig,
"telemetry": TelemetryConfig,
"experimental": ExperimentalConfig,
"maim_message": MaimMessageConfig,
"lpmm_knowledge": LPMMKnowledgeConfig,
"memory": MemoryConfig,
"debug": DebugConfig,
"voice": VoiceConfig,

View File

@@ -6,11 +6,13 @@ from .catalog import router as catalog_router
from .config_routes import router as config_router
from .management import router as management_router
from .progress import get_progress_router, update_progress
from .runtime_routes import router as runtime_router
router = APIRouter(prefix="/plugins", tags=["插件管理"])
router.include_router(catalog_router)
router.include_router(management_router)
router.include_router(config_router)
router.include_router(runtime_router)
set_update_progress_callback(update_progress)

View File

@@ -1,17 +1,18 @@
import json
"""插件配置相关 WebUI 路由。"""
from pathlib import Path
from typing import Any, Dict, Optional, cast
import tomlkit
from fastapi import APIRouter, Cookie, HTTPException
from src.common.logger import get_logger
from src.plugin_runtime.protocol.envelope import InspectPluginConfigResultPayload
from src.webui.utils.toml_utils import save_toml_with_format
from .schemas import UpdatePluginConfigRequest, UpdatePluginRawConfigRequest
from .support import (
backup_file,
coerce_types,
find_plugin_instance,
find_plugin_path_by_id,
get_plugin_config_path,
normalize_dotted_keys,
@@ -39,6 +40,16 @@ def _to_builtin_data(obj: Any) -> Any:
def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> Dict[str, Any]:
"""根据当前配置内容自动推断一个兜底 Schema。
Args:
plugin_id: 插件 ID。
current_config: 当前配置对象。
Returns:
Dict[str, Any]: 可供前端渲染的兜底 Schema。
"""
schema: Dict[str, Any] = {
"plugin_id": plugin_id,
"plugin_info": {
@@ -134,33 +145,187 @@ def _build_schema_from_current_config(plugin_id: str, current_config: Any) -> Di
return schema
def _coerce_scalar_value(field_schema: Dict[str, Any], value: Any) -> Any:
"""根据字段 Schema 规范化单个字段值。
Args:
field_schema: 单个字段 Schema。
value: 当前字段值。
Returns:
Any: 规范化后的字段值。
"""
field_type = str(field_schema.get("type", "") or "").lower()
if field_type == "boolean" and isinstance(value, str):
normalized_value = value.strip().lower()
if normalized_value in {"1", "true", "yes", "on"}:
return True
if normalized_value in {"0", "false", "no", "off"}:
return False
if field_type == "integer" and isinstance(value, str):
try:
return int(value)
except ValueError:
return value
if field_type == "number" and isinstance(value, str):
try:
return float(value)
except ValueError:
return value
if field_type == "array" and isinstance(value, str):
return [item.strip() for item in value.split(",") if item.strip()]
return value
def _coerce_config_by_plugin_schema(schema: Dict[str, Any], config_data: Dict[str, Any]) -> None:
"""根据插件配置 Schema 就地规范化配置值类型。
Args:
schema: 插件配置 Schema。
config_data: 待规范化的配置字典。
"""
sections = schema.get("sections")
if not isinstance(sections, dict):
return
for section_name, section_schema in sections.items():
if not isinstance(section_schema, dict):
continue
if section_name not in config_data or not isinstance(config_data[section_name], dict):
continue
section_fields = section_schema.get("fields")
if not isinstance(section_fields, dict):
continue
section_config = cast(Dict[str, Any], config_data[section_name])
for field_name, field_schema in section_fields.items():
if field_name not in section_config or not isinstance(field_schema, dict):
continue
section_config[field_name] = _coerce_scalar_value(field_schema, section_config[field_name])
def _build_toml_document(config_data: Dict[str, Any]) -> tomlkit.TOMLDocument:
"""将普通字典转换为 TOML 文档对象。
Args:
config_data: 原始配置字典。
Returns:
tomlkit.TOMLDocument: 解析后的 TOML 文档。
"""
if not config_data:
return tomlkit.document()
return tomlkit.parse(tomlkit.dumps(config_data))
def _load_plugin_config_from_disk(plugin_path: Path) -> Dict[str, Any]:
"""从磁盘读取插件配置。
Args:
plugin_path: 插件目录。
Returns:
Dict[str, Any]: 当前配置字典;文件不存在时返回空字典。
"""
config_path = resolve_plugin_file_path(plugin_path, "config.toml")
if not config_path.exists():
return {}
with open(config_path, "r", encoding="utf-8") as file_obj:
loaded_config = tomlkit.load(file_obj).unwrap()
return loaded_config if isinstance(loaded_config, dict) else {}
async def _inspect_plugin_config_via_runtime(
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
*,
use_provided_config: bool = False,
) -> InspectPluginConfigResultPayload | None:
"""通过插件运行时解析配置元数据。
Args:
plugin_id: 插件 ID。
config_data: 可选的配置内容。
use_provided_config: 是否优先使用传入配置而不是磁盘配置。
Returns:
InspectPluginConfigResultPayload | None: 运行时可用时返回解析结果,否则返回 ``None``。
Raises:
ValueError: 插件运行时明确拒绝解析请求时抛出。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
runtime_manager = get_plugin_runtime_manager()
return await runtime_manager.inspect_plugin_config(
plugin_id,
config_data,
use_provided_config=use_provided_config,
)
async def _validate_plugin_config_via_runtime(plugin_id: str, config_data: Dict[str, Any]) -> Dict[str, Any] | None:
"""通过插件运行时对配置进行校验。
Args:
plugin_id: 插件 ID。
config_data: 待校验的配置内容。
Returns:
Dict[str, Any] | None: 校验成功时返回规范化后的配置;若运行时不可用则返回
``None``,由调用方自行回退到静态 Schema 方案。
Raises:
ValueError: 插件运行时明确判定配置非法时抛出。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
runtime_manager = get_plugin_runtime_manager()
return await runtime_manager.validate_plugin_config(plugin_id, config_data)
@router.get("/config/{plugin_id}/schema")
async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
"""按插件 ID 返回配置 Schema。
Args:
plugin_id: 插件 ID。
maibot_session: 当前会话令牌。
Returns:
Dict[str, Any]: 包含 Schema 的响应字典。
"""
require_plugin_token(maibot_session)
logger.info(f"获取插件配置 Schema: {plugin_id}")
try:
plugin_instance = find_plugin_instance(plugin_id)
if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"):
return {"success": True, "schema": plugin_instance.get_webui_config_schema()}
plugin_path = find_plugin_path_by_id(plugin_id)
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
schema_json_path = resolve_plugin_file_path(plugin_path, "config_schema.json")
if schema_json_path.exists():
try:
with open(schema_json_path, "r", encoding="utf-8") as file_obj:
return {"success": True, "schema": json.load(file_obj)}
except Exception as e:
logger.warning(f"读取 config_schema.json 失败,回退到自动推断: {e}")
try:
runtime_snapshot = await _inspect_plugin_config_via_runtime(plugin_id)
except ValueError as exc:
logger.warning(f"插件 {plugin_id} 配置 Schema 解析失败,将回退到弱推断: {exc}")
runtime_snapshot = None
current_config: Any = {}
config_path = get_plugin_config_path(plugin_id, plugin_path)
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as file_obj:
current_config = tomlkit.load(file_obj)
if runtime_snapshot is not None and runtime_snapshot.config_schema:
return {"success": True, "schema": dict(runtime_snapshot.config_schema)}
current_config: Any = (
dict(runtime_snapshot.normalized_config)
if runtime_snapshot is not None
else _load_plugin_config_from_disk(plugin_path)
)
return {"success": True, "schema": _build_schema_from_current_config(plugin_id, current_config)}
except HTTPException:
@@ -172,6 +337,16 @@ async def get_plugin_config_schema(plugin_id: str, maibot_session: Optional[str]
@router.get("/config/{plugin_id}/raw")
async def get_plugin_config_raw(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
"""获取插件原始 TOML 配置内容。
Args:
plugin_id: 插件 ID。
maibot_session: 当前会话令牌。
Returns:
Dict[str, Any]: 包含原始配置文本的响应字典。
"""
require_plugin_token(maibot_session)
logger.info(f"获取插件原始配置: {plugin_id}")
@@ -199,6 +374,17 @@ async def update_plugin_config_raw(
request: UpdatePluginRawConfigRequest,
maibot_session: Optional[str] = Cookie(None),
) -> Dict[str, Any]:
"""更新插件原始 TOML 配置内容。
Args:
plugin_id: 插件 ID。
request: 原始配置更新请求。
maibot_session: 当前会话令牌。
Returns:
Dict[str, Any]: 更新结果。
"""
require_plugin_token(maibot_session)
logger.info(f"更新插件原始配置: {plugin_id}")
@@ -232,6 +418,16 @@ async def update_plugin_config_raw(
@router.get("/config/{plugin_id}")
async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
"""获取插件配置字典。
Args:
plugin_id: 插件 ID。
maibot_session: 当前会话令牌。
Returns:
Dict[str, Any]: 当前配置响应。
"""
require_plugin_token(maibot_session)
logger.info(f"获取插件配置: {plugin_id}")
@@ -241,12 +437,24 @@ async def get_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cook
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = get_plugin_config_path(plugin_id, plugin_path)
try:
runtime_snapshot = await _inspect_plugin_config_via_runtime(plugin_id)
except ValueError as exc:
logger.warning(f"插件 {plugin_id} 配置读取失败,将回退到磁盘内容: {exc}")
runtime_snapshot = None
if runtime_snapshot is not None:
message = "配置文件不存在,已返回默认配置" if not config_path.exists() else ""
return {
"success": True,
"config": dict(runtime_snapshot.normalized_config),
"message": message,
}
if not config_path.exists():
return {"success": True, "config": {}, "message": "配置文件不存在"}
with open(config_path, "r", encoding="utf-8") as file_obj:
config = tomlkit.load(file_obj)
return {"success": True, "config": _to_builtin_data(config)}
return {"success": True, "config": _load_plugin_config_from_disk(plugin_path)}
except HTTPException:
raise
except Exception as e:
@@ -260,21 +468,40 @@ async def update_plugin_config(
request: UpdatePluginConfigRequest,
maibot_session: Optional[str] = Cookie(None),
) -> Dict[str, Any]:
"""更新插件结构化配置。
Args:
plugin_id: 插件 ID。
request: 结构化配置更新请求。
maibot_session: 当前会话令牌。
Returns:
Dict[str, Any]: 更新结果。
"""
require_plugin_token(maibot_session)
logger.info(f"更新插件配置: {plugin_id}")
try:
plugin_instance = find_plugin_instance(plugin_id)
config_data = request.config or {}
if plugin_instance and isinstance(config_data, dict):
config_data = normalize_dotted_keys(config_data)
if isinstance(plugin_instance.config_schema, dict):
coerce_types(plugin_instance.config_schema, config_data)
plugin_path = find_plugin_path_by_id(plugin_id)
if plugin_path is None:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_data = request.config or {}
if isinstance(config_data, dict):
config_data = normalize_dotted_keys(config_data)
runtime_validated_config = await _validate_plugin_config_via_runtime(plugin_id, config_data)
if isinstance(runtime_validated_config, dict):
config_data = runtime_validated_config
else:
runtime_snapshot = await _inspect_plugin_config_via_runtime(
plugin_id,
config_data,
use_provided_config=True,
)
if runtime_snapshot is not None and runtime_snapshot.config_schema:
_coerce_config_by_plugin_schema(dict(runtime_snapshot.config_schema), config_data)
config_path = get_plugin_config_path(plugin_id, plugin_path)
backup_path = backup_file(config_path, "backup")
if backup_path is not None:
@@ -284,6 +511,8 @@ async def update_plugin_config(
save_toml_with_format(config_data, str(config_path))
logger.info(f"已更新插件配置: {plugin_id}")
return {"success": True, "message": "配置已保存", "note": "配置更改将自动热更新到对应插件"}
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except HTTPException:
raise
except Exception as e:
@@ -293,6 +522,16 @@ async def update_plugin_config(
@router.post("/config/{plugin_id}/reset")
async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
"""重置插件配置文件。
Args:
plugin_id: 插件 ID。
maibot_session: 当前会话令牌。
Returns:
Dict[str, Any]: 重置结果。
"""
require_plugin_token(maibot_session)
logger.info(f"重置插件配置: {plugin_id}")
@@ -317,6 +556,16 @@ async def reset_plugin_config(plugin_id: str, maibot_session: Optional[str] = Co
@router.post("/config/{plugin_id}/toggle")
async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
"""切换插件启用状态。
Args:
plugin_id: 插件 ID。
maibot_session: 当前会话令牌。
Returns:
Dict[str, Any]: 切换结果。
"""
require_plugin_token(maibot_session)
logger.info(f"切换插件状态: {plugin_id}")
@@ -326,16 +575,29 @@ async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(N
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = get_plugin_config_path(plugin_id, plugin_path)
config = tomlkit.document()
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as file_obj:
config = tomlkit.load(file_obj)
try:
runtime_snapshot = await _inspect_plugin_config_via_runtime(plugin_id)
except ValueError as exc:
logger.warning(f"插件 {plugin_id} 状态切换前配置解析失败,将回退到磁盘内容: {exc}")
runtime_snapshot = None
if "plugin" not in config:
current_config = (
dict(runtime_snapshot.normalized_config)
if runtime_snapshot is not None
else _load_plugin_config_from_disk(plugin_path)
)
config = _build_toml_document(current_config)
plugin_section = config.get("plugin")
if plugin_section is None or not hasattr(plugin_section, "get"):
config["plugin"] = tomlkit.table()
plugin_config = cast(Any, config["plugin"])
current_enabled = bool(plugin_config.get("enabled", True))
current_enabled = (
bool(runtime_snapshot.enabled)
if runtime_snapshot is not None
else bool(plugin_config.get("enabled", True))
)
new_enabled = not current_enabled
plugin_config["enabled"] = new_enabled
config_path.parent.mkdir(parents=True, exist_ok=True)
@@ -347,7 +609,7 @@ async def toggle_plugin(plugin_id: str, maibot_session: Optional[str] = Cookie(N
"success": True,
"enabled": new_enabled,
"message": f"插件已{status}",
"note": "状态更改将在下次加载插件时生效",
"note": "状态更改将自动热更新到对应插件",
}
except HTTPException:
raise

View File

@@ -1,12 +1,15 @@
"""插件进度实时推送支持。"""
from typing import Any, Dict, Optional, Set
import asyncio
import json
from typing import Any, Dict, Optional, Set
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.routers.websocket.manager import websocket_manager
logger = get_logger("webui.plugin_progress")
@@ -25,25 +28,29 @@ current_progress: Dict[str, Any] = {
}
def get_current_progress() -> Dict[str, Any]:
"""获取当前插件进度快照。
Returns:
Dict[str, Any]: 当前插件进度数据副本。
"""
return current_progress.copy()
async def broadcast_progress(progress_data: Dict[str, Any]) -> None:
"""向统一连接层广播插件进度更新。
Args:
progress_data: 插件进度数据。
"""
global current_progress
current_progress = progress_data.copy()
if not active_connections:
return
message = json.dumps(progress_data, ensure_ascii=False)
disconnected: Set[WebSocket] = set()
for websocket in active_connections:
try:
await websocket.send_text(message)
except Exception as e:
logger.error(f"发送进度更新失败: {e}")
disconnected.add(websocket)
for websocket in disconnected:
active_connections.discard(websocket)
await websocket_manager.broadcast_to_topic(
domain="plugin_progress",
topic="main",
event="update",
data={"progress": progress_data},
)
async def update_progress(
@@ -56,6 +63,18 @@ async def update_progress(
total_plugins: int = 0,
loaded_plugins: int = 0,
) -> None:
"""更新当前插件进度并广播。
Args:
stage: 当前阶段。
progress: 当前进度百分比。
message: 进度说明消息。
operation: 当前操作类型。
error: 可选的错误信息。
plugin_id: 当前处理的插件 ID。
total_plugins: 总插件数量。
loaded_plugins: 已处理插件数量。
"""
progress_data = {
"operation": operation,
"stage": stage,
@@ -74,6 +93,12 @@ async def update_progress(
@router.websocket("/ws/plugin-progress")
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
"""旧版插件进度 WebSocket 入口。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
"""
is_authenticated = False
if token and verify_ws_token(token):
@@ -105,17 +130,22 @@ async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] =
data = await websocket.receive_text()
if data == "ping":
await websocket.send_text("pong")
except Exception as e:
logger.error(f"处理客户端消息时出错: {e}")
except Exception as exc:
logger.error(f"处理客户端消息时出错: {exc}")
break
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
except Exception as e:
logger.error(f"❌ WebSocket 错误: {e}")
except Exception as exc:
logger.error(f"❌ WebSocket 错误: {exc}")
active_connections.discard(websocket)
def get_progress_router() -> APIRouter:
"""获取旧版插件进度路由对象。
Returns:
APIRouter: 插件进度路由对象。
"""
return router

View File

@@ -0,0 +1,28 @@
"""插件运行时相关 WebUI 路由。"""
from typing import Optional
from fastapi import APIRouter, Cookie
from src.plugin_runtime.component_query import component_query_service
from .schemas import HookSpecListResponse, HookSpecResponse
from .support import require_plugin_token
router = APIRouter()
@router.get("/runtime/hooks", response_model=HookSpecListResponse)
async def list_runtime_hook_specs(maibot_session: Optional[str] = Cookie(None)) -> HookSpecListResponse:
"""返回当前插件运行时公开的 Hook 规格清单。
Args:
maibot_session: 当前 WebUI 会话令牌。
Returns:
HookSpecListResponse: Hook 规格列表响应。
"""
require_plugin_token(maibot_session)
hooks = [HookSpecResponse(**hook_data) for hook_data in component_query_service.list_hook_specs()]
return HookSpecListResponse(success=True, hooks=hooks)

View File

@@ -111,3 +111,19 @@ class UpdatePluginConfigRequest(BaseModel):
class UpdatePluginRawConfigRequest(BaseModel):
config: str = Field(..., description="原始 TOML 配置内容")
class HookSpecResponse(BaseModel):
name: str = Field(..., description="Hook 名称")
description: str = Field("", description="Hook 描述")
parameters_schema: Dict[str, Any] = Field(default_factory=dict, description="Hook 参数模型")
default_timeout_ms: int = Field(..., description="默认超时毫秒数")
allow_blocking: bool = Field(..., description="是否允许 blocking 处理器")
allow_observe: bool = Field(..., description="是否允许 observe 处理器")
allow_abort: bool = Field(..., description="是否允许 abort")
allow_kwargs_mutation: bool = Field(..., description="是否允许修改 kwargs")
class HookSpecListResponse(BaseModel):
success: bool = Field(..., description="是否成功")
hooks: List[HookSpecResponse] = Field(default_factory=list, description="Hook 规格列表")

View File

@@ -1,7 +1,7 @@
from .auth import router as ws_auth_router
from .logs import router as logs_router
"""WebSocket 路由包。"""
__all__ = [
"logs_router",
"ws_auth_router",
"auth",
"manager",
"unified",
]

View File

@@ -1,11 +0,0 @@
"""WebSocket 日志推送路由兼容导出。"""
from src.webui.logs_ws import active_connections, broadcast_log, load_recent_logs, router, websocket_logs
__all__ = [
"active_connections",
"broadcast_log",
"load_recent_logs",
"router",
"websocket_logs",
]

View File

@@ -0,0 +1,297 @@
"""统一 WebSocket 连接管理器。"""
import asyncio
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Set
from fastapi import WebSocket
from src.common.logger import get_logger
logger = get_logger("webui.websocket")
@dataclass
class WebSocketConnection:
"""统一 WebSocket 连接上下文。"""
connection_id: str
websocket: WebSocket
subscriptions: Set[str] = field(default_factory=set)
chat_sessions: Dict[str, str] = field(default_factory=dict)
send_queue: "asyncio.Queue[Optional[Dict[str, Any]]]" = field(default_factory=asyncio.Queue)
sender_task: Optional["asyncio.Task[None]"] = None
class UnifiedWebSocketManager:
"""统一 WebSocket 连接管理器。"""
def __init__(self) -> None:
"""初始化统一 WebSocket 连接管理器。"""
self.connections: Dict[str, WebSocketConnection] = {}
def _build_subscription_key(self, domain: str, topic: str) -> str:
"""构建订阅索引键。
Args:
domain: 业务域名称。
topic: 主题名称。
Returns:
str: 订阅索引键。
"""
return f"{domain}:{topic}"
async def _sender_loop(self, connection: WebSocketConnection) -> None:
"""串行发送指定连接的出站消息。
Args:
connection: 目标连接上下文。
"""
try:
while True:
message = await connection.send_queue.get()
if message is None:
return
await connection.websocket.send_json(message)
except asyncio.CancelledError:
raise
except Exception as exc:
logger.error("统一 WebSocket 发送失败: connection=%s, error=%s", connection.connection_id, exc)
async def connect(self, connection_id: str, websocket: WebSocket) -> WebSocketConnection:
"""注册一个新的物理 WebSocket 连接。
Args:
connection_id: 连接 ID。
websocket: FastAPI WebSocket 对象。
Returns:
WebSocketConnection: 新建的连接上下文。
"""
await websocket.accept()
connection = WebSocketConnection(connection_id=connection_id, websocket=websocket)
connection.sender_task = asyncio.create_task(self._sender_loop(connection))
self.connections[connection_id] = connection
return connection
async def disconnect(self, connection_id: str) -> None:
"""断开并清理指定连接。
Args:
connection_id: 连接 ID。
"""
connection = self.connections.pop(connection_id, None)
if connection is None:
return
await connection.send_queue.put(None)
if connection.sender_task is not None:
try:
await connection.sender_task
except asyncio.CancelledError:
pass
except Exception as exc:
logger.debug("等待发送协程退出时出现异常: connection=%s, error=%s", connection_id, exc)
def get_connection(self, connection_id: str) -> Optional[WebSocketConnection]:
"""获取指定连接上下文。
Args:
connection_id: 连接 ID。
Returns:
Optional[WebSocketConnection]: 找到时返回连接上下文。
"""
return self.connections.get(connection_id)
def register_chat_session(self, connection_id: str, client_session_id: str, session_id: str) -> None:
"""登记连接下的逻辑聊天会话。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
session_id: 内部会话 ID。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.chat_sessions[client_session_id] = session_id
def unregister_chat_session(self, connection_id: str, client_session_id: str) -> None:
"""移除连接下的逻辑聊天会话登记。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.chat_sessions.pop(client_session_id, None)
def get_chat_session_id(self, connection_id: str, client_session_id: str) -> Optional[str]:
"""查询连接下的内部聊天会话 ID。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
Returns:
Optional[str]: 找到时返回内部会话 ID。
"""
connection = self.connections.get(connection_id)
if connection is None:
return None
return connection.chat_sessions.get(client_session_id)
def subscribe(self, connection_id: str, domain: str, topic: str) -> None:
"""登记连接的主题订阅。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
topic: 主题名称。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.subscriptions.add(self._build_subscription_key(domain, topic))
def unsubscribe(self, connection_id: str, domain: str, topic: str) -> None:
"""移除连接的主题订阅。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
topic: 主题名称。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
connection.subscriptions.discard(self._build_subscription_key(domain, topic))
def is_subscribed(self, connection_id: str, domain: str, topic: str) -> bool:
"""判断连接是否订阅了指定主题。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
topic: 主题名称。
Returns:
bool: 已订阅时返回 ``True``。
"""
connection = self.connections.get(connection_id)
if connection is None:
return False
return self._build_subscription_key(domain, topic) in connection.subscriptions
async def enqueue(self, connection_id: str, message: Dict[str, Any]) -> None:
"""向指定连接的发送队列压入消息。
Args:
connection_id: 连接 ID。
message: 待发送的消息。
"""
connection = self.connections.get(connection_id)
if connection is None:
return
await connection.send_queue.put(message)
async def send_response(
self,
connection_id: str,
request_id: Optional[str],
ok: bool,
data: Optional[Dict[str, Any]] = None,
error: Optional[Dict[str, Any]] = None,
) -> None:
"""发送统一响应消息。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
ok: 请求是否成功。
data: 成功响应数据。
error: 失败响应数据。
"""
response_message: Dict[str, Any] = {
"op": "response",
"id": request_id,
"ok": ok,
}
if data is not None:
response_message["data"] = data
if error is not None:
response_message["error"] = error
await self.enqueue(connection_id, response_message)
async def send_event(
self,
connection_id: str,
domain: str,
event: str,
data: Dict[str, Any],
session: Optional[str] = None,
topic: Optional[str] = None,
) -> None:
"""发送统一事件消息。
Args:
connection_id: 连接 ID。
domain: 业务域名称。
event: 事件名称。
data: 事件数据。
session: 可选的逻辑会话 ID。
topic: 可选的主题名称。
"""
event_message: Dict[str, Any] = {
"op": "event",
"domain": domain,
"event": event,
"data": data,
}
if session is not None:
event_message["session"] = session
if topic is not None:
event_message["topic"] = topic
await self.enqueue(connection_id, event_message)
async def send_pong(self, connection_id: str, timestamp: float) -> None:
"""发送心跳响应。
Args:
connection_id: 连接 ID。
timestamp: 当前时间戳。
"""
await self.enqueue(
connection_id,
{
"op": "pong",
"ts": timestamp,
},
)
async def broadcast_to_topic(self, domain: str, topic: str, event: str, data: Dict[str, Any]) -> None:
"""向订阅指定主题的全部连接广播事件。
Args:
domain: 业务域名称。
topic: 主题名称。
event: 事件名称。
data: 事件数据。
"""
subscription_key = self._build_subscription_key(domain, topic)
for connection in list(self.connections.values()):
if subscription_key in connection.subscriptions:
await self.send_event(
connection.connection_id,
domain=domain,
event=event,
data=data,
topic=topic,
)
websocket_manager = UnifiedWebSocketManager()

View File

@@ -0,0 +1,548 @@
"""统一 WebSocket 路由。"""
from typing import Any, Dict, Optional, Set, cast
import asyncio
import time
import uuid
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.logs_ws import load_recent_logs
from src.webui.routers.chat.service import (
chat_manager,
dispatch_chat_event,
normalize_webui_user_id,
resolve_initial_virtual_identity,
send_initial_chat_state,
)
from src.webui.routers.plugin.progress import get_current_progress
from src.webui.routers.websocket.auth import verify_ws_token
from src.webui.routers.websocket.manager import websocket_manager
logger = get_logger("webui.unified_ws")
router = APIRouter()
_background_tasks: Set["asyncio.Task[None]"] = set()
def _build_error(code: str, message: str) -> Dict[str, Any]:
"""构建统一错误响应体。
Args:
code: 错误码。
message: 错误描述。
Returns:
Dict[str, Any]: 统一错误对象。
"""
return {
"code": code,
"message": message,
}
def _get_request_data(message: Dict[str, Any]) -> Dict[str, Any]:
"""从客户端消息中提取数据字段。
Args:
message: 客户端消息。
Returns:
Dict[str, Any]: 标准化后的数据字典。
"""
data = message.get("data", {})
if isinstance(data, dict):
return cast(Dict[str, Any], data)
return {}
def _track_background_task(task: "asyncio.Task[None]") -> None:
"""登记后台任务并在完成后自动清理。
Args:
task: 后台协程任务。
"""
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
async def authenticate_websocket_connection(websocket: WebSocket, token: Optional[str]) -> bool:
"""校验统一 WebSocket 连接的认证状态。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
Returns:
bool: 认证通过时返回 ``True``。
"""
if token and verify_ws_token(token):
logger.debug("统一 WebSocket 使用临时 token 认证成功")
return True
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
logger.debug("统一 WebSocket 使用 Cookie 认证成功")
return True
return False
async def _handle_logs_subscribe(connection_id: str, request_id: Optional[str], data: Dict[str, Any]) -> None:
"""处理日志域订阅请求。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
data: 订阅参数。
"""
replay_limit = int(data.get("replay", 100) or 100)
replay_limit = max(0, min(replay_limit, 500))
websocket_manager.subscribe(connection_id, domain="logs", topic="main")
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": "logs", "topic": "main"},
)
await websocket_manager.send_event(
connection_id,
domain="logs",
event="snapshot",
topic="main",
data={"entries": load_recent_logs(limit=replay_limit)},
)
async def _handle_plugin_progress_subscribe(connection_id: str, request_id: Optional[str]) -> None:
"""处理插件进度域订阅请求。
Args:
connection_id: 连接 ID。
request_id: 请求 ID。
"""
websocket_manager.subscribe(connection_id, domain="plugin_progress", topic="main")
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": "plugin_progress", "topic": "main"},
)
await websocket_manager.send_event(
connection_id,
domain="plugin_progress",
event="snapshot",
topic="main",
data={"progress": get_current_progress()},
)
async def _handle_subscribe(connection_id: str, message: Dict[str, Any]) -> None:
"""处理主题订阅请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
topic = str(message.get("topic") or "").strip()
data = _get_request_data(message)
if domain == "logs" and topic == "main":
await _handle_logs_subscribe(connection_id, request_id, data)
return
if domain == "plugin_progress" and topic == "main":
await _handle_plugin_progress_subscribe(connection_id, request_id)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_subscription", f"不支持的订阅目标: {domain}:{topic}"),
)
async def _handle_unsubscribe(connection_id: str, message: Dict[str, Any]) -> None:
"""处理主题退订请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
topic = str(message.get("topic") or "").strip()
if not domain or not topic:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("invalid_unsubscribe", "退订请求缺少 domain 或 topic"),
)
return
websocket_manager.unsubscribe(connection_id, domain=domain, topic=topic)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"domain": domain, "topic": topic},
)
async def _open_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
"""打开一个逻辑聊天会话。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
if not client_session_id:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("missing_session", "聊天会话打开请求缺少 session"),
)
return
data = _get_request_data(message)
normalized_user_id = normalize_webui_user_id(cast(Optional[str], data.get("user_id")))
current_user_name = str(data.get("user_name") or "WebUI用户")
current_virtual_config = resolve_initial_virtual_identity(
platform=cast(Optional[str], data.get("platform")),
person_id=cast(Optional[str], data.get("person_id")),
group_name=cast(Optional[str], data.get("group_name")),
group_id=cast(Optional[str], data.get("group_id")),
)
restore = bool(data.get("restore"))
session_id = f"{connection_id}:{client_session_id}"
async def send_chat_event(chat_message: Dict[str, Any]) -> None:
"""将聊天消息封装为统一事件并发送。
Args:
chat_message: 聊天消息体。
"""
event_name = str(chat_message.get("type") or "message")
await websocket_manager.send_event(
connection_id,
domain="chat",
event=event_name,
session=client_session_id,
data=chat_message,
)
await chat_manager.connect(
session_id=session_id,
connection_id=connection_id,
client_session_id=client_session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
sender=send_chat_event,
)
websocket_manager.register_chat_session(connection_id, client_session_id, session_id)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id, "session_id": session_id},
)
await send_initial_chat_state(
session_id=session_id,
user_id=normalized_user_id,
user_name=current_user_name,
virtual_config=current_virtual_config,
include_welcome=not restore,
)
async def _close_chat_session(connection_id: str, message: Dict[str, Any]) -> None:
"""关闭一个逻辑聊天会话。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
chat_manager.disconnect(session_id)
websocket_manager.unregister_chat_session(connection_id, client_session_id)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id},
)
async def _process_chat_message(connection_id: str, client_session_id: str, data: Dict[str, Any]) -> None:
"""在后台处理聊天消息事件。
Args:
connection_id: 连接 ID。
client_session_id: 前端会话 ID。
data: 客户端提交的消息数据。
"""
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
return
session_state = chat_manager.get_session(session_id)
if session_state is None:
return
next_user_name, next_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data=data,
current_user_name=session_state.user_name,
normalized_user_id=session_state.user_id,
current_virtual_config=session_state.virtual_config,
)
chat_manager.update_session_context(
session_id=session_id,
user_name=next_user_name,
virtual_config=next_virtual_config,
)
async def _handle_chat_message_send(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天消息发送请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
data = _get_request_data(message)
payload = {
"type": "message",
"content": data.get("content", ""),
"user_name": data.get("user_name", ""),
}
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"accepted": True, "session": client_session_id},
)
_track_background_task(asyncio.create_task(_process_chat_message(connection_id, client_session_id, payload)))
async def _handle_chat_nickname_update(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天昵称更新请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
client_session_id = str(message.get("session") or "").strip()
session_id = websocket_manager.get_chat_session_id(connection_id, client_session_id)
if session_id is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
data = _get_request_data(message)
session_state = chat_manager.get_session(session_id)
if session_state is None:
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("session_not_found", f"找不到聊天会话: {client_session_id}"),
)
return
next_user_name, next_virtual_config = await dispatch_chat_event(
session_id=session_id,
session_id_prefix=session_id[:8],
data={
"type": "update_nickname",
"user_name": data.get("user_name", ""),
},
current_user_name=session_state.user_name,
normalized_user_id=session_state.user_id,
current_virtual_config=session_state.virtual_config,
)
chat_manager.update_session_context(
session_id=session_id,
user_name=next_user_name,
virtual_config=next_virtual_config,
)
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=True,
data={"session": client_session_id, "user_name": next_user_name},
)
async def _handle_chat_call(connection_id: str, message: Dict[str, Any]) -> None:
"""处理聊天域调用请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
method = str(message.get("method") or "").strip()
if method == "session.open":
await _open_chat_session(connection_id, message)
return
if method == "session.close":
await _close_chat_session(connection_id, message)
return
if method == "message.send":
await _handle_chat_message_send(connection_id, message)
return
if method == "session.update_nickname":
await _handle_chat_nickname_update(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_method", f"不支持的聊天方法: {method}"),
)
async def _handle_call(connection_id: str, message: Dict[str, Any]) -> None:
"""处理统一调用请求。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
request_id = cast(Optional[str], message.get("id"))
domain = str(message.get("domain") or "").strip()
if domain == "chat":
await _handle_chat_call(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_domain", f"不支持的调用域: {domain}"),
)
async def handle_client_message(connection_id: str, message: Dict[str, Any]) -> None:
"""处理统一 WebSocket 客户端消息。
Args:
connection_id: 连接 ID。
message: 客户端消息。
"""
operation = str(message.get("op") or "").strip()
request_id = cast(Optional[str], message.get("id"))
if operation == "ping":
await websocket_manager.send_pong(connection_id, time.time())
return
if operation == "subscribe":
await _handle_subscribe(connection_id, message)
return
if operation == "unsubscribe":
await _handle_unsubscribe(connection_id, message)
return
if operation == "call":
await _handle_call(connection_id, message)
return
await websocket_manager.send_response(
connection_id,
request_id=request_id,
ok=False,
error=_build_error("unsupported_operation", f"不支持的操作: {operation}"),
)
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = Query(None)) -> None:
"""统一 WebSocket 入口。
Args:
websocket: FastAPI WebSocket 对象。
token: 可选的一次性握手 Token。
"""
if not await authenticate_websocket_connection(websocket, token):
logger.warning("统一 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
connection_id = uuid.uuid4().hex
await websocket_manager.connect(connection_id=connection_id, websocket=websocket)
logger.info("统一 WebSocket 客户端已连接: connection=%s", connection_id)
await websocket_manager.send_event(
connection_id,
domain="system",
event="ready",
data={"connection_id": connection_id, "timestamp": time.time()},
)
try:
while True:
raw_message = await websocket.receive_json()
if not isinstance(raw_message, dict):
await websocket_manager.send_response(
connection_id,
request_id=None,
ok=False,
error=_build_error("invalid_message", "消息必须是 JSON 对象"),
)
continue
await handle_client_message(connection_id, cast(Dict[str, Any], raw_message))
except WebSocketDisconnect:
logger.info("统一 WebSocket 客户端已断开: connection=%s", connection_id)
except Exception as exc:
logger.error(f"统一 WebSocket 处理失败: {exc}")
finally:
chat_manager.disconnect_connection(connection_id)
await websocket_manager.disconnect(connection_id)

View File

@@ -19,11 +19,11 @@ from src.webui.routers.jargon import router as jargon_router
from src.webui.routers.memory import router as memory_router
from src.webui.routers.model import router as model_router
from src.webui.routers.person import router as person_router
from src.webui.routers.plugin import get_progress_router
from src.webui.routers.plugin import router as plugin_router
from src.webui.routers.statistics import router as statistics_router
from src.webui.routers.system import router as system_router
from src.webui.routers.websocket.auth import router as ws_auth_router
from src.webui.routers.websocket.unified import router as unified_ws_router
logger = get_logger("webui.api")
@@ -44,8 +44,6 @@ router.include_router(jargon_router)
router.include_router(emoji_router)
# 注册插件管理路由
router.include_router(plugin_router)
# 注册插件进度 WebSocket 路由
router.include_router(get_progress_router())
# 注册系统控制路由
router.include_router(system_router)
# 注册模型列表获取路由
@@ -54,6 +52,8 @@ router.include_router(model_router)
router.include_router(memory_router)
# 注册 WebSocket 认证路由
router.include_router(ws_auth_router)
# 注册统一 WebSocket 路由
router.include_router(unified_ws_router)
class TokenVerifyRequest(BaseModel):