feat: Implement adapter runtime state management and update handling
- Added support for adapter runtime state updates in the PluginRunnerSupervisor. - Introduced new payload classes: AdapterStateUpdatePayload and AdapterStateUpdateResultPayload for handling state updates. - Implemented methods to bind and unbind routes based on adapter connection status. - Enhanced the NapCat adapter to report connection state and manage runtime state. - Added tests for adapter runtime state synchronization and database session behavior in the statistic module. - Updated existing methods to ensure proper handling of adapter state and route bindings.
This commit is contained in:
355
pytests/common_test/test_person_info_group_cardname.py
Normal file
355
pytests/common_test/test_person_info_group_cardname.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""人物信息群名片字段兼容测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
|
||||
|
||||
|
||||
class _DummyLogger:
|
||||
"""模拟日志记录器。"""
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
"""记录调试日志。
|
||||
|
||||
Args:
|
||||
message: 日志内容。
|
||||
"""
|
||||
del message
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""记录信息日志。
|
||||
|
||||
Args:
|
||||
message: 日志内容。
|
||||
"""
|
||||
del message
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""记录警告日志。
|
||||
|
||||
Args:
|
||||
message: 日志内容。
|
||||
"""
|
||||
del message
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""记录错误日志。
|
||||
|
||||
Args:
|
||||
message: 日志内容。
|
||||
"""
|
||||
del message
|
||||
|
||||
|
||||
class _DummyStatement:
|
||||
"""模拟 SQL 查询语句对象。"""
|
||||
|
||||
def where(self, condition: Any) -> "_DummyStatement":
|
||||
"""附加过滤条件。
|
||||
|
||||
Args:
|
||||
condition: 过滤条件。
|
||||
|
||||
Returns:
|
||||
_DummyStatement: 当前语句对象。
|
||||
"""
|
||||
del condition
|
||||
return self
|
||||
|
||||
def limit(self, value: int) -> "_DummyStatement":
|
||||
"""限制返回条数。
|
||||
|
||||
Args:
|
||||
value: 条数限制。
|
||||
|
||||
Returns:
|
||||
_DummyStatement: 当前语句对象。
|
||||
"""
|
||||
del value
|
||||
return self
|
||||
|
||||
|
||||
class _DummyColumn:
|
||||
"""模拟 SQLModel 列对象。"""
|
||||
|
||||
def is_not(self, value: Any) -> "_DummyColumn":
|
||||
"""模拟 `IS NOT` 条件构造。
|
||||
|
||||
Args:
|
||||
value: 比较值。
|
||||
|
||||
Returns:
|
||||
_DummyColumn: 当前列对象。
|
||||
"""
|
||||
del value
|
||||
return self
|
||||
|
||||
def __eq__(self, other: Any) -> "_DummyColumn":
|
||||
"""模拟等值条件构造。
|
||||
|
||||
Args:
|
||||
other: 比较值。
|
||||
|
||||
Returns:
|
||||
_DummyColumn: 当前列对象。
|
||||
"""
|
||||
del other
|
||||
return self
|
||||
|
||||
|
||||
class _DummyResult:
|
||||
"""模拟数据库查询结果。"""
|
||||
|
||||
def __init__(self, record: Any) -> None:
|
||||
"""初始化查询结果。
|
||||
|
||||
Args:
|
||||
record: 待返回的首条记录。
|
||||
"""
|
||||
self._record = record
|
||||
|
||||
def first(self) -> Any:
|
||||
"""返回第一条记录。
|
||||
|
||||
Returns:
|
||||
Any: 首条记录。
|
||||
"""
|
||||
return self._record
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
"""返回全部结果。
|
||||
|
||||
Returns:
|
||||
list[Any]: 结果列表。
|
||||
"""
|
||||
if self._record is None:
|
||||
return []
|
||||
return self._record if isinstance(self._record, list) else [self._record]
|
||||
|
||||
|
||||
class _DummySession:
|
||||
"""模拟数据库 Session。"""
|
||||
|
||||
def __init__(self, record: Any) -> None:
|
||||
"""初始化 Session。
|
||||
|
||||
Args:
|
||||
record: `first()` 应返回的记录。
|
||||
"""
|
||||
self.record = record
|
||||
self.added_records: list[Any] = []
|
||||
|
||||
def __enter__(self) -> "_DummySession":
|
||||
"""进入上下文管理器。
|
||||
|
||||
Returns:
|
||||
_DummySession: 当前 Session。
|
||||
"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""退出上下文管理器。
|
||||
|
||||
Args:
|
||||
exc_type: 异常类型。
|
||||
exc_val: 异常值。
|
||||
exc_tb: 异常回溯。
|
||||
"""
|
||||
del exc_type
|
||||
del exc_val
|
||||
del exc_tb
|
||||
|
||||
def exec(self, statement: Any) -> _DummyResult:
|
||||
"""执行查询。
|
||||
|
||||
Args:
|
||||
statement: 查询语句。
|
||||
|
||||
Returns:
|
||||
_DummyResult: 模拟结果对象。
|
||||
"""
|
||||
del statement
|
||||
return _DummyResult(self.record)
|
||||
|
||||
def add(self, record: Any) -> None:
|
||||
"""记录被添加的对象。
|
||||
|
||||
Args:
|
||||
record: 被写入 Session 的对象。
|
||||
"""
|
||||
self.added_records.append(record)
|
||||
|
||||
|
||||
class _DummyPersonInfoRecord:
|
||||
"""模拟 `PersonInfo` ORM 模型。"""
|
||||
|
||||
person_id = "person_id"
|
||||
person_name = "person_name"
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""使用关键字参数初始化记录对象。
|
||||
|
||||
Args:
|
||||
**kwargs: 字段值。
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
def _load_person_module(monkeypatch: pytest.MonkeyPatch, session: _DummySession) -> ModuleType:
|
||||
"""加载带依赖桩的 `person_info` 模块。
|
||||
|
||||
Args:
|
||||
monkeypatch: Pytest monkeypatch 工具。
|
||||
session: 提供给模块使用的假数据库 Session。
|
||||
|
||||
Returns:
|
||||
ModuleType: 加载后的模块对象。
|
||||
"""
|
||||
logger_module = ModuleType("src.common.logger")
|
||||
logger_module.get_logger = lambda name: _DummyLogger()
|
||||
monkeypatch.setitem(sys.modules, "src.common.logger", logger_module)
|
||||
|
||||
database_module = ModuleType("src.common.database.database")
|
||||
database_module.get_db_session = lambda: session
|
||||
monkeypatch.setitem(sys.modules, "src.common.database.database", database_module)
|
||||
|
||||
database_model_module = ModuleType("src.common.database.database_model")
|
||||
database_model_module.PersonInfo = _DummyPersonInfoRecord
|
||||
monkeypatch.setitem(sys.modules, "src.common.database.database_model", database_model_module)
|
||||
|
||||
llm_module = ModuleType("src.llm_models.utils_model")
|
||||
|
||||
class _DummyLLMRequest:
|
||||
"""模拟 LLMRequest。"""
|
||||
|
||||
def __init__(self, model_set: Any, request_type: str) -> None:
|
||||
"""初始化假请求对象。
|
||||
|
||||
Args:
|
||||
model_set: 模型配置。
|
||||
request_type: 请求类型。
|
||||
"""
|
||||
del model_set
|
||||
del request_type
|
||||
|
||||
llm_module.LLMRequest = _DummyLLMRequest
|
||||
monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_module)
|
||||
|
||||
config_module = ModuleType("src.config.config")
|
||||
config_module.global_config = SimpleNamespace(bot=SimpleNamespace(nickname="MaiBot"))
|
||||
config_module.model_config = SimpleNamespace(model_task_config=SimpleNamespace(tool_use="tool_use", utils="utils"))
|
||||
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
|
||||
|
||||
chat_manager_module = ModuleType("src.chat.message_receive.chat_manager")
|
||||
chat_manager_module.chat_manager = SimpleNamespace()
|
||||
monkeypatch.setitem(sys.modules, "src.chat.message_receive.chat_manager", chat_manager_module)
|
||||
|
||||
module_path = Path(__file__).resolve().parents[2] / "src" / "person_info" / "person_info.py"
|
||||
spec = spec_from_file_location("person_info_group_cardname_test_module", module_path)
|
||||
assert spec is not None and spec.loader is not None
|
||||
|
||||
module = module_from_spec(spec)
|
||||
monkeypatch.setitem(sys.modules, spec.name, module)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
monkeypatch.setattr(module, "select", lambda *args: _DummyStatement())
|
||||
monkeypatch.setattr(module, "col", lambda field: _DummyColumn())
|
||||
return module
|
||||
|
||||
|
||||
def test_parse_group_cardname_json_uses_canonical_key() -> None:
|
||||
"""群名片 JSON 解析应只使用 `group_cardname` 键名。"""
|
||||
parsed = parse_group_cardname_json(
|
||||
json.dumps(
|
||||
[
|
||||
{"group_id": "1001", "group_cardname": "现行字段"},
|
||||
],
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert parsed is not None
|
||||
assert [(item.group_id, item.group_cardname) for item in parsed] == [
|
||||
("1001", "现行字段"),
|
||||
]
|
||||
|
||||
|
||||
def test_dump_group_cardname_records_uses_canonical_key() -> None:
|
||||
"""群名片序列化应输出 `group_cardname` 键名。"""
|
||||
dumped = dump_group_cardname_records(
|
||||
[
|
||||
{"group_id": "1001", "group_cardname": "群昵称"},
|
||||
]
|
||||
)
|
||||
|
||||
assert json.loads(dumped) == [{"group_id": "1001", "group_cardname": "群昵称"}]
|
||||
|
||||
|
||||
def test_person_sync_to_database_uses_group_cardname_field(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""同步人物信息时应写入数据库模型的 `group_cardname` 字段。"""
|
||||
record = _DummyPersonInfoRecord()
|
||||
session = _DummySession(record)
|
||||
module = _load_person_module(monkeypatch, session)
|
||||
|
||||
person = module.Person.__new__(module.Person)
|
||||
person.is_known = True
|
||||
person.person_id = "person-1"
|
||||
person.platform = "qq"
|
||||
person.user_id = "10001"
|
||||
person.nickname = "看番的龙"
|
||||
person.person_name = "看番的龙"
|
||||
person.name_reason = "测试"
|
||||
person.know_times = 1
|
||||
person.know_since = 1700000000.0
|
||||
person.last_know = 1700000100.0
|
||||
person.memory_points = ["喜好:番剧:0.8"]
|
||||
person.group_cardname_list = [{"group_id": "20001", "group_cardname": "白泽大人"}]
|
||||
|
||||
person.sync_to_database()
|
||||
|
||||
assert record.group_cardname == '[{"group_id": "20001", "group_cardname": "白泽大人"}]'
|
||||
assert not hasattr(record, "group_nickname")
|
||||
|
||||
|
||||
def test_person_load_from_database_normalizes_group_cardname_payload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""从数据库加载人物信息时应读取标准 `group_cardname` 结构。"""
|
||||
record = _DummyPersonInfoRecord(
|
||||
user_id="10001",
|
||||
platform="qq",
|
||||
is_known=True,
|
||||
user_nickname="看番的龙",
|
||||
person_name="看番的龙",
|
||||
name_reason=None,
|
||||
know_counts=2,
|
||||
memory_points='["喜好:番剧:0.8"]',
|
||||
group_cardname=json.dumps(
|
||||
[
|
||||
{"group_id": "20001", "group_cardname": "白泽大人"},
|
||||
],
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
session = _DummySession(record)
|
||||
module = _load_person_module(monkeypatch, session)
|
||||
|
||||
person = module.Person.__new__(module.Person)
|
||||
person.person_id = "person-1"
|
||||
person.memory_points = []
|
||||
person.group_cardname_list = []
|
||||
|
||||
person.load_from_database()
|
||||
|
||||
assert person.group_cardname_list == [
|
||||
{"group_id": "20001", "group_cardname": "白泽大人"},
|
||||
]
|
||||
162
pytests/test_adapter_runtime_state.py
Normal file
162
pytests/test_adapter_runtime_state.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""适配器运行时状态同步测试。"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from src.platform_io.manager import PlatformIOManager
|
||||
from src.platform_io.types import RouteKey
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
AdapterDeclarationPayload,
|
||||
Envelope,
|
||||
MessageType,
|
||||
)
|
||||
|
||||
|
||||
def _make_request(plugin_id: str, payload: Dict[str, Any]) -> Envelope:
|
||||
"""构造一个适配器状态更新 RPC 请求。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标适配器插件 ID。
|
||||
payload: 请求载荷。
|
||||
|
||||
Returns:
|
||||
Envelope: 标准 RPC 请求信封。
|
||||
"""
|
||||
return Envelope(
|
||||
request_id=1,
|
||||
message_type=MessageType.REQUEST,
|
||||
method="host.update_adapter_state",
|
||||
plugin_id=plugin_id,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapter_runtime_state_binds_and_unbinds_routes(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""连接建立后应绑定路由,断开后应撤销路由。"""
|
||||
import src.plugin_runtime.host.supervisor as supervisor_module
|
||||
|
||||
platform_io_manager = PlatformIOManager()
|
||||
monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
|
||||
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
adapter = AdapterDeclarationPayload(platform="qq", protocol="napcat")
|
||||
await supervisor._register_adapter_driver("napcat_adapter_builtin", adapter)
|
||||
|
||||
response = await supervisor._handle_update_adapter_state(
|
||||
_make_request(
|
||||
"napcat_adapter_builtin",
|
||||
{
|
||||
"connected": True,
|
||||
"account_id": "10001",
|
||||
"scope": "",
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert response.payload["accepted"] is True
|
||||
assert (
|
||||
platform_io_manager.route_table.get_active_binding(
|
||||
RouteKey(platform="qq", account_id="10001"),
|
||||
exact_only=True,
|
||||
).driver_id
|
||||
== "adapter:napcat_adapter_builtin"
|
||||
)
|
||||
assert (
|
||||
platform_io_manager.route_table.get_active_binding(
|
||||
RouteKey(platform="qq"),
|
||||
exact_only=True,
|
||||
).driver_id
|
||||
== "adapter:napcat_adapter_builtin"
|
||||
)
|
||||
|
||||
response = await supervisor._handle_update_adapter_state(
|
||||
_make_request(
|
||||
"napcat_adapter_builtin",
|
||||
{
|
||||
"connected": False,
|
||||
"account_id": "",
|
||||
"scope": "",
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert response.payload["accepted"] is True
|
||||
assert platform_io_manager.route_table.get_active_binding(
|
||||
RouteKey(platform="qq", account_id="10001"),
|
||||
exact_only=True,
|
||||
) is None
|
||||
assert platform_io_manager.route_table.get_active_binding(RouteKey(platform="qq"), exact_only=True) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_default_route_is_removed_when_multiple_exact_routes_exist(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""同一平台存在多个精确路由时不应保留默认平台路由。"""
|
||||
import src.plugin_runtime.host.supervisor as supervisor_module
|
||||
|
||||
platform_io_manager = PlatformIOManager()
|
||||
monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
|
||||
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
adapter = AdapterDeclarationPayload(platform="qq", protocol="napcat")
|
||||
await supervisor._register_adapter_driver("adapter_a", adapter)
|
||||
await supervisor._register_adapter_driver("adapter_b", adapter)
|
||||
|
||||
await supervisor._handle_update_adapter_state(
|
||||
_make_request(
|
||||
"adapter_a",
|
||||
{
|
||||
"connected": True,
|
||||
"account_id": "10001",
|
||||
"scope": "",
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
)
|
||||
assert (
|
||||
platform_io_manager.route_table.get_active_binding(
|
||||
RouteKey(platform="qq"),
|
||||
exact_only=True,
|
||||
).driver_id
|
||||
== "adapter:adapter_a"
|
||||
)
|
||||
|
||||
await supervisor._handle_update_adapter_state(
|
||||
_make_request(
|
||||
"adapter_b",
|
||||
{
|
||||
"connected": True,
|
||||
"account_id": "10002",
|
||||
"scope": "",
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
)
|
||||
assert platform_io_manager.route_table.get_active_binding(RouteKey(platform="qq"), exact_only=True) is None
|
||||
|
||||
await supervisor._handle_update_adapter_state(
|
||||
_make_request(
|
||||
"adapter_b",
|
||||
{
|
||||
"connected": False,
|
||||
"account_id": "",
|
||||
"scope": "",
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
)
|
||||
assert (
|
||||
platform_io_manager.route_table.get_active_binding(
|
||||
RouteKey(platform="qq"),
|
||||
exact_only=True,
|
||||
).driver_id
|
||||
== "adapter:adapter_a"
|
||||
)
|
||||
@@ -486,10 +486,10 @@ class TestSDK:
|
||||
"timeout_ms": timeout_ms,
|
||||
}
|
||||
)
|
||||
if method == "cap.request":
|
||||
if method == "cap.call":
|
||||
bootstrap_methods = [call["method"] for call in self.calls[:-1]]
|
||||
assert "plugin.bootstrap" in bootstrap_methods
|
||||
return SimpleNamespace(error=None, payload={"result": {"success": True}})
|
||||
return SimpleNamespace(error=None, payload={"success": True})
|
||||
return SimpleNamespace(error=None, payload={"accepted": True})
|
||||
|
||||
async def disconnect(self):
|
||||
@@ -529,7 +529,7 @@ class TestSDK:
|
||||
await runner.run()
|
||||
|
||||
methods = [call["method"] for call in runner._rpc_client.calls]
|
||||
assert methods == ["plugin.bootstrap", "cap.request", "plugin.register_components", "runner.ready"]
|
||||
assert methods == ["plugin.bootstrap", "plugin.register_components", "cap.call", "runner.ready"]
|
||||
|
||||
|
||||
class TestPluginSdkUsage:
|
||||
|
||||
115
pytests/utils_test/statistic_test.py
Normal file
115
pytests/utils_test/statistic_test.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""统计模块数据库会话行为测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Iterator
|
||||
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from src.chat.utils import statistic
|
||||
|
||||
|
||||
class _DummyResult:
|
||||
"""模拟 SQLModel 查询结果对象。"""
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
"""返回空结果集。
|
||||
|
||||
Returns:
|
||||
list[Any]: 空列表。
|
||||
"""
|
||||
return []
|
||||
|
||||
|
||||
class _DummySession:
|
||||
"""模拟数据库 Session。"""
|
||||
|
||||
def exec(self, statement: Any) -> _DummyResult:
|
||||
"""执行查询语句并返回空结果。
|
||||
|
||||
Args:
|
||||
statement: 待执行的查询语句。
|
||||
|
||||
Returns:
|
||||
_DummyResult: 空结果对象。
|
||||
"""
|
||||
del statement
|
||||
return _DummyResult()
|
||||
|
||||
|
||||
def _build_fake_get_db_session(calls: list[bool]) -> Callable[[bool], Iterator[_DummySession]]:
|
||||
"""构造一个记录 auto_commit 参数的假会话工厂。
|
||||
|
||||
Args:
|
||||
calls: 用于记录每次调用 auto_commit 参数的列表。
|
||||
|
||||
Returns:
|
||||
Callable[[bool], Iterator[_DummySession]]: 可替换 `get_db_session` 的上下文管理器工厂。
|
||||
"""
|
||||
|
||||
@contextmanager
|
||||
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_DummySession]:
|
||||
"""记录会话参数并返回假 Session。
|
||||
|
||||
Args:
|
||||
auto_commit: 是否启用自动提交。
|
||||
|
||||
Yields:
|
||||
Iterator[_DummySession]: 假 Session 对象。
|
||||
"""
|
||||
calls.append(auto_commit)
|
||||
yield _DummySession()
|
||||
|
||||
return _fake_get_db_session
|
||||
|
||||
|
||||
def _build_statistic_task() -> statistic.StatisticOutputTask:
|
||||
"""构造一个最小可用的统计任务实例。
|
||||
|
||||
Returns:
|
||||
statistic.StatisticOutputTask: 跳过 `__init__` 的测试实例。
|
||||
"""
|
||||
task = statistic.StatisticOutputTask.__new__(statistic.StatisticOutputTask)
|
||||
task.name_mapping = {}
|
||||
return task
|
||||
|
||||
|
||||
def _is_bot_self(platform: str, user_id: str) -> bool:
|
||||
"""返回固定的非机器人身份判断结果。
|
||||
|
||||
Args:
|
||||
platform: 平台名称。
|
||||
user_id: 用户 ID。
|
||||
|
||||
Returns:
|
||||
bool: 始终返回 ``False``。
|
||||
"""
|
||||
del platform
|
||||
del user_id
|
||||
return False
|
||||
|
||||
|
||||
def test_statistic_read_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""统计模块的纯读查询应关闭自动提交,避免 Session 退出后对象被 expire。"""
|
||||
calls: list[bool] = []
|
||||
now = datetime.now()
|
||||
task = _build_statistic_task()
|
||||
|
||||
monkeypatch.setattr(statistic, "get_db_session", _build_fake_get_db_session(calls))
|
||||
|
||||
utils_module = ModuleType("src.chat.utils.utils")
|
||||
utils_module.is_bot_self = _is_bot_self
|
||||
monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
|
||||
|
||||
statistic.StatisticOutputTask._fetch_online_time_since(now)
|
||||
statistic.StatisticOutputTask._fetch_model_usage_since(now)
|
||||
task._collect_message_count_for_period([("last_hour", now - timedelta(hours=1))])
|
||||
task._collect_interval_data(now, hours=1, interval_minutes=60)
|
||||
task._collect_metrics_interval_data(now, hours=1, interval_hours=1)
|
||||
|
||||
assert calls == [False] * 9
|
||||
@@ -436,14 +436,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
@staticmethod
|
||||
def _fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]:
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time)
|
||||
records = session.exec(statement).all()
|
||||
return [(record.start_timestamp, record.end_timestamp) for record in records]
|
||||
|
||||
@staticmethod
|
||||
def _fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]:
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time)
|
||||
records = session.exec(statement).all()
|
||||
return [
|
||||
@@ -664,7 +664,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_timestamp = collect_period[-1][1]
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Messages).where(col(Messages.timestamp) >= query_start_timestamp)
|
||||
messages = session.exec(statement).all()
|
||||
for message in messages:
|
||||
@@ -713,7 +713,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
|
||||
try:
|
||||
action_query_start_timestamp = collect_period[-1][1]
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp)
|
||||
actions = session.exec(statement).all()
|
||||
for action in actions:
|
||||
@@ -1750,7 +1750,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
|
||||
messages = session.exec(statement).all()
|
||||
for message in messages:
|
||||
@@ -2131,7 +2131,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
with get_db_session() as session:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
|
||||
messages = session.exec(statement).all()
|
||||
for message in messages:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from typing import Any, List, Mapping, Optional, Sequence
|
||||
|
||||
import json
|
||||
|
||||
@@ -15,6 +15,76 @@ class GroupCardnameInfo:
|
||||
group_cardname: str
|
||||
|
||||
|
||||
def _normalize_group_cardname_item(raw_item: Mapping[str, Any]) -> Optional[GroupCardnameInfo]:
|
||||
"""将单条群名片数据规范化为统一结构。
|
||||
|
||||
Args:
|
||||
raw_item: 原始群名片字典,必须包含 `group_id` 和 `group_cardname`。
|
||||
|
||||
Returns:
|
||||
Optional[GroupCardnameInfo]: 规范化后的群名片信息;若数据不完整则返回 ``None``。
|
||||
"""
|
||||
group_id = str(raw_item.get("group_id") or "").strip()
|
||||
group_cardname = str(raw_item.get("group_cardname") or "").strip()
|
||||
if not group_id or not group_cardname:
|
||||
return None
|
||||
return GroupCardnameInfo(group_id=group_id, group_cardname=group_cardname)
|
||||
|
||||
|
||||
def parse_group_cardname_json(group_cardname_json: Optional[str]) -> Optional[List[GroupCardnameInfo]]:
|
||||
"""解析数据库中的群名片 JSON 字段。
|
||||
|
||||
Args:
|
||||
group_cardname_json: 数据库存储的群名片 JSON 字符串。
|
||||
|
||||
Returns:
|
||||
Optional[List[GroupCardnameInfo]]: 解析并规范化后的群名片列表;若字段为空或无有效项则返回 ``None``。
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
|
||||
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
|
||||
"""
|
||||
if not group_cardname_json:
|
||||
return None
|
||||
|
||||
raw_items = json.loads(group_cardname_json)
|
||||
if not isinstance(raw_items, list):
|
||||
return None
|
||||
|
||||
normalized_items: List[GroupCardnameInfo] = []
|
||||
for raw_item in raw_items:
|
||||
if not isinstance(raw_item, Mapping):
|
||||
continue
|
||||
if normalized_item := _normalize_group_cardname_item(raw_item):
|
||||
normalized_items.append(normalized_item)
|
||||
|
||||
return normalized_items or None
|
||||
|
||||
|
||||
def dump_group_cardname_records(
|
||||
group_cardname_records: Optional[Sequence[GroupCardnameInfo | Mapping[str, Any]]],
|
||||
) -> str:
|
||||
"""将群名片列表序列化为数据库使用的标准 JSON 字符串。
|
||||
|
||||
Args:
|
||||
group_cardname_records: 待序列化的群名片列表,支持 `GroupCardnameInfo`
|
||||
对象和包含 `group_id` / `group_cardname` 的字典。
|
||||
|
||||
Returns:
|
||||
str: 统一使用 `group_cardname` 键名的 JSON 字符串。
|
||||
"""
|
||||
normalized_items: List[GroupCardnameInfo] = []
|
||||
for raw_item in group_cardname_records or []:
|
||||
if isinstance(raw_item, GroupCardnameInfo):
|
||||
normalized_items.append(raw_item)
|
||||
continue
|
||||
if isinstance(raw_item, Mapping):
|
||||
if normalized_item := _normalize_group_cardname_item(raw_item):
|
||||
normalized_items.append(normalized_item)
|
||||
|
||||
return json.dumps([asdict(item) for item in normalized_items], ensure_ascii=False)
|
||||
|
||||
|
||||
class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -58,9 +128,16 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
||||
"""最后一次被认识的时间"""
|
||||
|
||||
@classmethod
|
||||
def from_db_instance(cls, db_record: "PersonInfo"):
|
||||
nickname_json = json.loads(db_record.group_cardname) if db_record.group_cardname else None
|
||||
group_cardname_list = [GroupCardnameInfo(**item) for item in nickname_json] if nickname_json else None
|
||||
def from_db_instance(cls, db_record: "PersonInfo") -> "MaiPersonInfo":
|
||||
"""从数据库记录构造人物信息数据模型。
|
||||
|
||||
Args:
|
||||
db_record: 数据库中的人物信息记录。
|
||||
|
||||
Returns:
|
||||
MaiPersonInfo: 转换后的数据模型对象。
|
||||
"""
|
||||
group_cardname_list = parse_group_cardname_json(db_record.group_cardname)
|
||||
memory_points = json.loads(db_record.memory_points) if db_record.memory_points else None
|
||||
return cls(
|
||||
is_known=db_record.is_known,
|
||||
@@ -78,9 +155,12 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
||||
)
|
||||
|
||||
def to_db_instance(self) -> "PersonInfo":
|
||||
group_cardname = (
|
||||
json.dumps([gc.__dict__ for gc in self.group_cardname_list]) if self.group_cardname_list else None
|
||||
)
|
||||
"""将当前数据模型转换为数据库记录对象。
|
||||
|
||||
Returns:
|
||||
PersonInfo: 可直接写入数据库的模型实例。
|
||||
"""
|
||||
group_cardname = dump_group_cardname_records(self.group_cardname_list)
|
||||
return PersonInfo(
|
||||
is_known=self.is_known,
|
||||
person_id=self.person_id,
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import Union, Optional, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
|
||||
logger = get_logger("person_info")
|
||||
@@ -26,6 +28,32 @@ relation_selection_model = LLMRequest(
|
||||
)
|
||||
|
||||
|
||||
def _to_group_cardname_records(group_cardname_json: Optional[str]) -> list[dict[str, str]]:
|
||||
"""将数据库中的群名片 JSON 转换为 `Person` 内部使用的结构。
|
||||
|
||||
Args:
|
||||
group_cardname_json: 数据库存储的群名片 JSON 字符串。
|
||||
|
||||
Returns:
|
||||
list[dict[str, str]]: 统一使用 `group_cardname` 键名的群名片列表。
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
|
||||
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
|
||||
"""
|
||||
group_cardname_list = parse_group_cardname_json(group_cardname_json)
|
||||
if not group_cardname_list:
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"group_id": group_cardname.group_id,
|
||||
"group_cardname": group_cardname.group_cardname,
|
||||
}
|
||||
for group_cardname in group_cardname_list
|
||||
]
|
||||
|
||||
|
||||
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||
"""获取唯一id"""
|
||||
if "-" in platform:
|
||||
@@ -231,7 +259,7 @@ class Person:
|
||||
person.know_since = time.time()
|
||||
person.last_know = time.time()
|
||||
person.memory_points = []
|
||||
person.group_nick_name = [] # 初始化群昵称列表
|
||||
person.group_cardname_list = [] # 初始化群名片列表
|
||||
|
||||
# 如果是群聊,添加群昵称
|
||||
if group_id and group_nick_name:
|
||||
@@ -269,7 +297,7 @@ class Person:
|
||||
self.platform = platform
|
||||
self.nickname = global_config.bot.nickname
|
||||
self.person_name = global_config.bot.nickname
|
||||
self.group_nick_name: list[dict[str, str]] = []
|
||||
self.group_cardname_list: list[dict[str, str]] = []
|
||||
return
|
||||
|
||||
self.user_id = ""
|
||||
@@ -308,7 +336,7 @@ class Person:
|
||||
self.know_since = None
|
||||
self.last_know: Optional[float] = None
|
||||
self.memory_points = []
|
||||
self.group_nick_name: list[dict[str, str]] = [] # 群昵称列表,存储 {"group_id": str, "group_nick_name": str}
|
||||
self.group_cardname_list: list[dict[str, str]] = [] # 群名片列表,存储 {"group_id": str, "group_cardname": str}
|
||||
|
||||
# 从数据库加载数据
|
||||
self.load_from_database()
|
||||
@@ -408,16 +436,16 @@ class Person:
|
||||
return
|
||||
|
||||
# 检查是否已存在该群号的记录
|
||||
for item in self.group_nick_name:
|
||||
for item in self.group_cardname_list:
|
||||
if item.get("group_id") == group_id:
|
||||
# 更新现有记录
|
||||
item["group_nick_name"] = group_nick_name
|
||||
item["group_cardname"] = group_nick_name
|
||||
self.sync_to_database()
|
||||
logger.debug(f"更新用户 {self.person_id} 在群 {group_id} 的群昵称为 {group_nick_name}")
|
||||
return
|
||||
|
||||
# 添加新记录
|
||||
self.group_nick_name.append({"group_id": group_id, "group_nick_name": group_nick_name})
|
||||
self.group_cardname_list.append({"group_id": group_id, "group_cardname": group_nick_name})
|
||||
self.sync_to_database()
|
||||
logger.debug(f"添加用户 {self.person_id} 在群 {group_id} 的群昵称 {group_nick_name}")
|
||||
|
||||
@@ -452,20 +480,15 @@ class Person:
|
||||
else:
|
||||
self.memory_points = []
|
||||
|
||||
# 处理group_nick_name字段(JSON格式的列表)
|
||||
# 处理 group_cardname 字段(JSON 格式的列表)
|
||||
if record.group_cardname:
|
||||
try:
|
||||
loaded_group_nick_names = json.loads(record.group_cardname)
|
||||
# 确保是列表格式
|
||||
if isinstance(loaded_group_nick_names, list):
|
||||
self.group_nick_name = loaded_group_nick_names
|
||||
else:
|
||||
self.group_nick_name = []
|
||||
self.group_cardname_list = _to_group_cardname_records(record.group_cardname)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(f"解析用户 {self.person_id} 的group_cardname字段失败,使用默认值")
|
||||
self.group_nick_name = []
|
||||
self.group_cardname_list = []
|
||||
else:
|
||||
self.group_nick_name = []
|
||||
self.group_cardname_list = []
|
||||
|
||||
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
|
||||
else:
|
||||
@@ -486,11 +509,7 @@ class Person:
|
||||
if self.memory_points
|
||||
else json.dumps([], ensure_ascii=False)
|
||||
)
|
||||
group_nickname_value = (
|
||||
json.dumps(self.group_nick_name, ensure_ascii=False)
|
||||
if self.group_nick_name
|
||||
else json.dumps([], ensure_ascii=False)
|
||||
)
|
||||
group_cardname_value = dump_group_cardname_records(self.group_cardname_list)
|
||||
first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None
|
||||
last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None
|
||||
|
||||
@@ -510,7 +529,7 @@ class Person:
|
||||
record.first_known_time = first_known_time
|
||||
record.last_known_time = last_known_time
|
||||
record.memory_points = memory_points_value
|
||||
record.group_nickname = group_nickname_value
|
||||
record.group_cardname = group_cardname_value
|
||||
session.add(record)
|
||||
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
|
||||
else:
|
||||
@@ -526,7 +545,7 @@ class Person:
|
||||
first_known_time=first_known_time,
|
||||
last_known_time=last_known_time,
|
||||
memory_points=memory_points_value,
|
||||
group_nickname=group_nickname_value,
|
||||
group_cardname=group_cardname_value,
|
||||
)
|
||||
session.add(record)
|
||||
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
@@ -8,13 +9,15 @@ import sys
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
|
||||
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, RouteMode, get_platform_io_manager
|
||||
from src.platform_io.drivers import PluginPlatformDriver
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
from src.platform_io.routing import RouteBindingConflictError
|
||||
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
AdapterDeclarationPayload,
|
||||
AdapterStateUpdatePayload,
|
||||
AdapterStateUpdateResultPayload,
|
||||
BootstrapPluginPayload,
|
||||
ConfigUpdatedPayload,
|
||||
Envelope,
|
||||
@@ -46,6 +49,19 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = get_logger("plugin_runtime.host.runner_manager")
|
||||
|
||||
_ADAPTER_BINDING_ROLE_RUNTIME_EXACT = "runtime_exact"
|
||||
_ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT = "platform_default"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _AdapterRuntimeState:
|
||||
"""保存适配器插件当前的运行时连接状态。"""
|
||||
|
||||
connected: bool = False
|
||||
account_id: Optional[str] = None
|
||||
scope: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PluginRunnerSupervisor:
|
||||
"""插件 Runner 监督器。
|
||||
@@ -94,6 +110,7 @@ class PluginRunnerSupervisor:
|
||||
self._runner_process: Optional[asyncio.subprocess.Process] = None
|
||||
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
|
||||
self._registered_adapters: Dict[str, AdapterDeclarationPayload] = {}
|
||||
self._adapter_runtime_states: Dict[str, _AdapterRuntimeState] = {}
|
||||
self._runner_ready_events: asyncio.Event = asyncio.Event()
|
||||
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
|
||||
self._health_task: Optional[asyncio.Task[None]] = None
|
||||
@@ -452,6 +469,7 @@ class PluginRunnerSupervisor:
|
||||
"""注册 Host 侧内部 RPC 方法。"""
|
||||
self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request)
|
||||
self._rpc_server.register_method("host.receive_external_message", self._handle_receive_external_message)
|
||||
self._rpc_server.register_method("host.update_adapter_state", self._handle_update_adapter_state)
|
||||
self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin)
|
||||
self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin)
|
||||
self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin)
|
||||
@@ -563,14 +581,14 @@ class PluginRunnerSupervisor:
|
||||
return f"adapter:{plugin_id}"
|
||||
|
||||
async def _register_adapter_driver(self, plugin_id: str, adapter: AdapterDeclarationPayload) -> None:
|
||||
"""将适配器插件注册到 Platform IO。
|
||||
"""将适配器插件驱动注册到 Platform IO。
|
||||
|
||||
Args:
|
||||
plugin_id: 适配器插件 ID。
|
||||
adapter: 经过校验的适配器声明。
|
||||
|
||||
Raises:
|
||||
ValueError: 适配器路由冲突或驱动注册失败时抛出。
|
||||
ValueError: 当驱动注册失败时抛出。
|
||||
"""
|
||||
await self._unregister_adapter_driver(plugin_id)
|
||||
|
||||
@@ -588,22 +606,12 @@ class PluginRunnerSupervisor:
|
||||
**adapter.metadata,
|
||||
},
|
||||
)
|
||||
binding = RouteBinding(
|
||||
route_key=driver.descriptor.route_key,
|
||||
driver_id=driver.driver_id,
|
||||
driver_kind=DriverKind.PLUGIN,
|
||||
metadata={
|
||||
"plugin_id": plugin_id,
|
||||
"protocol": adapter.protocol,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
if platform_io_manager.is_started:
|
||||
await platform_io_manager.add_driver(driver)
|
||||
else:
|
||||
platform_io_manager.register_driver(driver)
|
||||
platform_io_manager.bind_route(binding)
|
||||
except Exception:
|
||||
with contextlib.suppress(Exception):
|
||||
if platform_io_manager.is_started:
|
||||
@@ -613,6 +621,7 @@ class PluginRunnerSupervisor:
|
||||
raise
|
||||
|
||||
self._registered_adapters[plugin_id] = adapter
|
||||
self._adapter_runtime_states[plugin_id] = _AdapterRuntimeState()
|
||||
|
||||
async def _unregister_adapter_driver(self, plugin_id: str) -> None:
|
||||
"""从 Platform IO 注销一个适配器驱动。
|
||||
@@ -622,6 +631,9 @@ class PluginRunnerSupervisor:
|
||||
"""
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
driver_id = self._build_adapter_driver_id(plugin_id)
|
||||
adapter = self._registered_adapters.get(plugin_id)
|
||||
|
||||
self._remove_adapter_route_bindings(plugin_id)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
if platform_io_manager.is_started:
|
||||
@@ -629,7 +641,11 @@ class PluginRunnerSupervisor:
|
||||
else:
|
||||
platform_io_manager.unregister_driver(driver_id)
|
||||
|
||||
if adapter is not None:
|
||||
self._refresh_platform_default_route(adapter.platform)
|
||||
|
||||
self._registered_adapters.pop(plugin_id, None)
|
||||
self._adapter_runtime_states.pop(plugin_id, None)
|
||||
|
||||
async def _unregister_all_adapter_drivers(self) -> None:
|
||||
"""注销当前 Supervisor 管理的全部适配器驱动。"""
|
||||
@@ -637,6 +653,198 @@ class PluginRunnerSupervisor:
|
||||
for plugin_id in plugin_ids:
|
||||
await self._unregister_adapter_driver(plugin_id)
|
||||
|
||||
def _remove_adapter_route_bindings(self, plugin_id: str) -> None:
|
||||
"""移除某个适配器驱动当前持有的全部路由绑定。
|
||||
|
||||
Args:
|
||||
plugin_id: 适配器插件 ID。
|
||||
"""
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
platform_io_manager.route_table.remove_bindings_by_driver(self._build_adapter_driver_id(plugin_id))
|
||||
|
||||
@staticmethod
|
||||
def _normalize_runtime_route_value(value: str) -> Optional[str]:
|
||||
"""规范化适配器运行时路由字段。
|
||||
|
||||
Args:
|
||||
value: 待规范化的原始字符串。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 规范化后非空则返回字符串,否则返回 ``None``。
|
||||
"""
|
||||
normalized_value = str(value).strip()
|
||||
return normalized_value or None
|
||||
|
||||
def _build_runtime_route_key(
|
||||
self,
|
||||
adapter: AdapterDeclarationPayload,
|
||||
payload: AdapterStateUpdatePayload,
|
||||
) -> RouteKey:
|
||||
"""根据运行时状态更新构造适配器生效路由键。
|
||||
|
||||
Args:
|
||||
adapter: 当前适配器声明。
|
||||
payload: 适配器上报的运行时状态。
|
||||
|
||||
Returns:
|
||||
RouteKey: 当前连接应接管的精确路由键。
|
||||
|
||||
Raises:
|
||||
ValueError: 当静态声明与运行时上报的身份信息冲突时抛出。
|
||||
"""
|
||||
runtime_account_id = self._normalize_runtime_route_value(payload.account_id)
|
||||
runtime_scope = self._normalize_runtime_route_value(payload.scope)
|
||||
|
||||
if adapter.account_id and runtime_account_id and adapter.account_id != runtime_account_id:
|
||||
raise ValueError(
|
||||
f"适配器声明的 account_id={adapter.account_id} 与运行时上报的 {runtime_account_id} 不一致"
|
||||
)
|
||||
if adapter.scope and runtime_scope and adapter.scope != runtime_scope:
|
||||
raise ValueError(f"适配器声明的 scope={adapter.scope} 与运行时上报的 {runtime_scope} 不一致")
|
||||
|
||||
return RouteKey(
|
||||
platform=adapter.platform,
|
||||
account_id=runtime_account_id or adapter.account_id or None,
|
||||
scope=runtime_scope or adapter.scope or None,
|
||||
)
|
||||
|
||||
def _bind_runtime_exact_route(
|
||||
self,
|
||||
plugin_id: str,
|
||||
adapter: AdapterDeclarationPayload,
|
||||
route_key: RouteKey,
|
||||
) -> None:
|
||||
"""为适配器连接绑定精确生效路由。
|
||||
|
||||
Args:
|
||||
plugin_id: 适配器插件 ID。
|
||||
adapter: 当前适配器声明。
|
||||
route_key: 当前连接对应的精确路由键。
|
||||
|
||||
Raises:
|
||||
RouteBindingConflictError: 当目标路由已被其他 active owner 占用时抛出。
|
||||
"""
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
platform_io_manager.bind_route(
|
||||
RouteBinding(
|
||||
route_key=route_key,
|
||||
driver_id=self._build_adapter_driver_id(plugin_id),
|
||||
driver_kind=DriverKind.PLUGIN,
|
||||
metadata={
|
||||
"plugin_id": plugin_id,
|
||||
"protocol": adapter.protocol,
|
||||
"binding_role": _ADAPTER_BINDING_ROLE_RUNTIME_EXACT,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
def _list_runtime_exact_bindings(self, platform: str) -> List[RouteBinding]:
|
||||
"""列出某个平台上由 Host 动态维护的精确适配器绑定。
|
||||
|
||||
Args:
|
||||
platform: 目标平台名称。
|
||||
|
||||
Returns:
|
||||
List[RouteBinding]: 当前平台上全部动态精确绑定。
|
||||
"""
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
return [
|
||||
binding
|
||||
for binding in platform_io_manager.route_table.list_bindings()
|
||||
if binding.mode == RouteMode.ACTIVE
|
||||
and binding.route_key.platform == platform
|
||||
and binding.metadata.get("binding_role") == _ADAPTER_BINDING_ROLE_RUNTIME_EXACT
|
||||
]
|
||||
|
||||
def _refresh_platform_default_route(self, platform: str) -> None:
|
||||
"""根据当前精确绑定数量刷新平台级默认路由。
|
||||
|
||||
当某个平台恰好只存在一个动态精确绑定时,会为该绑定额外创建一条
|
||||
``RouteKey(platform=<platform>)`` 形式的默认路由,方便缺少账号维度的
|
||||
出站消息继续找到唯一 owner。若精确绑定数量变为 0 或大于 1,则撤销
|
||||
由 Host 自动维护的默认路由,避免出现隐式歧义。
|
||||
|
||||
Args:
|
||||
platform: 目标平台名称。
|
||||
"""
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
default_route_key = RouteKey(platform=platform)
|
||||
existing_default_binding = platform_io_manager.route_table.get_active_binding(default_route_key, exact_only=True)
|
||||
|
||||
if existing_default_binding is not None:
|
||||
binding_role = existing_default_binding.metadata.get("binding_role")
|
||||
if binding_role != _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT:
|
||||
return
|
||||
platform_io_manager.unbind_route(default_route_key, existing_default_binding.driver_id)
|
||||
|
||||
exact_bindings = self._list_runtime_exact_bindings(platform)
|
||||
if len(exact_bindings) != 1:
|
||||
return
|
||||
|
||||
exact_binding = exact_bindings[0]
|
||||
if exact_binding.route_key == default_route_key:
|
||||
return
|
||||
|
||||
platform_io_manager.bind_route(
|
||||
RouteBinding(
|
||||
route_key=default_route_key,
|
||||
driver_id=exact_binding.driver_id,
|
||||
driver_kind=exact_binding.driver_kind,
|
||||
metadata={
|
||||
"plugin_id": exact_binding.metadata.get("plugin_id", ""),
|
||||
"protocol": exact_binding.metadata.get("protocol", ""),
|
||||
"binding_role": _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT,
|
||||
},
|
||||
),
|
||||
replace=True,
|
||||
)
|
||||
|
||||
def _apply_adapter_runtime_state(
|
||||
self,
|
||||
plugin_id: str,
|
||||
adapter: AdapterDeclarationPayload,
|
||||
payload: AdapterStateUpdatePayload,
|
||||
) -> Tuple[_AdapterRuntimeState, Dict[str, Any]]:
|
||||
"""应用适配器运行时状态,并同步 Platform IO 路由。
|
||||
|
||||
Args:
|
||||
plugin_id: 适配器插件 ID。
|
||||
adapter: 当前适配器声明。
|
||||
payload: 适配器上报的运行时状态。
|
||||
|
||||
Returns:
|
||||
Tuple[_AdapterRuntimeState, Dict[str, Any]]: 更新后的运行时状态,以及
|
||||
供 RPC 响应返回的路由键字典。
|
||||
|
||||
Raises:
|
||||
RouteBindingConflictError: 当新的精确路由与其他 active owner 冲突时抛出。
|
||||
ValueError: 当运行时路由信息不合法时抛出。
|
||||
"""
|
||||
if not payload.connected:
|
||||
self._remove_adapter_route_bindings(plugin_id)
|
||||
self._refresh_platform_default_route(adapter.platform)
|
||||
runtime_state = _AdapterRuntimeState(connected=False, metadata=dict(payload.metadata))
|
||||
self._adapter_runtime_states[plugin_id] = runtime_state
|
||||
return runtime_state, {}
|
||||
|
||||
route_key = self._build_runtime_route_key(adapter, payload)
|
||||
self._remove_adapter_route_bindings(plugin_id)
|
||||
self._bind_runtime_exact_route(plugin_id, adapter, route_key)
|
||||
self._refresh_platform_default_route(adapter.platform)
|
||||
|
||||
runtime_state = _AdapterRuntimeState(
|
||||
connected=True,
|
||||
account_id=route_key.account_id,
|
||||
scope=route_key.scope,
|
||||
metadata=dict(payload.metadata),
|
||||
)
|
||||
self._adapter_runtime_states[plugin_id] = runtime_state
|
||||
return runtime_state, {
|
||||
"platform": route_key.platform,
|
||||
"account_id": route_key.account_id,
|
||||
"scope": route_key.scope,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _attach_inbound_route_metadata(
|
||||
session_message: "SessionMessage",
|
||||
@@ -706,6 +914,45 @@ class PluginRunnerSupervisor:
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
async def _handle_update_adapter_state(self, envelope: Envelope) -> Envelope:
|
||||
"""处理适配器插件上报的运行时状态更新。
|
||||
|
||||
Args:
|
||||
envelope: RPC 请求信封。
|
||||
|
||||
Returns:
|
||||
Envelope: 状态更新处理结果。
|
||||
"""
|
||||
try:
|
||||
payload = AdapterStateUpdatePayload.model_validate(envelope.payload)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
adapter = self._registered_adapters.get(envelope.plugin_id)
|
||||
if adapter is None:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"插件 {envelope.plugin_id} 未声明为适配器,不能更新运行时状态",
|
||||
)
|
||||
|
||||
try:
|
||||
runtime_state, route_key_dict = self._apply_adapter_runtime_state(
|
||||
plugin_id=envelope.plugin_id,
|
||||
adapter=adapter,
|
||||
payload=payload,
|
||||
)
|
||||
except RouteBindingConflictError as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, str(exc))
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
response = AdapterStateUpdateResultPayload(
|
||||
accepted=True,
|
||||
connected=runtime_state.connected,
|
||||
route_key=route_key_dict,
|
||||
)
|
||||
return envelope.make_response(payload=response.model_dump())
|
||||
|
||||
async def _handle_receive_external_message(self, envelope: Envelope) -> Envelope:
|
||||
"""处理适配器插件上报的外部入站消息。
|
||||
|
||||
@@ -970,6 +1217,7 @@ class PluginRunnerSupervisor:
|
||||
self._component_registry.clear()
|
||||
self._registered_plugins.clear()
|
||||
self._registered_adapters.clear()
|
||||
self._adapter_runtime_states.clear()
|
||||
self._runner_ready_events = asyncio.Event()
|
||||
self._runner_ready_payloads = RunnerReadyPayload()
|
||||
self._rpc_server.clear_handshake_state()
|
||||
|
||||
@@ -304,6 +304,30 @@ class AdapterDeclarationPayload(BaseModel):
|
||||
"""适配器附加元数据"""
|
||||
|
||||
|
||||
class AdapterStateUpdatePayload(BaseModel):
|
||||
"""适配器运行时状态更新载荷。"""
|
||||
|
||||
connected: bool = Field(description="适配器当前是否已连接并准备接管路由")
|
||||
"""适配器当前是否已连接并准备接管路由"""
|
||||
account_id: str = Field(default="", description="当前连接对应的账号 ID 或 self_id")
|
||||
"""当前连接对应的账号 ID 或 self_id"""
|
||||
scope: str = Field(default="", description="当前连接对应的可选路由作用域")
|
||||
"""当前连接对应的可选路由作用域"""
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的运行时状态元数据")
|
||||
"""可选的运行时状态元数据"""
|
||||
|
||||
|
||||
class AdapterStateUpdateResultPayload(BaseModel):
|
||||
"""适配器运行时状态更新结果载荷。"""
|
||||
|
||||
accepted: bool = Field(description="Host 是否接受了本次状态更新")
|
||||
"""Host 是否接受了本次状态更新"""
|
||||
connected: bool = Field(description="Host 记录的当前连接状态")
|
||||
"""Host 记录的当前连接状态"""
|
||||
route_key: Dict[str, Any] = Field(default_factory=dict, description="当前生效的路由键")
|
||||
"""当前生效的路由键"""
|
||||
|
||||
|
||||
class ReceiveExternalMessagePayload(BaseModel):
|
||||
"""适配器插件向 Host 注入外部消息的请求载荷。"""
|
||||
|
||||
|
||||
@@ -481,13 +481,14 @@ class PluginRunner:
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return False
|
||||
|
||||
if not await self._invoke_plugin_on_load(meta):
|
||||
if not await self._register_plugin(meta):
|
||||
await self._invoke_plugin_on_unload(meta)
|
||||
await self._deactivate_plugin(meta)
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return False
|
||||
|
||||
if not await self._register_plugin(meta):
|
||||
await self._invoke_plugin_on_unload(meta)
|
||||
if not await self._invoke_plugin_on_load(meta):
|
||||
await self._unregister_plugin(meta.plugin_id, reason="on_load_failed")
|
||||
await self._deactivate_plugin(meta)
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return False
|
||||
|
||||
@@ -60,6 +60,9 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
self._connection_task: Optional[asyncio.Task[None]] = None
|
||||
self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
|
||||
self._background_tasks: Set[asyncio.Task[Any]] = set()
|
||||
self._reported_account_id: Optional[str] = None
|
||||
self._reported_scope: Optional[str] = None
|
||||
self._runtime_state_connected: bool = False
|
||||
self._send_lock = asyncio.Lock()
|
||||
self._ws: Optional[AiohttpClientWebSocketResponse] = None
|
||||
|
||||
@@ -170,6 +173,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await connection_task
|
||||
|
||||
await self._report_adapter_disconnected()
|
||||
self._fail_pending_actions("NapCat connection closed")
|
||||
|
||||
async def _cancel_background_tasks(self) -> None:
|
||||
@@ -209,6 +213,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
self.ctx.logger.warning(f"NapCat 适配器连接失败: {exc}")
|
||||
finally:
|
||||
self._ws = None
|
||||
await self._report_adapter_disconnected()
|
||||
self._fail_pending_actions("NapCat connection interrupted")
|
||||
|
||||
if not self._should_connect():
|
||||
@@ -230,26 +235,39 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
"""
|
||||
assert WSMsgType is not None
|
||||
|
||||
async for ws_message in ws:
|
||||
if ws_message.type != WSMsgType.TEXT:
|
||||
if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
|
||||
break
|
||||
continue
|
||||
bootstrap_task = asyncio.create_task(
|
||||
self._bootstrap_adapter_runtime_state(),
|
||||
name="napcat_adapter.bootstrap",
|
||||
)
|
||||
self._background_tasks.add(bootstrap_task)
|
||||
bootstrap_task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
payload = self._parse_json_message(ws_message.data)
|
||||
if payload is None:
|
||||
continue
|
||||
try:
|
||||
async for ws_message in ws:
|
||||
if ws_message.type != WSMsgType.TEXT:
|
||||
if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
|
||||
break
|
||||
continue
|
||||
|
||||
if echo_id := str(payload.get("echo") or "").strip():
|
||||
self._resolve_pending_action(echo_id, payload)
|
||||
continue
|
||||
payload = self._parse_json_message(ws_message.data)
|
||||
if payload is None:
|
||||
continue
|
||||
|
||||
if str(payload.get("post_type") or "").strip() != "message":
|
||||
continue
|
||||
if echo_id := str(payload.get("echo") or "").strip():
|
||||
self._resolve_pending_action(echo_id, payload)
|
||||
continue
|
||||
|
||||
task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound")
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
if str(payload.get("post_type") or "").strip() != "message":
|
||||
continue
|
||||
|
||||
task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound")
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
finally:
|
||||
if not bootstrap_task.done():
|
||||
bootstrap_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await bootstrap_task
|
||||
|
||||
async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None:
|
||||
"""处理单条 NapCat 入站消息并注入 Host。
|
||||
@@ -258,6 +276,9 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
payload: NapCat / OneBot 推送的原始事件数据。
|
||||
"""
|
||||
self_id = str(payload.get("self_id") or "").strip()
|
||||
if self_id:
|
||||
await self._report_adapter_connected(self_id)
|
||||
|
||||
sender = payload.get("sender", {})
|
||||
if not isinstance(sender, dict):
|
||||
sender = {}
|
||||
@@ -570,6 +591,121 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
response_future.set_exception(RuntimeError(error_message))
|
||||
self._pending_actions.clear()
|
||||
|
||||
async def _bootstrap_adapter_runtime_state(self) -> None:
|
||||
"""在连接建立后主动获取账号信息并激活适配器路由。
|
||||
|
||||
该步骤会在 WebSocket 接收循环启动后异步执行,确保 `_call_action()`
|
||||
发出的 `get_login_info` 请求能够被同一连接上的接收循环消费到 echo
|
||||
响应,从而在真正收到业务消息前就完成 Host 侧 route 激活。
|
||||
"""
|
||||
max_attempts = 3
|
||||
last_error: Optional[Exception] = None
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
ws = self._ws
|
||||
if ws is None or ws.closed:
|
||||
return
|
||||
|
||||
try:
|
||||
response = await self._call_action("get_login_info", {})
|
||||
self_id = self._extract_self_id_from_login_response(response)
|
||||
await self._report_adapter_connected(self_id)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
self.ctx.logger.warning(
|
||||
f"NapCat 适配器获取登录信息失败,第 {attempt}/{max_attempts} 次重试: {exc}"
|
||||
)
|
||||
if attempt < max_attempts:
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
if last_error is not None:
|
||||
self.ctx.logger.error(f"NapCat 适配器未能完成路由激活,连接将保持只接收状态: {last_error}")
|
||||
|
||||
@staticmethod
|
||||
def _extract_self_id_from_login_response(response: Dict[str, Any]) -> str:
|
||||
"""从 `get_login_info` 响应中提取当前账号 ID。
|
||||
|
||||
Args:
|
||||
response: NapCat 返回的原始动作响应。
|
||||
|
||||
Returns:
|
||||
str: 规范化后的 `self_id` 字符串。
|
||||
|
||||
Raises:
|
||||
ValueError: 当响应中缺少有效账号 ID 时抛出。
|
||||
"""
|
||||
if str(response.get("status") or "").lower() != "ok":
|
||||
raise ValueError(str(response.get("wording") or response.get("message") or "get_login_info failed"))
|
||||
|
||||
response_data = response.get("data", {})
|
||||
if not isinstance(response_data, dict):
|
||||
raise ValueError("get_login_info 响应缺少 data 字段")
|
||||
|
||||
self_id = str(response_data.get("user_id") or "").strip()
|
||||
if not self_id:
|
||||
raise ValueError("get_login_info 响应缺少有效的 user_id")
|
||||
return self_id
|
||||
|
||||
async def _report_adapter_connected(self, account_id: str) -> None:
|
||||
"""向 Host 上报当前连接已就绪。
|
||||
|
||||
Args:
|
||||
account_id: 当前 NapCat 连接对应的机器人账号 ID。
|
||||
"""
|
||||
normalized_account_id = str(account_id).strip()
|
||||
if not normalized_account_id:
|
||||
return
|
||||
|
||||
scope = self._get_string(self._connection_config(), "connection_id").strip()
|
||||
if (
|
||||
self._runtime_state_connected
|
||||
and self._reported_account_id == normalized_account_id
|
||||
and self._reported_scope == (scope or None)
|
||||
):
|
||||
return
|
||||
|
||||
accepted = False
|
||||
try:
|
||||
accepted = await self.ctx.adapter.update_runtime_state(
|
||||
connected=True,
|
||||
account_id=normalized_account_id,
|
||||
scope=scope,
|
||||
metadata={"ws_url": self._get_string(self._connection_config(), "ws_url")},
|
||||
)
|
||||
except Exception as exc:
|
||||
self.ctx.logger.warning(f"NapCat 适配器上报连接就绪状态失败: {exc}")
|
||||
return
|
||||
|
||||
if not accepted:
|
||||
self.ctx.logger.warning("NapCat 适配器连接已建立,但 Host 未接受运行时状态更新")
|
||||
return
|
||||
|
||||
self._runtime_state_connected = True
|
||||
self._reported_account_id = normalized_account_id
|
||||
self._reported_scope = scope or None
|
||||
self.ctx.logger.info(
|
||||
f"NapCat 适配器已激活路由: platform=qq account_id={normalized_account_id} "
|
||||
f"scope={self._reported_scope or '*'}"
|
||||
)
|
||||
|
||||
async def _report_adapter_disconnected(self) -> None:
|
||||
"""向 Host 上报当前连接已断开,并撤销适配器路由。"""
|
||||
if not self._runtime_state_connected:
|
||||
self._reported_account_id = None
|
||||
self._reported_scope = None
|
||||
return
|
||||
|
||||
try:
|
||||
await self.ctx.adapter.update_runtime_state(connected=False)
|
||||
except Exception as exc:
|
||||
self.ctx.logger.warning(f"NapCat 适配器上报断开状态失败: {exc}")
|
||||
finally:
|
||||
self._runtime_state_connected = False
|
||||
self._reported_account_id = None
|
||||
self._reported_scope = None
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""构造连接 NapCat 所需的请求头。
|
||||
|
||||
|
||||
Reference in New Issue
Block a user