feat: Implement adapter runtime state management and update handling
- Added support for adapter runtime state updates in the PluginRunnerSupervisor. - Introduced new payload classes: AdapterStateUpdatePayload and AdapterStateUpdateResultPayload for handling state updates. - Implemented methods to bind and unbind routes based on adapter connection status. - Enhanced the NapCat adapter to report connection state and manage runtime state. - Added tests for adapter runtime state synchronization and database session behavior in the statistic module. - Updated existing methods to ensure proper handling of adapter state and route bindings.
This commit is contained in:
@@ -436,14 +436,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
@staticmethod
|
||||
def _fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]:
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time)
|
||||
records = session.exec(statement).all()
|
||||
return [(record.start_timestamp, record.end_timestamp) for record in records]
|
||||
|
||||
@staticmethod
|
||||
def _fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]:
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time)
|
||||
records = session.exec(statement).all()
|
||||
return [
|
||||
@@ -664,7 +664,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_timestamp = collect_period[-1][1]
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Messages).where(col(Messages.timestamp) >= query_start_timestamp)
|
||||
messages = session.exec(statement).all()
|
||||
for message in messages:
|
||||
@@ -713,7 +713,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
|
||||
try:
|
||||
action_query_start_timestamp = collect_period[-1][1]
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp)
|
||||
actions = session.exec(statement).all()
|
||||
for action in actions:
|
||||
@@ -1750,7 +1750,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
|
||||
messages = session.exec(statement).all()
|
||||
for message in messages:
|
||||
@@ -2131,7 +2131,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
|
||||
messages = session.exec(statement).all()
|
||||
for message in messages:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from typing import Any, List, Mapping, Optional, Sequence
|
||||
|
||||
import json
|
||||
|
||||
@@ -15,6 +15,76 @@ class GroupCardnameInfo:
|
||||
group_cardname: str
|
||||
|
||||
|
||||
def _normalize_group_cardname_item(raw_item: Mapping[str, Any]) -> Optional[GroupCardnameInfo]:
|
||||
"""将单条群名片数据规范化为统一结构。
|
||||
|
||||
Args:
|
||||
raw_item: 原始群名片字典,必须包含 `group_id` 和 `group_cardname`。
|
||||
|
||||
Returns:
|
||||
Optional[GroupCardnameInfo]: 规范化后的群名片信息;若数据不完整则返回 ``None``。
|
||||
"""
|
||||
group_id = str(raw_item.get("group_id") or "").strip()
|
||||
group_cardname = str(raw_item.get("group_cardname") or "").strip()
|
||||
if not group_id or not group_cardname:
|
||||
return None
|
||||
return GroupCardnameInfo(group_id=group_id, group_cardname=group_cardname)
|
||||
|
||||
|
||||
def parse_group_cardname_json(group_cardname_json: Optional[str]) -> Optional[List[GroupCardnameInfo]]:
|
||||
"""解析数据库中的群名片 JSON 字段。
|
||||
|
||||
Args:
|
||||
group_cardname_json: 数据库存储的群名片 JSON 字符串。
|
||||
|
||||
Returns:
|
||||
Optional[List[GroupCardnameInfo]]: 解析并规范化后的群名片列表;若字段为空或无有效项则返回 ``None``。
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
|
||||
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
|
||||
"""
|
||||
if not group_cardname_json:
|
||||
return None
|
||||
|
||||
raw_items = json.loads(group_cardname_json)
|
||||
if not isinstance(raw_items, list):
|
||||
return None
|
||||
|
||||
normalized_items: List[GroupCardnameInfo] = []
|
||||
for raw_item in raw_items:
|
||||
if not isinstance(raw_item, Mapping):
|
||||
continue
|
||||
if normalized_item := _normalize_group_cardname_item(raw_item):
|
||||
normalized_items.append(normalized_item)
|
||||
|
||||
return normalized_items or None
|
||||
|
||||
|
||||
def dump_group_cardname_records(
|
||||
group_cardname_records: Optional[Sequence[GroupCardnameInfo | Mapping[str, Any]]],
|
||||
) -> str:
|
||||
"""将群名片列表序列化为数据库使用的标准 JSON 字符串。
|
||||
|
||||
Args:
|
||||
group_cardname_records: 待序列化的群名片列表,支持 `GroupCardnameInfo`
|
||||
对象和包含 `group_id` / `group_cardname` 的字典。
|
||||
|
||||
Returns:
|
||||
str: 统一使用 `group_cardname` 键名的 JSON 字符串。
|
||||
"""
|
||||
normalized_items: List[GroupCardnameInfo] = []
|
||||
for raw_item in group_cardname_records or []:
|
||||
if isinstance(raw_item, GroupCardnameInfo):
|
||||
normalized_items.append(raw_item)
|
||||
continue
|
||||
if isinstance(raw_item, Mapping):
|
||||
if normalized_item := _normalize_group_cardname_item(raw_item):
|
||||
normalized_items.append(normalized_item)
|
||||
|
||||
return json.dumps([asdict(item) for item in normalized_items], ensure_ascii=False)
|
||||
|
||||
|
||||
class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -58,9 +128,16 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
||||
"""最后一次被认识的时间"""
|
||||
|
||||
@classmethod
|
||||
def from_db_instance(cls, db_record: "PersonInfo"):
|
||||
nickname_json = json.loads(db_record.group_cardname) if db_record.group_cardname else None
|
||||
group_cardname_list = [GroupCardnameInfo(**item) for item in nickname_json] if nickname_json else None
|
||||
def from_db_instance(cls, db_record: "PersonInfo") -> "MaiPersonInfo":
|
||||
"""从数据库记录构造人物信息数据模型。
|
||||
|
||||
Args:
|
||||
db_record: 数据库中的人物信息记录。
|
||||
|
||||
Returns:
|
||||
MaiPersonInfo: 转换后的数据模型对象。
|
||||
"""
|
||||
group_cardname_list = parse_group_cardname_json(db_record.group_cardname)
|
||||
memory_points = json.loads(db_record.memory_points) if db_record.memory_points else None
|
||||
return cls(
|
||||
is_known=db_record.is_known,
|
||||
@@ -78,9 +155,12 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
||||
)
|
||||
|
||||
def to_db_instance(self) -> "PersonInfo":
|
||||
group_cardname = (
|
||||
json.dumps([gc.__dict__ for gc in self.group_cardname_list]) if self.group_cardname_list else None
|
||||
)
|
||||
"""将当前数据模型转换为数据库记录对象。
|
||||
|
||||
Returns:
|
||||
PersonInfo: 可直接写入数据库的模型实例。
|
||||
"""
|
||||
group_cardname = dump_group_cardname_records(self.group_cardname_list)
|
||||
return PersonInfo(
|
||||
is_known=self.is_known,
|
||||
person_id=self.person_id,
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import Union, Optional, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
|
||||
logger = get_logger("person_info")
|
||||
@@ -26,6 +28,32 @@ relation_selection_model = LLMRequest(
|
||||
)
|
||||
|
||||
|
||||
def _to_group_cardname_records(group_cardname_json: Optional[str]) -> list[dict[str, str]]:
|
||||
"""将数据库中的群名片 JSON 转换为 `Person` 内部使用的结构。
|
||||
|
||||
Args:
|
||||
group_cardname_json: 数据库存储的群名片 JSON 字符串。
|
||||
|
||||
Returns:
|
||||
list[dict[str, str]]: 统一使用 `group_cardname` 键名的群名片列表。
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
|
||||
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
|
||||
"""
|
||||
group_cardname_list = parse_group_cardname_json(group_cardname_json)
|
||||
if not group_cardname_list:
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"group_id": group_cardname.group_id,
|
||||
"group_cardname": group_cardname.group_cardname,
|
||||
}
|
||||
for group_cardname in group_cardname_list
|
||||
]
|
||||
|
||||
|
||||
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||
"""获取唯一id"""
|
||||
if "-" in platform:
|
||||
@@ -231,7 +259,7 @@ class Person:
|
||||
person.know_since = time.time()
|
||||
person.last_know = time.time()
|
||||
person.memory_points = []
|
||||
person.group_nick_name = [] # 初始化群昵称列表
|
||||
person.group_cardname_list = [] # 初始化群名片列表
|
||||
|
||||
# 如果是群聊,添加群昵称
|
||||
if group_id and group_nick_name:
|
||||
@@ -269,7 +297,7 @@ class Person:
|
||||
self.platform = platform
|
||||
self.nickname = global_config.bot.nickname
|
||||
self.person_name = global_config.bot.nickname
|
||||
self.group_nick_name: list[dict[str, str]] = []
|
||||
self.group_cardname_list: list[dict[str, str]] = []
|
||||
return
|
||||
|
||||
self.user_id = ""
|
||||
@@ -308,7 +336,7 @@ class Person:
|
||||
self.know_since = None
|
||||
self.last_know: Optional[float] = None
|
||||
self.memory_points = []
|
||||
self.group_nick_name: list[dict[str, str]] = [] # 群昵称列表,存储 {"group_id": str, "group_nick_name": str}
|
||||
self.group_cardname_list: list[dict[str, str]] = [] # 群名片列表,存储 {"group_id": str, "group_cardname": str}
|
||||
|
||||
# 从数据库加载数据
|
||||
self.load_from_database()
|
||||
@@ -408,16 +436,16 @@ class Person:
|
||||
return
|
||||
|
||||
# 检查是否已存在该群号的记录
|
||||
for item in self.group_nick_name:
|
||||
for item in self.group_cardname_list:
|
||||
if item.get("group_id") == group_id:
|
||||
# 更新现有记录
|
||||
item["group_nick_name"] = group_nick_name
|
||||
item["group_cardname"] = group_nick_name
|
||||
self.sync_to_database()
|
||||
logger.debug(f"更新用户 {self.person_id} 在群 {group_id} 的群昵称为 {group_nick_name}")
|
||||
return
|
||||
|
||||
# 添加新记录
|
||||
self.group_nick_name.append({"group_id": group_id, "group_nick_name": group_nick_name})
|
||||
self.group_cardname_list.append({"group_id": group_id, "group_cardname": group_nick_name})
|
||||
self.sync_to_database()
|
||||
logger.debug(f"添加用户 {self.person_id} 在群 {group_id} 的群昵称 {group_nick_name}")
|
||||
|
||||
@@ -452,20 +480,15 @@ class Person:
|
||||
else:
|
||||
self.memory_points = []
|
||||
|
||||
# 处理group_nick_name字段(JSON格式的列表)
|
||||
# 处理 group_cardname 字段(JSON 格式的列表)
|
||||
if record.group_cardname:
|
||||
try:
|
||||
loaded_group_nick_names = json.loads(record.group_cardname)
|
||||
# 确保是列表格式
|
||||
if isinstance(loaded_group_nick_names, list):
|
||||
self.group_nick_name = loaded_group_nick_names
|
||||
else:
|
||||
self.group_nick_name = []
|
||||
self.group_cardname_list = _to_group_cardname_records(record.group_cardname)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(f"解析用户 {self.person_id} 的group_cardname字段失败,使用默认值")
|
||||
self.group_nick_name = []
|
||||
self.group_cardname_list = []
|
||||
else:
|
||||
self.group_nick_name = []
|
||||
self.group_cardname_list = []
|
||||
|
||||
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
|
||||
else:
|
||||
@@ -486,11 +509,7 @@ class Person:
|
||||
if self.memory_points
|
||||
else json.dumps([], ensure_ascii=False)
|
||||
)
|
||||
group_nickname_value = (
|
||||
json.dumps(self.group_nick_name, ensure_ascii=False)
|
||||
if self.group_nick_name
|
||||
else json.dumps([], ensure_ascii=False)
|
||||
)
|
||||
group_cardname_value = dump_group_cardname_records(self.group_cardname_list)
|
||||
first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None
|
||||
last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None
|
||||
|
||||
@@ -510,7 +529,7 @@ class Person:
|
||||
record.first_known_time = first_known_time
|
||||
record.last_known_time = last_known_time
|
||||
record.memory_points = memory_points_value
|
||||
record.group_nickname = group_nickname_value
|
||||
record.group_cardname = group_cardname_value
|
||||
session.add(record)
|
||||
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
|
||||
else:
|
||||
@@ -526,7 +545,7 @@ class Person:
|
||||
first_known_time=first_known_time,
|
||||
last_known_time=last_known_time,
|
||||
memory_points=memory_points_value,
|
||||
group_nickname=group_nickname_value,
|
||||
group_cardname=group_cardname_value,
|
||||
)
|
||||
session.add(record)
|
||||
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
@@ -8,13 +9,15 @@ import sys
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
|
||||
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, RouteMode, get_platform_io_manager
|
||||
from src.platform_io.drivers import PluginPlatformDriver
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
from src.platform_io.routing import RouteBindingConflictError
|
||||
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
AdapterDeclarationPayload,
|
||||
AdapterStateUpdatePayload,
|
||||
AdapterStateUpdateResultPayload,
|
||||
BootstrapPluginPayload,
|
||||
ConfigUpdatedPayload,
|
||||
Envelope,
|
||||
@@ -46,6 +49,19 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = get_logger("plugin_runtime.host.runner_manager")
|
||||
|
||||
_ADAPTER_BINDING_ROLE_RUNTIME_EXACT = "runtime_exact"
|
||||
_ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT = "platform_default"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _AdapterRuntimeState:
|
||||
"""保存适配器插件当前的运行时连接状态。"""
|
||||
|
||||
connected: bool = False
|
||||
account_id: Optional[str] = None
|
||||
scope: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PluginRunnerSupervisor:
|
||||
"""插件 Runner 监督器。
|
||||
@@ -94,6 +110,7 @@ class PluginRunnerSupervisor:
|
||||
self._runner_process: Optional[asyncio.subprocess.Process] = None
|
||||
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
|
||||
self._registered_adapters: Dict[str, AdapterDeclarationPayload] = {}
|
||||
self._adapter_runtime_states: Dict[str, _AdapterRuntimeState] = {}
|
||||
self._runner_ready_events: asyncio.Event = asyncio.Event()
|
||||
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
|
||||
self._health_task: Optional[asyncio.Task[None]] = None
|
||||
@@ -452,6 +469,7 @@ class PluginRunnerSupervisor:
|
||||
"""注册 Host 侧内部 RPC 方法。"""
|
||||
self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request)
|
||||
self._rpc_server.register_method("host.receive_external_message", self._handle_receive_external_message)
|
||||
self._rpc_server.register_method("host.update_adapter_state", self._handle_update_adapter_state)
|
||||
self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin)
|
||||
self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin)
|
||||
self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin)
|
||||
@@ -563,14 +581,14 @@ class PluginRunnerSupervisor:
|
||||
return f"adapter:{plugin_id}"
|
||||
|
||||
async def _register_adapter_driver(self, plugin_id: str, adapter: AdapterDeclarationPayload) -> None:
|
||||
"""将适配器插件注册到 Platform IO。
|
||||
"""将适配器插件驱动注册到 Platform IO。
|
||||
|
||||
Args:
|
||||
plugin_id: 适配器插件 ID。
|
||||
adapter: 经过校验的适配器声明。
|
||||
|
||||
Raises:
|
||||
ValueError: 适配器路由冲突或驱动注册失败时抛出。
|
||||
ValueError: 当驱动注册失败时抛出。
|
||||
"""
|
||||
await self._unregister_adapter_driver(plugin_id)
|
||||
|
||||
@@ -588,22 +606,12 @@ class PluginRunnerSupervisor:
|
||||
**adapter.metadata,
|
||||
},
|
||||
)
|
||||
binding = RouteBinding(
|
||||
route_key=driver.descriptor.route_key,
|
||||
driver_id=driver.driver_id,
|
||||
driver_kind=DriverKind.PLUGIN,
|
||||
metadata={
|
||||
"plugin_id": plugin_id,
|
||||
"protocol": adapter.protocol,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
if platform_io_manager.is_started:
|
||||
await platform_io_manager.add_driver(driver)
|
||||
else:
|
||||
platform_io_manager.register_driver(driver)
|
||||
platform_io_manager.bind_route(binding)
|
||||
except Exception:
|
||||
with contextlib.suppress(Exception):
|
||||
if platform_io_manager.is_started:
|
||||
@@ -613,6 +621,7 @@ class PluginRunnerSupervisor:
|
||||
raise
|
||||
|
||||
self._registered_adapters[plugin_id] = adapter
|
||||
self._adapter_runtime_states[plugin_id] = _AdapterRuntimeState()
|
||||
|
||||
async def _unregister_adapter_driver(self, plugin_id: str) -> None:
|
||||
"""从 Platform IO 注销一个适配器驱动。
|
||||
@@ -622,6 +631,9 @@ class PluginRunnerSupervisor:
|
||||
"""
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
driver_id = self._build_adapter_driver_id(plugin_id)
|
||||
adapter = self._registered_adapters.get(plugin_id)
|
||||
|
||||
self._remove_adapter_route_bindings(plugin_id)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
if platform_io_manager.is_started:
|
||||
@@ -629,7 +641,11 @@ class PluginRunnerSupervisor:
|
||||
else:
|
||||
platform_io_manager.unregister_driver(driver_id)
|
||||
|
||||
if adapter is not None:
|
||||
self._refresh_platform_default_route(adapter.platform)
|
||||
|
||||
self._registered_adapters.pop(plugin_id, None)
|
||||
self._adapter_runtime_states.pop(plugin_id, None)
|
||||
|
||||
async def _unregister_all_adapter_drivers(self) -> None:
|
||||
"""注销当前 Supervisor 管理的全部适配器驱动。"""
|
||||
@@ -637,6 +653,198 @@ class PluginRunnerSupervisor:
|
||||
for plugin_id in plugin_ids:
|
||||
await self._unregister_adapter_driver(plugin_id)
|
||||
|
||||
def _remove_adapter_route_bindings(self, plugin_id: str) -> None:
|
||||
"""移除某个适配器驱动当前持有的全部路由绑定。
|
||||
|
||||
Args:
|
||||
plugin_id: 适配器插件 ID。
|
||||
"""
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
platform_io_manager.route_table.remove_bindings_by_driver(self._build_adapter_driver_id(plugin_id))
|
||||
|
||||
@staticmethod
|
||||
def _normalize_runtime_route_value(value: str) -> Optional[str]:
|
||||
"""规范化适配器运行时路由字段。
|
||||
|
||||
Args:
|
||||
value: 待规范化的原始字符串。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 规范化后非空则返回字符串,否则返回 ``None``。
|
||||
"""
|
||||
normalized_value = str(value).strip()
|
||||
return normalized_value or None
|
||||
|
||||
def _build_runtime_route_key(
|
||||
self,
|
||||
adapter: AdapterDeclarationPayload,
|
||||
payload: AdapterStateUpdatePayload,
|
||||
) -> RouteKey:
|
||||
"""根据运行时状态更新构造适配器生效路由键。
|
||||
|
||||
Args:
|
||||
adapter: 当前适配器声明。
|
||||
payload: 适配器上报的运行时状态。
|
||||
|
||||
Returns:
|
||||
RouteKey: 当前连接应接管的精确路由键。
|
||||
|
||||
Raises:
|
||||
ValueError: 当静态声明与运行时上报的身份信息冲突时抛出。
|
||||
"""
|
||||
runtime_account_id = self._normalize_runtime_route_value(payload.account_id)
|
||||
runtime_scope = self._normalize_runtime_route_value(payload.scope)
|
||||
|
||||
if adapter.account_id and runtime_account_id and adapter.account_id != runtime_account_id:
|
||||
raise ValueError(
|
||||
f"适配器声明的 account_id={adapter.account_id} 与运行时上报的 {runtime_account_id} 不一致"
|
||||
)
|
||||
if adapter.scope and runtime_scope and adapter.scope != runtime_scope:
|
||||
raise ValueError(f"适配器声明的 scope={adapter.scope} 与运行时上报的 {runtime_scope} 不一致")
|
||||
|
||||
return RouteKey(
|
||||
platform=adapter.platform,
|
||||
account_id=runtime_account_id or adapter.account_id or None,
|
||||
scope=runtime_scope or adapter.scope or None,
|
||||
)
|
||||
|
||||
def _bind_runtime_exact_route(
|
||||
self,
|
||||
plugin_id: str,
|
||||
adapter: AdapterDeclarationPayload,
|
||||
route_key: RouteKey,
|
||||
) -> None:
|
||||
"""为适配器连接绑定精确生效路由。
|
||||
|
||||
Args:
|
||||
plugin_id: 适配器插件 ID。
|
||||
adapter: 当前适配器声明。
|
||||
route_key: 当前连接对应的精确路由键。
|
||||
|
||||
Raises:
|
||||
RouteBindingConflictError: 当目标路由已被其他 active owner 占用时抛出。
|
||||
"""
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
platform_io_manager.bind_route(
|
||||
RouteBinding(
|
||||
route_key=route_key,
|
||||
driver_id=self._build_adapter_driver_id(plugin_id),
|
||||
driver_kind=DriverKind.PLUGIN,
|
||||
metadata={
|
||||
"plugin_id": plugin_id,
|
||||
"protocol": adapter.protocol,
|
||||
"binding_role": _ADAPTER_BINDING_ROLE_RUNTIME_EXACT,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
def _list_runtime_exact_bindings(self, platform: str) -> List[RouteBinding]:
|
||||
"""列出某个平台上由 Host 动态维护的精确适配器绑定。
|
||||
|
||||
Args:
|
||||
platform: 目标平台名称。
|
||||
|
||||
Returns:
|
||||
List[RouteBinding]: 当前平台上全部动态精确绑定。
|
||||
"""
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
return [
|
||||
binding
|
||||
for binding in platform_io_manager.route_table.list_bindings()
|
||||
if binding.mode == RouteMode.ACTIVE
|
||||
and binding.route_key.platform == platform
|
||||
and binding.metadata.get("binding_role") == _ADAPTER_BINDING_ROLE_RUNTIME_EXACT
|
||||
]
|
||||
|
||||
def _refresh_platform_default_route(self, platform: str) -> None:
|
||||
"""根据当前精确绑定数量刷新平台级默认路由。
|
||||
|
||||
当某个平台恰好只存在一个动态精确绑定时,会为该绑定额外创建一条
|
||||
``RouteKey(platform=<platform>)`` 形式的默认路由,方便缺少账号维度的
|
||||
出站消息继续找到唯一 owner。若精确绑定数量变为 0 或大于 1,则撤销
|
||||
由 Host 自动维护的默认路由,避免出现隐式歧义。
|
||||
|
||||
Args:
|
||||
platform: 目标平台名称。
|
||||
"""
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
default_route_key = RouteKey(platform=platform)
|
||||
existing_default_binding = platform_io_manager.route_table.get_active_binding(default_route_key, exact_only=True)
|
||||
|
||||
if existing_default_binding is not None:
|
||||
binding_role = existing_default_binding.metadata.get("binding_role")
|
||||
if binding_role != _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT:
|
||||
return
|
||||
platform_io_manager.unbind_route(default_route_key, existing_default_binding.driver_id)
|
||||
|
||||
exact_bindings = self._list_runtime_exact_bindings(platform)
|
||||
if len(exact_bindings) != 1:
|
||||
return
|
||||
|
||||
exact_binding = exact_bindings[0]
|
||||
if exact_binding.route_key == default_route_key:
|
||||
return
|
||||
|
||||
platform_io_manager.bind_route(
|
||||
RouteBinding(
|
||||
route_key=default_route_key,
|
||||
driver_id=exact_binding.driver_id,
|
||||
driver_kind=exact_binding.driver_kind,
|
||||
metadata={
|
||||
"plugin_id": exact_binding.metadata.get("plugin_id", ""),
|
||||
"protocol": exact_binding.metadata.get("protocol", ""),
|
||||
"binding_role": _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT,
|
||||
},
|
||||
),
|
||||
replace=True,
|
||||
)
|
||||
|
||||
def _apply_adapter_runtime_state(
|
||||
self,
|
||||
plugin_id: str,
|
||||
adapter: AdapterDeclarationPayload,
|
||||
payload: AdapterStateUpdatePayload,
|
||||
) -> Tuple[_AdapterRuntimeState, Dict[str, Any]]:
|
||||
"""应用适配器运行时状态,并同步 Platform IO 路由。
|
||||
|
||||
Args:
|
||||
plugin_id: 适配器插件 ID。
|
||||
adapter: 当前适配器声明。
|
||||
payload: 适配器上报的运行时状态。
|
||||
|
||||
Returns:
|
||||
Tuple[_AdapterRuntimeState, Dict[str, Any]]: 更新后的运行时状态,以及
|
||||
供 RPC 响应返回的路由键字典。
|
||||
|
||||
Raises:
|
||||
RouteBindingConflictError: 当新的精确路由与其他 active owner 冲突时抛出。
|
||||
ValueError: 当运行时路由信息不合法时抛出。
|
||||
"""
|
||||
if not payload.connected:
|
||||
self._remove_adapter_route_bindings(plugin_id)
|
||||
self._refresh_platform_default_route(adapter.platform)
|
||||
runtime_state = _AdapterRuntimeState(connected=False, metadata=dict(payload.metadata))
|
||||
self._adapter_runtime_states[plugin_id] = runtime_state
|
||||
return runtime_state, {}
|
||||
|
||||
route_key = self._build_runtime_route_key(adapter, payload)
|
||||
self._remove_adapter_route_bindings(plugin_id)
|
||||
self._bind_runtime_exact_route(plugin_id, adapter, route_key)
|
||||
self._refresh_platform_default_route(adapter.platform)
|
||||
|
||||
runtime_state = _AdapterRuntimeState(
|
||||
connected=True,
|
||||
account_id=route_key.account_id,
|
||||
scope=route_key.scope,
|
||||
metadata=dict(payload.metadata),
|
||||
)
|
||||
self._adapter_runtime_states[plugin_id] = runtime_state
|
||||
return runtime_state, {
|
||||
"platform": route_key.platform,
|
||||
"account_id": route_key.account_id,
|
||||
"scope": route_key.scope,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _attach_inbound_route_metadata(
|
||||
session_message: "SessionMessage",
|
||||
@@ -706,6 +914,45 @@ class PluginRunnerSupervisor:
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
async def _handle_update_adapter_state(self, envelope: Envelope) -> Envelope:
|
||||
"""处理适配器插件上报的运行时状态更新。
|
||||
|
||||
Args:
|
||||
envelope: RPC 请求信封。
|
||||
|
||||
Returns:
|
||||
Envelope: 状态更新处理结果。
|
||||
"""
|
||||
try:
|
||||
payload = AdapterStateUpdatePayload.model_validate(envelope.payload)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
adapter = self._registered_adapters.get(envelope.plugin_id)
|
||||
if adapter is None:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"插件 {envelope.plugin_id} 未声明为适配器,不能更新运行时状态",
|
||||
)
|
||||
|
||||
try:
|
||||
runtime_state, route_key_dict = self._apply_adapter_runtime_state(
|
||||
plugin_id=envelope.plugin_id,
|
||||
adapter=adapter,
|
||||
payload=payload,
|
||||
)
|
||||
except RouteBindingConflictError as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, str(exc))
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
response = AdapterStateUpdateResultPayload(
|
||||
accepted=True,
|
||||
connected=runtime_state.connected,
|
||||
route_key=route_key_dict,
|
||||
)
|
||||
return envelope.make_response(payload=response.model_dump())
|
||||
|
||||
async def _handle_receive_external_message(self, envelope: Envelope) -> Envelope:
|
||||
"""处理适配器插件上报的外部入站消息。
|
||||
|
||||
@@ -970,6 +1217,7 @@ class PluginRunnerSupervisor:
|
||||
self._component_registry.clear()
|
||||
self._registered_plugins.clear()
|
||||
self._registered_adapters.clear()
|
||||
self._adapter_runtime_states.clear()
|
||||
self._runner_ready_events = asyncio.Event()
|
||||
self._runner_ready_payloads = RunnerReadyPayload()
|
||||
self._rpc_server.clear_handshake_state()
|
||||
|
||||
@@ -304,6 +304,30 @@ class AdapterDeclarationPayload(BaseModel):
|
||||
"""适配器附加元数据"""
|
||||
|
||||
|
||||
class AdapterStateUpdatePayload(BaseModel):
|
||||
"""适配器运行时状态更新载荷。"""
|
||||
|
||||
connected: bool = Field(description="适配器当前是否已连接并准备接管路由")
|
||||
"""适配器当前是否已连接并准备接管路由"""
|
||||
account_id: str = Field(default="", description="当前连接对应的账号 ID 或 self_id")
|
||||
"""当前连接对应的账号 ID 或 self_id"""
|
||||
scope: str = Field(default="", description="当前连接对应的可选路由作用域")
|
||||
"""当前连接对应的可选路由作用域"""
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的运行时状态元数据")
|
||||
"""可选的运行时状态元数据"""
|
||||
|
||||
|
||||
class AdapterStateUpdateResultPayload(BaseModel):
|
||||
"""适配器运行时状态更新结果载荷。"""
|
||||
|
||||
accepted: bool = Field(description="Host 是否接受了本次状态更新")
|
||||
"""Host 是否接受了本次状态更新"""
|
||||
connected: bool = Field(description="Host 记录的当前连接状态")
|
||||
"""Host 记录的当前连接状态"""
|
||||
route_key: Dict[str, Any] = Field(default_factory=dict, description="当前生效的路由键")
|
||||
"""当前生效的路由键"""
|
||||
|
||||
|
||||
class ReceiveExternalMessagePayload(BaseModel):
|
||||
"""适配器插件向 Host 注入外部消息的请求载荷。"""
|
||||
|
||||
|
||||
@@ -481,13 +481,14 @@ class PluginRunner:
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return False
|
||||
|
||||
if not await self._invoke_plugin_on_load(meta):
|
||||
if not await self._register_plugin(meta):
|
||||
await self._invoke_plugin_on_unload(meta)
|
||||
await self._deactivate_plugin(meta)
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return False
|
||||
|
||||
if not await self._register_plugin(meta):
|
||||
await self._invoke_plugin_on_unload(meta)
|
||||
if not await self._invoke_plugin_on_load(meta):
|
||||
await self._unregister_plugin(meta.plugin_id, reason="on_load_failed")
|
||||
await self._deactivate_plugin(meta)
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return False
|
||||
|
||||
@@ -60,6 +60,9 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
self._connection_task: Optional[asyncio.Task[None]] = None
|
||||
self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
|
||||
self._background_tasks: Set[asyncio.Task[Any]] = set()
|
||||
self._reported_account_id: Optional[str] = None
|
||||
self._reported_scope: Optional[str] = None
|
||||
self._runtime_state_connected: bool = False
|
||||
self._send_lock = asyncio.Lock()
|
||||
self._ws: Optional[AiohttpClientWebSocketResponse] = None
|
||||
|
||||
@@ -170,6 +173,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await connection_task
|
||||
|
||||
await self._report_adapter_disconnected()
|
||||
self._fail_pending_actions("NapCat connection closed")
|
||||
|
||||
async def _cancel_background_tasks(self) -> None:
|
||||
@@ -209,6 +213,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
self.ctx.logger.warning(f"NapCat 适配器连接失败: {exc}")
|
||||
finally:
|
||||
self._ws = None
|
||||
await self._report_adapter_disconnected()
|
||||
self._fail_pending_actions("NapCat connection interrupted")
|
||||
|
||||
if not self._should_connect():
|
||||
@@ -230,26 +235,39 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
"""
|
||||
assert WSMsgType is not None
|
||||
|
||||
async for ws_message in ws:
|
||||
if ws_message.type != WSMsgType.TEXT:
|
||||
if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
|
||||
break
|
||||
continue
|
||||
bootstrap_task = asyncio.create_task(
|
||||
self._bootstrap_adapter_runtime_state(),
|
||||
name="napcat_adapter.bootstrap",
|
||||
)
|
||||
self._background_tasks.add(bootstrap_task)
|
||||
bootstrap_task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
payload = self._parse_json_message(ws_message.data)
|
||||
if payload is None:
|
||||
continue
|
||||
try:
|
||||
async for ws_message in ws:
|
||||
if ws_message.type != WSMsgType.TEXT:
|
||||
if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
|
||||
break
|
||||
continue
|
||||
|
||||
if echo_id := str(payload.get("echo") or "").strip():
|
||||
self._resolve_pending_action(echo_id, payload)
|
||||
continue
|
||||
payload = self._parse_json_message(ws_message.data)
|
||||
if payload is None:
|
||||
continue
|
||||
|
||||
if str(payload.get("post_type") or "").strip() != "message":
|
||||
continue
|
||||
if echo_id := str(payload.get("echo") or "").strip():
|
||||
self._resolve_pending_action(echo_id, payload)
|
||||
continue
|
||||
|
||||
task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound")
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
if str(payload.get("post_type") or "").strip() != "message":
|
||||
continue
|
||||
|
||||
task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound")
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
finally:
|
||||
if not bootstrap_task.done():
|
||||
bootstrap_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await bootstrap_task
|
||||
|
||||
async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None:
|
||||
"""处理单条 NapCat 入站消息并注入 Host。
|
||||
@@ -258,6 +276,9 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
payload: NapCat / OneBot 推送的原始事件数据。
|
||||
"""
|
||||
self_id = str(payload.get("self_id") or "").strip()
|
||||
if self_id:
|
||||
await self._report_adapter_connected(self_id)
|
||||
|
||||
sender = payload.get("sender", {})
|
||||
if not isinstance(sender, dict):
|
||||
sender = {}
|
||||
@@ -570,6 +591,121 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
response_future.set_exception(RuntimeError(error_message))
|
||||
self._pending_actions.clear()
|
||||
|
||||
async def _bootstrap_adapter_runtime_state(self) -> None:
|
||||
"""在连接建立后主动获取账号信息并激活适配器路由。
|
||||
|
||||
该步骤会在 WebSocket 接收循环启动后异步执行,确保 `_call_action()`
|
||||
发出的 `get_login_info` 请求能够被同一连接上的接收循环消费到 echo
|
||||
响应,从而在真正收到业务消息前就完成 Host 侧 route 激活。
|
||||
"""
|
||||
max_attempts = 3
|
||||
last_error: Optional[Exception] = None
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
ws = self._ws
|
||||
if ws is None or ws.closed:
|
||||
return
|
||||
|
||||
try:
|
||||
response = await self._call_action("get_login_info", {})
|
||||
self_id = self._extract_self_id_from_login_response(response)
|
||||
await self._report_adapter_connected(self_id)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
self.ctx.logger.warning(
|
||||
f"NapCat 适配器获取登录信息失败,第 {attempt}/{max_attempts} 次重试: {exc}"
|
||||
)
|
||||
if attempt < max_attempts:
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
if last_error is not None:
|
||||
self.ctx.logger.error(f"NapCat 适配器未能完成路由激活,连接将保持只接收状态: {last_error}")
|
||||
|
||||
@staticmethod
|
||||
def _extract_self_id_from_login_response(response: Dict[str, Any]) -> str:
|
||||
"""从 `get_login_info` 响应中提取当前账号 ID。
|
||||
|
||||
Args:
|
||||
response: NapCat 返回的原始动作响应。
|
||||
|
||||
Returns:
|
||||
str: 规范化后的 `self_id` 字符串。
|
||||
|
||||
Raises:
|
||||
ValueError: 当响应中缺少有效账号 ID 时抛出。
|
||||
"""
|
||||
if str(response.get("status") or "").lower() != "ok":
|
||||
raise ValueError(str(response.get("wording") or response.get("message") or "get_login_info failed"))
|
||||
|
||||
response_data = response.get("data", {})
|
||||
if not isinstance(response_data, dict):
|
||||
raise ValueError("get_login_info 响应缺少 data 字段")
|
||||
|
||||
self_id = str(response_data.get("user_id") or "").strip()
|
||||
if not self_id:
|
||||
raise ValueError("get_login_info 响应缺少有效的 user_id")
|
||||
return self_id
|
||||
|
||||
async def _report_adapter_connected(self, account_id: str) -> None:
|
||||
"""向 Host 上报当前连接已就绪。
|
||||
|
||||
Args:
|
||||
account_id: 当前 NapCat 连接对应的机器人账号 ID。
|
||||
"""
|
||||
normalized_account_id = str(account_id).strip()
|
||||
if not normalized_account_id:
|
||||
return
|
||||
|
||||
scope = self._get_string(self._connection_config(), "connection_id").strip()
|
||||
if (
|
||||
self._runtime_state_connected
|
||||
and self._reported_account_id == normalized_account_id
|
||||
and self._reported_scope == (scope or None)
|
||||
):
|
||||
return
|
||||
|
||||
accepted = False
|
||||
try:
|
||||
accepted = await self.ctx.adapter.update_runtime_state(
|
||||
connected=True,
|
||||
account_id=normalized_account_id,
|
||||
scope=scope,
|
||||
metadata={"ws_url": self._get_string(self._connection_config(), "ws_url")},
|
||||
)
|
||||
except Exception as exc:
|
||||
self.ctx.logger.warning(f"NapCat 适配器上报连接就绪状态失败: {exc}")
|
||||
return
|
||||
|
||||
if not accepted:
|
||||
self.ctx.logger.warning("NapCat 适配器连接已建立,但 Host 未接受运行时状态更新")
|
||||
return
|
||||
|
||||
self._runtime_state_connected = True
|
||||
self._reported_account_id = normalized_account_id
|
||||
self._reported_scope = scope or None
|
||||
self.ctx.logger.info(
|
||||
f"NapCat 适配器已激活路由: platform=qq account_id={normalized_account_id} "
|
||||
f"scope={self._reported_scope or '*'}"
|
||||
)
|
||||
|
||||
async def _report_adapter_disconnected(self) -> None:
|
||||
"""向 Host 上报当前连接已断开,并撤销适配器路由。"""
|
||||
if not self._runtime_state_connected:
|
||||
self._reported_account_id = None
|
||||
self._reported_scope = None
|
||||
return
|
||||
|
||||
try:
|
||||
await self.ctx.adapter.update_runtime_state(connected=False)
|
||||
except Exception as exc:
|
||||
self.ctx.logger.warning(f"NapCat 适配器上报断开状态失败: {exc}")
|
||||
finally:
|
||||
self._runtime_state_connected = False
|
||||
self._reported_account_id = None
|
||||
self._reported_scope = None
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""构造连接 NapCat 所需的请求头。
|
||||
|
||||
|
||||
Reference in New Issue
Block a user