Ruff Format

This commit is contained in:
DrSmoothl
2026-03-13 11:45:26 +08:00
parent 2a510312bc
commit a576313b22
70 changed files with 956 additions and 731 deletions

View File

@@ -13,4 +13,3 @@ ENV_SESSION_TOKEN = "MAIBOT_SESSION_TOKEN"
ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS"
"""Runner 需要加载的插件目录列表os.pathsep 分隔)"""

View File

@@ -67,7 +67,9 @@ class CapabilityService:
# 1. 权限校验
allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation)
if not allowed:
error_code = ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
error_code = (
ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
)
return envelope.make_error_response(
error_code.value,
reason,

View File

@@ -22,8 +22,13 @@ class RegisteredComponent:
"""已注册的组件条目"""
__slots__ = (
"name", "full_name", "component_type", "plugin_id",
"metadata", "enabled", "_compiled_pattern",
"name",
"full_name",
"component_type",
"plugin_id",
"metadata",
"enabled",
"_compiled_pattern",
)
def __init__(
@@ -165,18 +170,14 @@ class ComponentRegistry:
"""按全名查询。"""
return self._components.get(full_name)
def get_components_by_type(
self, component_type: str, *, enabled_only: bool = True
) -> List[RegisteredComponent]:
def get_components_by_type(self, component_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""按类型查询。"""
type_dict = self._by_type.get(component_type, {})
if enabled_only:
return [c for c in type_dict.values() if c.enabled]
return list(type_dict.values())
def get_components_by_plugin(
self, plugin_id: str, *, enabled_only: bool = True
) -> List[RegisteredComponent]:
def get_components_by_plugin(self, plugin_id: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""按插件查询。"""
comps = self._by_plugin.get(plugin_id, [])
return [c for c in comps if c.enabled] if enabled_only else list(comps)
@@ -200,9 +201,7 @@ class ComponentRegistry:
return comp, {}
return None
def get_event_handlers(
self, event_type: str, *, enabled_only: bool = True
) -> List[RegisteredComponent]:
def get_event_handlers(self, event_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""获取特定事件类型的所有 event_handler按 weight 降序排列。"""
handlers = []
for comp in self._by_type.get("event_handler", {}).values():
@@ -213,9 +212,7 @@ class ComponentRegistry:
handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True)
return handlers
def get_workflow_steps(
self, stage: str, *, enabled_only: bool = True
) -> List[RegisteredComponent]:
def get_workflow_steps(self, stage: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
steps = []
for comp in self._by_type.get("workflow_step", {}).values():

View File

@@ -22,6 +22,7 @@ InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
class EventResult:
"""单个 EventHandler 的执行结果"""
__slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
def __init__(
@@ -107,9 +108,7 @@ class EventDispatcher:
modified_message = result.modified_message
else:
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
task = asyncio.create_task(
self._invoke_handler(invoke_fn, handler, args, event_type)
)
task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)

View File

@@ -11,6 +11,7 @@ from typing import Dict, List, Optional, Set, Tuple
@dataclass
class CapabilityToken:
"""能力令牌"""
plugin_id: str
generation: int
capabilities: Set[str] = field(default_factory=set)

View File

@@ -231,9 +231,7 @@ class RPCServer:
stale_count = 0
for _req_id, future in list(self._pending_requests.items()):
if not future.done():
future.set_exception(
RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已被新 generation 接管")
)
future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已被新 generation 接管"))
stale_count += 1
self._pending_requests.clear()
if stale_count:
@@ -399,9 +397,7 @@ class RPCServer:
result = await handler(envelope)
# 检查 handler 返回的信封是否包含错误信息
if result is not None and isinstance(result, Envelope) and result.error:
logger.warning(
f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}"
)
logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}")
except Exception as e:
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)

View File

@@ -39,6 +39,7 @@ logger = get_logger("plugin_runtime.host.supervisor")
# ─── 日志桥 ──────────────────────────────────────────────────────
class RunnerLogBridge:
"""将 Runner 进程上报的批量日志重放到主进程的 Logger 中。
@@ -80,9 +81,7 @@ class RunnerLogBridge:
stdlib_logging.getLogger(entry.logger_name).handle(record)
return envelope.make_response(
payload={"accepted": True, "count": len(batch.entries)}
)
return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)})
class PluginSupervisor:
@@ -101,8 +100,12 @@ class PluginSupervisor:
):
_cfg = global_config.plugin_runtime
self._plugin_dirs = plugin_dirs or []
self._health_interval = health_check_interval_sec if health_check_interval_sec is not None else _cfg.health_check_interval_sec
self._runner_spawn_timeout = runner_spawn_timeout_sec if runner_spawn_timeout_sec is not None else _cfg.runner_spawn_timeout_sec
self._health_interval = (
health_check_interval_sec if health_check_interval_sec is not None else _cfg.health_check_interval_sec
)
self._runner_spawn_timeout = (
runner_spawn_timeout_sec if runner_spawn_timeout_sec is not None else _cfg.runner_spawn_timeout_sec
)
# 基础设施
self._transport = create_transport_server(socket_path=socket_path)
@@ -114,6 +117,7 @@ class PluginSupervisor:
# 编解码
from src.plugin_runtime.protocol.codec import MsgPackCodec
codec = MsgPackCodec()
self._rpc_server = RPCServer(
@@ -124,7 +128,9 @@ class PluginSupervisor:
# Runner 子进程
self._runner_process: Optional[asyncio.subprocess.Process] = None
self._runner_generation: int = 0
self._max_restart_attempts: int = max_restart_attempts if max_restart_attempts is not None else _cfg.max_restart_attempts
self._max_restart_attempts: int = (
max_restart_attempts if max_restart_attempts is not None else _cfg.max_restart_attempts
)
self._restart_count: int = 0
# 已注册的插件组件信息
@@ -173,6 +179,7 @@ class PluginSupervisor:
extra_args: Optional[Dict[str, Any]] = None,
) -> Tuple[bool, Optional[Dict[str, Any]]]:
"""分发事件到所有对应 handler 的快捷方法。"""
async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
resp = await self.invoke_plugin(
method="plugin.emit_event",
@@ -196,6 +203,7 @@ class PluginSupervisor:
context: Optional[WorkflowContext] = None,
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
"""执行 Workflow Pipeline 的快捷方法。"""
async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
resp = await self.invoke_plugin(
method="plugin.invoke_workflow_step",
@@ -415,7 +423,9 @@ class PluginSupervisor:
env[ENV_PLUGIN_DIRS] = os.pathsep.join(self._plugin_dirs)
self._runner_process = await asyncio.create_subprocess_exec(
sys.executable, "-m", runner_module,
sys.executable,
"-m",
runner_module,
env=env,
# stdout 不捕获Runner 的日志均通过 IPC 传㛹RunnerIPCLogHandler
stdout=None,
@@ -557,9 +567,7 @@ class PluginSupervisor:
)
self._stderr_drain_task = task
task.add_done_callback(
lambda done_task: None
if self._stderr_drain_task is not done_task
else self._clear_stderr_drain_task()
lambda done_task: None if self._stderr_drain_task is not done_task else self._clear_stderr_drain_task()
)
def _clear_stderr_drain_task(self) -> None:

View File

@@ -44,6 +44,7 @@ HOOK_CONTINUE = "continue"
HOOK_SKIP_STAGE = "skip_stage"
HOOK_ABORT = "abort"
# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待
# 从配置文件读取,允许用户调整
def _get_blocking_timeout() -> float:
@@ -52,6 +53,7 @@ def _get_blocking_timeout() -> float:
class ModificationRecord:
"""消息修改记录"""
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
@@ -141,9 +143,7 @@ class WorkflowExecutor:
try:
# PLAN 阶段: 先做 Command 路由
if stage == "plan" and current_message:
cmd_result = await self._route_command(
command_invoke_fn or invoke_fn, current_message, ctx
)
cmd_result = await self._route_command(command_invoke_fn or invoke_fn, current_message, ctx)
if cmd_result is not None:
# 命令匹配成功,跳过 PLAN 阶段的 hook直接存结果进 stage_outputs
ctx.set_stage_output("plan", "command_result", cmd_result)
@@ -195,10 +195,10 @@ class WorkflowExecutor:
# 更新消息(仅 blocking hook 有权修改)
if modified:
changed_fields = _diff_keys(current_message, modified) if current_message else list(modified.keys())
ctx.modification_log.append(
ModificationRecord(stage, step.full_name, changed_fields)
changed_fields = (
_diff_keys(current_message, modified) if current_message else list(modified.keys())
)
ctx.modification_log.append(ModificationRecord(stage, step.full_name, changed_fields))
current_message = modified
if hook_result == HOOK_ABORT:
@@ -222,9 +222,7 @@ class WorkflowExecutor:
if nonblocking_steps and not skip_stage:
nb_tasks = [
asyncio.create_task(
self._invoke_step_fire_and_forget(
invoke_fn, step, stage, ctx, current_message
)
self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message)
)
for step in nonblocking_steps
]
@@ -314,12 +312,16 @@ class WorkflowExecutor:
step_start = time.perf_counter()
try:
coro = invoke_fn(step.plugin_id, step.name, {
"stage": stage,
"trace_id": ctx.trace_id,
"message": message,
"stage_outputs": ctx.stage_outputs,
})
coro = invoke_fn(
step.plugin_id,
step.name,
{
"stage": stage,
"trace_id": ctx.trace_id,
"message": message,
"stage_outputs": ctx.stage_outputs,
},
)
resp = await asyncio.wait_for(coro, timeout=timeout_sec)
ctx.timings[step_key] = time.perf_counter() - step_start
@@ -355,12 +357,16 @@ class WorkflowExecutor:
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
try:
coro = invoke_fn(step.plugin_id, step.name, {
"stage": stage,
"trace_id": ctx.trace_id,
"message": message,
"stage_outputs": ctx.stage_outputs,
})
coro = invoke_fn(
step.plugin_id,
step.name,
{
"stage": stage,
"trace_id": ctx.trace_id,
"message": message,
"stage_outputs": ctx.stage_outputs,
},
)
await asyncio.wait_for(coro, timeout=timeout_sec)
except asyncio.TimeoutError:
logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)")
@@ -393,12 +399,16 @@ class WorkflowExecutor:
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
try:
return await invoke_fn(matched.plugin_id, matched.name, {
"text": plain_text,
"message": message,
"trace_id": ctx.trace_id,
"matched_groups": matched_groups,
})
return await invoke_fn(
matched.plugin_id,
matched.name,
{
"text": plain_text,
"message": message,
"trace_id": ctx.trace_id,
"matched_groups": matched_groups,
},
)
except Exception as e:
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
ctx.errors.append(f"command:{matched.full_name}: {e}")

View File

@@ -113,9 +113,7 @@ class PluginRuntimeManager:
await self._thirdparty_supervisor.start()
started_supervisors.append(self._thirdparty_supervisor)
self._started = True
logger.info(
f"插件运行时已启动 — 内置: {builtin_dirs or ''}, 第三方: {thirdparty_dirs or ''}"
)
logger.info(f"插件运行时已启动 — 内置: {builtin_dirs or ''}, 第三方: {thirdparty_dirs or ''}")
except Exception as e:
logger.error(f"插件运行时启动失败: {e}", exc_info=True)
await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True)
@@ -303,7 +301,9 @@ class PluginRuntimeManager:
cap_service.register_capability("component.get_all_plugins", self._cap_component_get_all_plugins)
cap_service.register_capability("component.get_plugin_info", self._cap_component_get_plugin_info)
cap_service.register_capability("component.list_loaded_plugins", self._cap_component_list_loaded_plugins)
cap_service.register_capability("component.list_registered_plugins", self._cap_component_list_registered_plugins)
cap_service.register_capability(
"component.list_registered_plugins", self._cap_component_list_registered_plugins
)
cap_service.register_capability("component.enable", self._cap_component_enable)
cap_service.register_capability("component.disable", self._cap_component_disable)
cap_service.register_capability("component.load_plugin", self._cap_component_load_plugin)
@@ -1232,9 +1232,7 @@ class PluginRuntimeManager:
count: int = args.get("count", 1)
try:
results = await emoji_api.get_random(count=count)
emojis = [
{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results
]
emojis = [{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results]
return {"success": True, "emojis": emojis}
except Exception as e:
logger.error(f"[cap.emoji.get_random] 执行失败: {e}", exc_info=True)
@@ -1269,9 +1267,9 @@ class PluginRuntimeManager:
try:
results = await emoji_api.get_all()
emojis = [
{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results
] if results else []
emojis = (
[{"base64": b64, "description": desc, "emotion": emo} for b64, desc, emo in results] if results else []
)
return {"success": True, "emojis": emojis}
except Exception as e:
logger.error(f"[cap.emoji.get_all] 执行失败: {e}", exc_info=True)

View File

@@ -12,20 +12,16 @@ class Codec(ABC):
"""消息编解码器基类"""
@abstractmethod
def encode_envelope(self, envelope: Envelope) -> bytes:
...
def encode_envelope(self, envelope: Envelope) -> bytes: ...
@abstractmethod
def decode_envelope(self, data: bytes) -> Envelope:
...
def decode_envelope(self, data: bytes) -> Envelope: ...
@abstractmethod
def encode(self, obj: Dict[str, Any]) -> bytes:
...
def encode(self, obj: Dict[str, Any]) -> bytes: ...
@abstractmethod
def decode(self, data: bytes) -> Dict[str, Any]:
...
def decode(self, data: bytes) -> Dict[str, Any]: ...
class MsgPackCodec(Codec):

View File

@@ -24,8 +24,10 @@ MAX_SDK_VERSION = "1.99.99"
# ─── 消息类型 ──────────────────────────────────────────────────────
class MessageType(str, Enum):
"""RPC 消息类型"""
REQUEST = "request"
RESPONSE = "response"
EVENT = "event"
@@ -33,6 +35,7 @@ class MessageType(str, Enum):
# ─── 请求 ID 生成器 ───────────────────────────────────────────────
class RequestIdGenerator:
"""单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio"""
@@ -47,6 +50,7 @@ class RequestIdGenerator:
# ─── Envelope 模型 ─────────────────────────────────────────────────
class Envelope(BaseModel):
"""RPC 统一信封
@@ -75,7 +79,9 @@ class Envelope(BaseModel):
def is_event(self) -> bool:
return self.message_type == MessageType.EVENT
def make_response(self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None) -> "Envelope":
def make_response(
self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None
) -> "Envelope":
"""基于当前请求创建对应的响应信封"""
return Envelope(
protocol_version=self.protocol_version,
@@ -101,8 +107,10 @@ class Envelope(BaseModel):
# ─── 握手消息 ──────────────────────────────────────────────────────
class HelloPayload(BaseModel):
"""runner.hello 握手请求 payload"""
runner_id: str = Field(description="Runner 进程唯一标识")
sdk_version: str = Field(description="SDK 版本号")
session_token: str = Field(description="一次性会话令牌")
@@ -110,6 +118,7 @@ class HelloPayload(BaseModel):
class HelloResponsePayload(BaseModel):
"""runner.hello 握手响应 payload"""
accepted: bool = Field(description="是否接受连接")
host_version: str = Field(default="", description="Host 版本号")
assigned_generation: int = Field(default=0, description="分配的 generation 编号")
@@ -118,8 +127,10 @@ class HelloResponsePayload(BaseModel):
# ─── 组件注册消息 ──────────────────────────────────────────────────
class ComponentDeclaration(BaseModel):
"""单个组件声明"""
name: str = Field(description="组件名称")
component_type: str = Field(description="组件类型: action/command/tool/event_handler")
plugin_id: str = Field(description="所属插件 ID")
@@ -128,6 +139,7 @@ class ComponentDeclaration(BaseModel):
class RegisterComponentsPayload(BaseModel):
"""plugin.register_components 请求 payload"""
plugin_id: str = Field(description="插件 ID")
plugin_version: str = Field(default="1.0.0", description="插件版本")
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
@@ -136,36 +148,44 @@ class RegisterComponentsPayload(BaseModel):
# ─── 调用消息 ──────────────────────────────────────────────────────
class InvokePayload(BaseModel):
"""plugin.invoke_* 请求 payload"""
component_name: str = Field(description="要调用的组件名称")
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
class InvokeResultPayload(BaseModel):
"""plugin.invoke_* 响应 payload"""
success: bool = Field(description="是否成功")
result: Any = Field(default=None, description="返回值")
# ─── 能力调用消息 ──────────────────────────────────────────────────
class CapabilityRequestPayload(BaseModel):
"""cap.* 请求 payload插件 -> Host 能力调用)"""
capability: str = Field(description="能力名称,如 send.text, db.query")
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
class CapabilityResponsePayload(BaseModel):
"""cap.* 响应 payload"""
success: bool = Field(description="是否成功")
result: Any = Field(default=None, description="返回值")
# ─── 健康检查 ──────────────────────────────────────────────────────
class HealthPayload(BaseModel):
"""plugin.health 响应 payload"""
healthy: bool = Field(description="是否健康")
loaded_plugins: List[str] = Field(default_factory=list, description="已加载的插件列表")
uptime_ms: int = Field(default=0, description="运行时长(ms)")
@@ -173,11 +193,13 @@ class HealthPayload(BaseModel):
# ─── 配置更新 ──────────────────────────────────────────────────────
# TODO: Host 侧尚未实现配置变更检测与推送。Runner 端的 _handle_config_updated
# 已就绪,但当前无任何调用方通过 RPC 发送 plugin.config_updated 消息。
# 需要在 Supervisor 或 CapabilityService 中监听配置文件变化并主动推送。
class ConfigUpdatedPayload(BaseModel):
"""plugin.config_updated 事件 payload"""
plugin_id: str = Field(description="插件 ID")
config_version: str = Field(description="新配置版本")
config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容")
@@ -185,14 +207,17 @@ class ConfigUpdatedPayload(BaseModel):
# ─── 关停 ──────────────────────────────────────────────────────────
class ShutdownPayload(BaseModel):
"""plugin.shutdown / plugin.prepare_shutdown payload"""
reason: str = Field(default="normal", description="关停原因")
drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)")
# ─── 日志传输 ──────────────────────────────────────────────────────
class LogEntry(BaseModel):
"""单条日志记录Runner → Host 传输格式)"""
@@ -200,10 +225,7 @@ class LogEntry(BaseModel):
description="日志时间戳Unix epoch 毫秒",
)
level: int = Field(
description=(
"stdlib logging 整数级别:"
" 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"
),
description=("stdlib logging 整数级别: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"),
)
logger_name: str = Field(
description="Logger 名称,如 plugin.my_plugin.submodule",

View File

@@ -22,6 +22,7 @@ Host 端将其重放到主进程的 Logger以 plugin.<name> 为名)中,
- 后台刷新协程每 FLUSH_INTERVAL_SEC 秒或 FLUSH_BATCH_SIZE 条后批量发送
- IPC 发送失败时静默忽略stderr fallback 由 supervisor 的 drain task 覆盖
"""
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
@@ -203,6 +204,7 @@ class RunnerIPCLogHandler(logging.Handler):
# IPC 连接断开时回退到 stderr避免日志静默丢失
if not self._rpc_client.is_connected:
import sys
for entry in entries:
print(
f"[LOG-FALLBACK] [{entry.logger_name}] {entry.message}",
@@ -218,6 +220,7 @@ class RunnerIPCLogHandler(logging.Handler):
)
except Exception:
import sys
for entry in entries:
print(
f"[LOG-FALLBACK] [{entry.logger_name}] {entry.message}",

View File

@@ -105,9 +105,7 @@ class ManifestValidator:
def _check_manifest_version(self, manifest: Dict[str, Any]) -> None:
mv = manifest.get("manifest_version")
if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS:
self.errors.append(
f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}"
)
self.errors.append(f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}")
def _check_author(self, manifest: Dict[str, Any]) -> None:
author = manifest.get("author")

View File

@@ -240,8 +240,7 @@ class PluginLoader:
instance = self._try_load_legacy_plugin(module, plugin_id)
if instance is not None:
logger.info(
f"插件 {plugin_id} v{manifest.get('version', '?')} "
f"通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk"
f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk"
)
return PluginMeta(
plugin_id=plugin_id,
@@ -261,6 +260,7 @@ class PluginLoader:
return
try:
from maibot_sdk.compat._import_hook import install_hook
install_hook()
self._compat_hook_installed = True
except ImportError:
@@ -281,11 +281,7 @@ class PluginLoader:
for attr_name in dir(module):
obj = getattr(module, attr_name, None)
if (
isinstance(obj, type)
and issubclass(obj, LegacyBasePlugin)
and obj is not LegacyBasePlugin
):
if isinstance(obj, type) and issubclass(obj, LegacyBasePlugin) and obj is not LegacyBasePlugin:
legacy_cls = obj
break
@@ -294,6 +290,7 @@ class PluginLoader:
try:
from maibot_sdk.compat.legacy_adapter import LegacyPluginAdapter
legacy_instance = legacy_cls()
return LegacyPluginAdapter(legacy_instance)
except Exception as e:

View File

@@ -37,6 +37,7 @@ def _get_sdk_version() -> str:
"""从 maibot_sdk 包元数据中读取实际版本号,失败时回退到 1.0.0。"""
try:
from importlib.metadata import version
return version("maibot-plugin-sdk")
except Exception:
return "1.0.0"

View File

@@ -133,7 +133,9 @@ class PluginRunner:
self._suspend_console_handlers()
stdlib_logging.root.addHandler(handler)
self._log_handler = handler
logger.debug("RunnerIPCLogHandler \u5df2\u5b89\u88c3\uff0c\u63d2\u4ef6\u65e5\u5fd7\u5c06\u901a\u8fc7 IPC \u8f6c\u53d1\u5230\u4e3b\u8fdb\u7a0b")
logger.debug(
"RunnerIPCLogHandler \u5df2\u5b89\u88c3\uff0c\u63d2\u4ef6\u65e5\u5fd7\u5c06\u901a\u8fc7 IPC \u8f6c\u53d1\u5230\u4e3b\u8fdb\u7a0b"
)
async def _uninstall_log_handler(self) -> None:
"""关停前从 logging.root 移除 Handler 并刷空缓冲。
@@ -291,7 +293,11 @@ class PluginRunner:
)
try:
result = await handler_method(**invoke.args) if inspect.iscoroutinefunction(handler_method) else handler_method(**invoke.args)
result = (
await handler_method(**invoke.args)
if inspect.iscoroutinefunction(handler_method)
else handler_method(**invoke.args)
)
resp_payload = InvokeResultPayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except Exception as e:
@@ -332,7 +338,11 @@ class PluginRunner:
)
try:
raw = await handler_method(**invoke.args) if inspect.iscoroutinefunction(handler_method) else handler_method(**invoke.args)
raw = (
await handler_method(**invoke.args)
if inspect.iscoroutinefunction(handler_method)
else handler_method(**invoke.args)
)
# 规范化返回值:将 EventHandler 返回展平到 payload 顶层
if raw is None:
@@ -341,7 +351,9 @@ class PluginRunner:
result = {
"success": True,
# 兼容 guide.md 中文档的 {"blocked": True} 写法
"continue_processing": not raw.get("blocked", False) if "blocked" in raw else raw.get("continue_processing", True),
"continue_processing": not raw.get("blocked", False)
if "blocked" in raw
else raw.get("continue_processing", True),
"modified_message": raw.get("modified_message"),
"custom_result": raw.get("custom_result"),
}
@@ -383,7 +395,11 @@ class PluginRunner:
)
try:
raw = await handler_method(**invoke.args) if inspect.iscoroutinefunction(handler_method) else handler_method(**invoke.args)
raw = (
await handler_method(**invoke.args)
if inspect.iscoroutinefunction(handler_method)
else handler_method(**invoke.args)
)
# 规范化返回值
if isinstance(raw, str):
@@ -455,6 +471,7 @@ class PluginRunner:
# ─── sys.path 隔离 ────────────────────────────────────────
def _isolate_sys_path(plugin_dirs: List[str]) -> None:
"""清理 sys.path限制 Runner 子进程只能访问标准库、SDK 和插件目录。
@@ -504,9 +521,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
return self if self._should_block(fullname) else None
def load_module(self, fullname):
raise ImportError(
f"Runner 子进程不允许导入主程序模块: {fullname}"
)
raise ImportError(f"Runner 子进程不允许导入主程序模块: {fullname}")
def _should_block(self, fullname: str) -> bool:
# 放行非 src.* 的导入、以及 "src" 本身
@@ -514,8 +529,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
return False
# 放行白名单前缀
return not any(
fullname == prefix or fullname.startswith(f"{prefix}.")
for prefix in self._ALLOWED_SRC_PREFIXES
fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in self._ALLOWED_SRC_PREFIXES
)
sys.meta_path.insert(0, _PluginImportBlocker())
@@ -523,6 +537,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
# ─── 进程入口 ──────────────────────────────────────────────
async def _async_main() -> None:
"""异步主入口"""
host_address = os.environ.get(ENV_IPC_ADDRESS, "")

View File

@@ -20,6 +20,7 @@ MAX_FRAME_SIZE = 16 * 1024 * 1024 # 16 MB 最大帧大小
class ConnectionClosed(Exception):
"""连接已关闭"""
pass

View File

@@ -20,10 +20,12 @@ def create_transport_server(socket_path: Optional[str] = None) -> TransportServe
"""
if sys.platform != "win32":
from .uds import UDSTransportServer
return UDSTransportServer(socket_path=socket_path)
else:
# Windows 回退到 TCP后续可改为 Named Pipe
from .tcp import TCPTransportServer
return TCPTransportServer()
@@ -39,9 +41,11 @@ def create_transport_client(address: str) -> TransportClient:
"""
if "/" in address or address.endswith(".sock"):
from .uds import UDSTransportClient
return UDSTransportClient(socket_path=address)
elif ":" in address:
from .tcp import TCPTransportClient
host, port_str = address.rsplit(":", 1)
return TCPTransportClient(host=host, port=int(port_str))
else:

View File

@@ -13,6 +13,7 @@ from .base import Connection, ConnectionHandler, TransportClient, TransportServe
class TCPConnection(Connection):
"""基于 TCP 的连接"""
pass

View File

@@ -15,6 +15,7 @@ from .base import Connection, ConnectionHandler, TransportClient, TransportServe
class UDSConnection(Connection):
"""基于 UDS 的连接"""
pass # 直接复用 Connection 基类的分帧读写
@@ -30,16 +31,17 @@ class UDSTransportServer(TransportServer):
if socket_path is None:
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
import uuid
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock")
socket_path = os.path.join(
tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock"
)
# 如果路径超出 UDS 限制,回退到更短的路径
if len(socket_path.encode()) > _UDS_PATH_MAX:
socket_path = os.path.join("/tmp", f"mb-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock")
if len(socket_path.encode()) > _UDS_PATH_MAX:
raise OSError(
f"UDS socket 路径过长 ({len(socket_path.encode())} > {_UDS_PATH_MAX} 字节): {socket_path}"
)
raise OSError(f"UDS socket 路径过长 ({len(socket_path.encode())} > {_UDS_PATH_MAX} 字节): {socket_path}")
self._socket_path = socket_path
self._server: Optional[asyncio.AbstractServer] = None