feat: Implement adapter runtime state management and update handling

- Added support for adapter runtime state updates in the PluginRunnerSupervisor.
- Introduced new payload classes: AdapterStateUpdatePayload and AdapterStateUpdateResultPayload for handling state updates.
- Implemented methods to bind and unbind routes based on adapter connection status.
- Enhanced the NapCat adapter to report connection state and manage runtime state.
- Added tests for adapter runtime state synchronization and database session behavior in the statistic module.
- Updated existing methods to ensure proper handling of adapter state and route bindings.
This commit is contained in:
DrSmoothl
2026-03-21 21:47:22 +08:00
parent dd20cd4992
commit 4e2e7a279e
11 changed files with 1219 additions and 79 deletions

View File

@@ -0,0 +1,355 @@
"""人物信息群名片字段兼容测试。"""
from __future__ import annotations
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any
import json
import sys
import pytest
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
class _DummyLogger:
"""模拟日志记录器。"""
def debug(self, message: str) -> None:
"""记录调试日志。
Args:
message: 日志内容。
"""
del message
def info(self, message: str) -> None:
"""记录信息日志。
Args:
message: 日志内容。
"""
del message
def warning(self, message: str) -> None:
"""记录警告日志。
Args:
message: 日志内容。
"""
del message
def error(self, message: str) -> None:
"""记录错误日志。
Args:
message: 日志内容。
"""
del message
class _DummyStatement:
"""模拟 SQL 查询语句对象。"""
def where(self, condition: Any) -> "_DummyStatement":
"""附加过滤条件。
Args:
condition: 过滤条件。
Returns:
_DummyStatement: 当前语句对象。
"""
del condition
return self
def limit(self, value: int) -> "_DummyStatement":
"""限制返回条数。
Args:
value: 条数限制。
Returns:
_DummyStatement: 当前语句对象。
"""
del value
return self
class _DummyColumn:
"""模拟 SQLModel 列对象。"""
def is_not(self, value: Any) -> "_DummyColumn":
"""模拟 `IS NOT` 条件构造。
Args:
value: 比较值。
Returns:
_DummyColumn: 当前列对象。
"""
del value
return self
def __eq__(self, other: Any) -> "_DummyColumn":
"""模拟等值条件构造。
Args:
other: 比较值。
Returns:
_DummyColumn: 当前列对象。
"""
del other
return self
class _DummyResult:
"""模拟数据库查询结果。"""
def __init__(self, record: Any) -> None:
"""初始化查询结果。
Args:
record: 待返回的首条记录。
"""
self._record = record
def first(self) -> Any:
"""返回第一条记录。
Returns:
Any: 首条记录。
"""
return self._record
def all(self) -> list[Any]:
"""返回全部结果。
Returns:
list[Any]: 结果列表。
"""
if self._record is None:
return []
return self._record if isinstance(self._record, list) else [self._record]
class _DummySession:
"""模拟数据库 Session。"""
def __init__(self, record: Any) -> None:
"""初始化 Session。
Args:
record: `first()` 应返回的记录。
"""
self.record = record
self.added_records: list[Any] = []
def __enter__(self) -> "_DummySession":
"""进入上下文管理器。
Returns:
_DummySession: 当前 Session。
"""
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""退出上下文管理器。
Args:
exc_type: 异常类型。
exc_val: 异常值。
exc_tb: 异常回溯。
"""
del exc_type
del exc_val
del exc_tb
def exec(self, statement: Any) -> _DummyResult:
"""执行查询。
Args:
statement: 查询语句。
Returns:
_DummyResult: 模拟结果对象。
"""
del statement
return _DummyResult(self.record)
def add(self, record: Any) -> None:
"""记录被添加的对象。
Args:
record: 被写入 Session 的对象。
"""
self.added_records.append(record)
class _DummyPersonInfoRecord:
"""模拟 `PersonInfo` ORM 模型。"""
person_id = "person_id"
person_name = "person_name"
def __init__(self, **kwargs: Any) -> None:
"""使用关键字参数初始化记录对象。
Args:
**kwargs: 字段值。
"""
for key, value in kwargs.items():
setattr(self, key, value)
def _load_person_module(monkeypatch: pytest.MonkeyPatch, session: _DummySession) -> ModuleType:
"""加载带依赖桩的 `person_info` 模块。
Args:
monkeypatch: Pytest monkeypatch 工具。
session: 提供给模块使用的假数据库 Session。
Returns:
ModuleType: 加载后的模块对象。
"""
logger_module = ModuleType("src.common.logger")
logger_module.get_logger = lambda name: _DummyLogger()
monkeypatch.setitem(sys.modules, "src.common.logger", logger_module)
database_module = ModuleType("src.common.database.database")
database_module.get_db_session = lambda: session
monkeypatch.setitem(sys.modules, "src.common.database.database", database_module)
database_model_module = ModuleType("src.common.database.database_model")
database_model_module.PersonInfo = _DummyPersonInfoRecord
monkeypatch.setitem(sys.modules, "src.common.database.database_model", database_model_module)
llm_module = ModuleType("src.llm_models.utils_model")
class _DummyLLMRequest:
"""模拟 LLMRequest。"""
def __init__(self, model_set: Any, request_type: str) -> None:
"""初始化假请求对象。
Args:
model_set: 模型配置。
request_type: 请求类型。
"""
del model_set
del request_type
llm_module.LLMRequest = _DummyLLMRequest
monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_module)
config_module = ModuleType("src.config.config")
config_module.global_config = SimpleNamespace(bot=SimpleNamespace(nickname="MaiBot"))
config_module.model_config = SimpleNamespace(model_task_config=SimpleNamespace(tool_use="tool_use", utils="utils"))
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
chat_manager_module = ModuleType("src.chat.message_receive.chat_manager")
chat_manager_module.chat_manager = SimpleNamespace()
monkeypatch.setitem(sys.modules, "src.chat.message_receive.chat_manager", chat_manager_module)
module_path = Path(__file__).resolve().parents[2] / "src" / "person_info" / "person_info.py"
spec = spec_from_file_location("person_info_group_cardname_test_module", module_path)
assert spec is not None and spec.loader is not None
module = module_from_spec(spec)
monkeypatch.setitem(sys.modules, spec.name, module)
spec.loader.exec_module(module)
monkeypatch.setattr(module, "select", lambda *args: _DummyStatement())
monkeypatch.setattr(module, "col", lambda field: _DummyColumn())
return module
def test_parse_group_cardname_json_uses_canonical_key() -> None:
"""群名片 JSON 解析应只使用 `group_cardname` 键名。"""
parsed = parse_group_cardname_json(
json.dumps(
[
{"group_id": "1001", "group_cardname": "现行字段"},
],
ensure_ascii=False,
)
)
assert parsed is not None
assert [(item.group_id, item.group_cardname) for item in parsed] == [
("1001", "现行字段"),
]
def test_dump_group_cardname_records_uses_canonical_key() -> None:
"""群名片序列化应输出 `group_cardname` 键名。"""
dumped = dump_group_cardname_records(
[
{"group_id": "1001", "group_cardname": "群昵称"},
]
)
assert json.loads(dumped) == [{"group_id": "1001", "group_cardname": "群昵称"}]
def test_person_sync_to_database_uses_group_cardname_field(monkeypatch: pytest.MonkeyPatch) -> None:
"""同步人物信息时应写入数据库模型的 `group_cardname` 字段。"""
record = _DummyPersonInfoRecord()
session = _DummySession(record)
module = _load_person_module(monkeypatch, session)
person = module.Person.__new__(module.Person)
person.is_known = True
person.person_id = "person-1"
person.platform = "qq"
person.user_id = "10001"
person.nickname = "看番的龙"
person.person_name = "看番的龙"
person.name_reason = "测试"
person.know_times = 1
person.know_since = 1700000000.0
person.last_know = 1700000100.0
person.memory_points = ["喜好:番剧:0.8"]
person.group_cardname_list = [{"group_id": "20001", "group_cardname": "白泽大人"}]
person.sync_to_database()
assert record.group_cardname == '[{"group_id": "20001", "group_cardname": "白泽大人"}]'
assert not hasattr(record, "group_nickname")
def test_person_load_from_database_normalizes_group_cardname_payload(monkeypatch: pytest.MonkeyPatch) -> None:
"""从数据库加载人物信息时应读取标准 `group_cardname` 结构。"""
record = _DummyPersonInfoRecord(
user_id="10001",
platform="qq",
is_known=True,
user_nickname="看番的龙",
person_name="看番的龙",
name_reason=None,
know_counts=2,
memory_points='["喜好:番剧:0.8"]',
group_cardname=json.dumps(
[
{"group_id": "20001", "group_cardname": "白泽大人"},
],
ensure_ascii=False,
),
)
session = _DummySession(record)
module = _load_person_module(monkeypatch, session)
person = module.Person.__new__(module.Person)
person.person_id = "person-1"
person.memory_points = []
person.group_cardname_list = []
person.load_from_database()
assert person.group_cardname_list == [
{"group_id": "20001", "group_cardname": "白泽大人"},
]

View File

@@ -0,0 +1,162 @@
"""适配器运行时状态同步测试。"""
from typing import Any, Dict
import pytest
from src.platform_io.manager import PlatformIOManager
from src.platform_io.types import RouteKey
from src.plugin_runtime.host.supervisor import PluginSupervisor
from src.plugin_runtime.protocol.envelope import (
AdapterDeclarationPayload,
Envelope,
MessageType,
)
def _make_request(plugin_id: str, payload: Dict[str, Any]) -> Envelope:
"""构造一个适配器状态更新 RPC 请求。
Args:
plugin_id: 目标适配器插件 ID。
payload: 请求载荷。
Returns:
Envelope: 标准 RPC 请求信封。
"""
return Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method="host.update_adapter_state",
plugin_id=plugin_id,
payload=payload,
)
@pytest.mark.asyncio
async def test_adapter_runtime_state_binds_and_unbinds_routes(monkeypatch: pytest.MonkeyPatch) -> None:
"""连接建立后应绑定路由,断开后应撤销路由。"""
import src.plugin_runtime.host.supervisor as supervisor_module
platform_io_manager = PlatformIOManager()
monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
supervisor = PluginSupervisor(plugin_dirs=[])
adapter = AdapterDeclarationPayload(platform="qq", protocol="napcat")
await supervisor._register_adapter_driver("napcat_adapter_builtin", adapter)
response = await supervisor._handle_update_adapter_state(
_make_request(
"napcat_adapter_builtin",
{
"connected": True,
"account_id": "10001",
"scope": "",
"metadata": {},
},
)
)
assert response.error is None
assert response.payload["accepted"] is True
assert (
platform_io_manager.route_table.get_active_binding(
RouteKey(platform="qq", account_id="10001"),
exact_only=True,
).driver_id
== "adapter:napcat_adapter_builtin"
)
assert (
platform_io_manager.route_table.get_active_binding(
RouteKey(platform="qq"),
exact_only=True,
).driver_id
== "adapter:napcat_adapter_builtin"
)
response = await supervisor._handle_update_adapter_state(
_make_request(
"napcat_adapter_builtin",
{
"connected": False,
"account_id": "",
"scope": "",
"metadata": {},
},
)
)
assert response.error is None
assert response.payload["accepted"] is True
assert platform_io_manager.route_table.get_active_binding(
RouteKey(platform="qq", account_id="10001"),
exact_only=True,
) is None
assert platform_io_manager.route_table.get_active_binding(RouteKey(platform="qq"), exact_only=True) is None
@pytest.mark.asyncio
async def test_platform_default_route_is_removed_when_multiple_exact_routes_exist(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""同一平台存在多个精确路由时不应保留默认平台路由。"""
import src.plugin_runtime.host.supervisor as supervisor_module
platform_io_manager = PlatformIOManager()
monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
supervisor = PluginSupervisor(plugin_dirs=[])
adapter = AdapterDeclarationPayload(platform="qq", protocol="napcat")
await supervisor._register_adapter_driver("adapter_a", adapter)
await supervisor._register_adapter_driver("adapter_b", adapter)
await supervisor._handle_update_adapter_state(
_make_request(
"adapter_a",
{
"connected": True,
"account_id": "10001",
"scope": "",
"metadata": {},
},
)
)
assert (
platform_io_manager.route_table.get_active_binding(
RouteKey(platform="qq"),
exact_only=True,
).driver_id
== "adapter:adapter_a"
)
await supervisor._handle_update_adapter_state(
_make_request(
"adapter_b",
{
"connected": True,
"account_id": "10002",
"scope": "",
"metadata": {},
},
)
)
assert platform_io_manager.route_table.get_active_binding(RouteKey(platform="qq"), exact_only=True) is None
await supervisor._handle_update_adapter_state(
_make_request(
"adapter_b",
{
"connected": False,
"account_id": "",
"scope": "",
"metadata": {},
},
)
)
assert (
platform_io_manager.route_table.get_active_binding(
RouteKey(platform="qq"),
exact_only=True,
).driver_id
== "adapter:adapter_a"
)

