插件系统代码风格修复
This commit is contained in:
@@ -31,7 +31,7 @@ class CapabilityService:
|
|||||||
4. 执行实际操作并返回结果
|
4. 执行实际操作并返回结果
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, policy_engine: PolicyEngine):
|
def __init__(self, policy_engine: PolicyEngine) -> None:
|
||||||
self._policy = policy_engine
|
self._policy = policy_engine
|
||||||
# capability_name -> implementation
|
# capability_name -> implementation
|
||||||
self._implementations: Dict[str, CapabilityImpl] = {}
|
self._implementations: Dict[str, CapabilityImpl] = {}
|
||||||
|
|||||||
@@ -11,10 +11,10 @@
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("plugin_runtime.host.component_registry")
|
logger = get_logger("plugin_runtime.host.component_registry")
|
||||||
|
|
||||||
|
|
||||||
@@ -32,7 +32,7 @@ class RegisteredComponent:
|
|||||||
component_type: str,
|
component_type: str,
|
||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
metadata: Dict[str, Any],
|
metadata: Dict[str, Any],
|
||||||
):
|
) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.full_name = f"{plugin_id}.{name}"
|
self.full_name = f"{plugin_id}.{name}"
|
||||||
self.component_type = component_type
|
self.component_type = component_type
|
||||||
@@ -57,7 +57,7 @@ class ComponentRegistry:
|
|||||||
供业务层查询可用组件、匹配命令、调度 action/event 等。
|
供业务层查询可用组件、匹配命令、调度 action/event 等。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# 全量索引
|
# 全量索引
|
||||||
self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp
|
self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ class PolicyEngine:
|
|||||||
管理所有插件的能力令牌,提供授权校验。
|
管理所有插件的能力令牌,提供授权校验。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self._tokens: Dict[str, CapabilityToken] = {}
|
self._tokens: Dict[str, CapabilityToken] = {}
|
||||||
|
|
||||||
def register_plugin(
|
def register_plugin(
|
||||||
|
|||||||
@@ -7,11 +7,11 @@
|
|||||||
4. 优雅关停
|
4. 优雅关停
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging as stdlib_logging
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import logging as stdlib_logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@@ -559,9 +559,12 @@ class PluginSupervisor:
|
|||||||
task.add_done_callback(
|
task.add_done_callback(
|
||||||
lambda done_task: None
|
lambda done_task: None
|
||||||
if self._stderr_drain_task is not done_task
|
if self._stderr_drain_task is not done_task
|
||||||
else setattr(self, "_stderr_drain_task", None)
|
else self._clear_stderr_drain_task()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _clear_stderr_drain_task(self) -> None:
|
||||||
|
self._stderr_drain_task = None
|
||||||
|
|
||||||
async def _drain_runner_stderr(
|
async def _drain_runner_stderr(
|
||||||
self,
|
self,
|
||||||
stream: asyncio.StreamReader,
|
stream: asyncio.StreamReader,
|
||||||
@@ -578,8 +581,7 @@ class PluginSupervisor:
|
|||||||
line = await stream.readline()
|
line = await stream.readline()
|
||||||
if not line:
|
if not line:
|
||||||
break
|
break
|
||||||
message = line.decode(errors="replace").rstrip()
|
if message := line.decode(errors="replace").rstrip():
|
||||||
if message:
|
|
||||||
# 将 stderr 输出以 WARNING 级展示:
|
# 将 stderr 输出以 WARNING 级展示:
|
||||||
# 如果 Runner 正常运行,此流应当无输出;
|
# 如果 Runner 正常运行,此流应当无输出;
|
||||||
# 有输出说明进程级错误发生,需要出现在主进程日志中
|
# 有输出说明进程级错误发生,需要出现在主进程日志中
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class ModificationRecord:
|
|||||||
"""消息修改记录"""
|
"""消息修改记录"""
|
||||||
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
|
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
|
||||||
|
|
||||||
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]):
|
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
|
||||||
self.stage = stage
|
self.stage = stage
|
||||||
self.hook_name = hook_name
|
self.hook_name = hook_name
|
||||||
self.timestamp = time.perf_counter()
|
self.timestamp = time.perf_counter()
|
||||||
@@ -64,7 +64,7 @@ class ModificationRecord:
|
|||||||
class WorkflowContext:
|
class WorkflowContext:
|
||||||
"""Workflow 执行上下文"""
|
"""Workflow 执行上下文"""
|
||||||
|
|
||||||
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None):
|
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None) -> None:
|
||||||
self.trace_id = trace_id or uuid.uuid4().hex
|
self.trace_id = trace_id or uuid.uuid4().hex
|
||||||
self.stream_id = stream_id
|
self.stream_id = stream_id
|
||||||
self.timings: Dict[str, float] = {}
|
self.timings: Dict[str, float] = {}
|
||||||
@@ -92,7 +92,7 @@ class WorkflowResult:
|
|||||||
return_message: str = "",
|
return_message: str = "",
|
||||||
stopped_at: str = "",
|
stopped_at: str = "",
|
||||||
diagnostics: Optional[Dict[str, Any]] = None,
|
diagnostics: Optional[Dict[str, Any]] = None,
|
||||||
):
|
) -> None:
|
||||||
self.status = status
|
self.status = status
|
||||||
self.return_message = return_message
|
self.return_message = return_message
|
||||||
self.stopped_at = stopped_at
|
self.stopped_at = stopped_at
|
||||||
@@ -109,7 +109,7 @@ class WorkflowExecutor:
|
|||||||
实现 stage-based pipeline + per-stage hook chain with priority + early return。
|
实现 stage-based pipeline + per-stage hook chain with priority + early return。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, registry: ComponentRegistry):
|
def __init__(self, registry: ComponentRegistry) -> None:
|
||||||
self._registry = registry
|
self._registry = registry
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from src.chat.message_receive.chat_manager import BotChatSession
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
@@ -852,14 +853,14 @@ class PluginRuntimeManager:
|
|||||||
# ═════════════════════════════════════════════════════════
|
# ═════════════════════════════════════════════════════════
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_stream(stream: Any) -> Dict[str, Any]:
|
def _serialize_stream(stream: BotChatSession) -> Dict[str, Any]:
|
||||||
"""将 BotChatSession 序列化为可通过 RPC 传输的字典"""
|
"""将 BotChatSession 序列化为可通过 RPC 传输的字典"""
|
||||||
return {
|
return {
|
||||||
"session_id": getattr(stream, "session_id", ""),
|
"session_id": stream.session_id,
|
||||||
"platform": getattr(stream, "platform", ""),
|
"platform": stream.platform,
|
||||||
"user_id": getattr(stream, "user_id", ""),
|
"user_id": stream.user_id,
|
||||||
"group_id": getattr(stream, "group_id", ""),
|
"group_id": stream.group_id,
|
||||||
"is_group_session": getattr(stream, "is_group_session", False),
|
"is_group_session": stream.is_group_session,
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -5,11 +5,12 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import time
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import logging as stdlib_logging
|
import logging as stdlib_logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
# ─── 协议常量 ──────────────────────────────────────────────────────
|
# ─── 协议常量 ──────────────────────────────────────────────────────
|
||||||
@@ -35,7 +36,7 @@ class MessageType(str, Enum):
|
|||||||
class RequestIdGenerator:
|
class RequestIdGenerator:
|
||||||
"""单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio)"""
|
"""单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio)"""
|
||||||
|
|
||||||
def __init__(self, start: int = 1):
|
def __init__(self, start: int = 1) -> None:
|
||||||
self._counter = start
|
self._counter = start
|
||||||
|
|
||||||
def next(self) -> int:
|
def next(self) -> int:
|
||||||
|
|||||||
@@ -39,7 +39,12 @@ class ErrorCode(str, Enum):
|
|||||||
class RPCError(Exception):
|
class RPCError(Exception):
|
||||||
"""RPC 调用异常"""
|
"""RPC 调用异常"""
|
||||||
|
|
||||||
def __init__(self, code: ErrorCode, message: str = "", details: Optional[Dict[str, Any]] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
code: ErrorCode,
|
||||||
|
message: str = "",
|
||||||
|
details: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
self.code = code
|
self.code = code
|
||||||
self.message = message or code.value
|
self.message = message or code.value
|
||||||
self.details = details or {}
|
self.details = details or {}
|
||||||
|
|||||||
@@ -24,12 +24,13 @@ Host 端将其重放到主进程的 Logger(以 plugin.<name> 为名)中,
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import collections
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
|
||||||
|
|
||||||
from src.plugin_runtime.protocol.envelope import LogBatchPayload, LogEntry
|
from src.plugin_runtime.protocol.envelope import LogBatchPayload, LogEntry
|
||||||
|
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class ManifestValidator:
|
|||||||
RECOMMENDED_FIELDS = ["license", "keywords", "categories"]
|
RECOMMENDED_FIELDS = ["license", "keywords", "categories"]
|
||||||
SUPPORTED_MANIFEST_VERSIONS = [1, 2]
|
SUPPORTED_MANIFEST_VERSIONS = [1, 2]
|
||||||
|
|
||||||
def __init__(self, host_version: str = ""):
|
def __init__(self, host_version: str = "") -> None:
|
||||||
self._host_version = host_version
|
self._host_version = host_version
|
||||||
self.errors: List[str] = []
|
self.errors: List[str] = []
|
||||||
self.warnings: List[str] = []
|
self.warnings: List[str] = []
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class PluginMeta:
|
|||||||
plugin_dir: str,
|
plugin_dir: str,
|
||||||
plugin_instance: Any,
|
plugin_instance: Any,
|
||||||
manifest: Dict[str, Any],
|
manifest: Dict[str, Any],
|
||||||
):
|
) -> None:
|
||||||
self.plugin_id = plugin_id
|
self.plugin_id = plugin_id
|
||||||
self.plugin_dir = plugin_dir
|
self.plugin_dir = plugin_dir
|
||||||
self.instance = plugin_instance
|
self.instance = plugin_instance
|
||||||
@@ -61,7 +61,7 @@ class PluginLoader:
|
|||||||
- plugin.py: 插件入口模块(导出 create_plugin 工厂函数)
|
- plugin.py: 插件入口模块(导出 create_plugin 工厂函数)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, host_version: str = ""):
|
def __init__(self, host_version: str = "") -> None:
|
||||||
self._loaded_plugins: Dict[str, PluginMeta] = {}
|
self._loaded_plugins: Dict[str, PluginMeta] = {}
|
||||||
self._failed_plugins: Dict[str, str] = {}
|
self._failed_plugins: Dict[str, str] = {}
|
||||||
self._manifest_validator = ManifestValidator(host_version=host_version)
|
self._manifest_validator = ManifestValidator(host_version=host_version)
|
||||||
|
|||||||
@@ -9,19 +9,17 @@
|
|||||||
6. 转发插件的能力调用到 Host
|
6. 转发插件的能力调用到 Host
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging as stdlib_logging
|
from typing import Any, List, Optional
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging as stdlib_logging
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from src.common.logger import get_console_handler, get_logger, initialize_logging
|
from src.common.logger import get_console_handler, get_logger, initialize_logging
|
||||||
from src.plugin_runtime import ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
|
from src.plugin_runtime import ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
|
||||||
from src.plugin_runtime.protocol.envelope import (
|
from src.plugin_runtime.protocol.envelope import (
|
||||||
@@ -543,9 +541,12 @@ async def _async_main() -> None:
|
|||||||
runner = PluginRunner(host_address, session_token, plugin_dirs)
|
runner = PluginRunner(host_address, session_token, plugin_dirs)
|
||||||
|
|
||||||
# 注册信号处理
|
# 注册信号处理
|
||||||
|
def _mark_runner_shutting_down() -> None:
|
||||||
|
runner._shutting_down = True
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||||
loop.add_signal_handler(sig, lambda: setattr(runner, "_shutting_down", True))
|
loop.add_signal_handler(sig, _mark_runner_shutting_down)
|
||||||
|
|
||||||
await runner.run()
|
await runner.run()
|
||||||
|
|
||||||
|
|||||||
@@ -7,12 +7,12 @@
|
|||||||
分帧协议:4-byte big-endian length prefix + payload
|
分帧协议:4-byte big-endian length prefix + payload
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
import struct
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Awaitable, Callable
|
from typing import Awaitable, Callable
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import struct
|
|
||||||
|
|
||||||
# 分帧常量
|
# 分帧常量
|
||||||
FRAME_HEADER_SIZE = 4 # 4 字节长度前缀
|
FRAME_HEADER_SIZE = 4 # 4 字节长度前缀
|
||||||
MAX_FRAME_SIZE = 16 * 1024 * 1024 # 16 MB 最大帧大小
|
MAX_FRAME_SIZE = 16 * 1024 * 1024 # 16 MB 最大帧大小
|
||||||
@@ -23,13 +23,13 @@ class ConnectionClosed(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Connection(ABC):
|
class Connection:
|
||||||
"""单个连接的抽象
|
"""单个连接的抽象
|
||||||
|
|
||||||
封装了底层 StreamReader/StreamWriter,提供分帧读写能力。
|
封装了底层 StreamReader/StreamWriter,提供分帧读写能力。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||||||
self._reader = reader
|
self._reader = reader
|
||||||
self._writer = writer
|
self._writer = writer
|
||||||
self._closed = False
|
self._closed = False
|
||||||
@@ -57,19 +57,16 @@ class Connection(ABC):
|
|||||||
if length > MAX_FRAME_SIZE:
|
if length > MAX_FRAME_SIZE:
|
||||||
raise ValueError(f"帧大小 {length} 超过最大限制 {MAX_FRAME_SIZE}")
|
raise ValueError(f"帧大小 {length} 超过最大限制 {MAX_FRAME_SIZE}")
|
||||||
# 读取 payload
|
# 读取 payload
|
||||||
payload = await self._reader.readexactly(length)
|
return await self._reader.readexactly(length)
|
||||||
return payload
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""关闭连接"""
|
"""关闭连接"""
|
||||||
if self._closed:
|
if self._closed:
|
||||||
return
|
return
|
||||||
self._closed = True
|
self._closed = True
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
self._writer.close()
|
self._writer.close()
|
||||||
await self._writer.wait_closed()
|
await self._writer.wait_closed()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_closed(self) -> bool:
|
def is_closed(self) -> bool:
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class TCPConnection(Connection):
|
|||||||
class TCPTransportServer(TransportServer):
|
class TCPTransportServer(TransportServer):
|
||||||
"""TCP 传输服务端(回退方案)"""
|
"""TCP 传输服务端(回退方案)"""
|
||||||
|
|
||||||
def __init__(self, host: str = "127.0.0.1", port: int = 0):
|
def __init__(self, host: str = "127.0.0.1", port: int = 0) -> None:
|
||||||
self._host = host
|
self._host = host
|
||||||
self._port = port # 0 表示自动分配
|
self._port = port # 0 表示自动分配
|
||||||
self._server: Optional[asyncio.AbstractServer] = None
|
self._server: Optional[asyncio.AbstractServer] = None
|
||||||
@@ -52,7 +52,7 @@ class TCPTransportServer(TransportServer):
|
|||||||
class TCPTransportClient(TransportClient):
|
class TCPTransportClient(TransportClient):
|
||||||
"""TCP 传输客户端"""
|
"""TCP 传输客户端"""
|
||||||
|
|
||||||
def __init__(self, host: str, port: int):
|
def __init__(self, host: str, port: int) -> None:
|
||||||
self._host = host
|
self._host = host
|
||||||
self._port = port
|
self._port = port
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ _UDS_PATH_MAX = 104
|
|||||||
class UDSTransportServer(TransportServer):
|
class UDSTransportServer(TransportServer):
|
||||||
"""UDS 传输服务端"""
|
"""UDS 传输服务端"""
|
||||||
|
|
||||||
def __init__(self, socket_path: Optional[str] = None):
|
def __init__(self, socket_path: Optional[str] = None) -> None:
|
||||||
if socket_path is None:
|
if socket_path is None:
|
||||||
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
|
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
|
||||||
import uuid
|
import uuid
|
||||||
@@ -80,7 +80,7 @@ class UDSTransportServer(TransportServer):
|
|||||||
class UDSTransportClient(TransportClient):
|
class UDSTransportClient(TransportClient):
|
||||||
"""UDS 传输客户端"""
|
"""UDS 传输客户端"""
|
||||||
|
|
||||||
def __init__(self, socket_path: str):
|
def __init__(self, socket_path: str) -> None:
|
||||||
self._socket_path = socket_path
|
self._socket_path = socket_path
|
||||||
|
|
||||||
async def connect(self) -> Connection:
|
async def connect(self) -> Connection:
|
||||||
|
|||||||
Reference in New Issue
Block a user