插件系统代码风格修复

This commit is contained in:
DrSmoothl
2026-03-13 11:07:19 +08:00
parent bcb7963d37
commit 8ac0aff479
15 changed files with 55 additions and 47 deletions

View File

@@ -31,7 +31,7 @@ class CapabilityService:
4. 执行实际操作并返回结果 4. 执行实际操作并返回结果
""" """
def __init__(self, policy_engine: PolicyEngine): def __init__(self, policy_engine: PolicyEngine) -> None:
self._policy = policy_engine self._policy = policy_engine
# capability_name -> implementation # capability_name -> implementation
self._implementations: Dict[str, CapabilityImpl] = {} self._implementations: Dict[str, CapabilityImpl] = {}

View File

@@ -11,10 +11,10 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from src.common.logger import get_logger
import re import re
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.host.component_registry") logger = get_logger("plugin_runtime.host.component_registry")
@@ -32,7 +32,7 @@ class RegisteredComponent:
component_type: str, component_type: str,
plugin_id: str, plugin_id: str,
metadata: Dict[str, Any], metadata: Dict[str, Any],
): ) -> None:
self.name = name self.name = name
self.full_name = f"{plugin_id}.{name}" self.full_name = f"{plugin_id}.{name}"
self.component_type = component_type self.component_type = component_type
@@ -57,7 +57,7 @@ class ComponentRegistry:
供业务层查询可用组件、匹配命令、调度 action/event 等。 供业务层查询可用组件、匹配命令、调度 action/event 等。
""" """
def __init__(self): def __init__(self) -> None:
# 全量索引 # 全量索引
self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp

View File

