插件系统代码风格修复

This commit is contained in:
DrSmoothl
2026-03-13 11:07:19 +08:00
parent bcb7963d37
commit 8ac0aff479
15 changed files with 55 additions and 47 deletions

View File

@@ -31,7 +31,7 @@ class CapabilityService:
4. 执行实际操作并返回结果
"""
def __init__(self, policy_engine: PolicyEngine):
def __init__(self, policy_engine: PolicyEngine) -> None:
self._policy = policy_engine
# capability_name -> implementation
self._implementations: Dict[str, CapabilityImpl] = {}

View File

@@ -11,10 +11,10 @@
from typing import Any, Dict, List, Optional
from src.common.logger import get_logger
import re
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.host.component_registry")
@@ -32,7 +32,7 @@ class RegisteredComponent:
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
):
) -> None:
self.name = name
self.full_name = f"{plugin_id}.{name}"
self.component_type = component_type
@@ -57,7 +57,7 @@ class ComponentRegistry:
供业务层查询可用组件、匹配命令、调度 action/event 等。
"""
def __init__(self):
def __init__(self) -> None:
# 全量索引
self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp

View File

@@ -22,7 +22,7 @@ class PolicyEngine:
管理所有插件的能力令牌,提供授权校验。
"""
def __init__(self):
def __init__(self) -> None:
self._tokens: Dict[str, CapabilityToken] = {}
def register_plugin(

View File

@@ -7,11 +7,11 @@
4. 优雅关停
"""
import logging as stdlib_logging
from typing import Any, Dict, List, Optional, Tuple
import asyncio
import contextlib
import logging as stdlib_logging
import os
import sys
@@ -559,9 +559,12 @@ class PluginSupervisor:
task.add_done_callback(
lambda done_task: None
if self._stderr_drain_task is not done_task
else setattr(self, "_stderr_drain_task", None)
else self._clear_stderr_drain_task()
)
def _clear_stderr_drain_task(self) -> None:
self._stderr_drain_task = None
async def _drain_runner_stderr(
self,
stream: asyncio.StreamReader,
@@ -578,8 +581,7 @@ class PluginSupervisor:
line = await stream.readline()
if not line:
break
message = line.decode(errors="replace").rstrip()
if message:
if message := line.decode(errors="replace").rstrip():
# 将 stderr 输出以 WARNING 级展示:
# 如果 Runner 正常运行,此流应当无输出;
# 有输出说明进程级错误发生,需要出现在主进程日志中

View File

@@ -54,7 +54,7 @@ class ModificationRecord:
"""消息修改记录"""
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]):
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
self.stage = stage
self.hook_name = hook_name
self.timestamp = time.perf_counter()
@@ -64,7 +64,7 @@ class ModificationRecord:
class WorkflowContext:
"""Workflow 执行上下文"""
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None):
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None) -> None:
self.trace_id = trace_id or uuid.uuid4().hex
self.stream_id = stream_id
self.timings: Dict[str, float] = {}
@@ -92,7 +92,7 @@ class WorkflowResult:
return_message: str = "",
stopped_at: str = "",
diagnostics: Optional[Dict[str, Any]] = None,
):
) -> None:
self.status = status
self.return_message = return_message
self.stopped_at = stopped_at
@@ -109,7 +109,7 @@ class WorkflowExecutor:
实现 stage-based pipeline + per-stage hook chain with priority + early return。
"""
def __init__(self, registry: ComponentRegistry):
def __init__(self, registry: ComponentRegistry) -> None:
self._registry = registry
async def execute(

View File

@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Tuple
import asyncio
import os
from src.chat.message_receive.chat_manager import BotChatSession
from src.common.logger import get_logger
from src.config.config import global_config
@@ -852,14 +853,14 @@ class PluginRuntimeManager:
# ═════════════════════════════════════════════════════════
@staticmethod
def _serialize_stream(stream: Any) -> Dict[str, Any]:
def _serialize_stream(stream: BotChatSession) -> Dict[str, Any]:
"""将 BotChatSession 序列化为可通过 RPC 传输的字典"""
return {
"session_id": getattr(stream, "session_id", ""),
"platform": getattr(stream, "platform", ""),
"user_id": getattr(stream, "user_id", ""),
"group_id": getattr(stream, "group_id", ""),
"is_group_session": getattr(stream, "is_group_session", False),
"session_id": stream.session_id,
"platform": stream.platform,
"user_id": stream.user_id,
"group_id": stream.group_id,
"is_group_session": stream.is_group_session,
}
@staticmethod

View File

@@ -5,11 +5,12 @@
"""
from enum import Enum
from pydantic import BaseModel, Field
from typing import Any, Dict, List, Optional
import time
from pydantic import BaseModel, Field
import logging as stdlib_logging
import time
# ─── 协议常量 ──────────────────────────────────────────────────────
@@ -35,7 +36,7 @@ class MessageType(str, Enum):
class RequestIdGenerator:
"""单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio"""
def __init__(self, start: int = 1):
def __init__(self, start: int = 1) -> None:
self._counter = start
def next(self) -> int:

View File

@@ -39,7 +39,12 @@ class ErrorCode(str, Enum):
class RPCError(Exception):
"""RPC 调用异常"""
def __init__(self, code: ErrorCode, message: str = "", details: Optional[Dict[str, Any]] = None):
def __init__(
self,
code: ErrorCode,
message: str = "",
details: Optional[Dict[str, Any]] = None,
) -> None:
self.code = code
self.message = message or code.value
self.details = details or {}

View File

@@ -24,12 +24,13 @@ Host 端将其重放到主进程的 Logger以 plugin.<name> 为名)中,
"""
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
import asyncio
import collections
import contextlib
import json
import logging
from typing import TYPE_CHECKING, List, Optional
from src.plugin_runtime.protocol.envelope import LogBatchPayload, LogEntry

View File

@@ -70,7 +70,7 @@ class ManifestValidator:
RECOMMENDED_FIELDS = ["license", "keywords", "categories"]
SUPPORTED_MANIFEST_VERSIONS = [1, 2]
def __init__(self, host_version: str = ""):
def __init__(self, host_version: str = "") -> None:
self._host_version = host_version
self.errors: List[str] = []
self.warnings: List[str] = []

View File

@@ -30,7 +30,7 @@ class PluginMeta:
plugin_dir: str,
plugin_instance: Any,
manifest: Dict[str, Any],
):
) -> None:
self.plugin_id = plugin_id
self.plugin_dir = plugin_dir
self.instance = plugin_instance
@@ -61,7 +61,7 @@ class PluginLoader:
- plugin.py: 插件入口模块(导出 create_plugin 工厂函数)
"""
def __init__(self, host_version: str = ""):
def __init__(self, host_version: str = "") -> None:
self._loaded_plugins: Dict[str, PluginMeta] = {}
self._failed_plugins: Dict[str, str] = {}
self._manifest_validator = ManifestValidator(host_version=host_version)

View File

@@ -9,19 +9,17 @@
6. 转发插件的能力调用到 Host
"""
import logging as stdlib_logging
from typing import List, Optional
from typing import Any, List, Optional
import asyncio
import contextlib
import inspect
import logging as stdlib_logging
import os
import signal
import sys
import time
from typing import Any
from src.common.logger import get_console_handler, get_logger, initialize_logging
from src.plugin_runtime import ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
from src.plugin_runtime.protocol.envelope import (
@@ -543,9 +541,12 @@ async def _async_main() -> None:
runner = PluginRunner(host_address, session_token, plugin_dirs)
# 注册信号处理
def _mark_runner_shutting_down() -> None:
runner._shutting_down = True
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, lambda: setattr(runner, "_shutting_down", True))
loop.add_signal_handler(sig, _mark_runner_shutting_down)
await runner.run()

View File

@@ -7,12 +7,12 @@
分帧协议4-byte big-endian length prefix + payload
"""
import asyncio
import contextlib
import struct
from abc import ABC, abstractmethod
from typing import Awaitable, Callable
import asyncio
import struct
# 分帧常量
FRAME_HEADER_SIZE = 4 # 4 字节长度前缀
MAX_FRAME_SIZE = 16 * 1024 * 1024 # 16 MB 最大帧大小
@@ -23,13 +23,13 @@ class ConnectionClosed(Exception):
pass
class Connection(ABC):
class Connection:
"""单个连接的抽象
封装了底层 StreamReader/StreamWriter提供分帧读写能力。
"""
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
self._reader = reader
self._writer = writer
self._closed = False
@@ -57,19 +57,16 @@ class Connection(ABC):
if length > MAX_FRAME_SIZE:
raise ValueError(f"帧大小 {length} 超过最大限制 {MAX_FRAME_SIZE}")
# 读取 payload
payload = await self._reader.readexactly(length)
return payload
return await self._reader.readexactly(length)
async def close(self) -> None:
"""关闭连接"""
if self._closed:
return
self._closed = True
try:
with contextlib.suppress(Exception):
self._writer.close()
await self._writer.wait_closed()
except Exception:
pass
@property
def is_closed(self) -> bool:

View File

@@ -19,7 +19,7 @@ class TCPConnection(Connection):
class TCPTransportServer(TransportServer):
"""TCP 传输服务端(回退方案)"""
def __init__(self, host: str = "127.0.0.1", port: int = 0):
def __init__(self, host: str = "127.0.0.1", port: int = 0) -> None:
self._host = host
self._port = port # 0 表示自动分配
self._server: Optional[asyncio.AbstractServer] = None
@@ -52,7 +52,7 @@ class TCPTransportServer(TransportServer):
class TCPTransportClient(TransportClient):
"""TCP 传输客户端"""
def __init__(self, host: str, port: int):
def __init__(self, host: str, port: int) -> None:
self._host = host
self._port = port

View File

@@ -26,7 +26,7 @@ _UDS_PATH_MAX = 104
class UDSTransportServer(TransportServer):
"""UDS 传输服务端"""
def __init__(self, socket_path: Optional[str] = None):
def __init__(self, socket_path: Optional[str] = None) -> None:
if socket_path is None:
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
import uuid
@@ -80,7 +80,7 @@ class UDSTransportServer(TransportServer):
class UDSTransportClient(TransportClient):
"""UDS 传输客户端"""
def __init__(self, socket_path: str):
def __init__(self, socket_path: str) -> None:
self._socket_path = socket_path
async def connect(self) -> Connection: