merge: sync upstream/r-dev and resolve real conflicts
This commit is contained in:
89
pytests/common_test/test_expression_auto_check_task.py
Normal file
89
pytests/common_test/test_expression_auto_check_task.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""测试表达方式自动检查任务的数据库读取行为。"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
|
||||
@pytest.fixture(name="expression_auto_check_engine")
|
||||
def expression_auto_check_engine_fixture() -> Generator:
|
||||
"""创建用于表达方式自动检查任务测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_expressions_uses_read_only_session(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
expression_auto_check_engine,
|
||||
) -> None:
|
||||
"""选择表达方式时应使用只读会话,并在离开会话后安全读取 ORM 字段。"""
|
||||
|
||||
import src.bw_learner.expression_auto_check_task as expression_auto_check_task_module
|
||||
|
||||
with Session(expression_auto_check_engine) as session:
|
||||
session.add(
|
||||
Expression(
|
||||
situation="表达情绪高涨或生理反应",
|
||||
style="发送💦表情符号",
|
||||
content_list='["表达情绪高涨或生理反应"]',
|
||||
count=1,
|
||||
session_id="session-a",
|
||||
checked=False,
|
||||
rejected=False,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
auto_commit_calls: list[bool] = []
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||
"""构造带自动提交语义的测试会话工厂。
|
||||
|
||||
Args:
|
||||
auto_commit: 退出上下文时是否自动提交。
|
||||
|
||||
Yields:
|
||||
Generator[Session, None, None]: SQLModel 会话对象。
|
||||
"""
|
||||
|
||||
auto_commit_calls.append(auto_commit)
|
||||
session = Session(expression_auto_check_engine)
|
||||
try:
|
||||
yield session
|
||||
if auto_commit:
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
monkeypatch.setattr(expression_auto_check_task_module, "get_db_session", fake_get_db_session)
|
||||
monkeypatch.setattr(expression_auto_check_task_module.random, "sample", lambda entries, _count: list(entries))
|
||||
|
||||
task = ExpressionAutoCheckTask()
|
||||
expressions = await task._select_expressions(1)
|
||||
|
||||
assert auto_commit_calls == [False]
|
||||
assert len(expressions) == 1
|
||||
assert expressions[0].id is not None
|
||||
assert expressions[0].situation == "表达情绪高涨或生理反应"
|
||||
assert expressions[0].style == "发送💦表情符号"
|
||||
81
pytests/common_test/test_expression_learner.py
Normal file
81
pytests/common_test/test_expression_learner.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""测试表达方式学习器的数据库读取行为。"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from src.bw_learner.expression_learner import ExpressionLearner
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
|
||||
@pytest.fixture(name="expression_learner_engine")
|
||||
def expression_learner_engine_fixture() -> Generator:
|
||||
"""创建用于表达方式学习器测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
def test_find_similar_expression_uses_read_only_session_and_history_content(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
expression_learner_engine,
|
||||
) -> None:
|
||||
"""查找相似表达方式时,应能在离开会话后安全使用结果,并比较历史情景内容。"""
|
||||
import src.bw_learner.expression_learner as expression_learner_module
|
||||
|
||||
with Session(expression_learner_engine) as session:
|
||||
session.add(
|
||||
Expression(
|
||||
situation="发送汗滴表情",
|
||||
style="发送💦表情符号",
|
||||
content_list='["表达情绪高涨或生理反应"]',
|
||||
count=1,
|
||||
session_id="session-a",
|
||||
checked=False,
|
||||
rejected=False,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||
"""构造带自动提交语义的测试会话工厂。
|
||||
|
||||
Args:
|
||||
auto_commit: 退出上下文时是否自动提交。
|
||||
|
||||
Yields:
|
||||
Generator[Session, None, None]: SQLModel 会话对象。
|
||||
"""
|
||||
session = Session(expression_learner_engine)
|
||||
try:
|
||||
yield session
|
||||
if auto_commit:
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
monkeypatch.setattr(expression_learner_module, "get_db_session", fake_get_db_session)
|
||||
|
||||
learner = ExpressionLearner(session_id="session-a")
|
||||
result = learner._find_similar_expression("表达情绪高涨或生理反应")
|
||||
|
||||
assert result is not None
|
||||
expression, similarity = result
|
||||
assert expression.item_id is not None
|
||||
assert expression.style == "发送💦表情符号"
|
||||
assert similarity == pytest.approx(1.0)
|
||||
78
pytests/common_test/test_expression_schema.py
Normal file
78
pytests/common_test/test_expression_schema.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""测试表达方式表结构和基础插入行为。"""
|
||||
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
|
||||
@pytest.fixture(name="expression_engine")
|
||||
def expression_engine_fixture() -> Generator:
|
||||
"""创建仅用于表达方式表测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
def test_expression_insert_assigns_auto_increment_id(expression_engine) -> None:
|
||||
"""表达方式表在新库中应能自动分配自增主键。"""
|
||||
with Session(expression_engine) as session:
|
||||
expression = Expression(
|
||||
situation="表达情绪高涨或生理反应",
|
||||
style="发送💦表情符号",
|
||||
content_list='["表达情绪高涨或生理反应"]',
|
||||
count=1,
|
||||
session_id="session-a",
|
||||
checked=False,
|
||||
rejected=False,
|
||||
)
|
||||
session.add(expression)
|
||||
session.commit()
|
||||
session.refresh(expression)
|
||||
|
||||
assert expression.id is not None
|
||||
assert expression.id > 0
|
||||
|
||||
|
||||
def test_expression_insert_allows_same_situation_style(expression_engine) -> None:
|
||||
"""相同情景和风格的表达方式记录不应再被错误绑定到复合主键。"""
|
||||
with Session(expression_engine) as session:
|
||||
first_expression = Expression(
|
||||
situation="对重复行为的默契响应",
|
||||
style="持续性跟发相同内容",
|
||||
content_list='["对重复行为的默契响应"]',
|
||||
count=1,
|
||||
session_id="session-a",
|
||||
checked=False,
|
||||
rejected=False,
|
||||
)
|
||||
second_expression = Expression(
|
||||
situation="对重复行为的默契响应",
|
||||
style="持续性跟发相同内容",
|
||||
content_list='["对重复行为的默契响应-变体"]',
|
||||
count=2,
|
||||
session_id="session-b",
|
||||
checked=False,
|
||||
rejected=False,
|
||||
)
|
||||
|
||||
session.add(first_expression)
|
||||
session.add(second_expression)
|
||||
session.commit()
|
||||
session.refresh(first_expression)
|
||||
session.refresh(second_expression)
|
||||
|
||||
assert first_expression.id is not None
|
||||
assert second_expression.id is not None
|
||||
assert first_expression.id != second_expression.id
|
||||
90
pytests/common_test/test_jargon_miner.py
Normal file
90
pytests/common_test/test_jargon_miner.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""测试黑话学习器的数据库读取行为。"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
|
||||
from src.bw_learner.jargon_miner import JargonMiner
|
||||
from src.common.database.database_model import Jargon
|
||||
|
||||
|
||||
@pytest.fixture(name="jargon_miner_engine")
|
||||
def jargon_miner_engine_fixture() -> Generator:
|
||||
"""创建用于黑话学习器测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_extracted_entries_updates_existing_jargon_without_detached_session(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
jargon_miner_engine,
|
||||
) -> None:
|
||||
"""更新已有黑话时,不应因会话关闭导致 ORM 实例失效。"""
|
||||
import src.bw_learner.jargon_miner as jargon_miner_module
|
||||
|
||||
with Session(jargon_miner_engine) as session:
|
||||
session.add(
|
||||
Jargon(
|
||||
content="VF8V4L",
|
||||
raw_content='["[1] first"]',
|
||||
meaning="",
|
||||
session_id_dict='{"session-a": 1}',
|
||||
count=0,
|
||||
is_jargon=True,
|
||||
is_complete=False,
|
||||
is_global=False,
|
||||
last_inference_count=0,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||
"""构造带自动提交语义的测试会话工厂。
|
||||
|
||||
Args:
|
||||
auto_commit: 退出上下文时是否自动提交。
|
||||
|
||||
Yields:
|
||||
Generator[Session, None, None]: SQLModel 会话对象。
|
||||
"""
|
||||
session = Session(jargon_miner_engine)
|
||||
try:
|
||||
yield session
|
||||
if auto_commit:
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
monkeypatch.setattr(jargon_miner_module, "get_db_session", fake_get_db_session)
|
||||
|
||||
jargon_miner = JargonMiner(session_id="session-a", session_name="测试群")
|
||||
await jargon_miner.process_extracted_entries(
|
||||
[{"content": "VF8V4L", "raw_content": {"[2] second"}}],
|
||||
)
|
||||
|
||||
with Session(jargon_miner_engine) as session:
|
||||
db_jargon = session.exec(select(Jargon).where(Jargon.content == "VF8V4L")).one()
|
||||
|
||||
assert db_jargon.count == 1
|
||||
assert db_jargon.session_id_dict == '{"session-a": 2}'
|
||||
assert sorted(db_jargon.raw_content and __import__("json").loads(db_jargon.raw_content)) == [
|
||||
"[1] first",
|
||||
"[2] second",
|
||||
]
|
||||
84
pytests/common_test/test_jargon_schema.py
Normal file
84
pytests/common_test/test_jargon_schema.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""测试黑话表结构和基础插入行为。"""
|
||||
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from src.common.database.database_model import Jargon
|
||||
|
||||
|
||||
@pytest.fixture(name="jargon_engine")
|
||||
def jargon_engine_fixture() -> Generator:
|
||||
"""创建仅用于黑话表测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
def test_jargon_insert_assigns_auto_increment_id(jargon_engine) -> None:
|
||||
"""黑话表在新库中应能自动分配自增主键。"""
|
||||
with Session(jargon_engine) as session:
|
||||
jargon = Jargon(
|
||||
content="VF8V4L",
|
||||
raw_content='["[1] test"]',
|
||||
meaning="",
|
||||
session_id_dict='{"session-a": 1}',
|
||||
count=1,
|
||||
is_jargon=True,
|
||||
is_complete=False,
|
||||
is_global=True,
|
||||
last_inference_count=0,
|
||||
)
|
||||
session.add(jargon)
|
||||
session.commit()
|
||||
session.refresh(jargon)
|
||||
|
||||
assert jargon.id is not None
|
||||
assert jargon.id > 0
|
||||
|
||||
|
||||
def test_jargon_insert_allows_same_content_with_different_rows(jargon_engine) -> None:
|
||||
"""黑话内容不应再被错误地绑成复合主键的一部分。"""
|
||||
with Session(jargon_engine) as session:
|
||||
first_jargon = Jargon(
|
||||
content="表情1",
|
||||
raw_content='["[1] first"]',
|
||||
meaning="",
|
||||
session_id_dict='{"session-a": 1}',
|
||||
count=1,
|
||||
is_jargon=True,
|
||||
is_complete=False,
|
||||
is_global=False,
|
||||
last_inference_count=0,
|
||||
)
|
||||
second_jargon = Jargon(
|
||||
content="表情1",
|
||||
raw_content='["[1] second"]',
|
||||
meaning="",
|
||||
session_id_dict='{"session-b": 1}',
|
||||
count=1,
|
||||
is_jargon=True,
|
||||
is_complete=False,
|
||||
is_global=False,
|
||||
last_inference_count=0,
|
||||
)
|
||||
|
||||
session.add(first_jargon)
|
||||
session.add(second_jargon)
|
||||
session.commit()
|
||||
session.refresh(first_jargon)
|
||||
session.refresh(second_jargon)
|
||||
|
||||
assert first_jargon.id is not None
|
||||
assert second_jargon.id is not None
|
||||
assert first_jargon.id != second_jargon.id
|
||||
355
pytests/common_test/test_person_info_group_cardname.py
Normal file
355
pytests/common_test/test_person_info_group_cardname.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""人物信息群名片字段兼容测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
|
||||
|
||||
|
||||
class _DummyLogger:
|
||||
"""模拟日志记录器。"""
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
"""记录调试日志。
|
||||
|
||||
Args:
|
||||
message: 日志内容。
|
||||
"""
|
||||
del message
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""记录信息日志。
|
||||
|
||||
Args:
|
||||
message: 日志内容。
|
||||
"""
|
||||
del message
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""记录警告日志。
|
||||
|
||||
Args:
|
||||
message: 日志内容。
|
||||
"""
|
||||
del message
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""记录错误日志。
|
||||
|
||||
Args:
|
||||
message: 日志内容。
|
||||
"""
|
||||
del message
|
||||
|
||||
|
||||
class _DummyStatement:
|
||||
"""模拟 SQL 查询语句对象。"""
|
||||
|
||||
def where(self, condition: Any) -> "_DummyStatement":
|
||||
"""附加过滤条件。
|
||||
|
||||
Args:
|
||||
condition: 过滤条件。
|
||||
|
||||
Returns:
|
||||
_DummyStatement: 当前语句对象。
|
||||
"""
|
||||
del condition
|
||||
return self
|
||||
|
||||
def limit(self, value: int) -> "_DummyStatement":
|
||||
"""限制返回条数。
|
||||
|
||||
Args:
|
||||
value: 条数限制。
|
||||
|
||||
Returns:
|
||||
_DummyStatement: 当前语句对象。
|
||||
"""
|
||||
del value
|
||||
return self
|
||||
|
||||
|
||||
class _DummyColumn:
|
||||
"""模拟 SQLModel 列对象。"""
|
||||
|
||||
def is_not(self, value: Any) -> "_DummyColumn":
|
||||
"""模拟 `IS NOT` 条件构造。
|
||||
|
||||
Args:
|
||||
value: 比较值。
|
||||
|
||||
Returns:
|
||||
_DummyColumn: 当前列对象。
|
||||
"""
|
||||
del value
|
||||
return self
|
||||
|
||||
def __eq__(self, other: Any) -> "_DummyColumn":
|
||||
"""模拟等值条件构造。
|
||||
|
||||
Args:
|
||||
other: 比较值。
|
||||
|
||||
Returns:
|
||||
_DummyColumn: 当前列对象。
|
||||
"""
|
||||
del other
|
||||
return self
|
||||
|
||||
|
||||
class _DummyResult:
|
||||
"""模拟数据库查询结果。"""
|
||||
|
||||
def __init__(self, record: Any) -> None:
|
||||
"""初始化查询结果。
|
||||
|
||||
Args:
|
||||
record: 待返回的首条记录。
|
||||
"""
|
||||
self._record = record
|
||||
|
||||
def first(self) -> Any:
|
||||
"""返回第一条记录。
|
||||
|
||||
Returns:
|
||||
Any: 首条记录。
|
||||
"""
|
||||
return self._record
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
"""返回全部结果。
|
||||
|
||||
Returns:
|
||||
list[Any]: 结果列表。
|
||||
"""
|
||||
if self._record is None:
|
||||
return []
|
||||
return self._record if isinstance(self._record, list) else [self._record]
|
||||
|
||||
|
||||
class _DummySession:
|
||||
"""模拟数据库 Session。"""
|
||||
|
||||
def __init__(self, record: Any) -> None:
|
||||
"""初始化 Session。
|
||||
|
||||
Args:
|
||||
record: `first()` 应返回的记录。
|
||||
"""
|
||||
self.record = record
|
||||
self.added_records: list[Any] = []
|
||||
|
||||
def __enter__(self) -> "_DummySession":
|
||||
"""进入上下文管理器。
|
||||
|
||||
Returns:
|
||||
_DummySession: 当前 Session。
|
||||
"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""退出上下文管理器。
|
||||
|
||||
Args:
|
||||
exc_type: 异常类型。
|
||||
exc_val: 异常值。
|
||||
exc_tb: 异常回溯。
|
||||
"""
|
||||
del exc_type
|
||||
del exc_val
|
||||
del exc_tb
|
||||
|
||||
def exec(self, statement: Any) -> _DummyResult:
|
||||
"""执行查询。
|
||||
|
||||
Args:
|
||||
statement: 查询语句。
|
||||
|
||||
Returns:
|
||||
_DummyResult: 模拟结果对象。
|
||||
"""
|
||||
del statement
|
||||
return _DummyResult(self.record)
|
||||
|
||||
def add(self, record: Any) -> None:
|
||||
"""记录被添加的对象。
|
||||
|
||||
Args:
|
||||
record: 被写入 Session 的对象。
|
||||
"""
|
||||
self.added_records.append(record)
|
||||
|
||||
|
||||
class _DummyPersonInfoRecord:
|
||||
"""模拟 `PersonInfo` ORM 模型。"""
|
||||
|
||||
person_id = "person_id"
|
||||
person_name = "person_name"
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""使用关键字参数初始化记录对象。
|
||||
|
||||
Args:
|
||||
**kwargs: 字段值。
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
def _load_person_module(monkeypatch: pytest.MonkeyPatch, session: _DummySession) -> ModuleType:
|
||||
"""加载带依赖桩的 `person_info` 模块。
|
||||
|
||||
Args:
|
||||
monkeypatch: Pytest monkeypatch 工具。
|
||||
session: 提供给模块使用的假数据库 Session。
|
||||
|
||||
Returns:
|
||||
ModuleType: 加载后的模块对象。
|
||||
"""
|
||||
logger_module = ModuleType("src.common.logger")
|
||||
logger_module.get_logger = lambda name: _DummyLogger()
|
||||
monkeypatch.setitem(sys.modules, "src.common.logger", logger_module)
|
||||
|
||||
database_module = ModuleType("src.common.database.database")
|
||||
database_module.get_db_session = lambda: session
|
||||
monkeypatch.setitem(sys.modules, "src.common.database.database", database_module)
|
||||
|
||||
database_model_module = ModuleType("src.common.database.database_model")
|
||||
database_model_module.PersonInfo = _DummyPersonInfoRecord
|
||||
monkeypatch.setitem(sys.modules, "src.common.database.database_model", database_model_module)
|
||||
|
||||
llm_module = ModuleType("src.llm_models.utils_model")
|
||||
|
||||
class _DummyLLMRequest:
|
||||
"""模拟 LLMRequest。"""
|
||||
|
||||
def __init__(self, model_set: Any, request_type: str) -> None:
|
||||
"""初始化假请求对象。
|
||||
|
||||
Args:
|
||||
model_set: 模型配置。
|
||||
request_type: 请求类型。
|
||||
"""
|
||||
del model_set
|
||||
del request_type
|
||||
|
||||
llm_module.LLMRequest = _DummyLLMRequest
|
||||
monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_module)
|
||||
|
||||
config_module = ModuleType("src.config.config")
|
||||
config_module.global_config = SimpleNamespace(bot=SimpleNamespace(nickname="MaiBot"))
|
||||
config_module.model_config = SimpleNamespace(model_task_config=SimpleNamespace(tool_use="tool_use", utils="utils"))
|
||||
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
|
||||
|
||||
chat_manager_module = ModuleType("src.chat.message_receive.chat_manager")
|
||||
chat_manager_module.chat_manager = SimpleNamespace()
|
||||
monkeypatch.setitem(sys.modules, "src.chat.message_receive.chat_manager", chat_manager_module)
|
||||
|
||||
module_path = Path(__file__).resolve().parents[2] / "src" / "person_info" / "person_info.py"
|
||||
spec = spec_from_file_location("person_info_group_cardname_test_module", module_path)
|
||||
assert spec is not None and spec.loader is not None
|
||||
|
||||
module = module_from_spec(spec)
|
||||
monkeypatch.setitem(sys.modules, spec.name, module)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
monkeypatch.setattr(module, "select", lambda *args: _DummyStatement())
|
||||
monkeypatch.setattr(module, "col", lambda field: _DummyColumn())
|
||||
return module
|
||||
|
||||
|
||||
def test_parse_group_cardname_json_uses_canonical_key() -> None:
|
||||
"""群名片 JSON 解析应只使用 `group_cardname` 键名。"""
|
||||
parsed = parse_group_cardname_json(
|
||||
json.dumps(
|
||||
[
|
||||
{"group_id": "1001", "group_cardname": "现行字段"},
|
||||
],
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert parsed is not None
|
||||
assert [(item.group_id, item.group_cardname) for item in parsed] == [
|
||||
("1001", "现行字段"),
|
||||
]
|
||||
|
||||
|
||||
def test_dump_group_cardname_records_uses_canonical_key() -> None:
|
||||
"""群名片序列化应输出 `group_cardname` 键名。"""
|
||||
dumped = dump_group_cardname_records(
|
||||
[
|
||||
{"group_id": "1001", "group_cardname": "群昵称"},
|
||||
]
|
||||
)
|
||||
|
||||
assert json.loads(dumped) == [{"group_id": "1001", "group_cardname": "群昵称"}]
|
||||
|
||||
|
||||
def test_person_sync_to_database_uses_group_cardname_field(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""同步人物信息时应写入数据库模型的 `group_cardname` 字段。"""
|
||||
record = _DummyPersonInfoRecord()
|
||||
session = _DummySession(record)
|
||||
module = _load_person_module(monkeypatch, session)
|
||||
|
||||
person = module.Person.__new__(module.Person)
|
||||
person.is_known = True
|
||||
person.person_id = "person-1"
|
||||
person.platform = "qq"
|
||||
person.user_id = "10001"
|
||||
person.nickname = "看番的龙"
|
||||
person.person_name = "看番的龙"
|
||||
person.name_reason = "测试"
|
||||
person.know_times = 1
|
||||
person.know_since = 1700000000.0
|
||||
person.last_know = 1700000100.0
|
||||
person.memory_points = ["喜好:番剧:0.8"]
|
||||
person.group_cardname_list = [{"group_id": "20001", "group_cardname": "白泽大人"}]
|
||||
|
||||
person.sync_to_database()
|
||||
|
||||
assert record.group_cardname == '[{"group_id": "20001", "group_cardname": "白泽大人"}]'
|
||||
assert not hasattr(record, "group_nickname")
|
||||
|
||||
|
||||
def test_person_load_from_database_normalizes_group_cardname_payload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""从数据库加载人物信息时应读取标准 `group_cardname` 结构。"""
|
||||
record = _DummyPersonInfoRecord(
|
||||
user_id="10001",
|
||||
platform="qq",
|
||||
is_known=True,
|
||||
user_nickname="看番的龙",
|
||||
person_name="看番的龙",
|
||||
name_reason=None,
|
||||
know_counts=2,
|
||||
memory_points='["喜好:番剧:0.8"]',
|
||||
group_cardname=json.dumps(
|
||||
[
|
||||
{"group_id": "20001", "group_cardname": "白泽大人"},
|
||||
],
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
session = _DummySession(record)
|
||||
module = _load_person_module(monkeypatch, session)
|
||||
|
||||
person = module.Person.__new__(module.Person)
|
||||
person.person_id = "person-1"
|
||||
person.memory_points = []
|
||||
person.group_cardname_list = []
|
||||
|
||||
person.load_from_database()
|
||||
|
||||
assert person.group_cardname_list == [
|
||||
{"group_id": "20001", "group_cardname": "白泽大人"},
|
||||
]
|
||||
170
pytests/test_message_gateway_runtime.py
Normal file
170
pytests/test_message_gateway_runtime.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""消息网关运行时状态同步测试。"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from src.platform_io.manager import PlatformIOManager
|
||||
from src.platform_io.types import RouteKey
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
|
||||
|
||||
|
||||
def _make_request(method: str, plugin_id: str, payload: Dict[str, Any]) -> Envelope:
|
||||
"""构造一个 RPC 请求信封。
|
||||
|
||||
Args:
|
||||
method: RPC 方法名。
|
||||
plugin_id: 目标插件 ID。
|
||||
payload: 请求载荷。
|
||||
|
||||
Returns:
|
||||
Envelope: 标准 RPC 请求信封。
|
||||
"""
|
||||
|
||||
return Envelope(
|
||||
request_id=1,
|
||||
message_type=MessageType.REQUEST,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_gateway_runtime_state_binds_send_and_receive_routes(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""消息网关就绪后应同时绑定发送表和接收表。"""
|
||||
|
||||
import src.plugin_runtime.host.supervisor as supervisor_module
|
||||
|
||||
platform_io_manager = PlatformIOManager()
|
||||
monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
|
||||
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
register_response = await supervisor._handle_register_plugin(
|
||||
_make_request(
|
||||
"plugin.register_components",
|
||||
"napcat_plugin",
|
||||
{
|
||||
"plugin_id": "napcat_plugin",
|
||||
"plugin_version": "1.0.0",
|
||||
"components": [
|
||||
{
|
||||
"name": "napcat_gateway",
|
||||
"component_type": "MESSAGE_GATEWAY",
|
||||
"plugin_id": "napcat_plugin",
|
||||
"metadata": {
|
||||
"route_type": "duplex",
|
||||
"platform": "qq",
|
||||
"protocol": "napcat",
|
||||
},
|
||||
}
|
||||
],
|
||||
"capabilities_required": [],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert register_response.error is None
|
||||
response = await supervisor._handle_update_message_gateway_state(
|
||||
_make_request(
|
||||
"host.update_message_gateway_state",
|
||||
"napcat_plugin",
|
||||
{
|
||||
"gateway_name": "napcat_gateway",
|
||||
"ready": True,
|
||||
"platform": "qq",
|
||||
"account_id": "10001",
|
||||
"scope": "primary",
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert response.payload["accepted"] is True
|
||||
|
||||
send_bindings = platform_io_manager.send_route_table.resolve_bindings(
|
||||
RouteKey(platform="qq", account_id="10001", scope="primary")
|
||||
)
|
||||
receive_bindings = platform_io_manager.receive_route_table.resolve_bindings(
|
||||
RouteKey(platform="qq", account_id="10001", scope="primary")
|
||||
)
|
||||
|
||||
assert [binding.driver_id for binding in send_bindings] == ["gateway:napcat_plugin:napcat_gateway"]
|
||||
assert [binding.driver_id for binding in receive_bindings] == ["gateway:napcat_plugin:napcat_gateway"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_gateway_runtime_state_unbinds_routes_when_not_ready(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""消息网关断开后应撤销发送表和接收表中的绑定。"""
|
||||
|
||||
import src.plugin_runtime.host.supervisor as supervisor_module
|
||||
|
||||
platform_io_manager = PlatformIOManager()
|
||||
monkeypatch.setattr(supervisor_module, "get_platform_io_manager", lambda: platform_io_manager)
|
||||
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
await supervisor._handle_register_plugin(
|
||||
_make_request(
|
||||
"plugin.register_components",
|
||||
"napcat_plugin",
|
||||
{
|
||||
"plugin_id": "napcat_plugin",
|
||||
"plugin_version": "1.0.0",
|
||||
"components": [
|
||||
{
|
||||
"name": "napcat_gateway",
|
||||
"component_type": "MESSAGE_GATEWAY",
|
||||
"plugin_id": "napcat_plugin",
|
||||
"metadata": {
|
||||
"route_type": "duplex",
|
||||
"platform": "qq",
|
||||
"protocol": "napcat",
|
||||
},
|
||||
}
|
||||
],
|
||||
"capabilities_required": [],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
await supervisor._handle_update_message_gateway_state(
|
||||
_make_request(
|
||||
"host.update_message_gateway_state",
|
||||
"napcat_plugin",
|
||||
{
|
||||
"gateway_name": "napcat_gateway",
|
||||
"ready": True,
|
||||
"platform": "qq",
|
||||
"account_id": "10001",
|
||||
"scope": "primary",
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
)
|
||||
response = await supervisor._handle_update_message_gateway_state(
|
||||
_make_request(
|
||||
"host.update_message_gateway_state",
|
||||
"napcat_plugin",
|
||||
{
|
||||
"gateway_name": "napcat_gateway",
|
||||
"ready": False,
|
||||
"platform": "qq",
|
||||
"account_id": "",
|
||||
"scope": "",
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert response.payload["accepted"] is True
|
||||
assert platform_io_manager.send_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == []
|
||||
assert (
|
||||
platform_io_manager.receive_route_table.resolve_bindings(RouteKey(platform="qq", account_id="10001")) == []
|
||||
)
|
||||
132
pytests/test_napcat_adapter_sdk.py
Normal file
132
pytests/test_napcat_adapter_sdk.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""NapCat 插件与新 SDK 对接测试。"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
PLUGINS_ROOT = PROJECT_ROOT / "plugins"
|
||||
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
|
||||
|
||||
for import_path in (str(PLUGINS_ROOT), str(SDK_ROOT)):
|
||||
if import_path not in sys.path:
|
||||
sys.path.insert(0, import_path)
|
||||
|
||||
|
||||
class _FakeGatewayCapability:
|
||||
"""用于捕获消息网关状态上报的测试替身。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化测试替身。"""
|
||||
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def update_state(
|
||||
self,
|
||||
gateway_name: str,
|
||||
*,
|
||||
ready: bool,
|
||||
platform: str = "",
|
||||
account_id: str = "",
|
||||
scope: str = "",
|
||||
metadata: Dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""记录一次状态上报请求。
|
||||
|
||||
Args:
|
||||
gateway_name: 网关组件名称。
|
||||
ready: 当前是否就绪。
|
||||
platform: 平台名称。
|
||||
account_id: 账号 ID。
|
||||
scope: 路由作用域。
|
||||
metadata: 附加元数据。
|
||||
|
||||
Returns:
|
||||
bool: 始终返回 ``True``,模拟 Host 接受状态更新。
|
||||
"""
|
||||
|
||||
self.calls.append(
|
||||
{
|
||||
"gateway_name": gateway_name,
|
||||
"ready": ready,
|
||||
"platform": platform,
|
||||
"account_id": account_id,
|
||||
"scope": scope,
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _load_napcat_sdk_symbols() -> tuple[Any, Any, Any, Any]:
|
||||
"""动态加载 NapCat 插件测试所需的符号。
|
||||
|
||||
Returns:
|
||||
tuple[Any, Any, Any, Any]:
|
||||
依次返回网关名常量、配置类、插件类和运行时状态管理器类。
|
||||
"""
|
||||
|
||||
constants_module = importlib.import_module("napcat_adapter.constants")
|
||||
config_module = importlib.import_module("napcat_adapter.config")
|
||||
plugin_module = importlib.import_module("napcat_adapter.plugin")
|
||||
runtime_state_module = importlib.import_module("napcat_adapter.runtime_state")
|
||||
return (
|
||||
constants_module.NAPCAT_GATEWAY_NAME,
|
||||
config_module.NapCatServerConfig,
|
||||
plugin_module.NapCatAdapterPlugin,
|
||||
runtime_state_module.NapCatRuntimeStateManager,
|
||||
)
|
||||
|
||||
|
||||
def test_napcat_plugin_collects_duplex_message_gateway() -> None:
|
||||
"""NapCat 插件应声明新的双工消息网关组件。"""
|
||||
|
||||
napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
|
||||
plugin = napcat_plugin_cls()
|
||||
components = plugin.get_components()
|
||||
gateway_components = [
|
||||
component
|
||||
for component in components
|
||||
if component.get("type") == "MESSAGE_GATEWAY"
|
||||
]
|
||||
|
||||
assert len(gateway_components) == 1
|
||||
gateway_component = gateway_components[0]
|
||||
assert gateway_component["name"] == napcat_gateway_name
|
||||
assert gateway_component["metadata"]["route_type"] == "duplex"
|
||||
assert gateway_component["metadata"]["platform"] == "qq"
|
||||
assert gateway_component["metadata"]["protocol"] == "napcat"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_state_reports_via_gateway_capability() -> None:
|
||||
"""NapCat 运行时状态应通过新的消息网关能力上报。"""
|
||||
|
||||
napcat_gateway_name, napcat_server_config_cls, _napcat_plugin_cls, runtime_state_cls = _load_napcat_sdk_symbols()
|
||||
gateway_capability = _FakeGatewayCapability()
|
||||
runtime_state_manager = runtime_state_cls(
|
||||
gateway_capability=gateway_capability,
|
||||
logger=logging.getLogger("test.napcat_adapter"),
|
||||
gateway_name=napcat_gateway_name,
|
||||
)
|
||||
|
||||
connected = await runtime_state_manager.report_connected(
|
||||
"10001",
|
||||
napcat_server_config_cls(connection_id="primary"),
|
||||
)
|
||||
await runtime_state_manager.report_disconnected()
|
||||
|
||||
assert connected is True
|
||||
assert gateway_capability.calls[0]["gateway_name"] == napcat_gateway_name
|
||||
assert gateway_capability.calls[0]["ready"] is True
|
||||
assert gateway_capability.calls[0]["platform"] == "qq"
|
||||
assert gateway_capability.calls[0]["account_id"] == "10001"
|
||||
assert gateway_capability.calls[0]["scope"] == "primary"
|
||||
assert gateway_capability.calls[1]["gateway_name"] == napcat_gateway_name
|
||||
assert gateway_capability.calls[1]["ready"] is False
|
||||
assert gateway_capability.calls[1]["platform"] == "qq"
|
||||
209
pytests/test_platform_io_dedupe.py
Normal file
209
pytests/test_platform_io_dedupe.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Platform IO 入站去重策略测试。"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from src.platform_io.drivers.base import PlatformIODriver
|
||||
from src.platform_io.manager import PlatformIOManager
|
||||
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey
|
||||
|
||||
|
||||
def _build_envelope(
|
||||
*,
|
||||
dedupe_key: str | None = None,
|
||||
external_message_id: str | None = None,
|
||||
session_message_id: str | None = None,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
) -> InboundMessageEnvelope:
|
||||
"""构造测试用入站信封。
|
||||
|
||||
Args:
|
||||
dedupe_key: 显式去重键。
|
||||
external_message_id: 平台侧消息 ID。
|
||||
session_message_id: 规范化消息对象上的消息 ID。
|
||||
payload: 原始载荷。
|
||||
|
||||
Returns:
|
||||
InboundMessageEnvelope: 测试用入站消息信封。
|
||||
"""
|
||||
session_message = None
|
||||
if session_message_id is not None:
|
||||
session_message = SimpleNamespace(message_id=session_message_id)
|
||||
|
||||
return InboundMessageEnvelope(
|
||||
route_key=RouteKey(platform="qq", account_id="10001", scope="main"),
|
||||
driver_id="plugin.napcat",
|
||||
driver_kind=DriverKind.PLUGIN,
|
||||
dedupe_key=dedupe_key,
|
||||
external_message_id=external_message_id,
|
||||
session_message=session_message,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
|
||||
class _StubPlatformIODriver(PlatformIODriver):
|
||||
"""测试用 Platform IO 驱动。"""
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
route_key: RouteKey,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> DeliveryReceipt:
|
||||
"""返回一个固定的成功回执。
|
||||
|
||||
Args:
|
||||
message: 待发送的消息对象。
|
||||
route_key: 本次发送使用的路由键。
|
||||
metadata: 额外发送元数据。
|
||||
|
||||
Returns:
|
||||
DeliveryReceipt: 固定的成功回执。
|
||||
"""
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=str(getattr(message, "message_id", "stub-message-id")),
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.SENT,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
)
|
||||
|
||||
|
||||
def _build_manager() -> PlatformIOManager:
|
||||
"""构造带有最小接收路由的 Broker 管理器。
|
||||
|
||||
Returns:
|
||||
PlatformIOManager: 已注册测试驱动并绑定接收路由的 Broker。
|
||||
"""
|
||||
manager = PlatformIOManager()
|
||||
driver = _StubPlatformIODriver(
|
||||
DriverDescriptor(
|
||||
driver_id="plugin.napcat",
|
||||
kind=DriverKind.PLUGIN,
|
||||
platform="qq",
|
||||
account_id="10001",
|
||||
scope="main",
|
||||
)
|
||||
)
|
||||
manager.register_driver(driver)
|
||||
manager.bind_receive_route(
|
||||
RouteBinding(
|
||||
route_key=RouteKey(platform="qq", account_id="10001", scope="main"),
|
||||
driver_id=driver.driver_id,
|
||||
driver_kind=driver.descriptor.kind,
|
||||
)
|
||||
)
|
||||
return manager
|
||||
|
||||
|
||||
class TestPlatformIODedupe:
|
||||
"""Platform IO 去重测试。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_inbound_dedupes_by_external_message_id(self) -> None:
|
||||
"""相同平台消息 ID 的重复入站应被抑制。"""
|
||||
manager = _build_manager()
|
||||
accepted_envelopes: List[InboundMessageEnvelope] = []
|
||||
|
||||
async def dispatcher(envelope: InboundMessageEnvelope) -> None:
|
||||
"""记录被成功接收的入站消息。
|
||||
|
||||
Args:
|
||||
envelope: 被 Broker 接受的入站消息。
|
||||
"""
|
||||
accepted_envelopes.append(envelope)
|
||||
|
||||
manager.set_inbound_dispatcher(dispatcher)
|
||||
|
||||
first_envelope = _build_envelope(
|
||||
external_message_id="msg-1",
|
||||
payload={"message": "hello"},
|
||||
)
|
||||
second_envelope = _build_envelope(
|
||||
external_message_id="msg-1",
|
||||
payload={"message": "hello"},
|
||||
)
|
||||
|
||||
assert await manager.accept_inbound(first_envelope) is True
|
||||
assert await manager.accept_inbound(second_envelope) is False
|
||||
assert len(accepted_envelopes) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_inbound_without_stable_identity_does_not_guess_duplicate(self) -> None:
|
||||
"""缺少稳定身份时,不应仅凭 payload 内容猜测重复消息。"""
|
||||
manager = _build_manager()
|
||||
accepted_envelopes: List[InboundMessageEnvelope] = []
|
||||
|
||||
async def dispatcher(envelope: InboundMessageEnvelope) -> None:
|
||||
"""记录被成功接收的入站消息。
|
||||
|
||||
Args:
|
||||
envelope: 被 Broker 接受的入站消息。
|
||||
"""
|
||||
accepted_envelopes.append(envelope)
|
||||
|
||||
manager.set_inbound_dispatcher(dispatcher)
|
||||
|
||||
first_envelope = _build_envelope(payload={"message": "same-payload"})
|
||||
second_envelope = _build_envelope(payload={"message": "same-payload"})
|
||||
|
||||
assert await manager.accept_inbound(first_envelope) is True
|
||||
assert await manager.accept_inbound(second_envelope) is True
|
||||
assert len(accepted_envelopes) == 2
|
||||
|
||||
def test_build_inbound_dedupe_key_prefers_explicit_identity(self) -> None:
|
||||
"""去重键应只来自显式或稳定的技术身份。"""
|
||||
explicit_envelope = _build_envelope(dedupe_key="dedupe-1", external_message_id="msg-1")
|
||||
session_message_envelope = _build_envelope(session_message_id="session-1")
|
||||
payload_only_envelope = _build_envelope(payload={"message": "hello"})
|
||||
|
||||
assert PlatformIOManager._build_inbound_dedupe_key(explicit_envelope) == "plugin.napcat:dedupe-1"
|
||||
assert PlatformIOManager._build_inbound_dedupe_key(session_message_envelope) == "plugin.napcat:session-1"
|
||||
assert PlatformIOManager._build_inbound_dedupe_key(payload_only_envelope) is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_fans_out_to_all_matching_routes(self) -> None:
|
||||
"""同一路由命中多条发送链路时应全部发送。"""
|
||||
|
||||
manager = PlatformIOManager()
|
||||
first_driver = _StubPlatformIODriver(
|
||||
DriverDescriptor(
|
||||
driver_id="plugin.gateway_a",
|
||||
kind=DriverKind.PLUGIN,
|
||||
platform="qq",
|
||||
)
|
||||
)
|
||||
second_driver = _StubPlatformIODriver(
|
||||
DriverDescriptor(
|
||||
driver_id="plugin.gateway_b",
|
||||
kind=DriverKind.PLUGIN,
|
||||
platform="qq",
|
||||
)
|
||||
)
|
||||
manager.register_driver(first_driver)
|
||||
manager.register_driver(second_driver)
|
||||
manager.bind_send_route(
|
||||
RouteBinding(
|
||||
route_key=RouteKey(platform="qq"),
|
||||
driver_id=first_driver.driver_id,
|
||||
driver_kind=first_driver.descriptor.kind,
|
||||
)
|
||||
)
|
||||
manager.bind_send_route(
|
||||
RouteBinding(
|
||||
route_key=RouteKey(platform="qq"),
|
||||
driver_id=second_driver.driver_id,
|
||||
driver_kind=second_driver.descriptor.kind,
|
||||
)
|
||||
)
|
||||
|
||||
message = SimpleNamespace(message_id="internal-msg-1")
|
||||
result = await manager.send_message(message, RouteKey(platform="qq"))
|
||||
|
||||
assert result.has_success is True
|
||||
assert [receipt.driver_id for receipt in result.sent_receipts] == [
|
||||
"plugin.gateway_a",
|
||||
"plugin.gateway_b",
|
||||
]
|
||||
178
pytests/test_platform_io_legacy_driver.py
Normal file
178
pytests/test_platform_io_legacy_driver.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Platform IO legacy driver 回归测试。"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from src.chat.utils import utils as chat_utils
|
||||
from src.chat.message_receive import uni_message_sender
|
||||
from src.platform_io.drivers.base import PlatformIODriver
|
||||
from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver
|
||||
from src.platform_io.manager import PlatformIOManager
|
||||
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteBinding, RouteKey
|
||||
|
||||
|
||||
class _PluginDriver(PlatformIODriver):
|
||||
"""测试用插件发送驱动。"""
|
||||
|
||||
def __init__(self, driver_id: str, platform: str) -> None:
|
||||
"""初始化测试驱动。
|
||||
|
||||
Args:
|
||||
driver_id: 驱动 ID。
|
||||
platform: 负责的平台名称。
|
||||
"""
|
||||
super().__init__(
|
||||
DriverDescriptor(
|
||||
driver_id=driver_id,
|
||||
kind=DriverKind.PLUGIN,
|
||||
platform=platform,
|
||||
plugin_id="test.plugin",
|
||||
)
|
||||
)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
route_key: RouteKey,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> DeliveryReceipt:
|
||||
"""返回一个固定成功回执。
|
||||
|
||||
Args:
|
||||
message: 待发送消息。
|
||||
route_key: 当前路由键。
|
||||
metadata: 发送元数据。
|
||||
|
||||
Returns:
|
||||
DeliveryReceipt: 固定成功回执。
|
||||
"""
|
||||
del metadata
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=str(message.message_id),
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.SENT,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_io_uses_legacy_driver_when_no_explicit_send_route(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""没有显式发送路由时,应由 Platform IO 回退到 legacy driver。"""
|
||||
manager = PlatformIOManager()
|
||||
monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"})
|
||||
|
||||
try:
|
||||
await manager.ensure_send_pipeline_ready()
|
||||
|
||||
fallback_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
|
||||
assert [driver.driver_id for driver in fallback_drivers] == ["legacy.send.qq"]
|
||||
|
||||
plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq")
|
||||
await manager.add_driver(plugin_driver)
|
||||
manager.bind_send_route(
|
||||
RouteBinding(
|
||||
route_key=RouteKey(platform="qq"),
|
||||
driver_id=plugin_driver.driver_id,
|
||||
driver_kind=plugin_driver.descriptor.kind,
|
||||
)
|
||||
)
|
||||
|
||||
explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
|
||||
assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender", "legacy.send.qq"]
|
||||
finally:
|
||||
await manager.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_io_broadcasts_to_plugin_and_legacy_driver(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""同一路由命中插件驱动与 legacy driver 时,应同时广播发送。"""
|
||||
|
||||
manager = PlatformIOManager()
|
||||
legacy_calls: list[dict[str, Any]] = []
|
||||
monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"})
|
||||
|
||||
async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool:
|
||||
"""记录 legacy driver 调用。"""
|
||||
|
||||
legacy_calls.append({"message": message, "show_log": show_log})
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(
|
||||
uni_message_sender,
|
||||
"send_prepared_message_to_platform",
|
||||
_fake_send_prepared_message_to_platform,
|
||||
)
|
||||
|
||||
try:
|
||||
await manager.ensure_send_pipeline_ready()
|
||||
|
||||
plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq")
|
||||
await manager.add_driver(plugin_driver)
|
||||
manager.bind_send_route(
|
||||
RouteBinding(
|
||||
route_key=RouteKey(platform="qq"),
|
||||
driver_id=plugin_driver.driver_id,
|
||||
driver_kind=plugin_driver.descriptor.kind,
|
||||
)
|
||||
)
|
||||
|
||||
message = type("FakeMessage", (), {"message_id": "message-1"})()
|
||||
batch = await manager.send_message(
|
||||
message=message,
|
||||
route_key=RouteKey(platform="qq"),
|
||||
metadata={"show_log": False},
|
||||
)
|
||||
|
||||
assert sorted(receipt.driver_id for receipt in batch.sent_receipts) == [
|
||||
"legacy.send.qq",
|
||||
"plugin.qq.sender",
|
||||
]
|
||||
assert batch.failed_receipts == []
|
||||
assert len(legacy_calls) == 1
|
||||
assert legacy_calls[0]["message"] is message
|
||||
assert legacy_calls[0]["show_log"] is False
|
||||
finally:
|
||||
await manager.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_platform_driver_uses_prepared_universal_sender(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""legacy driver 应复用已预处理消息的旧链发送函数。"""
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool:
|
||||
"""记录 legacy driver 调用。"""
|
||||
calls.append({"message": message, "show_log": show_log})
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(
|
||||
uni_message_sender,
|
||||
"send_prepared_message_to_platform",
|
||||
_fake_send_prepared_message_to_platform,
|
||||
)
|
||||
|
||||
driver = LegacyPlatformDriver(
|
||||
driver_id="legacy.send.qq",
|
||||
platform="qq",
|
||||
account_id="bot-qq",
|
||||
)
|
||||
message = type("FakeMessage", (), {"message_id": "message-1"})()
|
||||
receipt = await driver.send_message(
|
||||
message=message,
|
||||
route_key=RouteKey(platform="qq"),
|
||||
metadata={"show_log": False},
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["message"] is message
|
||||
assert calls[0]["show_log"] is False
|
||||
assert receipt.status == DeliveryStatus.SENT
|
||||
assert receipt.driver_id == "legacy.send.qq"
|
||||
87
pytests/test_plugin_message_utils_runtime.py
Normal file
87
pytests/test_plugin_message_utils_runtime.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import sys
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo, UserInfo
|
||||
from src.common.data_models.message_component_data_model import (
|
||||
ForwardComponent,
|
||||
ForwardNodeComponent,
|
||||
ImageComponent,
|
||||
MessageSequence,
|
||||
ReplyComponent,
|
||||
TextComponent,
|
||||
VoiceComponent,
|
||||
)
|
||||
from src.plugin_runtime.host.message_utils import PluginMessageUtils
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
|
||||
def test_plugin_message_utils_preserves_binary_components_and_reply_metadata() -> None:
|
||||
message = SessionMessage(message_id="msg-1", timestamp=datetime.now(), platform="qq")
|
||||
message.message_info = MessageInfo(
|
||||
user_info=UserInfo(user_id="10001", user_nickname="tester"),
|
||||
group_info=GroupInfo(group_id="20001", group_name="group"),
|
||||
additional_config={"self_id": "999"},
|
||||
)
|
||||
message.session_id = "qq:20001:10001"
|
||||
message.processed_plain_text = "binary payload"
|
||||
message.display_message = "binary payload"
|
||||
message.raw_message = MessageSequence(
|
||||
components=[
|
||||
TextComponent("hello"),
|
||||
ImageComponent(binary_hash="", binary_data=b"image-bytes", content=""),
|
||||
VoiceComponent(binary_hash="", binary_data=b"voice-bytes", content=""),
|
||||
ReplyComponent(
|
||||
target_message_id="origin-1",
|
||||
target_message_content="origin text",
|
||||
target_message_sender_id="42",
|
||||
target_message_sender_nickname="alice",
|
||||
target_message_sender_cardname="Alice",
|
||||
),
|
||||
ForwardNodeComponent(
|
||||
forward_components=[
|
||||
ForwardComponent(
|
||||
user_nickname="bob",
|
||||
user_id="43",
|
||||
user_cardname="Bob",
|
||||
message_id="forward-1",
|
||||
content=[
|
||||
TextComponent("node-text"),
|
||||
ImageComponent(binary_hash="", binary_data=b"node-image", content=""),
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
message_dict = PluginMessageUtils._session_message_to_dict(message)
|
||||
rebuilt_message = PluginMessageUtils._build_session_message_from_dict(dict(message_dict))
|
||||
|
||||
image_component = rebuilt_message.raw_message.components[1]
|
||||
voice_component = rebuilt_message.raw_message.components[2]
|
||||
reply_component = rebuilt_message.raw_message.components[3]
|
||||
forward_component = rebuilt_message.raw_message.components[4]
|
||||
|
||||
assert isinstance(image_component, ImageComponent)
|
||||
assert image_component.binary_data == b"image-bytes"
|
||||
|
||||
assert isinstance(voice_component, VoiceComponent)
|
||||
assert voice_component.binary_data == b"voice-bytes"
|
||||
|
||||
assert isinstance(reply_component, ReplyComponent)
|
||||
assert reply_component.target_message_id == "origin-1"
|
||||
assert reply_component.target_message_content == "origin text"
|
||||
assert reply_component.target_message_sender_id == "42"
|
||||
assert reply_component.target_message_sender_nickname == "alice"
|
||||
assert reply_component.target_message_sender_cardname == "Alice"
|
||||
|
||||
assert isinstance(forward_component, ForwardNodeComponent)
|
||||
assert isinstance(forward_component.forward_components[0].content[1], ImageComponent)
|
||||
assert forward_component.forward_components[0].content[1].binary_data == b"node-image"
|
||||
File diff suppressed because it is too large
Load Diff
284
pytests/test_plugin_runtime_action_bridge.py
Normal file
284
pytests/test_plugin_runtime_action_bridge.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""核心组件查询层与插件运行时聚合测试。"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
import src.plugin_runtime.integration as integration_module
|
||||
|
||||
from src.core.types import ActionInfo, ToolInfo
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
|
||||
class _FakeRuntimeManager:
|
||||
"""测试用插件运行时管理器。"""
|
||||
|
||||
def __init__(self, supervisor: PluginSupervisor, plugin_id: str, plugin_config: dict[str, Any]) -> None:
|
||||
"""初始化测试用运行时管理器。
|
||||
|
||||
Args:
|
||||
supervisor: 持有测试组件的监督器。
|
||||
plugin_id: 目标插件 ID。
|
||||
plugin_config: 需要返回的插件配置。
|
||||
"""
|
||||
|
||||
self.supervisors = [supervisor]
|
||||
self._plugin_id = plugin_id
|
||||
self._plugin_config = plugin_config
|
||||
|
||||
def _get_supervisor_for_plugin(self, plugin_id: str) -> PluginSupervisor | None:
|
||||
"""按插件 ID 返回对应监督器。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
|
||||
Returns:
|
||||
PluginSupervisor | None: 命中时返回监督器。
|
||||
"""
|
||||
|
||||
return self.supervisors[0] if plugin_id == self._plugin_id else None
|
||||
|
||||
def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> dict[str, Any]:
|
||||
"""返回测试配置。
|
||||
|
||||
Args:
|
||||
supervisor: 监督器实例。
|
||||
plugin_id: 目标插件 ID。
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: 测试配置内容。
|
||||
"""
|
||||
|
||||
del supervisor
|
||||
if plugin_id != self._plugin_id:
|
||||
return {}
|
||||
return dict(self._plugin_config)
|
||||
|
||||
|
||||
def _install_runtime_manager(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
supervisor: PluginSupervisor,
|
||||
plugin_id: str,
|
||||
plugin_config: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""为测试安装假的运行时管理器。
|
||||
|
||||
Args:
|
||||
monkeypatch: pytest monkeypatch 对象。
|
||||
supervisor: 持有测试组件的监督器。
|
||||
plugin_id: 测试插件 ID。
|
||||
plugin_config: 可选的测试配置内容。
|
||||
"""
|
||||
|
||||
fake_manager = _FakeRuntimeManager(supervisor, plugin_id, plugin_config or {"enabled": True})
|
||||
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: fake_manager)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_core_component_registry_reads_runtime_action_and_executor(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""核心查询层应直接读取运行时 Action,并返回 RPC 执行闭包。"""
|
||||
|
||||
plugin_id = "runtime_action_bridge_plugin"
|
||||
action_name = "runtime_action_bridge_test"
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
supervisor.component_registry.register_component(
|
||||
name=action_name,
|
||||
component_type="ACTION",
|
||||
plugin_id=plugin_id,
|
||||
metadata={
|
||||
"description": "发送一个测试回复",
|
||||
"enabled": True,
|
||||
"activation_type": "keyword",
|
||||
"activation_probability": 0.25,
|
||||
"activation_keywords": ["测试", "hello"],
|
||||
"action_parameters": {"target": "目标对象"},
|
||||
"action_require": ["需要发送回复时使用"],
|
||||
"associated_types": ["text"],
|
||||
"parallel_action": True,
|
||||
},
|
||||
)
|
||||
_install_runtime_manager(monkeypatch, supervisor, plugin_id, {"enabled": True, "mode": "test"})
|
||||
|
||||
async def fake_invoke_plugin(
|
||||
method: str,
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""模拟动作 RPC 调用。"""
|
||||
|
||||
captured["method"] = method
|
||||
captured["plugin_id"] = plugin_id
|
||||
captured["component_name"] = component_name
|
||||
captured["args"] = args or {}
|
||||
captured["timeout_ms"] = timeout_ms
|
||||
return SimpleNamespace(payload={"success": True, "result": (True, "runtime action executed")})
|
||||
|
||||
monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
|
||||
|
||||
action_info = component_query_service.get_action_info(action_name)
|
||||
assert isinstance(action_info, ActionInfo)
|
||||
assert action_info.plugin_name == plugin_id
|
||||
assert action_info.description == "发送一个测试回复"
|
||||
assert action_info.activation_keywords == ["测试", "hello"]
|
||||
assert action_info.random_activation_probability == 0.25
|
||||
assert action_info.parallel_action is True
|
||||
assert action_name in component_query_service.get_default_actions()
|
||||
assert component_query_service.get_plugin_config(plugin_id) == {"enabled": True, "mode": "test"}
|
||||
|
||||
executor = component_query_service.get_action_executor(action_name)
|
||||
assert executor is not None
|
||||
|
||||
success, reason = await executor(
|
||||
action_data={"target": "MaiBot"},
|
||||
action_reasoning="当前适合使用这个动作",
|
||||
cycle_timers={"planner": 0.1},
|
||||
thinking_id="tid-1",
|
||||
chat_stream=SimpleNamespace(session_id="stream-1"),
|
||||
log_prefix="[test]",
|
||||
shutting_down=False,
|
||||
plugin_config={"enabled": True},
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert reason == "runtime action executed"
|
||||
assert captured["method"] == "plugin.invoke_action"
|
||||
assert captured["plugin_id"] == plugin_id
|
||||
assert captured["component_name"] == action_name
|
||||
assert captured["args"]["stream_id"] == "stream-1"
|
||||
assert captured["args"]["chat_id"] == "stream-1"
|
||||
assert captured["args"]["reasoning"] == "当前适合使用这个动作"
|
||||
assert captured["args"]["target"] == "MaiBot"
|
||||
assert captured["args"]["action_data"] == {"target": "MaiBot"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_core_component_registry_reads_runtime_command_and_executor(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""核心查询层应直接使用运行时命令匹配与执行闭包。"""
|
||||
|
||||
plugin_id = "runtime_command_bridge_plugin"
|
||||
command_name = "runtime_command_bridge_test"
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
supervisor.component_registry.register_component(
|
||||
name=command_name,
|
||||
component_type="COMMAND",
|
||||
plugin_id=plugin_id,
|
||||
metadata={
|
||||
"description": "测试命令",
|
||||
"enabled": True,
|
||||
"command_pattern": r"^/test(?:\s+.+)?$",
|
||||
"aliases": ["/hello"],
|
||||
"intercept_message_level": 1,
|
||||
},
|
||||
)
|
||||
_install_runtime_manager(monkeypatch, supervisor, plugin_id, {"mode": "command"})
|
||||
|
||||
async def fake_invoke_plugin(
|
||||
method: str,
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""模拟命令 RPC 调用。"""
|
||||
|
||||
captured["method"] = method
|
||||
captured["plugin_id"] = plugin_id
|
||||
captured["component_name"] = component_name
|
||||
captured["args"] = args or {}
|
||||
captured["timeout_ms"] = timeout_ms
|
||||
return SimpleNamespace(payload={"success": True, "result": (True, "command ok", True)})
|
||||
|
||||
monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
|
||||
|
||||
matched = component_query_service.find_command_by_text("/test hello")
|
||||
assert matched is not None
|
||||
command_executor, matched_groups, command_info = matched
|
||||
|
||||
assert matched_groups == {}
|
||||
assert command_info.plugin_name == plugin_id
|
||||
assert command_info.command_pattern == r"^/test(?:\s+.+)?$"
|
||||
|
||||
success, response_text, intercept = await command_executor(
|
||||
message=SimpleNamespace(processed_plain_text="/test hello", session_id="stream-2"),
|
||||
plugin_config={"mode": "command"},
|
||||
matched_groups=matched_groups,
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert response_text == "command ok"
|
||||
assert intercept is True
|
||||
assert captured["method"] == "plugin.invoke_command"
|
||||
assert captured["plugin_id"] == plugin_id
|
||||
assert captured["component_name"] == command_name
|
||||
assert captured["args"]["text"] == "/test hello"
|
||||
assert captured["args"]["stream_id"] == "stream-2"
|
||||
assert captured["args"]["plugin_config"] == {"mode": "command"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_core_component_registry_reads_runtime_tools_and_executor(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""核心查询层应直接读取运行时 Tool,并返回 RPC 执行闭包。"""
|
||||
|
||||
plugin_id = "runtime_tool_bridge_plugin"
|
||||
tool_name = "runtime_tool_bridge_test"
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
|
||||
supervisor.component_registry.register_component(
|
||||
name=tool_name,
|
||||
component_type="TOOL",
|
||||
plugin_id=plugin_id,
|
||||
metadata={
|
||||
"description": "测试工具",
|
||||
"enabled": True,
|
||||
"parameters": [
|
||||
{
|
||||
"name": "query",
|
||||
"param_type": "string",
|
||||
"description": "查询词",
|
||||
"required": True,
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
_install_runtime_manager(monkeypatch, supervisor, plugin_id)
|
||||
|
||||
async def fake_invoke_plugin(
|
||||
method: str,
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""模拟工具 RPC 调用。"""
|
||||
|
||||
del timeout_ms
|
||||
assert method == "plugin.invoke_tool"
|
||||
assert plugin_id == "runtime_tool_bridge_plugin"
|
||||
assert component_name == "runtime_tool_bridge_test"
|
||||
assert args == {"query": "MaiBot"}
|
||||
return SimpleNamespace(payload={"success": True, "result": {"content": "tool ok"}})
|
||||
|
||||
monkeypatch.setattr(supervisor, "invoke_plugin", fake_invoke_plugin)
|
||||
|
||||
tool_info = component_query_service.get_tool_info(tool_name)
|
||||
assert isinstance(tool_info, ToolInfo)
|
||||
assert tool_info.tool_description == "测试工具"
|
||||
assert tool_name in component_query_service.get_llm_available_tools()
|
||||
|
||||
executor = component_query_service.get_tool_executor(tool_name)
|
||||
assert executor is not None
|
||||
assert await executor({"query": "MaiBot"}) == {"content": "tool ok"}
|
||||
524
pytests/test_plugin_runtime_api.py
Normal file
524
pytests/test_plugin_runtime_api.py
Normal file
@@ -0,0 +1,524 @@
|
||||
"""插件 API 注册与调用测试。"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from src.plugin_runtime.integration import PluginRuntimeManager
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
ComponentDeclaration,
|
||||
Envelope,
|
||||
MessageType,
|
||||
RegisterPluginPayload,
|
||||
UnregisterPluginPayload,
|
||||
)
|
||||
|
||||
|
||||
def _build_manager(*supervisors: PluginSupervisor) -> PluginRuntimeManager:
|
||||
"""构造一个最小可用的插件运行时管理器。
|
||||
|
||||
Args:
|
||||
*supervisors: 需要挂载的监督器列表。
|
||||
|
||||
Returns:
|
||||
PluginRuntimeManager: 已注入监督器的运行时管理器。
|
||||
"""
|
||||
|
||||
manager = PluginRuntimeManager()
|
||||
if supervisors:
|
||||
manager._builtin_supervisor = supervisors[0]
|
||||
if len(supervisors) > 1:
|
||||
manager._third_party_supervisor = supervisors[1]
|
||||
return manager
|
||||
|
||||
|
||||
async def _register_plugin(
|
||||
supervisor: PluginSupervisor,
|
||||
plugin_id: str,
|
||||
components: List[Dict[str, Any]],
|
||||
) -> Envelope:
|
||||
"""通过 Supervisor 注册测试插件。
|
||||
|
||||
Args:
|
||||
supervisor: 目标监督器。
|
||||
plugin_id: 测试插件 ID。
|
||||
components: 组件声明列表。
|
||||
|
||||
Returns:
|
||||
Envelope: 注册响应信封。
|
||||
"""
|
||||
|
||||
payload = RegisterPluginPayload(
|
||||
plugin_id=plugin_id,
|
||||
plugin_version="1.0.0",
|
||||
components=[
|
||||
ComponentDeclaration(
|
||||
name=str(component.get("name", "") or ""),
|
||||
component_type=str(component.get("component_type", "") or ""),
|
||||
plugin_id=plugin_id,
|
||||
metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {},
|
||||
)
|
||||
for component in components
|
||||
],
|
||||
)
|
||||
return await supervisor._handle_register_plugin(
|
||||
Envelope(
|
||||
request_id=1,
|
||||
message_type=MessageType.REQUEST,
|
||||
method="plugin.register_components",
|
||||
plugin_id=plugin_id,
|
||||
payload=payload.model_dump(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _unregister_plugin(supervisor: PluginSupervisor, plugin_id: str) -> Envelope:
|
||||
"""通过 Supervisor 注销测试插件。
|
||||
|
||||
Args:
|
||||
supervisor: 目标监督器。
|
||||
plugin_id: 测试插件 ID。
|
||||
|
||||
Returns:
|
||||
Envelope: 注销响应信封。
|
||||
"""
|
||||
|
||||
payload = UnregisterPluginPayload(plugin_id=plugin_id, reason="test")
|
||||
return await supervisor._handle_unregister_plugin(
|
||||
Envelope(
|
||||
request_id=2,
|
||||
message_type=MessageType.REQUEST,
|
||||
method="plugin.unregister",
|
||||
plugin_id=plugin_id,
|
||||
payload=payload.model_dump(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_plugin_syncs_dedicated_api_registry() -> None:
|
||||
"""插件注册时应将 API 同步到独立注册表,而不是通用组件表。"""
|
||||
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
response = await _register_plugin(
|
||||
supervisor,
|
||||
"provider",
|
||||
[
|
||||
{
|
||||
"name": "render_html",
|
||||
"component_type": "API",
|
||||
"metadata": {
|
||||
"description": "渲染 HTML",
|
||||
"version": "1",
|
||||
"public": True,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
assert response.payload["accepted"] is True
|
||||
assert response.payload["registered_components"] == 0
|
||||
assert response.payload["registered_apis"] == 1
|
||||
assert supervisor.api_registry.get_api("provider", "render_html") is not None
|
||||
assert supervisor.component_registry.get_component("provider.render_html") is None
|
||||
|
||||
unregister_response = await _unregister_plugin(supervisor, "provider")
|
||||
assert unregister_response.payload["removed_apis"] == 1
|
||||
assert supervisor.api_registry.get_api("provider", "render_html") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_call_allows_public_api_between_plugins(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""公开 API 应允许其他插件通过 Host 转发调用。"""
|
||||
|
||||
provider_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
await _register_plugin(
|
||||
provider_supervisor,
|
||||
"provider",
|
||||
[
|
||||
{
|
||||
"name": "render_html",
|
||||
"component_type": "API",
|
||||
"metadata": {
|
||||
"description": "渲染 HTML",
|
||||
"version": "1",
|
||||
"public": True,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
await _register_plugin(consumer_supervisor, "consumer", [])
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
async def fake_invoke_api(
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: Dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""模拟 API RPC 调用。"""
|
||||
|
||||
captured["plugin_id"] = plugin_id
|
||||
captured["component_name"] = component_name
|
||||
captured["args"] = args or {}
|
||||
captured["timeout_ms"] = timeout_ms
|
||||
return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}})
|
||||
|
||||
monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api)
|
||||
|
||||
manager = _build_manager(provider_supervisor, consumer_supervisor)
|
||||
result = await manager._cap_api_call(
|
||||
"consumer",
|
||||
"api.call",
|
||||
{
|
||||
"api_name": "provider.render_html",
|
||||
"version": "1",
|
||||
"args": {"html": "<div>Hello</div>"},
|
||||
},
|
||||
)
|
||||
|
||||
assert result == {"success": True, "result": {"image": "ok"}}
|
||||
assert captured["plugin_id"] == "provider"
|
||||
assert captured["component_name"] == "render_html"
|
||||
assert captured["args"] == {"html": "<div>Hello</div>"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_call_rejects_private_api_between_plugins() -> None:
|
||||
"""未公开的 API 默认不允许跨插件调用。"""
|
||||
|
||||
provider_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
await _register_plugin(
|
||||
provider_supervisor,
|
||||
"provider",
|
||||
[
|
||||
{
|
||||
"name": "secret_api",
|
||||
"component_type": "API",
|
||||
"metadata": {
|
||||
"description": "私有 API",
|
||||
"version": "1",
|
||||
"public": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
await _register_plugin(consumer_supervisor, "consumer", [])
|
||||
|
||||
manager = _build_manager(provider_supervisor, consumer_supervisor)
|
||||
result = await manager._cap_api_call(
|
||||
"consumer",
|
||||
"api.call",
|
||||
{
|
||||
"api_name": "provider.secret_api",
|
||||
"args": {},
|
||||
},
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "未公开" in str(result["error"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_list_and_component_toggle_use_dedicated_registry() -> None:
|
||||
"""API 列表与组件启停应直接作用于独立 API 注册表。"""
|
||||
|
||||
provider_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
await _register_plugin(
|
||||
provider_supervisor,
|
||||
"provider",
|
||||
[
|
||||
{
|
||||
"name": "public_api",
|
||||
"component_type": "API",
|
||||
"metadata": {"version": "1", "public": True},
|
||||
},
|
||||
{
|
||||
"name": "private_api",
|
||||
"component_type": "API",
|
||||
"metadata": {"version": "1", "public": False},
|
||||
},
|
||||
],
|
||||
)
|
||||
await _register_plugin(
|
||||
consumer_supervisor,
|
||||
"consumer",
|
||||
[
|
||||
{
|
||||
"name": "self_private_api",
|
||||
"component_type": "API",
|
||||
"metadata": {"version": "1", "public": False},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
manager = _build_manager(provider_supervisor, consumer_supervisor)
|
||||
list_result = await manager._cap_api_list("consumer", "api.list", {})
|
||||
|
||||
assert list_result["success"] is True
|
||||
api_names = {(item["plugin_id"], item["name"]) for item in list_result["apis"]}
|
||||
assert ("provider", "public_api") in api_names
|
||||
assert ("provider", "private_api") not in api_names
|
||||
assert ("consumer", "self_private_api") in api_names
|
||||
|
||||
disable_result = await manager._cap_component_disable(
|
||||
"consumer",
|
||||
"component.disable",
|
||||
{
|
||||
"name": "provider.public_api",
|
||||
"component_type": "API",
|
||||
"scope": "global",
|
||||
"stream_id": "",
|
||||
},
|
||||
)
|
||||
assert disable_result["success"] is True
|
||||
assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is None
|
||||
|
||||
enable_result = await manager._cap_component_enable(
|
||||
"consumer",
|
||||
"component.enable",
|
||||
{
|
||||
"name": "provider.public_api",
|
||||
"component_type": "API",
|
||||
"scope": "global",
|
||||
"stream_id": "",
|
||||
},
|
||||
)
|
||||
assert enable_result["success"] is True
|
||||
assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_registry_supports_multiple_versions_with_distinct_handlers(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""同名 API 不同版本应可并存,并按版本路由到不同处理器。"""
|
||||
|
||||
provider_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
await _register_plugin(
|
||||
provider_supervisor,
|
||||
"provider",
|
||||
[
|
||||
{
|
||||
"name": "render_html",
|
||||
"component_type": "API",
|
||||
"metadata": {
|
||||
"description": "渲染 HTML v1",
|
||||
"version": "1",
|
||||
"public": True,
|
||||
"handler_name": "handle_render_html_v1",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "render_html",
|
||||
"component_type": "API",
|
||||
"metadata": {
|
||||
"description": "渲染 HTML v2",
|
||||
"version": "2",
|
||||
"public": True,
|
||||
"handler_name": "handle_render_html_v2",
|
||||
},
|
||||
},
|
||||
],
|
||||
)
|
||||
await _register_plugin(consumer_supervisor, "consumer", [])
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
async def fake_invoke_api(
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: Dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""模拟多版本 API 调用。"""
|
||||
|
||||
captured["plugin_id"] = plugin_id
|
||||
captured["component_name"] = component_name
|
||||
captured["args"] = args or {}
|
||||
captured["timeout_ms"] = timeout_ms
|
||||
return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}})
|
||||
|
||||
monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api)
|
||||
manager = _build_manager(provider_supervisor, consumer_supervisor)
|
||||
|
||||
ambiguous_result = await manager._cap_api_call(
|
||||
"consumer",
|
||||
"api.call",
|
||||
{
|
||||
"api_name": "provider.render_html",
|
||||
"args": {"html": "<div>Hello</div>"},
|
||||
},
|
||||
)
|
||||
assert ambiguous_result["success"] is False
|
||||
assert "多个版本" in str(ambiguous_result["error"])
|
||||
|
||||
disable_ambiguous_result = await manager._cap_component_disable(
|
||||
"consumer",
|
||||
"component.disable",
|
||||
{
|
||||
"name": "provider.render_html",
|
||||
"component_type": "API",
|
||||
"scope": "global",
|
||||
"stream_id": "",
|
||||
},
|
||||
)
|
||||
assert disable_ambiguous_result["success"] is False
|
||||
assert "多个版本" in str(disable_ambiguous_result["error"])
|
||||
|
||||
disable_v1_result = await manager._cap_component_disable(
|
||||
"consumer",
|
||||
"component.disable",
|
||||
{
|
||||
"name": "provider.render_html",
|
||||
"component_type": "API",
|
||||
"scope": "global",
|
||||
"stream_id": "",
|
||||
"version": "1",
|
||||
},
|
||||
)
|
||||
assert disable_v1_result["success"] is True
|
||||
assert provider_supervisor.api_registry.get_api("provider", "render_html", version="1", enabled_only=True) is None
|
||||
assert provider_supervisor.api_registry.get_api("provider", "render_html", version="2", enabled_only=True) is not None
|
||||
|
||||
result = await manager._cap_api_call(
|
||||
"consumer",
|
||||
"api.call",
|
||||
{
|
||||
"api_name": "provider.render_html",
|
||||
"version": "2",
|
||||
"args": {"html": "<div>Hello</div>"},
|
||||
},
|
||||
)
|
||||
|
||||
assert result == {"success": True, "result": {"image": "ok"}}
|
||||
assert captured["plugin_id"] == "provider"
|
||||
assert captured["component_name"] == "handle_render_html_v2"
|
||||
assert captured["args"] == {"html": "<div>Hello</div>"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_replace_dynamic_can_offline_removed_entries(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""动态 API 替换后,被移除的 API 应返回明确下线错误。"""
|
||||
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
await _register_plugin(supervisor, "provider", [])
|
||||
manager = _build_manager(supervisor)
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
async def fake_invoke_api(
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: Dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""模拟动态 API 调用。"""
|
||||
|
||||
captured["plugin_id"] = plugin_id
|
||||
captured["component_name"] = component_name
|
||||
captured["args"] = args or {}
|
||||
captured["timeout_ms"] = timeout_ms
|
||||
return SimpleNamespace(error=None, payload={"success": True, "result": {"ok": True}})
|
||||
|
||||
monkeypatch.setattr(supervisor, "invoke_api", fake_invoke_api)
|
||||
|
||||
replace_result = await manager._cap_api_replace_dynamic(
|
||||
"provider",
|
||||
"api.replace_dynamic",
|
||||
{
|
||||
"apis": [
|
||||
{
|
||||
"name": "mcp.search",
|
||||
"type": "API",
|
||||
"metadata": {
|
||||
"version": "1",
|
||||
"public": True,
|
||||
"handler_name": "dynamic_search",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "mcp.read",
|
||||
"type": "API",
|
||||
"metadata": {
|
||||
"version": "1",
|
||||
"public": True,
|
||||
"handler_name": "dynamic_read",
|
||||
},
|
||||
},
|
||||
],
|
||||
"offline_reason": "MCP 服务器已关闭",
|
||||
},
|
||||
)
|
||||
|
||||
assert replace_result["success"] is True
|
||||
assert replace_result["count"] == 2
|
||||
list_result = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
|
||||
assert {(item["name"], item["version"]) for item in list_result["apis"]} == {
|
||||
("mcp.read", "1"),
|
||||
("mcp.search", "1"),
|
||||
}
|
||||
|
||||
call_result = await manager._cap_api_call(
|
||||
"provider",
|
||||
"api.call",
|
||||
{
|
||||
"api_name": "provider.mcp.search",
|
||||
"version": "1",
|
||||
"args": {"query": "hello"},
|
||||
},
|
||||
)
|
||||
assert call_result == {"success": True, "result": {"ok": True}}
|
||||
assert captured["component_name"] == "dynamic_search"
|
||||
assert captured["args"]["query"] == "hello"
|
||||
assert captured["args"]["__maibot_api_name__"] == "mcp.search"
|
||||
assert captured["args"]["__maibot_api_version__"] == "1"
|
||||
|
||||
second_replace_result = await manager._cap_api_replace_dynamic(
|
||||
"provider",
|
||||
"api.replace_dynamic",
|
||||
{
|
||||
"apis": [
|
||||
{
|
||||
"name": "mcp.read",
|
||||
"type": "API",
|
||||
"metadata": {
|
||||
"version": "1",
|
||||
"public": True,
|
||||
"handler_name": "dynamic_read",
|
||||
},
|
||||
}
|
||||
],
|
||||
"offline_reason": "MCP 服务器已关闭",
|
||||
},
|
||||
)
|
||||
|
||||
assert second_replace_result["success"] is True
|
||||
assert second_replace_result["count"] == 1
|
||||
assert second_replace_result["offlined"] == 1
|
||||
|
||||
offlined_call_result = await manager._cap_api_call(
|
||||
"provider",
|
||||
"api.call",
|
||||
{
|
||||
"api_name": "provider.mcp.search",
|
||||
"version": "1",
|
||||
"args": {},
|
||||
},
|
||||
)
|
||||
assert offlined_call_result["success"] is False
|
||||
assert "MCP 服务器已关闭" in str(offlined_call_result["error"])
|
||||
|
||||
list_after_replace = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
|
||||
assert {(item["name"], item["version"]) for item in list_after_replace["apis"]} == {
|
||||
("mcp.read", "1"),
|
||||
}
|
||||
154
pytests/test_send_service.py
Normal file
154
pytests/test_send_service.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""发送服务回归测试。"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.services import send_service
|
||||
|
||||
|
||||
class _FakePlatformIOManager:
|
||||
"""用于测试的 Platform IO 管理器假对象。"""
|
||||
|
||||
def __init__(self, delivery_batch: Any) -> None:
|
||||
"""初始化假 Platform IO 管理器。
|
||||
|
||||
Args:
|
||||
delivery_batch: 发送时返回的批量回执。
|
||||
"""
|
||||
self._delivery_batch = delivery_batch
|
||||
self.ensure_calls = 0
|
||||
self.sent_messages: List[Dict[str, Any]] = []
|
||||
|
||||
async def ensure_send_pipeline_ready(self) -> None:
|
||||
"""记录发送管线准备调用次数。"""
|
||||
self.ensure_calls += 1
|
||||
|
||||
def build_route_key_from_message(self, message: Any) -> Any:
|
||||
"""根据消息构造假的路由键。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部消息对象。
|
||||
|
||||
Returns:
|
||||
Any: 简化后的路由键对象。
|
||||
"""
|
||||
del message
|
||||
return SimpleNamespace(platform="qq")
|
||||
|
||||
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
|
||||
"""记录发送请求并返回预设回执。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部消息对象。
|
||||
route_key: 本次发送使用的路由键。
|
||||
metadata: 发送元数据。
|
||||
|
||||
Returns:
|
||||
Any: 预设的批量发送回执。
|
||||
"""
|
||||
self.sent_messages.append(
|
||||
{
|
||||
"message": message,
|
||||
"route_key": route_key,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return self._delivery_batch
|
||||
|
||||
|
||||
def _build_target_stream() -> BotChatSession:
|
||||
"""构造一个最小可用的目标会话对象。
|
||||
|
||||
Returns:
|
||||
BotChatSession: 测试用会话对象。
|
||||
"""
|
||||
return BotChatSession(
|
||||
session_id="test-session",
|
||||
platform="qq",
|
||||
user_id="target-user",
|
||||
group_id=None,
|
||||
)
|
||||
|
||||
|
||||
def test_inherit_platform_io_route_metadata_falls_back_to_bot_account(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""没有上下文消息时,也应回填当前平台账号用于账号级路由命中。"""
|
||||
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "")
|
||||
|
||||
metadata = send_service._inherit_platform_io_route_metadata(_build_target_stream())
|
||||
|
||||
assert metadata["platform_io_account_id"] == "bot-qq"
|
||||
assert metadata["platform_io_target_user_id"] == "target-user"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""send service 应将发送职责统一交给 Platform IO。"""
|
||||
fake_manager = _FakePlatformIOManager(
|
||||
delivery_batch=SimpleNamespace(
|
||||
has_success=True,
|
||||
sent_receipts=[SimpleNamespace(driver_id="plugin.qq.sender")],
|
||||
failed_receipts=[],
|
||||
route_key=SimpleNamespace(platform="qq"),
|
||||
)
|
||||
)
|
||||
stored_messages: List[Any] = []
|
||||
|
||||
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: _build_target_stream() if stream_id == "test-session" else None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
send_service.MessageUtils,
|
||||
"store_message_to_db",
|
||||
lambda message: stored_messages.append(message),
|
||||
)
|
||||
|
||||
result = await send_service.text_to_stream(text="你好", stream_id="test-session")
|
||||
|
||||
assert result is True
|
||||
assert fake_manager.ensure_calls == 1
|
||||
assert len(fake_manager.sent_messages) == 1
|
||||
assert fake_manager.sent_messages[0]["metadata"] == {"show_log": False}
|
||||
assert len(stored_messages) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Platform IO 批量发送全部失败时,应直接向上返回失败。"""
|
||||
fake_manager = _FakePlatformIOManager(
|
||||
delivery_batch=SimpleNamespace(
|
||||
has_success=False,
|
||||
sent_receipts=[],
|
||||
failed_receipts=[
|
||||
SimpleNamespace(
|
||||
driver_id="plugin.qq.sender",
|
||||
status="failed",
|
||||
error="network error",
|
||||
)
|
||||
],
|
||||
route_key=SimpleNamespace(platform="qq"),
|
||||
)
|
||||
)
|
||||
|
||||
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: _build_target_stream() if stream_id == "test-session" else None,
|
||||
)
|
||||
|
||||
result = await send_service.text_to_stream(text="发送失败", stream_id="test-session")
|
||||
|
||||
assert result is False
|
||||
assert fake_manager.ensure_calls == 1
|
||||
assert len(fake_manager.sent_messages) == 1
|
||||
115
pytests/utils_test/statistic_test.py
Normal file
115
pytests/utils_test/statistic_test.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""统计模块数据库会话行为测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Iterator
|
||||
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from src.chat.utils import statistic
|
||||
|
||||
|
||||
class _DummyResult:
|
||||
"""模拟 SQLModel 查询结果对象。"""
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
"""返回空结果集。
|
||||
|
||||
Returns:
|
||||
list[Any]: 空列表。
|
||||
"""
|
||||
return []
|
||||
|
||||
|
||||
class _DummySession:
|
||||
"""模拟数据库 Session。"""
|
||||
|
||||
def exec(self, statement: Any) -> _DummyResult:
|
||||
"""执行查询语句并返回空结果。
|
||||
|
||||
Args:
|
||||
statement: 待执行的查询语句。
|
||||
|
||||
Returns:
|
||||
_DummyResult: 空结果对象。
|
||||
"""
|
||||
del statement
|
||||
return _DummyResult()
|
||||
|
||||
|
||||
def _build_fake_get_db_session(calls: list[bool]) -> Callable[[bool], Iterator[_DummySession]]:
|
||||
"""构造一个记录 auto_commit 参数的假会话工厂。
|
||||
|
||||
Args:
|
||||
calls: 用于记录每次调用 auto_commit 参数的列表。
|
||||
|
||||
Returns:
|
||||
Callable[[bool], Iterator[_DummySession]]: 可替换 `get_db_session` 的上下文管理器工厂。
|
||||
"""
|
||||
|
||||
@contextmanager
|
||||
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_DummySession]:
|
||||
"""记录会话参数并返回假 Session。
|
||||
|
||||
Args:
|
||||
auto_commit: 是否启用自动提交。
|
||||
|
||||
Yields:
|
||||
Iterator[_DummySession]: 假 Session 对象。
|
||||
"""
|
||||
calls.append(auto_commit)
|
||||
yield _DummySession()
|
||||
|
||||
return _fake_get_db_session
|
||||
|
||||
|
||||
def _build_statistic_task() -> statistic.StatisticOutputTask:
|
||||
"""构造一个最小可用的统计任务实例。
|
||||
|
||||
Returns:
|
||||
statistic.StatisticOutputTask: 跳过 `__init__` 的测试实例。
|
||||
"""
|
||||
task = statistic.StatisticOutputTask.__new__(statistic.StatisticOutputTask)
|
||||
task.name_mapping = {}
|
||||
return task
|
||||
|
||||
|
||||
def _is_bot_self(platform: str, user_id: str) -> bool:
|
||||
"""返回固定的非机器人身份判断结果。
|
||||
|
||||
Args:
|
||||
platform: 平台名称。
|
||||
user_id: 用户 ID。
|
||||
|
||||
Returns:
|
||||
bool: 始终返回 ``False``。
|
||||
"""
|
||||
del platform
|
||||
del user_id
|
||||
return False
|
||||
|
||||
|
||||
def test_statistic_read_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""统计模块的纯读查询应关闭自动提交,避免 Session 退出后对象被 expire。"""
|
||||
calls: list[bool] = []
|
||||
now = datetime.now()
|
||||
task = _build_statistic_task()
|
||||
|
||||
monkeypatch.setattr(statistic, "get_db_session", _build_fake_get_db_session(calls))
|
||||
|
||||
utils_module = ModuleType("src.chat.utils.utils")
|
||||
utils_module.is_bot_self = _is_bot_self
|
||||
monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
|
||||
|
||||
statistic.StatisticOutputTask._fetch_online_time_since(now)
|
||||
statistic.StatisticOutputTask._fetch_model_usage_since(now)
|
||||
task._collect_message_count_for_period([("last_hour", now - timedelta(hours=1))])
|
||||
task._collect_interval_data(now, hours=1, interval_minutes=60)
|
||||
task._collect_metrics_interval_data(now, hours=1, interval_hours=1)
|
||||
|
||||
assert calls == [False] * 9
|
||||
42
pytests/utils_test/test_session_utils.py
Normal file
42
pytests/utils_test/test_session_utils.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from src.chat.message_receive.chat_manager import ChatManager
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
|
||||
def test_calculate_session_id_distinguishes_account_and_scope() -> None:
|
||||
base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
|
||||
same_base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
|
||||
account_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123")
|
||||
route_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123", scope="main")
|
||||
|
||||
assert base_session_id == same_base_session_id
|
||||
assert account_scoped_session_id != base_session_id
|
||||
assert route_scoped_session_id != account_scoped_session_id
|
||||
|
||||
|
||||
def test_chat_manager_register_message_uses_route_metadata() -> None:
|
||||
chat_manager = ChatManager()
|
||||
message = SimpleNamespace(
|
||||
platform="qq",
|
||||
session_id="",
|
||||
message_info=SimpleNamespace(
|
||||
user_info=SimpleNamespace(user_id="42"),
|
||||
group_info=SimpleNamespace(group_id="1000"),
|
||||
additional_config={
|
||||
"platform_io_account_id": "123",
|
||||
"platform_io_scope": "main",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
chat_manager.register_message(message)
|
||||
|
||||
assert message.session_id == SessionUtils.calculate_session_id(
|
||||
"qq",
|
||||
user_id="42",
|
||||
group_id="1000",
|
||||
account_id="123",
|
||||
scope="main",
|
||||
)
|
||||
assert chat_manager.last_messages[message.session_id] is message
|
||||
Reference in New Issue
Block a user