From 4e2e7a279e42d7f33a41bdeaf7c9b1bedd4f9953 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sat, 21 Mar 2026 21:47:22 +0800 Subject: [PATCH] feat: Implement adapter runtime state management and update handling - Added support for adapter runtime state updates in the PluginRunnerSupervisor. - Introduced new payload classes: AdapterStateUpdatePayload and AdapterStateUpdateResultPayload for handling state updates. - Implemented methods to bind and unbind routes based on adapter connection status. - Enhanced the NapCat adapter to report connection state and manage runtime state. - Added tests for adapter runtime state synchronization and database session behavior in the statistic module. - Updated existing methods to ensure proper handling of adapter state and route bindings. --- .../test_person_info_group_cardname.py | 355 ++++++++++++++++++ pytests/test_adapter_runtime_state.py | 162 ++++++++ pytests/test_plugin_runtime.py | 6 +- pytests/utils_test/statistic_test.py | 115 ++++++ src/chat/utils/statistic.py | 12 +- .../data_models/person_info_data_model.py | 96 ++++- src/person_info/person_info.py | 79 ++-- src/plugin_runtime/host/supervisor.py | 274 +++++++++++++- src/plugin_runtime/protocol/envelope.py | 24 ++ src/plugin_runtime/runner/runner_main.py | 7 +- src/plugins/built_in/napcat_adapter/plugin.py | 168 ++++++++- 11 files changed, 1219 insertions(+), 79 deletions(-) create mode 100644 pytests/common_test/test_person_info_group_cardname.py create mode 100644 pytests/test_adapter_runtime_state.py create mode 100644 pytests/utils_test/statistic_test.py diff --git a/pytests/common_test/test_person_info_group_cardname.py b/pytests/common_test/test_person_info_group_cardname.py new file mode 100644 index 00000000..62a63f43 --- /dev/null +++ b/pytests/common_test/test_person_info_group_cardname.py @@ -0,0 +1,355 @@ +"""人物信息群名片字段兼容测试。""" + +from __future__ import annotations + +from importlib.util import module_from_spec, spec_from_file_location +from pathlib import Path +from types import ModuleType, SimpleNamespace +from typing import Any + +import json +import sys + +import pytest + +from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json + + +class _DummyLogger: + """模拟日志记录器。""" + + def debug(self, message: str) -> None: + """记录调试日志。 + + Args: + message: 日志内容。 + """ + del message + + def info(self, message: str) -> None: + """记录信息日志。 + + Args: + message: 日志内容。 + """ + del message + + def warning(self, message: str) -> None: + """记录警告日志。 + + Args: + message: 日志内容。 + """ + del message + + def error(self, message: str) -> None: + """记录错误日志。 + + Args: + message: 日志内容。 + """ + del message + + +class _DummyStatement: + """模拟 SQL 查询语句对象。""" + + def where(self, condition: Any) -> "_DummyStatement": + """附加过滤条件。 + + Args: + condition: 过滤条件。 + + Returns: + _DummyStatement: 当前语句对象。 + """ + del condition + return self + + def limit(self, value: int) -> "_DummyStatement": + """限制返回条数。 + + Args: + value: 条数限制。 + + Returns: + _DummyStatement: 当前语句对象。 + """ + del value + return self + + +class _DummyColumn: + """模拟 SQLModel 列对象。""" + + def is_not(self, value: Any) -> "_DummyColumn": + """模拟 `IS NOT` 条件构造。 + + Args: + value: 比较值。 + + Returns: + _DummyColumn: 当前列对象。 + """ + del value + return self + + def __eq__(self, other: Any) -> "_DummyColumn": + """模拟等值条件构造。 + + Args: + other: 比较值。 + + Returns: + _DummyColumn: 当前列对象。 + """ + del other + return self + + +class _DummyResult: + """模拟数据库查询结果。""" + + def __init__(self, record: Any) -> None: + """初始化查询结果。 + + Args: + record: 待返回的首条记录。 + """ + self._record = record + + def first(self) -> Any: + """返回第一条记录。 + + Returns: + Any: 首条记录。 + """ + return self._record + + def all(self) -> list[Any]: + """返回全部结果。 + + Returns: + list[Any]: 结果列表。 + """ + if self._record is None: + return [] + return self._record if isinstance(self._record, list) else [self._record] + + +class _DummySession: + """模拟数据库 Session。""" + + def __init__(self, record: Any) -> None: + """初始化 Session。 + + Args: + record: `first()` 应返回的记录。 + """ + self.record = record + self.added_records: list[Any] = [] + + def __enter__(self) -> "_DummySession": + """进入上下文管理器。 + + Returns: + _DummySession: 当前 Session。 + """ + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """退出上下文管理器。 + + Args: + exc_type: 异常类型。 + exc_val: 异常值。 + exc_tb: 异常回溯。 + """ + del exc_type + del exc_val + del exc_tb + + def exec(self, statement: Any) -> _DummyResult: + """执行查询。 + + Args: + statement: 查询语句。 + + Returns: + _DummyResult: 模拟结果对象。 + """ + del statement + return _DummyResult(self.record) + + def add(self, record: Any) -> None: + """记录被添加的对象。 + + Args: + record: 被写入 Session 的对象。 + """ + self.added_records.append(record) + + +class _DummyPersonInfoRecord: + """模拟 `PersonInfo` ORM 模型。""" + + person_id = "person_id" + person_name = "person_name" + + def __init__(self, **kwargs: Any) -> None: + """使用关键字参数初始化记录对象。 + + Args: + **kwargs: 字段值。 + """ + for key, value in kwargs.items(): + setattr(self, key, value) + + +def _load_person_module(monkeypatch: pytest.MonkeyPatch, session: _DummySession) -> ModuleType: + """加载带依赖桩的 `person_info` 模块。 + + Args: + monkeypatch: Pytest monkeypatch 工具。 + session: 提供给模块使用的假数据库 Session。 + + Returns: + ModuleType: 加载后的模块对象。 + """ + logger_module = ModuleType("src.common.logger") + logger_module.get_logger = lambda name: _DummyLogger() + monkeypatch.setitem(sys.modules, "src.common.logger", logger_module) + + database_module = ModuleType("src.common.database.database") + database_module.get_db_session = lambda: session + monkeypatch.setitem(sys.modules, "src.common.database.database", database_module) + + database_model_module = ModuleType("src.common.database.database_model") + database_model_module.PersonInfo = _DummyPersonInfoRecord + monkeypatch.setitem(sys.modules, "src.common.database.database_model", database_model_module) + + llm_module = ModuleType("src.llm_models.utils_model") + + class _DummyLLMRequest: + """模拟 LLMRequest。""" + + def __init__(self, model_set: Any, request_type: str) -> None: + """初始化假请求对象。 + + Args: + model_set: 模型配置。 + request_type: 请求类型。 + """ + del model_set + del request_type + + llm_module.LLMRequest = _DummyLLMRequest + monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_module) + + config_module = ModuleType("src.config.config") + config_module.global_config = SimpleNamespace(bot=SimpleNamespace(nickname="MaiBot")) + config_module.model_config = SimpleNamespace(model_task_config=SimpleNamespace(tool_use="tool_use", utils="utils")) + monkeypatch.setitem(sys.modules, "src.config.config", config_module) + + chat_manager_module = ModuleType("src.chat.message_receive.chat_manager") + chat_manager_module.chat_manager = SimpleNamespace() + monkeypatch.setitem(sys.modules, "src.chat.message_receive.chat_manager", chat_manager_module) + + module_path = Path(__file__).resolve().parents[2] / "src" / "person_info" / "person_info.py" + spec = spec_from_file_location("person_info_group_cardname_test_module", module_path) + assert spec is not None and spec.loader is not None + + module = module_from_spec(spec) + monkeypatch.setitem(sys.modules, spec.name, module) + spec.loader.exec_module(module) + + monkeypatch.setattr(module, "select", lambda *args: _DummyStatement()) + monkeypatch.setattr(module, "col", lambda field: _DummyColumn()) + return module + + +def test_parse_group_cardname_json_uses_canonical_key() -> None: + """群名片 JSON 解析应只使用 `group_cardname` 键名。""" + parsed = parse_group_cardname_json( + json.dumps( + [ + {"group_id": "1001", "group_cardname": "现行字段"}, + ], + ensure_ascii=False, + ) + ) + + assert parsed is not None + assert [(item.group_id, item.group_cardname) for item in parsed] == [ + ("1001", "现行字段"), + ] + + +def test_dump_group_cardname_records_uses_canonical_key() -> None: + """群名片序列化应输出 `group_cardname` 键名。""" + dumped = dump_group_cardname_records( + [ + {"group_id": "1001", "group_cardname": "群昵称"}, + ] + ) + + assert json.loads(dumped) == [{"group_id": "1001", "group_cardname": "群昵称"}] + + +def test_person_sync_to_database_uses_group_cardname_field(monkeypatch: pytest.MonkeyPatch) -> None: + """同步人物信息时应写入数据库模型的 `group_cardname` 字段。""" + record = _DummyPersonInfoRecord() + session = _DummySession(record) + module = _load_person_module(monkeypatch, session) + + person = module.Person.__new__(module.Person) + person.is_known = True + person.person_id = "person-1" + person.platform = "qq" + person.user_id = "10001" + person.nickname = "看番的龙" + person.person_name = "看番的龙" + person.name_reason = "测试" + person.know_times = 1 + person.know_since = 1700000000.0 + person.last_know = 1700000100.0 + person.memory_points = ["喜好:番剧:0.8"] + person.group_cardname_list = [{"group_id": "20001", "group_cardname": "白泽大人"}] + + person.sync_to_database() + + assert record.group_cardname == '[{"group_id": "20001", "group_cardname": "白泽大人"}]' + assert not hasattr(record, "group_nickname") + + +def test_person_load_from_database_normalizes_group_cardname_payload(monkeypatch: pytest.MonkeyPatch) -> None: + """从数据库加载人物信息时应读取标准 `group_cardname` 结构。""" + record = _DummyPersonInfoRecord( + user_id="10001", + platform="qq", + is_known=True, + user_nickname="看番的龙", + person_name="看番的龙", + name_reason=None, + know_counts=2, + memory_points='["喜好:番剧:0.8"]', + group_cardname=json.dumps( + [ + {"group_id": "20001", "group_cardname": "白泽大人"}, + ], + ensure_ascii=False, + ), + ) + session = _DummySession(record) + module = _load_person_module(monkeypatch, session) + + person = module.Person.__new__(module.Person) + person.person_id = "person-1" + person.memory_points = [] + person.group_cardname_list = [] + + person.load_from_database() + + assert person.group_cardname_list == [ + {"group_id": "20001", "group_cardname": "白泽大人"}, + ] diff --git a/pytests/test_adapter_runtime_state.py b/pytests/test_adapter_runtime_state.py new file mode 100644 index 00000000..e82f4c8c --- /dev/null +++ b/pytests/test_adapter_runtime_state.py @@ -0,0 +1,162 @@ +"""适配器运行时状态同步测试。""" + +from typing import Any, Dict + +import pytest + +from src.platform_io.manager import PlatformIOManager +from src.platform_io.types import RouteKey +from src.plugin_runtime.host.supervisor import PluginSupervisor +from src.plugin_runtime.protocol.envelope import ( + AdapterDeclarationPayload, + Envelope, + MessageType, +) + + +def _make_request(plugin_id: str, payload: Dict[str, Any]) -> Envelope: + """构造一个适配器状态更新 RPC 请求。 + + Args: + plugin_id: 目标适配器插件 ID。 + payload: 请求载荷。 + + Returns: + Envelope: 标准 RPC 请求信封。 + """ + return Envelope( + request_id=1, + message_type=MessageType.REQUEST, + method="host.update_adapter_state", + plugin_id=plugin_id, + payload=payload, + ) + + +@pytest.mark.asyncio +async def test_adapter_runtime_state_binds_and_unbinds_routes(monkeypatch: pytest.MonkeyPatch) -> None: + """连接建立后应绑定路由,断开后应撤销路由。""" + import src.plugin_runtime.host.supervisor as supervisor_module + + platform_io_manager = PlatformIOManager() + monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager) + + supervisor = PluginSupervisor(plugin_dirs=[]) + adapter = AdapterDeclarationPayload(platform="qq", protocol="napcat") + await supervisor._register_adapter_driver("napcat_adapter_builtin", adapter) + + response = await supervisor._handle_update_adapter_state( + _make_request( + "napcat_adapter_builtin", + { + "connected": True, + "account_id": "10001", + "scope": "", + "metadata": {}, + }, + ) + ) + + assert response.error is None + assert response.payload["accepted"] is True + assert ( + platform_io_manager.route_table.get_active_binding( + RouteKey(platform="qq", account_id="10001"), + exact_only=True, + ).driver_id + == "adapter:napcat_adapter_builtin" + ) + assert ( + platform_io_manager.route_table.get_active_binding( + RouteKey(platform="qq"), + exact_only=True, + ).driver_id + == "adapter:napcat_adapter_builtin" + ) + + response = await supervisor._handle_update_adapter_state( + _make_request( + "napcat_adapter_builtin", + { + "connected": False, + "account_id": "", + "scope": "", + "metadata": {}, + }, + ) + ) + + assert response.error is None + assert response.payload["accepted"] is True + assert platform_io_manager.route_table.get_active_binding( + RouteKey(platform="qq", account_id="10001"), + exact_only=True, + ) is None + assert platform_io_manager.route_table.get_active_binding(RouteKey(platform="qq"), exact_only=True) is None + + +@pytest.mark.asyncio +async def test_platform_default_route_is_removed_when_multiple_exact_routes_exist( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """同一平台存在多个精确路由时不应保留默认平台路由。""" + import src.plugin_runtime.host.supervisor as supervisor_module + + platform_io_manager = PlatformIOManager() + monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager) + + supervisor = PluginSupervisor(plugin_dirs=[]) + adapter = AdapterDeclarationPayload(platform="qq", protocol="napcat") + await supervisor._register_adapter_driver("adapter_a", adapter) + await supervisor._register_adapter_driver("adapter_b", adapter) + + await supervisor._handle_update_adapter_state( + _make_request( + "adapter_a", + { + "connected": True, + "account_id": "10001", + "scope": "", + "metadata": {}, + }, + ) + ) + assert ( + platform_io_manager.route_table.get_active_binding( + RouteKey(platform="qq"), + exact_only=True, + ).driver_id + == "adapter:adapter_a" + ) + + await supervisor._handle_update_adapter_state( + _make_request( + "adapter_b", + { + "connected": True, + "account_id": "10002", + "scope": "", + "metadata": {}, + }, + ) + ) + assert platform_io_manager.route_table.get_active_binding(RouteKey(platform="qq"), exact_only=True) is None + + await supervisor._handle_update_adapter_state( + _make_request( + "adapter_b", + { + "connected": False, + "account_id": "", + "scope": "", + "metadata": {}, + }, + ) + ) + assert ( + platform_io_manager.route_table.get_active_binding( + RouteKey(platform="qq"), + exact_only=True, + ).driver_id + == "adapter:adapter_a" + ) diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index 2c703161..5ab16c85 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -486,10 +486,10 @@ class TestSDK: "timeout_ms": timeout_ms, } ) - if method == "cap.request": + if method == "cap.call": bootstrap_methods = [call["method"] for call in self.calls[:-1]] assert "plugin.bootstrap" in bootstrap_methods - return SimpleNamespace(error=None, payload={"result": {"success": True}}) + return SimpleNamespace(error=None, payload={"success": True}) return SimpleNamespace(error=None, payload={"accepted": True}) async def disconnect(self): @@ -529,7 +529,7 @@ class TestSDK: await runner.run() methods = [call["method"] for call in runner._rpc_client.calls] - assert methods == ["plugin.bootstrap", "cap.request", "plugin.register_components", "runner.ready"] + assert methods == ["plugin.bootstrap", "plugin.register_components", "cap.call", "runner.ready"] class TestPluginSdkUsage: diff --git a/pytests/utils_test/statistic_test.py b/pytests/utils_test/statistic_test.py new file mode 100644 index 00000000..d3d8c18a --- /dev/null +++ b/pytests/utils_test/statistic_test.py @@ -0,0 +1,115 @@ +"""统计模块数据库会话行为测试。""" + +from __future__ import annotations + +from contextlib import contextmanager +from datetime import datetime, timedelta +from types import ModuleType +from typing import Any, Callable, Iterator + +import sys + +import pytest + +from src.chat.utils import statistic + + +class _DummyResult: + """模拟 SQLModel 查询结果对象。""" + + def all(self) -> list[Any]: + """返回空结果集。 + + Returns: + list[Any]: 空列表。 + """ + return [] + + +class _DummySession: + """模拟数据库 Session。""" + + def exec(self, statement: Any) -> _DummyResult: + """执行查询语句并返回空结果。 + + Args: + statement: 待执行的查询语句。 + + Returns: + _DummyResult: 空结果对象。 + """ + del statement + return _DummyResult() + + +def _build_fake_get_db_session(calls: list[bool]) -> Callable[[bool], Iterator[_DummySession]]: + """构造一个记录 auto_commit 参数的假会话工厂。 + + Args: + calls: 用于记录每次调用 auto_commit 参数的列表。 + + Returns: + Callable[[bool], Iterator[_DummySession]]: 可替换 `get_db_session` 的上下文管理器工厂。 + """ + + @contextmanager + def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_DummySession]: + """记录会话参数并返回假 Session。 + + Args: + auto_commit: 是否启用自动提交。 + + Yields: + Iterator[_DummySession]: 假 Session 对象。 + """ + calls.append(auto_commit) + yield _DummySession() + + return _fake_get_db_session + + +def _build_statistic_task() -> statistic.StatisticOutputTask: + """构造一个最小可用的统计任务实例。 + + Returns: + statistic.StatisticOutputTask: 跳过 `__init__` 的测试实例。 + """ + task = statistic.StatisticOutputTask.__new__(statistic.StatisticOutputTask) + task.name_mapping = {} + return task + + +def _is_bot_self(platform: str, user_id: str) -> bool: + """返回固定的非机器人身份判断结果。 + + Args: + platform: 平台名称。 + user_id: 用户 ID。 + + Returns: + bool: 始终返回 ``False``。 + """ + del platform + del user_id + return False + + +def test_statistic_read_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None: + """统计模块的纯读查询应关闭自动提交,避免 Session 退出后对象被 expire。""" + calls: list[bool] = [] + now = datetime.now() + task = _build_statistic_task() + + monkeypatch.setattr(statistic, "get_db_session", _build_fake_get_db_session(calls)) + + utils_module = ModuleType("src.chat.utils.utils") + utils_module.is_bot_self = _is_bot_self + monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module) + + statistic.StatisticOutputTask._fetch_online_time_since(now) + statistic.StatisticOutputTask._fetch_model_usage_since(now) + task._collect_message_count_for_period([("last_hour", now - timedelta(hours=1))]) + task._collect_interval_data(now, hours=1, interval_minutes=60) + task._collect_metrics_interval_data(now, hours=1, interval_hours=1) + + assert calls == [False] * 9 diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index ede10a41..51e5e643 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -436,14 +436,14 @@ class StatisticOutputTask(AsyncTask): @staticmethod def _fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]: - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time) records = session.exec(statement).all() return [(record.start_timestamp, record.end_timestamp) for record in records] @staticmethod def _fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]: - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time) records = session.exec(statement).all() return [ @@ -664,7 +664,7 @@ class StatisticOutputTask(AsyncTask): } query_start_timestamp = collect_period[-1][1] - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(Messages).where(col(Messages.timestamp) >= query_start_timestamp) messages = session.exec(statement).all() for message in messages: @@ -713,7 +713,7 @@ class StatisticOutputTask(AsyncTask): # 使用 ActionRecords 中的 reply 动作次数作为回复数基准 try: action_query_start_timestamp = collect_period[-1][1] - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp) actions = session.exec(statement).all() for action in actions: @@ -1750,7 +1750,7 @@ class StatisticOutputTask(AsyncTask): # 查询消息记录 query_start_timestamp = start_time.timestamp() - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(Messages).where(col(Messages.timestamp) >= start_time) messages = session.exec(statement).all() for message in messages: @@ -2131,7 +2131,7 @@ class StatisticOutputTask(AsyncTask): # 查询消息记录 query_start_timestamp = start_time.timestamp() - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(Messages).where(col(Messages.timestamp) >= start_time) messages = session.exec(statement).all() for message in messages: diff --git a/src/common/data_models/person_info_data_model.py b/src/common/data_models/person_info_data_model.py index 4cbb62d8..1b239356 100644 --- a/src/common/data_models/person_info_data_model.py +++ b/src/common/data_models/person_info_data_model.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass +from dataclasses import asdict, dataclass from datetime import datetime -from typing import Optional, List +from typing import Any, List, Mapping, Optional, Sequence import json @@ -15,6 +15,76 @@ class GroupCardnameInfo: group_cardname: str +def _normalize_group_cardname_item(raw_item: Mapping[str, Any]) -> Optional[GroupCardnameInfo]: + """将单条群名片数据规范化为统一结构。 + + Args: + raw_item: 原始群名片字典,必须包含 `group_id` 和 `group_cardname`。 + + Returns: + Optional[GroupCardnameInfo]: 规范化后的群名片信息;若数据不完整则返回 ``None``。 + """ + group_id = str(raw_item.get("group_id") or "").strip() + group_cardname = str(raw_item.get("group_cardname") or "").strip() + if not group_id or not group_cardname: + return None + return GroupCardnameInfo(group_id=group_id, group_cardname=group_cardname) + + +def parse_group_cardname_json(group_cardname_json: Optional[str]) -> Optional[List[GroupCardnameInfo]]: + """解析数据库中的群名片 JSON 字段。 + + Args: + group_cardname_json: 数据库存储的群名片 JSON 字符串。 + + Returns: + Optional[List[GroupCardnameInfo]]: 解析并规范化后的群名片列表;若字段为空或无有效项则返回 ``None``。 + + Raises: + json.JSONDecodeError: 当 JSON 文本格式非法时抛出。 + TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。 + """ + if not group_cardname_json: + return None + + raw_items = json.loads(group_cardname_json) + if not isinstance(raw_items, list): + return None + + normalized_items: List[GroupCardnameInfo] = [] + for raw_item in raw_items: + if not isinstance(raw_item, Mapping): + continue + if normalized_item := _normalize_group_cardname_item(raw_item): + normalized_items.append(normalized_item) + + return normalized_items or None + + +def dump_group_cardname_records( + group_cardname_records: Optional[Sequence[GroupCardnameInfo | Mapping[str, Any]]], +) -> str: + """将群名片列表序列化为数据库使用的标准 JSON 字符串。 + + Args: + group_cardname_records: 待序列化的群名片列表,支持 `GroupCardnameInfo` + 对象和包含 `group_id` / `group_cardname` 的字典。 + + Returns: + str: 统一使用 `group_cardname` 键名的 JSON 字符串。 + """ + normalized_items: List[GroupCardnameInfo] = [] + for raw_item in group_cardname_records or []: + if isinstance(raw_item, GroupCardnameInfo): + normalized_items.append(raw_item) + continue + if isinstance(raw_item, Mapping): + if normalized_item := _normalize_group_cardname_item(raw_item): + normalized_items.append(normalized_item) + + return json.dumps([asdict(item) for item in normalized_items], ensure_ascii=False) + + class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]): def __init__( self, @@ -58,9 +128,16 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]): """最后一次被认识的时间""" @classmethod - def from_db_instance(cls, db_record: "PersonInfo"): - nickname_json = json.loads(db_record.group_cardname) if db_record.group_cardname else None - group_cardname_list = [GroupCardnameInfo(**item) for item in nickname_json] if nickname_json else None + def from_db_instance(cls, db_record: "PersonInfo") -> "MaiPersonInfo": + """从数据库记录构造人物信息数据模型。 + + Args: + db_record: 数据库中的人物信息记录。 + + Returns: + MaiPersonInfo: 转换后的数据模型对象。 + """ + group_cardname_list = parse_group_cardname_json(db_record.group_cardname) memory_points = json.loads(db_record.memory_points) if db_record.memory_points else None return cls( is_known=db_record.is_known, @@ -78,9 +155,12 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]): ) def to_db_instance(self) -> "PersonInfo": - group_cardname = ( - json.dumps([gc.__dict__ for gc in self.group_cardname_list]) if self.group_cardname_list else None - ) + """将当前数据模型转换为数据库记录对象。 + + Returns: + PersonInfo: 可直接写入数据库的模型实例。 + """ + group_cardname = dump_group_cardname_records(self.group_cardname_list) return PersonInfo( is_known=self.is_known, person_id=self.person_id, diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 799f56a0..15ef0049 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -1,22 +1,24 @@ -import hashlib +from datetime import datetime +from typing import Dict, Optional, Union + import asyncio +import hashlib import json -import time -import random import math +import random +import time from json_repair import repair_json -from typing import Union, Optional, Dict -from datetime import datetime from sqlmodel import col, select -from src.common.logger import get_logger +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager +from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json from src.common.database.database import get_db_session from src.common.database.database_model import PersonInfo -from src.llm_models.utils_model import LLMRequest +from src.common.logger import get_logger from src.config.config import global_config, model_config -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager +from src.llm_models.utils_model import LLMRequest logger = get_logger("person_info") @@ -26,6 +28,32 @@ relation_selection_model = LLMRequest( ) +def _to_group_cardname_records(group_cardname_json: Optional[str]) -> list[dict[str, str]]: + """将数据库中的群名片 JSON 转换为 `Person` 内部使用的结构。 + + Args: + group_cardname_json: 数据库存储的群名片 JSON 字符串。 + + Returns: + list[dict[str, str]]: 统一使用 `group_cardname` 键名的群名片列表。 + + Raises: + json.JSONDecodeError: 当 JSON 文本格式非法时抛出。 + TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。 + """ + group_cardname_list = parse_group_cardname_json(group_cardname_json) + if not group_cardname_list: + return [] + + return [ + { + "group_id": group_cardname.group_id, + "group_cardname": group_cardname.group_cardname, + } + for group_cardname in group_cardname_list + ] + + def get_person_id(platform: str, user_id: Union[int, str]) -> str: """获取唯一id""" if "-" in platform: @@ -231,7 +259,7 @@ class Person: person.know_since = time.time() person.last_know = time.time() person.memory_points = [] - person.group_nick_name = [] # 初始化群昵称列表 + person.group_cardname_list = [] # 初始化群名片列表 # 如果是群聊,添加群昵称 if group_id and group_nick_name: @@ -269,7 +297,7 @@ class Person: self.platform = platform self.nickname = global_config.bot.nickname self.person_name = global_config.bot.nickname - self.group_nick_name: list[dict[str, str]] = [] + self.group_cardname_list: list[dict[str, str]] = [] return self.user_id = "" @@ -308,7 +336,7 @@ class Person: self.know_since = None self.last_know: Optional[float] = None self.memory_points = [] - self.group_nick_name: list[dict[str, str]] = [] # 群昵称列表,存储 {"group_id": str, "group_nick_name": str} + self.group_cardname_list: list[dict[str, str]] = [] # 群名片列表,存储 {"group_id": str, "group_cardname": str} # 从数据库加载数据 self.load_from_database() @@ -408,16 +436,16 @@ class Person: return # 检查是否已存在该群号的记录 - for item in self.group_nick_name: + for item in self.group_cardname_list: if item.get("group_id") == group_id: # 更新现有记录 - item["group_nick_name"] = group_nick_name + item["group_cardname"] = group_nick_name self.sync_to_database() logger.debug(f"更新用户 {self.person_id} 在群 {group_id} 的群昵称为 {group_nick_name}") return # 添加新记录 - self.group_nick_name.append({"group_id": group_id, "group_nick_name": group_nick_name}) + self.group_cardname_list.append({"group_id": group_id, "group_cardname": group_nick_name}) self.sync_to_database() logger.debug(f"添加用户 {self.person_id} 在群 {group_id} 的群昵称 {group_nick_name}") @@ -452,20 +480,15 @@ class Person: else: self.memory_points = [] - # 处理group_nick_name字段(JSON格式的列表) + # 处理 group_cardname 字段(JSON 格式的列表) if record.group_cardname: try: - loaded_group_nick_names = json.loads(record.group_cardname) - # 确保是列表格式 - if isinstance(loaded_group_nick_names, list): - self.group_nick_name = loaded_group_nick_names - else: - self.group_nick_name = [] + self.group_cardname_list = _to_group_cardname_records(record.group_cardname) except (json.JSONDecodeError, TypeError): logger.warning(f"解析用户 {self.person_id} 的group_cardname字段失败,使用默认值") - self.group_nick_name = [] + self.group_cardname_list = [] else: - self.group_nick_name = [] + self.group_cardname_list = [] logger.debug(f"已从数据库加载用户 {self.person_id} 的信息") else: @@ -486,11 +509,7 @@ class Person: if self.memory_points else json.dumps([], ensure_ascii=False) ) - group_nickname_value = ( - json.dumps(self.group_nick_name, ensure_ascii=False) - if self.group_nick_name - else json.dumps([], ensure_ascii=False) - ) + group_cardname_value = dump_group_cardname_records(self.group_cardname_list) first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None @@ -510,7 +529,7 @@ class Person: record.first_known_time = first_known_time record.last_known_time = last_known_time record.memory_points = memory_points_value - record.group_nickname = group_nickname_value + record.group_cardname = group_cardname_value session.add(record) logger.debug(f"已同步用户 {self.person_id} 的信息到数据库") else: @@ -526,7 +545,7 @@ class Person: first_known_time=first_known_time, last_known_time=last_known_time, memory_points=memory_points_value, - group_nickname=group_nickname_value, + group_cardname=group_cardname_value, ) session.add(record) logger.debug(f"已创建用户 {self.person_id} 的信息到数据库") diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 33091d5a..8a26af11 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -8,13 +9,15 @@ import sys from src.common.logger import get_logger from src.config.config import global_config -from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager +from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, RouteMode, get_platform_io_manager from src.platform_io.drivers import PluginPlatformDriver from src.platform_io.route_key_factory import RouteKeyFactory from src.platform_io.routing import RouteBindingConflictError from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN from src.plugin_runtime.protocol.envelope import ( AdapterDeclarationPayload, + AdapterStateUpdatePayload, + AdapterStateUpdateResultPayload, BootstrapPluginPayload, ConfigUpdatedPayload, Envelope, @@ -46,6 +49,19 @@ if TYPE_CHECKING: logger = get_logger("plugin_runtime.host.runner_manager") +_ADAPTER_BINDING_ROLE_RUNTIME_EXACT = "runtime_exact" +_ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT = "platform_default" + + +@dataclass(slots=True) +class _AdapterRuntimeState: + """保存适配器插件当前的运行时连接状态。""" + + connected: bool = False + account_id: Optional[str] = None + scope: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + class PluginRunnerSupervisor: """插件 Runner 监督器。 @@ -94,6 +110,7 @@ class PluginRunnerSupervisor: self._runner_process: Optional[asyncio.subprocess.Process] = None self._registered_plugins: Dict[str, RegisterPluginPayload] = {} self._registered_adapters: Dict[str, AdapterDeclarationPayload] = {} + self._adapter_runtime_states: Dict[str, _AdapterRuntimeState] = {} self._runner_ready_events: asyncio.Event = asyncio.Event() self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload() self._health_task: Optional[asyncio.Task[None]] = None @@ -452,6 +469,7 @@ class PluginRunnerSupervisor: """注册 Host 侧内部 RPC 方法。""" self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request) self._rpc_server.register_method("host.receive_external_message", self._handle_receive_external_message) + self._rpc_server.register_method("host.update_adapter_state", self._handle_update_adapter_state) self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin) self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin) self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin) @@ -563,14 +581,14 @@ class PluginRunnerSupervisor: return f"adapter:{plugin_id}" async def _register_adapter_driver(self, plugin_id: str, adapter: AdapterDeclarationPayload) -> None: - """将适配器插件注册到 Platform IO。 + """将适配器插件驱动注册到 Platform IO。 Args: plugin_id: 适配器插件 ID。 adapter: 经过校验的适配器声明。 Raises: - ValueError: 适配器路由冲突或驱动注册失败时抛出。 + ValueError: 当驱动注册失败时抛出。 """ await self._unregister_adapter_driver(plugin_id) @@ -588,22 +606,12 @@ class PluginRunnerSupervisor: **adapter.metadata, }, ) - binding = RouteBinding( - route_key=driver.descriptor.route_key, - driver_id=driver.driver_id, - driver_kind=DriverKind.PLUGIN, - metadata={ - "plugin_id": plugin_id, - "protocol": adapter.protocol, - }, - ) try: if platform_io_manager.is_started: await platform_io_manager.add_driver(driver) else: platform_io_manager.register_driver(driver) - platform_io_manager.bind_route(binding) except Exception: with contextlib.suppress(Exception): if platform_io_manager.is_started: @@ -613,6 +621,7 @@ class PluginRunnerSupervisor: raise self._registered_adapters[plugin_id] = adapter + self._adapter_runtime_states[plugin_id] = _AdapterRuntimeState() async def _unregister_adapter_driver(self, plugin_id: str) -> None: """从 Platform IO 注销一个适配器驱动。 @@ -622,6 +631,9 @@ class PluginRunnerSupervisor: """ platform_io_manager = get_platform_io_manager() driver_id = self._build_adapter_driver_id(plugin_id) + adapter = self._registered_adapters.get(plugin_id) + + self._remove_adapter_route_bindings(plugin_id) with contextlib.suppress(Exception): if platform_io_manager.is_started: @@ -629,7 +641,11 @@ class PluginRunnerSupervisor: else: platform_io_manager.unregister_driver(driver_id) + if adapter is not None: + self._refresh_platform_default_route(adapter.platform) + self._registered_adapters.pop(plugin_id, None) + self._adapter_runtime_states.pop(plugin_id, None) async def _unregister_all_adapter_drivers(self) -> None: """注销当前 Supervisor 管理的全部适配器驱动。""" @@ -637,6 +653,198 @@ class PluginRunnerSupervisor: for plugin_id in plugin_ids: await self._unregister_adapter_driver(plugin_id) + def _remove_adapter_route_bindings(self, plugin_id: str) -> None: + """移除某个适配器驱动当前持有的全部路由绑定。 + + Args: + plugin_id: 适配器插件 ID。 + """ + platform_io_manager = get_platform_io_manager() + platform_io_manager.route_table.remove_bindings_by_driver(self._build_adapter_driver_id(plugin_id)) + + @staticmethod + def _normalize_runtime_route_value(value: str) -> Optional[str]: + """规范化适配器运行时路由字段。 + + Args: + value: 待规范化的原始字符串。 + + Returns: + Optional[str]: 规范化后非空则返回字符串,否则返回 ``None``。 + """ + normalized_value = str(value).strip() + return normalized_value or None + + def _build_runtime_route_key( + self, + adapter: AdapterDeclarationPayload, + payload: AdapterStateUpdatePayload, + ) -> RouteKey: + """根据运行时状态更新构造适配器生效路由键。 + + Args: + adapter: 当前适配器声明。 + payload: 适配器上报的运行时状态。 + + Returns: + RouteKey: 当前连接应接管的精确路由键。 + + Raises: + ValueError: 当静态声明与运行时上报的身份信息冲突时抛出。 + """ + runtime_account_id = self._normalize_runtime_route_value(payload.account_id) + runtime_scope = self._normalize_runtime_route_value(payload.scope) + + if adapter.account_id and runtime_account_id and adapter.account_id != runtime_account_id: + raise ValueError( + f"适配器声明的 account_id={adapter.account_id} 与运行时上报的 {runtime_account_id} 不一致" + ) + if adapter.scope and runtime_scope and adapter.scope != runtime_scope: + raise ValueError(f"适配器声明的 scope={adapter.scope} 与运行时上报的 {runtime_scope} 不一致") + + return RouteKey( + platform=adapter.platform, + account_id=runtime_account_id or adapter.account_id or None, + scope=runtime_scope or adapter.scope or None, + ) + + def _bind_runtime_exact_route( + self, + plugin_id: str, + adapter: AdapterDeclarationPayload, + route_key: RouteKey, + ) -> None: + """为适配器连接绑定精确生效路由。 + + Args: + plugin_id: 适配器插件 ID。 + adapter: 当前适配器声明。 + route_key: 当前连接对应的精确路由键。 + + Raises: + RouteBindingConflictError: 当目标路由已被其他 active owner 占用时抛出。 + """ + platform_io_manager = get_platform_io_manager() + platform_io_manager.bind_route( + RouteBinding( + route_key=route_key, + driver_id=self._build_adapter_driver_id(plugin_id), + driver_kind=DriverKind.PLUGIN, + metadata={ + "plugin_id": plugin_id, + "protocol": adapter.protocol, + "binding_role": _ADAPTER_BINDING_ROLE_RUNTIME_EXACT, + }, + ) + ) + + def _list_runtime_exact_bindings(self, platform: str) -> List[RouteBinding]: + """列出某个平台上由 Host 动态维护的精确适配器绑定。 + + Args: + platform: 目标平台名称。 + + Returns: + List[RouteBinding]: 当前平台上全部动态精确绑定。 + """ + platform_io_manager = get_platform_io_manager() + return [ + binding + for binding in platform_io_manager.route_table.list_bindings() + if binding.mode == RouteMode.ACTIVE + and binding.route_key.platform == platform + and binding.metadata.get("binding_role") == _ADAPTER_BINDING_ROLE_RUNTIME_EXACT + ] + + def _refresh_platform_default_route(self, platform: str) -> None: + """根据当前精确绑定数量刷新平台级默认路由。 + + 当某个平台恰好只存在一个动态精确绑定时,会为该绑定额外创建一条 + ``RouteKey(platform=)`` 形式的默认路由,方便缺少账号维度的 + 出站消息继续找到唯一 owner。若精确绑定数量变为 0 或大于 1,则撤销 + 由 Host 自动维护的默认路由,避免出现隐式歧义。 + + Args: + platform: 目标平台名称。 + """ + platform_io_manager = get_platform_io_manager() + default_route_key = RouteKey(platform=platform) + existing_default_binding = platform_io_manager.route_table.get_active_binding(default_route_key, exact_only=True) + + if existing_default_binding is not None: + binding_role = existing_default_binding.metadata.get("binding_role") + if binding_role != _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT: + return + platform_io_manager.unbind_route(default_route_key, existing_default_binding.driver_id) + + exact_bindings = self._list_runtime_exact_bindings(platform) + if len(exact_bindings) != 1: + return + + exact_binding = exact_bindings[0] + if exact_binding.route_key == default_route_key: + return + + platform_io_manager.bind_route( + RouteBinding( + route_key=default_route_key, + driver_id=exact_binding.driver_id, + driver_kind=exact_binding.driver_kind, + metadata={ + "plugin_id": exact_binding.metadata.get("plugin_id", ""), + "protocol": exact_binding.metadata.get("protocol", ""), + "binding_role": _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT, + }, + ), + replace=True, + ) + + def _apply_adapter_runtime_state( + self, + plugin_id: str, + adapter: AdapterDeclarationPayload, + payload: AdapterStateUpdatePayload, + ) -> Tuple[_AdapterRuntimeState, Dict[str, Any]]: + """应用适配器运行时状态,并同步 Platform IO 路由。 + + Args: + plugin_id: 适配器插件 ID。 + adapter: 当前适配器声明。 + payload: 适配器上报的运行时状态。 + + Returns: + Tuple[_AdapterRuntimeState, Dict[str, Any]]: 更新后的运行时状态,以及 + 供 RPC 响应返回的路由键字典。 + + Raises: + RouteBindingConflictError: 当新的精确路由与其他 active owner 冲突时抛出。 + ValueError: 当运行时路由信息不合法时抛出。 + """ + if not payload.connected: + self._remove_adapter_route_bindings(plugin_id) + self._refresh_platform_default_route(adapter.platform) + runtime_state = _AdapterRuntimeState(connected=False, metadata=dict(payload.metadata)) + self._adapter_runtime_states[plugin_id] = runtime_state + return runtime_state, {} + + route_key = self._build_runtime_route_key(adapter, payload) + self._remove_adapter_route_bindings(plugin_id) + self._bind_runtime_exact_route(plugin_id, adapter, route_key) + self._refresh_platform_default_route(adapter.platform) + + runtime_state = _AdapterRuntimeState( + connected=True, + account_id=route_key.account_id, + scope=route_key.scope, + metadata=dict(payload.metadata), + ) + self._adapter_runtime_states[plugin_id] = runtime_state + return runtime_state, { + "platform": route_key.platform, + "account_id": route_key.account_id, + "scope": route_key.scope, + } + @staticmethod def _attach_inbound_route_metadata( session_message: "SessionMessage", @@ -706,6 +914,45 @@ class PluginRunnerSupervisor: scope=scope, ) + async def _handle_update_adapter_state(self, envelope: Envelope) -> Envelope: + """处理适配器插件上报的运行时状态更新。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: 状态更新处理结果。 + """ + try: + payload = AdapterStateUpdatePayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + adapter = self._registered_adapters.get(envelope.plugin_id) + if adapter is None: + return envelope.make_error_response( + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"插件 {envelope.plugin_id} 未声明为适配器,不能更新运行时状态", + ) + + try: + runtime_state, route_key_dict = self._apply_adapter_runtime_state( + plugin_id=envelope.plugin_id, + adapter=adapter, + payload=payload, + ) + except RouteBindingConflictError as exc: + return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, str(exc)) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + response = AdapterStateUpdateResultPayload( + accepted=True, + connected=runtime_state.connected, + route_key=route_key_dict, + ) + return envelope.make_response(payload=response.model_dump()) + async def _handle_receive_external_message(self, envelope: Envelope) -> Envelope: """处理适配器插件上报的外部入站消息。 @@ -970,6 +1217,7 @@ class PluginRunnerSupervisor: self._component_registry.clear() self._registered_plugins.clear() self._registered_adapters.clear() + self._adapter_runtime_states.clear() self._runner_ready_events = asyncio.Event() self._runner_ready_payloads = RunnerReadyPayload() self._rpc_server.clear_handshake_state() diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index d71e02c5..f68657fa 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -304,6 +304,30 @@ class AdapterDeclarationPayload(BaseModel): """适配器附加元数据""" +class AdapterStateUpdatePayload(BaseModel): + """适配器运行时状态更新载荷。""" + + connected: bool = Field(description="适配器当前是否已连接并准备接管路由") + """适配器当前是否已连接并准备接管路由""" + account_id: str = Field(default="", description="当前连接对应的账号 ID 或 self_id") + """当前连接对应的账号 ID 或 self_id""" + scope: str = Field(default="", description="当前连接对应的可选路由作用域") + """当前连接对应的可选路由作用域""" + metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的运行时状态元数据") + """可选的运行时状态元数据""" + + +class AdapterStateUpdateResultPayload(BaseModel): + """适配器运行时状态更新结果载荷。""" + + accepted: bool = Field(description="Host 是否接受了本次状态更新") + """Host 是否接受了本次状态更新""" + connected: bool = Field(description="Host 记录的当前连接状态") + """Host 记录的当前连接状态""" + route_key: Dict[str, Any] = Field(default_factory=dict, description="当前生效的路由键") + """当前生效的路由键""" + + class ReceiveExternalMessagePayload(BaseModel): """适配器插件向 Host 注入外部消息的请求载荷。""" diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index 88f92494..8078c88b 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -481,13 +481,14 @@ class PluginRunner: self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) return False - if not await self._invoke_plugin_on_load(meta): + if not await self._register_plugin(meta): + await self._invoke_plugin_on_unload(meta) await self._deactivate_plugin(meta) self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) return False - if not await self._register_plugin(meta): - await self._invoke_plugin_on_unload(meta) + if not await self._invoke_plugin_on_load(meta): + await self._unregister_plugin(meta.plugin_id, reason="on_load_failed") await self._deactivate_plugin(meta) self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) return False diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py index a481101f..c8bb837b 100644 --- a/src/plugins/built_in/napcat_adapter/plugin.py +++ b/src/plugins/built_in/napcat_adapter/plugin.py @@ -60,6 +60,9 @@ class NapCatAdapterPlugin(MaiBotPlugin): self._connection_task: Optional[asyncio.Task[None]] = None self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {} self._background_tasks: Set[asyncio.Task[Any]] = set() + self._reported_account_id: Optional[str] = None + self._reported_scope: Optional[str] = None + self._runtime_state_connected: bool = False self._send_lock = asyncio.Lock() self._ws: Optional[AiohttpClientWebSocketResponse] = None @@ -170,6 +173,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): with contextlib.suppress(asyncio.CancelledError): await connection_task + await self._report_adapter_disconnected() self._fail_pending_actions("NapCat connection closed") async def _cancel_background_tasks(self) -> None: @@ -209,6 +213,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): self.ctx.logger.warning(f"NapCat 适配器连接失败: {exc}") finally: self._ws = None + await self._report_adapter_disconnected() self._fail_pending_actions("NapCat connection interrupted") if not self._should_connect(): @@ -230,26 +235,39 @@ class NapCatAdapterPlugin(MaiBotPlugin): """ assert WSMsgType is not None - async for ws_message in ws: - if ws_message.type != WSMsgType.TEXT: - if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}: - break - continue + bootstrap_task = asyncio.create_task( + self._bootstrap_adapter_runtime_state(), + name="napcat_adapter.bootstrap", + ) + self._background_tasks.add(bootstrap_task) + bootstrap_task.add_done_callback(self._background_tasks.discard) - payload = self._parse_json_message(ws_message.data) - if payload is None: - continue + try: + async for ws_message in ws: + if ws_message.type != WSMsgType.TEXT: + if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}: + break + continue - if echo_id := str(payload.get("echo") or "").strip(): - self._resolve_pending_action(echo_id, payload) - continue + payload = self._parse_json_message(ws_message.data) + if payload is None: + continue - if str(payload.get("post_type") or "").strip() != "message": - continue + if echo_id := str(payload.get("echo") or "").strip(): + self._resolve_pending_action(echo_id, payload) + continue - task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound") - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) + if str(payload.get("post_type") or "").strip() != "message": + continue + + task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound") + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + finally: + if not bootstrap_task.done(): + bootstrap_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await bootstrap_task async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None: """处理单条 NapCat 入站消息并注入 Host。 @@ -258,6 +276,9 @@ class NapCatAdapterPlugin(MaiBotPlugin): payload: NapCat / OneBot 推送的原始事件数据。 """ self_id = str(payload.get("self_id") or "").strip() + if self_id: + await self._report_adapter_connected(self_id) + sender = payload.get("sender", {}) if not isinstance(sender, dict): sender = {} @@ -570,6 +591,121 @@ class NapCatAdapterPlugin(MaiBotPlugin): response_future.set_exception(RuntimeError(error_message)) self._pending_actions.clear() + async def _bootstrap_adapter_runtime_state(self) -> None: + """在连接建立后主动获取账号信息并激活适配器路由。 + + 该步骤会在 WebSocket 接收循环启动后异步执行,确保 `_call_action()` + 发出的 `get_login_info` 请求能够被同一连接上的接收循环消费到 echo + 响应,从而在真正收到业务消息前就完成 Host 侧 route 激活。 + """ + max_attempts = 3 + last_error: Optional[Exception] = None + for attempt in range(1, max_attempts + 1): + ws = self._ws + if ws is None or ws.closed: + return + + try: + response = await self._call_action("get_login_info", {}) + self_id = self._extract_self_id_from_login_response(response) + await self._report_adapter_connected(self_id) + return + except asyncio.CancelledError: + raise + except Exception as exc: + last_error = exc + self.ctx.logger.warning( + f"NapCat 适配器获取登录信息失败,第 {attempt}/{max_attempts} 次重试: {exc}" + ) + if attempt < max_attempts: + await asyncio.sleep(1.0) + + if last_error is not None: + self.ctx.logger.error(f"NapCat 适配器未能完成路由激活,连接将保持只接收状态: {last_error}") + + @staticmethod + def _extract_self_id_from_login_response(response: Dict[str, Any]) -> str: + """从 `get_login_info` 响应中提取当前账号 ID。 + + Args: + response: NapCat 返回的原始动作响应。 + + Returns: + str: 规范化后的 `self_id` 字符串。 + + Raises: + ValueError: 当响应中缺少有效账号 ID 时抛出。 + """ + if str(response.get("status") or "").lower() != "ok": + raise ValueError(str(response.get("wording") or response.get("message") or "get_login_info failed")) + + response_data = response.get("data", {}) + if not isinstance(response_data, dict): + raise ValueError("get_login_info 响应缺少 data 字段") + + self_id = str(response_data.get("user_id") or "").strip() + if not self_id: + raise ValueError("get_login_info 响应缺少有效的 user_id") + return self_id + + async def _report_adapter_connected(self, account_id: str) -> None: + """向 Host 上报当前连接已就绪。 + + Args: + account_id: 当前 NapCat 连接对应的机器人账号 ID。 + """ + normalized_account_id = str(account_id).strip() + if not normalized_account_id: + return + + scope = self._get_string(self._connection_config(), "connection_id").strip() + if ( + self._runtime_state_connected + and self._reported_account_id == normalized_account_id + and self._reported_scope == (scope or None) + ): + return + + accepted = False + try: + accepted = await self.ctx.adapter.update_runtime_state( + connected=True, + account_id=normalized_account_id, + scope=scope, + metadata={"ws_url": self._get_string(self._connection_config(), "ws_url")}, + ) + except Exception as exc: + self.ctx.logger.warning(f"NapCat 适配器上报连接就绪状态失败: {exc}") + return + + if not accepted: + self.ctx.logger.warning("NapCat 适配器连接已建立,但 Host 未接受运行时状态更新") + return + + self._runtime_state_connected = True + self._reported_account_id = normalized_account_id + self._reported_scope = scope or None + self.ctx.logger.info( + f"NapCat 适配器已激活路由: platform=qq account_id={normalized_account_id} " + f"scope={self._reported_scope or '*'}" + ) + + async def _report_adapter_disconnected(self) -> None: + """向 Host 上报当前连接已断开,并撤销适配器路由。""" + if not self._runtime_state_connected: + self._reported_account_id = None + self._reported_scope = None + return + + try: + await self.ctx.adapter.update_runtime_state(connected=False) + except Exception as exc: + self.ctx.logger.warning(f"NapCat 适配器上报断开状态失败: {exc}") + finally: + self._runtime_state_connected = False + self._reported_account_id = None + self._reported_scope = None + def _build_headers(self) -> Dict[str, str]: """构造连接 NapCat 所需的请求头。