@@ -22,7 +22,7 @@ class PolicyEngine:
管理所有插件的能力令牌,提供授权校验。 管理所有插件的能力令牌,提供授权校验。
""" """
def __init__(self): def __init__(self) -> None:
self._tokens: Dict[str, CapabilityToken] = {} self._tokens: Dict[str, CapabilityToken] = {}
def register_plugin( def register_plugin(

View File

@@ -7,11 +7,11 @@
4. 优雅关停 4. 优雅关停
""" """
import logging as stdlib_logging
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import asyncio import asyncio
import contextlib import contextlib
import logging as stdlib_logging
import os import os
import sys import sys
@@ -559,9 +559,12 @@ class PluginSupervisor:
task.add_done_callback( task.add_done_callback(
lambda done_task: None lambda done_task: None
if self._stderr_drain_task is not done_task if self._stderr_drain_task is not done_task
else setattr(self, "_stderr_drain_task", None) else self._clear_stderr_drain_task()
) )
def _clear_stderr_drain_task(self) -> None:
self._stderr_drain_task = None
async def _drain_runner_stderr( async def _drain_runner_stderr(
self, self,
stream: asyncio.StreamReader, stream: asyncio.StreamReader,
@@ -578,8 +581,7 @@ class PluginSupervisor:
line = await stream.readline() line = await stream.readline()
if not line: if not line:
break break
message = line.decode(errors="replace").rstrip() if message := line.decode(errors="replace").rstrip():
if message:
# 将 stderr 输出以 WARNING 级展示: # 将 stderr 输出以 WARNING 级展示:
# 如果 Runner 正常运行,此流应当无输出; # 如果 Runner 正常运行,此流应当无输出;
# 有输出说明进程级错误发生,需要出现在主进程日志中 # 有输出说明进程级错误发生,需要出现在主进程日志中

View File

@@ -54,7 +54,7 @@ class ModificationRecord:
"""消息修改记录""" """消息修改记录"""
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed") __slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]): def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
self.stage = stage self.stage = stage
self.hook_name = hook_name self.hook_name = hook_name
self.timestamp = time.perf_counter() self.timestamp = time.perf_counter()
@@ -64,7 +64,7 @@ class ModificationRecord:
class WorkflowContext: class WorkflowContext:
"""Workflow 执行上下文""" """Workflow 执行上下文"""
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None): def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None) -> None:
self.trace_id = trace_id or uuid.uuid4().hex self.trace_id = trace_id or uuid.uuid4().hex
self.stream_id = stream_id self.stream_id = stream_id
self.timings: Dict[str, float] = {} self.timings: Dict[str, float] = {}
@@ -92,7 +92,7 @@ class WorkflowResult:
return_message: str = "", return_message: str = "",
stopped_at: str = "", stopped_at: str = "",
diagnostics: Optional[Dict[str, Any]] = None, diagnostics: Optional[Dict[str, Any]] = None,
): ) -> None:
self.status = status self.status = status
self.return_message = return_message self.return_message = return_message
self.stopped_at = stopped_at self.stopped_at = stopped_at
@@ -109,7 +109,7 @@ class WorkflowExecutor:
实现 stage-based pipeline + per-stage hook chain with priority + early return。 实现 stage-based pipeline + per-stage hook chain with priority + early return。
""" """
def __init__(self, registry: ComponentRegistry): def __init__(self, registry: ComponentRegistry) -> None:
self._registry = registry self._registry = registry
async def execute( async def execute(

View File

@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Tuple
import asyncio import asyncio
import os import os
from src.chat.message_receive.chat_manager import BotChatSession
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
@@ -852,14 +853,14 @@ class PluginRuntimeManager:
# ═════════════════════════════════════════════════════════ # ═════════════════════════════════════════════════════════
@staticmethod @staticmethod
def _serialize_stream(stream: Any) -> Dict[str, Any]: def _serialize_stream(stream: BotChatSession) -> Dict[str, Any]:
"""将 BotChatSession 序列化为可通过 RPC 传输的字典""" """将 BotChatSession 序列化为可通过 RPC 传输的字典"""
return { return {
"session_id": getattr(stream, "session_id", ""), "session_id": stream.session_id,
"platform": getattr(stream, "platform", ""), "platform": stream.platform,
"user_id": getattr(stream, "user_id", ""), "user_id": stream.user_id,
"group_id": getattr(stream, "group_id", ""), "group_id": stream.group_id,
"is_group_session": getattr(stream, "is_group_session", False), "is_group_session": stream.is_group_session,
} }
@staticmethod @staticmethod

View File

@@ -5,11 +5,12 @@
""" """
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import time from pydantic import BaseModel, Field
import logging as stdlib_logging import logging as stdlib_logging
import time
# ─── 协议常量 ────────────────────────────────────────────────────── # ─── 协议常量 ──────────────────────────────────────────────────────
@@ -35,7 +36,7 @@ class MessageType(str, Enum):
class RequestIdGenerator: class RequestIdGenerator:
"""单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio""" """单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio"""
def __init__(self, start: int = 1): def __init__(self, start: int = 1) -> None:
self._counter = start self._counter = start
def next(self) -> int: def next(self) -> int:

View File

@@ -39,7 +39,12 @@ class ErrorCode(str, Enum):
class RPCError(Exception): class RPCError(Exception):
"""RPC 调用异常""" """RPC 调用异常"""
def __init__(self, code: ErrorCode, message: str = "", details: Optional[Dict[str, Any]] = None): def __init__(
self,
code: ErrorCode,
message: str = "",
details: Optional[Dict[str, Any]] = None,
) -> None:
self.code = code self.code = code
self.message = message or code.value self.message = message or code.value
self.details = details or {} self.details = details or {}

View File

@@ -24,12 +24,13 @@ Host 端将其重放到主进程的 Logger以 plugin.<name> 为名)中,
""" """
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
import asyncio import asyncio
import collections import collections
import contextlib import contextlib
import json import json
import logging import logging
from typing import TYPE_CHECKING, List, Optional
from src.plugin_runtime.protocol.envelope import LogBatchPayload, LogEntry from src.plugin_runtime.protocol.envelope import LogBatchPayload, LogEntry

View File

@@ -70,7 +70,7 @@ class ManifestValidator:
RECOMMENDED_FIELDS = ["license", "keywords", "categories"] RECOMMENDED_FIELDS = ["license", "keywords", "categories"]
SUPPORTED_MANIFEST_VERSIONS = [1, 2] SUPPORTED_MANIFEST_VERSIONS = [1, 2]
def __init__(self, host_version: str = ""): def __init__(self, host_version: str = "") -> None:
self._host_version = host_version self._host_version = host_version
self.errors: List[str] = [] self.errors: List[str] = []
self.warnings: List[str] = [] self.warnings: List[str] = []

View File

@@ -30,7 +30,7 @@ class PluginMeta:
plugin_dir: str, plugin_dir: str,
plugin_instance: Any, plugin_instance: Any,
manifest: Dict[str, Any], manifest: Dict[str, Any],
): ) -> None:
self.plugin_id = plugin_id self.plugin_id = plugin_id
self.plugin_dir = plugin_dir self.plugin_dir = plugin_dir
self.instance = plugin_instance self.instance = plugin_instance
@@ -61,7 +61,7 @@ class PluginLoader:
- plugin.py: 插件入口模块(导出 create_plugin 工厂函数) - plugin.py: 插件入口模块(导出 create_plugin 工厂函数)
""" """
def __init__(self, host_version: str = ""): def __init__(self, host_version: str = "") -> None:
self._loaded_plugins: Dict[str, PluginMeta] = {} self._loaded_plugins: Dict[str, PluginMeta] = {}
self._failed_plugins: Dict[str, str] = {} self._failed_plugins: Dict[str, str] = {}
self._manifest_validator = ManifestValidator(host_version=host_version) self._manifest_validator = ManifestValidator(host_version=host_version)

View File