View File

@@ -486,10 +486,10 @@ class TestSDK:
"timeout_ms": timeout_ms,
}
)
if method == "cap.request":
if method == "cap.call":
bootstrap_methods = [call["method"] for call in self.calls[:-1]]
assert "plugin.bootstrap" in bootstrap_methods
return SimpleNamespace(error=None, payload={"result": {"success": True}})
return SimpleNamespace(error=None, payload={"success": True})
return SimpleNamespace(error=None, payload={"accepted": True})
async def disconnect(self):
@@ -529,7 +529,7 @@ class TestSDK:
await runner.run()
methods = [call["method"] for call in runner._rpc_client.calls]
assert methods == ["plugin.bootstrap", "cap.request", "plugin.register_components", "runner.ready"]
assert methods == ["plugin.bootstrap", "plugin.register_components", "cap.call", "runner.ready"]
class TestPluginSdkUsage:

View File

@@ -0,0 +1,115 @@
"""统计模块数据库会话行为测试。"""
from __future__ import annotations
from contextlib import contextmanager
from datetime import datetime, timedelta
from types import ModuleType
from typing import Any, Callable, Iterator
import sys
import pytest
from src.chat.utils import statistic
class _DummyResult:
"""模拟 SQLModel 查询结果对象。"""
def all(self) -> list[Any]:
"""返回空结果集。
Returns:
list[Any]: 空列表。
"""
return []
class _DummySession:
"""模拟数据库 Session。"""
def exec(self, statement: Any) -> _DummyResult:
"""执行查询语句并返回空结果。
Args:
statement: 待执行的查询语句。
Returns:
_DummyResult: 空结果对象。
"""
del statement
return _DummyResult()
def _build_fake_get_db_session(calls: list[bool]) -> Callable[[bool], Iterator[_DummySession]]:
"""构造一个记录 auto_commit 参数的假会话工厂。
Args:
calls: 用于记录每次调用 auto_commit 参数的列表。
Returns:
Callable[[bool], Iterator[_DummySession]]: 可替换 `get_db_session` 的上下文管理器工厂。
"""
@contextmanager
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_DummySession]:
"""记录会话参数并返回假 Session。
Args:
auto_commit: 是否启用自动提交。
Yields:
Iterator[_DummySession]: 假 Session 对象。
"""
calls.append(auto_commit)
yield _DummySession()
return _fake_get_db_session
def _build_statistic_task() -> statistic.StatisticOutputTask:
"""构造一个最小可用的统计任务实例。
Returns:
statistic.StatisticOutputTask: 跳过 `__init__` 的测试实例。
"""
task = statistic.StatisticOutputTask.__new__(statistic.StatisticOutputTask)
task.name_mapping = {}
return task
def _is_bot_self(platform: str, user_id: str) -> bool:
"""返回固定的非机器人身份判断结果。
Args:
platform: 平台名称。
user_id: 用户 ID。
Returns:
bool: 始终返回 ``False``。
"""
del platform
del user_id
return False
def test_statistic_read_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None:
"""统计模块的纯读查询应关闭自动提交,避免 Session 退出后对象被 expire。"""
calls: list[bool] = []
now = datetime.now()
task = _build_statistic_task()
monkeypatch.setattr(statistic, "get_db_session", _build_fake_get_db_session(calls))
utils_module = ModuleType("src.chat.utils.utils")
utils_module.is_bot_self = _is_bot_self
monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
statistic.StatisticOutputTask._fetch_online_time_since(now)
statistic.StatisticOutputTask._fetch_model_usage_since(now)
task._collect_message_count_for_period([("last_hour", now - timedelta(hours=1))])
task._collect_interval_data(now, hours=1, interval_minutes=60)
task._collect_metrics_interval_data(now, hours=1, interval_hours=1)
assert calls == [False] * 9

View File

@@ -436,14 +436,14 @@ class StatisticOutputTask(AsyncTask):
@staticmethod
def _fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]:
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time)
records = session.exec(statement).all()
return [(record.start_timestamp, record.end_timestamp) for record in records]
@staticmethod
def _fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]:
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time)
records = session.exec(statement).all()
return [
@@ -664,7 +664,7 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1]
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(Messages).where(col(Messages.timestamp) >= query_start_timestamp)
messages = session.exec(statement).all()
for message in messages:
@@ -713,7 +713,7 @@ class StatisticOutputTask(AsyncTask):
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
try:
action_query_start_timestamp = collect_period[-1][1]
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp)
actions = session.exec(statement).all()
for action in actions:
@@ -1750,7 +1750,7 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
messages = session.exec(statement).all()
for message in messages:
@@ -2131,7 +2131,7 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
with get_db_session() as session:
with get_db_session(auto_commit=False) as session:
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
messages = session.exec(statement).all()
for message in messages:

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from datetime import datetime
from typing import Optional, List
from typing import Any, List, Mapping, Optional, Sequence
import json
@@ -15,6 +15,76 @@ class GroupCardnameInfo:
group_cardname: str
def _normalize_group_cardname_item(raw_item: Mapping[str, Any]) -> Optional[GroupCardnameInfo]:
"""将单条群名片数据规范化为统一结构。
Args:
raw_item: 原始群名片字典,必须包含 `group_id` 和 `group_cardname`。
Returns:
Optional[GroupCardnameInfo]: 规范化后的群名片信息;若数据不完整则返回 ``None``。
"""
group_id = str(raw_item.get("group_id") or "").strip()
group_cardname = str(raw_item.get("group_cardname") or "").strip()
if not group_id or not group_cardname:
return None
return GroupCardnameInfo(group_id=group_id, group_cardname=group_cardname)
def parse_group_cardname_json(group_cardname_json: Optional[str]) -> Optional[List[GroupCardnameInfo]]:
"""解析数据库中的群名片 JSON 字段。
Args:
group_cardname_json: 数据库存储的群名片 JSON 字符串。
Returns:
Optional[List[GroupCardnameInfo]]: 解析并规范化后的群名片列表;若字段为空或无有效项则返回 ``None``。
Raises:
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
"""
if not group_cardname_json:
return None
raw_items = json.loads(group_cardname_json)
if not isinstance(raw_items, list):
return None
normalized_items: List[GroupCardnameInfo] = []
for raw_item in raw_items:
if not isinstance(raw_item, Mapping):
continue
if normalized_item := _normalize_group_cardname_item(raw_item):
normalized_items.append(normalized_item)
return normalized_items or None
def dump_group_cardname_records(
group_cardname_records: Optional[Sequence[GroupCardnameInfo | Mapping[str, Any]]],
) -> str:
"""将群名片列表序列化为数据库使用的标准 JSON 字符串。
Args:
group_cardname_records: 待序列化的群名片列表,支持 `GroupCardnameInfo`
对象和包含 `group_id` / `group_cardname` 的字典。
Returns:
str: 统一使用 `group_cardname` 键名的 JSON 字符串。
"""
normalized_items: List[GroupCardnameInfo] = []
for raw_item in group_cardname_records or []:
if isinstance(raw_item, GroupCardnameInfo):
normalized_items.append(raw_item)
continue
if isinstance(raw_item, Mapping):
if normalized_item := _normalize_group_cardname_item(raw_item):
normalized_items.append(normalized_item)
return json.dumps([asdict(item) for item in normalized_items], ensure_ascii=False)
class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
def __init__(
self,
@@ -58,9 +128,16 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
"""最后一次被认识的时间"""
@classmethod
def from_db_instance(cls, db_record: "PersonInfo"):
nickname_json = json.loads(db_record.group_cardname) if db_record.group_cardname else None
group_cardname_list = [GroupCardnameInfo(**item) for item in nickname_json] if nickname_json else None
def from_db_instance(cls, db_record: "PersonInfo") -> "MaiPersonInfo":
"""从数据库记录构造人物信息数据模型。
Args:
db_record: 数据库中的人物信息记录。
Returns:
MaiPersonInfo: 转换后的数据模型对象。
"""
group_cardname_list = parse_group_cardname_json(db_record.group_cardname)
memory_points = json.loads(db_record.memory_points) if db_record.memory_points else None
return cls(
is_known=db_record.is_known,
@@ -78,9 +155,12 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
)
def to_db_instance(self) -> "PersonInfo":
group_cardname = (
json.dumps([gc.__dict__ for gc in self.group_cardname_list]) if self.group_cardname_list else None
)
"""将当前数据模型转换为数据库记录对象。
Returns:
PersonInfo: 可直接写入数据库的模型实例。
"""
group_cardname = dump_group_cardname_records(self.group_cardname_list)
return PersonInfo(
is_known=self.is_known,
person_id=self.person_id,

View File

@@ -1,22 +1,24 @@
import hashlib
from datetime import datetime
from typing import Dict, Optional, Union
import asyncio
import hashlib
import json
import time
import random
import math
import random
import time
from json_repair import repair_json
from typing import Union, Optional, Dict
from datetime import datetime
from sqlmodel import col, select
from src.common.logger import get_logger
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.llm_models.utils_model import LLMRequest
logger = get_logger("person_info")
@@ -26,6 +28,32 @@ relation_selection_model = LLMRequest(
)
def _to_group_cardname_records(group_cardname_json: Optional[str]) -> list[dict[str, str]]:
"""将数据库中的群名片 JSON 转换为 `Person` 内部使用的结构。
Args:
group_cardname_json: 数据库存储的群名片 JSON 字符串。
Returns:
list[dict[str, str]]: 统一使用 `group_cardname` 键名的群名片列表。
Raises:
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
"""
group_cardname_list = parse_group_cardname_json(group_cardname_json)
if not group_cardname_list:
return []
return [
{
"group_id": group_cardname.group_id,
"group_cardname": group_cardname.group_cardname,
}
for group_cardname in group_cardname_list
]
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
"""获取唯一id"""
if "-" in platform:
@@ -231,7 +259,7 @@ class Person:
person.know_since = time.time()
person.last_know = time.time()
person.memory_points = []
person.group_nick_name = [] # 初始化群昵称列表
person.group_cardname_list = [] # 初始化群名片列表
# 如果是群聊,添加群昵称
if group_id and group_nick_name:
@@ -269,7 +297,7 @@ class Person:
self.platform = platform
self.nickname = global_config.bot.nickname
self.person_name = global_config.bot.nickname
self.group_nick_name: list[dict[str, str]] = []
self.group_cardname_list: list[dict[str, str]] = []
return
self.user_id = ""
@@ -308,7 +336,7 @@ class Person:
self.know_since = None
self.last_know: Optional[float] = None
self.memory_points = []
self.group_nick_name: list[dict[str, str]] = [] # 群昵称列表,存储 {"group_id": str, "group_nick_name": str}
self.group_cardname_list: list[dict[str, str]] = [] # 群名片列表,存储 {"group_id": str, "group_cardname": str}
# 从数据库加载数据
self.load_from_database()
@@ -408,16 +436,16 @@ class Person:
return
# 检查是否已存在该群号的记录
for item in self.group_nick_name:
for item in self.group_cardname_list:
if item.get("group_id") == group_id:
# 更新现有记录
item["group_nick_name"] = group_nick_name
item["group_cardname"] = group_nick_name
self.sync_to_database()
logger.debug(f"更新用户 {self.person_id} 在群 {group_id} 的群昵称为 {group_nick_name}")
return
# 添加新记录
self.group_nick_name.append({"group_id": group_id, "group_nick_name": group_nick_name})
self.group_cardname_list.append({"group_id": group_id, "group_cardname": group_nick_name})
self.sync_to_database()
logger.debug(f"添加用户 {self.person_id} 在群 {group_id} 的群昵称 {group_nick_name}")
@@ -452,20 +480,15 @@ class Person:
else:
self.memory_points = []
# 处理group_nick_name字段JSON格式的列表
# 处理 group_cardname 字段JSON 格式的列表)
if record.group_cardname:
try:
loaded_group_nick_names = json.loads(record.group_cardname)
# 确保是列表格式
if isinstance(loaded_group_nick_names, list):
self.group_nick_name = loaded_group_nick_names
else:
self.group_nick_name = []
self.group_cardname_list = _to_group_cardname_records(record.group_cardname)
except (json.JSONDecodeError, TypeError):
logger.warning(f"解析用户 {self.person_id} 的group_cardname字段失败使用默认值")
self.group_nick_name = []
self.group_cardname_list = []
else:
self.group_nick_name = []
self.group_cardname_list = []
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
else:
@@ -486,11 +509,7 @@ class Person:
if self.memory_points
else json.dumps([], ensure_ascii=False)
)
group_nickname_value = (
json.dumps(self.group_nick_name, ensure_ascii=False)
if self.group_nick_name
else json.dumps([], ensure_ascii=False)
)
group_cardname_value = dump_group_cardname_records(self.group_cardname_list)
first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None
last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None
@@ -510,7 +529,7 @@ class Person:
record.first_known_time = first_known_time
record.last_known_time = last_known_time
record.memory_points = memory_points_value
record.group_nickname = group_nickname_value
record.group_cardname = group_cardname_value
session.add(record)
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
else:
@@ -526,7 +545,7 @@ class Person:
first_known_time=first_known_time,
last_known_time=last_known_time,
memory_points=memory_points_value,
group_nickname=group_nickname_value,
group_cardname=group_cardname_value,
)
session.add(record)
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
@@ -8,13 +9,15 @@ import sys
from src.common.logger import get_logger
from src.config.config import global_config
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, RouteMode, get_platform_io_manager
from src.platform_io.drivers import PluginPlatformDriver
from src.platform_io.route_key_factory import RouteKeyFactory
from src.platform_io.routing import RouteBindingConflictError
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
from src.plugin_runtime.protocol.envelope import (
AdapterDeclarationPayload,
AdapterStateUpdatePayload,
AdapterStateUpdateResultPayload,
BootstrapPluginPayload,
ConfigUpdatedPayload,
Envelope,
@@ -46,6 +49,19 @@ if TYPE_CHECKING:
logger = get_logger("plugin_runtime.host.runner_manager")
_ADAPTER_BINDING_ROLE_RUNTIME_EXACT = "runtime_exact"
_ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT = "platform_default"
@dataclass(slots=True)
class _AdapterRuntimeState:
"""保存适配器插件当前的运行时连接状态。"""
connected: bool = False
account_id: Optional[str] = None
scope: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
class PluginRunnerSupervisor:
"""插件 Runner 监督器。
@@ -94,6 +110,7 @@ class PluginRunnerSupervisor:
self._runner_process: Optional[asyncio.subprocess.Process] = None
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
self._registered_adapters: Dict[str, AdapterDeclarationPayload] = {}
self._adapter_runtime_states: Dict[str, _AdapterRuntimeState] = {}
self._runner_ready_events: asyncio.Event = asyncio.Event()
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
self._health_task: Optional[asyncio.Task[None]] = None
@@ -452,6 +469,7 @@ class PluginRunnerSupervisor:
"""注册 Host 侧内部 RPC 方法。"""
self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request)
self._rpc_server.register_method("host.receive_external_message", self._handle_receive_external_message)
self._rpc_server.register_method("host.update_adapter_state", self._handle_update_adapter_state)
self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin)
self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin)
self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin)
@@ -563,14 +581,14 @@ class PluginRunnerSupervisor:
return f"adapter:{plugin_id}"
async def _register_adapter_driver(self, plugin_id: str, adapter: AdapterDeclarationPayload) -> None:
"""将适配器插件注册到 Platform IO。
"""将适配器插件驱动注册到 Platform IO。
Args:
plugin_id: 适配器插件 ID。
adapter: 经过校验的适配器声明。
Raises:
ValueError: 适配器路由冲突或驱动注册失败时抛出。
ValueError: 驱动注册失败时抛出。
"""
await self._unregister_adapter_driver(plugin_id)
@@ -588,22 +606,12 @@ class PluginRunnerSupervisor:
**adapter.metadata,
},
)
binding = RouteBinding(
route_key=driver.descriptor.route_key,
driver_id=driver.driver_id,
driver_kind=DriverKind.PLUGIN,
metadata={
"plugin_id": plugin_id,
"protocol": adapter.protocol,
},
)
try:
if platform_io_manager.is_started:
await platform_io_manager.add_driver(driver)
else:
platform_io_manager.register_driver(driver)
platform_io_manager.bind_route(binding)
except Exception:
with contextlib.suppress(Exception):
if platform_io_manager.is_started:
@@ -613,6 +621,7 @@ class PluginRunnerSupervisor:
raise
self._registered_adapters[plugin_id] = adapter
self._adapter_runtime_states[plugin_id] = _AdapterRuntimeState()
async def _unregister_adapter_driver(self, plugin_id: str) -> None:
"""从 Platform IO 注销一个适配器驱动。
@@ -622,6 +631,9 @@ class PluginRunnerSupervisor:
"""
platform_io_manager = get_platform_io_manager()
driver_id = self._build_adapter_driver_id(plugin_id)
adapter = self._registered_adapters.get(plugin_id)
self._remove_adapter_route_bindings(plugin_id)
with contextlib.suppress(Exception):
if platform_io_manager.is_started:
@@ -629,7 +641,11 @@ class PluginRunnerSupervisor:
else:
platform_io_manager.unregister_driver(driver_id)
if adapter is not None:
self._refresh_platform_default_route(adapter.platform)
self._registered_adapters.pop(plugin_id, None)
self._adapter_runtime_states.pop(plugin_id, None)
async def _unregister_all_adapter_drivers(self) -> None:
"""注销当前 Supervisor 管理的全部适配器驱动。"""
@@ -637,6 +653,198 @@ class PluginRunnerSupervisor:
for plugin_id in plugin_ids:
await self._unregister_adapter_driver(plugin_id)
def _remove_adapter_route_bindings(self, plugin_id: str) -> None:
"""移除某个适配器驱动当前持有的全部路由绑定。
Args:
plugin_id: 适配器插件 ID。
"""
platform_io_manager = get_platform_io_manager()
platform_io_manager.route_table.remove_bindings_by_driver(self._build_adapter_driver_id(plugin_id))
@staticmethod
def _normalize_runtime_route_value(value: str) -> Optional[str]:
"""规范化适配器运行时路由字段。
Args:
value: 待规范化的原始字符串。
Returns:
Optional[str]: 规范化后非空则返回字符串,否则返回 ``None``。
"""
normalized_value = str(value).strip()
return normalized_value or None
def _build_runtime_route_key(
self,
adapter: AdapterDeclarationPayload,
payload: AdapterStateUpdatePayload,
) -> RouteKey:
"""根据运行时状态更新构造适配器生效路由键。
Args:
adapter: 当前适配器声明。
payload: 适配器上报的运行时状态。
Returns:
RouteKey: 当前连接应接管的精确路由键。
Raises:
ValueError: 当静态声明与运行时上报的身份信息冲突时抛出。
"""
runtime_account_id = self._normalize_runtime_route_value(payload.account_id)
runtime_scope = self._normalize_runtime_route_value(payload.scope)
if adapter.account_id and runtime_account_id and adapter.account_id != runtime_account_id:
raise ValueError(
f"适配器声明的 account_id={adapter.account_id} 与运行时上报的 {runtime_account_id} 不一致"
)
if adapter.scope and runtime_scope and adapter.scope != runtime_scope:
raise ValueError(f"适配器声明的 scope={adapter.scope} 与运行时上报的 {runtime_scope} 不一致")
return RouteKey(
platform=adapter.platform,
account_id=runtime_account_id or adapter.account_id or None,
scope=runtime_scope or adapter.scope or None,
)
def _bind_runtime_exact_route(
self,
plugin_id: str,
adapter: AdapterDeclarationPayload,
route_key: RouteKey,
) -> None:
"""为适配器连接绑定精确生效路由。
Args:
plugin_id: 适配器插件 ID。
adapter: 当前适配器声明。
route_key: 当前连接对应的精确路由键。
Raises:
RouteBindingConflictError: 当目标路由已被其他 active owner 占用时抛出。
"""
platform_io_manager = get_platform_io_manager()
platform_io_manager.bind_route(
RouteBinding(
route_key=route_key,
driver_id=self._build_adapter_driver_id(plugin_id),
driver_kind=DriverKind.PLUGIN,
metadata={
"plugin_id": plugin_id,
"protocol": adapter.protocol,
"binding_role": _ADAPTER_BINDING_ROLE_RUNTIME_EXACT,
},
)
)
def _list_runtime_exact_bindings(self, platform: str) -> List[RouteBinding]:
"""列出某个平台上由 Host 动态维护的精确适配器绑定。
Args:
platform: 目标平台名称。
Returns:
List[RouteBinding]: 当前平台上全部动态精确绑定。
"""
platform_io_manager = get_platform_io_manager()
return [
binding
for binding in platform_io_manager.route_table.list_bindings()
if binding.mode == RouteMode.ACTIVE
and binding.route_key.platform == platform
and binding.metadata.get("binding_role") == _ADAPTER_BINDING_ROLE_RUNTIME_EXACT
]
def _refresh_platform_default_route(self, platform: str) -> None:
"""根据当前精确绑定数量刷新平台级默认路由。
当某个平台恰好只存在一个动态精确绑定时,会为该绑定额外创建一条
``RouteKey(platform=<platform>)`` 形式的默认路由,方便缺少账号维度的
出站消息继续找到唯一 owner。若精确绑定数量变为 0 或大于 1则撤销
由 Host 自动维护的默认路由,避免出现隐式歧义。
Args:
platform: 目标平台名称。
"""
platform_io_manager = get_platform_io_manager()
default_route_key = RouteKey(platform=platform)
existing_default_binding = platform_io_manager.route_table.get_active_binding(default_route_key, exact_only=True)
if existing_default_binding is not None:
binding_role = existing_default_binding.metadata.get("binding_role")
if binding_role != _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT:
return
platform_io_manager.unbind_route(default_route_key, existing_default_binding.driver_id)
exact_bindings = self._list_runtime_exact_bindings(platform)
if len(exact_bindings) != 1:
return
exact_binding = exact_bindings[0]
if exact_binding.route_key == default_route_key:
return
platform_io_manager.bind_route(
RouteBinding(
route_key=default_route_key,
driver_id=exact_binding.driver_id,
driver_kind=exact_binding.driver_kind,
metadata={
"plugin_id": exact_binding.metadata.get("plugin_id", ""),
"protocol": exact_binding.metadata.get("protocol", ""),
"binding_role": _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT,
},
),
replace=True,
)
def _apply_adapter_runtime_state(
self,
plugin_id: str,
adapter: AdapterDeclarationPayload,
payload: AdapterStateUpdatePayload,
) -> Tuple[_AdapterRuntimeState, Dict[str, Any]]:
"""应用适配器运行时状态,并同步 Platform IO 路由。
Args:
plugin_id: 适配器插件 ID。
adapter: 当前适配器声明。
payload: 适配器上报的运行时状态。
Returns:
Tuple[_AdapterRuntimeState, Dict[str, Any]]: 更新后的运行时状态,以及
供 RPC 响应返回的路由键字典。
Raises:
RouteBindingConflictError: 当新的精确路由与其他 active owner 冲突时抛出。
ValueError: 当运行时路由信息不合法时抛出。
"""
if not payload.connected:
self._remove_adapter_route_bindings(plugin_id)
self._refresh_platform_default_route(adapter.platform)
runtime_state = _AdapterRuntimeState(connected=False, metadata=dict(payload.metadata))
self._adapter_runtime_states[plugin_id] = runtime_state
return runtime_state, {}
route_key = self._build_runtime_route_key(adapter, payload)
self._remove_adapter_route_bindings(plugin_id)
self._bind_runtime_exact_route(plugin_id, adapter, route_key)
self._refresh_platform_default_route(adapter.platform)
runtime_state = _AdapterRuntimeState(
connected=True,
account_id=route_key.account_id,
scope=route_key.scope,
metadata=dict(payload.metadata),
)
self._adapter_runtime_states[plugin_id] = runtime_state
return runtime_state, {
"platform": route_key.platform,
"account_id": route_key.account_id,
"scope": route_key.scope,
}
@staticmethod
def _attach_inbound_route_metadata(
session_message: "SessionMessage",
@@ -706,6 +914,45 @@ class PluginRunnerSupervisor:
scope=scope,
)
async def _handle_update_adapter_state(self, envelope: Envelope) -> Envelope:
"""处理适配器插件上报的运行时状态更新。
Args:
envelope: RPC 请求信封。
Returns:
Envelope: 状态更新处理结果。
"""
try:
payload = AdapterStateUpdatePayload.model_validate(envelope.payload)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
adapter = self._registered_adapters.get(envelope.plugin_id)
if adapter is None:
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"插件 {envelope.plugin_id} 未声明为适配器,不能更新运行时状态",
)
try:
runtime_state, route_key_dict = self._apply_adapter_runtime_state(
plugin_id=envelope.plugin_id,
adapter=adapter,
payload=payload,
)
except RouteBindingConflictError as exc:
return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, str(exc))
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
response = AdapterStateUpdateResultPayload(
accepted=True,
connected=runtime_state.connected,
route_key=route_key_dict,
)
return envelope.make_response(payload=response.model_dump())
async def _handle_receive_external_message(self, envelope: Envelope) -> Envelope:
"""处理适配器插件上报的外部入站消息。
@@ -970,6 +1217,7 @@ class PluginRunnerSupervisor:
self._component_registry.clear()
self._registered_plugins.clear()
self._registered_adapters.clear()
self._adapter_runtime_states.clear()
self._runner_ready_events = asyncio.Event()
self._runner_ready_payloads = RunnerReadyPayload()
self._rpc_server.clear_handshake_state()

View File

@@ -304,6 +304,30 @@ class AdapterDeclarationPayload(BaseModel):
"""适配器附加元数据"""
class AdapterStateUpdatePayload(BaseModel):
"""适配器运行时状态更新载荷。"""
connected: bool = Field(description="适配器当前是否已连接并准备接管路由")
"""适配器当前是否已连接并准备接管路由"""
account_id: str = Field(default="", description="当前连接对应的账号 ID 或 self_id")
"""当前连接对应的账号 ID 或 self_id"""
scope: str = Field(default="", description="当前连接对应的可选路由作用域")
"""当前连接对应的可选路由作用域"""
metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的运行时状态元数据")
"""可选的运行时状态元数据"""
class AdapterStateUpdateResultPayload(BaseModel):
"""适配器运行时状态更新结果载荷。"""
accepted: bool = Field(description="Host 是否接受了本次状态更新")
"""Host 是否接受了本次状态更新"""
connected: bool = Field(description="Host 记录的当前连接状态")
"""Host 记录的当前连接状态"""
route_key: Dict[str, Any] = Field(default_factory=dict, description="当前生效的路由键")
"""当前生效的路由键"""
class ReceiveExternalMessagePayload(BaseModel):
"""适配器插件向 Host 注入外部消息的请求载荷。"""

View File

@@ -481,13 +481,14 @@ class PluginRunner:
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False
if not await self._invoke_plugin_on_load(meta):
if not await self._register_plugin(meta):
await self._invoke_plugin_on_unload(meta)
await self._deactivate_plugin(meta)
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False
if not await self._register_plugin(meta):
await self._invoke_plugin_on_unload(meta)
if not await self._invoke_plugin_on_load(meta):
await self._unregister_plugin(meta.plugin_id, reason="on_load_failed")
await self._deactivate_plugin(meta)
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False

View File

@@ -60,6 +60,9 @@ class NapCatAdapterPlugin(MaiBotPlugin):
self._connection_task: Optional[asyncio.Task[None]] = None
self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
self._background_tasks: Set[asyncio.Task[Any]] = set()
self._reported_account_id: Optional[str] = None
self._reported_scope: Optional[str] = None
self._runtime_state_connected: bool = False
self._send_lock = asyncio.Lock()
self._ws: Optional[AiohttpClientWebSocketResponse] = None
@@ -170,6 +173,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
with contextlib.suppress(asyncio.CancelledError):
await connection_task
await self._report_adapter_disconnected()
self._fail_pending_actions("NapCat connection closed")
async def _cancel_background_tasks(self) -> None:
@@ -209,6 +213,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
self.ctx.logger.warning(f"NapCat 适配器连接失败: {exc}")
finally:
self._ws = None
await self._report_adapter_disconnected()
self._fail_pending_actions("NapCat connection interrupted")
if not self._should_connect():
@@ -230,26 +235,39 @@ class NapCatAdapterPlugin(MaiBotPlugin):
"""
assert WSMsgType is not None
async for ws_message in ws:
if ws_message.type != WSMsgType.TEXT:
if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
break
continue
bootstrap_task = asyncio.create_task(
self._bootstrap_adapter_runtime_state(),
name="napcat_adapter.bootstrap",
)
self._background_tasks.add(bootstrap_task)
bootstrap_task.add_done_callback(self._background_tasks.discard)
payload = self._parse_json_message(ws_message.data)
if payload is None:
continue
try:
async for ws_message in ws:
if ws_message.type != WSMsgType.TEXT:
if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
break
continue
if echo_id := str(payload.get("echo") or "").strip():
self._resolve_pending_action(echo_id, payload)
continue
payload = self._parse_json_message(ws_message.data)
if payload is None:
continue
if str(payload.get("post_type") or "").strip() != "message":
continue
if echo_id := str(payload.get("echo") or "").strip():
self._resolve_pending_action(echo_id, payload)
continue
task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound")
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
if str(payload.get("post_type") or "").strip() != "message":
continue
task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound")
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
finally:
if not bootstrap_task.done():
bootstrap_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await bootstrap_task
async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None:
"""处理单条 NapCat 入站消息并注入 Host。
@@ -258,6 +276,9 @@ class NapCatAdapterPlugin(MaiBotPlugin):
payload: NapCat / OneBot 推送的原始事件数据。
"""
self_id = str(payload.get("self_id") or "").strip()
if self_id:
await self._report_adapter_connected(self_id)
sender = payload.get("sender", {})
if not isinstance(sender, dict):
sender = {}
@@ -570,6 +591,121 @@ class NapCatAdapterPlugin(MaiBotPlugin):
response_future.set_exception(RuntimeError(error_message))
self._pending_actions.clear()
async def _bootstrap_adapter_runtime_state(self) -> None:
"""在连接建立后主动获取账号信息并激活适配器路由。
该步骤会在 WebSocket 接收循环启动后异步执行,确保 `_call_action()`
发出的 `get_login_info` 请求能够被同一连接上的接收循环消费到 echo
响应,从而在真正收到业务消息前就完成 Host 侧 route 激活。
"""
max_attempts = 3
last_error: Optional[Exception] = None
for attempt in range(1, max_attempts + 1):
ws = self._ws
if ws is None or ws.closed:
return
try:
response = await self._call_action("get_login_info", {})
self_id = self._extract_self_id_from_login_response(response)
await self._report_adapter_connected(self_id)
return
except asyncio.CancelledError:
raise
except Exception as exc:
last_error = exc
self.ctx.logger.warning(
f"NapCat 适配器获取登录信息失败,第 {attempt}/{max_attempts} 次重试: {exc}"
)
if attempt < max_attempts:
await asyncio.sleep(1.0)
if last_error is not None:
self.ctx.logger.error(f"NapCat 适配器未能完成路由激活,连接将保持只接收状态: {last_error}")
@staticmethod
def _extract_self_id_from_login_response(response: Dict[str, Any]) -> str:
"""从 `get_login_info` 响应中提取当前账号 ID。
Args:
response: NapCat 返回的原始动作响应。
Returns:
str: 规范化后的 `self_id` 字符串。
Raises:
ValueError: 当响应中缺少有效账号 ID 时抛出。
"""
if str(response.get("status") or "").lower() != "ok":
raise ValueError(str(response.get("wording") or response.get("message") or "get_login_info failed"))
response_data = response.get("data", {})
if not isinstance(response_data, dict):
raise ValueError("get_login_info 响应缺少 data 字段")
self_id = str(response_data.get("user_id") or "").strip()
if not self_id:
raise ValueError("get_login_info 响应缺少有效的 user_id")
return self_id
async def _report_adapter_connected(self, account_id: str) -> None:
"""向 Host 上报当前连接已就绪。
Args:
account_id: 当前 NapCat 连接对应的机器人账号 ID。
"""
normalized_account_id = str(account_id).strip()
if not normalized_account_id:
return
scope = self._get_string(self._connection_config(), "connection_id").strip()
if (
self._runtime_state_connected
and self._reported_account_id == normalized_account_id
and self._reported_scope == (scope or None)
):
return
accepted = False
try:
accepted = await self.ctx.adapter.update_runtime_state(
connected=True,
account_id=normalized_account_id,
scope=scope,
metadata={"ws_url": self._get_string(self._connection_config(), "ws_url")},
)
except Exception as exc:
self.ctx.logger.warning(f"NapCat 适配器上报连接就绪状态失败: {exc}")
return
if not accepted:
self.ctx.logger.warning("NapCat 适配器连接已建立,但 Host 未接受运行时状态更新")
return
self._runtime_state_connected = True
self._reported_account_id = normalized_account_id
self._reported_scope = scope or None
self.ctx.logger.info(
f"NapCat 适配器已激活路由: platform=qq account_id={normalized_account_id} "
f"scope={self._reported_scope or '*'}"
)
async def _report_adapter_disconnected(self) -> None:
"""向 Host 上报当前连接已断开,并撤销适配器路由。"""
if not self._runtime_state_connected:
self._reported_account_id = None
self._reported_scope = None
return
try:
await self.ctx.adapter.update_runtime_state(connected=False)
except Exception as exc:
self.ctx.logger.warning(f"NapCat 适配器上报断开状态失败: {exc}")
finally:
self._runtime_state_connected = False
self._reported_account_id = None
self._reported_scope = None
def _build_headers(self) -> Dict[str, str]:
"""构造连接 NapCat 所需的请求头。