Merge remote-tracking branch 'upstream/r-dev' into sync/pr-1564-upstream-20260331

# Conflicts:
#	src/chat/brain_chat/PFC/conversation.py
#	src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py
#	src/chat/knowledge/lpmm_ops.py
This commit is contained in:
A-Dawn
2026-03-31 10:43:55 +08:00
179 changed files with 21829 additions and 20118 deletions

View File

@@ -0,0 +1,924 @@
"""数据库迁移基础设施测试。"""
from pathlib import Path
from typing import List, Optional, Tuple
from sqlalchemy import text
from sqlalchemy.engine import Connection, Engine
from sqlmodel import SQLModel, create_engine
import json
import msgpack
import pytest
from src.common.database import database as database_module
from src.common.database.migrations import (
BaseSchemaVersionDetector,
BaseMigrationProgressReporter,
DatabaseSchemaSnapshot,
DatabaseMigrationBootstrapper,
DatabaseMigrationState,
DatabaseMigrationManager,
EMPTY_SCHEMA_VERSION,
LATEST_SCHEMA_VERSION,
LEGACY_V1_SCHEMA_VERSION,
MigrationExecutionContext,
MigrationPlan,
MigrationRegistry,
MigrationStep,
ResolvedSchemaVersion,
SchemaVersionResolver,
SchemaVersionSource,
SQLiteSchemaInspector,
SQLiteUserVersionStore,
build_default_migration_registry,
build_default_schema_version_resolver,
create_database_migration_bootstrapper,
)
class FixedVersionDetector(BaseSchemaVersionDetector):
"""测试用固定版本探测器。"""
@property
def name(self) -> str:
"""返回测试探测器名称。
Returns:
str: 探测器名称。
"""
return "fixed_version_detector"
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
"""根据测试表是否存在返回固定版本。
Args:
snapshot: 当前数据库结构快照。
Returns:
Optional[int]: 若存在测试表则返回固定版本,否则返回 ``None``。
"""
if snapshot.has_table("legacy_records"):
return 2
return None
class FakeMigrationProgressReporter(BaseMigrationProgressReporter):
"""测试用迁移进度上报器。"""
def __init__(self) -> None:
"""初始化测试用进度上报器。"""
self.events: List[Tuple[str, Optional[int], Optional[int], Optional[str]]] = []
def open(self) -> None:
"""记录打开事件。"""
self.events.append(("open", None, None, None))
def close(self) -> None:
"""记录关闭事件。"""
self.events.append(("close", None, None, None))
def start(
self,
total_records: int,
total_tables: int,
description: str = "总迁移进度",
table_unit_name: str = "",
record_unit_name: str = "记录",
) -> None:
"""记录启动事件。
Args:
total_records: 任务记录总数。
total_tables: 任务表总数。
description: 任务描述。
table_unit_name: 表级进度单位名称。
record_unit_name: 记录级进度单位名称。
"""
del table_unit_name, record_unit_name
self.events.append(("start", total_records, total_tables, description))
def advance(
self,
records: int = 0,
completed_tables: int = 0,
item_name: Optional[str] = None,
) -> None:
"""记录推进事件。
Args:
records: 推进的记录数。
completed_tables: 已完成的表数。
item_name: 当前完成的项目名称。
"""
self.events.append(("advance", records, completed_tables, item_name))
def _create_sqlite_engine(database_file: Path) -> Engine:
"""创建测试用 SQLite 引擎。
Args:
database_file: 测试数据库文件路径。
Returns:
Engine: SQLite 引擎实例。
"""
return create_engine(
f"sqlite:///{database_file}",
echo=False,
connect_args={"check_same_thread": False},
)
def _create_current_schema(connection: Connection) -> None:
"""创建当前最新版本的数据库结构。
Args:
connection: 当前数据库连接。
"""
import src.common.database.database_model # noqa: F401
SQLModel.metadata.create_all(connection)
def _create_legacy_v1_schema_with_sample_data(connection: Connection) -> None:
"""创建带示例数据的旧版 ``0.x`` 数据库结构。
Args:
connection: 当前数据库连接。
"""
connection.execute(
text(
"""
CREATE TABLE chat_streams (
id INTEGER PRIMARY KEY,
stream_id TEXT NOT NULL,
create_time REAL NOT NULL,
last_active_time REAL NOT NULL,
platform TEXT NOT NULL,
user_id TEXT,
group_id TEXT,
group_name TEXT
)
"""
)
)
connection.execute(
text(
"""
CREATE TABLE messages (
id INTEGER PRIMARY KEY,
message_id TEXT NOT NULL,
time REAL NOT NULL,
chat_id TEXT NOT NULL,
chat_info_platform TEXT,
user_id TEXT,
user_nickname TEXT,
chat_info_group_id TEXT,
chat_info_group_name TEXT,
is_mentioned INTEGER,
is_at INTEGER,
processed_plain_text TEXT,
display_message TEXT,
is_emoji INTEGER,
is_picid INTEGER,
is_command INTEGER,
is_notify INTEGER,
additional_config TEXT,
priority_mode TEXT
)
"""
)
)
connection.execute(
text(
"""
CREATE TABLE action_records (
id INTEGER PRIMARY KEY,
action_id TEXT NOT NULL,
time REAL NOT NULL,
action_reasoning TEXT,
action_name TEXT NOT NULL,
action_data TEXT,
action_prompt_display TEXT,
chat_id TEXT
)
"""
)
)
connection.execute(
text(
"""
CREATE TABLE expression (
id INTEGER PRIMARY KEY,
situation TEXT NOT NULL,
style TEXT NOT NULL,
content_list TEXT,
count INTEGER,
last_active_time REAL NOT NULL,
chat_id TEXT,
create_date REAL,
checked INTEGER,
rejected INTEGER,
modified_by TEXT
)
"""
)
)
connection.execute(
text(
"""
CREATE TABLE jargon (
id INTEGER PRIMARY KEY,
content TEXT NOT NULL,
raw_content TEXT,
meaning TEXT,
chat_id TEXT,
is_global INTEGER,
count INTEGER,
is_jargon INTEGER,
last_inference_count INTEGER,
is_complete INTEGER,
inference_with_context TEXT,
inference_content_only TEXT
)
"""
)
)
connection.execute(
text(
"""
INSERT INTO chat_streams (
id,
stream_id,
create_time,
last_active_time,
platform,
user_id,
group_id,
group_name
) VALUES (
1,
'session-1',
1710000000.0,
1710000300.0,
'qq',
'user-1',
'group-1',
'测试群'
)
"""
)
)
connection.execute(
text(
"""
INSERT INTO messages (
id,
message_id,
time,
chat_id,
chat_info_platform,
user_id,
user_nickname,
chat_info_group_id,
chat_info_group_name,
is_mentioned,
is_at,
processed_plain_text,
display_message,
is_emoji,
is_picid,
is_command,
is_notify,
additional_config,
priority_mode
) VALUES (
1,
'msg-1',
1710000010.0,
'session-1',
'qq',
'user-1',
'测试用户',
'group-1',
'测试群',
1,
0,
'你好',
'你好呀',
0,
1,
0,
1,
'{"source":"legacy"}',
'high'
)
"""
)
)
connection.execute(
text(
"""
INSERT INTO action_records (
id,
action_id,
time,
action_reasoning,
action_name,
action_data,
action_prompt_display,
chat_id
) VALUES (
1,
'action-1',
1710000020.0,
'需要调用工具',
'search',
'{"query":"MaiBot"}',
'执行搜索',
'session-1'
)
"""
)
)
connection.execute(
text(
"""
INSERT INTO expression (
id,
situation,
style,
content_list,
count,
last_active_time,
chat_id,
create_date,
checked,
rejected,
modified_by
) VALUES (
1,
'打招呼',
'可爱',
'["你好呀","早上好"]',
3,
1710000030.0,
'session-1',
1710000040.0,
1,
0,
'ai'
)
"""
)
)
connection.execute(
text(
"""
INSERT INTO jargon (
id,
content,
raw_content,
meaning,
chat_id,
is_global,
count,
is_jargon,
last_inference_count,
is_complete,
inference_with_context,
inference_content_only
) VALUES (
1,
'上分',
'["上分"]',
'提高排名',
'session-1',
0,
5,
1,
2,
1,
'{"guess":"context"}',
'{"guess":"content"}'
)
"""
)
)
def test_user_version_store_can_read_and_write_versions(tmp_path: Path) -> None:
"""应支持读取与写入 SQLite ``user_version``。"""
engine = _create_sqlite_engine(tmp_path / "version_store.db")
version_store = SQLiteUserVersionStore()
with engine.begin() as connection:
assert version_store.read_version(connection) == 0
version_store.write_version(connection, 7)
with engine.connect() as connection:
assert version_store.read_version(connection) == 7
def test_schema_inspector_can_extract_tables_and_columns(tmp_path: Path) -> None:
"""应能提取 SQLite 数据库的表与列结构。"""
engine = _create_sqlite_engine(tmp_path / "schema_inspector.db")
inspector = SQLiteSchemaInspector()
with engine.begin() as connection:
connection.execute(
text(
"""
CREATE TABLE legacy_records (
id INTEGER PRIMARY KEY,
payload TEXT NOT NULL,
created_at TEXT
)
"""
)
)
with engine.connect() as connection:
snapshot = inspector.inspect(connection)
assert snapshot.has_table("legacy_records")
assert snapshot.has_column("legacy_records", "payload")
assert not snapshot.has_column("legacy_records", "missing_column")
table_schema = snapshot.get_table("legacy_records")
assert table_schema is not None
assert table_schema.column_names() == ["created_at", "id", "payload"]
def test_resolver_can_identify_empty_database(tmp_path: Path) -> None:
"""空数据库应被解析为版本 ``0``。"""
engine = _create_sqlite_engine(tmp_path / "empty_resolver.db")
resolver = SchemaVersionResolver()
with engine.connect() as connection:
resolved_version = resolver.resolve(connection)
assert resolved_version.version == 0
assert resolved_version.source == SchemaVersionSource.EMPTY_DATABASE
assert resolved_version.snapshot is not None
assert resolved_version.snapshot.is_empty()
def test_resolver_can_use_detector_for_unversioned_legacy_database(tmp_path: Path) -> None:
"""未写入 ``user_version`` 的历史库应支持通过探测器识别版本。"""
engine = _create_sqlite_engine(tmp_path / "legacy_resolver.db")
resolver = SchemaVersionResolver(detectors=[FixedVersionDetector()])
with engine.begin() as connection:
connection.execute(text("CREATE TABLE legacy_records (id INTEGER PRIMARY KEY, payload TEXT NOT NULL)"))
with engine.connect() as connection:
resolved_version = resolver.resolve(connection)
assert resolved_version.version == 2
assert resolved_version.source == SchemaVersionSource.DETECTOR
assert resolved_version.detector_name == "fixed_version_detector"
def test_registry_and_manager_can_execute_registered_steps(tmp_path: Path) -> None:
"""迁移编排器应能按顺序执行已注册步骤并更新版本号。"""
engine = _create_sqlite_engine(tmp_path / "manager.db")
executed_steps: List[str] = []
def migrate_0_to_1(context: MigrationExecutionContext) -> None:
"""测试迁移步骤 0 -> 1。
Args:
context: 当前迁移步骤执行上下文。
"""
executed_steps.append(f"{context.current_version}->{context.target_version}:step_0_to_1")
context.connection.execute(text("CREATE TABLE sample_records (id INTEGER PRIMARY KEY, name TEXT NOT NULL)"))
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
"""测试迁移步骤 1 -> 2。
Args:
context: 当前迁移步骤执行上下文。
"""
executed_steps.append(f"{context.current_version}->{context.target_version}:step_1_to_2")
context.connection.execute(text("ALTER TABLE sample_records ADD COLUMN email TEXT"))
registry = MigrationRegistry(
steps=[
MigrationStep(
version_from=0,
version_to=1,
name="create_sample_records",
description="创建示例表。",
handler=migrate_0_to_1,
),
MigrationStep(
version_from=1,
version_to=2,
name="add_sample_email",
description="为示例表增加邮箱字段。",
handler=migrate_1_to_2,
),
]
)
manager = DatabaseMigrationManager(engine=engine, registry=registry)
migration_plan = manager.migrate()
assert migration_plan.step_count() == 2
assert executed_steps == ["0->2:step_0_to_1", "1->2:step_1_to_2"]
with engine.connect() as connection:
version_store = SQLiteUserVersionStore()
snapshot = SQLiteSchemaInspector().inspect(connection)
recorded_version = version_store.read_version(connection)
assert recorded_version == 2
assert snapshot.has_table("sample_records")
assert snapshot.has_column("sample_records", "email")
def test_manager_can_report_step_progress(tmp_path: Path) -> None:
"""迁移编排器应支持通过上下文上报步骤进度。"""
engine = _create_sqlite_engine(tmp_path / "manager_progress.db")
reporter_instances: List[FakeMigrationProgressReporter] = []
def _build_reporter() -> BaseMigrationProgressReporter:
"""构建测试用进度上报器。
Returns:
BaseMigrationProgressReporter: 测试用进度上报器实例。
"""
reporter = FakeMigrationProgressReporter()
reporter_instances.append(reporter)
return reporter
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
"""测试迁移步骤 ``1 -> 2`` 的进度上报。
Args:
context: 当前迁移步骤执行上下文。
"""
context.start_progress(total_tables=3, total_records=30, description="总迁移进度")
context.advance_progress(records=10, completed_tables=1, item_name="chat_sessions")
context.advance_progress(records=10, completed_tables=1, item_name="mai_messages")
context.advance_progress(records=10, completed_tables=1, item_name="tool_records")
context.connection.execute(text("CREATE TABLE progress_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)"))
with engine.begin() as connection:
SQLiteUserVersionStore().write_version(connection, 1)
registry = MigrationRegistry(
steps=[
MigrationStep(
version_from=1,
version_to=2,
name="progress_step",
description="测试进度上报。",
handler=migrate_1_to_2,
)
]
)
manager = DatabaseMigrationManager(
engine=engine,
registry=registry,
progress_reporter_factory=_build_reporter,
)
migration_plan = manager.migrate()
assert migration_plan.step_count() == 1
assert len(reporter_instances) == 1
assert reporter_instances[0].events == [
("open", None, None, None),
("start", 30, 3, "总迁移进度"),
("advance", 10, 1, "chat_sessions"),
("advance", 10, 1, "mai_messages"),
("advance", 10, 1, "tool_records"),
("close", None, None, None),
]
def test_default_resolver_can_identify_unversioned_latest_database(tmp_path: Path) -> None:
"""默认解析器应能识别未写入版本号的最新结构数据库。"""
engine = _create_sqlite_engine(tmp_path / "latest_resolver.db")
resolver = build_default_schema_version_resolver()
with engine.begin() as connection:
_create_current_schema(connection)
with engine.connect() as connection:
resolved_version = resolver.resolve(connection)
assert resolved_version.version == LATEST_SCHEMA_VERSION
assert resolved_version.source == SchemaVersionSource.DETECTOR
assert resolved_version.detector_name == "latest_schema_detector"
def test_default_resolver_can_identify_legacy_v1_database(tmp_path: Path) -> None:
"""默认解析器应能识别未写版本号的旧版 ``0.x`` 数据库。"""
engine = _create_sqlite_engine(tmp_path / "legacy_v1_resolver.db")
resolver = build_default_schema_version_resolver()
with engine.begin() as connection:
_create_legacy_v1_schema_with_sample_data(connection)
with engine.connect() as connection:
resolved_version = resolver.resolve(connection)
assert resolved_version.version == LEGACY_V1_SCHEMA_VERSION
assert resolved_version.source == SchemaVersionSource.DETECTOR
assert resolved_version.detector_name == "legacy_v1_schema_detector"
def test_bootstrapper_can_finalize_unversioned_latest_database(tmp_path: Path) -> None:
"""已是最新结构但未写版本号的数据库应直接补写 ``user_version``。"""
engine = _create_sqlite_engine(tmp_path / "latest_finalize.db")
bootstrapper = create_database_migration_bootstrapper(engine)
with engine.begin() as connection:
_create_current_schema(connection)
migration_state = bootstrapper.prepare_database()
bootstrapper.finalize_database(migration_state)
assert not migration_state.requires_migration()
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
assert migration_state.resolved_version.source == SchemaVersionSource.DETECTOR
with engine.connect() as connection:
recorded_version = SQLiteUserVersionStore().read_version(connection)
assert recorded_version == LATEST_SCHEMA_VERSION
def test_bootstrapper_can_finalize_empty_database_to_latest_version(tmp_path: Path) -> None:
"""空库在建表完成后应回写最新 ``user_version``。"""
engine = _create_sqlite_engine(tmp_path / "bootstrap_empty.db")
bootstrapper = create_database_migration_bootstrapper(engine)
migration_state = bootstrapper.prepare_database()
assert not migration_state.requires_migration()
assert migration_state.resolved_version.version == EMPTY_SCHEMA_VERSION
assert migration_state.target_version == LATEST_SCHEMA_VERSION
with engine.begin() as connection:
_create_current_schema(connection)
bootstrapper.finalize_database(migration_state)
with engine.connect() as connection:
recorded_version = SQLiteUserVersionStore().read_version(connection)
assert recorded_version == LATEST_SCHEMA_VERSION
def test_bootstrapper_runs_registered_steps_for_versioned_database(tmp_path: Path) -> None:
"""启动桥接器应在已登记旧版本数据库上执行注册迁移步骤。"""
engine = _create_sqlite_engine(tmp_path / "bootstrap_registered.db")
execution_marks: List[str] = []
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
"""测试桥接器迁移步骤 ``1 -> 2``。
Args:
context: 当前迁移步骤执行上下文。
"""
execution_marks.append(f"step={context.step_name},index={context.step_index}")
context.connection.execute(text("ALTER TABLE bootstrap_records ADD COLUMN email TEXT"))
with engine.begin() as connection:
connection.execute(
text("CREATE TABLE bootstrap_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)")
)
SQLiteUserVersionStore().write_version(connection, 1)
registry = MigrationRegistry(
steps=[
MigrationStep(
version_from=1,
version_to=2,
name="bootstrap_add_email",
description="为桥接器测试表增加邮箱字段。",
handler=migrate_1_to_2,
)
]
)
bootstrapper = DatabaseMigrationBootstrapper(
manager=DatabaseMigrationManager(engine=engine, registry=registry),
latest_schema_version=2,
)
migration_state = bootstrapper.prepare_database()
assert migration_state.resolved_version.version == 2
assert migration_state.target_version == 2
assert execution_marks == ["step=bootstrap_add_email,index=1"]
with engine.connect() as connection:
snapshot = SQLiteSchemaInspector().inspect(connection)
recorded_version = SQLiteUserVersionStore().read_version(connection)
assert recorded_version == 2
assert snapshot.has_table("bootstrap_records")
assert snapshot.has_column("bootstrap_records", "email")
def test_default_bootstrapper_can_migrate_legacy_v1_database(tmp_path: Path) -> None:
"""默认桥接器应能把旧版 ``0.x`` 数据库整体迁移到最新结构。"""
engine = _create_sqlite_engine(tmp_path / "legacy_v1_to_v2.db")
bootstrapper = create_database_migration_bootstrapper(engine)
with engine.begin() as connection:
_create_legacy_v1_schema_with_sample_data(connection)
migration_state = bootstrapper.prepare_database()
bootstrapper.finalize_database(migration_state)
assert not migration_state.requires_migration()
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
assert migration_state.resolved_version.source == SchemaVersionSource.PRAGMA
with engine.connect() as connection:
recorded_version = SQLiteUserVersionStore().read_version(connection)
snapshot = SQLiteSchemaInspector().inspect(connection)
message_row = connection.execute(
text(
"""
SELECT session_id, processed_plain_text, additional_config, raw_content
FROM mai_messages
WHERE message_id = 'msg-1'
"""
)
).mappings().one()
action_row = connection.execute(
text(
"""
SELECT session_id, action_name, action_display_prompt
FROM action_records
WHERE action_id = 'action-1'
"""
)
).mappings().one()
tool_row = connection.execute(
text(
"""
SELECT session_id, tool_name, tool_display_prompt
FROM tool_records
WHERE tool_id = 'action-1'
"""
)
).mappings().one()
expression_row = connection.execute(
text(
"""
SELECT session_id, content_list, modified_by
FROM expressions
WHERE id = 1
"""
)
).mappings().one()
jargon_row = connection.execute(
text(
"""
SELECT session_id_dict, raw_content, inference_with_content_only
FROM jargons
WHERE id = 1
"""
)
).mappings().one()
assert recorded_version == LATEST_SCHEMA_VERSION
assert snapshot.has_table("__legacy_v1_messages")
assert snapshot.has_table("chat_sessions")
assert snapshot.has_table("mai_messages")
assert snapshot.has_table("tool_records")
unpacked_raw_content = msgpack.unpackb(message_row["raw_content"], raw=False)
additional_config = json.loads(message_row["additional_config"])
expression_content_list = json.loads(expression_row["content_list"])
jargon_session_id_dict = json.loads(jargon_row["session_id_dict"])
jargon_raw_content = json.loads(jargon_row["raw_content"])
assert message_row["session_id"] == "session-1"
assert message_row["processed_plain_text"] == "你好"
assert unpacked_raw_content == [{"type": "text", "data": "你好呀"}]
assert additional_config == {"priority_mode": "high", "source": "legacy"}
assert action_row["session_id"] == "session-1"
assert action_row["action_name"] == "search"
assert action_row["action_display_prompt"] == "执行搜索"
assert tool_row["session_id"] == "session-1"
assert tool_row["tool_name"] == "search"
assert tool_row["tool_display_prompt"] == "执行搜索"
assert expression_row["session_id"] == "session-1"
assert expression_row["modified_by"] == "AI"
assert expression_content_list == ["你好呀", "早上好"]
assert jargon_session_id_dict == {"session-1": 5}
assert jargon_raw_content == ["上分"]
assert jargon_row["inference_with_content_only"] == '{"guess":"content"}'
def test_legacy_v1_migration_reports_table_progress(tmp_path: Path) -> None:
"""旧版迁移步骤应按目标表数量推进总进度。"""
engine = _create_sqlite_engine(tmp_path / "legacy_progress.db")
reporter_instances: List[FakeMigrationProgressReporter] = []
def _build_reporter() -> BaseMigrationProgressReporter:
"""构建测试用进度上报器。
Returns:
BaseMigrationProgressReporter: 测试用进度上报器实例。
"""
reporter = FakeMigrationProgressReporter()
reporter_instances.append(reporter)
return reporter
with engine.begin() as connection:
_create_legacy_v1_schema_with_sample_data(connection)
manager = DatabaseMigrationManager(
engine=engine,
registry=build_default_migration_registry(),
resolver=build_default_schema_version_resolver(),
progress_reporter_factory=_build_reporter,
)
migration_plan = manager.migrate(target_version=LATEST_SCHEMA_VERSION)
assert migration_plan.step_count() == 1
assert len(reporter_instances) == 1
reporter_events = reporter_instances[0].events
assert reporter_events[0] == ("open", None, None, None)
assert reporter_events[1] == ("start", 6, 12, "总迁移进度")
assert reporter_events[-1] == ("close", None, None, None)
assert reporter_events.count(("advance", 1, 0, None)) == 6
assert reporter_events.count(("advance", 0, 1, "chat_sessions")) == 1
assert reporter_events.count(("advance", 0, 1, "thinking_questions")) == 1
assert len([event for event in reporter_events if event[0] == "advance"]) == 18
def test_initialize_database_calls_bootstrapper_before_create_all(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
"""数据库初始化入口应先准备迁移,再建表、补迁移并收尾。"""
call_order: List[str] = []
def _fake_prepare_database() -> DatabaseMigrationState:
"""返回测试用迁移状态。
Returns:
DatabaseMigrationState: 不包含迁移步骤的测试状态。
"""
call_order.append("prepare_database")
return DatabaseMigrationState(
resolved_version=ResolvedSchemaVersion(version=0, source=SchemaVersionSource.EMPTY_DATABASE),
target_version=LATEST_SCHEMA_VERSION,
plan=MigrationPlan(
current_version=EMPTY_SCHEMA_VERSION,
target_version=LATEST_SCHEMA_VERSION,
steps=[],
),
)
def _fake_create_all(bind) -> None:
"""记录建表调用。
Args:
bind: 传入的数据库绑定对象。
"""
del bind
call_order.append("create_all")
def _fake_migrate_action_records() -> None:
"""记录轻量补迁移调用。"""
call_order.append("migrate_action_records")
def _fake_finalize_database(migration_state: DatabaseMigrationState) -> None:
"""记录迁移收尾调用。
Args:
migration_state: 当前数据库迁移状态。
"""
del migration_state
call_order.append("finalize_database")
monkeypatch.setattr(database_module, "_db_initialized", False)
monkeypatch.setattr(database_module, "_DB_DIR", tmp_path / "data")
monkeypatch.setattr(database_module._migration_bootstrapper, "prepare_database", _fake_prepare_database)
monkeypatch.setattr(database_module._migration_bootstrapper, "finalize_database", _fake_finalize_database)
monkeypatch.setattr(database_module.SQLModel.metadata, "create_all", _fake_create_all)
monkeypatch.setattr(database_module, "_migrate_action_records_to_tool_records", _fake_migrate_action_records)
database_module.initialize_database()
assert call_order == [
"prepare_database",
"create_all",
"migrate_action_records",
"finalize_database",
]

View File

@@ -0,0 +1,55 @@
from datetime import datetime
from pathlib import Path
import sys
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.llm_models.payload_content.tool_option import ToolCall
from src.maisaka.message_adapter import build_message, get_message_kind, get_message_role, get_tool_call_id, get_tool_calls
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
def test_build_message_returns_session_message_with_maisaka_metadata() -> None:
timestamp = datetime.now()
tool_call = ToolCall(
call_id="call-1",
func_name="reply",
args={"message_id": "msg-1"},
)
raw_message = MessageSequence(components=[TextComponent(text="内部消息内容")])
message = build_message(
role="assistant",
content="展示消息内容",
message_kind="perception",
source="assistant",
tool_call_id="call-1",
tool_calls=[tool_call],
timestamp=timestamp,
message_id="maisaka-msg-1",
raw_message=raw_message,
display_text="展示消息内容",
)
assert isinstance(message, SessionMessage)
assert message.initialized is True
assert message.message_id == "maisaka-msg-1"
assert message.timestamp == timestamp
assert message.processed_plain_text == "展示消息内容"
assert message.display_message == "展示消息内容"
assert message.raw_message is raw_message
assert get_message_role(message) == "assistant"
assert get_message_kind(message) == "perception"
assert get_tool_call_id(message) == "call-1"
restored_tool_calls = get_tool_calls(message)
assert len(restored_tool_calls) == 1
assert restored_tool_calls[0].call_id == "call-1"
assert restored_tool_calls[0].func_name == "reply"
assert restored_tool_calls[0].args == {"message_id": "msg-1"}

View File

@@ -5,6 +5,7 @@
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Awaitable, Callable, Dict, List, Optional
import asyncio
import json
@@ -1831,395 +1832,445 @@ class TestMaiMessages:
assert msg.llm_response_content == "new response"
# ─── WorkflowExecutor 测试 ────────────────────────────────
class _FakeHookSupervisor:
"""用于 Hook 分发测试的简化 Supervisor。"""
def __init__(
self,
group_name: str,
component_registry: Any,
handlers: Dict[str, Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]] | Dict[str, Any]]],
call_log: List[tuple[str, str]],
) -> None:
"""初始化测试用 Supervisor。
Args:
group_name: 运行时分组名称。
component_registry: 组件注册表实例。
handlers: 处理器映射,键为 `plugin_id.component_name`。
call_log: 记录调用顺序的列表。
"""
self._group_name = group_name
self.component_registry = component_registry
self._handlers = handlers
self._call_log = call_log
@property
def group_name(self) -> str:
"""返回当前测试 Supervisor 的分组名称。"""
return self._group_name
async def invoke_plugin(
self,
method: str,
plugin_id: str,
component_name: str,
args: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
) -> SimpleNamespace:
"""模拟调用插件组件。
Args:
method: RPC 方法名。
plugin_id: 目标插件 ID。
component_name: 目标组件名称。
args: 调用参数。
timeout_ms: 超时配置,测试中仅用于保持接口一致。
Returns:
SimpleNamespace: 仅包含 `payload` 字段的简化响应对象。
"""
del method
del timeout_ms
full_name = f"{plugin_id}.{component_name}"
handler = self._handlers[full_name]
self._call_log.append((plugin_id, component_name))
result = handler(dict(args or {}))
if asyncio.iscoroutine(result):
result = await result
return SimpleNamespace(payload=result)
class TestWorkflowExecutor:
"""Host-side Workflow 执行器测试(新 pipeline 模型)"""
# ─── HookDispatcher 测试 ────────────────────────────────
class TestHookDispatcher:
"""命名 Hook 分发器测试。"""
@staticmethod
def _import_dispatcher_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]:
"""导入 Hook 分发相关模块,并屏蔽配置初始化触发的退出。
Args:
monkeypatch: pytest 的 monkeypatch 工具。
Returns:
tuple[Any, Any]: `ComponentRegistry` 与 `HookDispatcher` 类型。
"""
monkeypatch.setattr(sys, "exit", lambda code=0: None)
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.hook_dispatcher import HookDispatcher
return ComponentRegistry, HookDispatcher
@pytest.mark.asyncio
async def test_empty_pipeline_completes(self):
"""无任何 workflow_step 注册时pipeline 全阶段跳过,状态 completed"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
async def test_empty_hook_returns_original_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""未注册处理器时应直接返回原始参数。"""
reg = ComponentRegistry()
executor = WorkflowExecutor(reg)
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
async def mock_invoke(plugin_id, comp_name, args):
return {"hook_result": "continue"}
dispatcher = HookDispatcher()
supervisor = _FakeHookSupervisor("builtin", ComponentRegistry(), {}, [])
result, final_msg, ctx = await executor.execute(
mock_invoke,
message={"plain_text": "test"},
)
assert result.status == "completed"
assert result.return_message == "workflow completed"
assert len(ctx.timings) == 6 # 6 stages
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
assert result.hook_name == "heart_fc.cycle_start"
assert result.kwargs == {"session_id": "s-1"}
assert result.aborted is False
@pytest.mark.asyncio
async def test_blocking_hook_modifies_message(self):
"""blocking hook 可以修改消息"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
async def test_blocking_hook_modifies_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""blocking 处理器可以修改参数。"""
reg = ComponentRegistry()
reg.register_component(
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
registry = ComponentRegistry()
registry.register_component(
"upper",
"workflow_step",
"HOOK_HANDLER",
"p1",
{
"stage": "pre_process",
"priority": 10,
"blocking": True,
"hook": "heart_fc.cycle_start",
"mode": "blocking",
"order": "normal",
},
)
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
msg = args.get("message", {})
return {
"hook_result": "continue",
"modified_message": {**msg, "plain_text": msg.get("plain_text", "").upper()},
}
result, final_msg, ctx = await executor.execute(
mock_invoke,
message={"plain_text": "hello"},
dispatcher = HookDispatcher()
supervisor = _FakeHookSupervisor(
"builtin",
registry,
{
"p1.upper": lambda args: {
"success": True,
"action": "continue",
"modified_kwargs": {
"session_id": args["session_id"],
"text": str(args["text"]).upper(),
},
}
},
[],
)
assert result.status == "completed"
assert final_msg["plain_text"] == "HELLO"
assert len(ctx.modification_log) == 1
assert ctx.modification_log[0].stage == "pre_process"
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1", text="hello")
assert result.kwargs["session_id"] == "s-1"
assert result.kwargs["text"] == "HELLO"
assert result.aborted is False
@pytest.mark.asyncio
async def test_abort_stops_pipeline(self):
"""HookResult.ABORT 立即终止 pipeline"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
async def test_abort_stops_following_blocking_handlers(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""blocking 处理器的 abort 应阻止后续 blocking 处理器执行。"""
reg = ComponentRegistry()
reg.register_component(
"blocker",
"workflow_step",
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
registry = ComponentRegistry()
registry.register_component(
"stopper",
"HOOK_HANDLER",
"p1",
{
"stage": "pre_process",
"priority": 10,
"blocking": True,
},
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
)
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
return {"hook_result": "abort"}
result, _, ctx = await executor.execute(
mock_invoke,
message={"plain_text": "test"},
)
assert result.status == "aborted"
assert result.stopped_at == "pre_process"
@pytest.mark.asyncio
async def test_skip_stage(self):
"""HookResult.SKIP_STAGE 跳过当前阶段剩余 hook"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
# high-priority hook 返回 skip_stage
reg.register_component(
"skipper",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 100,
"blocking": True,
},
)
# low-priority hook 不应被执行
reg.register_component(
"checker",
"workflow_step",
registry.register_component(
"after_stop",
"HOOK_HANDLER",
"p2",
{
"stage": "ingress",
"priority": 1,
"blocking": True,
},
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"},
)
call_log: List[tuple[str, str]] = []
dispatcher = HookDispatcher()
supervisor = _FakeHookSupervisor(
"builtin",
registry,
{
"p1.stopper": lambda args: {"success": True, "action": "abort"},
"p2.after_stop": lambda args: {"success": True, "action": "continue"},
},
call_log,
)
executor = WorkflowExecutor(reg)
call_log = []
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], cycle_id="c-1")
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
if comp_name == "skipper":
return {"hook_result": "skip_stage"}
return {"hook_result": "continue"}
result, _, _ = await executor.execute(mock_invoke, message={"plain_text": "test"})
assert result.status == "completed"
# 只有 skipper 被调用checker 被跳过
assert call_log == ["skipper"]
assert result.aborted is True
assert result.stopped_by == "p1.stopper"
assert call_log == [("p1", "stopper")]
@pytest.mark.asyncio
async def test_pre_filter(self):
"""filter 条件不匹配时跳过 hook"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
async def test_observe_handler_runs_in_background_without_mutation(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""observe 处理器应后台执行且不能影响主流程参数。"""
reg = ComponentRegistry()
reg.register_component(
"only_dm",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 10,
"blocking": True,
"filter": {"chat_type": "direct"},
},
)
executor = WorkflowExecutor(reg)
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
return {"hook_result": "continue"}
# 不匹配 filter —— hook 不应被调用
await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "group"})
assert not call_log
# 匹配 filter —— hook 应被调用
await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "direct"})
assert call_log == ["only_dm"]
@pytest.mark.asyncio
async def test_error_policy_skip(self):
"""error_policy=skip 时跳过失败的 hook 继续执行"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
"failer",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 100,
"blocking": True,
"error_policy": "skip",
},
)
reg.register_component(
"ok_step",
"workflow_step",
"p2",
{
"stage": "ingress",
"priority": 1,
"blocking": True,
},
)
executor = WorkflowExecutor(reg)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
if comp_name == "failer":
raise RuntimeError("boom")
return {"hook_result": "continue"}
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"})
assert result.status == "completed"
assert "failer" in call_log
assert "ok_step" in call_log
assert any("boom" in e for e in ctx.errors)
@pytest.mark.asyncio
async def test_error_policy_abort(self):
"""error_policy=abort默认时 pipeline 失败"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
"failer",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 10,
"blocking": True,
# error_policy defaults to "abort"
},
)
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
raise RuntimeError("fatal")
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"})
assert result.status == "failed"
assert result.stopped_at == "ingress"
@pytest.mark.asyncio
async def test_nonblocking_hooks_concurrent(self):
"""non-blocking hook 并发执行,不修改消息"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
for i in range(3):
reg.register_component(
f"nb_{i}",
"workflow_step",
f"p{i}",
{
"stage": "post_process",
"priority": 0,
"blocking": False,
},
)
executor = WorkflowExecutor(reg)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}}
result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"})
# non-blocking 的 modified_message 被忽略
assert final_msg["plain_text"] == "original"
# 给异步 task 时间完成
await asyncio.sleep(0.1)
assert result.status == "completed"
@pytest.mark.asyncio
async def test_nonblocking_tasks_are_retained_until_completion(self):
"""execute 返回后non-blocking task 仍应保持强引用直到执行完成。"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
registry = ComponentRegistry()
registry.register_component(
"observer",
"workflow_step",
"HOOK_HANDLER",
"p1",
{
"stage": "post_process",
"priority": 0,
"blocking": False,
},
{"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"},
)
executor = WorkflowExecutor(reg)
started = asyncio.Event()
release = asyncio.Event()
call_log: List[tuple[str, str]] = []
async def observe_handler(args: Dict[str, Any]) -> Dict[str, Any]:
"""模拟耗时观察型处理器。"""
async def mock_invoke(plugin_id, comp_name, args):
started.set()
await release.wait()
return {"hook_result": "continue"}
return {
"success": True,
"action": "abort",
"modified_kwargs": {"session_id": "changed"},
"custom_result": args["session_id"],
}
result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"})
dispatcher = HookDispatcher()
supervisor = _FakeHookSupervisor(
"builtin",
registry,
{"p1.observer": observe_handler},
call_log,
)
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
await asyncio.sleep(0)
assert result.status == "completed"
assert final_msg["plain_text"] == "original"
assert result.aborted is False
assert result.kwargs["session_id"] == "s-1"
assert started.is_set()
assert len(executor._background_tasks) == 1
assert len(dispatcher._background_tasks) == 1
release.set()
await asyncio.sleep(0)
await asyncio.sleep(0)
assert not executor._background_tasks
assert call_log == [("p1", "observer")]
assert not dispatcher._background_tasks
@pytest.mark.asyncio
async def test_command_routing(self):
"""PLAN 阶段内置命令路由"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
async def test_global_order_prefers_order_slot_then_source(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""全局排序应先看 order再看内置/第三方来源。"""
reg = ComponentRegistry()
reg.register_component(
"help",
"command",
"p1",
{
"command_pattern": r"^/help",
},
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
builtin_registry = ComponentRegistry()
third_registry = ComponentRegistry()
builtin_registry.register_component(
"builtin_early",
"HOOK_HANDLER",
"b1",
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
)
builtin_registry.register_component(
"builtin_normal",
"HOOK_HANDLER",
"b1",
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"},
)
third_registry.register_component(
"third_early",
"HOOK_HANDLER",
"t1",
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
)
third_registry.register_component(
"third_normal",
"HOOK_HANDLER",
"t1",
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"},
)
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
if comp_name == "help":
return {"output": "帮助信息"}
return {"hook_result": "continue"}
call_log: List[tuple[str, str]] = []
dispatcher = HookDispatcher()
builtin_supervisor = _FakeHookSupervisor(
"builtin",
builtin_registry,
{
"b1.builtin_early": lambda args: {"success": True, "action": "continue"},
"b1.builtin_normal": lambda args: {"success": True, "action": "continue"},
},
call_log,
)
third_supervisor = _FakeHookSupervisor(
"third_party",
third_registry,
{
"t1.third_early": lambda args: {"success": True, "action": "continue"},
"t1.third_normal": lambda args: {"success": True, "action": "continue"},
},
call_log,
)
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "/help topic"})
assert result.status == "completed"
assert ctx.matched_command == "p1.help"
cmd_result = ctx.get_stage_output("plan", "command_result")
assert cmd_result is not None
assert cmd_result["output"] == "帮助信息"
await dispatcher.invoke_hook(
"heart_fc.cycle_start",
[third_supervisor, builtin_supervisor],
cycle_id="c-1",
)
assert call_log == [
("b1", "builtin_early"),
("t1", "third_early"),
("b1", "builtin_normal"),
("t1", "third_normal"),
]
@pytest.mark.asyncio
async def test_stage_outputs(self):
"""stage_outputs 数据在阶段间传递"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
async def test_error_policy_abort_stops_dispatch(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""error_policy=abort 时应中止本次 Hook 调用。"""
reg = ComponentRegistry()
# ingress 阶段写入数据
reg.register_component(
"writer",
"workflow_step",
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
registry = ComponentRegistry()
registry.register_component(
"failer",
"HOOK_HANDLER",
"p1",
{
"stage": "ingress",
"priority": 10,
"blocking": True,
"hook": "heart_fc.cycle_start",
"mode": "blocking",
"order": "normal",
"error_policy": "abort",
},
)
# pre_process 阶段读取数据
reg.register_component(
"reader",
"workflow_step",
"p2",
call_log: List[tuple[str, str]] = []
async def fail_handler(args: Dict[str, Any]) -> Dict[str, Any]:
"""抛出异常以触发 abort 策略。"""
del args
raise RuntimeError("boom")
dispatcher = HookDispatcher()
supervisor = _FakeHookSupervisor("builtin", registry, {"p1.failer": fail_handler}, call_log)
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
assert result.aborted is True
assert result.stopped_by == "p1.failer"
assert any("boom" in error for error in result.errors)
assert call_log == [("p1", "failer")]
@pytest.mark.asyncio
async def test_timeout_respects_handler_timeout_ms(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""处理器超时应被记录为错误并继续。"""
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
registry = ComponentRegistry()
registry.register_component(
"slow",
"HOOK_HANDLER",
"p1",
{
"stage": "pre_process",
"priority": 10,
"blocking": True,
"hook": "heart_fc.cycle_start",
"mode": "blocking",
"order": "normal",
"timeout_ms": 10,
},
)
executor = WorkflowExecutor(reg)
call_log: List[tuple[str, str]] = []
async def mock_invoke(plugin_id, comp_name, args):
if comp_name == "writer":
return {
"hook_result": "continue",
"stage_output": {"parsed_intent": "greeting"},
}
if comp_name == "reader":
# 验证 stage_outputs 被传递过来
outputs = args.get("stage_outputs", {})
ingress_data = outputs.get("ingress", {})
assert ingress_data.get("parsed_intent") == "greeting"
return {"hook_result": "continue"}
return {"hook_result": "continue"}
async def slow_handler(args: Dict[str, Any]) -> Dict[str, Any]:
"""模拟超时处理器。"""
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "hi"})
assert result.status == "completed"
assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting"
del args
await asyncio.sleep(0.05)
return {"success": True, "action": "continue"}
dispatcher = HookDispatcher()
supervisor = _FakeHookSupervisor("builtin", registry, {"p1.slow": slow_handler}, call_log)
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
assert result.aborted is False
assert any("超时" in error for error in result.errors)
assert call_log == [("p1", "slow")]
class TestPluginRuntimeHookEntry:
"""PluginRuntimeManager 命名 Hook 入口测试。"""
@staticmethod
def _import_manager_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]:
"""导入运行时管理器相关模块,并屏蔽配置初始化触发的退出。
Args:
monkeypatch: pytest 的 monkeypatch 工具。
Returns:
tuple[Any, Any]: `ComponentRegistry` 与 `PluginRuntimeManager` 类型。
"""
monkeypatch.setattr(sys, "exit", lambda code=0: None)
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.integration import PluginRuntimeManager
return ComponentRegistry, PluginRuntimeManager
@pytest.mark.asyncio
async def test_manager_invoke_hook_dispatches_across_supervisors(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""PluginRuntimeManager.invoke_hook() 应调用全局 Hook 分发器。"""
ComponentRegistry, PluginRuntimeManager = self._import_manager_modules(monkeypatch)
builtin_registry = ComponentRegistry()
builtin_registry.register_component(
"builtin_guard",
"HOOK_HANDLER",
"b1",
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
)
third_registry = ComponentRegistry()
third_registry.register_component(
"observer",
"HOOK_HANDLER",
"t1",
{"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"},
)
call_log: List[tuple[str, str]] = []
manager = PluginRuntimeManager()
manager._started = True
manager._builtin_supervisor = _FakeHookSupervisor(
"builtin",
builtin_registry,
{"b1.builtin_guard": lambda args: {"success": True, "action": "continue"}},
call_log,
)
manager._third_party_supervisor = _FakeHookSupervisor(
"third_party",
third_registry,
{"t1.observer": lambda args: {"success": True, "action": "continue"}},
call_log,
)
result = await manager.invoke_dispatcher.invoke_hook("heart_fc.cycle_start", session_id="s-1")
await asyncio.sleep(0)
assert manager.invoke_dispatcher is manager.hook_dispatcher
assert result.aborted is False
assert result.kwargs["session_id"] == "s-1"
assert ("b1", "builtin_guard") in call_log
class TestRPCServer:

View File

@@ -6,6 +6,7 @@ from typing import Any, Dict, List
import pytest
from src.chat.message_receive.chat_manager import BotChatSession
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.services import send_service
@@ -13,42 +14,18 @@ class _FakePlatformIOManager:
"""用于测试的 Platform IO 管理器假对象。"""
def __init__(self, delivery_batch: Any) -> None:
"""初始化假 Platform IO 管理器。
Args:
delivery_batch: 发送时返回的批量回执。
"""
self._delivery_batch = delivery_batch
self.ensure_calls = 0
self.sent_messages: List[Dict[str, Any]] = []
async def ensure_send_pipeline_ready(self) -> None:
"""记录发送管线准备调用次数。"""
self.ensure_calls += 1
def build_route_key_from_message(self, message: Any) -> Any:
"""根据消息构造假的路由键。
Args:
message: 待发送的内部消息对象。
Returns:
Any: 简化后的路由键对象。
"""
del message
return SimpleNamespace(platform="qq")
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
"""记录发送请求并返回预设回执。
Args:
message: 待发送的内部消息对象。
route_key: 本次发送使用的路由键。
metadata: 发送元数据。
Returns:
Any: 预设的批量发送回执。
"""
self.sent_messages.append(
{
"message": message,
@@ -59,12 +36,7 @@ class _FakePlatformIOManager:
return self._delivery_batch
def _build_target_stream() -> BotChatSession:
"""构造一个最小可用的目标会话对象。
Returns:
BotChatSession: 测试用会话对象。
"""
def _build_private_stream() -> BotChatSession:
return BotChatSession(
session_id="test-session",
platform="qq",
@@ -73,14 +45,21 @@ def _build_target_stream() -> BotChatSession:
)
def _build_group_stream() -> BotChatSession:
return BotChatSession(
session_id="group-session",
platform="qq",
user_id="target-user",
group_id="target-group",
)
def test_inherit_platform_io_route_metadata_falls_back_to_bot_account(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""没有上下文消息时,也应回填当前平台账号用于账号级路由命中。"""
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "")
metadata = send_service._inherit_platform_io_route_metadata(_build_target_stream())
metadata = send_service._inherit_platform_io_route_metadata(_build_private_stream())
assert metadata["platform_io_account_id"] == "bot-qq"
assert metadata["platform_io_target_user_id"] == "target-user"
@@ -88,7 +67,6 @@ def test_inherit_platform_io_route_metadata_falls_back_to_bot_account(
@pytest.mark.asyncio
async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None:
"""send service 应将发送职责统一交给 Platform IO。"""
fake_manager = _FakePlatformIOManager(
delivery_batch=SimpleNamespace(
has_success=True,
@@ -104,7 +82,7 @@ async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.Monke
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: _build_target_stream() if stream_id == "test-session" else None,
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
)
monkeypatch.setattr(
send_service.MessageUtils,
@@ -123,7 +101,6 @@ async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.Monke
@pytest.mark.asyncio
async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None:
"""Platform IO 批量发送全部失败时,应直接向上返回失败。"""
fake_manager = _FakePlatformIOManager(
delivery_batch=SimpleNamespace(
has_success=False,
@@ -144,7 +121,7 @@ async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch:
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: _build_target_stream() if stream_id == "test-session" else None,
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
)
result = await send_service.text_to_stream(text="发送失败", stream_id="test-session")
@@ -152,3 +129,63 @@ async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch:
assert result is False
assert fake_manager.ensure_calls == 1
assert len(fake_manager.sent_messages) == 1
@pytest.mark.asyncio
async def test_private_outbound_message_preserves_bot_sender_and_receiver_user(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
)
outbound_message = send_service._build_outbound_session_message(
message_sequence=MessageSequence(components=[TextComponent(text="你好")]),
stream_id="test-session",
display_message="你好",
)
assert outbound_message is not None
maim_message = await outbound_message.to_maim_message()
assert maim_message.message_info.user_info is not None
assert maim_message.message_info.user_info.user_id == "bot-qq"
assert maim_message.message_info.group_info is None
assert maim_message.message_info.sender_info is not None
assert maim_message.message_info.sender_info.user_info is not None
assert maim_message.message_info.sender_info.user_info.user_id == "bot-qq"
assert maim_message.message_info.receiver_info is not None
assert maim_message.message_info.receiver_info.user_info is not None
assert maim_message.message_info.receiver_info.user_info.user_id == "target-user"
@pytest.mark.asyncio
async def test_group_outbound_message_preserves_bot_sender_and_target_group(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: _build_group_stream() if stream_id == "group-session" else None,
)
outbound_message = send_service._build_outbound_session_message(
message_sequence=MessageSequence(components=[TextComponent(text="大家好")]),
stream_id="group-session",
display_message="大家好",
)
assert outbound_message is not None
maim_message = await outbound_message.to_maim_message()
assert maim_message.message_info.user_info is not None
assert maim_message.message_info.user_info.user_id == "bot-qq"
assert maim_message.message_info.group_info is not None
assert maim_message.message_info.group_info.group_id == "target-group"
assert maim_message.message_info.receiver_info is not None
assert maim_message.message_info.receiver_info.group_info is not None
assert maim_message.message_info.receiver_info.group_info.group_id == "target-group"