@@ -9,19 +9,17 @@
6. 转发插件的能力调用到 Host 6. 转发插件的能力调用到 Host
""" """
import logging as stdlib_logging from typing import Any, List, Optional
from typing import List, Optional
import asyncio import asyncio
import contextlib import contextlib
import inspect import inspect
import logging as stdlib_logging
import os import os
import signal import signal
import sys import sys
import time import time
from typing import Any
from src.common.logger import get_console_handler, get_logger, initialize_logging from src.common.logger import get_console_handler, get_logger, initialize_logging
from src.plugin_runtime import ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN from src.plugin_runtime import ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
from src.plugin_runtime.protocol.envelope import ( from src.plugin_runtime.protocol.envelope import (
@@ -543,9 +541,12 @@ async def _async_main() -> None:
runner = PluginRunner(host_address, session_token, plugin_dirs) runner = PluginRunner(host_address, session_token, plugin_dirs)
# 注册信号处理 # 注册信号处理
def _mark_runner_shutting_down() -> None:
runner._shutting_down = True
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT): for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, lambda: setattr(runner, "_shutting_down", True)) loop.add_signal_handler(sig, _mark_runner_shutting_down)
await runner.run() await runner.run()

View File

@@ -7,12 +7,12 @@
分帧协议4-byte big-endian length prefix + payload 分帧协议4-byte big-endian length prefix + payload
""" """
import asyncio
import contextlib
import struct
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Awaitable, Callable from typing import Awaitable, Callable
import asyncio
import struct
# 分帧常量 # 分帧常量
FRAME_HEADER_SIZE = 4 # 4 字节长度前缀 FRAME_HEADER_SIZE = 4 # 4 字节长度前缀
MAX_FRAME_SIZE = 16 * 1024 * 1024 # 16 MB 最大帧大小 MAX_FRAME_SIZE = 16 * 1024 * 1024 # 16 MB 最大帧大小
@@ -23,13 +23,13 @@ class ConnectionClosed(Exception):
pass pass
class Connection(ABC): class Connection:
"""单个连接的抽象 """单个连接的抽象
封装了底层 StreamReader/StreamWriter提供分帧读写能力。 封装了底层 StreamReader/StreamWriter提供分帧读写能力。
""" """
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
self._reader = reader self._reader = reader
self._writer = writer self._writer = writer
self._closed = False self._closed = False
@@ -57,19 +57,16 @@ class Connection(ABC):
if length > MAX_FRAME_SIZE: if length > MAX_FRAME_SIZE:
raise ValueError(f"帧大小 {length} 超过最大限制 {MAX_FRAME_SIZE}") raise ValueError(f"帧大小 {length} 超过最大限制 {MAX_FRAME_SIZE}")
# 读取 payload # 读取 payload
payload = await self._reader.readexactly(length) return await self._reader.readexactly(length)
return payload
async def close(self) -> None: async def close(self) -> None:
"""关闭连接""" """关闭连接"""
if self._closed: if self._closed:
return return
self._closed = True self._closed = True
try: with contextlib.suppress(Exception):
self._writer.close() self._writer.close()
await self._writer.wait_closed() await self._writer.wait_closed()
except Exception:
pass
@property @property
def is_closed(self) -> bool: def is_closed(self) -> bool:

View File

@@ -19,7 +19,7 @@ class TCPConnection(Connection):
class TCPTransportServer(TransportServer): class TCPTransportServer(TransportServer):
"""TCP 传输服务端(回退方案)""" """TCP 传输服务端(回退方案)"""
def __init__(self, host: str = "127.0.0.1", port: int = 0): def __init__(self, host: str = "127.0.0.1", port: int = 0) -> None:
self._host = host self._host = host
self._port = port # 0 表示自动分配 self._port = port # 0 表示自动分配
self._server: Optional[asyncio.AbstractServer] = None self._server: Optional[asyncio.AbstractServer] = None
@@ -52,7 +52,7 @@ class TCPTransportServer(TransportServer):
class TCPTransportClient(TransportClient): class TCPTransportClient(TransportClient):
"""TCP 传输客户端""" """TCP 传输客户端"""
def __init__(self, host: str, port: int): def __init__(self, host: str, port: int) -> None:
self._host = host self._host = host
self._port = port self._port = port

View File

@@ -26,7 +26,7 @@ _UDS_PATH_MAX = 104
class UDSTransportServer(TransportServer): class UDSTransportServer(TransportServer):
"""UDS 传输服务端""" """UDS 传输服务端"""
def __init__(self, socket_path: Optional[str] = None): def __init__(self, socket_path: Optional[str] = None) -> None:
if socket_path is None: if socket_path is None:
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞 # 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
import uuid import uuid
@@ -80,7 +80,7 @@ class UDSTransportServer(TransportServer):
class UDSTransportClient(TransportClient): class UDSTransportClient(TransportClient):
"""UDS 传输客户端""" """UDS 传输客户端"""
def __init__(self, socket_path: str): def __init__(self, socket_path: str) -> None:
self._socket_path = socket_path self._socket_path = socket_path
async def connect(self) -> Connection: async def connect(self) -> Connection: