feat: Implement adapter runtime state management and update handling

- Added support for adapter runtime state updates in the PluginRunnerSupervisor.
- Introduced new payload classes: AdapterStateUpdatePayload and AdapterStateUpdateResultPayload for handling state updates.
- Implemented methods to bind and unbind routes based on adapter connection status.
- Enhanced the NapCat adapter to report connection state and manage runtime state.
- Added tests for adapter runtime state synchronization and database session behavior in the statistic module.
- Updated existing methods to ensure proper handling of adapter state and route bindings.
This commit is contained in:
DrSmoothl
2026-03-21 21:47:22 +08:00
parent dd20cd4992
commit 4e2e7a279e
11 changed files with 1219 additions and 79 deletions

View File

@@ -436,14 +436,14 @@ class StatisticOutputTask(AsyncTask):
@staticmethod
def _fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]:
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time)
records = session.exec(statement).all()
return [(record.start_timestamp, record.end_timestamp) for record in records]
@staticmethod
def _fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]:
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time)
records = session.exec(statement).all()
return [
@@ -664,7 +664,7 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1]
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(Messages).where(col(Messages.timestamp) >= query_start_timestamp)
messages = session.exec(statement).all()
for message in messages:
@@ -713,7 +713,7 @@ class StatisticOutputTask(AsyncTask):
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
try:
action_query_start_timestamp = collect_period[-1][1]
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp)
actions = session.exec(statement).all()
for action in actions:
@@ -1750,7 +1750,7 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
messages = session.exec(statement).all()
for message in messages:
@@ -2131,7 +2131,7 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
messages = session.exec(statement).all()
for message in messages:

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from datetime import datetime
from typing import Optional, List
from typing import Any, List, Mapping, Optional, Sequence
import json
@@ -15,6 +15,76 @@ class GroupCardnameInfo:
group_cardname: str
def _normalize_group_cardname_item(raw_item: Mapping[str, Any]) -> Optional[GroupCardnameInfo]:
"""将单条群名片数据规范化为统一结构。
Args:
raw_item: 原始群名片字典,必须包含 `group_id` 和 `group_cardname`。
Returns:
Optional[GroupCardnameInfo]: 规范化后的群名片信息;若数据不完整则返回 ``None``。
"""
group_id = str(raw_item.get("group_id") or "").strip()
group_cardname = str(raw_item.get("group_cardname") or "").strip()
if not group_id or not group_cardname:
return None
return GroupCardnameInfo(group_id=group_id, group_cardname=group_cardname)
def parse_group_cardname_json(group_cardname_json: Optional[str]) -> Optional[List[GroupCardnameInfo]]:
"""解析数据库中的群名片 JSON 字段。
Args:
group_cardname_json: 数据库存储的群名片 JSON 字符串。
Returns:
Optional[List[GroupCardnameInfo]]: 解析并规范化后的群名片列表;若字段为空或无有效项则返回 ``None``。
Raises:
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
"""
if not group_cardname_json:
return None
raw_items = json.loads(group_cardname_json)
if not isinstance(raw_items, list):
return None
normalized_items: List[GroupCardnameInfo] = []
for raw_item in raw_items:
if not isinstance(raw_item, Mapping):
continue
if normalized_item := _normalize_group_cardname_item(raw_item):
normalized_items.append(normalized_item)
return normalized_items or None
def dump_group_cardname_records(
group_cardname_records: Optional[Sequence[GroupCardnameInfo | Mapping[str, Any]]],
) -> str:
"""将群名片列表序列化为数据库使用的标准 JSON 字符串。
Args:
group_cardname_records: 待序列化的群名片列表,支持 `GroupCardnameInfo`
对象和包含 `group_id` / `group_cardname` 的字典。
Returns:
str: 统一使用 `group_cardname` 键名的 JSON 字符串。
"""
normalized_items: List[GroupCardnameInfo] = []
for raw_item in group_cardname_records or []:
if isinstance(raw_item, GroupCardnameInfo):
normalized_items.append(raw_item)
continue
if isinstance(raw_item, Mapping):
if normalized_item := _normalize_group_cardname_item(raw_item):
normalized_items.append(normalized_item)
return json.dumps([asdict(item) for item in normalized_items], ensure_ascii=False)
class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
def __init__(
self,
@@ -58,9 +128,16 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
"""最后一次被认识的时间"""
@classmethod
def from_db_instance(cls, db_record: "PersonInfo"):
nickname_json = json.loads(db_record.group_cardname) if db_record.group_cardname else None
group_cardname_list = [GroupCardnameInfo(**item) for item in nickname_json] if nickname_json else None
def from_db_instance(cls, db_record: "PersonInfo") -> "MaiPersonInfo":
"""从数据库记录构造人物信息数据模型。
Args:
db_record: 数据库中的人物信息记录。
Returns:
MaiPersonInfo: 转换后的数据模型对象。
"""
group_cardname_list = parse_group_cardname_json(db_record.group_cardname)
memory_points = json.loads(db_record.memory_points) if db_record.memory_points else None
return cls(
is_known=db_record.is_known,
@@ -78,9 +155,12 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
)
def to_db_instance(self) -> "PersonInfo":
group_cardname = (
json.dumps([gc.__dict__ for gc in self.group_cardname_list]) if self.group_cardname_list else None
)
"""将当前数据模型转换为数据库记录对象。
Returns:
PersonInfo: 可直接写入数据库的模型实例。
"""
group_cardname = dump_group_cardname_records(self.group_cardname_list)
return PersonInfo(
is_known=self.is_known,
person_id=self.person_id,

View File

@@ -1,22 +1,24 @@
import hashlib
from datetime import datetime
from typing import Dict, Optional, Union
import asyncio
import hashlib
import json
import time
import random
import math
import random
import time
from json_repair import repair_json
from typing import Union, Optional, Dict
from datetime import datetime
from sqlmodel import col, select
from src.common.logger import get_logger
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.llm_models.utils_model import LLMRequest
logger = get_logger("person_info")
@@ -26,6 +28,32 @@ relation_selection_model = LLMRequest(
)
def _to_group_cardname_records(group_cardname_json: Optional[str]) -> list[dict[str, str]]:
"""将数据库中的群名片 JSON 转换为 `Person` 内部使用的结构。
Args:
group_cardname_json: 数据库存储的群名片 JSON 字符串。
Returns:
list[dict[str, str]]: 统一使用 `group_cardname` 键名的群名片列表。
Raises:
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
"""
group_cardname_list = parse_group_cardname_json(group_cardname_json)
if not group_cardname_list:
return []
return [
{
"group_id": group_cardname.group_id,
"group_cardname": group_cardname.group_cardname,
}
for group_cardname in group_cardname_list
]
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
"""获取唯一id"""
if "-" in platform:
@@ -231,7 +259,7 @@ class Person:
person.know_since = time.time()
person.last_know = time.time()
person.memory_points = []
person.group_nick_name = [] # 初始化群昵称列表
person.group_cardname_list = [] # 初始化群名片列表
# 如果是群聊,添加群昵称
if group_id and group_nick_name:
@@ -269,7 +297,7 @@ class Person:
self.platform = platform
self.nickname = global_config.bot.nickname
self.person_name = global_config.bot.nickname
self.group_nick_name: list[dict[str, str]] = []
self.group_cardname_list: list[dict[str, str]] = []
return
self.user_id = ""
@@ -308,7 +336,7 @@ class Person:
self.know_since = None
self.last_know: Optional[float] = None
self.memory_points = []
self.group_nick_name: list[dict[str, str]] = [] # 群昵称列表,存储 {"group_id": str, "group_nick_name": str}
self.group_cardname_list: list[dict[str, str]] = [] # 群名片列表,存储 {"group_id": str, "group_cardname": str}
# 从数据库加载数据
self.load_from_database()
@@ -408,16 +436,16 @@ class Person:
return
# 检查是否已存在该群号的记录
for item in self.group_nick_name:
for item in self.group_cardname_list:
if item.get("group_id") == group_id:
# 更新现有记录
item["group_nick_name"] = group_nick_name
item["group_cardname"] = group_nick_name
self.sync_to_database()
logger.debug(f"更新用户 {self.person_id} 在群 {group_id} 的群昵称为 {group_nick_name}")
return
# 添加新记录
self.group_nick_name.append({"group_id": group_id, "group_nick_name": group_nick_name})
self.group_cardname_list.append({"group_id": group_id, "group_cardname": group_nick_name})
self.sync_to_database()
logger.debug(f"添加用户 {self.person_id} 在群 {group_id} 的群昵称 {group_nick_name}")
@@ -452,20 +480,15 @@ class Person:
else:
self.memory_points = []
# 处理group_nick_name字段JSON格式的列表
# 处理 group_cardname 字段JSON 格式的列表)
if record.group_cardname:
try:
loaded_group_nick_names = json.loads(record.group_cardname)
# 确保是列表格式
if isinstance(loaded_group_nick_names, list):
self.group_nick_name = loaded_group_nick_names
else:
self.group_nick_name = []
self.group_cardname_list = _to_group_cardname_records(record.group_cardname)
except (json.JSONDecodeError, TypeError):
logger.warning(f"解析用户 {self.person_id} 的group_cardname字段失败使用默认值")
self.group_nick_name = []
self.group_cardname_list = []
else:
self.group_nick_name = []
self.group_cardname_list = []
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
else:
@@ -486,11 +509,7 @@ class Person:
if self.memory_points
else json.dumps([], ensure_ascii=False)
)
group_nickname_value = (
json.dumps(self.group_nick_name, ensure_ascii=False)
if self.group_nick_name
else json.dumps([], ensure_ascii=False)
)
group_cardname_value = dump_group_cardname_records(self.group_cardname_list)
first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None
last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None
@@ -510,7 +529,7 @@ class Person:
record.first_known_time = first_known_time
record.last_known_time = last_known_time
record.memory_points = memory_points_value
record.group_nickname = group_nickname_value
record.group_cardname = group_cardname_value
session.add(record)
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
else:
@@ -526,7 +545,7 @@ class Person:
first_known_time=first_known_time,
last_known_time=last_known_time,
memory_points=memory_points_value,
group_nickname=group_nickname_value,
group_cardname=group_cardname_value,
)
session.add(record)
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
@@ -8,13 +9,15 @@ import sys
from src.common.logger import get_logger
from src.config.config import global_config
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, RouteMode, get_platform_io_manager
from src.platform_io.drivers import PluginPlatformDriver
from src.platform_io.route_key_factory import RouteKeyFactory
from src.platform_io.routing import RouteBindingConflictError
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
from src.plugin_runtime.protocol.envelope import (
AdapterDeclarationPayload,
AdapterStateUpdatePayload,
AdapterStateUpdateResultPayload,
BootstrapPluginPayload,
ConfigUpdatedPayload,
Envelope,
@@ -46,6 +49,19 @@ if TYPE_CHECKING:
logger = get_logger("plugin_runtime.host.runner_manager")
_ADAPTER_BINDING_ROLE_RUNTIME_EXACT = "runtime_exact"
_ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT = "platform_default"
@dataclass(slots=True)
class _AdapterRuntimeState:
"""保存适配器插件当前的运行时连接状态。"""
connected: bool = False
account_id: Optional[str] = None
scope: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
class PluginRunnerSupervisor:
"""插件 Runner 监督器。
@@ -94,6 +110,7 @@ class PluginRunnerSupervisor:
self._runner_process: Optional[asyncio.subprocess.Process] = None
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
self._registered_adapters: Dict[str, AdapterDeclarationPayload] = {}
self._adapter_runtime_states: Dict[str, _AdapterRuntimeState] = {}
self._runner_ready_events: asyncio.Event = asyncio.Event()
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
self._health_task: Optional[asyncio.Task[None]] = None
@@ -452,6 +469,7 @@ class PluginRunnerSupervisor:
"""注册 Host 侧内部 RPC 方法。"""
self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request)
self._rpc_server.register_method("host.receive_external_message", self._handle_receive_external_message)
self._rpc_server.register_method("host.update_adapter_state", self._handle_update_adapter_state)
self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin)
self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin)
self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin)
@@ -563,14 +581,14 @@ class PluginRunnerSupervisor:
return f"adapter:{plugin_id}"
async def _register_adapter_driver(self, plugin_id: str, adapter: AdapterDeclarationPayload) -> None:
"""将适配器插件注册到 Platform IO。
"""将适配器插件驱动注册到 Platform IO。
Args:
plugin_id: 适配器插件 ID。
adapter: 经过校验的适配器声明。
Raises:
ValueError: 适配器路由冲突或驱动注册失败时抛出。
ValueError: 驱动注册失败时抛出。
"""
await self._unregister_adapter_driver(plugin_id)
@@ -588,22 +606,12 @@ class PluginRunnerSupervisor:
**adapter.metadata,
},
)
binding = RouteBinding(
route_key=driver.descriptor.route_key,
driver_id=driver.driver_id,
driver_kind=DriverKind.PLUGIN,
metadata={
"plugin_id": plugin_id,
"protocol": adapter.protocol,
},
)
try:
if platform_io_manager.is_started:
await platform_io_manager.add_driver(driver)
else:
platform_io_manager.register_driver(driver)
platform_io_manager.bind_route(binding)
except Exception:
with contextlib.suppress(Exception):
if platform_io_manager.is_started:
@@ -613,6 +621,7 @@ class PluginRunnerSupervisor:
raise
self._registered_adapters[plugin_id] = adapter
self._adapter_runtime_states[plugin_id] = _AdapterRuntimeState()
async def _unregister_adapter_driver(self, plugin_id: str) -> None:
"""从 Platform IO 注销一个适配器驱动。
@@ -622,6 +631,9 @@ class PluginRunnerSupervisor:
"""
platform_io_manager = get_platform_io_manager()
driver_id = self._build_adapter_driver_id(plugin_id)
adapter = self._registered_adapters.get(plugin_id)
self._remove_adapter_route_bindings(plugin_id)
with contextlib.suppress(Exception):
if platform_io_manager.is_started:
@@ -629,7 +641,11 @@ class PluginRunnerSupervisor:
else:
platform_io_manager.unregister_driver(driver_id)
if adapter is not None:
self._refresh_platform_default_route(adapter.platform)
self._registered_adapters.pop(plugin_id, None)
self._adapter_runtime_states.pop(plugin_id, None)
async def _unregister_all_adapter_drivers(self) -> None:
"""注销当前 Supervisor 管理的全部适配器驱动。"""
@@ -637,6 +653,198 @@ class PluginRunnerSupervisor:
for plugin_id in plugin_ids:
await self._unregister_adapter_driver(plugin_id)
def _remove_adapter_route_bindings(self, plugin_id: str) -> None:
"""移除某个适配器驱动当前持有的全部路由绑定。
Args:
plugin_id: 适配器插件 ID。
"""
platform_io_manager = get_platform_io_manager()
platform_io_manager.route_table.remove_bindings_by_driver(self._build_adapter_driver_id(plugin_id))
@staticmethod
def _normalize_runtime_route_value(value: str) -> Optional[str]:
"""规范化适配器运行时路由字段。
Args:
value: 待规范化的原始字符串。
Returns:
Optional[str]: 规范化后非空则返回字符串,否则返回 ``None``。
"""
normalized_value = str(value).strip()
return normalized_value or None
def _build_runtime_route_key(
self,
adapter: AdapterDeclarationPayload,
payload: AdapterStateUpdatePayload,
) -> RouteKey:
"""根据运行时状态更新构造适配器生效路由键。
Args:
adapter: 当前适配器声明。
payload: 适配器上报的运行时状态。
Returns:
RouteKey: 当前连接应接管的精确路由键。
Raises:
ValueError: 当静态声明与运行时上报的身份信息冲突时抛出。
"""
runtime_account_id = self._normalize_runtime_route_value(payload.account_id)
runtime_scope = self._normalize_runtime_route_value(payload.scope)
if adapter.account_id and runtime_account_id and adapter.account_id != runtime_account_id:
raise ValueError(
f"适配器声明的 account_id={adapter.account_id} 与运行时上报的 {runtime_account_id} 不一致"
)
if adapter.scope and runtime_scope and adapter.scope != runtime_scope:
raise ValueError(f"适配器声明的 scope={adapter.scope} 与运行时上报的 {runtime_scope} 不一致")
return RouteKey(
platform=adapter.platform,
account_id=runtime_account_id or adapter.account_id or None,
scope=runtime_scope or adapter.scope or None,
)
def _bind_runtime_exact_route(
self,
plugin_id: str,
adapter: AdapterDeclarationPayload,
route_key: RouteKey,
) -> None:
"""为适配器连接绑定精确生效路由。
Args:
plugin_id: 适配器插件 ID。
adapter: 当前适配器声明。
route_key: 当前连接对应的精确路由键。
Raises:
RouteBindingConflictError: 当目标路由已被其他 active owner 占用时抛出。
"""
platform_io_manager = get_platform_io_manager()
platform_io_manager.bind_route(
RouteBinding(
route_key=route_key,
driver_id=self._build_adapter_driver_id(plugin_id),
driver_kind=DriverKind.PLUGIN,
metadata={
"plugin_id": plugin_id,
"protocol": adapter.protocol,
"binding_role": _ADAPTER_BINDING_ROLE_RUNTIME_EXACT,
},
)
)
def _list_runtime_exact_bindings(self, platform: str) -> List[RouteBinding]:
"""列出某个平台上由 Host 动态维护的精确适配器绑定。
Args:
platform: 目标平台名称。
Returns:
List[RouteBinding]: 当前平台上全部动态精确绑定。
"""
platform_io_manager = get_platform_io_manager()
return [
binding
for binding in platform_io_manager.route_table.list_bindings()
if binding.mode == RouteMode.ACTIVE
and binding.route_key.platform == platform
and binding.metadata.get("binding_role") == _ADAPTER_BINDING_ROLE_RUNTIME_EXACT
]
def _refresh_platform_default_route(self, platform: str) -> None:
"""根据当前精确绑定数量刷新平台级默认路由。
当某个平台恰好只存在一个动态精确绑定时,会为该绑定额外创建一条
``RouteKey(platform=<platform>)`` 形式的默认路由,方便缺少账号维度的
出站消息继续找到唯一 owner。若精确绑定数量变为 0 或大于 1则撤销
由 Host 自动维护的默认路由,避免出现隐式歧义。
Args:
platform: 目标平台名称。
"""
platform_io_manager = get_platform_io_manager()
default_route_key = RouteKey(platform=platform)
existing_default_binding = platform_io_manager.route_table.get_active_binding(default_route_key, exact_only=True)
if existing_default_binding is not None:
binding_role = existing_default_binding.metadata.get("binding_role")
if binding_role != _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT:
return
platform_io_manager.unbind_route(default_route_key, existing_default_binding.driver_id)
exact_bindings = self._list_runtime_exact_bindings(platform)
if len(exact_bindings) != 1:
return
exact_binding = exact_bindings[0]
if exact_binding.route_key == default_route_key:
return
platform_io_manager.bind_route(
RouteBinding(
route_key=default_route_key,
driver_id=exact_binding.driver_id,
driver_kind=exact_binding.driver_kind,
metadata={
"plugin_id": exact_binding.metadata.get("plugin_id", ""),
"protocol": exact_binding.metadata.get("protocol", ""),
"binding_role": _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT,
},
),
replace=True,
)
def _apply_adapter_runtime_state(
self,
plugin_id: str,
adapter: AdapterDeclarationPayload,
payload: AdapterStateUpdatePayload,
) -> Tuple[_AdapterRuntimeState, Dict[str, Any]]:
"""应用适配器运行时状态,并同步 Platform IO 路由。
Args:
plugin_id: 适配器插件 ID。
adapter: 当前适配器声明。
payload: 适配器上报的运行时状态。
Returns:
Tuple[_AdapterRuntimeState, Dict[str, Any]]: 更新后的运行时状态,以及
供 RPC 响应返回的路由键字典。
Raises:
RouteBindingConflictError: 当新的精确路由与其他 active owner 冲突时抛出。
ValueError: 当运行时路由信息不合法时抛出。
"""
if not payload.connected:
self._remove_adapter_route_bindings(plugin_id)
self._refresh_platform_default_route(adapter.platform)
runtime_state = _AdapterRuntimeState(connected=False, metadata=dict(payload.metadata))
self._adapter_runtime_states[plugin_id] = runtime_state
return runtime_state, {}
route_key = self._build_runtime_route_key(adapter, payload)
self._remove_adapter_route_bindings(plugin_id)
self._bind_runtime_exact_route(plugin_id, adapter, route_key)
self._refresh_platform_default_route(adapter.platform)
runtime_state = _AdapterRuntimeState(
connected=True,
account_id=route_key.account_id,
scope=route_key.scope,
metadata=dict(payload.metadata),
)
self._adapter_runtime_states[plugin_id] = runtime_state
return runtime_state, {
"platform": route_key.platform,
"account_id": route_key.account_id,
"scope": route_key.scope,
}
@staticmethod
def _attach_inbound_route_metadata(
session_message: "SessionMessage",
@@ -706,6 +914,45 @@ class PluginRunnerSupervisor:
scope=scope,
)
async def _handle_update_adapter_state(self, envelope: Envelope) -> Envelope:
"""处理适配器插件上报的运行时状态更新。
Args:
envelope: RPC 请求信封。
Returns:
Envelope: 状态更新处理结果。
"""
try:
payload = AdapterStateUpdatePayload.model_validate(envelope.payload)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
adapter = self._registered_adapters.get(envelope.plugin_id)
if adapter is None:
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"插件 {envelope.plugin_id} 未声明为适配器,不能更新运行时状态",
)
try:
runtime_state, route_key_dict = self._apply_adapter_runtime_state(
plugin_id=envelope.plugin_id,
adapter=adapter,
payload=payload,
)
except RouteBindingConflictError as exc:
return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, str(exc))
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
response = AdapterStateUpdateResultPayload(
accepted=True,
connected=runtime_state.connected,
route_key=route_key_dict,
)
return envelope.make_response(payload=response.model_dump())
async def _handle_receive_external_message(self, envelope: Envelope) -> Envelope:
"""处理适配器插件上报的外部入站消息。
@@ -970,6 +1217,7 @@ class PluginRunnerSupervisor:
self._component_registry.clear()
self._registered_plugins.clear()
self._registered_adapters.clear()
self._adapter_runtime_states.clear()
self._runner_ready_events = asyncio.Event()
self._runner_ready_payloads = RunnerReadyPayload()
self._rpc_server.clear_handshake_state()

View File

@@ -304,6 +304,30 @@ class AdapterDeclarationPayload(BaseModel):
"""适配器附加元数据"""
class AdapterStateUpdatePayload(BaseModel):
"""适配器运行时状态更新载荷。"""
connected: bool = Field(description="适配器当前是否已连接并准备接管路由")
"""适配器当前是否已连接并准备接管路由"""
account_id: str = Field(default="", description="当前连接对应的账号 ID 或 self_id")
"""当前连接对应的账号 ID 或 self_id"""
scope: str = Field(default="", description="当前连接对应的可选路由作用域")
"""当前连接对应的可选路由作用域"""
metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的运行时状态元数据")
"""可选的运行时状态元数据"""
class AdapterStateUpdateResultPayload(BaseModel):
"""适配器运行时状态更新结果载荷。"""
accepted: bool = Field(description="Host 是否接受了本次状态更新")
"""Host 是否接受了本次状态更新"""
connected: bool = Field(description="Host 记录的当前连接状态")
"""Host 记录的当前连接状态"""
route_key: Dict[str, Any] = Field(default_factory=dict, description="当前生效的路由键")
"""当前生效的路由键"""
class ReceiveExternalMessagePayload(BaseModel):
"""适配器插件向 Host 注入外部消息的请求载荷。"""

View File

@@ -481,13 +481,14 @@ class PluginRunner:
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False
if not await self._invoke_plugin_on_load(meta):
if not await self._register_plugin(meta):
await self._invoke_plugin_on_unload(meta)
await self._deactivate_plugin(meta)
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False
if not await self._register_plugin(meta):
await self._invoke_plugin_on_unload(meta)
if not await self._invoke_plugin_on_load(meta):
await self._unregister_plugin(meta.plugin_id, reason="on_load_failed")
await self._deactivate_plugin(meta)
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False

View File

@@ -60,6 +60,9 @@ class NapCatAdapterPlugin(MaiBotPlugin):
self._connection_task: Optional[asyncio.Task[None]] = None
self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
self._background_tasks: Set[asyncio.Task[Any]] = set()
self._reported_account_id: Optional[str] = None
self._reported_scope: Optional[str] = None
self._runtime_state_connected: bool = False
self._send_lock = asyncio.Lock()
self._ws: Optional[AiohttpClientWebSocketResponse] = None
@@ -170,6 +173,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
with contextlib.suppress(asyncio.CancelledError):
await connection_task
await self._report_adapter_disconnected()
self._fail_pending_actions("NapCat connection closed")
async def _cancel_background_tasks(self) -> None:
@@ -209,6 +213,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
self.ctx.logger.warning(f"NapCat 适配器连接失败: {exc}")
finally:
self._ws = None
await self._report_adapter_disconnected()
self._fail_pending_actions("NapCat connection interrupted")
if not self._should_connect():
@@ -230,26 +235,39 @@ class NapCatAdapterPlugin(MaiBotPlugin):
"""
assert WSMsgType is not None
async for ws_message in ws:
if ws_message.type != WSMsgType.TEXT:
if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
break
continue
bootstrap_task = asyncio.create_task(
self._bootstrap_adapter_runtime_state(),
name="napcat_adapter.bootstrap",
)
self._background_tasks.add(bootstrap_task)
bootstrap_task.add_done_callback(self._background_tasks.discard)
payload = self._parse_json_message(ws_message.data)
if payload is None:
continue
try:
async for ws_message in ws:
if ws_message.type != WSMsgType.TEXT:
if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
break
continue
if echo_id := str(payload.get("echo") or "").strip():
self._resolve_pending_action(echo_id, payload)
continue
payload = self._parse_json_message(ws_message.data)
if payload is None:
continue
if str(payload.get("post_type") or "").strip() != "message":
continue
if echo_id := str(payload.get("echo") or "").strip():
self._resolve_pending_action(echo_id, payload)
continue
task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound")
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
if str(payload.get("post_type") or "").strip() != "message":
continue
task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound")
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
finally:
if not bootstrap_task.done():
bootstrap_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await bootstrap_task
async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None:
"""处理单条 NapCat 入站消息并注入 Host。
@@ -258,6 +276,9 @@ class NapCatAdapterPlugin(MaiBotPlugin):
payload: NapCat / OneBot 推送的原始事件数据。
"""
self_id = str(payload.get("self_id") or "").strip()
if self_id:
await self._report_adapter_connected(self_id)
sender = payload.get("sender", {})
if not isinstance(sender, dict):
sender = {}
@@ -570,6 +591,121 @@ class NapCatAdapterPlugin(MaiBotPlugin):
response_future.set_exception(RuntimeError(error_message))
self._pending_actions.clear()
async def _bootstrap_adapter_runtime_state(self) -> None:
"""在连接建立后主动获取账号信息并激活适配器路由。
该步骤会在 WebSocket 接收循环启动后异步执行,确保 `_call_action()`
发出的 `get_login_info` 请求能够被同一连接上的接收循环消费到 echo
响应,从而在真正收到业务消息前就完成 Host 侧 route 激活。
"""
max_attempts = 3
last_error: Optional[Exception] = None
for attempt in range(1, max_attempts + 1):
ws = self._ws
if ws is None or ws.closed:
return
try:
response = await self._call_action("get_login_info", {})
self_id = self._extract_self_id_from_login_response(response)
await self._report_adapter_connected(self_id)
return
except asyncio.CancelledError:
raise
except Exception as exc:
last_error = exc
self.ctx.logger.warning(
f"NapCat 适配器获取登录信息失败,第 {attempt}/{max_attempts} 次重试: {exc}"
)
if attempt < max_attempts:
await asyncio.sleep(1.0)
if last_error is not None:
self.ctx.logger.error(f"NapCat 适配器未能完成路由激活,连接将保持只接收状态: {last_error}")
@staticmethod
def _extract_self_id_from_login_response(response: Dict[str, Any]) -> str:
"""从 `get_login_info` 响应中提取当前账号 ID。
Args:
response: NapCat 返回的原始动作响应。
Returns:
str: 规范化后的 `self_id` 字符串。
Raises:
ValueError: 当响应中缺少有效账号 ID 时抛出。
"""
if str(response.get("status") or "").lower() != "ok":
raise ValueError(str(response.get("wording") or response.get("message") or "get_login_info failed"))
response_data = response.get("data", {})
if not isinstance(response_data, dict):
raise ValueError("get_login_info 响应缺少 data 字段")
self_id = str(response_data.get("user_id") or "").strip()
if not self_id:
raise ValueError("get_login_info 响应缺少有效的 user_id")
return self_id
async def _report_adapter_connected(self, account_id: str) -> None:
"""向 Host 上报当前连接已就绪。
Args:
account_id: 当前 NapCat 连接对应的机器人账号 ID。
"""
normalized_account_id = str(account_id).strip()
if not normalized_account_id:
return
scope = self._get_string(self._connection_config(), "connection_id").strip()
if (
self._runtime_state_connected
and self._reported_account_id == normalized_account_id
and self._reported_scope == (scope or None)
):
return
accepted = False
try:
accepted = await self.ctx.adapter.update_runtime_state(
connected=True,
account_id=normalized_account_id,
scope=scope,
metadata={"ws_url": self._get_string(self._connection_config(), "ws_url")},
)
except Exception as exc:
self.ctx.logger.warning(f"NapCat 适配器上报连接就绪状态失败: {exc}")
return
if not accepted:
self.ctx.logger.warning("NapCat 适配器连接已建立,但 Host 未接受运行时状态更新")
return
self._runtime_state_connected = True
self._reported_account_id = normalized_account_id
self._reported_scope = scope or None
self.ctx.logger.info(
f"NapCat 适配器已激活路由: platform=qq account_id={normalized_account_id} "
f"scope={self._reported_scope or '*'}"
)
async def _report_adapter_disconnected(self) -> None:
"""向 Host 上报当前连接已断开,并撤销适配器路由。"""
if not self._runtime_state_connected:
self._reported_account_id = None
self._reported_scope = None
return
try:
await self.ctx.adapter.update_runtime_state(connected=False)
except Exception as exc:
self.ctx.logger.warning(f"NapCat 适配器上报断开状态失败: {exc}")
finally:
self._runtime_state_connected = False
self._reported_account_id = None
self._reported_scope = None
def _build_headers(self) -> Dict[str, str]:
"""构造连接 NapCat 所需的请求头。