Merge remote-tracking branch 'upstream/r-dev' into sync/pr-1564-upstream-20260331
# Conflicts: # src/chat/brain_chat/PFC/conversation.py # src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py # src/chat/knowledge/lpmm_ops.py
This commit is contained in:
924
pytests/common_test/test_database_migration_foundation.py
Normal file
924
pytests/common_test/test_database_migration_foundation.py
Normal file
@@ -0,0 +1,924 @@
|
||||
"""数据库迁移基础设施测试。"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Connection, Engine
|
||||
from sqlmodel import SQLModel, create_engine
|
||||
|
||||
import json
|
||||
import msgpack
|
||||
import pytest
|
||||
|
||||
from src.common.database import database as database_module
|
||||
from src.common.database.migrations import (
|
||||
BaseSchemaVersionDetector,
|
||||
BaseMigrationProgressReporter,
|
||||
DatabaseSchemaSnapshot,
|
||||
DatabaseMigrationBootstrapper,
|
||||
DatabaseMigrationState,
|
||||
DatabaseMigrationManager,
|
||||
EMPTY_SCHEMA_VERSION,
|
||||
LATEST_SCHEMA_VERSION,
|
||||
LEGACY_V1_SCHEMA_VERSION,
|
||||
MigrationExecutionContext,
|
||||
MigrationPlan,
|
||||
MigrationRegistry,
|
||||
MigrationStep,
|
||||
ResolvedSchemaVersion,
|
||||
SchemaVersionResolver,
|
||||
SchemaVersionSource,
|
||||
SQLiteSchemaInspector,
|
||||
SQLiteUserVersionStore,
|
||||
build_default_migration_registry,
|
||||
build_default_schema_version_resolver,
|
||||
create_database_migration_bootstrapper,
|
||||
)
|
||||
|
||||
|
||||
class FixedVersionDetector(BaseSchemaVersionDetector):
|
||||
"""测试用固定版本探测器。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""返回测试探测器名称。
|
||||
|
||||
Returns:
|
||||
str: 探测器名称。
|
||||
"""
|
||||
return "fixed_version_detector"
|
||||
|
||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||
"""根据测试表是否存在返回固定版本。
|
||||
|
||||
Args:
|
||||
snapshot: 当前数据库结构快照。
|
||||
|
||||
Returns:
|
||||
Optional[int]: 若存在测试表则返回固定版本,否则返回 ``None``。
|
||||
"""
|
||||
if snapshot.has_table("legacy_records"):
|
||||
return 2
|
||||
return None
|
||||
|
||||
|
||||
class FakeMigrationProgressReporter(BaseMigrationProgressReporter):
|
||||
"""测试用迁移进度上报器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化测试用进度上报器。"""
|
||||
self.events: List[Tuple[str, Optional[int], Optional[int], Optional[str]]] = []
|
||||
|
||||
def open(self) -> None:
|
||||
"""记录打开事件。"""
|
||||
self.events.append(("open", None, None, None))
|
||||
|
||||
def close(self) -> None:
|
||||
"""记录关闭事件。"""
|
||||
self.events.append(("close", None, None, None))
|
||||
|
||||
def start(
|
||||
self,
|
||||
total_records: int,
|
||||
total_tables: int,
|
||||
description: str = "总迁移进度",
|
||||
table_unit_name: str = "表",
|
||||
record_unit_name: str = "记录",
|
||||
) -> None:
|
||||
"""记录启动事件。
|
||||
|
||||
Args:
|
||||
total_records: 任务记录总数。
|
||||
total_tables: 任务表总数。
|
||||
description: 任务描述。
|
||||
table_unit_name: 表级进度单位名称。
|
||||
record_unit_name: 记录级进度单位名称。
|
||||
"""
|
||||
del table_unit_name, record_unit_name
|
||||
self.events.append(("start", total_records, total_tables, description))
|
||||
|
||||
def advance(
|
||||
self,
|
||||
records: int = 0,
|
||||
completed_tables: int = 0,
|
||||
item_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""记录推进事件。
|
||||
|
||||
Args:
|
||||
records: 推进的记录数。
|
||||
completed_tables: 已完成的表数。
|
||||
item_name: 当前完成的项目名称。
|
||||
"""
|
||||
self.events.append(("advance", records, completed_tables, item_name))
|
||||
|
||||
|
||||
def _create_sqlite_engine(database_file: Path) -> Engine:
|
||||
"""创建测试用 SQLite 引擎。
|
||||
|
||||
Args:
|
||||
database_file: 测试数据库文件路径。
|
||||
|
||||
Returns:
|
||||
Engine: SQLite 引擎实例。
|
||||
"""
|
||||
return create_engine(
|
||||
f"sqlite:///{database_file}",
|
||||
echo=False,
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
|
||||
|
||||
def _create_current_schema(connection: Connection) -> None:
|
||||
"""创建当前最新版本的数据库结构。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
"""
|
||||
import src.common.database.database_model # noqa: F401
|
||||
|
||||
SQLModel.metadata.create_all(connection)
|
||||
|
||||
|
||||
def _create_legacy_v1_schema_with_sample_data(connection: Connection) -> None:
|
||||
"""创建带示例数据的旧版 ``0.x`` 数据库结构。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
"""
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE chat_streams (
|
||||
id INTEGER PRIMARY KEY,
|
||||
stream_id TEXT NOT NULL,
|
||||
create_time REAL NOT NULL,
|
||||
last_active_time REAL NOT NULL,
|
||||
platform TEXT NOT NULL,
|
||||
user_id TEXT,
|
||||
group_id TEXT,
|
||||
group_name TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE messages (
|
||||
id INTEGER PRIMARY KEY,
|
||||
message_id TEXT NOT NULL,
|
||||
time REAL NOT NULL,
|
||||
chat_id TEXT NOT NULL,
|
||||
chat_info_platform TEXT,
|
||||
user_id TEXT,
|
||||
user_nickname TEXT,
|
||||
chat_info_group_id TEXT,
|
||||
chat_info_group_name TEXT,
|
||||
is_mentioned INTEGER,
|
||||
is_at INTEGER,
|
||||
processed_plain_text TEXT,
|
||||
display_message TEXT,
|
||||
is_emoji INTEGER,
|
||||
is_picid INTEGER,
|
||||
is_command INTEGER,
|
||||
is_notify INTEGER,
|
||||
additional_config TEXT,
|
||||
priority_mode TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE action_records (
|
||||
id INTEGER PRIMARY KEY,
|
||||
action_id TEXT NOT NULL,
|
||||
time REAL NOT NULL,
|
||||
action_reasoning TEXT,
|
||||
action_name TEXT NOT NULL,
|
||||
action_data TEXT,
|
||||
action_prompt_display TEXT,
|
||||
chat_id TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE expression (
|
||||
id INTEGER PRIMARY KEY,
|
||||
situation TEXT NOT NULL,
|
||||
style TEXT NOT NULL,
|
||||
content_list TEXT,
|
||||
count INTEGER,
|
||||
last_active_time REAL NOT NULL,
|
||||
chat_id TEXT,
|
||||
create_date REAL,
|
||||
checked INTEGER,
|
||||
rejected INTEGER,
|
||||
modified_by TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE jargon (
|
||||
id INTEGER PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
raw_content TEXT,
|
||||
meaning TEXT,
|
||||
chat_id TEXT,
|
||||
is_global INTEGER,
|
||||
count INTEGER,
|
||||
is_jargon INTEGER,
|
||||
last_inference_count INTEGER,
|
||||
is_complete INTEGER,
|
||||
inference_with_context TEXT,
|
||||
inference_content_only TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO chat_streams (
|
||||
id,
|
||||
stream_id,
|
||||
create_time,
|
||||
last_active_time,
|
||||
platform,
|
||||
user_id,
|
||||
group_id,
|
||||
group_name
|
||||
) VALUES (
|
||||
1,
|
||||
'session-1',
|
||||
1710000000.0,
|
||||
1710000300.0,
|
||||
'qq',
|
||||
'user-1',
|
||||
'group-1',
|
||||
'测试群'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO messages (
|
||||
id,
|
||||
message_id,
|
||||
time,
|
||||
chat_id,
|
||||
chat_info_platform,
|
||||
user_id,
|
||||
user_nickname,
|
||||
chat_info_group_id,
|
||||
chat_info_group_name,
|
||||
is_mentioned,
|
||||
is_at,
|
||||
processed_plain_text,
|
||||
display_message,
|
||||
is_emoji,
|
||||
is_picid,
|
||||
is_command,
|
||||
is_notify,
|
||||
additional_config,
|
||||
priority_mode
|
||||
) VALUES (
|
||||
1,
|
||||
'msg-1',
|
||||
1710000010.0,
|
||||
'session-1',
|
||||
'qq',
|
||||
'user-1',
|
||||
'测试用户',
|
||||
'group-1',
|
||||
'测试群',
|
||||
1,
|
||||
0,
|
||||
'你好',
|
||||
'你好呀',
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
'{"source":"legacy"}',
|
||||
'high'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO action_records (
|
||||
id,
|
||||
action_id,
|
||||
time,
|
||||
action_reasoning,
|
||||
action_name,
|
||||
action_data,
|
||||
action_prompt_display,
|
||||
chat_id
|
||||
) VALUES (
|
||||
1,
|
||||
'action-1',
|
||||
1710000020.0,
|
||||
'需要调用工具',
|
||||
'search',
|
||||
'{"query":"MaiBot"}',
|
||||
'执行搜索',
|
||||
'session-1'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO expression (
|
||||
id,
|
||||
situation,
|
||||
style,
|
||||
content_list,
|
||||
count,
|
||||
last_active_time,
|
||||
chat_id,
|
||||
create_date,
|
||||
checked,
|
||||
rejected,
|
||||
modified_by
|
||||
) VALUES (
|
||||
1,
|
||||
'打招呼',
|
||||
'可爱',
|
||||
'["你好呀","早上好"]',
|
||||
3,
|
||||
1710000030.0,
|
||||
'session-1',
|
||||
1710000040.0,
|
||||
1,
|
||||
0,
|
||||
'ai'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO jargon (
|
||||
id,
|
||||
content,
|
||||
raw_content,
|
||||
meaning,
|
||||
chat_id,
|
||||
is_global,
|
||||
count,
|
||||
is_jargon,
|
||||
last_inference_count,
|
||||
is_complete,
|
||||
inference_with_context,
|
||||
inference_content_only
|
||||
) VALUES (
|
||||
1,
|
||||
'上分',
|
||||
'["上分"]',
|
||||
'提高排名',
|
||||
'session-1',
|
||||
0,
|
||||
5,
|
||||
1,
|
||||
2,
|
||||
1,
|
||||
'{"guess":"context"}',
|
||||
'{"guess":"content"}'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_user_version_store_can_read_and_write_versions(tmp_path: Path) -> None:
|
||||
"""应支持读取与写入 SQLite ``user_version``。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "version_store.db")
|
||||
version_store = SQLiteUserVersionStore()
|
||||
|
||||
with engine.begin() as connection:
|
||||
assert version_store.read_version(connection) == 0
|
||||
version_store.write_version(connection, 7)
|
||||
|
||||
with engine.connect() as connection:
|
||||
assert version_store.read_version(connection) == 7
|
||||
|
||||
|
||||
def test_schema_inspector_can_extract_tables_and_columns(tmp_path: Path) -> None:
|
||||
"""应能提取 SQLite 数据库的表与列结构。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "schema_inspector.db")
|
||||
inspector = SQLiteSchemaInspector()
|
||||
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE legacy_records (
|
||||
id INTEGER PRIMARY KEY,
|
||||
payload TEXT NOT NULL,
|
||||
created_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
with engine.connect() as connection:
|
||||
snapshot = inspector.inspect(connection)
|
||||
|
||||
assert snapshot.has_table("legacy_records")
|
||||
assert snapshot.has_column("legacy_records", "payload")
|
||||
assert not snapshot.has_column("legacy_records", "missing_column")
|
||||
table_schema = snapshot.get_table("legacy_records")
|
||||
|
||||
assert table_schema is not None
|
||||
assert table_schema.column_names() == ["created_at", "id", "payload"]
|
||||
|
||||
|
||||
def test_resolver_can_identify_empty_database(tmp_path: Path) -> None:
|
||||
"""空数据库应被解析为版本 ``0``。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "empty_resolver.db")
|
||||
resolver = SchemaVersionResolver()
|
||||
|
||||
with engine.connect() as connection:
|
||||
resolved_version = resolver.resolve(connection)
|
||||
|
||||
assert resolved_version.version == 0
|
||||
assert resolved_version.source == SchemaVersionSource.EMPTY_DATABASE
|
||||
assert resolved_version.snapshot is not None
|
||||
assert resolved_version.snapshot.is_empty()
|
||||
|
||||
|
||||
def test_resolver_can_use_detector_for_unversioned_legacy_database(tmp_path: Path) -> None:
|
||||
"""未写入 ``user_version`` 的历史库应支持通过探测器识别版本。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "legacy_resolver.db")
|
||||
resolver = SchemaVersionResolver(detectors=[FixedVersionDetector()])
|
||||
|
||||
with engine.begin() as connection:
|
||||
connection.execute(text("CREATE TABLE legacy_records (id INTEGER PRIMARY KEY, payload TEXT NOT NULL)"))
|
||||
|
||||
with engine.connect() as connection:
|
||||
resolved_version = resolver.resolve(connection)
|
||||
|
||||
assert resolved_version.version == 2
|
||||
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
||||
assert resolved_version.detector_name == "fixed_version_detector"
|
||||
|
||||
|
||||
def test_registry_and_manager_can_execute_registered_steps(tmp_path: Path) -> None:
|
||||
"""迁移编排器应能按顺序执行已注册步骤并更新版本号。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "manager.db")
|
||||
executed_steps: List[str] = []
|
||||
|
||||
def migrate_0_to_1(context: MigrationExecutionContext) -> None:
|
||||
"""测试迁移步骤 0 -> 1。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
executed_steps.append(f"{context.current_version}->{context.target_version}:step_0_to_1")
|
||||
context.connection.execute(text("CREATE TABLE sample_records (id INTEGER PRIMARY KEY, name TEXT NOT NULL)"))
|
||||
|
||||
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
||||
"""测试迁移步骤 1 -> 2。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
executed_steps.append(f"{context.current_version}->{context.target_version}:step_1_to_2")
|
||||
context.connection.execute(text("ALTER TABLE sample_records ADD COLUMN email TEXT"))
|
||||
|
||||
registry = MigrationRegistry(
|
||||
steps=[
|
||||
MigrationStep(
|
||||
version_from=0,
|
||||
version_to=1,
|
||||
name="create_sample_records",
|
||||
description="创建示例表。",
|
||||
handler=migrate_0_to_1,
|
||||
),
|
||||
MigrationStep(
|
||||
version_from=1,
|
||||
version_to=2,
|
||||
name="add_sample_email",
|
||||
description="为示例表增加邮箱字段。",
|
||||
handler=migrate_1_to_2,
|
||||
),
|
||||
]
|
||||
)
|
||||
manager = DatabaseMigrationManager(engine=engine, registry=registry)
|
||||
|
||||
migration_plan = manager.migrate()
|
||||
|
||||
assert migration_plan.step_count() == 2
|
||||
assert executed_steps == ["0->2:step_0_to_1", "1->2:step_1_to_2"]
|
||||
|
||||
with engine.connect() as connection:
|
||||
version_store = SQLiteUserVersionStore()
|
||||
snapshot = SQLiteSchemaInspector().inspect(connection)
|
||||
recorded_version = version_store.read_version(connection)
|
||||
|
||||
assert recorded_version == 2
|
||||
assert snapshot.has_table("sample_records")
|
||||
assert snapshot.has_column("sample_records", "email")
|
||||
|
||||
|
||||
def test_manager_can_report_step_progress(tmp_path: Path) -> None:
|
||||
"""迁移编排器应支持通过上下文上报步骤进度。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "manager_progress.db")
|
||||
reporter_instances: List[FakeMigrationProgressReporter] = []
|
||||
|
||||
def _build_reporter() -> BaseMigrationProgressReporter:
|
||||
"""构建测试用进度上报器。
|
||||
|
||||
Returns:
|
||||
BaseMigrationProgressReporter: 测试用进度上报器实例。
|
||||
"""
|
||||
reporter = FakeMigrationProgressReporter()
|
||||
reporter_instances.append(reporter)
|
||||
return reporter
|
||||
|
||||
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
||||
"""测试迁移步骤 ``1 -> 2`` 的进度上报。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
context.start_progress(total_tables=3, total_records=30, description="总迁移进度")
|
||||
context.advance_progress(records=10, completed_tables=1, item_name="chat_sessions")
|
||||
context.advance_progress(records=10, completed_tables=1, item_name="mai_messages")
|
||||
context.advance_progress(records=10, completed_tables=1, item_name="tool_records")
|
||||
context.connection.execute(text("CREATE TABLE progress_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)"))
|
||||
|
||||
with engine.begin() as connection:
|
||||
SQLiteUserVersionStore().write_version(connection, 1)
|
||||
|
||||
registry = MigrationRegistry(
|
||||
steps=[
|
||||
MigrationStep(
|
||||
version_from=1,
|
||||
version_to=2,
|
||||
name="progress_step",
|
||||
description="测试进度上报。",
|
||||
handler=migrate_1_to_2,
|
||||
)
|
||||
]
|
||||
)
|
||||
manager = DatabaseMigrationManager(
|
||||
engine=engine,
|
||||
registry=registry,
|
||||
progress_reporter_factory=_build_reporter,
|
||||
)
|
||||
|
||||
migration_plan = manager.migrate()
|
||||
|
||||
assert migration_plan.step_count() == 1
|
||||
assert len(reporter_instances) == 1
|
||||
assert reporter_instances[0].events == [
|
||||
("open", None, None, None),
|
||||
("start", 30, 3, "总迁移进度"),
|
||||
("advance", 10, 1, "chat_sessions"),
|
||||
("advance", 10, 1, "mai_messages"),
|
||||
("advance", 10, 1, "tool_records"),
|
||||
("close", None, None, None),
|
||||
]
|
||||
|
||||
|
||||
def test_default_resolver_can_identify_unversioned_latest_database(tmp_path: Path) -> None:
|
||||
"""默认解析器应能识别未写入版本号的最新结构数据库。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "latest_resolver.db")
|
||||
resolver = build_default_schema_version_resolver()
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_current_schema(connection)
|
||||
|
||||
with engine.connect() as connection:
|
||||
resolved_version = resolver.resolve(connection)
|
||||
|
||||
assert resolved_version.version == LATEST_SCHEMA_VERSION
|
||||
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
||||
assert resolved_version.detector_name == "latest_schema_detector"
|
||||
|
||||
|
||||
def test_default_resolver_can_identify_legacy_v1_database(tmp_path: Path) -> None:
|
||||
"""默认解析器应能识别未写版本号的旧版 ``0.x`` 数据库。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "legacy_v1_resolver.db")
|
||||
resolver = build_default_schema_version_resolver()
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_legacy_v1_schema_with_sample_data(connection)
|
||||
|
||||
with engine.connect() as connection:
|
||||
resolved_version = resolver.resolve(connection)
|
||||
|
||||
assert resolved_version.version == LEGACY_V1_SCHEMA_VERSION
|
||||
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
||||
assert resolved_version.detector_name == "legacy_v1_schema_detector"
|
||||
|
||||
|
||||
def test_bootstrapper_can_finalize_unversioned_latest_database(tmp_path: Path) -> None:
|
||||
"""已是最新结构但未写版本号的数据库应直接补写 ``user_version``。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "latest_finalize.db")
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_current_schema(connection)
|
||||
|
||||
migration_state = bootstrapper.prepare_database()
|
||||
bootstrapper.finalize_database(migration_state)
|
||||
|
||||
assert not migration_state.requires_migration()
|
||||
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
|
||||
assert migration_state.resolved_version.source == SchemaVersionSource.DETECTOR
|
||||
|
||||
with engine.connect() as connection:
|
||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||
|
||||
assert recorded_version == LATEST_SCHEMA_VERSION
|
||||
|
||||
|
||||
def test_bootstrapper_can_finalize_empty_database_to_latest_version(tmp_path: Path) -> None:
|
||||
"""空库在建表完成后应回写最新 ``user_version``。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "bootstrap_empty.db")
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
migration_state = bootstrapper.prepare_database()
|
||||
|
||||
assert not migration_state.requires_migration()
|
||||
assert migration_state.resolved_version.version == EMPTY_SCHEMA_VERSION
|
||||
assert migration_state.target_version == LATEST_SCHEMA_VERSION
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_current_schema(connection)
|
||||
|
||||
bootstrapper.finalize_database(migration_state)
|
||||
|
||||
with engine.connect() as connection:
|
||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||
|
||||
assert recorded_version == LATEST_SCHEMA_VERSION
|
||||
|
||||
|
||||
def test_bootstrapper_runs_registered_steps_for_versioned_database(tmp_path: Path) -> None:
|
||||
"""启动桥接器应在已登记旧版本数据库上执行注册迁移步骤。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "bootstrap_registered.db")
|
||||
execution_marks: List[str] = []
|
||||
|
||||
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
||||
"""测试桥接器迁移步骤 ``1 -> 2``。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
execution_marks.append(f"step={context.step_name},index={context.step_index}")
|
||||
context.connection.execute(text("ALTER TABLE bootstrap_records ADD COLUMN email TEXT"))
|
||||
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
text("CREATE TABLE bootstrap_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)")
|
||||
)
|
||||
SQLiteUserVersionStore().write_version(connection, 1)
|
||||
|
||||
registry = MigrationRegistry(
|
||||
steps=[
|
||||
MigrationStep(
|
||||
version_from=1,
|
||||
version_to=2,
|
||||
name="bootstrap_add_email",
|
||||
description="为桥接器测试表增加邮箱字段。",
|
||||
handler=migrate_1_to_2,
|
||||
)
|
||||
]
|
||||
)
|
||||
bootstrapper = DatabaseMigrationBootstrapper(
|
||||
manager=DatabaseMigrationManager(engine=engine, registry=registry),
|
||||
latest_schema_version=2,
|
||||
)
|
||||
|
||||
migration_state = bootstrapper.prepare_database()
|
||||
|
||||
assert migration_state.resolved_version.version == 2
|
||||
assert migration_state.target_version == 2
|
||||
assert execution_marks == ["step=bootstrap_add_email,index=1"]
|
||||
|
||||
with engine.connect() as connection:
|
||||
snapshot = SQLiteSchemaInspector().inspect(connection)
|
||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||
|
||||
assert recorded_version == 2
|
||||
assert snapshot.has_table("bootstrap_records")
|
||||
assert snapshot.has_column("bootstrap_records", "email")
|
||||
|
||||
|
||||
def test_default_bootstrapper_can_migrate_legacy_v1_database(tmp_path: Path) -> None:
|
||||
"""默认桥接器应能把旧版 ``0.x`` 数据库整体迁移到最新结构。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "legacy_v1_to_v2.db")
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_legacy_v1_schema_with_sample_data(connection)
|
||||
|
||||
migration_state = bootstrapper.prepare_database()
|
||||
bootstrapper.finalize_database(migration_state)
|
||||
|
||||
assert not migration_state.requires_migration()
|
||||
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
|
||||
assert migration_state.resolved_version.source == SchemaVersionSource.PRAGMA
|
||||
|
||||
with engine.connect() as connection:
|
||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||
snapshot = SQLiteSchemaInspector().inspect(connection)
|
||||
message_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id, processed_plain_text, additional_config, raw_content
|
||||
FROM mai_messages
|
||||
WHERE message_id = 'msg-1'
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
action_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id, action_name, action_display_prompt
|
||||
FROM action_records
|
||||
WHERE action_id = 'action-1'
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
tool_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id, tool_name, tool_display_prompt
|
||||
FROM tool_records
|
||||
WHERE tool_id = 'action-1'
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
expression_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id, content_list, modified_by
|
||||
FROM expressions
|
||||
WHERE id = 1
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
jargon_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id_dict, raw_content, inference_with_content_only
|
||||
FROM jargons
|
||||
WHERE id = 1
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
|
||||
assert recorded_version == LATEST_SCHEMA_VERSION
|
||||
assert snapshot.has_table("__legacy_v1_messages")
|
||||
assert snapshot.has_table("chat_sessions")
|
||||
assert snapshot.has_table("mai_messages")
|
||||
assert snapshot.has_table("tool_records")
|
||||
|
||||
unpacked_raw_content = msgpack.unpackb(message_row["raw_content"], raw=False)
|
||||
additional_config = json.loads(message_row["additional_config"])
|
||||
expression_content_list = json.loads(expression_row["content_list"])
|
||||
jargon_session_id_dict = json.loads(jargon_row["session_id_dict"])
|
||||
jargon_raw_content = json.loads(jargon_row["raw_content"])
|
||||
|
||||
assert message_row["session_id"] == "session-1"
|
||||
assert message_row["processed_plain_text"] == "你好"
|
||||
assert unpacked_raw_content == [{"type": "text", "data": "你好呀"}]
|
||||
assert additional_config == {"priority_mode": "high", "source": "legacy"}
|
||||
assert action_row["session_id"] == "session-1"
|
||||
assert action_row["action_name"] == "search"
|
||||
assert action_row["action_display_prompt"] == "执行搜索"
|
||||
assert tool_row["session_id"] == "session-1"
|
||||
assert tool_row["tool_name"] == "search"
|
||||
assert tool_row["tool_display_prompt"] == "执行搜索"
|
||||
assert expression_row["session_id"] == "session-1"
|
||||
assert expression_row["modified_by"] == "AI"
|
||||
assert expression_content_list == ["你好呀", "早上好"]
|
||||
assert jargon_session_id_dict == {"session-1": 5}
|
||||
assert jargon_raw_content == ["上分"]
|
||||
assert jargon_row["inference_with_content_only"] == '{"guess":"content"}'
|
||||
|
||||
|
||||
def test_legacy_v1_migration_reports_table_progress(tmp_path: Path) -> None:
|
||||
"""旧版迁移步骤应按目标表数量推进总进度。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "legacy_progress.db")
|
||||
reporter_instances: List[FakeMigrationProgressReporter] = []
|
||||
|
||||
def _build_reporter() -> BaseMigrationProgressReporter:
|
||||
"""构建测试用进度上报器。
|
||||
|
||||
Returns:
|
||||
BaseMigrationProgressReporter: 测试用进度上报器实例。
|
||||
"""
|
||||
reporter = FakeMigrationProgressReporter()
|
||||
reporter_instances.append(reporter)
|
||||
return reporter
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_legacy_v1_schema_with_sample_data(connection)
|
||||
|
||||
manager = DatabaseMigrationManager(
|
||||
engine=engine,
|
||||
registry=build_default_migration_registry(),
|
||||
resolver=build_default_schema_version_resolver(),
|
||||
progress_reporter_factory=_build_reporter,
|
||||
)
|
||||
|
||||
migration_plan = manager.migrate(target_version=LATEST_SCHEMA_VERSION)
|
||||
|
||||
assert migration_plan.step_count() == 1
|
||||
assert len(reporter_instances) == 1
|
||||
reporter_events = reporter_instances[0].events
|
||||
|
||||
assert reporter_events[0] == ("open", None, None, None)
|
||||
assert reporter_events[1] == ("start", 6, 12, "总迁移进度")
|
||||
assert reporter_events[-1] == ("close", None, None, None)
|
||||
assert reporter_events.count(("advance", 1, 0, None)) == 6
|
||||
assert reporter_events.count(("advance", 0, 1, "chat_sessions")) == 1
|
||||
assert reporter_events.count(("advance", 0, 1, "thinking_questions")) == 1
|
||||
assert len([event for event in reporter_events if event[0] == "advance"]) == 18
|
||||
|
||||
|
||||
def test_initialize_database_calls_bootstrapper_before_create_all(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""数据库初始化入口应先准备迁移,再建表、补迁移并收尾。"""
|
||||
call_order: List[str] = []
|
||||
|
||||
def _fake_prepare_database() -> DatabaseMigrationState:
|
||||
"""返回测试用迁移状态。
|
||||
|
||||
Returns:
|
||||
DatabaseMigrationState: 不包含迁移步骤的测试状态。
|
||||
"""
|
||||
call_order.append("prepare_database")
|
||||
return DatabaseMigrationState(
|
||||
resolved_version=ResolvedSchemaVersion(version=0, source=SchemaVersionSource.EMPTY_DATABASE),
|
||||
target_version=LATEST_SCHEMA_VERSION,
|
||||
plan=MigrationPlan(
|
||||
current_version=EMPTY_SCHEMA_VERSION,
|
||||
target_version=LATEST_SCHEMA_VERSION,
|
||||
steps=[],
|
||||
),
|
||||
)
|
||||
|
||||
def _fake_create_all(bind) -> None:
|
||||
"""记录建表调用。
|
||||
|
||||
Args:
|
||||
bind: 传入的数据库绑定对象。
|
||||
"""
|
||||
del bind
|
||||
call_order.append("create_all")
|
||||
|
||||
def _fake_migrate_action_records() -> None:
|
||||
"""记录轻量补迁移调用。"""
|
||||
call_order.append("migrate_action_records")
|
||||
|
||||
def _fake_finalize_database(migration_state: DatabaseMigrationState) -> None:
|
||||
"""记录迁移收尾调用。
|
||||
|
||||
Args:
|
||||
migration_state: 当前数据库迁移状态。
|
||||
"""
|
||||
del migration_state
|
||||
call_order.append("finalize_database")
|
||||
|
||||
monkeypatch.setattr(database_module, "_db_initialized", False)
|
||||
monkeypatch.setattr(database_module, "_DB_DIR", tmp_path / "data")
|
||||
monkeypatch.setattr(database_module._migration_bootstrapper, "prepare_database", _fake_prepare_database)
|
||||
monkeypatch.setattr(database_module._migration_bootstrapper, "finalize_database", _fake_finalize_database)
|
||||
monkeypatch.setattr(database_module.SQLModel.metadata, "create_all", _fake_create_all)
|
||||
monkeypatch.setattr(database_module, "_migrate_action_records_to_tool_records", _fake_migrate_action_records)
|
||||
|
||||
database_module.initialize_database()
|
||||
|
||||
assert call_order == [
|
||||
"prepare_database",
|
||||
"create_all",
|
||||
"migrate_action_records",
|
||||
"finalize_database",
|
||||
]
|
||||
55
pytests/test_maisaka_message_adapter.py
Normal file
55
pytests/test_maisaka_message_adapter.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import sys
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.maisaka.message_adapter import build_message, get_message_kind, get_message_role, get_tool_call_id, get_tool_calls
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
|
||||
def test_build_message_returns_session_message_with_maisaka_metadata() -> None:
|
||||
timestamp = datetime.now()
|
||||
tool_call = ToolCall(
|
||||
call_id="call-1",
|
||||
func_name="reply",
|
||||
args={"message_id": "msg-1"},
|
||||
)
|
||||
raw_message = MessageSequence(components=[TextComponent(text="内部消息内容")])
|
||||
|
||||
message = build_message(
|
||||
role="assistant",
|
||||
content="展示消息内容",
|
||||
message_kind="perception",
|
||||
source="assistant",
|
||||
tool_call_id="call-1",
|
||||
tool_calls=[tool_call],
|
||||
timestamp=timestamp,
|
||||
message_id="maisaka-msg-1",
|
||||
raw_message=raw_message,
|
||||
display_text="展示消息内容",
|
||||
)
|
||||
|
||||
assert isinstance(message, SessionMessage)
|
||||
assert message.initialized is True
|
||||
assert message.message_id == "maisaka-msg-1"
|
||||
assert message.timestamp == timestamp
|
||||
assert message.processed_plain_text == "展示消息内容"
|
||||
assert message.display_message == "展示消息内容"
|
||||
assert message.raw_message is raw_message
|
||||
|
||||
assert get_message_role(message) == "assistant"
|
||||
assert get_message_kind(message) == "perception"
|
||||
assert get_tool_call_id(message) == "call-1"
|
||||
|
||||
restored_tool_calls = get_tool_calls(message)
|
||||
assert len(restored_tool_calls) == 1
|
||||
assert restored_tool_calls[0].call_id == "call-1"
|
||||
assert restored_tool_calls[0].func_name == "reply"
|
||||
assert restored_tool_calls[0].args == {"message_id": "msg-1"}
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
@@ -1831,395 +1832,445 @@ class TestMaiMessages:
|
||||
assert msg.llm_response_content == "new response"
|
||||
|
||||
|
||||
# ─── WorkflowExecutor 测试 ────────────────────────────────
|
||||
class _FakeHookSupervisor:
|
||||
"""用于 Hook 分发测试的简化 Supervisor。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_name: str,
|
||||
component_registry: Any,
|
||||
handlers: Dict[str, Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]] | Dict[str, Any]]],
|
||||
call_log: List[tuple[str, str]],
|
||||
) -> None:
|
||||
"""初始化测试用 Supervisor。
|
||||
|
||||
Args:
|
||||
group_name: 运行时分组名称。
|
||||
component_registry: 组件注册表实例。
|
||||
handlers: 处理器映射,键为 `plugin_id.component_name`。
|
||||
call_log: 记录调用顺序的列表。
|
||||
"""
|
||||
|
||||
self._group_name = group_name
|
||||
self.component_registry = component_registry
|
||||
self._handlers = handlers
|
||||
self._call_log = call_log
|
||||
|
||||
@property
|
||||
def group_name(self) -> str:
|
||||
"""返回当前测试 Supervisor 的分组名称。"""
|
||||
|
||||
return self._group_name
|
||||
|
||||
async def invoke_plugin(
|
||||
self,
|
||||
method: str,
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> SimpleNamespace:
|
||||
"""模拟调用插件组件。
|
||||
|
||||
Args:
|
||||
method: RPC 方法名。
|
||||
plugin_id: 目标插件 ID。
|
||||
component_name: 目标组件名称。
|
||||
args: 调用参数。
|
||||
timeout_ms: 超时配置,测试中仅用于保持接口一致。
|
||||
|
||||
Returns:
|
||||
SimpleNamespace: 仅包含 `payload` 字段的简化响应对象。
|
||||
"""
|
||||
|
||||
del method
|
||||
del timeout_ms
|
||||
|
||||
full_name = f"{plugin_id}.{component_name}"
|
||||
handler = self._handlers[full_name]
|
||||
self._call_log.append((plugin_id, component_name))
|
||||
result = handler(dict(args or {}))
|
||||
if asyncio.iscoroutine(result):
|
||||
result = await result
|
||||
return SimpleNamespace(payload=result)
|
||||
|
||||
|
||||
class TestWorkflowExecutor:
|
||||
"""Host-side Workflow 执行器测试(新 pipeline 模型)"""
|
||||
# ─── HookDispatcher 测试 ────────────────────────────────
|
||||
|
||||
|
||||
class TestHookDispatcher:
|
||||
"""命名 Hook 分发器测试。"""
|
||||
|
||||
@staticmethod
|
||||
def _import_dispatcher_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]:
|
||||
"""导入 Hook 分发相关模块,并屏蔽配置初始化触发的退出。
|
||||
|
||||
Args:
|
||||
monkeypatch: pytest 的 monkeypatch 工具。
|
||||
|
||||
Returns:
|
||||
tuple[Any, Any]: `ComponentRegistry` 与 `HookDispatcher` 类型。
|
||||
"""
|
||||
|
||||
monkeypatch.setattr(sys, "exit", lambda code=0: None)
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.hook_dispatcher import HookDispatcher
|
||||
|
||||
return ComponentRegistry, HookDispatcher
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_pipeline_completes(self):
|
||||
"""无任何 workflow_step 注册时,pipeline 全阶段跳过,状态 completed"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
async def test_empty_hook_returns_original_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""未注册处理器时应直接返回原始参数。"""
|
||||
|
||||
reg = ComponentRegistry()
|
||||
executor = WorkflowExecutor(reg)
|
||||
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
return {"hook_result": "continue"}
|
||||
dispatcher = HookDispatcher()
|
||||
supervisor = _FakeHookSupervisor("builtin", ComponentRegistry(), {}, [])
|
||||
|
||||
result, final_msg, ctx = await executor.execute(
|
||||
mock_invoke,
|
||||
message={"plain_text": "test"},
|
||||
)
|
||||
assert result.status == "completed"
|
||||
assert result.return_message == "workflow completed"
|
||||
assert len(ctx.timings) == 6 # 6 stages
|
||||
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
|
||||
|
||||
assert result.hook_name == "heart_fc.cycle_start"
|
||||
assert result.kwargs == {"session_id": "s-1"}
|
||||
assert result.aborted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocking_hook_modifies_message(self):
|
||||
"""blocking hook 可以修改消息"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
async def test_blocking_hook_modifies_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""blocking 处理器可以修改参数。"""
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component(
|
||||
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
|
||||
|
||||
registry = ComponentRegistry()
|
||||
registry.register_component(
|
||||
"upper",
|
||||
"workflow_step",
|
||||
"HOOK_HANDLER",
|
||||
"p1",
|
||||
{
|
||||
"stage": "pre_process",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
"hook": "heart_fc.cycle_start",
|
||||
"mode": "blocking",
|
||||
"order": "normal",
|
||||
},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
msg = args.get("message", {})
|
||||
return {
|
||||
"hook_result": "continue",
|
||||
"modified_message": {**msg, "plain_text": msg.get("plain_text", "").upper()},
|
||||
}
|
||||
|
||||
result, final_msg, ctx = await executor.execute(
|
||||
mock_invoke,
|
||||
message={"plain_text": "hello"},
|
||||
dispatcher = HookDispatcher()
|
||||
supervisor = _FakeHookSupervisor(
|
||||
"builtin",
|
||||
registry,
|
||||
{
|
||||
"p1.upper": lambda args: {
|
||||
"success": True,
|
||||
"action": "continue",
|
||||
"modified_kwargs": {
|
||||
"session_id": args["session_id"],
|
||||
"text": str(args["text"]).upper(),
|
||||
},
|
||||
}
|
||||
},
|
||||
[],
|
||||
)
|
||||
assert result.status == "completed"
|
||||
assert final_msg["plain_text"] == "HELLO"
|
||||
assert len(ctx.modification_log) == 1
|
||||
assert ctx.modification_log[0].stage == "pre_process"
|
||||
|
||||
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1", text="hello")
|
||||
|
||||
assert result.kwargs["session_id"] == "s-1"
|
||||
assert result.kwargs["text"] == "HELLO"
|
||||
assert result.aborted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_stops_pipeline(self):
|
||||
"""HookResult.ABORT 立即终止 pipeline"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
async def test_abort_stops_following_blocking_handlers(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""blocking 处理器的 abort 应阻止后续 blocking 处理器执行。"""
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component(
|
||||
"blocker",
|
||||
"workflow_step",
|
||||
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
|
||||
|
||||
registry = ComponentRegistry()
|
||||
registry.register_component(
|
||||
"stopper",
|
||||
"HOOK_HANDLER",
|
||||
"p1",
|
||||
{
|
||||
"stage": "pre_process",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
},
|
||||
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
return {"hook_result": "abort"}
|
||||
|
||||
result, _, ctx = await executor.execute(
|
||||
mock_invoke,
|
||||
message={"plain_text": "test"},
|
||||
)
|
||||
assert result.status == "aborted"
|
||||
assert result.stopped_at == "pre_process"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_stage(self):
|
||||
"""HookResult.SKIP_STAGE 跳过当前阶段剩余 hook"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
# high-priority hook 返回 skip_stage
|
||||
reg.register_component(
|
||||
"skipper",
|
||||
"workflow_step",
|
||||
"p1",
|
||||
{
|
||||
"stage": "ingress",
|
||||
"priority": 100,
|
||||
"blocking": True,
|
||||
},
|
||||
)
|
||||
# low-priority hook 不应被执行
|
||||
reg.register_component(
|
||||
"checker",
|
||||
"workflow_step",
|
||||
registry.register_component(
|
||||
"after_stop",
|
||||
"HOOK_HANDLER",
|
||||
"p2",
|
||||
{
|
||||
"stage": "ingress",
|
||||
"priority": 1,
|
||||
"blocking": True,
|
||||
},
|
||||
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"},
|
||||
)
|
||||
call_log: List[tuple[str, str]] = []
|
||||
dispatcher = HookDispatcher()
|
||||
supervisor = _FakeHookSupervisor(
|
||||
"builtin",
|
||||
registry,
|
||||
{
|
||||
"p1.stopper": lambda args: {"success": True, "action": "abort"},
|
||||
"p2.after_stop": lambda args: {"success": True, "action": "continue"},
|
||||
},
|
||||
call_log,
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
call_log = []
|
||||
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], cycle_id="c-1")
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append(comp_name)
|
||||
if comp_name == "skipper":
|
||||
return {"hook_result": "skip_stage"}
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
result, _, _ = await executor.execute(mock_invoke, message={"plain_text": "test"})
|
||||
assert result.status == "completed"
|
||||
# 只有 skipper 被调用,checker 被跳过
|
||||
assert call_log == ["skipper"]
|
||||
assert result.aborted is True
|
||||
assert result.stopped_by == "p1.stopper"
|
||||
assert call_log == [("p1", "stopper")]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_filter(self):
|
||||
"""filter 条件不匹配时跳过 hook"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
async def test_observe_handler_runs_in_background_without_mutation(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""observe 处理器应后台执行且不能影响主流程参数。"""
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component(
|
||||
"only_dm",
|
||||
"workflow_step",
|
||||
"p1",
|
||||
{
|
||||
"stage": "ingress",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
"filter": {"chat_type": "direct"},
|
||||
},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append(comp_name)
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
# 不匹配 filter —— hook 不应被调用
|
||||
await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "group"})
|
||||
assert not call_log
|
||||
|
||||
# 匹配 filter —— hook 应被调用
|
||||
await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "direct"})
|
||||
assert call_log == ["only_dm"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_policy_skip(self):
|
||||
"""error_policy=skip 时跳过失败的 hook 继续执行"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component(
|
||||
"failer",
|
||||
"workflow_step",
|
||||
"p1",
|
||||
{
|
||||
"stage": "ingress",
|
||||
"priority": 100,
|
||||
"blocking": True,
|
||||
"error_policy": "skip",
|
||||
},
|
||||
)
|
||||
reg.register_component(
|
||||
"ok_step",
|
||||
"workflow_step",
|
||||
"p2",
|
||||
{
|
||||
"stage": "ingress",
|
||||
"priority": 1,
|
||||
"blocking": True,
|
||||
},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append(comp_name)
|
||||
if comp_name == "failer":
|
||||
raise RuntimeError("boom")
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"})
|
||||
assert result.status == "completed"
|
||||
assert "failer" in call_log
|
||||
assert "ok_step" in call_log
|
||||
assert any("boom" in e for e in ctx.errors)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_policy_abort(self):
|
||||
"""error_policy=abort(默认)时 pipeline 失败"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component(
|
||||
"failer",
|
||||
"workflow_step",
|
||||
"p1",
|
||||
{
|
||||
"stage": "ingress",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
# error_policy defaults to "abort"
|
||||
},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
raise RuntimeError("fatal")
|
||||
|
||||
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"})
|
||||
assert result.status == "failed"
|
||||
assert result.stopped_at == "ingress"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonblocking_hooks_concurrent(self):
|
||||
"""non-blocking hook 并发执行,不修改消息"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
for i in range(3):
|
||||
reg.register_component(
|
||||
f"nb_{i}",
|
||||
"workflow_step",
|
||||
f"p{i}",
|
||||
{
|
||||
"stage": "post_process",
|
||||
"priority": 0,
|
||||
"blocking": False,
|
||||
},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append(comp_name)
|
||||
return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}}
|
||||
|
||||
result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"})
|
||||
# non-blocking 的 modified_message 被忽略
|
||||
assert final_msg["plain_text"] == "original"
|
||||
# 给异步 task 时间完成
|
||||
await asyncio.sleep(0.1)
|
||||
assert result.status == "completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonblocking_tasks_are_retained_until_completion(self):
|
||||
"""execute 返回后,non-blocking task 仍应保持强引用直到执行完成。"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component(
|
||||
registry = ComponentRegistry()
|
||||
registry.register_component(
|
||||
"observer",
|
||||
"workflow_step",
|
||||
"HOOK_HANDLER",
|
||||
"p1",
|
||||
{
|
||||
"stage": "post_process",
|
||||
"priority": 0,
|
||||
"blocking": False,
|
||||
},
|
||||
{"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
call_log: List[tuple[str, str]] = []
|
||||
|
||||
async def observe_handler(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""模拟耗时观察型处理器。"""
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
started.set()
|
||||
await release.wait()
|
||||
return {"hook_result": "continue"}
|
||||
return {
|
||||
"success": True,
|
||||
"action": "abort",
|
||||
"modified_kwargs": {"session_id": "changed"},
|
||||
"custom_result": args["session_id"],
|
||||
}
|
||||
|
||||
result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"})
|
||||
dispatcher = HookDispatcher()
|
||||
supervisor = _FakeHookSupervisor(
|
||||
"builtin",
|
||||
registry,
|
||||
{"p1.observer": observe_handler},
|
||||
call_log,
|
||||
)
|
||||
|
||||
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
|
||||
|
||||
await asyncio.sleep(0)
|
||||
assert result.status == "completed"
|
||||
assert final_msg["plain_text"] == "original"
|
||||
assert result.aborted is False
|
||||
assert result.kwargs["session_id"] == "s-1"
|
||||
assert started.is_set()
|
||||
assert len(executor._background_tasks) == 1
|
||||
assert len(dispatcher._background_tasks) == 1
|
||||
|
||||
release.set()
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
assert not executor._background_tasks
|
||||
assert call_log == [("p1", "observer")]
|
||||
assert not dispatcher._background_tasks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_routing(self):
|
||||
"""PLAN 阶段内置命令路由"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
async def test_global_order_prefers_order_slot_then_source(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""全局排序应先看 order,再看内置/第三方来源。"""
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component(
|
||||
"help",
|
||||
"command",
|
||||
"p1",
|
||||
{
|
||||
"command_pattern": r"^/help",
|
||||
},
|
||||
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
|
||||
|
||||
builtin_registry = ComponentRegistry()
|
||||
third_registry = ComponentRegistry()
|
||||
builtin_registry.register_component(
|
||||
"builtin_early",
|
||||
"HOOK_HANDLER",
|
||||
"b1",
|
||||
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
|
||||
)
|
||||
builtin_registry.register_component(
|
||||
"builtin_normal",
|
||||
"HOOK_HANDLER",
|
||||
"b1",
|
||||
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"},
|
||||
)
|
||||
third_registry.register_component(
|
||||
"third_early",
|
||||
"HOOK_HANDLER",
|
||||
"t1",
|
||||
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
|
||||
)
|
||||
third_registry.register_component(
|
||||
"third_normal",
|
||||
"HOOK_HANDLER",
|
||||
"t1",
|
||||
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
if comp_name == "help":
|
||||
return {"output": "帮助信息"}
|
||||
return {"hook_result": "continue"}
|
||||
call_log: List[tuple[str, str]] = []
|
||||
dispatcher = HookDispatcher()
|
||||
builtin_supervisor = _FakeHookSupervisor(
|
||||
"builtin",
|
||||
builtin_registry,
|
||||
{
|
||||
"b1.builtin_early": lambda args: {"success": True, "action": "continue"},
|
||||
"b1.builtin_normal": lambda args: {"success": True, "action": "continue"},
|
||||
},
|
||||
call_log,
|
||||
)
|
||||
third_supervisor = _FakeHookSupervisor(
|
||||
"third_party",
|
||||
third_registry,
|
||||
{
|
||||
"t1.third_early": lambda args: {"success": True, "action": "continue"},
|
||||
"t1.third_normal": lambda args: {"success": True, "action": "continue"},
|
||||
},
|
||||
call_log,
|
||||
)
|
||||
|
||||
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "/help topic"})
|
||||
assert result.status == "completed"
|
||||
assert ctx.matched_command == "p1.help"
|
||||
cmd_result = ctx.get_stage_output("plan", "command_result")
|
||||
assert cmd_result is not None
|
||||
assert cmd_result["output"] == "帮助信息"
|
||||
await dispatcher.invoke_hook(
|
||||
"heart_fc.cycle_start",
|
||||
[third_supervisor, builtin_supervisor],
|
||||
cycle_id="c-1",
|
||||
)
|
||||
|
||||
assert call_log == [
|
||||
("b1", "builtin_early"),
|
||||
("t1", "third_early"),
|
||||
("b1", "builtin_normal"),
|
||||
("t1", "third_normal"),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stage_outputs(self):
|
||||
"""stage_outputs 数据在阶段间传递"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
async def test_error_policy_abort_stops_dispatch(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""error_policy=abort 时应中止本次 Hook 调用。"""
|
||||
|
||||
reg = ComponentRegistry()
|
||||
# ingress 阶段写入数据
|
||||
reg.register_component(
|
||||
"writer",
|
||||
"workflow_step",
|
||||
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
|
||||
|
||||
registry = ComponentRegistry()
|
||||
registry.register_component(
|
||||
"failer",
|
||||
"HOOK_HANDLER",
|
||||
"p1",
|
||||
{
|
||||
"stage": "ingress",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
"hook": "heart_fc.cycle_start",
|
||||
"mode": "blocking",
|
||||
"order": "normal",
|
||||
"error_policy": "abort",
|
||||
},
|
||||
)
|
||||
# pre_process 阶段读取数据
|
||||
reg.register_component(
|
||||
"reader",
|
||||
"workflow_step",
|
||||
"p2",
|
||||
call_log: List[tuple[str, str]] = []
|
||||
|
||||
async def fail_handler(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""抛出异常以触发 abort 策略。"""
|
||||
|
||||
del args
|
||||
raise RuntimeError("boom")
|
||||
|
||||
dispatcher = HookDispatcher()
|
||||
supervisor = _FakeHookSupervisor("builtin", registry, {"p1.failer": fail_handler}, call_log)
|
||||
|
||||
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
|
||||
|
||||
assert result.aborted is True
|
||||
assert result.stopped_by == "p1.failer"
|
||||
assert any("boom" in error for error in result.errors)
|
||||
assert call_log == [("p1", "failer")]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_respects_handler_timeout_ms(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""处理器超时应被记录为错误并继续。"""
|
||||
|
||||
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
|
||||
|
||||
registry = ComponentRegistry()
|
||||
registry.register_component(
|
||||
"slow",
|
||||
"HOOK_HANDLER",
|
||||
"p1",
|
||||
{
|
||||
"stage": "pre_process",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
"hook": "heart_fc.cycle_start",
|
||||
"mode": "blocking",
|
||||
"order": "normal",
|
||||
"timeout_ms": 10,
|
||||
},
|
||||
)
|
||||
executor = WorkflowExecutor(reg)
|
||||
call_log: List[tuple[str, str]] = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
if comp_name == "writer":
|
||||
return {
|
||||
"hook_result": "continue",
|
||||
"stage_output": {"parsed_intent": "greeting"},
|
||||
}
|
||||
if comp_name == "reader":
|
||||
# 验证 stage_outputs 被传递过来
|
||||
outputs = args.get("stage_outputs", {})
|
||||
ingress_data = outputs.get("ingress", {})
|
||||
assert ingress_data.get("parsed_intent") == "greeting"
|
||||
return {"hook_result": "continue"}
|
||||
return {"hook_result": "continue"}
|
||||
async def slow_handler(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""模拟超时处理器。"""
|
||||
|
||||
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "hi"})
|
||||
assert result.status == "completed"
|
||||
assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting"
|
||||
del args
|
||||
await asyncio.sleep(0.05)
|
||||
return {"success": True, "action": "continue"}
|
||||
|
||||
dispatcher = HookDispatcher()
|
||||
supervisor = _FakeHookSupervisor("builtin", registry, {"p1.slow": slow_handler}, call_log)
|
||||
|
||||
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
|
||||
|
||||
assert result.aborted is False
|
||||
assert any("超时" in error for error in result.errors)
|
||||
assert call_log == [("p1", "slow")]
|
||||
|
||||
|
||||
class TestPluginRuntimeHookEntry:
|
||||
"""PluginRuntimeManager 命名 Hook 入口测试。"""
|
||||
|
||||
@staticmethod
|
||||
def _import_manager_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]:
|
||||
"""导入运行时管理器相关模块,并屏蔽配置初始化触发的退出。
|
||||
|
||||
Args:
|
||||
monkeypatch: pytest 的 monkeypatch 工具。
|
||||
|
||||
Returns:
|
||||
tuple[Any, Any]: `ComponentRegistry` 与 `PluginRuntimeManager` 类型。
|
||||
"""
|
||||
|
||||
monkeypatch.setattr(sys, "exit", lambda code=0: None)
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.integration import PluginRuntimeManager
|
||||
|
||||
return ComponentRegistry, PluginRuntimeManager
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_invoke_hook_dispatches_across_supervisors(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""PluginRuntimeManager.invoke_hook() 应调用全局 Hook 分发器。"""
|
||||
|
||||
ComponentRegistry, PluginRuntimeManager = self._import_manager_modules(monkeypatch)
|
||||
|
||||
builtin_registry = ComponentRegistry()
|
||||
builtin_registry.register_component(
|
||||
"builtin_guard",
|
||||
"HOOK_HANDLER",
|
||||
"b1",
|
||||
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
|
||||
)
|
||||
third_registry = ComponentRegistry()
|
||||
third_registry.register_component(
|
||||
"observer",
|
||||
"HOOK_HANDLER",
|
||||
"t1",
|
||||
{"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"},
|
||||
)
|
||||
|
||||
call_log: List[tuple[str, str]] = []
|
||||
manager = PluginRuntimeManager()
|
||||
manager._started = True
|
||||
manager._builtin_supervisor = _FakeHookSupervisor(
|
||||
"builtin",
|
||||
builtin_registry,
|
||||
{"b1.builtin_guard": lambda args: {"success": True, "action": "continue"}},
|
||||
call_log,
|
||||
)
|
||||
manager._third_party_supervisor = _FakeHookSupervisor(
|
||||
"third_party",
|
||||
third_registry,
|
||||
{"t1.observer": lambda args: {"success": True, "action": "continue"}},
|
||||
call_log,
|
||||
)
|
||||
|
||||
result = await manager.invoke_dispatcher.invoke_hook("heart_fc.cycle_start", session_id="s-1")
|
||||
|
||||
await asyncio.sleep(0)
|
||||
assert manager.invoke_dispatcher is manager.hook_dispatcher
|
||||
assert result.aborted is False
|
||||
assert result.kwargs["session_id"] == "s-1"
|
||||
assert ("b1", "builtin_guard") in call_log
|
||||
|
||||
|
||||
class TestRPCServer:
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, List
|
||||
import pytest
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.services import send_service
|
||||
|
||||
|
||||
@@ -13,42 +14,18 @@ class _FakePlatformIOManager:
|
||||
"""用于测试的 Platform IO 管理器假对象。"""
|
||||
|
||||
def __init__(self, delivery_batch: Any) -> None:
|
||||
"""初始化假 Platform IO 管理器。
|
||||
|
||||
Args:
|
||||
delivery_batch: 发送时返回的批量回执。
|
||||
"""
|
||||
self._delivery_batch = delivery_batch
|
||||
self.ensure_calls = 0
|
||||
self.sent_messages: List[Dict[str, Any]] = []
|
||||
|
||||
async def ensure_send_pipeline_ready(self) -> None:
|
||||
"""记录发送管线准备调用次数。"""
|
||||
self.ensure_calls += 1
|
||||
|
||||
def build_route_key_from_message(self, message: Any) -> Any:
|
||||
"""根据消息构造假的路由键。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部消息对象。
|
||||
|
||||
Returns:
|
||||
Any: 简化后的路由键对象。
|
||||
"""
|
||||
del message
|
||||
return SimpleNamespace(platform="qq")
|
||||
|
||||
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
|
||||
"""记录发送请求并返回预设回执。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部消息对象。
|
||||
route_key: 本次发送使用的路由键。
|
||||
metadata: 发送元数据。
|
||||
|
||||
Returns:
|
||||
Any: 预设的批量发送回执。
|
||||
"""
|
||||
self.sent_messages.append(
|
||||
{
|
||||
"message": message,
|
||||
@@ -59,12 +36,7 @@ class _FakePlatformIOManager:
|
||||
return self._delivery_batch
|
||||
|
||||
|
||||
def _build_target_stream() -> BotChatSession:
|
||||
"""构造一个最小可用的目标会话对象。
|
||||
|
||||
Returns:
|
||||
BotChatSession: 测试用会话对象。
|
||||
"""
|
||||
def _build_private_stream() -> BotChatSession:
|
||||
return BotChatSession(
|
||||
session_id="test-session",
|
||||
platform="qq",
|
||||
@@ -73,14 +45,21 @@ def _build_target_stream() -> BotChatSession:
|
||||
)
|
||||
|
||||
|
||||
def _build_group_stream() -> BotChatSession:
|
||||
return BotChatSession(
|
||||
session_id="group-session",
|
||||
platform="qq",
|
||||
user_id="target-user",
|
||||
group_id="target-group",
|
||||
)
|
||||
|
||||
|
||||
def test_inherit_platform_io_route_metadata_falls_back_to_bot_account(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""没有上下文消息时,也应回填当前平台账号用于账号级路由命中。"""
|
||||
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "")
|
||||
|
||||
metadata = send_service._inherit_platform_io_route_metadata(_build_target_stream())
|
||||
metadata = send_service._inherit_platform_io_route_metadata(_build_private_stream())
|
||||
|
||||
assert metadata["platform_io_account_id"] == "bot-qq"
|
||||
assert metadata["platform_io_target_user_id"] == "target-user"
|
||||
@@ -88,7 +67,6 @@ def test_inherit_platform_io_route_metadata_falls_back_to_bot_account(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""send service 应将发送职责统一交给 Platform IO。"""
|
||||
fake_manager = _FakePlatformIOManager(
|
||||
delivery_batch=SimpleNamespace(
|
||||
has_success=True,
|
||||
@@ -104,7 +82,7 @@ async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.Monke
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: _build_target_stream() if stream_id == "test-session" else None,
|
||||
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
send_service.MessageUtils,
|
||||
@@ -123,7 +101,6 @@ async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.Monke
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Platform IO 批量发送全部失败时,应直接向上返回失败。"""
|
||||
fake_manager = _FakePlatformIOManager(
|
||||
delivery_batch=SimpleNamespace(
|
||||
has_success=False,
|
||||
@@ -144,7 +121,7 @@ async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch:
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: _build_target_stream() if stream_id == "test-session" else None,
|
||||
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
|
||||
)
|
||||
|
||||
result = await send_service.text_to_stream(text="发送失败", stream_id="test-session")
|
||||
@@ -152,3 +129,63 @@ async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch:
|
||||
assert result is False
|
||||
assert fake_manager.ensure_calls == 1
|
||||
assert len(fake_manager.sent_messages) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_private_outbound_message_preserves_bot_sender_and_receiver_user(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
|
||||
)
|
||||
|
||||
outbound_message = send_service._build_outbound_session_message(
|
||||
message_sequence=MessageSequence(components=[TextComponent(text="你好")]),
|
||||
stream_id="test-session",
|
||||
display_message="你好",
|
||||
)
|
||||
|
||||
assert outbound_message is not None
|
||||
maim_message = await outbound_message.to_maim_message()
|
||||
|
||||
assert maim_message.message_info.user_info is not None
|
||||
assert maim_message.message_info.user_info.user_id == "bot-qq"
|
||||
assert maim_message.message_info.group_info is None
|
||||
assert maim_message.message_info.sender_info is not None
|
||||
assert maim_message.message_info.sender_info.user_info is not None
|
||||
assert maim_message.message_info.sender_info.user_info.user_id == "bot-qq"
|
||||
assert maim_message.message_info.receiver_info is not None
|
||||
assert maim_message.message_info.receiver_info.user_info is not None
|
||||
assert maim_message.message_info.receiver_info.user_info.user_id == "target-user"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_outbound_message_preserves_bot_sender_and_target_group(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: _build_group_stream() if stream_id == "group-session" else None,
|
||||
)
|
||||
|
||||
outbound_message = send_service._build_outbound_session_message(
|
||||
message_sequence=MessageSequence(components=[TextComponent(text="大家好")]),
|
||||
stream_id="group-session",
|
||||
display_message="大家好",
|
||||
)
|
||||
|
||||
assert outbound_message is not None
|
||||
maim_message = await outbound_message.to_maim_message()
|
||||
|
||||
assert maim_message.message_info.user_info is not None
|
||||
assert maim_message.message_info.user_info.user_id == "bot-qq"
|
||||
assert maim_message.message_info.group_info is not None
|
||||
assert maim_message.message_info.group_info.group_id == "target-group"
|
||||
assert maim_message.message_info.receiver_info is not None
|
||||
assert maim_message.message_info.receiver_info.group_info is not None
|
||||
assert maim_message.message_info.receiver_info.group_info.group_id == "target-group"
|
||||
|
||||
Reference in New Issue
Block a user