From 8ac0aff4796651ecf69005e87d430472512d0d8e Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Fri, 13 Mar 2026 11:07:19 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=92=E4=BB=B6=E7=B3=BB=E7=BB=9F=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E9=A3=8E=E6=A0=BC=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_runtime/host/capability_service.py | 2 +- src/plugin_runtime/host/component_registry.py | 8 ++++---- src/plugin_runtime/host/policy_engine.py | 2 +- src/plugin_runtime/host/supervisor.py | 10 ++++++---- src/plugin_runtime/host/workflow_executor.py | 8 ++++---- src/plugin_runtime/integration.py | 13 +++++++------ src/plugin_runtime/protocol/envelope.py | 7 ++++--- src/plugin_runtime/protocol/errors.py | 7 ++++++- src/plugin_runtime/runner/log_handler.py | 3 ++- src/plugin_runtime/runner/manifest_validator.py | 2 +- src/plugin_runtime/runner/plugin_loader.py | 4 ++-- src/plugin_runtime/runner/runner_main.py | 11 ++++++----- src/plugin_runtime/transport/base.py | 17 +++++++---------- src/plugin_runtime/transport/tcp.py | 4 ++-- src/plugin_runtime/transport/uds.py | 4 ++-- 15 files changed, 55 insertions(+), 47 deletions(-) diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py index 643fef95..3f3a000c 100644 --- a/src/plugin_runtime/host/capability_service.py +++ b/src/plugin_runtime/host/capability_service.py @@ -31,7 +31,7 @@ class CapabilityService: 4. 执行实际操作并返回结果 """ - def __init__(self, policy_engine: PolicyEngine): + def __init__(self, policy_engine: PolicyEngine) -> None: self._policy = policy_engine # capability_name -> implementation self._implementations: Dict[str, CapabilityImpl] = {} diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index d693883f..4abbd08e 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -11,10 +11,10 @@ from typing import Any, Dict, List, Optional -from src.common.logger import get_logger - import re +from src.common.logger import get_logger + logger = get_logger("plugin_runtime.host.component_registry") @@ -32,7 +32,7 @@ class RegisteredComponent: component_type: str, plugin_id: str, metadata: Dict[str, Any], - ): + ) -> None: self.name = name self.full_name = f"{plugin_id}.{name}" self.component_type = component_type @@ -57,7 +57,7 @@ class ComponentRegistry: 供业务层查询可用组件、匹配命令、调度 action/event 等。 """ - def __init__(self): + def __init__(self) -> None: # 全量索引 self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp diff --git a/src/plugin_runtime/host/policy_engine.py b/src/plugin_runtime/host/policy_engine.py index 8327e56c..9ce7ccde 100644 --- a/src/plugin_runtime/host/policy_engine.py +++ b/src/plugin_runtime/host/policy_engine.py @@ -22,7 +22,7 @@ class PolicyEngine: 管理所有插件的能力令牌,提供授权校验。 """ - def __init__(self): + def __init__(self) -> None: self._tokens: Dict[str, CapabilityToken] = {} def register_plugin( diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 6e1c22fe..e65a9f66 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -7,11 +7,11 @@ 4. 优雅关停 """ -import logging as stdlib_logging from typing import Any, Dict, List, Optional, Tuple import asyncio import contextlib +import logging as stdlib_logging import os import sys @@ -559,9 +559,12 @@ class PluginSupervisor: task.add_done_callback( lambda done_task: None if self._stderr_drain_task is not done_task - else setattr(self, "_stderr_drain_task", None) + else self._clear_stderr_drain_task() ) + def _clear_stderr_drain_task(self) -> None: + self._stderr_drain_task = None + async def _drain_runner_stderr( self, stream: asyncio.StreamReader, @@ -578,8 +581,7 @@ class PluginSupervisor: line = await stream.readline() if not line: break - message = line.decode(errors="replace").rstrip() - if message: + if message := line.decode(errors="replace").rstrip(): # 将 stderr 输出以 WARNING 级展示: # 如果 Runner 正常运行,此流应当无输出; # 有输出说明进程级错误发生,需要出现在主进程日志中 diff --git a/src/plugin_runtime/host/workflow_executor.py b/src/plugin_runtime/host/workflow_executor.py index 316f6c10..60fe7d07 100644 --- a/src/plugin_runtime/host/workflow_executor.py +++ b/src/plugin_runtime/host/workflow_executor.py @@ -54,7 +54,7 @@ class ModificationRecord: """消息修改记录""" __slots__ = ("stage", "hook_name", "timestamp", "fields_changed") - def __init__(self, stage: str, hook_name: str, fields_changed: List[str]): + def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None: self.stage = stage self.hook_name = hook_name self.timestamp = time.perf_counter() @@ -64,7 +64,7 @@ class ModificationRecord: class WorkflowContext: """Workflow 执行上下文""" - def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None): + def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None) -> None: self.trace_id = trace_id or uuid.uuid4().hex self.stream_id = stream_id self.timings: Dict[str, float] = {} @@ -92,7 +92,7 @@ class WorkflowResult: return_message: str = "", stopped_at: str = "", diagnostics: Optional[Dict[str, Any]] = None, - ): + ) -> None: self.status = status self.return_message = return_message self.stopped_at = stopped_at @@ -109,7 +109,7 @@ class WorkflowExecutor: 实现 stage-based pipeline + per-stage hook chain with priority + early return。 """ - def __init__(self, registry: ComponentRegistry): + def __init__(self, registry: ComponentRegistry) -> None: self._registry = registry async def execute( diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index 8842fa31..d070ef15 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Tuple import asyncio import os +from src.chat.message_receive.chat_manager import BotChatSession from src.common.logger import get_logger from src.config.config import global_config @@ -852,14 +853,14 @@ class PluginRuntimeManager: # ═════════════════════════════════════════════════════════ @staticmethod - def _serialize_stream(stream: Any) -> Dict[str, Any]: + def _serialize_stream(stream: BotChatSession) -> Dict[str, Any]: """将 BotChatSession 序列化为可通过 RPC 传输的字典""" return { - "session_id": getattr(stream, "session_id", ""), - "platform": getattr(stream, "platform", ""), - "user_id": getattr(stream, "user_id", ""), - "group_id": getattr(stream, "group_id", ""), - "is_group_session": getattr(stream, "is_group_session", False), + "session_id": stream.session_id, + "platform": stream.platform, + "user_id": stream.user_id, + "group_id": stream.group_id, + "is_group_session": stream.is_group_session, } @staticmethod diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index 41fcf764..0b276c7c 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -5,11 +5,12 @@ """ from enum import Enum -from pydantic import BaseModel, Field from typing import Any, Dict, List, Optional -import time +from pydantic import BaseModel, Field + import logging as stdlib_logging +import time # ─── 协议常量 ────────────────────────────────────────────────────── @@ -35,7 +36,7 @@ class MessageType(str, Enum): class RequestIdGenerator: """单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio)""" - def __init__(self, start: int = 1): + def __init__(self, start: int = 1) -> None: self._counter = start def next(self) -> int: diff --git a/src/plugin_runtime/protocol/errors.py b/src/plugin_runtime/protocol/errors.py index 30242fdc..dcae6b8f 100644 --- a/src/plugin_runtime/protocol/errors.py +++ b/src/plugin_runtime/protocol/errors.py @@ -39,7 +39,12 @@ class ErrorCode(str, Enum): class RPCError(Exception): """RPC 调用异常""" - def __init__(self, code: ErrorCode, message: str = "", details: Optional[Dict[str, Any]] = None): + def __init__( + self, + code: ErrorCode, + message: str = "", + details: Optional[Dict[str, Any]] = None, + ) -> None: self.code = code self.message = message or code.value self.details = details or {} diff --git a/src/plugin_runtime/runner/log_handler.py b/src/plugin_runtime/runner/log_handler.py index 3151f7c1..74b4a3e9 100644 --- a/src/plugin_runtime/runner/log_handler.py +++ b/src/plugin_runtime/runner/log_handler.py @@ -24,12 +24,13 @@ Host 端将其重放到主进程的 Logger(以 plugin. 为名)中, """ from __future__ import annotations +from typing import TYPE_CHECKING, List, Optional + import asyncio import collections import contextlib import json import logging -from typing import TYPE_CHECKING, List, Optional from src.plugin_runtime.protocol.envelope import LogBatchPayload, LogEntry diff --git a/src/plugin_runtime/runner/manifest_validator.py b/src/plugin_runtime/runner/manifest_validator.py index d9368d11..bc407e6d 100644 --- a/src/plugin_runtime/runner/manifest_validator.py +++ b/src/plugin_runtime/runner/manifest_validator.py @@ -70,7 +70,7 @@ class ManifestValidator: RECOMMENDED_FIELDS = ["license", "keywords", "categories"] SUPPORTED_MANIFEST_VERSIONS = [1, 2] - def __init__(self, host_version: str = ""): + def __init__(self, host_version: str = "") -> None: self._host_version = host_version self.errors: List[str] = [] self.warnings: List[str] = [] diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py index 00201d16..50767363 100644 --- a/src/plugin_runtime/runner/plugin_loader.py +++ b/src/plugin_runtime/runner/plugin_loader.py @@ -30,7 +30,7 @@ class PluginMeta: plugin_dir: str, plugin_instance: Any, manifest: Dict[str, Any], - ): + ) -> None: self.plugin_id = plugin_id self.plugin_dir = plugin_dir self.instance = plugin_instance @@ -61,7 +61,7 @@ class PluginLoader: - plugin.py: 插件入口模块(导出 create_plugin 工厂函数) """ - def __init__(self, host_version: str = ""): + def __init__(self, host_version: str = "") -> None: self._loaded_plugins: Dict[str, PluginMeta] = {} self._failed_plugins: Dict[str, str] = {} self._manifest_validator = ManifestValidator(host_version=host_version) diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index 24ddaf24..46e9b641 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -9,19 +9,17 @@ 6. 转发插件的能力调用到 Host """ -import logging as stdlib_logging -from typing import List, Optional +from typing import Any, List, Optional import asyncio import contextlib import inspect +import logging as stdlib_logging import os import signal import sys import time -from typing import Any - from src.common.logger import get_console_handler, get_logger, initialize_logging from src.plugin_runtime import ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN from src.plugin_runtime.protocol.envelope import ( @@ -543,9 +541,12 @@ async def _async_main() -> None: runner = PluginRunner(host_address, session_token, plugin_dirs) # 注册信号处理 + def _mark_runner_shutting_down() -> None: + runner._shutting_down = True + loop = asyncio.get_event_loop() for sig in (signal.SIGTERM, signal.SIGINT): - loop.add_signal_handler(sig, lambda: setattr(runner, "_shutting_down", True)) + loop.add_signal_handler(sig, _mark_runner_shutting_down) await runner.run() diff --git a/src/plugin_runtime/transport/base.py b/src/plugin_runtime/transport/base.py index 19dc40a0..229de0b1 100644 --- a/src/plugin_runtime/transport/base.py +++ b/src/plugin_runtime/transport/base.py @@ -7,12 +7,12 @@ 分帧协议:4-byte big-endian length prefix + payload """ +import asyncio +import contextlib +import struct from abc import ABC, abstractmethod from typing import Awaitable, Callable -import asyncio -import struct - # 分帧常量 FRAME_HEADER_SIZE = 4 # 4 字节长度前缀 MAX_FRAME_SIZE = 16 * 1024 * 1024 # 16 MB 最大帧大小 @@ -23,13 +23,13 @@ class ConnectionClosed(Exception): pass -class Connection(ABC): +class Connection: """单个连接的抽象 封装了底层 StreamReader/StreamWriter,提供分帧读写能力。 """ - def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: self._reader = reader self._writer = writer self._closed = False @@ -57,19 +57,16 @@ class Connection(ABC): if length > MAX_FRAME_SIZE: raise ValueError(f"帧大小 {length} 超过最大限制 {MAX_FRAME_SIZE}") # 读取 payload - payload = await self._reader.readexactly(length) - return payload + return await self._reader.readexactly(length) async def close(self) -> None: """关闭连接""" if self._closed: return self._closed = True - try: + with contextlib.suppress(Exception): self._writer.close() await self._writer.wait_closed() - except Exception: - pass @property def is_closed(self) -> bool: diff --git a/src/plugin_runtime/transport/tcp.py b/src/plugin_runtime/transport/tcp.py index acf9cf87..870d33b3 100644 --- a/src/plugin_runtime/transport/tcp.py +++ b/src/plugin_runtime/transport/tcp.py @@ -19,7 +19,7 @@ class TCPConnection(Connection): class TCPTransportServer(TransportServer): """TCP 传输服务端(回退方案)""" - def __init__(self, host: str = "127.0.0.1", port: int = 0): + def __init__(self, host: str = "127.0.0.1", port: int = 0) -> None: self._host = host self._port = port # 0 表示自动分配 self._server: Optional[asyncio.AbstractServer] = None @@ -52,7 +52,7 @@ class TCPTransportServer(TransportServer): class TCPTransportClient(TransportClient): """TCP 传输客户端""" - def __init__(self, host: str, port: int): + def __init__(self, host: str, port: int) -> None: self._host = host self._port = port diff --git a/src/plugin_runtime/transport/uds.py b/src/plugin_runtime/transport/uds.py index 209645a5..0ea885e7 100644 --- a/src/plugin_runtime/transport/uds.py +++ b/src/plugin_runtime/transport/uds.py @@ -26,7 +26,7 @@ _UDS_PATH_MAX = 104 class UDSTransportServer(TransportServer): """UDS 传输服务端""" - def __init__(self, socket_path: Optional[str] = None): + def __init__(self, socket_path: Optional[str] = None) -> None: if socket_path is None: # 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞 import uuid @@ -80,7 +80,7 @@ class UDSTransportServer(TransportServer): class UDSTransportClient(TransportClient): """UDS 传输客户端""" - def __init__(self, socket_path: str): + def __init__(self, socket_path: str) -> None: self._socket_path = socket_path async def connect(self) -> Connection: