插件系统代码风格修复
This commit is contained in:
@@ -31,7 +31,7 @@ class CapabilityService:
|
||||
4. 执行实际操作并返回结果
|
||||
"""
|
||||
|
||||
def __init__(self, policy_engine: PolicyEngine):
|
||||
def __init__(self, policy_engine: PolicyEngine) -> None:
|
||||
self._policy = policy_engine
|
||||
# capability_name -> implementation
|
||||
self._implementations: Dict[str, CapabilityImpl] = {}
|
||||
|
||||
@@ -11,10 +11,10 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
import re
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("plugin_runtime.host.component_registry")
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ class RegisteredComponent:
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: Dict[str, Any],
|
||||
):
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.full_name = f"{plugin_id}.{name}"
|
||||
self.component_type = component_type
|
||||
@@ -57,7 +57,7 @@ class ComponentRegistry:
|
||||
供业务层查询可用组件、匹配命令、调度 action/event 等。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
# 全量索引
|
||||
self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class PolicyEngine:
|
||||
管理所有插件的能力令牌,提供授权校验。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._tokens: Dict[str, CapabilityToken] = {}
|
||||
|
||||
def register_plugin(
|
||||
|
||||
@@ -7,11 +7,11 @@
|
||||
4. 优雅关停
|
||||
"""
|
||||
|
||||
import logging as stdlib_logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging as stdlib_logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
@@ -559,9 +559,12 @@ class PluginSupervisor:
|
||||
task.add_done_callback(
|
||||
lambda done_task: None
|
||||
if self._stderr_drain_task is not done_task
|
||||
else setattr(self, "_stderr_drain_task", None)
|
||||
else self._clear_stderr_drain_task()
|
||||
)
|
||||
|
||||
def _clear_stderr_drain_task(self) -> None:
|
||||
self._stderr_drain_task = None
|
||||
|
||||
async def _drain_runner_stderr(
|
||||
self,
|
||||
stream: asyncio.StreamReader,
|
||||
@@ -578,8 +581,7 @@ class PluginSupervisor:
|
||||
line = await stream.readline()
|
||||
if not line:
|
||||
break
|
||||
message = line.decode(errors="replace").rstrip()
|
||||
if message:
|
||||
if message := line.decode(errors="replace").rstrip():
|
||||
# 将 stderr 输出以 WARNING 级展示:
|
||||
# 如果 Runner 正常运行,此流应当无输出;
|
||||
# 有输出说明进程级错误发生,需要出现在主进程日志中
|
||||
|
||||
@@ -54,7 +54,7 @@ class ModificationRecord:
|
||||
"""消息修改记录"""
|
||||
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
|
||||
|
||||
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]):
|
||||
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
|
||||
self.stage = stage
|
||||
self.hook_name = hook_name
|
||||
self.timestamp = time.perf_counter()
|
||||
@@ -64,7 +64,7 @@ class ModificationRecord:
|
||||
class WorkflowContext:
|
||||
"""Workflow 执行上下文"""
|
||||
|
||||
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None):
|
||||
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None) -> None:
|
||||
self.trace_id = trace_id or uuid.uuid4().hex
|
||||
self.stream_id = stream_id
|
||||
self.timings: Dict[str, float] = {}
|
||||
@@ -92,7 +92,7 @@ class WorkflowResult:
|
||||
return_message: str = "",
|
||||
stopped_at: str = "",
|
||||
diagnostics: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
) -> None:
|
||||
self.status = status
|
||||
self.return_message = return_message
|
||||
self.stopped_at = stopped_at
|
||||
@@ -109,7 +109,7 @@ class WorkflowExecutor:
|
||||
实现 stage-based pipeline + per-stage hook chain with priority + early return。
|
||||
"""
|
||||
|
||||
def __init__(self, registry: ComponentRegistry):
|
||||
def __init__(self, registry: ComponentRegistry) -> None:
|
||||
self._registry = registry
|
||||
|
||||
async def execute(
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
@@ -852,14 +853,14 @@ class PluginRuntimeManager:
|
||||
# ═════════════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
def _serialize_stream(stream: Any) -> Dict[str, Any]:
|
||||
def _serialize_stream(stream: BotChatSession) -> Dict[str, Any]:
|
||||
"""将 BotChatSession 序列化为可通过 RPC 传输的字典"""
|
||||
return {
|
||||
"session_id": getattr(stream, "session_id", ""),
|
||||
"platform": getattr(stream, "platform", ""),
|
||||
"user_id": getattr(stream, "user_id", ""),
|
||||
"group_id": getattr(stream, "group_id", ""),
|
||||
"is_group_session": getattr(stream, "is_group_session", False),
|
||||
"session_id": stream.session_id,
|
||||
"platform": stream.platform,
|
||||
"user_id": stream.user_id,
|
||||
"group_id": stream.group_id,
|
||||
"is_group_session": stream.is_group_session,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -5,11 +5,12 @@
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import time
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import logging as stdlib_logging
|
||||
import time
|
||||
|
||||
|
||||
# ─── 协议常量 ──────────────────────────────────────────────────────
|
||||
@@ -35,7 +36,7 @@ class MessageType(str, Enum):
|
||||
class RequestIdGenerator:
|
||||
"""单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio)"""
|
||||
|
||||
def __init__(self, start: int = 1):
|
||||
def __init__(self, start: int = 1) -> None:
|
||||
self._counter = start
|
||||
|
||||
def next(self) -> int:
|
||||
|
||||
@@ -39,7 +39,12 @@ class ErrorCode(str, Enum):
|
||||
class RPCError(Exception):
|
||||
"""RPC 调用异常"""
|
||||
|
||||
def __init__(self, code: ErrorCode, message: str = "", details: Optional[Dict[str, Any]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
code: ErrorCode,
|
||||
message: str = "",
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.code = code
|
||||
self.message = message or code.value
|
||||
self.details = details or {}
|
||||
|
||||
@@ -24,12 +24,13 @@ Host 端将其重放到主进程的 Logger(以 plugin.<name> 为名)中,
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from src.plugin_runtime.protocol.envelope import LogBatchPayload, LogEntry
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ class ManifestValidator:
|
||||
RECOMMENDED_FIELDS = ["license", "keywords", "categories"]
|
||||
SUPPORTED_MANIFEST_VERSIONS = [1, 2]
|
||||
|
||||
def __init__(self, host_version: str = ""):
|
||||
def __init__(self, host_version: str = "") -> None:
|
||||
self._host_version = host_version
|
||||
self.errors: List[str] = []
|
||||
self.warnings: List[str] = []
|
||||
|
||||
@@ -30,7 +30,7 @@ class PluginMeta:
|
||||
plugin_dir: str,
|
||||
plugin_instance: Any,
|
||||
manifest: Dict[str, Any],
|
||||
):
|
||||
) -> None:
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_dir = plugin_dir
|
||||
self.instance = plugin_instance
|
||||
@@ -61,7 +61,7 @@ class PluginLoader:
|
||||
- plugin.py: 插件入口模块(导出 create_plugin 工厂函数)
|
||||
"""
|
||||
|
||||
def __init__(self, host_version: str = ""):
|
||||
def __init__(self, host_version: str = "") -> None:
|
||||
self._loaded_plugins: Dict[str, PluginMeta] = {}
|
||||
self._failed_plugins: Dict[str, str] = {}
|
||||
self._manifest_validator = ManifestValidator(host_version=host_version)
|
||||
|
||||
@@ -9,19 +9,17 @@
|
||||
6. 转发插件的能力调用到 Host
|
||||
"""
|
||||
|
||||
import logging as stdlib_logging
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging as stdlib_logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_console_handler, get_logger, initialize_logging
|
||||
from src.plugin_runtime import ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
@@ -543,9 +541,12 @@ async def _async_main() -> None:
|
||||
runner = PluginRunner(host_address, session_token, plugin_dirs)
|
||||
|
||||
# 注册信号处理
|
||||
def _mark_runner_shutting_down() -> None:
|
||||
runner._shutting_down = True
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, lambda: setattr(runner, "_shutting_down", True))
|
||||
loop.add_signal_handler(sig, _mark_runner_shutting_down)
|
||||
|
||||
await runner.run()
|
||||
|
||||
|
||||
@@ -7,12 +7,12 @@
|
||||
分帧协议:4-byte big-endian length prefix + payload
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import struct
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
import asyncio
|
||||
import struct
|
||||
|
||||
# 分帧常量
|
||||
FRAME_HEADER_SIZE = 4 # 4 字节长度前缀
|
||||
MAX_FRAME_SIZE = 16 * 1024 * 1024 # 16 MB 最大帧大小
|
||||
@@ -23,13 +23,13 @@ class ConnectionClosed(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Connection(ABC):
|
||||
class Connection:
|
||||
"""单个连接的抽象
|
||||
|
||||
封装了底层 StreamReader/StreamWriter,提供分帧读写能力。
|
||||
"""
|
||||
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
self._closed = False
|
||||
@@ -57,19 +57,16 @@ class Connection(ABC):
|
||||
if length > MAX_FRAME_SIZE:
|
||||
raise ValueError(f"帧大小 {length} 超过最大限制 {MAX_FRAME_SIZE}")
|
||||
# 读取 payload
|
||||
payload = await self._reader.readexactly(length)
|
||||
return payload
|
||||
return await self._reader.readexactly(length)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭连接"""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
self._writer.close()
|
||||
await self._writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
|
||||
@@ -19,7 +19,7 @@ class TCPConnection(Connection):
|
||||
class TCPTransportServer(TransportServer):
|
||||
"""TCP 传输服务端(回退方案)"""
|
||||
|
||||
def __init__(self, host: str = "127.0.0.1", port: int = 0):
|
||||
def __init__(self, host: str = "127.0.0.1", port: int = 0) -> None:
|
||||
self._host = host
|
||||
self._port = port # 0 表示自动分配
|
||||
self._server: Optional[asyncio.AbstractServer] = None
|
||||
@@ -52,7 +52,7 @@ class TCPTransportServer(TransportServer):
|
||||
class TCPTransportClient(TransportClient):
|
||||
"""TCP 传输客户端"""
|
||||
|
||||
def __init__(self, host: str, port: int):
|
||||
def __init__(self, host: str, port: int) -> None:
|
||||
self._host = host
|
||||
self._port = port
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ _UDS_PATH_MAX = 104
|
||||
class UDSTransportServer(TransportServer):
|
||||
"""UDS 传输服务端"""
|
||||
|
||||
def __init__(self, socket_path: Optional[str] = None):
|
||||
def __init__(self, socket_path: Optional[str] = None) -> None:
|
||||
if socket_path is None:
|
||||
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
|
||||
import uuid
|
||||
@@ -80,7 +80,7 @@ class UDSTransportServer(TransportServer):
|
||||
class UDSTransportClient(TransportClient):
|
||||
"""UDS 传输客户端"""
|
||||
|
||||
def __init__(self, socket_path: str):
|
||||
def __init__(self, socket_path: str) -> None:
|
||||
self._socket_path = socket_path
|
||||
|
||||
async def connect(self) -> Connection:
|
||||
|
||||
Reference in New Issue
Block a user