feat: Implement adapter runtime state management and update handling
- Added support for adapter runtime state updates in the PluginRunnerSupervisor. - Introduced new payload classes: AdapterStateUpdatePayload and AdapterStateUpdateResultPayload for handling state updates. - Implemented methods to bind and unbind routes based on adapter connection status. - Enhanced the NapCat adapter to report connection state and manage runtime state. - Added tests for adapter runtime state synchronization and database session behavior in the statistic module. - Updated existing methods to ensure proper handling of adapter state and route bindings.
This commit is contained in:
355
pytests/common_test/test_person_info_group_cardname.py
Normal file
355
pytests/common_test/test_person_info_group_cardname.py
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
"""人物信息群名片字段兼容测试。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from importlib.util import module_from_spec, spec_from_file_location
|
||||||
|
from pathlib import Path
|
||||||
|
from types import ModuleType, SimpleNamespace
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyLogger:
|
||||||
|
"""模拟日志记录器。"""
|
||||||
|
|
||||||
|
def debug(self, message: str) -> None:
|
||||||
|
"""记录调试日志。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 日志内容。
|
||||||
|
"""
|
||||||
|
del message
|
||||||
|
|
||||||
|
def info(self, message: str) -> None:
|
||||||
|
"""记录信息日志。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 日志内容。
|
||||||
|
"""
|
||||||
|
del message
|
||||||
|
|
||||||
|
def warning(self, message: str) -> None:
|
||||||
|
"""记录警告日志。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 日志内容。
|
||||||
|
"""
|
||||||
|
del message
|
||||||
|
|
||||||
|
def error(self, message: str) -> None:
|
||||||
|
"""记录错误日志。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 日志内容。
|
||||||
|
"""
|
||||||
|
del message
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyStatement:
|
||||||
|
"""模拟 SQL 查询语句对象。"""
|
||||||
|
|
||||||
|
def where(self, condition: Any) -> "_DummyStatement":
|
||||||
|
"""附加过滤条件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
condition: 过滤条件。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_DummyStatement: 当前语句对象。
|
||||||
|
"""
|
||||||
|
del condition
|
||||||
|
return self
|
||||||
|
|
||||||
|
def limit(self, value: int) -> "_DummyStatement":
|
||||||
|
"""限制返回条数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: 条数限制。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_DummyStatement: 当前语句对象。
|
||||||
|
"""
|
||||||
|
del value
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyColumn:
|
||||||
|
"""模拟 SQLModel 列对象。"""
|
||||||
|
|
||||||
|
def is_not(self, value: Any) -> "_DummyColumn":
|
||||||
|
"""模拟 `IS NOT` 条件构造。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: 比较值。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_DummyColumn: 当前列对象。
|
||||||
|
"""
|
||||||
|
del value
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> "_DummyColumn":
|
||||||
|
"""模拟等值条件构造。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other: 比较值。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_DummyColumn: 当前列对象。
|
||||||
|
"""
|
||||||
|
del other
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyResult:
|
||||||
|
"""模拟数据库查询结果。"""
|
||||||
|
|
||||||
|
def __init__(self, record: Any) -> None:
|
||||||
|
"""初始化查询结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
record: 待返回的首条记录。
|
||||||
|
"""
|
||||||
|
self._record = record
|
||||||
|
|
||||||
|
def first(self) -> Any:
|
||||||
|
"""返回第一条记录。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: 首条记录。
|
||||||
|
"""
|
||||||
|
return self._record
|
||||||
|
|
||||||
|
def all(self) -> list[Any]:
|
||||||
|
"""返回全部结果。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Any]: 结果列表。
|
||||||
|
"""
|
||||||
|
if self._record is None:
|
||||||
|
return []
|
||||||
|
return self._record if isinstance(self._record, list) else [self._record]
|
||||||
|
|
||||||
|
|
||||||
|
class _DummySession:
|
||||||
|
"""模拟数据库 Session。"""
|
||||||
|
|
||||||
|
def __init__(self, record: Any) -> None:
|
||||||
|
"""初始化 Session。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
record: `first()` 应返回的记录。
|
||||||
|
"""
|
||||||
|
self.record = record
|
||||||
|
self.added_records: list[Any] = []
|
||||||
|
|
||||||
|
def __enter__(self) -> "_DummySession":
|
||||||
|
"""进入上下文管理器。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_DummySession: 当前 Session。
|
||||||
|
"""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||||
|
"""退出上下文管理器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exc_type: 异常类型。
|
||||||
|
exc_val: 异常值。
|
||||||
|
exc_tb: 异常回溯。
|
||||||
|
"""
|
||||||
|
del exc_type
|
||||||
|
del exc_val
|
||||||
|
del exc_tb
|
||||||
|
|
||||||
|
def exec(self, statement: Any) -> _DummyResult:
|
||||||
|
"""执行查询。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
statement: 查询语句。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_DummyResult: 模拟结果对象。
|
||||||
|
"""
|
||||||
|
del statement
|
||||||
|
return _DummyResult(self.record)
|
||||||
|
|
||||||
|
def add(self, record: Any) -> None:
|
||||||
|
"""记录被添加的对象。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
record: 被写入 Session 的对象。
|
||||||
|
"""
|
||||||
|
self.added_records.append(record)
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyPersonInfoRecord:
|
||||||
|
"""模拟 `PersonInfo` ORM 模型。"""
|
||||||
|
|
||||||
|
person_id = "person_id"
|
||||||
|
person_name = "person_name"
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
"""使用关键字参数初始化记录对象。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: 字段值。
|
||||||
|
"""
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_person_module(monkeypatch: pytest.MonkeyPatch, session: _DummySession) -> ModuleType:
|
||||||
|
"""加载带依赖桩的 `person_info` 模块。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
monkeypatch: Pytest monkeypatch 工具。
|
||||||
|
session: 提供给模块使用的假数据库 Session。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModuleType: 加载后的模块对象。
|
||||||
|
"""
|
||||||
|
logger_module = ModuleType("src.common.logger")
|
||||||
|
logger_module.get_logger = lambda name: _DummyLogger()
|
||||||
|
monkeypatch.setitem(sys.modules, "src.common.logger", logger_module)
|
||||||
|
|
||||||
|
database_module = ModuleType("src.common.database.database")
|
||||||
|
database_module.get_db_session = lambda: session
|
||||||
|
monkeypatch.setitem(sys.modules, "src.common.database.database", database_module)
|
||||||
|
|
||||||
|
database_model_module = ModuleType("src.common.database.database_model")
|
||||||
|
database_model_module.PersonInfo = _DummyPersonInfoRecord
|
||||||
|
monkeypatch.setitem(sys.modules, "src.common.database.database_model", database_model_module)
|
||||||
|
|
||||||
|
llm_module = ModuleType("src.llm_models.utils_model")
|
||||||
|
|
||||||
|
class _DummyLLMRequest:
|
||||||
|
"""模拟 LLMRequest。"""
|
||||||
|
|
||||||
|
def __init__(self, model_set: Any, request_type: str) -> None:
|
||||||
|
"""初始化假请求对象。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_set: 模型配置。
|
||||||
|
request_type: 请求类型。
|
||||||
|
"""
|
||||||
|
del model_set
|
||||||
|
del request_type
|
||||||
|
|
||||||
|
llm_module.LLMRequest = _DummyLLMRequest
|
||||||
|
monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_module)
|
||||||
|
|
||||||
|
config_module = ModuleType("src.config.config")
|
||||||
|
config_module.global_config = SimpleNamespace(bot=SimpleNamespace(nickname="MaiBot"))
|
||||||
|
config_module.model_config = SimpleNamespace(model_task_config=SimpleNamespace(tool_use="tool_use", utils="utils"))
|
||||||
|
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
|
||||||
|
|
||||||
|
chat_manager_module = ModuleType("src.chat.message_receive.chat_manager")
|
||||||
|
chat_manager_module.chat_manager = SimpleNamespace()
|
||||||
|
monkeypatch.setitem(sys.modules, "src.chat.message_receive.chat_manager", chat_manager_module)
|
||||||
|
|
||||||
|
module_path = Path(__file__).resolve().parents[2] / "src" / "person_info" / "person_info.py"
|
||||||
|
spec = spec_from_file_location("person_info_group_cardname_test_module", module_path)
|
||||||
|
assert spec is not None and spec.loader is not None
|
||||||
|
|
||||||
|
module = module_from_spec(spec)
|
||||||
|
monkeypatch.setitem(sys.modules, spec.name, module)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
monkeypatch.setattr(module, "select", lambda *args: _DummyStatement())
|
||||||
|
monkeypatch.setattr(module, "col", lambda field: _DummyColumn())
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_group_cardname_json_uses_canonical_key() -> None:
|
||||||
|
"""群名片 JSON 解析应只使用 `group_cardname` 键名。"""
|
||||||
|
parsed = parse_group_cardname_json(
|
||||||
|
json.dumps(
|
||||||
|
[
|
||||||
|
{"group_id": "1001", "group_cardname": "现行字段"},
|
||||||
|
],
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert parsed is not None
|
||||||
|
assert [(item.group_id, item.group_cardname) for item in parsed] == [
|
||||||
|
("1001", "现行字段"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_dump_group_cardname_records_uses_canonical_key() -> None:
|
||||||
|
"""群名片序列化应输出 `group_cardname` 键名。"""
|
||||||
|
dumped = dump_group_cardname_records(
|
||||||
|
[
|
||||||
|
{"group_id": "1001", "group_cardname": "群昵称"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert json.loads(dumped) == [{"group_id": "1001", "group_cardname": "群昵称"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_sync_to_database_uses_group_cardname_field(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""同步人物信息时应写入数据库模型的 `group_cardname` 字段。"""
|
||||||
|
record = _DummyPersonInfoRecord()
|
||||||
|
session = _DummySession(record)
|
||||||
|
module = _load_person_module(monkeypatch, session)
|
||||||
|
|
||||||
|
person = module.Person.__new__(module.Person)
|
||||||
|
person.is_known = True
|
||||||
|
person.person_id = "person-1"
|
||||||
|
person.platform = "qq"
|
||||||
|
person.user_id = "10001"
|
||||||
|
person.nickname = "看番的龙"
|
||||||
|
person.person_name = "看番的龙"
|
||||||
|
person.name_reason = "测试"
|
||||||
|
person.know_times = 1
|
||||||
|
person.know_since = 1700000000.0
|
||||||
|
person.last_know = 1700000100.0
|
||||||
|
person.memory_points = ["喜好:番剧:0.8"]
|
||||||
|
person.group_cardname_list = [{"group_id": "20001", "group_cardname": "白泽大人"}]
|
||||||
|
|
||||||
|
person.sync_to_database()
|
||||||
|
|
||||||
|
assert record.group_cardname == '[{"group_id": "20001", "group_cardname": "白泽大人"}]'
|
||||||
|
assert not hasattr(record, "group_nickname")
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_load_from_database_normalizes_group_cardname_payload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""从数据库加载人物信息时应读取标准 `group_cardname` 结构。"""
|
||||||
|
record = _DummyPersonInfoRecord(
|
||||||
|
user_id="10001",
|
||||||
|
platform="qq",
|
||||||
|
is_known=True,
|
||||||
|
user_nickname="看番的龙",
|
||||||
|
person_name="看番的龙",
|
||||||
|
name_reason=None,
|
||||||
|
know_counts=2,
|
||||||
|
memory_points='["喜好:番剧:0.8"]',
|
||||||
|
group_cardname=json.dumps(
|
||||||
|
[
|
||||||
|
{"group_id": "20001", "group_cardname": "白泽大人"},
|
||||||
|
],
|
||||||
|
ensure_ascii=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
session = _DummySession(record)
|
||||||
|
module = _load_person_module(monkeypatch, session)
|
||||||
|
|
||||||
|
person = module.Person.__new__(module.Person)
|
||||||
|
person.person_id = "person-1"
|
||||||
|
person.memory_points = []
|
||||||
|
person.group_cardname_list = []
|
||||||
|
|
||||||
|
person.load_from_database()
|
||||||
|
|
||||||
|
assert person.group_cardname_list == [
|
||||||
|
{"group_id": "20001", "group_cardname": "白泽大人"},
|
||||||
|
]
|
||||||
162
pytests/test_adapter_runtime_state.py
Normal file
162
pytests/test_adapter_runtime_state.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""适配器运行时状态同步测试。"""
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.platform_io.manager import PlatformIOManager
|
||||||
|
from src.platform_io.types import RouteKey
|
||||||
|
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||||
|
from src.plugin_runtime.protocol.envelope import (
|
||||||
|
AdapterDeclarationPayload,
|
||||||
|
Envelope,
|
||||||
|
MessageType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_request(plugin_id: str, payload: Dict[str, Any]) -> Envelope:
|
||||||
|
"""构造一个适配器状态更新 RPC 请求。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: 目标适配器插件 ID。
|
||||||
|
payload: 请求载荷。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Envelope: 标准 RPC 请求信封。
|
||||||
|
"""
|
||||||
|
return Envelope(
|
||||||
|
request_id=1,
|
||||||
|
message_type=MessageType.REQUEST,
|
||||||
|
method="host.update_adapter_state",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_adapter_runtime_state_binds_and_unbinds_routes(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""连接建立后应绑定路由,断开后应撤销路由。"""
|
||||||
|
import src.plugin_runtime.host.supervisor as supervisor_module
|
||||||
|
|
||||||
|
platform_io_manager = PlatformIOManager()
|
||||||
|
monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
|
||||||
|
|
||||||
|
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||||
|
adapter = AdapterDeclarationPayload(platform="qq", protocol="napcat")
|
||||||
|
await supervisor._register_adapter_driver("napcat_adapter_builtin", adapter)
|
||||||
|
|
||||||
|
response = await supervisor._handle_update_adapter_state(
|
||||||
|
_make_request(
|
||||||
|
"napcat_adapter_builtin",
|
||||||
|
{
|
||||||
|
"connected": True,
|
||||||
|
"account_id": "10001",
|
||||||
|
"scope": "",
|
||||||
|
"metadata": {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.error is None
|
||||||
|
assert response.payload["accepted"] is True
|
||||||
|
assert (
|
||||||
|
platform_io_manager.route_table.get_active_binding(
|
||||||
|
RouteKey(platform="qq", account_id="10001"),
|
||||||
|
exact_only=True,
|
||||||
|
).driver_id
|
||||||
|
== "adapter:napcat_adapter_builtin"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
platform_io_manager.route_table.get_active_binding(
|
||||||
|
RouteKey(platform="qq"),
|
||||||
|
exact_only=True,
|
||||||
|
).driver_id
|
||||||
|
== "adapter:napcat_adapter_builtin"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await supervisor._handle_update_adapter_state(
|
||||||
|
_make_request(
|
||||||
|
"napcat_adapter_builtin",
|
||||||
|
{
|
||||||
|
"connected": False,
|
||||||
|
"account_id": "",
|
||||||
|
"scope": "",
|
||||||
|
"metadata": {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.error is None
|
||||||
|
assert response.payload["accepted"] is True
|
||||||
|
assert platform_io_manager.route_table.get_active_binding(
|
||||||
|
RouteKey(platform="qq", account_id="10001"),
|
||||||
|
exact_only=True,
|
||||||
|
) is None
|
||||||
|
assert platform_io_manager.route_table.get_active_binding(RouteKey(platform="qq"), exact_only=True) is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_platform_default_route_is_removed_when_multiple_exact_routes_exist(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""同一平台存在多个精确路由时不应保留默认平台路由。"""
|
||||||
|
import src.plugin_runtime.host.supervisor as supervisor_module
|
||||||
|
|
||||||
|
platform_io_manager = PlatformIOManager()
|
||||||
|
monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
|
||||||
|
|
||||||
|
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||||
|
adapter = AdapterDeclarationPayload(platform="qq", protocol="napcat")
|
||||||
|
await supervisor._register_adapter_driver("adapter_a", adapter)
|
||||||
|
await supervisor._register_adapter_driver("adapter_b", adapter)
|
||||||
|
|
||||||
|
await supervisor._handle_update_adapter_state(
|
||||||
|
_make_request(
|
||||||
|
"adapter_a",
|
||||||
|
{
|
||||||
|
"connected": True,
|
||||||
|
"account_id": "10001",
|
||||||
|
"scope": "",
|
||||||
|
"metadata": {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
platform_io_manager.route_table.get_active_binding(
|
||||||
|
RouteKey(platform="qq"),
|
||||||
|
exact_only=True,
|
||||||
|
).driver_id
|
||||||
|
== "adapter:adapter_a"
|
||||||
|
)
|
||||||
|
|
||||||
|
await supervisor._handle_update_adapter_state(
|
||||||
|
_make_request(
|
||||||
|
"adapter_b",
|
||||||
|
{
|
||||||
|
"connected": True,
|
||||||
|
"account_id": "10002",
|
||||||
|
"scope": "",
|
||||||
|
"metadata": {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert platform_io_manager.route_table.get_active_binding(RouteKey(platform="qq"), exact_only=True) is None
|
||||||
|
|
||||||
|
await supervisor._handle_update_adapter_state(
|
||||||
|
_make_request(
|
||||||
|
"adapter_b",
|
||||||
|
{
|
||||||
|
"connected": False,
|
||||||
|
"account_id": "",
|
||||||
|
"scope": "",
|
||||||
|
"metadata": {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
platform_io_manager.route_table.get_active_binding(
|
||||||
|
RouteKey(platform="qq"),
|
||||||
|
exact_only=True,
|
||||||
|
).driver_id
|
||||||
|
== "adapter:adapter_a"
|
||||||
|
)
|
||||||
@@ -486,10 +486,10 @@ class TestSDK:
|
|||||||
"timeout_ms": timeout_ms,
|
"timeout_ms": timeout_ms,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if method == "cap.request":
|
if method == "cap.call":
|
||||||
bootstrap_methods = [call["method"] for call in self.calls[:-1]]
|
bootstrap_methods = [call["method"] for call in self.calls[:-1]]
|
||||||
assert "plugin.bootstrap" in bootstrap_methods
|
assert "plugin.bootstrap" in bootstrap_methods
|
||||||
return SimpleNamespace(error=None, payload={"result": {"success": True}})
|
return SimpleNamespace(error=None, payload={"success": True})
|
||||||
return SimpleNamespace(error=None, payload={"accepted": True})
|
return SimpleNamespace(error=None, payload={"accepted": True})
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
@@ -529,7 +529,7 @@ class TestSDK:
|
|||||||
await runner.run()
|
await runner.run()
|
||||||
|
|
||||||
methods = [call["method"] for call in runner._rpc_client.calls]
|
methods = [call["method"] for call in runner._rpc_client.calls]
|
||||||
assert methods == ["plugin.bootstrap", "cap.request", "plugin.register_components", "runner.ready"]
|
assert methods == ["plugin.bootstrap", "plugin.register_components", "cap.call", "runner.ready"]
|
||||||
|
|
||||||
|
|
||||||
class TestPluginSdkUsage:
|
class TestPluginSdkUsage:
|
||||||
|
|||||||
115
pytests/utils_test/statistic_test.py
Normal file
115
pytests/utils_test/statistic_test.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
"""统计模块数据库会话行为测试。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Any, Callable, Iterator
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.chat.utils import statistic
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyResult:
|
||||||
|
"""模拟 SQLModel 查询结果对象。"""
|
||||||
|
|
||||||
|
def all(self) -> list[Any]:
|
||||||
|
"""返回空结果集。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Any]: 空列表。
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class _DummySession:
|
||||||
|
"""模拟数据库 Session。"""
|
||||||
|
|
||||||
|
def exec(self, statement: Any) -> _DummyResult:
|
||||||
|
"""执行查询语句并返回空结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
statement: 待执行的查询语句。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_DummyResult: 空结果对象。
|
||||||
|
"""
|
||||||
|
del statement
|
||||||
|
return _DummyResult()
|
||||||
|
|
||||||
|
|
||||||
|
def _build_fake_get_db_session(calls: list[bool]) -> Callable[[bool], Iterator[_DummySession]]:
|
||||||
|
"""构造一个记录 auto_commit 参数的假会话工厂。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
calls: 用于记录每次调用 auto_commit 参数的列表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[bool], Iterator[_DummySession]]: 可替换 `get_db_session` 的上下文管理器工厂。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_DummySession]:
|
||||||
|
"""记录会话参数并返回假 Session。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auto_commit: 是否启用自动提交。
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Iterator[_DummySession]: 假 Session 对象。
|
||||||
|
"""
|
||||||
|
calls.append(auto_commit)
|
||||||
|
yield _DummySession()
|
||||||
|
|
||||||
|
return _fake_get_db_session
|
||||||
|
|
||||||
|
|
||||||
|
def _build_statistic_task() -> statistic.StatisticOutputTask:
|
||||||
|
"""构造一个最小可用的统计任务实例。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
statistic.StatisticOutputTask: 跳过 `__init__` 的测试实例。
|
||||||
|
"""
|
||||||
|
task = statistic.StatisticOutputTask.__new__(statistic.StatisticOutputTask)
|
||||||
|
task.name_mapping = {}
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
def _is_bot_self(platform: str, user_id: str) -> bool:
|
||||||
|
"""返回固定的非机器人身份判断结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform: 平台名称。
|
||||||
|
user_id: 用户 ID。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 始终返回 ``False``。
|
||||||
|
"""
|
||||||
|
del platform
|
||||||
|
del user_id
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_statistic_read_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""统计模块的纯读查询应关闭自动提交,避免 Session 退出后对象被 expire。"""
|
||||||
|
calls: list[bool] = []
|
||||||
|
now = datetime.now()
|
||||||
|
task = _build_statistic_task()
|
||||||
|
|
||||||
|
monkeypatch.setattr(statistic, "get_db_session", _build_fake_get_db_session(calls))
|
||||||
|
|
||||||
|
utils_module = ModuleType("src.chat.utils.utils")
|
||||||
|
utils_module.is_bot_self = _is_bot_self
|
||||||
|
monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
|
||||||
|
|
||||||
|
statistic.StatisticOutputTask._fetch_online_time_since(now)
|
||||||
|
statistic.StatisticOutputTask._fetch_model_usage_since(now)
|
||||||
|
task._collect_message_count_for_period([("last_hour", now - timedelta(hours=1))])
|
||||||
|
task._collect_interval_data(now, hours=1, interval_minutes=60)
|
||||||
|
task._collect_metrics_interval_data(now, hours=1, interval_hours=1)
|
||||||
|
|
||||||
|
assert calls == [False] * 9
|
||||||
@@ -436,14 +436,14 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]:
|
def _fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]:
|
||||||
with get_db_session() as session:
|
with get_db_session(auto_commit=False) as session:
|
||||||
statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time)
|
statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time)
|
||||||
records = session.exec(statement).all()
|
records = session.exec(statement).all()
|
||||||
return [(record.start_timestamp, record.end_timestamp) for record in records]
|
return [(record.start_timestamp, record.end_timestamp) for record in records]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]:
|
def _fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]:
|
||||||
with get_db_session() as session:
|
with get_db_session(auto_commit=False) as session:
|
||||||
statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time)
|
statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time)
|
||||||
records = session.exec(statement).all()
|
records = session.exec(statement).all()
|
||||||
return [
|
return [
|
||||||
@@ -664,7 +664,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
}
|
}
|
||||||
|
|
||||||
query_start_timestamp = collect_period[-1][1]
|
query_start_timestamp = collect_period[-1][1]
|
||||||
with get_db_session() as session:
|
with get_db_session(auto_commit=False) as session:
|
||||||
statement = select(Messages).where(col(Messages.timestamp) >= query_start_timestamp)
|
statement = select(Messages).where(col(Messages.timestamp) >= query_start_timestamp)
|
||||||
messages = session.exec(statement).all()
|
messages = session.exec(statement).all()
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@@ -713,7 +713,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
|
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
|
||||||
try:
|
try:
|
||||||
action_query_start_timestamp = collect_period[-1][1]
|
action_query_start_timestamp = collect_period[-1][1]
|
||||||
with get_db_session() as session:
|
with get_db_session(auto_commit=False) as session:
|
||||||
statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp)
|
statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp)
|
||||||
actions = session.exec(statement).all()
|
actions = session.exec(statement).all()
|
||||||
for action in actions:
|
for action in actions:
|
||||||
@@ -1750,7 +1750,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
# 查询消息记录
|
# 查询消息记录
|
||||||
query_start_timestamp = start_time.timestamp()
|
query_start_timestamp = start_time.timestamp()
|
||||||
with get_db_session() as session:
|
with get_db_session(auto_commit=False) as session:
|
||||||
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
|
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
|
||||||
messages = session.exec(statement).all()
|
messages = session.exec(statement).all()
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@@ -2131,7 +2131,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
# 查询消息记录
|
# 查询消息记录
|
||||||
query_start_timestamp = start_time.timestamp()
|
query_start_timestamp = start_time.timestamp()
|
||||||
with get_db_session() as session:
|
with get_db_session(auto_commit=False) as session:
|
||||||
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
|
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
|
||||||
messages = session.exec(statement).all()
|
messages = session.exec(statement).all()
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, List
|
from typing import Any, List, Mapping, Optional, Sequence
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@@ -15,6 +15,76 @@ class GroupCardnameInfo:
|
|||||||
group_cardname: str
|
group_cardname: str
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_group_cardname_item(raw_item: Mapping[str, Any]) -> Optional[GroupCardnameInfo]:
|
||||||
|
"""将单条群名片数据规范化为统一结构。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_item: 原始群名片字典,必须包含 `group_id` 和 `group_cardname`。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[GroupCardnameInfo]: 规范化后的群名片信息;若数据不完整则返回 ``None``。
|
||||||
|
"""
|
||||||
|
group_id = str(raw_item.get("group_id") or "").strip()
|
||||||
|
group_cardname = str(raw_item.get("group_cardname") or "").strip()
|
||||||
|
if not group_id or not group_cardname:
|
||||||
|
return None
|
||||||
|
return GroupCardnameInfo(group_id=group_id, group_cardname=group_cardname)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_group_cardname_json(group_cardname_json: Optional[str]) -> Optional[List[GroupCardnameInfo]]:
|
||||||
|
"""解析数据库中的群名片 JSON 字段。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_cardname_json: 数据库存储的群名片 JSON 字符串。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[List[GroupCardnameInfo]]: 解析并规范化后的群名片列表;若字段为空或无有效项则返回 ``None``。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
|
||||||
|
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
|
||||||
|
"""
|
||||||
|
if not group_cardname_json:
|
||||||
|
return None
|
||||||
|
|
||||||
|
raw_items = json.loads(group_cardname_json)
|
||||||
|
if not isinstance(raw_items, list):
|
||||||
|
return None
|
||||||
|
|
||||||
|
normalized_items: List[GroupCardnameInfo] = []
|
||||||
|
for raw_item in raw_items:
|
||||||
|
if not isinstance(raw_item, Mapping):
|
||||||
|
continue
|
||||||
|
if normalized_item := _normalize_group_cardname_item(raw_item):
|
||||||
|
normalized_items.append(normalized_item)
|
||||||
|
|
||||||
|
return normalized_items or None
|
||||||
|
|
||||||
|
|
||||||
|
def dump_group_cardname_records(
|
||||||
|
group_cardname_records: Optional[Sequence[GroupCardnameInfo | Mapping[str, Any]]],
|
||||||
|
) -> str:
|
||||||
|
"""将群名片列表序列化为数据库使用的标准 JSON 字符串。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_cardname_records: 待序列化的群名片列表,支持 `GroupCardnameInfo`
|
||||||
|
对象和包含 `group_id` / `group_cardname` 的字典。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 统一使用 `group_cardname` 键名的 JSON 字符串。
|
||||||
|
"""
|
||||||
|
normalized_items: List[GroupCardnameInfo] = []
|
||||||
|
for raw_item in group_cardname_records or []:
|
||||||
|
if isinstance(raw_item, GroupCardnameInfo):
|
||||||
|
normalized_items.append(raw_item)
|
||||||
|
continue
|
||||||
|
if isinstance(raw_item, Mapping):
|
||||||
|
if normalized_item := _normalize_group_cardname_item(raw_item):
|
||||||
|
normalized_items.append(normalized_item)
|
||||||
|
|
||||||
|
return json.dumps([asdict(item) for item in normalized_items], ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -58,9 +128,16 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
|||||||
"""最后一次被认识的时间"""
|
"""最后一次被认识的时间"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db_instance(cls, db_record: "PersonInfo"):
|
def from_db_instance(cls, db_record: "PersonInfo") -> "MaiPersonInfo":
|
||||||
nickname_json = json.loads(db_record.group_cardname) if db_record.group_cardname else None
|
"""从数据库记录构造人物信息数据模型。
|
||||||
group_cardname_list = [GroupCardnameInfo(**item) for item in nickname_json] if nickname_json else None
|
|
||||||
|
Args:
|
||||||
|
db_record: 数据库中的人物信息记录。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MaiPersonInfo: 转换后的数据模型对象。
|
||||||
|
"""
|
||||||
|
group_cardname_list = parse_group_cardname_json(db_record.group_cardname)
|
||||||
memory_points = json.loads(db_record.memory_points) if db_record.memory_points else None
|
memory_points = json.loads(db_record.memory_points) if db_record.memory_points else None
|
||||||
return cls(
|
return cls(
|
||||||
is_known=db_record.is_known,
|
is_known=db_record.is_known,
|
||||||
@@ -78,9 +155,12 @@ class MaiPersonInfo(BaseDatabaseDataModel[PersonInfo]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def to_db_instance(self) -> "PersonInfo":
|
def to_db_instance(self) -> "PersonInfo":
|
||||||
group_cardname = (
|
"""将当前数据模型转换为数据库记录对象。
|
||||||
json.dumps([gc.__dict__ for gc in self.group_cardname_list]) if self.group_cardname_list else None
|
|
||||||
)
|
Returns:
|
||||||
|
PersonInfo: 可直接写入数据库的模型实例。
|
||||||
|
"""
|
||||||
|
group_cardname = dump_group_cardname_records(self.group_cardname_list)
|
||||||
return PersonInfo(
|
return PersonInfo(
|
||||||
is_known=self.is_known,
|
is_known=self.is_known,
|
||||||
person_id=self.person_id,
|
person_id=self.person_id,
|
||||||
|
|||||||
@@ -1,22 +1,24 @@
|
|||||||
import hashlib
|
from datetime import datetime
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
import random
|
|
||||||
import math
|
import math
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
from typing import Union, Optional, Dict
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||||
|
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
|
||||||
from src.common.database.database import get_db_session
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import PersonInfo
|
from src.common.database.database_model import PersonInfo
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("person_info")
|
logger = get_logger("person_info")
|
||||||
@@ -26,6 +28,32 @@ relation_selection_model = LLMRequest(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_group_cardname_records(group_cardname_json: Optional[str]) -> list[dict[str, str]]:
|
||||||
|
"""将数据库中的群名片 JSON 转换为 `Person` 内部使用的结构。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_cardname_json: 数据库存储的群名片 JSON 字符串。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict[str, str]]: 统一使用 `group_cardname` 键名的群名片列表。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
json.JSONDecodeError: 当 JSON 文本格式非法时抛出。
|
||||||
|
TypeError: 当输入值类型不符合 `json.loads()` 要求时抛出。
|
||||||
|
"""
|
||||||
|
group_cardname_list = parse_group_cardname_json(group_cardname_json)
|
||||||
|
if not group_cardname_list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"group_id": group_cardname.group_id,
|
||||||
|
"group_cardname": group_cardname.group_cardname,
|
||||||
|
}
|
||||||
|
for group_cardname in group_cardname_list
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||||
"""获取唯一id"""
|
"""获取唯一id"""
|
||||||
if "-" in platform:
|
if "-" in platform:
|
||||||
@@ -231,7 +259,7 @@ class Person:
|
|||||||
person.know_since = time.time()
|
person.know_since = time.time()
|
||||||
person.last_know = time.time()
|
person.last_know = time.time()
|
||||||
person.memory_points = []
|
person.memory_points = []
|
||||||
person.group_nick_name = [] # 初始化群昵称列表
|
person.group_cardname_list = [] # 初始化群名片列表
|
||||||
|
|
||||||
# 如果是群聊,添加群昵称
|
# 如果是群聊,添加群昵称
|
||||||
if group_id and group_nick_name:
|
if group_id and group_nick_name:
|
||||||
@@ -269,7 +297,7 @@ class Person:
|
|||||||
self.platform = platform
|
self.platform = platform
|
||||||
self.nickname = global_config.bot.nickname
|
self.nickname = global_config.bot.nickname
|
||||||
self.person_name = global_config.bot.nickname
|
self.person_name = global_config.bot.nickname
|
||||||
self.group_nick_name: list[dict[str, str]] = []
|
self.group_cardname_list: list[dict[str, str]] = []
|
||||||
return
|
return
|
||||||
|
|
||||||
self.user_id = ""
|
self.user_id = ""
|
||||||
@@ -308,7 +336,7 @@ class Person:
|
|||||||
self.know_since = None
|
self.know_since = None
|
||||||
self.last_know: Optional[float] = None
|
self.last_know: Optional[float] = None
|
||||||
self.memory_points = []
|
self.memory_points = []
|
||||||
self.group_nick_name: list[dict[str, str]] = [] # 群昵称列表,存储 {"group_id": str, "group_nick_name": str}
|
self.group_cardname_list: list[dict[str, str]] = [] # 群名片列表,存储 {"group_id": str, "group_cardname": str}
|
||||||
|
|
||||||
# 从数据库加载数据
|
# 从数据库加载数据
|
||||||
self.load_from_database()
|
self.load_from_database()
|
||||||
@@ -408,16 +436,16 @@ class Person:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 检查是否已存在该群号的记录
|
# 检查是否已存在该群号的记录
|
||||||
for item in self.group_nick_name:
|
for item in self.group_cardname_list:
|
||||||
if item.get("group_id") == group_id:
|
if item.get("group_id") == group_id:
|
||||||
# 更新现有记录
|
# 更新现有记录
|
||||||
item["group_nick_name"] = group_nick_name
|
item["group_cardname"] = group_nick_name
|
||||||
self.sync_to_database()
|
self.sync_to_database()
|
||||||
logger.debug(f"更新用户 {self.person_id} 在群 {group_id} 的群昵称为 {group_nick_name}")
|
logger.debug(f"更新用户 {self.person_id} 在群 {group_id} 的群昵称为 {group_nick_name}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 添加新记录
|
# 添加新记录
|
||||||
self.group_nick_name.append({"group_id": group_id, "group_nick_name": group_nick_name})
|
self.group_cardname_list.append({"group_id": group_id, "group_cardname": group_nick_name})
|
||||||
self.sync_to_database()
|
self.sync_to_database()
|
||||||
logger.debug(f"添加用户 {self.person_id} 在群 {group_id} 的群昵称 {group_nick_name}")
|
logger.debug(f"添加用户 {self.person_id} 在群 {group_id} 的群昵称 {group_nick_name}")
|
||||||
|
|
||||||
@@ -452,20 +480,15 @@ class Person:
|
|||||||
else:
|
else:
|
||||||
self.memory_points = []
|
self.memory_points = []
|
||||||
|
|
||||||
# 处理group_nick_name字段(JSON格式的列表)
|
# 处理 group_cardname 字段(JSON 格式的列表)
|
||||||
if record.group_cardname:
|
if record.group_cardname:
|
||||||
try:
|
try:
|
||||||
loaded_group_nick_names = json.loads(record.group_cardname)
|
self.group_cardname_list = _to_group_cardname_records(record.group_cardname)
|
||||||
# 确保是列表格式
|
|
||||||
if isinstance(loaded_group_nick_names, list):
|
|
||||||
self.group_nick_name = loaded_group_nick_names
|
|
||||||
else:
|
|
||||||
self.group_nick_name = []
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
logger.warning(f"解析用户 {self.person_id} 的group_cardname字段失败,使用默认值")
|
logger.warning(f"解析用户 {self.person_id} 的group_cardname字段失败,使用默认值")
|
||||||
self.group_nick_name = []
|
self.group_cardname_list = []
|
||||||
else:
|
else:
|
||||||
self.group_nick_name = []
|
self.group_cardname_list = []
|
||||||
|
|
||||||
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
|
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
|
||||||
else:
|
else:
|
||||||
@@ -486,11 +509,7 @@ class Person:
|
|||||||
if self.memory_points
|
if self.memory_points
|
||||||
else json.dumps([], ensure_ascii=False)
|
else json.dumps([], ensure_ascii=False)
|
||||||
)
|
)
|
||||||
group_nickname_value = (
|
group_cardname_value = dump_group_cardname_records(self.group_cardname_list)
|
||||||
json.dumps(self.group_nick_name, ensure_ascii=False)
|
|
||||||
if self.group_nick_name
|
|
||||||
else json.dumps([], ensure_ascii=False)
|
|
||||||
)
|
|
||||||
first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None
|
first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None
|
||||||
last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None
|
last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None
|
||||||
|
|
||||||
@@ -510,7 +529,7 @@ class Person:
|
|||||||
record.first_known_time = first_known_time
|
record.first_known_time = first_known_time
|
||||||
record.last_known_time = last_known_time
|
record.last_known_time = last_known_time
|
||||||
record.memory_points = memory_points_value
|
record.memory_points = memory_points_value
|
||||||
record.group_nickname = group_nickname_value
|
record.group_cardname = group_cardname_value
|
||||||
session.add(record)
|
session.add(record)
|
||||||
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
|
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
|
||||||
else:
|
else:
|
||||||
@@ -526,7 +545,7 @@ class Person:
|
|||||||
first_known_time=first_known_time,
|
first_known_time=first_known_time,
|
||||||
last_known_time=last_known_time,
|
last_known_time=last_known_time,
|
||||||
memory_points=memory_points_value,
|
memory_points=memory_points_value,
|
||||||
group_nickname=group_nickname_value,
|
group_cardname=group_cardname_value,
|
||||||
)
|
)
|
||||||
session.add(record)
|
session.add(record)
|
||||||
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
|
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@@ -8,13 +9,15 @@ import sys
|
|||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
|
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, RouteMode, get_platform_io_manager
|
||||||
from src.platform_io.drivers import PluginPlatformDriver
|
from src.platform_io.drivers import PluginPlatformDriver
|
||||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||||
from src.platform_io.routing import RouteBindingConflictError
|
from src.platform_io.routing import RouteBindingConflictError
|
||||||
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
|
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
|
||||||
from src.plugin_runtime.protocol.envelope import (
|
from src.plugin_runtime.protocol.envelope import (
|
||||||
AdapterDeclarationPayload,
|
AdapterDeclarationPayload,
|
||||||
|
AdapterStateUpdatePayload,
|
||||||
|
AdapterStateUpdateResultPayload,
|
||||||
BootstrapPluginPayload,
|
BootstrapPluginPayload,
|
||||||
ConfigUpdatedPayload,
|
ConfigUpdatedPayload,
|
||||||
Envelope,
|
Envelope,
|
||||||
@@ -46,6 +49,19 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = get_logger("plugin_runtime.host.runner_manager")
|
logger = get_logger("plugin_runtime.host.runner_manager")
|
||||||
|
|
||||||
|
_ADAPTER_BINDING_ROLE_RUNTIME_EXACT = "runtime_exact"
|
||||||
|
_ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT = "platform_default"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class _AdapterRuntimeState:
|
||||||
|
"""保存适配器插件当前的运行时连接状态。"""
|
||||||
|
|
||||||
|
connected: bool = False
|
||||||
|
account_id: Optional[str] = None
|
||||||
|
scope: Optional[str] = None
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class PluginRunnerSupervisor:
|
class PluginRunnerSupervisor:
|
||||||
"""插件 Runner 监督器。
|
"""插件 Runner 监督器。
|
||||||
@@ -94,6 +110,7 @@ class PluginRunnerSupervisor:
|
|||||||
self._runner_process: Optional[asyncio.subprocess.Process] = None
|
self._runner_process: Optional[asyncio.subprocess.Process] = None
|
||||||
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
|
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
|
||||||
self._registered_adapters: Dict[str, AdapterDeclarationPayload] = {}
|
self._registered_adapters: Dict[str, AdapterDeclarationPayload] = {}
|
||||||
|
self._adapter_runtime_states: Dict[str, _AdapterRuntimeState] = {}
|
||||||
self._runner_ready_events: asyncio.Event = asyncio.Event()
|
self._runner_ready_events: asyncio.Event = asyncio.Event()
|
||||||
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
|
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
|
||||||
self._health_task: Optional[asyncio.Task[None]] = None
|
self._health_task: Optional[asyncio.Task[None]] = None
|
||||||
@@ -452,6 +469,7 @@ class PluginRunnerSupervisor:
|
|||||||
"""注册 Host 侧内部 RPC 方法。"""
|
"""注册 Host 侧内部 RPC 方法。"""
|
||||||
self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request)
|
self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request)
|
||||||
self._rpc_server.register_method("host.receive_external_message", self._handle_receive_external_message)
|
self._rpc_server.register_method("host.receive_external_message", self._handle_receive_external_message)
|
||||||
|
self._rpc_server.register_method("host.update_adapter_state", self._handle_update_adapter_state)
|
||||||
self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin)
|
self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin)
|
||||||
self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin)
|
self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin)
|
||||||
self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin)
|
self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin)
|
||||||
@@ -563,14 +581,14 @@ class PluginRunnerSupervisor:
|
|||||||
return f"adapter:{plugin_id}"
|
return f"adapter:{plugin_id}"
|
||||||
|
|
||||||
async def _register_adapter_driver(self, plugin_id: str, adapter: AdapterDeclarationPayload) -> None:
|
async def _register_adapter_driver(self, plugin_id: str, adapter: AdapterDeclarationPayload) -> None:
|
||||||
"""将适配器插件注册到 Platform IO。
|
"""将适配器插件驱动注册到 Platform IO。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
plugin_id: 适配器插件 ID。
|
plugin_id: 适配器插件 ID。
|
||||||
adapter: 经过校验的适配器声明。
|
adapter: 经过校验的适配器声明。
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: 适配器路由冲突或驱动注册失败时抛出。
|
ValueError: 当驱动注册失败时抛出。
|
||||||
"""
|
"""
|
||||||
await self._unregister_adapter_driver(plugin_id)
|
await self._unregister_adapter_driver(plugin_id)
|
||||||
|
|
||||||
@@ -588,22 +606,12 @@ class PluginRunnerSupervisor:
|
|||||||
**adapter.metadata,
|
**adapter.metadata,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
binding = RouteBinding(
|
|
||||||
route_key=driver.descriptor.route_key,
|
|
||||||
driver_id=driver.driver_id,
|
|
||||||
driver_kind=DriverKind.PLUGIN,
|
|
||||||
metadata={
|
|
||||||
"plugin_id": plugin_id,
|
|
||||||
"protocol": adapter.protocol,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if platform_io_manager.is_started:
|
if platform_io_manager.is_started:
|
||||||
await platform_io_manager.add_driver(driver)
|
await platform_io_manager.add_driver(driver)
|
||||||
else:
|
else:
|
||||||
platform_io_manager.register_driver(driver)
|
platform_io_manager.register_driver(driver)
|
||||||
platform_io_manager.bind_route(binding)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
if platform_io_manager.is_started:
|
if platform_io_manager.is_started:
|
||||||
@@ -613,6 +621,7 @@ class PluginRunnerSupervisor:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
self._registered_adapters[plugin_id] = adapter
|
self._registered_adapters[plugin_id] = adapter
|
||||||
|
self._adapter_runtime_states[plugin_id] = _AdapterRuntimeState()
|
||||||
|
|
||||||
async def _unregister_adapter_driver(self, plugin_id: str) -> None:
|
async def _unregister_adapter_driver(self, plugin_id: str) -> None:
|
||||||
"""从 Platform IO 注销一个适配器驱动。
|
"""从 Platform IO 注销一个适配器驱动。
|
||||||
@@ -622,6 +631,9 @@ class PluginRunnerSupervisor:
|
|||||||
"""
|
"""
|
||||||
platform_io_manager = get_platform_io_manager()
|
platform_io_manager = get_platform_io_manager()
|
||||||
driver_id = self._build_adapter_driver_id(plugin_id)
|
driver_id = self._build_adapter_driver_id(plugin_id)
|
||||||
|
adapter = self._registered_adapters.get(plugin_id)
|
||||||
|
|
||||||
|
self._remove_adapter_route_bindings(plugin_id)
|
||||||
|
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
if platform_io_manager.is_started:
|
if platform_io_manager.is_started:
|
||||||
@@ -629,7 +641,11 @@ class PluginRunnerSupervisor:
|
|||||||
else:
|
else:
|
||||||
platform_io_manager.unregister_driver(driver_id)
|
platform_io_manager.unregister_driver(driver_id)
|
||||||
|
|
||||||
|
if adapter is not None:
|
||||||
|
self._refresh_platform_default_route(adapter.platform)
|
||||||
|
|
||||||
self._registered_adapters.pop(plugin_id, None)
|
self._registered_adapters.pop(plugin_id, None)
|
||||||
|
self._adapter_runtime_states.pop(plugin_id, None)
|
||||||
|
|
||||||
async def _unregister_all_adapter_drivers(self) -> None:
|
async def _unregister_all_adapter_drivers(self) -> None:
|
||||||
"""注销当前 Supervisor 管理的全部适配器驱动。"""
|
"""注销当前 Supervisor 管理的全部适配器驱动。"""
|
||||||
@@ -637,6 +653,198 @@ class PluginRunnerSupervisor:
|
|||||||
for plugin_id in plugin_ids:
|
for plugin_id in plugin_ids:
|
||||||
await self._unregister_adapter_driver(plugin_id)
|
await self._unregister_adapter_driver(plugin_id)
|
||||||
|
|
||||||
|
def _remove_adapter_route_bindings(self, plugin_id: str) -> None:
|
||||||
|
"""移除某个适配器驱动当前持有的全部路由绑定。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: 适配器插件 ID。
|
||||||
|
"""
|
||||||
|
platform_io_manager = get_platform_io_manager()
|
||||||
|
platform_io_manager.route_table.remove_bindings_by_driver(self._build_adapter_driver_id(plugin_id))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_runtime_route_value(value: str) -> Optional[str]:
|
||||||
|
"""规范化适配器运行时路由字段。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: 待规范化的原始字符串。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: 规范化后非空则返回字符串,否则返回 ``None``。
|
||||||
|
"""
|
||||||
|
normalized_value = str(value).strip()
|
||||||
|
return normalized_value or None
|
||||||
|
|
||||||
|
def _build_runtime_route_key(
|
||||||
|
self,
|
||||||
|
adapter: AdapterDeclarationPayload,
|
||||||
|
payload: AdapterStateUpdatePayload,
|
||||||
|
) -> RouteKey:
|
||||||
|
"""根据运行时状态更新构造适配器生效路由键。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
adapter: 当前适配器声明。
|
||||||
|
payload: 适配器上报的运行时状态。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RouteKey: 当前连接应接管的精确路由键。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当静态声明与运行时上报的身份信息冲突时抛出。
|
||||||
|
"""
|
||||||
|
runtime_account_id = self._normalize_runtime_route_value(payload.account_id)
|
||||||
|
runtime_scope = self._normalize_runtime_route_value(payload.scope)
|
||||||
|
|
||||||
|
if adapter.account_id and runtime_account_id and adapter.account_id != runtime_account_id:
|
||||||
|
raise ValueError(
|
||||||
|
f"适配器声明的 account_id={adapter.account_id} 与运行时上报的 {runtime_account_id} 不一致"
|
||||||
|
)
|
||||||
|
if adapter.scope and runtime_scope and adapter.scope != runtime_scope:
|
||||||
|
raise ValueError(f"适配器声明的 scope={adapter.scope} 与运行时上报的 {runtime_scope} 不一致")
|
||||||
|
|
||||||
|
return RouteKey(
|
||||||
|
platform=adapter.platform,
|
||||||
|
account_id=runtime_account_id or adapter.account_id or None,
|
||||||
|
scope=runtime_scope or adapter.scope or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _bind_runtime_exact_route(
|
||||||
|
self,
|
||||||
|
plugin_id: str,
|
||||||
|
adapter: AdapterDeclarationPayload,
|
||||||
|
route_key: RouteKey,
|
||||||
|
) -> None:
|
||||||
|
"""为适配器连接绑定精确生效路由。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: 适配器插件 ID。
|
||||||
|
adapter: 当前适配器声明。
|
||||||
|
route_key: 当前连接对应的精确路由键。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RouteBindingConflictError: 当目标路由已被其他 active owner 占用时抛出。
|
||||||
|
"""
|
||||||
|
platform_io_manager = get_platform_io_manager()
|
||||||
|
platform_io_manager.bind_route(
|
||||||
|
RouteBinding(
|
||||||
|
route_key=route_key,
|
||||||
|
driver_id=self._build_adapter_driver_id(plugin_id),
|
||||||
|
driver_kind=DriverKind.PLUGIN,
|
||||||
|
metadata={
|
||||||
|
"plugin_id": plugin_id,
|
||||||
|
"protocol": adapter.protocol,
|
||||||
|
"binding_role": _ADAPTER_BINDING_ROLE_RUNTIME_EXACT,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _list_runtime_exact_bindings(self, platform: str) -> List[RouteBinding]:
|
||||||
|
"""列出某个平台上由 Host 动态维护的精确适配器绑定。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform: 目标平台名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[RouteBinding]: 当前平台上全部动态精确绑定。
|
||||||
|
"""
|
||||||
|
platform_io_manager = get_platform_io_manager()
|
||||||
|
return [
|
||||||
|
binding
|
||||||
|
for binding in platform_io_manager.route_table.list_bindings()
|
||||||
|
if binding.mode == RouteMode.ACTIVE
|
||||||
|
and binding.route_key.platform == platform
|
||||||
|
and binding.metadata.get("binding_role") == _ADAPTER_BINDING_ROLE_RUNTIME_EXACT
|
||||||
|
]
|
||||||
|
|
||||||
|
def _refresh_platform_default_route(self, platform: str) -> None:
|
||||||
|
"""根据当前精确绑定数量刷新平台级默认路由。
|
||||||
|
|
||||||
|
当某个平台恰好只存在一个动态精确绑定时,会为该绑定额外创建一条
|
||||||
|
``RouteKey(platform=<platform>)`` 形式的默认路由,方便缺少账号维度的
|
||||||
|
出站消息继续找到唯一 owner。若精确绑定数量变为 0 或大于 1,则撤销
|
||||||
|
由 Host 自动维护的默认路由,避免出现隐式歧义。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform: 目标平台名称。
|
||||||
|
"""
|
||||||
|
platform_io_manager = get_platform_io_manager()
|
||||||
|
default_route_key = RouteKey(platform=platform)
|
||||||
|
existing_default_binding = platform_io_manager.route_table.get_active_binding(default_route_key, exact_only=True)
|
||||||
|
|
||||||
|
if existing_default_binding is not None:
|
||||||
|
binding_role = existing_default_binding.metadata.get("binding_role")
|
||||||
|
if binding_role != _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT:
|
||||||
|
return
|
||||||
|
platform_io_manager.unbind_route(default_route_key, existing_default_binding.driver_id)
|
||||||
|
|
||||||
|
exact_bindings = self._list_runtime_exact_bindings(platform)
|
||||||
|
if len(exact_bindings) != 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
exact_binding = exact_bindings[0]
|
||||||
|
if exact_binding.route_key == default_route_key:
|
||||||
|
return
|
||||||
|
|
||||||
|
platform_io_manager.bind_route(
|
||||||
|
RouteBinding(
|
||||||
|
route_key=default_route_key,
|
||||||
|
driver_id=exact_binding.driver_id,
|
||||||
|
driver_kind=exact_binding.driver_kind,
|
||||||
|
metadata={
|
||||||
|
"plugin_id": exact_binding.metadata.get("plugin_id", ""),
|
||||||
|
"protocol": exact_binding.metadata.get("protocol", ""),
|
||||||
|
"binding_role": _ADAPTER_BINDING_ROLE_PLATFORM_DEFAULT,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
replace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_adapter_runtime_state(
|
||||||
|
self,
|
||||||
|
plugin_id: str,
|
||||||
|
adapter: AdapterDeclarationPayload,
|
||||||
|
payload: AdapterStateUpdatePayload,
|
||||||
|
) -> Tuple[_AdapterRuntimeState, Dict[str, Any]]:
|
||||||
|
"""应用适配器运行时状态,并同步 Platform IO 路由。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: 适配器插件 ID。
|
||||||
|
adapter: 当前适配器声明。
|
||||||
|
payload: 适配器上报的运行时状态。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[_AdapterRuntimeState, Dict[str, Any]]: 更新后的运行时状态,以及
|
||||||
|
供 RPC 响应返回的路由键字典。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RouteBindingConflictError: 当新的精确路由与其他 active owner 冲突时抛出。
|
||||||
|
ValueError: 当运行时路由信息不合法时抛出。
|
||||||
|
"""
|
||||||
|
if not payload.connected:
|
||||||
|
self._remove_adapter_route_bindings(plugin_id)
|
||||||
|
self._refresh_platform_default_route(adapter.platform)
|
||||||
|
runtime_state = _AdapterRuntimeState(connected=False, metadata=dict(payload.metadata))
|
||||||
|
self._adapter_runtime_states[plugin_id] = runtime_state
|
||||||
|
return runtime_state, {}
|
||||||
|
|
||||||
|
route_key = self._build_runtime_route_key(adapter, payload)
|
||||||
|
self._remove_adapter_route_bindings(plugin_id)
|
||||||
|
self._bind_runtime_exact_route(plugin_id, adapter, route_key)
|
||||||
|
self._refresh_platform_default_route(adapter.platform)
|
||||||
|
|
||||||
|
runtime_state = _AdapterRuntimeState(
|
||||||
|
connected=True,
|
||||||
|
account_id=route_key.account_id,
|
||||||
|
scope=route_key.scope,
|
||||||
|
metadata=dict(payload.metadata),
|
||||||
|
)
|
||||||
|
self._adapter_runtime_states[plugin_id] = runtime_state
|
||||||
|
return runtime_state, {
|
||||||
|
"platform": route_key.platform,
|
||||||
|
"account_id": route_key.account_id,
|
||||||
|
"scope": route_key.scope,
|
||||||
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _attach_inbound_route_metadata(
|
def _attach_inbound_route_metadata(
|
||||||
session_message: "SessionMessage",
|
session_message: "SessionMessage",
|
||||||
@@ -706,6 +914,45 @@ class PluginRunnerSupervisor:
|
|||||||
scope=scope,
|
scope=scope,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _handle_update_adapter_state(self, envelope: Envelope) -> Envelope:
|
||||||
|
"""处理适配器插件上报的运行时状态更新。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
envelope: RPC 请求信封。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Envelope: 状态更新处理结果。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
payload = AdapterStateUpdatePayload.model_validate(envelope.payload)
|
||||||
|
except Exception as exc:
|
||||||
|
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||||
|
|
||||||
|
adapter = self._registered_adapters.get(envelope.plugin_id)
|
||||||
|
if adapter is None:
|
||||||
|
return envelope.make_error_response(
|
||||||
|
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||||
|
f"插件 {envelope.plugin_id} 未声明为适配器,不能更新运行时状态",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
runtime_state, route_key_dict = self._apply_adapter_runtime_state(
|
||||||
|
plugin_id=envelope.plugin_id,
|
||||||
|
adapter=adapter,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
except RouteBindingConflictError as exc:
|
||||||
|
return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, str(exc))
|
||||||
|
except Exception as exc:
|
||||||
|
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||||
|
|
||||||
|
response = AdapterStateUpdateResultPayload(
|
||||||
|
accepted=True,
|
||||||
|
connected=runtime_state.connected,
|
||||||
|
route_key=route_key_dict,
|
||||||
|
)
|
||||||
|
return envelope.make_response(payload=response.model_dump())
|
||||||
|
|
||||||
async def _handle_receive_external_message(self, envelope: Envelope) -> Envelope:
|
async def _handle_receive_external_message(self, envelope: Envelope) -> Envelope:
|
||||||
"""处理适配器插件上报的外部入站消息。
|
"""处理适配器插件上报的外部入站消息。
|
||||||
|
|
||||||
@@ -970,6 +1217,7 @@ class PluginRunnerSupervisor:
|
|||||||
self._component_registry.clear()
|
self._component_registry.clear()
|
||||||
self._registered_plugins.clear()
|
self._registered_plugins.clear()
|
||||||
self._registered_adapters.clear()
|
self._registered_adapters.clear()
|
||||||
|
self._adapter_runtime_states.clear()
|
||||||
self._runner_ready_events = asyncio.Event()
|
self._runner_ready_events = asyncio.Event()
|
||||||
self._runner_ready_payloads = RunnerReadyPayload()
|
self._runner_ready_payloads = RunnerReadyPayload()
|
||||||
self._rpc_server.clear_handshake_state()
|
self._rpc_server.clear_handshake_state()
|
||||||
|
|||||||
@@ -304,6 +304,30 @@ class AdapterDeclarationPayload(BaseModel):
|
|||||||
"""适配器附加元数据"""
|
"""适配器附加元数据"""
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterStateUpdatePayload(BaseModel):
|
||||||
|
"""适配器运行时状态更新载荷。"""
|
||||||
|
|
||||||
|
connected: bool = Field(description="适配器当前是否已连接并准备接管路由")
|
||||||
|
"""适配器当前是否已连接并准备接管路由"""
|
||||||
|
account_id: str = Field(default="", description="当前连接对应的账号 ID 或 self_id")
|
||||||
|
"""当前连接对应的账号 ID 或 self_id"""
|
||||||
|
scope: str = Field(default="", description="当前连接对应的可选路由作用域")
|
||||||
|
"""当前连接对应的可选路由作用域"""
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的运行时状态元数据")
|
||||||
|
"""可选的运行时状态元数据"""
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterStateUpdateResultPayload(BaseModel):
|
||||||
|
"""适配器运行时状态更新结果载荷。"""
|
||||||
|
|
||||||
|
accepted: bool = Field(description="Host 是否接受了本次状态更新")
|
||||||
|
"""Host 是否接受了本次状态更新"""
|
||||||
|
connected: bool = Field(description="Host 记录的当前连接状态")
|
||||||
|
"""Host 记录的当前连接状态"""
|
||||||
|
route_key: Dict[str, Any] = Field(default_factory=dict, description="当前生效的路由键")
|
||||||
|
"""当前生效的路由键"""
|
||||||
|
|
||||||
|
|
||||||
class ReceiveExternalMessagePayload(BaseModel):
|
class ReceiveExternalMessagePayload(BaseModel):
|
||||||
"""适配器插件向 Host 注入外部消息的请求载荷。"""
|
"""适配器插件向 Host 注入外部消息的请求载荷。"""
|
||||||
|
|
||||||
|
|||||||
@@ -481,13 +481,14 @@ class PluginRunner:
|
|||||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not await self._invoke_plugin_on_load(meta):
|
if not await self._register_plugin(meta):
|
||||||
|
await self._invoke_plugin_on_unload(meta)
|
||||||
await self._deactivate_plugin(meta)
|
await self._deactivate_plugin(meta)
|
||||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not await self._register_plugin(meta):
|
if not await self._invoke_plugin_on_load(meta):
|
||||||
await self._invoke_plugin_on_unload(meta)
|
await self._unregister_plugin(meta.plugin_id, reason="on_load_failed")
|
||||||
await self._deactivate_plugin(meta)
|
await self._deactivate_plugin(meta)
|
||||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -60,6 +60,9 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
|||||||
self._connection_task: Optional[asyncio.Task[None]] = None
|
self._connection_task: Optional[asyncio.Task[None]] = None
|
||||||
self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
|
self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
|
||||||
self._background_tasks: Set[asyncio.Task[Any]] = set()
|
self._background_tasks: Set[asyncio.Task[Any]] = set()
|
||||||
|
self._reported_account_id: Optional[str] = None
|
||||||
|
self._reported_scope: Optional[str] = None
|
||||||
|
self._runtime_state_connected: bool = False
|
||||||
self._send_lock = asyncio.Lock()
|
self._send_lock = asyncio.Lock()
|
||||||
self._ws: Optional[AiohttpClientWebSocketResponse] = None
|
self._ws: Optional[AiohttpClientWebSocketResponse] = None
|
||||||
|
|
||||||
@@ -170,6 +173,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
|||||||
with contextlib.suppress(asyncio.CancelledError):
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
await connection_task
|
await connection_task
|
||||||
|
|
||||||
|
await self._report_adapter_disconnected()
|
||||||
self._fail_pending_actions("NapCat connection closed")
|
self._fail_pending_actions("NapCat connection closed")
|
||||||
|
|
||||||
async def _cancel_background_tasks(self) -> None:
|
async def _cancel_background_tasks(self) -> None:
|
||||||
@@ -209,6 +213,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
|||||||
self.ctx.logger.warning(f"NapCat 适配器连接失败: {exc}")
|
self.ctx.logger.warning(f"NapCat 适配器连接失败: {exc}")
|
||||||
finally:
|
finally:
|
||||||
self._ws = None
|
self._ws = None
|
||||||
|
await self._report_adapter_disconnected()
|
||||||
self._fail_pending_actions("NapCat connection interrupted")
|
self._fail_pending_actions("NapCat connection interrupted")
|
||||||
|
|
||||||
if not self._should_connect():
|
if not self._should_connect():
|
||||||
@@ -230,26 +235,39 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
|||||||
"""
|
"""
|
||||||
assert WSMsgType is not None
|
assert WSMsgType is not None
|
||||||
|
|
||||||
async for ws_message in ws:
|
bootstrap_task = asyncio.create_task(
|
||||||
if ws_message.type != WSMsgType.TEXT:
|
self._bootstrap_adapter_runtime_state(),
|
||||||
if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
|
name="napcat_adapter.bootstrap",
|
||||||
break
|
)
|
||||||
continue
|
self._background_tasks.add(bootstrap_task)
|
||||||
|
bootstrap_task.add_done_callback(self._background_tasks.discard)
|
||||||
|
|
||||||
payload = self._parse_json_message(ws_message.data)
|
try:
|
||||||
if payload is None:
|
async for ws_message in ws:
|
||||||
continue
|
if ws_message.type != WSMsgType.TEXT:
|
||||||
|
if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
if echo_id := str(payload.get("echo") or "").strip():
|
payload = self._parse_json_message(ws_message.data)
|
||||||
self._resolve_pending_action(echo_id, payload)
|
if payload is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if str(payload.get("post_type") or "").strip() != "message":
|
if echo_id := str(payload.get("echo") or "").strip():
|
||||||
continue
|
self._resolve_pending_action(echo_id, payload)
|
||||||
|
continue
|
||||||
|
|
||||||
task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound")
|
if str(payload.get("post_type") or "").strip() != "message":
|
||||||
self._background_tasks.add(task)
|
continue
|
||||||
task.add_done_callback(self._background_tasks.discard)
|
|
||||||
|
task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound")
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
finally:
|
||||||
|
if not bootstrap_task.done():
|
||||||
|
bootstrap_task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await bootstrap_task
|
||||||
|
|
||||||
async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None:
|
async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None:
|
||||||
"""处理单条 NapCat 入站消息并注入 Host。
|
"""处理单条 NapCat 入站消息并注入 Host。
|
||||||
@@ -258,6 +276,9 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
|||||||
payload: NapCat / OneBot 推送的原始事件数据。
|
payload: NapCat / OneBot 推送的原始事件数据。
|
||||||
"""
|
"""
|
||||||
self_id = str(payload.get("self_id") or "").strip()
|
self_id = str(payload.get("self_id") or "").strip()
|
||||||
|
if self_id:
|
||||||
|
await self._report_adapter_connected(self_id)
|
||||||
|
|
||||||
sender = payload.get("sender", {})
|
sender = payload.get("sender", {})
|
||||||
if not isinstance(sender, dict):
|
if not isinstance(sender, dict):
|
||||||
sender = {}
|
sender = {}
|
||||||
@@ -570,6 +591,121 @@ class NapCatAdapterPlugin(MaiBotPlugin):
|
|||||||
response_future.set_exception(RuntimeError(error_message))
|
response_future.set_exception(RuntimeError(error_message))
|
||||||
self._pending_actions.clear()
|
self._pending_actions.clear()
|
||||||
|
|
||||||
|
async def _bootstrap_adapter_runtime_state(self) -> None:
|
||||||
|
"""在连接建立后主动获取账号信息并激活适配器路由。
|
||||||
|
|
||||||
|
该步骤会在 WebSocket 接收循环启动后异步执行,确保 `_call_action()`
|
||||||
|
发出的 `get_login_info` 请求能够被同一连接上的接收循环消费到 echo
|
||||||
|
响应,从而在真正收到业务消息前就完成 Host 侧 route 激活。
|
||||||
|
"""
|
||||||
|
max_attempts = 3
|
||||||
|
last_error: Optional[Exception] = None
|
||||||
|
for attempt in range(1, max_attempts + 1):
|
||||||
|
ws = self._ws
|
||||||
|
if ws is None or ws.closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._call_action("get_login_info", {})
|
||||||
|
self_id = self._extract_self_id_from_login_response(response)
|
||||||
|
await self._report_adapter_connected(self_id)
|
||||||
|
return
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
last_error = exc
|
||||||
|
self.ctx.logger.warning(
|
||||||
|
f"NapCat 适配器获取登录信息失败,第 {attempt}/{max_attempts} 次重试: {exc}"
|
||||||
|
)
|
||||||
|
if attempt < max_attempts:
|
||||||
|
await asyncio.sleep(1.0)
|
||||||
|
|
||||||
|
if last_error is not None:
|
||||||
|
self.ctx.logger.error(f"NapCat 适配器未能完成路由激活,连接将保持只接收状态: {last_error}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_self_id_from_login_response(response: Dict[str, Any]) -> str:
|
||||||
|
"""从 `get_login_info` 响应中提取当前账号 ID。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: NapCat 返回的原始动作响应。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 规范化后的 `self_id` 字符串。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当响应中缺少有效账号 ID 时抛出。
|
||||||
|
"""
|
||||||
|
if str(response.get("status") or "").lower() != "ok":
|
||||||
|
raise ValueError(str(response.get("wording") or response.get("message") or "get_login_info failed"))
|
||||||
|
|
||||||
|
response_data = response.get("data", {})
|
||||||
|
if not isinstance(response_data, dict):
|
||||||
|
raise ValueError("get_login_info 响应缺少 data 字段")
|
||||||
|
|
||||||
|
self_id = str(response_data.get("user_id") or "").strip()
|
||||||
|
if not self_id:
|
||||||
|
raise ValueError("get_login_info 响应缺少有效的 user_id")
|
||||||
|
return self_id
|
||||||
|
|
||||||
|
async def _report_adapter_connected(self, account_id: str) -> None:
|
||||||
|
"""向 Host 上报当前连接已就绪。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
account_id: 当前 NapCat 连接对应的机器人账号 ID。
|
||||||
|
"""
|
||||||
|
normalized_account_id = str(account_id).strip()
|
||||||
|
if not normalized_account_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
scope = self._get_string(self._connection_config(), "connection_id").strip()
|
||||||
|
if (
|
||||||
|
self._runtime_state_connected
|
||||||
|
and self._reported_account_id == normalized_account_id
|
||||||
|
and self._reported_scope == (scope or None)
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
accepted = False
|
||||||
|
try:
|
||||||
|
accepted = await self.ctx.adapter.update_runtime_state(
|
||||||
|
connected=True,
|
||||||
|
account_id=normalized_account_id,
|
||||||
|
scope=scope,
|
||||||
|
metadata={"ws_url": self._get_string(self._connection_config(), "ws_url")},
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
self.ctx.logger.warning(f"NapCat 适配器上报连接就绪状态失败: {exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not accepted:
|
||||||
|
self.ctx.logger.warning("NapCat 适配器连接已建立,但 Host 未接受运行时状态更新")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._runtime_state_connected = True
|
||||||
|
self._reported_account_id = normalized_account_id
|
||||||
|
self._reported_scope = scope or None
|
||||||
|
self.ctx.logger.info(
|
||||||
|
f"NapCat 适配器已激活路由: platform=qq account_id={normalized_account_id} "
|
||||||
|
f"scope={self._reported_scope or '*'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _report_adapter_disconnected(self) -> None:
|
||||||
|
"""向 Host 上报当前连接已断开,并撤销适配器路由。"""
|
||||||
|
if not self._runtime_state_connected:
|
||||||
|
self._reported_account_id = None
|
||||||
|
self._reported_scope = None
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.ctx.adapter.update_runtime_state(connected=False)
|
||||||
|
except Exception as exc:
|
||||||
|
self.ctx.logger.warning(f"NapCat 适配器上报断开状态失败: {exc}")
|
||||||
|
finally:
|
||||||
|
self._runtime_state_connected = False
|
||||||
|
self._reported_account_id = None
|
||||||
|
self._reported_scope = None
|
||||||
|
|
||||||
def _build_headers(self) -> Dict[str, str]:
|
def _build_headers(self) -> Dict[str, str]:
|
||||||
"""构造连接 NapCat 所需的请求头。
|
"""构造连接 NapCat 所需的请求头。
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user